• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/lib/promise/party.h"
16 
17 #include <grpc/event_engine/event_engine.h>
18 #include <grpc/event_engine/memory_allocator.h>
19 #include <grpc/grpc.h>
20 #include <stdio.h>
21 
22 #include <algorithm>
23 #include <atomic>
24 #include <memory>
25 #include <thread>
26 #include <vector>
27 
28 #include "absl/base/thread_annotations.h"
29 #include "absl/log/log.h"
30 #include "gtest/gtest.h"
31 #include "src/core/lib/event_engine/default_event_engine.h"
32 #include "src/core/lib/event_engine/event_engine_context.h"
33 #include "src/core/lib/iomgr/exec_ctx.h"
34 #include "src/core/lib/promise/context.h"
35 #include "src/core/lib/promise/inter_activity_latch.h"
36 #include "src/core/lib/promise/poll.h"
37 #include "src/core/lib/promise/seq.h"
38 #include "src/core/lib/promise/sleep.h"
39 #include "src/core/lib/resource_quota/arena.h"
40 #include "src/core/lib/resource_quota/memory_quota.h"
41 #include "src/core/lib/resource_quota/resource_quota.h"
42 #include "src/core/util/notification.h"
43 #include "src/core/util/ref_counted_ptr.h"
44 #include "src/core/util/sync.h"
45 #include "src/core/util/time.h"
46 
47 namespace grpc_core {
48 
49 ///////////////////////////////////////////////////////////////////////////////
50 // PartyTest
51 
52 class PartyTest : public ::testing::Test {
53  protected:
MakeParty()54   RefCountedPtr<Party> MakeParty() {
55     auto arena = SimpleArenaAllocator()->MakeArena();
56     arena->SetContext<grpc_event_engine::experimental::EventEngine>(
57         event_engine_.get());
58     return Party::Make(std::move(arena));
59   }
60 
61  private:
62   std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_ =
63       grpc_event_engine::experimental::GetDefaultEventEngine();
64 };
65 
TEST_F(PartyTest,Noop)66 TEST_F(PartyTest, Noop) { auto party = MakeParty(); }
67 
TEST_F(PartyTest,CanSpawnAndRun)68 TEST_F(PartyTest, CanSpawnAndRun) {
69   auto party = MakeParty();
70   Notification n;
71   party->Spawn(
72       "TestSpawn",
73       [i = 10]() mutable -> Poll<int> {
74         EXPECT_GT(i, 0);
75         GetContext<Activity>()->ForceImmediateRepoll();
76         --i;
77         if (i == 0) return 42;
78         return Pending{};
79       },
80       [&n](int x) {
81         EXPECT_EQ(x, 42);
82         n.Notify();
83       });
84   n.WaitForNotification();
85 }
86 
TEST_F(PartyTest,CanSpawnWaitableAndRun)87 TEST_F(PartyTest, CanSpawnWaitableAndRun) {
88   auto party1 = MakeParty();
89   auto party2 = MakeParty();
90   Notification n;
91   InterActivityLatch<void> done;
92   // Spawn a task on party1 that will wait for a task on party2.
93   // The party2 task will wait on the latch `done`.
94   party1->Spawn(
95       "party1_main",
96       [&party2, &done]() {
97         return party2->SpawnWaitable("party2_main",
98                                      [&done]() { return done.Wait(); });
99       },
100       [&n](Empty) { n.Notify(); });
101   ASSERT_FALSE(n.HasBeenNotified());
102   party1->Spawn("party1_notify_latch", [&done]() { done.Set(); }, [](Empty) {});
103   n.WaitForNotification();
104 }
105 
TEST_F(PartyTest,CanSpawnFromSpawn)106 TEST_F(PartyTest, CanSpawnFromSpawn) {
107   auto party = MakeParty();
108   Notification n1;
109   Notification n2;
110   party->Spawn(
111       "TestSpawn",
112       [party, &n2]() -> Poll<int> {
113         party->Spawn(
114             "TestSpawnInner",
115             [i = 10]() mutable -> Poll<int> {
116               GetContext<Activity>()->ForceImmediateRepoll();
117               --i;
118               if (i == 0) return 42;
119               return Pending{};
120             },
121             [&n2](int x) {
122               EXPECT_EQ(x, 42);
123               n2.Notify();
124             });
125         return 1234;
126       },
127       [&n1](int x) {
128         EXPECT_EQ(x, 1234);
129         n1.Notify();
130       });
131   n1.WaitForNotification();
132   n2.WaitForNotification();
133 }
134 
TEST_F(PartyTest,CanWakeupWithOwningWaker)135 TEST_F(PartyTest, CanWakeupWithOwningWaker) {
136   auto party = MakeParty();
137   Notification n[10];
138   Notification complete;
139   Waker waker;
140   party->Spawn(
141       "TestSpawn",
142       [i = 0, &waker, &n]() mutable -> Poll<int> {
143         waker = GetContext<Activity>()->MakeOwningWaker();
144         n[i].Notify();
145         i++;
146         if (i == 10) return 42;
147         return Pending{};
148       },
149       [&complete](int x) {
150         EXPECT_EQ(x, 42);
151         complete.Notify();
152       });
153   for (int i = 0; i < 10; i++) {
154     n[i].WaitForNotification();
155     waker.Wakeup();
156   }
157   complete.WaitForNotification();
158 }
159 
TEST_F(PartyTest,CanWakeupWithNonOwningWaker)160 TEST_F(PartyTest, CanWakeupWithNonOwningWaker) {
161   auto party = MakeParty();
162   Notification n[10];
163   Notification complete;
164   Waker waker;
165   party->Spawn(
166       "TestSpawn",
167       [i = 10, &waker, &n]() mutable -> Poll<int> {
168         waker = GetContext<Activity>()->MakeNonOwningWaker();
169         --i;
170         n[9 - i].Notify();
171         if (i == 0) return 42;
172         return Pending{};
173       },
174       [&complete](int x) {
175         EXPECT_EQ(x, 42);
176         complete.Notify();
177       });
178   for (int i = 0; i < 9; i++) {
179     n[i].WaitForNotification();
180     EXPECT_FALSE(n[i + 1].HasBeenNotified());
181     waker.Wakeup();
182   }
183   complete.WaitForNotification();
184 }
185 
TEST_F(PartyTest,CanWakeupWithNonOwningWakerAfterOrphaning)186 TEST_F(PartyTest, CanWakeupWithNonOwningWakerAfterOrphaning) {
187   auto party = MakeParty();
188   Notification set_waker;
189   Waker waker;
190   party->Spawn(
191       "TestSpawn",
192       [&waker, &set_waker]() mutable -> Poll<int> {
193         EXPECT_FALSE(set_waker.HasBeenNotified());
194         waker = GetContext<Activity>()->MakeNonOwningWaker();
195         set_waker.Notify();
196         return Pending{};
197       },
198       [](int) { Crash("unreachable"); });
199   set_waker.WaitForNotification();
200   party.reset();
201   EXPECT_FALSE(waker.is_unwakeable());
202   waker.Wakeup();
203   EXPECT_TRUE(waker.is_unwakeable());
204 }
205 
TEST_F(PartyTest,CanDropNonOwningWakeAfterOrphaning)206 TEST_F(PartyTest, CanDropNonOwningWakeAfterOrphaning) {
207   auto party = MakeParty();
208   Notification set_waker;
209   std::unique_ptr<Waker> waker;
210   party->Spawn(
211       "TestSpawn",
212       [&waker, &set_waker]() mutable -> Poll<int> {
213         EXPECT_FALSE(set_waker.HasBeenNotified());
214         waker = std::make_unique<Waker>(
215             GetContext<Activity>()->MakeNonOwningWaker());
216         set_waker.Notify();
217         return Pending{};
218       },
219       [](int) { Crash("unreachable"); });
220   set_waker.WaitForNotification();
221   party.reset();
222   EXPECT_NE(waker, nullptr);
223   waker.reset();
224 }
225 
TEST_F(PartyTest,CanWakeupNonOwningOrphanedWakerWithNoEffect)226 TEST_F(PartyTest, CanWakeupNonOwningOrphanedWakerWithNoEffect) {
227   auto party = MakeParty();
228   Notification set_waker;
229   Waker waker;
230   party->Spawn(
231       "TestSpawn",
232       [&waker, &set_waker]() mutable -> Poll<int> {
233         EXPECT_FALSE(set_waker.HasBeenNotified());
234         waker = GetContext<Activity>()->MakeNonOwningWaker();
235         set_waker.Notify();
236         return Pending{};
237       },
238       [](int) { Crash("unreachable"); });
239   set_waker.WaitForNotification();
240   EXPECT_FALSE(waker.is_unwakeable());
241   party.reset();
242   waker.Wakeup();
243   EXPECT_TRUE(waker.is_unwakeable());
244 }
245 
TEST_F(PartyTest,CanBulkSpawn)246 TEST_F(PartyTest, CanBulkSpawn) {
247   auto party = MakeParty();
248   Notification n1;
249   Notification n2;
250   {
251     Party::WakeupHold hold(party.get());
252     party->Spawn("spawn1", []() {}, [&n1](Empty) { n1.Notify(); });
253     party->Spawn("spawn2", []() {}, [&n2](Empty) { n2.Notify(); });
254     for (int i = 0; i < 5000; i++) {
255       EXPECT_FALSE(n1.HasBeenNotified());
256       EXPECT_FALSE(n2.HasBeenNotified());
257     }
258   }
259   n1.WaitForNotification();
260   n2.WaitForNotification();
261 }
262 
TEST_F(PartyTest,CanNestWakeupHold)263 TEST_F(PartyTest, CanNestWakeupHold) {
264   auto party = MakeParty();
265   Notification n1;
266   Notification n2;
267   {
268     Party::WakeupHold hold1(party.get());
269     Party::WakeupHold hold2(party.get());
270     party->Spawn("spawn1", []() {}, [&n1](Empty) { n1.Notify(); });
271     party->Spawn("spawn2", []() {}, [&n2](Empty) { n2.Notify(); });
272     for (int i = 0; i < 5000; i++) {
273       EXPECT_FALSE(n1.HasBeenNotified());
274       EXPECT_FALSE(n2.HasBeenNotified());
275     }
276   }
277   n1.WaitForNotification();
278   n2.WaitForNotification();
279 }
280 
TEST_F(PartyTest,ThreadStressTest)281 TEST_F(PartyTest, ThreadStressTest) {
282   auto party = MakeParty();
283   std::vector<std::thread> threads;
284   threads.reserve(8);
285   for (int i = 0; i < 8; i++) {
286     threads.emplace_back([party]() {
287       for (int i = 0; i < 100; i++) {
288         ExecCtx ctx;  // needed for Sleep
289         Notification promise_complete;
290         party->Spawn("TestSpawn",
291                      Seq(Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
292                          []() -> Poll<int> { return 42; }),
293                      [&promise_complete](int i) {
294                        EXPECT_EQ(i, 42);
295                        promise_complete.Notify();
296                      });
297         promise_complete.WaitForNotification();
298       }
299     });
300   }
301   for (auto& thread : threads) {
302     thread.join();
303   }
304 }
305 
306 class PromiseNotification {
307  public:
PromiseNotification(bool owning_waker)308   explicit PromiseNotification(bool owning_waker)
309       : owning_waker_(owning_waker) {}
310 
Wait()311   auto Wait() {
312     return [this]() -> Poll<int> {
313       MutexLock lock(&mu_);
314       if (done_) return 42;
315       if (!polled_) {
316         if (owning_waker_) {
317           waker_ = GetContext<Activity>()->MakeOwningWaker();
318         } else {
319           waker_ = GetContext<Activity>()->MakeNonOwningWaker();
320         }
321         polled_ = true;
322       }
323       return Pending{};
324     };
325   }
326 
Notify()327   void Notify() {
328     Waker waker;
329     {
330       MutexLock lock(&mu_);
331       done_ = true;
332       waker = std::move(waker_);
333     }
334     waker.Wakeup();
335   }
336 
NotifyUnderLock()337   void NotifyUnderLock() {
338     MutexLock lock(&mu_);
339     done_ = true;
340     waker_.WakeupAsync();
341   }
342 
343  private:
344   Mutex mu_;
345   const bool owning_waker_;
346   bool done_ ABSL_GUARDED_BY(mu_) = false;
347   bool polled_ ABSL_GUARDED_BY(mu_) = false;
348   Waker waker_ ABSL_GUARDED_BY(mu_);
349 };
350 
TEST_F(PartyTest,ThreadStressTestWithOwningWaker)351 TEST_F(PartyTest, ThreadStressTestWithOwningWaker) {
352   auto party = MakeParty();
353   std::vector<std::thread> threads;
354   threads.reserve(8);
355   for (int i = 0; i < 8; i++) {
356     threads.emplace_back([party]() {
357       for (int i = 0; i < 100; i++) {
358         ExecCtx ctx;  // needed for Sleep
359         PromiseNotification promise_start(true);
360         Notification promise_complete;
361         party->Spawn("TestSpawn",
362                      Seq(promise_start.Wait(),
363                          Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
364                          []() -> Poll<int> { return 42; }),
365                      [&promise_complete](int i) {
366                        EXPECT_EQ(i, 42);
367                        promise_complete.Notify();
368                      });
369         promise_start.Notify();
370         promise_complete.WaitForNotification();
371       }
372     });
373   }
374   for (auto& thread : threads) {
375     thread.join();
376   }
377 }
378 
TEST_F(PartyTest,ThreadStressTestWithOwningWakerHoldingLock)379 TEST_F(PartyTest, ThreadStressTestWithOwningWakerHoldingLock) {
380   auto party = MakeParty();
381   std::vector<std::thread> threads;
382   threads.reserve(8);
383   for (int i = 0; i < 8; i++) {
384     threads.emplace_back([party]() {
385       for (int i = 0; i < 100; i++) {
386         ExecCtx ctx;  // needed for Sleep
387         PromiseNotification promise_start(true);
388         Notification promise_complete;
389         party->Spawn("TestSpawn",
390                      Seq(promise_start.Wait(),
391                          Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
392                          []() -> Poll<int> { return 42; }),
393                      [&promise_complete](int i) {
394                        EXPECT_EQ(i, 42);
395                        promise_complete.Notify();
396                      });
397         promise_start.NotifyUnderLock();
398         promise_complete.WaitForNotification();
399       }
400     });
401   }
402   for (auto& thread : threads) {
403     thread.join();
404   }
405 }
406 
TEST_F(PartyTest,ThreadStressTestWithNonOwningWaker)407 TEST_F(PartyTest, ThreadStressTestWithNonOwningWaker) {
408   auto party = MakeParty();
409   std::vector<std::thread> threads;
410   threads.reserve(8);
411   for (int i = 0; i < 8; i++) {
412     threads.emplace_back([party]() {
413       for (int i = 0; i < 100; i++) {
414         ExecCtx ctx;  // needed for Sleep
415         PromiseNotification promise_start(false);
416         Notification promise_complete;
417         party->Spawn("TestSpawn",
418                      Seq(promise_start.Wait(),
419                          Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
420                          []() -> Poll<int> { return 42; }),
421                      [&promise_complete](int i) {
422                        EXPECT_EQ(i, 42);
423                        promise_complete.Notify();
424                      });
425         promise_start.Notify();
426         promise_complete.WaitForNotification();
427       }
428     });
429   }
430   for (auto& thread : threads) {
431     thread.join();
432   }
433 }
434 
TEST_F(PartyTest,ThreadStressTestWithOwningWakerNoSleep)435 TEST_F(PartyTest, ThreadStressTestWithOwningWakerNoSleep) {
436   auto party = MakeParty();
437   std::vector<std::thread> threads;
438   threads.reserve(8);
439   for (int i = 0; i < 8; i++) {
440     threads.emplace_back([party]() {
441       for (int i = 0; i < 10000; i++) {
442         PromiseNotification promise_start(true);
443         Notification promise_complete;
444         party->Spawn(
445             "TestSpawn",
446             Seq(promise_start.Wait(), []() -> Poll<int> { return 42; }),
447             [&promise_complete](int i) {
448               EXPECT_EQ(i, 42);
449               promise_complete.Notify();
450             });
451         promise_start.Notify();
452         promise_complete.WaitForNotification();
453       }
454     });
455   }
456   for (auto& thread : threads) {
457     thread.join();
458   }
459 }
460 
TEST_F(PartyTest,ThreadStressTestWithNonOwningWakerNoSleep)461 TEST_F(PartyTest, ThreadStressTestWithNonOwningWakerNoSleep) {
462   auto party = MakeParty();
463   std::vector<std::thread> threads;
464   threads.reserve(8);
465   for (int i = 0; i < 8; i++) {
466     threads.emplace_back([party]() {
467       for (int i = 0; i < 10000; i++) {
468         PromiseNotification promise_start(false);
469         Notification promise_complete;
470         party->Spawn(
471             "TestSpawn",
472             Seq(promise_start.Wait(), []() -> Poll<int> { return 42; }),
473             [&promise_complete](int i) {
474               EXPECT_EQ(i, 42);
475               promise_complete.Notify();
476             });
477         promise_start.Notify();
478         promise_complete.WaitForNotification();
479       }
480     });
481   }
482   for (auto& thread : threads) {
483     thread.join();
484   }
485 }
486 
TEST_F(PartyTest,ThreadStressTestWithInnerSpawn)487 TEST_F(PartyTest, ThreadStressTestWithInnerSpawn) {
488   auto party = MakeParty();
489   std::vector<std::thread> threads;
490   threads.reserve(8);
491   for (int i = 0; i < 8; i++) {
492     threads.emplace_back([party]() {
493       for (int i = 0; i < 100; i++) {
494         ExecCtx ctx;  // needed for Sleep
495         PromiseNotification inner_start(true);
496         PromiseNotification inner_complete(false);
497         Notification promise_complete;
498         party->Spawn(
499             "TestSpawn",
500             Seq(
501                 [party, &inner_start, &inner_complete]() -> Poll<int> {
502                   party->Spawn("TestSpawnInner",
503                                Seq(inner_start.Wait(), []() { return 0; }),
504                                [&inner_complete](int i) {
505                                  EXPECT_EQ(i, 0);
506                                  inner_complete.Notify();
507                                });
508                   return 0;
509                 },
510                 Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
511                 [&inner_start]() {
512                   inner_start.Notify();
513                   return 0;
514                 },
515                 inner_complete.Wait(), []() -> Poll<int> { return 42; }),
516             [&promise_complete](int i) {
517               EXPECT_EQ(i, 42);
518               promise_complete.Notify();
519             });
520         promise_complete.WaitForNotification();
521       }
522     });
523   }
524   for (auto& thread : threads) {
525     thread.join();
526   }
527 }
528 
TEST_F(PartyTest,NestedWakeup)529 TEST_F(PartyTest, NestedWakeup) {
530   auto party1 = MakeParty();
531   auto party2 = MakeParty();
532   auto party3 = MakeParty();
533   int whats_going_on = 0;
534   Notification done1;
535   Notification started2;
536   Notification done2;
537   Notification started3;
538   Notification notify_done;
539   party1->Spawn(
540       "p1",
541       [&]() {
542         EXPECT_EQ(whats_going_on, 0);
543         whats_going_on = 1;
544         party2->Spawn(
545             "p2",
546             [&]() {
547               done1.WaitForNotification();
548               started2.Notify();
549               started3.WaitForNotification();
550               EXPECT_EQ(whats_going_on, 3);
551               whats_going_on = 4;
552             },
553             [&](Empty) {
554               EXPECT_EQ(whats_going_on, 4);
555               whats_going_on = 5;
556               done2.Notify();
557             });
558         party3->Spawn(
559             "p3",
560             [&]() {
561               started2.WaitForNotification();
562               started3.Notify();
563               done2.WaitForNotification();
564               EXPECT_EQ(whats_going_on, 5);
565               whats_going_on = 6;
566             },
567             [&](Empty) {
568               EXPECT_EQ(whats_going_on, 6);
569               whats_going_on = 7;
570               notify_done.Notify();
571             });
572         EXPECT_EQ(whats_going_on, 1);
573         whats_going_on = 2;
574       },
575       [&](Empty) {
576         EXPECT_EQ(whats_going_on, 2);
577         whats_going_on = 3;
578         done1.Notify();
579       });
580   notify_done.WaitForNotification();
581 }
582 
583 }  // namespace grpc_core
584 
main(int argc,char ** argv)585 int main(int argc, char** argv) {
586   ::testing::InitGoogleTest(&argc, argv);
587   grpc_init();
588   int r = RUN_ALL_TESTS();
589   grpc_shutdown();
590   return r;
591 }
592