/*** Copyright (C) 2023 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: AuFutures.hpp Date: 2023-05-25 Author: Reece ***/ #pragma once namespace __detail { struct FutureAccessor; } template struct AuFuture : AuEnableSharedFromThis> { private: template struct CppFun { using B = A; }; template <> struct CppFun { struct Dummy { }; using B = Dummy; }; using Move_t = AuConditional_t, typename CppFun::B &&, T>; using Move2_t = AuConditional_t, typename CppFun::B &&, Error_t>; using ErrorStore_t = AuConditional_t, typename CppFun::B, Error_t>; public: using CompleteCallback_f = AuConditional_t, AuVoidFunc, AuConsumer>; using ErrorCallback_f = AuConditional_t, AuVoidFunc, AuConsumer>; AU_NO_COPY_NO_MOVE(AuFuture); void OnComplete(CompleteCallback_f callback) { AU_LOCK_GUARD(this->mutex); if (this->bComplete) { ExchangeDoneCb(); if constexpr (AuIsVoid_v) { callback(); } else { callback(this->value); } DoWaterFalls(); return; } SysAssertDbg(!this->callback); this->callback = callback; if (!this->pid) { this->pid = AuAsync::GetCurrentWorkerPId(); } else { SysAssert(this->pid.value() == AuAsync::GetCurrentWorkerPId()); } } void OnFailure(ErrorCallback_f onFailure) { AU_LOCK_GUARD(this->mutex); if (this->bFailed) { ExchangeDoneCb(); if constexpr (AuIsVoid_v) { onFailure(); } else { onFailure(this->errorValue); } DoWaterFalls(); return; } SysAssertDbg(!this->onFailure); if (!this->pid) { this->pid = AuAsync::GetCurrentWorkerPId(); } else { SysAssert(this->pid.value() == AuAsync::GetCurrentWorkerPId()); } this->onFailure = onFailure; } template> * = nullptr> void Complete(Move_t value) { AU_LOCK_GUARD(this->mutex); ExchangeDone(); this->value = AuMove(value); this->bComplete = true; SubmitComplete(); } template> * = nullptr> void Complete() { AU_LOCK_GUARD(this->mutex); ExchangeDone(); this->bComplete = true; SubmitComplete(); } template> * = nullptr> void Fail(Move2_t error) { AU_LOCK_GUARD(this->mutex); ExchangeDone(); this->errorValue = AuMove(error); this->bFailed = true; SubmitComplete(); } template> * = nullptr> void Fail() { AU_LOCK_GUARD(this->mutex); ExchangeDone(); this->bFailed = true; SubmitComplete(); } static AuSPtr> New() { AU_DEBUG_MEMCRUNCH; return AuSPtr>(new AuFuture(), AuDefaultDeleter> {}); } static AuSPtr> New(AuConsumer callback) { AU_DEBUG_MEMCRUNCH; return AuSPtr>(new AuFuture(callback), AuDefaultDeleter> {}); } static AuSPtr> New(AuConsumer callback, ErrorCallback_f onFailure) { AU_DEBUG_MEMCRUNCH; return AuSPtr>(new AuFuture(callback, onFailure), AuDefaultDeleter> {}); } protected: friend struct __detail::FutureAccessor; typename CppFun::B &GetValue() { return value; } bool IsFinished() { return this->bComplete || this->bFailed; } bool IsFailed() { return this->bFailed; } private: void ExchangeDone() { #if 0 SysAssert(!AuExchange(this->bDone, true), "Future has already finished"); #else SysAssert(!this->bDone, "Future has already finished"); this->bDone = true; #endif } void ExchangeDoneCb() { #if 0 SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback"); #else SysAssert(!this->bDoneCb, "Future has already called a completion callback"); this->bDoneCb = true; #endif } void SubmitComplete() { if (AuAsync::GetCurrentWorkerPId() == this->pid.value()) { if (!this->onFailure && !this->callback) { DoWaterFalls(); return; } ExchangeDoneCb(); if (this->bComplete) { if (auto callback = AuExchange(this->callback, {})) { if constexpr (AuIsVoid_v) { callback(); } else { callback(this->value); } } } else if (this->bFailed) { if (auto callback = AuExchange(this->onFailure, {})) { if constexpr (AuIsVoid_v) { callback(); } else { callback(this->errorValue); } } } DoWaterFalls(); } else { if (!this->onFailure && !this->callback) { DoWaterFalls(); return; } AuAsync::DispatchOn(this->pid.value(), [pThat = this->SharedFromThis()] { AU_LOCK_GUARD(pThat->mutex); pThat->SubmitComplete(); }); } } void DoWaterFalls() { auto callbacks = AuExchange(this->waterfall, {}); for (const auto &callback : callbacks) { callback(this->bComplete, this->bFailed); } } void AddWaterFall(AuConsumer callback) { AU_LOCK_GUARD(this->mutex); if (this->bDoneCb) { callback(this->bComplete, this->bFailed); return; } { AU_DEBUG_MEMCRUNCH; SysAssert(AuTryInsert(this->waterfall, callback)); } } AuFuture() { this->pid = AuAsync::GetCurrentWorkerPId(); } AuFuture(AuConsumer callback) : callback(callback) { this->pid = AuAsync::GetCurrentWorkerPId(); } AuFuture(AuConsumer callback, ErrorCallback_f onFailure) : callback(callback), onFailure(onFailure) { this->pid = AuAsync::GetCurrentWorkerPId(); } typename CppFun::B value; ErrorStore_t errorValue; AuFutexMutex mutex; CompleteCallback_f callback; ErrorCallback_f onFailure; AuOptionalEx pid; // todo: make weak? AuList> waterfall; AuUInt8 bComplete : 1 {}; AuUInt8 bFailed : 1 {}; AuUInt8 bDone : 1 {}; AuUInt8 bDoneCb : 1 {}; friend struct AuWaterfall; }; template using AuSharedFuture = AuSPtr>; struct AuWaterfall : AuEnableSharedFromThis { AU_NO_COPY_NO_MOVE(AuWaterfall); AuWaterfall(bool bFailOnAny = true) : bFailOnAny(bFailOnAny) { this->pFuture = AuFuture::New(); } template AuSPtr AddFuture(AuSharedFuture future) { AU_LOCK_GUARD(this->mutex); SysAssert(!this->bReady); this->uCount++; future->AddWaterFall([pThat = this->SharedFromThis()](bool bSuccess, bool bFailed) { bool bSendSuccess {}; bool bSendFail {}; if (bSuccess) { ++pThat->uCountOfComplete; } else if (bFailed) { ++pThat->uCountOfFailed; } if (!pThat->bReady) { return; } AU_LOCK_GUARD(pThat->mutex); pThat->FireDelayed(); }); return this->SharedFromThis(); } void OnFailure(AuVoidFunc onFailure) { AU_LOCK_GUARD(this->mutex); if (this->bDone) { auto [bSendSuccess, bSendFail] = this->GetDispatch(true); if (bSendFail) { onFailure(); } } else { this->onFailure.push_back(onFailure); this->Start(); this->FireDelayed(); } } void OnSuccess(AuVoidFunc onSuccess) { AU_LOCK_GUARD(this->mutex); if (this->bDone) { auto [bSendSuccess, bSendFail] = this->GetDispatch(true); if (bSendSuccess) { onSuccess(); } } else { this->onSuccess.push_back(onSuccess); this->Start(); this->FireDelayed(); } } static AuSPtr New(bool bFailOnAny = true) { return AuMakeSharedThrow(bFailOnAny); } private: AuPair GetDispatch(bool bForce = false) { bool bSendSuccess {}; bool bSendFail {}; if ((this->bFailOnAny && bool(this->uCountOfFailed)) || (this->uCountOfFailed == this->uCount)) { bSendFail = bool(this->onFailure.size()) || bForce; } else if ((!this->bFailOnAny || !this->uCountOfFailed) && this->uCountOfComplete == this->uCount) { bSendSuccess = bool(this->onSuccess.size()) || bForce; } else if (!this->bFailOnAny && (this->uCountOfFailed == this->uCount)) { bSendFail = bool(this->onFailure.size()) || bForce; } else if (!this->bFailOnAny && ((this->uCountOfComplete + this->uCountOfFailed) == this->uCount)) { bSendSuccess = bool(this->onSuccess.size()) || bForce; } return AuMakePair(bSendSuccess, bSendFail); } void FireDelayed() { auto [bSendSuccess, bSendFail] = this->GetDispatch(false); if (!bSendSuccess && !bSendFail) { return; } if (bSendFail) { this->bFailed = true; } #if 0 if (AuExchange(this->bDone, true)) { // Miss? return; } #else if (this->bDone) { return; } this->bDone = true; #endif if (bSendSuccess) { this->pFuture->Complete(); } else if (bSendFail) { this->pFuture->Fail(); } } void Start() { #if 0 if (AuExchange(this->bDone, true)) { // Miss? return; } #else if (this->bDone) { return; } this->bDone = true; #endif this->pFuture->OnComplete([pThat = this->SharedFromThis()]() { pThat->FireSuccess(); }); this->pFuture->OnFailure([pThat = this->SharedFromThis()]() { pThat->FireFailure(); }); } void FireSuccess() { decltype(onSuccess) callbacks; { AU_LOCK_GUARD(this->mutex); callbacks = AuExchange(this->onSuccess, {}); this->onFailure.clear(); } for (const auto &callback : callbacks) { callback(); } } void FireFailure() { decltype(onSuccess) callbacks; { AU_LOCK_GUARD(this->mutex); callbacks = AuExchange(this->onFailure, {}); this->onSuccess.clear(); } for (const auto &callback : callbacks) { callback(); } } AuSharedFuture pFuture; AuList onSuccess; AuList onFailure; AuThreadPrimitives::CriticalSection mutex; AuUInt uCount {}; AuUInt uCountOfComplete {}; AuUInt uCountOfFailed {}; #if 0 bool bFailOnAny {}; bool bReady {}; bool bDone {}; bool bFailed {}; #else AuUInt8 bFailOnAny : 1{}; AuUInt8 bReady : 1 {}; AuUInt8 bDone : 1 {}; AuUInt8 bFailed : 1 {}; #endif }; using AuSharedWaterfall = AuSPtr; namespace __detail { struct FutureAccessor { template static auto &GetValue(AuFuture &future) { return future.GetValue(); } template static bool IsFinished(AuFuture &future) { return future.IsFinished(); } template static bool IsFailed(AuFuture &future) { return future.IsFailed(); } }; } #if defined(AU_LANG_CPP_17) || defined(AU_LANG_CPP_14) #if !defined(AU_HasCoRoutinedNoIncludeIfAvailable) #define AU_HasCoRoutinedNoIncludeIfAvailable #endif #endif #if defined(AU_HasCoRoutinedIncluded) #define __AUHAS_COROUTINES_CO_AWAIT #else #if !defined(AU_HasCoRoutinedNoIncludeIfAvailable) #include #endif #define __AUHAS_COROUTINES_CO_AWAIT #endif #if defined(__AUHAS_COROUTINES_CO_AWAIT) namespace std { #if !defined(AU_HasVoidCoRoutineTraitsAvailable) template<> struct coroutine_traits { struct promise_type { void get_return_object() { } void set_exception(exception_ptr const &) noexcept { } std::suspend_always initial_suspend() noexcept { return {}; } std::suspend_always final_suspend() noexcept { return {}; } void return_void() noexcept { } void unhandled_exception() { } }; }; template struct coroutine_traits { struct promise_type { void get_return_object() { } void set_exception(exception_ptr const &) noexcept { } std::suspend_always initial_suspend() noexcept { return {}; } std::suspend_always final_suspend() noexcept { return {}; } void return_void() noexcept { } void unhandled_exception() { } }; }; #endif } struct AuVoidTask { struct promise_type { AuVoidTask get_return_object() { return {}; } std::suspend_never initial_suspend() { return {}; } std::suspend_never final_suspend() noexcept { return {}; } void return_void() { } void unhandled_exception() { } }; }; #else #if defined(AURORA_IS_MODERNNT_DERIVED) #include using AuVoidTask = concurrency::task; #else struct AuVoidTask { struct promise_type { AuVoidTask get_return_object() { return {}; } bool initial_suspend() { return false; } bool final_suspend() noexcept { return false; } void return_void() { } void unhandled_exception() { } }; }; #endif #endif namespace __detail { template struct Awaitable { AuSharedFuture pFuture; Awaitable(AuSharedFuture pFuture) : pFuture(pFuture) { } bool await_ready() { return __detail::FutureAccessor::IsFinished(*pFuture.get()); } template void await_suspend(T h) { auto pFuture = this->pFuture; pFuture->OnComplete([h = h](const A &) { h.resume(); }); if (__detail::FutureAccessor::IsFinished(*pFuture.get())) { return; } if constexpr (!AuIsVoid_v) { pFuture->OnFailure([h = h](const B &) { h.resume(); }); } else { pFuture->OnFailure([h = h]() { h.resume(); }); } } AuOptionalEx await_resume() { auto &refFuture = *pFuture.get(); if (__detail::FutureAccessor::IsFailed(refFuture)) { return {}; } return __detail::FutureAccessor::GetValue(refFuture); } }; template struct AwaitableVoid { AuSharedFuture pFuture; bool await_ready() { return __detail::FutureAccessor::IsFinished(*pFuture.get()); } template void await_suspend(T h) { auto pFuture = this->pFuture; pFuture->OnComplete([h = h]() { h.resume(); }); if (__detail::FutureAccessor::IsFinished(*pFuture.get())) { return; } if constexpr (!AuIsVoid_v) { pFuture->OnFailure([h = h](const B &) { h.resume(); }); } else { pFuture->OnFailure([h = h]() { h.resume(); }); } } bool await_resume() { return !__detail::FutureAccessor::IsFailed(*pFuture.get()); } }; } #if defined(__AUHAS_COROUTINES_CO_AWAIT) template )> inline auto operator co_await (AuSharedFuture pFuture) { SysAssert(pFuture); return __detail::Awaitable { pFuture }; } template )> inline auto operator co_await (AuSharedFuture pFuture) { SysAssert(pFuture); return __detail::AwaitableVoid { pFuture }; } #endif