/*** Copyright (C) 2023 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: AuFutures.hpp Date: 2023-05-25 Author: Reece ***/ #pragma once template struct AuFuture : AuEnableSharedFromThis> { private: template struct CppFun { using B = A; }; template <> struct CppFun { struct Dummy { }; using B = Dummy; }; using Move_t = AuConditional_t, typename CppFun::B &&, T>; using Move2_t = AuConditional_t, typename CppFun::B &&, Error_t>; using ErrorStore_t = AuConditional_t, typename CppFun::B, Error_t>; public: using CompleteCallback_f = AuConditional_t, AuVoidFunc, AuConsumer>; using ErrorCallback_f = AuConditional_t, AuVoidFunc, AuConsumer>; AU_NO_COPY_NO_MOVE(AuFuture); void OnComplete(CompleteCallback_f callback) { AU_LOCK_GUARD(this->mutex); if (this->bComplete) { SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback"); if constexpr (AuIsVoid_v) { callback(); } else { callback(this->value); } DoWaterFalls(); return; } SysAssertDbg(!this->callback); this->callback = callback; if (!this->pid) { this->pid = AuAsync::GetCurrentWorkerPId(); } else { SysAssert(this->pid == AuAsync::GetCurrentWorkerPId()); } } void OnFailure(ErrorCallback_f onFailure) { AU_LOCK_GUARD(this->mutex); if (this->bFailed) { SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback"); if constexpr (AuIsVoid_v) { onFailure(); } else { onFailure(this->errorValue); } DoWaterFalls(); return; } SysAssertDbg(!this->onFailure); if (!this->pid) { this->pid = AuAsync::GetCurrentWorkerPId(); } else { SysAssert(this->pid == AuAsync::GetCurrentWorkerPId()); } this->onFailure = onFailure; } template> * = nullptr> void Complete(Move_t value) { AU_LOCK_GUARD(this->mutex); SysAssert(!AuExchange(this->bDone, true), "Future has already finished"); this->value = AuMove(value); this->bComplete = true; SubmitComplete(); } template> * = nullptr> void Complete() { AU_LOCK_GUARD(this->mutex); SysAssert(!AuExchange(this->bDone, true), "Future has already finished"); this->bComplete = true; SubmitComplete(); } template> * = nullptr> void Fail(Move2_t error) { AU_LOCK_GUARD(this->mutex); SysAssert(!AuExchange(this->bDone, true), "Future has already finished"); this->errorValue = AuMove(error); this->bFailed = true; SubmitComplete(); } template> * = nullptr> void Fail() { AU_LOCK_GUARD(this->mutex); SysAssert(!AuExchange(this->bDone, true), "Future has already finished"); this->bFailed = true; SubmitComplete(); } static AuSPtr> New() { AuDebug::AddMemoryCrunch(); auto pRet = AuSPtr>(new AuFuture(), AuDefaultDeleter> {}); AuDebug::DecMemoryCrunch(); return pRet; } static AuSPtr> New(AuConsumer callback) { AuDebug::AddMemoryCrunch(); auto pRet = AuSPtr>(new AuFuture(callback), AuDefaultDeleter> {}); AuDebug::DecMemoryCrunch(); return pRet; } static AuSPtr> New(AuConsumer callback, ErrorCallback_f onFailure) { AuDebug::AddMemoryCrunch(); auto pRet = AuSPtr>(new AuFuture(callback, onFailure), AuDefaultDeleter> {}); AuDebug::DecMemoryCrunch(); return pRet; } private: void SubmitComplete() { if (AuAsync::GetCurrentWorkerPId() == this->pid) { if (!this->onFailure && !this->callback) { DoWaterFalls(); return; } SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback"); if (this->bComplete) { if (auto callback = AuExchange(this->callback, {})) { if constexpr (AuIsVoid_v) { callback(); } else { callback(this->value); } } } else if (this->bFailed) { if (auto callback = AuExchange(this->onFailure, {})) { if constexpr (AuIsVoid_v) { callback(); } else { callback(this->errorValue); } } } DoWaterFalls(); } else { if (!this->onFailure && !this->callback) { DoWaterFalls(); return; } AuDebug::AddMemoryCrunch(); AuAsync::NewWorkItem(this->pid.value(), AuMakeSharedPanic([pThat = this->SharedFromThis()] { AU_LOCK_GUARD(pThat->mutex); pThat->SubmitComplete(); }))->Dispatch(); AuDebug::DecMemoryCrunch(); } } void DoWaterFalls() { auto callbacks = AuExchange(this->waterfall, {}); for (const auto &callback : callbacks) { callback(this->bComplete, this->bFailed); } } void AddWaterFall(AuConsumer callback) { AU_LOCK_GUARD(this->mutex); if (this->bDoneCb) { callback(this->bComplete, this->bFailed); return; } AuDebug::AddMemoryCrunch(); SysAssert(AuTryInsert(this->waterfall, callback)); AuDebug::DecMemoryCrunch(); } AuFuture() { this->pid = AuAsync::GetCurrentWorkerPId(); } AuFuture(AuConsumer callback) : callback(callback) { this->pid = AuAsync::GetCurrentWorkerPId(); } AuFuture(AuConsumer callback, ErrorCallback_f onFailure) : callback(callback), onFailure(onFailure) { this->pid = AuAsync::GetCurrentWorkerPId(); } CppFun::B value; ErrorStore_t errorValue; AuThreadPrimitives::Mutex mutex; CompleteCallback_f callback; ErrorCallback_f onFailure; AuOptionalEx pid; // todo: make weak? bool bComplete {}; bool bFailed {}; bool bDone {}; bool bDoneCb {}; AuList> waterfall; friend struct AuWaterfall; }; template using AuSharedFuture = AuSPtr>; struct AuWaterfall : AuEnableSharedFromThis { AU_NO_COPY_NO_MOVE(AuWaterfall); AuWaterfall(bool bFailOnAny = true) : bFailOnAny(bFailOnAny) { this->pFuture = AuFuture::New(); } template AuSPtr AddFuture(AuSharedFuture future) { AU_LOCK_GUARD(this->mutex); SysAssert(!this->bReady); this->uCount++; future->AddWaterFall([pThat = this->SharedFromThis()](bool bSuccess, bool bFailed) { bool bSendSuccess {}; bool bSendFail {}; if (bSuccess) { ++pThat->uCountOfComplete; } else if (bFailed) { ++pThat->uCountOfFailed; } if (!pThat->bReady) { return; } AU_LOCK_GUARD(pThat->mutex); pThat->FireDelayed(); }); return this->SharedFromThis(); } void OnFailure(AuVoidFunc onFailure) { AU_LOCK_GUARD(this->mutex); if (this->bDone) { auto [bSendSuccess, bSendFail] = this->GetDispatch(true); if (bSendFail) { onFailure(); } } else { this->onFailure.push_back(onFailure); this->Start(); this->FireDelayed(); } } void OnSuccess(AuVoidFunc onSuccess) { AU_LOCK_GUARD(this->mutex); if (this->bDone) { auto [bSendSuccess, bSendFail] = this->GetDispatch(true); if (bSendSuccess) { onSuccess(); } } else { this->onSuccess.push_back(onSuccess); this->Start(); this->FireDelayed(); } } static AuSPtr New(bool bFailOnAny = true) { return AuMakeSharedThrow(bFailOnAny); } private: AuPair GetDispatch(bool bForce = false) { bool bSendSuccess {}; bool bSendFail {}; if ((this->bFailOnAny && bool(this->uCountOfFailed)) || (this->uCountOfFailed == this->uCount)) { bSendFail = bool(this->onFailure.size()) || bForce; } else if ((!this->bFailOnAny || !this->uCountOfFailed) && this->uCountOfComplete == this->uCount) { bSendSuccess = bool(this->onSuccess.size()) || bForce; } else if (!this->bFailOnAny && ((this->uCountOfComplete + this->uCountOfFailed) == this->uCount)) { bSendSuccess = bool(this->onSuccess.size()) || bForce; } return AuMakePair(bSendSuccess, bSendFail); } void FireDelayed() { auto [bSendSuccess, bSendFail] = this->GetDispatch(false); if (!bSendSuccess && !bSendFail) { return; } if (bSendFail) { this->bFailed = true; } if (AuExchange(this->bDone, true)) { // Miss? return; } if (bSendSuccess) { this->pFuture->Complete(); } else if (bSendFail) { this->pFuture->Fail(); } } void Start() { if (AuExchange(this->bReady, true)) { return; } this->pFuture->OnComplete([pThat = this->SharedFromThis()]() { pThat->FireSuccess(); }); this->pFuture->OnFailure([pThat = this->SharedFromThis()]() { pThat->FireFailure(); }); } void FireSuccess() { decltype(onSuccess) callbacks; { AU_LOCK_GUARD(this->mutex); callbacks = AuExchange(this->onSuccess, {}); this->onFailure.clear(); } for (const auto &callback : callbacks) { callback(); } } void FireFailure() { decltype(onSuccess) callbacks; { AU_LOCK_GUARD(this->mutex); callbacks = AuExchange(this->onFailure, {}); this->onSuccess.clear(); } for (const auto &callback : callbacks) { callback(); } } AuSharedFuture pFuture; AuList onSuccess; AuList onFailure; AuThreadPrimitives::CriticalSection mutex; bool bFailOnAny; AuUInt uCount {}; AuUInt uCountOfComplete {}; AuUInt uCountOfFailed {}; bool bReady {}; bool bDone {}; bool bFailed {}; }; using AuSharedWaterfall = AuWaterfall;