• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "model/setup/async_manager.h"
18 
19 #include <fcntl.h>        // for fcntl, F_SETFL, O_NONBLOCK
20 #include <gtest/gtest.h>  // for Message, TestPartResult, SuiteApi...
21 #include <netdb.h>        // for gethostbyname, h_addr, hostent
22 #include <netinet/in.h>   // for sockaddr_in, in_addr, INADDR_ANY
23 #include <stdio.h>        // for printf
24 #include <sys/socket.h>   // for socket, AF_INET, accept, bind
25 #include <sys/types.h>    // for in_addr_t
26 #include <time.h>         // for NULL, size_t
27 #include <unistd.h>       // for close, write, read
28 
29 #include <condition_variable>  // for condition_variable
30 #include <cstdint>             // for uint16_t
31 #include <cstring>             // for memset, strcmp, strcpy, strlen
32 #include <mutex>               // for mutex
33 #include <ratio>               // for ratio
34 #include <string>              // for string
35 #include <thread>
36 #include <tuple>  // for tuple
37 
38 namespace rootcanal {
39 
40 class Event {
41  public:
set(bool set=true)42   void set(bool set = true) {
43     std::unique_lock<std::mutex> lk(m_);
44     set_ = set;
45     cv_.notify_all();
46   }
47 
reset()48   void reset() { set(false); }
49 
wait_for(std::chrono::microseconds timeout)50   bool wait_for(std::chrono::microseconds timeout) {
51     std::unique_lock<std::mutex> lk(m_);
52     return cv_.wait_for(lk, timeout, [&] { return set_; });
53   }
54 
operator *()55   bool operator*() { return set_; }
56 
57  private:
58   std::mutex m_;
59   std::condition_variable cv_;
60   bool set_{false};
61 };
62 
63 class AsyncManagerSocketTest : public ::testing::Test {
64  public:
65   static const uint16_t kPort = 6111;
66   static const size_t kBufferSize = 16;
67 
CheckBufferEquals()68   bool CheckBufferEquals() {
69     return strcmp(server_buffer_, client_buffer_) == 0;
70   }
71 
72  protected:
StartServer()73   int StartServer() {
74     struct sockaddr_in serv_addr = {};
75     int fd = socket(AF_INET, SOCK_STREAM, 0);
76     EXPECT_FALSE(fd < 0);
77 
78     serv_addr.sin_family = AF_INET;
79     serv_addr.sin_addr.s_addr = INADDR_ANY;
80     serv_addr.sin_port = htons(kPort);
81     int reuse_flag = 1;
82     EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
83                             sizeof(reuse_flag)) < 0);
84     EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
85 
86     listen(fd, 1);
87     return fd;
88   }
89 
AcceptConnection(int fd)90   int AcceptConnection(int fd) {
91     struct sockaddr_in cli_addr;
92     memset(&cli_addr, 0, sizeof(cli_addr));
93     socklen_t clilen = sizeof(cli_addr);
94 
95     int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
96     EXPECT_FALSE(connection_fd < 0);
97 
98     return connection_fd;
99   }
100 
ConnectSocketPair()101   std::tuple<int, int> ConnectSocketPair() {
102     int cli = ConnectClient();
103     WriteFromClient(cli);
104     AwaitServerResponse(cli);
105     int ser = connection_fd_;
106     connection_fd_ = -1;
107     return {cli, ser};
108   }
109 
ReadIncomingMessage(int fd)110   void ReadIncomingMessage(int fd) {
111     int n;
112     do {
113       n = read(fd, server_buffer_, kBufferSize - 1);
114     } while (n == -1 && errno == EAGAIN);
115     ASSERT_GE(n, 0) << strerror(errno);
116 
117     if (n == 0) {  // got EOF
118       async_manager_.StopWatchingFileDescriptor(fd);
119       close(fd);
120     } else {
121       n = write(fd, "1", 1);
122     }
123   }
124 
SetUp()125   void SetUp() override {
126     memset(server_buffer_, 0, kBufferSize);
127     memset(client_buffer_, 0, kBufferSize);
128     socket_fd_ = -1;
129     connection_fd_ = -1;
130 
131     socket_fd_ = StartServer();
132 
133     async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
134       connection_fd_ = AcceptConnection(fd);
135 
136       async_manager_.WatchFdForNonBlockingReads(
137           connection_fd_, [this](int fd) { ReadIncomingMessage(fd); });
138     });
139   }
140 
TearDown()141   void TearDown() override {
142     async_manager_.StopWatchingFileDescriptor(socket_fd_);
143     close(socket_fd_);
144     close(connection_fd_);
145     ASSERT_EQ(std::string_view(server_buffer_, kBufferSize),
146               std::string_view(client_buffer_, kBufferSize));
147   }
148 
ConnectClient()149   int ConnectClient() {
150     int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
151     EXPECT_GE(socket_cli_fd, 0) << strerror(errno);
152 
153     struct hostent* server;
154     server = gethostbyname("localhost");
155     EXPECT_FALSE(server == NULL) << strerror(errno);
156 
157     struct sockaddr_in serv_addr;
158     memset((void*)&serv_addr, 0, sizeof(serv_addr));
159     serv_addr.sin_family = AF_INET;
160     serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
161     serv_addr.sin_port = htons(kPort);
162 
163     int result =
164         connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
165     EXPECT_GE(result, 0) << strerror(errno);
166 
167     return socket_cli_fd;
168   }
169 
WriteFromClient(int socket_cli_fd)170   void WriteFromClient(int socket_cli_fd) {
171     strcpy(client_buffer_, "1");
172     int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
173     ASSERT_GT(n, 0) << strerror(errno);
174   }
175 
AwaitServerResponse(int socket_cli_fd)176   void AwaitServerResponse(int socket_cli_fd) {
177     int n = read(socket_cli_fd, client_buffer_, 1);
178     ASSERT_GT(n, 0) << strerror(errno);
179   }
180 
181  protected:
182   AsyncManager async_manager_;
183   int socket_fd_;
184   int connection_fd_;
185   char server_buffer_[kBufferSize];
186   char client_buffer_[kBufferSize];
187 };
188 
TEST_F(AsyncManagerSocketTest,TestOneConnection)189 TEST_F(AsyncManagerSocketTest, TestOneConnection) {
190   int socket_cli_fd = ConnectClient();
191 
192   WriteFromClient(socket_cli_fd);
193 
194   AwaitServerResponse(socket_cli_fd);
195 
196   close(socket_cli_fd);
197 }
198 
TEST_F(AsyncManagerSocketTest,CanUnsubscribeInCallback)199 TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
200   using namespace std::chrono_literals;
201 
202   int socket_cli_fd = ConnectClient();
203   WriteFromClient(socket_cli_fd);
204   AwaitServerResponse(socket_cli_fd);
205   fcntl(connection_fd_, F_SETFL, O_NONBLOCK);
206 
207   std::string data('x', 32);
208 
209   bool stopped = false;
210   async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
211     async_manager_.StopWatchingFileDescriptor(fd);
212     char buf[32];
213     while (read(fd, buf, sizeof(buf)) > 0)
214       ;
215     stopped = true;
216   });
217 
218   while (!stopped) {
219     write(socket_cli_fd, data.data(), data.size());
220     std::this_thread::sleep_for(5ms);
221   }
222 
223   SUCCEED();
224   close(socket_cli_fd);
225 }
226 
TEST_F(AsyncManagerSocketTest,CanUnsubscribeTaskFromWithinTask)227 TEST_F(AsyncManagerSocketTest, CanUnsubscribeTaskFromWithinTask) {
228   Event running;
229   using namespace std::chrono_literals;
230   async_manager_.ExecAsyncPeriodically(1, 1ms, 2ms, [&running, this]() {
231     EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
232         << "We were scheduled, so cancel should return true";
233     EXPECT_FALSE(async_manager_.CancelAsyncTask(1))
234         << "We were not scheduled, so cancel should return false";
235     running.set(true);
236   });
237 
238   EXPECT_TRUE(running.wait_for(10ms));
239 }
240 
TEST_F(AsyncManagerSocketTest,UnsubScribeWaitsUntilCompletion)241 TEST_F(AsyncManagerSocketTest, UnsubScribeWaitsUntilCompletion) {
242   using namespace std::chrono_literals;
243   Event running;
244   bool cancel_done = false;
245   bool task_complete = false;
246   async_manager_.ExecAsyncPeriodically(
247       1, 1ms, 2ms, [&running, &cancel_done, &task_complete]() {
248         // Let the other thread now we are in the callback..
249         running.set(true);
250         // Wee bit of a hack that relies on timing..
251         std::this_thread::sleep_for(20ms);
252         EXPECT_FALSE(cancel_done)
253             << "Task cancellation did not wait for us to complete!";
254         task_complete = true;
255       });
256 
257   EXPECT_TRUE(running.wait_for(10ms));
258   auto start = std::chrono::system_clock::now();
259 
260   // There is a 20ms wait.. so we know that this should take some time.
261   EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
262       << "We were scheduled, so cancel should return true";
263   cancel_done = true;
264   EXPECT_TRUE(task_complete)
265       << "We managed to cancel a task while it was not yet finished.";
266   auto end = std::chrono::system_clock::now();
267   auto passed_ms =
268       std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
269   EXPECT_GT(passed_ms.count(), 10);
270 }
271 
TEST_F(AsyncManagerSocketTest,NoEventsAfterUnsubscribe)272 TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
273   // This tests makes sure the AsyncManager never fires an event
274   // after calling StopWatchingFileDescriptor.
275   using clock = std::chrono::system_clock;
276   using namespace std::chrono_literals;
277 
278   clock::time_point time_fast_called;
279   clock::time_point time_slow_called;
280   clock::time_point time_stopped_listening;
281 
282   int round = 0;
283   auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
284   fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);
285 
286   auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
287   fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);
288 
289   std::string data(1, 'x');
290 
291   // The idea here is as follows:
292   // We want to make sure that an unsubscribed callback never gets called.
293   // This is to make sure we can safely do things like this:
294   //
295   // class Foo {
296   //   Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
297   //     am_->WatchFdForNonBlockingReads(
298   //         fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
299   //   }
300   //   ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
301   //
302   //   AsyncManager* am_;
303   //   int fd_;
304   // };
305   //
306   // We are going to force a failure as follows:
307   //
308   // The slow callback needs to be called first, if it does not we cannot
309   // force failure, so we have to try multiple times.
310   //
311   // t1, is the thread doing the loop.
312   // t2, is the async manager handler thread.
313   //
314   // t1 will block until the slowcallback.
315   // t2 will now block (for at most 250 ms).
316   // t1 will unsubscribe the fast callback.
317   // 2 cases:
318   //   with bug:
319   //      - t1 takes a timestamp, unblocks t2,
320   //      - t2 invokes the fast callback, and gets a timestamp.
321   //      - Now the unsubscribe time is before the callback time.
322   //   without bug.:
323   //      - t1 locks un unsusbcribe in asyn manager
324   //      - t2 unlocks due to timeout,
325   //      - t2 invokes the fast callback, and gets a timestamp.
326   //      - t1 is unlocked and gets a timestamp.
327   //      - Now the unsubscribe time is after the callback time..
328 
329   do {
330     Event unblock_slow, inslow, infast;
331     time_fast_called = {};
332     time_slow_called = {};
333     time_stopped_listening = {};
334     printf("round: %d\n", round++);
335 
336     // Register fd events
337     async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
338       if (*inslow) return;
339       time_slow_called = clock::now();
340       printf("slow: %lld\n",
341              time_slow_called.time_since_epoch().count() % 10000);
342       inslow.set();
343       unblock_slow.wait_for(25ms);
344     });
345 
346     async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
347       if (*infast) return;
348       time_fast_called = clock::now();
349       printf("fast: %lld\n",
350              time_fast_called.time_since_epoch().count() % 10000);
351       infast.set();
352     });
353 
354     // Generate fd events
355     write(fast_cli_fd, data.data(), data.size());
356     write(slow_cli_fd, data.data(), data.size());
357 
358     // Block in the right places.
359     if (inslow.wait_for(25ms)) {
360       async_manager_.StopWatchingFileDescriptor(fast_s_fd);
361       time_stopped_listening = clock::now();
362       printf("stop: %lld\n",
363              time_stopped_listening.time_since_epoch().count() % 10000);
364       unblock_slow.set();
365     }
366 
367     infast.wait_for(25ms);
368 
369     // Unregister.
370     async_manager_.StopWatchingFileDescriptor(fast_s_fd);
371     async_manager_.StopWatchingFileDescriptor(slow_s_fd);
372   } while (time_fast_called < time_slow_called);
373 
374   // fast before stop listening.
375   ASSERT_LT(time_fast_called.time_since_epoch().count(),
376             time_stopped_listening.time_since_epoch().count());
377 
378   // Cleanup
379   close(fast_cli_fd);
380   close(fast_s_fd);
381   close(slow_cli_fd);
382   close(slow_s_fd);
383 }
384 
TEST_F(AsyncManagerSocketTest,TestRepeatedConnections)385 TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
386   static const int num_connections = 30;
387   for (int i = 0; i < num_connections; i++) {
388     int socket_cli_fd = ConnectClient();
389     WriteFromClient(socket_cli_fd);
390     AwaitServerResponse(socket_cli_fd);
391     close(socket_cli_fd);
392   }
393 }
394 
TEST_F(AsyncManagerSocketTest,TestMultipleConnections)395 TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
396   static const int num_connections = 30;
397   int socket_cli_fd[num_connections];
398   for (int i = 0; i < num_connections; i++) {
399     socket_cli_fd[i] = ConnectClient();
400     ASSERT_TRUE(socket_cli_fd[i] > 0);
401     WriteFromClient(socket_cli_fd[i]);
402   }
403   for (int i = 0; i < num_connections; i++) {
404     AwaitServerResponse(socket_cli_fd[i]);
405     close(socket_cli_fd[i]);
406   }
407 }
408 
409 class AsyncManagerTest : public ::testing::Test {
410  public:
411   AsyncManager async_manager_;
412 };
413 
TEST_F(AsyncManagerTest,TestSetupTeardown)414 TEST_F(AsyncManagerTest, TestSetupTeardown) {}
415 
TEST_F(AsyncManagerTest,TestCancelTask)416 TEST_F(AsyncManagerTest, TestCancelTask) {
417   AsyncUserId user1 = async_manager_.GetNextUserId();
418   bool task1_ran = false;
419   bool* task1_ran_ptr = &task1_ran;
420   AsyncTaskId task1_id =
421       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
422                                [task1_ran_ptr]() { *task1_ran_ptr = true; });
423   ASSERT_TRUE(async_manager_.CancelAsyncTask(task1_id));
424   ASSERT_FALSE(task1_ran);
425 }
426 
TEST_F(AsyncManagerTest,TestCancelLongTask)427 TEST_F(AsyncManagerTest, TestCancelLongTask) {
428   AsyncUserId user1 = async_manager_.GetNextUserId();
429   bool task1_ran = false;
430   bool* task1_ran_ptr = &task1_ran;
431   AsyncTaskId task1_id =
432       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
433                                [task1_ran_ptr]() { *task1_ran_ptr = true; });
434   bool task2_ran = false;
435   bool* task2_ran_ptr = &task2_ran;
436   AsyncTaskId task2_id =
437       async_manager_.ExecAsync(user1, std::chrono::seconds(2),
438                                [task2_ran_ptr]() { *task2_ran_ptr = true; });
439   ASSERT_FALSE(task1_ran);
440   ASSERT_FALSE(task2_ran);
441   while (!task1_ran)
442     ;
443   ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
444   ASSERT_FALSE(task2_ran);
445   ASSERT_TRUE(async_manager_.CancelAsyncTask(task2_id));
446 }
447 
TEST_F(AsyncManagerTest,TestCancelAsyncTasksFromUser)448 TEST_F(AsyncManagerTest, TestCancelAsyncTasksFromUser) {
449   AsyncUserId user1 = async_manager_.GetNextUserId();
450   AsyncUserId user2 = async_manager_.GetNextUserId();
451   bool task1_ran = false;
452   bool* task1_ran_ptr = &task1_ran;
453   bool task2_ran = false;
454   bool* task2_ran_ptr = &task2_ran;
455   bool task3_ran = false;
456   bool* task3_ran_ptr = &task3_ran;
457   bool task4_ran = false;
458   bool* task4_ran_ptr = &task4_ran;
459   bool task5_ran = false;
460   bool* task5_ran_ptr = &task5_ran;
461   AsyncTaskId task1_id =
462       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
463                                [task1_ran_ptr]() { *task1_ran_ptr = true; });
464   AsyncTaskId task2_id =
465       async_manager_.ExecAsync(user1, std::chrono::seconds(2),
466                                [task2_ran_ptr]() { *task2_ran_ptr = true; });
467   AsyncTaskId task3_id =
468       async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
469                                [task3_ran_ptr]() { *task3_ran_ptr = true; });
470   AsyncTaskId task4_id =
471       async_manager_.ExecAsync(user1, std::chrono::seconds(2),
472                                [task4_ran_ptr]() { *task4_ran_ptr = true; });
473   AsyncTaskId task5_id =
474       async_manager_.ExecAsync(user2, std::chrono::milliseconds(2),
475                                [task5_ran_ptr]() { *task5_ran_ptr = true; });
476   ASSERT_FALSE(task1_ran);
477   while (!task1_ran || !task3_ran || !task5_ran)
478     ;
479   ASSERT_TRUE(task1_ran);
480   ASSERT_FALSE(task2_ran);
481   ASSERT_TRUE(task3_ran);
482   ASSERT_FALSE(task4_ran);
483   ASSERT_TRUE(task5_ran);
484   async_manager_.CancelAsyncTasksFromUser(user1);
485   ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
486   ASSERT_FALSE(async_manager_.CancelAsyncTask(task2_id));
487   ASSERT_FALSE(async_manager_.CancelAsyncTask(task3_id));
488   ASSERT_FALSE(async_manager_.CancelAsyncTask(task4_id));
489   ASSERT_FALSE(async_manager_.CancelAsyncTask(task5_id));
490 }
491 
492 }  // namespace rootcanal
493