• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "model/setup/async_manager.h"
18 
19 #include <fcntl.h>        // for fcntl, F_SETFL, O_NONBLOCK
20 #include <gtest/gtest.h>  // for Message, TestPartResult, SuiteApi...
21 #include <netdb.h>        // for gethostbyname, h_addr, hostent
22 #include <netinet/in.h>   // for sockaddr_in, in_addr, INADDR_ANY
23 #include <stdio.h>        // for printf
24 #include <sys/socket.h>   // for socket, AF_INET, accept, bind
25 #include <sys/types.h>    // for in_addr_t
26 #include <time.h>         // for NULL, size_t
27 #include <unistd.h>       // for close, write, read
28 
29 #include <condition_variable>  // for condition_variable
30 #include <cstdint>             // for uint16_t
31 #include <cstring>             // for memset, strcmp, strcpy, strlen
32 #include <mutex>               // for mutex
33 #include <ratio>               // for ratio
34 #include <string>              // for string
35 #include <thread>
36 #include <tuple>  // for tuple
37 
38 namespace rootcanal {
39 
40 class Event {
41  public:
set(bool set=true)42   void set(bool set = true) {
43     std::unique_lock<std::mutex> lk(m_);
44     set_ = set;
45     cv_.notify_all();
46   }
47 
reset()48   void reset() { set(false); }
49 
wait_for(std::chrono::microseconds timeout)50   bool wait_for(std::chrono::microseconds timeout) {
51     std::unique_lock<std::mutex> lk(m_);
52     return cv_.wait_for(lk, timeout, [&] { return set_; });
53   }
54 
operator *()55   bool operator*() { return set_; }
56 
57  private:
58   std::mutex m_;
59   std::condition_variable cv_;
60   bool set_{false};
61 };
62 
63 class AsyncManagerSocketTest : public ::testing::Test {
64  public:
65   static const uint16_t kPort = 6111;
66   static const size_t kBufferSize = 16;
67 
CheckBufferEquals()68   bool CheckBufferEquals() {
69     return strcmp(server_buffer_, client_buffer_) == 0;
70   }
71 
72  protected:
StartServer()73   int StartServer() {
74     struct sockaddr_in serv_addr = {};
75     int fd = socket(AF_INET, SOCK_STREAM, 0);
76     EXPECT_FALSE(fd < 0);
77 
78     serv_addr.sin_family = AF_INET;
79     serv_addr.sin_addr.s_addr = INADDR_ANY;
80     serv_addr.sin_port = htons(kPort);
81     int reuse_flag = 1;
82     EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
83                             sizeof(reuse_flag)) < 0);
84     EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
85 
86     listen(fd, 1);
87     return fd;
88   }
89 
AcceptConnection(int fd)90   int AcceptConnection(int fd) {
91     struct sockaddr_in cli_addr;
92     memset(&cli_addr, 0, sizeof(cli_addr));
93     socklen_t clilen = sizeof(cli_addr);
94 
95     int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
96     EXPECT_FALSE(connection_fd < 0);
97 
98     return connection_fd;
99   }
100 
ConnectSocketPair()101   std::tuple<int, int> ConnectSocketPair() {
102     int cli = ConnectClient();
103     WriteFromClient(cli);
104     AwaitServerResponse(cli);
105     int ser = connection_fd_;
106     connection_fd_ = -1;
107     return {cli, ser};
108   }
109 
ReadIncomingMessage(int fd)110   void ReadIncomingMessage(int fd) {
111     int n;
112     do {
113       n = read(fd, server_buffer_, kBufferSize - 1);
114     } while (n == -1 && errno == EAGAIN);
115 
116     if (n == 0 || errno == EBADF) {
117       // Got EOF, or file descriptor disconnected.
118       async_manager_.StopWatchingFileDescriptor(fd);
119       close(fd);
120     } else {
121       ASSERT_GE(n, 0) << strerror(errno);
122       n = write(fd, "1", 1);
123     }
124   }
125 
SetUp()126   void SetUp() override {
127     memset(server_buffer_, 0, kBufferSize);
128     memset(client_buffer_, 0, kBufferSize);
129     socket_fd_ = -1;
130     connection_fd_ = -1;
131 
132     socket_fd_ = StartServer();
133 
134     async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
135       connection_fd_ = AcceptConnection(fd);
136 
137       async_manager_.WatchFdForNonBlockingReads(
138           connection_fd_, [this](int fd) { ReadIncomingMessage(fd); });
139     });
140   }
141 
TearDown()142   void TearDown() override {
143     async_manager_.StopWatchingFileDescriptor(socket_fd_);
144     close(socket_fd_);
145     close(connection_fd_);
146     ASSERT_EQ(std::string_view(server_buffer_, kBufferSize),
147               std::string_view(client_buffer_, kBufferSize));
148   }
149 
ConnectClient()150   int ConnectClient() {
151     int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
152     EXPECT_GE(socket_cli_fd, 0) << strerror(errno);
153 
154     struct hostent* server;
155     server = gethostbyname("localhost");
156     EXPECT_FALSE(server == NULL) << strerror(errno);
157 
158     struct sockaddr_in serv_addr;
159     memset((void*)&serv_addr, 0, sizeof(serv_addr));
160     serv_addr.sin_family = AF_INET;
161     serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
162     serv_addr.sin_port = htons(kPort);
163 
164     int result =
165         connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
166     EXPECT_GE(result, 0) << strerror(errno);
167 
168     return socket_cli_fd;
169   }
170 
WriteFromClient(int socket_cli_fd)171   void WriteFromClient(int socket_cli_fd) {
172     strcpy(client_buffer_, "1");
173     int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
174     ASSERT_GT(n, 0) << strerror(errno);
175   }
176 
AwaitServerResponse(int socket_cli_fd)177   void AwaitServerResponse(int socket_cli_fd) {
178     int n = read(socket_cli_fd, client_buffer_, 1);
179     ASSERT_GT(n, 0) << strerror(errno);
180   }
181 
182  protected:
183   AsyncManager async_manager_;
184   int socket_fd_;
185   int connection_fd_;
186   char server_buffer_[kBufferSize];
187   char client_buffer_[kBufferSize];
188 };
189 
TEST_F(AsyncManagerSocketTest,TestOneConnection)190 TEST_F(AsyncManagerSocketTest, TestOneConnection) {
191   int socket_cli_fd = ConnectClient();
192 
193   WriteFromClient(socket_cli_fd);
194 
195   AwaitServerResponse(socket_cli_fd);
196 
197   close(socket_cli_fd);
198 }
199 
TEST_F(AsyncManagerSocketTest,CanUnsubscribeInCallback)200 TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
201   using namespace std::chrono_literals;
202 
203   int socket_cli_fd = ConnectClient();
204   WriteFromClient(socket_cli_fd);
205   AwaitServerResponse(socket_cli_fd);
206   fcntl(connection_fd_, F_SETFL, O_NONBLOCK);
207 
208   std::string data('x', 32);
209 
210   bool stopped = false;
211   async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
212     async_manager_.StopWatchingFileDescriptor(fd);
213     char buf[32];
214     while (read(fd, buf, sizeof(buf)) > 0)
215       ;
216     stopped = true;
217   });
218 
219   while (!stopped) {
220     write(socket_cli_fd, data.data(), data.size());
221     std::this_thread::sleep_for(5ms);
222   }
223 
224   SUCCEED();
225   close(socket_cli_fd);
226 }
227 
TEST_F(AsyncManagerSocketTest,CanUnsubscribeTaskFromWithinTask)228 TEST_F(AsyncManagerSocketTest, CanUnsubscribeTaskFromWithinTask) {
229   Event running;
230   using namespace std::chrono_literals;
231   async_manager_.ExecAsyncPeriodically(1, 1ms, 2ms, [&running, this]() {
232     EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
233         << "We were scheduled, so cancel should return true";
234     EXPECT_FALSE(async_manager_.CancelAsyncTask(1))
235         << "We were not scheduled, so cancel should return false";
236     running.set(true);
237   });
238 
239   EXPECT_TRUE(running.wait_for(100ms));
240 }
241 
TEST_F(AsyncManagerSocketTest,UnsubScribeWaitsUntilCompletion)242 TEST_F(AsyncManagerSocketTest, UnsubScribeWaitsUntilCompletion) {
243   using namespace std::chrono_literals;
244   Event running;
245   std::atomic<bool> cancel_done = false;
246   std::atomic<bool> task_complete = false;
247   AsyncTaskId task_id = async_manager_.ExecAsyncPeriodically(
248       1, 1ms, 2ms, [&running, &cancel_done, &task_complete]() {
249         // Let the other thread now we are in the callback..
250         running.set(true);
251         // Wee bit of a hack that relies on timing..
252         std::this_thread::sleep_for(20ms);
253         EXPECT_FALSE(cancel_done.load())
254             << "Task cancellation did not wait for us to complete!";
255         task_complete.store(true);
256       });
257 
258   EXPECT_TRUE(running.wait_for(100ms));
259   auto start = std::chrono::system_clock::now();
260 
261   // There is a 20ms wait.. so we know that this should take some time.
262   EXPECT_TRUE(async_manager_.CancelAsyncTask(task_id))
263       << "We were scheduled, so cancel should return true";
264   cancel_done.store(true);
265   EXPECT_TRUE(task_complete.load())
266       << "We managed to cancel a task while it was not yet finished.";
267   auto end = std::chrono::system_clock::now();
268   auto passed_ms =
269       std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
270   EXPECT_GT(passed_ms.count(), 10);
271 }
272 
TEST_F(AsyncManagerSocketTest,NoEventsAfterUnsubscribe)273 TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
274   // This tests makes sure the AsyncManager never fires an event
275   // after calling StopWatchingFileDescriptor.
276   using clock = std::chrono::system_clock;
277   using namespace std::chrono_literals;
278 
279   clock::time_point time_fast_called;
280   clock::time_point time_slow_called;
281   clock::time_point time_stopped_listening;
282 
283   int round = 0;
284   auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
285   fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);
286 
287   auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
288   fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);
289 
290   std::string data(1, 'x');
291 
292   // The idea here is as follows:
293   // We want to make sure that an unsubscribed callback never gets called.
294   // This is to make sure we can safely do things like this:
295   //
296   // class Foo {
297   //   Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
298   //     am_->WatchFdForNonBlockingReads(
299   //         fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
300   //   }
301   //   ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
302   //
303   //   AsyncManager* am_;
304   //   int fd_;
305   // };
306   //
307   // We are going to force a failure as follows:
308   //
309   // The slow callback needs to be called first, if it does not we cannot
310   // force failure, so we have to try multiple times.
311   //
312   // t1, is the thread doing the loop.
313   // t2, is the async manager handler thread.
314   //
315   // t1 will block until the slowcallback.
316   // t2 will now block (for at most 250 ms).
317   // t1 will unsubscribe the fast callback.
318   // 2 cases:
319   //   with bug:
320   //      - t1 takes a timestamp, unblocks t2,
321   //      - t2 invokes the fast callback, and gets a timestamp.
322   //      - Now the unsubscribe time is before the callback time.
323   //   without bug.:
324   //      - t1 locks un unsusbcribe in asyn manager
325   //      - t2 unlocks due to timeout,
326   //      - t2 invokes the fast callback, and gets a timestamp.
327   //      - t1 is unlocked and gets a timestamp.
328   //      - Now the unsubscribe time is after the callback time..
329 
330   do {
331     Event unblock_slow, inslow, infast;
332     time_fast_called = {};
333     time_slow_called = {};
334     time_stopped_listening = {};
335     printf("round: %d\n", round++);
336 
337     // Register fd events
338     async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
339       if (*inslow) return;
340       time_slow_called = clock::now();
341       printf("slow: %lld\n",
342              time_slow_called.time_since_epoch().count() % 10000);
343       inslow.set();
344       unblock_slow.wait_for(25ms);
345     });
346 
347     async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
348       if (*infast) return;
349       time_fast_called = clock::now();
350       printf("fast: %lld\n",
351              time_fast_called.time_since_epoch().count() % 10000);
352       infast.set();
353     });
354 
355     // Generate fd events
356     write(fast_cli_fd, data.data(), data.size());
357     write(slow_cli_fd, data.data(), data.size());
358 
359     // Block in the right places.
360     if (inslow.wait_for(25ms)) {
361       async_manager_.StopWatchingFileDescriptor(fast_s_fd);
362       time_stopped_listening = clock::now();
363       printf("stop: %lld\n",
364              time_stopped_listening.time_since_epoch().count() % 10000);
365       unblock_slow.set();
366     }
367 
368     infast.wait_for(25ms);
369 
370     // Unregister.
371     async_manager_.StopWatchingFileDescriptor(fast_s_fd);
372     async_manager_.StopWatchingFileDescriptor(slow_s_fd);
373   } while (time_fast_called < time_slow_called);
374 
375   // fast before stop listening.
376   ASSERT_LT(time_fast_called.time_since_epoch().count(),
377             time_stopped_listening.time_since_epoch().count());
378 
379   // Cleanup
380   close(fast_cli_fd);
381   close(fast_s_fd);
382   close(slow_cli_fd);
383   close(slow_s_fd);
384 }
385 
TEST_F(AsyncManagerSocketTest,TestRepeatedConnections)386 TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
387   static const int num_connections = 30;
388   for (int i = 0; i < num_connections; i++) {
389     int socket_cli_fd = ConnectClient();
390     WriteFromClient(socket_cli_fd);
391     AwaitServerResponse(socket_cli_fd);
392     close(socket_cli_fd);
393   }
394 }
395 
TEST_F(AsyncManagerSocketTest,TestMultipleConnections)396 TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
397   static const int num_connections = 30;
398   int socket_cli_fd[num_connections];
399   for (int i = 0; i < num_connections; i++) {
400     socket_cli_fd[i] = ConnectClient();
401     ASSERT_TRUE(socket_cli_fd[i] > 0);
402     WriteFromClient(socket_cli_fd[i]);
403   }
404   for (int i = 0; i < num_connections; i++) {
405     AwaitServerResponse(socket_cli_fd[i]);
406     close(socket_cli_fd[i]);
407   }
408 }
409 
410 class AsyncManagerTest : public ::testing::Test {
411  public:
412   AsyncManager async_manager_;
413 };
414 
TEST_F(AsyncManagerTest,TestSetupTeardown)415 TEST_F(AsyncManagerTest, TestSetupTeardown) {}
416 
TEST_F(AsyncManagerTest,TestCancelTask)417 TEST_F(AsyncManagerTest, TestCancelTask) {
418   AsyncUserId user1 = async_manager_.GetNextUserId();
419   bool task1_ran = false;
420   bool* task1_ran_ptr = &task1_ran;
421   AsyncTaskId task1_id =
422       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
423                                [task1_ran_ptr]() { *task1_ran_ptr = true; });
424   ASSERT_TRUE(async_manager_.CancelAsyncTask(task1_id));
425   ASSERT_FALSE(task1_ran);
426 }
427 
TEST_F(AsyncManagerTest,TestCancelLongTask)428 TEST_F(AsyncManagerTest, TestCancelLongTask) {
429   AsyncUserId user1 = async_manager_.GetNextUserId();
430   bool task1_ran = false;
431   bool* task1_ran_ptr = &task1_ran;
432   AsyncTaskId task1_id =
433       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
434                                [task1_ran_ptr]() { *task1_ran_ptr = true; });
435   bool task2_ran = false;
436   bool* task2_ran_ptr = &task2_ran;
437   AsyncTaskId task2_id =
438       async_manager_.ExecAsync(user1, std::chrono::seconds(2),
439                                [task2_ran_ptr]() { *task2_ran_ptr = true; });
440   ASSERT_FALSE(task1_ran);
441   ASSERT_FALSE(task2_ran);
442   while (!task1_ran)
443     ;
444   ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
445   ASSERT_FALSE(task2_ran);
446   ASSERT_TRUE(async_manager_.CancelAsyncTask(task2_id));
447 }
448 
TEST_F(AsyncManagerTest,TestCancelAsyncTasksFromUser)449 TEST_F(AsyncManagerTest, TestCancelAsyncTasksFromUser) {
450   AsyncUserId user1 = async_manager_.GetNextUserId();
451   AsyncUserId user2 = async_manager_.GetNextUserId();
452   bool task1_ran = false;
453   bool* task1_ran_ptr = &task1_ran;
454   bool task2_ran = false;
455   bool* task2_ran_ptr = &task2_ran;
456   bool task3_ran = false;
457   bool* task3_ran_ptr = &task3_ran;
458   bool task4_ran = false;
459   bool* task4_ran_ptr = &task4_ran;
460   bool task5_ran = false;
461   bool* task5_ran_ptr = &task5_ran;
462   AsyncTaskId task1_id =
463       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
464                                [task1_ran_ptr]() { *task1_ran_ptr = true; });
465   AsyncTaskId task2_id =
466       async_manager_.ExecAsync(user1, std::chrono::seconds(2),
467                                [task2_ran_ptr]() { *task2_ran_ptr = true; });
468   AsyncTaskId task3_id =
469       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
470                                [task3_ran_ptr]() { *task3_ran_ptr = true; });
471   AsyncTaskId task4_id =
472       async_manager_.ExecAsync(user1, std::chrono::seconds(2),
473                                [task4_ran_ptr]() { *task4_ran_ptr = true; });
474   AsyncTaskId task5_id =
475       async_manager_.ExecAsync(user2, std::chrono::milliseconds(2),
476                                [task5_ran_ptr]() { *task5_ran_ptr = true; });
477   ASSERT_FALSE(task1_ran);
478   while (!task1_ran || !task3_ran || !task5_ran)
479     ;
480   ASSERT_TRUE(task1_ran);
481   ASSERT_FALSE(task2_ran);
482   ASSERT_TRUE(task3_ran);
483   ASSERT_FALSE(task4_ran);
484   ASSERT_TRUE(task5_ran);
485   async_manager_.CancelAsyncTasksFromUser(user1);
486   ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
487   ASSERT_FALSE(async_manager_.CancelAsyncTask(task2_id));
488   ASSERT_FALSE(async_manager_.CancelAsyncTask(task3_id));
489   ASSERT_FALSE(async_manager_.CancelAsyncTask(task4_id));
490   ASSERT_FALSE(async_manager_.CancelAsyncTask(task5_id));
491 }
492 
493 }  // namespace rootcanal
494