1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/public/cpp/bindings/message.h"
6
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10
11 #include <algorithm>
12 #include <utility>
13
14 #include "base/bind.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/numerics/safe_math.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/threading/thread_local.h"
20 #include "mojo/public/cpp/bindings/associated_group_controller.h"
21 #include "mojo/public/cpp/bindings/lib/array_internal.h"
22 #include "mojo/public/cpp/bindings/lib/unserialized_message_context.h"
23
24 namespace mojo {
25
26 namespace {
27
28 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
29 Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
30
31 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky
32 g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
33
DoNotifyBadMessage(Message message,const std::string & error)34 void DoNotifyBadMessage(Message message, const std::string& error) {
35 message.NotifyBadMessage(error);
36 }
37
38 template <typename HeaderType>
AllocateHeaderFromBuffer(internal::Buffer * buffer,HeaderType ** header)39 void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) {
40 *header = buffer->AllocateAndGet<HeaderType>();
41 (*header)->num_bytes = sizeof(HeaderType);
42 }
43
WriteMessageHeader(uint32_t name,uint32_t flags,size_t payload_interface_id_count,internal::Buffer * payload_buffer)44 void WriteMessageHeader(uint32_t name,
45 uint32_t flags,
46 size_t payload_interface_id_count,
47 internal::Buffer* payload_buffer) {
48 if (payload_interface_id_count > 0) {
49 // Version 2
50 internal::MessageHeaderV2* header;
51 AllocateHeaderFromBuffer(payload_buffer, &header);
52 header->version = 2;
53 header->name = name;
54 header->flags = flags;
55 // The payload immediately follows the header.
56 header->payload.Set(header + 1);
57 } else if (flags &
58 (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
59 // Version 1
60 internal::MessageHeaderV1* header;
61 AllocateHeaderFromBuffer(payload_buffer, &header);
62 header->version = 1;
63 header->name = name;
64 header->flags = flags;
65 } else {
66 internal::MessageHeader* header;
67 AllocateHeaderFromBuffer(payload_buffer, &header);
68 header->version = 0;
69 header->name = name;
70 header->flags = flags;
71 }
72 }
73
CreateSerializedMessageObject(uint32_t name,uint32_t flags,size_t payload_size,size_t payload_interface_id_count,std::vector<ScopedHandle> * handles,ScopedMessageHandle * out_handle,internal::Buffer * out_buffer)74 void CreateSerializedMessageObject(uint32_t name,
75 uint32_t flags,
76 size_t payload_size,
77 size_t payload_interface_id_count,
78 std::vector<ScopedHandle>* handles,
79 ScopedMessageHandle* out_handle,
80 internal::Buffer* out_buffer) {
81 ScopedMessageHandle handle;
82 MojoResult rv = mojo::CreateMessage(&handle);
83 DCHECK_EQ(MOJO_RESULT_OK, rv);
84 DCHECK(handle.is_valid());
85
86 void* buffer;
87 uint32_t buffer_size;
88 size_t total_size = internal::ComputeSerializedMessageSize(
89 flags, payload_size, payload_interface_id_count);
90 DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size));
91 DCHECK(!handles ||
92 base::IsValueInRangeForNumericType<uint32_t>(handles->size()));
93 rv = MojoAppendMessageData(
94 handle->value(), static_cast<uint32_t>(total_size),
95 handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr,
96 handles ? static_cast<uint32_t>(handles->size()) : 0, nullptr, &buffer,
97 &buffer_size);
98 DCHECK_EQ(MOJO_RESULT_OK, rv);
99 if (handles) {
100 // Handle ownership has been taken by MojoAppendMessageData.
101 for (size_t i = 0; i < handles->size(); ++i)
102 ignore_result(handles->at(i).release());
103 }
104
105 internal::Buffer payload_buffer(handle.get(), total_size, buffer,
106 buffer_size);
107
108 // Make sure we zero the memory first!
109 memset(payload_buffer.data(), 0, total_size);
110 WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer);
111
112 *out_handle = std::move(handle);
113 *out_buffer = std::move(payload_buffer);
114 }
115
SerializeUnserializedContext(MojoMessageHandle message,uintptr_t context_value)116 void SerializeUnserializedContext(MojoMessageHandle message,
117 uintptr_t context_value) {
118 auto* context =
119 reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
120 void* buffer;
121 uint32_t buffer_size;
122 MojoResult attach_result = MojoAppendMessageData(
123 message, 0, nullptr, 0, nullptr, &buffer, &buffer_size);
124 if (attach_result != MOJO_RESULT_OK)
125 return;
126
127 internal::Buffer payload_buffer(MessageHandle(message), 0, buffer,
128 buffer_size);
129 WriteMessageHeader(context->message_name(), context->message_flags(),
130 0 /* payload_interface_id_count */, &payload_buffer);
131
132 // We need to copy additional header data which may have been set after
133 // message construction, as this codepath may be reached at some arbitrary
134 // time between message send and message dispatch.
135 static_cast<internal::MessageHeader*>(buffer)->interface_id =
136 context->header()->interface_id;
137 if (context->header()->flags &
138 (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
139 DCHECK_GE(context->header()->version, 1u);
140 static_cast<internal::MessageHeaderV1*>(buffer)->request_id =
141 context->header()->request_id;
142 }
143
144 internal::SerializationContext serialization_context;
145 context->Serialize(&serialization_context, &payload_buffer);
146
147 // TODO(crbug.com/753433): Support lazy serialization of associated endpoint
148 // handles. See corresponding TODO in the bindings generator for proof that
149 // this DCHECK is indeed valid.
150 DCHECK(serialization_context.associated_endpoint_handles()->empty());
151 if (!serialization_context.handles()->empty())
152 payload_buffer.AttachHandles(serialization_context.mutable_handles());
153 payload_buffer.Seal();
154 }
155
DestroyUnserializedContext(uintptr_t context)156 void DestroyUnserializedContext(uintptr_t context) {
157 delete reinterpret_cast<internal::UnserializedMessageContext*>(context);
158 }
159
CreateUnserializedMessageObject(std::unique_ptr<internal::UnserializedMessageContext> context)160 ScopedMessageHandle CreateUnserializedMessageObject(
161 std::unique_ptr<internal::UnserializedMessageContext> context) {
162 ScopedMessageHandle handle;
163 MojoResult rv = mojo::CreateMessage(&handle);
164 DCHECK_EQ(MOJO_RESULT_OK, rv);
165 DCHECK(handle.is_valid());
166
167 rv = MojoSetMessageContext(
168 handle->value(), reinterpret_cast<uintptr_t>(context.release()),
169 &SerializeUnserializedContext, &DestroyUnserializedContext, nullptr);
170 DCHECK_EQ(MOJO_RESULT_OK, rv);
171 return handle;
172 }
173
174 } // namespace
175
176 Message::Message() = default;
177
Message(Message && other)178 Message::Message(Message&& other)
179 : handle_(std::move(other.handle_)),
180 payload_buffer_(std::move(other.payload_buffer_)),
181 handles_(std::move(other.handles_)),
182 associated_endpoint_handles_(
183 std::move(other.associated_endpoint_handles_)),
184 transferable_(other.transferable_),
185 serialized_(other.serialized_) {
186 other.transferable_ = false;
187 other.serialized_ = false;
188 #if defined(ENABLE_IPC_FUZZER)
189 interface_name_ = other.interface_name_;
190 method_name_ = other.method_name_;
191 #endif
192 }
193
Message(std::unique_ptr<internal::UnserializedMessageContext> context)194 Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context)
195 : Message(CreateUnserializedMessageObject(std::move(context))) {}
196
Message(uint32_t name,uint32_t flags,size_t payload_size,size_t payload_interface_id_count,std::vector<ScopedHandle> * handles)197 Message::Message(uint32_t name,
198 uint32_t flags,
199 size_t payload_size,
200 size_t payload_interface_id_count,
201 std::vector<ScopedHandle>* handles) {
202 CreateSerializedMessageObject(name, flags, payload_size,
203 payload_interface_id_count, handles, &handle_,
204 &payload_buffer_);
205 transferable_ = true;
206 serialized_ = true;
207 }
208
Message(ScopedMessageHandle handle)209 Message::Message(ScopedMessageHandle handle) {
210 DCHECK(handle.is_valid());
211
212 uintptr_t context_value = 0;
213 MojoResult get_context_result =
214 MojoGetMessageContext(handle->value(), nullptr, &context_value);
215 if (get_context_result == MOJO_RESULT_NOT_FOUND) {
216 // It's a serialized message. Extract handles if possible.
217 uint32_t num_bytes;
218 void* buffer;
219 uint32_t num_handles = 0;
220 MojoResult rv = MojoGetMessageData(handle->value(), nullptr, &buffer,
221 &num_bytes, nullptr, &num_handles);
222 if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
223 handles_.resize(num_handles);
224 rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes,
225 reinterpret_cast<MojoHandle*>(handles_.data()),
226 &num_handles);
227 } else {
228 // No handles, so it's safe to retransmit this message if the caller
229 // really wants to.
230 transferable_ = true;
231 }
232
233 if (rv != MOJO_RESULT_OK) {
234 // Failed to deserialize handles. Leave the Message uninitialized.
235 return;
236 }
237
238 payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes);
239 serialized_ = true;
240 } else {
241 DCHECK_EQ(MOJO_RESULT_OK, get_context_result);
242 auto* context =
243 reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
244 // Dummy data address so common header accessors still behave properly. The
245 // choice is V1 reflects unserialized message capabilities: we may or may
246 // not need to support request IDs (which require at least V1), but we never
247 // (for now, anyway) need to support associated interface handles (V2).
248 payload_buffer_ =
249 internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1),
250 sizeof(internal::MessageHeaderV1));
251 transferable_ = true;
252 serialized_ = false;
253 }
254
255 handle_ = std::move(handle);
256 }
257
258 Message::~Message() = default;
259
operator =(Message && other)260 Message& Message::operator=(Message&& other) {
261 handle_ = std::move(other.handle_);
262 payload_buffer_ = std::move(other.payload_buffer_);
263 handles_ = std::move(other.handles_);
264 associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_);
265 transferable_ = other.transferable_;
266 other.transferable_ = false;
267 serialized_ = other.serialized_;
268 other.serialized_ = false;
269 #if defined(ENABLE_IPC_FUZZER)
270 interface_name_ = other.interface_name_;
271 method_name_ = other.method_name_;
272 #endif
273 return *this;
274 }
275
Reset()276 void Message::Reset() {
277 handle_.reset();
278 payload_buffer_.Reset();
279 handles_.clear();
280 associated_endpoint_handles_.clear();
281 transferable_ = false;
282 serialized_ = false;
283 }
284
payload() const285 const uint8_t* Message::payload() const {
286 if (version() < 2)
287 return data() + header()->num_bytes;
288
289 DCHECK(!header_v2()->payload.is_null());
290 return static_cast<const uint8_t*>(header_v2()->payload.Get());
291 }
292
payload_num_bytes() const293 uint32_t Message::payload_num_bytes() const {
294 DCHECK_GE(data_num_bytes(), header()->num_bytes);
295 size_t num_bytes;
296 if (version() < 2) {
297 num_bytes = data_num_bytes() - header()->num_bytes;
298 } else {
299 auto payload_begin =
300 reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
301 auto payload_end =
302 reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
303 if (!payload_end)
304 payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
305 DCHECK_GE(payload_end, payload_begin);
306 num_bytes = payload_end - payload_begin;
307 }
308 DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes));
309 return static_cast<uint32_t>(num_bytes);
310 }
311
payload_num_interface_ids() const312 uint32_t Message::payload_num_interface_ids() const {
313 auto* array_pointer =
314 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
315 return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
316 }
317
payload_interface_ids() const318 const uint32_t* Message::payload_interface_ids() const {
319 auto* array_pointer =
320 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
321 return array_pointer ? array_pointer->storage() : nullptr;
322 }
323
AttachHandlesFromSerializationContext(internal::SerializationContext * context)324 void Message::AttachHandlesFromSerializationContext(
325 internal::SerializationContext* context) {
326 if (context->handles()->empty() &&
327 context->associated_endpoint_handles()->empty()) {
328 // No handles attached, so no extra serialization work.
329 return;
330 }
331
332 if (context->associated_endpoint_handles()->empty()) {
333 // Attaching only non-associated handles is easier since we don't have to
334 // modify the message header. Faster path for that.
335 payload_buffer_.AttachHandles(context->mutable_handles());
336 return;
337 }
338
339 // Allocate a new message with enough space to hold all attached handles. Copy
340 // this message's contents into the new one and use it to replace ourself.
341 //
342 // TODO(rockot): We could avoid the extra full message allocation by instead
343 // growing the buffer and carefully moving its contents around. This errs on
344 // the side of less complexity with probably only marginal performance cost.
345 uint32_t payload_size = payload_num_bytes();
346 mojo::Message new_message(name(), header()->flags, payload_size,
347 context->associated_endpoint_handles()->size(),
348 context->mutable_handles());
349 std::swap(*context->mutable_associated_endpoint_handles(),
350 new_message.associated_endpoint_handles_);
351 memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(),
352 payload_size);
353 *this = std::move(new_message);
354 }
355
TakeMojoMessage()356 ScopedMessageHandle Message::TakeMojoMessage() {
357 // If there are associated endpoints transferred,
358 // SerializeAssociatedEndpointHandles() must be called before this method.
359 DCHECK(associated_endpoint_handles_.empty());
360 DCHECK(transferable_);
361 payload_buffer_.Seal();
362 auto handle = std::move(handle_);
363 Reset();
364 return handle;
365 }
366
NotifyBadMessage(const std::string & error)367 void Message::NotifyBadMessage(const std::string& error) {
368 DCHECK(handle_.is_valid());
369 mojo::NotifyBadMessage(handle_.get(), error);
370 }
371
SerializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)372 void Message::SerializeAssociatedEndpointHandles(
373 AssociatedGroupController* group_controller) {
374 if (associated_endpoint_handles_.empty())
375 return;
376
377 DCHECK_GE(version(), 2u);
378 DCHECK(header_v2()->payload_interface_ids.is_null());
379 DCHECK(payload_buffer_.is_valid());
380 DCHECK(handle_.is_valid());
381
382 size_t size = associated_endpoint_handles_.size();
383
384 internal::Array_Data<uint32_t>::BufferWriter handle_writer;
385 handle_writer.Allocate(size, &payload_buffer_);
386 header_v2()->payload_interface_ids.Set(handle_writer.data());
387
388 for (size_t i = 0; i < size; ++i) {
389 ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
390
391 DCHECK(handle.pending_association());
392 handle_writer->storage()[i] =
393 group_controller->AssociateInterface(std::move(handle));
394 }
395 associated_endpoint_handles_.clear();
396 }
397
DeserializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)398 bool Message::DeserializeAssociatedEndpointHandles(
399 AssociatedGroupController* group_controller) {
400 if (!serialized_)
401 return true;
402
403 associated_endpoint_handles_.clear();
404
405 uint32_t num_ids = payload_num_interface_ids();
406 if (num_ids == 0)
407 return true;
408
409 associated_endpoint_handles_.reserve(num_ids);
410 uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
411 bool result = true;
412 for (uint32_t i = 0; i < num_ids; ++i) {
413 auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
414 if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
415 // |ids[i]| itself is valid but handle creation failed. In that case, mark
416 // deserialization as failed but continue to deserialize the rest of
417 // handles.
418 result = false;
419 }
420
421 associated_endpoint_handles_.push_back(std::move(handle));
422 ids[i] = kInvalidInterfaceId;
423 }
424 return result;
425 }
426
SerializeIfNecessary()427 void Message::SerializeIfNecessary() {
428 MojoResult rv = MojoSerializeMessage(handle_->value(), nullptr);
429 if (rv == MOJO_RESULT_FAILED_PRECONDITION)
430 return;
431
432 // Reconstruct this Message instance from the serialized message's handle.
433 *this = Message(std::move(handle_));
434 }
435
436 std::unique_ptr<internal::UnserializedMessageContext>
TakeUnserializedContext(const internal::UnserializedMessageContext::Tag * tag)437 Message::TakeUnserializedContext(
438 const internal::UnserializedMessageContext::Tag* tag) {
439 DCHECK(handle_.is_valid());
440 uintptr_t context_value = 0;
441 MojoResult rv =
442 MojoGetMessageContext(handle_->value(), nullptr, &context_value);
443 if (rv == MOJO_RESULT_NOT_FOUND)
444 return nullptr;
445 DCHECK_EQ(MOJO_RESULT_OK, rv);
446
447 auto* context =
448 reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
449 if (context->tag() != tag)
450 return nullptr;
451
452 // Detach the context from the message.
453 rv = MojoSetMessageContext(handle_->value(), 0, nullptr, nullptr, nullptr);
454 DCHECK_EQ(MOJO_RESULT_OK, rv);
455 return base::WrapUnique(context);
456 }
457
PrefersSerializedMessages()458 bool MessageReceiver::PrefersSerializedMessages() {
459 return false;
460 }
461
PassThroughFilter()462 PassThroughFilter::PassThroughFilter() {}
463
~PassThroughFilter()464 PassThroughFilter::~PassThroughFilter() {}
465
Accept(Message * message)466 bool PassThroughFilter::Accept(Message* message) {
467 return true;
468 }
469
SyncMessageResponseContext()470 SyncMessageResponseContext::SyncMessageResponseContext()
471 : outer_context_(current()) {
472 g_tls_sync_response_context.Get().Set(this);
473 }
474
~SyncMessageResponseContext()475 SyncMessageResponseContext::~SyncMessageResponseContext() {
476 DCHECK_EQ(current(), this);
477 g_tls_sync_response_context.Get().Set(outer_context_);
478 }
479
480 // static
current()481 SyncMessageResponseContext* SyncMessageResponseContext::current() {
482 return g_tls_sync_response_context.Get().Get();
483 }
484
ReportBadMessage(const std::string & error)485 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
486 GetBadMessageCallback().Run(error);
487 }
488
GetBadMessageCallback()489 ReportBadMessageCallback SyncMessageResponseContext::GetBadMessageCallback() {
490 DCHECK(!response_.IsNull());
491 return base::BindOnce(&DoNotifyBadMessage, std::move(response_));
492 }
493
ReadMessage(MessagePipeHandle handle,Message * message)494 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
495 ScopedMessageHandle message_handle;
496 MojoResult rv =
497 ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE);
498 if (rv != MOJO_RESULT_OK)
499 return rv;
500
501 *message = Message(std::move(message_handle));
502 return MOJO_RESULT_OK;
503 }
504
ReportBadMessage(const std::string & error)505 void ReportBadMessage(const std::string& error) {
506 internal::MessageDispatchContext* context =
507 internal::MessageDispatchContext::current();
508 DCHECK(context);
509 context->GetBadMessageCallback().Run(error);
510 }
511
GetBadMessageCallback()512 ReportBadMessageCallback GetBadMessageCallback() {
513 internal::MessageDispatchContext* context =
514 internal::MessageDispatchContext::current();
515 DCHECK(context);
516 return context->GetBadMessageCallback();
517 }
518
519 namespace internal {
520
521 MessageHeaderV2::MessageHeaderV2() = default;
522
MessageDispatchContext(Message * message)523 MessageDispatchContext::MessageDispatchContext(Message* message)
524 : outer_context_(current()), message_(message) {
525 g_tls_message_dispatch_context.Get().Set(this);
526 }
527
~MessageDispatchContext()528 MessageDispatchContext::~MessageDispatchContext() {
529 DCHECK_EQ(current(), this);
530 g_tls_message_dispatch_context.Get().Set(outer_context_);
531 }
532
533 // static
current()534 MessageDispatchContext* MessageDispatchContext::current() {
535 return g_tls_message_dispatch_context.Get().Get();
536 }
537
GetBadMessageCallback()538 ReportBadMessageCallback MessageDispatchContext::GetBadMessageCallback() {
539 DCHECK(!message_->IsNull());
540 return base::BindOnce(&DoNotifyBadMessage, std::move(*message_));
541 }
542
543 // static
SetCurrentSyncResponseMessage(Message * message)544 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
545 SyncMessageResponseContext* context = SyncMessageResponseContext::current();
546 if (context)
547 context->response_ = std::move(*message);
548 }
549
550 } // namespace internal
551
552 } // namespace mojo
553