Almost finished with the TCP Handshake procedure, need to properly handle disconnects (currently pretty forceful)

This commit is contained in:
2025-11-06 22:32:32 +01:00
parent 0f7191ad54
commit c7c3b1c54c
12 changed files with 408 additions and 41 deletions

View File

@@ -9,6 +9,10 @@
#include <columnlynx/common/net/tcp/tcp_message_type.hpp>
#include <columnlynx/common/net/tcp/net_helper.hpp>
#include <columnlynx/common/utils.hpp>
#include <columnlynx/common/libsodium_wrapper.hpp>
#include <array>
#include <algorithm>
#include <vector>
using asio::ip::tcp;
@@ -17,8 +21,9 @@ namespace ColumnLynx::Net::TCP {
public:
TCPClient(asio::io_context& ioContext,
const std::string& host,
const std::string& port)
: mResolver(ioContext), mSocket(ioContext), mHost(host), mPort(port) {}
const std::string& port,
Utils::LibSodiumWrapper* sodiumWrapper)
: mResolver(ioContext), mSocket(ioContext), mHost(host), mPort(port), mLibSodiumWrapper(sodiumWrapper) {}
void start() {
auto self = shared_from_this();
@@ -35,10 +40,10 @@ namespace ColumnLynx::Net::TCP {
mHandleMessage(static_cast<ServerMessageType>(MessageHandler::toUint8(type)), data);
});
mHandler->start();
// Init connection handshake
Utils::log("Sending handshake init to server.");
mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, "Hello, I am " + Utils::getHostname());
mHandler->sendMessage(ClientMessageType::HANDSHAKE_INIT, Utils::uint8ArrayToString(mLibSodiumWrapper->getXPublicKey(), crypto_box_PUBLICKEYBYTES));
} else {
Utils::error("Client connect failed: " + ec.message());
}
@@ -88,7 +93,81 @@ namespace ColumnLynx::Net::TCP {
switch (type) {
case ServerMessageType::HANDSHAKE_IDENTIFY:
Utils::log("Received server identity: " + data);
//std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey)));
std::memcpy(mServerPublicKey, data.data(), std::min(data.size(), sizeof(mServerPublicKey)));
// Generate and send challenge
{
Utils::log("Sending challenge to server.");
mSubmittedChallenge = Utils::LibSodiumWrapper::generateRandom256Bit(); // Temporarily store the challenge to verify later
mHandler->sendMessage(ClientMessageType::HANDSHAKE_CHALLENGE, Utils::uint8ArrayToString(mSubmittedChallenge));
}
break;
case ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE:
Utils::log("Received challenge response from server.");
{
// Verify the signature
Signature sig{};
std::memcpy(sig.data(), data.data(), std::min(data.size(), sig.size()));
if (Utils::LibSodiumWrapper::verifyMessage(mSubmittedChallenge.data(), mSubmittedChallenge.size(), sig, mServerPublicKey)) {
Utils::log("Challenge response verified successfully.");
// Convert the server's public key to Curve25519 for encryption
AsymPublicKey serverXPubKey{};
crypto_sign_ed25519_pk_to_curve25519(serverXPubKey.data(), mServerPublicKey);
// Generate AES key and send confirmation
mConnectionAESKey = Utils::LibSodiumWrapper::generateRandom256Bit();
AsymNonce nonce{};
randombytes_buf(nonce.data(), nonce.size());
// TODO: This is pretty redundant, it should return the required type directly
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
std::vector<uint8_t> encr = Utils::LibSodiumWrapper::encryptAsymmetric(
mConnectionAESKey.data(), mConnectionAESKey.size(),
nonce,
serverXPubKey,
arrayPrivateKey
);
std::vector<uint8_t> payload;
payload.reserve(nonce.size() + encr.size());
payload.insert(payload.end(), nonce.begin(), nonce.end());
payload.insert(payload.end(), encr.begin(), encr.end());
mHandler->sendMessage(ClientMessageType::HANDSHAKE_EXCHANGE_KEY, Utils::uint8ArrayToString(payload.data(), payload.size()));
} else {
Utils::error("Challenge response verification failed. Terminating connection.");
disconnect();
}
}
break;
case ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM:
Utils::log("Received handshake exchange key confirmation from server.");
// Decrypt the session ID using the established AES key
{
Nonce symNonce{}; // All zeros
std::vector<uint8_t> ciphertext(data.begin(), data.end());
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptMessage(
ciphertext.data(), ciphertext.size(),
mConnectionAESKey, symNonce
);
if (decrypted.size() != sizeof(mConnectionSessionID)) {
Utils::error("Decrypted session ID has invalid size. Terminating connection.");
disconnect();
return;
}
std::memcpy(&mConnectionSessionID, decrypted.data(), sizeof(mConnectionSessionID));
Utils::log("Connection established with Session ID: " + std::to_string(mConnectionSessionID));
}
break;
case ServerMessageType::GRACEFUL_DISCONNECT:
Utils::log("Server is disconnecting: " + data);
@@ -108,5 +187,9 @@ namespace ColumnLynx::Net::TCP {
std::shared_ptr<MessageHandler> mHandler;
std::string mHost, mPort;
uint8_t mServerPublicKey[32]; // Assuming 256-bit public key
std::array<uint8_t, 32> mSubmittedChallenge{};
Utils::LibSodiumWrapper* mLibSodiumWrapper;
uint64_t mConnectionSessionID;
SymmetricKey mConnectionAESKey;
};
}

View File

@@ -9,19 +9,179 @@
#include <string>
#include <cstdint>
#include <columnlynx/common/utils.hpp>
#include <array>
#include <vector>
namespace ColumnLynx {
using PublicKey = std::array<uint8_t, crypto_sign_PUBLICKEYBYTES>; // Ed25519
using PrivateKey = std::array<uint8_t, crypto_sign_SECRETKEYBYTES>; // Ed25519
using Signature = std::array<uint8_t, crypto_sign_BYTES>; // 64 bytes
using SymmetricKey = std::array<uint8_t, crypto_aead_chacha20poly1305_ietf_KEYBYTES>; // 32 bytes
using Nonce = std::array<uint8_t, crypto_aead_chacha20poly1305_ietf_NPUBBYTES>; // 12 bytes
using AsymPublicKey = std::array<uint8_t, crypto_box_PUBLICKEYBYTES>; // 32 bytes
using AsymSecretKey = std::array<uint8_t, crypto_box_SECRETKEYBYTES>; // 32 bytes
using AsymNonce = std::array<uint8_t, crypto_box_NONCEBYTES>; // 24 bytes
}
namespace ColumnLynx::Utils {
class LibSodiumWrapper {
public:
LibSodiumWrapper();
uint8_t* getPublicKey();
uint8_t* getPrivateKey();
uint8_t generateRandomAESKey();
uint8_t* getXPublicKey() { return mXPublicKey.data(); }
uint8_t* getXPrivateKey() { return mXPrivateKey.data(); }
// Helper section
// Generates a random 256-bit (32-byte) array
static std::array<uint8_t, 32> generateRandom256Bit();
static inline Signature signMessage(const uint8_t* msg, size_t len, const PrivateKey& sk) {
Signature sig{};
crypto_sign_detached(sig.data(), nullptr, msg, len, sk.data());
return sig;
}
static inline bool verifyMessage(const uint8_t* msg, size_t len,
const Signature& sig, const PublicKey& pk) {
return crypto_sign_verify_detached(sig.data(), msg, len, pk.data()) == 0;
}
// Overloads for std::string / std::array
static inline Signature signMessage(const std::string& msg, const PrivateKey& sk) {
return signMessage(reinterpret_cast<const uint8_t*>(msg.data()), msg.size(), sk);
}
template <size_t N>
static inline Signature signMessage(const std::array<uint8_t, N>& msg, const PrivateKey& sk) {
return signMessage(msg.data(), msg.size(), sk);
}
static inline Signature signMessage(const uint8_t* msg, size_t len, const uint8_t* sk_raw) {
Signature sig{};
crypto_sign_detached(sig.data(), nullptr, msg, len, sk_raw);
return sig;
}
static inline bool verifyMessage(const std::string& msg, const Signature& sig, const PublicKey& pk) {
return verifyMessage(reinterpret_cast<const uint8_t*>(msg.data()), msg.size(), sig, pk);
}
template <size_t N>
static inline bool verifyMessage(const std::array<uint8_t, N>& msg, const Signature& sig, const PublicKey& pk) {
return verifyMessage(msg.data(), msg.size(), sig, pk);
}
static inline bool verifyMessage(const uint8_t* msg, size_t len,
const Signature& sig, const uint8_t* pk_raw) {
return crypto_sign_verify_detached(sig.data(), msg, len, pk_raw) == 0;
}
// Encrypt with ChaCha20-Poly1305 (returns ciphertext as bytes)
static inline std::vector<uint8_t> encryptMessage(
const uint8_t* plaintext, size_t len,
const SymmetricKey& key, const Nonce& nonce,
const std::string& aad = "")
{
std::vector<uint8_t> ciphertext(len + crypto_aead_chacha20poly1305_ietf_ABYTES);
unsigned long long clen = 0;
if (crypto_aead_chacha20poly1305_ietf_encrypt(
ciphertext.data(), &clen,
plaintext, len,
reinterpret_cast<const unsigned char*>(aad.data()), aad.size(),
nullptr, // no additional secret data
nonce.data(), key.data()) != 0)
{
throw std::runtime_error("Encryption failed");
}
ciphertext.resize(static_cast<size_t>(clen));
return ciphertext;
}
// Decrypt with ChaCha20-Poly1305 (returns plaintext as bytes)
static inline std::vector<uint8_t> decryptMessage(
const uint8_t* ciphertext, size_t len,
const SymmetricKey& key, const Nonce& nonce,
const std::string& aad = "")
{
if (len < crypto_aead_chacha20poly1305_ietf_ABYTES)
throw std::runtime_error("Ciphertext too short");
std::vector<uint8_t> plaintext(len - crypto_aead_chacha20poly1305_ietf_ABYTES);
unsigned long long plen = 0;
if (crypto_aead_chacha20poly1305_ietf_decrypt(
plaintext.data(), &plen,
nullptr,
ciphertext, len,
reinterpret_cast<const unsigned char*>(aad.data()), aad.size(),
nonce.data(), key.data()) != 0)
{
throw std::runtime_error("Decryption failed or authentication tag invalid");
}
plaintext.resize(static_cast<size_t>(plen));
return plaintext;
}
static inline Nonce generateNonce() {
Nonce n{};
randombytes_buf(n.data(), n.size());
return n;
}
static inline std::vector<uint8_t> encryptAsymmetric(
const uint8_t* plaintext, size_t len,
const AsymNonce& nonce,
const AsymPublicKey& recipient_pk,
const AsymSecretKey& sender_sk)
{
std::vector<uint8_t> ciphertext(len + crypto_box_MACBYTES);
if (crypto_box_easy(
ciphertext.data(),
plaintext, len,
nonce.data(),
recipient_pk.data(), sender_sk.data()) != 0)
{
throw std::runtime_error("Asymmetric encryption failed");
}
return ciphertext;
}
static inline std::vector<uint8_t> decryptAsymmetric(
const uint8_t* ciphertext, size_t len,
const AsymNonce& nonce,
const AsymPublicKey& sender_pk,
const AsymSecretKey& recipient_sk)
{
if (len < crypto_box_MACBYTES)
throw std::runtime_error("Ciphertext too short");
std::vector<uint8_t> plaintext(len - crypto_box_MACBYTES);
if (crypto_box_open_easy(
plaintext.data(),
ciphertext, len,
nonce.data(),
sender_pk.data(), recipient_sk.data()) != 0)
{
throw std::runtime_error("Asymmetric decryption failed");
}
return plaintext;
}
private:
uint8_t mPublicKey[crypto_kx_PUBLICKEYBYTES];
uint8_t mPrivateKey[crypto_kx_SECRETKEYBYTES];
std::array<uint8_t, crypto_sign_PUBLICKEYBYTES> mPublicKey;
std::array<uint8_t, crypto_sign_SECRETKEYBYTES> mPrivateKey;
std::array<uint8_t, crypto_scalarmult_curve25519_BYTES> mXPublicKey;
std::array<uint8_t, crypto_scalarmult_curve25519_BYTES> mXPrivateKey;
};
}

View File

@@ -11,7 +11,7 @@ namespace ColumnLynx::Net::TCP {
enum class ServerMessageType : uint8_t { // Server to Client
HANDSHAKE_IDENTIFY = 0x02, // Send server identity (public key, server name, etc)
HANDSHAKE_CHALLENGE_RESPONSE = 0x04, // Response to client's challenge
HANDSHAKE_EXCHANGE_KEY = 0x06, // If accepted, send encrypted AES key and session ID
HANDSHAKE_EXCHANGE_KEY_CONFIRM = 0x06, // If accepted, send encrypted AES key and session ID
GRACEFUL_DISCONNECT = 0xFE, // Notify client of impending disconnection
KILL_CONNECTION = 0xFF, // Forecefully terminate the connection (with cleanup if possible), reserved for unrecoverable errors
@@ -20,8 +20,7 @@ namespace ColumnLynx::Net::TCP {
enum class ClientMessageType : uint8_t { // Client to Server
HANDSHAKE_INIT = 0xA1, // Request connection
HANDSHAKE_CHALLENGE = 0xA3, // Challenge ownership of private key
HANDSHAKE_CONFIRM = 0xA5, // Accept or reject identity, can kill the connection
HANDSHAKE_EXCHANGE_KEY_CONFIRM = 0xA7, // Confirm receipt of AES key and session ID
HANDSHAKE_EXCHANGE_KEY = 0xA5, // Accept or reject identity, can kill the connection, also sends the AES key
GRACEFUL_DISCONNECT = 0xFE, // Notify server of impending disconnection
KILL_CONNECTION = 0xFF, // Forecefully terminate the connection (with cleanup if possible), reserved for unrecoverable errors

View File

@@ -22,4 +22,17 @@ namespace ColumnLynx::Utils {
std::string getHostname();
std::string getVersion();
unsigned short serverPort();
// Raw byte to hex string conversion helper
std::string bytesToHexString(const uint8_t* bytes, size_t length);
// uint8_t to raw string conversion helper
template <size_t N>
inline std::string uint8ArrayToString(const std::array<uint8_t, N>& arr) {
return std::string(reinterpret_cast<const char*>(arr.data()), N);
}
inline std::string uint8ArrayToString(const uint8_t* data, size_t length) {
return std::string(reinterpret_cast<const char*>(data), length);
}
};

View File

@@ -9,6 +9,7 @@
#include <string>
#include <ctime>
#include <cstdint>
#include <new>
#include <asio/asio.hpp>
#include <columnlynx/common/net/tcp/tcp_message_type.hpp>
#include <columnlynx/common/net/tcp/tcp_message_handler.hpp>
@@ -22,10 +23,10 @@ namespace ColumnLynx::Net::TCP {
static pointer create(
asio::ip::tcp::socket socket,
Utils::LibSodiumWrapper *libsodium,
Utils::LibSodiumWrapper* sodiumWrapper,
std::function<void(pointer)> onDisconnect)
{
auto conn = pointer(new TCPConnection(std::move(socket), libsodium));
auto conn = pointer(new TCPConnection(std::move(socket), sodiumWrapper));
conn->mOnDisconnect = std::move(onDisconnect);
return conn;
}
@@ -71,16 +72,88 @@ namespace ColumnLynx::Net::TCP {
}
private:
TCPConnection(asio::ip::tcp::socket socket, Utils::LibSodiumWrapper *libsodium)
: mHandler(std::make_shared<MessageHandler>(std::move(socket))), mLibSodiumWrapper(libsodium) {}
TCPConnection(asio::ip::tcp::socket socket, Utils::LibSodiumWrapper* sodiumWrapper)
: mHandler(std::make_shared<MessageHandler>(std::move(socket))), mLibSodiumWrapper(sodiumWrapper) {}
void mHandleMessage(ClientMessageType type, const std::string& data) {
std::string reqAddr = mHandler->socket().remote_endpoint().address().to_string();
switch (type) {
case ClientMessageType::HANDSHAKE_INIT: {
Utils::log("Received HANDSHAKE_INIT from " + reqAddr + ": " + data);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, std::string(reinterpret_cast<const char*>(mLibSodiumWrapper->getPublicKey())));
Utils::log("Received HANDSHAKE_INIT from " + reqAddr);
std::memcpy(mConnectionPublicKey.data(), data.data(), std::min(data.size(), sizeof(mConnectionPublicKey))); // Store the client's public key (for identification)
mHandler->sendMessage(ServerMessageType::HANDSHAKE_IDENTIFY, Utils::uint8ArrayToString(mLibSodiumWrapper->getPublicKey(), crypto_sign_PUBLICKEYBYTES)); // This public key should always exist
break;
}
case ClientMessageType::HANDSHAKE_CHALLENGE: {
Utils::log("Received HANDSHAKE_CHALLENGE from " + reqAddr);
// Convert to byte array
uint8_t challengeData[32];
std::memcpy(challengeData, data.data(), std::min(data.size(), sizeof(challengeData)));
// Sign the challenge
Signature sig = Utils::LibSodiumWrapper::signMessage(
challengeData, sizeof(challengeData),
mLibSodiumWrapper->getPrivateKey()
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_CHALLENGE_RESPONSE, Utils::uint8ArrayToString(sig.data(), sig.size())); // Placeholder response
break;
}
case ClientMessageType::HANDSHAKE_EXCHANGE_KEY: {
Utils::log("Received HANDSHAKE_EXCHANGE_KEY from " + reqAddr);
// Extract encrypted AES key and nonce (nonce is the first 24 bytes, rest is the ciphertext)
if (data.size() < 24) { // Minimum size check (nonce)
Utils::warn("HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " is too short.");
disconnect();
return;
}
AsymNonce nonce{};
std::memcpy(nonce.data(), data.data(), nonce.size());
std::vector<uint8_t> ciphertext(data.size() - nonce.size());
std::memcpy(ciphertext.data(), data.data() + nonce.size(), ciphertext.size());
try {
std::array<uint8_t, 32> arrayPrivateKey;
std::copy(mLibSodiumWrapper->getXPrivateKey(),
mLibSodiumWrapper->getXPrivateKey() + 32,
arrayPrivateKey.begin());
// Decrypt the AES key using the client's public key and server's private key
std::vector<uint8_t> decrypted = Utils::LibSodiumWrapper::decryptAsymmetric(
ciphertext.data(), ciphertext.size(),
nonce,
mConnectionPublicKey,
arrayPrivateKey
);
if (decrypted.size() != 32) {
Utils::warn("Decrypted HANDSHAKE_EXCHANGE_KEY from " + reqAddr + " has invalid size.");
disconnect();
return;
}
std::memcpy(mConnectionAESKey.data(), decrypted.data(), decrypted.size());
// Make a Session ID
randombytes_buf(&mConnectionSessionID, sizeof(mConnectionSessionID));
// Encrypt the Session ID with the established AES key (using symmetric encryption, nonce can be all zeros for this purpose)
Nonce symNonce{}; // All zeros
std::vector<uint8_t> encryptedSessionID = Utils::LibSodiumWrapper::encryptMessage(
reinterpret_cast<uint8_t*>(&mConnectionSessionID), sizeof(mConnectionSessionID),
mConnectionAESKey, symNonce
);
mHandler->sendMessage(ServerMessageType::HANDSHAKE_EXCHANGE_KEY_CONFIRM, Utils::uint8ArrayToString(encryptedSessionID.data(), encryptedSessionID.size()));
} catch (const std::exception& e) {
Utils::error("Failed to decrypt HANDSHAKE_EXCHANGE_KEY from " + reqAddr + ": " + e.what());
disconnect();
}
break;
}
case ClientMessageType::GRACEFUL_DISCONNECT: {
@@ -97,5 +170,8 @@ namespace ColumnLynx::Net::TCP {
std::shared_ptr<MessageHandler> mHandler;
std::function<void(std::shared_ptr<TCPConnection>)> mOnDisconnect;
Utils::LibSodiumWrapper *mLibSodiumWrapper;
std::array<uint8_t, 32> mConnectionAESKey;
uint64_t mConnectionSessionID;
AsymPublicKey mConnectionPublicKey;
};
}

View File

@@ -10,6 +10,7 @@
#include <vector>
#include <memory>
#include <unordered_set>
#include <new>
#include <asio/asio.hpp>
#include <columnlynx/common/net/tcp/tcp_message_type.hpp>
#include <columnlynx/common/utils.hpp>