AuroraRuntime/Include/Aurora/Async/AuFutures.hpp

878 lines
19 KiB
C++

/***
Copyright (C) 2023 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: AuFutures.hpp
Date: 2023-05-25
Author: Reece
***/
#pragma once
namespace __detail
{
struct FutureAccessor;
}
template<typename T, typename Error_t = void>
struct AuFuture : AuEnableSharedFromThis<AuFuture<T, Error_t>>
{
private:
template <typename A>
struct CppFun
{
using B = A;
};
template <>
struct CppFun<void>
{
struct Dummy
{ };
using B = Dummy;
};
using Move_t = AuConditional_t<AuIsVoid_v<T>, typename CppFun<T>::B &&, T>;
using Move2_t = AuConditional_t<AuIsVoid_v<Error_t>, typename CppFun<Error_t>::B &&, Error_t>;
using ErrorStore_t = AuConditional_t<AuIsVoid_v<Error_t>, typename CppFun<Error_t>::B, Error_t>;
public:
using CompleteCallback_f = AuConditional_t<AuIsVoid_v<T>, AuVoidFunc, AuConsumer<Move_t>>;
using ErrorCallback_f = AuConditional_t<AuIsVoid_v<Error_t>, AuVoidFunc, AuConsumer<Move2_t>>;
AU_NO_COPY_NO_MOVE(AuFuture);
void OnComplete(CompleteCallback_f callback)
{
AU_LOCK_GUARD(this->mutex);
if (this->bComplete)
{
ExchangeDoneCb();
if constexpr (AuIsVoid_v<T>)
{
callback();
}
else
{
callback(this->value);
}
DoWaterFalls();
return;
}
SysAssertDbg(!this->callback);
this->callback = callback;
if (!this->pid)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
else
{
SysAssert(this->pid == AuAsync::GetCurrentWorkerPId());
}
}
void OnFailure(ErrorCallback_f onFailure)
{
AU_LOCK_GUARD(this->mutex);
if (this->bFailed)
{
ExchangeDoneCb();
if constexpr (AuIsVoid_v<Error_t>)
{
onFailure();
}
else
{
onFailure(this->errorValue);
}
DoWaterFalls();
return;
}
SysAssertDbg(!this->onFailure);
if (!this->pid)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
else
{
SysAssert(this->pid == AuAsync::GetCurrentWorkerPId());
}
this->onFailure = onFailure;
}
template<typename T1 = T, AuEnableIf_t<!AuIsVoid_v<T1>> * = nullptr>
void Complete(Move_t value)
{
AU_LOCK_GUARD(this->mutex);
ExchangeDone();
this->value = AuMove(value);
this->bComplete = true;
SubmitComplete();
}
template<typename T1 = T, AuEnableIf_t<AuIsVoid_v<T1>> * = nullptr>
void Complete()
{
AU_LOCK_GUARD(this->mutex);
ExchangeDone();
this->bComplete = true;
SubmitComplete();
}
template<typename T1 = Error_t, AuEnableIf_t<!AuIsVoid_v<T1>> * = nullptr>
void Fail(Move2_t error)
{
AU_LOCK_GUARD(this->mutex);
ExchangeDone();
this->errorValue = AuMove(error);
this->bFailed = true;
SubmitComplete();
}
template<typename T1 = Error_t, AuEnableIf_t<AuIsVoid_v<T1>> * = nullptr>
void Fail()
{
AU_LOCK_GUARD(this->mutex);
ExchangeDone();
this->bFailed = true;
SubmitComplete();
}
static AuSPtr<AuFuture<T, Error_t>> New()
{
AU_DEBUG_MEMCRUNCH;
return AuSPtr<AuFuture<T, Error_t>>(new AuFuture(), AuDefaultDeleter<AuFuture<T, Error_t>> {});
}
static AuSPtr<AuFuture<T, Error_t>> New(AuConsumer<Move_t> callback)
{
AU_DEBUG_MEMCRUNCH;
return AuSPtr<AuFuture<T, Error_t>>(new AuFuture(callback), AuDefaultDeleter<AuFuture<T, Error_t>> {});
}
static AuSPtr<AuFuture<T, Error_t>> New(AuConsumer<Move_t> callback, ErrorCallback_f onFailure)
{
AU_DEBUG_MEMCRUNCH;
return AuSPtr<AuFuture<T, Error_t>>(new AuFuture(callback, onFailure), AuDefaultDeleter<AuFuture<T, Error_t>> {});
}
protected:
friend struct __detail::FutureAccessor;
typename CppFun<T>::B &GetValue()
{
return value;
}
bool IsFinished()
{
return this->bComplete || this->bFailed;
}
bool IsFailed()
{
return this->bFailed;
}
private:
void ExchangeDone()
{
#if 0
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
#else
SysAssert(!this->bDone, "Future has already finished");
this->bDone = true;
#endif
}
void ExchangeDoneCb()
{
#if 0
SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback");
#else
SysAssert(!this->bDoneCb, "Future has already called a completion callback");
this->bDoneCb = true;
#endif
}
void SubmitComplete()
{
if (AuAsync::GetCurrentWorkerPId() == this->pid)
{
if (!this->onFailure && !this->callback)
{
DoWaterFalls();
return;
}
ExchangeDoneCb();
if (this->bComplete)
{
if (auto callback = AuExchange(this->callback, {}))
{
if constexpr (AuIsVoid_v<T>)
{
callback();
}
else
{
callback(this->value);
}
}
}
else if (this->bFailed)
{
if (auto callback = AuExchange(this->onFailure, {}))
{
if constexpr (AuIsVoid_v<Error_t>)
{
callback();
}
else
{
callback(this->errorValue);
}
}
}
DoWaterFalls();
}
else
{
if (!this->onFailure && !this->callback)
{
DoWaterFalls();
return;
}
AuAsync::NewWorkItem(this->pid.value(), AuMakeSharedPanic<AuAsync::BasicWorkStdFunc>([pThat = this->SharedFromThis()]
{
AU_LOCK_GUARD(pThat->mutex);
pThat->SubmitComplete();
}))->Dispatch();
}
}
void DoWaterFalls()
{
auto callbacks = AuExchange(this->waterfall, {});
for (const auto &callback : callbacks)
{
callback(this->bComplete, this->bFailed);
}
}
void AddWaterFall(AuConsumer<bool, bool> callback)
{
AU_LOCK_GUARD(this->mutex);
if (this->bDoneCb)
{
callback(this->bComplete, this->bFailed);
return;
}
AuDebug::AddMemoryCrunch();
SysAssert(AuTryInsert(this->waterfall, callback));
AuDebug::DecMemoryCrunch();
}
AuFuture()
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
AuFuture(AuConsumer<Move_t> callback) :
callback(callback)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
AuFuture(AuConsumer<Move_t> callback, ErrorCallback_f onFailure) :
callback(callback),
onFailure(onFailure)
{
this->pid = AuAsync::GetCurrentWorkerPId();
}
typename CppFun<T>::B value;
ErrorStore_t errorValue;
AuThreadPrimitives::Mutex mutex;
CompleteCallback_f callback;
ErrorCallback_f onFailure;
AuOptionalEx<AuAsync::WorkerPId_t> pid; // todo: make weak?
AuList<AuConsumer<bool, bool>> waterfall;
AuUInt8 bComplete : 1 {};
AuUInt8 bFailed : 1 {};
AuUInt8 bDone : 1 {};
AuUInt8 bDoneCb : 1 {};
friend struct AuWaterfall;
};
template<typename T, typename Error_t = void>
using AuSharedFuture = AuSPtr<AuFuture<T, Error_t>>;
struct AuWaterfall : AuEnableSharedFromThis<AuWaterfall>
{
AU_NO_COPY_NO_MOVE(AuWaterfall);
AuWaterfall(bool bFailOnAny = true) :
bFailOnAny(bFailOnAny)
{
this->pFuture = AuFuture<void>::New();
}
template<typename T>
AuSPtr<AuWaterfall> AddFuture(AuSharedFuture<T> future)
{
AU_LOCK_GUARD(this->mutex);
SysAssert(!this->bReady);
this->uCount++;
future->AddWaterFall([pThat = this->SharedFromThis()](bool bSuccess, bool bFailed)
{
bool bSendSuccess {};
bool bSendFail {};
if (bSuccess)
{
++pThat->uCountOfComplete;
}
else if (bFailed)
{
++pThat->uCountOfFailed;
}
if (!pThat->bReady)
{
return;
}
AU_LOCK_GUARD(pThat->mutex);
pThat->FireDelayed();
});
return this->SharedFromThis();
}
void OnFailure(AuVoidFunc onFailure)
{
AU_LOCK_GUARD(this->mutex);
if (this->bDone)
{
auto [bSendSuccess, bSendFail] = this->GetDispatch(true);
if (bSendFail)
{
onFailure();
}
}
else
{
this->onFailure.push_back(onFailure);
this->Start();
this->FireDelayed();
}
}
void OnSuccess(AuVoidFunc onSuccess)
{
AU_LOCK_GUARD(this->mutex);
if (this->bDone)
{
auto [bSendSuccess, bSendFail] = this->GetDispatch(true);
if (bSendSuccess)
{
onSuccess();
}
}
else
{
this->onSuccess.push_back(onSuccess);
this->Start();
this->FireDelayed();
}
}
static AuSPtr<AuWaterfall> New(bool bFailOnAny = true)
{
return AuMakeSharedThrow<AuWaterfall>(bFailOnAny);
}
private:
AuPair<bool, bool> GetDispatch(bool bForce = false)
{
bool bSendSuccess {};
bool bSendFail {};
if ((this->bFailOnAny && bool(this->uCountOfFailed)) ||
(this->uCountOfFailed == this->uCount))
{
bSendFail = bool(this->onFailure.size()) || bForce;
}
else if ((!this->bFailOnAny || !this->uCountOfFailed) &&
this->uCountOfComplete == this->uCount)
{
bSendSuccess = bool(this->onSuccess.size()) || bForce;
}
else if (!this->bFailOnAny && ((this->uCountOfComplete + this->uCountOfFailed) == this->uCount))
{
bSendSuccess = bool(this->onSuccess.size()) || bForce;
}
return AuMakePair(bSendSuccess, bSendFail);
}
void FireDelayed()
{
auto [bSendSuccess, bSendFail] = this->GetDispatch(false);
if (!bSendSuccess && !bSendFail)
{
return;
}
if (bSendFail)
{
this->bFailed = true;
}
#if 0
if (AuExchange(this->bDone, true))
{
// Miss?
return;
}
#else
if (this->bDone)
{
return;
}
this->bDone = true;
#endif
if (bSendSuccess)
{
this->pFuture->Complete();
}
else if (bSendFail)
{
this->pFuture->Fail();
}
}
void Start()
{
#if 0
if (AuExchange(this->bDone, true))
{
// Miss?
return;
}
#else
if (this->bDone)
{
return;
}
this->bDone = true;
#endif
this->pFuture->OnComplete([pThat = this->SharedFromThis()]()
{
pThat->FireSuccess();
});
this->pFuture->OnFailure([pThat = this->SharedFromThis()]()
{
pThat->FireFailure();
});
}
void FireSuccess()
{
decltype(onSuccess) callbacks;
{
AU_LOCK_GUARD(this->mutex);
callbacks = AuExchange(this->onSuccess, {});
this->onFailure.clear();
}
for (const auto &callback : callbacks)
{
callback();
}
}
void FireFailure()
{
decltype(onSuccess) callbacks;
{
AU_LOCK_GUARD(this->mutex);
callbacks = AuExchange(this->onFailure, {});
this->onSuccess.clear();
}
for (const auto &callback : callbacks)
{
callback();
}
}
AuSharedFuture<void> pFuture;
AuList<AuVoidFunc> onSuccess;
AuList<AuVoidFunc> onFailure;
AuThreadPrimitives::CriticalSection mutex;
AuUInt uCount {};
AuUInt uCountOfComplete {};
AuUInt uCountOfFailed {};
#if 0
bool bFailOnAny {};
bool bReady {};
bool bDone {};
bool bFailed {};
#else
AuUInt8 bFailOnAny : 1{};
AuUInt8 bReady : 1 {};
AuUInt8 bDone : 1 {};
AuUInt8 bFailed : 1 {};
#endif
};
using AuSharedWaterfall = AuSPtr<AuWaterfall>;
namespace __detail
{
struct FutureAccessor
{
template<typename A, typename B>
static auto &GetValue(AuFuture<A, B> &future)
{
return future.GetValue();
}
template<typename A, typename B>
static bool IsFinished(AuFuture<A, B> &future)
{
return future.IsFinished();
}
template<typename A, typename B>
static bool IsFailed(AuFuture<A, B> &future)
{
return future.IsFailed();
}
};
}
#if defined(AU_LANG_CPP_17) || defined(AU_LANG_CPP_14)
#if !defined(AU_HasCoRoutinedNoIncludeIfAvailable)
#define AU_HasCoRoutinedNoIncludeIfAvailable
#endif
#endif
#if defined(AU_HasCoRoutinedIncluded)
#define __AUHAS_COROUTINES_CO_AWAIT
#else
#if !defined(AU_HasCoRoutinedNoIncludeIfAvailable)
#include <coroutine>
#endif
#define __AUHAS_COROUTINES_CO_AWAIT
#endif
#if defined(__AUHAS_COROUTINES_CO_AWAIT)
namespace std
{
#if !defined(AU_HasVoidCoRoutineTraitsAvailable)
template<>
struct coroutine_traits<void>
{
struct promise_type
{
void get_return_object()
{ }
void set_exception(exception_ptr const &) noexcept
{ }
std::suspend_always initial_suspend() noexcept
{
return {};
}
std::suspend_always final_suspend() noexcept
{
return {};
}
void return_void() noexcept
{ }
void unhandled_exception()
{ }
};
};
template<class... T>
struct coroutine_traits<void, T...>
{
struct promise_type
{
void get_return_object()
{ }
void set_exception(exception_ptr const &) noexcept
{ }
std::suspend_always initial_suspend() noexcept
{
return {};
}
std::suspend_always final_suspend() noexcept
{
return {};
}
void return_void() noexcept
{ }
void unhandled_exception()
{ }
};
};
#endif
}
struct AuVoidTask
{
struct promise_type
{
AuVoidTask get_return_object()
{
return {};
}
std::suspend_never initial_suspend()
{
return {};
}
std::suspend_never final_suspend() noexcept
{
return {};
}
void return_void()
{ }
void unhandled_exception()
{ }
};
};
#else
#if defined(AURORA_IS_MODERNNT_DERIVED)
#include <pplawait.h>
using AuVoidTask = concurrency::task<void>;
#else
struct AuVoidTask
{
struct promise_type
{
AuVoidTask get_return_object()
{
return {};
}
bool initial_suspend()
{
return false;
}
bool final_suspend() noexcept
{
return false;
}
void return_void()
{ }
void unhandled_exception()
{ }
};
};
#endif
#endif
namespace __detail
{
template <typename A, typename B>
struct Awaitable
{
AuSharedFuture<A, B> pFuture;
Awaitable(AuSharedFuture<A, B> pFuture) :
pFuture(pFuture)
{
}
bool await_ready()
{
return __detail::FutureAccessor::IsFinished(*pFuture.get());
}
template <typename T>
void await_suspend(T h)
{
auto pFuture = this->pFuture;
pFuture->OnComplete([h = h](const A &)
{
h.resume();
});
if (__detail::FutureAccessor::IsFinished(*pFuture.get()))
{
return;
}
if constexpr (!AuIsVoid_v<B>)
{
pFuture->OnFailure([h = h](const B &)
{
h.resume();
});
}
else
{
pFuture->OnFailure([h = h]()
{
h.resume();
});
}
}
AuOptionalEx<A> await_resume()
{
auto &refFuture = *pFuture.get();
if (__detail::FutureAccessor::IsFailed(refFuture))
{
return {};
}
return __detail::FutureAccessor::GetValue(refFuture);
}
};
template <typename B>
struct AwaitableVoid
{
AuSharedFuture<void, B> pFuture;
bool await_ready()
{
return __detail::FutureAccessor::IsFinished(*pFuture.get());
}
template <typename T>
void await_suspend(T h)
{
auto pFuture = this->pFuture;
pFuture->OnComplete([h = h]()
{
h.resume();
});
if (__detail::FutureAccessor::IsFinished(*pFuture.get()))
{
return;
}
if constexpr (!AuIsVoid_v<B>)
{
pFuture->OnFailure([h = h](const B &)
{
h.resume();
});
}
else
{
pFuture->OnFailure([h = h]()
{
h.resume();
});
}
}
bool await_resume()
{
return !__detail::FutureAccessor::IsFailed(*pFuture.get());
}
};
}
#if defined(__AUHAS_COROUTINES_CO_AWAIT)
template <typename A, typename B, AU_TEMPLATE_ENABLE_WHEN(!AuIsVoid_v<A>)>
inline auto operator co_await (AuSharedFuture<A, B> pFuture)
{
SysAssert(pFuture);
return __detail::Awaitable<A, B> { pFuture };
}
template <typename A, typename B, AU_TEMPLATE_ENABLE_WHEN(AuIsVoid_v<A>)>
inline auto operator co_await (AuSharedFuture<A, B> pFuture)
{
SysAssert(pFuture);
return __detail::AwaitableVoid<B> { pFuture };
}
#endif