1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "chre/util/system/message_router.h"
18 #include "chre/util/dynamic_vector.h"
19 #include "chre/util/lock_guard.h"
20 #include "chre/util/system/message_common.h"
21
22 #include <inttypes.h>
23 #include <cstring>
24 #include <optional>
25 #include <utility>
26
27 namespace chre::message {
28
MessageHub()29 MessageRouter::MessageHub::MessageHub()
30 : mRouter(nullptr), mHubId(MESSAGE_HUB_ID_INVALID) {}
31
MessageHub(MessageRouter & router,MessageHubId id)32 MessageRouter::MessageHub::MessageHub(MessageRouter &router, MessageHubId id)
33 : mRouter(&router), mHubId(id) {}
34
MessageHub(MessageHub && other)35 MessageRouter::MessageHub::MessageHub(MessageHub &&other)
36 : mRouter(other.mRouter), mHubId(other.mHubId) {
37 other.mRouter = nullptr;
38 other.mHubId = MESSAGE_HUB_ID_INVALID;
39 }
40
operator =(MessageHub && other)41 MessageRouter::MessageHub &MessageRouter::MessageHub::operator=(
42 MessageHub &&other) {
43 unregister();
44 mRouter = other.mRouter;
45 mHubId = other.mHubId;
46 other.mRouter = nullptr;
47 other.mHubId = MESSAGE_HUB_ID_INVALID;
48 return *this;
49 }
50
~MessageHub()51 MessageRouter::MessageHub::~MessageHub() {
52 unregister();
53 }
54
onSessionOpenComplete(SessionId sessionId)55 void MessageRouter::MessageHub::onSessionOpenComplete(SessionId sessionId) {
56 if (mRouter != nullptr) {
57 mRouter->onSessionOpenComplete(mHubId, sessionId);
58 }
59 }
60
openSession(EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId,const char * serviceDescriptor,SessionId sessionId)61 SessionId MessageRouter::MessageHub::openSession(EndpointId fromEndpointId,
62 MessageHubId toMessageHubId,
63 EndpointId toEndpointId,
64 const char *serviceDescriptor,
65 SessionId sessionId) {
66 return mRouter == nullptr
67 ? SESSION_ID_INVALID
68 : mRouter->openSession(mHubId, fromEndpointId, toMessageHubId,
69 toEndpointId, serviceDescriptor, sessionId);
70 }
71
closeSession(SessionId sessionId,Reason reason)72 bool MessageRouter::MessageHub::closeSession(SessionId sessionId,
73 Reason reason) {
74 return mRouter != nullptr && mRouter->closeSession(mHubId, sessionId, reason);
75 }
76
getSessionWithId(SessionId sessionId)77 std::optional<Session> MessageRouter::MessageHub::getSessionWithId(
78 SessionId sessionId) {
79 return mRouter == nullptr ? std::nullopt
80 : mRouter->getSessionWithId(mHubId, sessionId);
81 }
82
sendMessage(pw::UniquePtr<std::byte[]> && data,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId,EndpointId fromEndpointId)83 bool MessageRouter::MessageHub::sendMessage(pw::UniquePtr<std::byte[]> &&data,
84 uint32_t messageType,
85 uint32_t messagePermissions,
86 SessionId sessionId,
87 EndpointId fromEndpointId) {
88 return mRouter != nullptr &&
89 mRouter->sendMessage(std::move(data), messageType, messagePermissions,
90 sessionId, fromEndpointId, mHubId);
91 }
92
registerEndpoint(EndpointId endpointId)93 bool MessageRouter::MessageHub::registerEndpoint(EndpointId endpointId) {
94 return mRouter != nullptr && mRouter->registerEndpoint(mHubId, endpointId);
95 }
96
unregisterEndpoint(EndpointId endpointId)97 bool MessageRouter::MessageHub::unregisterEndpoint(EndpointId endpointId) {
98 return mRouter != nullptr && mRouter->unregisterEndpoint(mHubId, endpointId);
99 }
100
getId()101 MessageHubId MessageRouter::MessageHub::getId() {
102 return mHubId;
103 }
104
isRegistered()105 bool MessageRouter::MessageHub::isRegistered() {
106 return mRouter != nullptr;
107 }
108
unregister()109 void MessageRouter::MessageHub::unregister() {
110 if (mRouter != nullptr) {
111 mRouter->unregisterMessageHub(mHubId);
112 }
113 mRouter = nullptr;
114 }
115
116 std::optional<typename MessageRouter::MessageHub>
registerMessageHub(const char * name,MessageHubId id,pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback)117 MessageRouter::registerMessageHub(
118 const char *name, MessageHubId id,
119 pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback) {
120 DynamicVector<MessageHubRecord> hubsToNotify;
121 std::optional<MessageHub> newHub;
122 MessageHubInfo newHubInfo;
123 {
124 LockGuard<Mutex> lock(mMutex);
125 if (mMessageHubs.full()) {
126 LOGE(
127 "Message hub '%s' not registered: maximum number of message hubs "
128 "reached",
129 name);
130 return std::nullopt;
131 }
132
133 for (MessageHubRecord &messageHub : mMessageHubs) {
134 if (std::strcmp(messageHub.info.name, name) == 0 ||
135 messageHub.info.id == id) {
136 LOGE(
137 "Message hub '%s' not registered: hub with same name or ID already "
138 "exists",
139 name);
140 return std::nullopt;
141 }
142 }
143
144 if (auto hubRecords = getMessageHubRecordsLocked(); hubRecords) {
145 hubsToNotify = std::move(*hubRecords);
146 } else {
147 return std::nullopt;
148 }
149
150 MessageHubRecord messageHubRecord = {
151 .info = {.id = id, .name = name},
152 .callback = std::move(callback),
153 };
154 newHubInfo = messageHubRecord.info;
155 mMessageHubs.push_back(std::move(messageHubRecord));
156 newHub = MessageHub(*this, id);
157 }
158
159 // NOTE: newHubInfo is guaranteed to be valid while we have newHub.
160 for (const auto &hubRecord : hubsToNotify) {
161 hubRecord.callback->onHubRegistered(newHubInfo);
162 }
163 return newHub;
164 }
165
forEachEndpointOfHub(MessageHubId messageHubId,const pw::Function<bool (const EndpointInfo &)> & function)166 bool MessageRouter::forEachEndpointOfHub(
167 MessageHubId messageHubId,
168 const pw::Function<bool(const EndpointInfo &)> &function) {
169 pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
170 getCallbackFromMessageHubId(messageHubId);
171 if (callback == nullptr) {
172 LOGE("Failed to find message hub with ID 0x%" PRIx64, messageHubId);
173 return false;
174 }
175
176 callback->forEachEndpoint(function);
177 return true;
178 }
179
forEachEndpoint(const pw::Function<void (const MessageHubInfo &,const EndpointInfo &)> & function)180 bool MessageRouter::forEachEndpoint(
181 const pw::Function<void(const MessageHubInfo &, const EndpointInfo &)>
182 &function) {
183 std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
184 getMessageHubRecords();
185 if (!messageHubRecords.has_value()) {
186 return false;
187 }
188
189 struct Context {
190 decltype(function) function;
191 const MessageHubInfo &messageHubInfo;
192 };
193 for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
194 Context context = {
195 .function = function,
196 .messageHubInfo = messageHubRecord.info,
197 };
198
199 messageHubRecord.callback->forEachEndpoint(
200 [&context](const EndpointInfo &endpointInfo) {
201 context.function(context.messageHubInfo, endpointInfo);
202 return false;
203 });
204 }
205 return true;
206 }
207
getEndpointInfo(MessageHubId messageHubId,EndpointId endpointId)208 std::optional<EndpointInfo> MessageRouter::getEndpointInfo(
209 MessageHubId messageHubId, EndpointId endpointId) {
210 pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
211 getCallbackFromMessageHubId(messageHubId);
212 if (callback == nullptr) {
213 LOGE("Failed to get endpoint info for message hub with ID 0x%" PRIx64
214 " and endpoint ID 0x%" PRIx64 ": hub not found",
215 messageHubId, endpointId);
216 return std::nullopt;
217 }
218
219 return callback->getEndpointInfo(endpointId);
220 }
221
getEndpointForService(MessageHubId messageHubId,const char * serviceDescriptor)222 std::optional<Endpoint> MessageRouter::getEndpointForService(
223 MessageHubId messageHubId, const char *serviceDescriptor) {
224 if (serviceDescriptor == nullptr) {
225 LOGE("Failed to get endpoint for service: service descriptor is null");
226 return std::nullopt;
227 }
228
229 std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
230 getMessageHubRecords();
231 if (!messageHubRecords.has_value()) {
232 return std::nullopt;
233 }
234
235 for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
236 if ((messageHubId == MESSAGE_HUB_ID_ANY ||
237 messageHubId == messageHubRecord.info.id) &&
238 messageHubRecord.callback != nullptr) {
239 std::optional<EndpointId> endpointId =
240 messageHubRecord.callback->getEndpointForService(serviceDescriptor);
241 if (endpointId.has_value()) {
242 return Endpoint(messageHubRecord.info.id, *endpointId);
243 }
244
245 // Only searching this message hub, so return early if not found
246 if (messageHubId != MESSAGE_HUB_ID_ANY) {
247 return std::nullopt;
248 }
249 }
250 }
251 return std::nullopt;
252 }
253
doesEndpointHaveService(MessageHubId messageHubId,EndpointId endpointId,const char * serviceDescriptor)254 bool MessageRouter::doesEndpointHaveService(MessageHubId messageHubId,
255 EndpointId endpointId,
256 const char *serviceDescriptor) {
257 if (serviceDescriptor == nullptr) {
258 LOGE("Failed to check if endpoint has service: service descriptor is null");
259 return false;
260 }
261
262 pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
263 getCallbackFromMessageHubId(messageHubId);
264 if (callback == nullptr) {
265 LOGE(
266 "Failed to check if endpoint has service for message hub with ID "
267 "0x%" PRIx64 " and endpoint ID 0x%" PRIx64 ": hub not found",
268 messageHubId, endpointId);
269 return false;
270 }
271 return callback->doesEndpointHaveService(endpointId, serviceDescriptor);
272 }
273
forEachService(const pw::Function<bool (const MessageHubInfo &,const EndpointInfo &,const ServiceInfo &)> & function)274 bool MessageRouter::forEachService(
275 const pw::Function<bool(const MessageHubInfo &, const EndpointInfo &,
276 const ServiceInfo &)> &function) {
277 std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
278 getMessageHubRecords();
279 if (!messageHubRecords.has_value()) {
280 return false;
281 }
282
283 struct Context {
284 decltype(function) &function;
285 const MessageHubInfo *messageHubInfo;
286 };
287 Context context = {
288 .function = function,
289 .messageHubInfo = nullptr,
290 };
291 for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
292 context.messageHubInfo = &messageHubRecord.info;
293 messageHubRecord.callback->forEachService(
294 [&context](const EndpointInfo &endpointInfo,
295 const ServiceInfo &serviceInfo) {
296 return context.function(*context.messageHubInfo, endpointInfo,
297 serviceInfo);
298 });
299 }
300 return true;
301 }
302
forEachMessageHub(const pw::Function<bool (const MessageHubInfo &)> & function)303 bool MessageRouter::forEachMessageHub(
304 const pw::Function<bool(const MessageHubInfo &)> &function) {
305 std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
306 getMessageHubRecords();
307 if (!messageHubRecords.has_value()) {
308 return false;
309 }
310
311 for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
312 function(messageHubRecord.info);
313 }
314 return true;
315 }
316
unregisterMessageHub(MessageHubId fromMessageHubId)317 bool MessageRouter::unregisterMessageHub(MessageHubId fromMessageHubId) {
318 DynamicVector<std::pair<pw::IntrusivePtr<MessageHubCallback>, Session>>
319 sessionsToDestroy;
320 DynamicVector<pw::IntrusivePtr<MessageHubCallback>> hubsToNotify;
321
322 {
323 LockGuard<Mutex> lock(mMutex);
324
325 if (!mMessageHubs.empty() &&
326 !hubsToNotify.reserve(mMessageHubs.size() - 1)) {
327 LOG_OOM();
328 return false;
329 }
330
331 bool success = false;
332 for (MessageHubRecord &messageHubRecord : mMessageHubs) {
333 if (messageHubRecord.info.id == fromMessageHubId) {
334 mMessageHubs.erase(&messageHubRecord);
335 success = true;
336 } else {
337 hubsToNotify.push_back(messageHubRecord.callback);
338 }
339 }
340 if (!success) {
341 return false;
342 }
343
344 for (size_t i = 0; i < mSessions.size();) {
345 Session &session = mSessions[i];
346 bool initiatorIsFromHub =
347 session.initiator.messageHubId == fromMessageHubId;
348 bool peerIsFromHub = session.peer.messageHubId == fromMessageHubId;
349
350 if (initiatorIsFromHub || peerIsFromHub) {
351 pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
352 getCallbackFromMessageHubIdLocked(
353 initiatorIsFromHub ? session.peer.messageHubId
354 : session.initiator.messageHubId);
355 sessionsToDestroy.push_back(std::make_pair(callback, session));
356 mSessions.erase(&mSessions[i]);
357 } else {
358 ++i;
359 }
360 }
361 }
362
363 for (auto [callback, session] : sessionsToDestroy) {
364 if (callback != nullptr) {
365 callback->onSessionClosed(session, Reason::HUB_RESET);
366 }
367 }
368 for (auto callback : hubsToNotify) {
369 if (callback != nullptr) {
370 callback->onHubUnregistered(fromMessageHubId);
371 }
372 }
373 return true;
374 }
375
onSessionOpenComplete(MessageHubId fromMessageHubId,SessionId sessionId)376 void MessageRouter::onSessionOpenComplete(MessageHubId fromMessageHubId,
377 SessionId sessionId) {
378 finalizeSession(fromMessageHubId, sessionId, /* reason = */ std::nullopt);
379 }
380
openSession(MessageHubId fromMessageHubId,EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId,const char * serviceDescriptor,SessionId sessionId)381 SessionId MessageRouter::openSession(MessageHubId fromMessageHubId,
382 EndpointId fromEndpointId,
383 MessageHubId toMessageHubId,
384 EndpointId toEndpointId,
385 const char *serviceDescriptor,
386 SessionId sessionId) {
387 if (sessionId != SESSION_ID_INVALID && sessionId < kReservedSessionId) {
388 LOGE("Failed to open session: session ID %" PRIu16
389 " is not in the reserved range",
390 sessionId);
391 return SESSION_ID_INVALID;
392 }
393
394 pw::IntrusivePtr<MessageRouter::MessageHubCallback> initiatorCallback =
395 getCallbackFromMessageHubId(fromMessageHubId);
396 pw::IntrusivePtr<MessageRouter::MessageHubCallback> peerCallback =
397 getCallbackFromMessageHubId(toMessageHubId);
398 if (initiatorCallback == nullptr || peerCallback == nullptr) {
399 LOGE("Failed to open session: %s message hub not found",
400 initiatorCallback == nullptr ? "initiator" : "peer");
401 return SESSION_ID_INVALID;
402 }
403
404 if (!checkIfEndpointExists(initiatorCallback, fromEndpointId)) {
405 LOGE("Failed to open session: endpoint with ID 0x%" PRIx64
406 " not found in message hub with ID 0x%" PRIx64,
407 fromEndpointId, fromMessageHubId);
408 return SESSION_ID_INVALID;
409 }
410
411 if (!checkIfEndpointExists(peerCallback, toEndpointId)) {
412 LOGE("Failed to open session: endpoint with ID 0x%" PRIx64
413 " not found in message hub with ID 0x%" PRIx64,
414 toEndpointId, toMessageHubId);
415 return SESSION_ID_INVALID;
416 }
417
418 if (serviceDescriptor != nullptr &&
419 !peerCallback->doesEndpointHaveService(toEndpointId, serviceDescriptor)) {
420 LOGE("Failed to open session: endpoint with ID 0x%" PRIx64
421 " does not have service descriptor '%s'",
422 toEndpointId, serviceDescriptor);
423 return SESSION_ID_INVALID;
424 }
425
426 Session session(SESSION_ID_INVALID,
427 Endpoint(fromMessageHubId, fromEndpointId),
428 Endpoint(toMessageHubId, toEndpointId), serviceDescriptor);
429 {
430 LockGuard<Mutex> lock(mMutex);
431 if (mSessions.full()) {
432 LOGE("Failed to open session: maximum number of sessions reached");
433 return SESSION_ID_INVALID;
434 }
435
436 bool foundSession = false;
437 for (Session &existingSession : mSessions) {
438 if (existingSession.isEquivalent(session)) {
439 LOGD("Session with ID %" PRIu16 " already exists",
440 existingSession.sessionId);
441 session = existingSession;
442 foundSession = true;
443 break;
444 }
445 }
446
447 if (!foundSession) {
448 if (sessionId == SESSION_ID_INVALID) {
449 sessionId = getNextSessionIdLocked();
450 if (sessionId == SESSION_ID_INVALID) {
451 LOGE("Failed to open session: no available session ID");
452 return SESSION_ID_INVALID;
453 }
454 }
455
456 session.sessionId = sessionId;
457 mSessions.push_back(session);
458 }
459 }
460
461 peerCallback->onSessionOpenRequest(session);
462 return session.sessionId;
463 }
464
closeSession(MessageHubId fromMessageHubId,SessionId sessionId,Reason reason)465 bool MessageRouter::closeSession(MessageHubId fromMessageHubId,
466 SessionId sessionId, Reason reason) {
467 return finalizeSession(fromMessageHubId, sessionId, reason);
468 }
469
finalizeSession(MessageHubId fromMessageHubId,SessionId sessionId,std::optional<Reason> reason)470 bool MessageRouter::finalizeSession(MessageHubId fromMessageHubId,
471 SessionId sessionId,
472 std::optional<Reason> reason) {
473 pw::IntrusivePtr<MessageRouter::MessageHubCallback> peerCallback = nullptr;
474 pw::IntrusivePtr<MessageRouter::MessageHubCallback> initiatorCallback =
475 nullptr;
476 Session session;
477 {
478 LockGuard<Mutex> lock(mMutex);
479 std::optional<size_t> index =
480 findSessionIndexLocked(fromMessageHubId, sessionId);
481 if (!index.has_value()) {
482 LOGE("Failed to %s session with ID %" PRIu16 " not found",
483 reason.has_value() ? "close" : "open", sessionId);
484 return false;
485 }
486
487 session = mSessions[*index];
488 if (reason.has_value()) {
489 mSessions.erase(&mSessions[*index]);
490 } else {
491 mSessions[*index].isActive = true;
492 session.isActive = true;
493 }
494
495 initiatorCallback =
496 getCallbackFromMessageHubIdLocked(session.initiator.messageHubId);
497 peerCallback = getCallbackFromMessageHubIdLocked(session.peer.messageHubId);
498
499 if (initiatorCallback == nullptr || peerCallback == nullptr) {
500 LOGE("Failed to finalize session: %s message hub with ID 0x%" PRIx64
501 " not found",
502 initiatorCallback == nullptr ? "initiator" : "peer",
503 initiatorCallback == nullptr ? session.initiator.messageHubId
504 : session.peer.messageHubId);
505 if (!reason.has_value()) {
506 // Only erase if it was not erased above
507 mSessions.erase(&mSessions[*index]);
508 }
509 return false;
510 }
511 }
512
513 if (reason.has_value()) {
514 initiatorCallback->onSessionClosed(session, reason.value());
515 if (initiatorCallback != peerCallback) {
516 peerCallback->onSessionClosed(session, reason.value());
517 }
518 } else {
519 initiatorCallback->onSessionOpened(session);
520 if (initiatorCallback != peerCallback) {
521 peerCallback->onSessionOpened(session);
522 }
523 }
524 return true;
525 }
526
getSessionWithId(MessageHubId fromMessageHubId,SessionId sessionId)527 std::optional<Session> MessageRouter::getSessionWithId(
528 MessageHubId fromMessageHubId, SessionId sessionId) {
529 LockGuard<Mutex> lock(mMutex);
530
531 std::optional<size_t> index =
532 findSessionIndexLocked(fromMessageHubId, sessionId);
533 return index.has_value() ? std::optional<Session>(mSessions[*index])
534 : std::nullopt;
535 }
536
sendMessage(pw::UniquePtr<std::byte[]> && data,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId,EndpointId fromEndpointId,MessageHubId fromMessageHubId)537 bool MessageRouter::sendMessage(pw::UniquePtr<std::byte[]> &&data,
538 uint32_t messageType,
539 uint32_t messagePermissions,
540 SessionId sessionId, EndpointId fromEndpointId,
541 MessageHubId fromMessageHubId) {
542 pw::IntrusivePtr<MessageRouter::MessageHubCallback> receiverCallback =
543 nullptr;
544 Session session;
545 bool sentBySessionInitiator;
546 {
547 LockGuard<Mutex> lock(mMutex);
548
549 std::optional<size_t> index =
550 findSessionIndexLocked(fromMessageHubId, sessionId);
551 if (!index.has_value()) {
552 LOGE("Failed to send message: session with ID %" PRIu16 " not found",
553 sessionId);
554 return false;
555 }
556
557 session = mSessions[*index];
558 if (!session.isActive) {
559 LOGE("Failed to send message: session with ID %" PRIu16 " is inactive",
560 sessionId);
561 return false;
562 }
563
564 Endpoint sender(fromMessageHubId, fromEndpointId);
565 if (fromEndpointId == ENDPOINT_ID_ANY) {
566 if (session.initiator.messageHubId == session.peer.messageHubId) {
567 LOGE("Unable to infer sender endpoint ID: session with ID %" PRIu16
568 " is between endpoints on the same message hub with ID 0x%" PRIx64,
569 sessionId, fromMessageHubId);
570 return false;
571 }
572 sender.endpointId = session.initiator.messageHubId == fromMessageHubId
573 ? session.initiator.endpointId
574 : session.peer.endpointId;
575 }
576
577 if (sender != session.initiator && sender != session.peer) {
578 LOGE("Failed to send message: session with ID %" PRIu16
579 " does not contain endpoint with hub ID 0x%" PRIx64
580 " and endpoint ID 0x%" PRIx64,
581 sessionId, fromMessageHubId, fromEndpointId);
582 return false;
583 }
584 sentBySessionInitiator = sender == session.initiator;
585 receiverCallback = getCallbackFromMessageHubIdLocked(
586 sentBySessionInitiator ? session.peer.messageHubId
587 : session.initiator.messageHubId);
588 }
589
590 bool success = false;
591 if (receiverCallback != nullptr) {
592 success = receiverCallback->onMessageReceived(std::move(data), messageType,
593 messagePermissions, session,
594 sentBySessionInitiator);
595 }
596
597 if (!success) {
598 closeSession(fromMessageHubId, sessionId, Reason::UNSPECIFIED);
599 }
600 return success;
601 }
602
registerEndpoint(MessageHubId messageHubId,EndpointId endpointId)603 bool MessageRouter::registerEndpoint(MessageHubId messageHubId,
604 EndpointId endpointId) {
605 return onEndpointRegistrationStateChanged(messageHubId, endpointId,
606 /* isRegistered = */ true);
607 }
608
unregisterEndpoint(MessageHubId messageHubId,EndpointId endpointId)609 bool MessageRouter::unregisterEndpoint(MessageHubId messageHubId,
610 EndpointId endpointId) {
611 return onEndpointRegistrationStateChanged(messageHubId, endpointId,
612 /* isRegistered = */ false);
613 }
614
onEndpointRegistrationStateChanged(MessageHubId messageHubId,EndpointId endpointId,bool isRegistered)615 bool MessageRouter::onEndpointRegistrationStateChanged(
616 MessageHubId messageHubId, EndpointId endpointId, bool isRegistered) {
617 pw::IntrusivePtr<MessageRouter::MessageHubCallback> callback =
618 getCallbackFromMessageHubId(messageHubId);
619 if (callback == nullptr) {
620 LOGE("Failed to register endpoint with ID 0x%" PRIx64
621 " to message hub with ID 0x%" PRIx64 ": hub not found",
622 endpointId, messageHubId);
623 return false;
624 }
625
626 std::optional<DynamicVector<MessageHubRecord>> messageHubRecords =
627 getMessageHubRecords();
628 if (!messageHubRecords.has_value()) {
629 return false;
630 }
631
632 for (const MessageHubRecord &messageHubRecord : *messageHubRecords) {
633 if (messageHubRecord.info.id == messageHubId) {
634 continue;
635 }
636
637 if (isRegistered) {
638 messageHubRecord.callback->onEndpointRegistered(messageHubId, endpointId);
639 } else {
640 messageHubRecord.callback->onEndpointUnregistered(messageHubId,
641 endpointId);
642 }
643 }
644
645 return true;
646 }
647
648 std::optional<DynamicVector<MessageRouter::MessageHubRecord>>
getMessageHubRecords()649 MessageRouter::getMessageHubRecords() {
650 LockGuard<Mutex> lock(mMutex);
651 return getMessageHubRecordsLocked();
652 }
653
654 std::optional<DynamicVector<MessageRouter::MessageHubRecord>>
getMessageHubRecordsLocked()655 MessageRouter::getMessageHubRecordsLocked() {
656 DynamicVector<MessageHubRecord> messageHubRecords;
657 if (!messageHubRecords.reserve(mMessageHubs.size())) {
658 LOG_OOM();
659 return std::nullopt;
660 }
661
662 for (const MessageHubRecord &messageHubRecord : mMessageHubs) {
663 // Will not fail because we reserved space above
664 messageHubRecords.push_back(messageHubRecord);
665 }
666 return messageHubRecords;
667 }
668
getMessageHubRecordLocked(MessageHubId messageHubId)669 const MessageRouter::MessageHubRecord *MessageRouter::getMessageHubRecordLocked(
670 MessageHubId messageHubId) {
671 for (MessageHubRecord &messageHubRecord : mMessageHubs) {
672 if (messageHubRecord.info.id == messageHubId) {
673 return &messageHubRecord;
674 }
675 }
676 return nullptr;
677 }
678
findSessionIndexLocked(MessageHubId fromMessageHubId,SessionId sessionId)679 std::optional<size_t> MessageRouter::findSessionIndexLocked(
680 MessageHubId fromMessageHubId, SessionId sessionId) {
681 for (size_t i = 0; i < mSessions.size(); ++i) {
682 if (mSessions[i].sessionId == sessionId) {
683 if (mSessions[i].initiator.messageHubId == fromMessageHubId ||
684 mSessions[i].peer.messageHubId == fromMessageHubId) {
685 return i;
686 }
687
688 LOGE("Hub mismatch for session with ID %" PRIu16
689 ": requesting hub ID 0x%" PRIx64
690 " but session is between hubs 0x%" PRIx64 " and 0x%" PRIx64,
691 sessionId, fromMessageHubId, mSessions[i].initiator.messageHubId,
692 mSessions[i].peer.messageHubId);
693 break;
694 }
695 }
696 return std::nullopt;
697 }
698
699 pw::IntrusivePtr<MessageRouter::MessageHubCallback>
getCallbackFromMessageHubId(MessageHubId messageHubId)700 MessageRouter::getCallbackFromMessageHubId(MessageHubId messageHubId) {
701 LockGuard<Mutex> lock(mMutex);
702 return getCallbackFromMessageHubIdLocked(messageHubId);
703 }
704
705 pw::IntrusivePtr<MessageRouter::MessageHubCallback>
getCallbackFromMessageHubIdLocked(MessageHubId messageHubId)706 MessageRouter::getCallbackFromMessageHubIdLocked(MessageHubId messageHubId) {
707 const MessageHubRecord *messageHubRecord =
708 getMessageHubRecordLocked(messageHubId);
709 return messageHubRecord == nullptr ? nullptr : messageHubRecord->callback;
710 }
711
checkIfEndpointExists(const pw::IntrusivePtr<MessageRouter::MessageHubCallback> & callback,EndpointId endpointId)712 bool MessageRouter::checkIfEndpointExists(
713 const pw::IntrusivePtr<MessageRouter::MessageHubCallback> &callback,
714 EndpointId endpointId) {
715 struct EndpointContext {
716 EndpointId endpointId;
717 bool foundEndpoint = false;
718 };
719 EndpointContext context = {
720 .endpointId = endpointId,
721 };
722
723 callback->forEachEndpoint([&context](const EndpointInfo &endpointInfo) {
724 if (context.endpointId == endpointInfo.id) {
725 context.foundEndpoint = true;
726 return true;
727 }
728 return false;
729 });
730 return context.foundEndpoint;
731 }
732
getNextSessionIdLocked()733 SessionId MessageRouter::getNextSessionIdLocked() {
734 constexpr size_t kMaxIterations = 10;
735
736 if (mNextSessionId >= kReservedSessionId) {
737 mNextSessionId = 0;
738 }
739
740 bool foundSessionIdConflict;
741 size_t iterations = 0;
742 do {
743 foundSessionIdConflict = false;
744 for (const Session &session : mSessions) {
745 if (session.sessionId == mNextSessionId) {
746 ++mNextSessionId;
747 if (mNextSessionId >= kReservedSessionId) {
748 mNextSessionId = 0;
749 }
750 foundSessionIdConflict = true;
751 break;
752 }
753 }
754 ++iterations;
755 } while (foundSessionIdConflict && iterations < kMaxIterations);
756
757 return foundSessionIdConflict ? SESSION_ID_INVALID : mNextSessionId++;
758 }
759
760 } // namespace chre::message
761