1 /*
2 * Copyright 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "model/setup/async_manager.h"
18
19 #include <fcntl.h> // for fcntl, F_SETFL, O_NONBLOCK
20 #include <gtest/gtest.h> // for Message, TestPartResult, SuiteApi...
21 #include <netdb.h> // for gethostbyname, h_addr, hostent
22 #include <netinet/in.h> // for sockaddr_in, in_addr, INADDR_ANY
23 #include <stdio.h> // for printf
24 #include <sys/socket.h> // for socket, AF_INET, accept, bind
25 #include <sys/types.h> // for in_addr_t
26 #include <time.h> // for NULL, size_t
27 #include <unistd.h> // for close, write, read
28
29 #include <condition_variable> // for condition_variable
30 #include <cstdint> // for uint16_t
31 #include <cstring> // for memset, strcmp, strcpy, strlen
32 #include <mutex> // for mutex
33 #include <ratio> // for ratio
34 #include <string> // for string
35 #include <thread>
36 #include <tuple> // for tuple
37
38 namespace rootcanal {
39
40 class Event {
41 public:
set(bool set=true)42 void set(bool set = true) {
43 std::unique_lock<std::mutex> lk(m_);
44 set_ = set;
45 cv_.notify_all();
46 }
47
reset()48 void reset() { set(false); }
49
wait_for(std::chrono::microseconds timeout)50 bool wait_for(std::chrono::microseconds timeout) {
51 std::unique_lock<std::mutex> lk(m_);
52 return cv_.wait_for(lk, timeout, [&] { return set_; });
53 }
54
operator *()55 bool operator*() { return set_; }
56
57 private:
58 std::mutex m_;
59 std::condition_variable cv_;
60 bool set_{false};
61 };
62
63 class AsyncManagerSocketTest : public ::testing::Test {
64 public:
65 static const uint16_t kPort = 6111;
66 static const size_t kBufferSize = 16;
67
CheckBufferEquals()68 bool CheckBufferEquals() {
69 return strcmp(server_buffer_, client_buffer_) == 0;
70 }
71
72 protected:
StartServer()73 int StartServer() {
74 struct sockaddr_in serv_addr = {};
75 int fd = socket(AF_INET, SOCK_STREAM, 0);
76 EXPECT_FALSE(fd < 0);
77
78 serv_addr.sin_family = AF_INET;
79 serv_addr.sin_addr.s_addr = INADDR_ANY;
80 serv_addr.sin_port = htons(kPort);
81 int reuse_flag = 1;
82 EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
83 sizeof(reuse_flag)) < 0);
84 EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
85
86 listen(fd, 1);
87 return fd;
88 }
89
AcceptConnection(int fd)90 int AcceptConnection(int fd) {
91 struct sockaddr_in cli_addr;
92 memset(&cli_addr, 0, sizeof(cli_addr));
93 socklen_t clilen = sizeof(cli_addr);
94
95 int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
96 EXPECT_FALSE(connection_fd < 0);
97
98 return connection_fd;
99 }
100
ConnectSocketPair()101 std::tuple<int, int> ConnectSocketPair() {
102 int cli = ConnectClient();
103 WriteFromClient(cli);
104 AwaitServerResponse(cli);
105 int ser = connection_fd_;
106 connection_fd_ = -1;
107 return {cli, ser};
108 }
109
ReadIncomingMessage(int fd)110 void ReadIncomingMessage(int fd) {
111 int n;
112 do {
113 n = read(fd, server_buffer_, kBufferSize - 1);
114 } while (n == -1 && errno == EAGAIN);
115
116 if (n == 0 || errno == EBADF) {
117 // Got EOF, or file descriptor disconnected.
118 async_manager_.StopWatchingFileDescriptor(fd);
119 close(fd);
120 } else {
121 ASSERT_GE(n, 0) << strerror(errno);
122 n = write(fd, "1", 1);
123 }
124 }
125
SetUp()126 void SetUp() override {
127 memset(server_buffer_, 0, kBufferSize);
128 memset(client_buffer_, 0, kBufferSize);
129 socket_fd_ = -1;
130 connection_fd_ = -1;
131
132 socket_fd_ = StartServer();
133
134 async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
135 connection_fd_ = AcceptConnection(fd);
136
137 async_manager_.WatchFdForNonBlockingReads(
138 connection_fd_, [this](int fd) { ReadIncomingMessage(fd); });
139 });
140 }
141
TearDown()142 void TearDown() override {
143 async_manager_.StopWatchingFileDescriptor(socket_fd_);
144 close(socket_fd_);
145 close(connection_fd_);
146 ASSERT_EQ(std::string_view(server_buffer_, kBufferSize),
147 std::string_view(client_buffer_, kBufferSize));
148 }
149
ConnectClient()150 int ConnectClient() {
151 int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
152 EXPECT_GE(socket_cli_fd, 0) << strerror(errno);
153
154 struct hostent* server;
155 server = gethostbyname("localhost");
156 EXPECT_FALSE(server == NULL) << strerror(errno);
157
158 struct sockaddr_in serv_addr;
159 memset((void*)&serv_addr, 0, sizeof(serv_addr));
160 serv_addr.sin_family = AF_INET;
161 serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
162 serv_addr.sin_port = htons(kPort);
163
164 int result =
165 connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
166 EXPECT_GE(result, 0) << strerror(errno);
167
168 return socket_cli_fd;
169 }
170
WriteFromClient(int socket_cli_fd)171 void WriteFromClient(int socket_cli_fd) {
172 strcpy(client_buffer_, "1");
173 int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
174 ASSERT_GT(n, 0) << strerror(errno);
175 }
176
AwaitServerResponse(int socket_cli_fd)177 void AwaitServerResponse(int socket_cli_fd) {
178 int n = read(socket_cli_fd, client_buffer_, 1);
179 ASSERT_GT(n, 0) << strerror(errno);
180 }
181
182 protected:
183 AsyncManager async_manager_;
184 int socket_fd_;
185 int connection_fd_;
186 char server_buffer_[kBufferSize];
187 char client_buffer_[kBufferSize];
188 };
189
TEST_F(AsyncManagerSocketTest,TestOneConnection)190 TEST_F(AsyncManagerSocketTest, TestOneConnection) {
191 int socket_cli_fd = ConnectClient();
192
193 WriteFromClient(socket_cli_fd);
194
195 AwaitServerResponse(socket_cli_fd);
196
197 close(socket_cli_fd);
198 }
199
TEST_F(AsyncManagerSocketTest,CanUnsubscribeInCallback)200 TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
201 using namespace std::chrono_literals;
202
203 int socket_cli_fd = ConnectClient();
204 WriteFromClient(socket_cli_fd);
205 AwaitServerResponse(socket_cli_fd);
206 fcntl(connection_fd_, F_SETFL, O_NONBLOCK);
207
208 std::string data('x', 32);
209
210 bool stopped = false;
211 async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
212 async_manager_.StopWatchingFileDescriptor(fd);
213 char buf[32];
214 while (read(fd, buf, sizeof(buf)) > 0)
215 ;
216 stopped = true;
217 });
218
219 while (!stopped) {
220 write(socket_cli_fd, data.data(), data.size());
221 std::this_thread::sleep_for(5ms);
222 }
223
224 SUCCEED();
225 close(socket_cli_fd);
226 }
227
TEST_F(AsyncManagerSocketTest,CanUnsubscribeTaskFromWithinTask)228 TEST_F(AsyncManagerSocketTest, CanUnsubscribeTaskFromWithinTask) {
229 Event running;
230 using namespace std::chrono_literals;
231 async_manager_.ExecAsyncPeriodically(1, 1ms, 2ms, [&running, this]() {
232 EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
233 << "We were scheduled, so cancel should return true";
234 EXPECT_FALSE(async_manager_.CancelAsyncTask(1))
235 << "We were not scheduled, so cancel should return false";
236 running.set(true);
237 });
238
239 EXPECT_TRUE(running.wait_for(100ms));
240 }
241
TEST_F(AsyncManagerSocketTest,UnsubScribeWaitsUntilCompletion)242 TEST_F(AsyncManagerSocketTest, UnsubScribeWaitsUntilCompletion) {
243 using namespace std::chrono_literals;
244 Event running;
245 std::atomic<bool> cancel_done = false;
246 std::atomic<bool> task_complete = false;
247 AsyncTaskId task_id = async_manager_.ExecAsyncPeriodically(
248 1, 1ms, 2ms, [&running, &cancel_done, &task_complete]() {
249 // Let the other thread now we are in the callback..
250 running.set(true);
251 // Wee bit of a hack that relies on timing..
252 std::this_thread::sleep_for(20ms);
253 EXPECT_FALSE(cancel_done.load())
254 << "Task cancellation did not wait for us to complete!";
255 task_complete.store(true);
256 });
257
258 EXPECT_TRUE(running.wait_for(100ms));
259 auto start = std::chrono::system_clock::now();
260
261 // There is a 20ms wait.. so we know that this should take some time.
262 EXPECT_TRUE(async_manager_.CancelAsyncTask(task_id))
263 << "We were scheduled, so cancel should return true";
264 cancel_done.store(true);
265 EXPECT_TRUE(task_complete.load())
266 << "We managed to cancel a task while it was not yet finished.";
267 auto end = std::chrono::system_clock::now();
268 auto passed_ms =
269 std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
270 EXPECT_GT(passed_ms.count(), 10);
271 }
272
TEST_F(AsyncManagerSocketTest,NoEventsAfterUnsubscribe)273 TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
274 // This tests makes sure the AsyncManager never fires an event
275 // after calling StopWatchingFileDescriptor.
276 using clock = std::chrono::system_clock;
277 using namespace std::chrono_literals;
278
279 clock::time_point time_fast_called;
280 clock::time_point time_slow_called;
281 clock::time_point time_stopped_listening;
282
283 int round = 0;
284 auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
285 fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);
286
287 auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
288 fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);
289
290 std::string data(1, 'x');
291
292 // The idea here is as follows:
293 // We want to make sure that an unsubscribed callback never gets called.
294 // This is to make sure we can safely do things like this:
295 //
296 // class Foo {
297 // Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
298 // am_->WatchFdForNonBlockingReads(
299 // fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
300 // }
301 // ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
302 //
303 // AsyncManager* am_;
304 // int fd_;
305 // };
306 //
307 // We are going to force a failure as follows:
308 //
309 // The slow callback needs to be called first, if it does not we cannot
310 // force failure, so we have to try multiple times.
311 //
312 // t1, is the thread doing the loop.
313 // t2, is the async manager handler thread.
314 //
315 // t1 will block until the slowcallback.
316 // t2 will now block (for at most 250 ms).
317 // t1 will unsubscribe the fast callback.
318 // 2 cases:
319 // with bug:
320 // - t1 takes a timestamp, unblocks t2,
321 // - t2 invokes the fast callback, and gets a timestamp.
322 // - Now the unsubscribe time is before the callback time.
323 // without bug.:
324 // - t1 locks un unsusbcribe in asyn manager
325 // - t2 unlocks due to timeout,
326 // - t2 invokes the fast callback, and gets a timestamp.
327 // - t1 is unlocked and gets a timestamp.
328 // - Now the unsubscribe time is after the callback time..
329
330 do {
331 Event unblock_slow, inslow, infast;
332 time_fast_called = {};
333 time_slow_called = {};
334 time_stopped_listening = {};
335 printf("round: %d\n", round++);
336
337 // Register fd events
338 async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
339 if (*inslow) return;
340 time_slow_called = clock::now();
341 printf("slow: %lld\n",
342 time_slow_called.time_since_epoch().count() % 10000);
343 inslow.set();
344 unblock_slow.wait_for(25ms);
345 });
346
347 async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
348 if (*infast) return;
349 time_fast_called = clock::now();
350 printf("fast: %lld\n",
351 time_fast_called.time_since_epoch().count() % 10000);
352 infast.set();
353 });
354
355 // Generate fd events
356 write(fast_cli_fd, data.data(), data.size());
357 write(slow_cli_fd, data.data(), data.size());
358
359 // Block in the right places.
360 if (inslow.wait_for(25ms)) {
361 async_manager_.StopWatchingFileDescriptor(fast_s_fd);
362 time_stopped_listening = clock::now();
363 printf("stop: %lld\n",
364 time_stopped_listening.time_since_epoch().count() % 10000);
365 unblock_slow.set();
366 }
367
368 infast.wait_for(25ms);
369
370 // Unregister.
371 async_manager_.StopWatchingFileDescriptor(fast_s_fd);
372 async_manager_.StopWatchingFileDescriptor(slow_s_fd);
373 } while (time_fast_called < time_slow_called);
374
375 // fast before stop listening.
376 ASSERT_LT(time_fast_called.time_since_epoch().count(),
377 time_stopped_listening.time_since_epoch().count());
378
379 // Cleanup
380 close(fast_cli_fd);
381 close(fast_s_fd);
382 close(slow_cli_fd);
383 close(slow_s_fd);
384 }
385
TEST_F(AsyncManagerSocketTest,TestRepeatedConnections)386 TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
387 static const int num_connections = 30;
388 for (int i = 0; i < num_connections; i++) {
389 int socket_cli_fd = ConnectClient();
390 WriteFromClient(socket_cli_fd);
391 AwaitServerResponse(socket_cli_fd);
392 close(socket_cli_fd);
393 }
394 }
395
TEST_F(AsyncManagerSocketTest,TestMultipleConnections)396 TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
397 static const int num_connections = 30;
398 int socket_cli_fd[num_connections];
399 for (int i = 0; i < num_connections; i++) {
400 socket_cli_fd[i] = ConnectClient();
401 ASSERT_TRUE(socket_cli_fd[i] > 0);
402 WriteFromClient(socket_cli_fd[i]);
403 }
404 for (int i = 0; i < num_connections; i++) {
405 AwaitServerResponse(socket_cli_fd[i]);
406 close(socket_cli_fd[i]);
407 }
408 }
409
410 class AsyncManagerTest : public ::testing::Test {
411 public:
412 AsyncManager async_manager_;
413 };
414
TEST_F(AsyncManagerTest,TestSetupTeardown)415 TEST_F(AsyncManagerTest, TestSetupTeardown) {}
416
TEST_F(AsyncManagerTest,TestCancelTask)417 TEST_F(AsyncManagerTest, TestCancelTask) {
418 AsyncUserId user1 = async_manager_.GetNextUserId();
419 bool task1_ran = false;
420 bool* task1_ran_ptr = &task1_ran;
421 AsyncTaskId task1_id =
422 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
423 [task1_ran_ptr]() { *task1_ran_ptr = true; });
424 ASSERT_TRUE(async_manager_.CancelAsyncTask(task1_id));
425 ASSERT_FALSE(task1_ran);
426 }
427
TEST_F(AsyncManagerTest,TestCancelLongTask)428 TEST_F(AsyncManagerTest, TestCancelLongTask) {
429 AsyncUserId user1 = async_manager_.GetNextUserId();
430 bool task1_ran = false;
431 bool* task1_ran_ptr = &task1_ran;
432 AsyncTaskId task1_id =
433 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
434 [task1_ran_ptr]() { *task1_ran_ptr = true; });
435 bool task2_ran = false;
436 bool* task2_ran_ptr = &task2_ran;
437 AsyncTaskId task2_id =
438 async_manager_.ExecAsync(user1, std::chrono::seconds(2),
439 [task2_ran_ptr]() { *task2_ran_ptr = true; });
440 ASSERT_FALSE(task1_ran);
441 ASSERT_FALSE(task2_ran);
442 while (!task1_ran)
443 ;
444 ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
445 ASSERT_FALSE(task2_ran);
446 ASSERT_TRUE(async_manager_.CancelAsyncTask(task2_id));
447 }
448
TEST_F(AsyncManagerTest,TestCancelAsyncTasksFromUser)449 TEST_F(AsyncManagerTest, TestCancelAsyncTasksFromUser) {
450 AsyncUserId user1 = async_manager_.GetNextUserId();
451 AsyncUserId user2 = async_manager_.GetNextUserId();
452 bool task1_ran = false;
453 bool* task1_ran_ptr = &task1_ran;
454 bool task2_ran = false;
455 bool* task2_ran_ptr = &task2_ran;
456 bool task3_ran = false;
457 bool* task3_ran_ptr = &task3_ran;
458 bool task4_ran = false;
459 bool* task4_ran_ptr = &task4_ran;
460 bool task5_ran = false;
461 bool* task5_ran_ptr = &task5_ran;
462 AsyncTaskId task1_id =
463 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
464 [task1_ran_ptr]() { *task1_ran_ptr = true; });
465 AsyncTaskId task2_id =
466 async_manager_.ExecAsync(user1, std::chrono::seconds(2),
467 [task2_ran_ptr]() { *task2_ran_ptr = true; });
468 AsyncTaskId task3_id =
469 async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
470 [task3_ran_ptr]() { *task3_ran_ptr = true; });
471 AsyncTaskId task4_id =
472 async_manager_.ExecAsync(user1, std::chrono::seconds(2),
473 [task4_ran_ptr]() { *task4_ran_ptr = true; });
474 AsyncTaskId task5_id =
475 async_manager_.ExecAsync(user2, std::chrono::milliseconds(2),
476 [task5_ran_ptr]() { *task5_ran_ptr = true; });
477 ASSERT_FALSE(task1_ran);
478 while (!task1_ran || !task3_ran || !task5_ran)
479 ;
480 ASSERT_TRUE(task1_ran);
481 ASSERT_FALSE(task2_ran);
482 ASSERT_TRUE(task3_ran);
483 ASSERT_FALSE(task4_ran);
484 ASSERT_TRUE(task5_ran);
485 async_manager_.CancelAsyncTasksFromUser(user1);
486 ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
487 ASSERT_FALSE(async_manager_.CancelAsyncTask(task2_id));
488 ASSERT_FALSE(async_manager_.CancelAsyncTask(task3_id));
489 ASSERT_FALSE(async_manager_.CancelAsyncTask(task4_id));
490 ASSERT_FALSE(async_manager_.CancelAsyncTask(task5_id));
491 }
492
493 } // namespace rootcanal
494