Almost finished with the TCP Handshake procedure, need to properly handle disconnects (currently pretty forceful)

This commit is contained in:
2025-11-06 22:32:32 +01:00
parent 0f7191ad54
commit c7c3b1c54c
12 changed files with 408 additions and 41 deletions

View File

@@ -9,6 +9,10 @@
#include <columnlynx/common/net/tcp/tcp_message_type.hpp>
#include <columnlynx/common/net/tcp/net_helper.hpp>
#include <columnlynx/common/utils.hpp>
#include <columnlynx/common/libsodium_wrapper.hpp>
#include <array>
#include <algorithm>
#include <vector>
using asio::ip::tcp;
@@ -17,8 +21,9 @@ namespace ColumnLynx::Net::TCP {
public:
TCPClient(asio::io_context& ioContext,
const std::string& host,
const std::string& port)
: mResolver(ioContext), mSocket(ioContext), mHost(host), mPort(port) {}
const std::string& port,
Utils::LibSodiumWrapper* sodiumWrapper)
: mResolver(ioContext), mSocket(ioContext), mHost(host), mPort(port), mLibSodiumWrapper(sodiumWrapper) {}
void start() {
auto self = shared_from_this();
@@ -35,10 +40,10 @@ namespace ColumnLynx::Net::TCP {
mHandleMessage(static_cast<ServerMessageType>(MessageHandler::toUint8(type)), data);
});
mHandler->start();
// Init connection handshake
Utils::log("Sending handshake init to server.");
mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, "Hello, I am " + Utils::getHostname());
mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, Utils::uint8ArrayToString(mLibSodiumWrapper->getXPublicKey(), crypto_box_PUBLICKEYBYTES));
} else {
Utils::error("Client connect failed: " + ec.message());
}
@@ -88,7 +93,81 @@ namespace ColumnLynx::Net::TCP {
switch (type) {
case ServerMessageType::HANDSHAKE_IDENTIFY:
Utils::log("Received server identity: " + data);
//std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey)));
std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey)));
// Generate and send challenge
{
Utils::log("Sending challenge to server.");
mSubmittedChallenge = Utils::LibSodiumWrapper::generateRandom256Bit(); // Temporarily store the challenge to verify later
mHandler->sendMessage(ClientMessageType::HANDSHAKE_CHALLENGE, Utils::uint8ArrayToString(mSubmittedChallenge));
}
break;
case ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE:
Utils::log("Received challenge response from server.");
{
// Verify the signature
Signature sig{};
std::memcpy(sig.data(), data.data(), std::min(data.size(), sig.size()));
if (Utils::LibSodiumWrapper::verifyMessage(mSubmittedChallenge.data(), mSubmittedChallenge.size(), sig, mServerPublicKey)) {
Utils::log("Challenge response verified successfully.");
// Convert the server's public key to Curve25519 for encryption
AsymPublicKey serverXPubKey{};
crypto_sign_ed25519_pk_to_curve25519(serverXPubKey.data(), mServerPublicKey);
// Generate AES key and send confirmation
mConnectionAESKey = Utils::LibSodiumWrapper::generateRandom256Bit();
AsymNonce nonce{};
randombytes_buf(nonce.data(), nonce.size());
// TODO: This is pretty redundant, it should return the required type directly
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
std::vector<uint8_t> encr = Utils::LibSodiumWrapper::encryptAsymmetric(
mConnectionAESKey.data(), mConnectionAESKey.size(),
nonce,
serverXPubKey,
arrayPrivateKey
);
std::vector<uint8_t> payload;
payload.reserve(nonce.size() + encr.size());
payload.insert(payload.end(), nonce.begin(), nonce.end());
payload.insert(payload.end(), encr.begin(), encr.end());
mHandler->sendMessage(ClientMessageType::HANDSHAKE_EXCHANGE_KEY, Utils::uint8ArrayToString(payload.data(), payload.size()));
} else {
Utils::error("Challenge response verification failed. Terminating connection.");
disconnect();
}
}
break;
case ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM:
Utils::log("Received handshake exchange key confirmation from server.");
// Decrypt the session ID using the established AES key
{
Nonce symNonce{}; // All zeros
std::vector<uint8_t> ciphertext(data.begin(), data.end());
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptMessage(
ciphertext.data(), ciphertext.size(),
mConnectionAESKey, symNonce
);
if (decrypted.size() != sizeof(mConnectionSessionID)) {
Utils::error("Decrypted session ID has invalid size. Terminating connection.");
disconnect();
return;
}
std::memcpy(&mConnectionSessionID, decrypted.data(), sizeof(mConnectionSessionID));
Utils::log("Connection established with Session ID: " + std::to_string(mConnectionSessionID));
}
break;
case ServerMessageType::GRACEFUL_DISCONNECT:
Utils::log("Server is disconnecting: " + data);
@@ -108,5 +187,9 @@ namespace ColumnLynx::Net::TCP {
std::shared_ptr<MessageHandler> mHandler;
std::string mHost, mPort;
uint8_t mServerPublicKey[32]; // Assuming 256-bit public key
std::array<uint8_t, 32> mSubmittedChallenge{};
Utils::LibSodiumWrapper* mLibSodiumWrapper;
uint64_t mConnectionSessionID;
SymmetricKey mConnectionAESKey;
};
}