AuroraRuntime/Source/IO/Net/AuNetSocketServerAcceptReadOperation.NT.cpp

287 lines
8.4 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuNetSocketServerAcceptReadOperation.NT.cpp
Date: 2022-8-22
Author: Reece
***/
#include "Networking.hpp"
#include "AuNetSocketServer.hpp"
#include "AuNetSocket.hpp"
#include "AuNetEndpoint.hpp"
#include "AuNetInterface.hpp"
#include "AuNetWorker.hpp"
#include "AuNetError.hpp"
namespace Aurora::IO::Net
{
SocketServerAcceptReadOperation::SocketServerAcceptReadOperation(NetInterface *pInterface,
SocketServer *pParent) :
pInterface_(pInterface),
SocketServerAcceptReadOperationBase(pParent)
, SocketOverlappedOperation(true)
{
this->InitOnce();
}
bool SocketServerAcceptReadOperation::IsValid()
{
this->InitOnce();
return SocketServerAcceptReadOperationBase::IsValid() &&
bool(lpfnAcceptEx);
}
void SocketServerAcceptReadOperation::OnOverlappedComplete()
{
if (!this->nextSocketPtr)
{
return;
}
if (!psetsockopt)
{
this->nextSocketPtr->SendErrorNoStream({});
return;
}
SOCKET hListenHandle = (SOCKET)this->pParent_->ToPlatformHandle();
int ret = psetsockopt(this->nextSocket,
SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT,
(char *)&hListenHandle,
sizeof(SOCKET));
auto dwError = pWSAGetLastError();
if (ret == -1)
{
SysPushErrorNet("Couldn't enable socket after overlapped accept");
NetError error;
error.osError = dwError;
this->nextSocketPtr->SendErrorNoStream(error);
return;
}
UpdateNextSocketAddresses();
if (this->pParent_->bMultiThreaded)
{
auto pCallback = AuMakeShared<IIOProcessorWorkUnitFunctional>([socket = this->nextSocketPtr]()
{
if (!socket->SendPreestablish())
{
SysPushErrorNet("Couldn't preestablish next socket");
socket->SendErrorNoStream({});
return;
}
socket->DoMain();
},
[]()
{
});
if (!pCallback)
{
SysPushErrorMemory();
SysPushErrorNet("Memory");
this->nextSocketPtr->SendErrorNoStream({});
return;
}
if (!this->nextSocketPtr->ToWorker()->ToProcessor()->SubmitIOWorkItem(pCallback))
{
SysPushErrorMemory();
SysPushErrorNet("couldnt schedule read tick on thread");
this->nextSocketPtr->SendErrorNoStream({});
}
}
else
{
this->nextSocketPtr->DoMain();
}
// accept next
this->pParent_->ScheduleAcceptTick(); // We **cannot** readd the current event in the trigger callback
}
void SocketServerAcceptReadOperation::OnOverlappedFailure(const NetError &error)
{
SysPushErrorNet("Accept fail: {}", NetErrorToExtendedString(error));
}
void SocketServerAcceptReadOperation::DoNext()
{
}
bool SocketServerAcceptReadOperation::DoTick()
{
this->InitOnce();
if (!this->IsValid())
{
return false;
}
if (!this->Pretick())
{
return false;
}
this->addressLengthA_ = 0;
if (!AuTryResize(this->addresses_,
(this->pParent_->endpointSize_ + 16) * 2))
{
return false;
}
auto bRet = lpfnAcceptEx(this->pParent_->ToPlatformHandle(),
this->nextSocket,
this->addresses_.data(),
0,
this->pParent_->endpointSize_ + 16,
this->pParent_->endpointSize_ + 16,
&this->addressLengthA_,
&this->overlapped);
return this->FinishOperation(this->pParent_->SharedFromThis(),
AuUnsafeRaiiToShared(this->pParent_->ToWorker()),
bRet);
}
void SocketServerAcceptReadOperation::InitOnce()
{
if (lpfnAcceptEx)
{
return;
}
if (!pWSAGetLastError)
{
return;
}
if (!pWSAIoctl)
{
return;
}
GUID GuidAcceptEx = WSAID_ACCEPTEX;
OVERLAPPED a {};
DWORD dwBytes;
a.hEvent = CreateEventA(NULL, true, 0, NULL);
if ((pWSAIoctl(this->pParent_->ToPlatformHandle(),
SIO_GET_EXTENSION_FUNCTION_POINTER,
&GuidAcceptEx,
sizeof(GuidAcceptEx),
&lpfnAcceptEx,
sizeof(lpfnAcceptEx),
&dwBytes,
&a,
NULL) != 0) &&
(pWSAGetLastError() != ERROR_IO_PENDING))
{
int error = pWSAGetLastError();
SysPushErrorIO();
::CloseHandle(a.hEvent);
return;
}
::WaitForSingleObject(a.hEvent, 0);
::CloseHandle(a.hEvent);
}
bool SocketServerAcceptReadOperation::Pretick()
{
auto &localAddress = this->pParent_->GetLocalEndpoint();
if (!pWSASocketW)
{
return {};
}
nextSocket = pWSASocketW(
IPToDomain(localAddress),
TransportToPlatformType(localAddress),
IPPROTO_IP,
nullptr,
0,
WSA_FLAG_OVERLAPPED
);
if (nextSocket == INVALID_SOCKET)
{
SysPushErrorNet("No socket");
return false;
}
auto pFactory = this->pParent_->GetFactory();
if (!bool(pFactory))
{
SysPushErrorNet("Socket missing factory");
return false;
}
auto pNewDriver = pFactory->NewSocketDriver();
if (!bool(pNewDriver))
{
SysPushErrorNet("Socket factory failed to provide a new instance ahead of server acceptance");
return false;
}
NetWorker *pWorker;
if (this->pParent_->bMultiThreaded)
{
pWorker = this->pInterface_->TryScheduleEx().get();
}
else
{
pWorker = this->pParent_->ToWorkerEx();
}
nextSocketPtr = AuMakeShared<Socket>(this->pInterface_,
pWorker,
pNewDriver,
(AuUInt)nextSocket,
AuSPtr<ISocketServer>(this->pParent_->SharedFromThis(), this->pParent_),
this->pParent_);
if (!bool(nextSocketPtr))
{
// TODO: schedule retry
SysPushErrorNet("Couldn't allocate a socket");
return false;
}
if (this->pParent_->bMultiThreaded)
{
// Defer SendPreestablish until we're done to minimize RPCs
}
else
{
if (!nextSocketPtr->SendPreestablish())
{
SysPushErrorNet("Couldn't preestablish next socket");
return false;
}
}
return true;
}
void SocketServerAcceptReadOperation::UpdateNextSocketAddresses()
{
SysAssert(this->nextSocketPtr);
this->nextSocketPtr->endpointSize_ = this->pParent_->endpointSize_;
AuMemcpy(this->nextSocketPtr->localEndpoint_.hint, &this->addresses_[0], this->pParent_->endpointSize_);
AuMemcpy(this->nextSocketPtr->remoteEndpoint_.hint, &this->addresses_[16 + this->pParent_->endpointSize_], this->pParent_->endpointSize_);
DeoptimizeEndpoint(nextSocketPtr->remoteEndpoint_);
DeoptimizeEndpoint(nextSocketPtr->localEndpoint_);
nextSocketPtr->remoteEndpoint_.transportProtocol = nextSocketPtr->localEndpoint_.transportProtocol = this->pParent_->GetLocalEndpoint().transportProtocol;
}
}