1 // Copyright 2023 gRPC authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GRPC_SRC_CORE_LIB_PROMISE_PARTY_H 16 #define GRPC_SRC_CORE_LIB_PROMISE_PARTY_H 17 18 #include <grpc/event_engine/event_engine.h> 19 #include <grpc/support/port_platform.h> 20 #include <stddef.h> 21 #include <stdint.h> 22 23 #include <atomic> 24 #include <string> 25 #include <utility> 26 27 #include "absl/base/attributes.h" 28 #include "absl/log/check.h" 29 #include "absl/strings/string_view.h" 30 #include "src/core/lib/debug/trace.h" 31 #include "src/core/lib/event_engine/event_engine_context.h" 32 #include "src/core/lib/promise/activity.h" 33 #include "src/core/lib/promise/context.h" 34 #include "src/core/lib/promise/detail/promise_factory.h" 35 #include "src/core/lib/promise/poll.h" 36 #include "src/core/lib/resource_quota/arena.h" 37 #include "src/core/util/construct_destruct.h" 38 #include "src/core/util/crash.h" 39 #include "src/core/util/ref_counted.h" 40 #include "src/core/util/ref_counted_ptr.h" 41 42 namespace grpc_core { 43 44 namespace party_detail { 45 46 // Number of bits reserved for wakeups gives us the maximum number of 47 // participants. 48 static constexpr size_t kMaxParticipants = 16; 49 50 } // namespace party_detail 51 52 // A Party is an Activity with multiple participant promises. 53 class Party : public Activity, private Wakeable { 54 private: 55 // Non-owning wakeup handle. 56 class Handle; 57 58 // One participant in the party. 59 class Participant { 60 public: 61 // Poll the participant. Return true if complete. 62 // Participant should take care of its own deallocation in this case. 63 virtual bool PollParticipantPromise() = 0; 64 65 // Destroy the participant before finishing. 66 virtual void Destroy() = 0; 67 68 // Return a Handle instance for this participant. 69 Wakeable* MakeNonOwningWakeable(Party* party); 70 71 protected: 72 ~Participant(); 73 74 private: 75 Handle* handle_ = nullptr; 76 }; 77 78 public: 79 Party(const Party&) = delete; 80 Party& operator=(const Party&) = delete; 81 82 // When calling into a Party from outside the promises system we often would 83 // like to perform more than one action. 84 // This class tries to acquire the party lock just once - if it succeeds then 85 // it runs the party in its destructor, effectively holding all wakeups of the 86 // party until it goes out of scope. 87 // If it fails, presumably some other thread holds the lock - and in this case 88 // we don't attempt to do any buffering. 89 class WakeupHold { 90 public: 91 WakeupHold() = default; WakeupHold(Party * party)92 explicit WakeupHold(Party* party) 93 : prev_state_(party->state_.load(std::memory_order_relaxed)) { 94 // Try to lock 95 if ((prev_state_ & kLocked) == 0 && 96 party->state_.compare_exchange_weak(prev_state_, 97 (prev_state_ | kLocked) + kOneRef, 98 std::memory_order_relaxed)) { 99 DCHECK_EQ(prev_state_ & ~(kRefMask | kAllocatedMask), 0u) 100 << "Party should have contained no wakeups on lock"; 101 // If we win, record that fact for the destructor 102 party->LogStateChange("WakeupHold", prev_state_, 103 (prev_state_ | kLocked) + kOneRef); 104 party_ = party; 105 } 106 } 107 WakeupHold(const WakeupHold&) = delete; 108 WakeupHold& operator=(const WakeupHold&) = delete; WakeupHold(WakeupHold && other)109 WakeupHold(WakeupHold&& other) noexcept 110 : party_(std::exchange(other.party_, nullptr)), 111 prev_state_(other.prev_state_) {} 112 WakeupHold& operator=(WakeupHold&& other) noexcept { 113 std::swap(party_, other.party_); 114 std::swap(prev_state_, other.prev_state_); 115 return *this; 116 } 117 ~WakeupHold()118 ~WakeupHold() { 119 if (party_ == nullptr) return; 120 party_->RunLockedAndUnref(party_, prev_state_); 121 } 122 123 private: 124 Party* party_ = nullptr; 125 uint64_t prev_state_; 126 }; 127 Make(RefCountedPtr<Arena> arena)128 static RefCountedPtr<Party> Make(RefCountedPtr<Arena> arena) { 129 auto* arena_ptr = arena.get(); 130 return RefCountedPtr<Party>(arena_ptr->New<Party>(std::move(arena))); 131 } 132 133 // Spawn one promise into the party. 134 // The promise will be polled until it is resolved, or until the party is shut 135 // down. 136 // The on_complete callback will be called with the result of the promise if 137 // it completes. 138 // A maximum of sixteen promises can be spawned onto a party. 139 // promise_factory called to create the promise with the party lock taken; 140 // after the promise is created the factory is destroyed. 141 // This means that pointers or references to factory members will be 142 // invalidated after the promise is created - so the promise should not retain 143 // any of these. 144 template <typename Factory, typename OnComplete> 145 void Spawn(absl::string_view name, Factory promise_factory, 146 OnComplete on_complete); 147 148 template <typename Factory> 149 auto SpawnWaitable(absl::string_view name, Factory factory); 150 Orphan()151 void Orphan() final { Crash("unused"); } 152 153 // Activity implementation: not allowed to be overridden by derived types. 154 void ForceImmediateRepoll(WakeupMask mask) final; CurrentParticipant()155 WakeupMask CurrentParticipant() const final { 156 DCHECK(currently_polling_ != kNotPolling); 157 return 1u << currently_polling_; 158 } 159 Waker MakeOwningWaker() final; 160 Waker MakeNonOwningWaker() final; 161 std::string ActivityDebugTag(WakeupMask wakeup_mask) const final; 162 IncrementRefCount()163 GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void IncrementRefCount() { 164 const uint64_t prev_state = 165 state_.fetch_add(kOneRef, std::memory_order_relaxed); 166 LogStateChange("IncrementRefCount", prev_state, prev_state + kOneRef); 167 } Unref()168 GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void Unref() { 169 uint64_t prev_state = state_.fetch_sub(kOneRef, std::memory_order_acq_rel); 170 LogStateChange("Unref", prev_state, prev_state - kOneRef); 171 if ((prev_state & kRefMask) == kOneRef) PartyIsOver(); 172 } 173 Ref()174 RefCountedPtr<Party> Ref() { 175 IncrementRefCount(); 176 return RefCountedPtr<Party>(this); 177 } 178 179 template <typename T> RefAsSubclass()180 RefCountedPtr<T> RefAsSubclass() { 181 IncrementRefCount(); 182 return RefCountedPtr<T>(DownCast<T*>(this)); 183 } 184 arena()185 Arena* arena() { return arena_.get(); } 186 187 protected: 188 friend class Arena; 189 190 // Derived types should be constructed upon `arena`. Party(RefCountedPtr<Arena> arena)191 explicit Party(RefCountedPtr<Arena> arena) : arena_(std::move(arena)) { 192 CHECK(arena_->GetContext<grpc_event_engine::experimental::EventEngine>() != 193 nullptr); 194 } 195 ~Party() override; 196 197 // Main run loop. Must be locked. 198 // Polls participants and drains the add queue until there is no work left to 199 // be done. 200 void RunPartyAndUnref(uint64_t prev_state); 201 202 bool RefIfNonZero(); 203 204 private: 205 // Concrete implementation of a participant for some promise & oncomplete 206 // type. 207 template <typename SuppliedFactory, typename OnComplete> 208 class ParticipantImpl final : public Participant { 209 using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>; 210 using Promise = typename Factory::Promise; 211 212 public: ParticipantImpl(absl::string_view,SuppliedFactory promise_factory,OnComplete on_complete)213 ParticipantImpl(absl::string_view, SuppliedFactory promise_factory, 214 OnComplete on_complete) 215 : on_complete_(std::move(on_complete)) { 216 Construct(&factory_, std::move(promise_factory)); 217 } ~ParticipantImpl()218 ~ParticipantImpl() { 219 if (!started_) { 220 Destruct(&factory_); 221 } else { 222 Destruct(&promise_); 223 } 224 } 225 PollParticipantPromise()226 bool PollParticipantPromise() override { 227 if (!started_) { 228 auto p = factory_.Make(); 229 Destruct(&factory_); 230 Construct(&promise_, std::move(p)); 231 started_ = true; 232 } 233 auto p = promise_(); 234 if (auto* r = p.value_if_ready()) { 235 on_complete_(std::move(*r)); 236 delete this; 237 return true; 238 } 239 return false; 240 } 241 Destroy()242 void Destroy() override { delete this; } 243 244 private: 245 union { 246 GPR_NO_UNIQUE_ADDRESS Factory factory_; 247 GPR_NO_UNIQUE_ADDRESS Promise promise_; 248 }; 249 GPR_NO_UNIQUE_ADDRESS OnComplete on_complete_; 250 bool started_ = false; 251 }; 252 253 template <typename SuppliedFactory> 254 class PromiseParticipantImpl final 255 : public RefCounted<PromiseParticipantImpl<SuppliedFactory>, 256 NonPolymorphicRefCount>, 257 public Participant { 258 using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>; 259 using Promise = typename Factory::Promise; 260 using Result = typename Promise::Result; 261 262 public: PromiseParticipantImpl(absl::string_view,SuppliedFactory promise_factory)263 PromiseParticipantImpl(absl::string_view, SuppliedFactory promise_factory) { 264 Construct(&factory_, std::move(promise_factory)); 265 } 266 ~PromiseParticipantImpl()267 ~PromiseParticipantImpl() { 268 switch (state_.load(std::memory_order_acquire)) { 269 case State::kFactory: 270 Destruct(&factory_); 271 break; 272 case State::kPromise: 273 Destruct(&promise_); 274 break; 275 case State::kResult: 276 Destruct(&result_); 277 break; 278 } 279 } 280 281 // Inside party poll: drive from factory -> promise -> result PollParticipantPromise()282 bool PollParticipantPromise() override { 283 switch (state_.load(std::memory_order_relaxed)) { 284 case State::kFactory: { 285 auto p = factory_.Make(); 286 Destruct(&factory_); 287 Construct(&promise_, std::move(p)); 288 state_.store(State::kPromise, std::memory_order_relaxed); 289 } 290 ABSL_FALLTHROUGH_INTENDED; 291 case State::kPromise: { 292 auto p = promise_(); 293 if (auto* r = p.value_if_ready()) { 294 Destruct(&promise_); 295 Construct(&result_, std::move(*r)); 296 state_.store(State::kResult, std::memory_order_release); 297 waiter_.Wakeup(); 298 this->Unref(); 299 return true; 300 } 301 return false; 302 } 303 case State::kResult: 304 Crash( 305 "unreachable: promises should not be repolled after completion"); 306 } 307 } 308 309 // Outside party poll: check whether the spawning party has completed this 310 // promise. PollCompletion()311 Poll<Result> PollCompletion() { 312 switch (state_.load(std::memory_order_acquire)) { 313 case State::kFactory: 314 case State::kPromise: 315 return Pending{}; 316 case State::kResult: 317 return std::move(result_); 318 } 319 } 320 Destroy()321 void Destroy() override { this->Unref(); } 322 323 private: 324 enum class State : uint8_t { kFactory, kPromise, kResult }; 325 union { 326 GPR_NO_UNIQUE_ADDRESS Factory factory_; 327 GPR_NO_UNIQUE_ADDRESS Promise promise_; 328 GPR_NO_UNIQUE_ADDRESS Result result_; 329 }; 330 Waker waiter_{GetContext<Activity>()->MakeOwningWaker()}; 331 std::atomic<State> state_{State::kFactory}; 332 }; 333 334 // State bits: 335 // The atomic state_ field is composed of the following: 336 // - 24 bits for ref counts 337 // 1 is owned by the party prior to Orphan() 338 // All others are owned by owning wakers 339 // - 1 bit to indicate whether the party is locked 340 // The first thread to set this owns the party until it is unlocked 341 // That thread will run the main loop until no further work needs to 342 // be done. 343 // - 1 bit to indicate whether there are participants waiting to be 344 // added 345 // - 16 bits, one per participant, indicating which participants have 346 // been 347 // woken up and should be polled next time the main loop runs. 348 349 // clang-format off 350 // Bits used to store 16 bits of wakeups 351 static constexpr uint64_t kWakeupMask = 0x0000'0000'0000'ffff; 352 // Bits used to store 16 bits of allocated participant slots. 353 static constexpr uint64_t kAllocatedMask = 0x0000'0000'ffff'0000; 354 // Bit indicating locked or not 355 static constexpr uint64_t kLocked = 0x0000'0008'0000'0000; 356 // Bits used to store 24 bits of ref counts 357 static constexpr uint64_t kRefMask = 0xffff'ff00'0000'0000; 358 // clang-format on 359 360 // Shift to get from a participant mask to an allocated mask. 361 static constexpr size_t kAllocatedShift = 16; 362 // How far to shift to get the refcount 363 static constexpr size_t kRefShift = 40; 364 // One ref count 365 static constexpr uint64_t kOneRef = 1ull << kRefShift; 366 367 // Destroy any remaining participants. 368 // Needs to have normal context setup before calling. 369 void CancelRemainingParticipants(); 370 371 // Run the locked part of the party until it is unlocked. 372 static void RunLockedAndUnref(Party* party, uint64_t prev_state); 373 // Called in response to Unref() hitting zero - ultimately calls PartyOver, 374 // but needs to set some stuff up. 375 // Here so it gets compiled out of line. 376 void PartyIsOver(); 377 378 // Wakeable implementation Wakeup(WakeupMask wakeup_mask)379 void Wakeup(WakeupMask wakeup_mask) final { 380 GRPC_LATENT_SEE_INNER_SCOPE("Party::Wakeup"); 381 if (Activity::current() == this) { 382 wakeup_mask_ |= wakeup_mask; 383 Unref(); 384 return; 385 } 386 WakeupFromState(state_.load(std::memory_order_relaxed), wakeup_mask); 387 } 388 WakeupFromState(uint64_t cur_state,WakeupMask wakeup_mask)389 GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void WakeupFromState( 390 uint64_t cur_state, WakeupMask wakeup_mask) { 391 GRPC_LATENT_SEE_INNER_SCOPE("Party::WakeupFromState"); 392 DCHECK_NE(wakeup_mask & kWakeupMask, 0u) 393 << "Wakeup mask must be non-zero: " << wakeup_mask; 394 while (true) { 395 if (cur_state & kLocked) { 396 // If the party is locked, we need to set the wakeup bits, and then 397 // we'll immediately unref. Since something is running this should never 398 // bring the refcount to zero. 399 DCHECK_GT(cur_state & kRefMask, kOneRef); 400 auto new_state = (cur_state | wakeup_mask) - kOneRef; 401 if (state_.compare_exchange_weak(cur_state, new_state, 402 std::memory_order_release)) { 403 LogStateChange("Wakeup", cur_state, cur_state | wakeup_mask); 404 return; 405 } 406 } else { 407 // If the party is not locked, we need to lock it and run. 408 DCHECK_EQ(cur_state & kWakeupMask, 0u); 409 if (state_.compare_exchange_weak(cur_state, cur_state | kLocked, 410 std::memory_order_acq_rel)) { 411 LogStateChange("WakeupAndRun", cur_state, cur_state | kLocked); 412 wakeup_mask_ |= wakeup_mask; 413 RunLockedAndUnref(this, cur_state); 414 return; 415 } 416 } 417 } 418 } 419 420 void WakeupAsync(WakeupMask wakeup_mask) final; 421 void Drop(WakeupMask wakeup_mask) final; 422 423 // Add a participant (backs Spawn, after type erasure to ParticipantFactory). 424 void AddParticipant(Participant* participant); 425 void DelayAddParticipant(Participant* participant); 426 427 static uint64_t NextAllocationMask(uint64_t current_allocation_mask); 428 429 GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION void LogStateChange( 430 const char* op, uint64_t prev_state, uint64_t new_state, 431 DebugLocation loc = {}) { 432 GRPC_TRACE_LOG(party_state, INFO).AtLocation(loc.file(), loc.line()) 433 << this << " " << op << " " 434 << absl::StrFormat("%016" PRIx64 " -> %016" PRIx64, prev_state, 435 new_state); 436 } 437 438 // Sentinel value for currently_polling_ when no participant is being polled. 439 static constexpr uint8_t kNotPolling = 255; 440 441 std::atomic<uint64_t> state_{kOneRef}; 442 uint8_t currently_polling_ = kNotPolling; 443 WakeupMask wakeup_mask_ = 0; 444 // All current participants, using a tagged format. 445 // If the lower bit is unset, then this is a Participant*. 446 // If the lower bit is set, then this is a ParticipantFactory*. 447 std::atomic<Participant*> participants_[party_detail::kMaxParticipants] = {}; 448 RefCountedPtr<Arena> arena_; 449 }; 450 451 template <> 452 struct ContextSubclass<Party> { 453 using Base = Activity; 454 }; 455 456 template <typename Factory, typename OnComplete> 457 void Party::Spawn(absl::string_view name, Factory promise_factory, 458 OnComplete on_complete) { 459 GRPC_TRACE_LOG(party_state, INFO) << "PARTY[" << this << "]: spawn " << name; 460 AddParticipant(new ParticipantImpl<Factory, OnComplete>( 461 name, std::move(promise_factory), std::move(on_complete))); 462 } 463 464 template <typename Factory> 465 auto Party::SpawnWaitable(absl::string_view name, Factory promise_factory) { 466 GRPC_TRACE_LOG(party_state, INFO) << "PARTY[" << this << "]: spawn " << name; 467 auto participant = MakeRefCounted<PromiseParticipantImpl<Factory>>( 468 name, std::move(promise_factory)); 469 Participant* p = participant->Ref().release(); 470 AddParticipant(p); 471 return [participant = std::move(participant)]() mutable { 472 return participant->PollCompletion(); 473 }; 474 } 475 476 } // namespace grpc_core 477 478 #endif // GRPC_SRC_CORE_LIB_PROMISE_PARTY_H 479