698 lines
19 KiB
C++
698 lines
19 KiB
C++
/***
|
|
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
|
|
|
|
File: AuNetSocket.cpp
|
|
Date: 2022-8-16
|
|
Author: Reece
|
|
***/
|
|
#include "Networking.hpp"
|
|
#include "AuNetSocket.hpp"
|
|
#include "AuNetEndpoint.hpp"
|
|
#include "AuNetWorker.hpp"
|
|
#include "AuIPAddress.hpp"
|
|
#include "AuNetError.hpp"
|
|
#include "AuNetSocketServer.hpp"
|
|
#include "AuNetInterface.hpp"
|
|
|
|
#include "../AuIOHandle.hpp"
|
|
|
|
#if defined(AURORA_IS_MODERNNT_DERIVED)
|
|
#include "AuNetStream.NT.hpp"
|
|
#endif
|
|
|
|
namespace Aurora::IO::Net
|
|
{
|
|
static NetError GetLastNetError()
|
|
{
|
|
NetError error;
|
|
NetError_SetCurrent(error);
|
|
return error;
|
|
}
|
|
|
|
SocketBase::SocketBase(struct NetInterface *pInterface,
|
|
struct NetWorker *pWorker,
|
|
const AuSPtr<ISocketDriver> &pSocketDriver,
|
|
AuUInt osHandle,
|
|
AuSPtr<ISocketServer> pParent,
|
|
SocketServer *pParent2) :
|
|
connectOperation(this),
|
|
socketChannel_(this),
|
|
pInterface_(pInterface),
|
|
pWorker_(pWorker),
|
|
pSocketDriver_(pSocketDriver),
|
|
osHandle_(osHandle),
|
|
wpParent_(pParent),
|
|
wpParent2_(pParent2)
|
|
{
|
|
this->pWorker_->AddSocket(this);
|
|
|
|
this->osHandleOwner_ = AuIO::IOHandleShared();
|
|
if (!this->osHandle_)
|
|
{
|
|
return;
|
|
}
|
|
|
|
|
|
#if defined(AURORA_IS_MODERNNT_DERIVED)
|
|
//this->osHandleOwner_->InitFromPairMove((HANDLE)this->osHandle_, (HANDLE)this->osHandle_);
|
|
#else
|
|
this->osHandleOwner_->InitFromPairMove((int)this->osHandle_, (int)this->osHandle_);
|
|
#endif
|
|
|
|
if (pParent)
|
|
{
|
|
pParent2->OnNotifyChildCreated(this);
|
|
}
|
|
}
|
|
|
|
SocketBase::SocketBase(struct NetInterface *pInterface,
|
|
struct NetWorker *pWorker,
|
|
const AuSPtr<ISocketDriver> &pSocketDriver,
|
|
const NetEndpoint &endpoint) :
|
|
connectOperation(this),
|
|
socketChannel_(this),
|
|
pInterface_(pInterface),
|
|
pWorker_(pWorker),
|
|
pSocketDriver_(pSocketDriver),
|
|
remoteEndpoint_(endpoint)
|
|
{
|
|
this->osHandleOwner_ = AuIO::IOHandleShared();
|
|
if (!this->osHandle_)
|
|
{
|
|
return;
|
|
}
|
|
this->pWorker_->AddSocket(this);
|
|
}
|
|
|
|
SocketBase::SocketBase(struct NetInterface *pInterface,
|
|
struct NetWorker *pWorker,
|
|
const AuSPtr<ISocketDriver> &pSocketDriver,
|
|
const AuPair<NetHostname, AuUInt16> &endpoint,
|
|
AuNet::ETransportProtocol eProtocol) :
|
|
connectOperation(this),
|
|
socketChannel_(this),
|
|
pInterface_(pInterface),
|
|
pWorker_(pWorker),
|
|
pSocketDriver_(pSocketDriver)
|
|
{
|
|
auto &[host, uPort] = endpoint;
|
|
|
|
if (host.type == AuNet::EHostnameType::eHostByIp)
|
|
{
|
|
this->remoteEndpoint_.ip = host.address;
|
|
this->remoteEndpoint_.uPort = uPort;
|
|
this->remoteEndpoint_.transportProtocol = eProtocol;
|
|
OptimizeEndpoint(this->remoteEndpoint_);
|
|
}
|
|
else
|
|
{
|
|
this->resolveLater = host.hostname;
|
|
this->remoteEndpoint_.uPort = uPort;
|
|
this->remoteEndpoint_.transportProtocol = eProtocol;
|
|
}
|
|
|
|
this->osHandleOwner_ = AuIO::IOHandleShared();
|
|
|
|
if (!this->osHandle_)
|
|
{
|
|
return;
|
|
}
|
|
|
|
this->pWorker_->AddSocket(this);
|
|
}
|
|
|
|
|
|
SocketBase::SocketBase(NetInterface *pInterface,
|
|
NetWorker *pWorker,
|
|
const AuSPtr<ISocketDriver> &pSocketDriver,
|
|
const NetSocketConnectMany &connectMany) :
|
|
connectOperation(this),
|
|
socketChannel_(this),
|
|
pInterface_(pInterface),
|
|
pWorker_(pWorker),
|
|
pSocketDriver_(pSocketDriver),
|
|
bHasRemoteMany_(true),
|
|
connectMany_(connectMany)
|
|
{
|
|
this->connectMany_.pDriver.reset();
|
|
|
|
this->osHandleOwner_ = AuIO::IOHandleShared();
|
|
|
|
if (!this->osHandle_)
|
|
{
|
|
return;
|
|
}
|
|
}
|
|
|
|
SocketBase::~SocketBase()
|
|
{
|
|
this->Destroy();
|
|
}
|
|
|
|
bool SocketBase::IsValid()
|
|
{
|
|
return (this->resolveLater.size()) ||
|
|
bool(this->osHandleOwner_) &&
|
|
bool(this->connectOperation.IsValid()) &&
|
|
bool(this->osHandle_ != 0) &&
|
|
bool(this->osHandle_ != -1) &&
|
|
bool(!this->bForceFailConstruct_);
|
|
}
|
|
|
|
bool SocketBase::TryStartResolve()
|
|
{
|
|
auto pThat = this->SharedFromThis();
|
|
|
|
if (this->bResolving_)
|
|
{
|
|
return true;
|
|
}
|
|
|
|
this->bResolving_ = true;
|
|
auto address = this->resolveLater;
|
|
this->resolveLater.clear();
|
|
|
|
auto pResolver = this->pInterface_->GetResolveService()->SimpleAllResolve(address,
|
|
AuMakeSharedThrow<AuAsync::PromiseCallbackFunctional<AuList<AuNet::IPAddress>,
|
|
AuNet::NetError>>(
|
|
[=](const AuSPtr<AuList<AuNet::IPAddress>> &ips)
|
|
{
|
|
pThat->bResolving_ = false;
|
|
|
|
AuList<NetSocketConnectBase> names;
|
|
for (auto &ip : *ips)
|
|
{
|
|
NetSocketConnectBase base(ip, pThat->remoteEndpoint_.uPort);
|
|
base.byEndpoint.value().transportProtocol = this->remoteEndpoint_.transportProtocol;
|
|
names.push_back(base);
|
|
}
|
|
|
|
pThat->connectMany_.names = names;
|
|
|
|
pThat->bHasRemoteMany_ = true;
|
|
pThat->RenewSocket();
|
|
pThat->ConnectNext();
|
|
},
|
|
[=](const AuSPtr<AuNet::NetError> &error)
|
|
{
|
|
pThat->SendErrorNoStream(error ? *error.get() : AuNet::NetError {});
|
|
}));
|
|
|
|
return bool(pResolver);
|
|
}
|
|
|
|
bool SocketBase::ConnectNext()
|
|
{
|
|
if (this->connectMany_.names.empty())
|
|
{
|
|
return false;
|
|
}
|
|
|
|
auto base = this->connectMany_.names[0];
|
|
this->connectMany_.names.erase(this->connectMany_.names.begin());
|
|
|
|
NetEndpoint endpoint;
|
|
if (base.byEndpoint)
|
|
{
|
|
endpoint = base.byEndpoint.value();
|
|
}
|
|
else
|
|
{
|
|
auto val = base.byHost.value();
|
|
if (val.netHostname.type == EHostnameType::eHostByDns)
|
|
{
|
|
return false;
|
|
}
|
|
endpoint.uPort = val.uPort;
|
|
endpoint.ip = val.netHostname.address;
|
|
endpoint.transportProtocol = val.protocol;
|
|
}
|
|
return this->Connect(endpoint);
|
|
}
|
|
|
|
bool SocketBase::Connect(const NetEndpoint &endpoint)
|
|
{
|
|
if (!this->IsValid())
|
|
{
|
|
this->SendErrorNoStream({});
|
|
return false;
|
|
}
|
|
|
|
this->remoteEndpoint_ = endpoint;
|
|
this->endpointSize_ = OptimizeEndpoint(this->remoteEndpoint_);
|
|
this->localEndpoint_.transportProtocol = this->remoteEndpoint_.transportProtocol;
|
|
|
|
if (!this->endpointSize_)
|
|
{
|
|
SysPushErrorIO("Invalid remote endpoint");
|
|
return false;
|
|
}
|
|
|
|
bool bStatus = this->ConnectOverlapped() ||
|
|
this->ConnectNonblocking() ||
|
|
this->ConnectBlocking();
|
|
|
|
if (!bStatus)
|
|
{
|
|
this->connectOperation.OnIOFailure();
|
|
}
|
|
|
|
return bStatus;
|
|
}
|
|
|
|
AuUInt SocketBase::ToPlatformHandle()
|
|
{
|
|
return this->osHandle_;
|
|
}
|
|
|
|
AuSPtr<ISocketDriver> SocketBase::GetUserDriver()
|
|
{
|
|
return this->pSocketDriver_;
|
|
}
|
|
|
|
AuSPtr<INetWorker> SocketBase::GetLockedWorkerThread()
|
|
{
|
|
return this->pWorker_->SharedFromThis();
|
|
}
|
|
|
|
const NetEndpoint &SocketBase::GetRemoteEndpoint()
|
|
{
|
|
return this->remoteEndpoint_;
|
|
}
|
|
|
|
const NetEndpoint &SocketBase::GetLocalEndpoint()
|
|
{
|
|
return this->localEndpoint_;
|
|
}
|
|
|
|
AuSPtr<ISocketServer> SocketBase::GetSocketServer()
|
|
{
|
|
return AuTryLockMemoryType(this->wpParent_);
|
|
}
|
|
|
|
void SocketBase::ConnectFinished()
|
|
{
|
|
this->UpdateLocalEndpoint();
|
|
this->UpdateRemoteEndpoint();
|
|
DoMain();
|
|
}
|
|
|
|
void SocketBase::DoMain()
|
|
{
|
|
this->socketChannel_.Establish();
|
|
|
|
if (AuExchange(this->bHasConnected_, true))
|
|
{
|
|
return;
|
|
}
|
|
|
|
if (this->bHasErrored_)
|
|
{
|
|
return;
|
|
}
|
|
|
|
if (this->bHasEnded)
|
|
{
|
|
return;
|
|
}
|
|
|
|
auto pDriver = this->pSocketDriver_;
|
|
if (bool(pDriver))
|
|
{
|
|
try
|
|
{
|
|
pDriver->OnEstablish();
|
|
}
|
|
catch (...)
|
|
{
|
|
SysPushErrorCatch();
|
|
this->SendErrorBeginShutdown({});
|
|
}
|
|
}
|
|
|
|
if (this->localEndpoint_.transportProtocol == ETransportProtocol::eProtocolTCP)
|
|
{
|
|
this->ToChannel()->ScheduleOutOfFrameWrite();
|
|
}
|
|
}
|
|
|
|
void SocketBase::ConnectFailed(const NetError &error)
|
|
{
|
|
if (this->ConnectNext())
|
|
{
|
|
return;
|
|
}
|
|
|
|
this->SendErrorNoStream(error);
|
|
}
|
|
|
|
void SocketBase::SendErrorNoStream(const NetError &error)
|
|
{
|
|
if (AuExchange(this->bHasErrored_, true))
|
|
{
|
|
return;
|
|
}
|
|
|
|
this->error_ = error;
|
|
|
|
this->RejectAllListeners();
|
|
|
|
this->socketChannel_.StopTime();
|
|
|
|
if (this->pSocketDriver_)
|
|
{
|
|
try
|
|
{
|
|
this->pSocketDriver_->OnFatalErrorReported(error);
|
|
}
|
|
catch (...)
|
|
{
|
|
SysPushErrorCatch();
|
|
}
|
|
}
|
|
|
|
this->SendFinalize();
|
|
}
|
|
|
|
void SocketBase::SendErrorBeginShutdown(const NetError &error)
|
|
{
|
|
if (AuExchange(this->bHasErrored_, true))
|
|
{
|
|
return;
|
|
}
|
|
|
|
this->error_ = error;
|
|
|
|
this->socketChannel_.StopTime();
|
|
|
|
if (this->pSocketDriver_)
|
|
{
|
|
try
|
|
{
|
|
this->pSocketDriver_->OnFatalErrorReported(error);
|
|
}
|
|
catch (...)
|
|
{
|
|
SysPushErrorCatch();
|
|
}
|
|
}
|
|
|
|
RejectAllListeners();
|
|
|
|
this->Shutdown();
|
|
}
|
|
|
|
void SocketBase::RejectAllListeners()
|
|
{
|
|
AU_LOCK_GUARD(this->socketChannel_.spinLock);
|
|
|
|
for (const auto &pListener : this->socketChannel_.eventListeners)
|
|
{
|
|
pListener->OnRejected();
|
|
}
|
|
|
|
this->socketChannel_.eventListeners.clear();
|
|
}
|
|
|
|
void SocketBase::CompleteAllListeners()
|
|
{
|
|
AU_LOCK_GUARD(this->socketChannel_.spinLock);
|
|
|
|
for (const auto &pListener : this->socketChannel_.eventListeners)
|
|
{
|
|
pListener->OnComplete();
|
|
}
|
|
|
|
this->socketChannel_.eventListeners.clear();
|
|
}
|
|
|
|
void SocketBase::SendOnData()
|
|
{
|
|
auto pReadableBuffer = this->socketChannel_.AsReadableByteBuffer();
|
|
auto pStartOffset = pReadableBuffer ? pReadableBuffer->readPtr : nullptr;
|
|
|
|
if (this->bHasFinalized_)
|
|
{
|
|
if (this->socketChannel_.pRecvProtocol)
|
|
{
|
|
this->socketChannel_.pRecvProtocol->DoTick();
|
|
}
|
|
|
|
this->socketChannel_.ScheduleOutOfFrameWrite();
|
|
return;
|
|
}
|
|
|
|
if (this->socketChannel_.pRecvProtocol)
|
|
{
|
|
this->socketChannel_.pRecvProtocol->DoTick();
|
|
}
|
|
|
|
if (this->pSocketDriver_)
|
|
{
|
|
try
|
|
{
|
|
this->pSocketDriver_->OnStreamUpdated();
|
|
}
|
|
catch (...)
|
|
{
|
|
SysPushErrorCatch();
|
|
this->SendErrorBeginShutdown({});
|
|
}
|
|
}
|
|
|
|
{
|
|
AuList<AuSPtr<ISocketChannelEventListener>> listeners;
|
|
AuList<AuSPtr<ISocketChannelEventListener>> listenersToEvict;
|
|
{
|
|
AU_LOCK_GUARD(this->socketChannel_.spinLock);
|
|
listeners = this->socketChannel_.eventListeners;
|
|
}
|
|
|
|
if (AuTryReserve(listenersToEvict, listeners.size()))
|
|
{
|
|
for (const auto &pListener : listeners)
|
|
{
|
|
if (!pListener->OnData())
|
|
{
|
|
pListener->OnComplete();
|
|
listenersToEvict.push_back(pListener);
|
|
}
|
|
}
|
|
|
|
{
|
|
AU_LOCK_GUARD(this->socketChannel_.spinLock);
|
|
for (const auto &pListener : listenersToEvict)
|
|
{
|
|
AuTryRemove(this->socketChannel_.eventListeners, pListener);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
this->ToChannel()->ScheduleOutOfFrameWrite();
|
|
|
|
auto uHeadDelta = pReadableBuffer ? (pReadableBuffer->readPtr - pStartOffset) : 0;
|
|
this->socketChannel_.GetRecvStatsEx().AddBytes(uHeadDelta);
|
|
|
|
if (auto pServerRecvStats = this->socketChannel_.GetRecvStatsEx2())
|
|
{
|
|
pServerRecvStats->AddBytes(uHeadDelta);
|
|
}
|
|
}
|
|
|
|
const NetError &SocketBase::GetError()
|
|
{
|
|
if (this->error_.netError == ENetworkError::kEnumInvalid)
|
|
{
|
|
// Note: some interfaces will now have a `HasErrorCode` that reports non-fatal errors
|
|
// This should be useful for reporting the disconnect reason when otherwise filtered as not-an-error
|
|
if (this->socketChannel_.inputChannel.pNetReadTransaction &&
|
|
#if defined(AURORA_IS_MODERNNT_DERIVED)
|
|
AuStaticCast<NtAsyncNetworkTransaction>(this->socketChannel_.inputChannel.pNetReadTransaction)->HasErrorCode()
|
|
#else
|
|
this->socketChannel_.inputChannel.pNetReadTransaction->HasFailed()
|
|
#endif
|
|
)
|
|
{
|
|
NetError_SetOsError(this->error_, this->socketChannel_.inputChannel.pNetReadTransaction->GetOSErrorCode());
|
|
}
|
|
else if (this->socketChannel_.outputChannel.pNetWriteTransaction_ &&
|
|
#if defined(AURORA_IS_MODERNNT_DERIVED)
|
|
AuStaticCast<NtAsyncNetworkTransaction>(this->socketChannel_.outputChannel.pNetWriteTransaction_)->HasErrorCode()
|
|
#else
|
|
this->socketChannel_.outputChannel.pNetWriteTransaction_->HasFailed()
|
|
#endif
|
|
)
|
|
{
|
|
NetError_SetOsError(this->error_, this->socketChannel_.outputChannel.pNetWriteTransaction_->GetOSErrorCode());
|
|
}
|
|
}
|
|
|
|
return this->error_;
|
|
}
|
|
|
|
AuSPtr<ISocketChannel> SocketBase::ToChannel()
|
|
{
|
|
return AuSPtr<ISocketChannel>(AuSharedFromThis(), &this->socketChannel_);
|
|
}
|
|
|
|
void SocketBase::ShutdownLite()
|
|
{
|
|
if (this->osHandleOwner_)
|
|
{
|
|
AuStaticCast<AFileHandle>(this->osHandleOwner_)->uOSWriteHandle.Reset();
|
|
}
|
|
}
|
|
|
|
void SocketBase::Destroy()
|
|
{
|
|
if (auto pParent = AuTryLockMemoryType(this->wpParent_))
|
|
{
|
|
this->wpParent2_->OnNotifyChildRemoved(this);
|
|
}
|
|
|
|
this->SendFinalize();
|
|
}
|
|
|
|
INetWorker *SocketBase::ToWorker()
|
|
{
|
|
return this->pWorker_;
|
|
}
|
|
|
|
NetWorker *SocketBase::ToWorkerEx()
|
|
{
|
|
return this->pWorker_;
|
|
}
|
|
|
|
bool SocketBase::SendPreestablish(SocketServer *pServer)
|
|
{
|
|
if (this->bHasPreestablished_)
|
|
{
|
|
return true;
|
|
}
|
|
|
|
if (pServer && pServer->uDefaultInputStreamSize)
|
|
{
|
|
this->socketChannel_.uBytesInputBuffer = pServer->uDefaultInputStreamSize;
|
|
}
|
|
|
|
this->socketChannel_.StartTime();
|
|
|
|
this->socketChannel_.inputChannel.WarmOnEstablish(); // Allocate stream resources, in case we need to start working with the source buffer
|
|
// think: setting up protocol stacks, accessing the base bytebuffer, without the pipe being
|
|
// allocated already
|
|
|
|
if (this->pSocketDriver_)
|
|
{
|
|
try
|
|
{
|
|
return this->pSocketDriver_->OnPreestablish(AuSharedFromThis());
|
|
}
|
|
catch (...)
|
|
{
|
|
SysPushErrorCatch();
|
|
}
|
|
}
|
|
|
|
return this->bHasPreestablished_ = true;
|
|
}
|
|
|
|
void SocketBase::SendEnd()
|
|
{
|
|
if (AuExchange(this->bHasEnded, true))
|
|
{
|
|
return;
|
|
}
|
|
|
|
this->socketChannel_.StopTime();
|
|
|
|
if (this->pSocketDriver_)
|
|
{
|
|
try
|
|
{
|
|
this->pSocketDriver_->OnEnd();
|
|
}
|
|
catch (...)
|
|
{
|
|
SysPushErrorCatch();
|
|
}
|
|
}
|
|
|
|
this->SendFinalize();
|
|
}
|
|
|
|
void SocketBase::SendFinalize()
|
|
{
|
|
if (AuExchange(this->bHasFinalized_, true))
|
|
{
|
|
return;
|
|
}
|
|
|
|
this->socketChannel_.StopTime();
|
|
|
|
this->pWorker_->RemoveSocket(this);
|
|
|
|
if (this->bHasErrored_)
|
|
{
|
|
this->RejectAllListeners();
|
|
}
|
|
else
|
|
{
|
|
this->CompleteAllListeners();
|
|
}
|
|
|
|
if (this->pSocketDriver_)
|
|
{
|
|
try
|
|
{
|
|
this->pSocketDriver_->OnFinalize();
|
|
}
|
|
catch (...)
|
|
{
|
|
SysPushErrorCatch();
|
|
}
|
|
|
|
this->pSocketDriver_.reset();
|
|
}
|
|
|
|
if (auto pParent = AuTryLockMemoryType(this->wpParent_))
|
|
{
|
|
this->wpParent2_->OnNotifyChildRemoved(this);
|
|
}
|
|
|
|
auto pWriteTransaction = this->socketChannel_.outputChannel.ToWriteTransaction();
|
|
|
|
#if defined(AURORA_IS_MODERNNT_DERIVED)
|
|
if (pWriteTransaction)
|
|
{
|
|
AuStaticCast<NtAsyncNetworkTransaction>(pWriteTransaction)->MakeSyncable();
|
|
AuStaticCast<NtAsyncNetworkTransaction>(pWriteTransaction)->ForceNextWriteWait();
|
|
}
|
|
#endif
|
|
|
|
this->SendOnData();
|
|
|
|
this->socketChannel_.Release();
|
|
|
|
#if defined(AURORA_IS_MODERNNT_DERIVED)
|
|
if (this->socketChannel_.inputChannel.pNetReadTransaction)
|
|
{
|
|
AuStaticCast<NtAsyncNetworkTransaction>(this->socketChannel_.inputChannel.pNetReadTransaction)->pSocket = nullptr;
|
|
}
|
|
|
|
this->socketChannel_.inputChannel.pNetReadTransaction.reset();
|
|
|
|
if (pWriteTransaction)
|
|
{
|
|
AuStaticCast<NtAsyncNetworkTransaction>(pWriteTransaction)->pSocket = nullptr;
|
|
}
|
|
|
|
this->socketChannel_.outputChannel.pParent_ = nullptr;
|
|
#else
|
|
// TODO:
|
|
#endif
|
|
|
|
|
|
this->CloseSocket();
|
|
}
|
|
} |