/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: AuNetSocket.cpp Date: 2022-8-16 Author: Reece ***/ #include "Networking.hpp" #include "AuNetSocket.hpp" #include "AuNetEndpoint.hpp" #include "AuNetWorker.hpp" #include "AuIPAddress.hpp" #include "AuNetError.hpp" #include "AuNetSocketServer.hpp" #include "AuNetInterface.hpp" #if defined(AURORA_IS_MODERNNT_DERIVED) #include "AuNetStream.NT.hpp" #endif namespace Aurora::IO::Net { static NetError GetLastNetError() { NetError error; NetError_SetCurrent(error); return error; } SocketBase::SocketBase(struct NetInterface *pInterface, struct NetWorker *pWorker, const AuSPtr &pSocketDriver, AuUInt osHandle) : connectOperation(this), socketChannel_(this), pInterface_(pInterface), pWorker_(pWorker), pSocketDriver_(pSocketDriver), osHandle_(osHandle) { this->pWorker_->AddSocket(this); this->osHandleOwner_ = AuMakeShared(); if (!this->osHandle_) { return; } #if defined(AURORA_IS_MODERNNT_DERIVED) //this->osHandleOwner_->Init((HANDLE)this->osHandle_, (HANDLE)this->osHandle_); #else this->osHandleOwner_->Init((int)this->osHandle_, (int)this->osHandle_); #endif } SocketBase::SocketBase(struct NetInterface *pInterface, struct NetWorker *pWorker, const AuSPtr &pSocketDriver, const NetEndpoint &endpoint) : connectOperation(this), socketChannel_(this), pInterface_(pInterface), pWorker_(pWorker), pSocketDriver_(pSocketDriver), remoteEndpoint_(endpoint) { this->osHandleOwner_ = AuMakeShared(); if (!this->osHandle_) { return; } this->pWorker_->AddSocket(this); } SocketBase::SocketBase(struct NetInterface *pInterface, struct NetWorker *pWorker, const AuSPtr &pSocketDriver, const AuPair &endpoint, AuNet::ETransportProtocol eProtocol) : connectOperation(this), socketChannel_(this), pInterface_(pInterface), pWorker_(pWorker), pSocketDriver_(pSocketDriver) { auto &[host, uPort] = endpoint; if (host.type == AuNet::EHostnameType::eHostByIp) { this->remoteEndpoint_.ip = host.address; this->remoteEndpoint_.uPort = uPort; this->remoteEndpoint_.transportProtocol = eProtocol; OptimizeEndpoint(this->remoteEndpoint_); } else { this->resolveLater = host.hostname; this->remoteEndpoint_.uPort = uPort; this->remoteEndpoint_.transportProtocol = eProtocol; this->connectMany_.uPort = uPort; this->connectMany_.protocol = eProtocol; } this->osHandleOwner_ = AuMakeShared(); if (!this->osHandle_) { return; } this->pWorker_->AddSocket(this); } SocketBase::SocketBase(NetInterface *pInterface, NetWorker *pWorker, const AuSPtr &pSocketDriver, const NetSocketConnectMany &connectMany) : connectOperation(this), socketChannel_(this), pInterface_(pInterface), pWorker_(pWorker), pSocketDriver_(pSocketDriver), bHasRemoteMany_(true), connectMany_(connectMany) { this->connectMany_.pDriver.reset(); this->osHandleOwner_ = AuMakeShared(); if (!this->osHandle_) { return; } } SocketBase::~SocketBase() { this->Destroy(); } bool SocketBase::IsValid() { return (this->resolveLater.size()) || bool(this->osHandleOwner_) && bool(this->connectOperation.IsValid()) && bool(this->osHandle_ != 0) && bool(this->osHandle_ != -1) && bool(!this->bForceFailConstruct_); } bool SocketBase::TryStartResolve() { auto pThat = this->SharedFromThis(); if (this->bResolving_) { return true; } this->bResolving_ = true; auto address = this->resolveLater; this->resolveLater.clear(); auto pResolver = this->pInterface_->GetResolveService()->SimpleAllResolve(address, AuMakeSharedThrow, AuNet::NetError>>( [=](const AuSPtr> &ips) { pThat->bResolving_ = false; pThat->connectMany_.uPort = pThat->remoteEndpoint_.uPort; pThat->connectMany_.ips.insert(pThat->connectMany_.ips.end(), ips->begin(), ips->end()); pThat->bHasRemoteMany_ = true; pThat->RenewSocket(); pThat->ConnectNext(); }, [=](const AuSPtr &error) { pThat->SendErrorNoStream(error ? *error.get() : AuNet::NetError {}); })); return bool(pResolver); } bool SocketBase::ConnectNext() { if (this->connectMany_.ips.empty()) { return false; } auto topLevelEntry = this->connectMany_.ips[0]; this->connectMany_.ips.erase(this->connectMany_.ips.begin()); NetEndpoint endpoint; endpoint.uPort = this->connectMany_.uPort; endpoint.ip = topLevelEntry; endpoint.transportProtocol = this->connectMany_.protocol; return this->Connect(endpoint); } bool SocketBase::Connect(const NetEndpoint &endpoint) { if (!this->IsValid()) { this->SendErrorNoStream({}); return false; } this->remoteEndpoint_ = endpoint; this->endpointSize_ = OptimizeEndpoint(this->remoteEndpoint_); this->localEndpoint_.transportProtocol = this->remoteEndpoint_.transportProtocol; if (!this->endpointSize_) { SysPushErrorIO("Invalid remote endpoint"); return false; } bool bStatus = this->ConnectOverlapped() || this->ConnectNonblocking() || this->ConnectBlocking(); if (!bStatus) { this->connectOperation.OnIOFailure(); } return bStatus; } AuUInt SocketBase::ToPlatformHandle() { return this->osHandle_; } AuSPtr SocketBase::GetUserDriver() { return this->pSocketDriver_; } const NetEndpoint &SocketBase::GetRemoteEndpoint() { return this->remoteEndpoint_; } const NetEndpoint &SocketBase::GetLocalEndpoint() { return this->localEndpoint_; } void SocketBase::ConnectFinished() { this->UpdateLocalEndpoint(); this->UpdateRemoteEndpoint(); DoMain(); } void SocketBase::DoMain() { this->socketChannel_.Establish(); if (AuExchange(this->bHasConnected_, true)) { return; } if (this->bHasErrored_) { return; } if (this->bHasEnded) { return; } auto pDriver = this->pSocketDriver_; if (bool(pDriver)) { try { pDriver->OnEstablish(); } catch (...) { SysPushErrorCatch(); this->SendErrorBeginShutdown({}); } } this->ToChannel()->ScheduleOutOfFrameWrite(); } void SocketBase::ConnectFailed(const NetError &error) { if (this->ConnectNext()) { return; } this->SendErrorNoStream(error); } void SocketBase::SendErrorNoStream(const NetError &error) { if (AuExchange(this->bHasErrored_, true)) { return; } this->error_ = error; this->RejectAllListeners(); if (this->pSocketDriver_) { try { this->pSocketDriver_->OnFatalErrorReported(error); } catch (...) { SysPushErrorCatch(); } } this->SendFinalize(); } void SocketBase::SendErrorBeginShutdown(const NetError &error) { if (AuExchange(this->bHasErrored_, true)) { return; } this->error_ = error; if (this->pSocketDriver_) { try { this->pSocketDriver_->OnFatalErrorReported(error); } catch (...) { SysPushErrorCatch(); } } RejectAllListeners(); this->Shutdown(); } void SocketBase::RejectAllListeners() { AU_LOCK_GUARD(this->socketChannel_.spinLock); for (const auto &pListener : this->socketChannel_.eventListeners) { pListener->OnRejected(); } this->socketChannel_.eventListeners.clear(); } void SocketBase::CompleteAllListeners() { AU_LOCK_GUARD(this->socketChannel_.spinLock); for (const auto &pListener : this->socketChannel_.eventListeners) { pListener->OnComplete(); } this->socketChannel_.eventListeners.clear(); } void SocketBase::SendOnData() { auto pReadableBuffer = this->socketChannel_.AsReadableByteBuffer(); auto pStartOffset = pReadableBuffer ? pReadableBuffer->readPtr : nullptr; if (this->bHasFinalized_) { if (this->socketChannel_.pRecvProtocol) { this->socketChannel_.pRecvProtocol->DoTick(); } this->socketChannel_.ScheduleOutOfFrameWrite(); return; } if (this->socketChannel_.pRecvProtocol) { this->socketChannel_.pRecvProtocol->DoTick(); } if (this->pSocketDriver_) { try { this->pSocketDriver_->OnStreamUpdated(); } catch (...) { SysPushErrorCatch(); this->SendErrorBeginShutdown({}); } } { AuList> listeners; AuList> listenersToEvict; { AU_LOCK_GUARD(this->socketChannel_.spinLock); listeners = this->socketChannel_.eventListeners; } if (AuTryReserve(listenersToEvict, listeners.size())) { for (const auto &pListener : listeners) { if (!pListener->OnData()) { pListener->OnComplete(); listenersToEvict.push_back(pListener); } } { AU_LOCK_GUARD(this->socketChannel_.spinLock); for (const auto &pListener : listenersToEvict) { AuTryRemove(this->socketChannel_.eventListeners, pListener); } } } } this->ToChannel()->ScheduleOutOfFrameWrite(); auto uHeadDelta = pReadableBuffer ? (pReadableBuffer->readPtr - pStartOffset) : 0; this->socketChannel_.GetRecvStatsEx().AddBytes(uHeadDelta); } const NetError &SocketBase::GetError() { if (this->error_.netError == ENetworkError::kEnumInvalid) { // Note: some interfaces will now have a `HasErrorCode` that reports non-fatal errors // This should be useful for reporting the disconnect reason when otherwise filtered as not-an-error if (this->socketChannel_.inputChannel.pNetReadTransaction && #if defined(AURORA_IS_MODERNNT_DERIVED) AuStaticCast(this->socketChannel_.inputChannel.pNetReadTransaction)->HasErrorCode() #else this->socketChannel_.inputChannel.pNetReadTransaction->Failed() #endif ) { NetError_SetOsError(this->error_, this->socketChannel_.inputChannel.pNetReadTransaction->GetOSErrorCode()); } else if (this->socketChannel_.outputChannel.pNetWriteTransaction_ && #if defined(AURORA_IS_MODERNNT_DERIVED) AuStaticCast(this->socketChannel_.outputChannel.pNetWriteTransaction_)->HasErrorCode() #else this->socketChannel_.outputChannel.pNetWriteTransaction_->Failed() #endif ) { NetError_SetOsError(this->error_, this->socketChannel_.outputChannel.pNetWriteTransaction_->GetOSErrorCode()); } } return this->error_; } AuSPtr SocketBase::ToChannel() { return AuSPtr(AuSharedFromThis(), &this->socketChannel_); } void SocketBase::ShutdownLite() { if (this->osHandleOwner_) { this->osHandleOwner_->bWriteLock = true; } } void SocketBase::Destroy() { this->SendFinalize(); } INetWorker *SocketBase::ToWorker() { return this->pWorker_; } NetWorker *SocketBase::ToWorkerEx() { return this->pWorker_; } bool SocketBase::SendPreestablish(SocketServer *pServer) { if (this->bHasPreestablished_) { return true; } if (pServer && pServer->uDefaultInputStreamSize) { this->socketChannel_.uBytesInputBuffer = pServer->uDefaultInputStreamSize; } this->socketChannel_.inputChannel.WarmOnEstablish(); // Allocate stream resources, in case we need to start working with the source buffer // think: setting up protocol stacks, accessing the base bytebuffer, without the pipe being // allocated already if (this->pSocketDriver_) { try { return this->pSocketDriver_->OnPreestablish(AuSharedFromThis()); } catch (...) { SysPushErrorCatch(); } } return this->bHasPreestablished_ = true; } void SocketBase::SendEnd() { if (AuExchange(this->bHasEnded, true)) { return; } if (this->pSocketDriver_) { try { this->pSocketDriver_->OnEnd(); } catch (...) { SysPushErrorCatch(); } } this->SendFinalize(); } void SocketBase::SendFinalize() { if (AuExchange(this->bHasFinalized_, true)) { return; } this->pWorker_->RemoveSocket(this); if (this->bHasErrored_) { this->RejectAllListeners(); } else { this->CompleteAllListeners(); } if (this->pSocketDriver_) { try { this->pSocketDriver_->OnFinalize(); } catch (...) { SysPushErrorCatch(); } this->pSocketDriver_.reset(); } auto pWriteTransaction = this->socketChannel_.outputChannel.ToWriteTransaction(); #if defined(AURORA_IS_MODERNNT_DERIVED) AuStaticCast(pWriteTransaction)->MakeSyncable(); AuStaticCast(pWriteTransaction)->ForceNextWriteWait(); #endif this->SendOnData(); this->socketChannel_.pRecvProtocol.reset(); this->socketChannel_.pSendProtocol.reset(); this->socketChannel_.inputChannel.pNetReader.reset(); this->socketChannel_.inputChannel.pNetReadTransaction->Reset(); this->CloseSocket(); } }