• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright 2024 The Android Open Source Project
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at:
7  *
8  *  http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License.
15  */
16 
17 #include <bluetooth/log.h>
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include <unistd.h>
21 
22 #include <atomic>
23 #include <cstdlib>
24 #include <memory>
25 #include <thread>
26 
27 #include "common/strings.h"
28 #include "gd/module_jniloop.h"
29 #include "gd/module_mainloop.h"
30 #include "main/shim/stack.h"
31 #include "module.h"
32 #include "os/thread.h"
33 #include "stack/include/main_thread.h"
34 #include "test/mock/mock_main_shim_entry.h"
35 
36 using ::testing::_;
37 
38 using namespace bluetooth;
39 using namespace testing;
40 
41 namespace {
42 constexpr int kSyncMainLoopTimeoutMs = 3000;
43 constexpr int kWaitUntilHandlerStoppedMs = 2000;
44 constexpr size_t kNumTestClients = 3;
45 constexpr size_t kNumTestModules = 3;
46 constexpr int kNumIters = 100;
47 constexpr int kAbruptStackShutdownIter = kNumIters * 3 / 4;
48 constexpr char kTestStackThreadName[] = "test_stack_thread";
49 constexpr char kTestDataTag[] = "This is a test";
50 
maybe_yield()51 inline void maybe_yield() {
52   if (std::rand() & 1) std::this_thread::yield();
53 }
54 
55 constexpr size_t kTagLength = 48 + sizeof(' ') + sizeof(' ');
log_tag(std::string tag)56 inline void log_tag(std::string tag) {
57   std::string prepend(kTagLength / 2 - tag.size() / 2, '=');
58   std::string append(kTagLength / 2 - tag.size() / 2, '=');
59   log::info("{} {} {}", prepend, tag, append);
60 }
61 
62 class MainThread {
63  public:
MainThread()64   MainThread() { main_thread_start_up(); }
65 
~MainThread()66   ~MainThread() {
67     sync_main_handler();
68     main_thread_shut_down();
69   }
70 
71  private:
sync_main_handler()72   void sync_main_handler() {
73     std::promise promise = std::promise<void>();
74     std::future future = promise.get_future();
75     post_on_bt_main([&promise]() { promise.set_value(); });
76     future.wait_for(std::chrono::milliseconds(kSyncMainLoopTimeoutMs));
77   }
78 };
79 
80 class TestStackManager {
81  public:
TestStackManager()82   TestStackManager() {
83     // Start is executed by the test after each test adds the default
84     // or their own modules
85   }
86 
~TestStackManager()87   ~TestStackManager() {
88     log::debug("Deleting stack manager");
89     Stop();
90   }
91 
92   TestStackManager(const TestStackManager&) = delete;
93 
94   template <typename T>
AddModule()95   void AddModule() {
96     modules_.add<T>();
97   }
98 
Start()99   void Start() {
100     if (stack_started_) return;
101     log::info("Started stack manager");
102     stack_started_ = true;
103     bluetooth::os::Thread* stack_thread = new bluetooth::os::Thread(
104         kTestStackThreadName, bluetooth::os::Thread::Priority::NORMAL);
105     bluetooth::shim::Stack::GetInstance()->StartModuleStack(&modules_,
106                                                             stack_thread);
107   }
108 
Stop()109   void Stop() {
110     if (!stack_started_) return;
111     stack_started_ = false;
112     bluetooth::shim::Stack::GetInstance()->Stop();
113   }
114 
115   // NOTE: Stack manager *must* be active else method returns nullptr
116   // if stack manager has not started or shutdown
117   template <typename T>
GetUnsafeModule()118   static T* GetUnsafeModule() {
119     return bluetooth::shim::Stack::GetInstance()
120         ->GetStackManager()
121         ->GetInstance<T>();
122   }
123 
NumModules() const124   size_t NumModules() const { return modules_.NumModules(); }
125 
126  private:
127   bluetooth::ModuleList modules_;
128   bool stack_started_{false};
129 };
130 
131 // Data returned via callback from a stack managed module
132 struct TestCallbackData {
133   int iter;
134   std::string tag;
135 };
136 
137 // Data sent to a stack managed module via a module API
138 struct TestData {
139   int iter;
140   std::string tag;
141   std::function<void(TestCallbackData callback_data)> callback;
142 };
143 
144 class TestStackModuleBase : public bluetooth::Module,
145                             public ModuleMainloop,
146                             public ModuleJniloop {
147  public:
148   TestStackModuleBase(const TestStackModuleBase&) = delete;
149   TestStackModuleBase& operator=(const TestStackModuleBase&) = delete;
150 
~TestStackModuleBase()151   virtual ~TestStackModuleBase(){};
152   static const ModuleFactory Factory;
153 
TestMethod(TestData test_data) const154   virtual void TestMethod(TestData test_data) const {
155     log::info("Test base class iter:{} tag:{}", test_data.iter, test_data.tag);
156   }
157 
158  protected:
ListDependencies(ModuleList * list) const159   void ListDependencies(ModuleList* list) const override{};
Start()160   void Start() override { log::error("Started TestStackModuleBase"); };
Stop()161   void Stop() override { log::error("Stopped TestStackModuleBase"); };
ToString() const162   std::string ToString() const override { return std::string("TestFunction"); }
163 
164   TestStackModuleBase() = default;
165 };
166 
167 class TestStackModule1 : public TestStackModuleBase {
168  public:
169   TestStackModule1(const TestStackModule1&) = delete;
170   TestStackModule1& operator=(const TestStackModule1&) = delete;
171   virtual ~TestStackModule1() = default;
172 
173   static const ModuleFactory Factory;
174 
175   void TestMethod(TestData test_data) const override;
176 
177  private:
178   struct impl;
179   std::shared_ptr<impl> impl_;
180   TestStackModule1();
181 };
182 
183 struct TestStackModule1::impl : public ModuleMainloop, public ModuleJniloop {
test__anon7a4e79f10111::TestStackModule1::impl184   void test(TestData test_data) {
185     TestCallbackData callback_data{
186         .iter = test_data.iter,
187         .tag = std::string(__func__),
188     };
189     PostFunctionOnMain(
190         [](std::function<void(TestCallbackData callback_data)> callback,
191            TestCallbackData data) { callback(data); },
192         test_data.callback, callback_data);
193   }
194 };
195 
TestStackModule1()196 TestStackModule1::TestStackModule1() : TestStackModuleBase() {
197   impl_ = std::make_shared<impl>();
198 }
199 
TestMethod(TestData test_data) const200 void TestStackModule1::TestMethod(TestData test_data) const {
201   PostMethodOnMain(impl_, &impl::test, test_data);
202 }
203 
204 class TestStackModule2 : public TestStackModuleBase {
205  public:
206   TestStackModule2(const TestStackModule2&) = delete;
207   TestStackModule2& operator=(const TestStackModule2&) = delete;
208   virtual ~TestStackModule2() = default;
209 
210   static const ModuleFactory Factory;
211 
212   void TestMethod(TestData test_data) const override;
213 
214  private:
215   struct impl;
216   std::shared_ptr<impl> impl_;
217   TestStackModule2();
218 };
219 
220 struct TestStackModule2::impl : public ModuleMainloop, public ModuleJniloop {
test__anon7a4e79f10111::TestStackModule2::impl221   void test(TestData test_data) {
222     TestCallbackData callback_data{
223         .iter = test_data.iter,
224         .tag = std::string(__func__),
225     };
226     PostFunctionOnMain(
227         [](std::function<void(TestCallbackData callback_data)> callback,
228            TestCallbackData data) { callback(data); },
229         test_data.callback, callback_data);
230   }
231 };
232 
TestStackModule2()233 TestStackModule2::TestStackModule2() : TestStackModuleBase() {
234   impl_ = std::make_shared<impl>();
235 }
236 
TestMethod(TestData test_data) const237 void TestStackModule2::TestMethod(TestData test_data) const {
238   PostMethodOnMain(impl_, &impl::test, test_data);
239 }
240 
241 class TestStackModule3 : public TestStackModuleBase {
242  public:
243   TestStackModule3(const TestStackModule3&) = delete;
244   TestStackModule3& operator=(const TestStackModule3&) = delete;
245   virtual ~TestStackModule3() = default;
246 
247   static const ModuleFactory Factory;
248 
249   void TestMethod(TestData test_data) const override;
250 
251  private:
252   struct impl;
253   std::shared_ptr<impl> impl_;
254   TestStackModule3();
255 };
256 
257 struct TestStackModule3::impl : public ModuleMainloop, public ModuleJniloop {
test__anon7a4e79f10111::TestStackModule3::impl258   void test(TestData test_data) {
259     TestCallbackData callback_data{
260         .iter = test_data.iter,
261         .tag = std::string(__func__),
262     };
263     PostFunctionOnMain(
264         [](std::function<void(TestCallbackData callback_data)> callback,
265            TestCallbackData data) { callback(data); },
266         test_data.callback, callback_data);
267   }
268 };
269 
TestStackModule3()270 TestStackModule3::TestStackModule3() : TestStackModuleBase() {
271   impl_ = std::make_shared<impl>();
272 }
273 
TestMethod(TestData test_data) const274 void TestStackModule3::TestMethod(TestData test_data) const {
275   PostMethodOnMain(impl_, &impl::test, test_data);
276 }
277 
278 class TestStackModule4 : public TestStackModuleBase {
279  public:
280   TestStackModule4(const TestStackModule4&) = delete;
281   TestStackModule4& operator=(const TestStackModule3&) = delete;
282   virtual ~TestStackModule4() = default;
283 
284   static const ModuleFactory Factory;
285 
TestMethod(TestData test_data) const286   void TestMethod(TestData test_data) const override {
287     log::info("mod:{} iter:{} tag:{}", __func__, test_data.iter, test_data.tag);
288   }
289 
290  private:
291   struct impl;
292   std::shared_ptr<impl> impl_;
TestStackModule4()293   TestStackModule4() : TestStackModuleBase() {}
294 };
295 
296 struct TestStackModule4::impl : public ModuleMainloop, public ModuleJniloop {};
297 
298 }  // namespace
299 
300 const ModuleFactory TestStackModuleBase::Factory =
__anon7a4e79f10602() 301     ModuleFactory([]() { return new TestStackModuleBase(); });
302 
303 const ModuleFactory TestStackModule1::Factory =
__anon7a4e79f10702() 304     ModuleFactory([]() { return new TestStackModule1(); });
305 const ModuleFactory TestStackModule2::Factory =
__anon7a4e79f10802() 306     ModuleFactory([]() { return new TestStackModule2(); });
307 const ModuleFactory TestStackModule3::Factory =
__anon7a4e79f10902() 308     ModuleFactory([]() { return new TestStackModule3(); });
309 const ModuleFactory TestStackModule4::Factory =
__anon7a4e79f10a02() 310     ModuleFactory([]() { return new TestStackModule4(); });
311 
312 class StackWithMainThreadUnitTest : public ::testing::Test {
313  protected:
SetUp()314   void SetUp() override { main_thread_ = std::make_unique<MainThread>(); }
TearDown()315   void TearDown() override { main_thread_.reset(); }
316 
317  private:
318   std::unique_ptr<MainThread> main_thread_;
319 };
320 
321 class StackLifecycleUnitTest : public StackWithMainThreadUnitTest {
322  public:
StackManager() const323   std::shared_ptr<TestStackManager> StackManager() const {
324     return stack_manager_;
325   }
326 
327  protected:
SetUp()328   void SetUp() override {
329     StackWithMainThreadUnitTest::SetUp();
330     stack_manager_ = std::make_shared<TestStackManager>();
331   }
332 
TearDown()333   void TearDown() override {
334     stack_manager_.reset();
335     StackWithMainThreadUnitTest::TearDown();
336   }
337 
338  private:
339   std::shared_ptr<TestStackManager> stack_manager_;
340 };
341 
TEST_F(StackLifecycleUnitTest,no_modules_in_stack)342 TEST_F(StackLifecycleUnitTest, no_modules_in_stack) {
343   ASSERT_EQ(0U, StackManager()->NumModules());
344 }
345 
346 class StackLifecycleWithDefaultModulesUnitTest : public StackLifecycleUnitTest {
347  protected:
SetUp()348   void SetUp() override {
349     StackLifecycleUnitTest::SetUp();
350     StackManager()->AddModule<TestStackModule1>();
351     StackManager()->AddModule<TestStackModule2>();
352     StackManager()->AddModule<TestStackModule3>();
353     StackManager()->Start();
354     ASSERT_EQ(3U, StackManager()->NumModules());
355   }
356 
TearDown()357   void TearDown() override { StackLifecycleUnitTest::TearDown(); }
358 };
359 
360 struct CallablePostCnt {
361   size_t success{0};
362   size_t misses{0};
operator +=CallablePostCnt363   CallablePostCnt operator+=(const CallablePostCnt& post_cnt) {
364     return CallablePostCnt(
365         {success += post_cnt.success, misses += post_cnt.misses});
366   }
367 };
368 
369 // Provide a client user of the stack manager module services
370 class Client {
371  public:
Client(int id)372   Client(int id) : id_(id) {}
373   Client(const Client&) = default;
374   virtual ~Client() = default;
375 
376   // Start up the client a thread and handler
Start()377   void Start() {
378     log::info("Started client {}", id_);
379     thread_ = new os::Thread(common::StringFormat("ClientThread%d", id_),
380                              os::Thread::Priority::NORMAL);
381     handler_ = new os::Handler(thread_);
382     handler_->Post(common::BindOnce(
383         [](int id) { log::info("Started client {}", id); }, id_));
384   }
385 
386   // Ensure all the client handlers are running
Await()387   void Await() {
388     std::promise<void> promise;
389     std::future future = promise.get_future();
390     handler_->Post(
391         base::BindOnce([](std::promise<void> promise) { promise.set_value(); },
392                        std::move(promise)));
393     future.wait();
394   }
395 
396   // Post a work task on behalf of this client
Post(common::OnceClosure closure)397   void Post(common::OnceClosure closure) {
398     if (quiesced_) {
399       post_cnt_.misses++;
400       maybe_yield();
401     } else {
402       post_cnt_.success++;
403       handler_->Post(std::move(closure));
404       maybe_yield();
405     }
406   }
407 
408   // Safely prevent new work tasks from being posted
Quiesce()409   void Quiesce() {
410     if (quiesced_) return;
411     quiesced_ = true;
412     std::promise promise = std::promise<void>();
413     std::future future = promise.get_future();
414     handler_->Post(common::BindOnce(
415         [](std::promise<void> promise) { promise.set_value(); },
416         std::move(promise)));
417     future.wait_for(std::chrono::milliseconds(kSyncMainLoopTimeoutMs));
418   }
419 
420   // Stops the client and associated resources
Stop()421   void Stop() {
422     if (!quiesced_) {
423       Quiesce();
424     }
425     handler_->Clear();
426     handler_->WaitUntilStopped(
427         std::chrono::milliseconds(kWaitUntilHandlerStoppedMs));
428     delete handler_;
429     delete thread_;
430   }
431 
Id() const432   int Id() const { return id_; }
433 
GetCallablePostCnt() const434   CallablePostCnt GetCallablePostCnt() const { return post_cnt_; }
435 
Name() const436   std::string Name() const {
437     return common::StringFormat("%s%d", __func__, id_);
438   }
439 
440  private:
441   int id_{0};
442   CallablePostCnt post_cnt_{};
443   bool quiesced_{false};
444   os::Handler* handler_{nullptr};
445   os::Thread* thread_{nullptr};
446 };
447 
448 // Convenience object to handle multiple clients with logging
449 class ClientGroup {
450  public:
ClientGroup()451   ClientGroup(){};
452 
Start()453   void Start() {
454     for (auto& c : clients_) {
455       c->Start();
456     }
457     log_tag("STARTING");
458   }
459 
Await()460   void Await() {
461     for (auto& c : clients_) {
462       c->Await();
463     }
464     log_tag("STARTED");
465   }
466 
Quiesce()467   void Quiesce() {
468     log_tag("QUIESCING");
469     for (auto& c : clients_) {
470       c->Quiesce();
471     }
472     log_tag("QUIESCED");
473   }
474 
Stop()475   void Stop() {
476     for (auto& c : clients_) {
477       c->Stop();
478     }
479     log_tag("STOPPED");
480   }
481 
Dump() const482   void Dump() const {
483     for (auto& c : clients_) {
484       log::info("Callable post cnt client_id:{} success:{} misses:{}", c->Id(),
485                 c->GetCallablePostCnt().success,
486                 c->GetCallablePostCnt().misses);
487     }
488   }
489 
GetCallablePostCnt() const490   CallablePostCnt GetCallablePostCnt() const {
491     CallablePostCnt post_cnt{};
492     for (auto& c : clients_) {
493       post_cnt += c->GetCallablePostCnt();
494     }
495     return post_cnt;
496   }
497 
NumClients() const498   size_t NumClients() const { return kNumTestClients; }
499 
500   std::unique_ptr<Client> clients_[kNumTestClients] = {
501       std::make_unique<Client>(1), std::make_unique<Client>(2),
502       std::make_unique<Client>(3)};
503 };
504 
TEST_F(StackLifecycleWithDefaultModulesUnitTest,clients_start)505 TEST_F(StackLifecycleWithDefaultModulesUnitTest, clients_start) {
506   ClientGroup client_group;
507 
508   client_group.Start();
509   client_group.Await();
510 
511   // Clients are operational
512 
513   client_group.Quiesce();
514   client_group.Stop();
515 }
516 
TEST_F(StackLifecycleWithDefaultModulesUnitTest,client_using_stack_manager)517 TEST_F(StackLifecycleWithDefaultModulesUnitTest, client_using_stack_manager) {
518   ClientGroup client_group;
519   client_group.Start();
520   client_group.Await();
521 
522   for (int i = 0; i < kNumIters; i++) {
523     for (auto& c : client_group.clients_) {
524       c->Post(base::BindOnce(
525           [](int id, int iter,
526              std::shared_ptr<TestStackManager> stack_manager) {
527             stack_manager->GetUnsafeModule<TestStackModule1>()->TestMethod({
528                 .iter = iter,
529                 .tag = std::string(kTestDataTag),
530                 .callback = [](TestCallbackData data) {},
531             });
532           },
533           c->Id(), i, StackManager()));
534       c->Post(base::BindOnce(
535           [](int id, int iter,
536              std::shared_ptr<TestStackManager> stack_manager) {
537             stack_manager->GetUnsafeModule<TestStackModule2>()->TestMethod({
538                 .iter = iter,
539                 .tag = std::string(kTestDataTag),
540                 .callback = [](TestCallbackData data) {},
541             });
542           },
543           c->Id(), i, StackManager()));
544       c->Post(base::BindOnce(
545           [](int id, int iter,
546              std::shared_ptr<TestStackManager> stack_manager) {
547             stack_manager->GetUnsafeModule<TestStackModule3>()->TestMethod({
548                 .iter = iter,
549                 .tag = std::string(kTestDataTag),
550                 .callback = [](TestCallbackData data) {},
551             });
552           },
553           c->Id(), i, StackManager()));
554     }
555   }
556 
557   client_group.Quiesce();
558   client_group.Stop();
559   client_group.Dump();
560 
561   ASSERT_EQ(client_group.NumClients() * kNumIters * kNumTestModules,
562             client_group.GetCallablePostCnt().success +
563                 client_group.GetCallablePostCnt().misses);
564 }
565 
TEST_F(StackLifecycleWithDefaultModulesUnitTest,client_using_stack_manager_when_shutdown)566 TEST_F(StackLifecycleWithDefaultModulesUnitTest,
567        client_using_stack_manager_when_shutdown) {
568   struct Counters {
569     struct {
570       std::atomic_size_t cnt{0};
571     } up, down;
572   } counters;
573 
574   ClientGroup client_group;
575   client_group.Start();
576   client_group.Await();
577 
578   for (int i = 0; i < kNumIters; i++) {
579     for (auto& c : client_group.clients_) {
580       c->Post(base::BindOnce(
581           [](int id, int iter, Counters* counters,
582              std::shared_ptr<TestStackManager> stack_manager) {
583             TestData test_data = {
584                 .iter = iter,
585                 .tag = std::string(kTestDataTag),
586                 .callback = [](TestCallbackData data) {},
587             };
588             if (bluetooth::shim::Stack::GetInstance()
589                     ->CallOnModule<TestStackModule1>(
590                         [test_data](TestStackModule1* mod) {
591                           mod->TestMethod(test_data);
592                         })) {
593               counters->up.cnt++;
594             } else {
595               counters->down.cnt++;
596             }
597           },
598           c->Id(), i, &counters, StackManager()));
599       c->Post(base::BindOnce(
600           [](int id, int iter, Counters* counters,
601              std::shared_ptr<TestStackManager> stack_manager) {
602             TestData test_data = {
603                 .iter = iter,
604                 .tag = std::string(kTestDataTag),
605                 .callback = [](TestCallbackData data) {},
606             };
607             if (bluetooth::shim::Stack::GetInstance()
608                     ->CallOnModule<TestStackModule2>(
609                         [test_data](TestStackModule2* mod) {
610                           mod->TestMethod(test_data);
611                         })) {
612               counters->up.cnt++;
613             } else {
614               counters->down.cnt++;
615             }
616           },
617           c->Id(), i, &counters, StackManager()));
618       c->Post(base::BindOnce(
619           [](int id, int iter, Counters* counters,
620              std::shared_ptr<TestStackManager> stack_manager) {
621             TestData test_data = {
622                 .iter = iter,
623                 .tag = std::string(kTestDataTag),
624                 .callback = [](TestCallbackData data) {},
625             };
626             if (bluetooth::shim::Stack::GetInstance()
627                     ->CallOnModule<TestStackModule3>(
628                         [test_data](TestStackModule3* mod) {
629                           mod->TestMethod(test_data);
630                         })) {
631               counters->up.cnt++;
632             } else {
633               counters->down.cnt++;
634             }
635           },
636           c->Id(), i, &counters, StackManager()));
637     }
638     // Abruptly shutdown stack at some point through the iterations
639     if (i == kAbruptStackShutdownIter) {
640       log_tag("SHUTTING DOWN STACK");
641       StackManager()->Stop();
642     }
643   }
644 
645   client_group.Quiesce();
646   client_group.Stop();
647   log::info("Execution stack availability counters up:{} down:{}",
648             counters.up.cnt, counters.down.cnt);
649 
650   ASSERT_EQ(client_group.NumClients() * kNumIters * kNumTestModules,
651             client_group.GetCallablePostCnt().success +
652                 client_group.GetCallablePostCnt().misses);
653 }
654