AuroraRPC/Source/AuRPCServerChannel.cpp

253 lines
5.7 KiB
C++

/***
Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuRPCServerChannel.cpp
Date: 2022-6-29
Author: Reece
***/
#include <AuroraRuntime.hpp>
#include "AuRPC.hpp"
#include "AuRPCServerChannel.hpp"
#include "AuRPCRequest.hpp"
#include "AuRPCPipePacket.hpp"
AuRPCServerChannel::AuRPCServerChannel(AuSPtr<AuRPC> parent, AuSPtr<AuRPCServer> server) :
parent(parent),
pipe(this),
server_(server)
{
this->uConnectTime_ = AuTime::SteadyClockNS();
}
AuRPCServerChannel::~AuRPCServerChannel()
{
this->RemoveFromParent();
}
AuSPtr<AuRPC> AuRPCServerChannel::ToContext()
{
return this->parent;
}
void AuRPCServerChannel::SendResponse(AuSPtr<AuRPCResponse> response)
{
if (response->message->flagWriteError)
{
this->FatalError();
return;
}
AuRPCPipePacket packet;
packet.serverResponse = response;
this->pipe.SendPacket(packet);
}
AuString AuRPCServerChannel::ExportString()
{
return this->pipe.pipe->ExportToString();
}
bool AuRPCServerChannel::OnConnect()
{
RpcLogDebug("Server channel: on connect");
return true;
}
void AuRPCServerChannel::OnDisconnect(bool error)
{
RpcLogDebug("Server channel disconnected: {}", error);
this->RemoveFromParent();
}
void AuRPCServerChannel::RemoveFromParent()
{
if (auto pCallbacks = this->server_->pCallbacks)
{
if (!this->isTempChannel_)
{
pCallbacks->OnDisconnect(this->SharedFromThis());
}
}
if (auto pParent = this->server_)
{
AU_LOCK_GUARD(pParent->lock->AsWritable());
if (auto pChannel = pParent->channel)
{
AuRemoveIf(pChannel->subchannels_,
[=](AuSPtr<AuRPCServerChannel> pChannel) -> bool
{
return pChannel.get() == this;
});
}
}
}
bool AuRPCServerChannel::OnDataAvailable(AuByteBuffer& view)
{
do
{
auto bytesAvailable = view.RemainingBytes();
if (bytesAvailable < 4)
{
return true;
}
auto oldRead = view.readPtr;
auto frameLength = view.Read<AuUInt32>();
if (!frameLength)
{
FatalError();
return false;
}
if (frameLength > bytesAvailable)
{
view.readPtr = oldRead;
return true;
}
auto packetType = view.Read<AuUInt8>();
if (packetType == kRequestConnect)
{
if (this->isTempChannel_)
{
this->SendToNewChannel();
}
else
{
this->SendConnectOK();
}
}
else if (packetType == kRequestRPC)
{
auto request = AuMakeShared<AuRPCRequest>();
if (!request)
{
SysPushErrorMem();
view.readPtr = oldRead;
return true;
}
request->serviceId = view.Read<AuUInt32>();
request->methodId = view.Read<AuUInt32>();
request->cookie = view.Read<AuUInt64>();
auto oldLength = view.length;
auto oldWriteHead = view.writePtr;
view.length = (oldRead - view.base) + frameLength;
view.writePtr = view.base + view.length;
this->server_->Dispatch(this, request, &view);
view.readPtr = oldRead + frameLength;
view.writePtr = oldWriteHead;
view.length = oldLength;
}
else if (packetType == kGeneralServerMessage)
{
auto oldLength = view.length;
auto oldWriteHead = view.writePtr;
view.length = (oldRead - view.base) + frameLength;
view.writePtr = view.base + view.length;
if (auto pCallbacks = this->server_->pCallbacks)
{
pCallbacks->OnMessage(this->SharedFromThis(),
view);
}
view.readPtr = oldRead + frameLength;
view.writePtr = oldWriteHead;
view.length = oldLength;
}
}
while (true);
return true;
}
AuUInt64 AuRPCServerChannel::GetConnectTimeNS()
{
return this->uConnectTime_;
}
void AuRPCServerChannel::SendMessage(const AuMemoryViewRead &view)
{
auto pMessage = AuMakeSharedPanic<AuRPCResponseOwned>();
pMessage->PrepareResponse(kGeneralClientMessage);
pMessage->buffer->Write(view.ptr, view.length);
pMessage->FinalizeWrite();
this->SendResponse(pMessage);
}
void AuRPCServerChannel::SendConnectOK()
{
RpcLogDebug("AuRPCServerChannel::SendConnectOK");
auto res = AuMakeShared<AuRPCResponseOwned>();
if (!res)
{
FatalError();
return;
}
res->PrepareResponse(kResponseConnectOK);
res->FinalizeWrite();
this->SendResponse(res);
if (auto pCallbacks = this->server_->pCallbacks)
{
pCallbacks->OnConnect(this->SharedFromThis());
}
}
void AuRPCServerChannel::SendToNewChannel()
{
RpcLogDebug("AuRPCServerChannel::SendToNewChannel");
auto res = AuMakeShared<AuRPCResponseOwned>();
if (!res)
{
FatalError();
return;
}
res->PrepareResponse(kResponseMulticonnect);
auto channel = this->server_->NewChannel(false);
if (!channel)
{
this->FatalError();
return;
}
if (!AuTryInsert(this->subchannels_, channel))
{
this->FatalError();
return;
}
res->buffer->Write(channel->ExportString());
res->FinalizeWrite();
this->SendResponse(res);
}
void AuRPCServerChannel::FatalError()
{
RpcLogDebug("AuRPCServerChannel::FatalError");
this->pipe.Deinit();
}
bool AuRPCServerChannel::Init()
{
return this->pipe.Init(this->parent->optPipeLength);
}