• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 HIMSA II K/S - www.himsa.com.
3  * Represented by EHIMA - www.ehima.com
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include <base/logging.h>
19 
20 #include <algorithm>
21 #include <limits>
22 #include <map>
23 #include <mutex>
24 #include <unordered_set>
25 
26 #include "bta_groups.h"
27 #include "btif_profile_storage.h"
28 #include "types/bluetooth/uuid.h"
29 #include "types/raw_address.h"
30 
31 using bluetooth::Uuid;
32 
33 namespace bluetooth {
34 namespace groups {
35 
36 class DeviceGroupsImpl;
37 DeviceGroupsImpl* instance;
38 std::mutex instance_mutex;
39 static constexpr int kMaxGroupId = 0xEF;
40 
41 class DeviceGroup {
42  public:
DeviceGroup(int group_id,Uuid uuid)43   DeviceGroup(int group_id, Uuid uuid)
44       : group_id_(group_id), group_uuid_(uuid) {}
Add(const RawAddress & addr)45   void Add(const RawAddress& addr) { devices_.insert(addr); }
Remove(const RawAddress & addr)46   void Remove(const RawAddress& addr) { devices_.erase(addr); }
Contains(const RawAddress & addr) const47   bool Contains(const RawAddress& addr) const {
48     return (devices_.count(addr) != 0);
49   }
50 
ForEachDevice(std::function<void (const RawAddress &)> cb) const51   void ForEachDevice(std::function<void(const RawAddress&)> cb) const {
52     for (auto const& addr : devices_) {
53       cb(addr);
54     }
55   }
56 
Size(void) const57   int Size(void) const { return devices_.size(); }
GetGroupId(void) const58   int GetGroupId(void) const { return group_id_; }
GetUuid(void) const59   const Uuid& GetUuid(void) const { return group_uuid_; }
60 
61  private:
62   friend std::ostream& operator<<(std::ostream& out,
63                                   const bluetooth::groups::DeviceGroup& value);
64   int group_id_;
65   Uuid group_uuid_;
66   std::unordered_set<RawAddress> devices_;
67 };
68 
69 class DeviceGroupsImpl : public DeviceGroups {
70   static constexpr uint8_t GROUP_STORAGE_CURRENT_LAYOUT_MAGIC = 0x10;
71   static constexpr size_t GROUP_STORAGE_HEADER_SZ =
72       sizeof(GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) +
73       sizeof(uint8_t); /* num_of_groups */
74   static constexpr size_t GROUP_STORAGE_ENTRY_SZ =
75       sizeof(uint8_t) /* group_id */ + Uuid::kNumBytes128;
76 
77  public:
DeviceGroupsImpl(DeviceGroupsCallbacks * callbacks)78   DeviceGroupsImpl(DeviceGroupsCallbacks* callbacks) {
79     AddCallbacks(callbacks);
80     btif_storage_load_bonded_groups();
81   }
82 
GetGroupId(const RawAddress & addr,Uuid uuid) const83   int GetGroupId(const RawAddress& addr, Uuid uuid) const override {
84     for (const auto& [id, g] : groups_) {
85       if ((g.Contains(addr)) && (uuid == g.GetUuid())) return id;
86     }
87     return kGroupUnknown;
88   }
89 
add_to_group(const RawAddress & addr,DeviceGroup * group)90   void add_to_group(const RawAddress& addr, DeviceGroup* group) {
91     group->Add(addr);
92 
93     bool first_device_in_group = (group->Size() == 1);
94 
95     for (auto c : callbacks_) {
96       if (first_device_in_group) {
97         c->OnGroupAdded(addr, group->GetUuid(), group->GetGroupId());
98       } else {
99         c->OnGroupMemberAdded(addr, group->GetGroupId());
100       }
101     }
102   }
103 
AddDevice(const RawAddress & addr,Uuid uuid,int group_id)104   int AddDevice(const RawAddress& addr, Uuid uuid, int group_id) override {
105     DeviceGroup* group = nullptr;
106 
107     if (group_id == kGroupUnknown) {
108       auto gid = GetGroupId(addr, uuid);
109       if (gid != kGroupUnknown) return gid;
110       group = create_group(uuid);
111     } else {
112       group = get_or_create_group_with_id(group_id, uuid);
113       if (!group) {
114         return kGroupUnknown;
115       }
116     }
117 
118     LOG_ASSERT(group);
119 
120     if (group->Contains(addr)) {
121       LOG(ERROR) << __func__ << " device " << ADDRESS_TO_LOGGABLE_STR(addr)
122                  << " already in the group: " << group_id;
123       return group->GetGroupId();
124     }
125 
126     add_to_group(addr, group);
127 
128     btif_storage_add_groups(addr);
129     return group->GetGroupId();
130   }
131 
RemoveDevice(const RawAddress & addr,int group_id)132   void RemoveDevice(const RawAddress& addr, int group_id) override {
133     int num_of_groups_dev_belongs = 0;
134 
135     /* Remove from all the groups. Usually happens on unbond */
136     for (auto it = groups_.begin(); it != groups_.end();) {
137       auto& [id, g] = *it;
138       if (!g.Contains(addr)) {
139         ++it;
140         continue;
141       }
142 
143       num_of_groups_dev_belongs++;
144 
145       if ((group_id != bluetooth::groups::kGroupUnknown) && (group_id != id)) {
146         ++it;
147         continue;
148       }
149 
150       num_of_groups_dev_belongs--;
151 
152       g.Remove(addr);
153       for (auto c : callbacks_) {
154         c->OnGroupMemberRemoved(addr, id);
155       }
156 
157       if (g.Size() == 0) {
158         for (auto c : callbacks_) {
159           c->OnGroupRemoved(g.GetUuid(), g.GetGroupId());
160         }
161         it = groups_.erase(it);
162       } else {
163         ++it;
164       }
165     }
166 
167     btif_storage_remove_groups(addr);
168     if (num_of_groups_dev_belongs > 0) {
169       btif_storage_add_groups(addr);
170     }
171   }
172 
SerializeGroups(const RawAddress & addr,std::vector<uint8_t> & out) const173   bool SerializeGroups(const RawAddress& addr,
174                        std::vector<uint8_t>& out) const {
175     auto num_groups = std::count_if(
176         groups_.begin(), groups_.end(), [&addr](auto& id_group_pair) {
177           return id_group_pair.second.Contains(addr);
178         });
179     if ((num_groups == 0) || (num_groups > std::numeric_limits<uint8_t>::max()))
180       return false;
181 
182     out.resize(GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ));
183     auto* ptr = out.data();
184 
185     /* header */
186     UINT8_TO_STREAM(ptr, GROUP_STORAGE_CURRENT_LAYOUT_MAGIC);
187     UINT8_TO_STREAM(ptr, num_groups);
188 
189     /* group entries */
190     for (const auto& [id, g] : groups_) {
191       if (g.Contains(addr)) {
192         UINT8_TO_STREAM(ptr, id);
193 
194         Uuid::UUID128Bit uuid128 = g.GetUuid().To128BitLE();
195         memcpy(ptr, uuid128.data(), Uuid::kNumBytes128);
196         ptr += Uuid::kNumBytes128;
197       }
198     }
199 
200     return true;
201   }
202 
DeserializeGroups(const RawAddress & addr,const std::vector<uint8_t> & in)203   void DeserializeGroups(const RawAddress& addr,
204                          const std::vector<uint8_t>& in) {
205     if (in.size() < GROUP_STORAGE_HEADER_SZ + GROUP_STORAGE_ENTRY_SZ) return;
206 
207     auto* ptr = in.data();
208 
209     uint8_t magic;
210     STREAM_TO_UINT8(magic, ptr);
211 
212     if (magic == GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) {
213       uint8_t num_groups;
214       STREAM_TO_UINT8(num_groups, ptr);
215 
216       if (in.size() <
217           GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ)) {
218         LOG(ERROR) << "Invalid persistent storage data";
219         return;
220       }
221 
222       /* group entries */
223       while (num_groups--) {
224         uint8_t id;
225         STREAM_TO_UINT8(id, ptr);
226 
227         Uuid::UUID128Bit uuid128;
228         STREAM_TO_ARRAY(uuid128.data(), ptr, (int)Uuid::kNumBytes128);
229 
230         auto* group =
231             get_or_create_group_with_id(id, Uuid::From128BitLE(uuid128));
232         if (group) add_to_group(addr, group);
233 
234         for (auto c : callbacks_) {
235           c->OnGroupAddFromStorage(addr, Uuid::From128BitLE(uuid128), id);
236         }
237       }
238     }
239   }
240 
AddCallbacks(DeviceGroupsCallbacks * callbacks)241   void AddCallbacks(DeviceGroupsCallbacks* callbacks) {
242     callbacks_.push_back(std::move(callbacks));
243 
244     /* Notify new user about known groups */
245     for (const auto& [id, g] : groups_) {
246       auto group_uuid = g.GetUuid();
247       auto group_id = g.GetGroupId();
248       g.ForEachDevice([&](auto& dev) {
249         callbacks->OnGroupAdded(dev, group_uuid, group_id);
250       });
251     }
252   }
253 
Clear(DeviceGroupsCallbacks * callbacks)254   bool Clear(DeviceGroupsCallbacks* callbacks) {
255     auto it = find_if(callbacks_.begin(), callbacks_.end(),
256                       [callbacks](auto c) { return c == callbacks; });
257 
258     if (it != callbacks_.end()) callbacks_.erase(it);
259 
260     if (callbacks_.size() != 0) {
261       return false;
262     }
263     /* When all clients were unregistered */
264     groups_.clear();
265     return true;
266   }
267 
Dump(int fd)268   void Dump(int fd) {
269     std::stringstream stream;
270 
271     stream << "  Num. registered clients: " << callbacks_.size() << std::endl;
272     stream << "  Groups:\n";
273     for (const auto& kv_pair : groups_) {
274       stream << kv_pair.second << std::endl;
275     }
276 
277     dprintf(fd, "%s", stream.str().c_str());
278   }
279 
280  private:
find_device_group(int group_id)281   DeviceGroup* find_device_group(int group_id) {
282     return groups_.count(group_id) ? &groups_.at(group_id) : nullptr;
283   }
284 
get_or_create_group_with_id(int group_id,Uuid uuid)285   DeviceGroup* get_or_create_group_with_id(int group_id, Uuid uuid) {
286     auto group = find_device_group(group_id);
287     if (group) {
288       if (group->GetUuid() != uuid) {
289         LOG(ERROR) << __func__ << " group " << group_id
290                    << " exists but for different uuid: " << group->GetUuid()
291                    << ", user request uuid: " << uuid;
292         return nullptr;
293       }
294 
295       LOG(INFO) << __func__ << " group already exists: " << group_id;
296       return group;
297     }
298 
299     DeviceGroup new_group(group_id, uuid);
300     groups_.insert({group_id, std::move(new_group)});
301 
302     return &groups_.at(group_id);
303   }
304 
create_group(Uuid & uuid)305   DeviceGroup* create_group(Uuid& uuid) {
306     /* Generate new group id and return empty group */
307     /* Find first free id */
308 
309     int group_id = -1;
310     for (int i = 1; i < kMaxGroupId; i++) {
311       if (groups_.count(i) == 0) {
312         group_id = i;
313         break;
314       }
315     }
316 
317     if (group_id < 0) {
318       LOG(ERROR) << __func__ << " too many groups";
319       return nullptr;
320     }
321 
322     DeviceGroup group(group_id, uuid);
323     groups_.insert({group_id, std::move(group)});
324 
325     return &groups_.at(group_id);
326   }
327 
328   std::map<int, DeviceGroup> groups_;
329   std::list<DeviceGroupsCallbacks*> callbacks_;
330 };
331 
Initialize(DeviceGroupsCallbacks * callbacks)332 void DeviceGroups::Initialize(DeviceGroupsCallbacks* callbacks) {
333   std::scoped_lock<std::mutex> lock(instance_mutex);
334   if (instance == nullptr) {
335     instance = new DeviceGroupsImpl(callbacks);
336     return;
337   }
338 
339   instance->AddCallbacks(callbacks);
340 }
341 
AddFromStorage(const RawAddress & addr,const std::vector<uint8_t> & in)342 void DeviceGroups::AddFromStorage(const RawAddress& addr,
343                                   const std::vector<uint8_t>& in) {
344   if (!instance) {
345     LOG(ERROR) << __func__ << ": Not initialized yet";
346     return;
347   }
348 
349   instance->DeserializeGroups(addr, in);
350 }
351 
GetForStorage(const RawAddress & addr,std::vector<uint8_t> & out)352 bool DeviceGroups::GetForStorage(const RawAddress& addr,
353                                  std::vector<uint8_t>& out) {
354   if (!instance) {
355     LOG(ERROR) << __func__ << ": Not initialized yet";
356     return false;
357   }
358 
359   return instance->SerializeGroups(addr, out);
360 }
361 
CleanUp(DeviceGroupsCallbacks * callbacks)362 void DeviceGroups::CleanUp(DeviceGroupsCallbacks* callbacks) {
363   std::scoped_lock<std::mutex> lock(instance_mutex);
364   if (!instance) return;
365 
366   if (instance->Clear(callbacks)) {
367     delete (instance);
368     instance = nullptr;
369   }
370 }
371 
operator <<(std::ostream & out,bluetooth::groups::DeviceGroup const & group)372 std::ostream& operator<<(std::ostream& out,
373                          bluetooth::groups::DeviceGroup const& group) {
374   out << "    == Group id: " << group.group_id_ << " == \n"
375       << "      Uuid: " << group.group_uuid_ << std::endl;
376   out << "      Devices:\n";
377   for (auto const& addr : group.devices_) {
378     out << "        " << ADDRESS_TO_LOGGABLE_STR(addr) << std::endl;
379   }
380   return out;
381 }
382 
DebugDump(int fd)383 void DeviceGroups::DebugDump(int fd) {
384   std::scoped_lock<std::mutex> lock(instance_mutex);
385   dprintf(fd, "Device Groups Manager:\n");
386   if (instance)
387     instance->Dump(fd);
388   else
389     dprintf(fd, "  Not initialized \n");
390 }
391 
Get()392 DeviceGroups* DeviceGroups::Get() { return instance; }
393 
394 }  // namespace groups
395 }  // namespace bluetooth
396