1 /******************************************************************************
2 *
3 * Copyright 2020 Google, Inc.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at:
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 ******************************************************************************/
18
19 #include "os/internal/wakelock_native.h"
20
21 #include <aidl/android/system/suspend/BnSuspendCallback.h>
22 #include <aidl/android/system/suspend/BnWakelockCallback.h>
23 #include <aidl/android/system/suspend/ISuspendControlService.h>
24 #include <android/binder_auto_utils.h>
25 #include <android/binder_interface_utils.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <gtest/gtest.h>
29
30 #include <chrono>
31 #include <future>
32 #include <memory>
33 #include <mutex>
34
35 namespace testing {
36
37 using aidl::android::system::suspend::BnSuspendCallback;
38 using aidl::android::system::suspend::BnWakelockCallback;
39 using aidl::android::system::suspend::ISuspendControlService;
40 using bluetooth::os::internal::WakelockNative;
41 using ndk::ScopedAStatus;
42 using ndk::SharedRefBase;
43 using ndk::SpAIBinder;
44
45 static const std::string kTestWakelockName = "BtWakelockNativeTestLock";
46
47 static std::recursive_mutex mutex;
48 static std::unique_ptr<std::promise<void>> acquire_promise = nullptr;
49 static std::unique_ptr<std::promise<void>> release_promise = nullptr;
50
51 class PromiseFutureContext {
52 public:
FulfilPromise(std::unique_ptr<std::promise<void>> & promise)53 static void FulfilPromise(std::unique_ptr<std::promise<void>>& promise) {
54 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
55 if (promise != nullptr) {
56 std::promise<void>* prom = promise.release();
57 prom->set_value();
58 delete prom;
59 }
60 }
61
PromiseFutureContext(std::unique_ptr<std::promise<void>> & promise,bool expect_fulfillment)62 explicit PromiseFutureContext(std::unique_ptr<std::promise<void>>& promise, bool expect_fulfillment)
63 : promise_(promise), expect_fulfillment_(expect_fulfillment) {
64 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
65 EXPECT_EQ(promise_, nullptr);
66 promise_ = std::make_unique<std::promise<void>>();
67 future_ = promise->get_future();
68 }
69
~PromiseFutureContext()70 ~PromiseFutureContext() {
71 auto future_status = future_.wait_for(std::chrono::seconds(2));
72 if (expect_fulfillment_) {
73 EXPECT_EQ(future_status, std::future_status::ready);
74 } else {
75 EXPECT_NE(future_status, std::future_status::ready);
76 }
77 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
78 promise_ = nullptr;
79 }
80
81 private:
82 std::unique_ptr<std::promise<void>>& promise_;
83 bool expect_fulfillment_ = true;
84 std::future<void> future_;
85 };
86
87 class WakelockCallback : public BnWakelockCallback {
88 public:
notifyAcquired()89 ScopedAStatus notifyAcquired() override {
90 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
91 net_acquired_count++;
92 fprintf(stderr, "notifyAcquired, count = %d\n", net_acquired_count);
93 PromiseFutureContext::FulfilPromise(acquire_promise);
94 return ScopedAStatus::ok();
95 }
notifyReleased()96 ScopedAStatus notifyReleased() override {
97 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
98 net_acquired_count--;
99 fprintf(stderr, "notifyReleased, count = %d\n", net_acquired_count);
100 PromiseFutureContext::FulfilPromise(release_promise);
101 return ScopedAStatus::ok();
102 }
103
104 int net_acquired_count = 0;
105 };
106
107 class SuspendCallback : public BnSuspendCallback {
108 public:
notifyWakeup(bool success,const std::vector<std::string> & wakeup_reasons)109 ScopedAStatus notifyWakeup(bool success, const std::vector<std::string>& wakeup_reasons) override {
110 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
111 fprintf(stderr, "notifyWakeup\n");
112 return ScopedAStatus::ok();
113 }
114 };
115
116 // There is no way to unregister these callbacks besides when this process dies
117 // Hence, we want to have only one copy of these callbacks per process
118 static std::shared_ptr<SuspendCallback> suspend_callback = nullptr;
119 static std::shared_ptr<WakelockCallback> control_callback = nullptr;
120
121 class WakelockNativeTest : public Test {
122 protected:
SetUp()123 void SetUp() override {
124 ABinderProcess_setThreadPoolMaxThreadCount(1);
125 ABinderProcess_startThreadPool();
126
127 WakelockNative::Get().Initialize();
128
129 auto binder_raw = AServiceManager_waitForService("suspend_control");
130 ASSERT_NE(binder_raw, nullptr);
131 binder.set(binder_raw);
132 control_service_ = ISuspendControlService::fromBinder(binder);
133 if (control_service_ == nullptr) {
134 FAIL() << "Fail to obtain suspend_control";
135 }
136
137 if (suspend_callback == nullptr) {
138 suspend_callback = SharedRefBase::make<SuspendCallback>();
139 bool is_registered = false;
140 ScopedAStatus status = control_service_->registerCallback(suspend_callback, &is_registered);
141 if (!is_registered || !status.isOk()) {
142 FAIL() << "Fail to register suspend callback";
143 }
144 }
145
146 if (control_callback == nullptr) {
147 control_callback = SharedRefBase::make<WakelockCallback>();
148 bool is_registered = false;
149 ScopedAStatus status =
150 control_service_->registerWakelockCallback(control_callback, kTestWakelockName, &is_registered);
151 if (!is_registered || !status.isOk()) {
152 FAIL() << "Fail to register wakeup callback";
153 }
154 }
155 control_callback->net_acquired_count = 0;
156 }
157
TearDown()158 void TearDown() override {
159 control_service_ = nullptr;
160 binder.set(nullptr);
161 WakelockNative::Get().CleanUp();
162 }
163
164 SpAIBinder binder;
165 std::shared_ptr<ISuspendControlService> control_service_ = nullptr;
166 };
167
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks)168 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks) {
169 ASSERT_EQ(control_callback->net_acquired_count, 0);
170
171 {
172 PromiseFutureContext context(acquire_promise, true);
173 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
174 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
175 }
176 ASSERT_EQ(control_callback->net_acquired_count, 1);
177
178 {
179 PromiseFutureContext context(release_promise, true);
180 auto status = WakelockNative::Get().Release(kTestWakelockName);
181 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
182 }
183 ASSERT_EQ(control_callback->net_acquired_count, 0);
184 }
185
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_acquire)186 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_acquire) {
187 ASSERT_EQ(control_callback->net_acquired_count, 0);
188
189 {
190 PromiseFutureContext context(acquire_promise, true);
191 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
192 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
193 }
194 ASSERT_EQ(control_callback->net_acquired_count, 1);
195
196 {
197 PromiseFutureContext context(acquire_promise, false);
198 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
199 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
200 }
201 ASSERT_EQ(control_callback->net_acquired_count, 1);
202
203 {
204 PromiseFutureContext context(release_promise, true);
205 auto status = WakelockNative::Get().Release(kTestWakelockName);
206 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
207 }
208 ASSERT_EQ(control_callback->net_acquired_count, 0);
209 }
210
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_release)211 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_release) {
212 ASSERT_EQ(control_callback->net_acquired_count, 0);
213
214 {
215 PromiseFutureContext context(acquire_promise, true);
216 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
217 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
218 }
219 ASSERT_EQ(control_callback->net_acquired_count, 1);
220
221 {
222 PromiseFutureContext context(release_promise, true);
223 auto status = WakelockNative::Get().Release(kTestWakelockName);
224 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
225 }
226 ASSERT_EQ(control_callback->net_acquired_count, 0);
227
228 {
229 PromiseFutureContext context(release_promise, false);
230 auto status = WakelockNative::Get().Release(kTestWakelockName);
231 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
232 }
233 ASSERT_EQ(control_callback->net_acquired_count, 0);
234 }
235
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_in_a_loop)236 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_in_a_loop) {
237 ASSERT_EQ(control_callback->net_acquired_count, 0);
238
239 for (int i = 0; i < 10; ++i) {
240 {
241 PromiseFutureContext context(acquire_promise, true);
242 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
243 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
244 }
245 ASSERT_EQ(control_callback->net_acquired_count, 1);
246
247 {
248 PromiseFutureContext context(release_promise, true);
249 auto status = WakelockNative::Get().Release(kTestWakelockName);
250 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
251 }
252 ASSERT_EQ(control_callback->net_acquired_count, 0);
253 }
254 }
255
TEST_F(WakelockNativeTest,test_clean_up)256 TEST_F(WakelockNativeTest, test_clean_up) {
257 WakelockNative::Get().Initialize();
258 ASSERT_EQ(control_callback->net_acquired_count, 0);
259
260 {
261 PromiseFutureContext context(acquire_promise, true);
262 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
263 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
264 }
265 ASSERT_EQ(control_callback->net_acquired_count, 1);
266
267 {
268 PromiseFutureContext context(release_promise, true);
269 WakelockNative::Get().CleanUp();
270 }
271 ASSERT_EQ(control_callback->net_acquired_count, 0);
272 }
273
274 } // namespace testing