// tcp_connection.cpp - TCP Connection for ColumnLynx // Copyright (C) 2025 DcruBro // Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details. #include namespace ColumnLynx::Net::TCP { void TCPConnection::start() { try { // Cache the remote IP early to avoid calling remote_endpoint() on closed sockets later mRemoteIP = mHandler->socket().remote_endpoint().address().to_string(); } catch (const std::exception& e) { mRemoteIP = "unknown"; Utils::warn("Failed to get remote endpoint: " + std::string(e.what())); } mHandler->onMessage([this](AnyMessageType type, const std::string& data) { mHandleMessage(static_cast(MessageHandler::toUint8(type)), data); }); mHandler->onDisconnect([this](const asio::error_code& ec) { // Peer has closed; finalize locally without sending RST Utils::log("Client disconnected: " + mRemoteIP + " - " + ec.message()); asio::error_code ec2; mHandler->socket().close(ec2); SessionRegistry::getInstance().erase(mConnectionSessionID); SessionRegistry::getInstance().deallocIP(mConnectionSessionID); Utils::log("Closed connection to " + mRemoteIP); if (mOnDisconnect) { mOnDisconnect(shared_from_this()); } }); mHandler->start(); mStartHeartbeat(); // Placeholder for message handling setup Utils::log("Client connected: " + mRemoteIP); } void TCPConnection::sendMessage(ServerMessageType type, const std::string& data) { if (mHandler) { mHandler->sendMessage(type, data); } } void TCPConnection::setDisconnectCallback(std::function)> cb) { mOnDisconnect = std::move(cb); } void TCPConnection::disconnect(bool echo) { if (echo) { mHandler->sendMessage(ServerMessageType::GRACEFUL_DISCONNECT, "Server initiated disconnect."); } mHeartbeatTimer.cancel(); asio::error_code ec; // Half-close: stop sending, keep reading until peer FIN mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_send, ec); if (ec) { Utils::error("Error during socket shutdown: " + ec.message()); } // Do not close immediately; final cleanup happens in onDisconnect Utils::log("Initiated graceful disconnect (half-close) to " + mRemoteIP); } uint64_t TCPConnection::getSessionID() const { return mConnectionSessionID; } std::array TCPConnection::getAESKey() const { return mConnectionAESKey; } void TCPConnection::mStartHeartbeat() { auto self = shared_from_this(); mHeartbeatTimer.expires_after(std::chrono::seconds(5)); mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) { if (ec == asio::error::operation_aborted) { return; // Timer was cancelled } auto now = std::chrono::steady_clock::now(); auto elapsed = std::chrono::duration_cast(now - self->mLastHeartbeatReceived).count(); if (elapsed >= 15) { // 3 missed heartbeats Utils::error("Missed 3 heartbeats. I think the other party (client " + std::to_string(self->mConnectionSessionID) + ") might have died! Disconnecting."); // Remove socket forcefully, client is dead asio::error_code ec; mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); mHandler->socket().close(ec); SessionRegistry::getInstance().erase(self->mConnectionSessionID); return; } self->sendMessage(ServerMessageType::HEARTBEAT); Utils::debug("Sent HEARTBEAT to client " + std::to_string(self->mConnectionSessionID)); self->mLastHeartbeatSent = now; self->mStartHeartbeat(); // Recursive }); } void TCPConnection::mHandleMessage(ClientMessageType type, const std::string& data) { std::string& reqAddr = mRemoteIP; switch (type) { case ClientMessageType::HANDSHAKE_INIT: { Utils::log("Received HANDSHAKE_INIT from " + reqAddr); if (data.size() < 1 + crypto_box_PUBLICKEYBYTES) { Utils::warn("HANDSHAKE_INIT from " + reqAddr + " is too short."); disconnect(); return; } uint8_t clientProtoVer = static_cast(data[0]); if (clientProtoVer != Utils::protocolVersion()) { Utils::warn("Client protocol version mismatch from " + reqAddr + ". Expected " + std::to_string(Utils::protocolVersion()) + ", got " + std::to_string(clientProtoVer) + "."); disconnect(); return; } Utils::log("Client protocol version " + std::to_string(clientProtoVer) + " accepted from " + reqAddr + "."); PublicKey signPk; std::memcpy(signPk.data(), data.data() + 1, std::min(data.size() - 1, sizeof(signPk))); // We can safely store this without further checking, the client will need to send the encrypted AES key in a way where they must possess the corresponding private key anyways. int r = crypto_sign_ed25519_pk_to_curve25519(mConnectionPublicKey.data(), signPk.data()); // Store the client's public encryption key key (for identification) if (r != 0) { Utils::error("Conversion of client signing key to encryption key failed! Killing connection from " + reqAddr); disconnect(); return; } Utils::debug("Client " + reqAddr + " converted public encryption key: " + Utils::bytesToHexString(mConnectionPublicKey.data(), 32)); Utils::debug("Key attempted connect: " + Utils::bytesToHexString(signPk.data(), signPk.size())); std::vector whitelistedKeys = Utils::getWhitelistedKeys(mConfigDirPath); if (std::find(whitelistedKeys.begin(), whitelistedKeys.end(), Utils::bytesToHexString(signPk.data(), signPk.size())) == whitelistedKeys.end()) { Utils::warn("Non-whitelisted client attempted to connect, terminating. Client IP: " + reqAddr); disconnect(); return; } Utils::debug("Client " + reqAddr + " passed authorized_keys"); mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist break; } case ClientMessageType::HANDSHAKE_CHALLENGE: { Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr); // Convert to byte array uint8_t challengeData[32]; std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData))); // Sign the challenge Signature sig = Utils::LibSodiumWrapper::signMessage( challengeData, sizeof(challengeData), mLibSodiumWrapper->getPrivateKey() ); mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response break; } case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: { Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr); // Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext) if (data.size() < 24) { // Minimum size check (nonce) Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short."); disconnect(); return; } AsymNonce nonce{}; std::memcpy(nonce.data(), data.data(), nonce.size()); std::vector ciphertext(data.size() - nonce.size()); std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size()); try { std::array arrayPrivateKey; std::copy(mLibSodiumWrapper->getXPrivateKey(), mLibSodiumWrapper->getXPrivateKey() + 32, arrayPrivateKey.begin()); // Decrypt the AES key using the client's public key and server's private key std::vector decrypted = Utils::LibSodiumWrapper::decryptAsymmetric( ciphertext.data(), ciphertext.size(), nonce, mConnectionPublicKey, arrayPrivateKey ); if (decrypted.size() != 32) { Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size."); disconnect(); return; } std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size()); // Make a Session ID randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID)); // Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose) Nonce symNonce{}; // All zeros std::string networkString = mRawServerConfig->find("NETWORK")->second; // The load check guarantees that this value exists uint8_t configMask = std::stoi(mRawServerConfig->find("SUBNET_MASK")->second); // Same deal here uint32_t baseIP = Net::VirtualInterface::stringToIpv4(networkString); if (baseIP == 0) { Utils::warn("Your NETWORK value in the server configuration is malformed! I will not be able to accept connections! (Connection " + reqAddr + " was killed)"); disconnect(); return; } uint32_t clientIP = SessionRegistry::getInstance().getFirstAvailableIP(baseIP, configMask); if (clientIP == 0) { Utils::warn("Out of available IPs! Disconnecting client " + reqAddr); disconnect(); return; } Protocol::TunConfig tunConfig{}; tunConfig.version = Utils::protocolVersion(); tunConfig.prefixLength = 24; tunConfig.mtu = 1420; tunConfig.serverIP = htonl(baseIP + 1); // e.g. 10.10.0.1 tunConfig.clientIP = htonl(clientIP); // e.g. 10.10.0.X tunConfig.dns1 = htonl(0x08080808); // 8.8.8.8 tunConfig.dns2 = 0; uint64_t sessionIDNet = Utils::chtobe64(mConnectionSessionID); std::vector payload(sizeof(uint64_t) + sizeof(tunConfig)); std::memcpy(payload.data(), &sessionIDNet, sizeof(uint64_t)); std::memcpy(payload.data() + sizeof(uint64_t), &tunConfig, sizeof(tunConfig)); std::vector encryptedPayload = Utils::LibSodiumWrapper::encryptMessage( payload.data(), payload.size(), mConnectionAESKey, symNonce ); mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedPayload.data(), encryptedPayload.size())); // Add to session registry Utils::log("Handshake with " + reqAddr + " completed successfully. Session ID assigned (" + std::to_string(mConnectionSessionID) + ")."); auto session = std::make_shared(mConnectionAESKey, std::chrono::hours(12), clientIP, htonl(0x0A0A0001), mConnectionSessionID); SessionRegistry::getInstance().put(mConnectionSessionID, std::move(session)); SessionRegistry::getInstance().lockIP(mConnectionSessionID, clientIP); } catch (const std::exception& e) { Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what()); disconnect(); } break; } case ClientMessageType::HEARTBEAT: { Utils::debug("Received HEARTBEAT from " + reqAddr); mHandler->sendMessage(ServerMessageType::HEARTBEAT_ACK, ""); // Send ACK break; } case ClientMessageType::HEARTBEAT_ACK: { Utils::debug("Received HEARTBEAT_ACK from " + reqAddr); mLastHeartbeatReceived = std::chrono::steady_clock::now(); mMissedHeartbeats = 0; // Reset missed heartbeat count break; } case ClientMessageType::GRACEFUL_DISCONNECT: { Utils::log("Received GRACEFUL_DISCONNECT from " + reqAddr + ": " + data); disconnect(); break; } case ClientMessageType::KILL_CONNECTION: { Utils::warn("Received KILL_CONNECTION from " + reqAddr + ": " + data); disconnect(); break; } default: Utils::warn("Unhandled message type from " + reqAddr); break; } } }