1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "common/task_runner.hpp"
30
31 #include <atomic>
32 #include <mutex>
33 #include <thread>
34 #include <unistd.h>
35
36 #include <CppUTest/TestHarness.h>
37
TEST_GROUP(TaskRunner)38 TEST_GROUP(TaskRunner){};
39
TEST(TaskRunner,TestSingleThread)40 TEST(TaskRunner, TestSingleThread)
41 {
42 int rval;
43 int counter = 0;
44 otbr::MainloopContext mainloop;
45 otbr::TaskRunner taskRunner;
46
47 mainloop.mMaxFd = -1;
48 mainloop.mTimeout = {10, 0};
49
50 FD_ZERO(&mainloop.mReadFdSet);
51 FD_ZERO(&mainloop.mWriteFdSet);
52 FD_ZERO(&mainloop.mErrorFdSet);
53
54 // Increase the `counter` to 3.
55 taskRunner.Post([&]() {
56 ++counter;
57 taskRunner.Post([&]() {
58 ++counter;
59 taskRunner.Post([&]() { ++counter; });
60 });
61 });
62
63 taskRunner.Update(mainloop);
64 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
65 &mainloop.mTimeout);
66 CHECK_EQUAL(1, rval);
67
68 taskRunner.Process(mainloop);
69 CHECK_EQUAL(3, counter);
70 }
71
TEST(TaskRunner,TestTasksOrder)72 TEST(TaskRunner, TestTasksOrder)
73 {
74 std::string str;
75 otbr::TaskRunner taskRunner;
76 int rval;
77 otbr::MainloopContext mainloop;
78
79 taskRunner.Post([&]() { str.push_back('a'); });
80 taskRunner.Post([&]() { str.push_back('b'); });
81 taskRunner.Post([&]() { str.push_back('c'); });
82
83 mainloop.mMaxFd = -1;
84 mainloop.mTimeout = {2, 0};
85
86 FD_ZERO(&mainloop.mReadFdSet);
87 FD_ZERO(&mainloop.mWriteFdSet);
88 FD_ZERO(&mainloop.mErrorFdSet);
89
90 taskRunner.Update(mainloop);
91 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
92 &mainloop.mTimeout);
93 CHECK_TRUE(rval == 1);
94
95 taskRunner.Process(mainloop);
96
97 // Make sure the tasks are executed in the order of posting.
98 STRCMP_EQUAL("abc", str.c_str());
99 }
100
TEST(TaskRunner,TestMultipleThreads)101 TEST(TaskRunner, TestMultipleThreads)
102 {
103 std::atomic<int> counter{0};
104 otbr::TaskRunner taskRunner;
105 std::vector<std::thread> threads;
106
107 // Increase the `counter` to 10 in separate threads.
108 for (size_t i = 0; i < 10; ++i)
109 {
110 threads.emplace_back([&]() { taskRunner.Post([&]() { ++counter; }); });
111 }
112
113 while (counter.load() < 10)
114 {
115 int rval;
116 otbr::MainloopContext mainloop;
117
118 mainloop.mMaxFd = -1;
119 mainloop.mTimeout = {10, 0};
120
121 FD_ZERO(&mainloop.mReadFdSet);
122 FD_ZERO(&mainloop.mWriteFdSet);
123 FD_ZERO(&mainloop.mErrorFdSet);
124
125 taskRunner.Update(mainloop);
126 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
127 &mainloop.mTimeout);
128 CHECK_EQUAL(1, rval);
129
130 taskRunner.Process(mainloop);
131 }
132
133 for (auto &th : threads)
134 {
135 th.join();
136 }
137
138 CHECK_EQUAL(10, counter.load());
139 }
140
TEST(TaskRunner,TestPostAndWait)141 TEST(TaskRunner, TestPostAndWait)
142 {
143 std::atomic<int> total{0};
144 std::atomic<int> counter{0};
145 otbr::TaskRunner taskRunner;
146 std::vector<std::thread> threads;
147
148 // Increase the `counter` to 10 in separate threads and accumulate the total value.
149 for (size_t i = 0; i < 10; ++i)
150 {
151 threads.emplace_back([&]() { total += taskRunner.PostAndWait<int>([&]() { return ++counter; }); });
152 }
153
154 while (counter.load() < 10)
155 {
156 int rval;
157 otbr::MainloopContext mainloop;
158
159 mainloop.mMaxFd = -1;
160 mainloop.mTimeout = {10, 0};
161
162 FD_ZERO(&mainloop.mReadFdSet);
163 FD_ZERO(&mainloop.mWriteFdSet);
164 FD_ZERO(&mainloop.mErrorFdSet);
165
166 taskRunner.Update(mainloop);
167 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
168 &mainloop.mTimeout);
169 CHECK_EQUAL(1, rval);
170
171 taskRunner.Process(mainloop);
172 }
173
174 for (auto &th : threads)
175 {
176 th.join();
177 }
178
179 CHECK_EQUAL(55, total);
180 CHECK_EQUAL(10, counter.load());
181 }
182
TEST(TaskRunner,TestDelayedTasks)183 TEST(TaskRunner, TestDelayedTasks)
184 {
185 std::atomic<int> counter{0};
186 otbr::TaskRunner taskRunner;
187 std::vector<std::thread> threads;
188
189 // Increase the `counter` to 10 in separate threads.
190 for (size_t i = 0; i < 10; ++i)
191 {
192 threads.emplace_back([&]() { taskRunner.Post(std::chrono::milliseconds(10), [&]() { ++counter; }); });
193 }
194
195 while (counter.load() < 10)
196 {
197 int rval;
198 otbr::MainloopContext mainloop;
199
200 mainloop.mMaxFd = -1;
201 mainloop.mTimeout = {2, 0};
202
203 FD_ZERO(&mainloop.mReadFdSet);
204 FD_ZERO(&mainloop.mWriteFdSet);
205 FD_ZERO(&mainloop.mErrorFdSet);
206
207 taskRunner.Update(mainloop);
208 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
209 &mainloop.mTimeout);
210 CHECK_TRUE(rval >= 0 || errno == EINTR);
211
212 taskRunner.Process(mainloop);
213 }
214
215 for (auto &th : threads)
216 {
217 th.join();
218 }
219
220 CHECK_EQUAL(10, counter.load());
221 }
222
TEST(TaskRunner,TestDelayedTasksOrder)223 TEST(TaskRunner, TestDelayedTasksOrder)
224 {
225 std::string str;
226 otbr::TaskRunner taskRunner;
227
228 taskRunner.Post(std::chrono::milliseconds(10), [&]() { str.push_back('a'); });
229 taskRunner.Post(std::chrono::milliseconds(9), [&]() { str.push_back('b'); });
230 taskRunner.Post(std::chrono::milliseconds(10), [&]() { str.push_back('c'); });
231
232 while (str.size() < 3)
233 {
234 int rval;
235 otbr::MainloopContext mainloop;
236
237 mainloop.mMaxFd = -1;
238 mainloop.mTimeout = {2, 0};
239
240 FD_ZERO(&mainloop.mReadFdSet);
241 FD_ZERO(&mainloop.mWriteFdSet);
242 FD_ZERO(&mainloop.mErrorFdSet);
243
244 taskRunner.Update(mainloop);
245 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
246 &mainloop.mTimeout);
247 CHECK_TRUE(rval >= 0 || errno == EINTR);
248
249 taskRunner.Process(mainloop);
250 }
251
252 // Make sure that tasks with smaller delay are executed earlier.
253 STRCMP_EQUAL("bac", str.c_str());
254 }
255
TEST(TaskRunner,TestCancelDelayedTasks)256 TEST(TaskRunner, TestCancelDelayedTasks)
257 {
258 std::string str;
259 otbr::TaskRunner taskRunner;
260 otbr::TaskRunner::TaskId tid1, tid2, tid3, tid4, tid5;
261
262 tid1 = taskRunner.Post(std::chrono::milliseconds(10), [&]() { str.push_back('a'); });
263 tid2 = taskRunner.Post(std::chrono::milliseconds(20), [&]() { str.push_back('b'); });
264 tid3 = taskRunner.Post(std::chrono::milliseconds(30), [&]() { str.push_back('c'); });
265 tid4 = taskRunner.Post(std::chrono::milliseconds(40), [&]() { str.push_back('d'); });
266 tid5 = taskRunner.Post(std::chrono::milliseconds(50), [&]() { str.push_back('e'); });
267
268 CHECK(0 < tid1);
269 CHECK(tid1 < tid2);
270 CHECK(tid2 < tid3);
271 CHECK(tid3 < tid4);
272 CHECK(tid4 < tid5);
273
274 taskRunner.Cancel(tid2);
275
276 taskRunner.Post(std::chrono::milliseconds(10), [&]() { taskRunner.Cancel(tid3); });
277 std::thread t([&]() {
278 usleep(20);
279 taskRunner.Cancel(tid4);
280 });
281
282 while (str.size() < 2)
283 {
284 int rval;
285 otbr::MainloopContext mainloop;
286
287 mainloop.mMaxFd = -1;
288 mainloop.mTimeout = {2, 0};
289
290 FD_ZERO(&mainloop.mReadFdSet);
291 FD_ZERO(&mainloop.mWriteFdSet);
292 FD_ZERO(&mainloop.mErrorFdSet);
293
294 taskRunner.Update(mainloop);
295 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
296 &mainloop.mTimeout);
297 CHECK_TRUE(rval >= 0 || errno == EINTR);
298
299 taskRunner.Process(mainloop);
300 }
301
302 // Make sure the delayed task was not executed.
303 STRCMP_EQUAL("ae", str.c_str());
304
305 // Make sure it's fine to cancel expired task IDs.
306 taskRunner.Cancel(tid1);
307 taskRunner.Cancel(tid2);
308 t.join();
309 }
310
TEST(TaskRunner,TestAllAPIs)311 TEST(TaskRunner, TestAllAPIs)
312 {
313 std::atomic<int> counter{0};
314 otbr::TaskRunner taskRunner;
315 std::vector<std::thread> threads;
316
317 // Increase the `counter` to 30 in separate threads.
318 for (size_t i = 0; i < 10; ++i)
319 {
320 threads.emplace_back([&]() { taskRunner.Post([&]() { ++counter; }); });
321 threads.emplace_back([&]() { taskRunner.Post(std::chrono::milliseconds(10), [&]() { ++counter; }); });
322 threads.emplace_back([&]() { taskRunner.PostAndWait<int>([&]() { return ++counter; }); });
323 }
324
325 while (counter.load() < 30)
326 {
327 int rval;
328 otbr::MainloopContext mainloop;
329
330 mainloop.mMaxFd = -1;
331 mainloop.mTimeout = {2, 0};
332
333 FD_ZERO(&mainloop.mReadFdSet);
334 FD_ZERO(&mainloop.mWriteFdSet);
335 FD_ZERO(&mainloop.mErrorFdSet);
336
337 taskRunner.Update(mainloop);
338 rval = select(mainloop.mMaxFd + 1, &mainloop.mReadFdSet, &mainloop.mWriteFdSet, &mainloop.mErrorFdSet,
339 &mainloop.mTimeout);
340 CHECK_TRUE(rval >= 0 || errno == EINTR);
341
342 taskRunner.Process(mainloop);
343 }
344
345 for (auto &th : threads)
346 {
347 th.join();
348 }
349
350 CHECK_EQUAL(30, counter.load());
351 }
352