/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: AuNetSocketChannel.cpp Date: 2022-8-16 Author: Reece ***/ #include "Networking.hpp" #include "AuNetSocketChannel.hpp" #include "AuNetSocket.hpp" #if defined(AURORA_IS_MODERNNT_DERIVED) #include "AuNetStream.NT.hpp" #else #include "AuNetStream.Linux.hpp" #endif #include #include "AuNetWorker.hpp" #include #include namespace Aurora::IO::Net { static const bool kDefaultFuckNagle = true; SocketChannel::SocketChannel(SocketBase *pParent) : pParent_(pParent), #if defined(AURORA_IS_MODERNNT_DERIVED) outputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), inputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), #else outputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), inputChannel(pParent, AuStaticCast(AuMakeShared(pParent))), #endif uBytesInputBuffer(kDefaultStreamSize), uBytesOutputBuffer(kDefaultStreamSize), bTcpNoDelay(kDefaultFuckNagle) { } void SocketChannel::Establish() { if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolTCP) { this->pParent_->UpdateNagleAnyThread(true); } // Switch over to extended send/recv (its just a flag to use a different branch - fetch last addr is just a hack) if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolUDP) { #if defined(AURORA_IS_MODERNNT_DERIVED) AuStaticCast(outputChannel.ToWriteTransaction())->bDatagramMode = true; AuStaticCast(inputChannel.pNetReadTransaction)->bDatagramMode = true; #else AuStaticCast(outputChannel.ToWriteTransaction())->bDatagramMode = true; AuStaticCast(inputChannel.pNetReadTransaction)->bDatagramMode = true; #endif } this->inputChannel.OnEstablish(); } AuSPtr SocketChannel::ToParent() { return this->pParent_->SharedFromThis(); } AuSPtr SocketChannel::AsStreamReader() { if (this->pRecvProtocol) { return this->pRecvProtocol->AsStreamReader(); } else { return AuMakeShared(this->AsReadableByteBuffer()); } } AuSPtr SocketChannel::AsReadableByteBuffer() { if (this->pRecvProtocol) { return this->pRecvProtocol->AsReadableByteBuffer(); } else { return this->inputChannel.AsReadableByteBuffer(); } } AuSPtr SocketChannel::AsStreamWriter() { if (this->pSendProtocol) { return this->pSendProtocol->AsStreamWriter(); } else { return AuMakeSharedThrow(this->AsWritableByteBuffer()); } } AuSPtr SocketChannel::AsWritableByteBuffer() { if (this->pSendProtocol) { return this->pSendProtocol->AsReadableByteBuffer(); } else { return this->outputChannel.AsWritableByteBuffer(); } } AuSPtr SocketChannel::NewProtocolRecvStack() { auto pProtocol = AuMakeShared(); if (!pProtocol) { return {}; } pProtocol->pSourceBufer = this->inputChannel.AsReadableByteBuffer(); return pProtocol; } AuSPtr SocketChannel::NewProtocolSendStack() { auto pBaseProtocol = Protocol::NewBufferedProtocolStack(this->uBytesOutputBuffer); if (!pBaseProtocol) { return {}; } auto pProtocol = AuStaticCast(pBaseProtocol); if (!pProtocol) { return {}; } pProtocol->pDrainBuffer = this->outputChannel.AsWritableByteBuffer(); return pProtocol; } void SocketChannel::SpecifyRecvProtocol(const AuSPtr &pRecvProtocol) { this->pRecvProtocol = pRecvProtocol; } void SocketChannel::SpecifySendProtocol(const AuSPtr &pSendProtocol) { this->pSendProtocol = pSendProtocol; } AuSPtr SocketChannel::GetRecvStats() { return AuSPtr(this->pParent_->SharedFromThis(), &this->recvStats_); } AuSPtr SocketChannel::GetSendStats() { return AuSPtr(this->pParent_->SharedFromThis(), &this->sendStats_); } SocketStats &SocketChannel::GetSendStatsEx() { return this->sendStats_; } SocketStats &SocketChannel::GetRecvStatsEx() { return this->recvStats_; } void SocketChannel::AddEventListener(const AuSPtr &pListener) { AU_LOCK_GUARD(this->spinLock); auto pWorker = this->pParent_->ToWorkerEx(); if (pWorker->IsOnThread() && this->AsReadableByteBuffer()->GetNextLinearRead().length) { if (!pListener->OnData()) { pListener->OnComplete(); return; } } this->eventListeners.push_back(pListener); } void SocketChannel::RemoveEventListener(const AuSPtr &pListener) { AU_LOCK_GUARD(this->spinLock); bool bSuccess = AuTryRemove(this->eventListeners, pListener); auto pWorker = this->pParent_->ToWorkerEx(); if (pWorker->IsOnThread() && bSuccess) { pListener->OnRejected(); } } bool SocketChannel::IsValid() { return bool(this->outputChannel.IsValid()) && bool(this->inputChannel.IsValid()); } bool SocketChannel::SpecifyBufferSize(AuUInt uBytes, const AuSPtr> &pCallbackOptional) { auto pWorker = this->pParent_->ToWorkerEx(); if (!pWorker) { return false; } if (pWorker->IsOnThread()) { this->DoBufferResizeOnThread(true, true, uBytes, pCallbackOptional); return true; } auto that = AuSPtr(this->pParent_->SharedFromThis(), this); if (!pWorker->TryScheduleInternalTemplate([=](const AuSPtr> &info) { that->DoBufferResizeOnThread(true, true, uBytes, pCallbackOptional); }, AuSPtr>{})) { return false; } return true; } bool SocketChannel::SpecifyOutputBufferSize(AuUInt uBytes, const AuSPtr> &pCallbackOptional) { auto pWorker = this->pParent_->ToWorkerEx(); if (!pWorker) { return false; } if (pWorker->IsOnThread()) { this->DoBufferResizeOnThread(false, true, uBytes, pCallbackOptional); return true; } auto that = AuSPtr(this->pParent_->SharedFromThis(), this); if (!pWorker->TryScheduleInternalTemplate([=](const AuSPtr> &info) { that->DoBufferResizeOnThread(false, true, uBytes, pCallbackOptional); }, AuSPtr>{})) { return false; } return true; } bool SocketChannel::SpecifyInputBufferSize(AuUInt uBytes, const AuSPtr> &pCallbackOptional) { auto pWorker = this->pParent_->ToWorkerEx(); if (!pWorker) { return false; } if (pWorker->IsOnThread()) { this->DoBufferResizeOnThread(true, false, uBytes, pCallbackOptional); return true; } auto that = AuSPtr(this->pParent_->SharedFromThis(), this); if (!pWorker->TryScheduleInternalTemplate([=](const AuSPtr> &info) { that->DoBufferResizeOnThread(true, false, uBytes, pCallbackOptional); }, AuSPtr>{})) { return false; } return true; } void SocketChannel::ScheduleOutOfFrameWrite() { if (this->pSendProtocol) { this->pSendProtocol->DoTick(); } this->outputChannel.ScheduleOutOfFrameWrite(); } AuSPtr SocketChannel::ToReadTransaction() { return this->bIsManualRead ? this->inputChannel.pNetReadTransaction : AuSPtr {}; } AuSPtr SocketChannel::ToWriteTransaction() { return this->bIsManualWrite ? this->outputChannel.ToWriteTransaction() : AuSPtr {}; } bool SocketChannel::SpecifyTCPNoDelay(bool bFuckNagle) { if (!this->bIsManualWrite) { return false; } this->bTcpNoDelay = bFuckNagle; this->pParent_->UpdateNagleAnyThread(this->bTcpNoDelay); return true; } bool SocketChannel::SpecifyTransactionsHaveIOFence(bool bAllocateFence) { return false; } bool SocketChannel::SpecifyManualWrite(bool bEnableDirectAIOWrite) { if (this->bIsEstablished) { return false; } this->bIsManualWrite = true; return true; } bool SocketChannel::SpecifyManualRead(bool bEnableDirectAIORead) { if (this->bIsEstablished) { return false; } this->bIsManualRead = true; return true; } bool SocketChannel::SpecifyPerTickAsyncReadLimit(AuUInt uBytes) { if (this->bIsEstablished) { return false; } this->uBytesToFlip = true; return true; } void SocketChannel::DoBufferResizeOnThread(bool bInput, bool bOutput, AuUInt uBytes, const AuSPtr> &pCallbackOptional) { if (!this->bIsEstablished) { if (bInput) { this->uBytesInputBuffer = uBytes; } if (bOutput) { this->uBytesOutputBuffer = uBytes; } if (pCallbackOptional) { pCallbackOptional->OnSuccess((void *)nullptr); } return; } if (bOutput) { if (!this->outputChannel.CanResize()) { AuByteBuffer newBuffer(uBytes, false, false); if (!(newBuffer.IsValid())) { SysPushErrorMemory(); if (pCallbackOptional) { pCallbackOptional->OnFailure((void *)nullptr); } return; } auto &byteBufferRef = this->outputChannel.GetByteBuffer(); auto oldReadHead = byteBufferRef.readPtr; if (!newBuffer.WriteFrom(byteBufferRef)) { SysPushErrorMemory(); this->outputChannel.GetByteBuffer().readPtr = oldReadHead; if (pCallbackOptional) { pCallbackOptional->OnFailure((void *)nullptr); } return; } byteBufferRef = AuMove(newBuffer); if (pCallbackOptional) { pCallbackOptional->OnSuccess((void *)nullptr); } return; } } { AU_LOCK_GUARD(this->spinLock); if (bOutput) { if (this->uBytesOutputBufferRetarget) { if (this->pRetargetOutput) { this->pRetargetOutput->OnFailure((void *)nullptr); } } this->uBytesOutputBufferRetarget = uBytes; this->pRetargetOutput = pCallbackOptional; } else { if (this->uBytesInputBufferRetarget) { if (this->pRetargetInput) { this->pRetargetInput->OnFailure((void *)nullptr); } } this->uBytesInputBufferRetarget = uBytes; this->pRetargetInput = pCallbackOptional; } } } void SocketChannel::DoReallocWriteTick() { AU_LOCK_GUARD(this->spinLock); if (!this->uBytesOutputBufferRetarget) { return; } AuByteBuffer newBuffer(this->uBytesOutputBufferRetarget, true, false); if (!(newBuffer.IsValid())) { SysPushErrorMemory(); if (this->pRetargetOutput) { this->pRetargetOutput->OnFailure((void *)nullptr); } return; } auto &byteBufferRef = this->outputChannel.GetByteBuffer(); auto oldReadHead = byteBufferRef.readPtr; auto oldWriteHead = byteBufferRef.writePtr; if (!newBuffer.WriteFrom(byteBufferRef)) { SysPushErrorMemory(); this->outputChannel.GetByteBuffer().readPtr = oldReadHead; this->outputChannel.GetByteBuffer().writePtr = oldWriteHead; if (this->pRetargetOutput) { this->pRetargetOutput->OnFailure((void *)nullptr); } return; } byteBufferRef = AuMove(newBuffer); this->uBytesOutputBufferRetarget = 0; if (this->pRetargetOutput) { this->pRetargetOutput->OnSuccess((void *)nullptr); } } void SocketChannel::DoReallocReadTick() { AU_LOCK_GUARD(this->spinLock); if (!this->uBytesInputBufferRetarget) { return; } AuByteBuffer newBuffer(this->uBytesInputBufferRetarget, true, false); if (!(newBuffer.IsValid())) { SysPushErrorMemory(); if (this->pRetargetInput) { this->pRetargetInput->OnFailure((void *)nullptr); } return; } auto byteBufferRef = this->inputChannel.AsReadableByteBuffer(); auto oldReadHead = byteBufferRef->readPtr; auto oldWriteHead = byteBufferRef->writePtr; if (!newBuffer.WriteFrom(*byteBufferRef.get())) { SysPushErrorMemory(); byteBufferRef->readPtr = oldReadHead; byteBufferRef->writePtr = oldWriteHead; if (this->pRetargetInput) { this->pRetargetInput->OnFailure((void *)nullptr); } return; } byteBufferRef.get()->operator=(AuMove(newBuffer)); this->uBytesInputBufferRetarget = 0; if (this->pRetargetInput) { this->pRetargetInput->OnSuccess((void *)nullptr); } } bool SocketChannel::GetCurrentTCPNoDelay() { return this->bTcpNoDelay; } AuUInt SocketChannel::GetInputBufferSize() { return this->uBytesInputBuffer; } AuUInt SocketChannel::GetOutputBufferSize() { return this->uBytesOutputBuffer; } }