1 /*
2 * Copyright 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at:
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <bluetooth/log.h>
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include <unistd.h>
21
22 #include <atomic>
23 #include <cstdlib>
24 #include <memory>
25 #include <thread>
26
27 #include "common/strings.h"
28 #include "gd/module_jniloop.h"
29 #include "gd/module_mainloop.h"
30 #include "main/shim/stack.h"
31 #include "module.h"
32 #include "os/thread.h"
33 #include "stack/include/main_thread.h"
34 #include "test/mock/mock_main_shim_entry.h"
35
36 using ::testing::_;
37
38 using namespace bluetooth;
39 using namespace testing;
40
41 namespace {
42 constexpr int kSyncMainLoopTimeoutMs = 3000;
43 constexpr int kWaitUntilHandlerStoppedMs = 2000;
44 constexpr size_t kNumTestClients = 3;
45 constexpr size_t kNumTestModules = 3;
46 constexpr int kNumIters = 100;
47 constexpr int kAbruptStackShutdownIter = kNumIters * 3 / 4;
48 constexpr char kTestStackThreadName[] = "test_stack_thread";
49 constexpr char kTestDataTag[] = "This is a test";
50
maybe_yield()51 inline void maybe_yield() {
52 if (std::rand() & 1) std::this_thread::yield();
53 }
54
55 constexpr size_t kTagLength = 48 + sizeof(' ') + sizeof(' ');
log_tag(std::string tag)56 inline void log_tag(std::string tag) {
57 std::string prepend(kTagLength / 2 - tag.size() / 2, '=');
58 std::string append(kTagLength / 2 - tag.size() / 2, '=');
59 log::info("{} {} {}", prepend, tag, append);
60 }
61
62 class MainThread {
63 public:
MainThread()64 MainThread() { main_thread_start_up(); }
65
~MainThread()66 ~MainThread() {
67 sync_main_handler();
68 main_thread_shut_down();
69 }
70
71 private:
sync_main_handler()72 void sync_main_handler() {
73 std::promise promise = std::promise<void>();
74 std::future future = promise.get_future();
75 post_on_bt_main([&promise]() { promise.set_value(); });
76 future.wait_for(std::chrono::milliseconds(kSyncMainLoopTimeoutMs));
77 }
78 };
79
80 class TestStackManager {
81 public:
TestStackManager()82 TestStackManager() {
83 // Start is executed by the test after each test adds the default
84 // or their own modules
85 }
86
~TestStackManager()87 ~TestStackManager() {
88 log::debug("Deleting stack manager");
89 Stop();
90 }
91
92 TestStackManager(const TestStackManager&) = delete;
93
94 template <typename T>
AddModule()95 void AddModule() {
96 modules_.add<T>();
97 }
98
Start()99 void Start() {
100 if (stack_started_) return;
101 log::info("Started stack manager");
102 stack_started_ = true;
103 bluetooth::os::Thread* stack_thread = new bluetooth::os::Thread(
104 kTestStackThreadName, bluetooth::os::Thread::Priority::NORMAL);
105 bluetooth::shim::Stack::GetInstance()->StartModuleStack(&modules_,
106 stack_thread);
107 }
108
Stop()109 void Stop() {
110 if (!stack_started_) return;
111 stack_started_ = false;
112 bluetooth::shim::Stack::GetInstance()->Stop();
113 }
114
115 // NOTE: Stack manager *must* be active else method returns nullptr
116 // if stack manager has not started or shutdown
117 template <typename T>
GetUnsafeModule()118 static T* GetUnsafeModule() {
119 return bluetooth::shim::Stack::GetInstance()
120 ->GetStackManager()
121 ->GetInstance<T>();
122 }
123
NumModules() const124 size_t NumModules() const { return modules_.NumModules(); }
125
126 private:
127 bluetooth::ModuleList modules_;
128 bool stack_started_{false};
129 };
130
131 // Data returned via callback from a stack managed module
132 struct TestCallbackData {
133 int iter;
134 std::string tag;
135 };
136
137 // Data sent to a stack managed module via a module API
138 struct TestData {
139 int iter;
140 std::string tag;
141 std::function<void(TestCallbackData callback_data)> callback;
142 };
143
144 class TestStackModuleBase : public bluetooth::Module,
145 public ModuleMainloop,
146 public ModuleJniloop {
147 public:
148 TestStackModuleBase(const TestStackModuleBase&) = delete;
149 TestStackModuleBase& operator=(const TestStackModuleBase&) = delete;
150
~TestStackModuleBase()151 virtual ~TestStackModuleBase(){};
152 static const ModuleFactory Factory;
153
TestMethod(TestData test_data) const154 virtual void TestMethod(TestData test_data) const {
155 log::info("Test base class iter:{} tag:{}", test_data.iter, test_data.tag);
156 }
157
158 protected:
ListDependencies(ModuleList * list) const159 void ListDependencies(ModuleList* list) const override{};
Start()160 void Start() override { log::error("Started TestStackModuleBase"); };
Stop()161 void Stop() override { log::error("Stopped TestStackModuleBase"); };
ToString() const162 std::string ToString() const override { return std::string("TestFunction"); }
163
164 TestStackModuleBase() = default;
165 };
166
167 class TestStackModule1 : public TestStackModuleBase {
168 public:
169 TestStackModule1(const TestStackModule1&) = delete;
170 TestStackModule1& operator=(const TestStackModule1&) = delete;
171 virtual ~TestStackModule1() = default;
172
173 static const ModuleFactory Factory;
174
175 void TestMethod(TestData test_data) const override;
176
177 private:
178 struct impl;
179 std::shared_ptr<impl> impl_;
180 TestStackModule1();
181 };
182
183 struct TestStackModule1::impl : public ModuleMainloop, public ModuleJniloop {
test__anon7a4e79f10111::TestStackModule1::impl184 void test(TestData test_data) {
185 TestCallbackData callback_data{
186 .iter = test_data.iter,
187 .tag = std::string(__func__),
188 };
189 PostFunctionOnMain(
190 [](std::function<void(TestCallbackData callback_data)> callback,
191 TestCallbackData data) { callback(data); },
192 test_data.callback, callback_data);
193 }
194 };
195
TestStackModule1()196 TestStackModule1::TestStackModule1() : TestStackModuleBase() {
197 impl_ = std::make_shared<impl>();
198 }
199
TestMethod(TestData test_data) const200 void TestStackModule1::TestMethod(TestData test_data) const {
201 PostMethodOnMain(impl_, &impl::test, test_data);
202 }
203
204 class TestStackModule2 : public TestStackModuleBase {
205 public:
206 TestStackModule2(const TestStackModule2&) = delete;
207 TestStackModule2& operator=(const TestStackModule2&) = delete;
208 virtual ~TestStackModule2() = default;
209
210 static const ModuleFactory Factory;
211
212 void TestMethod(TestData test_data) const override;
213
214 private:
215 struct impl;
216 std::shared_ptr<impl> impl_;
217 TestStackModule2();
218 };
219
220 struct TestStackModule2::impl : public ModuleMainloop, public ModuleJniloop {
test__anon7a4e79f10111::TestStackModule2::impl221 void test(TestData test_data) {
222 TestCallbackData callback_data{
223 .iter = test_data.iter,
224 .tag = std::string(__func__),
225 };
226 PostFunctionOnMain(
227 [](std::function<void(TestCallbackData callback_data)> callback,
228 TestCallbackData data) { callback(data); },
229 test_data.callback, callback_data);
230 }
231 };
232
TestStackModule2()233 TestStackModule2::TestStackModule2() : TestStackModuleBase() {
234 impl_ = std::make_shared<impl>();
235 }
236
TestMethod(TestData test_data) const237 void TestStackModule2::TestMethod(TestData test_data) const {
238 PostMethodOnMain(impl_, &impl::test, test_data);
239 }
240
241 class TestStackModule3 : public TestStackModuleBase {
242 public:
243 TestStackModule3(const TestStackModule3&) = delete;
244 TestStackModule3& operator=(const TestStackModule3&) = delete;
245 virtual ~TestStackModule3() = default;
246
247 static const ModuleFactory Factory;
248
249 void TestMethod(TestData test_data) const override;
250
251 private:
252 struct impl;
253 std::shared_ptr<impl> impl_;
254 TestStackModule3();
255 };
256
257 struct TestStackModule3::impl : public ModuleMainloop, public ModuleJniloop {
test__anon7a4e79f10111::TestStackModule3::impl258 void test(TestData test_data) {
259 TestCallbackData callback_data{
260 .iter = test_data.iter,
261 .tag = std::string(__func__),
262 };
263 PostFunctionOnMain(
264 [](std::function<void(TestCallbackData callback_data)> callback,
265 TestCallbackData data) { callback(data); },
266 test_data.callback, callback_data);
267 }
268 };
269
TestStackModule3()270 TestStackModule3::TestStackModule3() : TestStackModuleBase() {
271 impl_ = std::make_shared<impl>();
272 }
273
TestMethod(TestData test_data) const274 void TestStackModule3::TestMethod(TestData test_data) const {
275 PostMethodOnMain(impl_, &impl::test, test_data);
276 }
277
278 class TestStackModule4 : public TestStackModuleBase {
279 public:
280 TestStackModule4(const TestStackModule4&) = delete;
281 TestStackModule4& operator=(const TestStackModule3&) = delete;
282 virtual ~TestStackModule4() = default;
283
284 static const ModuleFactory Factory;
285
TestMethod(TestData test_data) const286 void TestMethod(TestData test_data) const override {
287 log::info("mod:{} iter:{} tag:{}", __func__, test_data.iter, test_data.tag);
288 }
289
290 private:
291 struct impl;
292 std::shared_ptr<impl> impl_;
TestStackModule4()293 TestStackModule4() : TestStackModuleBase() {}
294 };
295
296 struct TestStackModule4::impl : public ModuleMainloop, public ModuleJniloop {};
297
298 } // namespace
299
300 const ModuleFactory TestStackModuleBase::Factory =
__anon7a4e79f10602() 301 ModuleFactory([]() { return new TestStackModuleBase(); });
302
303 const ModuleFactory TestStackModule1::Factory =
__anon7a4e79f10702() 304 ModuleFactory([]() { return new TestStackModule1(); });
305 const ModuleFactory TestStackModule2::Factory =
__anon7a4e79f10802() 306 ModuleFactory([]() { return new TestStackModule2(); });
307 const ModuleFactory TestStackModule3::Factory =
__anon7a4e79f10902() 308 ModuleFactory([]() { return new TestStackModule3(); });
309 const ModuleFactory TestStackModule4::Factory =
__anon7a4e79f10a02() 310 ModuleFactory([]() { return new TestStackModule4(); });
311
312 class StackWithMainThreadUnitTest : public ::testing::Test {
313 protected:
SetUp()314 void SetUp() override { main_thread_ = std::make_unique<MainThread>(); }
TearDown()315 void TearDown() override { main_thread_.reset(); }
316
317 private:
318 std::unique_ptr<MainThread> main_thread_;
319 };
320
321 class StackLifecycleUnitTest : public StackWithMainThreadUnitTest {
322 public:
StackManager() const323 std::shared_ptr<TestStackManager> StackManager() const {
324 return stack_manager_;
325 }
326
327 protected:
SetUp()328 void SetUp() override {
329 StackWithMainThreadUnitTest::SetUp();
330 stack_manager_ = std::make_shared<TestStackManager>();
331 }
332
TearDown()333 void TearDown() override {
334 stack_manager_.reset();
335 StackWithMainThreadUnitTest::TearDown();
336 }
337
338 private:
339 std::shared_ptr<TestStackManager> stack_manager_;
340 };
341
TEST_F(StackLifecycleUnitTest,no_modules_in_stack)342 TEST_F(StackLifecycleUnitTest, no_modules_in_stack) {
343 ASSERT_EQ(0U, StackManager()->NumModules());
344 }
345
346 class StackLifecycleWithDefaultModulesUnitTest : public StackLifecycleUnitTest {
347 protected:
SetUp()348 void SetUp() override {
349 StackLifecycleUnitTest::SetUp();
350 StackManager()->AddModule<TestStackModule1>();
351 StackManager()->AddModule<TestStackModule2>();
352 StackManager()->AddModule<TestStackModule3>();
353 StackManager()->Start();
354 ASSERT_EQ(3U, StackManager()->NumModules());
355 }
356
TearDown()357 void TearDown() override { StackLifecycleUnitTest::TearDown(); }
358 };
359
360 struct CallablePostCnt {
361 size_t success{0};
362 size_t misses{0};
operator +=CallablePostCnt363 CallablePostCnt operator+=(const CallablePostCnt& post_cnt) {
364 return CallablePostCnt(
365 {success += post_cnt.success, misses += post_cnt.misses});
366 }
367 };
368
369 // Provide a client user of the stack manager module services
370 class Client {
371 public:
Client(int id)372 Client(int id) : id_(id) {}
373 Client(const Client&) = default;
374 virtual ~Client() = default;
375
376 // Start up the client a thread and handler
Start()377 void Start() {
378 log::info("Started client {}", id_);
379 thread_ = new os::Thread(common::StringFormat("ClientThread%d", id_),
380 os::Thread::Priority::NORMAL);
381 handler_ = new os::Handler(thread_);
382 handler_->Post(common::BindOnce(
383 [](int id) { log::info("Started client {}", id); }, id_));
384 }
385
386 // Ensure all the client handlers are running
Await()387 void Await() {
388 std::promise<void> promise;
389 std::future future = promise.get_future();
390 handler_->Post(
391 base::BindOnce([](std::promise<void> promise) { promise.set_value(); },
392 std::move(promise)));
393 future.wait();
394 }
395
396 // Post a work task on behalf of this client
Post(common::OnceClosure closure)397 void Post(common::OnceClosure closure) {
398 if (quiesced_) {
399 post_cnt_.misses++;
400 maybe_yield();
401 } else {
402 post_cnt_.success++;
403 handler_->Post(std::move(closure));
404 maybe_yield();
405 }
406 }
407
408 // Safely prevent new work tasks from being posted
Quiesce()409 void Quiesce() {
410 if (quiesced_) return;
411 quiesced_ = true;
412 std::promise promise = std::promise<void>();
413 std::future future = promise.get_future();
414 handler_->Post(common::BindOnce(
415 [](std::promise<void> promise) { promise.set_value(); },
416 std::move(promise)));
417 future.wait_for(std::chrono::milliseconds(kSyncMainLoopTimeoutMs));
418 }
419
420 // Stops the client and associated resources
Stop()421 void Stop() {
422 if (!quiesced_) {
423 Quiesce();
424 }
425 handler_->Clear();
426 handler_->WaitUntilStopped(
427 std::chrono::milliseconds(kWaitUntilHandlerStoppedMs));
428 delete handler_;
429 delete thread_;
430 }
431
Id() const432 int Id() const { return id_; }
433
GetCallablePostCnt() const434 CallablePostCnt GetCallablePostCnt() const { return post_cnt_; }
435
Name() const436 std::string Name() const {
437 return common::StringFormat("%s%d", __func__, id_);
438 }
439
440 private:
441 int id_{0};
442 CallablePostCnt post_cnt_{};
443 bool quiesced_{false};
444 os::Handler* handler_{nullptr};
445 os::Thread* thread_{nullptr};
446 };
447
448 // Convenience object to handle multiple clients with logging
449 class ClientGroup {
450 public:
ClientGroup()451 ClientGroup(){};
452
Start()453 void Start() {
454 for (auto& c : clients_) {
455 c->Start();
456 }
457 log_tag("STARTING");
458 }
459
Await()460 void Await() {
461 for (auto& c : clients_) {
462 c->Await();
463 }
464 log_tag("STARTED");
465 }
466
Quiesce()467 void Quiesce() {
468 log_tag("QUIESCING");
469 for (auto& c : clients_) {
470 c->Quiesce();
471 }
472 log_tag("QUIESCED");
473 }
474
Stop()475 void Stop() {
476 for (auto& c : clients_) {
477 c->Stop();
478 }
479 log_tag("STOPPED");
480 }
481
Dump() const482 void Dump() const {
483 for (auto& c : clients_) {
484 log::info("Callable post cnt client_id:{} success:{} misses:{}", c->Id(),
485 c->GetCallablePostCnt().success,
486 c->GetCallablePostCnt().misses);
487 }
488 }
489
GetCallablePostCnt() const490 CallablePostCnt GetCallablePostCnt() const {
491 CallablePostCnt post_cnt{};
492 for (auto& c : clients_) {
493 post_cnt += c->GetCallablePostCnt();
494 }
495 return post_cnt;
496 }
497
NumClients() const498 size_t NumClients() const { return kNumTestClients; }
499
500 std::unique_ptr<Client> clients_[kNumTestClients] = {
501 std::make_unique<Client>(1), std::make_unique<Client>(2),
502 std::make_unique<Client>(3)};
503 };
504
TEST_F(StackLifecycleWithDefaultModulesUnitTest,clients_start)505 TEST_F(StackLifecycleWithDefaultModulesUnitTest, clients_start) {
506 ClientGroup client_group;
507
508 client_group.Start();
509 client_group.Await();
510
511 // Clients are operational
512
513 client_group.Quiesce();
514 client_group.Stop();
515 }
516
TEST_F(StackLifecycleWithDefaultModulesUnitTest,client_using_stack_manager)517 TEST_F(StackLifecycleWithDefaultModulesUnitTest, client_using_stack_manager) {
518 ClientGroup client_group;
519 client_group.Start();
520 client_group.Await();
521
522 for (int i = 0; i < kNumIters; i++) {
523 for (auto& c : client_group.clients_) {
524 c->Post(base::BindOnce(
525 [](int id, int iter,
526 std::shared_ptr<TestStackManager> stack_manager) {
527 stack_manager->GetUnsafeModule<TestStackModule1>()->TestMethod({
528 .iter = iter,
529 .tag = std::string(kTestDataTag),
530 .callback = [](TestCallbackData data) {},
531 });
532 },
533 c->Id(), i, StackManager()));
534 c->Post(base::BindOnce(
535 [](int id, int iter,
536 std::shared_ptr<TestStackManager> stack_manager) {
537 stack_manager->GetUnsafeModule<TestStackModule2>()->TestMethod({
538 .iter = iter,
539 .tag = std::string(kTestDataTag),
540 .callback = [](TestCallbackData data) {},
541 });
542 },
543 c->Id(), i, StackManager()));
544 c->Post(base::BindOnce(
545 [](int id, int iter,
546 std::shared_ptr<TestStackManager> stack_manager) {
547 stack_manager->GetUnsafeModule<TestStackModule3>()->TestMethod({
548 .iter = iter,
549 .tag = std::string(kTestDataTag),
550 .callback = [](TestCallbackData data) {},
551 });
552 },
553 c->Id(), i, StackManager()));
554 }
555 }
556
557 client_group.Quiesce();
558 client_group.Stop();
559 client_group.Dump();
560
561 ASSERT_EQ(client_group.NumClients() * kNumIters * kNumTestModules,
562 client_group.GetCallablePostCnt().success +
563 client_group.GetCallablePostCnt().misses);
564 }
565
TEST_F(StackLifecycleWithDefaultModulesUnitTest,client_using_stack_manager_when_shutdown)566 TEST_F(StackLifecycleWithDefaultModulesUnitTest,
567 client_using_stack_manager_when_shutdown) {
568 struct Counters {
569 struct {
570 std::atomic_size_t cnt{0};
571 } up, down;
572 } counters;
573
574 ClientGroup client_group;
575 client_group.Start();
576 client_group.Await();
577
578 for (int i = 0; i < kNumIters; i++) {
579 for (auto& c : client_group.clients_) {
580 c->Post(base::BindOnce(
581 [](int id, int iter, Counters* counters,
582 std::shared_ptr<TestStackManager> stack_manager) {
583 TestData test_data = {
584 .iter = iter,
585 .tag = std::string(kTestDataTag),
586 .callback = [](TestCallbackData data) {},
587 };
588 if (bluetooth::shim::Stack::GetInstance()
589 ->CallOnModule<TestStackModule1>(
590 [test_data](TestStackModule1* mod) {
591 mod->TestMethod(test_data);
592 })) {
593 counters->up.cnt++;
594 } else {
595 counters->down.cnt++;
596 }
597 },
598 c->Id(), i, &counters, StackManager()));
599 c->Post(base::BindOnce(
600 [](int id, int iter, Counters* counters,
601 std::shared_ptr<TestStackManager> stack_manager) {
602 TestData test_data = {
603 .iter = iter,
604 .tag = std::string(kTestDataTag),
605 .callback = [](TestCallbackData data) {},
606 };
607 if (bluetooth::shim::Stack::GetInstance()
608 ->CallOnModule<TestStackModule2>(
609 [test_data](TestStackModule2* mod) {
610 mod->TestMethod(test_data);
611 })) {
612 counters->up.cnt++;
613 } else {
614 counters->down.cnt++;
615 }
616 },
617 c->Id(), i, &counters, StackManager()));
618 c->Post(base::BindOnce(
619 [](int id, int iter, Counters* counters,
620 std::shared_ptr<TestStackManager> stack_manager) {
621 TestData test_data = {
622 .iter = iter,
623 .tag = std::string(kTestDataTag),
624 .callback = [](TestCallbackData data) {},
625 };
626 if (bluetooth::shim::Stack::GetInstance()
627 ->CallOnModule<TestStackModule3>(
628 [test_data](TestStackModule3* mod) {
629 mod->TestMethod(test_data);
630 })) {
631 counters->up.cnt++;
632 } else {
633 counters->down.cnt++;
634 }
635 },
636 c->Id(), i, &counters, StackManager()));
637 }
638 // Abruptly shutdown stack at some point through the iterations
639 if (i == kAbruptStackShutdownIter) {
640 log_tag("SHUTTING DOWN STACK");
641 StackManager()->Stop();
642 }
643 }
644
645 client_group.Quiesce();
646 client_group.Stop();
647 log::info("Execution stack availability counters up:{} down:{}",
648 counters.up.cnt, counters.down.cnt);
649
650 ASSERT_EQ(client_group.NumClients() * kNumIters * kNumTestModules,
651 client_group.GetCallablePostCnt().success +
652 client_group.GetCallablePostCnt().misses);
653 }
654