/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: AuNetSocketChannel.cpp Date: 2022-8-16 Author: Reece ***/ #include "Networking.hpp" #include "AuNetSocketChannel.hpp" #include "AuNetSocket.hpp" #if defined(AURORA_IS_MODERNNT_DERIVED) #include "AuNetStream.NT.hpp" #else #include "AuNetStream.Linux.hpp" #endif #include #include "AuNetWorker.hpp" #include #include #include "AuNetSocketServer.hpp" namespace Aurora::IO::Net { static const bool kDefaultFuckNagle = true; SocketChannel::SocketChannel(SocketBase *pParent) : pParent_(pParent), #if defined(AURORA_IS_MODERNNT_DERIVED) outputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), inputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), #else outputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), inputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), #endif uBytesInputBuffer(kDefaultStreamSize), uBytesOutputBuffer(kDefaultStreamSize), bTcpNoDelay(kDefaultFuckNagle) { } void SocketChannel::Establish() { if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolTCP) { this->pParent_->UpdateNagleAnyThread(true); } // Switch over to extended send/recv (its just a flag to use a different branch - fetch last addr is just a hack) if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolUDP) { #if defined(AURORA_IS_MODERNNT_DERIVED) AuStaticCast(outputChannel.ToWriteTransaction())->bDatagramMode = true; AuStaticCast(inputChannel.pNetReadTransaction)->bDatagramMode = true; #else AuStaticCast(outputChannel.ToWriteTransaction())->bDatagramMode = true; AuStaticCast(inputChannel.pNetReadTransaction)->bDatagramMode = true; #endif } this->inputChannel.OnEstablish(); } AuSPtr SocketChannel::ToParent() { return this->pParent_->SharedFromThis(); } AuSPtr SocketChannel::AsStreamReader() { if (this->pRecvProtocol) { return this->pRecvProtocol->AsStreamReader(); } else if (this->pCachedReader) { return this->pCachedReader; } else { struct A : AuIO::Buffered::BlobReader { AuWPtr swpChannel; A(AuSPtr pBuffer, AuWPtr swpChannel) : AuIO::Buffered::BlobReader(pBuffer), swpChannel(swpChannel) { } virtual void Close() override { if (auto pChannel = AuTryLockMemoryType(this->swpChannel)) { pChannel->pParent_->Shutdown(false); } } }; return this->pCachedReader = AuMakeSharedThrow(this->AsReadableByteBuffer(), AuSPtr(this->pParent_->SharedFromThis(), this)); } } AuSPtr SocketChannel::AsReadableByteBuffer() { if (this->pRecvProtocol) { return this->pRecvProtocol->AsReadableByteBuffer(); } else { return this->inputChannel.AsReadableByteBuffer(); } } AuSPtr SocketChannel::AsStreamWriter() { if (this->pCachedWriter) { return this->pCachedWriter; } else if (this->pSendProtocol) { struct C : IO::IStreamWriter { AuWPtr swpIStreamWriter; AuWPtr swpChannel; C(AuWPtr swpChannel, AuWPtr swpIStreamWriter) : swpIStreamWriter(swpIStreamWriter), swpChannel(swpChannel) { } virtual EStreamError IsOpen() override { if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter)) { return EStreamError::eErrorNone; } else { return EStreamError::eErrorHandleClosed; } } virtual EStreamError Write(const AuMemoryViewStreamRead &read) override { if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter)) { return pStream->Write(read); } else { return EStreamError::eErrorHandleClosed; } } virtual void Close() override { if (auto pChannel = AuTryLockMemoryType(this->swpChannel)) { pChannel->pParent_->Shutdown(false); } } virtual void Flush() override { if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter)) { pStream->Flush(); } if (auto pChannel = AuTryLockMemoryType(this->swpChannel)) { pChannel->pParent_->ToChannel()->ScheduleOutOfFrameWrite(); } } }; return this->pCachedWriter = AuMakeSharedThrow(AuSPtr(this->pParent_->SharedFromThis(), this), this->pSendProtocol->AsStreamWriter()); } else { struct A : IO::Buffered::BlobWriter { AuWPtr swpChannel; A(AuSPtr pBuffer, AuWPtr swpChannel) : IO::Buffered::BlobWriter(pBuffer), swpChannel(swpChannel) { } virtual void Close() override { if (auto pChannel = AuTryLockMemoryType(this->swpChannel)) { pChannel->pParent_->Shutdown(false); } } virtual void Flush() override { if (auto pChannel = AuTryLockMemoryType(this->swpChannel)) { pChannel->pParent_->ToChannel()->ScheduleOutOfFrameWrite(); } } }; return this->pCachedWriter = AuMakeSharedThrow(this->AsWritableByteBuffer(), AuSPtr(this->pParent_->SharedFromThis(), this)); } } AuSPtr SocketChannel::AsWritableByteBuffer() { if (this->pSendProtocol) { return this->pSendProtocol->AsReadableByteBuffer(); } else { return this->outputChannel.AsWritableByteBuffer(); } } AuSPtr SocketChannel::NewProtocolRecvStack() { auto pProtocol = AuMakeShared(); if (!pProtocol) { SysPushErrorMemory(); return {}; } pProtocol->pSourceBufer = this->inputChannel.AsReadableByteBuffer(); pProtocol->bKillPipeOnFirstRootLevelFalse = true; // harden return pProtocol; } AuSPtr SocketChannel::NewProtocolSendStack() { auto pBaseProtocol = Protocol::NewBufferedProtocolStack(this->uBytesOutputBuffer); if (!pBaseProtocol) { SysPushErrorMemory(); return {}; } auto pProtocol = AuStaticCast(pBaseProtocol); if (!pProtocol) { return {}; } pProtocol->pDrainBuffer = this->outputChannel.AsWritableByteBuffer(); pProtocol->bKillPipeOnFirstRootLevelFalse = true; // harden return pProtocol; } void SocketChannel::SpecifyPageLength(AuUInt uPageSize) { this->uBytesPerFrame = uPageSize; } AuUInt SocketChannel::GetPageLength() { return this->uBytesPerFrame; } void SocketChannel::SpecifyRecvProtocol(const AuSPtr &pRecvProtocol) { this->pRecvProtocol = pRecvProtocol; this->pCachedReader.reset(); } void SocketChannel::SpecifySendProtocol(const AuSPtr &pSendProtocol) { this->pSendProtocol = pSendProtocol; this->pCachedWriter.reset(); } void SocketChannel::SetNextFrameTargetLength(AuUInt uNextFrameSize) { (void)this->inputChannel.pNetReader->SetNextFrameTargetLength(uNextFrameSize); } AuUInt SocketChannel::GetNextFrameTargetLength() { return this->inputChannel.pNetReader->GetNextFrameTargetLength(); } AuSPtr SocketChannel::GetRecvStats() { return AuSPtr(this->pParent_->SharedFromThis(), &this->recvStats_); } AuSPtr SocketChannel::GetSendStats() { return AuSPtr(this->pParent_->SharedFromThis(), &this->sendStats_); } SocketStats &SocketChannel::GetSendStatsEx() { return this->sendStats_; } SocketStats &SocketChannel::GetRecvStatsEx() { return this->recvStats_; } SocketStats *SocketChannel::GetSendStatsEx2() { if (auto pTest = this->pParent_->GetSocketServer()) { if (auto pWP2 = this->pParent_->wpParent2_) { return &pWP2->sendStats_; } else { return {}; } } else { return {}; } } SocketStats *SocketChannel::GetRecvStatsEx2() { if (auto pTest = this->pParent_->GetSocketServer()) { if (auto pWP2 = this->pParent_->wpParent2_) { return &pWP2->recvStats_; } else { return {}; } } else { return {}; } } void SocketChannel::AddEventListener(const AuSPtr &pListener) { AU_LOCK_GUARD(this->spinLock); auto pWorker = this->pParent_->ToWorkerEx(); if (pWorker->IsOnThread() && this->AsReadableByteBuffer()->GetNextLinearRead().length) { if (!pListener->OnData()) { pListener->OnComplete(); return; } } this->eventListeners.push_back(pListener); } void SocketChannel::RemoveEventListener(const AuSPtr &pListener) { AU_LOCK_GUARD(this->spinLock); bool bSuccess = AuTryRemove(this->eventListeners, pListener); auto pWorker = this->pParent_->ToWorkerEx(); if (pWorker->IsOnThread() && bSuccess) { pListener->OnRejected(); } } bool SocketChannel::IsValid() { return bool(this->outputChannel.IsValid()) && bool(this->inputChannel.IsValid()); } AuSPtr SocketChannel::GetChannelLimits() { return AuSPtr(this->pParent_->SharedFromThis(), &this->channelSecurity_); } bool SocketChannel::SpecifyBufferSize(AuUInt uBytes, const AuSPtr> &pCallbackOptional) { auto pWorker = this->pParent_->ToWorkerEx(); if (!pWorker) { return false; } if (pWorker->IsOnThread()) { this->DoBufferResizeOnThread(true, true, uBytes, pCallbackOptional); return true; } auto that = AuSPtr(this->pParent_->SharedFromThis(), this); if (!pWorker->TryScheduleInternalTemplate([=](const AuSPtr> &info) { that->DoBufferResizeOnThread(true, true, uBytes, pCallbackOptional); }, AuSPtr>{})) { return false; } return true; } bool SocketChannel::SpecifyOutputBufferSize(AuUInt uBytes, const AuSPtr> &pCallbackOptional) { auto pWorker = this->pParent_->ToWorkerEx(); if (!pWorker) { return false; } if (pWorker->IsOnThread()) { this->DoBufferResizeOnThread(false, true, uBytes, pCallbackOptional); return true; } auto that = AuSPtr(this->pParent_->SharedFromThis(), this); if (!pWorker->TryScheduleInternalTemplate([=](const AuSPtr> &info) { that->DoBufferResizeOnThread(false, true, uBytes, pCallbackOptional); }, AuSPtr>{})) { return false; } return true; } bool SocketChannel::SpecifyInputBufferSize(AuUInt uBytes, const AuSPtr> &pCallbackOptional) { auto pWorker = this->pParent_->ToWorkerEx(); if (!pWorker) { return false; } if (pWorker->IsOnThread()) { this->DoBufferResizeOnThread(true, false, uBytes, pCallbackOptional); return true; } auto that = AuSPtr(this->pParent_->SharedFromThis(), this); if (!pWorker->TryScheduleInternalTemplate([=](const AuSPtr> &info) { that->DoBufferResizeOnThread(true, false, uBytes, pCallbackOptional); }, AuSPtr>{})) { return false; } return true; } void SocketChannel::ScheduleOutOfFrameWrite() { if (this->pSendProtocol) { this->pSendProtocol->DoTick(); } this->outputChannel.ScheduleOutOfFrameWrite(); } AuSPtr SocketChannel::ToReadTransaction() { return this->bIsManualRead ? this->inputChannel.pNetReadTransaction : AuSPtr {}; } AuSPtr SocketChannel::ToWriteTransaction() { return this->bIsManualWrite ? this->outputChannel.ToWriteTransaction() : AuSPtr {}; } bool SocketChannel::SpecifyTCPNoDelay(bool bFuckNagle) { if (!this->bIsManualWrite) { return false; } this->bTcpNoDelay = bFuckNagle; this->pParent_->UpdateNagleAnyThread(this->bTcpNoDelay); return true; } bool SocketChannel::SpecifyTransactionsHaveIOFence(bool bAllocateFence) { return false; } bool SocketChannel::SpecifyManualWrite(bool bEnableDirectAIOWrite) { if (this->bIsEstablished) { return false; } this->bIsManualWrite = true; return true; } bool SocketChannel::SpecifyManualRead(bool bEnableDirectAIORead) { if (this->bIsEstablished) { return false; } this->bIsManualRead = true; return true; } bool SocketChannel::SpecifyPerTickAsyncReadLimit(AuUInt uBytes) { if (this->bIsEstablished) { return false; } this->uBytesToFlip = true; return true; } void SocketChannel::StopTime() { this->sendStats_.End(); this->recvStats_.End(); } void SocketChannel::StartTime() { this->sendStats_.Start(); this->recvStats_.Start(); } void SocketChannel::Release() { this->StopTime(); if (this->inputChannel.pNetReader) { this->inputChannel.pNetReader->End(); } this->PrivateUserDataClear(); this->pCachedReader.reset(); this->pCachedWriter.reset(); this->pRecvProtocol.reset(); this->pSendProtocol.reset(); this->inputChannel.pNetReader.reset(); } static bool ForceResize(AuByteBuffer &buffer, AuUInt uLength) { auto old = buffer.flagNoRealloc; buffer.flagNoRealloc = false; bool bRet = buffer.Resize(uLength); buffer.flagNoRealloc = old; return bRet; } void SocketChannel::DoBufferResizeOnThread(bool bInput, bool bOutput, AuUInt uBytes, const AuSPtr> &pCallbackOptional) { if (!this->bIsEstablished) { if (bInput) { this->uBytesInputBuffer = uBytes; } if (bOutput) { this->uBytesOutputBuffer = uBytes; } if (pCallbackOptional) { try { pCallbackOptional->OnSuccess((void *)nullptr); } catch (...) { SysPushErrorCatch(); } } return; } bool bHasCurrentSuccess {}; if (bOutput) { if (this->outputChannel.CanResize()) { if (ForceResize(this->outputChannel.GetByteBuffer(), uBytes)) { bOutput = false; bHasCurrentSuccess = !bInput; } } } if (bHasCurrentSuccess) { try { pCallbackOptional->OnSuccess((void *)nullptr); } catch (...) { SysPushErrorCatch(); } return; } { AU_LOCK_GUARD(this->spinLock); if (bOutput) { if (this->uBytesOutputBufferRetarget) { if (auto pOutput = AuExchange(this->pRetargetOutput, {})) { try { pOutput->OnFailure((void *)nullptr); } catch (...) { SysPushErrorCatch(); } if (this->pRetargetInput == pOutput) { AuResetMember(this->pRetargetInput); } } } this->uBytesOutputBufferRetarget = uBytes; this->pRetargetOutput = pCallbackOptional; } if (bInput) { if (this->uBytesInputBufferRetarget) { if (auto pInput = AuExchange(this->pRetargetInput, {})) { try { pInput->OnFailure((void *)nullptr); } catch (...) { SysPushErrorCatch(); } if (this->pRetargetOutput == pInput) { AuResetMember(this->pRetargetOutput); } } } this->uBytesInputBufferRetarget = uBytes; this->pRetargetInput = pCallbackOptional; } } } void SocketChannel::DoReallocWriteTick() { AU_LOCK_GUARD(this->spinLock); if (!this->uBytesOutputBufferRetarget) { return; } if (!ForceResize(this->outputChannel.GetByteBuffer(), uBytesOutputBufferRetarget)) { SysPushErrorMemory(); if (auto pOutput = AuExchange(this->pRetargetOutput, {})) { try { pOutput->OnFailure((void *)nullptr); } catch (...) { SysPushErrorCatch(); } if (this->pRetargetInput == pOutput) { AuResetMember(this->pRetargetInput); } } return; } this->uBytesOutputBufferRetarget = 0; if (auto pOutput = AuExchange(this->pRetargetOutput, {})) { if (this->pRetargetInput != pOutput) { try { pOutput->OnSuccess((void *)nullptr); } catch (...) { SysPushErrorCatch(); } } } } void SocketChannel::DoReallocReadTick() { AU_LOCK_GUARD(this->spinLock); if (!this->uBytesInputBufferRetarget) { return; } if (!ForceResize(*this->inputChannel.AsReadableByteBuffer(), uBytesOutputBufferRetarget)) { SysPushErrorMemory(); if (auto pInput = AuExchange(this->pRetargetInput, {})) { try { pInput->OnFailure((void *)nullptr); } catch (...) { SysPushErrorCatch(); } if (this->pRetargetOutput == pInput) { AuResetMember(this->pRetargetOutput); } } return; } this->uBytesInputBufferRetarget = 0; if (auto pInput = AuExchange(this->pRetargetInput, {})) { if (this->pRetargetOutput != pInput) { try { pInput->OnSuccess((void *)nullptr); } catch (...) { SysPushErrorCatch(); } } } } bool SocketChannel::GetCurrentTCPNoDelay() { return this->bTcpNoDelay; } AuUInt SocketChannel::GetInputBufferSize() { return this->uBytesInputBuffer; } AuUInt SocketChannel::GetOutputBufferSize() { return this->uBytesOutputBuffer; } }