• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <concepts>
17 #include <coroutine>
18 
19 #include "pw_allocator/allocator.h"
20 #include "pw_allocator/layout.h"
21 #include "pw_async2/dispatcher.h"
22 #include "pw_log/log.h"
23 #include "pw_status/status.h"
24 #include "pw_status/try.h"
25 
26 namespace pw::async2 {
27 
28 // Forward-declare `Coro` so that it can be referenced by the promise type APIs.
29 template <std::constructible_from<pw::Status> T>
30 class Coro;
31 
32 /// Context required for creating and executing coroutines.
33 class CoroContext {
34  public:
35   /// Creates a `CoroContext` which will allocate coroutine state using
36   /// `alloc`.
CoroContext(pw::allocator::Allocator & alloc)37   explicit CoroContext(pw::allocator::Allocator& alloc) : alloc_(alloc) {}
alloc()38   pw::allocator::Allocator& alloc() const { return alloc_; }
39 
40  private:
41   pw::allocator::Allocator& alloc_;
42 };
43 
44 // The internal coroutine API implementation details enabling `Coro<T>`.
45 //
46 // Users of `Coro<T>` need not concern themselves with these details, unless
47 // they think it sounds like fun ;)
48 namespace internal {
49 
50 // A container for a to-be-produced value of type `T`.
51 //
52 // This is designed to allow avoiding the overhead of `std::optional` when
53 // `T` is default-initializable.
54 //
55 // Values of this type begin as either:
56 // - a default-initialized `T` if `T` is default-initializable or
57 // - `std::nullopt`
58 template <typename T>
59 class OptionalOrDefault final {
60  public:
61   // Create an empty container for a to-be-provided value.
OptionalOrDefault()62   OptionalOrDefault() : value_() {}
63 
64   // Assign a value.
65   template <typename U>
66   OptionalOrDefault& operator=(U&& value) {
67     value_ = std::forward<U>(value);
68     return *this;
69   }
70 
71   // Retrieve the inner value.
72   //
73   // This operation will fail if no value was assigned.
74   T& operator*() {
75     PW_ASSERT(value_.has_value());
76     return *value_;
77   }
78 
79  private:
80   std::optional<T> value_;
81 };
82 
83 // A specialization of `OptionalOrDefault<T>` for `default_initializable`
84 // types.
85 template <std::default_initializable T>
86 class OptionalOrDefault<T> final {
87  public:
88   // Create a container for a to-be-provided value by default-initializing.
OptionalOrDefault()89   OptionalOrDefault() : value_() {}
90 
91   // Assign a value.
92   template <typename U>
93   OptionalOrDefault& operator=(U&& value) {
94     value_ = std::forward<U>(value);
95     return *this;
96   }
97 
98   // Retrieve the inner value.
99   //
100   // This operation will return a default-constructed `T` if no value was
101   // assigned. Typical users should not rely on this, and should instead
102   // only retrieve values assigned using `operator=`.
103   T& operator*() { return value_; }
104 
105  private:
106   T value_;
107 };
108 
109 // A wrapper for `std::coroutine_handle` that assumes unique ownership of the
110 // underlying `PromiseType`.
111 //
112 // This type will `destroy()` the underlying promise in its destructor, or
113 // when `Release()` is called.
114 template <typename PromiseType>
115 class OwningCoroutineHandle final {
116  public:
117   // Construct a null (`!IsValid()`) handle.
OwningCoroutineHandle(nullptr_t)118   OwningCoroutineHandle(nullptr_t) : promise_handle_(nullptr) {}
119 
120   /// Take ownership of `promise_handle`.
OwningCoroutineHandle(std::coroutine_handle<PromiseType> && promise_handle)121   OwningCoroutineHandle(std::coroutine_handle<PromiseType>&& promise_handle)
122       : promise_handle_(std::move(promise_handle)) {}
123 
124   // Empty out `other` and transfers ownership of its `promise_handle`
125   // to `this`.
OwningCoroutineHandle(OwningCoroutineHandle && other)126   OwningCoroutineHandle(OwningCoroutineHandle&& other)
127       : promise_handle_(std::move(other.promise_handle_)) {
128     other.promise_handle_ = nullptr;
129   }
130 
131   // Empty out `other` and transfers ownership of its `promise_handle`
132   // to `this`.
133   OwningCoroutineHandle& operator=(OwningCoroutineHandle&& other) {
134     Release();
135     promise_handle_ = std::move(other.promise_handle_);
136     other.promise_handle_ = nullptr;
137     return *this;
138   }
139 
140   // `destroy()`s the underlying `promise_handle` if valid.
~OwningCoroutineHandle()141   ~OwningCoroutineHandle() { Release(); }
142 
143   // Return whether or not this value contains a `promise_handle`.
144   //
145   // This will return `false` if this `OwningCoroutineHandle` was
146   // `nullptr`-initialized, moved from, or if `Release` was invoked.
IsValid()147   [[nodiscard]] bool IsValid() const {
148     return promise_handle_.address() != nullptr;
149   }
150 
151   // Return a reference to the underlying `PromiseType`.
152   //
153   // Precondition: `IsValid()` must be `true`.
promise()154   [[nodiscard]] PromiseType& promise() const {
155     return promise_handle_.promise();
156   }
157 
158   // Whether or not the underlying coroutine has completed.
159   //
160   // Precondition: `IsValid()` must be `true`.
done()161   [[nodiscard]] bool done() const { return promise_handle_.done(); }
162 
163   // Resume the underlying coroutine.
164   //
165   // Precondition: `IsValid()` must be `true`, and `done()` must be
166   // `false`.
resume()167   void resume() { promise_handle_.resume(); }
168 
169   // Invokes `destroy()` on the underlying promise and deallocates its
170   // associated storage.
Release()171   void Release() {
172     void* address = promise_handle_.address();
173     if (address != nullptr) {
174       pw::allocator::Deallocator& dealloc = promise_handle_.promise().dealloc_;
175       promise_handle_.destroy();
176       promise_handle_ = nullptr;
177       dealloc.Deallocate(address);
178     }
179   }
180 
181  private:
182   std::coroutine_handle<PromiseType> promise_handle_;
183 };
184 
185 // Forward-declare the wrapper type for values passed to `co_await`.
186 template <typename Pendable, typename PromiseType>
187 class Awaitable;
188 
189 // A container for values passed in and out of the promise.
190 //
191 // The C++20 coroutine `resume()` function cannot accept arguments no return
192 // values, so instead coroutine inputs and outputs are funneled through this
193 // type. A pointer to the `InOut` object is stored in the `CoroPromiseType`
194 // so that the coroutine object can access it.
195 template <typename T>
196 struct InOut final {
197   // The `Context` passed into the coroutine via `Pend`.
198   Context* input_cx;
199 
200   // The output assigned to by the coroutine if the coroutine is `done()`.
201   OptionalOrDefault<T>* output;
202 };
203 
204 // A base class for `Awaitable` instances.
205 //
206 // Each `co_await` statement creates an `Awaitable` object whose `Pend`
207 // method must be completed before the coroutine's `resume()` function can
208 // be invoked.
209 class AwaitableBase {
210  public:
211   // Attempt to complete the current pendable value passed to `co_await`,
212   // storing its return value inside the `Awaitable` object so that it can
213   // be retrieved by the coroutine.
214   virtual Poll<> PendFillReturnValue(Context& cx) = 0;
215 
216  protected:
217   // A protected destructor ensures that child classes are never destroyed
218   // through a base pointer, so no virtual destructor is needed.
~AwaitableBase()219   ~AwaitableBase() {}
220 };
221 
222 // The `promise_type` of `Coro<T>`.
223 //
224 // To understand this type, it may be necessary to refer to the reference
225 // documentation for the C++20 coroutine API.
226 template <typename T>
227 class CoroPromiseType final {
228  public:
229   // Construct the `CoroPromiseType` using the arguments passed to a
230   // function returning `Coro<T>`.
231   //
232   // The first argument *must* be a `CoroContext`. The other
233   // arguments are unused, but must be accepted in order for this to compile.
234   template <typename... Args>
CoroPromiseType(CoroContext & cx,const Args &...)235   CoroPromiseType(CoroContext& cx, const Args&...)
236       : dealloc_(cx.alloc()), current_awaitable_(nullptr), in_out_(nullptr) {}
237 
238   // Get the `Coro<T>` after successfully allocating the coroutine space
239   // and constructing `this`.
240   Coro<T> get_return_object();
241 
242   // Do not begin executing the `Coro<T>` until `resume()` has been invoked
243   // for the first time.
initial_suspend()244   std::suspend_always initial_suspend() { return {}; }
245 
246   // Unconditionally suspend to prevent `destroy()` being invoked.
247   //
248   // The caller of `resume()` needs to first observe `done()` before the
249   // state can be destroyed.
250   //
251   // Setting this to suspend means that the caller is responsible for invoking
252   // `destroy()`.
final_suspend()253   std::suspend_always final_suspend() noexcept { return {}; }
254 
255   // Store the `co_return` argument in the `InOut<T>` object provided by
256   // the `Pend` wrapper.
257   template <std::convertible_to<T> From>
return_value(From && value)258   void return_value(From&& value) {
259     *in_out_->output = std::forward<From>(value);
260   }
261 
262   // Ignore exceptions in coroutines.
263   //
264   // Pigweed is not designed to be used with exceptions: `Result` or a
265   // similar type should be used to propagate errors.
unhandled_exception()266   void unhandled_exception() { PW_ASSERT(false); }
267 
268   // Create an invalid (nullptr) `Coro<T>` if `operator new` below fails.
269   static Coro<T> get_return_object_on_allocation_failure();
270 
271   // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
272   // state.
273   template <typename... Args>
new(std::size_t n,CoroContext & coro_cx,const Args &...)274   static void* operator new(std::size_t n,
275                             CoroContext& coro_cx,
276                             const Args&...) noexcept {
277     return coro_cx.alloc().Allocate(pw::allocator::Layout(n));
278   }
279 
280   // Deallocate the space for both this `CoroPromiseType<T>` and the
281   // coroutine state.
282   //
283   // In reality, we do nothing here!!!
284   //
285   // Coroutines do not support `destroying_delete`, so we can't access
286   // `dealloc_` here, and therefore have no way to deallocate.
287   // Instead, deallocation is handled by `OwningCoroutineHandle<T>::Release`.
delete(void *)288   static void operator delete(void*) {}
289 
290   // Handle a `co_await` call by accepting a type with a
291   // `Poll<U> Pend(Context&)` method, returning an `Awaitable` which will
292   // yield a `U` once complete.
293   template <typename Pendable>
await_transform(Pendable && pendable)294   Awaitable<Pendable, CoroPromiseType> await_transform(Pendable&& pendable) {
295     return pendable;
296   }
297 
298   // Returns a reference to the `Context` passed in.
cx()299   Context& cx() { return *in_out_->input_cx; }
300 
301   pw::allocator::Deallocator& dealloc_;
302   AwaitableBase* current_awaitable_;
303   InOut<T>* in_out_;
304 };
305 
306 // The object created by invoking `co_await` in a `Coro<T>` function.
307 //
308 // This wraps a `Pendable` type and implements the awaitable interface
309 // expected by the standard coroutine API.
310 template <typename Pendable, typename PromiseType>
311 class Awaitable final : AwaitableBase {
312  public:
313   // The `OutputType` in `Poll<OutputType> Pendable::Pend(Context&)`.
314   using OutputType =
315       std::remove_cvref_t<decltype(std::declval<Pendable>()
316                                        .Pend(std::declval<Context&>())
317                                        .value())>;
318 
Awaitable(Pendable && pendable)319   Awaitable(Pendable&& pendable)
320       : pendable_(std::forward<Pendable>(pendable)) {}
321 
322   // Confirms that `await_suspend` must be invoked.
await_ready()323   bool await_ready() { return false; }
324 
325   // Returns whether or not the current coroutine should be suspended.
326   //
327   // This is invoked once as part of every `co_await` call after
328   // `await_ready` returns `false`.
329   //
330   // In the process, this method attempts to complete the inner `Pendable`
331   // before suspending this coroutine.
await_suspend(const std::coroutine_handle<PromiseType> & promise)332   bool await_suspend(const std::coroutine_handle<PromiseType>& promise) {
333     Context& cx = promise.promise().cx();
334     if (PendFillReturnValue(cx).IsPending()) {
335       /// The coroutine should suspend since the await-ed thing is pending.
336       promise.promise().current_awaitable_ = this;
337       return true;
338     }
339     return false;
340   }
341 
342   // Returns `return_value`.
343   //
344   // This is automatically invoked by the language runtime when the promise's
345   // `resume()` method is called.
await_resume()346   OutputType&& await_resume() { return std::move(*return_value_); }
347 
348   // Attempts to complete the `Pendable` value, storing its return value
349   // upon completion.
350   //
351   // This method must return `Ready()` before the coroutine can be safely
352   // resumed, as otherwise the return value will not be available when
353   // `await_resume` is called to produce the result of `co_await`.
PendFillReturnValue(Context & cx)354   Poll<> PendFillReturnValue(Context& cx) final {
355     Poll<OutputType> poll_res = pendable_.Pend(cx);
356     if (poll_res.IsPending()) {
357       return Pending();
358     }
359     return_value_ = std::move(*poll_res);
360     return Ready();
361   }
362 
363  private:
364   Pendable pendable_;
365   OptionalOrDefault<OutputType> return_value_;
366 };
367 
368 }  // namespace internal
369 
370 /// An asynchronous coroutine which implements the C++20 coroutine API.
371 ///
372 /// # Why coroutines?
373 /// Coroutines allow a series of asynchronous operations to be written as
374 /// straight line code. Rather than manually writing a state machine, users can
375 /// `co_await` any Pigweed asynchronous value (types with a
376 /// `Poll<T> Pend(Context&)` method).
377 ///
378 /// # Allocation
379 /// Pigweed's `Coro<T>` API supports checked, fallible, heap-free allocation.
380 /// The first argument to any coroutine function must be a
381 /// `CoroContext` (or a reference to one). This allows the
382 /// coroutine to allocate space for asynchronously-held stack variables using
383 /// the allocator member of the `CoroContext`.
384 ///
385 /// Failure to allocate coroutine "stack" space will result in the `Coro<T>`
386 /// returning `Status::Invalid()`.
387 ///
388 /// # Creating a coroutine function
389 /// To create a coroutine, a function must:
390 /// - Have an annotated return type of `Coro<T>` where `T` is some type
391 ///   constructible from `pw::Status`, such as `pw::Status` or
392 ///   `pw::Result<U>`.
393 /// - Use `co_return <value>` rather than `return <value>` for any
394 ///   `return` statements. This also requires the use of `PW_CO_TRY` and
395 ///   `PW_CO_TRY_ASSIGN` rather than `PW_TRY` and `PW_TRY_ASSIGN`.
396 /// - Accept a value convertible to `pw::allocator::Allocator&` as its first
397 ///   argument. This allocator will be used to allocate storage for coroutine
398 ///   stack variables held across a `co_await` point.
399 ///
400 /// # Using `co_await`
401 /// Inside a coroutine function, `co_await <expr>` can be used on any type
402 /// with a `Poll<T> Pend(Context&)` method. The result will be a value of
403 /// type `T`.
404 ///
405 /// # Example
406 /// @rst
407 /// .. literalinclude:: examples/coro.cc
408 ///    :language: cpp
409 ///    :linenos:
410 ///    :start-after: [pw_async2-examples-coro-injection]
411 ///    :end-before: [pw_async2-examples-coro-injection]
412 /// @endrst
413 template <std::constructible_from<pw::Status> T>
414 class Coro final {
415  public:
416   /// Whether or not this `Coro<T>` is a valid coroutine.
417   ///
418   /// This will return `false` if coroutine state allocation failed or if
419   /// this `Coro<T>::Pend` method previously returned a `Ready` value.
IsValid()420   [[nodiscard]] bool IsValid() const { return promise_handle_.IsValid(); }
421 
422   /// Attempt to complete this coroutine, returning the result if complete.
423   ///
424   /// Returns `Status::Internal()` if `!IsValid()`, which may occur if
425   /// coroutine state allocation fails.
Pend(Context & cx)426   Poll<T> Pend(Context& cx) {
427     if (!IsValid()) {
428       // This coroutine failed to allocate its internal state.
429       // (Or `Pend` is being erroniously invoked after previously completing.)
430       return Ready(Status::Internal());
431     }
432 
433     // If an `Awaitable` value is currently being processed, it must be
434     // allowed to complete and store its return value before we can resume
435     // the coroutine.
436     if (promise_handle_.promise().current_awaitable_ != nullptr &&
437         promise_handle_.promise()
438             .current_awaitable_->PendFillReturnValue(cx)
439             .IsPending()) {
440       return Pending();
441     }
442     // Create the arguments (and output storage) for the coroutine.
443     internal::InOut<T> in_out;
444     internal::OptionalOrDefault<T> return_value;
445     in_out.input_cx = &cx;
446     in_out.output = &return_value;
447     promise_handle_.promise().in_out_ = &in_out;
448 
449     // Resume the coroutine, triggering `Awaitable::await_resume()` and the
450     // returning of the resulting value from `co_await`.
451     promise_handle_.resume();
452     if (!promise_handle_.done()) {
453       return Pending();
454     }
455 
456     // Destroy the coroutine state: it has completed, and further calls to
457     // `resume` would result in undefined behavior.
458     promise_handle_.Release();
459 
460     // When the coroutine completed in `resume()` above, it stored its
461     // `co_return` value into `return_value`. This retrieves that value.
462     return std::move(*return_value);
463   }
464 
465   /// Used by the compiler in order to create a `Coro<T>` from a coroutine
466   /// function.
467   using promise_type = ::pw::async2::internal::CoroPromiseType<T>;
468 
469  private:
470   // Allow `CoroPromiseType<T>::get_return_object()` and
471   // `CoroPromiseType<T>::get_retunr_object_on_allocation_failure()` to
472   // use the private constructor below.
473   friend promise_type;
474 
475   /// Create a new `Coro<T>` using a (possibly null) handle.
Coro(internal::OwningCoroutineHandle<promise_type> && promise_handle)476   Coro(internal::OwningCoroutineHandle<promise_type>&& promise_handle)
477       : promise_handle_(std::move(promise_handle)) {}
478 
479   internal::OwningCoroutineHandle<promise_type> promise_handle_;
480 };
481 
482 // Implement the remaining internal pieces that require a definition of
483 // `Coro<T>`.
484 namespace internal {
485 
486 template <typename T>
get_return_object()487 Coro<T> CoroPromiseType<T>::get_return_object() {
488   return internal::OwningCoroutineHandle<CoroPromiseType<T>>(
489       std::coroutine_handle<CoroPromiseType<T>>::from_promise(*this));
490 }
491 
492 template <typename T>
get_return_object_on_allocation_failure()493 Coro<T> CoroPromiseType<T>::get_return_object_on_allocation_failure() {
494   return internal::OwningCoroutineHandle<CoroPromiseType<T>>(nullptr);
495 }
496 
497 }  // namespace internal
498 }  // namespace pw::async2
499