AuroraRuntime/Source/IO/Net/AuNetSocket.cpp
2023-09-23 20:36:28 +01:00

655 lines
18 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuNetSocket.cpp
Date: 2022-8-16
Author: Reece
***/
#include "Networking.hpp"
#include "AuNetSocket.hpp"
#include "AuNetEndpoint.hpp"
#include "AuNetWorker.hpp"
#include "AuIPAddress.hpp"
#include "AuNetError.hpp"
#include "AuNetSocketServer.hpp"
#include "AuNetInterface.hpp"
#include "../AuIOHandle.hpp"
#if defined(AURORA_IS_MODERNNT_DERIVED)
#include "AuNetStream.NT.hpp"
#endif
namespace Aurora::IO::Net
{
static NetError GetLastNetError()
{
NetError error;
NetError_SetCurrent(error);
return error;
}
SocketBase::SocketBase(struct NetInterface *pInterface,
struct NetWorker *pWorker,
const AuSPtr<ISocketDriver> &pSocketDriver,
AuUInt osHandle) :
connectOperation(this),
socketChannel_(this),
pInterface_(pInterface),
pWorker_(pWorker),
pSocketDriver_(pSocketDriver),
osHandle_(osHandle)
{
this->pWorker_->AddSocket(this);
this->osHandleOwner_ = AuIO::IOHandleShared();
if (!this->osHandle_)
{
return;
}
#if defined(AURORA_IS_MODERNNT_DERIVED)
//this->osHandleOwner_->InitFromPairMove((HANDLE)this->osHandle_, (HANDLE)this->osHandle_);
#else
this->osHandleOwner_->InitFromPairMove((int)this->osHandle_, (int)this->osHandle_);
#endif
}
SocketBase::SocketBase(struct NetInterface *pInterface,
struct NetWorker *pWorker,
const AuSPtr<ISocketDriver> &pSocketDriver,
const NetEndpoint &endpoint) :
connectOperation(this),
socketChannel_(this),
pInterface_(pInterface),
pWorker_(pWorker),
pSocketDriver_(pSocketDriver),
remoteEndpoint_(endpoint)
{
this->osHandleOwner_ = AuIO::IOHandleShared();
if (!this->osHandle_)
{
return;
}
this->pWorker_->AddSocket(this);
}
SocketBase::SocketBase(struct NetInterface *pInterface,
struct NetWorker *pWorker,
const AuSPtr<ISocketDriver> &pSocketDriver,
const AuPair<NetHostname, AuUInt16> &endpoint,
AuNet::ETransportProtocol eProtocol) :
connectOperation(this),
socketChannel_(this),
pInterface_(pInterface),
pWorker_(pWorker),
pSocketDriver_(pSocketDriver)
{
auto &[host, uPort] = endpoint;
if (host.type == AuNet::EHostnameType::eHostByIp)
{
this->remoteEndpoint_.ip = host.address;
this->remoteEndpoint_.uPort = uPort;
this->remoteEndpoint_.transportProtocol = eProtocol;
OptimizeEndpoint(this->remoteEndpoint_);
}
else
{
this->resolveLater = host.hostname;
this->remoteEndpoint_.uPort = uPort;
this->remoteEndpoint_.transportProtocol = eProtocol;
}
this->osHandleOwner_ = AuIO::IOHandleShared();
if (!this->osHandle_)
{
return;
}
this->pWorker_->AddSocket(this);
}
SocketBase::SocketBase(NetInterface *pInterface,
NetWorker *pWorker,
const AuSPtr<ISocketDriver> &pSocketDriver,
const NetSocketConnectMany &connectMany) :
connectOperation(this),
socketChannel_(this),
pInterface_(pInterface),
pWorker_(pWorker),
pSocketDriver_(pSocketDriver),
bHasRemoteMany_(true),
connectMany_(connectMany)
{
this->connectMany_.pDriver.reset();
this->osHandleOwner_ = AuIO::IOHandleShared();
if (!this->osHandle_)
{
return;
}
}
SocketBase::~SocketBase()
{
this->Destroy();
}
bool SocketBase::IsValid()
{
return (this->resolveLater.size()) ||
bool(this->osHandleOwner_) &&
bool(this->connectOperation.IsValid()) &&
bool(this->osHandle_ != 0) &&
bool(this->osHandle_ != -1) &&
bool(!this->bForceFailConstruct_);
}
bool SocketBase::TryStartResolve()
{
auto pThat = this->SharedFromThis();
if (this->bResolving_)
{
return true;
}
this->bResolving_ = true;
auto address = this->resolveLater;
this->resolveLater.clear();
auto pResolver = this->pInterface_->GetResolveService()->SimpleAllResolve(address,
AuMakeSharedThrow<AuAsync::PromiseCallbackFunctional<AuList<AuNet::IPAddress>,
AuNet::NetError>>(
[=](const AuSPtr<AuList<AuNet::IPAddress>> &ips)
{
pThat->bResolving_ = false;
AuList<NetSocketConnectBase> names;
for (auto &ip : *ips)
{
NetSocketConnectBase base(ip, pThat->remoteEndpoint_.uPort);
base.byEndpoint.value().transportProtocol = this->remoteEndpoint_.transportProtocol;
names.push_back(base);
}
pThat->connectMany_.names = names;
pThat->bHasRemoteMany_ = true;
pThat->RenewSocket();
pThat->ConnectNext();
},
[=](const AuSPtr<AuNet::NetError> &error)
{
pThat->SendErrorNoStream(error ? *error.get() : AuNet::NetError {});
}));
return bool(pResolver);
}
bool SocketBase::ConnectNext()
{
if (this->connectMany_.names.empty())
{
return false;
}
auto base = this->connectMany_.names[0];
this->connectMany_.names.erase(this->connectMany_.names.begin());
NetEndpoint endpoint;
if (base.byEndpoint)
{
endpoint = base.byEndpoint.value();
}
else
{
auto val = base.byHost.value();
if (val.netHostname.type == EHostnameType::eHostByDns)
{
return false;
}
endpoint.uPort = val.uPort;
endpoint.ip = val.netHostname.address;
endpoint.transportProtocol = val.protocol;
}
return this->Connect(endpoint);
}
bool SocketBase::Connect(const NetEndpoint &endpoint)
{
if (!this->IsValid())
{
this->SendErrorNoStream({});
return false;
}
this->remoteEndpoint_ = endpoint;
this->endpointSize_ = OptimizeEndpoint(this->remoteEndpoint_);
this->localEndpoint_.transportProtocol = this->remoteEndpoint_.transportProtocol;
if (!this->endpointSize_)
{
SysPushErrorIO("Invalid remote endpoint");
return false;
}
bool bStatus = this->ConnectOverlapped() ||
this->ConnectNonblocking() ||
this->ConnectBlocking();
if (!bStatus)
{
this->connectOperation.OnIOFailure();
}
return bStatus;
}
AuUInt SocketBase::ToPlatformHandle()
{
return this->osHandle_;
}
AuSPtr<ISocketDriver> SocketBase::GetUserDriver()
{
return this->pSocketDriver_;
}
const NetEndpoint &SocketBase::GetRemoteEndpoint()
{
return this->remoteEndpoint_;
}
const NetEndpoint &SocketBase::GetLocalEndpoint()
{
return this->localEndpoint_;
}
void SocketBase::ConnectFinished()
{
this->UpdateLocalEndpoint();
this->UpdateRemoteEndpoint();
DoMain();
}
void SocketBase::DoMain()
{
this->socketChannel_.Establish();
if (AuExchange(this->bHasConnected_, true))
{
return;
}
if (this->bHasErrored_)
{
return;
}
if (this->bHasEnded)
{
return;
}
auto pDriver = this->pSocketDriver_;
if (bool(pDriver))
{
try
{
pDriver->OnEstablish();
}
catch (...)
{
SysPushErrorCatch();
this->SendErrorBeginShutdown({});
}
}
if (this->localEndpoint_.transportProtocol == ETransportProtocol::eProtocolTCP)
{
this->ToChannel()->ScheduleOutOfFrameWrite();
}
}
void SocketBase::ConnectFailed(const NetError &error)
{
if (this->ConnectNext())
{
return;
}
this->SendErrorNoStream(error);
}
void SocketBase::SendErrorNoStream(const NetError &error)
{
if (AuExchange(this->bHasErrored_, true))
{
return;
}
this->error_ = error;
this->RejectAllListeners();
if (this->pSocketDriver_)
{
try
{
this->pSocketDriver_->OnFatalErrorReported(error);
}
catch (...)
{
SysPushErrorCatch();
}
}
this->SendFinalize();
}
void SocketBase::SendErrorBeginShutdown(const NetError &error)
{
if (AuExchange(this->bHasErrored_, true))
{
return;
}
this->error_ = error;
if (this->pSocketDriver_)
{
try
{
this->pSocketDriver_->OnFatalErrorReported(error);
}
catch (...)
{
SysPushErrorCatch();
}
}
RejectAllListeners();
this->Shutdown();
}
void SocketBase::RejectAllListeners()
{
AU_LOCK_GUARD(this->socketChannel_.spinLock);
for (const auto &pListener : this->socketChannel_.eventListeners)
{
pListener->OnRejected();
}
this->socketChannel_.eventListeners.clear();
}
void SocketBase::CompleteAllListeners()
{
AU_LOCK_GUARD(this->socketChannel_.spinLock);
for (const auto &pListener : this->socketChannel_.eventListeners)
{
pListener->OnComplete();
}
this->socketChannel_.eventListeners.clear();
}
void SocketBase::SendOnData()
{
auto pReadableBuffer = this->socketChannel_.AsReadableByteBuffer();
auto pStartOffset = pReadableBuffer ? pReadableBuffer->readPtr : nullptr;
if (this->bHasFinalized_)
{
if (this->socketChannel_.pRecvProtocol)
{
this->socketChannel_.pRecvProtocol->DoTick();
}
this->socketChannel_.ScheduleOutOfFrameWrite();
return;
}
if (this->socketChannel_.pRecvProtocol)
{
this->socketChannel_.pRecvProtocol->DoTick();
}
if (this->pSocketDriver_)
{
try
{
this->pSocketDriver_->OnStreamUpdated();
}
catch (...)
{
SysPushErrorCatch();
this->SendErrorBeginShutdown({});
}
}
{
AuList<AuSPtr<ISocketChannelEventListener>> listeners;
AuList<AuSPtr<ISocketChannelEventListener>> listenersToEvict;
{
AU_LOCK_GUARD(this->socketChannel_.spinLock);
listeners = this->socketChannel_.eventListeners;
}
if (AuTryReserve(listenersToEvict, listeners.size()))
{
for (const auto &pListener : listeners)
{
if (!pListener->OnData())
{
pListener->OnComplete();
listenersToEvict.push_back(pListener);
}
}
{
AU_LOCK_GUARD(this->socketChannel_.spinLock);
for (const auto &pListener : listenersToEvict)
{
AuTryRemove(this->socketChannel_.eventListeners, pListener);
}
}
}
}
this->ToChannel()->ScheduleOutOfFrameWrite();
auto uHeadDelta = pReadableBuffer ? (pReadableBuffer->readPtr - pStartOffset) : 0;
this->socketChannel_.GetRecvStatsEx().AddBytes(uHeadDelta);
}
const NetError &SocketBase::GetError()
{
if (this->error_.netError == ENetworkError::kEnumInvalid)
{
// Note: some interfaces will now have a `HasErrorCode` that reports non-fatal errors
// This should be useful for reporting the disconnect reason when otherwise filtered as not-an-error
if (this->socketChannel_.inputChannel.pNetReadTransaction &&
#if defined(AURORA_IS_MODERNNT_DERIVED)
AuStaticCast<NtAsyncNetworkTransaction>(this->socketChannel_.inputChannel.pNetReadTransaction)->HasErrorCode()
#else
this->socketChannel_.inputChannel.pNetReadTransaction->Failed()
#endif
)
{
NetError_SetOsError(this->error_, this->socketChannel_.inputChannel.pNetReadTransaction->GetOSErrorCode());
}
else if (this->socketChannel_.outputChannel.pNetWriteTransaction_ &&
#if defined(AURORA_IS_MODERNNT_DERIVED)
AuStaticCast<NtAsyncNetworkTransaction>(this->socketChannel_.outputChannel.pNetWriteTransaction_)->HasErrorCode()
#else
this->socketChannel_.outputChannel.pNetWriteTransaction_->Failed()
#endif
)
{
NetError_SetOsError(this->error_, this->socketChannel_.outputChannel.pNetWriteTransaction_->GetOSErrorCode());
}
}
return this->error_;
}
AuSPtr<ISocketChannel> SocketBase::ToChannel()
{
return AuSPtr<ISocketChannel>(AuSharedFromThis(), &this->socketChannel_);
}
void SocketBase::ShutdownLite()
{
if (this->osHandleOwner_)
{
AuStaticCast<AFileHandle>(this->osHandleOwner_)->uOSWriteHandle.Reset();
}
}
void SocketBase::Destroy()
{
this->SendFinalize();
}
INetWorker *SocketBase::ToWorker()
{
return this->pWorker_;
}
NetWorker *SocketBase::ToWorkerEx()
{
return this->pWorker_;
}
bool SocketBase::SendPreestablish(SocketServer *pServer)
{
if (this->bHasPreestablished_)
{
return true;
}
if (pServer && pServer->uDefaultInputStreamSize)
{
this->socketChannel_.uBytesInputBuffer = pServer->uDefaultInputStreamSize;
}
this->socketChannel_.inputChannel.WarmOnEstablish(); // Allocate stream resources, in case we need to start working with the source buffer
// think: setting up protocol stacks, accessing the base bytebuffer, without the pipe being
// allocated already
if (this->pSocketDriver_)
{
try
{
return this->pSocketDriver_->OnPreestablish(AuSharedFromThis());
}
catch (...)
{
SysPushErrorCatch();
}
}
return this->bHasPreestablished_ = true;
}
void SocketBase::SendEnd()
{
if (AuExchange(this->bHasEnded, true))
{
return;
}
if (this->pSocketDriver_)
{
try
{
this->pSocketDriver_->OnEnd();
}
catch (...)
{
SysPushErrorCatch();
}
}
this->SendFinalize();
}
void SocketBase::SendFinalize()
{
if (AuExchange(this->bHasFinalized_, true))
{
return;
}
this->pWorker_->RemoveSocket(this);
if (this->bHasErrored_)
{
this->RejectAllListeners();
}
else
{
this->CompleteAllListeners();
}
if (this->pSocketDriver_)
{
try
{
this->pSocketDriver_->OnFinalize();
}
catch (...)
{
SysPushErrorCatch();
}
this->pSocketDriver_.reset();
}
auto pWriteTransaction = this->socketChannel_.outputChannel.ToWriteTransaction();
#if defined(AURORA_IS_MODERNNT_DERIVED)
if (pWriteTransaction)
{
AuStaticCast<NtAsyncNetworkTransaction>(pWriteTransaction)->MakeSyncable();
AuStaticCast<NtAsyncNetworkTransaction>(pWriteTransaction)->ForceNextWriteWait();
}
#endif
this->SendOnData();
this->socketChannel_.Release();
#if defined(AURORA_IS_MODERNNT_DERIVED)
if (this->socketChannel_.inputChannel.pNetReadTransaction)
{
AuStaticCast<NtAsyncNetworkTransaction>(this->socketChannel_.inputChannel.pNetReadTransaction)->pSocket = nullptr;
}
this->socketChannel_.inputChannel.pNetReadTransaction.reset();
if (pWriteTransaction)
{
AuStaticCast<NtAsyncNetworkTransaction>(pWriteTransaction)->pSocket = nullptr;
}
this->socketChannel_.outputChannel.pParent_ = nullptr;
#else
// TODO:
#endif
this->CloseSocket();
}
}