AuroraRuntime/Source/IO/Net/AuNetStream.NT.cpp
2024-01-06 04:32:54 +00:00

587 lines
16 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuNetStream.NT.cpp
Date: 2022-8-17
Author: Reece
***/
#include "Networking.hpp"
#include "AuNetStream.NT.hpp"
#include "AuNetSocket.hpp"
#include "AuNetWorker.hpp"
#include "AuNetEndpoint.hpp"
#include <Source/IO/Loop/LSEvent.hpp>
namespace Aurora::IO::Net
{
NtAsyncNetworkTransaction::NtAsyncNetworkTransaction(SocketBase *pSocket) :
pSocket(pSocket)
{
if (this->pSocket->GetRemoteEndpoint().transportProtocol == ETransportProtocol::eProtocolTCP)
{
this->dwRecvFlags = MSG_PARTIAL;
}
}
NtAsyncNetworkTransaction::~NtAsyncNetworkTransaction()
{
Reset();
}
static void __stdcall WsaOverlappedCompletionRoutine(DWORD dwErrorCode,
DWORD cbTransferred,
LPWSAOVERLAPPED lpOverlapped,
DWORD dwFlags
)
{
auto transaction = reinterpret_cast<NtAsyncNetworkTransaction *>(reinterpret_cast<AuUInt8 *>(lpOverlapped) - offsetof(NtAsyncNetworkTransaction, overlap));
auto hold = AuExchange(transaction->pPin, {});
if (dwErrorCode)
{
hold->bHasFailed = true;
hold->dwOsErrorCode = dwErrorCode;
}
else if (!hold->dwLastAbstractStat)
{
return;
}
SetEvent(lpOverlapped->hEvent);
auto pHold = AuExchange(hold->pMemoryHold, {});
hold->CompleteEx(cbTransferred, true);
}
bool NtAsyncNetworkTransaction::StartRead(AuUInt64 offset, const AuSPtr<AuMemoryViewWrite> &memoryView)
{
if (this->bDisallowRecv)
{
SysPushErrorIO("Recv isn't allowed");
return false;
}
if (!this->pSocket)
{
return false;
}
if (this->pSocket->bHasEnded)
{
return false;
}
if (this->bIsIrredeemable)
{
SysPushErrorIO("Transaction was signaled to be destroyed to reset mid synchronizable operation. You can no longer use this stream object");
return false;
}
if (!IDontWannaUsePorts())
{
return false;
}
if (!memoryView)
{
SysPushErrorArg();
return {};
}
if (this->pMemoryHold)
{
SysPushErrorIO("IO Operation in progress");
return {};
}
this->bLatch = false;
this->pMemoryHold = memoryView;
this->bHasFailed = false;
this->dwLastAbstractStat = memoryView->length;
this->dwLastAbstractOffset = offset;
this->dwLastBytes = 0;
this->overlap.Offset = AuBitsToLower<AuUInt64>(offset);
this->overlap.OffsetHigh = AuBitsToHigher<AuUInt64>(offset);
this->bIsWriting = false;
WSABUF bufferArray[] {
{
(ULONG)memoryView->length,
memoryView->Begin<CHAR>()
}
};
this->overlap.hEvent = this->GetAlertable();
DWORD todo { this->dwRecvFlags };
int ret;
if (!this->bDatagramMode)
{
if (!pWSARecv)
{
return this->TranslateLastError(false, true);
}
ret = pWSARecv(this->GetSocket(),
bufferArray,
1,
NULL,
&todo,
&this->overlap,
WsaOverlappedCompletionRoutine);
}
else
{
if (!pWSARecvFrom)
{
return this->TranslateLastError(false, true);
}
this->iSocketLength = this->pSocket->endpointSize_;
ret = pWSARecvFrom(this->GetSocket(),
bufferArray,
1,
NULL,
&todo,
(sockaddr *)netEndpoint.hint,
&this->iSocketLength,
&this->overlap,
WsaOverlappedCompletionRoutine);
}
return this->TranslateLastError(ret != -1);
}
bool NtAsyncNetworkTransaction::StartWrite(AuUInt64 offset, const AuSPtr<AuMemoryViewRead> &memoryView)
{
if (this->bDisallowSend)
{
SysPushErrorIO("Send isn't allowed");
return false;
}
if (this->bIsIrredeemable)
{
SysPushErrorIO("Transaction was signaled to be destroyed to reset mid synchronizable operation. You can no longer use this stream object");
return false;
}
if (!IDontWannaUsePorts())
{
return false;
}
if (!this->pSocket)
{
return false;
}
if (this->pSocket->bHasEnded)
{
return false;
}
if (!memoryView)
{
SysPushErrorArg();
return {};
}
if (this->pMemoryHold)
{
SysPushErrorIO("IO Operation in progress");
return {};
}
this->bLatch = false;
this->pMemoryHold = memoryView;
this->bHasFailed = false;
this->dwLastAbstractStat = memoryView->length;
this->dwLastAbstractOffset = offset;
this->dwLastBytes = 0;
this->bIsWriting = true;
this->overlap.Offset = AuBitsToLower<AuUInt64>(offset);
this->overlap.OffsetHigh = AuBitsToHigher<AuUInt64>(offset);
WSABUF bufferArray[] {
{
(ULONG)memoryView->length,
(CHAR *)memoryView->Begin<CHAR>()
}
};
this->overlap.hEvent = this->GetAlertable();
int ret;
if (!this->bDatagramMode)
{
if (!pWSASend)
{
return this->TranslateLastError(false, true);
}
ret = pWSASend(this->GetSocket(),
bufferArray,
1,
NULL,
this->dwRecvFlags,
&this->overlap,
WsaOverlappedCompletionRoutine);
}
else
{
if (!pWSASendTo)
{
return this->TranslateLastError(false, true);
}
this->iSocketLength = this->pSocket->endpointSize_;
ret = pWSASendTo(this->GetSocket(),
bufferArray,
1,
NULL,
this->dwRecvFlags,
(sockaddr *)netEndpoint.hint,
EndpointToLength(netEndpoint),
&this->overlap,
WsaOverlappedCompletionRoutine);
}
return this->TranslateLastError(ret != -1);
}
bool NtAsyncNetworkTransaction::TranslateLastError(bool bReturnValue, bool bForceFail)
{
auto er = bForceFail ?
ERROR_IO_PREEMPTED :
pWSAGetLastError();
if (bReturnValue)
{
return true;
}
else if (er == ERROR_IO_PENDING)
{
if (AuExchange(this->bForceNextWait, false))
{
SysAssertDbg(this->pWaitable, "missing waitable for wait next");
this->pWaitable->WaitOn(0);
}
return true;
}
else if ((er == ERROR_HANDLE_EOF) ||
(er == WSAECONNABORTED) ||
(er == WSAECONNRESET) ||
(er == WSAEDISCON))
{
this->dwLastAbstractStat = 0;
this->bHasFailed = true; // to pass completion
this->dwOsErrorCode = er;// to suppress actual error condition
if (!AuExchange(this->bSendEOSOnce, true))
{
DispatchCb(0);
}
if (this->pSocket)
{
this->pSocket->SendEnd();
}
this->pMemoryHold.reset();
this->pPin.reset();
return true;
}
else
{
this->pPin.reset();
this->Reset();
this->dwOsErrorCode = er;
this->bHasFailed = true;
SysPushErrorNet("QoA async FIO error: {} on {}", this->dwOsErrorCode, (int)this->GetSocket());
return false;
}
}
bool NtAsyncNetworkTransaction::HasCompletedForGCWI()
{
return this->HasCompleted();
}
void NtAsyncNetworkTransaction::CleanupForGCWI()
{
this->overlap.hEvent = INVALID_HANDLE_VALUE;
AuResetMember(this->pCompletionGroup_);
}
bool NtAsyncNetworkTransaction::TryAttachToCompletionGroup(const AuSPtr<CompletionGroup::ICompletionGroup> &pCompletionGroup)
{
if (bool(this->pCompletionGroup_) ||
!pCompletionGroup)
{
return false;
}
auto pLoopSource = pCompletionGroup->GetTriggerLoopSource();
if (!pLoopSource)
{
return false;
}
this->overlap.hEvent = (HANDLE)AuStaticCast<Loop::LSEvent>(pLoopSource)->GetHandle();
pCompletionGroup->AddWorkItem(this->SharedFromThis());
this->pCompletionGroup_ = pCompletionGroup;
return true;
}
CompletionGroup::ICompletionGroupWorkHandle *NtAsyncNetworkTransaction::ToCompletionGroupHandle()
{
return this;
}
AuSPtr<CompletionGroup::ICompletionGroup> NtAsyncNetworkTransaction::GetCompletionGroup()
{
return this->pCompletionGroup_;
}
bool NtAsyncNetworkTransaction::Complete()
{
return CompleteEx(0);
}
bool NtAsyncNetworkTransaction::CompleteEx(AuUInt completeRoutine, bool bForce)
{
DWORD read {};
if (!completeRoutine)
{
if (this->dwLastBytes || this->bHasFailed)
{
return true;
}
if (!this->pMemoryHold)
{
if (bForce)
{
if (!AuExchange(this->bSendEOSOnce, true))
{
DispatchCb(0);
return true;
}
}
return false;
}
if (::GetOverlappedResult((HANDLE)this->GetSocket(), &this->overlap, &read, false) && (read || bForce))
{
DispatchCb(read);
return true;
}
}
else
{
if ((this->dwOsErrorCode == ERROR_HANDLE_EOF) ||
(this->dwOsErrorCode == WSAECONNABORTED) ||
(this->dwOsErrorCode == WSAECONNRESET) ||
(this->dwOsErrorCode == WSAEDISCON))
{
if (!AuExchange(this->bSendEOSOnce, true))
{
DispatchCb(0);
}
}
else
{
DispatchCb(completeRoutine);
}
return true;
}
return false;
}
bool NtAsyncNetworkTransaction::HasFailed()
{
return this->bHasFailed &&
this->dwOsErrorCode != ERROR_HANDLE_EOF &&
this->dwOsErrorCode != WSAECONNABORTED &&
this->dwOsErrorCode != WSAECONNRESET &&
this->dwOsErrorCode != WSAEDISCON;
}
bool NtAsyncNetworkTransaction::HasErrorCode()
{
return this->bHasFailed &&
this->dwOsErrorCode != ERROR_HANDLE_EOF;
}
bool NtAsyncNetworkTransaction::HasCompleted()
{
return this->bHasFailed ||
this->dwLastBytes;
}
AuUInt NtAsyncNetworkTransaction::GetOSErrorCode()
{
return this->dwOsErrorCode;
}
AuUInt32 NtAsyncNetworkTransaction::GetLastPacketLength()
{
return this->dwLastBytes;
}
void NtAsyncNetworkTransaction::SetCallback(const AuSPtr<IAsyncFinishedSubscriber> &sub)
{
this->pSub = sub;
}
bool NtAsyncNetworkTransaction::Wait(AuUInt32 timeout)
{
return this->bLatch;
}
AuSPtr<AuLoop::ILoopSource> NtAsyncNetworkTransaction::NewLoopSource()
{
AuLoop::ILSEvent *optEvent {};
if (auto pLoopSource = this->pCompletionGroup_->GetTriggerLoopSource())
{
return pLoopSource;
}
else
{
if (auto pWaitable = this->pWaitable)
{
return pWaitable;
}
}
return {};
}
void NtAsyncNetworkTransaction::SetBaseOffset(AuUInt64 uBaseOffset)
{
}
void NtAsyncNetworkTransaction::Reset()
{
if (this->dwLastAbstractStat)
{
this->bIsIrredeemable = true;
this->bHasFailed = true;
if (pCancelIoEx)
{
pCancelIoEx((HANDLE)this->GetSocket(), &this->overlap);
}
else
{
::CancelIo((HANDLE)this->GetSocket());
}
this->dwOsErrorCode = ERROR_ABANDONED_WAIT_0;
}
else
{
this->bHasFailed = false;
}
this->dwLastBytes = 0;
this->dwLastAbstractStat = 0;
}
void NtAsyncNetworkTransaction::MakeSyncable()
{
this->pWaitable = AuLoop::NewLSEvent(false, true);
SysAssert(this->pWaitable);
this->overlap.hEvent = (HANDLE)AuStaticCast<Loop::LSEvent>(this->pWaitable)->GetHandle();
}
void NtAsyncNetworkTransaction::ForceNextWriteWait()
{
this->bForceNextWait = true;
}
bool NtAsyncNetworkTransaction::IDontWannaUsePorts()
{
if (AuExchange(this->pPin, AuSharedFromThis()))
{
while (SleepEx(0, true) == WAIT_IO_COMPLETION)
{
}
if (AuExchange(this->pPin, AuSharedFromThis()))
{
SysPushErrorUnavailableError();
return {};
}
}
return true;
}
void NtAsyncNetworkTransaction::DispatchCb(AuUInt32 read)
{
if (this->bIsWriting)
{
if (read != this->dwLastAbstractStat)
{
this->dwOsErrorCode = WSAEMSGSIZE;
this->bHasFailed = true;
}
}
this->dwLastAbstractStat = 0;
this->dwLastBytes = read;
if (AuExchange(this->bLatch, true))
{
return;
}
if (this->pSub)
{
this->pSub->OnAsyncFileOpFinished(this->dwLastAbstractOffset, read);
}
}
SOCKET NtAsyncNetworkTransaction::GetSocket()
{
return this->pSocket ? this->pSocket->ToPlatformHandle() : SOCKET {};
}
HANDLE NtAsyncNetworkTransaction::GetAlertable()
{
if (this->pCompletionGroup_)
{
auto pLoopSource = this->pCompletionGroup_->GetTriggerLoopSource();
if (!pLoopSource)
{
return INVALID_HANDLE_VALUE;
}
return (HANDLE)AuStaticCast<Loop::LSEvent>(pLoopSource)->GetHandle();
}
else
{
if (auto pSocket = this->pSocket)
{
return (HANDLE)AuStaticCast<AuLoop::LSEvent>(pSocket->ToWorkerEx()->ToEvent())->GetHandle();
}
else
{
return INVALID_HANDLE_VALUE;
}
}
}
}