1 /*
2 * Copyright 2021 HIMSA II K/S - www.himsa.com.
3 * Represented by EHIMA - www.ehima.com
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 #include <base/logging.h>
19
20 #include <algorithm>
21 #include <limits>
22 #include <map>
23 #include <mutex>
24 #include <unordered_set>
25
26 #include "bta_groups.h"
27 #include "btif_profile_storage.h"
28 #include "types/bluetooth/uuid.h"
29 #include "types/raw_address.h"
30
31 using bluetooth::Uuid;
32
33 namespace bluetooth {
34 namespace groups {
35
36 class DeviceGroupsImpl;
37 DeviceGroupsImpl* instance;
38 std::mutex instance_mutex;
39 static constexpr int kMaxGroupId = 0xEF;
40
41 class DeviceGroup {
42 public:
DeviceGroup(int group_id,Uuid uuid)43 DeviceGroup(int group_id, Uuid uuid)
44 : group_id_(group_id), group_uuid_(uuid) {}
Add(const RawAddress & addr)45 void Add(const RawAddress& addr) { devices_.insert(addr); }
Remove(const RawAddress & addr)46 void Remove(const RawAddress& addr) { devices_.erase(addr); }
Contains(const RawAddress & addr) const47 bool Contains(const RawAddress& addr) const {
48 return (devices_.count(addr) != 0);
49 }
50
ForEachDevice(std::function<void (const RawAddress &)> cb) const51 void ForEachDevice(std::function<void(const RawAddress&)> cb) const {
52 for (auto const& addr : devices_) {
53 cb(addr);
54 }
55 }
56
Size(void) const57 int Size(void) const { return devices_.size(); }
GetGroupId(void) const58 int GetGroupId(void) const { return group_id_; }
GetUuid(void) const59 const Uuid& GetUuid(void) const { return group_uuid_; }
60
61 private:
62 friend std::ostream& operator<<(std::ostream& out,
63 const bluetooth::groups::DeviceGroup& value);
64 int group_id_;
65 Uuid group_uuid_;
66 std::unordered_set<RawAddress> devices_;
67 };
68
69 class DeviceGroupsImpl : public DeviceGroups {
70 static constexpr uint8_t GROUP_STORAGE_CURRENT_LAYOUT_MAGIC = 0x10;
71 static constexpr size_t GROUP_STORAGE_HEADER_SZ =
72 sizeof(GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) +
73 sizeof(uint8_t); /* num_of_groups */
74 static constexpr size_t GROUP_STORAGE_ENTRY_SZ =
75 sizeof(uint8_t) /* group_id */ + Uuid::kNumBytes128;
76
77 public:
DeviceGroupsImpl(DeviceGroupsCallbacks * callbacks)78 DeviceGroupsImpl(DeviceGroupsCallbacks* callbacks) {
79 AddCallbacks(callbacks);
80 btif_storage_load_bonded_groups();
81 }
82
GetGroupId(const RawAddress & addr,Uuid uuid) const83 int GetGroupId(const RawAddress& addr, Uuid uuid) const override {
84 for (const auto& [id, g] : groups_) {
85 if ((g.Contains(addr)) && (uuid == g.GetUuid())) return id;
86 }
87 return kGroupUnknown;
88 }
89
add_to_group(const RawAddress & addr,DeviceGroup * group)90 void add_to_group(const RawAddress& addr, DeviceGroup* group) {
91 group->Add(addr);
92
93 bool first_device_in_group = (group->Size() == 1);
94
95 for (auto c : callbacks_) {
96 if (first_device_in_group) {
97 c->OnGroupAdded(addr, group->GetUuid(), group->GetGroupId());
98 } else {
99 c->OnGroupMemberAdded(addr, group->GetGroupId());
100 }
101 }
102 }
103
AddDevice(const RawAddress & addr,Uuid uuid,int group_id)104 int AddDevice(const RawAddress& addr, Uuid uuid, int group_id) override {
105 DeviceGroup* group = nullptr;
106
107 if (group_id == kGroupUnknown) {
108 auto gid = GetGroupId(addr, uuid);
109 if (gid != kGroupUnknown) return gid;
110 group = create_group(uuid);
111 } else {
112 group = get_or_create_group_with_id(group_id, uuid);
113 if (!group) {
114 return kGroupUnknown;
115 }
116 }
117
118 LOG_ASSERT(group);
119
120 if (group->Contains(addr)) {
121 LOG(ERROR) << __func__ << " device " << ADDRESS_TO_LOGGABLE_STR(addr)
122 << " already in the group: " << group_id;
123 return group->GetGroupId();
124 }
125
126 add_to_group(addr, group);
127
128 btif_storage_add_groups(addr);
129 return group->GetGroupId();
130 }
131
RemoveDevice(const RawAddress & addr,int group_id)132 void RemoveDevice(const RawAddress& addr, int group_id) override {
133 int num_of_groups_dev_belongs = 0;
134
135 /* Remove from all the groups. Usually happens on unbond */
136 for (auto it = groups_.begin(); it != groups_.end();) {
137 auto& [id, g] = *it;
138 if (!g.Contains(addr)) {
139 ++it;
140 continue;
141 }
142
143 num_of_groups_dev_belongs++;
144
145 if ((group_id != bluetooth::groups::kGroupUnknown) && (group_id != id)) {
146 ++it;
147 continue;
148 }
149
150 num_of_groups_dev_belongs--;
151
152 g.Remove(addr);
153 for (auto c : callbacks_) {
154 c->OnGroupMemberRemoved(addr, id);
155 }
156
157 if (g.Size() == 0) {
158 for (auto c : callbacks_) {
159 c->OnGroupRemoved(g.GetUuid(), g.GetGroupId());
160 }
161 it = groups_.erase(it);
162 } else {
163 ++it;
164 }
165 }
166
167 btif_storage_remove_groups(addr);
168 if (num_of_groups_dev_belongs > 0) {
169 btif_storage_add_groups(addr);
170 }
171 }
172
SerializeGroups(const RawAddress & addr,std::vector<uint8_t> & out) const173 bool SerializeGroups(const RawAddress& addr,
174 std::vector<uint8_t>& out) const {
175 auto num_groups = std::count_if(
176 groups_.begin(), groups_.end(), [&addr](auto& id_group_pair) {
177 return id_group_pair.second.Contains(addr);
178 });
179 if ((num_groups == 0) || (num_groups > std::numeric_limits<uint8_t>::max()))
180 return false;
181
182 out.resize(GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ));
183 auto* ptr = out.data();
184
185 /* header */
186 UINT8_TO_STREAM(ptr, GROUP_STORAGE_CURRENT_LAYOUT_MAGIC);
187 UINT8_TO_STREAM(ptr, num_groups);
188
189 /* group entries */
190 for (const auto& [id, g] : groups_) {
191 if (g.Contains(addr)) {
192 UINT8_TO_STREAM(ptr, id);
193
194 Uuid::UUID128Bit uuid128 = g.GetUuid().To128BitLE();
195 memcpy(ptr, uuid128.data(), Uuid::kNumBytes128);
196 ptr += Uuid::kNumBytes128;
197 }
198 }
199
200 return true;
201 }
202
DeserializeGroups(const RawAddress & addr,const std::vector<uint8_t> & in)203 void DeserializeGroups(const RawAddress& addr,
204 const std::vector<uint8_t>& in) {
205 if (in.size() < GROUP_STORAGE_HEADER_SZ + GROUP_STORAGE_ENTRY_SZ) return;
206
207 auto* ptr = in.data();
208
209 uint8_t magic;
210 STREAM_TO_UINT8(magic, ptr);
211
212 if (magic == GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) {
213 uint8_t num_groups;
214 STREAM_TO_UINT8(num_groups, ptr);
215
216 if (in.size() <
217 GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ)) {
218 LOG(ERROR) << "Invalid persistent storage data";
219 return;
220 }
221
222 /* group entries */
223 while (num_groups--) {
224 uint8_t id;
225 STREAM_TO_UINT8(id, ptr);
226
227 Uuid::UUID128Bit uuid128;
228 STREAM_TO_ARRAY(uuid128.data(), ptr, (int)Uuid::kNumBytes128);
229
230 auto* group =
231 get_or_create_group_with_id(id, Uuid::From128BitLE(uuid128));
232 if (group) add_to_group(addr, group);
233
234 for (auto c : callbacks_) {
235 c->OnGroupAddFromStorage(addr, Uuid::From128BitLE(uuid128), id);
236 }
237 }
238 }
239 }
240
AddCallbacks(DeviceGroupsCallbacks * callbacks)241 void AddCallbacks(DeviceGroupsCallbacks* callbacks) {
242 callbacks_.push_back(std::move(callbacks));
243
244 /* Notify new user about known groups */
245 for (const auto& [id, g] : groups_) {
246 auto group_uuid = g.GetUuid();
247 auto group_id = g.GetGroupId();
248 g.ForEachDevice([&](auto& dev) {
249 callbacks->OnGroupAdded(dev, group_uuid, group_id);
250 });
251 }
252 }
253
Clear(DeviceGroupsCallbacks * callbacks)254 bool Clear(DeviceGroupsCallbacks* callbacks) {
255 auto it = find_if(callbacks_.begin(), callbacks_.end(),
256 [callbacks](auto c) { return c == callbacks; });
257
258 if (it != callbacks_.end()) callbacks_.erase(it);
259
260 if (callbacks_.size() != 0) {
261 return false;
262 }
263 /* When all clients were unregistered */
264 groups_.clear();
265 return true;
266 }
267
Dump(int fd)268 void Dump(int fd) {
269 std::stringstream stream;
270
271 stream << " Num. registered clients: " << callbacks_.size() << std::endl;
272 stream << " Groups:\n";
273 for (const auto& kv_pair : groups_) {
274 stream << kv_pair.second << std::endl;
275 }
276
277 dprintf(fd, "%s", stream.str().c_str());
278 }
279
280 private:
find_device_group(int group_id)281 DeviceGroup* find_device_group(int group_id) {
282 return groups_.count(group_id) ? &groups_.at(group_id) : nullptr;
283 }
284
get_or_create_group_with_id(int group_id,Uuid uuid)285 DeviceGroup* get_or_create_group_with_id(int group_id, Uuid uuid) {
286 auto group = find_device_group(group_id);
287 if (group) {
288 if (group->GetUuid() != uuid) {
289 LOG(ERROR) << __func__ << " group " << group_id
290 << " exists but for different uuid: " << group->GetUuid()
291 << ", user request uuid: " << uuid;
292 return nullptr;
293 }
294
295 LOG(INFO) << __func__ << " group already exists: " << group_id;
296 return group;
297 }
298
299 DeviceGroup new_group(group_id, uuid);
300 groups_.insert({group_id, std::move(new_group)});
301
302 return &groups_.at(group_id);
303 }
304
create_group(Uuid & uuid)305 DeviceGroup* create_group(Uuid& uuid) {
306 /* Generate new group id and return empty group */
307 /* Find first free id */
308
309 int group_id = -1;
310 for (int i = 1; i < kMaxGroupId; i++) {
311 if (groups_.count(i) == 0) {
312 group_id = i;
313 break;
314 }
315 }
316
317 if (group_id < 0) {
318 LOG(ERROR) << __func__ << " too many groups";
319 return nullptr;
320 }
321
322 DeviceGroup group(group_id, uuid);
323 groups_.insert({group_id, std::move(group)});
324
325 return &groups_.at(group_id);
326 }
327
328 std::map<int, DeviceGroup> groups_;
329 std::list<DeviceGroupsCallbacks*> callbacks_;
330 };
331
Initialize(DeviceGroupsCallbacks * callbacks)332 void DeviceGroups::Initialize(DeviceGroupsCallbacks* callbacks) {
333 std::scoped_lock<std::mutex> lock(instance_mutex);
334 if (instance == nullptr) {
335 instance = new DeviceGroupsImpl(callbacks);
336 return;
337 }
338
339 instance->AddCallbacks(callbacks);
340 }
341
AddFromStorage(const RawAddress & addr,const std::vector<uint8_t> & in)342 void DeviceGroups::AddFromStorage(const RawAddress& addr,
343 const std::vector<uint8_t>& in) {
344 if (!instance) {
345 LOG(ERROR) << __func__ << ": Not initialized yet";
346 return;
347 }
348
349 instance->DeserializeGroups(addr, in);
350 }
351
GetForStorage(const RawAddress & addr,std::vector<uint8_t> & out)352 bool DeviceGroups::GetForStorage(const RawAddress& addr,
353 std::vector<uint8_t>& out) {
354 if (!instance) {
355 LOG(ERROR) << __func__ << ": Not initialized yet";
356 return false;
357 }
358
359 return instance->SerializeGroups(addr, out);
360 }
361
CleanUp(DeviceGroupsCallbacks * callbacks)362 void DeviceGroups::CleanUp(DeviceGroupsCallbacks* callbacks) {
363 std::scoped_lock<std::mutex> lock(instance_mutex);
364 if (!instance) return;
365
366 if (instance->Clear(callbacks)) {
367 delete (instance);
368 instance = nullptr;
369 }
370 }
371
operator <<(std::ostream & out,bluetooth::groups::DeviceGroup const & group)372 std::ostream& operator<<(std::ostream& out,
373 bluetooth::groups::DeviceGroup const& group) {
374 out << " == Group id: " << group.group_id_ << " == \n"
375 << " Uuid: " << group.group_uuid_ << std::endl;
376 out << " Devices:\n";
377 for (auto const& addr : group.devices_) {
378 out << " " << ADDRESS_TO_LOGGABLE_STR(addr) << std::endl;
379 }
380 return out;
381 }
382
DebugDump(int fd)383 void DeviceGroups::DebugDump(int fd) {
384 std::scoped_lock<std::mutex> lock(instance_mutex);
385 dprintf(fd, "Device Groups Manager:\n");
386 if (instance)
387 instance->Dump(fd);
388 else
389 dprintf(fd, " Not initialized \n");
390 }
391
Get()392 DeviceGroups* DeviceGroups::Get() { return instance; }
393
394 } // namespace groups
395 } // namespace bluetooth
396