1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "chre/util/system/message_router.h"
18
19 #include "chre/util/dynamic_vector.h"
20 #include "chre/util/system/callback_allocator.h"
21 #include "chre/util/system/message_common.h"
22 #include "chre/util/system/message_router_mocks.h"
23 #include "chre_api/chre.h"
24
25 #include "pw_allocator/libc_allocator.h"
26 #include "pw_allocator/unique_ptr.h"
27 #include "pw_intrusive_ptr/intrusive_ptr.h"
28
29 #include "gmock/gmock.h"
30 #include "gtest/gtest.h"
31
32 #include <cstddef>
33 #include <cstdint>
34 #include <optional>
35 #include <utility>
36
37 using ::testing::_;
38
39 namespace chre::message {
40 namespace {
41
42 constexpr size_t kMaxMessageHubs = 3;
43 constexpr size_t kMaxSessions = 10;
44 constexpr size_t kMaxFreeCallbackRecords = kMaxSessions * 2;
45 constexpr size_t kNumEndpoints = 3;
46
47 const EndpointInfo kEndpointInfos[kNumEndpoints] = {
48 EndpointInfo(/* id= */ 1, /* name= */ "endpoint1", /* version= */ 1,
49 EndpointType::NANOAPP, CHRE_MESSAGE_PERMISSION_NONE),
50 EndpointInfo(/* id= */ 2, /* name= */ "endpoint2", /* version= */ 10,
51 EndpointType::HOST_NATIVE, CHRE_MESSAGE_PERMISSION_BLE),
52 EndpointInfo(/* id= */ 3, /* name= */ "endpoint3", /* version= */ 100,
53 EndpointType::GENERIC, CHRE_MESSAGE_PERMISSION_AUDIO)};
54 const char kServiceDescriptorForEndpoint2[] = "TEST_SERVICE.TEST";
55
56 class MessageRouterTest : public ::testing::Test {};
57
58 //! Iterates over the endpoints
forEachEndpoint(const pw::Function<bool (const EndpointInfo &)> & function)59 void forEachEndpoint(const pw::Function<bool(const EndpointInfo &)> &function) {
60 for (const EndpointInfo &endpointInfo : kEndpointInfos) {
61 if (function(endpointInfo)) {
62 return;
63 }
64 }
65 }
66
67 //! Base class for MessageHubCallbacks used in tests
68 class MessageHubCallbackBase : public MessageRouter::MessageHubCallback {
69 public:
forEachEndpoint(const pw::Function<bool (const EndpointInfo &)> & function)70 void forEachEndpoint(
71 const pw::Function<bool(const EndpointInfo &)> &function) override {
72 ::chre::message::forEachEndpoint(function);
73 }
74
getEndpointInfo(EndpointId endpointId)75 std::optional<EndpointInfo> getEndpointInfo(EndpointId endpointId) override {
76 for (const EndpointInfo &endpointInfo : kEndpointInfos) {
77 if (endpointInfo.id == endpointId) {
78 return endpointInfo;
79 }
80 }
81 return std::nullopt;
82 }
83
onSessionOpenRequest(const Session &)84 void onSessionOpenRequest(const Session & /* session */) override {}
85
getEndpointForService(const char * serviceDescriptor)86 std::optional<EndpointId> getEndpointForService(
87 const char *serviceDescriptor) override {
88 if (serviceDescriptor != nullptr &&
89 std::strcmp(serviceDescriptor, kServiceDescriptorForEndpoint2) == 0) {
90 return kEndpointInfos[1].id;
91 }
92 return std::nullopt;
93 }
94
doesEndpointHaveService(EndpointId endpointId,const char * serviceDescriptor)95 bool doesEndpointHaveService(EndpointId endpointId,
96 const char *serviceDescriptor) override {
97 return serviceDescriptor != nullptr && endpointId == kEndpointInfos[1].id &&
98 std::strcmp(serviceDescriptor, kServiceDescriptorForEndpoint2) == 0;
99 }
100
forEachService(const pw::Function<bool (const EndpointInfo &,const ServiceInfo &)> & function)101 void forEachService(
102 const pw::Function<bool(const EndpointInfo &, const ServiceInfo &)>
103 &function) override {
104 function(kEndpointInfos[1],
105 ServiceInfo(kServiceDescriptorForEndpoint2, /* majorVersion= */ 1,
106 /* minorVersion= */ 0, RpcFormat::CUSTOM));
107 }
108
onHubRegistered(const MessageHubInfo &)109 void onHubRegistered(const MessageHubInfo & /* info */) override {}
110
onHubUnregistered(MessageHubId)111 void onHubUnregistered(MessageHubId /* id */) override {}
112
onEndpointRegistered(MessageHubId,EndpointId)113 void onEndpointRegistered(MessageHubId /* messageHubId */,
114 EndpointId /* endpointId */) override {}
115
onEndpointUnregistered(MessageHubId,EndpointId)116 void onEndpointUnregistered(MessageHubId /* messageHubId */,
117 EndpointId /* endpointId */) override {}
118
pw_recycle()119 void pw_recycle() override {
120 delete this;
121 }
122 };
123
124 //! MessageHubCallback that stores the data passed to onMessageReceived and
125 //! onSessionClosed
126 class MessageHubCallbackStoreData : public MessageHubCallbackBase {
127 public:
MessageHubCallbackStoreData(Message * message,Session * session,Reason * reason=nullptr,Session * openedSession=nullptr)128 MessageHubCallbackStoreData(Message *message, Session *session,
129 Reason *reason = nullptr,
130 Session *openedSession = nullptr)
131 : mMessage(message),
132 mSession(session),
133 mReason(reason),
134 mOpenedSession(openedSession) {}
135
onMessageReceived(pw::UniquePtr<std::byte[]> && data,uint32_t messageType,uint32_t messagePermissions,const Session & session,bool sentBySessionInitiator)136 bool onMessageReceived(pw::UniquePtr<std::byte[]> &&data,
137 uint32_t messageType, uint32_t messagePermissions,
138 const Session &session,
139 bool sentBySessionInitiator) override {
140 if (mMessage != nullptr) {
141 mMessage->sender = sentBySessionInitiator ? session.initiator
142 : session.peer;
143 mMessage->recipient =
144 sentBySessionInitiator ? session.peer : session.initiator;
145 mMessage->sessionId = session.sessionId;
146 mMessage->data = std::move(data);
147 mMessage->messageType = messageType;
148 mMessage->messagePermissions = messagePermissions;
149 }
150 return true;
151 }
152
onSessionClosed(const Session & session,Reason reason)153 void onSessionClosed(const Session &session, Reason reason) override {
154 if (mSession != nullptr) {
155 *mSession = session;
156 }
157 if (mReason != nullptr) {
158 *mReason = reason;
159 }
160 }
161
onSessionOpened(const Session & session)162 void onSessionOpened(const Session &session) override {
163 if (mOpenedSession != nullptr) {
164 *mOpenedSession = session;
165 }
166 }
167
onEndpointRegistered(MessageHubId messageHubId,EndpointId endpointId)168 void onEndpointRegistered(MessageHubId messageHubId,
169 EndpointId endpointId) override {
170 mRegisteredEndpoints.insert(std::make_pair(messageHubId, endpointId));
171 }
172
onEndpointUnregistered(MessageHubId messageHubId,EndpointId endpointId)173 void onEndpointUnregistered(MessageHubId messageHubId,
174 EndpointId endpointId) override {
175 mRegisteredEndpoints.erase(std::make_pair(messageHubId, endpointId));
176 }
177
hasEndpointBeenRegistered(MessageHubId messageHubId,EndpointId endpointId)178 bool hasEndpointBeenRegistered(MessageHubId messageHubId,
179 EndpointId endpointId) {
180 return mRegisteredEndpoints.find(std::make_pair(
181 messageHubId, endpointId)) != mRegisteredEndpoints.end();
182 }
183
184 private:
185 Message *mMessage;
186 Session *mSession;
187 Reason *mReason;
188 Session *mOpenedSession;
189 std::set<std::pair<MessageHubId, EndpointId>> mRegisteredEndpoints;
190 };
191
192 //! MessageHubCallback that always fails to process messages
193 class MessageHubCallbackAlwaysFails : public MessageHubCallbackBase {
194 public:
MessageHubCallbackAlwaysFails(bool * wasMessageReceivedCalled,bool * wasSessionClosedCalled)195 MessageHubCallbackAlwaysFails(bool *wasMessageReceivedCalled,
196 bool *wasSessionClosedCalled)
197 : mWasMessageReceivedCalled(wasMessageReceivedCalled),
198 mWasSessionClosedCalled(wasSessionClosedCalled) {}
199
onMessageReceived(pw::UniquePtr<std::byte[]> &&,uint32_t,uint32_t,const Session &,bool)200 bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
201 uint32_t /* messageType */,
202 uint32_t /* messagePermissions */,
203 const Session & /* session */,
204 bool /* sentBySessionInitiator */) override {
205 if (mWasMessageReceivedCalled != nullptr) {
206 *mWasMessageReceivedCalled = true;
207 }
208 return false;
209 }
210
onSessionClosed(const Session &,Reason)211 void onSessionClosed(const Session & /* session */,
212 Reason /* reason */) override {
213 if (mWasSessionClosedCalled != nullptr) {
214 *mWasSessionClosedCalled = true;
215 }
216 }
217
onSessionOpened(const Session &)218 void onSessionOpened(const Session & /* session */) override {}
219
220 private:
221 bool *mWasMessageReceivedCalled;
222 bool *mWasSessionClosedCalled;
223 };
224
225 //! MessageHubCallback that tracks open session requests calls
226 class MessageHubCallbackOpenSessionRequest : public MessageHubCallbackBase {
227 public:
MessageHubCallbackOpenSessionRequest(bool * wasSessionOpenRequestCalled)228 MessageHubCallbackOpenSessionRequest(bool *wasSessionOpenRequestCalled)
229 : mWasSessionOpenRequestCalled(wasSessionOpenRequestCalled) {}
230
onSessionOpenRequest(const Session &)231 void onSessionOpenRequest(const Session & /* session */) override {
232 if (mWasSessionOpenRequestCalled != nullptr) {
233 *mWasSessionOpenRequestCalled = true;
234 }
235 }
236
onMessageReceived(pw::UniquePtr<std::byte[]> &&,uint32_t,uint32_t,const Session &,bool)237 bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
238 uint32_t /* messageType */,
239 uint32_t /* messagePermissions */,
240 const Session & /* session */,
241 bool /* sentBySessionInitiator */) override {
242 return true;
243 }
244
onSessionClosed(const Session &,Reason)245 void onSessionClosed(const Session & /* session */,
246 Reason /* reason */) override {}
247
onSessionOpened(const Session &)248 void onSessionOpened(const Session & /* session */) override {}
249
250 private:
251 bool *mWasSessionOpenRequestCalled;
252 };
253
254 //! MessageHubCallback that calls MessageHub APIs during callbacks
255 class MessageHubCallbackCallsMessageHubApisDuringCallback
256 : public MessageHubCallbackBase {
257 public:
onMessageReceived(pw::UniquePtr<std::byte[]> &&,uint32_t,uint32_t,const Session &,bool)258 bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
259 uint32_t /* messageType */,
260 uint32_t /* messagePermissions */,
261 const Session & /* session */,
262 bool /* sentBySessionInitiator */) override {
263 if (mMessageHub != nullptr) {
264 // Call a function that locks the MessageRouter mutex
265 mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
266 kEndpointInfos[1].id);
267 }
268 return true;
269 }
270
onSessionClosed(const Session &,Reason)271 void onSessionClosed(const Session & /* session */,
272 Reason /* reason */) override {
273 if (mMessageHub != nullptr) {
274 // Call a function that locks the MessageRouter mutex
275 mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
276 kEndpointInfos[1].id);
277 }
278 }
279
onSessionOpened(const Session &)280 void onSessionOpened(const Session & /* session */) override {
281 if (mMessageHub != nullptr) {
282 // Call a function that locks the MessageRouter mutex
283 mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
284 kEndpointInfos[1].id);
285 }
286 }
287
setMessageHub(MessageRouter::MessageHub * messageHub)288 void setMessageHub(MessageRouter::MessageHub *messageHub) {
289 mMessageHub = messageHub;
290 }
291
292 private:
293 MessageRouter::MessageHub *mMessageHub = nullptr;
294 };
295
TEST_F(MessageRouterTest,RegisterMessageHubNameIsUnique)296 TEST_F(MessageRouterTest, RegisterMessageHubNameIsUnique) {
297 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
298
299 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
300 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
301 /* session= */ nullptr);
302 std::optional<MessageRouter::MessageHub> messageHub1 =
303 router.registerMessageHub("hub1", /* id= */ 1, callback);
304 EXPECT_TRUE(messageHub1.has_value());
305 std::optional<MessageRouter::MessageHub> messageHub2 =
306 router.registerMessageHub("hub2", /* id= */ 2, callback);
307 EXPECT_TRUE(messageHub2.has_value());
308
309 std::optional<MessageRouter::MessageHub> messageHub3 =
310 router.registerMessageHub("hub1", /* id= */ 1, callback);
311 EXPECT_FALSE(messageHub3.has_value());
312 }
313
TEST_F(MessageRouterTest,RegisterMessageHubIdIsUnique)314 TEST_F(MessageRouterTest, RegisterMessageHubIdIsUnique) {
315 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
316
317 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
318 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
319 /* session= */ nullptr);
320 std::optional<MessageRouter::MessageHub> messageHub1 =
321 router.registerMessageHub("hub1", /* id= */ 1, callback);
322 EXPECT_TRUE(messageHub1.has_value());
323 std::optional<MessageRouter::MessageHub> messageHub2 =
324 router.registerMessageHub("hub2", /* id= */ 2, callback);
325 EXPECT_TRUE(messageHub2.has_value());
326
327 std::optional<MessageRouter::MessageHub> messageHub3 =
328 router.registerMessageHub("hub3", /* id= */ 1, callback);
329 EXPECT_FALSE(messageHub3.has_value());
330 }
331
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubs)332 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubs) {
333 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
334
335 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
336 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
337 /* session= */ nullptr);
338 std::optional<MessageRouter::MessageHub> messageHub1 =
339 router.registerMessageHub("hub1", /* id= */ 1, callback);
340 EXPECT_TRUE(messageHub1.has_value());
341 std::optional<MessageRouter::MessageHub> messageHub2 =
342 router.registerMessageHub("hub2", /* id= */ 2, callback);
343 EXPECT_TRUE(messageHub2.has_value());
344 std::optional<MessageRouter::MessageHub> messageHub3 =
345 router.registerMessageHub("hub3", /* id= */ 3, callback);
346 EXPECT_TRUE(messageHub3.has_value());
347
348 DynamicVector<MessageHubInfo> messageHubs;
349 router.forEachMessageHub(
350 [&messageHubs](const MessageHubInfo &messageHubInfo) {
351 messageHubs.push_back(messageHubInfo);
352 return false;
353 });
354 EXPECT_EQ(messageHubs.size(), 3);
355 EXPECT_EQ(messageHubs[0].name, "hub1");
356 EXPECT_EQ(messageHubs[1].name, "hub2");
357 EXPECT_EQ(messageHubs[2].name, "hub3");
358 EXPECT_EQ(messageHubs[0].id, 1);
359 EXPECT_EQ(messageHubs[1].id, 2);
360 EXPECT_EQ(messageHubs[2].id, 3);
361 EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
362 EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
363 EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
364 }
365
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubsWithUnregister)366 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubsWithUnregister) {
367 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
368
369 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
370 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
371 /* session= */ nullptr);
372 std::optional<MessageRouter::MessageHub> messageHub1 =
373 router.registerMessageHub("hub1", /* id= */ 1, callback);
374 EXPECT_TRUE(messageHub1.has_value());
375 std::optional<MessageRouter::MessageHub> messageHub2 =
376 router.registerMessageHub("hub2", /* id= */ 2, callback);
377 EXPECT_TRUE(messageHub2.has_value());
378 std::optional<MessageRouter::MessageHub> messageHub3 =
379 router.registerMessageHub("hub3", /* id= */ 3, callback);
380 EXPECT_TRUE(messageHub3.has_value());
381
382 DynamicVector<MessageHubInfo> messageHubs;
383 router.forEachMessageHub(
384 [&messageHubs](const MessageHubInfo &messageHubInfo) {
385 messageHubs.push_back(messageHubInfo);
386 return false;
387 });
388 EXPECT_EQ(messageHubs.size(), 3);
389 EXPECT_EQ(messageHubs[0].name, "hub1");
390 EXPECT_EQ(messageHubs[1].name, "hub2");
391 EXPECT_EQ(messageHubs[2].name, "hub3");
392 EXPECT_EQ(messageHubs[0].id, 1);
393 EXPECT_EQ(messageHubs[1].id, 2);
394 EXPECT_EQ(messageHubs[2].id, 3);
395 EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
396 EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
397 EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
398
399 // Clear messageHubs and reset messageHub2
400 messageHubs.clear();
401 messageHub2.reset();
402
403 router.forEachMessageHub(
404 [&messageHubs](const MessageHubInfo &messageHubInfo) {
405 messageHubs.push_back(messageHubInfo);
406 return false;
407 });
408 EXPECT_EQ(messageHubs.size(), 2);
409 EXPECT_EQ(messageHubs[0].name, "hub1");
410 EXPECT_EQ(messageHubs[1].name, "hub3");
411 EXPECT_EQ(messageHubs[0].id, 1);
412 EXPECT_EQ(messageHubs[1].id, 3);
413 EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
414 EXPECT_EQ(messageHubs[1].id, messageHub3->getId());
415 }
416
TEST_F(MessageRouterTest,RegisterMessageHubTooManyFails)417 TEST_F(MessageRouterTest, RegisterMessageHubTooManyFails) {
418 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
419 static_assert(kMaxMessageHubs == 3);
420 constexpr const char *kNames[3] = {"hub1", "hub2", "hub3"};
421
422 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
423 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
424 /* session= */ nullptr);
425 MessageRouter::MessageHub messageHubs[kMaxMessageHubs];
426 for (size_t i = 0; i < kMaxMessageHubs; ++i) {
427 std::optional<MessageRouter::MessageHub> messageHub =
428 router.registerMessageHub(kNames[i], /* id= */ i, callback);
429 EXPECT_TRUE(messageHub.has_value());
430 messageHubs[i] = std::move(*messageHub);
431 }
432
433 std::optional<MessageRouter::MessageHub> messageHub =
434 router.registerMessageHub("shouldfail", /* id= */ kMaxMessageHubs * 2,
435 callback);
436 EXPECT_FALSE(messageHub.has_value());
437 }
438
TEST_F(MessageRouterTest,GetEndpointInfo)439 TEST_F(MessageRouterTest, GetEndpointInfo) {
440 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
441
442 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
443 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
444 /* session= */ nullptr);
445 std::optional<MessageRouter::MessageHub> messageHub1 =
446 router.registerMessageHub("hub1", /* id= */ 1, callback);
447 EXPECT_TRUE(messageHub1.has_value());
448 std::optional<MessageRouter::MessageHub> messageHub2 =
449 router.registerMessageHub("hub2", /* id= */ 2, callback);
450 EXPECT_TRUE(messageHub2.has_value());
451 std::optional<MessageRouter::MessageHub> messageHub3 =
452 router.registerMessageHub("hub3", /* id= */ 3, callback);
453 EXPECT_TRUE(messageHub3.has_value());
454
455 for (size_t i = 0; i < kNumEndpoints; ++i) {
456 EXPECT_EQ(
457 router.getEndpointInfo(messageHub1->getId(), kEndpointInfos[i].id),
458 kEndpointInfos[i]);
459 EXPECT_EQ(
460 router.getEndpointInfo(messageHub2->getId(), kEndpointInfos[i].id),
461 kEndpointInfos[i]);
462 EXPECT_EQ(
463 router.getEndpointInfo(messageHub3->getId(), kEndpointInfos[i].id),
464 kEndpointInfos[i]);
465 }
466 }
467
TEST_F(MessageRouterTest,GetEndpointForService)468 TEST_F(MessageRouterTest, GetEndpointForService) {
469 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
470
471 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
472 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
473 /* session= */ nullptr);
474 std::optional<MessageRouter::MessageHub> messageHub1 =
475 router.registerMessageHub("hub1", /* id= */ 1, callback);
476 EXPECT_TRUE(messageHub1.has_value());
477
478 std::optional<Endpoint> endpoint = router.getEndpointForService(
479 MESSAGE_HUB_ID_INVALID, kServiceDescriptorForEndpoint2);
480 EXPECT_TRUE(endpoint.has_value());
481
482 EXPECT_EQ(endpoint->messageHubId, messageHub1->getId());
483 EXPECT_EQ(endpoint->endpointId, kEndpointInfos[1].id);
484 }
485
TEST_F(MessageRouterTest,DoesEndpointHaveService)486 TEST_F(MessageRouterTest, DoesEndpointHaveService) {
487 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
488
489 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
490 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
491 /* session= */ nullptr);
492 std::optional<MessageRouter::MessageHub> messageHub1 =
493 router.registerMessageHub("hub1", /* id= */ 1, callback);
494 EXPECT_TRUE(messageHub1.has_value());
495
496 EXPECT_TRUE(router.doesEndpointHaveService(messageHub1->getId(),
497 kEndpointInfos[1].id,
498 kServiceDescriptorForEndpoint2));
499 }
500
TEST_F(MessageRouterTest,ForEachService)501 TEST_F(MessageRouterTest, ForEachService) {
502 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
503
504 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
505 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
506 /* session= */ nullptr);
507 std::optional<MessageRouter::MessageHub> messageHub1 =
508 router.registerMessageHub("hub1", /* id= */ 1, callback);
509 EXPECT_TRUE(messageHub1.has_value());
510
511 router.forEachService([](const MessageHubInfo &hub,
512 const EndpointInfo &endpoint,
513 const ServiceInfo &service) {
514 EXPECT_EQ(hub.id, 1);
515 EXPECT_EQ(endpoint.id, kEndpointInfos[1].id);
516 EXPECT_STREQ(service.serviceDescriptor, kServiceDescriptorForEndpoint2);
517 EXPECT_EQ(service.majorVersion, 1);
518 EXPECT_EQ(service.minorVersion, 0);
519 EXPECT_EQ(service.format, RpcFormat::CUSTOM);
520 return true;
521 });
522 }
523
TEST_F(MessageRouterTest,GetEndpointForServiceBadServiceDescriptor)524 TEST_F(MessageRouterTest, GetEndpointForServiceBadServiceDescriptor) {
525 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
526
527 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
528 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
529 /* session= */ nullptr);
530 std::optional<MessageRouter::MessageHub> messageHub1 =
531 router.registerMessageHub("hub1", /* id= */ 1, callback);
532 EXPECT_TRUE(messageHub1.has_value());
533
534 std::optional<Endpoint> endpoint = router.getEndpointForService(
535 MESSAGE_HUB_ID_INVALID, "SERVICE_THAT_DOES_NOT_EXIST");
536 EXPECT_FALSE(endpoint.has_value());
537
538 std::optional<Endpoint> endpoint2 = router.getEndpointForService(
539 MESSAGE_HUB_ID_INVALID, /* serviceDescriptor= */ nullptr);
540 EXPECT_FALSE(endpoint2.has_value());
541 }
542
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubs)543 TEST_F(MessageRouterTest, RegisterSessionTwoDifferentMessageHubs) {
544 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
545 Session sessionFromCallback1;
546 Session sessionFromCallback2;
547 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
548 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
549 &sessionFromCallback1);
550 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
551 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
552 &sessionFromCallback2);
553
554 std::optional<MessageRouter::MessageHub> messageHub =
555 router.registerMessageHub("hub1", /* id= */ 1, callback);
556 EXPECT_TRUE(messageHub.has_value());
557 std::optional<MessageRouter::MessageHub> messageHub2 =
558 router.registerMessageHub("hub2", /* id= */ 2, callback2);
559 EXPECT_TRUE(messageHub2.has_value());
560
561 // Open session from messageHub:1 to messageHub2:2
562 SessionId sessionId = messageHub->openSession(
563 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
564 EXPECT_NE(sessionId, SESSION_ID_INVALID);
565 messageHub2->onSessionOpenComplete(sessionId);
566
567 // Get session from messageHub and compare it with messageHub2
568 std::optional<Session> sessionAfterRegistering =
569 messageHub->getSessionWithId(sessionId);
570 EXPECT_TRUE(sessionAfterRegistering.has_value());
571 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
572 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
573 messageHub->getId());
574 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
575 kEndpointInfos[0].id);
576 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
577 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
578 std::optional<Session> sessionAfterRegistering2 =
579 messageHub2->getSessionWithId(sessionId);
580 EXPECT_TRUE(sessionAfterRegistering2.has_value());
581 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
582
583 // Close the session and verify it is closed on both message hubs
584 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
585 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
586 EXPECT_TRUE(messageHub->closeSession(sessionId));
587 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
588 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
589 EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
590 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
591 }
592
TEST_F(MessageRouterTest,RegisterSessionVerifyAllCallbacksAreCalled)593 TEST_F(MessageRouterTest, RegisterSessionVerifyAllCallbacksAreCalled) {
594 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
595 Session sessionClosedFromCallback1;
596 Session sessionClosedFromCallback2;
597 Session sessionOpenedFromCallback1;
598 Session sessionOpenedFromCallback2;
599 Reason sessionCloseReason1;
600 Reason sessionCloseReason2;
601 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
602 pw::MakeRefCounted<MessageHubCallbackStoreData>(
603 /* message= */ nullptr, &sessionClosedFromCallback1,
604 &sessionCloseReason1, &sessionOpenedFromCallback1);
605 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
606 pw::MakeRefCounted<MessageHubCallbackStoreData>(
607 /* message= */ nullptr, &sessionClosedFromCallback2,
608 &sessionCloseReason2, &sessionOpenedFromCallback2);
609
610 std::optional<MessageRouter::MessageHub> messageHub =
611 router.registerMessageHub("hub1", /* id= */ 1, callback);
612 EXPECT_TRUE(messageHub.has_value());
613 std::optional<MessageRouter::MessageHub> messageHub2 =
614 router.registerMessageHub("hub2", /* id= */ 2, callback2);
615 EXPECT_TRUE(messageHub2.has_value());
616
617 // Open session from messageHub:1 to messageHub2:2
618 SessionId sessionId = messageHub->openSession(
619 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
620 EXPECT_NE(sessionId, SESSION_ID_INVALID);
621 messageHub2->onSessionOpenComplete(sessionId);
622
623 // Verify that onSessionOpened is called on both message hubs
624 EXPECT_EQ(sessionOpenedFromCallback1.sessionId, sessionId);
625 EXPECT_EQ(sessionOpenedFromCallback1.initiator.messageHubId,
626 messageHub->getId());
627 EXPECT_EQ(sessionOpenedFromCallback1.initiator.endpointId,
628 kEndpointInfos[0].id);
629 EXPECT_EQ(sessionOpenedFromCallback1.peer.messageHubId, messageHub2->getId());
630 EXPECT_EQ(sessionOpenedFromCallback1.peer.endpointId, kEndpointInfos[1].id);
631
632 EXPECT_EQ(sessionOpenedFromCallback2.sessionId, sessionId);
633 EXPECT_EQ(sessionOpenedFromCallback2.initiator.messageHubId,
634 messageHub->getId());
635 EXPECT_EQ(sessionOpenedFromCallback2.initiator.endpointId,
636 kEndpointInfos[0].id);
637 EXPECT_EQ(sessionOpenedFromCallback2.peer.messageHubId, messageHub2->getId());
638 EXPECT_EQ(sessionOpenedFromCallback2.peer.endpointId, kEndpointInfos[1].id);
639
640 // Close the session with a reason
641 Reason reason = Reason::TIMEOUT;
642 EXPECT_TRUE(messageHub->closeSession(sessionId, reason));
643
644 // Verify that onSessionClosed is called on both message hubs
645 EXPECT_EQ(sessionClosedFromCallback1.sessionId, sessionId);
646 EXPECT_EQ(sessionClosedFromCallback1.initiator.messageHubId,
647 messageHub->getId());
648 EXPECT_EQ(sessionClosedFromCallback1.initiator.endpointId,
649 kEndpointInfos[0].id);
650 EXPECT_EQ(sessionClosedFromCallback1.peer.messageHubId, messageHub2->getId());
651 EXPECT_EQ(sessionClosedFromCallback1.peer.endpointId, kEndpointInfos[1].id);
652 EXPECT_EQ(sessionCloseReason1, reason);
653
654 EXPECT_EQ(sessionClosedFromCallback2.sessionId, sessionId);
655 EXPECT_EQ(sessionClosedFromCallback2.initiator.messageHubId,
656 messageHub->getId());
657 EXPECT_EQ(sessionClosedFromCallback2.initiator.endpointId,
658 kEndpointInfos[0].id);
659 EXPECT_EQ(sessionClosedFromCallback2.peer.messageHubId, messageHub2->getId());
660 EXPECT_EQ(sessionClosedFromCallback2.peer.endpointId, kEndpointInfos[1].id);
661 EXPECT_EQ(sessionCloseReason2, reason);
662 }
663
TEST_F(MessageRouterTest,RegisterSessionGetsRejectedAndClosed)664 TEST_F(MessageRouterTest, RegisterSessionGetsRejectedAndClosed) {
665 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
666 Session sessionFromCallback1;
667 Session sessionFromCallback2;
668 Reason sessionCloseReason;
669 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
670 pw::MakeRefCounted<MessageHubCallbackStoreData>(
671 /* message= */ nullptr, &sessionFromCallback1, &sessionCloseReason);
672 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
673 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
674 &sessionFromCallback2);
675
676 std::optional<MessageRouter::MessageHub> messageHub =
677 router.registerMessageHub("hub1", /* id= */ 1, callback);
678 EXPECT_TRUE(messageHub.has_value());
679 std::optional<MessageRouter::MessageHub> messageHub2 =
680 router.registerMessageHub("hub2", /* id= */ 2, callback2);
681 EXPECT_TRUE(messageHub2.has_value());
682
683 // Open session from messageHub:1 to messageHub2:2
684 SessionId sessionId = messageHub->openSession(
685 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
686 EXPECT_NE(sessionId, SESSION_ID_INVALID);
687 Reason reason = Reason::OPEN_ENDPOINT_SESSION_REQUEST_REJECTED;
688 messageHub2->closeSession(sessionId, reason);
689
690 // Get session from messageHub and ensure it is deleted
691 std::optional<Session> sessionAfterRegistering =
692 messageHub->getSessionWithId(sessionId);
693 EXPECT_FALSE(sessionAfterRegistering.has_value());
694 std::optional<Session> sessionAfterRegistering2 =
695 messageHub2->getSessionWithId(sessionId);
696 EXPECT_FALSE(sessionAfterRegistering2.has_value());
697
698 // Close the session and verify it is closed on both message hubs
699 EXPECT_EQ(sessionFromCallback1.sessionId, sessionId);
700 EXPECT_EQ(sessionFromCallback1.initiator.messageHubId, messageHub->getId());
701 EXPECT_EQ(sessionFromCallback1.initiator.endpointId, kEndpointInfos[0].id);
702 EXPECT_EQ(sessionFromCallback1.peer.messageHubId, messageHub2->getId());
703 EXPECT_EQ(sessionFromCallback1.peer.endpointId, kEndpointInfos[1].id);
704 EXPECT_EQ(sessionCloseReason, reason);
705 }
706
TEST_F(MessageRouterTest,RegisterSessionSecondHubDoesNotRespond)707 TEST_F(MessageRouterTest, RegisterSessionSecondHubDoesNotRespond) {
708 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
709 bool wasOpenSessionRequestCalled = false;
710 bool wasOpenSessionRequestCalled2 = false;
711 pw::IntrusivePtr<MessageHubCallbackOpenSessionRequest> callback =
712 pw::MakeRefCounted<MessageHubCallbackOpenSessionRequest>(
713 &wasOpenSessionRequestCalled);
714 pw::IntrusivePtr<MessageHubCallbackOpenSessionRequest> callback2 =
715 pw::MakeRefCounted<MessageHubCallbackOpenSessionRequest>(
716 &wasOpenSessionRequestCalled2);
717
718 std::optional<MessageRouter::MessageHub> messageHub =
719 router.registerMessageHub("hub1", /* id= */ 1, callback);
720 EXPECT_TRUE(messageHub.has_value());
721 std::optional<MessageRouter::MessageHub> messageHub2 =
722 router.registerMessageHub("hub2", /* id= */ 2, callback2);
723 EXPECT_TRUE(messageHub2.has_value());
724
725 // Open session from messageHub:1 to messageHub2:2
726 SessionId sessionId = messageHub->openSession(
727 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
728 EXPECT_NE(sessionId, SESSION_ID_INVALID);
729
730 // Message Hub 2 does not respond - verify the callback was called once
731 EXPECT_FALSE(wasOpenSessionRequestCalled);
732 EXPECT_TRUE(wasOpenSessionRequestCalled2);
733
734 // Open session from messageHub:1 to messageHub2:2 - try again
735 wasOpenSessionRequestCalled = false;
736 SessionId sessionId2 = messageHub->openSession(
737 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
738 EXPECT_NE(sessionId, SESSION_ID_INVALID);
739 EXPECT_EQ(sessionId, sessionId2);
740 EXPECT_FALSE(wasOpenSessionRequestCalled);
741 EXPECT_TRUE(wasOpenSessionRequestCalled2);
742
743 // Respond then close the session
744 messageHub2->onSessionOpenComplete(sessionId2);
745 EXPECT_TRUE(messageHub->closeSession(sessionId));
746 }
747
TEST_F(MessageRouterTest,RegisterSessionWithServiceDescriptor)748 TEST_F(MessageRouterTest, RegisterSessionWithServiceDescriptor) {
749 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
750 Session sessionFromCallback1;
751 Session sessionFromCallback2;
752 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
753 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
754 &sessionFromCallback1);
755 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
756 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
757 &sessionFromCallback2);
758
759 std::optional<MessageRouter::MessageHub> messageHub =
760 router.registerMessageHub("hub1", /* id= */ 1, callback);
761 EXPECT_TRUE(messageHub.has_value());
762 std::optional<MessageRouter::MessageHub> messageHub2 =
763 router.registerMessageHub("hub2", /* id= */ 2, callback2);
764 EXPECT_TRUE(messageHub2.has_value());
765
766 // Open session from messageHub:1 to messageHub2:2
767 SessionId sessionId = messageHub->openSession(
768 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
769 kServiceDescriptorForEndpoint2);
770 EXPECT_NE(sessionId, SESSION_ID_INVALID);
771
772 // Get session from messageHub and compare it with messageHub2
773 std::optional<Session> sessionAfterRegistering =
774 messageHub->getSessionWithId(sessionId);
775 EXPECT_TRUE(sessionAfterRegistering.has_value());
776 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
777 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
778 messageHub->getId());
779 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
780 kEndpointInfos[0].id);
781 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
782 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
783 EXPECT_TRUE(sessionAfterRegistering->hasServiceDescriptor);
784 EXPECT_STREQ(sessionAfterRegistering->serviceDescriptor,
785 kServiceDescriptorForEndpoint2);
786 std::optional<Session> sessionAfterRegistering2 =
787 messageHub2->getSessionWithId(sessionId);
788 EXPECT_TRUE(sessionAfterRegistering2.has_value());
789 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
790
791 // Close the session and verify it is closed on both message hubs
792 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
793 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
794 EXPECT_TRUE(messageHub->closeSession(sessionId));
795 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
796 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
797 EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
798 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
799 }
800
TEST_F(MessageRouterTest,RegisterSessionWithAndWithoutServiceDescriptorSameEndpoints)801 TEST_F(MessageRouterTest,
802 RegisterSessionWithAndWithoutServiceDescriptorSameEndpoints) {
803 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
804 Session sessionFromCallback1;
805 Session sessionFromCallback2;
806 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
807 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
808 &sessionFromCallback1);
809 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
810 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
811 &sessionFromCallback2);
812
813 std::optional<MessageRouter::MessageHub> messageHub =
814 router.registerMessageHub("hub1", /* id= */ 1, callback);
815 EXPECT_TRUE(messageHub.has_value());
816 std::optional<MessageRouter::MessageHub> messageHub2 =
817 router.registerMessageHub("hub2", /* id= */ 2, callback2);
818 EXPECT_TRUE(messageHub2.has_value());
819
820 // Open session from messageHub:1 to messageHub2:2 with service descriptor
821 SessionId sessionId = messageHub->openSession(
822 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
823 kServiceDescriptorForEndpoint2);
824 EXPECT_NE(sessionId, SESSION_ID_INVALID);
825
826 // Open session from messageHub:1 to messageHub2:2 without service descriptor
827 SessionId sessionId2 = messageHub->openSession(
828 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
829 EXPECT_NE(sessionId2, SESSION_ID_INVALID);
830 EXPECT_NE(sessionId2, sessionId);
831
832 // Get the first session from messageHub and compare it with messageHub2
833 std::optional<Session> sessionAfterRegistering =
834 messageHub->getSessionWithId(sessionId);
835 EXPECT_TRUE(sessionAfterRegistering.has_value());
836 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
837 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
838 messageHub->getId());
839 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
840 kEndpointInfos[0].id);
841 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
842 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
843 EXPECT_TRUE(sessionAfterRegistering->hasServiceDescriptor);
844 EXPECT_STREQ(sessionAfterRegistering->serviceDescriptor,
845 kServiceDescriptorForEndpoint2);
846 std::optional<Session> sessionAfterRegistering2 =
847 messageHub2->getSessionWithId(sessionId);
848 EXPECT_TRUE(sessionAfterRegistering2.has_value());
849 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
850
851 // Get the second session from messageHub and compare it with messageHub2
852 std::optional<Session> sessionAfterRegistering3 =
853 messageHub->getSessionWithId(sessionId2);
854 EXPECT_TRUE(sessionAfterRegistering3.has_value());
855 EXPECT_EQ(sessionAfterRegistering3->sessionId, sessionId2);
856 EXPECT_EQ(sessionAfterRegistering3->initiator.messageHubId,
857 messageHub->getId());
858 EXPECT_EQ(sessionAfterRegistering3->initiator.endpointId,
859 kEndpointInfos[0].id);
860 EXPECT_EQ(sessionAfterRegistering3->peer.messageHubId, messageHub2->getId());
861 EXPECT_EQ(sessionAfterRegistering3->peer.endpointId, kEndpointInfos[1].id);
862 EXPECT_FALSE(sessionAfterRegistering3->hasServiceDescriptor);
863 EXPECT_STREQ(sessionAfterRegistering3->serviceDescriptor, "");
864 std::optional<Session> sessionAfterRegistering4 =
865 messageHub2->getSessionWithId(sessionId2);
866 EXPECT_TRUE(sessionAfterRegistering4.has_value());
867 EXPECT_EQ(*sessionAfterRegistering3, *sessionAfterRegistering4);
868 }
869
TEST_F(MessageRouterTest,RegisterSessionWithBadServiceDescriptor)870 TEST_F(MessageRouterTest, RegisterSessionWithBadServiceDescriptor) {
871 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
872 Session sessionFromCallback1;
873 Session sessionFromCallback2;
874 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
875 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
876 &sessionFromCallback1);
877 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
878 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
879 &sessionFromCallback2);
880
881 std::optional<MessageRouter::MessageHub> messageHub =
882 router.registerMessageHub("hub1", /* id= */ 1, callback);
883 EXPECT_TRUE(messageHub.has_value());
884 std::optional<MessageRouter::MessageHub> messageHub2 =
885 router.registerMessageHub("hub2", /* id= */ 2, callback2);
886 EXPECT_TRUE(messageHub2.has_value());
887
888 // Open session from messageHub:1 to messageHub2:2
889 SessionId sessionId = messageHub->openSession(
890 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[2].id,
891 kServiceDescriptorForEndpoint2);
892 EXPECT_EQ(sessionId, SESSION_ID_INVALID);
893 }
894
TEST_F(MessageRouterTest,UnregisterMessageHubCausesSessionClosed)895 TEST_F(MessageRouterTest, UnregisterMessageHubCausesSessionClosed) {
896 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
897 Session sessionFromCallback1;
898 Session sessionFromCallback2;
899 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
900 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
901 &sessionFromCallback1);
902 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
903 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
904 &sessionFromCallback2);
905
906 std::optional<MessageRouter::MessageHub> messageHub =
907 router.registerMessageHub("hub1", /* id= */ 1, callback);
908 EXPECT_TRUE(messageHub.has_value());
909 std::optional<MessageRouter::MessageHub> messageHub2 =
910 router.registerMessageHub("hub2", /* id= */ 2, callback2);
911 EXPECT_TRUE(messageHub2.has_value());
912
913 // Open session from messageHub:1 to messageHub2:2
914 SessionId sessionId = messageHub->openSession(
915 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
916 EXPECT_NE(sessionId, SESSION_ID_INVALID);
917 messageHub2->onSessionOpenComplete(sessionId);
918
919 // Get session from messageHub and compare it with messageHub2
920 std::optional<Session> sessionAfterRegistering =
921 messageHub->getSessionWithId(sessionId);
922 EXPECT_TRUE(sessionAfterRegistering.has_value());
923 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
924 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
925 messageHub->getId());
926 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
927 kEndpointInfos[0].id);
928 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
929 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
930 std::optional<Session> sessionAfterRegistering2 =
931 messageHub2->getSessionWithId(sessionId);
932 EXPECT_TRUE(sessionAfterRegistering2.has_value());
933 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
934
935 // Close the session and verify it is closed on the other hub
936 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
937 messageHub2.reset();
938 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
939 EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
940 }
941
TEST_F(MessageRouterTest,RegisterSessionSameMessageHubIsValid)942 TEST_F(MessageRouterTest, RegisterSessionSameMessageHubIsValid) {
943 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
944 Session sessionFromCallback1;
945 Session sessionFromCallback2;
946 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
947 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
948 &sessionFromCallback1);
949 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
950 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
951 &sessionFromCallback2);
952
953 std::optional<MessageRouter::MessageHub> messageHub =
954 router.registerMessageHub("hub1", /* id= */ 1, callback);
955 EXPECT_TRUE(messageHub.has_value());
956 std::optional<MessageRouter::MessageHub> messageHub2 =
957 router.registerMessageHub("hub2", /* id= */ 2, callback2);
958 EXPECT_TRUE(messageHub2.has_value());
959
960 // Open session from messageHub:2 to messageHub:2
961 SessionId sessionId = messageHub->openSession(
962 kEndpointInfos[1].id, messageHub->getId(), kEndpointInfos[1].id);
963 EXPECT_NE(sessionId, SESSION_ID_INVALID);
964
965 // Open session from messageHub:1 to messageHub:3
966 sessionId = messageHub->openSession(kEndpointInfos[0].id, messageHub->getId(),
967 kEndpointInfos[2].id);
968 EXPECT_NE(sessionId, SESSION_ID_INVALID);
969 }
970
TEST_F(MessageRouterTest,RegisterSessionReservedSessionIdAreRespected)971 TEST_F(MessageRouterTest, RegisterSessionReservedSessionIdAreRespected) {
972 constexpr SessionId kReservedSessionId = 25;
973 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router(
974 kReservedSessionId);
975 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
976 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
977 /* session= */ nullptr);
978 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
979 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
980 /* session= */ nullptr);
981
982 std::optional<MessageRouter::MessageHub> messageHub =
983 router.registerMessageHub("hub1", /* id= */ 1, callback);
984 EXPECT_TRUE(messageHub.has_value());
985 std::optional<MessageRouter::MessageHub> messageHub2 =
986 router.registerMessageHub("hub2", /* id= */ 2, callback2);
987 EXPECT_TRUE(messageHub2.has_value());
988
989 // Open session from messageHub:1 to messageHub:2 more than the max number of
990 // sessions - should wrap around
991 for (size_t i = 0; i < kReservedSessionId * 2; ++i) {
992 SessionId sessionId = messageHub->openSession(
993 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
994 EXPECT_NE(sessionId, SESSION_ID_INVALID);
995 messageHub2->onSessionOpenComplete(sessionId);
996 EXPECT_TRUE(messageHub->closeSession(sessionId));
997 }
998 }
999
TEST_F(MessageRouterTest,RegisterSessionOpenSessionNotReservedRegionRejected)1000 TEST_F(MessageRouterTest, RegisterSessionOpenSessionNotReservedRegionRejected) {
1001 constexpr SessionId kReservedSessionId = 25;
1002 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router(
1003 kReservedSessionId);
1004 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1005 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1006 /* session= */ nullptr);
1007 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1008 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1009 /* session= */ nullptr);
1010
1011 std::optional<MessageRouter::MessageHub> messageHub =
1012 router.registerMessageHub("hub1", /* id= */ 1, callback);
1013 EXPECT_TRUE(messageHub.has_value());
1014 std::optional<MessageRouter::MessageHub> messageHub2 =
1015 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1016 EXPECT_TRUE(messageHub2.has_value());
1017
1018 // Open session from messageHub:1 to messageHub:2 and provide an invalid
1019 // session ID (not in the reserved range)
1020 SessionId sessionId = messageHub->openSession(
1021 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
1022 /* serviceDescriptor= */ nullptr, kReservedSessionId / 2);
1023 EXPECT_EQ(sessionId, SESSION_ID_INVALID);
1024 }
1025
TEST_F(MessageRouterTest,RegisterSessionOpenSessionWithReservedSessionId)1026 TEST_F(MessageRouterTest, RegisterSessionOpenSessionWithReservedSessionId) {
1027 constexpr SessionId kReservedSessionId = 25;
1028 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router(
1029 kReservedSessionId);
1030 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1031 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1032 /* session= */ nullptr);
1033 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1034 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1035 /* session= */ nullptr);
1036
1037 std::optional<MessageRouter::MessageHub> messageHub =
1038 router.registerMessageHub("hub1", /* id= */ 1, callback);
1039 EXPECT_TRUE(messageHub.has_value());
1040 std::optional<MessageRouter::MessageHub> messageHub2 =
1041 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1042 EXPECT_TRUE(messageHub2.has_value());
1043
1044 // Open session from messageHub:1 to messageHub:2 and provide a reserved
1045 // session ID
1046 SessionId sessionId = messageHub->openSession(
1047 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id,
1048 /* serviceDescriptor= */ nullptr, kReservedSessionId);
1049 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1050 messageHub2->onSessionOpenComplete(sessionId);
1051 EXPECT_TRUE(messageHub->closeSession(sessionId));
1052 }
1053
TEST_F(MessageRouterTest,RegisterSessionDifferentMessageHubsSameEndpoints)1054 TEST_F(MessageRouterTest, RegisterSessionDifferentMessageHubsSameEndpoints) {
1055 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1056 Session sessionFromCallback1;
1057 Session sessionFromCallback2;
1058 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1059 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1060 &sessionFromCallback1);
1061 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1062 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1063 &sessionFromCallback2);
1064
1065 std::optional<MessageRouter::MessageHub> messageHub =
1066 router.registerMessageHub("hub1", /* id= */ 1, callback);
1067 EXPECT_TRUE(messageHub.has_value());
1068 std::optional<MessageRouter::MessageHub> messageHub2 =
1069 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1070 EXPECT_TRUE(messageHub2.has_value());
1071
1072 // Open session from messageHub:1 to messageHub:2
1073 SessionId sessionId = messageHub->openSession(
1074 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[0].id);
1075 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1076 messageHub2->onSessionOpenComplete(sessionId);
1077 }
1078
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubsInvalidEndpoint)1079 TEST_F(MessageRouterTest,
1080 RegisterSessionTwoDifferentMessageHubsInvalidEndpoint) {
1081 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1082 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1083 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1084 /* session= */ nullptr);
1085 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1086 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1087 /* session= */ nullptr);
1088
1089 std::optional<MessageRouter::MessageHub> messageHub =
1090 router.registerMessageHub("hub1", /* id= */ 1, callback);
1091 EXPECT_TRUE(messageHub.has_value());
1092 std::optional<MessageRouter::MessageHub> messageHub2 =
1093 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1094 EXPECT_TRUE(messageHub2.has_value());
1095
1096 // Open session from messageHub with other non-registered endpoint - not
1097 // valid
1098 SessionId sessionId = messageHub->openSession(
1099 kEndpointInfos[0].id, messageHub2->getId(), /* toEndpointId= */ 10);
1100 EXPECT_EQ(sessionId, SESSION_ID_INVALID);
1101 }
1102
TEST_F(MessageRouterTest,ThirdMessageHubTriesToFindOthersSession)1103 TEST_F(MessageRouterTest, ThirdMessageHubTriesToFindOthersSession) {
1104 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1105 Session sessionFromCallback1;
1106 Session sessionFromCallback2;
1107 Session sessionFromCallback3;
1108 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1109 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1110 &sessionFromCallback1);
1111 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1112 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1113 &sessionFromCallback2);
1114 pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1115 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1116 &sessionFromCallback3);
1117
1118 std::optional<MessageRouter::MessageHub> messageHub =
1119 router.registerMessageHub("hub1", /* id= */ 1, callback);
1120 EXPECT_TRUE(messageHub.has_value());
1121 std::optional<MessageRouter::MessageHub> messageHub2 =
1122 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1123 EXPECT_TRUE(messageHub2.has_value());
1124 std::optional<MessageRouter::MessageHub> messageHub3 =
1125 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1126 EXPECT_TRUE(messageHub3.has_value());
1127
1128 // Open session from messageHub:1 to messageHub2:2
1129 SessionId sessionId = messageHub->openSession(
1130 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1131 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1132
1133 // Get session from messageHub and compare it with messageHub2
1134 std::optional<Session> sessionAfterRegistering =
1135 messageHub->getSessionWithId(sessionId);
1136 EXPECT_TRUE(sessionAfterRegistering.has_value());
1137 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
1138 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
1139 messageHub->getId());
1140 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
1141 kEndpointInfos[0].id);
1142 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
1143 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
1144 std::optional<Session> sessionAfterRegistering2 =
1145 messageHub2->getSessionWithId(sessionId);
1146 EXPECT_TRUE(sessionAfterRegistering2.has_value());
1147 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1148
1149 // Third message hub tries to find the session - not found
1150 EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
1151 // Third message hub tries to close the session - not found
1152 EXPECT_FALSE(messageHub3->closeSession(sessionId));
1153
1154 // Get session from messageHub and compare it with messageHub2 again
1155 sessionAfterRegistering = messageHub->getSessionWithId(sessionId);
1156 EXPECT_TRUE(sessionAfterRegistering.has_value());
1157 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
1158 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
1159 messageHub->getId());
1160 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
1161 kEndpointInfos[0].id);
1162 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
1163 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
1164 sessionAfterRegistering2 = messageHub2->getSessionWithId(sessionId);
1165 EXPECT_TRUE(sessionAfterRegistering2.has_value());
1166 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1167
1168 // Close the session and verify it is closed on both message hubs
1169 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
1170 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
1171 EXPECT_TRUE(messageHub->closeSession(sessionId));
1172 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
1173 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
1174 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback3);
1175 EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
1176 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
1177 }
1178
TEST_F(MessageRouterTest,ThreeMessageHubsAndThreeSessions)1179 TEST_F(MessageRouterTest, ThreeMessageHubsAndThreeSessions) {
1180 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1181 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1182 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1183 /* session= */ nullptr);
1184 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1185 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1186 /* session= */ nullptr);
1187 pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1188 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1189 /* session= */ nullptr);
1190
1191 std::optional<MessageRouter::MessageHub> messageHub =
1192 router.registerMessageHub("hub1", /* id= */ 1, callback);
1193 EXPECT_TRUE(messageHub.has_value());
1194 std::optional<MessageRouter::MessageHub> messageHub2 =
1195 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1196 EXPECT_TRUE(messageHub2.has_value());
1197 std::optional<MessageRouter::MessageHub> messageHub3 =
1198 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1199 EXPECT_TRUE(messageHub3.has_value());
1200
1201 // Open session from messageHub:1 to messageHub2:2
1202 SessionId sessionId = messageHub->openSession(
1203 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1204 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1205 messageHub2->onSessionOpenComplete(sessionId);
1206
1207 // Open session from messageHub2:2 to messageHub3:3
1208 SessionId sessionId2 = messageHub2->openSession(
1209 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1210 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1211 messageHub3->onSessionOpenComplete(sessionId2);
1212
1213 // Open session from messageHub3:3 to messageHub1:1
1214 SessionId sessionId3 = messageHub3->openSession(
1215 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1216 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1217 messageHub->onSessionOpenComplete(sessionId3);
1218
1219 // Get sessions and compare
1220 // Find session: MessageHub1:1 -> MessageHub2:2
1221 std::optional<Session> sessionAfterRegistering =
1222 messageHub->getSessionWithId(sessionId);
1223 EXPECT_TRUE(sessionAfterRegistering.has_value());
1224 std::optional<Session> sessionAfterRegistering2 =
1225 messageHub2->getSessionWithId(sessionId);
1226 EXPECT_TRUE(sessionAfterRegistering2.has_value());
1227 EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
1228 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1229
1230 // Find session: MessageHub2:2 -> MessageHub3:3
1231 sessionAfterRegistering = messageHub2->getSessionWithId(sessionId2);
1232 EXPECT_TRUE(sessionAfterRegistering.has_value());
1233 sessionAfterRegistering2 = messageHub3->getSessionWithId(sessionId2);
1234 EXPECT_TRUE(sessionAfterRegistering2.has_value());
1235 EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1236 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1237
1238 // Find session: MessageHub3:3 -> MessageHub1:1
1239 sessionAfterRegistering = messageHub3->getSessionWithId(sessionId3);
1240 EXPECT_TRUE(sessionAfterRegistering.has_value());
1241 sessionAfterRegistering2 = messageHub->getSessionWithId(sessionId3);
1242 EXPECT_TRUE(sessionAfterRegistering2.has_value());
1243 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId3).has_value());
1244 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
1245
1246 // Close sessions from receivers and verify they are closed on all hubs
1247 EXPECT_TRUE(messageHub2->closeSession(sessionId));
1248 EXPECT_TRUE(messageHub3->closeSession(sessionId2));
1249 EXPECT_TRUE(messageHub->closeSession(sessionId3));
1250 for (SessionId id : {sessionId, sessionId2, sessionId3}) {
1251 EXPECT_FALSE(messageHub->getSessionWithId(id).has_value());
1252 EXPECT_FALSE(messageHub2->getSessionWithId(id).has_value());
1253 EXPECT_FALSE(messageHub3->getSessionWithId(id).has_value());
1254 }
1255 }
1256
TEST_F(MessageRouterTest,SendMessageToSession)1257 TEST_F(MessageRouterTest, SendMessageToSession) {
1258 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1259 constexpr size_t kMessageSize = 5;
1260 pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1261 pw::UniquePtr<std::byte[]> messageData =
1262 allocator.MakeUniqueArray<std::byte>(kMessageSize);
1263 for (size_t i = 0; i < 5; ++i) {
1264 messageData[i] = static_cast<std::byte>(i + 1);
1265 }
1266
1267 Message messageFromCallback1;
1268 Message messageFromCallback2;
1269 Message messageFromCallback3;
1270 Session sessionFromCallback1;
1271 Session sessionFromCallback2;
1272 Session sessionFromCallback3;
1273 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1274 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1275 &sessionFromCallback1);
1276 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1277 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1278 &sessionFromCallback2);
1279 pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1280 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback3,
1281 &sessionFromCallback3);
1282
1283 std::optional<MessageRouter::MessageHub> messageHub =
1284 router.registerMessageHub("hub1", /* id= */ 1, callback);
1285 EXPECT_TRUE(messageHub.has_value());
1286 std::optional<MessageRouter::MessageHub> messageHub2 =
1287 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1288 EXPECT_TRUE(messageHub2.has_value());
1289 std::optional<MessageRouter::MessageHub> messageHub3 =
1290 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1291 EXPECT_TRUE(messageHub3.has_value());
1292
1293 // Open session from messageHub:1 to messageHub2:2
1294 SessionId sessionId = messageHub->openSession(
1295 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1296 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1297 messageHub2->onSessionOpenComplete(sessionId);
1298
1299 // Open session from messageHub2:2 to messageHub3:3
1300 SessionId sessionId2 = messageHub2->openSession(
1301 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1302 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1303 messageHub3->onSessionOpenComplete(sessionId2);
1304
1305 // Open session from messageHub3:3 to messageHub1:1
1306 SessionId sessionId3 = messageHub3->openSession(
1307 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1308 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1309 messageHub->onSessionOpenComplete(sessionId3);
1310
1311 // Send message from messageHub:1 to messageHub2:2
1312 ASSERT_TRUE(messageHub->sendMessage(std::move(messageData),
1313 /* messageType= */ 1,
1314 /* messagePermissions= */ 0, sessionId));
1315 EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
1316 EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
1317 EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
1318 EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
1319 EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
1320 EXPECT_EQ(messageFromCallback2.messageType, 1);
1321 EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
1322 EXPECT_EQ(messageFromCallback2.data.size(), kMessageSize);
1323 for (size_t i = 0; i < kMessageSize; ++i) {
1324 EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
1325 }
1326
1327 messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1328 for (size_t i = 0; i < 5; ++i) {
1329 messageData[i] = static_cast<std::byte>(i + 1);
1330 }
1331
1332 // Send message from messageHub2:2 to messageHub:1
1333 ASSERT_TRUE(messageHub2->sendMessage(std::move(messageData),
1334 /* messageType= */ 2,
1335 /* messagePermissions= */ 3, sessionId));
1336 EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
1337 EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
1338 EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
1339 EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
1340 EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
1341 EXPECT_EQ(messageFromCallback1.messageType, 2);
1342 EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
1343 EXPECT_EQ(messageFromCallback1.data.size(), kMessageSize);
1344 for (size_t i = 0; i < kMessageSize; ++i) {
1345 EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
1346 }
1347 }
1348
TEST_F(MessageRouterTest,SendMessageOnHalfOpenSessionIsRejected)1349 TEST_F(MessageRouterTest, SendMessageOnHalfOpenSessionIsRejected) {
1350 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1351 constexpr size_t kMessageSize = 5;
1352 pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1353 pw::UniquePtr<std::byte[]> messageData =
1354 allocator.MakeUniqueArray<std::byte>(kMessageSize);
1355 for (size_t i = 0; i < 5; ++i) {
1356 messageData[i] = static_cast<std::byte>(i + 1);
1357 }
1358
1359 Message messageFromCallback1;
1360 Message messageFromCallback2;
1361 Session sessionFromCallback1;
1362 Session sessionFromCallback2;
1363 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1364 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1365 &sessionFromCallback1);
1366 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1367 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1368 &sessionFromCallback2);
1369
1370 std::optional<MessageRouter::MessageHub> messageHub =
1371 router.registerMessageHub("hub1", /* id= */ 1, callback);
1372 EXPECT_TRUE(messageHub.has_value());
1373 std::optional<MessageRouter::MessageHub> messageHub2 =
1374 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1375 EXPECT_TRUE(messageHub2.has_value());
1376
1377 // Open session from messageHub:1 to messageHub2:2 but do not complete it
1378 SessionId sessionId = messageHub->openSession(
1379 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1380 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1381
1382 // Try to send a message from messageHub:1 to messageHub2:2 - should fail
1383 EXPECT_FALSE(messageHub->sendMessage(std::move(messageData),
1384 /* messageType= */ 1,
1385 /* messagePermissions= */ 0, sessionId));
1386
1387 // Now complete the session
1388 messageHub2->onSessionOpenComplete(sessionId);
1389
1390 // Send message from messageHub:1 to messageHub2:2
1391 messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1392 for (size_t i = 0; i < 5; ++i) {
1393 messageData[i] = static_cast<std::byte>(i + 1);
1394 }
1395
1396 ASSERT_TRUE(messageHub->sendMessage(std::move(messageData),
1397 /* messageType= */ 1,
1398 /* messagePermissions= */ 0, sessionId));
1399 EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
1400 EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
1401 EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
1402 EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
1403 EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
1404 EXPECT_EQ(messageFromCallback2.messageType, 1);
1405 EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
1406 EXPECT_EQ(messageFromCallback2.data.size(), kMessageSize);
1407 for (size_t i = 0; i < kMessageSize; ++i) {
1408 EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
1409 }
1410
1411 messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1412 for (size_t i = 0; i < 5; ++i) {
1413 messageData[i] = static_cast<std::byte>(i + 1);
1414 }
1415
1416 // Send message from messageHub2:2 to messageHub:1
1417 ASSERT_TRUE(messageHub2->sendMessage(std::move(messageData),
1418 /* messageType= */ 2,
1419 /* messagePermissions= */ 3, sessionId));
1420 EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
1421 EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
1422 EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
1423 EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
1424 EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
1425 EXPECT_EQ(messageFromCallback1.messageType, 2);
1426 EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
1427 EXPECT_EQ(messageFromCallback1.data.size(), kMessageSize);
1428 for (size_t i = 0; i < kMessageSize; ++i) {
1429 EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
1430 }
1431 }
1432
TEST_F(MessageRouterTest,SendMessageToSessionUsingPointerAndFreeCallback)1433 TEST_F(MessageRouterTest, SendMessageToSessionUsingPointerAndFreeCallback) {
1434 struct FreeCallbackContext {
1435 bool *freeCallbackCalled;
1436 std::byte *message;
1437 size_t length;
1438 };
1439
1440 pw::Vector<CallbackAllocator<FreeCallbackContext>::CallbackRecord, 10>
1441 freeCallbackRecords;
1442 CallbackAllocator<FreeCallbackContext> allocator(
1443 [](std::byte *message, size_t length, FreeCallbackContext &&context) {
1444 *context.freeCallbackCalled =
1445 message == context.message && length == context.length;
1446 },
1447 freeCallbackRecords);
1448
1449 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1450 constexpr size_t kMessageSize = 5;
1451 std::byte messageData[kMessageSize];
1452 for (size_t i = 0; i < 5; ++i) {
1453 messageData[i] = static_cast<std::byte>(i + 1);
1454 }
1455
1456 Message messageFromCallback1;
1457 Message messageFromCallback2;
1458 Message messageFromCallback3;
1459 Session sessionFromCallback1;
1460 Session sessionFromCallback2;
1461 Session sessionFromCallback3;
1462 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1463 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1464 &sessionFromCallback1);
1465 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1466 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1467 &sessionFromCallback2);
1468 pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1469 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback3,
1470 &sessionFromCallback3);
1471
1472 std::optional<MessageRouter::MessageHub> messageHub =
1473 router.registerMessageHub("hub1", /* id= */ 1, callback);
1474 EXPECT_TRUE(messageHub.has_value());
1475 std::optional<MessageRouter::MessageHub> messageHub2 =
1476 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1477 EXPECT_TRUE(messageHub2.has_value());
1478 std::optional<MessageRouter::MessageHub> messageHub3 =
1479 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1480 EXPECT_TRUE(messageHub3.has_value());
1481
1482 // Open session from messageHub:1 to messageHub2:2
1483 SessionId sessionId = messageHub->openSession(
1484 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1485 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1486 messageHub2->onSessionOpenComplete(sessionId);
1487
1488 // Open session from messageHub2:2 to messageHub3:3
1489 SessionId sessionId2 = messageHub2->openSession(
1490 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1491 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1492 messageHub3->onSessionOpenComplete(sessionId2);
1493
1494 // Open session from messageHub3:3 to messageHub1:1
1495 SessionId sessionId3 = messageHub3->openSession(
1496 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1497 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1498 messageHub->onSessionOpenComplete(sessionId3);
1499
1500 // Send message from messageHub:1 to messageHub2:2
1501 bool freeCallbackCalled = false;
1502 FreeCallbackContext freeCallbackContext = {
1503 .freeCallbackCalled = &freeCallbackCalled,
1504 .message = messageData,
1505 .length = kMessageSize,
1506 };
1507 pw::UniquePtr<std::byte[]> data = allocator.MakeUniqueArrayWithCallback(
1508 messageData, kMessageSize, std::move(freeCallbackContext));
1509 ASSERT_NE(data.get(), nullptr);
1510
1511 ASSERT_TRUE(messageHub->sendMessage(std::move(data),
1512 /* messageType= */ 1,
1513 /* messagePermissions= */ 0, sessionId));
1514 EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
1515 EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
1516 EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
1517 EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
1518 EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
1519 EXPECT_EQ(messageFromCallback2.messageType, 1);
1520 EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
1521 EXPECT_EQ(messageFromCallback2.data.size(), kMessageSize);
1522 for (size_t i = 0; i < kMessageSize; ++i) {
1523 EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
1524 }
1525
1526 // Check if free callback was called
1527 EXPECT_FALSE(freeCallbackCalled);
1528 EXPECT_EQ(messageFromCallback2.data.get(), messageData);
1529 messageFromCallback2.data.Reset();
1530 EXPECT_TRUE(freeCallbackCalled);
1531
1532 // Send message from messageHub2:2 to messageHub:1
1533 freeCallbackCalled = false;
1534 FreeCallbackContext freeCallbackContext2 = {
1535 .freeCallbackCalled = &freeCallbackCalled,
1536 .message = messageData,
1537 .length = kMessageSize,
1538 };
1539 data = allocator.MakeUniqueArrayWithCallback(messageData, kMessageSize,
1540 std::move(freeCallbackContext2));
1541 ASSERT_NE(data.get(), nullptr);
1542
1543 ASSERT_TRUE(messageHub2->sendMessage(std::move(data),
1544 /* messageType= */ 2,
1545 /* messagePermissions= */ 3, sessionId));
1546 EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
1547 EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
1548 EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
1549 EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
1550 EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
1551 EXPECT_EQ(messageFromCallback1.messageType, 2);
1552 EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
1553 EXPECT_EQ(messageFromCallback1.data.size(), kMessageSize);
1554 for (size_t i = 0; i < kMessageSize; ++i) {
1555 EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
1556 }
1557
1558 // Check if free callback was called
1559 EXPECT_FALSE(freeCallbackCalled);
1560 EXPECT_EQ(messageFromCallback1.data.get(), messageData);
1561 messageFromCallback1.data.Reset();
1562 EXPECT_TRUE(freeCallbackCalled);
1563 }
1564
TEST_F(MessageRouterTest,SendMessageToSessionInvalidHubAndSession)1565 TEST_F(MessageRouterTest, SendMessageToSessionInvalidHubAndSession) {
1566 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1567 constexpr size_t kMessageSize = 5;
1568 pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1569 pw::UniquePtr<std::byte[]> messageData =
1570 allocator.MakeUniqueArray<std::byte>(kMessageSize);
1571 for (size_t i = 0; i < 5; ++i) {
1572 messageData[i] = static_cast<std::byte>(i + 1);
1573 }
1574
1575 Message messageFromCallback1;
1576 Message messageFromCallback2;
1577 Message messageFromCallback3;
1578 Session sessionFromCallback1;
1579 Session sessionFromCallback2;
1580 Session sessionFromCallback3;
1581 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1582 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback1,
1583 &sessionFromCallback1);
1584 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1585 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback2,
1586 &sessionFromCallback2);
1587 pw::IntrusivePtr<MessageHubCallbackStoreData> callback3 =
1588 pw::MakeRefCounted<MessageHubCallbackStoreData>(&messageFromCallback3,
1589 &sessionFromCallback3);
1590
1591 std::optional<MessageRouter::MessageHub> messageHub =
1592 router.registerMessageHub("hub1", /* id= */ 1, callback);
1593 EXPECT_TRUE(messageHub.has_value());
1594 std::optional<MessageRouter::MessageHub> messageHub2 =
1595 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1596 EXPECT_TRUE(messageHub2.has_value());
1597 std::optional<MessageRouter::MessageHub> messageHub3 =
1598 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1599 EXPECT_TRUE(messageHub3.has_value());
1600
1601 // Open session from messageHub:1 to messageHub2:2
1602 SessionId sessionId = messageHub->openSession(
1603 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1604 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1605 messageHub2->onSessionOpenComplete(sessionId);
1606
1607 // Open session from messageHub2:2 to messageHub3:3
1608 SessionId sessionId2 = messageHub2->openSession(
1609 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1610 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1611 messageHub3->onSessionOpenComplete(sessionId2);
1612
1613 // Open session from messageHub3:3 to messageHub1:1
1614 SessionId sessionId3 = messageHub3->openSession(
1615 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1616 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1617 messageHub->onSessionOpenComplete(sessionId3);
1618
1619 // Send message from messageHub:1 to messageHub2:2
1620 EXPECT_FALSE(messageHub->sendMessage(std::move(messageData),
1621 /* messageType= */ 1,
1622 /* messagePermissions= */ 0,
1623 sessionId2));
1624 EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData),
1625 /* messageType= */ 2,
1626 /* messagePermissions= */ 3,
1627 sessionId3));
1628 EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData),
1629 /* messageType= */ 2,
1630 /* messagePermissions= */ 3,
1631 sessionId));
1632 }
1633
TEST_F(MessageRouterTest,SendMessageToSessionCallbackFailureClosesSession)1634 TEST_F(MessageRouterTest, SendMessageToSessionCallbackFailureClosesSession) {
1635 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1636 constexpr size_t kMessageSize = 5;
1637 pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1638 pw::UniquePtr<std::byte[]> messageData =
1639 allocator.MakeUniqueArray<std::byte>(kMessageSize);
1640 for (size_t i = 0; i < 5; ++i) {
1641 messageData[i] = static_cast<std::byte>(i + 1);
1642 }
1643
1644 bool wasMessageReceivedCalled1 = false;
1645 bool wasMessageReceivedCalled2 = false;
1646 bool wasMessageReceivedCalled3 = false;
1647 pw::IntrusivePtr<MessageHubCallbackAlwaysFails> callback1 =
1648 pw::MakeRefCounted<MessageHubCallbackAlwaysFails>(
1649 &wasMessageReceivedCalled1,
1650 /* wasSessionClosedCalled= */ nullptr);
1651 pw::IntrusivePtr<MessageHubCallbackAlwaysFails> callback2 =
1652 pw::MakeRefCounted<MessageHubCallbackAlwaysFails>(
1653 &wasMessageReceivedCalled2,
1654 /* wasSessionClosedCalled= */ nullptr);
1655 pw::IntrusivePtr<MessageHubCallbackAlwaysFails> callback3 =
1656 pw::MakeRefCounted<MessageHubCallbackAlwaysFails>(
1657 &wasMessageReceivedCalled3,
1658 /* wasSessionClosedCalled= */ nullptr);
1659
1660 std::optional<MessageRouter::MessageHub> messageHub =
1661 router.registerMessageHub("hub1", /* id= */ 1, callback1);
1662 EXPECT_TRUE(messageHub.has_value());
1663 std::optional<MessageRouter::MessageHub> messageHub2 =
1664 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1665 EXPECT_TRUE(messageHub2.has_value());
1666 std::optional<MessageRouter::MessageHub> messageHub3 =
1667 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1668 EXPECT_TRUE(messageHub3.has_value());
1669
1670 // Open session from messageHub:1 to messageHub2:2
1671 SessionId sessionId = messageHub->openSession(
1672 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1673 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1674 messageHub2->onSessionOpenComplete(sessionId);
1675
1676 // Open session from messageHub2:2 to messageHub3:3
1677 SessionId sessionId2 = messageHub2->openSession(
1678 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1679 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1680 messageHub3->onSessionOpenComplete(sessionId2);
1681
1682 // Open session from messageHub3:3 to messageHub1:1
1683 SessionId sessionId3 = messageHub3->openSession(
1684 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1685 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1686 messageHub->onSessionOpenComplete(sessionId3);
1687
1688 // Send message from messageHub2:2 to messageHub3:3
1689 EXPECT_FALSE(wasMessageReceivedCalled1);
1690 EXPECT_FALSE(wasMessageReceivedCalled2);
1691 EXPECT_FALSE(wasMessageReceivedCalled3);
1692 EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1693 EXPECT_TRUE(messageHub2->getSessionWithId(sessionId2).has_value());
1694 EXPECT_TRUE(messageHub3->getSessionWithId(sessionId2).has_value());
1695
1696 EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData),
1697 /* messageType= */ 1,
1698 /* messagePermissions= */ 0,
1699 sessionId2));
1700 EXPECT_FALSE(wasMessageReceivedCalled1);
1701 EXPECT_FALSE(wasMessageReceivedCalled2);
1702 EXPECT_TRUE(wasMessageReceivedCalled3);
1703 EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1704 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId2).has_value());
1705 EXPECT_FALSE(messageHub3->getSessionWithId(sessionId2).has_value());
1706
1707 // Try to send a message on the same session - should fail
1708 wasMessageReceivedCalled1 = false;
1709 wasMessageReceivedCalled2 = false;
1710 wasMessageReceivedCalled3 = false;
1711 messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1712 for (size_t i = 0; i < 5; ++i) {
1713 messageData[i] = static_cast<std::byte>(i + 1);
1714 }
1715 EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData),
1716 /* messageType= */ 1,
1717 /* messagePermissions= */ 0,
1718 sessionId2));
1719 messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1720 for (size_t i = 0; i < 5; ++i) {
1721 messageData[i] = static_cast<std::byte>(i + 1);
1722 }
1723 EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData),
1724 /* messageType= */ 1,
1725 /* messagePermissions= */ 0,
1726 sessionId2));
1727 EXPECT_FALSE(wasMessageReceivedCalled1);
1728 EXPECT_FALSE(wasMessageReceivedCalled2);
1729 EXPECT_FALSE(wasMessageReceivedCalled3);
1730 }
1731
TEST_F(MessageRouterTest,MessageHubCallbackCanCallOtherMessageHubAPIs)1732 TEST_F(MessageRouterTest, MessageHubCallbackCanCallOtherMessageHubAPIs) {
1733 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1734 constexpr size_t kMessageSize = 5;
1735 pw::allocator::LibCAllocator allocator = pw::allocator::GetLibCAllocator();
1736 pw::UniquePtr<std::byte[]> messageData =
1737 allocator.MakeUniqueArray<std::byte>(kMessageSize);
1738 for (size_t i = 0; i < 5; ++i) {
1739 messageData[i] = static_cast<std::byte>(i + 1);
1740 }
1741
1742 pw::IntrusivePtr<MessageHubCallbackCallsMessageHubApisDuringCallback>
1743 callback = pw::MakeRefCounted<
1744 MessageHubCallbackCallsMessageHubApisDuringCallback>();
1745 pw::IntrusivePtr<MessageHubCallbackCallsMessageHubApisDuringCallback>
1746 callback2 = pw::MakeRefCounted<
1747 MessageHubCallbackCallsMessageHubApisDuringCallback>();
1748 pw::IntrusivePtr<MessageHubCallbackCallsMessageHubApisDuringCallback>
1749 callback3 = pw::MakeRefCounted<
1750 MessageHubCallbackCallsMessageHubApisDuringCallback>();
1751
1752 std::optional<MessageRouter::MessageHub> messageHub =
1753 router.registerMessageHub("hub1", /* id= */ 1, callback);
1754 EXPECT_TRUE(messageHub.has_value());
1755 callback->setMessageHub(&messageHub.value());
1756 std::optional<MessageRouter::MessageHub> messageHub2 =
1757 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1758 EXPECT_TRUE(messageHub2.has_value());
1759 callback2->setMessageHub(&messageHub2.value());
1760 std::optional<MessageRouter::MessageHub> messageHub3 =
1761 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1762 EXPECT_TRUE(messageHub3.has_value());
1763 callback3->setMessageHub(&messageHub3.value());
1764
1765 // Open session from messageHub:1 to messageHub2:2
1766 SessionId sessionId = messageHub->openSession(
1767 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1768 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1769 messageHub2->onSessionOpenComplete(sessionId);
1770
1771 // Open session from messageHub2:2 to messageHub3:3
1772 SessionId sessionId2 = messageHub2->openSession(
1773 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1774 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1775 messageHub3->onSessionOpenComplete(sessionId2);
1776
1777 // Open session from messageHub3:3 to messageHub1:1
1778 SessionId sessionId3 = messageHub3->openSession(
1779 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1780 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1781 messageHub->onSessionOpenComplete(sessionId3);
1782
1783 // Send message from messageHub:1 to messageHub2:2
1784 EXPECT_TRUE(messageHub->sendMessage(std::move(messageData),
1785 /* messageType= */ 1,
1786 /* messagePermissions= */ 0, sessionId));
1787
1788 // Send message from messageHub2:2 to messageHub:1
1789 messageData = allocator.MakeUniqueArray<std::byte>(kMessageSize);
1790 for (size_t i = 0; i < 5; ++i) {
1791 messageData[i] = static_cast<std::byte>(i + 1);
1792 }
1793 EXPECT_TRUE(messageHub2->sendMessage(std::move(messageData),
1794 /* messageType= */ 2,
1795 /* messagePermissions= */ 3, sessionId));
1796
1797 // Close all sessions
1798 EXPECT_TRUE(messageHub->closeSession(sessionId));
1799 EXPECT_TRUE(messageHub2->closeSession(sessionId2));
1800 EXPECT_TRUE(messageHub3->closeSession(sessionId3));
1801
1802 // If we finish the test, both callbacks should have been called
1803 // If the router holds the lock during the callback, this test will timeout
1804 }
1805
TEST_F(MessageRouterTest,ForEachEndpointOfHub)1806 TEST_F(MessageRouterTest, ForEachEndpointOfHub) {
1807 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1808 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1809 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1810 /* session= */ nullptr);
1811 std::optional<MessageRouter::MessageHub> messageHub =
1812 router.registerMessageHub("hub1", /* id= */ 1, callback);
1813 EXPECT_TRUE(messageHub.has_value());
1814
1815 DynamicVector<EndpointInfo> endpoints;
1816 EXPECT_TRUE(router.forEachEndpointOfHub(
1817 /* messageHubId= */ 1, [&endpoints](const EndpointInfo &info) {
1818 endpoints.push_back(info);
1819 return false;
1820 }));
1821 EXPECT_EQ(endpoints.size(), kNumEndpoints);
1822 for (size_t i = 0; i < endpoints.size(); ++i) {
1823 EXPECT_EQ(endpoints[i].id, kEndpointInfos[i].id);
1824 EXPECT_STREQ(endpoints[i].name, kEndpointInfos[i].name);
1825 EXPECT_EQ(endpoints[i].version, kEndpointInfos[i].version);
1826 EXPECT_EQ(endpoints[i].type, kEndpointInfos[i].type);
1827 EXPECT_EQ(endpoints[i].requiredPermissions,
1828 kEndpointInfos[i].requiredPermissions);
1829 }
1830 }
1831
TEST_F(MessageRouterTest,ForEachEndpoint)1832 TEST_F(MessageRouterTest, ForEachEndpoint) {
1833 const char *kHubName = "hub1";
1834 constexpr MessageHubId kHubId = 1;
1835
1836 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1837 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1838 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1839 /* session= */ nullptr);
1840 std::optional<MessageRouter::MessageHub> messageHub =
1841 router.registerMessageHub(kHubName, kHubId, callback);
1842 EXPECT_TRUE(messageHub.has_value());
1843
1844 DynamicVector<std::pair<MessageHubInfo, EndpointInfo>> endpoints;
1845 router.forEachEndpoint(
1846 [&endpoints](const MessageHubInfo &hubInfo, const EndpointInfo &info) {
1847 endpoints.push_back(std::make_pair(hubInfo, info));
1848 });
1849 EXPECT_EQ(endpoints.size(), kNumEndpoints);
1850 for (size_t i = 0; i < endpoints.size(); ++i) {
1851 EXPECT_EQ(endpoints[i].first.id, kHubId);
1852 EXPECT_STREQ(endpoints[i].first.name, kHubName);
1853
1854 EXPECT_EQ(endpoints[i].second.id, kEndpointInfos[i].id);
1855 EXPECT_STREQ(endpoints[i].second.name, kEndpointInfos[i].name);
1856 EXPECT_EQ(endpoints[i].second.version, kEndpointInfos[i].version);
1857 EXPECT_EQ(endpoints[i].second.type, kEndpointInfos[i].type);
1858 EXPECT_EQ(endpoints[i].second.requiredPermissions,
1859 kEndpointInfos[i].requiredPermissions);
1860 }
1861 }
1862
TEST_F(MessageRouterTest,ForEachEndpointOfHubInvalidHub)1863 TEST_F(MessageRouterTest, ForEachEndpointOfHubInvalidHub) {
1864 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1865 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1866 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1867 /* session= */ nullptr);
1868 std::optional<MessageRouter::MessageHub> messageHub =
1869 router.registerMessageHub("hub1", /* id= */ 1, callback);
1870 EXPECT_TRUE(messageHub.has_value());
1871
1872 DynamicVector<EndpointInfo> endpoints;
1873 EXPECT_FALSE(router.forEachEndpointOfHub(
1874 /* messageHubId= */ 2, [&endpoints](const EndpointInfo &info) {
1875 endpoints.push_back(info);
1876 return false;
1877 }));
1878 EXPECT_EQ(endpoints.size(), 0);
1879 }
1880
TEST_F(MessageRouterTest,RegisterEndpointCallbacksAreCalled)1881 TEST_F(MessageRouterTest, RegisterEndpointCallbacksAreCalled) {
1882 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1883 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1884 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1885 /* session= */ nullptr);
1886 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1887 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1888 /* session= */ nullptr);
1889 std::optional<MessageRouter::MessageHub> messageHub =
1890 router.registerMessageHub("hub1", /* id= */ 1, callback);
1891 EXPECT_TRUE(messageHub.has_value());
1892 std::optional<MessageRouter::MessageHub> messageHub2 =
1893 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1894 EXPECT_TRUE(messageHub.has_value());
1895
1896 // Register the endpoint and verify that the callbacks were called
1897 EXPECT_TRUE(messageHub->registerEndpoint(kEndpointInfos[0].id));
1898 EXPECT_TRUE(callback2->hasEndpointBeenRegistered(messageHub->getId(),
1899 kEndpointInfos[0].id));
1900 }
1901
TEST_F(MessageRouterTest,UnregisterEndpointCallbacksAreCalled)1902 TEST_F(MessageRouterTest, UnregisterEndpointCallbacksAreCalled) {
1903 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1904 pw::IntrusivePtr<MessageHubCallbackStoreData> callback =
1905 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1906 /* session= */ nullptr);
1907 pw::IntrusivePtr<MessageHubCallbackStoreData> callback2 =
1908 pw::MakeRefCounted<MessageHubCallbackStoreData>(/* message= */ nullptr,
1909 /* session= */ nullptr);
1910 std::optional<MessageRouter::MessageHub> messageHub =
1911 router.registerMessageHub("hub1", /* id= */ 1, callback);
1912 EXPECT_TRUE(messageHub.has_value());
1913 std::optional<MessageRouter::MessageHub> messageHub2 =
1914 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1915 EXPECT_TRUE(messageHub.has_value());
1916
1917 // Register the endpoint and verify that the callbacks were called
1918 // only on the other hub
1919 EXPECT_TRUE(messageHub->registerEndpoint(kEndpointInfos[0].id));
1920 EXPECT_FALSE(callback->hasEndpointBeenRegistered(messageHub->getId(),
1921 kEndpointInfos[0].id));
1922 EXPECT_TRUE(callback2->hasEndpointBeenRegistered(messageHub->getId(),
1923 kEndpointInfos[0].id));
1924
1925 // Unregister the endpoint and verify that the callbacks were called
1926 // only on the other hub
1927 EXPECT_TRUE(messageHub->unregisterEndpoint(kEndpointInfos[0].id));
1928 EXPECT_FALSE(callback->hasEndpointBeenRegistered(messageHub->getId(),
1929 kEndpointInfos[0].id));
1930 EXPECT_FALSE(callback2->hasEndpointBeenRegistered(messageHub->getId(),
1931 kEndpointInfos[0].id));
1932 }
1933
1934 MATCHER_P(HubMatcher, id, "Matches id in MessageHubInfo") {
1935 return arg.id == id;
1936 }
1937
TEST_F(MessageRouterTest,OnRegisterAndUnregisterHub)1938 TEST_F(MessageRouterTest, OnRegisterAndUnregisterHub) {
1939 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1940 pw::IntrusivePtr<MockMessageHubCallback> hub1Callback =
1941 pw::MakeRefCounted<MockMessageHubCallback>();
1942 pw::IntrusivePtr<MockMessageHubCallback> hub2Callback =
1943 pw::MakeRefCounted<MockMessageHubCallback>();
1944 MessageHubId hub1Id = 1, hub2Id = 2;
1945 std::optional<MessageRouter::MessageHub> hub1 =
1946 router.registerMessageHub("hub1", hub1Id, hub1Callback);
1947 ASSERT_TRUE(hub1.has_value());
1948
1949 EXPECT_CALL(*hub1Callback, onHubRegistered(HubMatcher(hub2Id)));
1950 std::optional<MessageRouter::MessageHub> hub2 =
1951 router.registerMessageHub("hub2", hub2Id, hub2Callback);
1952 ASSERT_TRUE(hub2.has_value());
1953
1954 EXPECT_CALL(*hub1Callback, onHubUnregistered(hub2Id));
1955 hub2.reset();
1956 }
1957
1958 MATCHER_P(SessionIdMatcher, id, "Matches id in Session") {
1959 return arg.sessionId == id;
1960 }
1961
TEST_F(MessageRouterTest,SessionCallbacksAreCalledOnceSameHub)1962 TEST_F(MessageRouterTest, SessionCallbacksAreCalledOnceSameHub) {
1963 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1964 pw::IntrusivePtr<MockMessageHubCallback> hub1Callback =
1965 pw::MakeRefCounted<MockMessageHubCallback>();
1966 MessageHubId hub1Id = 1;
1967 std::optional<MessageRouter::MessageHub> hub1 =
1968 router.registerMessageHub("hub1", hub1Id, hub1Callback);
1969 ASSERT_TRUE(hub1.has_value());
1970
1971 ON_CALL(*hub1Callback, forEachEndpoint).WillByDefault(forEachEndpoint);
1972
1973 // Try with different endpoints
1974 SessionId sessionId = hub1->openSession(kEndpointInfos[0].id, hub1->getId(),
1975 kEndpointInfos[1].id);
1976 ASSERT_NE(sessionId, SESSION_ID_INVALID);
1977
1978 EXPECT_CALL(*hub1Callback, onSessionOpened(_)).Times(1);
1979 hub1->onSessionOpenComplete(sessionId);
1980
1981 EXPECT_CALL(*hub1Callback, onSessionClosed(SessionIdMatcher(sessionId), _))
1982 .Times(1);
1983 hub1->closeSession(sessionId);
1984
1985 // Try with the same endpoint
1986 SessionId sessionId2 = hub1->openSession(kEndpointInfos[1].id, hub1->getId(),
1987 kEndpointInfos[1].id);
1988 ASSERT_NE(sessionId2, SESSION_ID_INVALID);
1989
1990 EXPECT_CALL(*hub1Callback, onSessionOpened(_)).Times(1);
1991 hub1->onSessionOpenComplete(sessionId2);
1992
1993 EXPECT_CALL(*hub1Callback, onSessionClosed(SessionIdMatcher(sessionId2), _))
1994 .Times(1);
1995 hub1->closeSession(sessionId2);
1996 }
1997
1998 } // namespace
1999 } // namespace chre::message
2000