From 4b4451d1a90eec167adab4fbf2037cd451eccb0b Mon Sep 17 00:00:00 2001 From: DcruBro Date: Mon, 10 Nov 2025 23:19:39 +0100 Subject: [PATCH] Refactoring: Moved some code from headers to dedicated source files --- .../columnlynx/client/net/tcp/tcp_client.hpp | 235 +---------------- .../columnlynx/client/net/udp/udp_client.hpp | 107 +------- .../server/net/tcp/tcp_connection.hpp | 215 +--------------- src/client/main.cpp | 3 +- src/client/net/tcp/tcp_client.cpp | 238 ++++++++++++++++++ src/client/net/udp/udp_client.cpp | 111 ++++++++ src/server/server/net/tcp/tcp_connection.cpp | 217 ++++++++++++++++ 7 files changed, 588 insertions(+), 538 deletions(-) create mode 100644 src/client/net/tcp/tcp_client.cpp create mode 100644 src/client/net/udp/udp_client.cpp create mode 100644 src/server/server/net/tcp/tcp_connection.cpp diff --git a/include/columnlynx/client/net/tcp/tcp_client.hpp b/include/columnlynx/client/net/tcp/tcp_client.hpp index 9b1a17b..28001ee 100644 --- a/include/columnlynx/client/net/tcp/tcp_client.hpp +++ b/include/columnlynx/client/net/tcp/tcp_client.hpp @@ -38,237 +38,16 @@ namespace ColumnLynx::Net::TCP { mLastHeartbeatSent(std::chrono::steady_clock::now()) {} - void start() { - auto self = shared_from_this(); - mResolver.async_resolve(mHost, mPort, - [this, self](asio::error_code ec, tcp::resolver::results_type endpoints) { - if (!ec) { - asio::async_connect(mSocket, endpoints, - [this, self](asio::error_code ec, const tcp::endpoint&) { - if (!NetHelper::isExpectedDisconnect(ec)) { - mConnected = true; - Utils::log("Client connected."); - mHandler = std::make_shared(std::move(mSocket)); - mHandler->onMessage([this](AnyMessageType type, const std::string& data) { - mHandleMessage(static_cast(MessageHandler::toUint8(type)), data); - }); - mHandler->start(); - - // Init connection handshake - Utils::log("Sending handshake init to server."); + void start(); + void sendMessage(ClientMessageType type, const std::string& data = ""); + void disconnect(bool echo = true); - std::vector payload; - payload.reserve(1 + crypto_box_PUBLICKEYBYTES); - payload.push_back(Utils::protocolVersion()); - payload.insert(payload.end(), - mLibSodiumWrapper->getXPublicKey(), - mLibSodiumWrapper->getXPublicKey() + crypto_box_PUBLICKEYBYTES - ); - - mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, Utils::uint8ArrayToString(payload.data(), payload.size())); - - mStartHeartbeat(); - } else { - Utils::error("Client connect failed: " + ec.message()); - } - }); - } else { - Utils::error("Client resolve failed: " + ec.message()); - } - }); - } - - void sendMessage(ClientMessageType type, const std::string& data = "") { - if (!mConnected) { - Utils::error("Cannot send message, client not connected."); - return; - } - - if (mHandler) { - asio::post(mHandler->socket().get_executor(), [self = shared_from_this(), type, data]() { - self->mHandler->sendMessage(type, data); - }); - } - } - - void disconnect(bool echo = true) { - if (mConnected && mHandler) { - if (echo) { - mHandler->sendMessage(ClientMessageType::GRACEFUL_DISCONNECT, "Goodbye"); - } - - asio::error_code ec; - mHeartbeatTimer.cancel(); - - mHandler->socket().shutdown(tcp::socket::shutdown_both, ec); - if (ec) { - Utils::error("Error during socket shutdown: " + ec.message()); - } - - mHandler->socket().close(ec); - if (ec) { - Utils::error("Error during socket close: " + ec.message()); - } - - mConnected = false; - Utils::log("Client disconnected."); - } - } - - bool isHandshakeComplete() const { - return mHandshakeComplete; - } - - bool isConnected() const { - return mConnected; - } + bool isHandshakeComplete() const; + bool isConnected() const; private: - void mStartHeartbeat() { - auto self = shared_from_this(); - mHeartbeatTimer.expires_after(std::chrono::seconds(5)); - mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) { - if (ec == asio::error::operation_aborted) { - return; // Timer was cancelled - } - - auto now = std::chrono::steady_clock::now(); - auto elapsed = std::chrono::duration_cast(now - self->mLastHeartbeatReceived).count(); - - if (elapsed >= 15) { // 3 missed heartbeats - Utils::error("Missed 3 heartbeats. I think the other party might have died! Disconnecting."); - - // Close sockets forcefully, server is dead - asio::error_code ec; - mHandler->socket().shutdown(tcp::socket::shutdown_both, ec); - mHandler->socket().close(ec); - mConnected = false; - - mGlobalKeyRef = nullptr; - if (mSessionIDRef) { - *mSessionIDRef = 0; - } - - return; - } - - self->sendMessage(ClientMessageType::HEARTBEAT); - Utils::log("Sent HEARTBEAT to server."); - self->mLastHeartbeatSent = std::chrono::steady_clock::now(); - - self->mStartHeartbeat(); // Recursive - }); - } - - void mHandleMessage(ServerMessageType type, const std::string& data) { - switch (type) { - case ServerMessageType::HANDSHAKE_IDENTIFY: - Utils::log("Received server identity: " + data); - std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey))); - - // Generate and send challenge - { - Utils::log("Sending challenge to server."); - mSubmittedChallenge = Utils::LibSodiumWrapper::generateRandom256Bit(); // Temporarily store the challenge to verify later - mHandler->sendMessage(ClientMessageType::HANDSHAKE_CHALLENGE, Utils::uint8ArrayToString(mSubmittedChallenge)); - } - - break; - case ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE: - Utils::log("Received challenge response from server."); - { - // Verify the signature - Signature sig{}; - std::memcpy(sig.data(), data.data(), std::min(data.size(), sig.size())); - if (Utils::LibSodiumWrapper::verifyMessage(mSubmittedChallenge.data(), mSubmittedChallenge.size(), sig, mServerPublicKey)) { - Utils::log("Challenge response verified successfully."); - - // Convert the server's public key to Curve25519 for encryption - AsymPublicKey serverXPubKey{}; - crypto_sign_ed25519_pk_to_curve25519(serverXPubKey.data(), mServerPublicKey); - - // Generate AES key and send confirmation - mConnectionAESKey = Utils::LibSodiumWrapper::generateRandom256Bit(); - if (mGlobalKeyRef) { // Copy to the global reference - std::copy(mConnectionAESKey.begin(), mConnectionAESKey.end(), mGlobalKeyRef->begin()); - } - AsymNonce nonce{}; - randombytes_buf(nonce.data(), nonce.size()); - - // TODO: This is pretty redundant, it should return the required type directly - std::array arrayPrivateKey; - std::copy(mLibSodiumWrapper->getXPrivateKey(), - mLibSodiumWrapper->getXPrivateKey() + 32, - arrayPrivateKey.begin()); - - std::vector encr = Utils::LibSodiumWrapper::encryptAsymmetric( - mConnectionAESKey.data(), mConnectionAESKey.size(), - nonce, - serverXPubKey, - arrayPrivateKey - ); - - std::vector payload; - payload.reserve(nonce.size() + encr.size()); - payload.insert(payload.end(), nonce.begin(), nonce.end()); - payload.insert(payload.end(), encr.begin(), encr.end()); - - mHandler->sendMessage(ClientMessageType::HANDSHAKE_EXCHANGE_KEY, Utils::uint8ArrayToString(payload.data(), payload.size())); - } else { - Utils::error("Challenge response verification failed. Terminating connection."); - disconnect(); - } - } - - break; - case ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM: - Utils::log("Received handshake exchange key confirmation from server."); - // Decrypt the session ID using the established AES key - { - Nonce symNonce{}; // All zeros - std::vector ciphertext(data.begin(), data.end()); - std::vector decrypted = Utils::LibSodiumWrapper::decryptMessage( - ciphertext.data(), ciphertext.size(), - mConnectionAESKey, symNonce - ); - - if (decrypted.size() != sizeof(mConnectionSessionID)) { - Utils::error("Decrypted session ID has invalid size. Terminating connection."); - disconnect(); - return; - } - - std::memcpy(&mConnectionSessionID, decrypted.data(), sizeof(mConnectionSessionID)); - Utils::log("Connection established with Session ID: " + std::to_string(mConnectionSessionID)); - - if (mSessionIDRef) { // Copy to the global reference - *mSessionIDRef = mConnectionSessionID; - } - - mHandshakeComplete = true; - } - - break; - case ServerMessageType::HEARTBEAT: - Utils::log("Received HEARTBEAT from server."); - mHandler->sendMessage(ClientMessageType::HEARTBEAT_ACK, ""); // Send ACK - break; - case ServerMessageType::HEARTBEAT_ACK: - Utils::log("Received HEARTBEAT_ACK from server."); - mLastHeartbeatReceived = std::chrono::steady_clock::now(); - mMissedHeartbeats = 0; // Reset missed heartbeat count - break; - case ServerMessageType::GRACEFUL_DISCONNECT: - Utils::log("Server is disconnecting: " + data); - if (mConnected) { // Prevent Recursion - disconnect(false); - } - break; - default: - Utils::log("Received unknown message type from server."); - break; - } - } + void mStartHeartbeat(); + void mHandleMessage(ServerMessageType type, const std::string& data); bool mConnected = false; bool mHandshakeComplete = false; diff --git a/include/columnlynx/client/net/udp/udp_client.hpp b/include/columnlynx/client/net/udp/udp_client.hpp index f49d043..0d804ec 100644 --- a/include/columnlynx/client/net/udp/udp_client.hpp +++ b/include/columnlynx/client/net/udp/udp_client.hpp @@ -20,110 +20,13 @@ namespace ColumnLynx::Net::UDP { uint64_t* sessionIDRef) : mSocket(ioContext), mResolver(ioContext), mHost(host), mPort(port), mAesKeyRef(aesKeyRef), mSessionIDRef(sessionIDRef) { mStartReceive(); } - void start() { - auto endpoints = mResolver.resolve(asio::ip::udp::v4(), mHost, mPort); - mRemoteEndpoint = *endpoints.begin(); - mSocket.open(asio::ip::udp::v4()); - Utils::log("UDP Client ready to send to " + mRemoteEndpoint.address().to_string() + ":" + std::to_string(mRemoteEndpoint.port())); - } - - void sendMessage(const std::string& data = "") { - UDPPacketHeader hdr{}; - randombytes_buf(hdr.nonce.data(), hdr.nonce.size()); - - if (mAesKeyRef == nullptr || mSessionIDRef == nullptr) { - Utils::error("UDP Client AES key or Session ID reference is null!"); - return; - } - - auto encryptedPayload = Utils::LibSodiumWrapper::encryptMessage( - reinterpret_cast(data.data()), data.size(), - *mAesKeyRef, hdr.nonce, "udp-data" - ); - - std::vector packet; - packet.reserve(sizeof(UDPPacketHeader) + sizeof(uint64_t) + encryptedPayload.size()); - packet.insert(packet.end(), - reinterpret_cast(&hdr), - reinterpret_cast(&hdr) + sizeof(UDPPacketHeader) - ); - uint64_t sid = *mSessionIDRef; - packet.insert(packet.end(), - reinterpret_cast(&sid), - reinterpret_cast(&sid) + sizeof(sid) - ); - packet.insert(packet.end(), encryptedPayload.begin(), encryptedPayload.end()); - - mSocket.send_to(asio::buffer(packet), mRemoteEndpoint); - Utils::log("Sent UDP packet of size " + std::to_string(packet.size())); - } - - void stop() { - if (mSocket.is_open()) { - asio::error_code ec; - mSocket.cancel(ec); - mSocket.close(ec); - Utils::log("UDP Client socket closed."); - } - } + void start(); + void sendMessage(const std::string& data = ""); + void stop(); private: - void mStartReceive() { - mSocket.async_receive_from( - asio::buffer(mRecvBuffer), mRemoteEndpoint, - [this](asio::error_code ec, std::size_t bytes) { - if (ec) { - if (ec == asio::error::operation_aborted) return; // Socket closed - // Other recv error - mStartReceive(); - return; - } - - if (bytes > 0) { - mHandlePacket(bytes); - } - - mStartReceive(); - } - ); - } - - void mHandlePacket(std::size_t bytes) { - if (bytes < sizeof(UDPPacketHeader) + sizeof(uint64_t)) { - Utils::warn("UDP Client received packet too small to process."); - return; - } - - // Parse header - UDPPacketHeader hdr; - std::memcpy(&hdr, mRecvBuffer.data(), sizeof(UDPPacketHeader)); - - // Parse session ID - uint64_t sessionID; - std::memcpy(&sessionID, mRecvBuffer.data() + sizeof(UDPPacketHeader), sizeof(uint64_t)); - - // Decrypt payload - std::vector ciphertext( - mRecvBuffer.begin() + sizeof(UDPPacketHeader) + sizeof(uint64_t), - mRecvBuffer.begin() + bytes - ); - - if (mAesKeyRef == nullptr) { - Utils::error("UDP Client AES key reference is null!"); - return; - } - - std::vector plaintext = Utils::LibSodiumWrapper::decryptMessage( - ciphertext.data(), ciphertext.size(), *mAesKeyRef, hdr.nonce, "udp-data" - ); - - if (plaintext.empty()) { - Utils::warn("UDP Client failed to decrypt received packet."); - return; - } - - Utils::log("UDP Client received packet from " + mRemoteEndpoint.address().to_string() + " - Packet size: " + std::to_string(bytes)); - } + void mStartReceive(); + void mHandlePacket(std::size_t bytes); asio::ip::udp::socket mSocket; asio::ip::udp::resolver mResolver; diff --git a/include/columnlynx/server/net/tcp/tcp_connection.hpp b/include/columnlynx/server/net/tcp/tcp_connection.hpp index acc7969..7681b97 100644 --- a/include/columnlynx/server/net/tcp/tcp_connection.hpp +++ b/include/columnlynx/server/net/tcp/tcp_connection.hpp @@ -32,56 +32,13 @@ namespace ColumnLynx::Net::TCP { return conn; } - void start() { - mHandler->onMessage([this](AnyMessageType type, const std::string& data) { - mHandleMessage(static_cast(MessageHandler::toUint8(type)), data); - }); + void start(); + void sendMessage(ServerMessageType type, const std::string& data = ""); + void setDisconnectCallback(std::function)> cb); + void disconnect(); - mHandler->onDisconnect([this](const asio::error_code& ec) { - Utils::log("Client disconnected: " + mHandler->socket().remote_endpoint().address().to_string() + " - " + ec.message()); - disconnect(); - }); - - mHandler->start(); - mStartHeartbeat(); - - // Placeholder for message handling setup - Utils::log("Client connected: " + mHandler->socket().remote_endpoint().address().to_string()); - } - - void sendMessage(ServerMessageType type, const std::string& data = "") { - if (mHandler) { - mHandler->sendMessage(type, data); - } - } - - void setDisconnectCallback(std::function)> cb) { - mOnDisconnect = std::move(cb); - } - - void disconnect() { - std::string ip = mHandler->socket().remote_endpoint().address().to_string(); - - mHandler->sendMessage(ServerMessageType::GRACEFUL_DISCONNECT, "Server initiated disconnect."); - mHeartbeatTimer.cancel(); - asio::error_code ec; - mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); - mHandler->socket().close(ec); - - Utils::log("Closed connection to " + ip); - - if (mOnDisconnect) { - mOnDisconnect(shared_from_this()); - } - } - - uint64_t getSessionID() const { - return mConnectionSessionID; - } - - std::array getAESKey() const { - return mConnectionAESKey; - } + uint64_t getSessionID() const; + std::array getAESKey() const; private: TCPConnection(asio::ip::tcp::socket socket, Utils::LibSodiumWrapper* sodiumWrapper) @@ -93,164 +50,8 @@ namespace ColumnLynx::Net::TCP { mLastHeartbeatSent(std::chrono::steady_clock::now()) {} - void mStartHeartbeat() { - auto self = shared_from_this(); - mHeartbeatTimer.expires_after(std::chrono::seconds(5)); - mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) { - if (ec == asio::error::operation_aborted) { - return; // Timer was cancelled - } - - auto now = std::chrono::steady_clock::now(); - auto elapsed = std::chrono::duration_cast(now - self->mLastHeartbeatReceived).count(); - - if (elapsed >= 15) { // 3 missed heartbeats - Utils::error("Missed 3 heartbeats. I think the other party (client " + std::to_string(self->mConnectionSessionID) + ") might have died! Disconnecting."); - - // Remove socket forcefully, client is dead - asio::error_code ec; - mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); - mHandler->socket().close(ec); - - SessionRegistry::getInstance().erase(self->mConnectionSessionID); - - return; - } - - self->sendMessage(ServerMessageType::HEARTBEAT); - Utils::log("Sent HEARTBEAT to client " + std::to_string(self->mConnectionSessionID)); - self->mLastHeartbeatSent = now; - - self->mStartHeartbeat(); // Recursive - }); - } - - void mHandleMessage(ClientMessageType type, const std::string& data) { - std::string reqAddr = mHandler->socket().remote_endpoint().address().to_string(); - - switch (type) { - case ClientMessageType::HANDSHAKE_INIT: { - Utils::log("Received HANDSHAKE_INIT from " + reqAddr); - - if (data.size() < 1 + crypto_box_PUBLICKEYBYTES) { - Utils::warn("HANDSHAKE_INIT from " + reqAddr + " is too short."); - disconnect(); - return; - } - - uint8_t clientProtoVer = static_cast(data[0]); - if (clientProtoVer != Utils::protocolVersion()) { - Utils::warn("Client protocol version mismatch from " + reqAddr + ". Expected " + - std::to_string(Utils::protocolVersion()) + ", got " + std::to_string(clientProtoVer) + "."); - disconnect(); - return; - } - - Utils::log("Client protocol version " + std::to_string(clientProtoVer) + " accepted from " + reqAddr + "."); - - std::memcpy(mConnectionPublicKey.data(), data.data() + 1, std::min(data.size() - 1, sizeof(mConnectionPublicKey))); // Store the client's public key (for identification) - mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist - break; - } - case ClientMessageType::HANDSHAKE_CHALLENGE: { - Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr); - - // Convert to byte array - uint8_t challengeData[32]; - std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData))); - - // Sign the challenge - Signature sig = Utils::LibSodiumWrapper::signMessage( - challengeData, sizeof(challengeData), - mLibSodiumWrapper->getPrivateKey() - ); - - mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response - break; - } - case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: { - Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr); - - // Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext) - if (data.size() < 24) { // Minimum size check (nonce) - Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short."); - disconnect(); - return; - } - - AsymNonce nonce{}; - std::memcpy(nonce.data(), data.data(), nonce.size()); - std::vector ciphertext(data.size() - nonce.size()); - std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size()); - try { - std::array arrayPrivateKey; - std::copy(mLibSodiumWrapper->getXPrivateKey(), - mLibSodiumWrapper->getXPrivateKey() + 32, - arrayPrivateKey.begin()); - - // Decrypt the AES key using the client's public key and server's private key - std::vector decrypted = Utils::LibSodiumWrapper::decryptAsymmetric( - ciphertext.data(), ciphertext.size(), - nonce, - mConnectionPublicKey, - arrayPrivateKey - ); - - if (decrypted.size() != 32) { - Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size."); - disconnect(); - return; - } - - std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size()); - - // Make a Session ID - randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID)); - - // TODO: Make the session ID little-endian for network transmission - - // Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose) - Nonce symNonce{}; // All zeros - std::vector encryptedSessionID = Utils::LibSodiumWrapper::encryptMessage( - reinterpret_cast(&mConnectionSessionID), sizeof(mConnectionSessionID), - mConnectionAESKey, symNonce - ); - - mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedSessionID.data(), encryptedSessionID.size())); - - // Add to session registry - Utils::log("Handshake with " + reqAddr + " completed successfully. Session ID assigned."); - auto session = std::make_shared(mConnectionAESKey, std::chrono::hours(12)); - SessionRegistry::getInstance().put(mConnectionSessionID, std::move(session)); - - } catch (const std::exception& e) { - Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what()); - disconnect(); - } - - break; - } - case ClientMessageType::HEARTBEAT: { - Utils::log("Received HEARTBEAT from " + reqAddr); - mHandler->sendMessage(ServerMessageType::HEARTBEAT_ACK, ""); // Send ACK - break; - } - case ClientMessageType::HEARTBEAT_ACK: { - Utils::log("Received HEARTBEAT_ACK from " + reqAddr); - mLastHeartbeatReceived = std::chrono::steady_clock::now(); - mMissedHeartbeats = 0; // Reset missed heartbeat count - break; - } - case ClientMessageType::GRACEFUL_DISCONNECT: { - Utils::log("Received GRACEFUL_DISCONNECT from " + reqAddr + ": " + data); - disconnect(); - break; - } - default: - Utils::warn("Unhandled message type from " + reqAddr); - break; - } - } + void mStartHeartbeat(); + void mHandleMessage(ClientMessageType type, const std::string& data); std::shared_ptr mHandler; std::function)> mOnDisconnect; diff --git a/src/client/main.cpp b/src/client/main.cpp index 1aff47b..7c6edd1 100644 --- a/src/client/main.cpp +++ b/src/client/main.cpp @@ -74,7 +74,8 @@ int main(int argc, char** argv) { log("Client connected to " + host + ":" + port); // Client is running - while ((!done && client->isConnected()) || !client->isHandshakeComplete()) { + // TODO: SIGINT or SIGTERM seems to not kill this instantly! + while (client->isConnected() || !client->isHandshakeComplete() || !done) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Temp wait if (client->isHandshakeComplete()) { diff --git a/src/client/net/tcp/tcp_client.cpp b/src/client/net/tcp/tcp_client.cpp new file mode 100644 index 0000000..7532bc8 --- /dev/null +++ b/src/client/net/tcp/tcp_client.cpp @@ -0,0 +1,238 @@ +// tcp_client.cpp - TCP Client for ColumnLynx +// Copyright (C) 2025 DcruBro +// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details. + +#include + +namespace ColumnLynx::Net::TCP { + void TCPClient::start() { + auto self = shared_from_this(); + mResolver.async_resolve(mHost, mPort, + [this, self](asio::error_code ec, tcp::resolver::results_type endpoints) { + if (!ec) { + asio::async_connect(mSocket, endpoints, + [this, self](asio::error_code ec, const tcp::endpoint&) { + if (!NetHelper::isExpectedDisconnect(ec)) { + mConnected = true; + Utils::log("Client connected."); + mHandler = std::make_shared(std::move(mSocket)); + mHandler->onMessage([this](AnyMessageType type, const std::string& data) { + mHandleMessage(static_cast(MessageHandler::toUint8(type)), data); + }); + mHandler->start(); + + // Init connection handshake + Utils::log("Sending handshake init to server."); + + std::vector payload; + payload.reserve(1 + crypto_box_PUBLICKEYBYTES); + payload.push_back(Utils::protocolVersion()); + payload.insert(payload.end(), + mLibSodiumWrapper->getXPublicKey(), + mLibSodiumWrapper->getXPublicKey() + crypto_box_PUBLICKEYBYTES + ); + + mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, Utils::uint8ArrayToString(payload.data(), payload.size())); + + mStartHeartbeat(); + } else { + Utils::error("Client connect failed: " + ec.message()); + } + }); + } else { + Utils::error("Client resolve failed: " + ec.message()); + } + }); + } + + void TCPClient::sendMessage(ClientMessageType type, const std::string& data) { + if (!mConnected) { + Utils::error("Cannot send message, client not connected."); + return; + } + + if (mHandler) { + asio::post(mHandler->socket().get_executor(), [self = shared_from_this(), type, data]() { + self->mHandler->sendMessage(type, data); + }); + } + } + + void TCPClient::disconnect(bool echo) { + if (mConnected && mHandler) { + if (echo) { + mHandler->sendMessage(ClientMessageType::GRACEFUL_DISCONNECT, "Goodbye"); + } + + asio::error_code ec; + mHeartbeatTimer.cancel(); + + mHandler->socket().shutdown(tcp::socket::shutdown_both, ec); + if (ec) { + Utils::error("Error during socket shutdown: " + ec.message()); + } + + mHandler->socket().close(ec); + if (ec) { + Utils::error("Error during socket close: " + ec.message()); + } + + mConnected = false; + Utils::log("Client disconnected."); + } + } + + bool TCPClient::isHandshakeComplete() const { + return mHandshakeComplete; + } + + bool TCPClient::isConnected() const { + return mConnected; + } + + void TCPClient::mStartHeartbeat() { + auto self = shared_from_this(); + mHeartbeatTimer.expires_after(std::chrono::seconds(5)); + mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) { + if (ec == asio::error::operation_aborted) { + return; // Timer was cancelled + } + + auto now = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(now - self->mLastHeartbeatReceived).count(); + + if (elapsed >= 15) { // 3 missed heartbeats + Utils::error("Missed 3 heartbeats. I think the other party might have died! Disconnecting."); + + // Close sockets forcefully, server is dead + asio::error_code ec; + mHandler->socket().shutdown(tcp::socket::shutdown_both, ec); + mHandler->socket().close(ec); + mConnected = false; + + mGlobalKeyRef = nullptr; + if (mSessionIDRef) { + *mSessionIDRef = 0; + } + + return; + } + + self->sendMessage(ClientMessageType::HEARTBEAT); + Utils::log("Sent HEARTBEAT to server."); + self->mLastHeartbeatSent = std::chrono::steady_clock::now(); + + self->mStartHeartbeat(); // Recursive + }); + } + + void TCPClient::mHandleMessage(ServerMessageType type, const std::string& data) { + switch (type) { + case ServerMessageType::HANDSHAKE_IDENTIFY: + Utils::log("Received server identity: " + data); + std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey))); + + // Generate and send challenge + { + Utils::log("Sending challenge to server."); + mSubmittedChallenge = Utils::LibSodiumWrapper::generateRandom256Bit(); // Temporarily store the challenge to verify later + mHandler->sendMessage(ClientMessageType::HANDSHAKE_CHALLENGE, Utils::uint8ArrayToString(mSubmittedChallenge)); + } + + break; + case ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE: + Utils::log("Received challenge response from server."); + { + // Verify the signature + Signature sig{}; + std::memcpy(sig.data(), data.data(), std::min(data.size(), sig.size())); + if (Utils::LibSodiumWrapper::verifyMessage(mSubmittedChallenge.data(), mSubmittedChallenge.size(), sig, mServerPublicKey)) { + Utils::log("Challenge response verified successfully."); + + // Convert the server's public key to Curve25519 for encryption + AsymPublicKey serverXPubKey{}; + crypto_sign_ed25519_pk_to_curve25519(serverXPubKey.data(), mServerPublicKey); + + // Generate AES key and send confirmation + mConnectionAESKey = Utils::LibSodiumWrapper::generateRandom256Bit(); + if (mGlobalKeyRef) { // Copy to the global reference + std::copy(mConnectionAESKey.begin(), mConnectionAESKey.end(), mGlobalKeyRef->begin()); + } + AsymNonce nonce{}; + randombytes_buf(nonce.data(), nonce.size()); + + // TODO: This is pretty redundant, it should return the required type directly + std::array arrayPrivateKey; + std::copy(mLibSodiumWrapper->getXPrivateKey(), + mLibSodiumWrapper->getXPrivateKey() + 32, + arrayPrivateKey.begin()); + + std::vector encr = Utils::LibSodiumWrapper::encryptAsymmetric( + mConnectionAESKey.data(), mConnectionAESKey.size(), + nonce, + serverXPubKey, + arrayPrivateKey + ); + + std::vector payload; + payload.reserve(nonce.size() + encr.size()); + payload.insert(payload.end(), nonce.begin(), nonce.end()); + payload.insert(payload.end(), encr.begin(), encr.end()); + + mHandler->sendMessage(ClientMessageType::HANDSHAKE_EXCHANGE_KEY, Utils::uint8ArrayToString(payload.data(), payload.size())); + } else { + Utils::error("Challenge response verification failed. Terminating connection."); + disconnect(); + } + } + + break; + case ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM: + Utils::log("Received handshake exchange key confirmation from server."); + // Decrypt the session ID using the established AES key + { + Nonce symNonce{}; // All zeros + std::vector ciphertext(data.begin(), data.end()); + std::vector decrypted = Utils::LibSodiumWrapper::decryptMessage( + ciphertext.data(), ciphertext.size(), + mConnectionAESKey, symNonce + ); + + if (decrypted.size() != sizeof(mConnectionSessionID)) { + Utils::error("Decrypted session ID has invalid size. Terminating connection."); + disconnect(); + return; + } + + std::memcpy(&mConnectionSessionID, decrypted.data(), sizeof(mConnectionSessionID)); + Utils::log("Connection established with Session ID: " + std::to_string(mConnectionSessionID)); + + if (mSessionIDRef) { // Copy to the global reference + *mSessionIDRef = mConnectionSessionID; + } + + mHandshakeComplete = true; + } + + break; + case ServerMessageType::HEARTBEAT: + Utils::log("Received HEARTBEAT from server."); + mHandler->sendMessage(ClientMessageType::HEARTBEAT_ACK, ""); // Send ACK + break; + case ServerMessageType::HEARTBEAT_ACK: + Utils::log("Received HEARTBEAT_ACK from server."); + mLastHeartbeatReceived = std::chrono::steady_clock::now(); + mMissedHeartbeats = 0; // Reset missed heartbeat count + break; + case ServerMessageType::GRACEFUL_DISCONNECT: + Utils::log("Server is disconnecting: " + data); + if (mConnected) { // Prevent Recursion + disconnect(false); + } + break; + default: + Utils::log("Received unknown message type from server."); + break; + } + } +} \ No newline at end of file diff --git a/src/client/net/udp/udp_client.cpp b/src/client/net/udp/udp_client.cpp new file mode 100644 index 0000000..c5cdc5c --- /dev/null +++ b/src/client/net/udp/udp_client.cpp @@ -0,0 +1,111 @@ +// udp_client.cpp - UDP Client for ColumnLynx +// Copyright (C) 2025 DcruBro +// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details. + +#include + +namespace ColumnLynx::Net::UDP { + void UDPClient::start() { + auto endpoints = mResolver.resolve(asio::ip::udp::v4(), mHost, mPort); + mRemoteEndpoint = *endpoints.begin(); + mSocket.open(asio::ip::udp::v4()); + Utils::log("UDP Client ready to send to " + mRemoteEndpoint.address().to_string() + ":" + std::to_string(mRemoteEndpoint.port())); + } + + void UDPClient::sendMessage(const std::string& data) { + UDPPacketHeader hdr{}; + randombytes_buf(hdr.nonce.data(), hdr.nonce.size()); + + if (mAesKeyRef == nullptr || mSessionIDRef == nullptr) { + Utils::error("UDP Client AES key or Session ID reference is null!"); + return; + } + + auto encryptedPayload = Utils::LibSodiumWrapper::encryptMessage( + reinterpret_cast(data.data()), data.size(), + *mAesKeyRef, hdr.nonce, "udp-data" + ); + + std::vector packet; + packet.reserve(sizeof(UDPPacketHeader) + sizeof(uint64_t) + encryptedPayload.size()); + packet.insert(packet.end(), + reinterpret_cast(&hdr), + reinterpret_cast(&hdr) + sizeof(UDPPacketHeader) + ); + uint64_t sid = *mSessionIDRef; + packet.insert(packet.end(), + reinterpret_cast(&sid), + reinterpret_cast(&sid) + sizeof(sid) + ); + packet.insert(packet.end(), encryptedPayload.begin(), encryptedPayload.end()); + + mSocket.send_to(asio::buffer(packet), mRemoteEndpoint); + Utils::log("Sent UDP packet of size " + std::to_string(packet.size())); + } + + void UDPClient::stop() { + if (mSocket.is_open()) { + asio::error_code ec; + mSocket.cancel(ec); + mSocket.close(ec); + Utils::log("UDP Client socket closed."); + } + } + + void UDPClient::mStartReceive() { + mSocket.async_receive_from( + asio::buffer(mRecvBuffer), mRemoteEndpoint, + [this](asio::error_code ec, std::size_t bytes) { + if (ec) { + if (ec == asio::error::operation_aborted) return; // Socket closed + // Other recv error + mStartReceive(); + return; + } + + if (bytes > 0) { + mHandlePacket(bytes); + } + + mStartReceive(); + } + ); + } + + void UDPClient::mHandlePacket(std::size_t bytes) { + if (bytes < sizeof(UDPPacketHeader) + sizeof(uint64_t)) { + Utils::warn("UDP Client received packet too small to process."); + return; + } + + // Parse header + UDPPacketHeader hdr; + std::memcpy(&hdr, mRecvBuffer.data(), sizeof(UDPPacketHeader)); + + // Parse session ID + uint64_t sessionID; + std::memcpy(&sessionID, mRecvBuffer.data() + sizeof(UDPPacketHeader), sizeof(uint64_t)); + + // Decrypt payload + std::vector ciphertext( + mRecvBuffer.begin() + sizeof(UDPPacketHeader) + sizeof(uint64_t), + mRecvBuffer.begin() + bytes + ); + + if (mAesKeyRef == nullptr) { + Utils::error("UDP Client AES key reference is null!"); + return; + } + + std::vector plaintext = Utils::LibSodiumWrapper::decryptMessage( + ciphertext.data(), ciphertext.size(), *mAesKeyRef, hdr.nonce, "udp-data" + ); + + if (plaintext.empty()) { + Utils::warn("UDP Client failed to decrypt received packet."); + return; + } + + Utils::log("UDP Client received packet from " + mRemoteEndpoint.address().to_string() + " - Packet size: " + std::to_string(bytes)); + } +} \ No newline at end of file diff --git a/src/server/server/net/tcp/tcp_connection.cpp b/src/server/server/net/tcp/tcp_connection.cpp new file mode 100644 index 0000000..196a7cd --- /dev/null +++ b/src/server/server/net/tcp/tcp_connection.cpp @@ -0,0 +1,217 @@ +// tcp_connection.cpp - TCP Connection for ColumnLynx +// Copyright (C) 2025 DcruBro +// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details. + +#include + +namespace ColumnLynx::Net::TCP { + void TCPConnection::start() { + mHandler->onMessage([this](AnyMessageType type, const std::string& data) { + mHandleMessage(static_cast(MessageHandler::toUint8(type)), data); + }); + + mHandler->onDisconnect([this](const asio::error_code& ec) { + Utils::log("Client disconnected: " + mHandler->socket().remote_endpoint().address().to_string() + " - " + ec.message()); + disconnect(); + }); + + mHandler->start(); + mStartHeartbeat(); + + // Placeholder for message handling setup + Utils::log("Client connected: " + mHandler->socket().remote_endpoint().address().to_string()); + } + + void TCPConnection::sendMessage(ServerMessageType type, const std::string& data) { + if (mHandler) { + mHandler->sendMessage(type, data); + } + } + + void TCPConnection::setDisconnectCallback(std::function)> cb) { + mOnDisconnect = std::move(cb); + } + + void TCPConnection::disconnect() { + std::string ip = mHandler->socket().remote_endpoint().address().to_string(); + + mHandler->sendMessage(ServerMessageType::GRACEFUL_DISCONNECT, "Server initiated disconnect."); + mHeartbeatTimer.cancel(); + asio::error_code ec; + mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); + mHandler->socket().close(ec); + + Utils::log("Closed connection to " + ip); + + if (mOnDisconnect) { + mOnDisconnect(shared_from_this()); + } + } + + uint64_t TCPConnection::getSessionID() const { + return mConnectionSessionID; + } + + std::array TCPConnection::getAESKey() const { + return mConnectionAESKey; + } + + void TCPConnection::mStartHeartbeat() { + auto self = shared_from_this(); + mHeartbeatTimer.expires_after(std::chrono::seconds(5)); + mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) { + if (ec == asio::error::operation_aborted) { + return; // Timer was cancelled + } + + auto now = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(now - self->mLastHeartbeatReceived).count(); + + if (elapsed >= 15) { // 3 missed heartbeats + Utils::error("Missed 3 heartbeats. I think the other party (client " + std::to_string(self->mConnectionSessionID) + ") might have died! Disconnecting."); + + // Remove socket forcefully, client is dead + asio::error_code ec; + mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); + mHandler->socket().close(ec); + + SessionRegistry::getInstance().erase(self->mConnectionSessionID); + + return; + } + + self->sendMessage(ServerMessageType::HEARTBEAT); + Utils::log("Sent HEARTBEAT to client " + std::to_string(self->mConnectionSessionID)); + self->mLastHeartbeatSent = now; + + self->mStartHeartbeat(); // Recursive + }); + } + + void TCPConnection::mHandleMessage(ClientMessageType type, const std::string& data) { + std::string reqAddr = mHandler->socket().remote_endpoint().address().to_string(); + + switch (type) { + case ClientMessageType::HANDSHAKE_INIT: { + Utils::log("Received HANDSHAKE_INIT from " + reqAddr); + + if (data.size() < 1 + crypto_box_PUBLICKEYBYTES) { + Utils::warn("HANDSHAKE_INIT from " + reqAddr + " is too short."); + disconnect(); + return; + } + + uint8_t clientProtoVer = static_cast(data[0]); + if (clientProtoVer != Utils::protocolVersion()) { + Utils::warn("Client protocol version mismatch from " + reqAddr + ". Expected " + + std::to_string(Utils::protocolVersion()) + ", got " + std::to_string(clientProtoVer) + "."); + disconnect(); + return; + } + + Utils::log("Client protocol version " + std::to_string(clientProtoVer) + " accepted from " + reqAddr + "."); + + std::memcpy(mConnectionPublicKey.data(), data.data() + 1, std::min(data.size() - 1, sizeof(mConnectionPublicKey))); // Store the client's public key (for identification) + mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist + break; + } + case ClientMessageType::HANDSHAKE_CHALLENGE: { + Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr); + + // Convert to byte array + uint8_t challengeData[32]; + std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData))); + + // Sign the challenge + Signature sig = Utils::LibSodiumWrapper::signMessage( + challengeData, sizeof(challengeData), + mLibSodiumWrapper->getPrivateKey() + ); + + mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response + break; + } + case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: { + Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr); + + // Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext) + if (data.size() < 24) { // Minimum size check (nonce) + Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short."); + disconnect(); + return; + } + + AsymNonce nonce{}; + std::memcpy(nonce.data(), data.data(), nonce.size()); + std::vector ciphertext(data.size() - nonce.size()); + std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size()); + try { + std::array arrayPrivateKey; + std::copy(mLibSodiumWrapper->getXPrivateKey(), + mLibSodiumWrapper->getXPrivateKey() + 32, + arrayPrivateKey.begin()); + + // Decrypt the AES key using the client's public key and server's private key + std::vector decrypted = Utils::LibSodiumWrapper::decryptAsymmetric( + ciphertext.data(), ciphertext.size(), + nonce, + mConnectionPublicKey, + arrayPrivateKey + ); + + if (decrypted.size() != 32) { + Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size."); + disconnect(); + return; + } + + std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size()); + + // Make a Session ID + randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID)); + + // TODO: Make the session ID little-endian for network transmission + + // Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose) + Nonce symNonce{}; // All zeros + std::vector encryptedSessionID = Utils::LibSodiumWrapper::encryptMessage( + reinterpret_cast(&mConnectionSessionID), sizeof(mConnectionSessionID), + mConnectionAESKey, symNonce + ); + + mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedSessionID.data(), encryptedSessionID.size())); + + // Add to session registry + Utils::log("Handshake with " + reqAddr + " completed successfully. Session ID assigned."); + auto session = std::make_shared(mConnectionAESKey, std::chrono::hours(12)); + SessionRegistry::getInstance().put(mConnectionSessionID, std::move(session)); + + } catch (const std::exception& e) { + Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what()); + disconnect(); + } + + break; + } + case ClientMessageType::HEARTBEAT: { + Utils::log("Received HEARTBEAT from " + reqAddr); + mHandler->sendMessage(ServerMessageType::HEARTBEAT_ACK, ""); // Send ACK + break; + } + case ClientMessageType::HEARTBEAT_ACK: { + Utils::log("Received HEARTBEAT_ACK from " + reqAddr); + mLastHeartbeatReceived = std::chrono::steady_clock::now(); + mMissedHeartbeats = 0; // Reset missed heartbeat count + break; + } + case ClientMessageType::GRACEFUL_DISCONNECT: { + Utils::log("Received GRACEFUL_DISCONNECT from " + reqAddr + ": " + data); + disconnect(); + break; + } + default: + Utils::warn("Unhandled message type from " + reqAddr); + break; + } + } +} \ No newline at end of file