511 lines
12 KiB
C++
511 lines
12 KiB
C++
/***
|
|
Copyright (C) 2023 J Reece Wilson (a/k/a "Reece"). All rights reserved.
|
|
|
|
File: AuFutures.hpp
|
|
Date: 2023-05-25
|
|
Author: Reece
|
|
***/
|
|
#pragma once
|
|
|
|
template<typename T, typename Error_t = void>
|
|
struct AuFuture : AuEnableSharedFromThis<AuFuture<T, Error_t>>
|
|
{
|
|
private:
|
|
template <typename A>
|
|
struct CppFun
|
|
{
|
|
using B = A;
|
|
};
|
|
|
|
template <>
|
|
struct CppFun<void>
|
|
{
|
|
struct Dummy
|
|
{ };
|
|
using B = Dummy;
|
|
};
|
|
|
|
|
|
using Move_t = AuConditional_t<AuIsVoid_v<T>, typename CppFun<T>::B &&, T>;
|
|
using Move2_t = AuConditional_t<AuIsVoid_v<Error_t>, typename CppFun<Error_t>::B &&, Error_t>;
|
|
using ErrorStore_t = AuConditional_t<AuIsVoid_v<Error_t>, typename CppFun<Error_t>::B, Error_t>;
|
|
public:
|
|
|
|
using CompleteCallback_f = AuConditional_t<AuIsVoid_v<T>, AuVoidFunc, AuConsumer<Move_t>>;
|
|
using ErrorCallback_f = AuConditional_t<AuIsVoid_v<Error_t>, AuVoidFunc, AuConsumer<Move2_t>>;
|
|
|
|
AU_NO_COPY_NO_MOVE(AuFuture);
|
|
|
|
void OnComplete(CompleteCallback_f callback)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
|
|
if (this->bComplete)
|
|
{
|
|
SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback");
|
|
|
|
if constexpr (AuIsVoid_v<T>)
|
|
{
|
|
callback();
|
|
}
|
|
else
|
|
{
|
|
callback(this->value);
|
|
}
|
|
|
|
DoWaterFalls();
|
|
return;
|
|
}
|
|
|
|
SysAssertDbg(!this->callback);
|
|
this->callback = callback;
|
|
|
|
if (!this->pid)
|
|
{
|
|
this->pid = AuAsync::GetCurrentWorkerPId();
|
|
}
|
|
else
|
|
{
|
|
SysAssert(this->pid == AuAsync::GetCurrentWorkerPId());
|
|
}
|
|
}
|
|
|
|
void OnFailure(ErrorCallback_f onFailure)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
|
|
if (this->bFailed)
|
|
{
|
|
SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback");
|
|
|
|
if constexpr (AuIsVoid_v<Error_t>)
|
|
{
|
|
onFailure();
|
|
}
|
|
else
|
|
{
|
|
onFailure(this->errorValue);
|
|
}
|
|
|
|
DoWaterFalls();
|
|
return;
|
|
}
|
|
|
|
SysAssertDbg(!this->onFailure);
|
|
if (!this->pid)
|
|
{
|
|
this->pid = AuAsync::GetCurrentWorkerPId();
|
|
}
|
|
else
|
|
{
|
|
SysAssert(this->pid == AuAsync::GetCurrentWorkerPId());
|
|
}
|
|
|
|
this->onFailure = onFailure;
|
|
}
|
|
|
|
template<typename T1 = T, AuEnableIf_t<!AuIsVoid_v<T1>> * = nullptr>
|
|
void Complete(Move_t value)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
|
|
|
|
this->value = AuMove(value);
|
|
this->bComplete = true;
|
|
|
|
SubmitComplete();
|
|
}
|
|
|
|
template<typename T1 = T, AuEnableIf_t<AuIsVoid_v<T1>> * = nullptr>
|
|
void Complete()
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
|
|
|
|
this->bComplete = true;
|
|
|
|
SubmitComplete();
|
|
}
|
|
|
|
template<typename T1 = Error_t, AuEnableIf_t<!AuIsVoid_v<T1>> * = nullptr>
|
|
void Fail(Move2_t error)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
|
|
|
|
this->errorValue = AuMove(error);
|
|
this->bFailed = true;
|
|
|
|
SubmitComplete();
|
|
}
|
|
|
|
template<typename T1 = Error_t, AuEnableIf_t<AuIsVoid_v<T1>> * = nullptr>
|
|
void Fail()
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
SysAssert(!AuExchange(this->bDone, true), "Future has already finished");
|
|
|
|
this->bFailed = true;
|
|
|
|
SubmitComplete();
|
|
}
|
|
|
|
static AuSPtr<AuFuture<T, Error_t>> New()
|
|
{
|
|
AuDebug::AddMemoryCrunch();
|
|
auto pRet = AuSPtr<AuFuture<T, Error_t>>(new AuFuture(), AuDefaultDeleter<AuFuture<T, Error_t>> {});
|
|
AuDebug::DecMemoryCrunch();
|
|
return pRet;
|
|
}
|
|
|
|
static AuSPtr<AuFuture<T, Error_t>> New(AuConsumer<Move_t> callback)
|
|
{
|
|
AuDebug::AddMemoryCrunch();
|
|
auto pRet = AuSPtr<AuFuture<T, Error_t>>(new AuFuture(callback), AuDefaultDeleter<AuFuture<T, Error_t>> {});
|
|
AuDebug::DecMemoryCrunch();
|
|
return pRet;
|
|
}
|
|
|
|
static AuSPtr<AuFuture<T, Error_t>> New(AuConsumer<Move_t> callback, ErrorCallback_f onFailure)
|
|
{
|
|
AuDebug::AddMemoryCrunch();
|
|
auto pRet = AuSPtr<AuFuture<T, Error_t>>(new AuFuture(callback, onFailure), AuDefaultDeleter<AuFuture<T, Error_t>> {});
|
|
AuDebug::DecMemoryCrunch();
|
|
return pRet;
|
|
}
|
|
|
|
private:
|
|
|
|
void SubmitComplete()
|
|
{
|
|
if (AuAsync::GetCurrentWorkerPId() == this->pid)
|
|
{
|
|
if (!this->onFailure && !this->callback)
|
|
{
|
|
DoWaterFalls();
|
|
return;
|
|
}
|
|
|
|
SysAssert(!AuExchange(this->bDoneCb, true), "Future has already called a completion callback");
|
|
|
|
if (this->bComplete)
|
|
{
|
|
if (auto callback = AuExchange(this->callback, {}))
|
|
{
|
|
if constexpr (AuIsVoid_v<T>)
|
|
{
|
|
callback();
|
|
}
|
|
else
|
|
{
|
|
callback(this->value);
|
|
}
|
|
}
|
|
}
|
|
else if (this->bFailed)
|
|
{
|
|
if (auto callback = AuExchange(this->onFailure, {}))
|
|
{
|
|
if constexpr (AuIsVoid_v<Error_t>)
|
|
{
|
|
callback();
|
|
}
|
|
else
|
|
{
|
|
callback(this->errorValue);
|
|
}
|
|
}
|
|
}
|
|
|
|
DoWaterFalls();
|
|
}
|
|
else
|
|
{
|
|
if (!this->onFailure && !this->callback)
|
|
{
|
|
DoWaterFalls();
|
|
return;
|
|
}
|
|
|
|
AuDebug::AddMemoryCrunch();
|
|
AuAsync::NewWorkItem(this->pid.value(), AuMakeSharedPanic<AuAsync::BasicWorkStdFunc>([pThat = this->SharedFromThis()]
|
|
{
|
|
AU_LOCK_GUARD(pThat->mutex);
|
|
pThat->SubmitComplete();
|
|
}))->Dispatch();
|
|
AuDebug::DecMemoryCrunch();
|
|
}
|
|
}
|
|
|
|
void DoWaterFalls()
|
|
{
|
|
auto callbacks = AuExchange(this->waterfall, {});
|
|
|
|
for (const auto &callback : callbacks)
|
|
{
|
|
callback(this->bComplete, this->bFailed);
|
|
}
|
|
}
|
|
|
|
void AddWaterFall(AuConsumer<bool, bool> callback)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
|
|
if (this->bDoneCb)
|
|
{
|
|
callback(this->bComplete, this->bFailed);
|
|
return;
|
|
}
|
|
|
|
AuDebug::AddMemoryCrunch();
|
|
SysAssert(AuTryInsert(this->waterfall, callback));
|
|
AuDebug::DecMemoryCrunch();
|
|
}
|
|
|
|
AuFuture()
|
|
{
|
|
this->pid = AuAsync::GetCurrentWorkerPId();
|
|
}
|
|
|
|
AuFuture(AuConsumer<Move_t> callback) : callback(callback)
|
|
{
|
|
this->pid = AuAsync::GetCurrentWorkerPId();
|
|
}
|
|
|
|
AuFuture(AuConsumer<Move_t> callback, ErrorCallback_f onFailure) : callback(callback), onFailure(onFailure)
|
|
{
|
|
this->pid = AuAsync::GetCurrentWorkerPId();
|
|
}
|
|
|
|
CppFun<T>::B value;
|
|
ErrorStore_t errorValue;
|
|
AuThreadPrimitives::Mutex mutex;
|
|
CompleteCallback_f callback;
|
|
ErrorCallback_f onFailure;
|
|
AuOptionalEx<AuAsync::WorkerPId_t> pid; // todo: make weak?
|
|
bool bComplete {};
|
|
bool bFailed {};
|
|
bool bDone {};
|
|
bool bDoneCb {};
|
|
AuList<AuConsumer<bool, bool>> waterfall;
|
|
|
|
friend struct AuWaterfall;
|
|
};
|
|
|
|
template<typename T, typename Error_t = void>
|
|
using AuSharedFuture = AuSPtr<AuFuture<T, Error_t>>;
|
|
|
|
struct AuWaterfall : AuEnableSharedFromThis<AuWaterfall>
|
|
{
|
|
AU_NO_COPY_NO_MOVE(AuWaterfall);
|
|
|
|
|
|
AuWaterfall(bool bFailOnAny = true) :
|
|
bFailOnAny(bFailOnAny)
|
|
{
|
|
this->pFuture = AuFuture<void>::New();
|
|
}
|
|
|
|
template<typename T>
|
|
AuSPtr<AuWaterfall> AddFuture(AuSharedFuture<T> future)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
SysAssert(!this->bReady);
|
|
|
|
this->uCount++;
|
|
future->AddWaterFall([pThat = this->SharedFromThis()](bool bSuccess, bool bFailed)
|
|
{
|
|
bool bSendSuccess {};
|
|
bool bSendFail {};
|
|
|
|
if (bSuccess)
|
|
{
|
|
++pThat->uCountOfComplete;
|
|
}
|
|
else if (bFailed)
|
|
{
|
|
++pThat->uCountOfFailed;
|
|
}
|
|
|
|
if (!pThat->bReady)
|
|
{
|
|
return;
|
|
}
|
|
|
|
AU_LOCK_GUARD(pThat->mutex);
|
|
pThat->FireDelayed();
|
|
});
|
|
|
|
return this->SharedFromThis();
|
|
}
|
|
|
|
void OnFailure(AuVoidFunc onFailure)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
|
|
if (this->bDone)
|
|
{
|
|
auto [bSendSuccess, bSendFail] = this->GetDispatch(true);
|
|
|
|
if (bSendFail)
|
|
{
|
|
onFailure();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
this->onFailure.push_back(onFailure);
|
|
this->Start();
|
|
this->FireDelayed();
|
|
}
|
|
}
|
|
|
|
void OnSuccess(AuVoidFunc onSuccess)
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
|
|
if (this->bDone)
|
|
{
|
|
auto [bSendSuccess, bSendFail] = this->GetDispatch(true);
|
|
|
|
if (bSendSuccess)
|
|
{
|
|
onSuccess();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
this->onSuccess.push_back(onSuccess);
|
|
this->Start();
|
|
this->FireDelayed();
|
|
}
|
|
}
|
|
|
|
static AuSPtr<AuWaterfall> New(bool bFailOnAny = true)
|
|
{
|
|
return AuMakeSharedThrow<AuWaterfall>(bFailOnAny);
|
|
}
|
|
|
|
private:
|
|
|
|
AuPair<bool, bool> GetDispatch(bool bForce = false)
|
|
{
|
|
bool bSendSuccess {};
|
|
bool bSendFail {};
|
|
|
|
if ((this->bFailOnAny && bool(this->uCountOfFailed)) ||
|
|
(this->uCountOfFailed == this->uCount))
|
|
{
|
|
bSendFail = bool(this->onFailure.size()) || bForce;
|
|
}
|
|
else if ((!this->bFailOnAny || !this->uCountOfFailed) &&
|
|
this->uCountOfComplete == this->uCount)
|
|
{
|
|
bSendSuccess = bool(this->onSuccess.size()) || bForce;
|
|
}
|
|
else if (!this->bFailOnAny && ((this->uCountOfComplete + this->uCountOfFailed) == this->uCount))
|
|
{
|
|
bSendSuccess = bool(this->onSuccess.size()) || bForce;
|
|
}
|
|
|
|
return AuMakePair(bSendSuccess, bSendFail);
|
|
}
|
|
|
|
void FireDelayed()
|
|
{
|
|
auto [bSendSuccess, bSendFail] = this->GetDispatch(false);
|
|
|
|
if (!bSendSuccess && !bSendFail)
|
|
{
|
|
return;
|
|
}
|
|
|
|
if (bSendFail)
|
|
{
|
|
this->bFailed = true;
|
|
}
|
|
|
|
if (AuExchange(this->bDone, true))
|
|
{
|
|
// Miss?
|
|
return;
|
|
}
|
|
|
|
if (bSendSuccess)
|
|
{
|
|
this->pFuture->Complete();
|
|
}
|
|
else if (bSendFail)
|
|
{
|
|
this->pFuture->Fail();
|
|
}
|
|
}
|
|
|
|
void Start()
|
|
{
|
|
if (AuExchange(this->bReady, true))
|
|
{
|
|
return;
|
|
}
|
|
|
|
this->pFuture->OnComplete([pThat = this->SharedFromThis()]()
|
|
{
|
|
pThat->FireSuccess();
|
|
});
|
|
|
|
this->pFuture->OnFailure([pThat = this->SharedFromThis()]()
|
|
{
|
|
pThat->FireFailure();
|
|
|
|
});
|
|
}
|
|
|
|
void FireSuccess()
|
|
{
|
|
decltype(onSuccess) callbacks;
|
|
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
callbacks = AuExchange(this->onSuccess, {});
|
|
this->onFailure.clear();
|
|
}
|
|
|
|
for (const auto &callback : callbacks)
|
|
{
|
|
callback();
|
|
}
|
|
}
|
|
|
|
void FireFailure()
|
|
{
|
|
decltype(onSuccess) callbacks;
|
|
|
|
{
|
|
AU_LOCK_GUARD(this->mutex);
|
|
callbacks = AuExchange(this->onFailure, {});
|
|
this->onSuccess.clear();
|
|
}
|
|
|
|
for (const auto &callback : callbacks)
|
|
{
|
|
callback();
|
|
}
|
|
}
|
|
|
|
AuSharedFuture<void> pFuture;
|
|
|
|
AuList<AuVoidFunc> onSuccess;
|
|
AuList<AuVoidFunc> onFailure;
|
|
|
|
AuThreadPrimitives::CriticalSection mutex;
|
|
bool bFailOnAny;
|
|
AuUInt uCount {};
|
|
AuUInt uCountOfComplete {};
|
|
AuUInt uCountOfFailed {};
|
|
bool bReady {};
|
|
bool bDone {};
|
|
bool bFailed {};
|
|
};
|
|
|
|
using AuSharedWaterfall = AuWaterfall;
|