• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2025 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cstring>
18 #include <optional>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "test_base.h"
25 
26 #include "chre/core/event_loop_manager.h"
27 #include "chre/core/host_message_hub_manager.h"
28 #include "chre/platform/memory.h"
29 #include "chre/util/system/message_common.h"
30 #include "chre/util/system/message_router.h"
31 #include "chre/util/system/message_router_mocks.h"
32 #include "chre_api/chre/event.h"
33 
34 #include "pw_allocator/libc_allocator.h"
35 #include "pw_allocator/unique_ptr.h"
36 #include "pw_function/function.h"
37 
38 namespace chre {
39 namespace {
40 
41 using ::chre::message::EndpointId;
42 using ::chre::message::EndpointInfo;
43 using ::chre::message::EndpointType;
44 using ::chre::message::Message;
45 using ::chre::message::MessageHubId;
46 using ::chre::message::MessageHubInfo;
47 using ::chre::message::MessageRouter;
48 using ::chre::message::MessageRouterSingleton;
49 using ::chre::message::MockMessageHubCallback;
50 using ::chre::message::Reason;
51 using ::chre::message::RpcFormat;
52 using ::chre::message::ServiceInfo;
53 using ::chre::message::Session;
54 using ::chre::message::SessionId;
55 using ::testing::_;
56 using ::testing::AnyNumber;
57 using ::testing::Expectation;
58 using ::testing::NiceMock;
59 using ::testing::UnorderedElementsAreArray;
60 
61 class MockHostCallback : public HostMessageHubManager::HostCallback {
62  public:
63   MOCK_METHOD(void, onReset, (), (override));
64   MOCK_METHOD(void, onHubRegistered, (const MessageHubInfo &), (override));
65   MOCK_METHOD(void, onHubUnregistered, (MessageHubId), (override));
66   MOCK_METHOD(void, onEndpointRegistered, (MessageHubId, const EndpointInfo &),
67               (override));
68   MOCK_METHOD(void, onEndpointService,
69               (MessageHubId, EndpointId, const ServiceInfo &), (override));
70   MOCK_METHOD(void, onEndpointReady, (MessageHubId, EndpointId), (override));
71   MOCK_METHOD(void, onEndpointUnregistered, (MessageHubId, EndpointId),
72               (override));
73   MOCK_METHOD(bool, onMessageReceived,
74               (MessageHubId, SessionId, pw::UniquePtr<std::byte[]> &&, uint32_t,
75                uint32_t),
76               (override));
77   MOCK_METHOD(void, onSessionOpenRequest, (const Session &), (override));
78   MOCK_METHOD(void, onSessionOpened, (MessageHubId, SessionId), (override));
79   MOCK_METHOD(void, onSessionClosed, (MessageHubId, SessionId, Reason),
80               (override));
81 };
82 
getManager()83 HostMessageHubManager &getManager() {
84   return EventLoopManagerSingleton::get()->getHostMessageHubManager();
85 }
86 
getRouter()87 MessageRouter &getRouter() {
88   return *MessageRouterSingleton::get();
89 }
90 
91 const char *kServiceName = "test_service";
92 const ServiceInfo kService(kServiceName, 0, 0, RpcFormat::CUSTOM);
93 const EndpointInfo kEndpoints[] = {
94     EndpointInfo(0x1, nullptr, 0, EndpointType::GENERIC, 0),
95     EndpointInfo(0x2, nullptr, 0, EndpointType::GENERIC, 0)};
96 const EndpointInfo kExtraEndpoint(0x3, nullptr, 0, EndpointType::GENERIC, 0);
97 const EndpointId kEndpointIds[] = {0x1, 0x2};
98 const char *kEmbeddedHubName = "embedded hub";
99 const MessageHubInfo kEmbeddedHub{.id = CHRE_PLATFORM_ID + 1,
100                                   .name = kEmbeddedHubName};
101 const char *kHostHubName = "host hub";
102 const MessageHubInfo kHostHub{.id = kEmbeddedHub.id + 1, .name = kHostHubName};
103 
104 class HostMessageHubTest : public TestBase {
105  public:
HostMessageHubTest()106   HostMessageHubTest() : TestBase() {
107     for (const auto &endpoint : kEndpoints) {
108       std::vector<ServiceInfo> services;
109       if (endpoint.id > 0x1) services.push_back(kService);
110       mEmbeddedEndpoints.push_back({endpoint, std::move(services)});
111     }
112   }
113 
SetUp()114   void SetUp() override {
115     TestBase::SetUp();
116 
117     mEmbeddedHubCb = pw::MakeRefCounted<NiceMock<MockMessageHubCallback>>();
118     ASSERT_NE(mEmbeddedHubCb.get(), nullptr);
119 
120     // Specify uninteresting behaviors for the mock embedded hub callback.
121     ON_CALL(*mEmbeddedHubCb, forEachEndpoint(_))
122         .WillByDefault(
123             [this](const pw::Function<bool(const EndpointInfo &)> &fn) {
124               for (const auto &endpoint : mEmbeddedEndpoints)
125                 if (fn(endpoint.first)) return;
126             });
127     ON_CALL(*mEmbeddedHubCb, getEndpointInfo(_))
128         .WillByDefault([this](EndpointId id) -> std::optional<EndpointInfo> {
129           for (const auto &endpoint : mEmbeddedEndpoints)
130             if (endpoint.first.id == id) return endpoint.first;
131           return {};
132         });
133     ON_CALL(*mEmbeddedHubCb, getEndpointForService(_))
134         .WillByDefault(
135             [this](const char *service) -> std::optional<EndpointId> {
136               for (const auto &endpoint : mEmbeddedEndpoints) {
137                 for (const auto &serviceInfo : endpoint.second) {
138                   if (!std::strcmp(serviceInfo.serviceDescriptor, service))
139                     return endpoint.first.id;
140                 }
141               }
142               return {};
143             });
144     ON_CALL(*mEmbeddedHubCb, doesEndpointHaveService(_, _))
145         .WillByDefault([this](EndpointId id, const char *service) {
146           for (const auto &endpoint : mEmbeddedEndpoints) {
147             if (endpoint.first.id != id) continue;
148             for (const auto &serviceInfo : endpoint.second) {
149               if (!std::strcmp(serviceInfo.serviceDescriptor, service))
150                 return true;
151             }
152           }
153           return false;
154         });
155     ON_CALL(*mEmbeddedHubCb, forEachService(_))
156         .WillByDefault(
157             [this](const pw::Function<bool(const EndpointInfo &,
158                                            const message::ServiceInfo &)> &fn) {
159               for (const auto &endpoint : mEmbeddedEndpoints) {
160                 for (const auto &serviceInfo : endpoint.second) {
161                   if (fn(endpoint.first, serviceInfo)) return;
162                 }
163               }
164             });
165 
166     // We mostly don't care about this. Individual tests may override this
167     // behavior.
168     EXPECT_CALL(*mEmbeddedHubCb, onHubRegistered(_)).Times(AnyNumber());
169     EXPECT_CALL(*mEmbeddedHubCb, onHubUnregistered(_)).Times(AnyNumber());
170     EXPECT_CALL(mHostCallback, onHubRegistered(_)).Times(AnyNumber());
171     EXPECT_CALL(mHostCallback, onHubUnregistered(_)).Times(AnyNumber());
172 
173     // Register the embedded message hub with MessageRouter.
174     auto maybeEmbeddedHub = getRouter().registerMessageHub(
175         kEmbeddedHubName, kEmbeddedHub.id, mEmbeddedHubCb);
176     if (maybeEmbeddedHub) {
177       mEmbeddedHubIntf = std::move(*maybeEmbeddedHub);
178     } else {
179       FAIL() << "Failed to register test embedded message hub";
180     }
181 
182     // Initialize the manager with a mock HostCallback.
183     getManager().onHostTransportReady(mHostCallback);
184   }
185 
TearDown()186   void TearDown() override {
187     EXPECT_CALL(mHostCallback, onReset());
188     EXPECT_CALL(mHostCallback, onHubRegistered(_)).Times(AnyNumber());
189     EXPECT_CALL(mHostCallback, onEndpointRegistered(_, _)).Times(AnyNumber());
190     EXPECT_CALL(mHostCallback, onEndpointService(_, _, _)).Times(AnyNumber());
191     EXPECT_CALL(mHostCallback, onEndpointReady(_, _)).Times(AnyNumber());
192     getManager().reset();
193     mEmbeddedHubIntf.unregister();
194 
195     TestBase::TearDown();
196   }
197 
getHostEndpointServices()198   DynamicVector<ServiceInfo> getHostEndpointServices() {
199     auto serviceName =
200         static_cast<char *>(memoryAlloc(std::strlen(kServiceName) + 1));
201     std::memcpy(serviceName, kServiceName, std::strlen(kServiceName) + 1);
202     DynamicVector<ServiceInfo> services;
203     services.emplace_back(serviceName, kService.majorVersion,
204                           kService.minorVersion, kService.format);
205     return services;
206   }
207 
expectOnEmbeddedEndpoint(const std::pair<EndpointInfo,std::vector<ServiceInfo>> & endpoint,Expectation * sequence)208   void expectOnEmbeddedEndpoint(
209       const std::pair<EndpointInfo, std::vector<ServiceInfo>> &endpoint,
210       Expectation *sequence) {
211     Expectation previous;
212     if (sequence) {
213       previous =
214           EXPECT_CALL(mHostCallback,
215                       onEndpointRegistered(kEmbeddedHub.id, endpoint.first))
216               .After(*sequence)
217               .RetiresOnSaturation();
218     } else {
219       previous =
220           EXPECT_CALL(mHostCallback,
221                       onEndpointRegistered(kEmbeddedHub.id, endpoint.first))
222               .RetiresOnSaturation();
223     }
224     for (const auto &service : endpoint.second) {
225       previous = EXPECT_CALL(mHostCallback,
226                              onEndpointService(kEmbeddedHub.id,
227                                                endpoint.first.id, service))
228                      .After(previous)
229                      .RetiresOnSaturation();
230     }
231     EXPECT_CALL(mHostCallback,
232                 onEndpointReady(kEmbeddedHub.id, endpoint.first.id))
233         .After(previous)
234         .RetiresOnSaturation();
235   }
236 
237  protected:
238   pw::IntrusivePtr<NiceMock<MockMessageHubCallback>> mEmbeddedHubCb;
239   MessageRouter::MessageHub mEmbeddedHubIntf;
240   MockHostCallback mHostCallback;
241 
242   std::vector<std::pair<EndpointInfo, std::vector<ServiceInfo>>>
243       mEmbeddedEndpoints;
244 };
245 
246 MATCHER_P(HubIdMatcher, id, "Matches a MessageHubInfo by id") {
247   return arg.id == id;
248 }
249 
TEST_F(HostMessageHubTest,Reset)250 TEST_F(HostMessageHubTest, Reset) {
251   // On each reset(), expect onReset() followed by onHubRegistered() and
252   // onEndpointRegistered() for each endpoint.
253   auto resetExpectations = [this] {
254     Expectation reset =
255         EXPECT_CALL(mHostCallback, onReset()).RetiresOnSaturation();
256     Expectation defaultHub =
257         EXPECT_CALL(mHostCallback,
258                     onHubRegistered(HubIdMatcher(CHRE_PLATFORM_ID)))
259             .After(reset)
260             .RetiresOnSaturation();
261     Expectation testHub =
262         EXPECT_CALL(mHostCallback, onHubRegistered(kEmbeddedHub))
263             .After(reset)
264             .RetiresOnSaturation();
265     for (const auto &endpoint : mEmbeddedEndpoints)
266       expectOnEmbeddedEndpoint(endpoint, &testHub);
267   };
268 
269   // reset() with no host endpoints.
270   resetExpectations();
271   getManager().reset();
272   getRouter().forEachEndpoint(
273       [](const MessageHubInfo &hub, const EndpointInfo &) {
274         EXPECT_EQ(hub.id, kEmbeddedHub.id);
275       });
276 
277   // Add a host hub and endpoint. MessageRouter should see none of them after a
278   // second reset().
279   getManager().registerHub(kHostHub);
280   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
281   resetExpectations();
282   getManager().reset();
283   getRouter().forEachEndpoint(
284       [](const MessageHubInfo &hub, const EndpointInfo &) {
285         EXPECT_EQ(hub.id, kEmbeddedHub.id);
286       });
287 }
288 
TEST_F(HostMessageHubTest,RegisterAndUnregisterHub)289 TEST_F(HostMessageHubTest, RegisterAndUnregisterHub) {
290   EXPECT_FALSE(getRouter().forEachEndpointOfHub(
291       kHostHub.id, [](const EndpointInfo &) { return true; }));
292 
293   EXPECT_CALL(*mEmbeddedHubCb, onHubRegistered(kHostHub));
294   getManager().registerHub(kHostHub);
295   EXPECT_TRUE(getRouter().forEachEndpointOfHub(
296       kHostHub.id, [](const EndpointInfo &) { return true; }));
297 
298   EXPECT_CALL(*mEmbeddedHubCb, onHubUnregistered(kHostHub.id));
299   getManager().unregisterHub(kHostHub.id);
300   // NOTE: The hub stays registered with MessageRouter to avoid races with
301   // unregistering message hubs, however its endpoints are no longer accessible.
302   getRouter().forEachEndpointOfHub(kHostHub.id, [](const EndpointInfo &) {
303     ADD_FAILURE();
304     return true;
305   });
306 }
307 
308 // Hubs are expected to be static over the runtime, i.e. regardless of when a
309 // hub is registered, the total set of hubs is fixed. A different hub cannot
310 // take the slot of an unregistered hub.
TEST_F(HostMessageHubTest,RegisterHubStaticHubLimit)311 TEST_F(HostMessageHubTest, RegisterHubStaticHubLimit) {
312   // Register a hub to occupy a slot.
313   getManager().registerHub(kHostHub);
314 
315   // Attempt to register a hub for each slot. The final registration should fail
316   // due to the occupied slot.
317   std::vector<std::string> hubNames;
318   for (uint64_t i = 1; i <= CHRE_MESSAGE_ROUTER_MAX_HOST_HUBS; ++i) {
319     MessageHubId id = kHostHub.id + i;
320     hubNames.push_back(std::string(kHostHubName) + '0');
321     hubNames.back().back() = i + '0';
322     getManager().registerHub({.id = id, .name = hubNames[i - 1].c_str()});
323     if (i < CHRE_MESSAGE_ROUTER_MAX_HOST_HUBS) {
324       EXPECT_TRUE(getRouter().forEachEndpointOfHub(
325           id, [](const EndpointInfo &) { return true; }));
326     } else {
327       EXPECT_FALSE(getRouter().forEachEndpointOfHub(
328           id, [](const EndpointInfo &) { return true; }));
329     }
330   }
331 }
332 
333 MATCHER_P(HubMatcher, id, "matches the hub id in MessageHubInfo") {
334   return arg.id == id;
335 }
336 
TEST_F(HostMessageHubTest,OnHubRegisteredAndUnregistered)337 TEST_F(HostMessageHubTest, OnHubRegisteredAndUnregistered) {
338   getManager().registerHub(kHostHub);
339 
340   const MessageHubId kHubId = kHostHub.id + 1;
341   EXPECT_CALL(mHostCallback, onHubRegistered(HubMatcher(kHubId)));
342   pw::IntrusivePtr<MockMessageHubCallback> newHubCb =
343       pw::MakeRefCounted<MockMessageHubCallback>();
344   const char *name = "test embedded hub";
345   auto newHub = getRouter().registerMessageHub(name, kHubId, newHubCb);
346   EXPECT_TRUE(newHub);
347 
348   EXPECT_CALL(mHostCallback, onHubUnregistered(kHubId));
349   newHub.reset();
350 }
351 
TEST_F(HostMessageHubTest,RegisterAndUnregisterEndpoint)352 TEST_F(HostMessageHubTest, RegisterAndUnregisterEndpoint) {
353   getManager().registerHub(kHostHub);
354 
355   EXPECT_CALL(*mEmbeddedHubCb,
356               onEndpointRegistered(kHostHub.id, kEndpoints[0].id));
357   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
358   getRouter().forEachEndpointOfHub(kHostHub.id, [](const EndpointInfo &info) {
359     EXPECT_EQ(info.id, kEndpoints[0].id);
360     return true;
361   });
362 
363   EXPECT_CALL(*mEmbeddedHubCb,
364               onEndpointUnregistered(kHostHub.id, kEndpoints[0].id));
365   getManager().unregisterEndpoint(kHostHub.id, kEndpoints[0].id);
366   bool found = false;
367   getRouter().forEachEndpointOfHub(kHostHub.id, [&found](const EndpointInfo &) {
368     found = true;
369     return true;
370   });
371   EXPECT_FALSE(found);
372 }
373 
TEST_F(HostMessageHubTest,RegisterAndUnregisterEndpointWithService)374 TEST_F(HostMessageHubTest, RegisterAndUnregisterEndpointWithService) {
375   getManager().registerHub(kHostHub);
376 
377   EXPECT_CALL(*mEmbeddedHubCb,
378               onEndpointRegistered(kHostHub.id, kEndpoints[0].id));
379   getManager().registerEndpoint(kHostHub.id, kEndpoints[0],
380                                 getHostEndpointServices());
381   bool found = false;
382   getRouter().forEachService([&found](const MessageHubInfo &hub,
383                                       const EndpointInfo &endpoint,
384                                       const ServiceInfo &service) {
385     if (hub.id != kHostHub.id || endpoint.id != kEndpoints[0].id ||
386         std::strcmp(service.serviceDescriptor, kServiceName)) {
387       return false;
388     }
389     found = true;
390     return true;
391   });
392   EXPECT_TRUE(found);
393 
394   EXPECT_CALL(*mEmbeddedHubCb,
395               onEndpointUnregistered(kHostHub.id, kEndpoints[0].id));
396   getManager().unregisterEndpoint(kHostHub.id, kEndpoints[0].id);
397   found = false;
398   getRouter().forEachEndpointOfHub(kHostHub.id, [&found](const EndpointInfo &) {
399     found = true;
400     return true;
401   });
402   EXPECT_FALSE(found);
403 }
404 
TEST_F(HostMessageHubTest,OnEndpointRegisteredAndUnregistered)405 TEST_F(HostMessageHubTest, OnEndpointRegisteredAndUnregistered) {
406   getManager().registerHub(kHostHub);
407 
408   mEmbeddedEndpoints.push_back({kExtraEndpoint, {}});
409   expectOnEmbeddedEndpoint(mEmbeddedEndpoints.back(), nullptr);
410   mEmbeddedHubIntf.registerEndpoint(kExtraEndpoint.id);
411 
412   EXPECT_CALL(mHostCallback,
413               onEndpointUnregistered(kEmbeddedHub.id, kExtraEndpoint.id));
414   mEmbeddedHubIntf.unregisterEndpoint(kExtraEndpoint.id);
415 }
416 
TEST_F(HostMessageHubTest,OnEndpointWithServiceRegisteredAndUnregistered)417 TEST_F(HostMessageHubTest, OnEndpointWithServiceRegisteredAndUnregistered) {
418   getManager().registerHub(kHostHub);
419 
420   mEmbeddedEndpoints.push_back({kExtraEndpoint, {kService}});
421   expectOnEmbeddedEndpoint(mEmbeddedEndpoints.back(), nullptr);
422   mEmbeddedHubIntf.registerEndpoint(kExtraEndpoint.id);
423 
424   EXPECT_CALL(mHostCallback,
425               onEndpointUnregistered(kEmbeddedHub.id, kExtraEndpoint.id));
426   mEmbeddedHubIntf.unregisterEndpoint(kExtraEndpoint.id);
427 }
428 
TEST_F(HostMessageHubTest,RegisterMaximumEndpoints)429 TEST_F(HostMessageHubTest, RegisterMaximumEndpoints) {
430   getManager().registerHub(kHostHub);
431 
432   // Try to register one more than the maximum endpoints.
433   for (int i = 0; i <= CHRE_MESSAGE_ROUTER_MAX_HOST_ENDPOINTS; ++i) {
434     EndpointInfo endpoint(0x1 + i, nullptr, 0, EndpointType::GENERIC, 0);
435     getManager().registerEndpoint(kHostHub.id, endpoint, {});
436   }
437 
438   int count = 0;
439   getRouter().forEachEndpointOfHub(kHostHub.id, [&count](const EndpointInfo &) {
440     count++;
441     return false;
442   });
443   EXPECT_EQ(count, CHRE_MESSAGE_ROUTER_MAX_HOST_ENDPOINTS);
444 
445   // Unregister one endpoint and register another one.
446   getManager().unregisterEndpoint(kHostHub.id, 0x1);
447   EndpointInfo endpoint(0x1 + CHRE_MESSAGE_ROUTER_MAX_HOST_ENDPOINTS, nullptr,
448                         0, EndpointType::GENERIC, 0);
449   getManager().registerEndpoint(kHostHub.id, endpoint, {});
450   bool found = false;
451   getRouter().forEachEndpointOfHub(
452       kHostHub.id, [&found](const EndpointInfo &info) {
453         if (info.id == 0x1 + CHRE_MESSAGE_ROUTER_MAX_HOST_ENDPOINTS) {
454           found = true;
455           return true;
456         }
457         return false;
458       });
459   EXPECT_TRUE(found);
460 }
461 
TEST_F(HostMessageHubTest,OpenAndCloseSession)462 TEST_F(HostMessageHubTest, OpenAndCloseSession) {
463   getManager().registerHub(kHostHub);
464   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
465 
466   constexpr auto sessionId = MessageRouter::kDefaultReservedSessionId;
467   EXPECT_CALL(mHostCallback, onSessionOpened(kHostHub.id, sessionId)).Times(1);
468   EXPECT_CALL(*mEmbeddedHubCb, onSessionOpenRequest(_))
469       .WillOnce([this](const Session &session) {
470         mEmbeddedHubIntf.onSessionOpenComplete(session.sessionId);
471       });
472   getManager().openSession(kHostHub.id, kEndpoints[0].id, kEmbeddedHub.id,
473                            kEndpoints[1].id, sessionId,
474                            /*serviceDescriptor=*/nullptr);
475 
476   EXPECT_CALL(*mEmbeddedHubCb,
477               onSessionClosed(_, Reason::CLOSE_ENDPOINT_SESSION_REQUESTED))
478       .Times(1);
479   getManager().closeSession(kHostHub.id, sessionId,
480                             Reason::CLOSE_ENDPOINT_SESSION_REQUESTED);
481 }
482 
TEST_F(HostMessageHubTest,OpenSessionAndHandleClose)483 TEST_F(HostMessageHubTest, OpenSessionAndHandleClose) {
484   getManager().registerHub(kHostHub);
485   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
486 
487   constexpr auto sessionId = MessageRouter::kDefaultReservedSessionId;
488   EXPECT_CALL(mHostCallback, onSessionOpened(kHostHub.id, sessionId)).Times(1);
489   EXPECT_CALL(*mEmbeddedHubCb, onSessionOpenRequest(_))
490       .WillOnce([this](const Session &session) {
491         mEmbeddedHubIntf.onSessionOpenComplete(session.sessionId);
492       });
493   getManager().openSession(kHostHub.id, kEndpoints[0].id, kEmbeddedHub.id,
494                            kEndpoints[1].id, sessionId,
495                            /*serviceDescriptor=*/nullptr);
496 
497   EXPECT_CALL(mHostCallback,
498               onSessionClosed(kHostHub.id, sessionId,
499                               Reason::CLOSE_ENDPOINT_SESSION_REQUESTED))
500       .Times(1);
501   mEmbeddedHubIntf.closeSession(sessionId,
502                                 Reason::CLOSE_ENDPOINT_SESSION_REQUESTED);
503 }
504 
TEST_F(HostMessageHubTest,OpenSessionRejected)505 TEST_F(HostMessageHubTest, OpenSessionRejected) {
506   getManager().registerHub(kHostHub);
507   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
508 
509   constexpr auto sessionId = MessageRouter::kDefaultReservedSessionId;
510   EXPECT_CALL(mHostCallback,
511               onSessionClosed(kHostHub.id, sessionId,
512                               Reason::OPEN_ENDPOINT_SESSION_REQUEST_REJECTED))
513       .Times(1);
514   EXPECT_CALL(*mEmbeddedHubCb, onSessionOpenRequest(_))
515       .WillOnce([this](const Session &session) {
516         mEmbeddedHubIntf.closeSession(
517             session.sessionId, Reason::OPEN_ENDPOINT_SESSION_REQUEST_REJECTED);
518       });
519   getManager().openSession(kHostHub.id, kEndpoints[0].id, kEmbeddedHub.id,
520                            kEndpoints[1].id, sessionId,
521                            /*serviceDescriptor=*/nullptr);
522 }
523 
TEST_F(HostMessageHubTest,OpenSessionWithService)524 TEST_F(HostMessageHubTest, OpenSessionWithService) {
525   getManager().registerHub(kHostHub);
526   getManager().registerEndpoint(kHostHub.id, kEndpoints[0],
527                                 getHostEndpointServices());
528 
529   constexpr auto sessionId = MessageRouter::kDefaultReservedSessionId;
530   EXPECT_CALL(mHostCallback, onSessionOpened(kHostHub.id, sessionId)).Times(1);
531   EXPECT_CALL(*mEmbeddedHubCb, onSessionOpenRequest(_))
532       .WillOnce([this](const Session &session) {
533         mEmbeddedHubIntf.onSessionOpenComplete(session.sessionId);
534       });
535   getManager().openSession(kHostHub.id, kEndpoints[0].id, kEmbeddedHub.id,
536                            kEndpoints[1].id, sessionId, kServiceName);
537 }
538 
TEST_F(HostMessageHubTest,OnOpenSessionWithService)539 TEST_F(HostMessageHubTest, OnOpenSessionWithService) {
540   getManager().registerHub(kHostHub);
541   getManager().registerEndpoint(kHostHub.id, kEndpoints[0],
542                                 getHostEndpointServices());
543 
544   SessionId receivedSessionId;
545   EXPECT_CALL(mHostCallback, onSessionOpenRequest(_))
546       .WillOnce([&receivedSessionId](const Session &session) {
547         receivedSessionId = session.sessionId;
548       });
549   auto sessionId = mEmbeddedHubIntf.openSession(kEndpoints[1].id, kHostHub.id,
550                                                 kEndpoints[0].id, kServiceName);
551   EXPECT_EQ(sessionId, receivedSessionId);
552 }
553 
TEST_F(HostMessageHubTest,AckSession)554 TEST_F(HostMessageHubTest, AckSession) {
555   getManager().registerHub(kHostHub);
556   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
557 
558   SessionId receivedSessionId;
559   EXPECT_CALL(mHostCallback, onSessionOpenRequest(_))
560       .WillOnce([&receivedSessionId](const Session &session) {
561         receivedSessionId = session.sessionId;
562       });
563   auto sessionId = mEmbeddedHubIntf.openSession(kEndpoints[1].id, kHostHub.id,
564                                                 kEndpoints[0].id);
565   EXPECT_EQ(sessionId, receivedSessionId);
566 
567   EXPECT_CALL(*mEmbeddedHubCb, onSessionOpened(_)).Times(1);
568   getManager().ackSession(kHostHub.id, sessionId);
569 }
570 
571 MATCHER_P(DataMatcher, data, "matches data in pw::UniquePtr<std::byte[]>") {
572   return arg != nullptr && !std::memcmp(arg.get(), data, arg.size());
573 }
574 
575 MATCHER_P(SessionIdMatcher, session, "matches the session id in Session") {
576   return arg.sessionId == session;
577 }
578 
TEST_F(HostMessageHubTest,SendMessage)579 TEST_F(HostMessageHubTest, SendMessage) {
580   getManager().registerHub(kHostHub);
581   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
582   constexpr auto sessionId = MessageRouter::kDefaultReservedSessionId;
583   EXPECT_CALL(mHostCallback, onSessionOpened(kHostHub.id, sessionId)).Times(1);
584   EXPECT_CALL(*mEmbeddedHubCb, onSessionOpenRequest(_))
585       .WillOnce([this](const Session &session) {
586         mEmbeddedHubIntf.onSessionOpenComplete(session.sessionId);
587       });
588   getManager().openSession(kHostHub.id, kEndpoints[0].id, kEmbeddedHub.id,
589                            kEndpoints[1].id, sessionId,
590                            /*serviceDescriptor=*/nullptr);
591 
592   std::byte data[] = {std::byte{0xde}, std::byte{0xad}, std::byte{0xbe},
593                       std::byte{0xef}};
594   EXPECT_CALL(*mEmbeddedHubCb,
595               onMessageReceived(DataMatcher(data), 1, 2,
596                                 SessionIdMatcher(sessionId), _))
597       .Times(1);
598   getManager().sendMessage(kHostHub.id, sessionId, {data, sizeof(data)}, 1, 2);
599 }
600 
TEST_F(HostMessageHubTest,ReceiveMessage)601 TEST_F(HostMessageHubTest, ReceiveMessage) {
602   getManager().registerHub(kHostHub);
603   getManager().registerEndpoint(kHostHub.id, kEndpoints[0], {});
604   constexpr auto sessionId = MessageRouter::kDefaultReservedSessionId;
605   EXPECT_CALL(mHostCallback, onSessionOpened(kHostHub.id, sessionId)).Times(1);
606   EXPECT_CALL(*mEmbeddedHubCb, onSessionOpenRequest(_))
607       .WillOnce([this](const Session &session) {
608         mEmbeddedHubIntf.onSessionOpenComplete(session.sessionId);
609       });
610   getManager().openSession(kHostHub.id, kEndpoints[0].id, kEmbeddedHub.id,
611                            kEndpoints[1].id, sessionId,
612                            /*serviceDescriptor=*/nullptr);
613 
614   std::byte bytes[] = {std::byte{0xde}, std::byte{0xad}, std::byte{0xbe},
615                        std::byte{0xef}};
616   auto data = pw::allocator::GetLibCAllocator().MakeUniqueArray<std::byte>(4);
617   std::memcpy(data.get(), bytes, sizeof(bytes));
618   EXPECT_CALL(mHostCallback, onMessageReceived(kHostHub.id, sessionId,
619                                                DataMatcher(bytes), 1, 2))
620       .Times(1);
621   mEmbeddedHubIntf.sendMessage(std::move(data), 1, 2, sessionId);
622 }
623 
624 }  // namespace
625 }  // namespace chre
626