AuroraRuntime/Include/Aurora/Async/AuFutures.hpp
2023-05-25 03:10:12 +01:00

511 lines
12 KiB
C++

/***
Copyright (C) 2023 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuFutures.hpp
Date: 2023-05-25
Author: Reece
***/
#pragma once
template<typename T, typename Error_t = void>
struct AuFuture : AuEnableSharedFromThis<AuFuture<T, Error_t>>
{
private:
template <typename A>
struct CppFun
{
using B = A;
};
template <>
struct CppFun<void>
{
struct Dummy
{ };
using B = Dummy;
};
using Move_t = AuConditional_t<AuIsVoid_v<T>, typename CppFun<T>::B &&, T>;
using Move2_t = AuConditional_t<AuIsVoid_v<Error_t>, typename CppFun<Error_t>::B &&, Error_t>;
using ErrorStore_t = AuConditional_t<AuIsVoid_v<Error_t>, typename CppFun<Error_t>::B, Error_t>;
public:
using CompleteCallback_f = AuConditional_t<AuIsVoid_v<T>, AuVoidFunc, AuConsumer<Move_t>>;
using ErrorCallback_f = AuConditional_t<AuIsVoid_v<Error_t>, AuVoidFunc, AuConsumer<Move2_t>>;
AU_NO_COPY_NO_MOVE(AuFuture);
void OnComplete(CompleteCallback_f callback)
{
AU_LOCK_GUARD(this->mutex);
if (this->bComplete)
{
SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback");
if constexpr (AuIsVoid_v<T>)
{
callback();
}
else
{
callback(this->value);
}
DoWaterFalls();
return;
}
SysAssertDbg(!this->callback);
this->callback = callback;
if (!this->pid)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
else
{
SysAssert(this->pid == AuAsync::GetCurrentWorkerPId());
}
}
void OnFailure(ErrorCallback_f onFailure)
{
AU_LOCK_GUARD(this->mutex);
if (this->bFailed)
{
SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback");
if constexpr (AuIsVoid_v<Error_t>)
{
onFailure();
}
else
{
onFailure(this->errorValue);
}
DoWaterFalls();
return;
}
SysAssertDbg(!this->onFailure);
if (!this->pid)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
else
{
SysAssert(this->pid == AuAsync::GetCurrentWorkerPId());
}
this->onFailure = onFailure;
}
template<typename T1 = T, AuEnableIf_t<!AuIsVoid_v<T1>> * = nullptr>
void Complete(Move_t value)
{
AU_LOCK_GUARD(this->mutex);
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
this->value = AuMove(value);
this->bComplete = true;
SubmitComplete();
}
template<typename T1 = T, AuEnableIf_t<AuIsVoid_v<T1>> * = nullptr>
void Complete()
{
AU_LOCK_GUARD(this->mutex);
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
this->bComplete = true;
SubmitComplete();
}
template<typename T1 = Error_t, AuEnableIf_t<!AuIsVoid_v<T1>> * = nullptr>
void Fail(Move2_t error)
{
AU_LOCK_GUARD(this->mutex);
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
this->errorValue = AuMove(error);
this->bFailed = true;
SubmitComplete();
}
template<typename T1 = Error_t, AuEnableIf_t<AuIsVoid_v<T1>> * = nullptr>
void Fail()
{
AU_LOCK_GUARD(this->mutex);
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
this->bFailed = true;
SubmitComplete();
}
static AuSPtr<AuFuture<T, Error_t>> New()
{
AuDebug::AddMemoryCrunch();
auto pRet = AuSPtr<AuFuture<T, Error_t>>(new AuFuture(), AuDefaultDeleter<AuFuture<T, Error_t>> {});
AuDebug::DecMemoryCrunch();
return pRet;
}
static AuSPtr<AuFuture<T, Error_t>> New(AuConsumer<Move_t> callback)
{
AuDebug::AddMemoryCrunch();
auto pRet = AuSPtr<AuFuture<T, Error_t>>(new AuFuture(callback), AuDefaultDeleter<AuFuture<T, Error_t>> {});
AuDebug::DecMemoryCrunch();
return pRet;
}
static AuSPtr<AuFuture<T, Error_t>> New(AuConsumer<Move_t> callback, ErrorCallback_f onFailure)
{
AuDebug::AddMemoryCrunch();
auto pRet = AuSPtr<AuFuture<T, Error_t>>(new AuFuture(callback, onFailure), AuDefaultDeleter<AuFuture<T, Error_t>> {});
AuDebug::DecMemoryCrunch();
return pRet;
}
private:
void SubmitComplete()
{
if (AuAsync::GetCurrentWorkerPId() == this->pid)
{
if (!this->onFailure && !this->callback)
{
DoWaterFalls();
return;
}
SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback");
if (this->bComplete)
{
if (auto callback = AuExchange(this->callback, {}))
{
if constexpr (AuIsVoid_v<T>)
{
callback();
}
else
{
callback(this->value);
}
}
}
else if (this->bFailed)
{
if (auto callback = AuExchange(this->onFailure, {}))
{
if constexpr (AuIsVoid_v<Error_t>)
{
callback();
}
else
{
callback(this->errorValue);
}
}
}
DoWaterFalls();
}
else
{
if (!this->onFailure && !this->callback)
{
DoWaterFalls();
return;
}
AuDebug::AddMemoryCrunch();
AuAsync::NewWorkItem(this->pid.value(), AuMakeSharedPanic<AuAsync::BasicWorkStdFunc>([pThat = this->SharedFromThis()]
{
AU_LOCK_GUARD(pThat->mutex);
pThat->SubmitComplete();
}))->Dispatch();
AuDebug::DecMemoryCrunch();
}
}
void DoWaterFalls()
{
auto callbacks = AuExchange(this->waterfall, {});
for (const auto &callback : callbacks)
{
callback(this->bComplete, this->bFailed);
}
}
void AddWaterFall(AuConsumer<bool, bool> callback)
{
AU_LOCK_GUARD(this->mutex);
if (this->bDoneCb)
{
callback(this->bComplete, this->bFailed);
return;
}
AuDebug::AddMemoryCrunch();
SysAssert(AuTryInsert(this->waterfall, callback));
AuDebug::DecMemoryCrunch();
}
AuFuture()
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
AuFuture(AuConsumer<Move_t> callback) : callback(callback)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
AuFuture(AuConsumer<Move_t> callback, ErrorCallback_f onFailure) : callback(callback), onFailure(onFailure)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
CppFun<T>::B value;
ErrorStore_t errorValue;
AuThreadPrimitives::Mutex mutex;
CompleteCallback_f callback;
ErrorCallback_f onFailure;
AuOptionalEx<AuAsync::WorkerPId_t> pid; // todo: make weak?
bool bComplete {};
bool bFailed {};
bool bDone {};
bool bDoneCb {};
AuList<AuConsumer<bool, bool>> waterfall;
friend struct AuWaterfall;
};
template<typename T, typename Error_t = void>
using AuSharedFuture = AuSPtr<AuFuture<T, Error_t>>;
struct AuWaterfall : AuEnableSharedFromThis<AuWaterfall>
{
AU_NO_COPY_NO_MOVE(AuWaterfall);
AuWaterfall(bool bFailOnAny = true) :
bFailOnAny(bFailOnAny)
{
this->pFuture = AuFuture<void>::New();
}
template<typename T>
AuSPtr<AuWaterfall> AddFuture(AuSharedFuture<T> future)
{
AU_LOCK_GUARD(this->mutex);
SysAssert(!this->bReady);
this->uCount++;
future->AddWaterFall([pThat = this->SharedFromThis()](bool bSuccess, bool bFailed)
{
bool bSendSuccess {};
bool bSendFail {};
if (bSuccess)
{
++pThat->uCountOfComplete;
}
else if (bFailed)
{
++pThat->uCountOfFailed;
}
if (!pThat->bReady)
{
return;
}
AU_LOCK_GUARD(pThat->mutex);
pThat->FireDelayed();
});
return this->SharedFromThis();
}
void OnFailure(AuVoidFunc onFailure)
{
AU_LOCK_GUARD(this->mutex);
if (this->bDone)
{
auto [bSendSuccess, bSendFail] = this->GetDispatch(true);
if (bSendFail)
{
onFailure();
}
}
else
{
this->onFailure.push_back(onFailure);
this->Start();
this->FireDelayed();
}
}
void OnSuccess(AuVoidFunc onSuccess)
{
AU_LOCK_GUARD(this->mutex);
if (this->bDone)
{
auto [bSendSuccess, bSendFail] = this->GetDispatch(true);
if (bSendSuccess)
{
onSuccess();
}
}
else
{
this->onSuccess.push_back(onSuccess);
this->Start();
this->FireDelayed();
}
}
static AuSPtr<AuWaterfall> New(bool bFailOnAny = true)
{
return AuMakeSharedThrow<AuWaterfall>(bFailOnAny);
}
private:
AuPair<bool, bool> GetDispatch(bool bForce = false)
{
bool bSendSuccess {};
bool bSendFail {};
if ((this->bFailOnAny && bool(this->uCountOfFailed)) ||
(this->uCountOfFailed == this->uCount))
{
bSendFail = bool(this->onFailure.size()) || bForce;
}
else if ((!this->bFailOnAny || !this->uCountOfFailed) &&
this->uCountOfComplete == this->uCount)
{
bSendSuccess = bool(this->onSuccess.size()) || bForce;
}
else if (!this->bFailOnAny && ((this->uCountOfComplete + this->uCountOfFailed) == this->uCount))
{
bSendSuccess = bool(this->onSuccess.size()) || bForce;
}
return AuMakePair(bSendSuccess, bSendFail);
}
void FireDelayed()
{
auto [bSendSuccess, bSendFail] = this->GetDispatch(false);
if (!bSendSuccess && !bSendFail)
{
return;
}
if (bSendFail)
{
this->bFailed = true;
}
if (AuExchange(this->bDone, true))
{
// Miss?
return;
}
if (bSendSuccess)
{
this->pFuture->Complete();
}
else if (bSendFail)
{
this->pFuture->Fail();
}
}
void Start()
{
if (AuExchange(this->bReady, true))
{
return;
}
this->pFuture->OnComplete([pThat = this->SharedFromThis()]()
{
pThat->FireSuccess();
});
this->pFuture->OnFailure([pThat = this->SharedFromThis()]()
{
pThat->FireFailure();
});
}
void FireSuccess()
{
decltype(onSuccess) callbacks;
{
AU_LOCK_GUARD(this->mutex);
callbacks = AuExchange(this->onSuccess, {});
this->onFailure.clear();
}
for (const auto &callback : callbacks)
{
callback();
}
}
void FireFailure()
{
decltype(onSuccess) callbacks;
{
AU_LOCK_GUARD(this->mutex);
callbacks = AuExchange(this->onFailure, {});
this->onSuccess.clear();
}
for (const auto &callback : callbacks)
{
callback();
}
}
AuSharedFuture<void> pFuture;
AuList<AuVoidFunc> onSuccess;
AuList<AuVoidFunc> onFailure;
AuThreadPrimitives::CriticalSection mutex;
bool bFailOnAny;
AuUInt uCount {};
AuUInt uCountOfComplete {};
AuUInt uCountOfFailed {};
bool bReady {};
bool bDone {};
bool bFailed {};
};
using AuSharedWaterfall = AuWaterfall;