AuroraRuntime/Source/IO/TLS/TLSContext.cpp

586 lines
18 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: TLSContext.cpp
Date: 2022-8-24
Author: Reece
***/
#include "TLS.hpp"
#include "TLSContext.hpp"
#include <Source/IO/Protocol/AuProtocolStack.hpp>
#include <Aurora/IO/Net/NetExperimental.hpp>
#include <Source/IO/Net/AuNetSocket.hpp>
#include <Source/Crypto/X509/x509.hpp>
#include "TLSCertificateChain.hpp"
namespace Aurora::IO::TLS
{
mbedtls_entropy_context gEntropy;
mbedtls_ctr_drbg_context gCtrDrbg;
static bool gTlsReady {};
AuString TLSErrorToString(int iError)
{
char description[1024];
::mbedtls_strerror(iError, description, AuArraySize(description));
return description;
}
void TLSInit()
{
int iRet;
::mbedtls_ctr_drbg_init(&gCtrDrbg);
::mbedtls_entropy_init(&gEntropy);
if ((iRet = ::mbedtls_ctr_drbg_seed(&gCtrDrbg,
::mbedtls_entropy_func,
&gEntropy,
(const unsigned char *)"ReeceWasHere",
12)) != 0)
{
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
return;
}
gTlsReady = true;
}
static int TLSContextRecv(void *ctx,
unsigned char *buf,
size_t len)
{
return ((TLSContext *)ctx)->Read(buf, len);
}
static int TLSContextSend(void *ctx,
const unsigned char *buf,
size_t len)
{
return ((TLSContext *)ctx)->Write(buf, len);
}
TLSContext::TLSContext(const TLSMeta &meta) :
channelRecv_(this),
channelSend_(this),
meta_(meta)
{
this->pRecvStack_ = AuMakeShared<Protocol::ProtocolStack>();
this->pSendStack_ = AuMakeShared<Protocol::ProtocolStack>();
}
TLSContext::TLSContext(const AuSPtr<Protocol::IProtocolStack> &pSendStack,
const AuSPtr<Protocol::IProtocolStack> &pRecvStack,
const TLSMeta &meta) :
channelRecv_(this),
channelSend_(this),
pRecvStack_(AuStaticCast<Protocol::ProtocolStack>(pRecvStack)),
pSendStack_(AuStaticCast<Protocol::ProtocolStack>(pSendStack)),
meta_(meta)
{
}
TLSContext::~TLSContext()
{
this->Destroy();
}
//
// mbedtls nonblocking interface
//
//
int TLSContext::Write(const void *pIn, AuUInt length)
{
if (auto pPiece = this->pPiece_.lock())
{
AuUInt count {};
if (Aurora::IO::EStreamError::eErrorNone !=
pPiece->ToNextWriter()->Write(AuMemoryViewStreamRead { AuMemoryViewRead { pIn, length }, count }))
{
SysPushErrorIO("TLS couldn't flush write into next protocol layer or drain");
return -1;
}
return count;
}
return this->pSendStack_->pDrainBuffer->Write(pIn, length);
}
int TLSContext::Read(void *pOut, AuUInt length)
{
auto tempReadBuffer = this->channelRecv_.pReadInByteBuffer.lock();
if (!tempReadBuffer)
{
//SysPushErrorNet();
return MBEDTLS_ERR_SSL_WANT_READ;
}
auto toRead = length;// AuMin<AuUInt>(length, this->channelRecv_.uBytesReadAvail - this->channelRecv_.uBytesRead);
auto uBytesRead = tempReadBuffer->Read(pOut, toRead);
if (!uBytesRead)
{
return MBEDTLS_ERR_SSL_WANT_READ;
}
this->channelRecv_.bHasRead = true;
auto old = this->channelRecv_.uBytesRead;
this->channelRecv_.uBytesRead += uBytesRead;
return uBytesRead;
}
bool TLSContext::CheckCertificate(mbedtls_x509_crt const *child, const AuMemoryViewRead &read)
{
if (!this->meta_.pCertPin)
{
return true;
}
auto pCertChain = AuMakeShared<CertificateChain>();
if (!pCertChain)
{
SysPushErrorMemory();
return false;
}
pCertChain->Init(child);
auto bRet = this->meta_.pCertPin->CheckCertificate(pCertChain, read);
pCertChain->pCertificate = nullptr;
return bRet;
}
//
// tls context
//
//
bool TLSContext::Init()
{
int iRet;
if (!this->pSendStack_)
{
return false;
}
if (!this->pRecvStack_)
{
return false;
}
auto pPiece = this->pSendStack_->AddInterceptorEx(AuIO::Protocol::EProtocolWhere::eAppend, this->GetSendInterceptor(), this->meta_.uOutPageSize);
if (!pPiece)
{
SysPushErrorNet("Couldn't add TLS interceptor");
return false;
}
this->pPiece_ = pPiece;
if (!this->pRecvStack_->AddInterceptorEx(AuIO::Protocol::EProtocolWhere::eAppend, this->GetRecvInterceptor(), this->meta_.uOutPageSize))
{
SysPushErrorNet("Couldn't add TLS interceptor");
return false;
}
::mbedtls_ssl_init(&this->ssl);
::mbedtls_ssl_config_init(&this->conf);
if ((iRet = ::mbedtls_ssl_config_defaults(&this->conf,
this->meta_.bIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
this->meta_.transportProtocol == AuNet::ETransportProtocol::eProtocolUDP ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
{
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
return false;
}
if (this->meta_.bIsClient)
{
::mbedtls_ssl_conf_authmode(&this->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
}
else
{
if (this->meta_.server.bPinServerPeers)
{
::mbedtls_ssl_conf_authmode(&this->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
}
else
{
::mbedtls_ssl_conf_authmode(&this->conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
}
}
::mbedtls_ssl_conf_verify(&this->conf, [](void *p_ctx, mbedtls_x509_crt *crt,
int depth, uint32_t *flags)
{
*flags &= ~MBEDTLS_X509_BADCERT_NOT_TRUSTED;
if (depth != 0)
{
return 0;
}
((TLSContext *)p_ctx)->CheckCertificate(crt, { crt->raw.p, crt->raw.len }) ? 0 : -1;
return 0;
}, this);
//
::mbedtls_ssl_conf_ca_cb(&this->conf, [](void *p_ctx,
mbedtls_x509_crt const *child,
mbedtls_x509_crt **candidate_cas) -> int
{
return 0;// ((TLSContext *)p_ctx)->CheckCertificate(child, { child->raw.p, child->raw.len }) ? 0 : -1;
}, this);
::mbedtls_ssl_conf_rng(&this->conf, mbedtls_ctr_drbg_random, &gCtrDrbg);
::mbedtls_ssl_conf_dbg(&this->conf, [](void *, int, const char *as, int, const char *ad)
{
//AuLogDbg("{} <--> {}", as, ad);
}, nullptr);
if ((iRet = ::mbedtls_ssl_setup(&this->ssl, &this->conf)) != 0)
{
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
return false;
}
if (this->meta_.bIsClient)
{
if (this->meta_.client.sSNIServerName.size())
{
if ((iRet = ::mbedtls_ssl_set_hostname(&this->ssl, this->meta_.client.sSNIServerName.c_str())) != 0)
{
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
return false;
}
}
}
if (!this->meta_.bIsClient)
{
if (this->meta_.server.bSessionCache)
{
#if defined(MBEDTLS_SSL_CACHE_C)
::mbedtls_ssl_cache_init(&this->cache_);
if (this->meta_.server.iCacheMax != -1)
{
::mbedtls_ssl_cache_set_max_entries(&this->cache_, this->meta_.server.iCacheMax);
}
if (this->meta_.server.iCacheTimeout)
{
::mbedtls_ssl_cache_set_timeout(&this->cache_, this->meta_.server.iCacheTimeout);
}
::mbedtls_ssl_conf_session_cache(&this->conf,
&this->cache_,
mbedtls_ssl_cache_get,
mbedtls_ssl_cache_set);
#endif
}
#if defined(MBEDTLS_SSL_SESSION_TICKETS)
::mbedtls_ssl_ticket_init(&this->ticketCtx_);
#endif
if (this->meta_.transportProtocol == Net::ETransportProtocol::eProtocolUDP)
{
#if defined(MBEDTLS_SSL_COOKIE_C)
::mbedtls_ssl_cookie_init(&this->cookieCtx_);
#endif
#if defined(MBEDTLS_SSL_COOKIE_C)
if (this->meta_.dtls.iServerCookies > 0)
{
if ((iRet = ::mbedtls_ssl_cookie_setup(&this->cookieCtx_,
mbedtls_ctr_drbg_random,
&gCtrDrbg)) != 0)
{
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
return false;
}
::mbedtls_ssl_conf_dtls_cookies(&conf,
mbedtls_ssl_cookie_write,
mbedtls_ssl_cookie_check,
&this->cookieCtx_);
}
else
#endif
{
#if defined(MBEDTLS_SSL_DTLS_HELLO_VERIFY)
if (this->meta_.dtls.iServerCookies == 0)
{
::mbedtls_ssl_conf_dtls_cookies(&conf, NULL, NULL, NULL);
}
#endif
}
#if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY)
::mbedtls_ssl_conf_dtls_anti_replay(&this->conf,
this->meta_.dtls.iServerCookies ? MBEDTLS_SSL_ANTI_REPLAY_ENABLED : MBEDTLS_SSL_ANTI_REPLAY_DISABLED);
#endif
::mbedtls_ssl_conf_dtls_badmac_limit(&this->conf,
this->meta_.dtls.iServerBadMacLimit);
}
#if defined(MBEDTLS_SSL_SESSION_TICKETS)
if (this->meta_.server.bEnableTickets)
{
if ((iRet = ::mbedtls_ssl_ticket_setup(&this->ticketCtx_,
mbedtls_ctr_drbg_random,
&gCtrDrbg,
MBEDTLS_CIPHER_AES_256_GCM,
this->meta_.server.iTicketTimeout)) != 0)
{
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
return false;
}
::mbedtls_ssl_conf_session_tickets_cb(&this->conf,
mbedtls_ssl_ticket_write,
mbedtls_ssl_ticket_parse,
&this->ticketCtx_);
}
#endif
}
if (this->meta_.transportProtocol == Net::ETransportProtocol::eProtocolUDP)
{
if (this->meta_.dtls.iMTUSize)
{
::mbedtls_ssl_set_mtu(&this->ssl,
this->meta_.dtls.iMTUSize);
}
::mbedtls_ssl_set_timer_cb(&this->ssl,
&this->timer_,
mbedtls_timing_set_delay,
mbedtls_timing_get_delay);
}
::mbedtls_ssl_set_bio(&this->ssl,
this,
TLSContextSend,
TLSContextRecv,
nullptr);
if (this->meta_.cipherSuites.size())
{
this->cipherSuites_.reserve(this->meta_.cipherSuites.size());
for (const auto &cipher : this->meta_.cipherSuites)
{
if (!AuTryInsert(this->cipherSuites_, cipher))
{
SysPushErrorMemory();
return false;
}
}
}
else
{
auto &defaultCiphers = GetDefaultCipherSuites();
this->cipherSuites_.reserve(defaultCiphers.size());
for (const auto &cipher : defaultCiphers)
{
if (!AuTryInsert(this->cipherSuites_, cipher))
{
SysPushErrorMemory();
return false;
}
}
}
if (!AuTryInsert(this->cipherSuites_, 0))
{
SysPushErrorMemory();
return false;
}
((mbedtls_ssl_config *)this->ssl.private_conf/*fuck yourself*/)->private_ciphersuite_list = this->cipherSuites_.data();
return true;
}
void TLSContext::Destroy()
{
::mbedtls_ssl_free(&this->ssl);
::mbedtls_ssl_config_free(&this->conf);
#if defined(MBEDTLS_SSL_SESSION_TICKETS)
::mbedtls_ssl_ticket_free(&this->ticketCtx_);
#endif
#if defined(MBEDTLS_SSL_CACHE_C)
::mbedtls_ssl_cache_free(&this->cache_);
#endif
#if defined(MBEDTLS_SSL_COOKIE_C)
::mbedtls_ssl_cookie_free(&this->cookieCtx_);
#endif
this->Attach({});
}
void TLSContext::OnClose()
{
this->bIsDead = true;
if (auto pSocket = this->wpSocket_.lock())
{
pSocket->Shutdown(false);
}
AuResetMember(this->meta_);
}
void TLSContext::OnFatal()
{
this->bIsDead = true;
this->bIsFatal = true;
if (auto pSocket = this->wpSocket_.lock())
{
AuDynamicCast<AuNet::Socket, AuNet::ISocket>(pSocket)->SendErrorBeginShutdown(AuNet::ENetworkError::eTLSError);
}
AuResetMember(this->meta_);
}
//
// public api
//
//
AuSPtr<Protocol::IProtocolStack> TLSContext::ToReadStack()
{
return this->pRecvStack_;
}
AuSPtr<Protocol::IProtocolStack> TLSContext::ToWriteStack()
{
return this->pSendStack_;
}
AuSPtr<Protocol::IProtocolInterceptorEx> TLSContext::GetRecvInterceptor()
{
return AuSPtr<Protocol::IProtocolInterceptorEx>(AuSharedFromThis(), &this->channelRecv_);
}
AuSPtr<Protocol::IProtocolInterceptorEx> TLSContext::GetSendInterceptor()
{
return AuSPtr<Protocol::IProtocolInterceptorEx>(AuSharedFromThis(), &this->channelSend_);
}
void TLSContext::Attach(const AuSPtr<Net::ISocket> &pSocket)
{
if (!pSocket)
{
if (auto pOldSocket = this->wpSocket_.lock())
{
this->wpSocket_.reset();
pOldSocket->ToChannel()->SpecifyRecvProtocol({});
pOldSocket->ToChannel()->SpecifySendProtocol({});
// TODO (Reece): Shutdown hook
}
return;
}
this->wpSocket_ = pSocket;
pSocket->ToChannel()->SpecifyRecvProtocol(ToReadStack());
pSocket->ToChannel()->SpecifySendProtocol(ToWriteStack());
// TODO (Reece): Shutdown hook
}
void TLSContext::StartHandshake()
{
this->bIsAlive = false;
this->bIsDead = false;
this->bIsFatal = false;
this->iFatalError = 0;
this->bPinLock_ = false;
this->channelRecv_.HasCompletedHandshake() = false;
if (::mbedtls_ssl_session_reset(&this->ssl) != 0)
{
this->OnFatal();
}
else
{
this->channelRecv_.TryHandshake();
}
}
AuUInt16 TLSContext::GetCurrentCipherSuite()
{
return ::mbedtls_ssl_get_ciphersuite_id_from_ssl(&this->ssl);
}
void TLSContext::StartClose()
{
}
bool TLSContext::HasCompletedHandshake()
{
return this->channelRecv_.HasCompletedHandshake();
}
bool TLSContext::HasEnded()
{
return this->bIsDead;
}
bool TLSContext::HasFailed()
{
return this->bIsFatal;
}
int TLSContext::GetFatalErrorCode()
{
return this->iFatalError;
}
AuString TLSContext::GetFatalErrorCodeAsString()
{
return TLSErrorToString(this->GetFatalErrorCode());
}
AUKN_SYM AuSPtr<ITLSContext> NewTLSContext(const TLSMeta &meta)
{
auto pTlsContext = AuMakeShared<TLSContext>(meta);
if (!pTlsContext)
{
return {};
}
pTlsContext->Init();
return pTlsContext;
}
AUKN_SYM AuSPtr<ITLSContext> NewTLSContextEx(const AuSPtr<Protocol::IProtocolStack> &pSendStack,
const AuSPtr<Protocol::IProtocolStack> &pRecvStack,
const TLSMeta &meta)
{
auto pTlsContext = AuMakeShared<TLSContext>(pSendStack, pRecvStack, meta);
if (!pTlsContext)
{
return {};
}
if (!pTlsContext->Init())
{
return {};
}
return pTlsContext;
}
}