AuroraRuntime/Source/IO/Net/AuNetSocketChannel.cpp
2023-06-04 22:19:57 +01:00

743 lines
22 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuNetSocketChannel.cpp
Date: 2022-8-16
Author: Reece
***/
#include "Networking.hpp"
#include "AuNetSocketChannel.hpp"
#include "AuNetSocket.hpp"
#if defined(AURORA_IS_MODERNNT_DERIVED)
#include "AuNetStream.NT.hpp"
#else
#include "AuNetStream.Linux.hpp"
#endif
#include <Source/IO/AuIOPipeProcessor.hpp>
#include "AuNetWorker.hpp"
#include <Source/IO/Protocol/Protocol.hpp>
#include <Source/IO/Protocol/AuProtocolStack.hpp>
namespace Aurora::IO::Net
{
static const bool kDefaultFuckNagle = true;
SocketChannel::SocketChannel(SocketBase *pParent) :
pParent_(pParent),
#if defined(AURORA_IS_MODERNNT_DERIVED)
outputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<NtAsyncNetworkTransaction>(pParent))),
inputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<NtAsyncNetworkTransaction>(pParent))),
#else
outputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<LinuxAsyncNetworkTransaction>(pParent))),
inputChannel(pParent, AuStaticCast<IAsyncTransaction>(AuMakeShared<LinuxAsyncNetworkTransaction>(pParent))),
#endif
uBytesInputBuffer(kDefaultStreamSize),
uBytesOutputBuffer(kDefaultStreamSize),
bTcpNoDelay(kDefaultFuckNagle)
{
}
void SocketChannel::Establish()
{
if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolTCP)
{
this->pParent_->UpdateNagleAnyThread(true);
}
// Switch over to extended send/recv (its just a flag to use a different branch - fetch last addr is just a hack)
if (this->pParent_->GetLocalEndpoint().transportProtocol == ETransportProtocol::eProtocolUDP)
{
#if defined(AURORA_IS_MODERNNT_DERIVED)
AuStaticCast<NtAsyncNetworkTransaction>(outputChannel.ToWriteTransaction())->bDatagramMode = true;
AuStaticCast<NtAsyncNetworkTransaction>(inputChannel.pNetReadTransaction)->bDatagramMode = true;
#else
AuStaticCast<LinuxAsyncNetworkTransaction>(outputChannel.ToWriteTransaction())->bDatagramMode = true;
AuStaticCast<LinuxAsyncNetworkTransaction>(inputChannel.pNetReadTransaction)->bDatagramMode = true;
#endif
}
this->inputChannel.OnEstablish();
}
AuSPtr<ISocket> SocketChannel::ToParent()
{
return this->pParent_->SharedFromThis();
}
AuSPtr<IStreamReader> SocketChannel::AsStreamReader()
{
if (this->pRecvProtocol)
{
return this->pRecvProtocol->AsStreamReader();
}
else if (this->pCachedReader)
{
return this->pCachedReader;
}
else
{
struct A : AuIO::Buffered::BlobReader
{
AuWPtr<SocketChannel> swpChannel;
A(AuSPtr<AuByteBuffer> pBuffer,
AuWPtr<SocketChannel> swpChannel) :
AuIO::Buffered::BlobReader(pBuffer),
swpChannel(swpChannel)
{
}
virtual void Close() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->Shutdown(false);
}
}
};
return this->pCachedReader = AuMakeSharedThrow<A>(this->AsReadableByteBuffer(), AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this));
}
}
AuSPtr<Memory::ByteBuffer> SocketChannel::AsReadableByteBuffer()
{
if (this->pRecvProtocol)
{
return this->pRecvProtocol->AsReadableByteBuffer();
}
else
{
return this->inputChannel.AsReadableByteBuffer();
}
}
AuSPtr<IStreamWriter> SocketChannel::AsStreamWriter()
{
if (this->pCachedWriter)
{
return this->pCachedWriter;
}
else if (this->pSendProtocol)
{
struct C : IO::IStreamWriter
{
AuWPtr<IO::IStreamWriter> swpIStreamWriter;
AuWPtr<SocketChannel> swpChannel;
C(AuWPtr<SocketChannel> swpChannel,
AuWPtr<IStreamWriter> swpIStreamWriter) :
swpIStreamWriter(swpIStreamWriter),
swpChannel(swpChannel)
{
}
virtual EStreamError IsOpen() override
{
if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter))
{
return EStreamError::eErrorNone;
}
else
{
return EStreamError::eErrorHandleClosed;
}
}
virtual EStreamError Write(const AuMemoryViewStreamRead &read) override
{
if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter))
{
return pStream->Write(read);
}
else
{
return EStreamError::eErrorHandleClosed;
}
}
virtual void Close() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->Shutdown(false);
}
}
virtual void Flush() override
{
if (auto pStream = AuTryLockMemoryType(this->swpIStreamWriter))
{
pStream->Flush();
}
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->ToChannel()->ScheduleOutOfFrameWrite();
}
}
};
return this->pCachedWriter = AuMakeSharedThrow<C>(AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this),
this->pSendProtocol->AsStreamWriter());
}
else
{
struct A : IO::Buffered::BlobWriter
{
AuWPtr<SocketChannel> swpChannel;
A(AuSPtr<AuByteBuffer> pBuffer,
AuWPtr<SocketChannel> swpChannel) :
IO::Buffered::BlobWriter(pBuffer),
swpChannel(swpChannel)
{
}
virtual void Close() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->Shutdown(false);
}
}
virtual void Flush() override
{
if (auto pChannel = AuTryLockMemoryType(this->swpChannel))
{
pChannel->pParent_->ToChannel()->ScheduleOutOfFrameWrite();
}
}
};
return this->pCachedWriter = AuMakeSharedThrow<A>(this->AsWritableByteBuffer(), AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this));
}
}
AuSPtr<Memory::ByteBuffer> SocketChannel::AsWritableByteBuffer()
{
if (this->pSendProtocol)
{
return this->pSendProtocol->AsReadableByteBuffer();
}
else
{
return this->outputChannel.AsWritableByteBuffer();
}
}
AuSPtr<Protocol::IProtocolStack> SocketChannel::NewProtocolRecvStack()
{
auto pProtocol = AuMakeShared<Protocol::ProtocolStack>();
if (!pProtocol)
{
SysPushErrorMemory();
return {};
}
pProtocol->pSourceBufer = this->inputChannel.AsReadableByteBuffer();
pProtocol->bKillPipeOnFirstRootLevelFalse = true; // harden
return pProtocol;
}
AuSPtr<Protocol::IProtocolStack> SocketChannel::NewProtocolSendStack()
{
auto pBaseProtocol = Protocol::NewBufferedProtocolStack(this->uBytesOutputBuffer);
if (!pBaseProtocol)
{
SysPushErrorMemory();
return {};
}
auto pProtocol = AuStaticCast<Protocol::ProtocolStack>(pBaseProtocol);
if (!pProtocol)
{
return {};
}
pProtocol->pDrainBuffer = this->outputChannel.AsWritableByteBuffer();
pProtocol->bKillPipeOnFirstRootLevelFalse = true; // harden
return pProtocol;
}
void SocketChannel::SpecifyRecvProtocol(const AuSPtr<Protocol::IProtocolStack> &pRecvProtocol)
{
this->pRecvProtocol = pRecvProtocol;
this->pCachedReader.reset();
}
void SocketChannel::SpecifySendProtocol(const AuSPtr<Protocol::IProtocolStack> &pSendProtocol)
{
this->pSendProtocol = pSendProtocol;
this->pCachedWriter.reset();
}
AuSPtr<ISocketStats> SocketChannel::GetRecvStats()
{
return AuSPtr<ISocketStats>(this->pParent_->SharedFromThis(), &this->recvStats_);
}
AuSPtr<ISocketStats> SocketChannel::GetSendStats()
{
return AuSPtr<ISocketStats>(this->pParent_->SharedFromThis(), &this->sendStats_);
}
SocketStats &SocketChannel::GetSendStatsEx()
{
return this->sendStats_;
}
SocketStats &SocketChannel::GetRecvStatsEx()
{
return this->recvStats_;
}
void SocketChannel::AddEventListener(const AuSPtr<ISocketChannelEventListener> &pListener)
{
AU_LOCK_GUARD(this->spinLock);
auto pWorker = this->pParent_->ToWorkerEx();
if (pWorker->IsOnThread() && this->AsReadableByteBuffer()->GetNextLinearRead().length)
{
if (!pListener->OnData())
{
pListener->OnComplete();
return;
}
}
this->eventListeners.push_back(pListener);
}
void SocketChannel::RemoveEventListener(const AuSPtr<ISocketChannelEventListener> &pListener)
{
AU_LOCK_GUARD(this->spinLock);
bool bSuccess = AuTryRemove(this->eventListeners, pListener);
auto pWorker = this->pParent_->ToWorkerEx();
if (pWorker->IsOnThread() && bSuccess)
{
pListener->OnRejected();
}
}
bool SocketChannel::IsValid()
{
return bool(this->outputChannel.IsValid()) &&
bool(this->inputChannel.IsValid());
}
bool SocketChannel::SpecifyBufferSize(AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
auto pWorker = this->pParent_->ToWorkerEx();
if (!pWorker)
{
return false;
}
if (pWorker->IsOnThread())
{
this->DoBufferResizeOnThread(true,
true,
uBytes,
pCallbackOptional);
return true;
}
auto that = AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this);
if (!pWorker->TryScheduleInternalTemplate<AuNullS>([=](const AuSPtr<AuAsync::PromiseCallback<AuNullS>> &info)
{
that->DoBufferResizeOnThread(true,
true,
uBytes,
pCallbackOptional);
}, AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>>{}))
{
return false;
}
return true;
}
bool SocketChannel::SpecifyOutputBufferSize(AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
auto pWorker = this->pParent_->ToWorkerEx();
if (!pWorker)
{
return false;
}
if (pWorker->IsOnThread())
{
this->DoBufferResizeOnThread(false,
true,
uBytes,
pCallbackOptional);
return true;
}
auto that = AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this);
if (!pWorker->TryScheduleInternalTemplate<AuNullS>([=](const AuSPtr<AuAsync::PromiseCallback<AuNullS>> &info)
{
that->DoBufferResizeOnThread(false,
true,
uBytes,
pCallbackOptional);
}, AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>>{}))
{
return false;
}
return true;
}
bool SocketChannel::SpecifyInputBufferSize(AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
auto pWorker = this->pParent_->ToWorkerEx();
if (!pWorker)
{
return false;
}
if (pWorker->IsOnThread())
{
this->DoBufferResizeOnThread(true,
false,
uBytes,
pCallbackOptional);
return true;
}
auto that = AuSPtr<SocketChannel>(this->pParent_->SharedFromThis(), this);
if (!pWorker->TryScheduleInternalTemplate<AuNullS>([=](const AuSPtr<AuAsync::PromiseCallback<AuNullS>> &info)
{
that->DoBufferResizeOnThread(true,
false,
uBytes,
pCallbackOptional);
}, AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>>{}))
{
return false;
}
return true;
}
void SocketChannel::ScheduleOutOfFrameWrite()
{
if (this->pSendProtocol)
{
this->pSendProtocol->DoTick();
}
this->outputChannel.ScheduleOutOfFrameWrite();
}
AuSPtr<IAsyncTransaction> SocketChannel::ToReadTransaction()
{
return this->bIsManualRead ?
this->inputChannel.pNetReadTransaction :
AuSPtr<IAsyncTransaction> {};
}
AuSPtr<IAsyncTransaction> SocketChannel::ToWriteTransaction()
{
return this->bIsManualWrite ?
this->outputChannel.ToWriteTransaction() :
AuSPtr<IAsyncTransaction> {};
}
bool SocketChannel::SpecifyTCPNoDelay(bool bFuckNagle)
{
if (!this->bIsManualWrite)
{
return false;
}
this->bTcpNoDelay = bFuckNagle;
this->pParent_->UpdateNagleAnyThread(this->bTcpNoDelay);
return true;
}
bool SocketChannel::SpecifyTransactionsHaveIOFence(bool bAllocateFence)
{
return false;
}
bool SocketChannel::SpecifyManualWrite(bool bEnableDirectAIOWrite)
{
if (this->bIsEstablished)
{
return false;
}
this->bIsManualWrite = true;
return true;
}
bool SocketChannel::SpecifyManualRead(bool bEnableDirectAIORead)
{
if (this->bIsEstablished)
{
return false;
}
this->bIsManualRead = true;
return true;
}
bool SocketChannel::SpecifyPerTickAsyncReadLimit(AuUInt uBytes)
{
if (this->bIsEstablished)
{
return false;
}
this->uBytesToFlip = true;
return true;
}
void SocketChannel::Release()
{
this->PrivateUserDataClear();
this->pCachedReader.reset();
this->pCachedWriter.reset();
this->pRecvProtocol.reset();
this->pSendProtocol.reset();
this->inputChannel.pNetReader.reset();
}
void SocketChannel::DoBufferResizeOnThread(bool bInput,
bool bOutput,
AuUInt uBytes,
const AuSPtr<AuAsync::PromiseCallback<AuNullS, AuNullS>> &pCallbackOptional)
{
if (!this->bIsEstablished)
{
if (bInput)
{
this->uBytesInputBuffer = uBytes;
}
if (bOutput)
{
this->uBytesOutputBuffer = uBytes;
}
if (pCallbackOptional)
{
pCallbackOptional->OnSuccess((void *)nullptr);
}
return;
}
if (bOutput)
{
if (this->outputChannel.CanResize())
{
AuByteBuffer newBuffer(uBytes, false, false);
if (!(newBuffer.IsValid()))
{
SysPushErrorMemory();
if (pCallbackOptional)
{
pCallbackOptional->OnFailure((void *)nullptr);
}
return;
}
auto &byteBufferRef = this->outputChannel.GetByteBuffer();
auto oldReadHead = byteBufferRef.readPtr;
if (!newBuffer.WriteFrom(byteBufferRef))
{
SysPushErrorMemory();
this->outputChannel.GetByteBuffer().readPtr = oldReadHead;
if (pCallbackOptional)
{
pCallbackOptional->OnFailure((void *)nullptr);
}
return;
}
byteBufferRef = AuMove(newBuffer);
if (pCallbackOptional)
{
pCallbackOptional->OnSuccess((void *)nullptr);
}
return;
}
}
{
AU_LOCK_GUARD(this->spinLock);
if (bOutput)
{
if (this->uBytesOutputBufferRetarget)
{
if (this->pRetargetOutput)
{
this->pRetargetOutput->OnFailure((void *)nullptr);
}
}
this->uBytesOutputBufferRetarget = uBytes;
this->pRetargetOutput = pCallbackOptional;
}
else
{
if (this->uBytesInputBufferRetarget)
{
if (this->pRetargetInput)
{
this->pRetargetInput->OnFailure((void *)nullptr);
}
}
this->uBytesInputBufferRetarget = uBytes;
this->pRetargetInput = pCallbackOptional;
}
}
}
void SocketChannel::DoReallocWriteTick()
{
AU_LOCK_GUARD(this->spinLock);
if (!this->uBytesOutputBufferRetarget)
{
return;
}
AuByteBuffer newBuffer(this->uBytesOutputBufferRetarget, true, false);
if (!(newBuffer.IsValid()))
{
SysPushErrorMemory();
if (this->pRetargetOutput)
{
this->pRetargetOutput->OnFailure((void *)nullptr);
}
return;
}
auto &byteBufferRef = this->outputChannel.GetByteBuffer();
auto oldReadHead = byteBufferRef.readPtr;
auto oldWriteHead = byteBufferRef.writePtr;
if (!newBuffer.WriteFrom(byteBufferRef))
{
SysPushErrorMemory();
this->outputChannel.GetByteBuffer().readPtr = oldReadHead;
this->outputChannel.GetByteBuffer().writePtr = oldWriteHead;
if (this->pRetargetOutput)
{
this->pRetargetOutput->OnFailure((void *)nullptr);
}
return;
}
byteBufferRef = AuMove(newBuffer);
this->uBytesOutputBufferRetarget = 0;
if (this->pRetargetOutput)
{
this->pRetargetOutput->OnSuccess((void *)nullptr);
}
}
void SocketChannel::DoReallocReadTick()
{
AU_LOCK_GUARD(this->spinLock);
if (!this->uBytesInputBufferRetarget)
{
return;
}
AuByteBuffer newBuffer(this->uBytesInputBufferRetarget, true, false);
if (!(newBuffer.IsValid()))
{
SysPushErrorMemory();
if (this->pRetargetInput)
{
this->pRetargetInput->OnFailure((void *)nullptr);
}
return;
}
auto byteBufferRef = this->inputChannel.AsReadableByteBuffer();
auto oldReadHead = byteBufferRef->readPtr;
auto oldWriteHead = byteBufferRef->writePtr;
if (!newBuffer.WriteFrom(*byteBufferRef.get()))
{
SysPushErrorMemory();
byteBufferRef->readPtr = oldReadHead;
byteBufferRef->writePtr = oldWriteHead;
if (this->pRetargetInput)
{
this->pRetargetInput->OnFailure((void *)nullptr);
}
return;
}
byteBufferRef.get()->operator=(AuMove(newBuffer));
this->uBytesInputBufferRetarget = 0;
if (this->pRetargetInput)
{
this->pRetargetInput->OnSuccess((void *)nullptr);
}
}
bool SocketChannel::GetCurrentTCPNoDelay()
{
return this->bTcpNoDelay;
}
AuUInt SocketChannel::GetInputBufferSize()
{
return this->uBytesInputBuffer;
}
AuUInt SocketChannel::GetOutputBufferSize()
{
return this->uBytesOutputBuffer;
}
}