1 /******************************************************************************
2 *
3 * Copyright 2020 Google, Inc.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at:
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 ******************************************************************************/
18
19 #include "os/internal/wakelock_native.h"
20
21 #include <aidl/android/system/suspend/BnSuspendCallback.h>
22 #include <aidl/android/system/suspend/BnWakelockCallback.h>
23 #include <aidl/android/system/suspend/ISuspendControlService.h>
24 #include <android/binder_auto_utils.h>
25 #include <android/binder_interface_utils.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <gtest/gtest.h>
29
30 #include <chrono>
31 #include <future>
32 #include <memory>
33 #include <mutex>
34
35 namespace testing {
36
37 using aidl::android::system::suspend::BnSuspendCallback;
38 using aidl::android::system::suspend::BnWakelockCallback;
39 using aidl::android::system::suspend::ISuspendControlService;
40 using bluetooth::os::internal::WakelockNative;
41 using ndk::ScopedAStatus;
42 using ndk::SharedRefBase;
43 using ndk::SpAIBinder;
44
45 static const std::string kTestWakelockName = "BtWakelockNativeTestLock";
46
47 static std::recursive_mutex mutex;
48 static std::unique_ptr<std::promise<void>> acquire_promise = nullptr;
49 static std::unique_ptr<std::promise<void>> release_promise = nullptr;
50
51 class PromiseFutureContext {
52 public:
FulfilPromise(std::unique_ptr<std::promise<void>> & promise)53 static void FulfilPromise(std::unique_ptr<std::promise<void>>& promise) {
54 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
55 if (promise != nullptr) {
56 std::promise<void>* prom = promise.release();
57 prom->set_value();
58 delete prom;
59 }
60 }
61
PromiseFutureContext(std::unique_ptr<std::promise<void>> & promise,bool expect_fulfillment)62 explicit PromiseFutureContext(std::unique_ptr<std::promise<void>>& promise, bool expect_fulfillment)
63 : promise_(promise), expect_fulfillment_(expect_fulfillment) {
64 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
65 EXPECT_EQ(promise_, nullptr);
66 promise_ = std::make_unique<std::promise<void>>();
67 future_ = promise->get_future();
68 }
69
~PromiseFutureContext()70 ~PromiseFutureContext() {
71 auto future_status = future_.wait_for(std::chrono::seconds(2));
72 if (expect_fulfillment_) {
73 EXPECT_EQ(future_status, std::future_status::ready);
74 } else {
75 EXPECT_NE(future_status, std::future_status::ready);
76 }
77 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
78 promise_ = nullptr;
79 }
80
81 private:
82 std::unique_ptr<std::promise<void>>& promise_;
83 bool expect_fulfillment_ = true;
84 std::future<void> future_;
85 };
86
87 class WakelockCallback : public BnWakelockCallback {
88 public:
notifyAcquired()89 ScopedAStatus notifyAcquired() override {
90 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
91 net_acquired_count++;
92 fprintf(stderr, "notifyAcquired, count = %d\n", net_acquired_count);
93 PromiseFutureContext::FulfilPromise(acquire_promise);
94 return ScopedAStatus::ok();
95 }
notifyReleased()96 ScopedAStatus notifyReleased() override {
97 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
98 net_acquired_count--;
99 fprintf(stderr, "notifyReleased, count = %d\n", net_acquired_count);
100 PromiseFutureContext::FulfilPromise(release_promise);
101 return ScopedAStatus::ok();
102 }
103
104 int net_acquired_count = 0;
105 };
106
107 class SuspendCallback : public BnSuspendCallback {
108 public:
notifyWakeup(bool,const std::vector<std::string> &)109 ScopedAStatus notifyWakeup(
110 bool /* success */, const std::vector<std::string>& /* wakeup_reasons */) override {
111 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
112 fprintf(stderr, "notifyWakeup\n");
113 return ScopedAStatus::ok();
114 }
115 };
116
117 // There is no way to unregister these callbacks besides when this process dies
118 // Hence, we want to have only one copy of these callbacks per process
119 static std::shared_ptr<SuspendCallback> suspend_callback = nullptr;
120 static std::shared_ptr<WakelockCallback> control_callback = nullptr;
121
122 class WakelockNativeTest : public Test {
123 protected:
SetUp()124 void SetUp() override {
125 ABinderProcess_setThreadPoolMaxThreadCount(1);
126 ABinderProcess_startThreadPool();
127
128 WakelockNative::Get().Initialize();
129
130 auto binder_raw = AServiceManager_waitForService("suspend_control");
131 ASSERT_NE(binder_raw, nullptr);
132 binder.set(binder_raw);
133 control_service_ = ISuspendControlService::fromBinder(binder);
134 if (control_service_ == nullptr) {
135 FAIL() << "Fail to obtain suspend_control";
136 }
137
138 if (suspend_callback == nullptr) {
139 suspend_callback = SharedRefBase::make<SuspendCallback>();
140 bool is_registered = false;
141 ScopedAStatus status = control_service_->registerCallback(suspend_callback, &is_registered);
142 if (!is_registered || !status.isOk()) {
143 FAIL() << "Fail to register suspend callback";
144 }
145 }
146
147 if (control_callback == nullptr) {
148 control_callback = SharedRefBase::make<WakelockCallback>();
149 bool is_registered = false;
150 ScopedAStatus status =
151 control_service_->registerWakelockCallback(control_callback, kTestWakelockName, &is_registered);
152 if (!is_registered || !status.isOk()) {
153 FAIL() << "Fail to register wakeup callback";
154 }
155 }
156 control_callback->net_acquired_count = 0;
157 }
158
TearDown()159 void TearDown() override {
160 control_service_ = nullptr;
161 binder.set(nullptr);
162 WakelockNative::Get().CleanUp();
163 }
164
165 SpAIBinder binder;
166 std::shared_ptr<ISuspendControlService> control_service_ = nullptr;
167 };
168
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks)169 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks) {
170 ASSERT_EQ(control_callback->net_acquired_count, 0);
171
172 {
173 PromiseFutureContext context(acquire_promise, true);
174 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
175 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
176 }
177 ASSERT_EQ(control_callback->net_acquired_count, 1);
178
179 {
180 PromiseFutureContext context(release_promise, true);
181 auto status = WakelockNative::Get().Release(kTestWakelockName);
182 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
183 }
184 ASSERT_EQ(control_callback->net_acquired_count, 0);
185 }
186
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_acquire)187 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_acquire) {
188 ASSERT_EQ(control_callback->net_acquired_count, 0);
189
190 {
191 PromiseFutureContext context(acquire_promise, true);
192 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
193 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
194 }
195 ASSERT_EQ(control_callback->net_acquired_count, 1);
196
197 {
198 PromiseFutureContext context(acquire_promise, false);
199 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
200 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
201 }
202 ASSERT_EQ(control_callback->net_acquired_count, 1);
203
204 {
205 PromiseFutureContext context(release_promise, true);
206 auto status = WakelockNative::Get().Release(kTestWakelockName);
207 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
208 }
209 ASSERT_EQ(control_callback->net_acquired_count, 0);
210 }
211
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_release)212 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_release) {
213 ASSERT_EQ(control_callback->net_acquired_count, 0);
214
215 {
216 PromiseFutureContext context(acquire_promise, true);
217 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
218 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
219 }
220 ASSERT_EQ(control_callback->net_acquired_count, 1);
221
222 {
223 PromiseFutureContext context(release_promise, true);
224 auto status = WakelockNative::Get().Release(kTestWakelockName);
225 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
226 }
227 ASSERT_EQ(control_callback->net_acquired_count, 0);
228
229 {
230 PromiseFutureContext context(release_promise, false);
231 auto status = WakelockNative::Get().Release(kTestWakelockName);
232 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
233 }
234 ASSERT_EQ(control_callback->net_acquired_count, 0);
235 }
236
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_in_a_loop)237 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_in_a_loop) {
238 ASSERT_EQ(control_callback->net_acquired_count, 0);
239
240 for (int i = 0; i < 10; ++i) {
241 {
242 PromiseFutureContext context(acquire_promise, true);
243 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
244 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
245 }
246 ASSERT_EQ(control_callback->net_acquired_count, 1);
247
248 {
249 PromiseFutureContext context(release_promise, true);
250 auto status = WakelockNative::Get().Release(kTestWakelockName);
251 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
252 }
253 ASSERT_EQ(control_callback->net_acquired_count, 0);
254 }
255 }
256
TEST_F(WakelockNativeTest,test_clean_up)257 TEST_F(WakelockNativeTest, test_clean_up) {
258 WakelockNative::Get().Initialize();
259 ASSERT_EQ(control_callback->net_acquired_count, 0);
260
261 {
262 PromiseFutureContext context(acquire_promise, true);
263 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
264 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
265 }
266 ASSERT_EQ(control_callback->net_acquired_count, 1);
267
268 {
269 PromiseFutureContext context(release_promise, true);
270 WakelockNative::Get().CleanUp();
271 }
272 ASSERT_EQ(control_callback->net_acquired_count, 0);
273 }
274
275 } // namespace testing
276