• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre/util/system/message_router.h"
18 
19 #include "chre/util/dynamic_vector.h"
20 #include "chre/util/system/callback_allocator.h"
21 #include "chre/util/system/message_common.h"
22 #include "chre/util/system/message_router_mocks.h"
23 #include "chre_api/chre.h"
24 
25 #include "pw_allocator/libc_allocator.h"
26 #include "pw_allocator/unique_ptr.h"
27 #include "pw_intrusive_ptr/intrusive_ptr.h"
28 
29 #include "gmock/gmock.h"
30 #include "gtest/gtest.h"
31 
32 #include <cstddef>
33 #include <cstdint>
34 #include <optional>
35 #include <utility>
36 
37 using ::testing::_;
38 
39 namespace chre::message {
40 namespace {
41 
42 constexpr size_t kMaxMessageHubs = 3;
43 constexpr size_t kMaxSessions = 10;
44 constexpr size_t kMaxFreeCallbackRecords = kMaxSessions * 2;
45 constexpr size_t kNumEndpoints = 3;
46 
47 const EndpointInfo kEndpointInfos[kNumEndpoints] = {
48     EndpointInfo(/* id= */ 1, /* name= */ "endpoint1", /* version= */ 1,
49                  EndpointType::NANOAPP, CHRE_MESSAGE_PERMISSION_NONE),
50     EndpointInfo(/* id= */ 2, /* name= */ "endpoint2", /* version= */ 10,
51                  EndpointType::HOST_NATIVE, CHRE_MESSAGE_PERMISSION_BLE),
52     EndpointInfo(/* id= */ 3, /* name= */ "endpoint3", /* version= */ 100,
53                  EndpointType::GENERIC, CHRE_MESSAGE_PERMISSION_AUDIO)};
54 const char kServiceDescriptorForEndpoint2[] = "TEST_SERVICE.TEST";
55 
56 class MessageRouterTest : public ::testing::Test {};
57 
58 //! Iterates over the endpoints
forEachEndpoint(const pw::Function<bool (const EndpointInfo &)> & function)59 void forEachEndpoint(const pw::Function<bool(const EndpointInfo &)> &function) {
60   for (const EndpointInfo &endpointInfo : kEndpointInfos) {
61     if (function(endpointInfo)) {
62       return;
63     }
64   }
65 }
66 
67 //! Base class for MessageHubCallbacks used in tests
68 class MessageHubCallbackBase : public MessageRouter::MessageHubCallback {
69  public:
forEachEndpoint(const pw::Function<bool (const EndpointInfo &)> & function)70   void forEachEndpoint(
71       const pw::Function<bool(const EndpointInfo &)> &function) override {
72     ::chre::message::forEachEndpoint(function);
73   }
74 
getEndpointInfo(EndpointId endpointId)75   std::optional<EndpointInfo> getEndpointInfo(EndpointId endpointId) override {
76     for (const EndpointInfo &endpointInfo : kEndpointInfos) {
77       if (endpointInfo.id == endpointId) {
78         return endpointInfo;
79       }
80     }
81     return std::nullopt;
82   }
83 
onSessionOpenRequest(const Session &)84   void onSessionOpenRequest(const Session & /* session */) override {}
85 
getEndpointForService(const char * serviceDescriptor)86   std::optional<EndpointId> getEndpointForService(
87       const char *serviceDescriptor) override {
88     if (serviceDescriptor != nullptr &&
89         std::strcmp(serviceDescriptor, kServiceDescriptorForEndpoint2) == 0) {
90       return kEndpointInfos[1].id;
91     }
92     return std::nullopt;
93   }
94 
doesEndpointHaveService(EndpointId endpointId,const char * serviceDescriptor)95   bool doesEndpointHaveService(EndpointId endpointId,
96                                const char *serviceDescriptor) override {
97     return serviceDescriptor != nullptr && endpointId == kEndpointInfos[1].id &&
98            std::strcmp(serviceDescriptor, kServiceDescriptorForEndpoint2) == 0;
99   }
100 
forEachService(const pw::Function<bool (const EndpointInfo &,const ServiceInfo &)> & function)101   void forEachService(
102       const pw::Function<bool(const EndpointInfo &, const ServiceInfo &)>
103           &function) override {
104     function(kEndpointInfos[1],
105              ServiceInfo(kServiceDescriptorForEndpoint2, /* majorVersion= */ 1,
106                          /* minorVersion= */ 0, RpcFormat::CUSTOM));
107   }
108 
onHubRegistered(const MessageHubInfo &)109   void onHubRegistered(const MessageHubInfo & /* info */) override {}
110 
onHubUnregistered(MessageHubId)111   void onHubUnregistered(MessageHubId /* id */) override {}
112 
onEndpointRegistered(MessageHubId,EndpointId)113   void onEndpointRegistered(MessageHubId /* messageHubId */,
114                             EndpointId /* endpointId */) override {}
115 
onEndpointUnregistered(MessageHubId,EndpointId)116   void onEndpointUnregistered(MessageHubId /* messageHubId */,
117                               EndpointId /* endpointId */) override {}
118 
pw_recycle()119   void pw_recycle() override {
120     delete this;
121   }
122 };
123 
124 //! MessageHubCallback that stores the data passed to onMessageReceived and
125 //! onSessionClosed
126 class MessageHubCallbackStoreData : public MessageHubCallbackBase {
127  public:
MessageHubCallbackStoreData(Message * message,Session * session,Reason * reason=nullptr,Session * openedSession=nullptr)128   MessageHubCallbackStoreData(Message *message, Session *session,
129                               Reason *reason = nullptr,
130                               Session *openedSession = nullptr)
131       : mMessage(message),
132         mSession(session),
133         mReason(reason),
134         mOpenedSession(openedSession) {}
135 
onMessageReceived(pw::UniquePtr<std::byte[]> && data,uint32_t messageType,uint32_t messagePermissions,const Session & session,bool sentBySessionInitiator)136   bool onMessageReceived(pw::UniquePtr<std::byte[]> &&data,
137                          uint32_t messageType, uint32_t messagePermissions,
138                          const Session &session,
139                          bool sentBySessionInitiator) override {
140     if (mMessage != nullptr) {
141       mMessage->sender = sentBySessionInitiator ? session.initiator
142                                                 : session.peer;
143       mMessage->recipient =
144           sentBySessionInitiator ? session.peer : session.initiator;
145       mMessage->sessionId = session.sessionId;
146       mMessage->data = std::move(data);
147       mMessage->messageType = messageType;
148       mMessage->messagePermissions = messagePermissions;
149     }
150     return true;
151   }
152 
onSessionClosed(const Session & session,Reason reason)153   void onSessionClosed(const Session &session, Reason reason) override {
154     if (mSession != nullptr) {
155       *mSession = session;
156     }
157     if (mReason != nullptr) {
158       *mReason = reason;
159     }
160   }
161 
onSessionOpened(const Session & session)162   void onSessionOpened(const Session &session) override {
163     if (mOpenedSession != nullptr) {
164       *mOpenedSession = session;
165     }
166   }
167 
onEndpointRegistered(MessageHubId messageHubId,EndpointId endpointId)168   void onEndpointRegistered(MessageHubId messageHubId,
169                             EndpointId endpointId) override {
170     mRegisteredEndpoints.insert(std::make_pair(messageHubId, endpointId));
171   }
172 
onEndpointUnregistered(MessageHubId messageHubId,EndpointId endpointId)173   void onEndpointUnregistered(MessageHubId messageHubId,
174                               EndpointId endpointId) override {
175     mRegisteredEndpoints.erase(std::make_pair(messageHubId, endpointId));
176   }
177 
hasEndpointBeenRegistered(MessageHubId messageHubId,EndpointId endpointId)178   bool hasEndpointBeenRegistered(MessageHubId messageHubId,
179                                  EndpointId endpointId) {
180     return mRegisteredEndpoints.find(std::make_pair(
181                messageHubId, endpointId)) != mRegisteredEndpoints.end();
182   }
183 
184  private:
185   Message *mMessage;
186   Session *mSession;
187   Reason *mReason;
188   Session *mOpenedSession;
189   std::set<std::pair<MessageHubId, EndpointId>> mRegisteredEndpoints;
190 };
191 
192 //! MessageHubCallback that always fails to process messages
193 class MessageHubCallbackAlwaysFails : public MessageHubCallbackBase {
194  public:
MessageHubCallbackAlwaysFails(bool * wasMessageReceivedCalled,bool * wasSessionClosedCalled)195   MessageHubCallbackAlwaysFails(bool *wasMessageReceivedCalled,
196                                 bool *wasSessionClosedCalled)
197       : mWasMessageReceivedCalled(wasMessageReceivedCalled),
198         mWasSessionClosedCalled(wasSessionClosedCalled) {}
199 
onMessageReceived(pw::UniquePtr<std::byte[]> &&,uint32_t,uint32_t,const Session &,bool)200   bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
201                          uint32_t /* messageType */,
202                          uint32_t /* messagePermissions */,
203                          const Session & /* session */,
204                          bool /* sentBySessionInitiator */) override {
205     if (mWasMessageReceivedCalled != nullptr) {
206       *mWasMessageReceivedCalled = true;
207     }
208     return false;
209   }
210 
onSessionClosed(const Session &,Reason)211   void onSessionClosed(const Session & /* session */,
212                        Reason /* reason */) override {
213     if (mWasSessionClosedCalled != nullptr) {
214       *mWasSessionClosedCalled = true;
215     }
216   }
217 
onSessionOpened(const Session &)218   void onSessionOpened(const Session & /* session */) override {}
219 
220  private:
221   bool *mWasMessageReceivedCalled;
222   bool *mWasSessionClosedCalled;
223 };
224 
225 //! MessageHubCallback that tracks open session requests calls
226 class MessageHubCallbackOpenSessionRequest : public MessageHubCallbackBase {
227  public:
MessageHubCallbackOpenSessionRequest(bool * wasSessionOpenRequestCalled)228   MessageHubCallbackOpenSessionRequest(bool *wasSessionOpenRequestCalled)
229       : mWasSessionOpenRequestCalled(wasSessionOpenRequestCalled) {}
230 
onSessionOpenRequest(const Session &)231   void onSessionOpenRequest(const Session & /* session */) override {
232     if (mWasSessionOpenRequestCalled != nullptr) {
233       *mWasSessionOpenRequestCalled = true;
234     }
235   }
236 
onMessageReceived(pw::UniquePtr<std::byte[]> &&,uint32_t,uint32_t,const Session &,bool)237   bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
238                          uint32_t /* messageType */,
239                          uint32_t /* messagePermissions */,
240                          const Session & /* session */,
241                          bool /* sentBySessionInitiator */) override {
242     return true;
243   }
244 
onSessionClosed(const Session &,Reason)245   void onSessionClosed(const Session & /* session */,
246                        Reason /* reason */) override {}
247 
onSessionOpened(const Session &)248   void onSessionOpened(const Session & /* session */) override {}
249 
250  private:
251   bool *mWasSessionOpenRequestCalled;
252 };
253 
254 //! MessageHubCallback that calls MessageHub APIs during callbacks
255 class MessageHubCallbackCallsMessageHubApisDuringCallback
256     : public MessageHubCallbackBase {
257  public:
onMessageReceived(pw::UniquePtr<std::byte[]> &&,uint32_t,uint32_t,const Session &,bool)258   bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
259                          uint32_t /* messageType */,
260                          uint32_t /* messagePermissions */,
261                          const Session & /* session */,
262                          bool /* sentBySessionInitiator */) override {
263     if (mMessageHub != nullptr) {
264       // Call a function that locks the MessageRouter mutex
265       mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
266                                kEndpointInfos[1].id);
267     }
268     return true;
269   }
270 
onSessionClosed(const Session &,Reason)271   void onSessionClosed(const Session & /* session */,
272                        Reason /* reason */) override {
273     if (mMessageHub != nullptr) {
274       // Call a function that locks the MessageRouter mutex
275       mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
276                                kEndpointInfos[1].id);
277     }
278   }
279 
onSessionOpened(const Session &)280   void onSessionOpened(const Session & /* session */) override {
281     if (mMessageHub != nullptr) {
282       // Call a function that locks the MessageRouter mutex
283       mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
284                                kEndpointInfos[1].id);
285     }
286   }
287 
setMessageHub(MessageRouter::MessageHub * messageHub)288   void setMessageHub(MessageRouter::MessageHub *messageHub) {
289     mMessageHub = messageHub;
290   }
291 
292  private:
293   MessageRouter::MessageHub *mMessageHub = nullptr;
294 };
295 
TEST_F(MessageRouterTest,RegisterMessageHubNameIsUnique)296 TEST_F(MessageRouterTest, RegisterMessageHubNameIsUnique) {
297   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
298 
299   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
300       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
301                                                       /* session= */ nullptr);
302   std::optional<MessageRouter::MessageHub> messageHub1 =
303       router.registerMessageHub("hub1", /* id= */ 1, callback);
304   EXPECT_TRUE(messageHub1.has_value());
305   std::optional<MessageRouter::MessageHub> messageHub2 =
306       router.registerMessageHub("hub2", /* id= */ 2, callback);
307   EXPECT_TRUE(messageHub2.has_value());
308 
309   std::optional<MessageRouter::MessageHub> messageHub3 =
310       router.registerMessageHub("hub1", /* id= */ 1, callback);
311   EXPECT_FALSE(messageHub3.has_value());
312 }
313 
TEST_F(MessageRouterTest,RegisterMessageHubIdIsUnique)314 TEST_F(MessageRouterTest, RegisterMessageHubIdIsUnique) {
315   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
316 
317   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
318       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
319                                                       /* session= */ nullptr);
320   std::optional<MessageRouter::MessageHub> messageHub1 =
321       router.registerMessageHub("hub1", /* id= */ 1, callback);
322   EXPECT_TRUE(messageHub1.has_value());
323   std::optional<MessageRouter::MessageHub> messageHub2 =
324       router.registerMessageHub("hub2", /* id= */ 2, callback);
325   EXPECT_TRUE(messageHub2.has_value());
326 
327   std::optional<MessageRouter::MessageHub> messageHub3 =
328       router.registerMessageHub("hub3", /* id= */ 1, callback);
329   EXPECT_FALSE(messageHub3.has_value());
330 }
331 
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubs)332 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubs) {
333   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
334 
335   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
336       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
337                                                       /* session= */ nullptr);
338   std::optional<MessageRouter::MessageHub> messageHub1 =
339       router.registerMessageHub("hub1", /* id= */ 1, callback);
340   EXPECT_TRUE(messageHub1.has_value());
341   std::optional<MessageRouter::MessageHub> messageHub2 =
342       router.registerMessageHub("hub2", /* id= */ 2, callback);
343   EXPECT_TRUE(messageHub2.has_value());
344   std::optional<MessageRouter::MessageHub> messageHub3 =
345       router.registerMessageHub("hub3", /* id= */ 3, callback);
346   EXPECT_TRUE(messageHub3.has_value());
347 
348   DynamicVector<MessageHubInfo> messageHubs;
349   router.forEachMessageHub(
350       [&messageHubs](const MessageHubInfo &messageHubInfo) {
351         messageHubs.push_back(messageHubInfo);
352         return false;
353       });
354   EXPECT_EQ(messageHubs.size(), 3);
355   EXPECT_EQ(messageHubs[0].name, "hub1");
356   EXPECT_EQ(messageHubs[1].name, "hub2");
357   EXPECT_EQ(messageHubs[2].name, "hub3");
358   EXPECT_EQ(messageHubs[0].id, 1);
359   EXPECT_EQ(messageHubs[1].id, 2);
360   EXPECT_EQ(messageHubs[2].id, 3);
361   EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
362   EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
363   EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
364 }
365 
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubsWithUnregister)366 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubsWithUnregister) {
367   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
368 
369   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
370       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
371                                                       /* session= */ nullptr);
372   std::optional<MessageRouter::MessageHub> messageHub1 =
373       router.registerMessageHub("hub1", /* id= */ 1, callback);
374   EXPECT_TRUE(messageHub1.has_value());
375   std::optional<MessageRouter::MessageHub> messageHub2 =
376       router.registerMessageHub("hub2", /* id= */ 2, callback);
377   EXPECT_TRUE(messageHub2.has_value());
378   std::optional<MessageRouter::MessageHub> messageHub3 =
379       router.registerMessageHub("hub3", /* id= */ 3, callback);
380   EXPECT_TRUE(messageHub3.has_value());
381 
382   DynamicVector<MessageHubInfo> messageHubs;
383   router.forEachMessageHub(
384       [&messageHubs](const MessageHubInfo &messageHubInfo) {
385         messageHubs.push_back(messageHubInfo);
386         return false;
387       });
388   EXPECT_EQ(messageHubs.size(), 3);
389   EXPECT_EQ(messageHubs[0].name, "hub1");
390   EXPECT_EQ(messageHubs[1].name, "hub2");
391   EXPECT_EQ(messageHubs[2].name, "hub3");
392   EXPECT_EQ(messageHubs[0].id, 1);
393   EXPECT_EQ(messageHubs[1].id, 2);
394   EXPECT_EQ(messageHubs[2].id, 3);
395   EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
396   EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
397   EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
398 
399   // Clear messageHubs and reset messageHub2
400   messageHubs.clear();
401   messageHub2.reset();
402 
403   router.forEachMessageHub(
404       [&messageHubs](const MessageHubInfo &messageHubInfo) {
405         messageHubs.push_back(messageHubInfo);
406         return false;
407       });
408   EXPECT_EQ(messageHubs.size(), 2);
409   EXPECT_EQ(messageHubs[0].name, "hub1");
410   EXPECT_EQ(messageHubs[1].name, "hub3");
411   EXPECT_EQ(messageHubs[0].id, 1);
412   EXPECT_EQ(messageHubs[1].id, 3);
413   EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
414   EXPECT_EQ(messageHubs[1].id, messageHub3->getId());
415 }
416 
TEST_F(MessageRouterTest,RegisterMessageHubTooManyFails)417 TEST_F(MessageRouterTest, RegisterMessageHubTooManyFails) {
418   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
419   static_assert(kMaxMessageHubs == 3);
420   constexpr const char *kNames[3] = {"hub1", "hub2", "hub3"};
421 
422   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
423       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
424                                                       /* session= */ nullptr);
425   MessageRouter::MessageHub messageHubs[kMaxMessageHubs];
426   for (size_t i = 0; i < kMaxMessageHubs; ++i) {
427     std::optional<MessageRouter::MessageHub> messageHub =
428         router.registerMessageHub(kNames[i], /* id= */ i, callback);
429     EXPECT_TRUE(messageHub.has_value());
430     messageHubs[i] = std::move(*messageHub);
431   }
432 
433   std::optional<MessageRouter::MessageHub> messageHub =
434       router.registerMessageHub("shouldfail", /* id= */ kMaxMessageHubs * 2,
435                                 callback);
436   EXPECT_FALSE(messageHub.has_value());
437 }
438 
TEST_F(MessageRouterTest,GetEndpointInfo)439 TEST_F(MessageRouterTest, GetEndpointInfo) {
440   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
441 
442   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
443       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
444                                                       /* session= */ nullptr);
445   std::optional<MessageRouter::MessageHub> messageHub1 =
446       router.registerMessageHub("hub1", /* id= */ 1, callback);
447   EXPECT_TRUE(messageHub1.has_value());
448   std::optional<MessageRouter::MessageHub> messageHub2 =
449       router.registerMessageHub("hub2", /* id= */ 2, callback);
450   EXPECT_TRUE(messageHub2.has_value());
451   std::optional<MessageRouter::MessageHub> messageHub3 =
452       router.registerMessageHub("hub3", /* id= */ 3, callback);
453   EXPECT_TRUE(messageHub3.has_value());
454 
455   for (size_t i = 0; i < kNumEndpoints; ++i) {
456     EXPECT_EQ(
457         router.getEndpointInfo(messageHub1->getId(), kEndpointInfos[i].id),
458         kEndpointInfos[i]);
459     EXPECT_EQ(
460         router.getEndpointInfo(messageHub2->getId(), kEndpointInfos[i].id),
461         kEndpointInfos[i]);
462     EXPECT_EQ(
463         router.getEndpointInfo(messageHub3->getId(), kEndpointInfos[i].id),
464         kEndpointInfos[i]);
465   }
466 }
467 
TEST_F(MessageRouterTest,GetEndpointForService)468 TEST_F(MessageRouterTest, GetEndpointForService) {
469   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
470 
471   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
472       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
473                                                       /* session= */ nullptr);
474   std::optional<MessageRouter::MessageHub> messageHub1 =
475       router.registerMessageHub("hub1", /* id= */ 1, callback);
476   EXPECT_TRUE(messageHub1.has_value());
477 
478   std::optional<Endpoint> endpoint = router.getEndpointForService(
479       MESSAGE_HUB_ID_INVALID, kServiceDescriptorForEndpoint2);
480   EXPECT_TRUE(endpoint.has_value());
481 
482   EXPECT_EQ(endpoint->messageHubId, messageHub1->getId());
483   EXPECT_EQ(endpoint->endpointId, kEndpointInfos[1].id);
484 }
485 
TEST_F(MessageRouterTest,DoesEndpointHaveService)486 TEST_F(MessageRouterTest, DoesEndpointHaveService) {
487   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
488 
489   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
490       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
491                                                       /* session= */ nullptr);
492   std::optional<MessageRouter::MessageHub> messageHub1 =
493       router.registerMessageHub("hub1", /* id= */ 1, callback);
494   EXPECT_TRUE(messageHub1.has_value());
495 
496   EXPECT_TRUE(router.doesEndpointHaveService(messageHub1->getId(),
497                                              kEndpointInfos[1].id,
498                                              kServiceDescriptorForEndpoint2));
499 }
500 
TEST_F(MessageRouterTest,ForEachService)501 TEST_F(MessageRouterTest, ForEachService) {
502   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
503 
504   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
505       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
506                                                       /* session= */ nullptr);
507   std::optional<MessageRouter::MessageHub> messageHub1 =
508       router.registerMessageHub("hub1", /* id= */ 1, callback);
509   EXPECT_TRUE(messageHub1.has_value());
510 
511   router.forEachService([](const MessageHubInfo &hub,
512                            const EndpointInfo &endpoint,
513                            const ServiceInfo &service) {
514     EXPECT_EQ(hub.id, 1);
515     EXPECT_EQ(endpoint.id, kEndpointInfos[1].id);
516     EXPECT_STREQ(service.serviceDescriptor, kServiceDescriptorForEndpoint2);
517     EXPECT_EQ(service.majorVersion, 1);
518     EXPECT_EQ(service.minorVersion, 0);
519     EXPECT_EQ(service.format, RpcFormat::CUSTOM);
520     return true;
521   });
522 }
523 
TEST_F(MessageRouterTest,GetEndpointForServiceBadServiceDescriptor)524 TEST_F(MessageRouterTest, GetEndpointForServiceBadServiceDescriptor) {
525   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
526 
527   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
528       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
529                                                       /* session= */ nullptr);
530   std::optional<MessageRouter::MessageHub> messageHub1 =
531       router.registerMessageHub("hub1", /* id= */ 1, callback);
532   EXPECT_TRUE(messageHub1.has_value());
533 
534   std::optional<Endpoint> endpoint = router.getEndpointForService(
535       MESSAGE_HUB_ID_INVALID, "SERVICE_THAT_DOES_NOT_EXIST");
536   EXPECT_FALSE(endpoint.has_value());
537 
538   std::optional<Endpoint> endpoint2 = router.getEndpointForService(
539       MESSAGE_HUB_ID_INVALID, /* serviceDescriptor= */ nullptr);
540   EXPECT_FALSE(endpoint2.has_value());
541 }
542 
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubs)543 TEST_F(MessageRouterTest, RegisterSessionTwoDifferentMessageHubs) {
544   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
545   Session sessionFromCallback1;
546   Session sessionFromCallback2;
547   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
548       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
549                                                       &sessionFromCallback1);
550   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
551       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
552                                                       &sessionFromCallback2);
553 
554   std::optional<MessageRouter::MessageHub> messageHub =
555       router.registerMessageHub("hub1", /* id= */ 1, callback);
556   EXPECT_TRUE(messageHub.has_value());
557   std::optional<MessageRouter::MessageHub> messageHub2 =
558       router.registerMessageHub("hub2", /* id= */ 2, callback2);
559   EXPECT_TRUE(messageHub2.has_value());
560 
561   // Open session from messageHub:1 to messageHub2:2
562   SessionId sessionId = messageHub->openSession(
563       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
564   EXPECT_NE(sessionId, SESSION_ID_INVALID);
565   messageHub2->onSessionOpenComplete(sessionId);
566 
567   // Get session from messageHub and compare it with messageHub2
568   std::optional<Session> sessionAfterRegistering =
569       messageHub->getSessionWithId(sessionId);
570   EXPECT_TRUE(sessionAfterRegistering.has_value());
571   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
572   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
573             messageHub->getId());
574   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
575             kEndpointInfos[0].id);
576   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
577   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
578   std::optional<Session> sessionAfterRegistering2 =
579       messageHub2->getSessionWithId(sessionId);
580   EXPECT_TRUE(sessionAfterRegistering2.has_value());
581   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
582 
583   // Close the session and verify it is closed on both message hubs
584   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
585   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
586   EXPECT_TRUE(messageHub->closeSession(sessionId));
587   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
588   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
589   EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
590   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
591 }
592 
TEST_F(MessageRouterTest,RegisterSessionVerifyAllCallbacksAreCalled)593 TEST_F(MessageRouterTest, RegisterSessionVerifyAllCallbacksAreCalled) {
594   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
595   Session sessionClosedFromCallback1;
596   Session sessionClosedFromCallback2;
597   Session sessionOpenedFromCallback1;
598   Session sessionOpenedFromCallback2;
599   Reason sessionCloseReason1;
600   Reason sessionCloseReason2;
601   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
602       pw::MakeRefCounted<MessageHubCallbackStoreData>(
603           /* message= */ nullptr, &sessionClosedFromCallback1,
604           &sessionCloseReason1, &sessionOpenedFromCallback1);
605   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
606       pw::MakeRefCounted<MessageHubCallbackStoreData>(
607           /* message= */ nullptr, &sessionClosedFromCallback2,
608           &sessionCloseReason2, &sessionOpenedFromCallback2);
609 
610   std::optional<MessageRouter::MessageHub> messageHub =
611       router.registerMessageHub("hub1", /* id= */ 1, callback);
612   EXPECT_TRUE(messageHub.has_value());
613   std::optional<MessageRouter::MessageHub> messageHub2 =
614       router.registerMessageHub("hub2", /* id= */ 2, callback2);
615   EXPECT_TRUE(messageHub2.has_value());
616 
617   // Open session from messageHub:1 to messageHub2:2
618   SessionId sessionId = messageHub->openSession(
619       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
620   EXPECT_NE(sessionId, SESSION_ID_INVALID);
621   messageHub2->onSessionOpenComplete(sessionId);
622 
623   // Verify that onSessionOpened is called on both message hubs
624   EXPECT_EQ(sessionOpenedFromCallback1.sessionId, sessionId);
625   EXPECT_EQ(sessionOpenedFromCallback1.initiator.messageHubId,
626             messageHub->getId());
627   EXPECT_EQ(sessionOpenedFromCallback1.initiator.endpointId,
628             kEndpointInfos[0].id);
629   EXPECT_EQ(sessionOpenedFromCallback1.peer.messageHubId, messageHub2->getId());
630   EXPECT_EQ(sessionOpenedFromCallback1.peer.endpointId, kEndpointInfos[1].id);
631 
632   EXPECT_EQ(sessionOpenedFromCallback2.sessionId, sessionId);
633   EXPECT_EQ(sessionOpenedFromCallback2.initiator.messageHubId,
634             messageHub->getId());
635   EXPECT_EQ(sessionOpenedFromCallback2.initiator.endpointId,
636             kEndpointInfos[0].id);
637   EXPECT_EQ(sessionOpenedFromCallback2.peer.messageHubId, messageHub2->getId());
638   EXPECT_EQ(sessionOpenedFromCallback2.peer.endpointId, kEndpointInfos[1].id);
639 
640   // Close the session with a reason
641   Reason reason = Reason::TIMEOUT;
642   EXPECT_TRUE(messageHub->closeSession(sessionId, reason));
643 
644   // Verify that onSessionClosed is called on both message hubs
645   EXPECT_EQ(sessionClosedFromCallback1.sessionId, sessionId);
646   EXPECT_EQ(sessionClosedFromCallback1.initiator.messageHubId,
647             messageHub->getId());
648   EXPECT_EQ(sessionClosedFromCallback1.initiator.endpointId,
649             kEndpointInfos[0].id);
650   EXPECT_EQ(sessionClosedFromCallback1.peer.messageHubId, messageHub2->getId());
651   EXPECT_EQ(sessionClosedFromCallback1.peer.endpointId, kEndpointInfos[1].id);
652   EXPECT_EQ(sessionCloseReason1, reason);
653 
654   EXPECT_EQ(sessionClosedFromCallback2.sessionId, sessionId);
655   EXPECT_EQ(sessionClosedFromCallback2.initiator.messageHubId,
656             messageHub->getId());
657   EXPECT_EQ(sessionClosedFromCallback2.initiator.endpointId,
658             kEndpointInfos[0].id);
659   EXPECT_EQ(sessionClosedFromCallback2.peer.messageHubId, messageHub2->getId());
660   EXPECT_EQ(sessionClosedFromCallback2.peer.endpointId, kEndpointInfos[1].id);
661   EXPECT_EQ(sessionCloseReason2, reason);
662 }
663 
TEST_F(MessageRouterTest,RegisterSessionGetsRejectedAndClosed)664 TEST_F(MessageRouterTest, RegisterSessionGetsRejectedAndClosed) {
665   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
666   Session sessionFromCallback1;
667   Session sessionFromCallback2;
668   Reason sessionCloseReason;
669   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
670       pw::MakeRefCounted<MessageHubCallbackStoreData>(
671           /* message= */ nullptr, &sessionFromCallback1, &sessionCloseReason);
672   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
673       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
674                                                       &sessionFromCallback2);
675 
676   std::optional<MessageRouter::MessageHub> messageHub =
677       router.registerMessageHub("hub1", /* id= */ 1, callback);
678   EXPECT_TRUE(messageHub.has_value());
679   std::optional<MessageRouter::MessageHub> messageHub2 =
680       router.registerMessageHub("hub2", /* id= */ 2, callback2);
681   EXPECT_TRUE(messageHub2.has_value());
682 
683   // Open session from messageHub:1 to messageHub2:2
684   SessionId sessionId = messageHub->openSession(
685       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
686   EXPECT_NE(sessionId, SESSION_ID_INVALID);
687   Reason reason = Reason::OPEN_ENDPOINT_SESSION_REQUEST_REJECTED;
688   messageHub2->closeSession(sessionId, reason);
689 
690   // Get session from messageHub and ensure it is deleted
691   std::optional<Session> sessionAfterRegistering =
692       messageHub->getSessionWithId(sessionId);
693   EXPECT_FALSE(sessionAfterRegistering.has_value());
694   std::optional<Session> sessionAfterRegistering2 =
695       messageHub2->getSessionWithId(sessionId);
696   EXPECT_FALSE(sessionAfterRegistering2.has_value());
697 
698   // Close the session and verify it is closed on both message hubs
699   EXPECT_EQ(sessionFromCallback1.sessionId, sessionId);
700   EXPECT_EQ(sessionFromCallback1.initiator.messageHubId, messageHub->getId());
701   EXPECT_EQ(sessionFromCallback1.initiator.endpointId, kEndpointInfos[0].id);
702   EXPECT_EQ(sessionFromCallback1.peer.messageHubId, messageHub2->getId());
703   EXPECT_EQ(sessionFromCallback1.peer.endpointId, kEndpointInfos[1].id);
704   EXPECT_EQ(sessionCloseReason, reason);
705 }
706 
TEST_F(MessageRouterTest,RegisterSessionSecondHubDoesNotRespond)707 TEST_F(MessageRouterTest, RegisterSessionSecondHubDoesNotRespond) {
708   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
709   bool wasOpenSessionRequestCalled = false;
710   bool wasOpenSessionRequestCalled2 = false;
711   pw::IntrusivePtr<MessageHubCallbackOpenSessionRequest> callback =
712       pw::MakeRefCounted<MessageHubCallbackOpenSessionRequest>(
713           &wasOpenSessionRequestCalled);
714   pw::IntrusivePtr<MessageHubCallbackOpenSessionRequest> callback2 =
715       pw::MakeRefCounted<MessageHubCallbackOpenSessionRequest>(
716           &wasOpenSessionRequestCalled2);
717 
718   std::optional<MessageRouter::MessageHub> messageHub =
719       router.registerMessageHub("hub1", /* id= */ 1, callback);
720   EXPECT_TRUE(messageHub.has_value());
721   std::optional<MessageRouter::MessageHub> messageHub2 =
722       router.registerMessageHub("hub2", /* id= */ 2, callback2);
723   EXPECT_TRUE(messageHub2.has_value());
724 
725   // Open session from messageHub:1 to messageHub2:2
726   SessionId sessionId = messageHub->openSession(
727       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
728   EXPECT_NE(sessionId, SESSION_ID_INVALID);
729 
730   // Message Hub 2 does not respond - verify the callback was called once
731   EXPECT_FALSE(wasOpenSessionRequestCalled);
732   EXPECT_TRUE(wasOpenSessionRequestCalled2);
733 
734   // Open session from messageHub:1 to messageHub2:2 - try again
735   wasOpenSessionRequestCalled = false;
736   SessionId sessionId2 = messageHub->openSession(
737       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
738   EXPECT_NE(sessionId, SESSION_ID_INVALID);
739   EXPECT_EQ(sessionId, sessionId2);
740   EXPECT_FALSE(wasOpenSessionRequestCalled);
741   EXPECT_TRUE(wasOpenSessionRequestCalled2);
742 
743   // Respond then close the session
744   messageHub2->onSessionOpenComplete(sessionId2);
745   EXPECT_TRUE(messageHub->closeSession(sessionId));
746 }
747 
TEST_F(MessageRouterTest,RegisterSessionWithServiceDescriptor)748 TEST_F(MessageRouterTest, RegisterSessionWithServiceDescriptor) {
749   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
750   Session sessionFromCallback1;
751   Session sessionFromCallback2;
752   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
753       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
754                                                       &sessionFromCallback1);
755   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
756       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
757                                                       &sessionFromCallback2);
758 
759   std::optional<MessageRouter::MessageHub> messageHub =
760       router.registerMessageHub("hub1", /* id= */ 1, callback);
761   EXPECT_TRUE(messageHub.has_value());
762   std::optional<MessageRouter::MessageHub> messageHub2 =
763       router.registerMessageHub("hub2", /* id= */ 2, callback2);
764   EXPECT_TRUE(messageHub2.has_value());
765 
766   // Open session from messageHub:1 to messageHub2:2
767   SessionId sessionId = messageHub->openSession(
768       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
769       kServiceDescriptorForEndpoint2);
770   EXPECT_NE(sessionId, SESSION_ID_INVALID);
771 
772   // Get session from messageHub and compare it with messageHub2
773   std::optional<Session> sessionAfterRegistering =
774       messageHub->getSessionWithId(sessionId);
775   EXPECT_TRUE(sessionAfterRegistering.has_value());
776   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
777   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
778             messageHub->getId());
779   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
780             kEndpointInfos[0].id);
781   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
782   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
783   EXPECT_TRUE(sessionAfterRegistering->hasServiceDescriptor);
784   EXPECT_STREQ(sessionAfterRegistering->serviceDescriptor,
785                kServiceDescriptorForEndpoint2);
786   std::optional<Session> sessionAfterRegistering2 =
787       messageHub2->getSessionWithId(sessionId);
788   EXPECT_TRUE(sessionAfterRegistering2.has_value());
789   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
790 
791   // Close the session and verify it is closed on both message hubs
792   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
793   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
794   EXPECT_TRUE(messageHub->closeSession(sessionId));
795   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
796   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
797   EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
798   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
799 }
800 
TEST_F(MessageRouterTest,RegisterSessionWithAndWithoutServiceDescriptorSameEndpoints)801 TEST_F(MessageRouterTest,
802        RegisterSessionWithAndWithoutServiceDescriptorSameEndpoints) {
803   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
804   Session sessionFromCallback1;
805   Session sessionFromCallback2;
806   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
807       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
808                                                       &sessionFromCallback1);
809   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
810       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
811                                                       &sessionFromCallback2);
812 
813   std::optional<MessageRouter::MessageHub> messageHub =
814       router.registerMessageHub("hub1", /* id= */ 1, callback);
815   EXPECT_TRUE(messageHub.has_value());
816   std::optional<MessageRouter::MessageHub> messageHub2 =
817       router.registerMessageHub("hub2", /* id= */ 2, callback2);
818   EXPECT_TRUE(messageHub2.has_value());
819 
820   // Open session from messageHub:1 to messageHub2:2 with service descriptor
821   SessionId sessionId = messageHub->openSession(
822       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
823       kServiceDescriptorForEndpoint2);
824   EXPECT_NE(sessionId, SESSION_ID_INVALID);
825 
826   // Open session from messageHub:1 to messageHub2:2 without service descriptor
827   SessionId sessionId2 = messageHub->openSession(
828       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
829   EXPECT_NE(sessionId2, SESSION_ID_INVALID);
830   EXPECT_NE(sessionId2, sessionId);
831 
832   // Get the first session from messageHub and compare it with messageHub2
833   std::optional<Session> sessionAfterRegistering =
834       messageHub->getSessionWithId(sessionId);
835   EXPECT_TRUE(sessionAfterRegistering.has_value());
836   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
837   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
838             messageHub->getId());
839   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
840             kEndpointInfos[0].id);
841   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
842   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
843   EXPECT_TRUE(sessionAfterRegistering->hasServiceDescriptor);
844   EXPECT_STREQ(sessionAfterRegistering->serviceDescriptor,
845                kServiceDescriptorForEndpoint2);
846   std::optional<Session> sessionAfterRegistering2 =
847       messageHub2->getSessionWithId(sessionId);
848   EXPECT_TRUE(sessionAfterRegistering2.has_value());
849   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
850 
851   // Get the second session from messageHub and compare it with messageHub2
852   std::optional<Session> sessionAfterRegistering3 =
853       messageHub->getSessionWithId(sessionId2);
854   EXPECT_TRUE(sessionAfterRegistering3.has_value());
855   EXPECT_EQ(sessionAfterRegistering3->sessionId, sessionId2);
856   EXPECT_EQ(sessionAfterRegistering3->initiator.messageHubId,
857             messageHub->getId());
858   EXPECT_EQ(sessionAfterRegistering3->initiator.endpointId,
859             kEndpointInfos[0].id);
860   EXPECT_EQ(sessionAfterRegistering3->peer.messageHubId, messageHub2->getId());
861   EXPECT_EQ(sessionAfterRegistering3->peer.endpointId, kEndpointInfos[1].id);
862   EXPECT_FALSE(sessionAfterRegistering3->hasServiceDescriptor);
863   EXPECT_STREQ(sessionAfterRegistering3->serviceDescriptor, "");
864   std::optional<Session> sessionAfterRegistering4 =
865       messageHub2->getSessionWithId(sessionId2);
866   EXPECT_TRUE(sessionAfterRegistering4.has_value());
867   EXPECT_EQ(*sessionAfterRegistering3, *sessionAfterRegistering4);
868 }
869 
TEST_F(MessageRouterTest,RegisterSessionWithBadServiceDescriptor)870 TEST_F(MessageRouterTest, RegisterSessionWithBadServiceDescriptor) {
871   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
872   Session sessionFromCallback1;
873   Session sessionFromCallback2;
874   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
875       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
876                                                       &sessionFromCallback1);
877   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
878       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
879                                                       &sessionFromCallback2);
880 
881   std::optional<MessageRouter::MessageHub> messageHub =
882       router.registerMessageHub("hub1", /* id= */ 1, callback);
883   EXPECT_TRUE(messageHub.has_value());
884   std::optional<MessageRouter::MessageHub> messageHub2 =
885       router.registerMessageHub("hub2", /* id= */ 2, callback2);
886   EXPECT_TRUE(messageHub2.has_value());
887 
888   // Open session from messageHub:1 to messageHub2:2
889   SessionId sessionId = messageHub->openSession(
890       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[2].id,
891       kServiceDescriptorForEndpoint2);
892   EXPECT_EQ(sessionId, SESSION_ID_INVALID);
893 }
894 
TEST_F(MessageRouterTest,UnregisterMessageHubCausesSessionClosed)895 TEST_F(MessageRouterTest, UnregisterMessageHubCausesSessionClosed) {
896   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
897   Session sessionFromCallback1;
898   Session sessionFromCallback2;
899   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
900       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
901                                                       &sessionFromCallback1);
902   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
903       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
904                                                       &sessionFromCallback2);
905 
906   std::optional<MessageRouter::MessageHub> messageHub =
907       router.registerMessageHub("hub1", /* id= */ 1, callback);
908   EXPECT_TRUE(messageHub.has_value());
909   std::optional<MessageRouter::MessageHub> messageHub2 =
910       router.registerMessageHub("hub2", /* id= */ 2, callback2);
911   EXPECT_TRUE(messageHub2.has_value());
912 
913   // Open session from messageHub:1 to messageHub2:2
914   SessionId sessionId = messageHub->openSession(
915       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
916   EXPECT_NE(sessionId, SESSION_ID_INVALID);
917   messageHub2->onSessionOpenComplete(sessionId);
918 
919   // Get session from messageHub and compare it with messageHub2
920   std::optional<Session> sessionAfterRegistering =
921       messageHub->getSessionWithId(sessionId);
922   EXPECT_TRUE(sessionAfterRegistering.has_value());
923   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
924   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
925             messageHub->getId());
926   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
927             kEndpointInfos[0].id);
928   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
929   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
930   std::optional<Session> sessionAfterRegistering2 =
931       messageHub2->getSessionWithId(sessionId);
932   EXPECT_TRUE(sessionAfterRegistering2.has_value());
933   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
934 
935   // Close the session and verify it is closed on the other hub
936   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
937   messageHub2.reset();
938   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
939   EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
940 }
941 
TEST_F(MessageRouterTest,RegisterSessionSameMessageHubIsValid)942 TEST_F(MessageRouterTest, RegisterSessionSameMessageHubIsValid) {
943   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
944   Session sessionFromCallback1;
945   Session sessionFromCallback2;
946   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
947       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
948                                                       &sessionFromCallback1);
949   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
950       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
951                                                       &sessionFromCallback2);
952 
953   std::optional<MessageRouter::MessageHub> messageHub =
954       router.registerMessageHub("hub1", /* id= */ 1, callback);
955   EXPECT_TRUE(messageHub.has_value());
956   std::optional<MessageRouter::MessageHub> messageHub2 =
957       router.registerMessageHub("hub2", /* id= */ 2, callback2);
958   EXPECT_TRUE(messageHub2.has_value());
959 
960   // Open session from messageHub:2 to messageHub:2
961   SessionId sessionId = messageHub->openSession(
962       kEndpointInfos[1].id, messageHub->getId(), kEndpointInfos[1].id);
963   EXPECT_NE(sessionId, SESSION_ID_INVALID);
964 
965   // Open session from messageHub:1 to messageHub:3
966   sessionId = messageHub->openSession(kEndpointInfos[0].id, messageHub->getId(),
967                                       kEndpointInfos[2].id);
968   EXPECT_NE(sessionId, SESSION_ID_INVALID);
969 }
970 
TEST_F(MessageRouterTest,RegisterSessionReservedSessionIdAreRespected)971 TEST_F(MessageRouterTest, RegisterSessionReservedSessionIdAreRespected) {
972   constexpr SessionId kReservedSessionId = 25;
973   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router(
974       kReservedSessionId);
975   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
976       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
977                                                       /* session= */ nullptr);
978   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
979       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
980                                                       /* session= */ nullptr);
981 
982   std::optional<MessageRouter::MessageHub> messageHub =
983       router.registerMessageHub("hub1", /* id= */ 1, callback);
984   EXPECT_TRUE(messageHub.has_value());
985   std::optional<MessageRouter::MessageHub> messageHub2 =
986       router.registerMessageHub("hub2", /* id= */ 2, callback2);
987   EXPECT_TRUE(messageHub2.has_value());
988 
989   // Open session from messageHub:1 to messageHub:2 more than the max number of
990   // sessions - should wrap around
991   for (size_t i = 0; i < kReservedSessionId * 2; ++i) {
992     SessionId sessionId = messageHub->openSession(
993         kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
994     EXPECT_NE(sessionId, SESSION_ID_INVALID);
995     messageHub2->onSessionOpenComplete(sessionId);
996     EXPECT_TRUE(messageHub->closeSession(sessionId));
997   }
998 }
999 
TEST_F(MessageRouterTest,RegisterSessionOpenSessionNotReservedRegionRejected)1000 TEST_F(MessageRouterTest, RegisterSessionOpenSessionNotReservedRegionRejected) {
1001   constexpr SessionId kReservedSessionId = 25;
1002   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router(
1003       kReservedSessionId);
1004   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1005       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1006                                                       /* session= */ nullptr);
1007   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1008       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1009                                                       /* session= */ nullptr);
1010 
1011   std::optional<MessageRouter::MessageHub> messageHub =
1012       router.registerMessageHub("hub1", /* id= */ 1, callback);
1013   EXPECT_TRUE(messageHub.has_value());
1014   std::optional<MessageRouter::MessageHub> messageHub2 =
1015       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1016   EXPECT_TRUE(messageHub2.has_value());
1017 
1018   // Open session from messageHub:1 to messageHub:2 and provide an invalid
1019   // session ID (not in the reserved range)
1020   SessionId sessionId = messageHub->openSession(
1021       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
1022       /* serviceDescriptor= */ nullptr, kReservedSessionId / 2);
1023   EXPECT_EQ(sessionId, SESSION_ID_INVALID);
1024 }
1025 
TEST_F(MessageRouterTest,RegisterSessionOpenSessionWithReservedSessionId)1026 TEST_F(MessageRouterTest, RegisterSessionOpenSessionWithReservedSessionId) {
1027   constexpr SessionId kReservedSessionId = 25;
1028   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router(
1029       kReservedSessionId);
1030   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1031       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1032                                                       /* session= */ nullptr);
1033   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1034       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1035                                                       /* session= */ nullptr);
1036 
1037   std::optional<MessageRouter::MessageHub> messageHub =
1038       router.registerMessageHub("hub1", /* id= */ 1, callback);
1039   EXPECT_TRUE(messageHub.has_value());
1040   std::optional<MessageRouter::MessageHub> messageHub2 =
1041       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1042   EXPECT_TRUE(messageHub2.has_value());
1043 
1044   // Open session from messageHub:1 to messageHub:2 and provide a reserved
1045   // session ID
1046   SessionId sessionId = messageHub->openSession(
1047       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
1048       /* serviceDescriptor= */ nullptr, kReservedSessionId);
1049   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1050   messageHub2->onSessionOpenComplete(sessionId);
1051   EXPECT_TRUE(messageHub->closeSession(sessionId));
1052 }
1053 
TEST_F(MessageRouterTest,RegisterSessionDifferentMessageHubsSameEndpoints)1054 TEST_F(MessageRouterTest, RegisterSessionDifferentMessageHubsSameEndpoints) {
1055   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1056   Session sessionFromCallback1;
1057   Session sessionFromCallback2;
1058   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1059       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1060                                                       &sessionFromCallback1);
1061   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1062       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1063                                                       &sessionFromCallback2);
1064 
1065   std::optional<MessageRouter::MessageHub> messageHub =
1066       router.registerMessageHub("hub1", /* id= */ 1, callback);
1067   EXPECT_TRUE(messageHub.has_value());
1068   std::optional<MessageRouter::MessageHub> messageHub2 =
1069       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1070   EXPECT_TRUE(messageHub2.has_value());
1071 
1072   // Open session from messageHub:1 to messageHub:2
1073   SessionId sessionId = messageHub->openSession(
1074       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[0].id);
1075   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1076   messageHub2->onSessionOpenComplete(sessionId);
1077 }
1078 
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubsInvalidEndpoint)1079 TEST_F(MessageRouterTest,
1080        RegisterSessionTwoDifferentMessageHubsInvalidEndpoint) {
1081   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1082   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1083       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1084                                                       /* session= */ nullptr);
1085   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1086       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1087                                                       /* session= */ nullptr);
1088 
1089   std::optional<MessageRouter::MessageHub> messageHub =
1090       router.registerMessageHub("hub1", /* id= */ 1, callback);
1091   EXPECT_TRUE(messageHub.has_value());
1092   std::optional<MessageRouter::MessageHub> messageHub2 =
1093       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1094   EXPECT_TRUE(messageHub2.has_value());
1095 
1096   // Open session from messageHub with other non-registered endpoint - not
1097   // valid
1098   SessionId sessionId = messageHub->openSession(
1099       kEndpointInfos[0].id, messageHub2->getId(), /* toEndpointId= */ 10);
1100   EXPECT_EQ(sessionId, SESSION_ID_INVALID);
1101 }
1102 
TEST_F(MessageRouterTest,ThirdMessageHubTriesToFindOthersSession)1103 TEST_F(MessageRouterTest, ThirdMessageHubTriesToFindOthersSession) {
1104   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1105   Session sessionFromCallback1;
1106   Session sessionFromCallback2;
1107   Session sessionFromCallback3;
1108   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1109       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1110                                                       &sessionFromCallback1);
1111   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1112       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1113                                                       &sessionFromCallback2);
1114   pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1115       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1116                                                       &sessionFromCallback3);
1117 
1118   std::optional<MessageRouter::MessageHub> messageHub =
1119       router.registerMessageHub("hub1", /* id= */ 1, callback);
1120   EXPECT_TRUE(messageHub.has_value());
1121   std::optional<MessageRouter::MessageHub> messageHub2 =
1122       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1123   EXPECT_TRUE(messageHub2.has_value());
1124   std::optional<MessageRouter::MessageHub> messageHub3 =
1125       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1126   EXPECT_TRUE(messageHub3.has_value());
1127 
1128   // Open session from messageHub:1 to messageHub2:2
1129   SessionId sessionId = messageHub->openSession(
1130       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1131   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1132 
1133   // Get session from messageHub and compare it with messageHub2
1134   std::optional<Session> sessionAfterRegistering =
1135       messageHub->getSessionWithId(sessionId);
1136   EXPECT_TRUE(sessionAfterRegistering.has_value());
1137   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
1138   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
1139             messageHub->getId());
1140   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
1141             kEndpointInfos[0].id);
1142   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
1143   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
1144   std::optional<Session> sessionAfterRegistering2 =
1145       messageHub2->getSessionWithId(sessionId);
1146   EXPECT_TRUE(sessionAfterRegistering2.has_value());
1147   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1148 
1149   // Third message hub tries to find the session - not found
1150   EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
1151   // Third message hub tries to close the session - not found
1152   EXPECT_FALSE(messageHub3->closeSession(sessionId));
1153 
1154   // Get session from messageHub and compare it with messageHub2 again
1155   sessionAfterRegistering = messageHub->getSessionWithId(sessionId);
1156   EXPECT_TRUE(sessionAfterRegistering.has_value());
1157   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
1158   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
1159             messageHub->getId());
1160   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
1161             kEndpointInfos[0].id);
1162   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
1163   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
1164   sessionAfterRegistering2 = messageHub2->getSessionWithId(sessionId);
1165   EXPECT_TRUE(sessionAfterRegistering2.has_value());
1166   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1167 
1168   // Close the session and verify it is closed on both message hubs
1169   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
1170   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
1171   EXPECT_TRUE(messageHub->closeSession(sessionId));
1172   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
1173   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
1174   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback3);
1175   EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
1176   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
1177 }
1178 
TEST_F(MessageRouterTest,ThreeMessageHubsAndThreeSessions)1179 TEST_F(MessageRouterTest, ThreeMessageHubsAndThreeSessions) {
1180   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1181   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1182       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1183                                                       /* session= */ nullptr);
1184   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1185       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1186                                                       /* session= */ nullptr);
1187   pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1188       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1189                                                       /* session= */ nullptr);
1190 
1191   std::optional<MessageRouter::MessageHub> messageHub =
1192       router.registerMessageHub("hub1", /* id= */ 1, callback);
1193   EXPECT_TRUE(messageHub.has_value());
1194   std::optional<MessageRouter::MessageHub> messageHub2 =
1195       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1196   EXPECT_TRUE(messageHub2.has_value());
1197   std::optional<MessageRouter::MessageHub> messageHub3 =
1198       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1199   EXPECT_TRUE(messageHub3.has_value());
1200 
1201   // Open session from messageHub:1 to messageHub2:2
1202   SessionId sessionId = messageHub->openSession(
1203       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1204   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1205   messageHub2->onSessionOpenComplete(sessionId);
1206 
1207   // Open session from messageHub2:2 to messageHub3:3
1208   SessionId sessionId2 = messageHub2->openSession(
1209       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1210   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1211   messageHub3->onSessionOpenComplete(sessionId2);
1212 
1213   // Open session from messageHub3:3 to messageHub1:1
1214   SessionId sessionId3 = messageHub3->openSession(
1215       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1216   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1217   messageHub->onSessionOpenComplete(sessionId3);
1218 
1219   // Get sessions and compare
1220   // Find session: MessageHub1:1 -> MessageHub2:2
1221   std::optional<Session> sessionAfterRegistering =
1222       messageHub->getSessionWithId(sessionId);
1223   EXPECT_TRUE(sessionAfterRegistering.has_value());
1224   std::optional<Session> sessionAfterRegistering2 =
1225       messageHub2->getSessionWithId(sessionId);
1226   EXPECT_TRUE(sessionAfterRegistering2.has_value());
1227   EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
1228   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1229 
1230   // Find session: MessageHub2:2 -> MessageHub3:3
1231   sessionAfterRegistering = messageHub2->getSessionWithId(sessionId2);
1232   EXPECT_TRUE(sessionAfterRegistering.has_value());
1233   sessionAfterRegistering2 = messageHub3->getSessionWithId(sessionId2);
1234   EXPECT_TRUE(sessionAfterRegistering2.has_value());
1235   EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1236   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1237 
1238   // Find session: MessageHub3:3 -> MessageHub1:1
1239   sessionAfterRegistering = messageHub3->getSessionWithId(sessionId3);
1240   EXPECT_TRUE(sessionAfterRegistering.has_value());
1241   sessionAfterRegistering2 = messageHub->getSessionWithId(sessionId3);
1242   EXPECT_TRUE(sessionAfterRegistering2.has_value());
1243   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId3).has_value());
1244   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1245 
1246   // Close sessions from receivers and verify they are closed on all hubs
1247   EXPECT_TRUE(messageHub2->closeSession(sessionId));
1248   EXPECT_TRUE(messageHub3->closeSession(sessionId2));
1249   EXPECT_TRUE(messageHub->closeSession(sessionId3));
1250   for (SessionId id : {sessionId, sessionId2, sessionId3}) {
1251     EXPECT_FALSE(messageHub->getSessionWithId(id).has_value());
1252     EXPECT_FALSE(messageHub2->getSessionWithId(id).has_value());
1253     EXPECT_FALSE(messageHub3->getSessionWithId(id).has_value());
1254   }
1255 }
1256 
TEST_F(MessageRouterTest,SendMessageToSession)1257 TEST_F(MessageRouterTest, SendMessageToSession) {
1258   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1259   constexpr size_t kMessageSize = 5;
1260   pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1261   pw::UniquePtr<std::byte[]> messageData =
1262       allocator.MakeUniqueArray<std::byte>(kMessageSize);
1263   for (size_t i = 0; i < 5; ++i) {
1264     messageData[i] = static_cast<std::byte>(i + 1);
1265   }
1266 
1267   Message messageFromCallback1;
1268   Message messageFromCallback2;
1269   Message messageFromCallback3;
1270   Session sessionFromCallback1;
1271   Session sessionFromCallback2;
1272   Session sessionFromCallback3;
1273   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1274       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1275                                                       &sessionFromCallback1);
1276   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1277       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1278                                                       &sessionFromCallback2);
1279   pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1280       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback3,
1281                                                       &sessionFromCallback3);
1282 
1283   std::optional<MessageRouter::MessageHub> messageHub =
1284       router.registerMessageHub("hub1", /* id= */ 1, callback);
1285   EXPECT_TRUE(messageHub.has_value());
1286   std::optional<MessageRouter::MessageHub> messageHub2 =
1287       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1288   EXPECT_TRUE(messageHub2.has_value());
1289   std::optional<MessageRouter::MessageHub> messageHub3 =
1290       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1291   EXPECT_TRUE(messageHub3.has_value());
1292 
1293   // Open session from messageHub:1 to messageHub2:2
1294   SessionId sessionId = messageHub->openSession(
1295       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1296   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1297   messageHub2->onSessionOpenComplete(sessionId);
1298 
1299   // Open session from messageHub2:2 to messageHub3:3
1300   SessionId sessionId2 = messageHub2->openSession(
1301       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1302   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1303   messageHub3->onSessionOpenComplete(sessionId2);
1304 
1305   // Open session from messageHub3:3 to messageHub1:1
1306   SessionId sessionId3 = messageHub3->openSession(
1307       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1308   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1309   messageHub->onSessionOpenComplete(sessionId3);
1310 
1311   // Send message from messageHub:1 to messageHub2:2
1312   ASSERT_TRUE(messageHub->sendMessage(std::move(messageData),
1313                                       /* messageType= */ 1,
1314                                       /* messagePermissions= */ 0, sessionId));
1315   EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
1316   EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
1317   EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
1318   EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
1319   EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
1320   EXPECT_EQ(messageFromCallback2.messageType, 1);
1321   EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
1322   EXPECT_EQ(messageFromCallback2.data.size(), kMessageSize);
1323   for (size_t i = 0; i < kMessageSize; ++i) {
1324     EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
1325   }
1326 
1327   messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1328   for (size_t i = 0; i < 5; ++i) {
1329     messageData[i] = static_cast<std::byte>(i + 1);
1330   }
1331 
1332   // Send message from messageHub2:2 to messageHub:1
1333   ASSERT_TRUE(messageHub2->sendMessage(std::move(messageData),
1334                                        /* messageType= */ 2,
1335                                        /* messagePermissions= */ 3, sessionId));
1336   EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
1337   EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
1338   EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
1339   EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
1340   EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
1341   EXPECT_EQ(messageFromCallback1.messageType, 2);
1342   EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
1343   EXPECT_EQ(messageFromCallback1.data.size(), kMessageSize);
1344   for (size_t i = 0; i < kMessageSize; ++i) {
1345     EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
1346   }
1347 }
1348 
TEST_F(MessageRouterTest,SendMessageOnHalfOpenSessionIsRejected)1349 TEST_F(MessageRouterTest, SendMessageOnHalfOpenSessionIsRejected) {
1350   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1351   constexpr size_t kMessageSize = 5;
1352   pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1353   pw::UniquePtr<std::byte[]> messageData =
1354       allocator.MakeUniqueArray<std::byte>(kMessageSize);
1355   for (size_t i = 0; i < 5; ++i) {
1356     messageData[i] = static_cast<std::byte>(i + 1);
1357   }
1358 
1359   Message messageFromCallback1;
1360   Message messageFromCallback2;
1361   Session sessionFromCallback1;
1362   Session sessionFromCallback2;
1363   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1364       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1365                                                       &sessionFromCallback1);
1366   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1367       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1368                                                       &sessionFromCallback2);
1369 
1370   std::optional<MessageRouter::MessageHub> messageHub =
1371       router.registerMessageHub("hub1", /* id= */ 1, callback);
1372   EXPECT_TRUE(messageHub.has_value());
1373   std::optional<MessageRouter::MessageHub> messageHub2 =
1374       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1375   EXPECT_TRUE(messageHub2.has_value());
1376 
1377   // Open session from messageHub:1 to messageHub2:2 but do not complete it
1378   SessionId sessionId = messageHub->openSession(
1379       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1380   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1381 
1382   // Try to send a message from messageHub:1 to messageHub2:2 - should fail
1383   EXPECT_FALSE(messageHub->sendMessage(std::move(messageData),
1384                                        /* messageType= */ 1,
1385                                        /* messagePermissions= */ 0, sessionId));
1386 
1387   // Now complete the session
1388   messageHub2->onSessionOpenComplete(sessionId);
1389 
1390   // Send message from messageHub:1 to messageHub2:2
1391   messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1392   for (size_t i = 0; i < 5; ++i) {
1393     messageData[i] = static_cast<std::byte>(i + 1);
1394   }
1395 
1396   ASSERT_TRUE(messageHub->sendMessage(std::move(messageData),
1397                                       /* messageType= */ 1,
1398                                       /* messagePermissions= */ 0, sessionId));
1399   EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
1400   EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
1401   EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
1402   EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
1403   EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
1404   EXPECT_EQ(messageFromCallback2.messageType, 1);
1405   EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
1406   EXPECT_EQ(messageFromCallback2.data.size(), kMessageSize);
1407   for (size_t i = 0; i < kMessageSize; ++i) {
1408     EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
1409   }
1410 
1411   messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1412   for (size_t i = 0; i < 5; ++i) {
1413     messageData[i] = static_cast<std::byte>(i + 1);
1414   }
1415 
1416   // Send message from messageHub2:2 to messageHub:1
1417   ASSERT_TRUE(messageHub2->sendMessage(std::move(messageData),
1418                                        /* messageType= */ 2,
1419                                        /* messagePermissions= */ 3, sessionId));
1420   EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
1421   EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
1422   EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
1423   EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
1424   EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
1425   EXPECT_EQ(messageFromCallback1.messageType, 2);
1426   EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
1427   EXPECT_EQ(messageFromCallback1.data.size(), kMessageSize);
1428   for (size_t i = 0; i < kMessageSize; ++i) {
1429     EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
1430   }
1431 }
1432 
TEST_F(MessageRouterTest,SendMessageToSessionUsingPointerAndFreeCallback)1433 TEST_F(MessageRouterTest, SendMessageToSessionUsingPointerAndFreeCallback) {
1434   struct FreeCallbackContext {
1435     bool *freeCallbackCalled;
1436     std::byte *message;
1437     size_t length;
1438   };
1439 
1440   pw::Vector<CallbackAllocator<FreeCallbackContext>::CallbackRecord, 10>
1441       freeCallbackRecords;
1442   CallbackAllocator<FreeCallbackContext> allocator(
1443       [](std::byte *message, size_t length, FreeCallbackContext &&context) {
1444         *context.freeCallbackCalled =
1445             message == context.message && length == context.length;
1446       },
1447       freeCallbackRecords);
1448 
1449   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1450   constexpr size_t kMessageSize = 5;
1451   std::byte messageData[kMessageSize];
1452   for (size_t i = 0; i < 5; ++i) {
1453     messageData[i] = static_cast<std::byte>(i + 1);
1454   }
1455 
1456   Message messageFromCallback1;
1457   Message messageFromCallback2;
1458   Message messageFromCallback3;
1459   Session sessionFromCallback1;
1460   Session sessionFromCallback2;
1461   Session sessionFromCallback3;
1462   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1463       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1464                                                       &sessionFromCallback1);
1465   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1466       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1467                                                       &sessionFromCallback2);
1468   pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1469       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback3,
1470                                                       &sessionFromCallback3);
1471 
1472   std::optional<MessageRouter::MessageHub> messageHub =
1473       router.registerMessageHub("hub1", /* id= */ 1, callback);
1474   EXPECT_TRUE(messageHub.has_value());
1475   std::optional<MessageRouter::MessageHub> messageHub2 =
1476       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1477   EXPECT_TRUE(messageHub2.has_value());
1478   std::optional<MessageRouter::MessageHub> messageHub3 =
1479       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1480   EXPECT_TRUE(messageHub3.has_value());
1481 
1482   // Open session from messageHub:1 to messageHub2:2
1483   SessionId sessionId = messageHub->openSession(
1484       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1485   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1486   messageHub2->onSessionOpenComplete(sessionId);
1487 
1488   // Open session from messageHub2:2 to messageHub3:3
1489   SessionId sessionId2 = messageHub2->openSession(
1490       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1491   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1492   messageHub3->onSessionOpenComplete(sessionId2);
1493 
1494   // Open session from messageHub3:3 to messageHub1:1
1495   SessionId sessionId3 = messageHub3->openSession(
1496       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1497   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1498   messageHub->onSessionOpenComplete(sessionId3);
1499 
1500   // Send message from messageHub:1 to messageHub2:2
1501   bool freeCallbackCalled = false;
1502   FreeCallbackContext freeCallbackContext = {
1503       .freeCallbackCalled = &freeCallbackCalled,
1504       .message = messageData,
1505       .length = kMessageSize,
1506   };
1507   pw::UniquePtr<std::byte[]> data = allocator.MakeUniqueArrayWithCallback(
1508       messageData, kMessageSize, std::move(freeCallbackContext));
1509   ASSERT_NE(data.get(), nullptr);
1510 
1511   ASSERT_TRUE(messageHub->sendMessage(std::move(data),
1512                                       /* messageType= */ 1,
1513                                       /* messagePermissions= */ 0, sessionId));
1514   EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
1515   EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
1516   EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
1517   EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
1518   EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
1519   EXPECT_EQ(messageFromCallback2.messageType, 1);
1520   EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
1521   EXPECT_EQ(messageFromCallback2.data.size(), kMessageSize);
1522   for (size_t i = 0; i < kMessageSize; ++i) {
1523     EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
1524   }
1525 
1526   // Check if free callback was called
1527   EXPECT_FALSE(freeCallbackCalled);
1528   EXPECT_EQ(messageFromCallback2.data.get(), messageData);
1529   messageFromCallback2.data.Reset();
1530   EXPECT_TRUE(freeCallbackCalled);
1531 
1532   // Send message from messageHub2:2 to messageHub:1
1533   freeCallbackCalled = false;
1534   FreeCallbackContext freeCallbackContext2 = {
1535       .freeCallbackCalled = &freeCallbackCalled,
1536       .message = messageData,
1537       .length = kMessageSize,
1538   };
1539   data = allocator.MakeUniqueArrayWithCallback(messageData, kMessageSize,
1540                                                std::move(freeCallbackContext2));
1541   ASSERT_NE(data.get(), nullptr);
1542 
1543   ASSERT_TRUE(messageHub2->sendMessage(std::move(data),
1544                                        /* messageType= */ 2,
1545                                        /* messagePermissions= */ 3, sessionId));
1546   EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
1547   EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
1548   EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
1549   EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
1550   EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
1551   EXPECT_EQ(messageFromCallback1.messageType, 2);
1552   EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
1553   EXPECT_EQ(messageFromCallback1.data.size(), kMessageSize);
1554   for (size_t i = 0; i < kMessageSize; ++i) {
1555     EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
1556   }
1557 
1558   // Check if free callback was called
1559   EXPECT_FALSE(freeCallbackCalled);
1560   EXPECT_EQ(messageFromCallback1.data.get(), messageData);
1561   messageFromCallback1.data.Reset();
1562   EXPECT_TRUE(freeCallbackCalled);
1563 }
1564 
TEST_F(MessageRouterTest,SendMessageToSessionInvalidHubAndSession)1565 TEST_F(MessageRouterTest, SendMessageToSessionInvalidHubAndSession) {
1566   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1567   constexpr size_t kMessageSize = 5;
1568   pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1569   pw::UniquePtr<std::byte[]> messageData =
1570       allocator.MakeUniqueArray<std::byte>(kMessageSize);
1571   for (size_t i = 0; i < 5; ++i) {
1572     messageData[i] = static_cast<std::byte>(i + 1);
1573   }
1574 
1575   Message messageFromCallback1;
1576   Message messageFromCallback2;
1577   Message messageFromCallback3;
1578   Session sessionFromCallback1;
1579   Session sessionFromCallback2;
1580   Session sessionFromCallback3;
1581   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1582       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1583                                                       &sessionFromCallback1);
1584   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1585       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1586                                                       &sessionFromCallback2);
1587   pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1588       pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback3,
1589                                                       &sessionFromCallback3);
1590 
1591   std::optional<MessageRouter::MessageHub> messageHub =
1592       router.registerMessageHub("hub1", /* id= */ 1, callback);
1593   EXPECT_TRUE(messageHub.has_value());
1594   std::optional<MessageRouter::MessageHub> messageHub2 =
1595       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1596   EXPECT_TRUE(messageHub2.has_value());
1597   std::optional<MessageRouter::MessageHub> messageHub3 =
1598       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1599   EXPECT_TRUE(messageHub3.has_value());
1600 
1601   // Open session from messageHub:1 to messageHub2:2
1602   SessionId sessionId = messageHub->openSession(
1603       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1604   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1605   messageHub2->onSessionOpenComplete(sessionId);
1606 
1607   // Open session from messageHub2:2 to messageHub3:3
1608   SessionId sessionId2 = messageHub2->openSession(
1609       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1610   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1611   messageHub3->onSessionOpenComplete(sessionId2);
1612 
1613   // Open session from messageHub3:3 to messageHub1:1
1614   SessionId sessionId3 = messageHub3->openSession(
1615       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1616   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1617   messageHub->onSessionOpenComplete(sessionId3);
1618 
1619   // Send message from messageHub:1 to messageHub2:2
1620   EXPECT_FALSE(messageHub->sendMessage(std::move(messageData),
1621                                        /* messageType= */ 1,
1622                                        /* messagePermissions= */ 0,
1623                                        sessionId2));
1624   EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData),
1625                                         /* messageType= */ 2,
1626                                         /* messagePermissions= */ 3,
1627                                         sessionId3));
1628   EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData),
1629                                         /* messageType= */ 2,
1630                                         /* messagePermissions= */ 3,
1631                                         sessionId));
1632 }
1633 
TEST_F(MessageRouterTest,SendMessageToSessionCallbackFailureClosesSession)1634 TEST_F(MessageRouterTest, SendMessageToSessionCallbackFailureClosesSession) {
1635   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1636   constexpr size_t kMessageSize = 5;
1637   pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1638   pw::UniquePtr<std::byte[]> messageData =
1639       allocator.MakeUniqueArray<std::byte>(kMessageSize);
1640   for (size_t i = 0; i < 5; ++i) {
1641     messageData[i] = static_cast<std::byte>(i + 1);
1642   }
1643 
1644   bool wasMessageReceivedCalled1 = false;
1645   bool wasMessageReceivedCalled2 = false;
1646   bool wasMessageReceivedCalled3 = false;
1647   pw::IntrusivePtr<MessageHubCallbackAlwaysFails> callback1 =
1648       pw::MakeRefCounted<MessageHubCallbackAlwaysFails>(
1649           &wasMessageReceivedCalled1,
1650           /* wasSessionClosedCalled= */ nullptr);
1651   pw::IntrusivePtr<MessageHubCallbackAlwaysFails> callback2 =
1652       pw::MakeRefCounted<MessageHubCallbackAlwaysFails>(
1653           &wasMessageReceivedCalled2,
1654           /* wasSessionClosedCalled= */ nullptr);
1655   pw::IntrusivePtr<MessageHubCallbackAlwaysFails> callback3 =
1656       pw::MakeRefCounted<MessageHubCallbackAlwaysFails>(
1657           &wasMessageReceivedCalled3,
1658           /* wasSessionClosedCalled= */ nullptr);
1659 
1660   std::optional<MessageRouter::MessageHub> messageHub =
1661       router.registerMessageHub("hub1", /* id= */ 1, callback1);
1662   EXPECT_TRUE(messageHub.has_value());
1663   std::optional<MessageRouter::MessageHub> messageHub2 =
1664       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1665   EXPECT_TRUE(messageHub2.has_value());
1666   std::optional<MessageRouter::MessageHub> messageHub3 =
1667       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1668   EXPECT_TRUE(messageHub3.has_value());
1669 
1670   // Open session from messageHub:1 to messageHub2:2
1671   SessionId sessionId = messageHub->openSession(
1672       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1673   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1674   messageHub2->onSessionOpenComplete(sessionId);
1675 
1676   // Open session from messageHub2:2 to messageHub3:3
1677   SessionId sessionId2 = messageHub2->openSession(
1678       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1679   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1680   messageHub3->onSessionOpenComplete(sessionId2);
1681 
1682   // Open session from messageHub3:3 to messageHub1:1
1683   SessionId sessionId3 = messageHub3->openSession(
1684       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1685   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1686   messageHub->onSessionOpenComplete(sessionId3);
1687 
1688   // Send message from messageHub2:2 to messageHub3:3
1689   EXPECT_FALSE(wasMessageReceivedCalled1);
1690   EXPECT_FALSE(wasMessageReceivedCalled2);
1691   EXPECT_FALSE(wasMessageReceivedCalled3);
1692   EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1693   EXPECT_TRUE(messageHub2->getSessionWithId(sessionId2).has_value());
1694   EXPECT_TRUE(messageHub3->getSessionWithId(sessionId2).has_value());
1695 
1696   EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData),
1697                                         /* messageType= */ 1,
1698                                         /* messagePermissions= */ 0,
1699                                         sessionId2));
1700   EXPECT_FALSE(wasMessageReceivedCalled1);
1701   EXPECT_FALSE(wasMessageReceivedCalled2);
1702   EXPECT_TRUE(wasMessageReceivedCalled3);
1703   EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1704   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId2).has_value());
1705   EXPECT_FALSE(messageHub3->getSessionWithId(sessionId2).has_value());
1706 
1707   // Try to send a message on the same session - should fail
1708   wasMessageReceivedCalled1 = false;
1709   wasMessageReceivedCalled2 = false;
1710   wasMessageReceivedCalled3 = false;
1711   messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1712   for (size_t i = 0; i < 5; ++i) {
1713     messageData[i] = static_cast<std::byte>(i + 1);
1714   }
1715   EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData),
1716                                         /* messageType= */ 1,
1717                                         /* messagePermissions= */ 0,
1718                                         sessionId2));
1719   messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1720   for (size_t i = 0; i < 5; ++i) {
1721     messageData[i] = static_cast<std::byte>(i + 1);
1722   }
1723   EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData),
1724                                         /* messageType= */ 1,
1725                                         /* messagePermissions= */ 0,
1726                                         sessionId2));
1727   EXPECT_FALSE(wasMessageReceivedCalled1);
1728   EXPECT_FALSE(wasMessageReceivedCalled2);
1729   EXPECT_FALSE(wasMessageReceivedCalled3);
1730 }
1731 
TEST_F(MessageRouterTest,MessageHubCallbackCanCallOtherMessageHubAPIs)1732 TEST_F(MessageRouterTest, MessageHubCallbackCanCallOtherMessageHubAPIs) {
1733   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1734   constexpr size_t kMessageSize = 5;
1735   pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1736   pw::UniquePtr<std::byte[]> messageData =
1737       allocator.MakeUniqueArray<std::byte>(kMessageSize);
1738   for (size_t i = 0; i < 5; ++i) {
1739     messageData[i] = static_cast<std::byte>(i + 1);
1740   }
1741 
1742   pw::IntrusivePtr<MessageHubCallbackCallsMessageHubApisDuringCallback>
1743       callback = pw::MakeRefCounted<
1744           MessageHubCallbackCallsMessageHubApisDuringCallback>();
1745   pw::IntrusivePtr<MessageHubCallbackCallsMessageHubApisDuringCallback>
1746       callback2 = pw::MakeRefCounted<
1747           MessageHubCallbackCallsMessageHubApisDuringCallback>();
1748   pw::IntrusivePtr<MessageHubCallbackCallsMessageHubApisDuringCallback>
1749       callback3 = pw::MakeRefCounted<
1750           MessageHubCallbackCallsMessageHubApisDuringCallback>();
1751 
1752   std::optional<MessageRouter::MessageHub> messageHub =
1753       router.registerMessageHub("hub1", /* id= */ 1, callback);
1754   EXPECT_TRUE(messageHub.has_value());
1755   callback->setMessageHub(&messageHub.value());
1756   std::optional<MessageRouter::MessageHub> messageHub2 =
1757       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1758   EXPECT_TRUE(messageHub2.has_value());
1759   callback2->setMessageHub(&messageHub2.value());
1760   std::optional<MessageRouter::MessageHub> messageHub3 =
1761       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1762   EXPECT_TRUE(messageHub3.has_value());
1763   callback3->setMessageHub(&messageHub3.value());
1764 
1765   // Open session from messageHub:1 to messageHub2:2
1766   SessionId sessionId = messageHub->openSession(
1767       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1768   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1769   messageHub2->onSessionOpenComplete(sessionId);
1770 
1771   // Open session from messageHub2:2 to messageHub3:3
1772   SessionId sessionId2 = messageHub2->openSession(
1773       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1774   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1775   messageHub3->onSessionOpenComplete(sessionId2);
1776 
1777   // Open session from messageHub3:3 to messageHub1:1
1778   SessionId sessionId3 = messageHub3->openSession(
1779       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1780   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1781   messageHub->onSessionOpenComplete(sessionId3);
1782 
1783   // Send message from messageHub:1 to messageHub2:2
1784   EXPECT_TRUE(messageHub->sendMessage(std::move(messageData),
1785                                       /* messageType= */ 1,
1786                                       /* messagePermissions= */ 0, sessionId));
1787 
1788   // Send message from messageHub2:2 to messageHub:1
1789   messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1790   for (size_t i = 0; i < 5; ++i) {
1791     messageData[i] = static_cast<std::byte>(i + 1);
1792   }
1793   EXPECT_TRUE(messageHub2->sendMessage(std::move(messageData),
1794                                        /* messageType= */ 2,
1795                                        /* messagePermissions= */ 3, sessionId));
1796 
1797   // Close all sessions
1798   EXPECT_TRUE(messageHub->closeSession(sessionId));
1799   EXPECT_TRUE(messageHub2->closeSession(sessionId2));
1800   EXPECT_TRUE(messageHub3->closeSession(sessionId3));
1801 
1802   // If we finish the test, both callbacks should have been called
1803   // If the router holds the lock during the callback, this test will timeout
1804 }
1805 
TEST_F(MessageRouterTest,ForEachEndpointOfHub)1806 TEST_F(MessageRouterTest, ForEachEndpointOfHub) {
1807   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1808   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1809       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1810                                                       /* session= */ nullptr);
1811   std::optional<MessageRouter::MessageHub> messageHub =
1812       router.registerMessageHub("hub1", /* id= */ 1, callback);
1813   EXPECT_TRUE(messageHub.has_value());
1814 
1815   DynamicVector<EndpointInfo> endpoints;
1816   EXPECT_TRUE(router.forEachEndpointOfHub(
1817       /* messageHubId= */ 1, [&endpoints](const EndpointInfo &info) {
1818         endpoints.push_back(info);
1819         return false;
1820       }));
1821   EXPECT_EQ(endpoints.size(), kNumEndpoints);
1822   for (size_t i = 0; i < endpoints.size(); ++i) {
1823     EXPECT_EQ(endpoints[i].id, kEndpointInfos[i].id);
1824     EXPECT_STREQ(endpoints[i].name, kEndpointInfos[i].name);
1825     EXPECT_EQ(endpoints[i].version, kEndpointInfos[i].version);
1826     EXPECT_EQ(endpoints[i].type, kEndpointInfos[i].type);
1827     EXPECT_EQ(endpoints[i].requiredPermissions,
1828               kEndpointInfos[i].requiredPermissions);
1829   }
1830 }
1831 
TEST_F(MessageRouterTest,ForEachEndpoint)1832 TEST_F(MessageRouterTest, ForEachEndpoint) {
1833   const char *kHubName = "hub1";
1834   constexpr MessageHubId kHubId = 1;
1835 
1836   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1837   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1838       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1839                                                       /* session= */ nullptr);
1840   std::optional<MessageRouter::MessageHub> messageHub =
1841       router.registerMessageHub(kHubName, kHubId, callback);
1842   EXPECT_TRUE(messageHub.has_value());
1843 
1844   DynamicVector<std::pair<MessageHubInfo, EndpointInfo>> endpoints;
1845   router.forEachEndpoint(
1846       [&endpoints](const MessageHubInfo &hubInfo, const EndpointInfo &info) {
1847         endpoints.push_back(std::make_pair(hubInfo, info));
1848       });
1849   EXPECT_EQ(endpoints.size(), kNumEndpoints);
1850   for (size_t i = 0; i < endpoints.size(); ++i) {
1851     EXPECT_EQ(endpoints[i].first.id, kHubId);
1852     EXPECT_STREQ(endpoints[i].first.name, kHubName);
1853 
1854     EXPECT_EQ(endpoints[i].second.id, kEndpointInfos[i].id);
1855     EXPECT_STREQ(endpoints[i].second.name, kEndpointInfos[i].name);
1856     EXPECT_EQ(endpoints[i].second.version, kEndpointInfos[i].version);
1857     EXPECT_EQ(endpoints[i].second.type, kEndpointInfos[i].type);
1858     EXPECT_EQ(endpoints[i].second.requiredPermissions,
1859               kEndpointInfos[i].requiredPermissions);
1860   }
1861 }
1862 
TEST_F(MessageRouterTest,ForEachEndpointOfHubInvalidHub)1863 TEST_F(MessageRouterTest, ForEachEndpointOfHubInvalidHub) {
1864   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1865   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1866       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1867                                                       /* session= */ nullptr);
1868   std::optional<MessageRouter::MessageHub> messageHub =
1869       router.registerMessageHub("hub1", /* id= */ 1, callback);
1870   EXPECT_TRUE(messageHub.has_value());
1871 
1872   DynamicVector<EndpointInfo> endpoints;
1873   EXPECT_FALSE(router.forEachEndpointOfHub(
1874       /* messageHubId= */ 2, [&endpoints](const EndpointInfo &info) {
1875         endpoints.push_back(info);
1876         return false;
1877       }));
1878   EXPECT_EQ(endpoints.size(), 0);
1879 }
1880 
TEST_F(MessageRouterTest,RegisterEndpointCallbacksAreCalled)1881 TEST_F(MessageRouterTest, RegisterEndpointCallbacksAreCalled) {
1882   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1883   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1884       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1885                                                       /* session= */ nullptr);
1886   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1887       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1888                                                       /* session= */ nullptr);
1889   std::optional<MessageRouter::MessageHub> messageHub =
1890       router.registerMessageHub("hub1", /* id= */ 1, callback);
1891   EXPECT_TRUE(messageHub.has_value());
1892   std::optional<MessageRouter::MessageHub> messageHub2 =
1893       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1894   EXPECT_TRUE(messageHub.has_value());
1895 
1896   // Register the endpoint and verify that the callbacks were called
1897   EXPECT_TRUE(messageHub->registerEndpoint(kEndpointInfos[0].id));
1898   EXPECT_TRUE(callback2->hasEndpointBeenRegistered(messageHub->getId(),
1899                                                    kEndpointInfos[0].id));
1900 }
1901 
TEST_F(MessageRouterTest,UnregisterEndpointCallbacksAreCalled)1902 TEST_F(MessageRouterTest, UnregisterEndpointCallbacksAreCalled) {
1903   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1904   pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1905       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1906                                                       /* session= */ nullptr);
1907   pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1908       pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1909                                                       /* session= */ nullptr);
1910   std::optional<MessageRouter::MessageHub> messageHub =
1911       router.registerMessageHub("hub1", /* id= */ 1, callback);
1912   EXPECT_TRUE(messageHub.has_value());
1913   std::optional<MessageRouter::MessageHub> messageHub2 =
1914       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1915   EXPECT_TRUE(messageHub.has_value());
1916 
1917   // Register the endpoint and verify that the callbacks were called
1918   // only on the other hub
1919   EXPECT_TRUE(messageHub->registerEndpoint(kEndpointInfos[0].id));
1920   EXPECT_FALSE(callback->hasEndpointBeenRegistered(messageHub->getId(),
1921                                                    kEndpointInfos[0].id));
1922   EXPECT_TRUE(callback2->hasEndpointBeenRegistered(messageHub->getId(),
1923                                                    kEndpointInfos[0].id));
1924 
1925   // Unregister the endpoint and verify that the callbacks were called
1926   // only on the other hub
1927   EXPECT_TRUE(messageHub->unregisterEndpoint(kEndpointInfos[0].id));
1928   EXPECT_FALSE(callback->hasEndpointBeenRegistered(messageHub->getId(),
1929                                                    kEndpointInfos[0].id));
1930   EXPECT_FALSE(callback2->hasEndpointBeenRegistered(messageHub->getId(),
1931                                                     kEndpointInfos[0].id));
1932 }
1933 
1934 MATCHER_P(HubMatcher, id, "Matches id in MessageHubInfo") {
1935   return arg.id == id;
1936 }
1937 
TEST_F(MessageRouterTest,OnRegisterAndUnregisterHub)1938 TEST_F(MessageRouterTest, OnRegisterAndUnregisterHub) {
1939   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1940   pw::IntrusivePtr<MockMessageHubCallback> hub1Callback =
1941       pw::MakeRefCounted<MockMessageHubCallback>();
1942   pw::IntrusivePtr<MockMessageHubCallback> hub2Callback =
1943       pw::MakeRefCounted<MockMessageHubCallback>();
1944   MessageHubId hub1Id = 1, hub2Id = 2;
1945   std::optional<MessageRouter::MessageHub> hub1 =
1946       router.registerMessageHub("hub1", hub1Id, hub1Callback);
1947   ASSERT_TRUE(hub1.has_value());
1948 
1949   EXPECT_CALL(*hub1Callback, onHubRegistered(HubMatcher(hub2Id)));
1950   std::optional<MessageRouter::MessageHub> hub2 =
1951       router.registerMessageHub("hub2", hub2Id, hub2Callback);
1952   ASSERT_TRUE(hub2.has_value());
1953 
1954   EXPECT_CALL(*hub1Callback, onHubUnregistered(hub2Id));
1955   hub2.reset();
1956 }
1957 
1958 MATCHER_P(SessionIdMatcher, id, "Matches id in Session") {
1959   return arg.sessionId == id;
1960 }
1961 
TEST_F(MessageRouterTest,SessionCallbacksAreCalledOnceSameHub)1962 TEST_F(MessageRouterTest, SessionCallbacksAreCalledOnceSameHub) {
1963   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1964   pw::IntrusivePtr<MockMessageHubCallback> hub1Callback =
1965       pw::MakeRefCounted<MockMessageHubCallback>();
1966   MessageHubId hub1Id = 1;
1967   std::optional<MessageRouter::MessageHub> hub1 =
1968       router.registerMessageHub("hub1", hub1Id, hub1Callback);
1969   ASSERT_TRUE(hub1.has_value());
1970 
1971   ON_CALL(*hub1Callback, forEachEndpoint).WillByDefault(forEachEndpoint);
1972 
1973   // Try with different endpoints
1974   SessionId sessionId = hub1->openSession(kEndpointInfos[0].id, hub1->getId(),
1975                                           kEndpointInfos[1].id);
1976   ASSERT_NE(sessionId, SESSION_ID_INVALID);
1977 
1978   EXPECT_CALL(*hub1Callback, onSessionOpened(_)).Times(1);
1979   hub1->onSessionOpenComplete(sessionId);
1980 
1981   EXPECT_CALL(*hub1Callback, onSessionClosed(SessionIdMatcher(sessionId), _))
1982       .Times(1);
1983   hub1->closeSession(sessionId);
1984 
1985   // Try with the same endpoint
1986   SessionId sessionId2 = hub1->openSession(kEndpointInfos[1].id, hub1->getId(),
1987                                            kEndpointInfos[1].id);
1988   ASSERT_NE(sessionId2, SESSION_ID_INVALID);
1989 
1990   EXPECT_CALL(*hub1Callback, onSessionOpened(_)).Times(1);
1991   hub1->onSessionOpenComplete(sessionId2);
1992 
1993   EXPECT_CALL(*hub1Callback, onSessionClosed(SessionIdMatcher(sessionId2), _))
1994       .Times(1);
1995   hub1->closeSession(sessionId2);
1996 }
1997 
1998 }  // namespace
1999 }  // namespace chre::message
2000