1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_async2/dispatcher.h"
16
17 #include "gtest/gtest.h"
18 #include "pw_containers/vector.h"
19
20 namespace pw::async2 {
21 namespace {
22
23 class MockTask : public Task {
24 public:
25 bool should_complete = false;
26 int polled = 0;
27 int destroyed = 0;
28 std::optional<Waker> last_waker = std::nullopt;
29
30 private:
DoPend(Context & cx)31 Poll<> DoPend(Context& cx) override {
32 ++polled;
33 last_waker = cx.GetWaker(WaitReason::Unspecified());
34 if (should_complete) {
35 return Ready();
36 } else {
37 return Pending();
38 }
39 }
DoDestroy()40 void DoDestroy() override { ++destroyed; }
41 };
42
TEST(Dispatcher,RunUntilStalledPendsPostedTask)43 TEST(Dispatcher, RunUntilStalledPendsPostedTask) {
44 MockTask task;
45 task.should_complete = true;
46 Dispatcher dispatcher;
47 dispatcher.Post(task);
48 EXPECT_TRUE(dispatcher.RunUntilStalled(task).IsReady());
49 EXPECT_EQ(task.polled, 1);
50 EXPECT_EQ(task.destroyed, 1);
51 }
52
TEST(Dispatcher,RunUntilStalledReturnsOnNotReady)53 TEST(Dispatcher, RunUntilStalledReturnsOnNotReady) {
54 MockTask task;
55 task.should_complete = false;
56 Dispatcher dispatcher;
57 dispatcher.Post(task);
58 EXPECT_FALSE(dispatcher.RunUntilStalled(task).IsReady());
59 EXPECT_EQ(task.polled, 1);
60 EXPECT_EQ(task.destroyed, 0);
61 }
62
TEST(Dispatcher,RunUntilStalledDoesNotPendSleepingTask)63 TEST(Dispatcher, RunUntilStalledDoesNotPendSleepingTask) {
64 MockTask task;
65 task.should_complete = false;
66 Dispatcher dispatcher;
67 dispatcher.Post(task);
68
69 EXPECT_FALSE(dispatcher.RunUntilStalled(task).IsReady());
70 EXPECT_EQ(task.polled, 1);
71 EXPECT_EQ(task.destroyed, 0);
72
73 task.should_complete = true;
74 EXPECT_FALSE(dispatcher.RunUntilStalled(task).IsReady());
75 EXPECT_EQ(task.polled, 1);
76 EXPECT_EQ(task.destroyed, 0);
77
78 std::move(*task.last_waker).Wake();
79 EXPECT_TRUE(dispatcher.RunUntilStalled(task).IsReady());
80 EXPECT_EQ(task.polled, 2);
81 EXPECT_EQ(task.destroyed, 1);
82 }
83
TEST(Dispatcher,RunUntilCompletePendsMultipleTasks)84 TEST(Dispatcher, RunUntilCompletePendsMultipleTasks) {
85 class CounterTask : public Task {
86 public:
87 CounterTask(pw::Vector<Waker>* wakers, int* counter, int until)
88 : counter_(counter), until_(until), wakers_(wakers) {}
89 int* counter_;
90 int until_;
91 pw::Vector<Waker>* wakers_;
92
93 private:
94 Poll<> DoPend(Context& cx) override {
95 ++(*counter_);
96 if (*counter_ >= until_) {
97 for (auto& waker : *wakers_) {
98 std::move(waker).Wake();
99 }
100 return Ready();
101 } else {
102 wakers_->push_back(cx.GetWaker(WaitReason::Unspecified()));
103 return Pending();
104 }
105 }
106 };
107
108 int counter = 0;
109 constexpr const int num_tasks = 3;
110 pw::Vector<Waker, num_tasks> wakers;
111 CounterTask task_one(&wakers, &counter, num_tasks);
112 CounterTask task_two(&wakers, &counter, num_tasks);
113 CounterTask task_three(&wakers, &counter, num_tasks);
114 Dispatcher dispatcher;
115 dispatcher.Post(task_one);
116 dispatcher.Post(task_two);
117 dispatcher.Post(task_three);
118 EXPECT_TRUE(dispatcher.RunUntilStalled().IsReady());
119 // We expect to see 5 total calls to `Pend`:
120 // - two which increment counter and return pending
121 // - one which increments the counter, returns complete, and wakes the
122 // others
123 // - two which have woken back up and complete
124 EXPECT_EQ(counter, 5);
125 }
126
TEST(Dispatcher,PostToDispatcherFromInsidePendSucceeds)127 TEST(Dispatcher, PostToDispatcherFromInsidePendSucceeds) {
128 class TaskPoster : public Task {
129 public:
130 TaskPoster(Task& task_to_post) : task_to_post_(&task_to_post) {}
131
132 private:
133 Poll<> DoPend(Context& cx) override {
134 cx.dispatcher().Post(*task_to_post_);
135 return Ready();
136 }
137 Task* task_to_post_;
138 };
139
140 MockTask posted_task;
141 posted_task.should_complete = true;
142 TaskPoster task_poster(posted_task);
143
144 Dispatcher dispatcher;
145 dispatcher.Post(task_poster);
146 EXPECT_TRUE(dispatcher.RunUntilStalled().IsReady());
147 EXPECT_EQ(posted_task.polled, 1);
148 EXPECT_EQ(posted_task.destroyed, 1);
149 }
150
TEST(Dispatcher,RunToCompletionPendsPostedTask)151 TEST(Dispatcher, RunToCompletionPendsPostedTask) {
152 MockTask task;
153 task.should_complete = true;
154 Dispatcher dispatcher;
155 dispatcher.Post(task);
156 dispatcher.RunToCompletion(task);
157 EXPECT_EQ(task.polled, 1);
158 EXPECT_EQ(task.destroyed, 1);
159 }
160
161 } // namespace
162 } // namespace pw::async2
163