From 9711061dab67029d53cbf06702b9c29d708cfeb6 Mon Sep 17 00:00:00 2001 From: Jamie Reece Wilson Date: Fri, 29 Nov 2024 07:58:50 +0000 Subject: [PATCH] [+] Aurora::Threading::WaitForMultipleAddressesOr [+] Aurora::Threading::WaitForMultipleAddressesAnd [+] Aurora::Threading::WaitMultipleEntry [+] Aurora::Threading::WaitMulipleContainer --- Include/Aurora/Threading/WakeOnAddress.hpp | 42 ++ Source/Threading/AuWakeOnAddress.cpp | 470 +++++++++++++++++++-- Source/Threading/AuWakeOnAddress.hpp | 24 +- 3 files changed, 505 insertions(+), 31 deletions(-) diff --git a/Include/Aurora/Threading/WakeOnAddress.hpp b/Include/Aurora/Threading/WakeOnAddress.hpp index 84b527ad..a63551b7 100644 --- a/Include/Aurora/Threading/WakeOnAddress.hpp +++ b/Include/Aurora/Threading/WakeOnAddress.hpp @@ -101,6 +101,43 @@ namespace Aurora::Threading AUE_DEFINE(EWaitMethod, ( eNotEqual, eEqual, eLessThanCompare, eGreaterThanCompare, eLessThanOrEqualsCompare, eGreaterThanOrEqualsCompare, eAnd, eNotAnd )) + + struct AU_ALIGN(sizeof(void *)) WaitMultipleEntry + { + AuUInt8 internalContext[128]; + + // See WaitOnAddressSpecialSteady + EWaitMethod eMethod { EWaitMethod::eNotEqual }; + + // See WaitOnAddressSteady + union + { + const void * pTargetAddress; + const volatile void *pTargetVolatileAddress; + }; + + // See WaitOnAddressSteady + const void *pCompareAddress; + + // See WaitOnAddressSteady + AuUInt8 uSize {}; + + // For each valid state change, this counter gets incremented by 1, allowing for list reuse. + AuUInt16 uHasStateChangedCounter {}; + + // Skip this current entry + bool bIgnoreCurrentFlag {}; + }; + + struct WaitMulipleContainer + { + // Assign this to an virtually contiguous array of WaitMultipleEntry entries. + // A AuList, std::vector, WaitMultipleHead extent[N], or { raw pointer, count } will suffice. + AuMemoryViewWrite waitArray; + + // 0 = indefinite, AuTime::SteadyClockXXX convention + AuUInt64 qwNanoseconds {}; + }; AUKN_SYM void WakeAllOnAddress(const void *pTargetAddress); @@ -177,6 +214,11 @@ namespace Aurora::Threading AuUInt64 qwNanoseconds, AuOptional optAlreadySpun = {} /*hint: do not spin before switching. subject to global config.*/); + AUKN_SYM bool WaitForMultipleAddressesOr(const WaitMulipleContainer &waitMultipleOnAddress); + + AUKN_SYM bool WaitForMultipleAddressesAnd(const WaitMulipleContainer &waitMultipleOnAddress); + + // C++ doesn't allow for implicit casting between nonvolatile and volatile pointers. // The following stubs unify the above APIs for non-volatile marked atomic containers. // Whether the underlying data of "pTargetAddress" is thread-locally-volatile or not is upto the chosen compiler intrin used to load/store and/or whether you upcast to volatile later on. diff --git a/Source/Threading/AuWakeOnAddress.cpp b/Source/Threading/AuWakeOnAddress.cpp index 858bd804..c1c6565a 100644 --- a/Source/Threading/AuWakeOnAddress.cpp +++ b/Source/Threading/AuWakeOnAddress.cpp @@ -29,6 +29,21 @@ namespace Aurora::Threading { + struct WaitMulipleContainer; + + struct MultipleInternalContext + { + WaitState state; + WaitEntry *pBefore {}; + WaitEntry *pNext {}; + bool bOldIgnore {}; + AuUInt16 uOldStateChangedCounter {}; + }; + + static WaitEntry **GetPBeforeFromContainer(const WaitMulipleContainer *pContainer, const void *pAddress); + static WaitEntry **GetPNextFromContainer(const WaitMulipleContainer *pContainer, const void *pAddress); + static const void *GetPCompareFromContainer(const WaitMulipleContainer *pContainer, const void *pAddress); + #if defined(HACK_NO_INVALID_ACCESS_LEAK_SHARED_REF_ON_DESTROYED_THREAD) static thread_local AuSPtr tlsWaitEntry = AuMakeSharedPanic(); #else @@ -236,15 +251,138 @@ namespace Aurora::Threading return false; } - bool WaitEntry::TrySignalAddress(const void *pAddress) + bool WaitEntry::SleepLossy(AuUInt64 qwNanosecondsAbs) { - if (this->pAddress != pAddress) + #if !defined(WOA_SEMAPHORE_MODE) + AU_LOCK_GUARD(this->mutex); + #endif + + if (qwNanosecondsAbs) { + #if defined(WOA_SEMAPHORE_MODE) + return this->semaphore->LockAbsNS(uEndTime); + #else + auto uNow = AuTime::SteadyClockNS(); + + while (uNow < qwNanosecondsAbs) + { + if (!AuAtomicLoad(&this->bAlive)) + { + return true; + } + + auto uTimeRemNS = qwNanosecondsAbs - uNow; + if (this->variable.WaitForSignalNsEx(&this->mutex, uTimeRemNS, false)) + { + return true; + } + + uNow = AuTime::SteadyClockNS(); + } + return false; + #endif + } + else + { + if (!AuAtomicLoad(&this->bAlive)) + { + return true; + } + + #if defined(WOA_SEMAPHORE_MODE) + this->semaphore->Lock(); + #else + this->variable.WaitForSignalNsEx(&this->mutex, 0, false); + #endif + + return true; } - if (this->pCompareAddress) + return false; + } + + WaitEntry *WaitEntry::GetBefore(const void *pAddress) + { + if (auto pSpecial = this->pSpecial) { + return *GetPBeforeFromContainer(pSpecial, pAddress); + } + else + { + return this->pBefore; + } + } + + void WaitEntry::SetBefore(const void *pAddress, WaitEntry *pNext) + { + if (auto pSpecial = this->pSpecial) + { + if (auto pNextEntry = GetPBeforeFromContainer(pSpecial, pAddress)) + { + *pNextEntry = pNext; + } + else + { + SysUnreachable(); + } + } + else + { + this->pBefore = pNext; + } + } + + WaitEntry *WaitEntry::GetNext(const void *pAddress) + { + if (auto pSpecial = this->pSpecial) + { + return *GetPNextFromContainer(pSpecial, pAddress); + } + else + { + return this->pNext; + } + } + + void WaitEntry::SetNext(const void *pAddress, WaitEntry *pNext) + { + if (auto pSpecial = this->pSpecial) + { + if (auto pNextEntry = GetPNextFromContainer(pSpecial, pAddress)) + { + *pNextEntry = pNext; + } + else + { + SysUnreachable(); + } + } + else + { + this->pNext = pNext; + } + } + + bool WaitEntry::TrySignalAddress(const void *pAddress) + { + if (auto pSpecial = this->pSpecial) + { + if (auto pCompare = GetPCompareFromContainer(pSpecial, pAddress)) + { + if (WaitBuffer::Compare(pAddress, this->uSize, pCompare, kMax64, this->eWaitMethod)) + { + return false; + } + } + } + else if (this->pCompareAddress) + { + if (this->pAddress != pAddress) + { + return false; + } + if (WaitBuffer::Compare(pAddress, this->uSize, this->pCompareAddress, kMax64, this->eWaitMethod)) { return false; @@ -610,6 +748,7 @@ namespace Aurora::Threading pReturn->uSize = uSize; pReturn->pCompareAddress = pCompareAddress; pReturn->eWaitMethod = eWaitMethod; + pReturn->pSpecial = nullptr; if (bScheduleFirst /*First in, First Out*/) { @@ -629,8 +768,8 @@ namespace Aurora::Threading pReturn->bAlive = true; if (auto pLoadFromMemory = this->waitList.pHead) { - pLoadFromMemory->pBefore = pReturn; - pReturn->pNext = pLoadFromMemory; + pLoadFromMemory->SetBefore(pAddress, pReturn); + pReturn->SetNext(pAddress, pLoadFromMemory); } else { @@ -648,8 +787,8 @@ namespace Aurora::Threading pReturn->bAlive = true; if (auto pLoadFromMemory = this->waitList.pTail) { - pLoadFromMemory->pNext = pReturn; - pReturn->pBefore = pLoadFromMemory; + pLoadFromMemory->SetNext(pAddress, pReturn); + pReturn->SetBefore(pAddress, pLoadFromMemory); } else { @@ -663,8 +802,70 @@ namespace Aurora::Threading return pReturn; } + WaitEntry *ProcessWaitNodeContainer::WaitBufferFrom2(const void *pAddress, + AuUInt8 uSize, + const void *pAddressCompare, + EWaitMethod eWaitMethod, + MultipleInternalContext *pContext, + const WaitMulipleContainer *pContainer) + { + #if defined(HACK_NO_INVALID_ACCESS_LEAK_SHARED_REF_ON_DESTROYED_THREAD) + auto pReturn = tlsWaitEntry.get(); + #else + auto pReturn = &tlsWaitEntry; + #endif + + pReturn->pAddress = pAddress; + pReturn->uSize = uSize; + pReturn->pCompareAddress = pAddressCompare; + pReturn->eWaitMethod = eWaitMethod; + + Lock(); + + if (!WaitBuffer::Compare(pAddress, uSize, pAddressCompare, kMax64, eWaitMethod)) + { + pReturn->bAlive = false; + Unlock(); + return nullptr; + } + + bool bAddToArray {}; + if (!pReturn->bAlive) + { + pReturn->bAlive = true; + bAddToArray = true; + } + else + { + // TODO: traverse list and reject duplicates + bAddToArray = true; + } + + if (bAddToArray) + { + if (auto pLoadFromMemory = this->waitList.pHead) + { + if (pLoadFromMemory != pReturn) + { + pLoadFromMemory->SetBefore(pAddress, pReturn); + pReturn->SetNext(pAddress, pLoadFromMemory); + } + } + else + { + this->waitList.pTail = pReturn; + } + this->waitList.pHead = pReturn; + } + + pReturn->pSpecial = pContainer; + Unlock(); + + return pReturn; + } + template - bool ProcessWaitNodeContainer::IterateWake(T callback) + bool ProcessWaitNodeContainer::IterateWake(const void *pAddress, T callback) { bool bRetStatus { true }; @@ -698,11 +899,11 @@ namespace Aurora::Threading auto [bCont, bRemove] = callback(*pCurrentHead); - pBefore = pCurrentHead->pBefore; + pBefore = pCurrentHead->GetBefore(pAddress); if (bRemove) { - this->RemoveEntry(pCurrentHead); + this->RemoveEntry(pAddress, pCurrentHead); } else { @@ -729,46 +930,49 @@ namespace Aurora::Threading } template - void ProcessWaitNodeContainer::RemoveEntry(WaitEntry *pEntry) + void ProcessWaitNodeContainer::RemoveEntry(const void *pAddress, WaitEntry *pEntry) { + auto pNext = pEntry->GetNext(pAddress); + auto pBefore = pEntry->GetBefore(pAddress); + if (this->waitList.pHead == pEntry) { - this->waitList.pHead = pEntry->pNext; + this->waitList.pHead = pNext; } if (this->waitList.pTail == pEntry) { - this->waitList.pTail = pEntry->pBefore; + this->waitList.pTail = pBefore; } - if (auto pBefore = pEntry->pBefore) + if (pBefore) { - pBefore->pNext = pEntry->pNext; + pBefore->SetNext(pAddress, pNext); } - if (auto pNext = pEntry->pNext) + if (pNext) { - pNext->pBefore = pEntry->pBefore; + pNext->SetBefore(pAddress, pBefore); } if (bAllUnderLock) { - pEntry->pBefore = nullptr; - pEntry->pNext = nullptr; + pEntry->SetNext(pAddress, nullptr); + pEntry->SetBefore(pAddress, nullptr); //pEntry->bAlive = false; - redundant } } - void ProcessWaitNodeContainer::RemoveSelf(WaitEntry *pSelf) + void ProcessWaitNodeContainer::RemoveSelf(const void *pAddress, WaitEntry *pSelf) { { this->Lock(); - this->RemoveEntry(pSelf); + this->RemoveEntry(pAddress, pSelf); this->Unlock(); } - pSelf->pBefore = nullptr; - pSelf->pNext = nullptr; + pSelf->SetBefore(pAddress, nullptr); + pSelf->SetNext(pAddress, nullptr); pSelf->bAlive = false; } @@ -789,15 +993,20 @@ namespace Aurora::Threading return this->list[AddressToIndex].WaitBufferFrom(pAddress, uSize, bScheduleFirst, pCompareAddress, eWaitMethod); } + WaitEntry *ProcessWaitContainer::WaitBufferFrom2(const void *pAddress, AuUInt8 uSize, const void *pAddressCompare, EWaitMethod eWaitMethod, MultipleInternalContext *pContext, const WaitMulipleContainer *pContainer) + { + return this->list[AddressToIndex].WaitBufferFrom2(pAddress, uSize, pAddressCompare, eWaitMethod, pContext, pContainer); + } + template bool ProcessWaitContainer::IterateWake(const void *pAddress, T callback) { - return this->list[AddressToIndex].IterateWake(callback); + return this->list[AddressToIndex].IterateWake(pAddress, callback); } void ProcessWaitContainer::RemoveSelf(const void *pAddress, WaitEntry *pSelf) { - return this->list[AddressToIndex].RemoveSelf(pSelf); + return this->list[AddressToIndex].RemoveSelf(pAddress, pSelf); } bool IsNativeWaitOnSupported() @@ -865,6 +1074,10 @@ namespace Aurora::Threading SysAssertDbg(uWordSize <= 32); auto pWaitEntry = gProcessWaitables.WaitBufferFrom(pTargetAddress, uWordSize, true, pCompareAddress2, T); + if (!pWaitEntry) + { + return true; + } // Unlocked update to a safer comparison address; hardens against bad code { @@ -1609,4 +1822,211 @@ namespace Aurora::Threading return false; } + + static WaitEntry **GetPNextFromContainer(const WaitMulipleContainer *pContainer, const void *pAddress) + { + auto uCount = pContainer->waitArray.Count(); + auto pBase = pContainer->waitArray.Begin(); + + for (AU_ITERATE_N(i, uCount)) + { + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + + if (pBase[i].pTargetAddress == pAddress) + { + return &pCurrent->pNext; + } + } + + return nullptr; + } + + static WaitEntry **GetPBeforeFromContainer(const WaitMulipleContainer *pContainer, const void *pAddress) + { + auto uCount = pContainer->waitArray.Count(); + auto pBase = pContainer->waitArray.Begin(); + + for (AU_ITERATE_N(i, uCount)) + { + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + + if (pBase[i].pTargetAddress == pAddress) + { + return &pCurrent->pBefore; + } + } + + return nullptr; + } + + static const void *GetPCompareFromContainer(const WaitMulipleContainer *pContainer, const void *pAddress) + { + auto uCount = pContainer->waitArray.Count(); + auto pBase = pContainer->waitArray.Begin(); + + for (AU_ITERATE_N(i, uCount)) + { + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + + if (pBase[i].pTargetAddress == pAddress) + { + return pBase[i].pCompareAddress; + } + } + + return nullptr; + } + + static_assert(sizeof(MultipleInternalContext) <= sizeof(WaitMultipleEntry::internalContext)); + + AUKN_SYM bool WaitForMultipleAddressesOr(const WaitMulipleContainer &waitMultipleOnAddress) + { + bool bResult {}, bAny {}, bSleepStatus {}; + WaitEntry *pWaitEntryMain {}, *pWaitEntryAux {}; + + SysAssertDbg(!IsWaitOnRecommended(), "WoA not in emulation mode"); + + auto uCount = waitMultipleOnAddress.waitArray.Count(); + auto pBase = waitMultipleOnAddress.waitArray.Begin(); + + #if defined(HACK_NO_INVALID_ACCESS_LEAK_SHARED_REF_ON_DESTROYED_THREAD) + auto pTempHoldMe = tlsWaitEntry; + #endif + + do + { + for (AU_ITERATE_N(i, uCount)) + { + auto ¤t = pBase[i]; + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + auto &state = pCurrent->state; + + if (current.bIgnoreCurrentFlag) + { + continue; + } + + pCurrent->pBefore = nullptr; + pCurrent->pNext = nullptr; + + pWaitEntryAux = gProcessWaitables.WaitBufferFrom2(current.pTargetAddress, current.uSize, current.pCompareAddress, current.eMethod, pCurrent, &waitMultipleOnAddress); + if (!pWaitEntryAux) + { + break; + } + else + { + pWaitEntryMain = pWaitEntryAux; + } + + state.qwNanosecondsAbs = waitMultipleOnAddress.qwNanoseconds; + bAny = true; + } + + if (!bAny) + { + return true; + } + + bSleepStatus = pWaitEntryAux && pWaitEntryAux->SleepLossy(waitMultipleOnAddress.qwNanoseconds); + + for (AU_ITERATE_N(i, uCount)) + { + auto ¤t = pBase[i]; + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + auto &state = pCurrent->state; + + if (current.bIgnoreCurrentFlag) + { + continue; + } + + if (!WaitBuffer::Compare(current.pTargetAddress, current.uSize, current.pCompareAddress, kMax64, current.eMethod)) + { + current.uHasStateChangedCounter++; + bResult = true; + } + + gProcessWaitables.RemoveSelf(current.pTargetAddress, pWaitEntryMain); + } + } + while (!bResult && (!waitMultipleOnAddress.qwNanoseconds || bSleepStatus)); + + #if defined(HACK_NO_INVALID_ACCESS_LEAK_SHARED_REF_ON_DESTROYED_THREAD) + pTempHoldMe.reset(); + #endif + + return bResult; + } + + AUKN_SYM bool WaitForMultipleAddressesAnd(const WaitMulipleContainer &waitMultipleOnAddress) + { + auto uCount = waitMultipleOnAddress.waitArray.Count(); + auto pBase = waitMultipleOnAddress.waitArray.Begin(); + + for (AU_ITERATE_N(i, uCount)) + { + auto ¤t = pBase[i]; + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + + pCurrent->bOldIgnore = current.bIgnoreCurrentFlag; + pCurrent->uOldStateChangedCounter = current.uHasStateChangedCounter; + } + + bool bFound {}; + bool bTimeout {}; + do + { + bool bRet = WaitForMultipleAddressesOr(waitMultipleOnAddress); + + bFound = false; + + for (AU_ITERATE_N(i, uCount)) + { + auto ¤t = pBase[i]; + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + + if (current.bIgnoreCurrentFlag) + { + continue; + } + + bool bTriggered = current.uHasStateChangedCounter - pCurrent->uOldStateChangedCounter; + + if (bTriggered) + { + current.bIgnoreCurrentFlag = true; + } + else + { + bFound = true; + } + } + + if (waitMultipleOnAddress.qwNanoseconds && !bRet) + { + bTimeout = true; + break; + } + } + while (bFound); + + for (AU_ITERATE_N(i, uCount)) + { + auto ¤t = pBase[i]; + auto pCurrent = AuReinterpretCast(pBase[i].internalContext); + + current.bIgnoreCurrentFlag = pCurrent->bOldIgnore; + current.uHasStateChangedCounter = pCurrent->uOldStateChangedCounter + (bFound ? 0 : 1); + } + + if (!bFound) + { + return true; + } + else + { + return !bTimeout; + } + } } \ No newline at end of file diff --git a/Source/Threading/AuWakeOnAddress.hpp b/Source/Threading/AuWakeOnAddress.hpp index dcb7c3ee..eb81b24e 100644 --- a/Source/Threading/AuWakeOnAddress.hpp +++ b/Source/Threading/AuWakeOnAddress.hpp @@ -28,6 +28,8 @@ namespace Aurora::Threading static const auto kPlatformFutexNoForcedAlignedU32 = AuBuild::kIsNTDerived; struct WaitState; + struct MultipleInternalContext; + struct WaitMulipleContainer; struct WaitBuffer { @@ -62,9 +64,10 @@ namespace Aurora::Threading { WaitEntry(); ~WaitEntry(); - - WaitEntry * volatile pNext {}; - WaitEntry * volatile pBefore {}; + + WaitEntry * volatile pNext {}; + WaitEntry * volatile pBefore {}; + const WaitMulipleContainer * volatile pSpecial {}; // synch #if defined(WOA_SEMAPHORE_MODE) @@ -105,7 +108,14 @@ namespace Aurora::Threading template bool SleepOn(WaitState &state); + bool SleepLossy(AuUInt64 qwNanosecondsAbs); bool TrySignalAddress(const void *pAddress); + + auline WaitEntry *GetNext(const void *pAddress); + auline void SetNext(const void *pAddress, WaitEntry *pNext); + + auline WaitEntry *GetBefore(const void *pAddress); + auline void SetBefore(const void *pAddress, WaitEntry *pNext); }; struct ProcessListWait @@ -120,14 +130,15 @@ namespace Aurora::Threading ProcessListWait waitList; WaitEntry *WaitBufferFrom(const void *pAddress, AuUInt8 uSize, bool bScheduleFirst, const void *pAddressCompare, EWaitMethod eWaitMethod); + WaitEntry *WaitBufferFrom2(const void *pAddress, AuUInt8 uSize, const void *pAddressCompare, EWaitMethod eWaitMethod, MultipleInternalContext *pContext, const WaitMulipleContainer *pContainer); template - bool IterateWake(T callback); + bool IterateWake(const void *pAddress, T callback); - void RemoveSelf(WaitEntry *pSelf); + void RemoveSelf(const void *pAddress, WaitEntry *pSelf); template - void RemoveEntry(WaitEntry *pSelf); + void RemoveEntry(const void *pAddress, WaitEntry *pSelf); void Lock(); @@ -139,6 +150,7 @@ namespace Aurora::Threading ProcessWaitNodeContainer list[kDefaultWaitPerProcess]; WaitEntry *WaitBufferFrom(const void *pAddress, AuUInt8 uSize, bool bScheduleFirst, const void *pAddressCompare, EWaitMethod eWaitMethod); + WaitEntry *WaitBufferFrom2(const void *pAddress, AuUInt8 uSize, const void *pAddressCompare, EWaitMethod eWaitMethod, MultipleInternalContext *pContext, const WaitMulipleContainer *pContainer); template bool IterateWake(const void *pAddress, T callback);