1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15
16 #include <concepts>
17 #include <coroutine>
18
19 #include "pw_allocator/allocator.h"
20 #include "pw_allocator/layout.h"
21 #include "pw_async2/dispatcher.h"
22 #include "pw_log/log.h"
23 #include "pw_status/status.h"
24 #include "pw_status/try.h"
25
26 namespace pw::async2 {
27
28 // Forward-declare `Coro` so that it can be referenced by the promise type APIs.
29 template <std::constructible_from<pw::Status> T>
30 class Coro;
31
32 /// Context required for creating and executing coroutines.
33 class CoroContext {
34 public:
35 /// Creates a `CoroContext` which will allocate coroutine state using
36 /// `alloc`.
CoroContext(pw::allocator::Allocator & alloc)37 explicit CoroContext(pw::allocator::Allocator& alloc) : alloc_(alloc) {}
alloc()38 pw::allocator::Allocator& alloc() const { return alloc_; }
39
40 private:
41 pw::allocator::Allocator& alloc_;
42 };
43
44 // The internal coroutine API implementation details enabling `Coro<T>`.
45 //
46 // Users of `Coro<T>` need not concern themselves with these details, unless
47 // they think it sounds like fun ;)
48 namespace internal {
49
50 // A container for a to-be-produced value of type `T`.
51 //
52 // This is designed to allow avoiding the overhead of `std::optional` when
53 // `T` is default-initializable.
54 //
55 // Values of this type begin as either:
56 // - a default-initialized `T` if `T` is default-initializable or
57 // - `std::nullopt`
58 template <typename T>
59 class OptionalOrDefault final {
60 public:
61 // Create an empty container for a to-be-provided value.
OptionalOrDefault()62 OptionalOrDefault() : value_() {}
63
64 // Assign a value.
65 template <typename U>
66 OptionalOrDefault& operator=(U&& value) {
67 value_ = std::forward<U>(value);
68 return *this;
69 }
70
71 // Retrieve the inner value.
72 //
73 // This operation will fail if no value was assigned.
74 T& operator*() {
75 PW_ASSERT(value_.has_value());
76 return *value_;
77 }
78
79 private:
80 std::optional<T> value_;
81 };
82
83 // A specialization of `OptionalOrDefault<T>` for `default_initializable`
84 // types.
85 template <std::default_initializable T>
86 class OptionalOrDefault<T> final {
87 public:
88 // Create a container for a to-be-provided value by default-initializing.
OptionalOrDefault()89 OptionalOrDefault() : value_() {}
90
91 // Assign a value.
92 template <typename U>
93 OptionalOrDefault& operator=(U&& value) {
94 value_ = std::forward<U>(value);
95 return *this;
96 }
97
98 // Retrieve the inner value.
99 //
100 // This operation will return a default-constructed `T` if no value was
101 // assigned. Typical users should not rely on this, and should instead
102 // only retrieve values assigned using `operator=`.
103 T& operator*() { return value_; }
104
105 private:
106 T value_;
107 };
108
109 // A wrapper for `std::coroutine_handle` that assumes unique ownership of the
110 // underlying `PromiseType`.
111 //
112 // This type will `destroy()` the underlying promise in its destructor, or
113 // when `Release()` is called.
114 template <typename PromiseType>
115 class OwningCoroutineHandle final {
116 public:
117 // Construct a null (`!IsValid()`) handle.
OwningCoroutineHandle(nullptr_t)118 OwningCoroutineHandle(nullptr_t) : promise_handle_(nullptr) {}
119
120 /// Take ownership of `promise_handle`.
OwningCoroutineHandle(std::coroutine_handle<PromiseType> && promise_handle)121 OwningCoroutineHandle(std::coroutine_handle<PromiseType>&& promise_handle)
122 : promise_handle_(std::move(promise_handle)) {}
123
124 // Empty out `other` and transfers ownership of its `promise_handle`
125 // to `this`.
OwningCoroutineHandle(OwningCoroutineHandle && other)126 OwningCoroutineHandle(OwningCoroutineHandle&& other)
127 : promise_handle_(std::move(other.promise_handle_)) {
128 other.promise_handle_ = nullptr;
129 }
130
131 // Empty out `other` and transfers ownership of its `promise_handle`
132 // to `this`.
133 OwningCoroutineHandle& operator=(OwningCoroutineHandle&& other) {
134 Release();
135 promise_handle_ = std::move(other.promise_handle_);
136 other.promise_handle_ = nullptr;
137 return *this;
138 }
139
140 // `destroy()`s the underlying `promise_handle` if valid.
~OwningCoroutineHandle()141 ~OwningCoroutineHandle() { Release(); }
142
143 // Return whether or not this value contains a `promise_handle`.
144 //
145 // This will return `false` if this `OwningCoroutineHandle` was
146 // `nullptr`-initialized, moved from, or if `Release` was invoked.
IsValid()147 [[nodiscard]] bool IsValid() const {
148 return promise_handle_.address() != nullptr;
149 }
150
151 // Return a reference to the underlying `PromiseType`.
152 //
153 // Precondition: `IsValid()` must be `true`.
promise()154 [[nodiscard]] PromiseType& promise() const {
155 return promise_handle_.promise();
156 }
157
158 // Whether or not the underlying coroutine has completed.
159 //
160 // Precondition: `IsValid()` must be `true`.
done()161 [[nodiscard]] bool done() const { return promise_handle_.done(); }
162
163 // Resume the underlying coroutine.
164 //
165 // Precondition: `IsValid()` must be `true`, and `done()` must be
166 // `false`.
resume()167 void resume() { promise_handle_.resume(); }
168
169 // Invokes `destroy()` on the underlying promise and deallocates its
170 // associated storage.
Release()171 void Release() {
172 void* address = promise_handle_.address();
173 if (address != nullptr) {
174 pw::allocator::Deallocator& dealloc = promise_handle_.promise().dealloc_;
175 promise_handle_.destroy();
176 promise_handle_ = nullptr;
177 dealloc.Deallocate(address);
178 }
179 }
180
181 private:
182 std::coroutine_handle<PromiseType> promise_handle_;
183 };
184
185 // Forward-declare the wrapper type for values passed to `co_await`.
186 template <typename Pendable, typename PromiseType>
187 class Awaitable;
188
189 // A container for values passed in and out of the promise.
190 //
191 // The C++20 coroutine `resume()` function cannot accept arguments no return
192 // values, so instead coroutine inputs and outputs are funneled through this
193 // type. A pointer to the `InOut` object is stored in the `CoroPromiseType`
194 // so that the coroutine object can access it.
195 template <typename T>
196 struct InOut final {
197 // The `Context` passed into the coroutine via `Pend`.
198 Context* input_cx;
199
200 // The output assigned to by the coroutine if the coroutine is `done()`.
201 OptionalOrDefault<T>* output;
202 };
203
204 // A base class for `Awaitable` instances.
205 //
206 // Each `co_await` statement creates an `Awaitable` object whose `Pend`
207 // method must be completed before the coroutine's `resume()` function can
208 // be invoked.
209 class AwaitableBase {
210 public:
211 // Attempt to complete the current pendable value passed to `co_await`,
212 // storing its return value inside the `Awaitable` object so that it can
213 // be retrieved by the coroutine.
214 virtual Poll<> PendFillReturnValue(Context& cx) = 0;
215
216 protected:
217 // A protected destructor ensures that child classes are never destroyed
218 // through a base pointer, so no virtual destructor is needed.
~AwaitableBase()219 ~AwaitableBase() {}
220 };
221
222 // The `promise_type` of `Coro<T>`.
223 //
224 // To understand this type, it may be necessary to refer to the reference
225 // documentation for the C++20 coroutine API.
226 template <typename T>
227 class CoroPromiseType final {
228 public:
229 // Construct the `CoroPromiseType` using the arguments passed to a
230 // function returning `Coro<T>`.
231 //
232 // The first argument *must* be a `CoroContext`. The other
233 // arguments are unused, but must be accepted in order for this to compile.
234 template <typename... Args>
CoroPromiseType(CoroContext & cx,const Args &...)235 CoroPromiseType(CoroContext& cx, const Args&...)
236 : dealloc_(cx.alloc()), current_awaitable_(nullptr), in_out_(nullptr) {}
237
238 // Get the `Coro<T>` after successfully allocating the coroutine space
239 // and constructing `this`.
240 Coro<T> get_return_object();
241
242 // Do not begin executing the `Coro<T>` until `resume()` has been invoked
243 // for the first time.
initial_suspend()244 std::suspend_always initial_suspend() { return {}; }
245
246 // Unconditionally suspend to prevent `destroy()` being invoked.
247 //
248 // The caller of `resume()` needs to first observe `done()` before the
249 // state can be destroyed.
250 //
251 // Setting this to suspend means that the caller is responsible for invoking
252 // `destroy()`.
final_suspend()253 std::suspend_always final_suspend() noexcept { return {}; }
254
255 // Store the `co_return` argument in the `InOut<T>` object provided by
256 // the `Pend` wrapper.
257 template <std::convertible_to<T> From>
return_value(From && value)258 void return_value(From&& value) {
259 *in_out_->output = std::forward<From>(value);
260 }
261
262 // Ignore exceptions in coroutines.
263 //
264 // Pigweed is not designed to be used with exceptions: `Result` or a
265 // similar type should be used to propagate errors.
unhandled_exception()266 void unhandled_exception() { PW_ASSERT(false); }
267
268 // Create an invalid (nullptr) `Coro<T>` if `operator new` below fails.
269 static Coro<T> get_return_object_on_allocation_failure();
270
271 // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
272 // state.
273 template <typename... Args>
new(std::size_t n,CoroContext & coro_cx,const Args &...)274 static void* operator new(std::size_t n,
275 CoroContext& coro_cx,
276 const Args&...) noexcept {
277 return coro_cx.alloc().Allocate(pw::allocator::Layout(n));
278 }
279
280 // Deallocate the space for both this `CoroPromiseType<T>` and the
281 // coroutine state.
282 //
283 // In reality, we do nothing here!!!
284 //
285 // Coroutines do not support `destroying_delete`, so we can't access
286 // `dealloc_` here, and therefore have no way to deallocate.
287 // Instead, deallocation is handled by `OwningCoroutineHandle<T>::Release`.
delete(void *)288 static void operator delete(void*) {}
289
290 // Handle a `co_await` call by accepting a type with a
291 // `Poll<U> Pend(Context&)` method, returning an `Awaitable` which will
292 // yield a `U` once complete.
293 template <typename Pendable>
await_transform(Pendable && pendable)294 Awaitable<Pendable, CoroPromiseType> await_transform(Pendable&& pendable) {
295 return pendable;
296 }
297
298 // Returns a reference to the `Context` passed in.
cx()299 Context& cx() { return *in_out_->input_cx; }
300
301 pw::allocator::Deallocator& dealloc_;
302 AwaitableBase* current_awaitable_;
303 InOut<T>* in_out_;
304 };
305
306 // The object created by invoking `co_await` in a `Coro<T>` function.
307 //
308 // This wraps a `Pendable` type and implements the awaitable interface
309 // expected by the standard coroutine API.
310 template <typename Pendable, typename PromiseType>
311 class Awaitable final : AwaitableBase {
312 public:
313 // The `OutputType` in `Poll<OutputType> Pendable::Pend(Context&)`.
314 using OutputType =
315 std::remove_cvref_t<decltype(std::declval<Pendable>()
316 .Pend(std::declval<Context&>())
317 .value())>;
318
Awaitable(Pendable && pendable)319 Awaitable(Pendable&& pendable)
320 : pendable_(std::forward<Pendable>(pendable)) {}
321
322 // Confirms that `await_suspend` must be invoked.
await_ready()323 bool await_ready() { return false; }
324
325 // Returns whether or not the current coroutine should be suspended.
326 //
327 // This is invoked once as part of every `co_await` call after
328 // `await_ready` returns `false`.
329 //
330 // In the process, this method attempts to complete the inner `Pendable`
331 // before suspending this coroutine.
await_suspend(const std::coroutine_handle<PromiseType> & promise)332 bool await_suspend(const std::coroutine_handle<PromiseType>& promise) {
333 Context& cx = promise.promise().cx();
334 if (PendFillReturnValue(cx).IsPending()) {
335 /// The coroutine should suspend since the await-ed thing is pending.
336 promise.promise().current_awaitable_ = this;
337 return true;
338 }
339 return false;
340 }
341
342 // Returns `return_value`.
343 //
344 // This is automatically invoked by the language runtime when the promise's
345 // `resume()` method is called.
await_resume()346 OutputType&& await_resume() { return std::move(*return_value_); }
347
348 // Attempts to complete the `Pendable` value, storing its return value
349 // upon completion.
350 //
351 // This method must return `Ready()` before the coroutine can be safely
352 // resumed, as otherwise the return value will not be available when
353 // `await_resume` is called to produce the result of `co_await`.
PendFillReturnValue(Context & cx)354 Poll<> PendFillReturnValue(Context& cx) final {
355 Poll<OutputType> poll_res = pendable_.Pend(cx);
356 if (poll_res.IsPending()) {
357 return Pending();
358 }
359 return_value_ = std::move(*poll_res);
360 return Ready();
361 }
362
363 private:
364 Pendable pendable_;
365 OptionalOrDefault<OutputType> return_value_;
366 };
367
368 } // namespace internal
369
370 /// An asynchronous coroutine which implements the C++20 coroutine API.
371 ///
372 /// # Why coroutines?
373 /// Coroutines allow a series of asynchronous operations to be written as
374 /// straight line code. Rather than manually writing a state machine, users can
375 /// `co_await` any Pigweed asynchronous value (types with a
376 /// `Poll<T> Pend(Context&)` method).
377 ///
378 /// # Allocation
379 /// Pigweed's `Coro<T>` API supports checked, fallible, heap-free allocation.
380 /// The first argument to any coroutine function must be a
381 /// `CoroContext` (or a reference to one). This allows the
382 /// coroutine to allocate space for asynchronously-held stack variables using
383 /// the allocator member of the `CoroContext`.
384 ///
385 /// Failure to allocate coroutine "stack" space will result in the `Coro<T>`
386 /// returning `Status::Invalid()`.
387 ///
388 /// # Creating a coroutine function
389 /// To create a coroutine, a function must:
390 /// - Have an annotated return type of `Coro<T>` where `T` is some type
391 /// constructible from `pw::Status`, such as `pw::Status` or
392 /// `pw::Result<U>`.
393 /// - Use `co_return <value>` rather than `return <value>` for any
394 /// `return` statements. This also requires the use of `PW_CO_TRY` and
395 /// `PW_CO_TRY_ASSIGN` rather than `PW_TRY` and `PW_TRY_ASSIGN`.
396 /// - Accept a value convertible to `pw::allocator::Allocator&` as its first
397 /// argument. This allocator will be used to allocate storage for coroutine
398 /// stack variables held across a `co_await` point.
399 ///
400 /// # Using `co_await`
401 /// Inside a coroutine function, `co_await <expr>` can be used on any type
402 /// with a `Poll<T> Pend(Context&)` method. The result will be a value of
403 /// type `T`.
404 ///
405 /// # Example
406 /// @rst
407 /// .. literalinclude:: examples/coro.cc
408 /// :language: cpp
409 /// :linenos:
410 /// :start-after: [pw_async2-examples-coro-injection]
411 /// :end-before: [pw_async2-examples-coro-injection]
412 /// @endrst
413 template <std::constructible_from<pw::Status> T>
414 class Coro final {
415 public:
416 /// Whether or not this `Coro<T>` is a valid coroutine.
417 ///
418 /// This will return `false` if coroutine state allocation failed or if
419 /// this `Coro<T>::Pend` method previously returned a `Ready` value.
IsValid()420 [[nodiscard]] bool IsValid() const { return promise_handle_.IsValid(); }
421
422 /// Attempt to complete this coroutine, returning the result if complete.
423 ///
424 /// Returns `Status::Internal()` if `!IsValid()`, which may occur if
425 /// coroutine state allocation fails.
Pend(Context & cx)426 Poll<T> Pend(Context& cx) {
427 if (!IsValid()) {
428 // This coroutine failed to allocate its internal state.
429 // (Or `Pend` is being erroniously invoked after previously completing.)
430 return Ready(Status::Internal());
431 }
432
433 // If an `Awaitable` value is currently being processed, it must be
434 // allowed to complete and store its return value before we can resume
435 // the coroutine.
436 if (promise_handle_.promise().current_awaitable_ != nullptr &&
437 promise_handle_.promise()
438 .current_awaitable_->PendFillReturnValue(cx)
439 .IsPending()) {
440 return Pending();
441 }
442 // Create the arguments (and output storage) for the coroutine.
443 internal::InOut<T> in_out;
444 internal::OptionalOrDefault<T> return_value;
445 in_out.input_cx = &cx;
446 in_out.output = &return_value;
447 promise_handle_.promise().in_out_ = &in_out;
448
449 // Resume the coroutine, triggering `Awaitable::await_resume()` and the
450 // returning of the resulting value from `co_await`.
451 promise_handle_.resume();
452 if (!promise_handle_.done()) {
453 return Pending();
454 }
455
456 // Destroy the coroutine state: it has completed, and further calls to
457 // `resume` would result in undefined behavior.
458 promise_handle_.Release();
459
460 // When the coroutine completed in `resume()` above, it stored its
461 // `co_return` value into `return_value`. This retrieves that value.
462 return std::move(*return_value);
463 }
464
465 /// Used by the compiler in order to create a `Coro<T>` from a coroutine
466 /// function.
467 using promise_type = ::pw::async2::internal::CoroPromiseType<T>;
468
469 private:
470 // Allow `CoroPromiseType<T>::get_return_object()` and
471 // `CoroPromiseType<T>::get_retunr_object_on_allocation_failure()` to
472 // use the private constructor below.
473 friend promise_type;
474
475 /// Create a new `Coro<T>` using a (possibly null) handle.
Coro(internal::OwningCoroutineHandle<promise_type> && promise_handle)476 Coro(internal::OwningCoroutineHandle<promise_type>&& promise_handle)
477 : promise_handle_(std::move(promise_handle)) {}
478
479 internal::OwningCoroutineHandle<promise_type> promise_handle_;
480 };
481
482 // Implement the remaining internal pieces that require a definition of
483 // `Coro<T>`.
484 namespace internal {
485
486 template <typename T>
get_return_object()487 Coro<T> CoroPromiseType<T>::get_return_object() {
488 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(
489 std::coroutine_handle<CoroPromiseType<T>>::from_promise(*this));
490 }
491
492 template <typename T>
get_return_object_on_allocation_failure()493 Coro<T> CoroPromiseType<T>::get_return_object_on_allocation_failure() {
494 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(nullptr);
495 }
496
497 } // namespace internal
498 } // namespace pw::async2
499