Refactoring: Moved some code from headers to dedicated source files

This commit is contained in:
DcruBro
2025-11-10 23:19:39 +01:00
parent 9252425bdf
commit 4b4451d1a9
7 changed files with 588 additions and 538 deletions

View File

@@ -38,237 +38,16 @@ namespace ColumnLynx::Net::TCP {
mLastHeartbeatSent(std::chrono::steady_clock::now()) mLastHeartbeatSent(std::chrono::steady_clock::now())
{} {}
void start() { void start();
auto self = shared_from_this(); void sendMessage(ClientMessageType type, const std::string& data = "");
mResolver.async_resolve(mHost, mPort, void disconnect(bool echo = true);
[this, self](asio::error_code ec, tcp::resolver::results_type endpoints) {
if (!ec) {
asio::async_connect(mSocket, endpoints,
[this, self](asio::error_code ec, const tcp::endpoint&) {
if (!NetHelper::isExpectedDisconnect(ec)) {
mConnected = true;
Utils::log("Client connected.");
mHandler = std::make_shared<MessageHandler>(std::move(mSocket));
mHandler->onMessage([this](AnyMessageType type, const std::string& data) {
mHandleMessage(static_cast<ServerMessageType>(MessageHandler::toUint8(type)), data);
});
mHandler->start();
// Init connection handshake bool isHandshakeComplete() const;
Utils::log("Sending handshake init to server."); bool isConnected() const;
std::vector<uint8_t> payload;
payload.reserve(1 + crypto_box_PUBLICKEYBYTES);
payload.push_back(Utils::protocolVersion());
payload.insert(payload.end(),
mLibSodiumWrapper->getXPublicKey(),
mLibSodiumWrapper->getXPublicKey() + crypto_box_PUBLICKEYBYTES
);
mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, Utils::uint8ArrayToString(payload.data(), payload.size()));
mStartHeartbeat();
} else {
Utils::error("Client connect failed: " + ec.message());
}
});
} else {
Utils::error("Client resolve failed: " + ec.message());
}
});
}
void sendMessage(ClientMessageType type, const std::string& data = "") {
if (!mConnected) {
Utils::error("Cannot send message, client not connected.");
return;
}
if (mHandler) {
asio::post(mHandler->socket().get_executor(), [self = shared_from_this(), type, data]() {
self->mHandler->sendMessage(type, data);
});
}
}
void disconnect(bool echo = true) {
if (mConnected && mHandler) {
if (echo) {
mHandler->sendMessage(ClientMessageType::GRACEFUL_DISCONNECT, "Goodbye");
}
asio::error_code ec;
mHeartbeatTimer.cancel();
mHandler->socket().shutdown(tcp::socket::shutdown_both, ec);
if (ec) {
Utils::error("Error during socket shutdown: " + ec.message());
}
mHandler->socket().close(ec);
if (ec) {
Utils::error("Error during socket close: " + ec.message());
}
mConnected = false;
Utils::log("Client disconnected.");
}
}
bool isHandshakeComplete() const {
return mHandshakeComplete;
}
bool isConnected() const {
return mConnected;
}
private: private:
void mStartHeartbeat() { void mStartHeartbeat();
auto self = shared_from_this(); void mHandleMessage(ServerMessageType type, const std::string& data);
mHeartbeatTimer.expires_after(std::chrono::seconds(5));
mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) {
if (ec == asio::error::operation_aborted) {
return; // Timer was cancelled
}
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - self->mLastHeartbeatReceived).count();
if (elapsed >= 15) { // 3 missed heartbeats
Utils::error("Missed 3 heartbeats. I think the other party might have died! Disconnecting.");
// Close sockets forcefully, server is dead
asio::error_code ec;
mHandler->socket().shutdown(tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
mConnected = false;
mGlobalKeyRef = nullptr;
if (mSessionIDRef) {
*mSessionIDRef = 0;
}
return;
}
self->sendMessage(ClientMessageType::HEARTBEAT);
Utils::log("Sent HEARTBEAT to server.");
self->mLastHeartbeatSent = std::chrono::steady_clock::now();
self->mStartHeartbeat(); // Recursive
});
}
void mHandleMessage(ServerMessageType type, const std::string& data) {
switch (type) {
case ServerMessageType::HANDSHAKE_IDENTIFY:
Utils::log("Received server identity: " + data);
std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey)));
// Generate and send challenge
{
Utils::log("Sending challenge to server.");
mSubmittedChallenge = Utils::LibSodiumWrapper::generateRandom256Bit(); // Temporarily store the challenge to verify later
mHandler->sendMessage(ClientMessageType::HANDSHAKE_CHALLENGE, Utils::uint8ArrayToString(mSubmittedChallenge));
}
break;
case ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE:
Utils::log("Received challenge response from server.");
{
// Verify the signature
Signature sig{};
std::memcpy(sig.data(), data.data(), std::min(data.size(), sig.size()));
if (Utils::LibSodiumWrapper::verifyMessage(mSubmittedChallenge.data(), mSubmittedChallenge.size(), sig, mServerPublicKey)) {
Utils::log("Challenge response verified successfully.");
// Convert the server's public key to Curve25519 for encryption
AsymPublicKey serverXPubKey{};
crypto_sign_ed25519_pk_to_curve25519(serverXPubKey.data(), mServerPublicKey);
// Generate AES key and send confirmation
mConnectionAESKey = Utils::LibSodiumWrapper::generateRandom256Bit();
if (mGlobalKeyRef) { // Copy to the global reference
std::copy(mConnectionAESKey.begin(), mConnectionAESKey.end(), mGlobalKeyRef->begin());
}
AsymNonce nonce{};
randombytes_buf(nonce.data(), nonce.size());
// TODO: This is pretty redundant, it should return the required type directly
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
std::vector<uint8_t> encr = Utils::LibSodiumWrapper::encryptAsymmetric(
mConnectionAESKey.data(), mConnectionAESKey.size(),
nonce,
serverXPubKey,
arrayPrivateKey
);
std::vector<uint8_t> payload;
payload.reserve(nonce.size() + encr.size());
payload.insert(payload.end(), nonce.begin(), nonce.end());
payload.insert(payload.end(), encr.begin(), encr.end());
mHandler->sendMessage(ClientMessageType::HANDSHAKE_EXCHANGE_KEY, Utils::uint8ArrayToString(payload.data(), payload.size()));
} else {
Utils::error("Challenge response verification failed. Terminating connection.");
disconnect();
}
}
break;
case ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM:
Utils::log("Received handshake exchange key confirmation from server.");
// Decrypt the session ID using the established AES key
{
Nonce symNonce{}; // All zeros
std::vector<uint8_t> ciphertext(data.begin(), data.end());
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptMessage(
ciphertext.data(), ciphertext.size(),
mConnectionAESKey, symNonce
);
if (decrypted.size() != sizeof(mConnectionSessionID)) {
Utils::error("Decrypted session ID has invalid size. Terminating connection.");
disconnect();
return;
}
std::memcpy(&mConnectionSessionID, decrypted.data(), sizeof(mConnectionSessionID));
Utils::log("Connection established with Session ID: " + std::to_string(mConnectionSessionID));
if (mSessionIDRef) { // Copy to the global reference
*mSessionIDRef = mConnectionSessionID;
}
mHandshakeComplete = true;
}
break;
case ServerMessageType::HEARTBEAT:
Utils::log("Received HEARTBEAT from server.");
mHandler->sendMessage(ClientMessageType::HEARTBEAT_ACK, ""); // Send ACK
break;
case ServerMessageType::HEARTBEAT_ACK:
Utils::log("Received HEARTBEAT_ACK from server.");
mLastHeartbeatReceived = std::chrono::steady_clock::now();
mMissedHeartbeats = 0; // Reset missed heartbeat count
break;
case ServerMessageType::GRACEFUL_DISCONNECT:
Utils::log("Server is disconnecting: " + data);
if (mConnected) { // Prevent Recursion
disconnect(false);
}
break;
default:
Utils::log("Received unknown message type from server.");
break;
}
}
bool mConnected = false; bool mConnected = false;
bool mHandshakeComplete = false; bool mHandshakeComplete = false;

View File

@@ -20,110 +20,13 @@ namespace ColumnLynx::Net::UDP {
uint64_t* sessionIDRef) uint64_t* sessionIDRef)
: mSocket(ioContext), mResolver(ioContext), mHost(host), mPort(port), mAesKeyRef(aesKeyRef), mSessionIDRef(sessionIDRef) { mStartReceive(); } : mSocket(ioContext), mResolver(ioContext), mHost(host), mPort(port), mAesKeyRef(aesKeyRef), mSessionIDRef(sessionIDRef) { mStartReceive(); }
void start() { void start();
auto endpoints = mResolver.resolve(asio::ip::udp::v4(), mHost, mPort); void sendMessage(const std::string& data = "");
mRemoteEndpoint = *endpoints.begin(); void stop();
mSocket.open(asio::ip::udp::v4());
Utils::log("UDP Client ready to send to " + mRemoteEndpoint.address().to_string() + ":" + std::to_string(mRemoteEndpoint.port()));
}
void sendMessage(const std::string& data = "") {
UDPPacketHeader hdr{};
randombytes_buf(hdr.nonce.data(), hdr.nonce.size());
if (mAesKeyRef == nullptr || mSessionIDRef == nullptr) {
Utils::error("UDP Client AES key or Session ID reference is null!");
return;
}
auto encryptedPayload = Utils::LibSodiumWrapper::encryptMessage(
reinterpret_cast<const uint8_t*>(data.data()), data.size(),
*mAesKeyRef, hdr.nonce, "udp-data"
);
std::vector<uint8_t> packet;
packet.reserve(sizeof(UDPPacketHeader) + sizeof(uint64_t) + encryptedPayload.size());
packet.insert(packet.end(),
reinterpret_cast<uint8_t*>(&hdr),
reinterpret_cast<uint8_t*>(&hdr) + sizeof(UDPPacketHeader)
);
uint64_t sid = *mSessionIDRef;
packet.insert(packet.end(),
reinterpret_cast<uint8_t*>(&sid),
reinterpret_cast<uint8_t*>(&sid) + sizeof(sid)
);
packet.insert(packet.end(), encryptedPayload.begin(), encryptedPayload.end());
mSocket.send_to(asio::buffer(packet), mRemoteEndpoint);
Utils::log("Sent UDP packet of size " + std::to_string(packet.size()));
}
void stop() {
if (mSocket.is_open()) {
asio::error_code ec;
mSocket.cancel(ec);
mSocket.close(ec);
Utils::log("UDP Client socket closed.");
}
}
private: private:
void mStartReceive() { void mStartReceive();
mSocket.async_receive_from( void mHandlePacket(std::size_t bytes);
asio::buffer(mRecvBuffer), mRemoteEndpoint,
[this](asio::error_code ec, std::size_t bytes) {
if (ec) {
if (ec == asio::error::operation_aborted) return; // Socket closed
// Other recv error
mStartReceive();
return;
}
if (bytes > 0) {
mHandlePacket(bytes);
}
mStartReceive();
}
);
}
void mHandlePacket(std::size_t bytes) {
if (bytes < sizeof(UDPPacketHeader) + sizeof(uint64_t)) {
Utils::warn("UDP Client received packet too small to process.");
return;
}
// Parse header
UDPPacketHeader hdr;
std::memcpy(&hdr, mRecvBuffer.data(), sizeof(UDPPacketHeader));
// Parse session ID
uint64_t sessionID;
std::memcpy(&sessionID, mRecvBuffer.data() + sizeof(UDPPacketHeader), sizeof(uint64_t));
// Decrypt payload
std::vector<uint8_t> ciphertext(
mRecvBuffer.begin() + sizeof(UDPPacketHeader) + sizeof(uint64_t),
mRecvBuffer.begin() + bytes
);
if (mAesKeyRef == nullptr) {
Utils::error("UDP Client AES key reference is null!");
return;
}
std::vector<uint8_t> plaintext = Utils::LibSodiumWrapper::decryptMessage(
ciphertext.data(), ciphertext.size(), *mAesKeyRef, hdr.nonce, "udp-data"
);
if (plaintext.empty()) {
Utils::warn("UDP Client failed to decrypt received packet.");
return;
}
Utils::log("UDP Client received packet from " + mRemoteEndpoint.address().to_string() + " - Packet size: " + std::to_string(bytes));
}
asio::ip::udp::socket mSocket; asio::ip::udp::socket mSocket;
asio::ip::udp::resolver mResolver; asio::ip::udp::resolver mResolver;

View File

@@ -32,56 +32,13 @@ namespace ColumnLynx::Net::TCP {
return conn; return conn;
} }
void start() { void start();
mHandler->onMessage([this](AnyMessageType type, const std::string& data) { void sendMessage(ServerMessageType type, const std::string& data = "");
mHandleMessage(static_cast<ClientMessageType>(MessageHandler::toUint8(type)), data); void setDisconnectCallback(std::function<void(std::shared_ptr<TCPConnection>)> cb);
}); void disconnect();
mHandler->onDisconnect([this](const asio::error_code& ec) { uint64_t getSessionID() const;
Utils::log("Client disconnected: " + mHandler->socket().remote_endpoint().address().to_string() + " - " + ec.message()); std::array<uint8_t, 32> getAESKey() const;
disconnect();
});
mHandler->start();
mStartHeartbeat();
// Placeholder for message handling setup
Utils::log("Client connected: " + mHandler->socket().remote_endpoint().address().to_string());
}
void sendMessage(ServerMessageType type, const std::string& data = "") {
if (mHandler) {
mHandler->sendMessage(type, data);
}
}
void setDisconnectCallback(std::function<void(std::shared_ptr<TCPConnection>)> cb) {
mOnDisconnect = std::move(cb);
}
void disconnect() {
std::string ip = mHandler->socket().remote_endpoint().address().to_string();
mHandler->sendMessage(ServerMessageType::GRACEFUL_DISCONNECT, "Server initiated disconnect.");
mHeartbeatTimer.cancel();
asio::error_code ec;
mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
Utils::log("Closed connection to " + ip);
if (mOnDisconnect) {
mOnDisconnect(shared_from_this());
}
}
uint64_t getSessionID() const {
return mConnectionSessionID;
}
std::array<uint8_t, 32> getAESKey() const {
return mConnectionAESKey;
}
private: private:
TCPConnection(asio::ip::tcp::socket socket, Utils::LibSodiumWrapper* sodiumWrapper) TCPConnection(asio::ip::tcp::socket socket, Utils::LibSodiumWrapper* sodiumWrapper)
@@ -93,164 +50,8 @@ namespace ColumnLynx::Net::TCP {
mLastHeartbeatSent(std::chrono::steady_clock::now()) mLastHeartbeatSent(std::chrono::steady_clock::now())
{} {}
void mStartHeartbeat() { void mStartHeartbeat();
auto self = shared_from_this(); void mHandleMessage(ClientMessageType type, const std::string& data);
mHeartbeatTimer.expires_after(std::chrono::seconds(5));
mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) {
if (ec == asio::error::operation_aborted) {
return; // Timer was cancelled
}
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - self->mLastHeartbeatReceived).count();
if (elapsed >= 15) { // 3 missed heartbeats
Utils::error("Missed 3 heartbeats. I think the other party (client " + std::to_string(self->mConnectionSessionID) + ") might have died! Disconnecting.");
// Remove socket forcefully, client is dead
asio::error_code ec;
mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
SessionRegistry::getInstance().erase(self->mConnectionSessionID);
return;
}
self->sendMessage(ServerMessageType::HEARTBEAT);
Utils::log("Sent HEARTBEAT to client " + std::to_string(self->mConnectionSessionID));
self->mLastHeartbeatSent = now;
self->mStartHeartbeat(); // Recursive
});
}
void mHandleMessage(ClientMessageType type, const std::string& data) {
std::string reqAddr = mHandler->socket().remote_endpoint().address().to_string();
switch (type) {
case ClientMessageType::HANDSHAKE_INIT: {
Utils::log("Received HANDSHAKE_INIT from " + reqAddr);
if (data.size() < 1 + crypto_box_PUBLICKEYBYTES) {
Utils::warn("HANDSHAKE_INIT from " + reqAddr + " is too short.");
disconnect();
return;
}
uint8_t clientProtoVer = static_cast<uint8_t>(data[0]);
if (clientProtoVer != Utils::protocolVersion()) {
Utils::warn("Client protocol version mismatch from " + reqAddr + ". Expected " +
std::to_string(Utils::protocolVersion()) + ", got " + std::to_string(clientProtoVer) + ".");
disconnect();
return;
}
Utils::log("Client protocol version " + std::to_string(clientProtoVer) + " accepted from " + reqAddr + ".");
std::memcpy(mConnectionPublicKey.data(), data.data() + 1, std::min(data.size() - 1, sizeof(mConnectionPublicKey))); // Store the client's public key (for identification)
mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist
break;
}
case ClientMessageType::HANDSHAKE_CHALLENGE: {
Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr);
// Convert to byte array
uint8_t challengeData[32];
std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData)));
// Sign the challenge
Signature sig = Utils::LibSodiumWrapper::signMessage(
challengeData, sizeof(challengeData),
mLibSodiumWrapper->getPrivateKey()
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response
break;
}
case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: {
Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr);
// Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext)
if (data.size() < 24) { // Minimum size check (nonce)
Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short.");
disconnect();
return;
}
AsymNonce nonce{};
std::memcpy(nonce.data(), data.data(), nonce.size());
std::vector<uint8_t> ciphertext(data.size() - nonce.size());
std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size());
try {
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
// Decrypt the AES key using the client's public key and server's private key
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptAsymmetric(
ciphertext.data(), ciphertext.size(),
nonce,
mConnectionPublicKey,
arrayPrivateKey
);
if (decrypted.size() != 32) {
Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size.");
disconnect();
return;
}
std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size());
// Make a Session ID
randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID));
// TODO: Make the session ID little-endian for network transmission
// Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose)
Nonce symNonce{}; // All zeros
std::vector<uint8_t> encryptedSessionID = Utils::LibSodiumWrapper::encryptMessage(
reinterpret_cast<uint8_t*>(&mConnectionSessionID), sizeof(mConnectionSessionID),
mConnectionAESKey, symNonce
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedSessionID.data(), encryptedSessionID.size()));
// Add to session registry
Utils::log("Handshake with " + reqAddr + " completed successfully. Session ID assigned.");
auto session = std::make_shared<SessionState>(mConnectionAESKey, std::chrono::hours(12));
SessionRegistry::getInstance().put(mConnectionSessionID, std::move(session));
} catch (const std::exception& e) {
Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what());
disconnect();
}
break;
}
case ClientMessageType::HEARTBEAT: {
Utils::log("Received HEARTBEAT from " + reqAddr);
mHandler->sendMessage(ServerMessageType::HEARTBEAT_ACK, ""); // Send ACK
break;
}
case ClientMessageType::HEARTBEAT_ACK: {
Utils::log("Received HEARTBEAT_ACK from " + reqAddr);
mLastHeartbeatReceived = std::chrono::steady_clock::now();
mMissedHeartbeats = 0; // Reset missed heartbeat count
break;
}
case ClientMessageType::GRACEFUL_DISCONNECT: {
Utils::log("Received GRACEFUL_DISCONNECT from " + reqAddr + ": " + data);
disconnect();
break;
}
default:
Utils::warn("Unhandled message type from " + reqAddr);
break;
}
}
std::shared_ptr<MessageHandler> mHandler; std::shared_ptr<MessageHandler> mHandler;
std::function<void(std::shared_ptr<TCPConnection>)> mOnDisconnect; std::function<void(std::shared_ptr<TCPConnection>)> mOnDisconnect;

View File

@@ -74,7 +74,8 @@ int main(int argc, char** argv) {
log("Client connected to " + host + ":" + port); log("Client connected to " + host + ":" + port);
// Client is running // Client is running
while ((!done && client->isConnected()) || !client->isHandshakeComplete()) { // TODO: SIGINT or SIGTERM seems to not kill this instantly!
while (client->isConnected() || !client->isHandshakeComplete() || !done) {
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Temp wait std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Temp wait
if (client->isHandshakeComplete()) { if (client->isHandshakeComplete()) {

View File

@@ -0,0 +1,238 @@
// tcp_client.cpp - TCP Client for ColumnLynx
// Copyright (C) 2025 DcruBro
// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details.
#include <columnlynx/client/net/tcp/tcp_client.hpp>
namespace ColumnLynx::Net::TCP {
void TCPClient::start() {
auto self = shared_from_this();
mResolver.async_resolve(mHost, mPort,
[this, self](asio::error_code ec, tcp::resolver::results_type endpoints) {
if (!ec) {
asio::async_connect(mSocket, endpoints,
[this, self](asio::error_code ec, const tcp::endpoint&) {
if (!NetHelper::isExpectedDisconnect(ec)) {
mConnected = true;
Utils::log("Client connected.");
mHandler = std::make_shared<MessageHandler>(std::move(mSocket));
mHandler->onMessage([this](AnyMessageType type, const std::string& data) {
mHandleMessage(static_cast<ServerMessageType>(MessageHandler::toUint8(type)), data);
});
mHandler->start();
// Init connection handshake
Utils::log("Sending handshake init to server.");
std::vector<uint8_t> payload;
payload.reserve(1 + crypto_box_PUBLICKEYBYTES);
payload.push_back(Utils::protocolVersion());
payload.insert(payload.end(),
mLibSodiumWrapper->getXPublicKey(),
mLibSodiumWrapper->getXPublicKey() + crypto_box_PUBLICKEYBYTES
);
mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, Utils::uint8ArrayToString(payload.data(), payload.size()));
mStartHeartbeat();
} else {
Utils::error("Client connect failed: " + ec.message());
}
});
} else {
Utils::error("Client resolve failed: " + ec.message());
}
});
}
void TCPClient::sendMessage(ClientMessageType type, const std::string& data) {
if (!mConnected) {
Utils::error("Cannot send message, client not connected.");
return;
}
if (mHandler) {
asio::post(mHandler->socket().get_executor(), [self = shared_from_this(), type, data]() {
self->mHandler->sendMessage(type, data);
});
}
}
void TCPClient::disconnect(bool echo) {
if (mConnected && mHandler) {
if (echo) {
mHandler->sendMessage(ClientMessageType::GRACEFUL_DISCONNECT, "Goodbye");
}
asio::error_code ec;
mHeartbeatTimer.cancel();
mHandler->socket().shutdown(tcp::socket::shutdown_both, ec);
if (ec) {
Utils::error("Error during socket shutdown: " + ec.message());
}
mHandler->socket().close(ec);
if (ec) {
Utils::error("Error during socket close: " + ec.message());
}
mConnected = false;
Utils::log("Client disconnected.");
}
}
bool TCPClient::isHandshakeComplete() const {
return mHandshakeComplete;
}
bool TCPClient::isConnected() const {
return mConnected;
}
void TCPClient::mStartHeartbeat() {
auto self = shared_from_this();
mHeartbeatTimer.expires_after(std::chrono::seconds(5));
mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) {
if (ec == asio::error::operation_aborted) {
return; // Timer was cancelled
}
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - self->mLastHeartbeatReceived).count();
if (elapsed >= 15) { // 3 missed heartbeats
Utils::error("Missed 3 heartbeats. I think the other party might have died! Disconnecting.");
// Close sockets forcefully, server is dead
asio::error_code ec;
mHandler->socket().shutdown(tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
mConnected = false;
mGlobalKeyRef = nullptr;
if (mSessionIDRef) {
*mSessionIDRef = 0;
}
return;
}
self->sendMessage(ClientMessageType::HEARTBEAT);
Utils::log("Sent HEARTBEAT to server.");
self->mLastHeartbeatSent = std::chrono::steady_clock::now();
self->mStartHeartbeat(); // Recursive
});
}
void TCPClient::mHandleMessage(ServerMessageType type, const std::string& data) {
switch (type) {
case ServerMessageType::HANDSHAKE_IDENTIFY:
Utils::log("Received server identity: " + data);
std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey)));
// Generate and send challenge
{
Utils::log("Sending challenge to server.");
mSubmittedChallenge = Utils::LibSodiumWrapper::generateRandom256Bit(); // Temporarily store the challenge to verify later
mHandler->sendMessage(ClientMessageType::HANDSHAKE_CHALLENGE, Utils::uint8ArrayToString(mSubmittedChallenge));
}
break;
case ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE:
Utils::log("Received challenge response from server.");
{
// Verify the signature
Signature sig{};
std::memcpy(sig.data(), data.data(), std::min(data.size(), sig.size()));
if (Utils::LibSodiumWrapper::verifyMessage(mSubmittedChallenge.data(), mSubmittedChallenge.size(), sig, mServerPublicKey)) {
Utils::log("Challenge response verified successfully.");
// Convert the server's public key to Curve25519 for encryption
AsymPublicKey serverXPubKey{};
crypto_sign_ed25519_pk_to_curve25519(serverXPubKey.data(), mServerPublicKey);
// Generate AES key and send confirmation
mConnectionAESKey = Utils::LibSodiumWrapper::generateRandom256Bit();
if (mGlobalKeyRef) { // Copy to the global reference
std::copy(mConnectionAESKey.begin(), mConnectionAESKey.end(), mGlobalKeyRef->begin());
}
AsymNonce nonce{};
randombytes_buf(nonce.data(), nonce.size());
// TODO: This is pretty redundant, it should return the required type directly
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
std::vector<uint8_t> encr = Utils::LibSodiumWrapper::encryptAsymmetric(
mConnectionAESKey.data(), mConnectionAESKey.size(),
nonce,
serverXPubKey,
arrayPrivateKey
);
std::vector<uint8_t> payload;
payload.reserve(nonce.size() + encr.size());
payload.insert(payload.end(), nonce.begin(), nonce.end());
payload.insert(payload.end(), encr.begin(), encr.end());
mHandler->sendMessage(ClientMessageType::HANDSHAKE_EXCHANGE_KEY, Utils::uint8ArrayToString(payload.data(), payload.size()));
} else {
Utils::error("Challenge response verification failed. Terminating connection.");
disconnect();
}
}
break;
case ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM:
Utils::log("Received handshake exchange key confirmation from server.");
// Decrypt the session ID using the established AES key
{
Nonce symNonce{}; // All zeros
std::vector<uint8_t> ciphertext(data.begin(), data.end());
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptMessage(
ciphertext.data(), ciphertext.size(),
mConnectionAESKey, symNonce
);
if (decrypted.size() != sizeof(mConnectionSessionID)) {
Utils::error("Decrypted session ID has invalid size. Terminating connection.");
disconnect();
return;
}
std::memcpy(&mConnectionSessionID, decrypted.data(), sizeof(mConnectionSessionID));
Utils::log("Connection established with Session ID: " + std::to_string(mConnectionSessionID));
if (mSessionIDRef) { // Copy to the global reference
*mSessionIDRef = mConnectionSessionID;
}
mHandshakeComplete = true;
}
break;
case ServerMessageType::HEARTBEAT:
Utils::log("Received HEARTBEAT from server.");
mHandler->sendMessage(ClientMessageType::HEARTBEAT_ACK, ""); // Send ACK
break;
case ServerMessageType::HEARTBEAT_ACK:
Utils::log("Received HEARTBEAT_ACK from server.");
mLastHeartbeatReceived = std::chrono::steady_clock::now();
mMissedHeartbeats = 0; // Reset missed heartbeat count
break;
case ServerMessageType::GRACEFUL_DISCONNECT:
Utils::log("Server is disconnecting: " + data);
if (mConnected) { // Prevent Recursion
disconnect(false);
}
break;
default:
Utils::log("Received unknown message type from server.");
break;
}
}
}

View File

@@ -0,0 +1,111 @@
// udp_client.cpp - UDP Client for ColumnLynx
// Copyright (C) 2025 DcruBro
// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details.
#include <columnlynx/client/net/udp/udp_client.hpp>
namespace ColumnLynx::Net::UDP {
void UDPClient::start() {
auto endpoints = mResolver.resolve(asio::ip::udp::v4(), mHost, mPort);
mRemoteEndpoint = *endpoints.begin();
mSocket.open(asio::ip::udp::v4());
Utils::log("UDP Client ready to send to " + mRemoteEndpoint.address().to_string() + ":" + std::to_string(mRemoteEndpoint.port()));
}
void UDPClient::sendMessage(const std::string& data) {
UDPPacketHeader hdr{};
randombytes_buf(hdr.nonce.data(), hdr.nonce.size());
if (mAesKeyRef == nullptr || mSessionIDRef == nullptr) {
Utils::error("UDP Client AES key or Session ID reference is null!");
return;
}
auto encryptedPayload = Utils::LibSodiumWrapper::encryptMessage(
reinterpret_cast<const uint8_t*>(data.data()), data.size(),
*mAesKeyRef, hdr.nonce, "udp-data"
);
std::vector<uint8_t> packet;
packet.reserve(sizeof(UDPPacketHeader) + sizeof(uint64_t) + encryptedPayload.size());
packet.insert(packet.end(),
reinterpret_cast<uint8_t*>(&hdr),
reinterpret_cast<uint8_t*>(&hdr) + sizeof(UDPPacketHeader)
);
uint64_t sid = *mSessionIDRef;
packet.insert(packet.end(),
reinterpret_cast<uint8_t*>(&sid),
reinterpret_cast<uint8_t*>(&sid) + sizeof(sid)
);
packet.insert(packet.end(), encryptedPayload.begin(), encryptedPayload.end());
mSocket.send_to(asio::buffer(packet), mRemoteEndpoint);
Utils::log("Sent UDP packet of size " + std::to_string(packet.size()));
}
void UDPClient::stop() {
if (mSocket.is_open()) {
asio::error_code ec;
mSocket.cancel(ec);
mSocket.close(ec);
Utils::log("UDP Client socket closed.");
}
}
void UDPClient::mStartReceive() {
mSocket.async_receive_from(
asio::buffer(mRecvBuffer), mRemoteEndpoint,
[this](asio::error_code ec, std::size_t bytes) {
if (ec) {
if (ec == asio::error::operation_aborted) return; // Socket closed
// Other recv error
mStartReceive();
return;
}
if (bytes > 0) {
mHandlePacket(bytes);
}
mStartReceive();
}
);
}
void UDPClient::mHandlePacket(std::size_t bytes) {
if (bytes < sizeof(UDPPacketHeader) + sizeof(uint64_t)) {
Utils::warn("UDP Client received packet too small to process.");
return;
}
// Parse header
UDPPacketHeader hdr;
std::memcpy(&hdr, mRecvBuffer.data(), sizeof(UDPPacketHeader));
// Parse session ID
uint64_t sessionID;
std::memcpy(&sessionID, mRecvBuffer.data() + sizeof(UDPPacketHeader), sizeof(uint64_t));
// Decrypt payload
std::vector<uint8_t> ciphertext(
mRecvBuffer.begin() + sizeof(UDPPacketHeader) + sizeof(uint64_t),
mRecvBuffer.begin() + bytes
);
if (mAesKeyRef == nullptr) {
Utils::error("UDP Client AES key reference is null!");
return;
}
std::vector<uint8_t> plaintext = Utils::LibSodiumWrapper::decryptMessage(
ciphertext.data(), ciphertext.size(), *mAesKeyRef, hdr.nonce, "udp-data"
);
if (plaintext.empty()) {
Utils::warn("UDP Client failed to decrypt received packet.");
return;
}
Utils::log("UDP Client received packet from " + mRemoteEndpoint.address().to_string() + " - Packet size: " + std::to_string(bytes));
}
}

View File

@@ -0,0 +1,217 @@
// tcp_connection.cpp - TCP Connection for ColumnLynx
// Copyright (C) 2025 DcruBro
// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details.
#include <columnlynx/server/net/tcp/tcp_connection.hpp>
namespace ColumnLynx::Net::TCP {
void TCPConnection::start() {
mHandler->onMessage([this](AnyMessageType type, const std::string& data) {
mHandleMessage(static_cast<ClientMessageType>(MessageHandler::toUint8(type)), data);
});
mHandler->onDisconnect([this](const asio::error_code& ec) {
Utils::log("Client disconnected: " + mHandler->socket().remote_endpoint().address().to_string() + " - " + ec.message());
disconnect();
});
mHandler->start();
mStartHeartbeat();
// Placeholder for message handling setup
Utils::log("Client connected: " + mHandler->socket().remote_endpoint().address().to_string());
}
void TCPConnection::sendMessage(ServerMessageType type, const std::string& data) {
if (mHandler) {
mHandler->sendMessage(type, data);
}
}
void TCPConnection::setDisconnectCallback(std::function<void(std::shared_ptr<TCPConnection>)> cb) {
mOnDisconnect = std::move(cb);
}
void TCPConnection::disconnect() {
std::string ip = mHandler->socket().remote_endpoint().address().to_string();
mHandler->sendMessage(ServerMessageType::GRACEFUL_DISCONNECT, "Server initiated disconnect.");
mHeartbeatTimer.cancel();
asio::error_code ec;
mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
Utils::log("Closed connection to " + ip);
if (mOnDisconnect) {
mOnDisconnect(shared_from_this());
}
}
uint64_t TCPConnection::getSessionID() const {
return mConnectionSessionID;
}
std::array<uint8_t, 32> TCPConnection::getAESKey() const {
return mConnectionAESKey;
}
void TCPConnection::mStartHeartbeat() {
auto self = shared_from_this();
mHeartbeatTimer.expires_after(std::chrono::seconds(5));
mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) {
if (ec == asio::error::operation_aborted) {
return; // Timer was cancelled
}
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - self->mLastHeartbeatReceived).count();
if (elapsed >= 15) { // 3 missed heartbeats
Utils::error("Missed 3 heartbeats. I think the other party (client " + std::to_string(self->mConnectionSessionID) + ") might have died! Disconnecting.");
// Remove socket forcefully, client is dead
asio::error_code ec;
mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
SessionRegistry::getInstance().erase(self->mConnectionSessionID);
return;
}
self->sendMessage(ServerMessageType::HEARTBEAT);
Utils::log("Sent HEARTBEAT to client " + std::to_string(self->mConnectionSessionID));
self->mLastHeartbeatSent = now;
self->mStartHeartbeat(); // Recursive
});
}
void TCPConnection::mHandleMessage(ClientMessageType type, const std::string& data) {
std::string reqAddr = mHandler->socket().remote_endpoint().address().to_string();
switch (type) {
case ClientMessageType::HANDSHAKE_INIT: {
Utils::log("Received HANDSHAKE_INIT from " + reqAddr);
if (data.size() < 1 + crypto_box_PUBLICKEYBYTES) {
Utils::warn("HANDSHAKE_INIT from " + reqAddr + " is too short.");
disconnect();
return;
}
uint8_t clientProtoVer = static_cast<uint8_t>(data[0]);
if (clientProtoVer != Utils::protocolVersion()) {
Utils::warn("Client protocol version mismatch from " + reqAddr + ". Expected " +
std::to_string(Utils::protocolVersion()) + ", got " + std::to_string(clientProtoVer) + ".");
disconnect();
return;
}
Utils::log("Client protocol version " + std::to_string(clientProtoVer) + " accepted from " + reqAddr + ".");
std::memcpy(mConnectionPublicKey.data(), data.data() + 1, std::min(data.size() - 1, sizeof(mConnectionPublicKey))); // Store the client's public key (for identification)
mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist
break;
}
case ClientMessageType::HANDSHAKE_CHALLENGE: {
Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr);
// Convert to byte array
uint8_t challengeData[32];
std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData)));
// Sign the challenge
Signature sig = Utils::LibSodiumWrapper::signMessage(
challengeData, sizeof(challengeData),
mLibSodiumWrapper->getPrivateKey()
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response
break;
}
case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: {
Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr);
// Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext)
if (data.size() < 24) { // Minimum size check (nonce)
Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short.");
disconnect();
return;
}
AsymNonce nonce{};
std::memcpy(nonce.data(), data.data(), nonce.size());
std::vector<uint8_t> ciphertext(data.size() - nonce.size());
std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size());
try {
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
// Decrypt the AES key using the client's public key and server's private key
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptAsymmetric(
ciphertext.data(), ciphertext.size(),
nonce,
mConnectionPublicKey,
arrayPrivateKey
);
if (decrypted.size() != 32) {
Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size.");
disconnect();
return;
}
std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size());
// Make a Session ID
randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID));
// TODO: Make the session ID little-endian for network transmission
// Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose)
Nonce symNonce{}; // All zeros
std::vector<uint8_t> encryptedSessionID = Utils::LibSodiumWrapper::encryptMessage(
reinterpret_cast<uint8_t*>(&mConnectionSessionID), sizeof(mConnectionSessionID),
mConnectionAESKey, symNonce
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedSessionID.data(), encryptedSessionID.size()));
// Add to session registry
Utils::log("Handshake with " + reqAddr + " completed successfully. Session ID assigned.");
auto session = std::make_shared<SessionState>(mConnectionAESKey, std::chrono::hours(12));
SessionRegistry::getInstance().put(mConnectionSessionID, std::move(session));
} catch (const std::exception& e) {
Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what());
disconnect();
}
break;
}
case ClientMessageType::HEARTBEAT: {
Utils::log("Received HEARTBEAT from " + reqAddr);
mHandler->sendMessage(ServerMessageType::HEARTBEAT_ACK, ""); // Send ACK
break;
}
case ClientMessageType::HEARTBEAT_ACK: {
Utils::log("Received HEARTBEAT_ACK from " + reqAddr);
mLastHeartbeatReceived = std::chrono::steady_clock::now();
mMissedHeartbeats = 0; // Reset missed heartbeat count
break;
}
case ClientMessageType::GRACEFUL_DISCONNECT: {
Utils::log("Received GRACEFUL_DISCONNECT from " + reqAddr + ": " + data);
disconnect();
break;
}
default:
Utils::warn("Unhandled message type from " + reqAddr);
break;
}
}
}