586 lines
18 KiB
C++
586 lines
18 KiB
C++
/***
|
|
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
|
|
|
|
File: TLSContext.cpp
|
|
Date: 2022-8-24
|
|
Author: Reece
|
|
***/
|
|
#include "TLS.hpp"
|
|
#include "TLSContext.hpp"
|
|
#include <Source/IO/Protocol/AuProtocolStack.hpp>
|
|
#include <Aurora/IO/Net/NetExperimental.hpp>
|
|
#include <Source/IO/Net/AuNetSocket.hpp>
|
|
#include <Source/Crypto/X509/x509.hpp>
|
|
#include "TLSCertificateChain.hpp"
|
|
|
|
namespace Aurora::IO::TLS
|
|
{
|
|
mbedtls_entropy_context gEntropy;
|
|
mbedtls_ctr_drbg_context gCtrDrbg;
|
|
|
|
static bool gTlsReady {};
|
|
|
|
AuString TLSErrorToString(int iError)
|
|
{
|
|
char description[1024];
|
|
::mbedtls_strerror(iError, description, AuArraySize(description));
|
|
return description;
|
|
}
|
|
|
|
void TLSInit()
|
|
{
|
|
int iRet;
|
|
|
|
::mbedtls_ctr_drbg_init(&gCtrDrbg);
|
|
::mbedtls_entropy_init(&gEntropy);
|
|
|
|
if ((iRet = ::mbedtls_ctr_drbg_seed(&gCtrDrbg,
|
|
::mbedtls_entropy_func,
|
|
&gEntropy,
|
|
(const unsigned char *)"ReeceWasHere",
|
|
12)) != 0)
|
|
{
|
|
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
|
|
return;
|
|
}
|
|
|
|
gTlsReady = true;
|
|
}
|
|
|
|
static int TLSContextRecv(void *ctx,
|
|
unsigned char *buf,
|
|
size_t len)
|
|
{
|
|
return ((TLSContext *)ctx)->Read(buf, len);
|
|
}
|
|
|
|
static int TLSContextSend(void *ctx,
|
|
const unsigned char *buf,
|
|
size_t len)
|
|
{
|
|
return ((TLSContext *)ctx)->Write(buf, len);
|
|
}
|
|
|
|
TLSContext::TLSContext(const TLSMeta &meta) :
|
|
channelRecv_(this),
|
|
channelSend_(this),
|
|
meta_(meta)
|
|
{
|
|
this->pRecvStack_ = AuMakeShared<Protocol::ProtocolStack>();
|
|
this->pSendStack_ = AuMakeShared<Protocol::ProtocolStack>();
|
|
}
|
|
|
|
TLSContext::TLSContext(const AuSPtr<Protocol::IProtocolStack> &pSendStack,
|
|
const AuSPtr<Protocol::IProtocolStack> &pRecvStack,
|
|
const TLSMeta &meta) :
|
|
channelRecv_(this),
|
|
channelSend_(this),
|
|
pRecvStack_(AuStaticCast<Protocol::ProtocolStack>(pRecvStack)),
|
|
pSendStack_(AuStaticCast<Protocol::ProtocolStack>(pSendStack)),
|
|
meta_(meta)
|
|
{
|
|
}
|
|
|
|
TLSContext::~TLSContext()
|
|
{
|
|
this->Destroy();
|
|
}
|
|
|
|
//
|
|
// mbedtls nonblocking interface
|
|
//
|
|
//
|
|
|
|
int TLSContext::Write(const void *pIn, AuUInt length)
|
|
{
|
|
if (auto pPiece = this->pPiece_.lock())
|
|
{
|
|
AuUInt count {};
|
|
if (Aurora::IO::EStreamError::eErrorNone !=
|
|
pPiece->ToNextWriter()->Write(AuMemoryViewStreamRead { AuMemoryViewRead { pIn, length }, count }))
|
|
{
|
|
SysPushErrorIO("TLS couldn't flush write into next protocol layer or drain");
|
|
return -1;
|
|
}
|
|
|
|
return count;
|
|
}
|
|
return this->pSendStack_->pDrainBuffer->Write(pIn, length);
|
|
}
|
|
|
|
int TLSContext::Read(void *pOut, AuUInt length)
|
|
{
|
|
auto tempReadBuffer = this->channelRecv_.pReadInByteBuffer.lock();
|
|
if (!tempReadBuffer)
|
|
{
|
|
//SysPushErrorNet();
|
|
return MBEDTLS_ERR_SSL_WANT_READ;
|
|
}
|
|
|
|
auto toRead = length;// AuMin<AuUInt>(length, this->channelRecv_.uBytesReadAvail - this->channelRecv_.uBytesRead);
|
|
auto uBytesRead = tempReadBuffer->Read(pOut, toRead);
|
|
if (!uBytesRead)
|
|
{
|
|
return MBEDTLS_ERR_SSL_WANT_READ;
|
|
}
|
|
|
|
this->channelRecv_.bHasRead = true;
|
|
auto old = this->channelRecv_.uBytesRead;
|
|
this->channelRecv_.uBytesRead += uBytesRead;
|
|
return uBytesRead;
|
|
}
|
|
|
|
bool TLSContext::CheckCertificate(mbedtls_x509_crt const *child, const AuMemoryViewRead &read)
|
|
{
|
|
if (!this->meta_.pCertPin)
|
|
{
|
|
return true;
|
|
}
|
|
|
|
auto pCertChain = AuMakeShared<CertificateChain>();
|
|
if (!pCertChain)
|
|
{
|
|
SysPushErrorMemory();
|
|
return false;
|
|
}
|
|
|
|
pCertChain->Init(child);
|
|
|
|
auto bRet = this->meta_.pCertPin->CheckCertificate(pCertChain, read);
|
|
pCertChain->pCertificate = nullptr;
|
|
return bRet;
|
|
}
|
|
|
|
//
|
|
// tls context
|
|
//
|
|
//
|
|
|
|
bool TLSContext::Init()
|
|
{
|
|
int iRet;
|
|
|
|
if (!this->pSendStack_)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (!this->pRecvStack_)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
auto pPiece = this->pSendStack_->AddInterceptorEx(AuIO::Protocol::EProtocolWhere::eAppend, this->GetSendInterceptor(), this->meta_.uOutPageSize);
|
|
if (!pPiece)
|
|
{
|
|
SysPushErrorNet("Couldn't add TLS interceptor");
|
|
return false;
|
|
}
|
|
|
|
this->pPiece_ = pPiece;
|
|
|
|
if (!this->pRecvStack_->AddInterceptorEx(AuIO::Protocol::EProtocolWhere::eAppend, this->GetRecvInterceptor(), this->meta_.uOutPageSize))
|
|
{
|
|
SysPushErrorNet("Couldn't add TLS interceptor");
|
|
return false;
|
|
}
|
|
|
|
::mbedtls_ssl_init(&this->ssl);
|
|
::mbedtls_ssl_config_init(&this->conf);
|
|
|
|
if ((iRet = ::mbedtls_ssl_config_defaults(&this->conf,
|
|
this->meta_.bIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
|
|
this->meta_.transportProtocol == AuNet::ETransportProtocol::eProtocolUDP ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM,
|
|
MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
|
|
{
|
|
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
|
|
return false;
|
|
}
|
|
|
|
if (this->meta_.bIsClient)
|
|
{
|
|
::mbedtls_ssl_conf_authmode(&this->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
|
|
}
|
|
else
|
|
{
|
|
if (this->meta_.server.bPinServerPeers)
|
|
{
|
|
::mbedtls_ssl_conf_authmode(&this->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
|
|
}
|
|
else
|
|
{
|
|
::mbedtls_ssl_conf_authmode(&this->conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
|
|
}
|
|
}
|
|
|
|
::mbedtls_ssl_conf_verify(&this->conf, [](void *p_ctx, mbedtls_x509_crt *crt,
|
|
int depth, uint32_t *flags)
|
|
{
|
|
*flags &= ~MBEDTLS_X509_BADCERT_NOT_TRUSTED;
|
|
if (depth != 0)
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
((TLSContext *)p_ctx)->CheckCertificate(crt, { crt->raw.p, crt->raw.len }) ? 0 : -1;
|
|
return 0;
|
|
}, this);
|
|
|
|
//
|
|
::mbedtls_ssl_conf_ca_cb(&this->conf, [](void *p_ctx,
|
|
mbedtls_x509_crt const *child,
|
|
mbedtls_x509_crt **candidate_cas) -> int
|
|
{
|
|
return 0;// ((TLSContext *)p_ctx)->CheckCertificate(child, { child->raw.p, child->raw.len }) ? 0 : -1;
|
|
}, this);
|
|
|
|
::mbedtls_ssl_conf_rng(&this->conf, mbedtls_ctr_drbg_random, &gCtrDrbg);
|
|
|
|
::mbedtls_ssl_conf_dbg(&this->conf, [](void *, int, const char *as, int, const char *ad)
|
|
{
|
|
//AuLogDbg("{} <--> {}", as, ad);
|
|
}, nullptr);
|
|
|
|
if ((iRet = ::mbedtls_ssl_setup(&this->ssl, &this->conf)) != 0)
|
|
{
|
|
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
|
|
return false;
|
|
}
|
|
|
|
if (this->meta_.bIsClient)
|
|
{
|
|
if (this->meta_.client.sSNIServerName.size())
|
|
{
|
|
if ((iRet = ::mbedtls_ssl_set_hostname(&this->ssl, this->meta_.client.sSNIServerName.c_str())) != 0)
|
|
{
|
|
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!this->meta_.bIsClient)
|
|
{
|
|
if (this->meta_.server.bSessionCache)
|
|
{
|
|
#if defined(MBEDTLS_SSL_CACHE_C)
|
|
::mbedtls_ssl_cache_init(&this->cache_);
|
|
|
|
if (this->meta_.server.iCacheMax != -1)
|
|
{
|
|
::mbedtls_ssl_cache_set_max_entries(&this->cache_, this->meta_.server.iCacheMax);
|
|
}
|
|
|
|
if (this->meta_.server.iCacheTimeout)
|
|
{
|
|
::mbedtls_ssl_cache_set_timeout(&this->cache_, this->meta_.server.iCacheTimeout);
|
|
}
|
|
|
|
::mbedtls_ssl_conf_session_cache(&this->conf,
|
|
&this->cache_,
|
|
mbedtls_ssl_cache_get,
|
|
mbedtls_ssl_cache_set);
|
|
#endif
|
|
}
|
|
|
|
#if defined(MBEDTLS_SSL_SESSION_TICKETS)
|
|
::mbedtls_ssl_ticket_init(&this->ticketCtx_);
|
|
#endif
|
|
|
|
if (this->meta_.transportProtocol == Net::ETransportProtocol::eProtocolUDP)
|
|
{
|
|
#if defined(MBEDTLS_SSL_COOKIE_C)
|
|
::mbedtls_ssl_cookie_init(&this->cookieCtx_);
|
|
#endif
|
|
|
|
#if defined(MBEDTLS_SSL_COOKIE_C)
|
|
if (this->meta_.dtls.iServerCookies > 0)
|
|
{
|
|
if ((iRet = ::mbedtls_ssl_cookie_setup(&this->cookieCtx_,
|
|
mbedtls_ctr_drbg_random,
|
|
&gCtrDrbg)) != 0)
|
|
{
|
|
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
|
|
return false;
|
|
}
|
|
|
|
::mbedtls_ssl_conf_dtls_cookies(&conf,
|
|
mbedtls_ssl_cookie_write,
|
|
mbedtls_ssl_cookie_check,
|
|
&this->cookieCtx_);
|
|
}
|
|
else
|
|
#endif
|
|
{
|
|
#if defined(MBEDTLS_SSL_DTLS_HELLO_VERIFY)
|
|
if (this->meta_.dtls.iServerCookies == 0)
|
|
{
|
|
::mbedtls_ssl_conf_dtls_cookies(&conf, NULL, NULL, NULL);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
#if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY)
|
|
::mbedtls_ssl_conf_dtls_anti_replay(&this->conf,
|
|
this->meta_.dtls.iServerCookies ? MBEDTLS_SSL_ANTI_REPLAY_ENABLED : MBEDTLS_SSL_ANTI_REPLAY_DISABLED);
|
|
#endif
|
|
|
|
::mbedtls_ssl_conf_dtls_badmac_limit(&this->conf,
|
|
this->meta_.dtls.iServerBadMacLimit);
|
|
}
|
|
|
|
|
|
#if defined(MBEDTLS_SSL_SESSION_TICKETS)
|
|
if (this->meta_.server.bEnableTickets)
|
|
{
|
|
if ((iRet = ::mbedtls_ssl_ticket_setup(&this->ticketCtx_,
|
|
mbedtls_ctr_drbg_random,
|
|
&gCtrDrbg,
|
|
MBEDTLS_CIPHER_AES_256_GCM,
|
|
this->meta_.server.iTicketTimeout)) != 0)
|
|
{
|
|
SysPushErrorNet("{} ({})", TLSErrorToString(iRet), iRet);
|
|
return false;
|
|
}
|
|
|
|
::mbedtls_ssl_conf_session_tickets_cb(&this->conf,
|
|
mbedtls_ssl_ticket_write,
|
|
mbedtls_ssl_ticket_parse,
|
|
&this->ticketCtx_);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
if (this->meta_.transportProtocol == Net::ETransportProtocol::eProtocolUDP)
|
|
{
|
|
if (this->meta_.dtls.iMTUSize)
|
|
{
|
|
::mbedtls_ssl_set_mtu(&this->ssl,
|
|
this->meta_.dtls.iMTUSize);
|
|
}
|
|
|
|
::mbedtls_ssl_set_timer_cb(&this->ssl,
|
|
&this->timer_,
|
|
mbedtls_timing_set_delay,
|
|
mbedtls_timing_get_delay);
|
|
|
|
}
|
|
|
|
::mbedtls_ssl_set_bio(&this->ssl,
|
|
this,
|
|
TLSContextSend,
|
|
TLSContextRecv,
|
|
nullptr);
|
|
|
|
if (this->meta_.cipherSuites.size())
|
|
{
|
|
this->cipherSuites_.reserve(this->meta_.cipherSuites.size());
|
|
for (const auto &cipher : this->meta_.cipherSuites)
|
|
{
|
|
if (!AuTryInsert(this->cipherSuites_, cipher))
|
|
{
|
|
SysPushErrorMemory();
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
auto &defaultCiphers = GetDefaultCipherSuites();
|
|
this->cipherSuites_.reserve(defaultCiphers.size());
|
|
for (const auto &cipher : defaultCiphers)
|
|
{
|
|
if (!AuTryInsert(this->cipherSuites_, cipher))
|
|
{
|
|
SysPushErrorMemory();
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!AuTryInsert(this->cipherSuites_, 0))
|
|
{
|
|
SysPushErrorMemory();
|
|
return false;
|
|
}
|
|
|
|
((mbedtls_ssl_config *)this->ssl.private_conf/*fuck yourself*/)->private_ciphersuite_list = this->cipherSuites_.data();
|
|
return true;
|
|
}
|
|
|
|
void TLSContext::Destroy()
|
|
{
|
|
::mbedtls_ssl_free(&this->ssl);
|
|
::mbedtls_ssl_config_free(&this->conf);
|
|
|
|
#if defined(MBEDTLS_SSL_SESSION_TICKETS)
|
|
::mbedtls_ssl_ticket_free(&this->ticketCtx_);
|
|
#endif
|
|
|
|
#if defined(MBEDTLS_SSL_CACHE_C)
|
|
::mbedtls_ssl_cache_free(&this->cache_);
|
|
#endif
|
|
|
|
#if defined(MBEDTLS_SSL_COOKIE_C)
|
|
::mbedtls_ssl_cookie_free(&this->cookieCtx_);
|
|
#endif
|
|
|
|
this->Attach({});
|
|
}
|
|
|
|
void TLSContext::OnClose()
|
|
{
|
|
this->bIsDead = true;
|
|
|
|
if (auto pSocket = this->wpSocket_.lock())
|
|
{
|
|
pSocket->Shutdown(false);
|
|
}
|
|
|
|
AuResetMember(this->meta_);
|
|
}
|
|
|
|
void TLSContext::OnFatal()
|
|
{
|
|
this->bIsDead = true;
|
|
this->bIsFatal = true;
|
|
|
|
if (auto pSocket = this->wpSocket_.lock())
|
|
{
|
|
AuDynamicCast<AuNet::Socket, AuNet::ISocket>(pSocket)->SendErrorBeginShutdown(AuNet::ENetworkError::eTLSError);
|
|
}
|
|
|
|
AuResetMember(this->meta_);
|
|
}
|
|
|
|
//
|
|
// public api
|
|
//
|
|
//
|
|
|
|
AuSPtr<Protocol::IProtocolStack> TLSContext::ToReadStack()
|
|
{
|
|
return this->pRecvStack_;
|
|
}
|
|
|
|
AuSPtr<Protocol::IProtocolStack> TLSContext::ToWriteStack()
|
|
{
|
|
return this->pSendStack_;
|
|
}
|
|
|
|
AuSPtr<Protocol::IProtocolInterceptorEx> TLSContext::GetRecvInterceptor()
|
|
{
|
|
return AuSPtr<Protocol::IProtocolInterceptorEx>(AuSharedFromThis(), &this->channelRecv_);
|
|
}
|
|
|
|
AuSPtr<Protocol::IProtocolInterceptorEx> TLSContext::GetSendInterceptor()
|
|
{
|
|
return AuSPtr<Protocol::IProtocolInterceptorEx>(AuSharedFromThis(), &this->channelSend_);
|
|
}
|
|
|
|
void TLSContext::Attach(const AuSPtr<Net::ISocket> &pSocket)
|
|
{
|
|
if (!pSocket)
|
|
{
|
|
if (auto pOldSocket = this->wpSocket_.lock())
|
|
{
|
|
this->wpSocket_.reset();
|
|
pOldSocket->ToChannel()->SpecifyRecvProtocol({});
|
|
pOldSocket->ToChannel()->SpecifySendProtocol({});
|
|
// TODO (Reece): Shutdown hook
|
|
}
|
|
return;
|
|
}
|
|
|
|
this->wpSocket_ = pSocket;
|
|
pSocket->ToChannel()->SpecifyRecvProtocol(ToReadStack());
|
|
pSocket->ToChannel()->SpecifySendProtocol(ToWriteStack());
|
|
// TODO (Reece): Shutdown hook
|
|
}
|
|
|
|
void TLSContext::StartHandshake()
|
|
{
|
|
this->bIsAlive = false;
|
|
this->bIsDead = false;
|
|
this->bIsFatal = false;
|
|
this->iFatalError = 0;
|
|
|
|
this->bPinLock_ = false;
|
|
|
|
this->channelRecv_.HasCompletedHandshake() = false;
|
|
|
|
if (::mbedtls_ssl_session_reset(&this->ssl) != 0)
|
|
{
|
|
this->OnFatal();
|
|
}
|
|
else
|
|
{
|
|
this->channelRecv_.TryHandshake();
|
|
}
|
|
}
|
|
|
|
AuUInt16 TLSContext::GetCurrentCipherSuite()
|
|
{
|
|
return ::mbedtls_ssl_get_ciphersuite_id_from_ssl(&this->ssl);
|
|
}
|
|
|
|
void TLSContext::StartClose()
|
|
{
|
|
}
|
|
|
|
bool TLSContext::HasCompletedHandshake()
|
|
{
|
|
return this->channelRecv_.HasCompletedHandshake();
|
|
}
|
|
|
|
bool TLSContext::HasEnded()
|
|
{
|
|
return this->bIsDead;
|
|
}
|
|
|
|
bool TLSContext::HasFailed()
|
|
{
|
|
return this->bIsFatal;
|
|
}
|
|
|
|
int TLSContext::GetFatalErrorCode()
|
|
{
|
|
return this->iFatalError;
|
|
}
|
|
|
|
AuString TLSContext::GetFatalErrorCodeAsString()
|
|
{
|
|
return TLSErrorToString(this->GetFatalErrorCode());
|
|
}
|
|
|
|
AUKN_SYM AuSPtr<ITLSContext> NewTLSContext(const TLSMeta &meta)
|
|
{
|
|
auto pTlsContext = AuMakeShared<TLSContext>(meta);
|
|
if (!pTlsContext)
|
|
{
|
|
return {};
|
|
}
|
|
|
|
pTlsContext->Init();
|
|
|
|
return pTlsContext;
|
|
}
|
|
|
|
AUKN_SYM AuSPtr<ITLSContext> NewTLSContextEx(const AuSPtr<Protocol::IProtocolStack> &pSendStack,
|
|
const AuSPtr<Protocol::IProtocolStack> &pRecvStack,
|
|
const TLSMeta &meta)
|
|
{
|
|
auto pTlsContext = AuMakeShared<TLSContext>(pSendStack, pRecvStack, meta);
|
|
if (!pTlsContext)
|
|
{
|
|
return {};
|
|
}
|
|
|
|
if (!pTlsContext->Init())
|
|
{
|
|
return {};
|
|
}
|
|
|
|
return pTlsContext;
|
|
}
|
|
} |