AuroraRPC/Source/AuRPCClientChannel.cpp

347 lines
8.1 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuRPCClientChannel.cpp
Date: 2022-6-29
Author: Reece
***/
#include <AuroraRuntime.hpp>
#include "AuRPC.hpp"
#include "AuRPCClientChannel.hpp"
#include "AuRPCRequest.hpp"
#include "AuRPCPipePacket.hpp"
AuRPCClientChannel::AuRPCClientChannel(AuSPtr<AuRPC> parent) :
parent_(parent),
pipe_(this)
{ }
bool AuRPCClientChannel::OnConnect()
{
auto request = AuMakeShared<AuRPCRequest>();
request->WriteHeaderConnect();
AuRPCPipePacket packet;
packet.clientRequest = request;
packet.protPacket = true;
packet.clientChannel = SharedFromThis();
this->pipe_.SendPacket(packet);
return true;
}
void AuRPCClientChannel::Disconnect()
{
this->pipe_.Deinit();
Finalize();
}
void AuRPCClientChannel::SendRequest(AuSPtr<AuIRPCRequest> response2)
{
auto response = AuStaticCast<AuRPCRequest>(response2);
if (!response)
{
SysPushErrorArg();
return;
}
RpcLogDebug("Client sending request of bytes: {}", response->GetData().length);
AuRPCPipePacket packet;
packet.clientRequest = response;
packet.clientChannel = SharedFromThis();
this->pipe_.SendPacket(packet);
}
AuUInt64 AuRPCClientChannel::GetConnectTimeNS()
{
return this->uConnectTime_;
}
bool AuRPCClientChannel::SendMessage(const AuMemoryViewRead &view, bool bLarge)
{
if (!bLarge &&
view.length < this->parent_->GetLargePacketLength())
{
auto pMessage = AuMakeShared<AuRPCRequest>();
if (!pMessage)
{
return false;
}
pMessage->dataType = kGeneralServerMessage;
pMessage->SetData(view);
this->SendRequest(pMessage);
return true;
}
else
{
auto pSharedMemory = AuIPC::NewSharedMemory(view.length);
if (!pSharedMemory)
{
return false;
}
auto dest = pSharedMemory->GetMemory();
AuMemcpy(dest.ptr, view.ptr, view.length);
auto pMessage = AuMakeShared<AuRPCRequest>();
if (!pMessage)
{
return false;
}
pMessage->dataType = kGeneralMassiveMessage;
AuByteBuffer data;
data.Write(pSharedMemory->ExportToString());
if (!data)
{
return false;
}
pMessage->SetData(data);
this->SendRequest(pMessage);
return true;
}
}
AuSPtr<AuRPC> AuRPCClientChannel::ToContext()
{
return this->parent_;
}
void AuRPCClientChannel::OnDisconnect(bool error)
{
RpcLogDebug("AuRPCClientChannel::OnDisconnect");
if (this->bConnected_ || !this->bConnectingAlternate_)
{
Finalize();
}
}
bool AuRPCClientChannel::OnDataAvailable(AuByteBuffer &view)
{
RpcLogDebug("AuRPCClientChannel::OnDataAvailable");
do
{
auto bytesAvailable = view.RemainingBytes();
if (bytesAvailable < 4)
{
return true;
}
auto oldRead = view.readPtr;
auto frameLength = view.Read<AuUInt32>();
if (!frameLength)
{
Disconnect();
return false;
}
if (frameLength > bytesAvailable)
{
view.readPtr = oldRead;
return true;
}
auto endPtr = oldRead + frameLength;
auto packetType = view.Read<AuUInt8>();
if (packetType == kResponseConnectOK)
{
this->bConnected_ = true;
this->ProcessConnectionOK();
}
else if (packetType == kResponseMulticonnect)
{
this->bConnectingAlternate_ = true;
this->pipe_.Deinit();
this->Init(view.Read<AuString>());
}
else if (packetType == kResponseRPC)
{
auto response = AuMakeShared<AuRPCResponse>();
if (!response)
{
SysPushErrorMem();
view.readPtr = oldRead;
return true;
}
auto oldLength = view.length;
auto oldWriteHead = view.writePtr;
view.length = (oldRead - view.base) + frameLength;
view.writePtr = view.base + view.length;
response->message = &view;
response->Deserialize();
this->ProcessResponse(response);
view.readPtr = endPtr;
view.writePtr = oldWriteHead;
view.length = oldLength;
}
else if (packetType == kGeneralBroadcast ||
packetType == kGeneralClientMessage)
{
auto oldLength = view.length;
auto oldWriteHead = view.writePtr;
view.length = (oldRead - view.base) + frameLength;
view.writePtr = view.base + view.length;
if (auto pCallbacks = this->callbacks_)
{
if (packetType == kGeneralBroadcast)
{
pCallbacks->OnBroadcast(view);
}
else
{
pCallbacks->OnMessage(view);
}
}
view.readPtr = endPtr;
view.writePtr = oldWriteHead;
view.length = oldLength;
}
else if (packetType == kGeneralMassiveMessage)
{
auto oldLength = view.length;
auto oldWriteHead = view.writePtr;
view.length = (oldRead - view.base) + frameLength;
view.writePtr = view.base + view.length;
auto str = view.Read<AuString>();
if (view)
{
auto pMemory = AuIPC::ImportSharedMemory(str);
if (!pMemory)
{
SysPushErrorIO("Couldnt open shared memory for large packet");
this->FatalIOError();
return true;
}
AuByteBuffer buf(pMemory->GetMemory());
if (!buf)
{
SysPushErrorMemory();
this->FatalIOError();
return true;
}
if (auto pCallbacks = this->callbacks_)
{
pCallbacks->OnMessage(buf);
}
}
view.readPtr = oldRead + frameLength;
view.writePtr = oldWriteHead;
view.length = oldLength;
}
}
while (true);
return true;
}
bool AuRPCClientChannel::Init(const AuString &ipc)
{
return this->pipe_.Init(ipc);
}
void AuRPCClientChannel::FatalIOError()
{
RpcLogDebug("AuRPCClientChannel::FatalIOError");
Disconnect();
}
void AuRPCClientChannel::ProcessConnectionOK()
{
this->uConnectTime_ = AuTime::SteadyClockNS();
auto re = AuExchange(this->outstandingRequests, AuList<AuSPtr<AuRPCRequest>>{});
if (this->callbacks_)
{
this->callbacks_->OnConnect();
}
for (auto a : re)
{
this->SendRequest(a);
}
}
void AuRPCClientChannel::ProcessResponse(const AuSPtr<AuRPCResponse> &response)
{
RpcLogDebug("AuRPCClientChannel::ProcessResponse");
for (auto itr = this->outstandingRequests.begin();
itr != this->outstandingRequests.end();
itr++)
{
if (response->cookie != AuUInt(itr->get()))
{
continue;
}
auto req = *itr;
this->outstandingRequests.erase(itr);
if (req->callback)
{
req->callback->OnResponse(*response);
}
break;
}
}
bool AuRPCClientChannel::IsConnected()
{
return this->bConnected_;
}
void AuRPCClientChannel::SetCallbacks(const AuSPtr<AuIRPCChannelCallbacks> &callbacks)
{
if (!callbacks)
{
SysPushErrorArg();
return;
}
if (IsConnected())
{
callbacks->OnConnect();
}
this->callbacks_ = callbacks;
}
void AuRPCClientChannel::Finalize()
{
auto re = AuExchange(this->outstandingRequests, AuList<AuSPtr<AuRPCRequest>>{});
if (AuExchange(this->bIsDead_, true))
{
return;
}
for (auto a : re)
{
a->callback->OnResponse(AuRPCResponse(ERPCError::eAborted));
}
if (this->callbacks_)
{
this->callbacks_->OnDisconnect();
this->callbacks_.reset();
}
}