1 /*
2 * Copyright 2021 HIMSA II K/S - www.himsa.com.
3 * Represented by EHIMA - www.ehima.com
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 #include <base/logging.h>
19
20 #include <algorithm>
21 #include <limits>
22 #include <map>
23 #include <unordered_set>
24
25 #include "bta_groups.h"
26 #include "btif_storage.h"
27 #include "types/bluetooth/uuid.h"
28 #include "types/raw_address.h"
29
30 using bluetooth::Uuid;
31
32 namespace bluetooth {
33 namespace groups {
34
35 class DeviceGroupsImpl;
36 DeviceGroupsImpl* instance;
37 static constexpr int kMaxGroupId = 0xEF;
38
39 class DeviceGroup {
40 public:
DeviceGroup(int group_id,Uuid uuid)41 DeviceGroup(int group_id, Uuid uuid)
42 : group_id_(group_id), group_uuid_(uuid) {}
Add(const RawAddress & addr)43 void Add(const RawAddress& addr) { devices_.insert(addr); }
Remove(const RawAddress & addr)44 void Remove(const RawAddress& addr) { devices_.erase(addr); }
Contains(const RawAddress & addr) const45 bool Contains(const RawAddress& addr) const {
46 return (devices_.count(addr) != 0);
47 }
48
ForEachDevice(std::function<void (const RawAddress &)> cb) const49 void ForEachDevice(std::function<void(const RawAddress&)> cb) const {
50 for (auto const& addr : devices_) {
51 cb(addr);
52 }
53 }
54
Size(void) const55 int Size(void) const { return devices_.size(); }
GetGroupId(void) const56 int GetGroupId(void) const { return group_id_; }
GetUuid(void) const57 const Uuid& GetUuid(void) const { return group_uuid_; }
58
59 private:
60 friend std::ostream& operator<<(std::ostream& out,
61 const bluetooth::groups::DeviceGroup& value);
62 int group_id_;
63 Uuid group_uuid_;
64 std::unordered_set<RawAddress> devices_;
65 };
66
67 class DeviceGroupsImpl : public DeviceGroups {
68 static constexpr uint8_t GROUP_STORAGE_CURRENT_LAYOUT_MAGIC = 0x10;
69 static constexpr size_t GROUP_STORAGE_HEADER_SZ =
70 sizeof(GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) +
71 sizeof(uint8_t); /* num_of_groups */
72 static constexpr size_t GROUP_STORAGE_ENTRY_SZ =
73 sizeof(uint8_t) /* group_id */ + Uuid::kNumBytes128;
74
75 public:
DeviceGroupsImpl(DeviceGroupsCallbacks * callbacks)76 DeviceGroupsImpl(DeviceGroupsCallbacks* callbacks) {
77 AddCallbacks(callbacks);
78 btif_storage_load_bonded_groups();
79 }
80
GetGroupId(const RawAddress & addr,Uuid uuid) const81 int GetGroupId(const RawAddress& addr, Uuid uuid) const override {
82 for (const auto& [id, g] : groups_) {
83 if ((g.Contains(addr)) && (uuid == g.GetUuid())) return id;
84 }
85 return kGroupUnknown;
86 }
87
add_to_group(const RawAddress & addr,DeviceGroup * group)88 void add_to_group(const RawAddress& addr, DeviceGroup* group) {
89 group->Add(addr);
90
91 bool first_device_in_group = (group->Size() == 1);
92
93 for (auto c : callbacks_) {
94 if (first_device_in_group) {
95 c->OnGroupAdded(addr, group->GetUuid(), group->GetGroupId());
96 } else {
97 c->OnGroupMemberAdded(addr, group->GetGroupId());
98 }
99 }
100 }
101
AddDevice(const RawAddress & addr,Uuid uuid,int group_id)102 int AddDevice(const RawAddress& addr, Uuid uuid, int group_id) override {
103 DeviceGroup* group = nullptr;
104
105 if (group_id == kGroupUnknown) {
106 auto gid = GetGroupId(addr, uuid);
107 if (gid != kGroupUnknown) return gid;
108 group = create_group(uuid);
109 } else {
110 group = get_or_create_group_with_id(group_id, uuid);
111 if (!group) {
112 return kGroupUnknown;
113 }
114 }
115
116 LOG_ASSERT(group);
117
118 if (group->Contains(addr)) {
119 LOG(ERROR) << __func__ << " device " << addr
120 << " already in the group: " << group_id;
121 return group->GetGroupId();
122 }
123
124 add_to_group(addr, group);
125
126 btif_storage_add_groups(addr);
127 return group->GetGroupId();
128 }
129
RemoveDevice(const RawAddress & addr,int group_id)130 void RemoveDevice(const RawAddress& addr, int group_id) override {
131 int num_of_groups_dev_belongs = 0;
132
133 /* Remove from all the groups. Usually happens on unbond */
134 for (auto it = groups_.begin(); it != groups_.end();) {
135 auto& [id, g] = *it;
136 if (!g.Contains(addr)) {
137 ++it;
138 continue;
139 }
140
141 num_of_groups_dev_belongs++;
142
143 if ((group_id != bluetooth::groups::kGroupUnknown) && (group_id != id)) {
144 ++it;
145 continue;
146 }
147
148 num_of_groups_dev_belongs--;
149
150 g.Remove(addr);
151 for (auto c : callbacks_) {
152 c->OnGroupMemberRemoved(addr, id);
153 }
154
155 if (g.Size() == 0) {
156 for (auto c : callbacks_) {
157 c->OnGroupRemoved(g.GetUuid(), g.GetGroupId());
158 }
159 it = groups_.erase(it);
160 } else {
161 ++it;
162 }
163 }
164
165 btif_storage_remove_groups(addr);
166 if (num_of_groups_dev_belongs > 0) {
167 btif_storage_add_groups(addr);
168 }
169 }
170
SerializeGroups(const RawAddress & addr,std::vector<uint8_t> & out) const171 bool SerializeGroups(const RawAddress& addr,
172 std::vector<uint8_t>& out) const {
173 auto num_groups = std::count_if(
174 groups_.begin(), groups_.end(), [&addr](auto& id_group_pair) {
175 return id_group_pair.second.Contains(addr);
176 });
177 if ((num_groups == 0) || (num_groups > std::numeric_limits<uint8_t>::max()))
178 return false;
179
180 out.resize(GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ));
181 auto* ptr = out.data();
182
183 /* header */
184 UINT8_TO_STREAM(ptr, GROUP_STORAGE_CURRENT_LAYOUT_MAGIC);
185 UINT8_TO_STREAM(ptr, num_groups);
186
187 /* group entries */
188 for (const auto& [id, g] : groups_) {
189 if (g.Contains(addr)) {
190 UINT8_TO_STREAM(ptr, id);
191
192 Uuid::UUID128Bit uuid128 = g.GetUuid().To128BitLE();
193 memcpy(ptr, uuid128.data(), Uuid::kNumBytes128);
194 ptr += Uuid::kNumBytes128;
195 }
196 }
197
198 return true;
199 }
200
DeserializeGroups(const RawAddress & addr,const std::vector<uint8_t> & in)201 void DeserializeGroups(const RawAddress& addr,
202 const std::vector<uint8_t>& in) {
203 if (in.size() < GROUP_STORAGE_HEADER_SZ + GROUP_STORAGE_ENTRY_SZ) return;
204
205 auto* ptr = in.data();
206
207 uint8_t magic;
208 STREAM_TO_UINT8(magic, ptr);
209
210 if (magic == GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) {
211 uint8_t num_groups;
212 STREAM_TO_UINT8(num_groups, ptr);
213
214 if (in.size() <
215 GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ)) {
216 LOG(ERROR) << "Invalid persistent storage data";
217 return;
218 }
219
220 /* group entries */
221 while (num_groups--) {
222 uint8_t id;
223 STREAM_TO_UINT8(id, ptr);
224
225 Uuid::UUID128Bit uuid128;
226 STREAM_TO_ARRAY(uuid128.data(), ptr, (int)Uuid::kNumBytes128);
227
228 auto* group =
229 get_or_create_group_with_id(id, Uuid::From128BitLE(uuid128));
230 if (group) add_to_group(addr, group);
231
232 for (auto c : callbacks_) {
233 c->OnGroupAddFromStorage(addr, Uuid::From128BitLE(uuid128), id);
234 }
235 }
236 }
237 }
238
AddCallbacks(DeviceGroupsCallbacks * callbacks)239 void AddCallbacks(DeviceGroupsCallbacks* callbacks) {
240 callbacks_.push_back(std::move(callbacks));
241
242 /* Notify new user about known groups */
243 for (const auto& [id, g] : groups_) {
244 auto group_uuid = g.GetUuid();
245 auto group_id = g.GetGroupId();
246 g.ForEachDevice([&](auto& dev) {
247 callbacks->OnGroupAdded(dev, group_uuid, group_id);
248 });
249 }
250 }
251
Clear(DeviceGroupsCallbacks * callbacks)252 bool Clear(DeviceGroupsCallbacks* callbacks) {
253 auto it = find_if(callbacks_.begin(), callbacks_.end(),
254 [callbacks](auto c) { return c == callbacks; });
255
256 if (it != callbacks_.end()) callbacks_.erase(it);
257
258 if (callbacks_.size() != 0) {
259 return false;
260 }
261 /* When all clients were unregistered */
262 groups_.clear();
263 return true;
264 }
265
Dump(int fd)266 void Dump(int fd) {
267 std::stringstream stream;
268
269 stream << " Num. registered clients: " << callbacks_.size() << std::endl;
270 stream << " Groups:\n";
271 for (const auto& kv_pair : groups_) {
272 stream << kv_pair.second << std::endl;
273 }
274
275 dprintf(fd, "%s", stream.str().c_str());
276 }
277
278 private:
find_device_group(int group_id)279 DeviceGroup* find_device_group(int group_id) {
280 return groups_.count(group_id) ? &groups_.at(group_id) : nullptr;
281 }
282
get_or_create_group_with_id(int group_id,Uuid uuid)283 DeviceGroup* get_or_create_group_with_id(int group_id, Uuid uuid) {
284 auto group = find_device_group(group_id);
285 if (group) {
286 if (group->GetUuid() != uuid) {
287 LOG(ERROR) << __func__ << " group " << group_id
288 << " exists but for different uuid: " << group->GetUuid()
289 << ", user request uuid: " << uuid;
290 return nullptr;
291 }
292
293 LOG(INFO) << __func__ << " group already exists: " << group_id;
294 return group;
295 }
296
297 DeviceGroup new_group(group_id, uuid);
298 groups_.insert({group_id, std::move(new_group)});
299
300 return &groups_.at(group_id);
301 }
302
create_group(Uuid & uuid)303 DeviceGroup* create_group(Uuid& uuid) {
304 /* Generate new group id and return empty group */
305 /* Find first free id */
306
307 int group_id = -1;
308 for (int i = 1; i < kMaxGroupId; i++) {
309 if (groups_.count(i) == 0) {
310 group_id = i;
311 break;
312 }
313 }
314
315 if (group_id < 0) {
316 LOG(ERROR) << __func__ << " too many groups";
317 return nullptr;
318 }
319
320 DeviceGroup group(group_id, uuid);
321 groups_.insert({group_id, std::move(group)});
322
323 return &groups_.at(group_id);
324 }
325
326 std::map<int, DeviceGroup> groups_;
327 std::list<DeviceGroupsCallbacks*> callbacks_;
328 };
329
Initialize(DeviceGroupsCallbacks * callbacks)330 void DeviceGroups::Initialize(DeviceGroupsCallbacks* callbacks) {
331 if (instance == nullptr) {
332 instance = new DeviceGroupsImpl(callbacks);
333 return;
334 }
335
336 instance->AddCallbacks(callbacks);
337 }
338
AddFromStorage(const RawAddress & addr,const std::vector<uint8_t> & in)339 void DeviceGroups::AddFromStorage(const RawAddress& addr,
340 const std::vector<uint8_t>& in) {
341 if (!instance) {
342 LOG(ERROR) << __func__ << ": Not initialized yet";
343 return;
344 }
345
346 instance->DeserializeGroups(addr, in);
347 }
348
GetForStorage(const RawAddress & addr,std::vector<uint8_t> & out)349 bool DeviceGroups::GetForStorage(const RawAddress& addr,
350 std::vector<uint8_t>& out) {
351 if (!instance) {
352 LOG(ERROR) << __func__ << ": Not initialized yet";
353 return false;
354 }
355
356 return instance->SerializeGroups(addr, out);
357 }
358
CleanUp(DeviceGroupsCallbacks * callbacks)359 void DeviceGroups::CleanUp(DeviceGroupsCallbacks* callbacks) {
360 if (!instance) return;
361
362 if (instance->Clear(callbacks)) {
363 delete (instance);
364 instance = nullptr;
365 }
366 }
367
operator <<(std::ostream & out,bluetooth::groups::DeviceGroup const & group)368 std::ostream& operator<<(std::ostream& out,
369 bluetooth::groups::DeviceGroup const& group) {
370 out << " == Group id: " << group.group_id_ << " == \n"
371 << " Uuid: " << group.group_uuid_ << std::endl;
372 out << " Devices:\n";
373 for (auto const& addr : group.devices_) {
374 out << " " << addr << std::endl;
375 }
376 return out;
377 }
378
DebugDump(int fd)379 void DeviceGroups::DebugDump(int fd) {
380 dprintf(fd, "Device Groups Manager:\n");
381 if (instance)
382 instance->Dump(fd);
383 else
384 dprintf(fd, " Not initialized \n");
385 }
386
Get()387 DeviceGroups* DeviceGroups::Get() { return instance; }
388
389 } // namespace groups
390 } // namespace bluetooth
391