AuroraRuntime/Source/Async/ThreadPool.cpp

1643 lines
44 KiB
C++

/***
Copyright (C) 2021 J Reece Wilson (a/k/a "Reece"). All rights reserved.
File: ThreadPool.cpp
Date: 2021-10-30
Author: Reece
***/
#include <Source/RuntimeInternal.hpp>
#include "Async.hpp"
#include "ThreadPool.hpp"
#include "AsyncApp.hpp"
#include "WorkItem.hpp"
#include "AuSchedular.hpp"
#include "ThreadWorkerQueueShim.hpp"
#include "IAsyncRunnable.hpp"
#include "AuAsyncFuncRunnable.hpp"
#include "AuAsyncFuncWorker.hpp"
namespace Aurora::Async
{
//STATIC_TLS(WorkerId_t, tlsWorkerId);
static thread_local AuWPtr<ThreadPool> tlsCurrentThreadPool;
inline auto GetWorkerInternal(const AuSPtr<IThreadPool> &pool)
{
if (pool.get() == AuStaticCast<IAsyncApp>(gAsyncApp))
{
return AuUnsafeRaiiToShared(AuStaticCast<ThreadPool>(gAsyncApp));
}
return AuStaticPointerCast<ThreadPool>(pool);
}
AUKN_SYM WorkerPId_t GetCurrentWorkerPId()
{
auto lkPool = tlsCurrentThreadPool.lock();
if (!lkPool)
{
return {};
}
auto cpy = *lkPool->tlsWorkerId;
if (auto pPool = AuTryLockMemoryType(cpy.pool))
{
return WorkerPId_t(pPool, cpy);
}
else
{
return {};
}
}
//
ThreadPool::ThreadPool() : shutdownEvent_(false, false, true)
{
this->pRWReadView = this->rwlock_->AsReadable();
}
// internal pool interface
bool ThreadPool::WaitFor(WorkerId_t unlocker, const AuSPtr<Threading::IWaitable> &primitive, AuUInt32 timeoutMs)
{
return WaitFor(WorkerPId_t { AuAsync::GetCurrentWorkerPId().GetPool(), unlocker }, primitive, timeoutMs);
}
bool ThreadPool::WaitFor(WorkerPId_t unlocker, const AuSPtr<Threading::IWaitable> &primitive, AuUInt32 timeoutMs)
{
AuUInt64 uEndTimeNS = timeoutMs ?
AuTime::SteadyClockNS() + AuMSToNS<AuUInt64>(timeoutMs) :
0;
if (auto pCurThread = GetThreadState())
{
bool bStat {};
{
bStat = !bool(unlocker);
}
if (!bStat)
{
// old shid (to clean up)
bool bWorkerIdMatches = (unlocker.second == pCurThread->thread.id.second) ||
((unlocker.second == Async::kThreadIdAny) &&
(GetThreadWorkersCount(unlocker.first) == 1));
bStat = (unlocker.first == pCurThread->thread.id.first) &&
(unlocker.GetPool().get() == this) &&
(bWorkerIdMatches);
}
if (bStat)
{
while (true)
{
AuUInt32 didntAsk;
bool bTimedOut {};
if (primitive->TryLock())
{
return true;
}
this->InternalRunOne(pCurThread, true, false, didntAsk);
if (uEndTimeNS)
{
bTimedOut = AuTime::SteadyClockNS() >= uEndTimeNS;
}
if (primitive->TryLock())
{
return true;
}
if (!AuAtomicLoad(&this->shuttingdown_) ||
bTimedOut)
{
return false;
}
}
}
}
{
AuSPtr<ThreadState> pHandle;
if (auto pPool = unlocker)
{
auto pPoolEx = AuStaticCast<ThreadPool>(unlocker.GetPool());
AU_LOCK_GLOBAL_GUARD(pPoolEx->rwlock_->AsReadable());
auto pShutdownLock = Aurora::Threading::GetShutdownReadLock();
if ((pHandle = AuStaticCast<ThreadPool>(unlocker.GetPool())->GetThreadHandle(unlocker)))
{
AU_LOCK_GUARD(pHandle->externalFencesLock);
if (pHandle->exitingflag2)
{
pShutdownLock->Unlock();
bool bRet = primitive->TryLock();
pShutdownLock->Lock();
return bRet;
}
else
{
pHandle->externalFences.push_back(primitive.get());
}
}
else if (unlocker.GetPool().get() == this)
{
pShutdownLock->Unlock();
bool bRet = primitive->LockMS(timeoutMs);
pShutdownLock->Lock();
return bRet;
}
}
bool bRet = primitive->LockAbsNS(uEndTimeNS);
if (pHandle)
{
AU_LOCK_GLOBAL_GUARD(pHandle->externalFencesLock);
AuTryRemove(pHandle->externalFences, primitive.get());
}
return bRet;
}
}
void ThreadPool::Run(WorkerId_t target, AuSPtr<IAsyncRunnable> runnable)
{
return this->Run(target, runnable, true);
}
void ThreadPool::Run(WorkerId_t target, AuSPtr<IAsyncRunnable> runnable, bool bIncrement)
{
AuSPtr<ThreadState> pWorker;
auto pGroupState = GetGroup(target.first);
SysAssert(static_cast<bool>(pGroupState), "couldn't dispatch a task to an offline group");
if (target.second != Async::kThreadIdAny)
{
pWorker = pGroupState->GetThreadByIndex(target.second);
if (!pWorker)
{
runnable->CancelAsync();
return;
}
if (pWorker->shutdown.bDropSubmissions)
{
runnable->CancelAsync();
return;
}
}
if (bIncrement)
{
AuAtomicAdd(&this->uAtomicCounter, 1u);
}
pGroupState->workQueue.AddWorkEntry(AuMakePair(target.second, runnable));
if (target.second == Async::kThreadIdAny)
{
pGroupState->SignalAll();
}
else
{
pWorker->sync.SetEvent(true, true);
}
}
IThreadPool *ThreadPool::ToThreadPool()
{
return this;
}
// ithreadpool
size_t ThreadPool::GetThreadWorkersCount(ThreadGroup_t group)
{
return GetGroup(group)->workers.size();
}
void ThreadPool::SetRunningMode(bool eventRunning)
{
this->runnersRunning_ = eventRunning;
}
bool ThreadPool::Spawn(WorkerId_t workerId)
{
return Spawn(workerId, false);
}
bool ThreadPool::Create(WorkerId_t workerId)
{
return Spawn(workerId, true);
}
bool ThreadPool::InRunnerMode()
{
return this->runnersRunning_;
}
#define ASYNC_THREADGROUP_TLS_PREP \
auto &weakTlsHandle = tlsCurrentThreadPool; \
auto tlsHandle = AuTryLockMemoryType(weakTlsHandle); \
if (!tlsHandle || tlsHandle.get() != this) \
{ \
weakTlsHandle = AuSharedFromThis(); \
}
#define ASYNC_THREADGROUP_TLS_UNSET \
if (tlsHandle && tlsHandle.get() != this) \
{ \
tlsCurrentThreadPool = weakTlsHandle; \
}
bool ThreadPool::Poll()
{
AuUInt32 uCount {};
ASYNC_THREADGROUP_TLS_PREP;
auto bRet = InternalRunOne(GetThreadStateLocal(), false, false, uCount);
ASYNC_THREADGROUP_TLS_UNSET;
return bRet;
}
bool ThreadPool::RunOnce()
{
AuUInt32 uCount {};
ASYNC_THREADGROUP_TLS_PREP;
auto bRet = InternalRunOne(GetThreadStateLocal(), true, false, uCount);
ASYNC_THREADGROUP_TLS_UNSET;
return bRet;
}
bool ThreadPool::Run()
{
bool ranOnce {};
ASYNC_THREADGROUP_TLS_PREP;
auto pJobRunner = GetThreadStateLocal();
if (!pJobRunner)
{
ASYNC_THREADGROUP_TLS_UNSET;
this->shutdownEvent_->LockMS(0);
return true;
}
auto auThread = AuThreads::GetThread();
while ((!auThread->Exiting()) &&
((AuAtomicLoad(&this->shuttingdown_) & 2) != 2) &&
(!pJobRunner->shutdown.bBreakMainLoop))
{
AuUInt32 uCount {};
// Do work (blocking)
if (!InternalRunOne(pJobRunner, true, true, uCount))
{
if ((AuAtomicLoad(&this->shuttingdown_) & 2) == 2)
{
return ranOnce;
}
}
ranOnce = true;
}
ASYNC_THREADGROUP_TLS_UNSET;
return ranOnce;
}
bool ThreadPool::InternalRunOne(AuSPtr<ThreadState> state, bool block, bool bUntilWork, AuUInt32 &uCount)
{
bool bSuccess {};
if (!state)
{
SysPushErrorUninitialized("Not an async thread");
return false;
}
EarlyExitTick();
{
auto asyncLoop = state->asyncLoop;
asyncLoop->OnFrame();
if (asyncLoop->GetSourceCount() > 1)
{
if (block)
{
asyncLoop->WaitAny(0);
}
else
{
asyncLoop->PumpNonblocking();
}
bSuccess = PollInternal(state, false, bUntilWork, uCount);
}
else
{
bSuccess = PollInternal(state, block, bUntilWork, uCount);
}
}
EarlyExitTick();
return bSuccess;
}
#if defined(__AUHAS_COROUTINES_CO_AWAIT) && defined(AU_LANG_CPP_20_)
AuVoidTask ThreadPool::PollInternal_ForceCoRoutine(AuSPtr<ThreadState> state, bool block, bool bUntilWork, AuUInt32 &uCount, bool &bRet)
{
bRet = PollInternal_Base(state, block, bUntilWork, uCount);
co_return;
}
#endif
bool ThreadPool::PollInternal(AuSPtr<ThreadState> state, bool block, bool bUntilWork, AuUInt32 &uCount)
{
#if defined(__AUHAS_COROUTINES_CO_AWAIT) && defined(AU_LANG_CPP_20_)
if (state->stackState.uStackCallDepth &&
gRuntimeConfig.async.bEnableCpp20RecursiveCallstack)
{
bool bRet {};
PollInternal_ForceCoRoutine(state, block, bUntilWork, uCount, bRet);
return bRet;
}
#endif
return PollInternal_Base(state, block, bUntilWork, uCount);
}
bool ThreadPool::PollInternal_Base(AuSPtr<ThreadState> state, bool block, bool bUntilWork, AuUInt32 &uCount)
{
if (!state)
{
SysPushErrorUninitialized("Not an async thread");
return false;
}
auto group = state->parent.lock();
{
AU_LOCK_GUARD(state->sync.cvWorkMutex);
do
{
bool bFailedOOM = group->workQueue.Dequeue(state->pendingWorkItems,
state->stackState.uWorkMultipopCount,
state->thread.id.second);
state->sync.UpdateCVState(state.get());
// Consider blocking for more work
if (!block)
{
break;
}
// OOM: hardened: sleep for 0.01MS if the heap for task dequeue is full.
// Until the mixed heap object is implemented, we can only dequeue 2^16 tasks globally at a time into a reserved heap.
if (!bFailedOOM)
{
if (state->pendingWorkItems.empty())
{
AuThreading::SleepNs(10'000);
continue;
}
else
{
break;
}
}
// Block if no work items are present
if (state->pendingWorkItems.empty())
{
// pre-wakeup thread terminating check
if (state->thread.pThread->Exiting())
{
break;
}
if (AuAtomicLoad(&this->shuttingdown_) & 2)
{
break;
}
// OOM: hardened: do not sleep after OOM re-try
if (group->workQueue.IsEmpty(this, state->thread.id))
{
state->sync.cvVariable->WaitForSignal();
}
if (AuAtomicLoad(&this->shuttingdown_) & 2)
{
break;
}
// Post-wakeup thread terminating check
if (state->thread.pThread->Exiting())
{
break;
}
}
if (state->pendingWorkItems.empty() && (
(this->GetThreadState()->asyncLoop->GetSourceCount() > 1) ||
this->GetThreadState()->asyncLoop->CommitPending())) //(this->ToKernelWorkQueue()->IsSignaledPeek()))
{
return false;
}
}
while (state->pendingWorkItems.empty() && (block && bUntilWork));
if (!block &&
!(this->shuttingdown_ & 2)) // quick hack: is worthy of io reset by virtue of having polled externally (most likely for IO ticks, unlikely for intraprocess ticks)
{
AU_LOCK_GLOBAL_GUARD(group->workQueue.mutex); // dont atomically increment our work counters [signal under mutex group]...
AU_LOCK_GUARD(group->workersMutex); // dont atomically increment our work counters [broadcast]...
// ...these primitives are far less expensive to hit than resetting kernel primitives
// AU_LOCK_GUARD(state->cvWorkMutex) used to protect us
if (group->workQueue.IsEmpty(this, state->thread.id))
{
state->sync.eventLs->Reset(); // ...until we're done
AuAtomicStore(&state->sync.cvLSActive, 0u);
}
}
}
if (state->pendingWorkItems.empty())
{
if (InRunnerMode())
{
if ((AuAtomicLoad(&this->uAtomicCounter) == 0) &&
this->IsDepleted(state))
{
Shutdown();
}
}
return false;
}
int runningTasks {};
auto uStartCookie = state->stackState.uStackCookie++;
// Account for
// while (AuAsync.GetCurrentPool()->runForever());
// in the first task (or deeper)
if (InRunnerMode() && state->stackState.uStackCallDepth) // are we one call deep?
{
if ((AuAtomicLoad(&this->uAtomicCounter) == state->stackState.uStackCallDepth) &&
this->IsDepleted(state))
{
return false;
}
}
//
for (auto itr = state->pendingWorkItems.begin(); itr != state->pendingWorkItems.end(); )
{
if (state->thread.pThread->Exiting())
{
break;
}
// Dispatch
auto oops = itr->second;
// Remove from our local job queue
itr = state->pendingWorkItems.erase(itr);
state->stackState.uStackCallDepth++;
//SysBenchmark(fmt::format("RunAsync: {}", block));
// Dispatch
if (oops)
{
oops->RunAsync();
}
uCount++;
// Atomically decrement global task counter
runningTasks = AuAtomicSub(&this->uAtomicCounter, 1u);
state->stackState.uStackCallDepth--;
if (uStartCookie != state->stackState.uStackCookie)
{
uStartCookie = state->stackState.uStackCookie;
itr = state->pendingWorkItems.begin();
}
}
// Return popped work back to the groups work pool when our -pump loops were preempted
if (state->pendingWorkItems.size())
{
AU_LOCK_GUARD(state->sync.cvWorkMutex);
for (const auto &item : state->pendingWorkItems)
{
group->workQueue.AddWorkEntry(item);
}
state->pendingWorkItems.clear();
state->sync.cvVariable->Broadcast();
state->sync.eventLs->Set();
}
// Account for
// while (AuAsync.GetCurrentPool()->runForever());
// in the top most task
if (InRunnerMode())
{
if ((runningTasks == 0) &&
(AuAtomicLoad(&this->uAtomicCounter) == 0) &&
this->IsDepleted(state))
{
Shutdown();
}
}
return true;
}
// While much of this subsystem needs good rewrite, under no circumstance should the shutdown process be "simpified" or "cleaned up"
// This is our expected behaviour. Any changes will likely introduce hard to catch bugs across various softwares and exit conditions.
void ThreadPool::Shutdown()
{
AU_DEBUG_MEMCRUNCH;
auto trySelfPid = AuAsync::GetCurrentWorkerPId();
// Update shutting down flag
// Specify the root-level shutdown flag for 'ok, u can work, but you're shutting down soon [microseconds, probably]'
{
if (AuAtomicTestAndSet(&this->shuttingdown_, 0) != 0)
{
return;
}
}
auto pLocalRunner = this->GetThreadStateNoWarn();
AuList<WorkerId_t> toBarrier;
// wait for regular prio work to complete
{
for (auto pGroup : this->threadGroups_)
{
if (!pGroup)
{
continue;
}
AU_LOCK_GLOBAL_GUARD(pGroup->workersMutex);
for (auto &[id, worker] : pGroup->workers)
{
if (trySelfPid == worker->thread.id)
{
continue;
}
toBarrier.push_back(worker->thread.id);
}
}
for (const auto &id : toBarrier)
{
if (trySelfPid == id)
{
continue;
}
this->Barrier(id, 0, false, false /* no reject*/); // absolute safest point in shutdown; sync to already submitted work
}
}
// increment abort cookies
{
for (const auto &id : toBarrier)
{
if (trySelfPid == id)
{
continue;
}
AuAtomicAdd(&this->uAtomicShutdownCookie, 1u);
}
}
// set shutdown flags
{
AuAtomicTestAndSet(&this->shuttingdown_, 1);
}
// Finally set the shutdown flag on all of our thread contexts
// then release them from the runners/workers list
// then release all group contexts
AuList<AuThreads::ThreadShared_t> threads;
AuList<AuSPtr<ThreadState>> states;
{
AU_LOCK_GLOBAL_GUARD(this->pRWReadView);
for (auto pGroup : this->threadGroups_)
{
if (!pGroup)
{
continue;
}
for (auto &[id, pState] : pGroup->workers)
{
// main loop:
if (pState)
{
states.push_back(pState);
pState->shuttingdown = true;
}
else
{
pState->shuttingdown = true;
}
// thread object:
if (pState->thread.bOwnsThread)
{
pState->thread.pThread->SendExitSignal();
threads.push_back(pState->thread.pThread);
}
// unrefreeze signals:
auto &event = pState->running;
if (event)
{
event->Set();
}
}
}
}
// Break all condvar loops, just in case
for (const auto &pState : states)
{
pState->sync.SetEvent();
}
// Final sync to exit
{
for (const auto &id : toBarrier)
{
if (trySelfPid == id)
{
continue;
}
auto handle = this->GetThreadHandle(id);
if (handle)
{
handle->shutdown.bDropSubmissions = false;
handle->isDeadEvent->LockMS(250);
}
}
}
// Sync to shutdown threads to prevent a race condition whereby the async subsystem shuts down before the threads
auto pSelf = AuThreads::GetThread();
for (const auto &thread : threads)
{
if (thread.get() != pSelf)
{
thread->Terminate();
}
}
// Is dead flag
this->shutdownEvent_->Set();
//
if (pLocalRunner)
{
pLocalRunner->shutdown.bIsKillerThread = true;
}
// Notify observing threads of our work exhaustion
for (const auto &wOther : this->listWeakDepsParents_)
{
if (auto pThat = AuTryLockMemoryType(wOther))
{
if (pThat->InRunnerMode())
{
continue;
}
if (!pThat->IsSelfDepleted(nullptr))
{
continue;
}
if (pThat->uAtomicCounter)
{
continue;
}
pThat->Shutdown();
}
}
}
bool ThreadPool::Exiting()
{
return this->shuttingdown_ ||
GetThreadState()->exiting;
}
AuUInt32 ThreadPool::PollAndCount(bool bStrict)
{
AuUInt32 uCount {};
ASYNC_THREADGROUP_TLS_PREP;
auto bRanAtLeastOne = this->InternalRunOne(this->GetThreadStateNoWarn(), false, false, uCount);
ASYNC_THREADGROUP_TLS_UNSET;
return uCount ? uCount : (bStrict ? bRanAtLeastOne : 0);
}
AuUInt32 ThreadPool::RunAllPending()
{
AuUInt32 uCount {};
AuUInt32 uCountTotal {};
ASYNC_THREADGROUP_TLS_PREP;
do
{
uCount = 0;
(void)this->InternalRunOne(this->GetThreadStateNoWarn(), false, true, uCount);
uCountTotal += uCount;
}
while (uCount);
ASYNC_THREADGROUP_TLS_UNSET;
return uCountTotal;
}
AuSPtr<IWorkItem> ThreadPool::NewWorkItem(const WorkerId_t &worker,
const AuSPtr<IWorkItemHandler> &task)
{
// Error pass-through
if (!task)
{
return {};
}
return AuMakeShared<WorkItem>(this, WorkerPId_t { this->SharedFromThis(), worker }, task);
}
AuSPtr<IWorkItem> ThreadPool::NewWorkFunction(const WorkerId_t &worker,
AuVoidFunc callback)
{
SysAssert(callback);
return AuMakeShared<AsyncFuncWorker>(this, WorkerPId_t { this->SharedFromThis(), worker }, AuMove(callback));
}
AuSPtr<IWorkItem> ThreadPool::NewFence()
{
return AuMakeShared<WorkItem>(this, AuAsync::GetCurrentWorkerPId(), AuSPtr<IWorkItemHandler>{});
}
AuThreads::ThreadShared_t ThreadPool::ResolveHandle(WorkerId_t id)
{
auto pState = GetThreadHandle(id);
if (!pState)
{
return {};
}
return pState->thread.pThread;
}
AuBST<ThreadGroup_t, AuList<ThreadId_t>> ThreadPool::GetThreads()
{
AU_DEBUG_MEMCRUNCH;
AuBST<ThreadGroup_t, AuList<ThreadId_t>> ret;
for (auto pGroup : this->threadGroups_)
{
AuList<ThreadId_t> workers;
if (!pGroup)
{
continue;
}
AU_LOCK_GUARD(pGroup->workersMutex);
AuTryReserve(workers, pGroup->workers.size());
for (const auto &thread : pGroup->workers)
{
workers.push_back(thread.second->thread.id.second);
}
ret[pGroup->group] = workers;
}
return ret;
}
WorkerId_t ThreadPool::GetCurrentThread()
{
return tlsWorkerId;
}
AuSPtr<AuIO::IIOProcessor> ThreadPool::GetIOProcessor(WorkerId_t id)
{
if (auto pState = this->GetThreadHandle(id))
{
return pState->singletons.GetIOProcessor({ this->SharedFromThis(), id });
}
return {};
}
AuSPtr<IO::CompletionGroup::ICompletionGroup> ThreadPool::GetIOGroup(WorkerId_t id)
{
if (auto pState = this->GetThreadHandle(id))
{
return pState->singletons.GetIOGroup({ this->SharedFromThis(), id });
}
return {};
}
AuSPtr<AuIO::Net::INetInterface> ThreadPool::GetIONetInterface(WorkerId_t id)
{
if (auto pState = this->GetThreadHandle(id))
{
return pState->singletons.GetIONetInterface({ this->SharedFromThis(), id });
}
return {};
}
AuSPtr<AuIO::Net::INetWorker> ThreadPool::GetIONetWorker(WorkerId_t id)
{
if (auto pState = this->GetThreadHandle(id))
{
return pState->singletons.GetIONetWorker({ this->SharedFromThis(), id });
}
return {};
}
bool ThreadPool::Sync(WorkerId_t workerId, AuUInt32 timeoutMs, bool requireSignal)
{
//AU_LOCK_GUARD(this->pRWReadView);
auto currentWorkerId = GetCurrentThread().second;
if (workerId.second == Async::kThreadIdAny)
{
decltype(GroupState::workers) workers;
{
AU_LOCK_GLOBAL_GUARD(this->pRWReadView);
if (auto pGroup = GetGroup(workerId.first))
{
workers = pGroup->workers;
}
}
for (auto &jobWorker : workers)
{
if (!Barrier(jobWorker.second->thread.id, timeoutMs, requireSignal && jobWorker.second->thread.id.second != currentWorkerId, false)) // BAD!, should subtract time elapsed, clamp to, i dunno, 5ms min?
{
return false;
}
}
}
else
{
return Barrier(workerId, timeoutMs, requireSignal && workerId.second != currentWorkerId, false);
}
return true;
}
void ThreadPool::Signal(WorkerId_t workerId)
{
auto group = GetGroup(workerId.first);
if (workerId.second == Async::kThreadIdAny)
{
AU_LOCK_GLOBAL_GUARD(group->workersMutex);
for (auto &jobWorker : group->workers)
{
jobWorker.second->running->Set();
}
}
else if (auto pThread = GetThreadHandle(workerId))
{
pThread->running->Set();
}
}
void ThreadPool::Wakeup(WorkerId_t workerId)
{
auto group = GetGroup(workerId.first);
if (workerId.second == Async::kThreadIdAny)
{
group->SignalAll(false);
}
else if (auto pThread = GetThreadHandle(workerId))
{
pThread->sync.SetEvent(true, false);
}
}
AuSPtr<AuLoop::ILoopSource> ThreadPool::WorkerToLoopSource(WorkerId_t workerId)
{
auto a = GetThreadHandle(workerId);
if (!a)
{
return {};
}
return a->sync.eventLs;
}
void ThreadPool::SyncAllSafe()
{
AU_LOCK_GLOBAL_GUARD(this->pRWReadView);
for (auto pGroup : this->threadGroups_)
{
if (!pGroup)
{
continue;
}
for (auto &jobWorker : pGroup->workers)
{
SysAssert(Barrier(jobWorker.second->thread.id, 0, false, false));
}
}
}
void ThreadPool::AddFeature(WorkerId_t id,
AuSPtr<AuThreads::IThreadFeature> pFeature,
bool bNonBlock)
{
auto pWorkItem = DispatchOn({ this->SharedFromThis(), id }, [=]()
{
auto pState = GetThreadState();
{
AU_LOCK_GUARD(pState->tlsFeatures.mutex);
pState->tlsFeatures.features.push_back(pFeature);
}
pFeature->Init();
});
if (!bNonBlock)
{
pWorkItem->BlockUntilComplete();
}
}
void ThreadPool::AssertInThreadGroup(ThreadGroup_t group)
{
SysAssert(static_cast<WorkerId_t>(tlsWorkerId).first == group);
}
void ThreadPool::AssertWorker(WorkerId_t id)
{
SysAssert(static_cast<WorkerId_t>(tlsWorkerId) == id);
}
AuSPtr<AuLoop::ILoopQueue> ThreadPool::ToKernelWorkQueue()
{
return this->GetThreadState()->asyncLoop;
}
AuSPtr<AuLoop::ILoopQueue> ThreadPool::ToKernelWorkQueue(WorkerId_t workerId)
{
auto worker = this->GetThreadHandle(workerId);
if (!worker)
{
SysPushErrorGeneric("Couldn't find requested worker");
return {};
}
return worker->asyncLoop;
}
bool ThreadPool::IsSelfDepleted(const AuSPtr<ThreadState> &pState)
{
AuSPtr<AuLoop::ILoopQueue> pLoopQueue;
if (pState)
{
pLoopQueue = pState->asyncLoop;
}
else
{
if (auto pLocalThread = this->GetThreadStateLocal())
{
pLoopQueue = pLocalThread->asyncLoop;
}
}
if (!pLoopQueue)
{
return true;
}
return pLoopQueue->GetSourceCount() <= 1 + this->uAtomicIOProcessorsWorthlessSources + this->uAtomicIOProcessors;
}
bool ThreadPool::IsDepleted(const AuSPtr<ThreadState> &state)
{
if (!IsSelfDepleted(state))
{
return false;
}
for (const auto &wOther : this->listWeakDeps_)
{
if (auto pThat = AuTryLockMemoryType(wOther))
{
if (!pThat->IsSelfDepleted(nullptr))
{
return false;
}
if (AuAtomicLoad(&pThat->uAtomicCounter))
{
return false;
}
}
}
return true;
}
void ThreadPool::AddDependency(AuSPtr<IThreadPool> pPool)
{
if (!pPool)
{
return;
}
auto pOther = AuStaticCast<ThreadPool>(pPool);
this->listWeakDeps_.push_back(pOther);
pOther->listWeakDepsParents_.push_back(AuSharedFromThis());
}
AuSPtr<AuThreading::IWaitable> ThreadPool::GetShutdownEvent()
{
return AuSPtr<AuThreading::IWaitable>(AuSharedFromThis(), this->shutdownEvent_.AsPointer());
}
// Unimplemented fiber hooks, 'twas used for science. no longer in use
int ThreadPool::CtxPollPush()
{
// TOOD (Reece): implement a context switching library
// Refer to the old implementation of this on pastebin
return 0;
}
void ThreadPool::CtxPollReturn(const AuSPtr<ThreadState> &state, int status, bool hitTask)
{
}
bool ThreadPool::CtxYield()
{
bool ranAtLeastOne = false;
// !!!
auto pA = this->GetThreadStateNoWarn();
if (AuAtomicLoad(&this->shuttingdown_) & 2) // fast
{
if (pA->shutdown.bDropSubmissions)
{
return false;
}
}
AuUInt32 uCount {};
#if 1
return this->InternalRunOne(pA, false, false, uCount);
#else
do
{
uCount = 0;
ranAtLeastOne |= this->InternalRunOne(pA, false, uCount);
}
while (uCount);
return uCount || ranAtLeastOne;
#endif
}
//
void ThreadPool::IncrementAbortFenceOnPool()
{
AuAtomicAdd(&this->uAtomicShutdownCookie, 1u);
}
void ThreadPool::IncrementAbortFenceOnWorker(WorkerId_t workerId)
{
auto group = GetGroup(workerId.first);
if (workerId.second == kThreadIdAny)
{
AU_LOCK_GLOBAL_GUARD(group->workersMutex);
for (auto &[jobWorker, pState]: group->workers)
{
AuAtomicAdd(&pState->shutdown.uShutdownFence, 1u);
}
}
else
{
if (auto pState = this->GetThreadHandle(workerId))
{
AuAtomicAdd(&pState->shutdown.uShutdownFence, 1u);
}
}
}
AuUInt64 ThreadPool::QueryAbortFence(AuOptional<WorkerId_t> optWorkerId)
{
if (auto pState = this->GetThreadHandle(optWorkerId.value_or(GetCurrentWorkerPId())))
{
return (AuUInt64(pState->shutdown.uShutdownFence) << 32ull) | AuUInt64(this->uAtomicShutdownCookie);
}
else
{
return this->uAtomicShutdownCookie;
}
}
bool ThreadPool::QueryShouldAbort(AuOptional<WorkerId_t> optWorkerId, AuUInt64 uFenceMagic)
{
auto uSelfCookie = AuBitsToLower(uFenceMagic);
if (uSelfCookie != AuAtomicLoad(&this->uAtomicShutdownCookie))
{
return true;
}
auto uThreadCookie = AuBitsToHigher(uFenceMagic);
if (!uThreadCookie)
{
return false;
}
if (auto pState = this->GetThreadHandle(optWorkerId.value_or(GetCurrentWorkerPId())))
{
return uThreadCookie != pState->shutdown.uShutdownFence;
}
else
{
return false;
}
}
// internal api
bool ThreadPool::Spawn(WorkerId_t workerId, bool create)
{
AU_LOCK_GLOBAL_GUARD(this->rwlock_->AsWritable());
if (create)
{
tlsCurrentThreadPool = AuSharedFromThis();
}
AuSPtr<GroupState> pGroup;
// Try fetch or allocate group
{
if (!(pGroup = threadGroups_[workerId.first]))
{
pGroup = AuMakeShared<GroupState>();
if (!pGroup->Init())
{
SysPushErrorMemory("Not enough memory to intiialize a new group state");
return false;
}
pGroup->group = workerId.first;
this->threadGroups_[workerId.first] = pGroup;
}
}
// Assert worker does not already exist
{
AuSPtr<ThreadState>* ret;
if (AuTryFind(pGroup->workers, workerId.second, ret))
{
SysPushErrorGeneric("Thread ID already exists");
return false;
}
}
auto pThreadState = pGroup->CreateWorker(workerId, create);
if (!pThreadState)
{
return {};
}
if (this->pHeap)
{
AuResetMember(pThreadState->pendingWorkItems, AuMemory::PmrCppHeapWrapper<WorkEntry_t>());
}
if (!create)
{
pThreadState->thread.pThread= AuThreads::ThreadShared(AuThreads::ThreadInfo(
AuMakeShared<AuThreads::IThreadVectorsFunctional>(AuThreads::IThreadVectorsFunctional::OnEntry_t(std::bind(&ThreadPool::Entrypoint, this, workerId)),
AuThreads::IThreadVectorsFunctional::OnExit_t{}),
gRuntimeConfig.async.threadPoolDefaultStackSize
));
if (!pThreadState->thread.pThread)
{
return {};
}
pThreadState->thread.pThread->Run();
}
else
{
pThreadState->thread.pThread = AuSPtr<AuThreads::IAuroraThread>(AuThreads::GetThread(), [](AuThreads::IAuroraThread *){});
// TODO: this is just a hack
// we should implement this properly
pThreadState->thread.pThread->AddLastHopeTlsHook(AuMakeShared<AuThreads::IThreadFeatureFunctional>([]() -> void
{
}, []() -> void
{
auto pid = GetCurrentWorkerPId();
if (pid)
{
GetWorkerInternal(pid.GetPool())->ThisExiting();
}
}));
tlsCurrentThreadPool = AuWeakFromThis();
tlsWorkerId = WorkerPId_t(AuSharedFromThis(), workerId);
}
pGroup->AddWorker(workerId.second, pThreadState);
return true;
}
// private api
AU_NOINLINE bool ThreadPool::Barrier(WorkerId_t workerId, AuUInt32 ms, bool requireSignal, bool drop)
{
auto self = GetThreadState();
if (!self)
{
return {};
}
auto &semaphore = self->syncSema;
auto unsafeSemaphore = semaphore.AsPointer();
bool failed {};
auto work = AuMakeShared<AsyncFuncRunnable>(
[=]()
{
auto state = GetThreadState();
if (drop)
{
state->shutdown.bDropSubmissions = true;
}
if (requireSignal)
{
state->running->Reset();
}
unsafeSemaphore->Unlock(1);
if (requireSignal)
{
state->running->Lock();
}
},
[&]()
{
unsafeSemaphore->Unlock(1);
failed = true;
}
);
if (!work)
{
return false;
}
Run(workerId, work);
return WaitFor(workerId, AuUnsafeRaiiToShared(semaphore.AsPointer()), ms) && !failed;
}
void ThreadPool::Entrypoint(WorkerId_t id)
{
{
AU_LOCK_GLOBAL_GUARD(this->pRWReadView);
}
tlsCurrentThreadPool = AuWeakFromThis();
tlsWorkerId = WorkerPId_t(AuSharedFromThis(), id);
auto job = GetThreadState();
Run();
if (id != WorkerId_t {0, 0})
{
AU_LOCK_GLOBAL_GUARD(this->pRWReadView);
if (!AuAtomicLoad(&this->shuttingdown_) && !job->shutdown.bDropSubmissions)
{
// Pump and barrier + reject all after atomically
Barrier(id, 0, false, true);
}
}
ThisExiting();
if (id == WorkerId_t {0, 0})
{
CleanWorkerPoolReservedZeroFree();
}
}
void ThreadPool::EarlyExitTick()
{
if ((AuAtomicLoad(&this->shuttingdown_) & 2) != 2)
{
return;
}
auto jobWorker = GetThreadState();
auto state = jobWorker->parent.lock();
if (!jobWorker)
{
SysPushErrorUninitialized("Not an async thread");
return;
}
state->SignalAll();
{
if (AuExchange(jobWorker->bAlreadyDoingExitTick, true))
{
return;
}
AuUInt32 uCount {};
do
{
uCount = 0;
this->PollInternal(jobWorker, false, false, uCount);
}
while (uCount);
}
AuList<AuSPtr<AuThreads::IThreadFeature>> features;
{
AU_LOCK_GUARD(jobWorker->tlsFeatures.mutex);
features = AuExchange(jobWorker->tlsFeatures.features, {});
}
{
for (const auto &thread : features)
{
try
{
thread->Cleanup();
}
catch (...)
{
SysPushErrorCatch("Couldn't clean up thread feature!");
}
}
jobWorker->isDeadEvent->Set();
jobWorker->bAlreadyDoingExitTick = false;
jobWorker->shutdown.bBreakMainLoop = true;
}
}
void ThreadPool::ThisExiting()
{
AU_DEBUG_MEMCRUNCH;
auto id = GetCurrentThread();
auto state = GetGroup(id.first);
auto pLocalState = state->GetThreadByIndex(id.second);
AuList<AuSPtr<AuThreads::IThreadFeature>> features;
{
AU_LOCK_GLOBAL_GUARD(this->pRWReadView);
pLocalState->isDeadEvent->Set();
CleanUpWorker(id);
TerminateSceduledTasks(this, id);
pLocalState->syncSema->Unlock(10); // prevent ::Barrier dead-locks
{
AU_LOCK_GUARD(pLocalState->externalFencesLock);
pLocalState->exitingflag2 = true;
for (const auto &pIWaitable : pLocalState->externalFences)
{
pIWaitable->Unlock();
}
pLocalState->externalFences.clear();
}
{
AU_LOCK_GUARD(pLocalState->tlsFeatures.mutex);
features = AuExchange(pLocalState->tlsFeatures.features, {});
}
}
{
for (const auto &thread : features)
{
try
{
thread->Cleanup();
}
catch (...)
{
SysPushErrorConcurrentRejected("Couldn't clean up thread feature!");
}
}
features.clear();
}
{
state->Decommit(id.second);
}
pLocalState->Deinit();
}
AuSPtr<GroupState> ThreadPool::GetGroup(ThreadGroup_t type)
{
return this->threadGroups_[type];
}
AuSPtr<ThreadState> ThreadPool::GetThreadState()
{
auto thread = tlsCurrentThreadPool.lock();
if (!thread)
{
return {};
}
#if defined(AU_CFG_ID_INTERNAL) || defined(AU_CFG_ID_DEBUG)
if (thread.get() != this)
{
SysPushErrorGeneric("wrong thread");
return {};
}
#endif
auto worker = *tlsWorkerId;
auto state = GetGroup(worker.first);
if (!state)
{
return {};
}
return state->GetThreadByIndex(worker.second);
}
AuSPtr<ThreadState> ThreadPool::GetThreadStateNoWarn()
{
auto thread = tlsCurrentThreadPool.lock();
if (!thread)
{
return {};
}
if (thread.get() != this)
{
return {};
}
auto worker = *tlsWorkerId;
auto state = GetGroup(worker.first);
if (!state)
{
return {};
}
return state->GetThreadByIndex(worker.second);
}
AuSPtr<ThreadState> ThreadPool::GetThreadStateLocal()
{
auto worker = *tlsWorkerId;
if (auto pSelf = AuTryLockMemoryType(worker.pool))
{
auto state = GetGroup(worker.first);
if (!state)
{
return {};
}
return state->GetThreadByIndex(worker.second);
}
else
{
return {};
}
}
AuSPtr<ThreadState> ThreadPool::GetThreadHandle(WorkerId_t id)
{
auto group = GetGroup(id.first);
if (!group)
{
return {};
}
return group->GetThreadByIndex(id.second);
}
AuList<AuSPtr<ThreadState>> ThreadPool::GetThreadHandles(WorkerId_t id)
{
auto group = GetGroup(id.first);
if (!group)
{
return {};
}
AuList<AuSPtr<ThreadState>> ret;
if (id.second != Async::kThreadIdAny)
{
if (auto pPtr = group->GetThreadByIndex(id.second))
{
ret.push_back(pPtr);
}
}
else
{
AU_LOCK_GLOBAL_GUARD(group->workersMutex);
for (const auto &[key, value] : group->workers)
{
ret.push_back(value);
}
}
return AuMove(ret);
}
AUKN_SYM AuSPtr<IThreadPool> NewThreadPool()
{
// apps that don't require async shouldn't be burdened with the overhead of this litl spiner
StartSched();
return AuMakeShared<ThreadPool>();
}
}