• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre/util/system/message_router.h"
18 #include "chre/util/dynamic_vector.h"
19 #include "chre/util/lock_guard.h"
20 #include "chre/util/system/message_common.h"
21 
22 #include <inttypes.h>
23 #include <cstring>
24 #include <optional>
25 #include <utility>
26 
27 namespace chre::message {
28 
MessageHub()29 MessageRouter::MessageHub::MessageHub()
30     : mRouter(nullptr), mHubId(MESSAGE_HUB_ID_INVALID) {}
31 
MessageHub(MessageRouter & router,MessageHubId id)32 MessageRouter::MessageHub::MessageHub(MessageRouter &router, MessageHubId id)
33     : mRouter(&router), mHubId(id) {}
34 
MessageHub(MessageHub && other)35 MessageRouter::MessageHub::MessageHub(MessageHub &&other)
36     : mRouter(other.mRouter), mHubId(other.mHubId) {
37   other.mRouter = nullptr;
38   other.mHubId = MESSAGE_HUB_ID_INVALID;
39 }
40 
operator =(MessageHub && other)41 MessageRouter::MessageHub &MessageRouter::MessageHub::operator=(
42     MessageHub &&other) {
43   unregister();
44   mRouter = other.mRouter;
45   mHubId = other.mHubId;
46   other.mRouter = nullptr;
47   other.mHubId = MESSAGE_HUB_ID_INVALID;
48   return *this;
49 }
50 
~MessageHub()51 MessageRouter::MessageHub::~MessageHub() {
52   unregister();
53 }
54 
onSessionOpenComplete(SessionId sessionId)55 void MessageRouter::MessageHub::onSessionOpenComplete(SessionId sessionId) {
56   if (mRouter != nullptr) {
57     mRouter->onSessionOpenComplete(mHubId, sessionId);
58   }
59 }
60 
openSession(EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId,const char * serviceDescriptor,SessionId sessionId)61 SessionId MessageRouter::MessageHub::openSession(EndpointId fromEndpointId,
62                                                  MessageHubId toMessageHubId,
63                                                  EndpointId toEndpointId,
64                                                  const char *serviceDescriptor,
65                                                  SessionId sessionId) {
66   return mRouter == nullptr
67              ? SESSION_ID_INVALID
68              : mRouter->openSession(mHubId, fromEndpointId, toMessageHubId,
69                                     toEndpointId, serviceDescriptor, sessionId);
70 }
71 
closeSession(SessionId sessionId,Reason reason)72 bool MessageRouter::MessageHub::closeSession(SessionId sessionId,
73                                              Reason reason) {
74   return mRouter != nullptr && mRouter->closeSession(mHubId, sessionId, reason);
75 }
76 
getSessionWithId(SessionId sessionId)77 std::optional<Session> MessageRouter::MessageHub::getSessionWithId(
78     SessionId sessionId) {
79   return mRouter == nullptr ? std::nullopt
80                             : mRouter->getSessionWithId(mHubId, sessionId);
81 }
82 
sendMessage(pw::UniquePtr<std::byte[]> && data,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId,EndpointId fromEndpointId)83 bool MessageRouter::MessageHub::sendMessage(pw::UniquePtr<std::byte[]> &&data,
84                                             uint32_t messageType,
85                                             uint32_t messagePermissions,
86                                             SessionId sessionId,
87                                             EndpointId fromEndpointId) {
88   return mRouter != nullptr &&
89          mRouter->sendMessage(std::move(data), messageType, messagePermissions,
90                               sessionId, fromEndpointId, mHubId);
91 }
92 
registerEndpoint(EndpointId endpointId)93 bool MessageRouter::MessageHub::registerEndpoint(EndpointId endpointId) {
94   return mRouter != nullptr && mRouter->registerEndpoint(mHubId, endpointId);
95 }
96 
unregisterEndpoint(EndpointId endpointId)97 bool MessageRouter::MessageHub::unregisterEndpoint(EndpointId endpointId) {
98   return mRouter != nullptr && mRouter->unregisterEndpoint(mHubId, endpointId);
99 }
100 
getId()101 MessageHubId MessageRouter::MessageHub::getId() {
102   return mHubId;
103 }
104 
isRegistered()105 bool MessageRouter::MessageHub::isRegistered() {
106   return mRouter != nullptr;
107 }
108 
unregister()109 void MessageRouter::MessageHub::unregister() {
110   if (mRouter != nullptr) {
111     mRouter->unregisterMessageHub(mHubId);
112   }
113   mRouter = nullptr;
114 }
115 
116 std::optional<typename MessageRouter::MessageHub>
registerMessageHub(const char * name,MessageHubId id,pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback)117 MessageRouter::registerMessageHub(
118     const char *name, MessageHubId id,
119     pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback) {
120   DynamicVector<MessageHubRecord> hubsToNotify;
121   std::optional<MessageHub> newHub;
122   MessageHubInfo newHubInfo;
123   {
124     LockGuard<Mutex> lock(mMutex);
125     if (mMessageHubs.full()) {
126       LOGE(
127           "Message hub '%s' not registered: maximum number of message hubs "
128           "reached",
129           name);
130       return std::nullopt;
131     }
132 
133     for (MessageHubRecord &messageHub : mMessageHubs) {
134       if (std::strcmp(messageHub.info.name, name) == 0 ||
135           messageHub.info.id == id) {
136         LOGE(
137             "Message hub '%s' not registered: hub with same name or ID already "
138             "exists",
139             name);
140         return std::nullopt;
141       }
142     }
143 
144     if (auto hubRecords = getMessageHubRecordsLocked(); hubRecords) {
145       hubsToNotify = std::move(*hubRecords);
146     } else {
147       return std::nullopt;
148     }
149 
150     MessageHubRecord messageHubRecord = {
151         .info = {.id = id, .name = name},
152         .callback = std::move(callback),
153     };
154     newHubInfo = messageHubRecord.info;
155     mMessageHubs.push_back(std::move(messageHubRecord));
156     newHub = MessageHub(*this, id);
157   }
158 
159   // NOTE: newHubInfo is guaranteed to be valid while we have newHub.
160   for (const auto &hubRecord : hubsToNotify) {
161     hubRecord.callback->onHubRegistered(newHubInfo);
162   }
163   return newHub;
164 }
165 
forEachEndpointOfHub(MessageHubId messageHubId,const pw::Function<bool (const EndpointInfo &)> & function)166 bool MessageRouter::forEachEndpointOfHub(
167     MessageHubId messageHubId,
168     const pw::Function<bool(const EndpointInfo &)> &function) {
169   pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
170       getCallbackFromMessageHubId(messageHubId);
171   if (callback == nullptr) {
172     LOGE("Failed to find message hub with ID 0x%" PRIx64, messageHubId);
173     return false;
174   }
175 
176   callback->forEachEndpoint(function);
177   return true;
178 }
179 
forEachEndpoint(const pw::Function<void (const MessageHubInfo &,const EndpointInfo &)> & function)180 bool MessageRouter::forEachEndpoint(
181     const pw::Function<void(const MessageHubInfo &, const EndpointInfo &)>
182         &function) {
183   std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
184       getMessageHubRecords();
185   if (!messageHubRecords.has_value()) {
186     return false;
187   }
188 
189   struct Context {
190     decltype(function) function;
191     const MessageHubInfo &messageHubInfo;
192   };
193   for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
194     Context context = {
195         .function = function,
196         .messageHubInfo = messageHubRecord.info,
197     };
198 
199     messageHubRecord.callback->forEachEndpoint(
200         [&context](const EndpointInfo &endpointInfo) {
201           context.function(context.messageHubInfo, endpointInfo);
202           return false;
203         });
204   }
205   return true;
206 }
207 
getEndpointInfo(MessageHubId messageHubId,EndpointId endpointId)208 std::optional<EndpointInfo> MessageRouter::getEndpointInfo(
209     MessageHubId messageHubId, EndpointId endpointId) {
210   pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
211       getCallbackFromMessageHubId(messageHubId);
212   if (callback == nullptr) {
213     LOGE("Failed to get endpoint info for message hub with ID 0x%" PRIx64
214          " and endpoint ID 0x%" PRIx64 ": hub not found",
215          messageHubId, endpointId);
216     return std::nullopt;
217   }
218 
219   return callback->getEndpointInfo(endpointId);
220 }
221 
getEndpointForService(MessageHubId messageHubId,const char * serviceDescriptor)222 std::optional<Endpoint> MessageRouter::getEndpointForService(
223     MessageHubId messageHubId, const char *serviceDescriptor) {
224   if (serviceDescriptor == nullptr) {
225     LOGE("Failed to get endpoint for service: service descriptor is null");
226     return std::nullopt;
227   }
228 
229   std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
230       getMessageHubRecords();
231   if (!messageHubRecords.has_value()) {
232     return std::nullopt;
233   }
234 
235   for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
236     if ((messageHubId == MESSAGE_HUB_ID_ANY ||
237          messageHubId == messageHubRecord.info.id) &&
238         messageHubRecord.callback != nullptr) {
239       std::optional<EndpointId> endpointId =
240           messageHubRecord.callback->getEndpointForService(serviceDescriptor);
241       if (endpointId.has_value()) {
242         return Endpoint(messageHubRecord.info.id, *endpointId);
243       }
244 
245       // Only searching this message hub, so return early if not found
246       if (messageHubId != MESSAGE_HUB_ID_ANY) {
247         return std::nullopt;
248       }
249     }
250   }
251   return std::nullopt;
252 }
253 
doesEndpointHaveService(MessageHubId messageHubId,EndpointId endpointId,const char * serviceDescriptor)254 bool MessageRouter::doesEndpointHaveService(MessageHubId messageHubId,
255                                             EndpointId endpointId,
256                                             const char *serviceDescriptor) {
257   if (serviceDescriptor == nullptr) {
258     LOGE("Failed to check if endpoint has service: service descriptor is null");
259     return false;
260   }
261 
262   pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
263       getCallbackFromMessageHubId(messageHubId);
264   if (callback == nullptr) {
265     LOGE(
266         "Failed to check if endpoint has service for message hub with ID "
267         "0x%" PRIx64 " and endpoint ID 0x%" PRIx64 ": hub not found",
268         messageHubId, endpointId);
269     return false;
270   }
271   return callback->doesEndpointHaveService(endpointId, serviceDescriptor);
272 }
273 
forEachService(const pw::Function<bool (const MessageHubInfo &,const EndpointInfo &,const ServiceInfo &)> & function)274 bool MessageRouter::forEachService(
275     const pw::Function<bool(const MessageHubInfo &, const EndpointInfo &,
276                             const ServiceInfo &)> &function) {
277   std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
278       getMessageHubRecords();
279   if (!messageHubRecords.has_value()) {
280     return false;
281   }
282 
283   struct Context {
284     decltype(function) &function;
285     const MessageHubInfo *messageHubInfo;
286   };
287   Context context = {
288       .function = function,
289       .messageHubInfo = nullptr,
290   };
291   for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
292     context.messageHubInfo = &messageHubRecord.info;
293     messageHubRecord.callback->forEachService(
294         [&context](const EndpointInfo &endpointInfo,
295                    const ServiceInfo &serviceInfo) {
296           return context.function(*context.messageHubInfo, endpointInfo,
297                                   serviceInfo);
298         });
299   }
300   return true;
301 }
302 
forEachMessageHub(const pw::Function<bool (const MessageHubInfo &)> & function)303 bool MessageRouter::forEachMessageHub(
304     const pw::Function<bool(const MessageHubInfo &)> &function) {
305   std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
306       getMessageHubRecords();
307   if (!messageHubRecords.has_value()) {
308     return false;
309   }
310 
311   for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
312     function(messageHubRecord.info);
313   }
314   return true;
315 }
316 
unregisterMessageHub(MessageHubId fromMessageHubId)317 bool MessageRouter::unregisterMessageHub(MessageHubId fromMessageHubId) {
318   DynamicVector<std::pair<pw::IntrusivePtr<MessageHubCallback>, Session>>
319       sessionsToDestroy;
320   DynamicVector<pw::IntrusivePtr<MessageHubCallback>> hubsToNotify;
321 
322   {
323     LockGuard<Mutex> lock(mMutex);
324 
325     if (!mMessageHubs.empty() &&
326         !hubsToNotify.reserve(mMessageHubs.size() - 1)) {
327       LOG_OOM();
328       return false;
329     }
330 
331     bool success = false;
332     for (MessageHubRecord &messageHubRecord : mMessageHubs) {
333       if (messageHubRecord.info.id == fromMessageHubId) {
334         mMessageHubs.erase(&messageHubRecord);
335         success = true;
336       } else {
337         hubsToNotify.push_back(messageHubRecord.callback);
338       }
339     }
340     if (!success) {
341       return false;
342     }
343 
344     for (size_t i = 0; i < mSessions.size();) {
345       Session &session = mSessions[i];
346       bool initiatorIsFromHub =
347           session.initiator.messageHubId == fromMessageHubId;
348       bool peerIsFromHub = session.peer.messageHubId == fromMessageHubId;
349 
350       if (initiatorIsFromHub || peerIsFromHub) {
351         pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
352             getCallbackFromMessageHubIdLocked(
353                 initiatorIsFromHub ? session.peer.messageHubId
354                                    : session.initiator.messageHubId);
355         sessionsToDestroy.push_back(std::make_pair(callback, session));
356         mSessions.erase(&mSessions[i]);
357       } else {
358         ++i;
359       }
360     }
361   }
362 
363   for (auto [callback, session] : sessionsToDestroy) {
364     if (callback != nullptr) {
365       callback->onSessionClosed(session, Reason::HUB_RESET);
366     }
367   }
368   for (auto callback : hubsToNotify) {
369     if (callback != nullptr) {
370       callback->onHubUnregistered(fromMessageHubId);
371     }
372   }
373   return true;
374 }
375 
onSessionOpenComplete(MessageHubId fromMessageHubId,SessionId sessionId)376 void MessageRouter::onSessionOpenComplete(MessageHubId fromMessageHubId,
377                                           SessionId sessionId) {
378   finalizeSession(fromMessageHubId, sessionId, /* reason = */ std::nullopt);
379 }
380 
openSession(MessageHubId fromMessageHubId,EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId,const char * serviceDescriptor,SessionId sessionId)381 SessionId MessageRouter::openSession(MessageHubId fromMessageHubId,
382                                      EndpointId fromEndpointId,
383                                      MessageHubId toMessageHubId,
384                                      EndpointId toEndpointId,
385                                      const char *serviceDescriptor,
386                                      SessionId sessionId) {
387   if (sessionId != SESSION_ID_INVALID && sessionId < kReservedSessionId) {
388     LOGE("Failed to open session: session ID %" PRIu16
389          " is not in the reserved range",
390          sessionId);
391     return SESSION_ID_INVALID;
392   }
393 
394   pw::IntrusivePtr<MessageRouter::MessageHubCallback> initiatorCallback =
395       getCallbackFromMessageHubId(fromMessageHubId);
396   pw::IntrusivePtr<MessageRouter::MessageHubCallback> peerCallback =
397       getCallbackFromMessageHubId(toMessageHubId);
398   if (initiatorCallback == nullptr || peerCallback == nullptr) {
399     LOGE("Failed to open session: %s message hub not found",
400          initiatorCallback == nullptr ? "initiator" : "peer");
401     return SESSION_ID_INVALID;
402   }
403 
404   if (!checkIfEndpointExists(initiatorCallback, fromEndpointId)) {
405     LOGE("Failed to open session: endpoint with ID 0x%" PRIx64
406          " not found in message hub with ID 0x%" PRIx64,
407          fromEndpointId, fromMessageHubId);
408     return SESSION_ID_INVALID;
409   }
410 
411   if (!checkIfEndpointExists(peerCallback, toEndpointId)) {
412     LOGE("Failed to open session: endpoint with ID 0x%" PRIx64
413          " not found in message hub with ID 0x%" PRIx64,
414          toEndpointId, toMessageHubId);
415     return SESSION_ID_INVALID;
416   }
417 
418   if (serviceDescriptor != nullptr &&
419       !peerCallback->doesEndpointHaveService(toEndpointId, serviceDescriptor)) {
420     LOGE("Failed to open session: endpoint with ID 0x%" PRIx64
421          " does not have service descriptor '%s'",
422          toEndpointId, serviceDescriptor);
423     return SESSION_ID_INVALID;
424   }
425 
426   Session session(SESSION_ID_INVALID,
427                   Endpoint(fromMessageHubId, fromEndpointId),
428                   Endpoint(toMessageHubId, toEndpointId), serviceDescriptor);
429   {
430     LockGuard<Mutex> lock(mMutex);
431     if (mSessions.full()) {
432       LOGE("Failed to open session: maximum number of sessions reached");
433       return SESSION_ID_INVALID;
434     }
435 
436     bool foundSession = false;
437     for (Session &existingSession : mSessions) {
438       if (existingSession.isEquivalent(session)) {
439         LOGD("Session with ID %" PRIu16 " already exists",
440              existingSession.sessionId);
441         session = existingSession;
442         foundSession = true;
443         break;
444       }
445     }
446 
447     if (!foundSession) {
448       if (sessionId == SESSION_ID_INVALID) {
449         sessionId = getNextSessionIdLocked();
450         if (sessionId == SESSION_ID_INVALID) {
451           LOGE("Failed to open session: no available session ID");
452           return SESSION_ID_INVALID;
453         }
454       }
455 
456       session.sessionId = sessionId;
457       mSessions.push_back(session);
458     }
459   }
460 
461   peerCallback->onSessionOpenRequest(session);
462   return session.sessionId;
463 }
464 
closeSession(MessageHubId fromMessageHubId,SessionId sessionId,Reason reason)465 bool MessageRouter::closeSession(MessageHubId fromMessageHubId,
466                                  SessionId sessionId, Reason reason) {
467   return finalizeSession(fromMessageHubId, sessionId, reason);
468 }
469 
finalizeSession(MessageHubId fromMessageHubId,SessionId sessionId,std::optional<Reason> reason)470 bool MessageRouter::finalizeSession(MessageHubId fromMessageHubId,
471                                     SessionId sessionId,
472                                     std::optional<Reason> reason) {
473   pw::IntrusivePtr<MessageRouter::MessageHubCallback> peerCallback = nullptr;
474   pw::IntrusivePtr<MessageRouter::MessageHubCallback> initiatorCallback =
475       nullptr;
476   Session session;
477   {
478     LockGuard<Mutex> lock(mMutex);
479     std::optional<size_t> index =
480         findSessionIndexLocked(fromMessageHubId, sessionId);
481     if (!index.has_value()) {
482       LOGE("Failed to %s session with ID %" PRIu16 " not found",
483            reason.has_value() ? "close" : "open", sessionId);
484       return false;
485     }
486 
487     session = mSessions[*index];
488     if (reason.has_value()) {
489       mSessions.erase(&mSessions[*index]);
490     } else {
491       mSessions[*index].isActive = true;
492       session.isActive = true;
493     }
494 
495     initiatorCallback =
496         getCallbackFromMessageHubIdLocked(session.initiator.messageHubId);
497     peerCallback = getCallbackFromMessageHubIdLocked(session.peer.messageHubId);
498 
499     if (initiatorCallback == nullptr || peerCallback == nullptr) {
500       LOGE("Failed to finalize session: %s message hub with ID 0x%" PRIx64
501            " not found",
502            initiatorCallback == nullptr ? "initiator" : "peer",
503            initiatorCallback == nullptr ? session.initiator.messageHubId
504                                         : session.peer.messageHubId);
505       if (!reason.has_value()) {
506         // Only erase if it was not erased above
507         mSessions.erase(&mSessions[*index]);
508       }
509       return false;
510     }
511   }
512 
513   if (reason.has_value()) {
514     initiatorCallback->onSessionClosed(session, reason.value());
515     if (initiatorCallback != peerCallback) {
516       peerCallback->onSessionClosed(session, reason.value());
517     }
518   } else {
519     initiatorCallback->onSessionOpened(session);
520     if (initiatorCallback != peerCallback) {
521       peerCallback->onSessionOpened(session);
522     }
523   }
524   return true;
525 }
526 
getSessionWithId(MessageHubId fromMessageHubId,SessionId sessionId)527 std::optional<Session> MessageRouter::getSessionWithId(
528     MessageHubId fromMessageHubId, SessionId sessionId) {
529   LockGuard<Mutex> lock(mMutex);
530 
531   std::optional<size_t> index =
532       findSessionIndexLocked(fromMessageHubId, sessionId);
533   return index.has_value() ? std::optional<Session>(mSessions[*index])
534                            : std::nullopt;
535 }
536 
sendMessage(pw::UniquePtr<std::byte[]> && data,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId,EndpointId fromEndpointId,MessageHubId fromMessageHubId)537 bool MessageRouter::sendMessage(pw::UniquePtr<std::byte[]> &&data,
538                                 uint32_t messageType,
539                                 uint32_t messagePermissions,
540                                 SessionId sessionId, EndpointId fromEndpointId,
541                                 MessageHubId fromMessageHubId) {
542   pw::IntrusivePtr<MessageRouter::MessageHubCallback> receiverCallback =
543       nullptr;
544   Session session;
545   bool sentBySessionInitiator;
546   {
547     LockGuard<Mutex> lock(mMutex);
548 
549     std::optional<size_t> index =
550         findSessionIndexLocked(fromMessageHubId, sessionId);
551     if (!index.has_value()) {
552       LOGE("Failed to send message: session with ID %" PRIu16 " not found",
553            sessionId);
554       return false;
555     }
556 
557     session = mSessions[*index];
558     if (!session.isActive) {
559       LOGE("Failed to send message: session with ID %" PRIu16 " is inactive",
560            sessionId);
561       return false;
562     }
563 
564     Endpoint sender(fromMessageHubId, fromEndpointId);
565     if (fromEndpointId == ENDPOINT_ID_ANY) {
566       if (session.initiator.messageHubId == session.peer.messageHubId) {
567         LOGE("Unable to infer sender endpoint ID: session with ID %" PRIu16
568              " is between endpoints on the same message hub with ID 0x%" PRIx64,
569              sessionId, fromMessageHubId);
570         return false;
571       }
572       sender.endpointId = session.initiator.messageHubId == fromMessageHubId
573                               ? session.initiator.endpointId
574                               : session.peer.endpointId;
575     }
576 
577     if (sender != session.initiator && sender != session.peer) {
578       LOGE("Failed to send message: session with ID %" PRIu16
579            " does not contain endpoint with hub ID 0x%" PRIx64
580            " and endpoint ID 0x%" PRIx64,
581            sessionId, fromMessageHubId, fromEndpointId);
582       return false;
583     }
584     sentBySessionInitiator = sender == session.initiator;
585     receiverCallback = getCallbackFromMessageHubIdLocked(
586         sentBySessionInitiator ? session.peer.messageHubId
587                                : session.initiator.messageHubId);
588   }
589 
590   bool success = false;
591   if (receiverCallback != nullptr) {
592     success = receiverCallback->onMessageReceived(std::move(data), messageType,
593                                                   messagePermissions, session,
594                                                   sentBySessionInitiator);
595   }
596 
597   if (!success) {
598     closeSession(fromMessageHubId, sessionId, Reason::UNSPECIFIED);
599   }
600   return success;
601 }
602 
registerEndpoint(MessageHubId messageHubId,EndpointId endpointId)603 bool MessageRouter::registerEndpoint(MessageHubId messageHubId,
604                                      EndpointId endpointId) {
605   return onEndpointRegistrationStateChanged(messageHubId, endpointId,
606                                             /* isRegistered = */ true);
607 }
608 
unregisterEndpoint(MessageHubId messageHubId,EndpointId endpointId)609 bool MessageRouter::unregisterEndpoint(MessageHubId messageHubId,
610                                        EndpointId endpointId) {
611   return onEndpointRegistrationStateChanged(messageHubId, endpointId,
612                                             /* isRegistered = */ false);
613 }
614 
onEndpointRegistrationStateChanged(MessageHubId messageHubId,EndpointId endpointId,bool isRegistered)615 bool MessageRouter::onEndpointRegistrationStateChanged(
616     MessageHubId messageHubId, EndpointId endpointId, bool isRegistered) {
617   pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
618       getCallbackFromMessageHubId(messageHubId);
619   if (callback == nullptr) {
620     LOGE("Failed to register endpoint with ID 0x%" PRIx64
621          " to message hub with ID 0x%" PRIx64 ": hub not found",
622          endpointId, messageHubId);
623     return false;
624   }
625 
626   std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
627       getMessageHubRecords();
628   if (!messageHubRecords.has_value()) {
629     return false;
630   }
631 
632   for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
633     if (messageHubRecord.info.id == messageHubId) {
634       continue;
635     }
636 
637     if (isRegistered) {
638       messageHubRecord.callback->onEndpointRegistered(messageHubId, endpointId);
639     } else {
640       messageHubRecord.callback->onEndpointUnregistered(messageHubId,
641                                                         endpointId);
642     }
643   }
644 
645   return true;
646 }
647 
648 std::optional<DynamicVector<MessageRouter::MessageHubRecord>>
getMessageHubRecords()649 MessageRouter::getMessageHubRecords() {
650   LockGuard<Mutex> lock(mMutex);
651   return getMessageHubRecordsLocked();
652 }
653 
654 std::optional<DynamicVector<MessageRouter::MessageHubRecord>>
getMessageHubRecordsLocked()655 MessageRouter::getMessageHubRecordsLocked() {
656   DynamicVector<MessageHubRecord> messageHubRecords;
657   if (!messageHubRecords.reserve(mMessageHubs.size())) {
658     LOG_OOM();
659     return std::nullopt;
660   }
661 
662   for (const MessageHubRecord &messageHubRecord : mMessageHubs) {
663     // Will not fail because we reserved space above
664     messageHubRecords.push_back(messageHubRecord);
665   }
666   return messageHubRecords;
667 }
668 
getMessageHubRecordLocked(MessageHubId messageHubId)669 const MessageRouter::MessageHubRecord *MessageRouter::getMessageHubRecordLocked(
670     MessageHubId messageHubId) {
671   for (MessageHubRecord &messageHubRecord : mMessageHubs) {
672     if (messageHubRecord.info.id == messageHubId) {
673       return &messageHubRecord;
674     }
675   }
676   return nullptr;
677 }
678 
findSessionIndexLocked(MessageHubId fromMessageHubId,SessionId sessionId)679 std::optional<size_t> MessageRouter::findSessionIndexLocked(
680     MessageHubId fromMessageHubId, SessionId sessionId) {
681   for (size_t i = 0; i < mSessions.size(); ++i) {
682     if (mSessions[i].sessionId == sessionId) {
683       if (mSessions[i].initiator.messageHubId == fromMessageHubId ||
684           mSessions[i].peer.messageHubId == fromMessageHubId) {
685         return i;
686       }
687 
688       LOGE("Hub mismatch for session with ID %" PRIu16
689            ": requesting hub ID 0x%" PRIx64
690            " but session is between hubs 0x%" PRIx64 " and 0x%" PRIx64,
691            sessionId, fromMessageHubId, mSessions[i].initiator.messageHubId,
692            mSessions[i].peer.messageHubId);
693       break;
694     }
695   }
696   return std::nullopt;
697 }
698 
699 pw::IntrusivePtr<MessageRouter::MessageHubCallback>
getCallbackFromMessageHubId(MessageHubId messageHubId)700 MessageRouter::getCallbackFromMessageHubId(MessageHubId messageHubId) {
701   LockGuard<Mutex> lock(mMutex);
702   return getCallbackFromMessageHubIdLocked(messageHubId);
703 }
704 
705 pw::IntrusivePtr<MessageRouter::MessageHubCallback>
getCallbackFromMessageHubIdLocked(MessageHubId messageHubId)706 MessageRouter::getCallbackFromMessageHubIdLocked(MessageHubId messageHubId) {
707   const MessageHubRecord *messageHubRecord =
708       getMessageHubRecordLocked(messageHubId);
709   return messageHubRecord == nullptr ? nullptr : messageHubRecord->callback;
710 }
711 
checkIfEndpointExists(const pw::IntrusivePtr<MessageRouter::MessageHubCallback> & callback,EndpointId endpointId)712 bool MessageRouter::checkIfEndpointExists(
713     const pw::IntrusivePtr<MessageRouter::MessageHubCallback> &callback,
714     EndpointId endpointId) {
715   struct EndpointContext {
716     EndpointId endpointId;
717     bool foundEndpoint = false;
718   };
719   EndpointContext context = {
720       .endpointId = endpointId,
721   };
722 
723   callback->forEachEndpoint([&context](const EndpointInfo &endpointInfo) {
724     if (context.endpointId == endpointInfo.id) {
725       context.foundEndpoint = true;
726       return true;
727     }
728     return false;
729   });
730   return context.foundEndpoint;
731 }
732 
getNextSessionIdLocked()733 SessionId MessageRouter::getNextSessionIdLocked() {
734   constexpr size_t kMaxIterations = 10;
735 
736   if (mNextSessionId >= kReservedSessionId) {
737     mNextSessionId = 0;
738   }
739 
740   bool foundSessionIdConflict;
741   size_t iterations = 0;
742   do {
743     foundSessionIdConflict = false;
744     for (const Session &session : mSessions) {
745       if (session.sessionId == mNextSessionId) {
746         ++mNextSessionId;
747         if (mNextSessionId >= kReservedSessionId) {
748           mNextSessionId = 0;
749         }
750         foundSessionIdConflict = true;
751         break;
752       }
753     }
754     ++iterations;
755   } while (foundSessionIdConflict && iterations < kMaxIterations);
756 
757   return foundSessionIdConflict ? SESSION_ID_INVALID : mNextSessionId++;
758 }
759 
760 }  // namespace chre::message
761