AuroraRuntime/Source/IO/Net/AuNetSocketChannel.cpp
Jamie Reece Wilson 087bac4085 [+] AuByteBuffer::flagNoRealloc
[*] Fix IO regression / Critical Bug / Leak and stupid double free
2023-10-29 20:36:11 +00:00

740 lines
21 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuNetSocketChannel.cpp
Date: 2022-8-16
Author: Reece
***/
#include "Networking.hpp"
#include "AuNetSocketChannel.hpp"
#include "AuNetSocket.hpp"
#if defined(AURORA_IS_MODERNNT_DERIVED)
#include "AuNetStream.NT.hpp"
#else
#include "AuNetStream.Linux.hpp"
#endif
#include <Source/IO/AuIOPipeProcessor.hpp>
#include "AuNetWorker.hpp"
#include <Source/IO/Protocol/Protocol.hpp>
#include <Source/IO/Protocol/AuProtocolStack.hpp>
namespace Aurora::IO::Net
{
static const bool kDefaultFuckNagle = true;
SocketChannel::SocketChannel(SocketBase *pParent) :
pParent_(pParent),
#if defined(AURORA_IS_MODERNNT_DERIVED)
outputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<NtAsyncNetworkTransaction>(pParent))),
inputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<NtAsyncNetworkTransaction>(pParent))),
#else
outputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<LinuxAsyncNetworkTransaction>(pParent))),
inputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<LinuxAsyncNetworkTransaction>(pParent))),
#endif
uBytesInputBuffer(kDefaultStreamSize),
uBytesOutputBuffer(kDefaultStreamSize),
bTcpNoDelay(kDefaultFuckNagle)
{
}
void SocketChannel::Establish()
{
if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolTCP)
{
this->pParent_->UpdateNagleAnyThread(true);
}
// Switch over to extended send/recv (its just a flag to use a different branch - fetch last addr is just a hack)
if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolUDP)
{
#if defined(AURORA_IS_MODERNNT_DERIVED)
AuStaticCast<NtAsyncNetworkTransaction>(outputChannel.ToWriteTransaction())->bDatagramMode = true;
AuStaticCast<NtAsyncNetworkTransaction>(inputChannel.pNetReadTransaction)->bDatagramMode = true;
#else
AuStaticCast<LinuxAsyncNetworkTransaction>(outputChannel.ToWriteTransaction())->bDatagramMode = true;
AuStaticCast<LinuxAsyncNetworkTransaction>(inputChannel.pNetReadTransaction)->bDatagramMode = true;
#endif
}
this->inputChannel.OnEstablish();
}
AuSPtr<ISocket> SocketChannel::ToParent()
{
return this->pParent_->SharedFromThis();
}
AuSPtr<IStreamReader> SocketChannel::AsStreamReader()
{
if (this->pRecvProtocol)
{
return this->pRecvProtocol->AsStreamReader();
}
else if (this->pCachedReader)
{
return this->pCachedReader;
}
else
{
struct A : AuIO::Buffered::BlobReader
{
AuWPtr<SocketChannel> swpChannel;
A(AuSPtr<AuByteBuffer> pBuffer,
AuWPtr<SocketChannel> swpChannel) :
AuIO::Buffered::BlobReader(pBuffer),
swpChannel(swpChannel)
{
}
virtual void Close() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->Shutdown(false);
}
}
};
return this->pCachedReader = AuMakeSharedThrow<A>(this->AsReadableByteBuffer(), AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this));
}
}
AuSPtr<Memory::ByteBuffer> SocketChannel::AsReadableByteBuffer()
{
if (this->pRecvProtocol)
{
return this->pRecvProtocol->AsReadableByteBuffer();
}
else
{
return this->inputChannel.AsReadableByteBuffer();
}
}
AuSPtr<IStreamWriter> SocketChannel::AsStreamWriter()
{
if (this->pCachedWriter)
{
return this->pCachedWriter;
}
else if (this->pSendProtocol)
{
struct C : IO::IStreamWriter
{
AuWPtr<IO::IStreamWriter> swpIStreamWriter;
AuWPtr<SocketChannel> swpChannel;
C(AuWPtr<SocketChannel> swpChannel,
AuWPtr<IStreamWriter> swpIStreamWriter) :
swpIStreamWriter(swpIStreamWriter),
swpChannel(swpChannel)
{
}
virtual EStreamError IsOpen() override
{
if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter))
{
return EStreamError::eErrorNone;
}
else
{
return EStreamError::eErrorHandleClosed;
}
}
virtual EStreamError Write(const AuMemoryViewStreamRead &read) override
{
if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter))
{
return pStream->Write(read);
}
else
{
return EStreamError::eErrorHandleClosed;
}
}
virtual void Close() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->Shutdown(false);
}
}
virtual void Flush() override
{
if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter))
{
pStream->Flush();
}
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->ToChannel()->ScheduleOutOfFrameWrite();
}
}
};
return this->pCachedWriter = AuMakeSharedThrow<C>(AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this),
this->pSendProtocol->AsStreamWriter());
}
else
{
struct A : IO::Buffered::BlobWriter
{
AuWPtr<SocketChannel> swpChannel;
A(AuSPtr<AuByteBuffer> pBuffer,
AuWPtr<SocketChannel> swpChannel) :
IO::Buffered::BlobWriter(pBuffer),
swpChannel(swpChannel)
{
}
virtual void Close() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->Shutdown(false);
}
}
virtual void Flush() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->ToChannel()->ScheduleOutOfFrameWrite();
}
}
};
return this->pCachedWriter = AuMakeSharedThrow<A>(this->AsWritableByteBuffer(), AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this));
}
}
AuSPtr<Memory::ByteBuffer> SocketChannel::AsWritableByteBuffer()
{
if (this->pSendProtocol)
{
return this->pSendProtocol->AsReadableByteBuffer();
}
else
{
return this->outputChannel.AsWritableByteBuffer();
}
}
AuSPtr<Protocol::IProtocolStack> SocketChannel::NewProtocolRecvStack()
{
auto pProtocol = AuMakeShared<Protocol::ProtocolStack>();
if (!pProtocol)
{
SysPushErrorMemory();
return {};
}
pProtocol->pSourceBufer = this->inputChannel.AsReadableByteBuffer();
pProtocol->bKillPipeOnFirstRootLevelFalse = true; // harden
return pProtocol;
}
AuSPtr<Protocol::IProtocolStack> SocketChannel::NewProtocolSendStack()
{
auto pBaseProtocol = Protocol::NewBufferedProtocolStack(this->uBytesOutputBuffer);
if (!pBaseProtocol)
{
SysPushErrorMemory();
return {};
}
auto pProtocol = AuStaticCast<Protocol::ProtocolStack>(pBaseProtocol);
if (!pProtocol)
{
return {};
}
pProtocol->pDrainBuffer = this->outputChannel.AsWritableByteBuffer();
pProtocol->bKillPipeOnFirstRootLevelFalse = true; // harden
return pProtocol;
}
void SocketChannel::SpecifyRecvProtocol(const AuSPtr<Protocol::IProtocolStack> &pRecvProtocol)
{
this->pRecvProtocol = pRecvProtocol;
this->pCachedReader.reset();
}
void SocketChannel::SpecifySendProtocol(const AuSPtr<Protocol::IProtocolStack> &pSendProtocol)
{
this->pSendProtocol = pSendProtocol;
this->pCachedWriter.reset();
}
void SocketChannel::SetNextFrameTargetLength(AuUInt uNextFrameSize)
{
(void)this->inputChannel.pNetReader->SetNextFrameTargetLength(uNextFrameSize);
}
AuUInt SocketChannel::GetNextFrameTargetLength()
{
return this->inputChannel.pNetReader->GetNextFrameTargetLength();
}
AuSPtr<ISocketStats> SocketChannel::GetRecvStats()
{
return AuSPtr<ISocketStats>(this->pParent_->SharedFromThis(), &this->recvStats_);
}
AuSPtr<ISocketStats> SocketChannel::GetSendStats()
{
return AuSPtr<ISocketStats>(this->pParent_->SharedFromThis(), &this->sendStats_);
}
SocketStats &SocketChannel::GetSendStatsEx()
{
return this->sendStats_;
}
SocketStats &SocketChannel::GetRecvStatsEx()
{
return this->recvStats_;
}
void SocketChannel::AddEventListener(const AuSPtr<ISocketChannelEventListener> &pListener)
{
AU_LOCK_GUARD(this->spinLock);
auto pWorker = this->pParent_->ToWorkerEx();
if (pWorker->IsOnThread() && this->AsReadableByteBuffer()->GetNextLinearRead().length)
{
if (!pListener->OnData())
{
pListener->OnComplete();
return;
}
}
this->eventListeners.push_back(pListener);
}
void SocketChannel::RemoveEventListener(const AuSPtr<ISocketChannelEventListener> &pListener)
{
AU_LOCK_GUARD(this->spinLock);
bool bSuccess = AuTryRemove(this->eventListeners, pListener);
auto pWorker = this->pParent_->ToWorkerEx();
if (pWorker->IsOnThread() && bSuccess)
{
pListener->OnRejected();
}
}
bool SocketChannel::IsValid()
{
return bool(this->outputChannel.IsValid()) &&
bool(this->inputChannel.IsValid());
}
bool SocketChannel::SpecifyBufferSize(AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
auto pWorker = this->pParent_->ToWorkerEx();
if (!pWorker)
{
return false;
}
if (pWorker->IsOnThread())
{
this->DoBufferResizeOnThread(true,
true,
uBytes,
pCallbackOptional);
return true;
}
auto that = AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this);
if (!pWorker->TryScheduleInternalTemplate<AuNullS>([=](const AuSPtr<AuAsync::PromiseCallback<AuNullS>> &info)
{
that->DoBufferResizeOnThread(true,
true,
uBytes,
pCallbackOptional);
}, AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>>{}))
{
return false;
}
return true;
}
bool SocketChannel::SpecifyOutputBufferSize(AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
auto pWorker = this->pParent_->ToWorkerEx();
if (!pWorker)
{
return false;
}
if (pWorker->IsOnThread())
{
this->DoBufferResizeOnThread(false,
true,
uBytes,
pCallbackOptional);
return true;
}
auto that = AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this);
if (!pWorker->TryScheduleInternalTemplate<AuNullS>([=](const AuSPtr<AuAsync::PromiseCallback<AuNullS>> &info)
{
that->DoBufferResizeOnThread(false,
true,
uBytes,
pCallbackOptional);
}, AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>>{}))
{
return false;
}
return true;
}
bool SocketChannel::SpecifyInputBufferSize(AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
auto pWorker = this->pParent_->ToWorkerEx();
if (!pWorker)
{
return false;
}
if (pWorker->IsOnThread())
{
this->DoBufferResizeOnThread(true,
false,
uBytes,
pCallbackOptional);
return true;
}
auto that = AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this);
if (!pWorker->TryScheduleInternalTemplate<AuNullS>([=](const AuSPtr<AuAsync::PromiseCallback<AuNullS>> &info)
{
that->DoBufferResizeOnThread(true,
false,
uBytes,
pCallbackOptional);
}, AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>>{}))
{
return false;
}
return true;
}
void SocketChannel::ScheduleOutOfFrameWrite()
{
if (this->pSendProtocol)
{
this->pSendProtocol->DoTick();
}
this->outputChannel.ScheduleOutOfFrameWrite();
}
AuSPtr<IAsyncTransaction> SocketChannel::ToReadTransaction()
{
return this->bIsManualRead ?
this->inputChannel.pNetReadTransaction :
AuSPtr<IAsyncTransaction> {};
}
AuSPtr<IAsyncTransaction> SocketChannel::ToWriteTransaction()
{
return this->bIsManualWrite ?
this->outputChannel.ToWriteTransaction() :
AuSPtr<IAsyncTransaction> {};
}
bool SocketChannel::SpecifyTCPNoDelay(bool bFuckNagle)
{
if (!this->bIsManualWrite)
{
return false;
}
this->bTcpNoDelay = bFuckNagle;
this->pParent_->UpdateNagleAnyThread(this->bTcpNoDelay);
return true;
}
bool SocketChannel::SpecifyTransactionsHaveIOFence(bool bAllocateFence)
{
return false;
}
bool SocketChannel::SpecifyManualWrite(bool bEnableDirectAIOWrite)
{
if (this->bIsEstablished)
{
return false;
}
this->bIsManualWrite = true;
return true;
}
bool SocketChannel::SpecifyManualRead(bool bEnableDirectAIORead)
{
if (this->bIsEstablished)
{
return false;
}
this->bIsManualRead = true;
return true;
}
bool SocketChannel::SpecifyPerTickAsyncReadLimit(AuUInt uBytes)
{
if (this->bIsEstablished)
{
return false;
}
this->uBytesToFlip = true;
return true;
}
void SocketChannel::StopTime()
{
this->sendStats_.End();
this->recvStats_.End();
}
void SocketChannel::StartTime()
{
this->sendStats_.Start();
this->recvStats_.Start();
}
void SocketChannel::Release()
{
this->StopTime();
this->PrivateUserDataClear();
this->pCachedReader.reset();
this->pCachedWriter.reset();
this->pRecvProtocol.reset();
this->pSendProtocol.reset();
this->inputChannel.pNetReader.reset();
}
static bool ForceResize(AuByteBuffer &buffer, AuUInt uLength)
{
auto old = buffer.flagNoRealloc;
buffer.flagNoRealloc = false;
bool bRet = buffer.Resize(uLength);
buffer.flagNoRealloc = old;
return bRet;
}
void SocketChannel::DoBufferResizeOnThread(bool bInput,
bool bOutput,
AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
if (!this->bIsEstablished)
{
if (bInput)
{
this->uBytesInputBuffer = uBytes;
}
if (bOutput)
{
this->uBytesOutputBuffer = uBytes;
}
if (pCallbackOptional)
{
pCallbackOptional->OnSuccess((void *)nullptr);
}
return;
}
bool bHasCurrentSuccess {};
if (bOutput)
{
if (this->outputChannel.CanResize())
{
if (ForceResize(this->outputChannel.GetByteBuffer(), uBytes))
{
bOutput = false;
bHasCurrentSuccess = !bInput;
}
}
}
if (bHasCurrentSuccess)
{
if (pCallbackOptional)
{
pCallbackOptional->OnSuccess((void *)nullptr);
}
return;
}
{
AU_LOCK_GUARD(this->spinLock);
if (bOutput)
{
if (this->uBytesOutputBufferRetarget)
{
if (auto pOutput = AuExchange(this->pRetargetOutput, {}))
{
pOutput->OnFailure((void *)nullptr);
if (this->pRetargetInput == pOutput)
{
this->pRetargetInput = {};
}
}
}
this->uBytesOutputBufferRetarget = uBytes;
this->pRetargetOutput = pCallbackOptional;
}
if (bInput)
{
if (this->uBytesInputBufferRetarget)
{
if (auto pInput = AuExchange(this->pRetargetInput, {}))
{
pInput->OnFailure((void *)nullptr);
if (this->pRetargetOutput == pInput)
{
this->pRetargetOutput = {};
}
}
}
this->uBytesInputBufferRetarget = uBytes;
this->pRetargetInput = pCallbackOptional;
}
}
}
void SocketChannel::DoReallocWriteTick()
{
AU_LOCK_GUARD(this->spinLock);
if (!this->uBytesOutputBufferRetarget)
{
return;
}
if (!ForceResize(this->outputChannel.GetByteBuffer(), uBytesOutputBufferRetarget))
{
SysPushErrorMemory();
if (auto pOutput = AuExchange(this->pRetargetOutput, {}))
{
pOutput->OnFailure((void *)nullptr);
if (this->pRetargetInput == pOutput)
{
this->pRetargetInput = {};
}
}
return;
}
this->uBytesOutputBufferRetarget = 0;
if (auto pOutput = AuExchange(this->pRetargetOutput, {}))
{
if (this->pRetargetInput != pOutput)
{
pOutput->OnSuccess((void *)nullptr);
}
}
}
void SocketChannel::DoReallocReadTick()
{
AU_LOCK_GUARD(this->spinLock);
if (!this->uBytesInputBufferRetarget)
{
return;
}
if (!ForceResize(*this->inputChannel.AsReadableByteBuffer(), uBytesOutputBufferRetarget))
{
SysPushErrorMemory();
if (auto pInput = AuExchange(this->pRetargetInput, {}))
{
pInput->OnFailure((void *)nullptr);
if (this->pRetargetOutput == pInput)
{
this->pRetargetOutput = {};
}
}
return;
}
this->uBytesInputBufferRetarget = 0;
if (auto pInput = AuExchange(this->pRetargetInput, {}))
{
if (this->pRetargetOutput != pInput)
{
pInput->OnSuccess((void *)nullptr);
}
}
}
bool SocketChannel::GetCurrentTCPNoDelay()
{
return this->bTcpNoDelay;
}
AuUInt SocketChannel::GetInputBufferSize()
{
return this->uBytesInputBuffer;
}
AuUInt SocketChannel::GetOutputBufferSize()
{
return this->uBytesOutputBuffer;
}
}