[+] AuVoidTask

(https://devblogs.microsoft.com/oldnewthing/20190116-00/?p=100715)
[*] Amend coroutines Awaitable to prevent use after free
This commit is contained in:
Reece Wilson 2023-07-08 14:49:42 +01:00
parent 69a5bb8061
commit a1b07c634a

View File

@ -693,6 +693,69 @@ namespace std
#endif #endif
} }
struct AuVoidTask
{
struct promise_type
{
AuVoidTask get_return_object()
{
return {};
}
std::suspend_never initial_suspend()
{
return {};
}
std::suspend_never final_suspend() noexcept
{
return {};
}
void return_void()
{ }
void unhandled_exception()
{ }
};
};
#else
#if defined(AURORA_IS_MODERNNT_DERIVED)
#include <pplawait.h>
using AuVoidTask = concurrency::task<void>;
#else
struct AuVoidTask
{
struct promise_type
{
AuVoidTask get_return_object()
{
return {};
}
bool initial_suspend()
{
return false;
}
bool final_suspend() noexcept
{
return false;
}
void return_void()
{ }
void unhandled_exception()
{ }
};
};
#endif
#endif #endif
namespace __detail namespace __detail
@ -702,6 +765,11 @@ namespace __detail
{ {
AuSharedFuture<A, B> pFuture; AuSharedFuture<A, B> pFuture;
Awaitable(AuSharedFuture<A, B> pFuture) :
pFuture(pFuture)
{
}
bool await_ready() bool await_ready()
{ {
return __detail::FutureAccessor::IsFinished(*pFuture.get()); return __detail::FutureAccessor::IsFinished(*pFuture.get());
@ -710,21 +778,28 @@ namespace __detail
template <typename T> template <typename T>
void await_suspend(T h) void await_suspend(T h)
{ {
pFuture->OnComplete([=](const A &) auto pFuture = this->pFuture;
pFuture->OnComplete([h = h](const A &)
{ {
h.resume(); h.resume();
}); });
if (__detail::FutureAccessor::IsFinished(*pFuture.get()))
{
return;
}
if constexpr (!AuIsVoid_v<B>) if constexpr (!AuIsVoid_v<B>)
{ {
pFuture->OnFailure([=](const B &) pFuture->OnFailure([h = h](const B &)
{ {
h.resume(); h.resume();
}); });
} }
else else
{ {
pFuture->OnFailure([=]() pFuture->OnFailure([h = h]()
{ {
h.resume(); h.resume();
}); });
@ -757,21 +832,28 @@ namespace __detail
template <typename T> template <typename T>
void await_suspend(T h) void await_suspend(T h)
{ {
pFuture->OnComplete([=]() auto pFuture = this->pFuture;
pFuture->OnComplete([h = h]()
{ {
h.resume(); h.resume();
}); });
if (__detail::FutureAccessor::IsFinished(*pFuture.get()))
{
return;
}
if constexpr (!AuIsVoid_v<B>) if constexpr (!AuIsVoid_v<B>)
{ {
pFuture->OnFailure([=](const B &) pFuture->OnFailure([h = h](const B &)
{ {
h.resume(); h.resume();
}); });
} }
else else
{ {
pFuture->OnFailure([=]() pFuture->OnFailure([h = h]()
{ {
h.resume(); h.resume();
}); });
@ -788,14 +870,14 @@ namespace __detail
#if defined(__AUHAS_COROUTINES_CO_AWAIT) #if defined(__AUHAS_COROUTINES_CO_AWAIT)
template <typename A, typename B, AU_TEMPLATE_ENABLE_WHEN(!AuIsVoid_v<A>)> template <typename A, typename B, AU_TEMPLATE_ENABLE_WHEN(!AuIsVoid_v<A>)>
inline auto operator co_await (const AuSharedFuture<A, B> &pFuture) inline auto operator co_await (AuSharedFuture<A, B> pFuture)
{ {
SysAssert(pFuture); SysAssert(pFuture);
return __detail::Awaitable<A, B> { pFuture }; return __detail::Awaitable<A, B> { pFuture };
} }
template <typename A, typename B, AU_TEMPLATE_ENABLE_WHEN(AuIsVoid_v<A>)> template <typename A, typename B, AU_TEMPLATE_ENABLE_WHEN(AuIsVoid_v<A>)>
inline auto operator co_await (const AuSharedFuture<A, B> &pFuture) inline auto operator co_await (AuSharedFuture<A, B> pFuture)
{ {
SysAssert(pFuture); SysAssert(pFuture);
return __detail::AwaitableVoid<B> { pFuture }; return __detail::AwaitableVoid<B> { pFuture };