• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
16 #define GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
17 
18 #include <grpc/event_engine/event_engine.h>
19 #include <grpc/support/port_platform.h>
20 #include <stddef.h>
21 #include <stdint.h>
22 
23 #include <atomic>
24 #include <string>
25 #include <utility>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/log/check.h"
29 #include "absl/strings/string_view.h"
30 #include "src/core/lib/debug/trace.h"
31 #include "src/core/lib/event_engine/event_engine_context.h"
32 #include "src/core/lib/promise/activity.h"
33 #include "src/core/lib/promise/context.h"
34 #include "src/core/lib/promise/detail/promise_factory.h"
35 #include "src/core/lib/promise/poll.h"
36 #include "src/core/lib/resource_quota/arena.h"
37 #include "src/core/util/construct_destruct.h"
38 #include "src/core/util/crash.h"
39 #include "src/core/util/ref_counted.h"
40 #include "src/core/util/ref_counted_ptr.h"
41 
42 namespace grpc_core {
43 
44 namespace party_detail {
45 
46 // Number of bits reserved for wakeups gives us the maximum number of
47 // participants.
48 static constexpr size_t kMaxParticipants = 16;
49 
50 }  // namespace party_detail
51 
52 // A Party is an Activity with multiple participant promises.
53 class Party : public Activity, private Wakeable {
54  private:
55   // Non-owning wakeup handle.
56   class Handle;
57 
58   // One participant in the party.
59   class Participant {
60    public:
61     // Poll the participant. Return true if complete.
62     // Participant should take care of its own deallocation in this case.
63     virtual bool PollParticipantPromise() = 0;
64 
65     // Destroy the participant before finishing.
66     virtual void Destroy() = 0;
67 
68     // Return a Handle instance for this participant.
69     Wakeable* MakeNonOwningWakeable(Party* party);
70 
71    protected:
72     ~Participant();
73 
74    private:
75     Handle* handle_ = nullptr;
76   };
77 
78  public:
79   Party(const Party&) = delete;
80   Party& operator=(const Party&) = delete;
81 
82   // When calling into a Party from outside the promises system we often would
83   // like to perform more than one action.
84   // This class tries to acquire the party lock just once - if it succeeds then
85   // it runs the party in its destructor, effectively holding all wakeups of the
86   // party until it goes out of scope.
87   // If it fails, presumably some other thread holds the lock - and in this case
88   // we don't attempt to do any buffering.
89   class WakeupHold {
90    public:
91     WakeupHold() = default;
WakeupHold(Party * party)92     explicit WakeupHold(Party* party)
93         : prev_state_(party->state_.load(std::memory_order_relaxed)) {
94       // Try to lock
95       if ((prev_state_ & kLocked) == 0 &&
96           party->state_.compare_exchange_weak(prev_state_,
97                                               (prev_state_ | kLocked) + kOneRef,
98                                               std::memory_order_relaxed)) {
99         DCHECK_EQ(prev_state_ & ~(kRefMask | kAllocatedMask), 0u)
100             << "Party should have contained no wakeups on lock";
101         // If we win, record that fact for the destructor
102         party->LogStateChange("WakeupHold", prev_state_,
103                               (prev_state_ | kLocked) + kOneRef);
104         party_ = party;
105       }
106     }
107     WakeupHold(const WakeupHold&) = delete;
108     WakeupHold& operator=(const WakeupHold&) = delete;
WakeupHold(WakeupHold && other)109     WakeupHold(WakeupHold&& other) noexcept
110         : party_(std::exchange(other.party_, nullptr)),
111           prev_state_(other.prev_state_) {}
112     WakeupHold& operator=(WakeupHold&& other) noexcept {
113       std::swap(party_, other.party_);
114       std::swap(prev_state_, other.prev_state_);
115       return *this;
116     }
117 
~WakeupHold()118     ~WakeupHold() {
119       if (party_ == nullptr) return;
120       party_->RunLockedAndUnref(party_, prev_state_);
121     }
122 
123    private:
124     Party* party_ = nullptr;
125     uint64_t prev_state_;
126   };
127 
Make(RefCountedPtr<Arena> arena)128   static RefCountedPtr<Party> Make(RefCountedPtr<Arena> arena) {
129     auto* arena_ptr = arena.get();
130     return RefCountedPtr<Party>(arena_ptr->New<Party>(std::move(arena)));
131   }
132 
133   // Spawn one promise into the party.
134   // The promise will be polled until it is resolved, or until the party is shut
135   // down.
136   // The on_complete callback will be called with the result of the promise if
137   // it completes.
138   // A maximum of sixteen promises can be spawned onto a party.
139   // promise_factory called to create the promise with the party lock taken;
140   // after the promise is created the factory is destroyed.
141   // This means that pointers or references to factory members will be
142   // invalidated after the promise is created - so the promise should not retain
143   // any of these.
144   template <typename Factory, typename OnComplete>
145   void Spawn(absl::string_view name, Factory promise_factory,
146              OnComplete on_complete);
147 
148   template <typename Factory>
149   auto SpawnWaitable(absl::string_view name, Factory factory);
150 
Orphan()151   void Orphan() final { Crash("unused"); }
152 
153   // Activity implementation: not allowed to be overridden by derived types.
154   void ForceImmediateRepoll(WakeupMask mask) final;
CurrentParticipant()155   WakeupMask CurrentParticipant() const final {
156     DCHECK(currently_polling_ != kNotPolling);
157     return 1u << currently_polling_;
158   }
159   Waker MakeOwningWaker() final;
160   Waker MakeNonOwningWaker() final;
161   std::string ActivityDebugTag(WakeupMask wakeup_mask) const final;
162 
IncrementRefCount()163   GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void IncrementRefCount() {
164     const uint64_t prev_state =
165         state_.fetch_add(kOneRef, std::memory_order_relaxed);
166     LogStateChange("IncrementRefCount", prev_state, prev_state + kOneRef);
167   }
Unref()168   GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void Unref() {
169     uint64_t prev_state = state_.fetch_sub(kOneRef, std::memory_order_acq_rel);
170     LogStateChange("Unref", prev_state, prev_state - kOneRef);
171     if ((prev_state & kRefMask) == kOneRef) PartyIsOver();
172   }
173 
Ref()174   RefCountedPtr<Party> Ref() {
175     IncrementRefCount();
176     return RefCountedPtr<Party>(this);
177   }
178 
179   template <typename T>
RefAsSubclass()180   RefCountedPtr<T> RefAsSubclass() {
181     IncrementRefCount();
182     return RefCountedPtr<T>(DownCast<T*>(this));
183   }
184 
arena()185   Arena* arena() { return arena_.get(); }
186 
187  protected:
188   friend class Arena;
189 
190   // Derived types should be constructed upon `arena`.
Party(RefCountedPtr<Arena> arena)191   explicit Party(RefCountedPtr<Arena> arena) : arena_(std::move(arena)) {
192     CHECK(arena_->GetContext<grpc_event_engine::experimental::EventEngine>() !=
193           nullptr);
194   }
195   ~Party() override;
196 
197   // Main run loop. Must be locked.
198   // Polls participants and drains the add queue until there is no work left to
199   // be done.
200   void RunPartyAndUnref(uint64_t prev_state);
201 
202   bool RefIfNonZero();
203 
204  private:
205   // Concrete implementation of a participant for some promise & oncomplete
206   // type.
207   template <typename SuppliedFactory, typename OnComplete>
208   class ParticipantImpl final : public Participant {
209     using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>;
210     using Promise = typename Factory::Promise;
211 
212    public:
ParticipantImpl(absl::string_view,SuppliedFactory promise_factory,OnComplete on_complete)213     ParticipantImpl(absl::string_view, SuppliedFactory promise_factory,
214                     OnComplete on_complete)
215         : on_complete_(std::move(on_complete)) {
216       Construct(&factory_, std::move(promise_factory));
217     }
~ParticipantImpl()218     ~ParticipantImpl() {
219       if (!started_) {
220         Destruct(&factory_);
221       } else {
222         Destruct(&promise_);
223       }
224     }
225 
PollParticipantPromise()226     bool PollParticipantPromise() override {
227       if (!started_) {
228         auto p = factory_.Make();
229         Destruct(&factory_);
230         Construct(&promise_, std::move(p));
231         started_ = true;
232       }
233       auto p = promise_();
234       if (auto* r = p.value_if_ready()) {
235         on_complete_(std::move(*r));
236         delete this;
237         return true;
238       }
239       return false;
240     }
241 
Destroy()242     void Destroy() override { delete this; }
243 
244    private:
245     union {
246       GPR_NO_UNIQUE_ADDRESS Factory factory_;
247       GPR_NO_UNIQUE_ADDRESS Promise promise_;
248     };
249     GPR_NO_UNIQUE_ADDRESS OnComplete on_complete_;
250     bool started_ = false;
251   };
252 
253   template <typename SuppliedFactory>
254   class PromiseParticipantImpl final
255       : public RefCounted<PromiseParticipantImpl<SuppliedFactory>,
256                           NonPolymorphicRefCount>,
257         public Participant {
258     using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>;
259     using Promise = typename Factory::Promise;
260     using Result = typename Promise::Result;
261 
262    public:
PromiseParticipantImpl(absl::string_view,SuppliedFactory promise_factory)263     PromiseParticipantImpl(absl::string_view, SuppliedFactory promise_factory) {
264       Construct(&factory_, std::move(promise_factory));
265     }
266 
~PromiseParticipantImpl()267     ~PromiseParticipantImpl() {
268       switch (state_.load(std::memory_order_acquire)) {
269         case State::kFactory:
270           Destruct(&factory_);
271           break;
272         case State::kPromise:
273           Destruct(&promise_);
274           break;
275         case State::kResult:
276           Destruct(&result_);
277           break;
278       }
279     }
280 
281     // Inside party poll: drive from factory -> promise -> result
PollParticipantPromise()282     bool PollParticipantPromise() override {
283       switch (state_.load(std::memory_order_relaxed)) {
284         case State::kFactory: {
285           auto p = factory_.Make();
286           Destruct(&factory_);
287           Construct(&promise_, std::move(p));
288           state_.store(State::kPromise, std::memory_order_relaxed);
289         }
290           ABSL_FALLTHROUGH_INTENDED;
291         case State::kPromise: {
292           auto p = promise_();
293           if (auto* r = p.value_if_ready()) {
294             Destruct(&promise_);
295             Construct(&result_, std::move(*r));
296             state_.store(State::kResult, std::memory_order_release);
297             waiter_.Wakeup();
298             this->Unref();
299             return true;
300           }
301           return false;
302         }
303         case State::kResult:
304           Crash(
305               "unreachable: promises should not be repolled after completion");
306       }
307     }
308 
309     // Outside party poll: check whether the spawning party has completed this
310     // promise.
PollCompletion()311     Poll<Result> PollCompletion() {
312       switch (state_.load(std::memory_order_acquire)) {
313         case State::kFactory:
314         case State::kPromise:
315           return Pending{};
316         case State::kResult:
317           return std::move(result_);
318       }
319     }
320 
Destroy()321     void Destroy() override { this->Unref(); }
322 
323    private:
324     enum class State : uint8_t { kFactory, kPromise, kResult };
325     union {
326       GPR_NO_UNIQUE_ADDRESS Factory factory_;
327       GPR_NO_UNIQUE_ADDRESS Promise promise_;
328       GPR_NO_UNIQUE_ADDRESS Result result_;
329     };
330     Waker waiter_{GetContext<Activity>()->MakeOwningWaker()};
331     std::atomic<State> state_{State::kFactory};
332   };
333 
334   // State bits:
335   // The atomic state_ field is composed of the following:
336   //   - 24 bits for ref counts
337   //     1 is owned by the party prior to Orphan()
338   //     All others are owned by owning wakers
339   //   - 1 bit to indicate whether the party is locked
340   //     The first thread to set this owns the party until it is unlocked
341   //     That thread will run the main loop until no further work needs to
342   //     be done.
343   //   - 1 bit to indicate whether there are participants waiting to be
344   //   added
345   //   - 16 bits, one per participant, indicating which participants have
346   //   been
347   //     woken up and should be polled next time the main loop runs.
348 
349   // clang-format off
350   // Bits used to store 16 bits of wakeups
351   static constexpr uint64_t kWakeupMask    = 0x0000'0000'0000'ffff;
352   // Bits used to store 16 bits of allocated participant slots.
353   static constexpr uint64_t kAllocatedMask = 0x0000'0000'ffff'0000;
354   // Bit indicating locked or not
355   static constexpr uint64_t kLocked        = 0x0000'0008'0000'0000;
356   // Bits used to store 24 bits of ref counts
357   static constexpr uint64_t kRefMask       = 0xffff'ff00'0000'0000;
358   // clang-format on
359 
360   // Shift to get from a participant mask to an allocated mask.
361   static constexpr size_t kAllocatedShift = 16;
362   // How far to shift to get the refcount
363   static constexpr size_t kRefShift = 40;
364   // One ref count
365   static constexpr uint64_t kOneRef = 1ull << kRefShift;
366 
367   // Destroy any remaining participants.
368   // Needs to have normal context setup before calling.
369   void CancelRemainingParticipants();
370 
371   // Run the locked part of the party until it is unlocked.
372   static void RunLockedAndUnref(Party* party, uint64_t prev_state);
373   // Called in response to Unref() hitting zero - ultimately calls PartyOver,
374   // but needs to set some stuff up.
375   // Here so it gets compiled out of line.
376   void PartyIsOver();
377 
378   // Wakeable implementation
Wakeup(WakeupMask wakeup_mask)379   void Wakeup(WakeupMask wakeup_mask) final {
380     GRPC_LATENT_SEE_INNER_SCOPE("Party::Wakeup");
381     if (Activity::current() == this) {
382       wakeup_mask_ |= wakeup_mask;
383       Unref();
384       return;
385     }
386     WakeupFromState(state_.load(std::memory_order_relaxed), wakeup_mask);
387   }
388 
WakeupFromState(uint64_t cur_state,WakeupMask wakeup_mask)389   GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void WakeupFromState(
390       uint64_t cur_state, WakeupMask wakeup_mask) {
391     GRPC_LATENT_SEE_INNER_SCOPE("Party::WakeupFromState");
392     DCHECK_NE(wakeup_mask & kWakeupMask, 0u)
393         << "Wakeup mask must be non-zero: " << wakeup_mask;
394     while (true) {
395       if (cur_state & kLocked) {
396         // If the party is locked, we need to set the wakeup bits, and then
397         // we'll immediately unref. Since something is running this should never
398         // bring the refcount to zero.
399         DCHECK_GT(cur_state & kRefMask, kOneRef);
400         auto new_state = (cur_state | wakeup_mask) - kOneRef;
401         if (state_.compare_exchange_weak(cur_state, new_state,
402                                          std::memory_order_release)) {
403           LogStateChange("Wakeup", cur_state, cur_state | wakeup_mask);
404           return;
405         }
406       } else {
407         // If the party is not locked, we need to lock it and run.
408         DCHECK_EQ(cur_state & kWakeupMask, 0u);
409         if (state_.compare_exchange_weak(cur_state, cur_state | kLocked,
410                                          std::memory_order_acq_rel)) {
411           LogStateChange("WakeupAndRun", cur_state, cur_state | kLocked);
412           wakeup_mask_ |= wakeup_mask;
413           RunLockedAndUnref(this, cur_state);
414           return;
415         }
416       }
417     }
418   }
419 
420   void WakeupAsync(WakeupMask wakeup_mask) final;
421   void Drop(WakeupMask wakeup_mask) final;
422 
423   // Add a participant (backs Spawn, after type erasure to ParticipantFactory).
424   void AddParticipant(Participant* participant);
425   void DelayAddParticipant(Participant* participant);
426 
427   static uint64_t NextAllocationMask(uint64_t current_allocation_mask);
428 
429   GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void LogStateChange(
430       const char* op, uint64_t prev_state, uint64_t new_state,
431       DebugLocation loc = {}) {
432     GRPC_TRACE_LOG(party_state, INFO).AtLocation(loc.file(), loc.line())
433         << this << " " << op << " "
434         << absl::StrFormat("%016" PRIx64 " -> %016" PRIx64, prev_state,
435                            new_state);
436   }
437 
438   // Sentinel value for currently_polling_ when no participant is being polled.
439   static constexpr uint8_t kNotPolling = 255;
440 
441   std::atomic<uint64_t> state_{kOneRef};
442   uint8_t currently_polling_ = kNotPolling;
443   WakeupMask wakeup_mask_ = 0;
444   // All current participants, using a tagged format.
445   // If the lower bit is unset, then this is a Participant*.
446   // If the lower bit is set, then this is a ParticipantFactory*.
447   std::atomic<Participant*> participants_[party_detail::kMaxParticipants] = {};
448   RefCountedPtr<Arena> arena_;
449 };
450 
451 template <>
452 struct ContextSubclass<Party> {
453   using Base = Activity;
454 };
455 
456 template <typename Factory, typename OnComplete>
457 void Party::Spawn(absl::string_view name, Factory promise_factory,
458                   OnComplete on_complete) {
459   GRPC_TRACE_LOG(party_state, INFO) << "PARTY[" << this << "]: spawn " << name;
460   AddParticipant(new ParticipantImpl<Factory, OnComplete>(
461       name, std::move(promise_factory), std::move(on_complete)));
462 }
463 
464 template <typename Factory>
465 auto Party::SpawnWaitable(absl::string_view name, Factory promise_factory) {
466   GRPC_TRACE_LOG(party_state, INFO) << "PARTY[" << this << "]: spawn " << name;
467   auto participant = MakeRefCounted<PromiseParticipantImpl<Factory>>(
468       name, std::move(promise_factory));
469   Participant* p = participant->Ref().release();
470   AddParticipant(p);
471   return [participant = std::move(participant)]() mutable {
472     return participant->PollCompletion();
473   };
474 }
475 
476 }  // namespace grpc_core
477 
478 #endif  // GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
479