Version b0.1 (Beta 0.1) - Refactor some stuff and move session_registry to .cpp file

This commit is contained in:
2025-12-04 15:48:18 +01:00
parent 668b96e8e0
commit 842752cd88
7 changed files with 117 additions and 85 deletions

View File

@@ -0,0 +1,280 @@
// tcp_connection.cpp - TCP Connection for ColumnLynx
// Copyright (C) 2025 DcruBro
// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details.
#include <columnlynx/server/net/tcp/tcp_connection.hpp>
namespace ColumnLynx::Net::TCP {
void TCPConnection::start() {
mHandler->onMessage([this](AnyMessageType type, const std::string& data) {
mHandleMessage(static_cast<ClientMessageType>(MessageHandler::toUint8(type)), data);
});
mHandler->onDisconnect([this](const asio::error_code& ec) {
Utils::log("Client disconnected: " + mHandler->socket().remote_endpoint().address().to_string() + " - " + ec.message());
disconnect();
});
mHandler->start();
mStartHeartbeat();
// Placeholder for message handling setup
Utils::log("Client connected: " + mHandler->socket().remote_endpoint().address().to_string());
}
void TCPConnection::sendMessage(ServerMessageType type, const std::string& data) {
if (mHandler) {
mHandler->sendMessage(type, data);
}
}
void TCPConnection::setDisconnectCallback(std::function<void(std::shared_ptr<TCPConnection>)> cb) {
mOnDisconnect = std::move(cb);
}
void TCPConnection::disconnect() {
std::string ip = mHandler->socket().remote_endpoint().address().to_string();
mHandler->sendMessage(ServerMessageType::GRACEFUL_DISCONNECT, "Server initiated disconnect.");
mHeartbeatTimer.cancel();
asio::error_code ec;
mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
SessionRegistry::getInstance().erase(mConnectionSessionID);
SessionRegistry::getInstance().deallocIP(mConnectionSessionID);
Utils::log("Closed connection to " + ip);
if (mOnDisconnect) {
mOnDisconnect(shared_from_this());
}
}
uint64_t TCPConnection::getSessionID() const {
return mConnectionSessionID;
}
std::array<uint8_t, 32> TCPConnection::getAESKey() const {
return mConnectionAESKey;
}
void TCPConnection::mStartHeartbeat() {
auto self = shared_from_this();
mHeartbeatTimer.expires_after(std::chrono::seconds(5));
mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) {
if (ec == asio::error::operation_aborted) {
return; // Timer was cancelled
}
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - self->mLastHeartbeatReceived).count();
if (elapsed >= 15) { // 3 missed heartbeats
Utils::error("Missed 3 heartbeats. I think the other party (client " + std::to_string(self->mConnectionSessionID) + ") might have died! Disconnecting.");
// Remove socket forcefully, client is dead
asio::error_code ec;
mHandler->socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
SessionRegistry::getInstance().erase(self->mConnectionSessionID);
return;
}
self->sendMessage(ServerMessageType::HEARTBEAT);
Utils::debug("Sent HEARTBEAT to client " + std::to_string(self->mConnectionSessionID));
self->mLastHeartbeatSent = now;
self->mStartHeartbeat(); // Recursive
});
}
void TCPConnection::mHandleMessage(ClientMessageType type, const std::string& data) {
std::string reqAddr = mHandler->socket().remote_endpoint().address().to_string();
switch (type) {
case ClientMessageType::HANDSHAKE_INIT: {
Utils::log("Received HANDSHAKE_INIT from " + reqAddr);
if (data.size() < 1 + crypto_box_PUBLICKEYBYTES) {
Utils::warn("HANDSHAKE_INIT from " + reqAddr + " is too short.");
disconnect();
return;
}
uint8_t clientProtoVer = static_cast<uint8_t>(data[0]);
if (clientProtoVer != Utils::protocolVersion()) {
Utils::warn("Client protocol version mismatch from " + reqAddr + ". Expected " +
std::to_string(Utils::protocolVersion()) + ", got " + std::to_string(clientProtoVer) + ".");
disconnect();
return;
}
Utils::log("Client protocol version " + std::to_string(clientProtoVer) + " accepted from " + reqAddr + ".");
PublicKey signPk;
std::memcpy(signPk.data(), data.data() + 1, std::min(data.size() - 1, sizeof(signPk)));
// We can safely store this without further checking, the client will need to send the encrypted AES key in a way where they must possess the corresponding private key anyways.
int r = crypto_sign_ed25519_pk_to_curve25519(mConnectionPublicKey.data(), signPk.data()); // Store the client's public encryption key key (for identification)
if (r != 0) {
Utils::error("Conversion of client signing key to encryption key failed! Killing connection from " + reqAddr);
disconnect();
return;
}
Utils::debug("Client " + reqAddr + " converted public encryption key: " + Utils::bytesToHexString(mConnectionPublicKey.data(), 32));
Utils::debug("Key attempted connect: " + Utils::bytesToHexString(signPk.data(), signPk.size()));
std::vector<std::string> whitelistedKeys = Utils::getWhitelistedKeys();
if (std::find(whitelistedKeys.begin(), whitelistedKeys.end(), Utils::bytesToHexString(signPk.data(), signPk.size())) == whitelistedKeys.end()) {
Utils::warn("Non-whitelisted client attempted to connect, terminating. Client IP: " + reqAddr);
disconnect();
return;
}
Utils::debug("Client " + reqAddr + " passed authorized_keys");
mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist
break;
}
case ClientMessageType::HANDSHAKE_CHALLENGE: {
Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr);
// Convert to byte array
uint8_t challengeData[32];
std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData)));
// Sign the challenge
Signature sig = Utils::LibSodiumWrapper::signMessage(
challengeData, sizeof(challengeData),
mLibSodiumWrapper->getPrivateKey()
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response
break;
}
case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: {
Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr);
// Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext)
if (data.size() < 24) { // Minimum size check (nonce)
Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short.");
disconnect();
return;
}
AsymNonce nonce{};
std::memcpy(nonce.data(), data.data(), nonce.size());
std::vector<uint8_t> ciphertext(data.size() - nonce.size());
std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size());
try {
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
// Decrypt the AES key using the client's public key and server's private key
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptAsymmetric(
ciphertext.data(), ciphertext.size(),
nonce,
mConnectionPublicKey,
arrayPrivateKey
);
if (decrypted.size() != 32) {
Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size.");
disconnect();
return;
}
std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size());
// Make a Session ID
randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID));
// Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose)
Nonce symNonce{}; // All zeros
std::string networkString = mRawServerConfig->find("NETWORK")->second; // The load check guarantees that this value exists
uint8_t configMask = std::stoi(mRawServerConfig->find("SUBNET_MASK")->second); // Same deal here
uint32_t baseIP = Net::VirtualInterface::stringToIpv4(networkString);
if (baseIP == 0) {
Utils::warn("Your NETWORK value in the server configuration is malformed! I will not be able to accept connections! (Connection " + reqAddr + " was killed)");
disconnect();
return;
}
uint32_t clientIP = SessionRegistry::getInstance().getFirstAvailableIP(baseIP, configMask);
if (clientIP == 0) {
Utils::warn("Out of available IPs! Disconnecting client " + reqAddr);
disconnect();
return;
}
Protocol::TunConfig tunConfig{};
tunConfig.version = Utils::protocolVersion();
tunConfig.prefixLength = 24;
tunConfig.mtu = 1420;
tunConfig.serverIP = htonl(baseIP + 1); // e.g. 10.10.0.1
tunConfig.clientIP = htonl(clientIP); // e.g. 10.10.0.X
tunConfig.dns1 = htonl(0x08080808); // 8.8.8.8
tunConfig.dns2 = 0;
uint64_t sessionIDNet = Utils::chtobe64(mConnectionSessionID);
std::vector<uint8_t> payload(sizeof(uint64_t) + sizeof(tunConfig));
std::memcpy(payload.data(), &sessionIDNet, sizeof(uint64_t));
std::memcpy(payload.data() + sizeof(uint64_t), &tunConfig, sizeof(tunConfig));
std::vector<uint8_t> encryptedPayload = Utils::LibSodiumWrapper::encryptMessage(
payload.data(), payload.size(),
mConnectionAESKey, symNonce
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedPayload.data(), encryptedPayload.size()));
// Add to session registry
Utils::log("Handshake with " + reqAddr + " completed successfully. Session ID assigned (" + std::to_string(mConnectionSessionID) + ").");
auto session = std::make_shared<SessionState>(mConnectionAESKey, std::chrono::hours(12), clientIP, htonl(0x0A0A0001), mConnectionSessionID);
SessionRegistry::getInstance().put(mConnectionSessionID, std::move(session));
SessionRegistry::getInstance().lockIP(mConnectionSessionID, clientIP);
} catch (const std::exception& e) {
Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what());
disconnect();
}
break;
}
case ClientMessageType::HEARTBEAT: {
Utils::debug("Received HEARTBEAT from " + reqAddr);
mHandler->sendMessage(ServerMessageType::HEARTBEAT_ACK, ""); // Send ACK
break;
}
case ClientMessageType::HEARTBEAT_ACK: {
Utils::debug("Received HEARTBEAT_ACK from " + reqAddr);
mLastHeartbeatReceived = std::chrono::steady_clock::now();
mMissedHeartbeats = 0; // Reset missed heartbeat count
break;
}
case ClientMessageType::GRACEFUL_DISCONNECT: {
Utils::log("Received GRACEFUL_DISCONNECT from " + reqAddr + ": " + data);
disconnect();
break;
}
default:
Utils::warn("Unhandled message type from " + reqAddr);
break;
}
}
}

View File

@@ -0,0 +1,76 @@
// tcp_server.cpp - TCP Server for ColumnLynx
// Copyright (C) 2025 DcruBro
// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details.
#include <columnlynx/server/net/tcp/tcp_server.hpp>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
#include <memory>
#include <unordered_set>
#include <asio.hpp>
#include <columnlynx/common/net/tcp/tcp_message_type.hpp>
#include <columnlynx/common/net/tcp/net_helper.hpp>
#include <columnlynx/common/utils.hpp>
namespace ColumnLynx::Net::TCP {
void TCPServer::mStartAccept() {
mAcceptor.async_accept(
[this](asio::error_code ec, asio::ip::tcp::socket socket) {
if (ec) {
if (ec == asio::error::operation_aborted) {
// Acceptor was cancelled/closed during shutdown
return;
}
Utils::error("Accept failed: " + ec.message());
// Try again only if still running
if (mHostRunning && *mHostRunning && mAcceptor.is_open())
mStartAccept();
return;
}
auto client = TCPConnection::create(
std::move(socket),
mSodiumWrapper,
&mRawServerConfig,
[this](std::shared_ptr<TCPConnection> c) {
mClients.erase(c);
Utils::log("Client removed.");
}
);
mClients.insert(client);
client->start();
Utils::log("Accepted new client connection.");
if (mHostRunning && *mHostRunning && mAcceptor.is_open())
mStartAccept();
}
);
}
void TCPServer::stop() {
// Stop accepting
if (mAcceptor.is_open()) {
asio::error_code ec;
mAcceptor.cancel(ec);
mAcceptor.close(ec);
Utils::log("TCP Acceptor closed.");
}
// Snapshot to avoid iterator invalidation while callbacks erase()
std::vector<std::shared_ptr<TCPConnection>> snapshot(mClients.begin(), mClients.end());
for (auto &client : snapshot) {
try {
client->disconnect(); // should shutdown+close the socket
Utils::log("GRACEFUL_DISCONNECT sent to session: " + std::to_string(client->getSessionID()));
} catch (const std::exception& e) {
Utils::error(std::string("Error disconnecting client: ") + e.what());
}
}
// Let the erase callback run as sockets close
// Do NOT destroy server while io handlers may still reference it
}
}

View File

@@ -0,0 +1,128 @@
// udp_server.cpp - UDP Server for ColumnLynx
// Copyright (C) 2025 DcruBro
// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details.
#include <columnlynx/server/net/udp/udp_server.hpp>
#include <columnlynx/common/libsodium_wrapper.hpp>
#include <columnlynx/common/net/session_registry.hpp>
#include <sodium.h>
#include <memory>
namespace ColumnLynx::Net::UDP {
void UDPServer::mStartReceive() {
mSocket.async_receive_from(
asio::buffer(mRecvBuffer), mRemoteEndpoint,
[this](asio::error_code ec, std::size_t bytes) {
if (ec) {
if (ec == asio::error::operation_aborted) return; // Socket closed
// Other recv error
if (mHostRunning && *mHostRunning) mStartReceive();
return;
}
if (bytes > 0) mHandlePacket(bytes);
if (mHostRunning && *mHostRunning) mStartReceive();
}
);
}
void UDPServer::mHandlePacket(std::size_t bytes) {
if (bytes < sizeof(UDPPacketHeader))
return;
const auto* hdr = reinterpret_cast<UDPPacketHeader*>(mRecvBuffer.data());
// Get plaintext session ID (assuming first 8 bytes after nonce (header))
uint64_t sessionID = 0;
std::memcpy(&sessionID, mRecvBuffer.data() + sizeof(UDPPacketHeader), sizeof(uint64_t));
auto it = mRecvBuffer.begin() + sizeof(UDPPacketHeader) + sizeof(uint64_t);
std::vector<uint8_t> encryptedPayload(it, mRecvBuffer.begin() + bytes);
// Get associated session state
std::shared_ptr<const SessionState> session = SessionRegistry::getInstance().get(sessionID);
if (!session) {
Utils::warn("UDP: Unknown or invalid session from " + mRemoteEndpoint.address().to_string());
return;
}
// Decrypt the actual payload
try {
//Utils::debug("Using AES key " + Utils::bytesToHexString(session->aesKey.data(), 32));
auto plaintext = Utils::LibSodiumWrapper::decryptMessage(
encryptedPayload.data(), encryptedPayload.size(),
session->aesKey,
hdr->nonce, "udp-data"
//std::string(reinterpret_cast<const char*>(&sessionID), sizeof(uint64_t))
);
Utils::debug("Passed decryption");
const_cast<SessionState*>(session.get())->setUDPEndpoint(mRemoteEndpoint); // Update endpoint after confirming decryption
// Update recv counter
const_cast<SessionState*>(session.get())->recv_ctr.fetch_add(1, std::memory_order_relaxed);
// For now, just log the decrypted payload
std::string payloadStr(plaintext.begin(), plaintext.end());
Utils::debug("UDP: Received packet from " + mRemoteEndpoint.address().to_string() + " - Payload: " + payloadStr);
if (mTun) {
mTun->writePacket(plaintext); // Send to virtual interface
}
} catch (const std::exception &ex) {
Utils::warn("UDP: Failed to process payload from " + mRemoteEndpoint.address().to_string() + " Raw Error: '" + ex.what() + "'");
return;
}
}
void UDPServer::sendData(const uint64_t sessionID, const std::string& data) {
// Find the IPv4/IPv6 endpoint for the session
std::shared_ptr<const SessionState> session = SessionRegistry::getInstance().get(sessionID);
if (!session) {
Utils::warn("UDP: Cannot send data, unknown session ID " + std::to_string(sessionID));
return;
}
asio::ip::udp::endpoint endpoint = session->udpEndpoint;
if (endpoint.address().is_unspecified()) {
Utils::warn("UDP: Cannot send data, session ID " + std::to_string(sessionID) + " has no known UDP endpoint.");
return;
}
// Prepare packet
UDPPacketHeader hdr{};
randombytes_buf(hdr.nonce.data(), hdr.nonce.size());
auto encryptedPayload = Utils::LibSodiumWrapper::encryptMessage(
reinterpret_cast<const uint8_t*>(data.data()), data.size(),
session->aesKey, hdr.nonce, "udp-data"
//std::string(reinterpret_cast<const char*>(&sessionID), sizeof(uint64_t))
);
std::vector<uint8_t> packet;
packet.reserve(sizeof(UDPPacketHeader) + sizeof(uint64_t) + encryptedPayload.size());
packet.insert(packet.end(),
reinterpret_cast<uint8_t*>(&hdr),
reinterpret_cast<uint8_t*>(&hdr) + sizeof(UDPPacketHeader)
);
packet.insert(packet.end(),
reinterpret_cast<const uint8_t*>(&sessionID),
reinterpret_cast<const uint8_t*>(&sessionID) + sizeof(sessionID)
);
packet.insert(packet.end(), encryptedPayload.begin(), encryptedPayload.end());
// Send packet
mSocket.send_to(asio::buffer(packet), endpoint);
Utils::debug("UDP: Sent packet of size " + std::to_string(packet.size()) + " to " + std::to_string(sessionID) + " (" + endpoint.address().to_string() + ":" + std::to_string(endpoint.port()) + ")");
}
void UDPServer::stop() {
if (mSocket.is_open()) {
asio::error_code ec;
mSocket.cancel(ec);
mSocket.close(ec);
Utils::log("UDP Socket closed.");
}
}
}