Almost finished with the TCP Handshake procedure, need to properly handle disconnects (currently pretty forceful)

This commit is contained in:
2025-11-06 22:32:32 +01:00
parent 0f7191ad54
commit c7c3b1c54c
12 changed files with 408 additions and 41 deletions

View File

@@ -9,6 +9,7 @@
#include <string>
#include <ctime>
#include <cstdint>
#include <new>
#include <asio/asio.hpp>
#include <columnlynx/common/net/tcp/tcp_message_type.hpp>
#include <columnlynx/common/net/tcp/tcp_message_handler.hpp>
@@ -22,10 +23,10 @@ namespace ColumnLynx::Net::TCP {
static pointer create(
asio::ip::tcp::socket socket,
Utils::LibSodiumWrapper *libsodium,
Utils::LibSodiumWrapper* sodiumWrapper,
std::function<void(pointer)> onDisconnect)
{
auto conn = pointer(new TCPConnection(std::move(socket), libsodium));
auto conn = pointer(new TCPConnection(std::move(socket), sodiumWrapper));
conn->mOnDisconnect = std::move(onDisconnect);
return conn;
}
@@ -71,16 +72,88 @@ namespace ColumnLynx::Net::TCP {
}
private:
TCPConnection(asio::ip::tcp::socket socket, Utils::LibSodiumWrapper *libsodium)
: mHandler(std::make_shared<MessageHandler>(std::move(socket))), mLibSodiumWrapper(libsodium) {}
TCPConnection(asio::ip::tcp::socket socket, Utils::LibSodiumWrapper* sodiumWrapper)
: mHandler(std::make_shared<MessageHandler>(std::move(socket))), mLibSodiumWrapper(sodiumWrapper) {}
void mHandleMessage(ClientMessageType type, const std::string& data) {
std::string reqAddr = mHandler->socket().remote_endpoint().address().to_string();
switch (type) {
case ClientMessageType::HANDSHAKE_INIT: {
Utils::log("Received HANDSHAKE_INIT from " + reqAddr + ": " + data);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, std::string(reinterpret_cast<const char*>(mLibSodiumWrapper->getPublicKey())));
Utils::log("Received HANDSHAKE_INIT from " + reqAddr);
std::memcpy(mConnectionPublicKey.data(), data.data(), std::min(data.size(), sizeof(mConnectionPublicKey))); // Store the client's public key (for identification)
mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist
break;
}
case ClientMessageType::HANDSHAKE_CHALLENGE: {
Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr);
// Convert to byte array
uint8_t challengeData[32];
std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData)));
// Sign the challenge
Signature sig = Utils::LibSodiumWrapper::signMessage(
challengeData, sizeof(challengeData),
mLibSodiumWrapper->getPrivateKey()
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response
break;
}
case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: {
Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr);
// Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext)
if (data.size() < 24) { // Minimum size check (nonce)
Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short.");
disconnect();
return;
}
AsymNonce nonce{};
std::memcpy(nonce.data(), data.data(), nonce.size());
std::vector<uint8_t> ciphertext(data.size() - nonce.size());
std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size());
try {
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
// Decrypt the AES key using the client's public key and server's private key
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptAsymmetric(
ciphertext.data(), ciphertext.size(),
nonce,
mConnectionPublicKey,
arrayPrivateKey
);
if (decrypted.size() != 32) {
Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size.");
disconnect();
return;
}
std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size());
// Make a Session ID
randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID));
// Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose)
Nonce symNonce{}; // All zeros
std::vector<uint8_t> encryptedSessionID = Utils::LibSodiumWrapper::encryptMessage(
reinterpret_cast<uint8_t*>(&mConnectionSessionID), sizeof(mConnectionSessionID),
mConnectionAESKey, symNonce
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedSessionID.data(), encryptedSessionID.size()));
} catch (const std::exception& e) {
Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what());
disconnect();
}
break;
}
case ClientMessageType::GRACEFUL_DISCONNECT: {
@@ -97,5 +170,8 @@ namespace ColumnLynx::Net::TCP {
std::shared_ptr<MessageHandler> mHandler;
std::function<void(std::shared_ptr<TCPConnection>)> mOnDisconnect;
Utils::LibSodiumWrapper *mLibSodiumWrapper;
std::array<uint8_t, 32> mConnectionAESKey;
uint64_t mConnectionSessionID;
AsymPublicKey mConnectionPublicKey;
};
}

View File

@@ -10,6 +10,7 @@
#include <vector>
#include <memory>
#include <unordered_set>
#include <new>
#include <asio/asio.hpp>
#include <columnlynx/common/net/tcp/tcp_message_type.hpp>
#include <columnlynx/common/utils.hpp>