1 /*
2 * Copyright 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "model/setup/async_manager.h"
18
19 #include <fcntl.h> // for fcntl, F_SETFL, O_NONBLOCK
20 #include <gtest/gtest.h> // for Message, TestPartResult, SuiteApi...
21 #include <netdb.h> // for gethostbyname, h_addr, hostent
22 #include <netinet/in.h> // for sockaddr_in, in_addr, INADDR_ANY
23 #include <stdio.h> // for printf
24 #include <sys/socket.h> // for socket, AF_INET, accept, bind
25 #include <sys/types.h> // for in_addr_t
26 #include <time.h> // for NULL, size_t
27 #include <unistd.h> // for close, write, read
28
29 #include <condition_variable> // for condition_variable
30 #include <cstdint> // for uint16_t
31 #include <cstring> // for memset, strcmp, strcpy, strlen
32 #include <mutex> // for mutex
33 #include <ratio> // for ratio
34 #include <string> // for string
35 #include <thread>
36 #include <tuple> // for tuple
37
38 namespace rootcanal {
39
40 class Event {
41 public:
set(bool set=true)42 void set(bool set = true) {
43 std::unique_lock<std::mutex> lk(m_);
44 set_ = set;
45 cv_.notify_all();
46 }
47
reset()48 void reset() { set(false); }
49
wait_for(std::chrono::microseconds timeout)50 bool wait_for(std::chrono::microseconds timeout) {
51 std::unique_lock<std::mutex> lk(m_);
52 return cv_.wait_for(lk, timeout, [&] { return set_; });
53 }
54
operator *()55 bool operator*() { return set_; }
56
57 private:
58 std::mutex m_;
59 std::condition_variable cv_;
60 bool set_{false};
61 };
62
63 class AsyncManagerSocketTest : public ::testing::Test {
64 public:
65 static const uint16_t kPort = 6111;
66 static const size_t kBufferSize = 16;
67
CheckBufferEquals()68 bool CheckBufferEquals() {
69 return strcmp(server_buffer_, client_buffer_) == 0;
70 }
71
72 protected:
StartServer()73 int StartServer() {
74 struct sockaddr_in serv_addr = {};
75 int fd = socket(AF_INET, SOCK_STREAM, 0);
76 EXPECT_FALSE(fd < 0);
77
78 serv_addr.sin_family = AF_INET;
79 serv_addr.sin_addr.s_addr = INADDR_ANY;
80 serv_addr.sin_port = htons(kPort);
81 int reuse_flag = 1;
82 EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
83 sizeof(reuse_flag)) < 0);
84 EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
85
86 listen(fd, 1);
87 return fd;
88 }
89
AcceptConnection(int fd)90 int AcceptConnection(int fd) {
91 struct sockaddr_in cli_addr;
92 memset(&cli_addr, 0, sizeof(cli_addr));
93 socklen_t clilen = sizeof(cli_addr);
94
95 int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
96 EXPECT_FALSE(connection_fd < 0);
97
98 return connection_fd;
99 }
100
ConnectSocketPair()101 std::tuple<int, int> ConnectSocketPair() {
102 int cli = ConnectClient();
103 WriteFromClient(cli);
104 AwaitServerResponse(cli);
105 int ser = connection_fd_;
106 connection_fd_ = -1;
107 return {cli, ser};
108 }
109
ReadIncomingMessage(int fd)110 void ReadIncomingMessage(int fd) {
111 int n;
112 do {
113 n = read(fd, server_buffer_, kBufferSize - 1);
114 } while (n == -1 && errno == EAGAIN);
115 ASSERT_GE(n, 0) << strerror(errno);
116
117 if (n == 0) { // got EOF
118 async_manager_.StopWatchingFileDescriptor(fd);
119 close(fd);
120 } else {
121 n = write(fd, "1", 1);
122 }
123 }
124
SetUp()125 void SetUp() override {
126 memset(server_buffer_, 0, kBufferSize);
127 memset(client_buffer_, 0, kBufferSize);
128 socket_fd_ = -1;
129 connection_fd_ = -1;
130
131 socket_fd_ = StartServer();
132
133 async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
134 connection_fd_ = AcceptConnection(fd);
135
136 async_manager_.WatchFdForNonBlockingReads(
137 connection_fd_, [this](int fd) { ReadIncomingMessage(fd); });
138 });
139 }
140
TearDown()141 void TearDown() override {
142 async_manager_.StopWatchingFileDescriptor(socket_fd_);
143 close(socket_fd_);
144 close(connection_fd_);
145 ASSERT_EQ(std::string_view(server_buffer_, kBufferSize),
146 std::string_view(client_buffer_, kBufferSize));
147 }
148
ConnectClient()149 int ConnectClient() {
150 int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
151 EXPECT_GE(socket_cli_fd, 0) << strerror(errno);
152
153 struct hostent* server;
154 server = gethostbyname("localhost");
155 EXPECT_FALSE(server == NULL) << strerror(errno);
156
157 struct sockaddr_in serv_addr;
158 memset((void*)&serv_addr, 0, sizeof(serv_addr));
159 serv_addr.sin_family = AF_INET;
160 serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
161 serv_addr.sin_port = htons(kPort);
162
163 int result =
164 connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
165 EXPECT_GE(result, 0) << strerror(errno);
166
167 return socket_cli_fd;
168 }
169
WriteFromClient(int socket_cli_fd)170 void WriteFromClient(int socket_cli_fd) {
171 strcpy(client_buffer_, "1");
172 int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
173 ASSERT_GT(n, 0) << strerror(errno);
174 }
175
AwaitServerResponse(int socket_cli_fd)176 void AwaitServerResponse(int socket_cli_fd) {
177 int n = read(socket_cli_fd, client_buffer_, 1);
178 ASSERT_GT(n, 0) << strerror(errno);
179 }
180
181 protected:
182 AsyncManager async_manager_;
183 int socket_fd_;
184 int connection_fd_;
185 char server_buffer_[kBufferSize];
186 char client_buffer_[kBufferSize];
187 };
188
TEST_F(AsyncManagerSocketTest,TestOneConnection)189 TEST_F(AsyncManagerSocketTest, TestOneConnection) {
190 int socket_cli_fd = ConnectClient();
191
192 WriteFromClient(socket_cli_fd);
193
194 AwaitServerResponse(socket_cli_fd);
195
196 close(socket_cli_fd);
197 }
198
TEST_F(AsyncManagerSocketTest,CanUnsubscribeInCallback)199 TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
200 using namespace std::chrono_literals;
201
202 int socket_cli_fd = ConnectClient();
203 WriteFromClient(socket_cli_fd);
204 AwaitServerResponse(socket_cli_fd);
205 fcntl(connection_fd_, F_SETFL, O_NONBLOCK);
206
207 std::string data('x', 32);
208
209 bool stopped = false;
210 async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
211 async_manager_.StopWatchingFileDescriptor(fd);
212 char buf[32];
213 while (read(fd, buf, sizeof(buf)) > 0)
214 ;
215 stopped = true;
216 });
217
218 while (!stopped) {
219 write(socket_cli_fd, data.data(), data.size());
220 std::this_thread::sleep_for(5ms);
221 }
222
223 SUCCEED();
224 close(socket_cli_fd);
225 }
226
TEST_F(AsyncManagerSocketTest,CanUnsubscribeTaskFromWithinTask)227 TEST_F(AsyncManagerSocketTest, CanUnsubscribeTaskFromWithinTask) {
228 Event running;
229 using namespace std::chrono_literals;
230 async_manager_.ExecAsyncPeriodically(1, 1ms, 2ms, [&running, this]() {
231 EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
232 << "We were scheduled, so cancel should return true";
233 EXPECT_FALSE(async_manager_.CancelAsyncTask(1))
234 << "We were not scheduled, so cancel should return false";
235 running.set(true);
236 });
237
238 EXPECT_TRUE(running.wait_for(10ms));
239 }
240
TEST_F(AsyncManagerSocketTest,UnsubScribeWaitsUntilCompletion)241 TEST_F(AsyncManagerSocketTest, UnsubScribeWaitsUntilCompletion) {
242 using namespace std::chrono_literals;
243 Event running;
244 bool cancel_done = false;
245 bool task_complete = false;
246 async_manager_.ExecAsyncPeriodically(
247 1, 1ms, 2ms, [&running, &cancel_done, &task_complete]() {
248 // Let the other thread now we are in the callback..
249 running.set(true);
250 // Wee bit of a hack that relies on timing..
251 std::this_thread::sleep_for(20ms);
252 EXPECT_FALSE(cancel_done)
253 << "Task cancellation did not wait for us to complete!";
254 task_complete = true;
255 });
256
257 EXPECT_TRUE(running.wait_for(10ms));
258 auto start = std::chrono::system_clock::now();
259
260 // There is a 20ms wait.. so we know that this should take some time.
261 EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
262 << "We were scheduled, so cancel should return true";
263 cancel_done = true;
264 EXPECT_TRUE(task_complete)
265 << "We managed to cancel a task while it was not yet finished.";
266 auto end = std::chrono::system_clock::now();
267 auto passed_ms =
268 std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
269 EXPECT_GT(passed_ms.count(), 10);
270 }
271
TEST_F(AsyncManagerSocketTest,NoEventsAfterUnsubscribe)272 TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
273 // This tests makes sure the AsyncManager never fires an event
274 // after calling StopWatchingFileDescriptor.
275 using clock = std::chrono::system_clock;
276 using namespace std::chrono_literals;
277
278 clock::time_point time_fast_called;
279 clock::time_point time_slow_called;
280 clock::time_point time_stopped_listening;
281
282 int round = 0;
283 auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
284 fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);
285
286 auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
287 fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);
288
289 std::string data(1, 'x');
290
291 // The idea here is as follows:
292 // We want to make sure that an unsubscribed callback never gets called.
293 // This is to make sure we can safely do things like this:
294 //
295 // class Foo {
296 // Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
297 // am_->WatchFdForNonBlockingReads(
298 // fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
299 // }
300 // ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
301 //
302 // AsyncManager* am_;
303 // int fd_;
304 // };
305 //
306 // We are going to force a failure as follows:
307 //
308 // The slow callback needs to be called first, if it does not we cannot
309 // force failure, so we have to try multiple times.
310 //
311 // t1, is the thread doing the loop.
312 // t2, is the async manager handler thread.
313 //
314 // t1 will block until the slowcallback.
315 // t2 will now block (for at most 250 ms).
316 // t1 will unsubscribe the fast callback.
317 // 2 cases:
318 // with bug:
319 // - t1 takes a timestamp, unblocks t2,
320 // - t2 invokes the fast callback, and gets a timestamp.
321 // - Now the unsubscribe time is before the callback time.
322 // without bug.:
323 // - t1 locks un unsusbcribe in asyn manager
324 // - t2 unlocks due to timeout,
325 // - t2 invokes the fast callback, and gets a timestamp.
326 // - t1 is unlocked and gets a timestamp.
327 // - Now the unsubscribe time is after the callback time..
328
329 do {
330 Event unblock_slow, inslow, infast;
331 time_fast_called = {};
332 time_slow_called = {};
333 time_stopped_listening = {};
334 printf("round: %d\n", round++);
335
336 // Register fd events
337 async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
338 if (*inslow) return;
339 time_slow_called = clock::now();
340 printf("slow: %lld\n",
341 time_slow_called.time_since_epoch().count() % 10000);
342 inslow.set();
343 unblock_slow.wait_for(25ms);
344 });
345
346 async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
347 if (*infast) return;
348 time_fast_called = clock::now();
349 printf("fast: %lld\n",
350 time_fast_called.time_since_epoch().count() % 10000);
351 infast.set();
352 });
353
354 // Generate fd events
355 write(fast_cli_fd, data.data(), data.size());
356 write(slow_cli_fd, data.data(), data.size());
357
358 // Block in the right places.
359 if (inslow.wait_for(25ms)) {
360 async_manager_.StopWatchingFileDescriptor(fast_s_fd);
361 time_stopped_listening = clock::now();
362 printf("stop: %lld\n",
363 time_stopped_listening.time_since_epoch().count() % 10000);
364 unblock_slow.set();
365 }
366
367 infast.wait_for(25ms);
368
369 // Unregister.
370 async_manager_.StopWatchingFileDescriptor(fast_s_fd);
371 async_manager_.StopWatchingFileDescriptor(slow_s_fd);
372 } while (time_fast_called < time_slow_called);
373
374 // fast before stop listening.
375 ASSERT_LT(time_fast_called.time_since_epoch().count(),
376 time_stopped_listening.time_since_epoch().count());
377
378 // Cleanup
379 close(fast_cli_fd);
380 close(fast_s_fd);
381 close(slow_cli_fd);
382 close(slow_s_fd);
383 }
384
TEST_F(AsyncManagerSocketTest,TestRepeatedConnections)385 TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
386 static const int num_connections = 30;
387 for (int i = 0; i < num_connections; i++) {
388 int socket_cli_fd = ConnectClient();
389 WriteFromClient(socket_cli_fd);
390 AwaitServerResponse(socket_cli_fd);
391 close(socket_cli_fd);
392 }
393 }
394
TEST_F(AsyncManagerSocketTest,TestMultipleConnections)395 TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
396 static const int num_connections = 30;
397 int socket_cli_fd[num_connections];
398 for (int i = 0; i < num_connections; i++) {
399 socket_cli_fd[i] = ConnectClient();
400 ASSERT_TRUE(socket_cli_fd[i] > 0);
401 WriteFromClient(socket_cli_fd[i]);
402 }
403 for (int i = 0; i < num_connections; i++) {
404 AwaitServerResponse(socket_cli_fd[i]);
405 close(socket_cli_fd[i]);
406 }
407 }
408
409 class AsyncManagerTest : public ::testing::Test {
410 public:
411 AsyncManager async_manager_;
412 };
413
TEST_F(AsyncManagerTest,TestSetupTeardown)414 TEST_F(AsyncManagerTest, TestSetupTeardown) {}
415
TEST_F(AsyncManagerTest,TestCancelTask)416 TEST_F(AsyncManagerTest, TestCancelTask) {
417 AsyncUserId user1 = async_manager_.GetNextUserId();
418 bool task1_ran = false;
419 bool* task1_ran_ptr = &task1_ran;
420 AsyncTaskId task1_id =
421 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
422 [task1_ran_ptr]() { *task1_ran_ptr = true; });
423 ASSERT_TRUE(async_manager_.CancelAsyncTask(task1_id));
424 ASSERT_FALSE(task1_ran);
425 }
426
TEST_F(AsyncManagerTest,TestCancelLongTask)427 TEST_F(AsyncManagerTest, TestCancelLongTask) {
428 AsyncUserId user1 = async_manager_.GetNextUserId();
429 bool task1_ran = false;
430 bool* task1_ran_ptr = &task1_ran;
431 AsyncTaskId task1_id =
432 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
433 [task1_ran_ptr]() { *task1_ran_ptr = true; });
434 bool task2_ran = false;
435 bool* task2_ran_ptr = &task2_ran;
436 AsyncTaskId task2_id =
437 async_manager_.ExecAsync(user1, std::chrono::seconds(2),
438 [task2_ran_ptr]() { *task2_ran_ptr = true; });
439 ASSERT_FALSE(task1_ran);
440 ASSERT_FALSE(task2_ran);
441 while (!task1_ran)
442 ;
443 ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
444 ASSERT_FALSE(task2_ran);
445 ASSERT_TRUE(async_manager_.CancelAsyncTask(task2_id));
446 }
447
TEST_F(AsyncManagerTest,TestCancelAsyncTasksFromUser)448 TEST_F(AsyncManagerTest, TestCancelAsyncTasksFromUser) {
449 AsyncUserId user1 = async_manager_.GetNextUserId();
450 AsyncUserId user2 = async_manager_.GetNextUserId();
451 bool task1_ran = false;
452 bool* task1_ran_ptr = &task1_ran;
453 bool task2_ran = false;
454 bool* task2_ran_ptr = &task2_ran;
455 bool task3_ran = false;
456 bool* task3_ran_ptr = &task3_ran;
457 bool task4_ran = false;
458 bool* task4_ran_ptr = &task4_ran;
459 bool task5_ran = false;
460 bool* task5_ran_ptr = &task5_ran;
461 AsyncTaskId task1_id =
462 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
463 [task1_ran_ptr]() { *task1_ran_ptr = true; });
464 AsyncTaskId task2_id =
465 async_manager_.ExecAsync(user1, std::chrono::seconds(2),
466 [task2_ran_ptr]() { *task2_ran_ptr = true; });
467 AsyncTaskId task3_id =
468 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
469 [task3_ran_ptr]() { *task3_ran_ptr = true; });
470 AsyncTaskId task4_id =
471 async_manager_.ExecAsync(user1, std::chrono::seconds(2),
472 [task4_ran_ptr]() { *task4_ran_ptr = true; });
473 AsyncTaskId task5_id =
474 async_manager_.ExecAsync(user2, std::chrono::milliseconds(2),
475 [task5_ran_ptr]() { *task5_ran_ptr = true; });
476 ASSERT_FALSE(task1_ran);
477 while (!task1_ran || !task3_ran || !task5_ran)
478 ;
479 ASSERT_TRUE(task1_ran);
480 ASSERT_FALSE(task2_ran);
481 ASSERT_TRUE(task3_ran);
482 ASSERT_FALSE(task4_ran);
483 ASSERT_TRUE(task5_ran);
484 async_manager_.CancelAsyncTasksFromUser(user1);
485 ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
486 ASSERT_FALSE(async_manager_.CancelAsyncTask(task2_id));
487 ASSERT_FALSE(async_manager_.CancelAsyncTask(task3_id));
488 ASSERT_FALSE(async_manager_.CancelAsyncTask(task4_id));
489 ASSERT_FALSE(async_manager_.CancelAsyncTask(task5_id));
490 }
491
492 } // namespace rootcanal
493