Files
columnlynx/src/client/net/tcp/tcp_client.cpp
2025-12-29 18:02:42 +01:00

294 lines
14 KiB
C++

// tcp_client.cpp - TCP Client for ColumnLynx
// Copyright (C) 2025 DcruBro
// Distributed under the terms of the GNU General Public License, either version 2 only or version 3. See LICENSES/ for details.
#include <columnlynx/client/net/tcp/tcp_client.hpp>
//#include <arpa/inet.h>
namespace ColumnLynx::Net::TCP {
void TCPClient::start() {
auto self = shared_from_this();
mResolver.async_resolve(mHost, mPort,
[this, self](asio::error_code ec, tcp::resolver::results_type endpoints) {
if (!ec) {
asio::async_connect(mSocket, endpoints,
[this, self](asio::error_code ec, const tcp::endpoint&) {
if (!ec) {
mConnected = true;
Utils::log("Client connected.");
mHandler = std::make_shared<MessageHandler>(std::move(mSocket));
mHandler->onMessage([this](AnyMessageType type, const std::string& data) {
mHandleMessage(static_cast<ServerMessageType>(MessageHandler::toUint8(type)), data);
});
// Close only after peer FIN to avoid RSTs
mHandler->onDisconnect([this](const asio::error_code& ec) {
asio::error_code ec2;
if (mHandler) {
mHandler->socket().close(ec2);
}
mConnected = false;
Utils::log(std::string("Server disconnected: ") + ec.message());
});
mHandler->start();
// Init connection handshake
Utils::log("Sending handshake init to server.");
// Check if hostname or IPv4/IPv6
try {
asio::ip::make_address(mHost);
self->mIsHostDomain = false; // IPv4 or IPv6 literal
} catch (const asio::system_error&) {
self->mIsHostDomain = true; // hostname / domain
}
std::vector<uint8_t> payload;
payload.reserve(1 + crypto_box_PUBLICKEYBYTES);
payload.push_back(Utils::protocolVersion());
/*payload.insert(payload.end(),
mLibSodiumWrapper->getXPublicKey(),
mLibSodiumWrapper->getXPublicKey() + crypto_box_PUBLICKEYBYTES
);*/
payload.insert(payload.end(),
mLibSodiumWrapper->getPublicKey(),
mLibSodiumWrapper->getPublicKey() + crypto_sign_PUBLICKEYBYTES
);
mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, Utils::uint8ArrayToString(payload.data(), payload.size()));
mStartHeartbeat();
} else {
if (!NetHelper::isExpectedDisconnect(ec)) {
Utils::error("Client connect failed: " + ec.message());
}
}
});
} else {
Utils::error("Client resolve failed: " + ec.message());
}
});
}
void TCPClient::sendMessage(ClientMessageType type, const std::string& data) {
if (!mConnected) {
Utils::error("Cannot send message, client not connected.");
return;
}
if (mHandler) {
asio::post(mHandler->socket().get_executor(), [self = shared_from_this(), type, data]() {
self->mHandler->sendMessage(type, data);
});
}
}
void TCPClient::disconnect(bool echo) {
if (mConnected && mHandler) {
if (echo) {
mHandler->sendMessage(ClientMessageType::GRACEFUL_DISCONNECT, "Goodbye");
}
asio::error_code ec;
mHeartbeatTimer.cancel();
// Half-close: stop sending, keep reading until peer FIN
mHandler->socket().shutdown(tcp::socket::shutdown_send, ec);
if (ec) {
Utils::error("Error during socket shutdown: " + ec.message());
}
// Do not close immediately; rely on onDisconnect to finalize
Utils::log("Client initiated graceful disconnect (half-close).");
}
}
bool TCPClient::isHandshakeComplete() const {
return mHandshakeComplete;
}
bool TCPClient::isConnected() const {
return mConnected;
}
void TCPClient::mStartHeartbeat() {
auto self = shared_from_this();
mHeartbeatTimer.expires_after(std::chrono::seconds(5));
mHeartbeatTimer.async_wait([this, self](const asio::error_code& ec) {
if (ec == asio::error::operation_aborted) {
return; // Timer was cancelled
}
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - self->mLastHeartbeatReceived).count();
if (elapsed >= 15) { // 3 missed heartbeats
Utils::error("Missed 3 heartbeats. I think the other party might have died! Disconnecting.");
// Close sockets forcefully, server is dead
asio::error_code ec;
mHandler->socket().shutdown(tcp::socket::shutdown_both, ec);
mHandler->socket().close(ec);
mConnected = false;
mGlobalKeyRef = nullptr;
if (mSessionIDRef) {
*mSessionIDRef = 0;
}
return;
}
self->sendMessage(ClientMessageType::HEARTBEAT);
Utils::debug("Sent HEARTBEAT to server.");
self->mLastHeartbeatSent = std::chrono::steady_clock::now();
self->mStartHeartbeat(); // Recursive
});
}
void TCPClient::mHandleMessage(ServerMessageType type, const std::string& data) {
switch (type) {
case ServerMessageType::HANDSHAKE_IDENTIFY: {
Utils::log("Received server identity: " + data);
std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey)));
// Verify pubkey against whitelisted_keys
std::vector<std::string> whitelistedKeys = Utils::getWhitelistedKeys();
if (std::find(whitelistedKeys.begin(), whitelistedKeys.end(), Utils::bytesToHexString(mServerPublicKey, 32)) == whitelistedKeys.end()) { // Key verification is handled in later steps of the handshake
if (!mInsecureMode) {
Utils::error("Server public key not in whitelisted_keys. Terminating connection.");
disconnect();
return;
}
Utils::warn("Server public key not in whitelisted_keys, but continuing due to insecure mode.");
}
// Generate and send challenge
Utils::log("Sending challenge to server.");
mSubmittedChallenge = Utils::LibSodiumWrapper::generateRandom256Bit(); // Temporarily store the challenge to verify later
mHandler->sendMessage(ClientMessageType::HANDSHAKE_CHALLENGE, Utils::uint8ArrayToString(mSubmittedChallenge));
}
break;
case ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE:
Utils::log("Received challenge response from server.");
{
// Verify the signature
Signature sig{};
std::memcpy(sig.data(), data.data(), std::min(data.size(), sig.size()));
if (Utils::LibSodiumWrapper::verifyMessage(mSubmittedChallenge.data(), mSubmittedChallenge.size(), sig, mServerPublicKey)) {
Utils::log("Challenge response verified successfully.");
// Convert the server's public key to Curve25519 for encryption
AsymPublicKey serverXPubKey{};
int r = crypto_sign_ed25519_pk_to_curve25519(serverXPubKey.data(), mServerPublicKey);
if (r != 0) {
Utils::error("Failed to convert server signing key to encryption key! Killing connection.");
disconnect();
return;
}
// Generate AES key and send confirmation
mConnectionAESKey = Utils::LibSodiumWrapper::generateRandom256Bit();
if (mGlobalKeyRef) { // Copy to the global reference
std::copy(mConnectionAESKey.begin(), mConnectionAESKey.end(), mGlobalKeyRef->begin());
}
AsymNonce nonce{};
randombytes_buf(nonce.data(), nonce.size());
// TODO: This is pretty redundant, it should return the required type directly
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
std::vector<uint8_t> encr = Utils::LibSodiumWrapper::encryptAsymmetric(
mConnectionAESKey.data(), mConnectionAESKey.size(),
nonce,
serverXPubKey,
arrayPrivateKey
);
std::vector<uint8_t> payload;
payload.reserve(nonce.size() + encr.size());
payload.insert(payload.end(), nonce.begin(), nonce.end());
payload.insert(payload.end(), encr.begin(), encr.end());
mHandler->sendMessage(ClientMessageType::HANDSHAKE_EXCHANGE_KEY, Utils::uint8ArrayToString(payload.data(), payload.size()));
} else {
Utils::error("Challenge response verification failed. Terminating connection.");
disconnect();
}
}
break;
case ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM:
Utils::log("Received handshake exchange key confirmation from server.");
// Decrypt the session ID using the established AES key
{
Nonce symNonce{}; // All zeros
std::vector<uint8_t> ciphertext(data.begin(), data.end());
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptMessage(
ciphertext.data(), ciphertext.size(),
mConnectionAESKey, symNonce
);
if (decrypted.size() != sizeof(mConnectionSessionID) + sizeof(Protocol::TunConfig)) {
Utils::error("Decrypted config has invalid size. Terminating connection.");
disconnect();
return;
}
std::memcpy(&mConnectionSessionID, decrypted.data(), sizeof(mConnectionSessionID));
std::memcpy(&mTunConfig, decrypted.data() + sizeof(mConnectionSessionID), sizeof(Protocol::TunConfig));
mConnectionSessionID = Utils::cbe64toh(mConnectionSessionID);
Utils::log("Connection established with Session ID: " + std::to_string(mConnectionSessionID));
if (mSessionIDRef) { // Copy to the global reference
*mSessionIDRef = mConnectionSessionID;
}
uint32_t clientIP = ntohl(mTunConfig.clientIP);
uint32_t serverIP = ntohl(mTunConfig.serverIP);
uint8_t prefixLen = mTunConfig.prefixLength;
uint16_t mtu = mTunConfig.mtu;
if (mTun) {
mTun->configureIP(clientIP, serverIP, prefixLen, mtu);
}
mHandshakeComplete = true;
}
break;
case ServerMessageType::HEARTBEAT:
Utils::debug("Received HEARTBEAT from server.");
mHandler->sendMessage(ClientMessageType::HEARTBEAT_ACK, ""); // Send ACK
break;
case ServerMessageType::HEARTBEAT_ACK:
Utils::debug("Received HEARTBEAT_ACK from server.");
mLastHeartbeatReceived = std::chrono::steady_clock::now();
mMissedHeartbeats = 0; // Reset missed heartbeat count
break;
case ServerMessageType::GRACEFUL_DISCONNECT:
Utils::log("Server is disconnecting: " + data);
if (mConnected) { // Prevent Recursion
disconnect(false);
}
break;
case ServerMessageType::KILL_CONNECTION:
Utils::warn("Server is killing the connection: " + data);
if (mConnected) {
disconnect(false);
}
break;
default:
Utils::log("Received unknown message type from server.");
break;
}
}
}