1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/core/lib/promise/party.h"
16
17 #include <grpc/event_engine/event_engine.h>
18 #include <grpc/event_engine/memory_allocator.h>
19 #include <grpc/grpc.h>
20 #include <stdio.h>
21
22 #include <algorithm>
23 #include <atomic>
24 #include <memory>
25 #include <thread>
26 #include <vector>
27
28 #include "absl/base/thread_annotations.h"
29 #include "absl/log/log.h"
30 #include "gtest/gtest.h"
31 #include "src/core/lib/event_engine/default_event_engine.h"
32 #include "src/core/lib/event_engine/event_engine_context.h"
33 #include "src/core/lib/iomgr/exec_ctx.h"
34 #include "src/core/lib/promise/context.h"
35 #include "src/core/lib/promise/inter_activity_latch.h"
36 #include "src/core/lib/promise/poll.h"
37 #include "src/core/lib/promise/seq.h"
38 #include "src/core/lib/promise/sleep.h"
39 #include "src/core/lib/resource_quota/arena.h"
40 #include "src/core/lib/resource_quota/memory_quota.h"
41 #include "src/core/lib/resource_quota/resource_quota.h"
42 #include "src/core/util/notification.h"
43 #include "src/core/util/ref_counted_ptr.h"
44 #include "src/core/util/sync.h"
45 #include "src/core/util/time.h"
46
47 namespace grpc_core {
48
49 ///////////////////////////////////////////////////////////////////////////////
50 // PartyTest
51
52 class PartyTest : public ::testing::Test {
53 protected:
MakeParty()54 RefCountedPtr<Party> MakeParty() {
55 auto arena = SimpleArenaAllocator()->MakeArena();
56 arena->SetContext<grpc_event_engine::experimental::EventEngine>(
57 event_engine_.get());
58 return Party::Make(std::move(arena));
59 }
60
61 private:
62 std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_ =
63 grpc_event_engine::experimental::GetDefaultEventEngine();
64 };
65
TEST_F(PartyTest,Noop)66 TEST_F(PartyTest, Noop) { auto party = MakeParty(); }
67
TEST_F(PartyTest,CanSpawnAndRun)68 TEST_F(PartyTest, CanSpawnAndRun) {
69 auto party = MakeParty();
70 Notification n;
71 party->Spawn(
72 "TestSpawn",
73 [i = 10]() mutable -> Poll<int> {
74 EXPECT_GT(i, 0);
75 GetContext<Activity>()->ForceImmediateRepoll();
76 --i;
77 if (i == 0) return 42;
78 return Pending{};
79 },
80 [&n](int x) {
81 EXPECT_EQ(x, 42);
82 n.Notify();
83 });
84 n.WaitForNotification();
85 }
86
TEST_F(PartyTest,CanSpawnWaitableAndRun)87 TEST_F(PartyTest, CanSpawnWaitableAndRun) {
88 auto party1 = MakeParty();
89 auto party2 = MakeParty();
90 Notification n;
91 InterActivityLatch<void> done;
92 // Spawn a task on party1 that will wait for a task on party2.
93 // The party2 task will wait on the latch `done`.
94 party1->Spawn(
95 "party1_main",
96 [&party2, &done]() {
97 return party2->SpawnWaitable("party2_main",
98 [&done]() { return done.Wait(); });
99 },
100 [&n](Empty) { n.Notify(); });
101 ASSERT_FALSE(n.HasBeenNotified());
102 party1->Spawn("party1_notify_latch", [&done]() { done.Set(); }, [](Empty) {});
103 n.WaitForNotification();
104 }
105
TEST_F(PartyTest,CanSpawnFromSpawn)106 TEST_F(PartyTest, CanSpawnFromSpawn) {
107 auto party = MakeParty();
108 Notification n1;
109 Notification n2;
110 party->Spawn(
111 "TestSpawn",
112 [party, &n2]() -> Poll<int> {
113 party->Spawn(
114 "TestSpawnInner",
115 [i = 10]() mutable -> Poll<int> {
116 GetContext<Activity>()->ForceImmediateRepoll();
117 --i;
118 if (i == 0) return 42;
119 return Pending{};
120 },
121 [&n2](int x) {
122 EXPECT_EQ(x, 42);
123 n2.Notify();
124 });
125 return 1234;
126 },
127 [&n1](int x) {
128 EXPECT_EQ(x, 1234);
129 n1.Notify();
130 });
131 n1.WaitForNotification();
132 n2.WaitForNotification();
133 }
134
TEST_F(PartyTest,CanWakeupWithOwningWaker)135 TEST_F(PartyTest, CanWakeupWithOwningWaker) {
136 auto party = MakeParty();
137 Notification n[10];
138 Notification complete;
139 Waker waker;
140 party->Spawn(
141 "TestSpawn",
142 [i = 0, &waker, &n]() mutable -> Poll<int> {
143 waker = GetContext<Activity>()->MakeOwningWaker();
144 n[i].Notify();
145 i++;
146 if (i == 10) return 42;
147 return Pending{};
148 },
149 [&complete](int x) {
150 EXPECT_EQ(x, 42);
151 complete.Notify();
152 });
153 for (int i = 0; i < 10; i++) {
154 n[i].WaitForNotification();
155 waker.Wakeup();
156 }
157 complete.WaitForNotification();
158 }
159
TEST_F(PartyTest,CanWakeupWithNonOwningWaker)160 TEST_F(PartyTest, CanWakeupWithNonOwningWaker) {
161 auto party = MakeParty();
162 Notification n[10];
163 Notification complete;
164 Waker waker;
165 party->Spawn(
166 "TestSpawn",
167 [i = 10, &waker, &n]() mutable -> Poll<int> {
168 waker = GetContext<Activity>()->MakeNonOwningWaker();
169 --i;
170 n[9 - i].Notify();
171 if (i == 0) return 42;
172 return Pending{};
173 },
174 [&complete](int x) {
175 EXPECT_EQ(x, 42);
176 complete.Notify();
177 });
178 for (int i = 0; i < 9; i++) {
179 n[i].WaitForNotification();
180 EXPECT_FALSE(n[i + 1].HasBeenNotified());
181 waker.Wakeup();
182 }
183 complete.WaitForNotification();
184 }
185
TEST_F(PartyTest,CanWakeupWithNonOwningWakerAfterOrphaning)186 TEST_F(PartyTest, CanWakeupWithNonOwningWakerAfterOrphaning) {
187 auto party = MakeParty();
188 Notification set_waker;
189 Waker waker;
190 party->Spawn(
191 "TestSpawn",
192 [&waker, &set_waker]() mutable -> Poll<int> {
193 EXPECT_FALSE(set_waker.HasBeenNotified());
194 waker = GetContext<Activity>()->MakeNonOwningWaker();
195 set_waker.Notify();
196 return Pending{};
197 },
198 [](int) { Crash("unreachable"); });
199 set_waker.WaitForNotification();
200 party.reset();
201 EXPECT_FALSE(waker.is_unwakeable());
202 waker.Wakeup();
203 EXPECT_TRUE(waker.is_unwakeable());
204 }
205
TEST_F(PartyTest,CanDropNonOwningWakeAfterOrphaning)206 TEST_F(PartyTest, CanDropNonOwningWakeAfterOrphaning) {
207 auto party = MakeParty();
208 Notification set_waker;
209 std::unique_ptr<Waker> waker;
210 party->Spawn(
211 "TestSpawn",
212 [&waker, &set_waker]() mutable -> Poll<int> {
213 EXPECT_FALSE(set_waker.HasBeenNotified());
214 waker = std::make_unique<Waker>(
215 GetContext<Activity>()->MakeNonOwningWaker());
216 set_waker.Notify();
217 return Pending{};
218 },
219 [](int) { Crash("unreachable"); });
220 set_waker.WaitForNotification();
221 party.reset();
222 EXPECT_NE(waker, nullptr);
223 waker.reset();
224 }
225
TEST_F(PartyTest,CanWakeupNonOwningOrphanedWakerWithNoEffect)226 TEST_F(PartyTest, CanWakeupNonOwningOrphanedWakerWithNoEffect) {
227 auto party = MakeParty();
228 Notification set_waker;
229 Waker waker;
230 party->Spawn(
231 "TestSpawn",
232 [&waker, &set_waker]() mutable -> Poll<int> {
233 EXPECT_FALSE(set_waker.HasBeenNotified());
234 waker = GetContext<Activity>()->MakeNonOwningWaker();
235 set_waker.Notify();
236 return Pending{};
237 },
238 [](int) { Crash("unreachable"); });
239 set_waker.WaitForNotification();
240 EXPECT_FALSE(waker.is_unwakeable());
241 party.reset();
242 waker.Wakeup();
243 EXPECT_TRUE(waker.is_unwakeable());
244 }
245
TEST_F(PartyTest,CanBulkSpawn)246 TEST_F(PartyTest, CanBulkSpawn) {
247 auto party = MakeParty();
248 Notification n1;
249 Notification n2;
250 {
251 Party::WakeupHold hold(party.get());
252 party->Spawn("spawn1", []() {}, [&n1](Empty) { n1.Notify(); });
253 party->Spawn("spawn2", []() {}, [&n2](Empty) { n2.Notify(); });
254 for (int i = 0; i < 5000; i++) {
255 EXPECT_FALSE(n1.HasBeenNotified());
256 EXPECT_FALSE(n2.HasBeenNotified());
257 }
258 }
259 n1.WaitForNotification();
260 n2.WaitForNotification();
261 }
262
TEST_F(PartyTest,CanNestWakeupHold)263 TEST_F(PartyTest, CanNestWakeupHold) {
264 auto party = MakeParty();
265 Notification n1;
266 Notification n2;
267 {
268 Party::WakeupHold hold1(party.get());
269 Party::WakeupHold hold2(party.get());
270 party->Spawn("spawn1", []() {}, [&n1](Empty) { n1.Notify(); });
271 party->Spawn("spawn2", []() {}, [&n2](Empty) { n2.Notify(); });
272 for (int i = 0; i < 5000; i++) {
273 EXPECT_FALSE(n1.HasBeenNotified());
274 EXPECT_FALSE(n2.HasBeenNotified());
275 }
276 }
277 n1.WaitForNotification();
278 n2.WaitForNotification();
279 }
280
TEST_F(PartyTest,ThreadStressTest)281 TEST_F(PartyTest, ThreadStressTest) {
282 auto party = MakeParty();
283 std::vector<std::thread> threads;
284 threads.reserve(8);
285 for (int i = 0; i < 8; i++) {
286 threads.emplace_back([party]() {
287 for (int i = 0; i < 100; i++) {
288 ExecCtx ctx; // needed for Sleep
289 Notification promise_complete;
290 party->Spawn("TestSpawn",
291 Seq(Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
292 []() -> Poll<int> { return 42; }),
293 [&promise_complete](int i) {
294 EXPECT_EQ(i, 42);
295 promise_complete.Notify();
296 });
297 promise_complete.WaitForNotification();
298 }
299 });
300 }
301 for (auto& thread : threads) {
302 thread.join();
303 }
304 }
305
306 class PromiseNotification {
307 public:
PromiseNotification(bool owning_waker)308 explicit PromiseNotification(bool owning_waker)
309 : owning_waker_(owning_waker) {}
310
Wait()311 auto Wait() {
312 return [this]() -> Poll<int> {
313 MutexLock lock(&mu_);
314 if (done_) return 42;
315 if (!polled_) {
316 if (owning_waker_) {
317 waker_ = GetContext<Activity>()->MakeOwningWaker();
318 } else {
319 waker_ = GetContext<Activity>()->MakeNonOwningWaker();
320 }
321 polled_ = true;
322 }
323 return Pending{};
324 };
325 }
326
Notify()327 void Notify() {
328 Waker waker;
329 {
330 MutexLock lock(&mu_);
331 done_ = true;
332 waker = std::move(waker_);
333 }
334 waker.Wakeup();
335 }
336
NotifyUnderLock()337 void NotifyUnderLock() {
338 MutexLock lock(&mu_);
339 done_ = true;
340 waker_.WakeupAsync();
341 }
342
343 private:
344 Mutex mu_;
345 const bool owning_waker_;
346 bool done_ ABSL_GUARDED_BY(mu_) = false;
347 bool polled_ ABSL_GUARDED_BY(mu_) = false;
348 Waker waker_ ABSL_GUARDED_BY(mu_);
349 };
350
TEST_F(PartyTest,ThreadStressTestWithOwningWaker)351 TEST_F(PartyTest, ThreadStressTestWithOwningWaker) {
352 auto party = MakeParty();
353 std::vector<std::thread> threads;
354 threads.reserve(8);
355 for (int i = 0; i < 8; i++) {
356 threads.emplace_back([party]() {
357 for (int i = 0; i < 100; i++) {
358 ExecCtx ctx; // needed for Sleep
359 PromiseNotification promise_start(true);
360 Notification promise_complete;
361 party->Spawn("TestSpawn",
362 Seq(promise_start.Wait(),
363 Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
364 []() -> Poll<int> { return 42; }),
365 [&promise_complete](int i) {
366 EXPECT_EQ(i, 42);
367 promise_complete.Notify();
368 });
369 promise_start.Notify();
370 promise_complete.WaitForNotification();
371 }
372 });
373 }
374 for (auto& thread : threads) {
375 thread.join();
376 }
377 }
378
TEST_F(PartyTest,ThreadStressTestWithOwningWakerHoldingLock)379 TEST_F(PartyTest, ThreadStressTestWithOwningWakerHoldingLock) {
380 auto party = MakeParty();
381 std::vector<std::thread> threads;
382 threads.reserve(8);
383 for (int i = 0; i < 8; i++) {
384 threads.emplace_back([party]() {
385 for (int i = 0; i < 100; i++) {
386 ExecCtx ctx; // needed for Sleep
387 PromiseNotification promise_start(true);
388 Notification promise_complete;
389 party->Spawn("TestSpawn",
390 Seq(promise_start.Wait(),
391 Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
392 []() -> Poll<int> { return 42; }),
393 [&promise_complete](int i) {
394 EXPECT_EQ(i, 42);
395 promise_complete.Notify();
396 });
397 promise_start.NotifyUnderLock();
398 promise_complete.WaitForNotification();
399 }
400 });
401 }
402 for (auto& thread : threads) {
403 thread.join();
404 }
405 }
406
TEST_F(PartyTest,ThreadStressTestWithNonOwningWaker)407 TEST_F(PartyTest, ThreadStressTestWithNonOwningWaker) {
408 auto party = MakeParty();
409 std::vector<std::thread> threads;
410 threads.reserve(8);
411 for (int i = 0; i < 8; i++) {
412 threads.emplace_back([party]() {
413 for (int i = 0; i < 100; i++) {
414 ExecCtx ctx; // needed for Sleep
415 PromiseNotification promise_start(false);
416 Notification promise_complete;
417 party->Spawn("TestSpawn",
418 Seq(promise_start.Wait(),
419 Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
420 []() -> Poll<int> { return 42; }),
421 [&promise_complete](int i) {
422 EXPECT_EQ(i, 42);
423 promise_complete.Notify();
424 });
425 promise_start.Notify();
426 promise_complete.WaitForNotification();
427 }
428 });
429 }
430 for (auto& thread : threads) {
431 thread.join();
432 }
433 }
434
TEST_F(PartyTest,ThreadStressTestWithOwningWakerNoSleep)435 TEST_F(PartyTest, ThreadStressTestWithOwningWakerNoSleep) {
436 auto party = MakeParty();
437 std::vector<std::thread> threads;
438 threads.reserve(8);
439 for (int i = 0; i < 8; i++) {
440 threads.emplace_back([party]() {
441 for (int i = 0; i < 10000; i++) {
442 PromiseNotification promise_start(true);
443 Notification promise_complete;
444 party->Spawn(
445 "TestSpawn",
446 Seq(promise_start.Wait(), []() -> Poll<int> { return 42; }),
447 [&promise_complete](int i) {
448 EXPECT_EQ(i, 42);
449 promise_complete.Notify();
450 });
451 promise_start.Notify();
452 promise_complete.WaitForNotification();
453 }
454 });
455 }
456 for (auto& thread : threads) {
457 thread.join();
458 }
459 }
460
TEST_F(PartyTest,ThreadStressTestWithNonOwningWakerNoSleep)461 TEST_F(PartyTest, ThreadStressTestWithNonOwningWakerNoSleep) {
462 auto party = MakeParty();
463 std::vector<std::thread> threads;
464 threads.reserve(8);
465 for (int i = 0; i < 8; i++) {
466 threads.emplace_back([party]() {
467 for (int i = 0; i < 10000; i++) {
468 PromiseNotification promise_start(false);
469 Notification promise_complete;
470 party->Spawn(
471 "TestSpawn",
472 Seq(promise_start.Wait(), []() -> Poll<int> { return 42; }),
473 [&promise_complete](int i) {
474 EXPECT_EQ(i, 42);
475 promise_complete.Notify();
476 });
477 promise_start.Notify();
478 promise_complete.WaitForNotification();
479 }
480 });
481 }
482 for (auto& thread : threads) {
483 thread.join();
484 }
485 }
486
TEST_F(PartyTest,ThreadStressTestWithInnerSpawn)487 TEST_F(PartyTest, ThreadStressTestWithInnerSpawn) {
488 auto party = MakeParty();
489 std::vector<std::thread> threads;
490 threads.reserve(8);
491 for (int i = 0; i < 8; i++) {
492 threads.emplace_back([party]() {
493 for (int i = 0; i < 100; i++) {
494 ExecCtx ctx; // needed for Sleep
495 PromiseNotification inner_start(true);
496 PromiseNotification inner_complete(false);
497 Notification promise_complete;
498 party->Spawn(
499 "TestSpawn",
500 Seq(
501 [party, &inner_start, &inner_complete]() -> Poll<int> {
502 party->Spawn("TestSpawnInner",
503 Seq(inner_start.Wait(), []() { return 0; }),
504 [&inner_complete](int i) {
505 EXPECT_EQ(i, 0);
506 inner_complete.Notify();
507 });
508 return 0;
509 },
510 Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
511 [&inner_start]() {
512 inner_start.Notify();
513 return 0;
514 },
515 inner_complete.Wait(), []() -> Poll<int> { return 42; }),
516 [&promise_complete](int i) {
517 EXPECT_EQ(i, 42);
518 promise_complete.Notify();
519 });
520 promise_complete.WaitForNotification();
521 }
522 });
523 }
524 for (auto& thread : threads) {
525 thread.join();
526 }
527 }
528
TEST_F(PartyTest,NestedWakeup)529 TEST_F(PartyTest, NestedWakeup) {
530 auto party1 = MakeParty();
531 auto party2 = MakeParty();
532 auto party3 = MakeParty();
533 int whats_going_on = 0;
534 Notification done1;
535 Notification started2;
536 Notification done2;
537 Notification started3;
538 Notification notify_done;
539 party1->Spawn(
540 "p1",
541 [&]() {
542 EXPECT_EQ(whats_going_on, 0);
543 whats_going_on = 1;
544 party2->Spawn(
545 "p2",
546 [&]() {
547 done1.WaitForNotification();
548 started2.Notify();
549 started3.WaitForNotification();
550 EXPECT_EQ(whats_going_on, 3);
551 whats_going_on = 4;
552 },
553 [&](Empty) {
554 EXPECT_EQ(whats_going_on, 4);
555 whats_going_on = 5;
556 done2.Notify();
557 });
558 party3->Spawn(
559 "p3",
560 [&]() {
561 started2.WaitForNotification();
562 started3.Notify();
563 done2.WaitForNotification();
564 EXPECT_EQ(whats_going_on, 5);
565 whats_going_on = 6;
566 },
567 [&](Empty) {
568 EXPECT_EQ(whats_going_on, 6);
569 whats_going_on = 7;
570 notify_done.Notify();
571 });
572 EXPECT_EQ(whats_going_on, 1);
573 whats_going_on = 2;
574 },
575 [&](Empty) {
576 EXPECT_EQ(whats_going_on, 2);
577 whats_going_on = 3;
578 done1.Notify();
579 });
580 notify_done.WaitForNotification();
581 }
582
583 } // namespace grpc_core
584
main(int argc,char ** argv)585 int main(int argc, char** argv) {
586 ::testing::InitGoogleTest(&argc, argv);
587 grpc_init();
588 int r = RUN_ALL_TESTS();
589 grpc_shutdown();
590 return r;
591 }
592