• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 HIMSA II K/S - www.himsa.com.
3  * Represented by EHIMA - www.ehima.com
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include <base/logging.h>
19 
20 #include <algorithm>
21 #include <limits>
22 #include <map>
23 #include <unordered_set>
24 
25 #include "bta_groups.h"
26 #include "btif_storage.h"
27 #include "types/bluetooth/uuid.h"
28 #include "types/raw_address.h"
29 
30 using bluetooth::Uuid;
31 
32 namespace bluetooth {
33 namespace groups {
34 
35 class DeviceGroupsImpl;
36 DeviceGroupsImpl* instance;
37 static constexpr int kMaxGroupId = 0xEF;
38 
39 class DeviceGroup {
40  public:
DeviceGroup(int group_id,Uuid uuid)41   DeviceGroup(int group_id, Uuid uuid)
42       : group_id_(group_id), group_uuid_(uuid) {}
Add(const RawAddress & addr)43   void Add(const RawAddress& addr) { devices_.insert(addr); }
Remove(const RawAddress & addr)44   void Remove(const RawAddress& addr) { devices_.erase(addr); }
Contains(const RawAddress & addr) const45   bool Contains(const RawAddress& addr) const {
46     return (devices_.count(addr) != 0);
47   }
48 
ForEachDevice(std::function<void (const RawAddress &)> cb) const49   void ForEachDevice(std::function<void(const RawAddress&)> cb) const {
50     for (auto const& addr : devices_) {
51       cb(addr);
52     }
53   }
54 
Size(void) const55   int Size(void) const { return devices_.size(); }
GetGroupId(void) const56   int GetGroupId(void) const { return group_id_; }
GetUuid(void) const57   const Uuid& GetUuid(void) const { return group_uuid_; }
58 
59  private:
60   friend std::ostream& operator<<(std::ostream& out,
61                                   const bluetooth::groups::DeviceGroup& value);
62   int group_id_;
63   Uuid group_uuid_;
64   std::unordered_set<RawAddress> devices_;
65 };
66 
67 class DeviceGroupsImpl : public DeviceGroups {
68   static constexpr uint8_t GROUP_STORAGE_CURRENT_LAYOUT_MAGIC = 0x10;
69   static constexpr size_t GROUP_STORAGE_HEADER_SZ =
70       sizeof(GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) +
71       sizeof(uint8_t); /* num_of_groups */
72   static constexpr size_t GROUP_STORAGE_ENTRY_SZ =
73       sizeof(uint8_t) /* group_id */ + Uuid::kNumBytes128;
74 
75  public:
DeviceGroupsImpl(DeviceGroupsCallbacks * callbacks)76   DeviceGroupsImpl(DeviceGroupsCallbacks* callbacks) {
77     AddCallbacks(callbacks);
78     btif_storage_load_bonded_groups();
79   }
80 
GetGroupId(const RawAddress & addr,Uuid uuid) const81   int GetGroupId(const RawAddress& addr, Uuid uuid) const override {
82     for (const auto& [id, g] : groups_) {
83       if ((g.Contains(addr)) && (uuid == g.GetUuid())) return id;
84     }
85     return kGroupUnknown;
86   }
87 
add_to_group(const RawAddress & addr,DeviceGroup * group)88   void add_to_group(const RawAddress& addr, DeviceGroup* group) {
89     group->Add(addr);
90 
91     bool first_device_in_group = (group->Size() == 1);
92 
93     for (auto c : callbacks_) {
94       if (first_device_in_group) {
95         c->OnGroupAdded(addr, group->GetUuid(), group->GetGroupId());
96       } else {
97         c->OnGroupMemberAdded(addr, group->GetGroupId());
98       }
99     }
100   }
101 
AddDevice(const RawAddress & addr,Uuid uuid,int group_id)102   int AddDevice(const RawAddress& addr, Uuid uuid, int group_id) override {
103     DeviceGroup* group = nullptr;
104 
105     if (group_id == kGroupUnknown) {
106       auto gid = GetGroupId(addr, uuid);
107       if (gid != kGroupUnknown) return gid;
108       group = create_group(uuid);
109     } else {
110       group = get_or_create_group_with_id(group_id, uuid);
111       if (!group) {
112         return kGroupUnknown;
113       }
114     }
115 
116     LOG_ASSERT(group);
117 
118     if (group->Contains(addr)) {
119       LOG(ERROR) << __func__ << " device " << addr
120                  << " already in the group: " << group_id;
121       return group->GetGroupId();
122     }
123 
124     add_to_group(addr, group);
125 
126     btif_storage_add_groups(addr);
127     return group->GetGroupId();
128   }
129 
RemoveDevice(const RawAddress & addr,int group_id)130   void RemoveDevice(const RawAddress& addr, int group_id) override {
131     int num_of_groups_dev_belongs = 0;
132 
133     /* Remove from all the groups. Usually happens on unbond */
134     for (auto it = groups_.begin(); it != groups_.end();) {
135       auto& [id, g] = *it;
136       if (!g.Contains(addr)) {
137         ++it;
138         continue;
139       }
140 
141       num_of_groups_dev_belongs++;
142 
143       if ((group_id != bluetooth::groups::kGroupUnknown) && (group_id != id)) {
144         ++it;
145         continue;
146       }
147 
148       num_of_groups_dev_belongs--;
149 
150       g.Remove(addr);
151       for (auto c : callbacks_) {
152         c->OnGroupMemberRemoved(addr, id);
153       }
154 
155       if (g.Size() == 0) {
156         for (auto c : callbacks_) {
157           c->OnGroupRemoved(g.GetUuid(), g.GetGroupId());
158         }
159         it = groups_.erase(it);
160       } else {
161         ++it;
162       }
163     }
164 
165     btif_storage_remove_groups(addr);
166     if (num_of_groups_dev_belongs > 0) {
167       btif_storage_add_groups(addr);
168     }
169   }
170 
SerializeGroups(const RawAddress & addr,std::vector<uint8_t> & out) const171   bool SerializeGroups(const RawAddress& addr,
172                        std::vector<uint8_t>& out) const {
173     auto num_groups = std::count_if(
174         groups_.begin(), groups_.end(), [&addr](auto& id_group_pair) {
175           return id_group_pair.second.Contains(addr);
176         });
177     if ((num_groups == 0) || (num_groups > std::numeric_limits<uint8_t>::max()))
178       return false;
179 
180     out.resize(GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ));
181     auto* ptr = out.data();
182 
183     /* header */
184     UINT8_TO_STREAM(ptr, GROUP_STORAGE_CURRENT_LAYOUT_MAGIC);
185     UINT8_TO_STREAM(ptr, num_groups);
186 
187     /* group entries */
188     for (const auto& [id, g] : groups_) {
189       if (g.Contains(addr)) {
190         UINT8_TO_STREAM(ptr, id);
191 
192         Uuid::UUID128Bit uuid128 = g.GetUuid().To128BitLE();
193         memcpy(ptr, uuid128.data(), Uuid::kNumBytes128);
194         ptr += Uuid::kNumBytes128;
195       }
196     }
197 
198     return true;
199   }
200 
DeserializeGroups(const RawAddress & addr,const std::vector<uint8_t> & in)201   void DeserializeGroups(const RawAddress& addr,
202                          const std::vector<uint8_t>& in) {
203     if (in.size() < GROUP_STORAGE_HEADER_SZ + GROUP_STORAGE_ENTRY_SZ) return;
204 
205     auto* ptr = in.data();
206 
207     uint8_t magic;
208     STREAM_TO_UINT8(magic, ptr);
209 
210     if (magic == GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) {
211       uint8_t num_groups;
212       STREAM_TO_UINT8(num_groups, ptr);
213 
214       if (in.size() <
215           GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ)) {
216         LOG(ERROR) << "Invalid persistent storage data";
217         return;
218       }
219 
220       /* group entries */
221       while (num_groups--) {
222         uint8_t id;
223         STREAM_TO_UINT8(id, ptr);
224 
225         Uuid::UUID128Bit uuid128;
226         STREAM_TO_ARRAY(uuid128.data(), ptr, (int)Uuid::kNumBytes128);
227 
228         auto* group =
229             get_or_create_group_with_id(id, Uuid::From128BitLE(uuid128));
230         if (group) add_to_group(addr, group);
231 
232         for (auto c : callbacks_) {
233           c->OnGroupAddFromStorage(addr, Uuid::From128BitLE(uuid128), id);
234         }
235       }
236     }
237   }
238 
AddCallbacks(DeviceGroupsCallbacks * callbacks)239   void AddCallbacks(DeviceGroupsCallbacks* callbacks) {
240     callbacks_.push_back(std::move(callbacks));
241 
242     /* Notify new user about known groups */
243     for (const auto& [id, g] : groups_) {
244       auto group_uuid = g.GetUuid();
245       auto group_id = g.GetGroupId();
246       g.ForEachDevice([&](auto& dev) {
247         callbacks->OnGroupAdded(dev, group_uuid, group_id);
248       });
249     }
250   }
251 
Clear(DeviceGroupsCallbacks * callbacks)252   bool Clear(DeviceGroupsCallbacks* callbacks) {
253     auto it = find_if(callbacks_.begin(), callbacks_.end(),
254                       [callbacks](auto c) { return c == callbacks; });
255 
256     if (it != callbacks_.end()) callbacks_.erase(it);
257 
258     if (callbacks_.size() != 0) {
259       return false;
260     }
261     /* When all clients were unregistered */
262     groups_.clear();
263     return true;
264   }
265 
Dump(int fd)266   void Dump(int fd) {
267     std::stringstream stream;
268 
269     stream << "  Num. registered clients: " << callbacks_.size() << std::endl;
270     stream << "  Groups:\n";
271     for (const auto& kv_pair : groups_) {
272       stream << kv_pair.second << std::endl;
273     }
274 
275     dprintf(fd, "%s", stream.str().c_str());
276   }
277 
278  private:
find_device_group(int group_id)279   DeviceGroup* find_device_group(int group_id) {
280     return groups_.count(group_id) ? &groups_.at(group_id) : nullptr;
281   }
282 
get_or_create_group_with_id(int group_id,Uuid uuid)283   DeviceGroup* get_or_create_group_with_id(int group_id, Uuid uuid) {
284     auto group = find_device_group(group_id);
285     if (group) {
286       if (group->GetUuid() != uuid) {
287         LOG(ERROR) << __func__ << " group " << group_id
288                    << " exists but for different uuid: " << group->GetUuid()
289                    << ", user request uuid: " << uuid;
290         return nullptr;
291       }
292 
293       LOG(INFO) << __func__ << " group already exists: " << group_id;
294       return group;
295     }
296 
297     DeviceGroup new_group(group_id, uuid);
298     groups_.insert({group_id, std::move(new_group)});
299 
300     return &groups_.at(group_id);
301   }
302 
create_group(Uuid & uuid)303   DeviceGroup* create_group(Uuid& uuid) {
304     /* Generate new group id and return empty group */
305     /* Find first free id */
306 
307     int group_id = -1;
308     for (int i = 1; i < kMaxGroupId; i++) {
309       if (groups_.count(i) == 0) {
310         group_id = i;
311         break;
312       }
313     }
314 
315     if (group_id < 0) {
316       LOG(ERROR) << __func__ << " too many groups";
317       return nullptr;
318     }
319 
320     DeviceGroup group(group_id, uuid);
321     groups_.insert({group_id, std::move(group)});
322 
323     return &groups_.at(group_id);
324   }
325 
326   std::map<int, DeviceGroup> groups_;
327   std::list<DeviceGroupsCallbacks*> callbacks_;
328 };
329 
Initialize(DeviceGroupsCallbacks * callbacks)330 void DeviceGroups::Initialize(DeviceGroupsCallbacks* callbacks) {
331   if (instance == nullptr) {
332     instance = new DeviceGroupsImpl(callbacks);
333     return;
334   }
335 
336   instance->AddCallbacks(callbacks);
337 }
338 
AddFromStorage(const RawAddress & addr,const std::vector<uint8_t> & in)339 void DeviceGroups::AddFromStorage(const RawAddress& addr,
340                                   const std::vector<uint8_t>& in) {
341   if (!instance) {
342     LOG(ERROR) << __func__ << ": Not initialized yet";
343     return;
344   }
345 
346   instance->DeserializeGroups(addr, in);
347 }
348 
GetForStorage(const RawAddress & addr,std::vector<uint8_t> & out)349 bool DeviceGroups::GetForStorage(const RawAddress& addr,
350                                  std::vector<uint8_t>& out) {
351   if (!instance) {
352     LOG(ERROR) << __func__ << ": Not initialized yet";
353     return false;
354   }
355 
356   return instance->SerializeGroups(addr, out);
357 }
358 
CleanUp(DeviceGroupsCallbacks * callbacks)359 void DeviceGroups::CleanUp(DeviceGroupsCallbacks* callbacks) {
360   if (!instance) return;
361 
362   if (instance->Clear(callbacks)) {
363     delete (instance);
364     instance = nullptr;
365   }
366 }
367 
operator <<(std::ostream & out,bluetooth::groups::DeviceGroup const & group)368 std::ostream& operator<<(std::ostream& out,
369                          bluetooth::groups::DeviceGroup const& group) {
370   out << "    == Group id: " << group.group_id_ << " == \n"
371       << "      Uuid: " << group.group_uuid_ << std::endl;
372   out << "      Devices:\n";
373   for (auto const& addr : group.devices_) {
374     out << "        " << addr << std::endl;
375   }
376   return out;
377 }
378 
DebugDump(int fd)379 void DeviceGroups::DebugDump(int fd) {
380   dprintf(fd, "Device Groups Manager:\n");
381   if (instance)
382     instance->Dump(fd);
383   else
384     dprintf(fd, "  Not initialized \n");
385 }
386 
Get()387 DeviceGroups* DeviceGroups::Get() { return instance; }
388 
389 }  // namespace groups
390 }  // namespace bluetooth
391