/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: AuNetSocket.cpp Date: 2022-8-16 Author: Reece ***/ #include "Networking.hpp" #include "AuNetSocket.hpp" #include "AuNetEndpoint.hpp" #include "AuNetWorker.hpp" #include "AuIPAddress.hpp" #include "AuNetError.hpp" #include "AuNetSocketServer.hpp" #if defined(AURORA_IS_MODERNNT_DERIVED) #include "AuNetStream.NT.hpp" #endif namespace Aurora::IO::Net { static NetError GetLastNetError() { NetError error; NetError_SetCurrent(error); return error; } SocketBase::SocketBase(struct NetInterface *pInterface, struct NetWorker *pWorker, const AuSPtr &pSocketDriver, AuUInt osHandle) : connectOperation(this), socketChannel_(this), pInterface_(pInterface), pWorker_(pWorker), pSocketDriver_(pSocketDriver), osHandle_(osHandle) { this->pWorker_->AddSocket(this); this->osHandleOwner_ = AuMakeShared(); if (!this->osHandle_) { return; } #if defined(AURORA_IS_MODERNNT_DERIVED) //this->osHandleOwner_->Init((HANDLE)this->osHandle_, (HANDLE)this->osHandle_); #else this->osHandleOwner_->Init((int)this->osHandle_, (int)this->osHandle_); #endif } SocketBase::SocketBase(struct NetInterface *pInterface, struct NetWorker *pWorker, const AuSPtr &pSocketDriver, const NetEndpoint &endpoint) : connectOperation(this), socketChannel_(this), pInterface_(pInterface), pWorker_(pWorker), pSocketDriver_(pSocketDriver), remoteEndpoint_(endpoint) { this->osHandleOwner_ = AuMakeShared(); if (!this->osHandle_) { return; } this->pWorker_->AddSocket(this); } SocketBase::SocketBase(NetInterface *pInterface, NetWorker *pWorker, const AuSPtr &pSocketDriver, const NetSocketConnectMany &connectMany) : connectOperation(this), socketChannel_(this), pInterface_(pInterface), pWorker_(pWorker), pSocketDriver_(pSocketDriver), bHasRemoteMany_(true), connectMany_(connectMany) { this->connectMany_.pDriver.reset(); this->osHandleOwner_ = AuMakeShared(); if (!this->osHandle_) { return; } } SocketBase::~SocketBase() { this->Destroy(); } bool SocketBase::IsValid() { return bool(this->osHandleOwner_) && bool(this->connectOperation.IsValid()) && bool(this->osHandle_ != 0) && bool(this->osHandle_ != -1) && bool(!this->bForceFailConstruct_); } bool SocketBase::ConnectNext() { if (this->connectMany_.ips.empty()) { return false; } auto topLevelEntry = this->connectMany_.ips[0]; this->connectMany_.ips.erase(this->connectMany_.ips.begin()); NetEndpoint endpoint; endpoint.uPort = this->connectMany_.uPort; endpoint.ip = topLevelEntry; endpoint.transportProtocol = this->connectMany_.protocol; return this->Connect(endpoint); } bool SocketBase::Connect(const NetEndpoint &endpoint) { this->remoteEndpoint_ = endpoint; this->endpointSize_ = OptimizeEndpoint(this->remoteEndpoint_); this->localEndpoint_.transportProtocol = this->remoteEndpoint_.transportProtocol; if (!this->endpointSize_) { SysPushErrorIO("Invalid remote endpoint"); return false; } bool bStatus = this->ConnectOverlapped() || this->ConnectNonblocking() || this->ConnectBlocking(); if (!bStatus) { this->connectOperation.OnIOFailure(); } return bStatus; } AuUInt SocketBase::ToPlatformHandle() { return this->osHandle_; } AuSPtr SocketBase::GetUserDriver() { return this->pSocketDriver_; } const NetEndpoint &SocketBase::GetRemoteEndpoint() { return this->remoteEndpoint_; } const NetEndpoint &SocketBase::GetLocalEndpoint() { return this->localEndpoint_; } void SocketBase::ConnectFinished() { this->UpdateLocalEndpoint(); this->UpdateRemoteEndpoint(); DoMain(); } void SocketBase::DoMain() { this->socketChannel_.Establish(); if (AuExchange(this->bHasConnected_, true)) { return; } if (this->bHasErrored_) { return; } if (this->bHasEnded) { return; } auto pDriver = this->pSocketDriver_; if (bool(pDriver)) { try { pDriver->OnEstablish(); } catch (...) { SysPushErrorCatch(); this->SendErrorBeginShutdown({}); } } this->ToChannel()->ScheduleOutOfFrameWrite(); } void SocketBase::ConnectFailed(const NetError &error) { if (this->ConnectNext()) { return; } this->SendErrorNoStream(error); } void SocketBase::SendErrorNoStream(const NetError &error) { if (AuExchange(this->bHasErrored_, true)) { return; } this->error_ = error; if (this->pSocketDriver_) { try { this->pSocketDriver_->OnFatalErrorReported(error); } catch (...) { SysPushErrorCatch(); } } this->SendFinalize(); } void SocketBase::SendErrorBeginShutdown(const NetError &error) { if (AuExchange(this->bHasErrored_, true)) { return; } this->error_ = error; if (this->pSocketDriver_) { try { this->pSocketDriver_->OnFatalErrorReported(error); } catch (...) { SysPushErrorCatch(); } } this->Shutdown(); } void SocketBase::SendOnData() { auto pReadableBuffer = this->socketChannel_.AsReadableByteBuffer(); auto pStartOffset = pReadableBuffer->readPtr; if (this->bHasFinalized_) { if (this->socketChannel_.pRecvProtocol) { this->socketChannel_.pRecvProtocol->DoTick(); } this->socketChannel_.ScheduleOutOfFrameWrite(); return; } if (this->socketChannel_.pRecvProtocol) { this->socketChannel_.pRecvProtocol->DoTick(); } if (this->pSocketDriver_) { try { this->pSocketDriver_->OnStreamUpdated(); } catch (...) { SysPushErrorCatch(); this->SendErrorBeginShutdown({}); } } this->ToChannel()->ScheduleOutOfFrameWrite(); auto uHeadDelta = pReadableBuffer->readPtr - pStartOffset; this->socketChannel_.GetRecvStatsEx().AddBytes(uHeadDelta); } const NetError &SocketBase::GetError() { return this->error_; } AuSPtr SocketBase::ToChannel() { return AuSPtr(AuSharedFromThis(), &this->socketChannel_); } void SocketBase::ShutdownLite() { if (this->osHandleOwner_) { this->osHandleOwner_->bWriteLock = true; } } void SocketBase::Destroy() { this->SendFinalize(); } INetWorker *SocketBase::ToWorker() { return this->pWorker_; } NetWorker *SocketBase::ToWorkerEx() { return this->pWorker_; } bool SocketBase::SendPreestablish(SocketServer *pServer) { if (pServer && pServer->uDefaultInputStreamSize) { this->socketChannel_.uBytesInputBuffer = pServer->uDefaultInputStreamSize; } this->socketChannel_.inputChannel.WarmOnEstablish(); // Allocate stream resources, in case we need to start working with the source buffer // think: setting up protocol stacks, accessing the base bytebuffer, without the pipe being // allocated already if (this->pSocketDriver_) { try { return this->pSocketDriver_->OnPreestablish(AuSharedFromThis()); } catch (...) { SysPushErrorCatch(); } } return true; } void SocketBase::SendEnd() { if (AuExchange(this->bHasEnded, true)) { return; } if (this->pSocketDriver_) { try { this->pSocketDriver_->OnEnd(); } catch (...) { SysPushErrorCatch(); } } this->SendFinalize(); } void SocketBase::SendFinalize() { if (AuExchange(this->bHasFinalized_, true)) { return; } this->pWorker_->RemoveSocket(this); if (this->pSocketDriver_) { try { this->pSocketDriver_->OnFinalize(); } catch (...) { SysPushErrorCatch(); } this->pSocketDriver_.reset(); } auto pWriteTransaction = this->socketChannel_.outputChannel.ToWriteTransaction(); #if defined(AURORA_IS_MODERNNT_DERIVED) AuStaticCast(pWriteTransaction)->MakeSyncable(); AuStaticCast(pWriteTransaction)->ForceNextWriteWait(); #endif this->SendOnData(); this->socketChannel_.pRecvProtocol.reset(); this->socketChannel_.pSendProtocol.reset(); this->socketChannel_.inputChannel.pNetReader.reset(); this->socketChannel_.inputChannel.pNetReadTransaction->Reset(); this->CloseSocket(); } }