127 lines
5.4 KiB
C++
127 lines
5.4 KiB
C++
|
/***
|
||
|
Copyright (C) 2021 J Reece Wilson (a/k/a "Reece"). All rights reserved.
|
||
|
|
||
|
File: ThreadPool.hpp
|
||
|
Date: 2021-10-30
|
||
|
Author: Reece
|
||
|
***/
|
||
|
#pragma once
|
||
|
|
||
|
namespace Aurora::Async
|
||
|
{
|
||
|
struct GroupState;
|
||
|
struct ThreadState;
|
||
|
//class WorkItem;
|
||
|
|
||
|
struct IThreadPoolInternal
|
||
|
{
|
||
|
virtual bool WaitFor(WorkerId_t unlocker, const AuSPtr<Threading::IWaitable> &primitive, AuUInt32 ms) = 0;
|
||
|
virtual void Run(WorkerId_t target, AuSPtr<IAsyncRunnable> runnable) = 0;
|
||
|
virtual IThreadPool *ToThreadPool() = 0;
|
||
|
virtual void IncrementTasksRunning() = 0;
|
||
|
virtual void DecrementTasksRunning() = 0;
|
||
|
};
|
||
|
|
||
|
|
||
|
struct ThreadPool : public IThreadPool, public IThreadPoolInternal, std::enable_shared_from_this<ThreadPool>
|
||
|
{
|
||
|
ThreadPool();
|
||
|
|
||
|
// IThreadPoolInternal
|
||
|
bool WaitFor(WorkerId_t unlocker, const AuSPtr<Threading::IWaitable> &primitive, AuUInt32 ms) override;
|
||
|
void Run(WorkerId_t target, AuSPtr<IAsyncRunnable> runnable) override;
|
||
|
IThreadPool *ToThreadPool() override;
|
||
|
void IncrementTasksRunning() override;
|
||
|
void DecrementTasksRunning() override;
|
||
|
|
||
|
// IThreadPool
|
||
|
virtual bool Spawn(WorkerId_t workerId) override;
|
||
|
|
||
|
virtual void SetRunningMode(bool eventRunning) override;
|
||
|
|
||
|
virtual bool Create(WorkerId_t workerId) override;
|
||
|
|
||
|
virtual bool InRunnerMode() override;
|
||
|
|
||
|
virtual bool Poll() override;
|
||
|
virtual bool RunOnce() override;
|
||
|
virtual bool Run() override;
|
||
|
|
||
|
virtual void Shutdown() override;
|
||
|
virtual bool Exiting() override;
|
||
|
|
||
|
virtual AuSPtr<IWorkItem> NewWorkItem(const WorkerId_t &worker, const AuSPtr<IWorkItemHandler> &task, bool supportsBlocking) override;
|
||
|
virtual AuSPtr<IWorkItem> NewFence() override;
|
||
|
|
||
|
virtual Threading::Threads::ThreadShared_t ResolveHandle(WorkerId_t) override;
|
||
|
|
||
|
virtual AuBST<ThreadGroup_t, AuList<ThreadId_t>> GetThreads() override;
|
||
|
|
||
|
virtual WorkerId_t GetCurrentThread() override;
|
||
|
|
||
|
virtual bool Sync(WorkerId_t workerId, AuUInt32 timeoutMs, bool requireSignal) override;
|
||
|
virtual void Signal(WorkerId_t workerId) override;
|
||
|
virtual void SyncAllSafe() override;
|
||
|
|
||
|
virtual void AddFeature(WorkerId_t id, AuSPtr<Threading::Threads::IThreadFeature> feature, bool async) override;
|
||
|
|
||
|
virtual void AssertInThreadGroup(ThreadGroup_t group) override;
|
||
|
virtual void AssertWorker(WorkerId_t id) override;
|
||
|
|
||
|
virtual bool ScheduleLoopSource(const AuSPtr<Loop::ILoopSource> &loopSource, WorkerId_t workerId, AuUInt32 timeout, const AuConsumer<AuSPtr<Loop::ILoopSource>, bool> &callback) override;
|
||
|
|
||
|
// Internal API
|
||
|
|
||
|
bool Spawn(WorkerId_t workerId, bool create);
|
||
|
|
||
|
bool InternalRunOne(bool block);
|
||
|
bool PollInternal(bool block);
|
||
|
bool PollLoopSource(bool block);
|
||
|
|
||
|
size_t GetThreadWorkersCount(ThreadGroup_t group);
|
||
|
|
||
|
virtual void CleanUpWorker(WorkerId_t wid) {};
|
||
|
virtual void CleanWorkerPoolReservedZeroFree() {}; // calls shutdown under async apps
|
||
|
|
||
|
// Secret old fiber api
|
||
|
bool CtxYield();
|
||
|
int CtxPollPush();
|
||
|
void CtxPollReturn(const AuSPtr<ThreadState> &state, int status, bool hitTask);
|
||
|
|
||
|
// TLS handle
|
||
|
struct WorkerWPId_t : WorkerId_t
|
||
|
{
|
||
|
WorkerWPId_t()
|
||
|
{}
|
||
|
|
||
|
WorkerWPId_t(const WorkerPId_t &ref) : WorkerId_t(ref.first, ref.second), pool(ref.pool)
|
||
|
{}
|
||
|
|
||
|
AuWPtr<IThreadPool> pool;
|
||
|
};
|
||
|
|
||
|
AuThreads::TLSVariable<WorkerWPId_t> tlsWorkerId;
|
||
|
|
||
|
private:
|
||
|
// TODO: BarrierMultiple
|
||
|
bool Barrier(WorkerId_t, AuUInt32 ms, bool requireSignal, bool drop);
|
||
|
|
||
|
protected:
|
||
|
void Entrypoint(WorkerId_t id);
|
||
|
|
||
|
private:
|
||
|
void ThisExiting();
|
||
|
|
||
|
AuSPtr<GroupState> GetGroup(ThreadGroup_t type);
|
||
|
AuSPtr<ThreadState> GetThreadState();
|
||
|
AuSPtr<ThreadState> GetThreadHandle(WorkerId_t id);
|
||
|
|
||
|
using ThreadDb_t = AuBST<ThreadGroup_t, AuSPtr<GroupState>>;
|
||
|
|
||
|
ThreadDb_t threads_;
|
||
|
bool shuttingdown_ {};
|
||
|
AuThreadPrimitives::RWLockUnique_t rwlock_;
|
||
|
std::atomic_int tasksRunning_;
|
||
|
bool runnersRunning_;
|
||
|
};
|
||
|
}
|