/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: TLSProtocolRecv.cpp Date: 2022-8-24 Author: Reece ***/ #include "TLS.hpp" #include "TLSContext.hpp" namespace Aurora::IO::TLS { TLSProtocolRecv::TLSProtocolRecv(TLSContext *pParent) : pParent_(pParent) { } bool TLSProtocolRecv::OnDataAvailable(const AuSPtr &pReadInByteBuffer, const AuSPtr &pWriteOutByteBuffer) { this->bHasRead = false; this->pReadInByteBuffer = pReadInByteBuffer; this->uBytesReadAvail = pReadInByteBuffer->RemainingBytes(); this->uBytesRead = 0; if (this->pParent_->bIsDead) { return false; } if (!TryHandshake()) { return true; } if (!this->DoOneTick(pWriteOutByteBuffer)) { this->pReadInByteBuffer.reset(); return true; } this->pReadInByteBuffer.reset(); return true; } bool TLSProtocolRecv::TryHandshake() { if (this->bHasCompletedHandshake_) { return true; } bool bComplete {}; if (!this->DoHandshake(bComplete)) { return false; } if (!bComplete) { return true; } this->bHasCompletedHandshake_ = true; this->pParent_->bIsAlive = true; this->pParent_->bPinLock_ = false; return true; } bool TLSProtocolRecv::DoHandshake(bool &bComplete) { bComplete = false; if (this->bHasFailedOnce) { auto pBuffer = this->pReadInByteBuffer.lock(); if (!pBuffer) { return true; } if (!pBuffer->RemainingBytes()) { return true; } } int iRet {}; switch ((iRet = ::mbedtls_ssl_handshake(&this->pParent_->ssl))) { case MBEDTLS_ERR_SSL_WANT_READ: case MBEDTLS_ERR_SSL_WANT_WRITE: case MBEDTLS_ERR_SSL_CONN_EOF: { this->bHasFailedOnce = true; return false; } case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: { this->bHasFailedOnce = true; this->pParent_->OnClose(); return true; } case MBEDTLS_ERR_X509_FATAL_ERROR: { this->bHasFailedOnce = true; this->pParent_->OnClose(); return false; } case 0: { bComplete = true; return true; } default: this->pParent_->iFatalError = iRet; SysPushErrorNet("Error during handshake: {:x}", iRet); this->pParent_->OnFatal(); return false; } } bool TLSProtocolRecv::DoOneTick(const AuSPtr &pWriteOutByteBuffer) { while (true) { auto pDest = pWriteOutByteBuffer->GetNextLinearWrite(); AuUInt8 *pBase { pDest.ToPointer() }; AuUInt uCount { pDest.length }; // ... if (!uCount) { return true; } // mbedtls tick int iRet = ::mbedtls_ssl_read(&this->pParent_->ssl, pBase, uCount); if ((iRet == MBEDTLS_ERR_SSL_WANT_READ) || (iRet == MBEDTLS_ERR_SSL_WANT_WRITE) || (iRet == MBEDTLS_ERR_SSL_CONN_EOF)) { // mbedtls doesn't know about peeking. their os doesn't support it. wont be added for linux+nt. return true; } if (iRet == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { this->pParent_->OnClose(); // mbedtls doesn't know about peeking. their os doesn't support it. wont be added for linux+nt. return true; } if (iRet < 0) { this->pParent_->iFatalError = iRet; SysPushErrorNet("TLS Error: {}", iRet); this->pParent_->OnFatal(); return true; } pWriteOutByteBuffer->writePtr += iRet; } return this->bHasRead; } bool &TLSProtocolRecv::HasCompletedHandshake() { return this->bHasCompletedHandshake_; } }