/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: AuNetStream.NT.cpp Date: 2022-8-17 Author: Reece ***/ #include "Networking.hpp" #include "AuNetStream.NT.hpp" #include "AuNetSocket.hpp" #include "AuNetWorker.hpp" #include "AuNetEndpoint.hpp" #include namespace Aurora::IO::Net { NtAsyncNetworkTransaction::NtAsyncNetworkTransaction(SocketBase *pSocket) : pSocket(pSocket) { if (this->pSocket->GetRemoteEndpoint().transportProtocol == ETransportProtocol::eProtocolTCP) { this->dwRecvFlags = MSG_PARTIAL; } } NtAsyncNetworkTransaction::~NtAsyncNetworkTransaction() { Reset(); } static void __stdcall WsaOverlappedCompletionRoutine(DWORD dwErrorCode, DWORD cbTransferred, LPWSAOVERLAPPED lpOverlapped, DWORD dwFlags ) { auto transaction = reinterpret_cast(reinterpret_cast(lpOverlapped) - offsetof(NtAsyncNetworkTransaction, overlap)); auto hold = AuExchange(transaction->pPin, {}); if (dwErrorCode) { hold->bHasFailed = true; hold->dwOsErrorCode = dwErrorCode; } else if (!hold->dwLastAbstractStat) { return; } SetEvent(lpOverlapped->hEvent); auto pHold1 = AuExchange(hold->readView, {}); auto pHold2 = AuExchange(hold->writeView, {}); hold->CompleteEx(cbTransferred, true); } bool NtAsyncNetworkTransaction::StartRead(AuUInt64 offset, const AuMemoryViewWrite &memoryView) { if (this->bDisallowRecv) { SysPushErrorIO("Recv isn't allowed"); return false; } if (!this->pSocket) { return false; } if (this->pSocket->bHasEnded) { return false; } if (this->bIsIrredeemable) { SysPushErrorIO("Transaction was signaled to be destroyed to reset mid synchronizable operation. You can no longer use this stream object"); return false; } if (!IDontWannaUsePorts()) { return false; } if (!memoryView) { SysPushErrorArg(); return {}; } if (this->readView || this->writeView) { SysPushErrorIO("IO Operation in progress"); return {}; } this->bLatch = false; this->writeView = memoryView; this->bHasFailed = false; this->dwLastAbstractStat = memoryView.length; this->dwLastAbstractOffset = offset; this->dwLastBytes = 0; this->overlap.Offset = AuBitsToLower(offset); this->overlap.OffsetHigh = AuBitsToHigher(offset); this->bIsWriting = false; WSABUF bufferArray[] { { (ULONG)writeView.length, writeView.Begin() } }; this->overlap.hEvent = this->GetAlertable(); DWORD todo { this->dwRecvFlags }; int ret; if (!this->bDatagramMode) { if (!pWSARecv) { return this->TranslateLastError(false, true); } ret = pWSARecv(this->GetSocket(), bufferArray, 1, NULL, &todo, &this->overlap, WsaOverlappedCompletionRoutine); } else { if (!pWSARecvFrom) { return this->TranslateLastError(false, true); } this->iSocketLength = this->pSocket->endpointSize_; ret = pWSARecvFrom(this->GetSocket(), bufferArray, 1, NULL, &todo, (sockaddr *)netEndpoint.hint, &this->iSocketLength, &this->overlap, WsaOverlappedCompletionRoutine); } return this->TranslateLastError(ret != -1); } bool NtAsyncNetworkTransaction::StartWrite(AuUInt64 offset, const AuMemoryViewRead &memoryView) { if (this->bDisallowSend) { SysPushErrorIO("Send isn't allowed"); return false; } if (this->bIsIrredeemable) { SysPushErrorIO("Transaction was signaled to be destroyed to reset mid synchronizable operation. You can no longer use this stream object"); return false; } if (!IDontWannaUsePorts()) { return false; } if (!this->pSocket) { return false; } if (this->pSocket->bHasEnded) { return false; } if (!memoryView) { SysPushErrorArg(); return {}; } if (this->readView || this->writeView) { SysPushErrorIO("IO Operation in progress"); return {}; } this->bLatch = false; this->readView = memoryView; this->bHasFailed = false; this->dwLastAbstractStat = memoryView.length; this->dwLastAbstractOffset = offset; this->dwLastBytes = 0; this->bIsWriting = true; this->overlap.Offset = AuBitsToLower(offset); this->overlap.OffsetHigh = AuBitsToHigher(offset); WSABUF bufferArray[] { { (ULONG)memoryView.length, (CHAR *)memoryView.Begin() } }; this->overlap.hEvent = this->GetAlertable(); int ret; if (!this->bDatagramMode) { if (!pWSASend) { return this->TranslateLastError(false, true); } ret = pWSASend(this->GetSocket(), bufferArray, 1, NULL, this->dwRecvFlags, &this->overlap, WsaOverlappedCompletionRoutine); } else { if (!pWSASendTo) { return this->TranslateLastError(false, true); } this->iSocketLength = this->pSocket->endpointSize_; ret = pWSASendTo(this->GetSocket(), bufferArray, 1, NULL, this->dwRecvFlags, (sockaddr *)netEndpoint.hint, EndpointToLength(netEndpoint), &this->overlap, WsaOverlappedCompletionRoutine); } return this->TranslateLastError(ret != -1); } bool NtAsyncNetworkTransaction::TranslateLastError(bool bReturnValue, bool bForceFail) { auto er = bForceFail ? ERROR_IO_PREEMPTED : pWSAGetLastError(); if (bReturnValue) { return true; } else if (er == ERROR_IO_PENDING) { if (AuExchange(this->bForceNextWait, false)) { SysAssertDbg(this->pWaitable, "missing waitable for wait next"); this->pWaitable->WaitOn(0); } return true; } else if ((er == ERROR_HANDLE_EOF) || (er == WSAECONNABORTED) || (er == WSAECONNRESET) || (er == WSAEDISCON)) { this->dwLastAbstractStat = 0; this->bHasFailed = true; // to pass completion this->dwOsErrorCode = er;// to suppress actual error condition if (!AuExchange(this->bSendEOSOnce, true)) { DispatchCb(0); } if (this->pSocket) { this->pSocket->SendEnd(); } AuResetMember(this->readView); AuResetMember(this->writeView); this->pPin.reset(); return true; } else { this->pPin.reset(); this->Reset(); this->dwOsErrorCode = er; this->bHasFailed = true; SysPushErrorNet("QoA async FIO error: {} on {}", this->dwOsErrorCode, (int)this->GetSocket()); return false; } } bool NtAsyncNetworkTransaction::HasCompletedForGCWI() { return this->HasCompleted(); } void NtAsyncNetworkTransaction::CleanupForGCWI() { this->overlap.hEvent = INVALID_HANDLE_VALUE; AuResetMember(this->pCompletionGroup_); } bool NtAsyncNetworkTransaction::TryAttachToCompletionGroup(const AuSPtr &pCompletionGroup) { if (bool(this->pCompletionGroup_) || !pCompletionGroup) { return false; } auto pLoopSource = pCompletionGroup->GetTriggerLoopSource(); if (!pLoopSource) { return false; } this->overlap.hEvent = (HANDLE)AuStaticCast(pLoopSource)->GetHandle(); pCompletionGroup->AddWorkItem(this->SharedFromThis()); this->pCompletionGroup_ = pCompletionGroup; return true; } CompletionGroup::ICompletionGroupWorkHandle *NtAsyncNetworkTransaction::ToCompletionGroupHandle() { return this; } AuSPtr NtAsyncNetworkTransaction::GetCompletionGroup() { return this->pCompletionGroup_; } bool NtAsyncNetworkTransaction::Complete() { return CompleteEx(0); } bool NtAsyncNetworkTransaction::CompleteEx(AuUInt completeRoutine, bool bForce) { DWORD read {}; if (!completeRoutine) { if (this->dwLastBytes || this->bHasFailed) { return true; } if (!this->readView && !this->writeView) { if (bForce) { if (!AuExchange(this->bSendEOSOnce, true)) { DispatchCb(0); return true; } } return false; } if (::GetOverlappedResult((HANDLE)this->GetSocket(), &this->overlap, &read, false) && (read || bForce)) { DispatchCb(read); return true; } } else { if ((this->dwOsErrorCode == ERROR_HANDLE_EOF) || (this->dwOsErrorCode == WSAECONNABORTED) || (this->dwOsErrorCode == WSAECONNRESET) || (this->dwOsErrorCode == WSAEDISCON)) { if (!AuExchange(this->bSendEOSOnce, true)) { DispatchCb(0); } } else { DispatchCb(completeRoutine); } return true; } return false; } bool NtAsyncNetworkTransaction::HasFailed() { return this->bHasFailed && this->dwOsErrorCode != ERROR_HANDLE_EOF && this->dwOsErrorCode != WSAECONNABORTED && this->dwOsErrorCode != WSAECONNRESET && this->dwOsErrorCode != WSAEDISCON; } bool NtAsyncNetworkTransaction::HasErrorCode() { return this->bHasFailed && this->dwOsErrorCode != ERROR_HANDLE_EOF; } bool NtAsyncNetworkTransaction::HasCompleted() { return this->bHasFailed || this->dwLastBytes || this->bLatch; } AuUInt NtAsyncNetworkTransaction::GetOSErrorCode() { return this->dwOsErrorCode; } AuUInt32 NtAsyncNetworkTransaction::GetLastPacketLength() { return this->dwLastBytes; } void NtAsyncNetworkTransaction::SetCallback(const AuSPtr &sub) { this->pSub = sub; } bool NtAsyncNetworkTransaction::Wait(AuUInt32 timeout) { return this->bLatch; } AuSPtr NtAsyncNetworkTransaction::NewLoopSource() { AuLoop::ILSEvent *optEvent {}; if (auto pCompletionGroup = this->pCompletionGroup_) { if (auto pLoopSource = pCompletionGroup->GetTriggerLoopSource()) { return pLoopSource; } } if (auto pWaitable = this->pWaitable) { return pWaitable; } if (auto pSocket = this->pSocket) { return AuStaticCast(pSocket->ToWorkerEx()->ToEvent()); } return {}; } void NtAsyncNetworkTransaction::SetBaseOffset(AuUInt64 uBaseOffset) { } void NtAsyncNetworkTransaction::Reset() { if (this->dwLastAbstractStat) { this->bIsIrredeemable = true; this->bHasFailed = true; if (pCancelIoEx) { pCancelIoEx((HANDLE)this->GetSocket(), &this->overlap); } else { ::CancelIo((HANDLE)this->GetSocket()); } this->dwOsErrorCode = ERROR_ABANDONED_WAIT_0; } else { this->bHasFailed = false; } this->dwLastBytes = 0; this->dwLastAbstractStat = 0; } void NtAsyncNetworkTransaction::MakeSyncable() { this->pWaitable = AuLoop::NewLSEventSlow(false, true); SysAssert(this->pWaitable); this->overlap.hEvent = (HANDLE)AuStaticCast(this->pWaitable)->GetHandle(); } void NtAsyncNetworkTransaction::ForceNextWriteWait() { this->bForceNextWait = true; } bool NtAsyncNetworkTransaction::IDontWannaUsePorts() { if (AuExchange(this->pPin, AuSharedFromThis())) { while (SleepEx(0, true) == WAIT_IO_COMPLETION) { } if (AuExchange(this->pPin, AuSharedFromThis())) { SysPushErrorUnavailableError(); return {}; } } return true; } void NtAsyncNetworkTransaction::DispatchCb(AuUInt32 read) { if (this->bIsWriting) { if (read != this->dwLastAbstractStat) { this->dwOsErrorCode = WSAEMSGSIZE; this->bHasFailed = true; } } this->dwLastAbstractStat = 0; this->dwLastBytes = read; if (AuExchange(this->bLatch, true)) { return; } if (this->pSub) { this->RoutineCallCB(read); } } #if defined(__AUHAS_COROUTINES_CO_AWAIT) && defined(AU_LANG_CPP_20_) AuVoidTask NtAsyncNetworkTransaction::RoutineCallCB_(AuUInt32 len) { if (this->pSub) { try { this->pSub->OnAsyncFileOpFinished(this->dwLastAbstractOffset, len); } catch (...) { SysPushErrorCatch(); } } co_return; } #endif void NtAsyncNetworkTransaction::RoutineCallCB(AuUInt32 len) { #if defined(__AUHAS_COROUTINES_CO_AWAIT) && defined(AU_LANG_CPP_20_) if (gRuntimeConfig.ioConfig.bAPCUseCoroutineStack) { this->RoutineCallCB_(len); return; } #endif if (this->pSub) { this->pSub->OnAsyncFileOpFinished(this->dwLastAbstractOffset, len); } } SOCKET NtAsyncNetworkTransaction::GetSocket() { return this->pSocket ? this->pSocket->ToPlatformHandle() : SOCKET {}; } HANDLE NtAsyncNetworkTransaction::GetAlertable() { if (this->pCompletionGroup_) { auto pLoopSource = this->pCompletionGroup_->GetTriggerLoopSource(); if (!pLoopSource) { return INVALID_HANDLE_VALUE; } return (HANDLE)AuStaticCast(pLoopSource)->GetHandle(); } else { if (auto pSocket = this->pSocket) { return (HANDLE)AuStaticCast(pSocket->ToWorkerEx()->ToEvent())->GetHandle(); } else { return INVALID_HANDLE_VALUE; } } } }