1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/public/cpp/bindings/message.h"
6
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10
11 #include <algorithm>
12 #include <utility>
13
14 #include "base/bind.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/strings/stringprintf.h"
18 #include "base/threading/thread_local.h"
19 #include "mojo/public/cpp/bindings/associated_group_controller.h"
20 #include "mojo/public/cpp/bindings/lib/array_internal.h"
21
22 namespace mojo {
23
24 namespace {
25
26 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
27 DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
28
29 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::
30 DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
31
DoNotifyBadMessage(Message message,const std::string & error)32 void DoNotifyBadMessage(Message message, const std::string& error) {
33 message.NotifyBadMessage(error);
34 }
35
36 } // namespace
37
Message()38 Message::Message() {
39 }
40
Message(Message && other)41 Message::Message(Message&& other)
42 : buffer_(std::move(other.buffer_)),
43 handles_(std::move(other.handles_)),
44 associated_endpoint_handles_(
45 std::move(other.associated_endpoint_handles_)) {}
46
~Message()47 Message::~Message() {
48 CloseHandles();
49 }
50
operator =(Message && other)51 Message& Message::operator=(Message&& other) {
52 Reset();
53 std::swap(other.buffer_, buffer_);
54 std::swap(other.handles_, handles_);
55 std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_);
56 return *this;
57 }
58
Reset()59 void Message::Reset() {
60 CloseHandles();
61 handles_.clear();
62 associated_endpoint_handles_.clear();
63 buffer_.reset();
64 }
65
Initialize(size_t capacity,bool zero_initialized)66 void Message::Initialize(size_t capacity, bool zero_initialized) {
67 DCHECK(!buffer_);
68 buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized));
69 }
70
InitializeFromMojoMessage(ScopedMessageHandle message,uint32_t num_bytes,std::vector<Handle> * handles)71 void Message::InitializeFromMojoMessage(ScopedMessageHandle message,
72 uint32_t num_bytes,
73 std::vector<Handle>* handles) {
74 DCHECK(!buffer_);
75 buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes));
76 handles_.swap(*handles);
77 }
78
payload() const79 const uint8_t* Message::payload() const {
80 if (version() < 2)
81 return data() + header()->num_bytes;
82
83 return static_cast<const uint8_t*>(header_v2()->payload.Get());
84 }
85
payload_num_bytes() const86 uint32_t Message::payload_num_bytes() const {
87 DCHECK_GE(data_num_bytes(), header()->num_bytes);
88 size_t num_bytes;
89 if (version() < 2) {
90 num_bytes = data_num_bytes() - header()->num_bytes;
91 } else {
92 auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
93 if (!payload) {
94 num_bytes = 0;
95 } else {
96 auto payload_end =
97 reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
98 if (!payload_end)
99 payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
100 DCHECK_GE(payload_end, payload);
101 num_bytes = payload_end - payload;
102 }
103 }
104 DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max());
105 return static_cast<uint32_t>(num_bytes);
106 }
107
payload_num_interface_ids() const108 uint32_t Message::payload_num_interface_ids() const {
109 auto* array_pointer =
110 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
111 return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
112 }
113
payload_interface_ids() const114 const uint32_t* Message::payload_interface_ids() const {
115 auto* array_pointer =
116 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
117 return array_pointer ? array_pointer->storage() : nullptr;
118 }
119
TakeMojoMessage()120 ScopedMessageHandle Message::TakeMojoMessage() {
121 // If there are associated endpoints transferred,
122 // SerializeAssociatedEndpointHandles() must be called before this method.
123 DCHECK(associated_endpoint_handles_.empty());
124
125 if (handles_.empty()) // Fast path for the common case: No handles.
126 return buffer_->TakeMessage();
127
128 // Allocate a new message with space for the handles, then copy the buffer
129 // contents into it.
130 //
131 // TODO(rockot): We could avoid this copy by extending GetSerializedSize()
132 // behavior to collect handles. It's unoptimized for now because it's much
133 // more common to have messages with no handles.
134 ScopedMessageHandle new_message;
135 MojoResult rv = AllocMessage(
136 data_num_bytes(),
137 handles_.empty() ? nullptr
138 : reinterpret_cast<const MojoHandle*>(handles_.data()),
139 handles_.size(),
140 MOJO_ALLOC_MESSAGE_FLAG_NONE,
141 &new_message);
142 CHECK_EQ(rv, MOJO_RESULT_OK);
143 handles_.clear();
144
145 void* new_buffer = nullptr;
146 rv = GetMessageBuffer(new_message.get(), &new_buffer);
147 CHECK_EQ(rv, MOJO_RESULT_OK);
148
149 memcpy(new_buffer, data(), data_num_bytes());
150 buffer_.reset();
151
152 return new_message;
153 }
154
NotifyBadMessage(const std::string & error)155 void Message::NotifyBadMessage(const std::string& error) {
156 DCHECK(buffer_);
157 buffer_->NotifyBadMessage(error);
158 }
159
CloseHandles()160 void Message::CloseHandles() {
161 for (std::vector<Handle>::iterator it = handles_.begin();
162 it != handles_.end(); ++it) {
163 if (it->is_valid())
164 CloseRaw(*it);
165 }
166 }
167
SerializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)168 void Message::SerializeAssociatedEndpointHandles(
169 AssociatedGroupController* group_controller) {
170 if (associated_endpoint_handles_.empty())
171 return;
172
173 DCHECK_GE(version(), 2u);
174 DCHECK(header_v2()->payload_interface_ids.is_null());
175
176 size_t size = associated_endpoint_handles_.size();
177 auto* data = internal::Array_Data<uint32_t>::New(size, buffer());
178 header_v2()->payload_interface_ids.Set(data);
179
180 for (size_t i = 0; i < size; ++i) {
181 ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
182
183 DCHECK(handle.pending_association());
184 data->storage()[i] =
185 group_controller->AssociateInterface(std::move(handle));
186 }
187 associated_endpoint_handles_.clear();
188 }
189
DeserializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)190 bool Message::DeserializeAssociatedEndpointHandles(
191 AssociatedGroupController* group_controller) {
192 associated_endpoint_handles_.clear();
193
194 uint32_t num_ids = payload_num_interface_ids();
195 if (num_ids == 0)
196 return true;
197
198 associated_endpoint_handles_.reserve(num_ids);
199 uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
200 bool result = true;
201 for (uint32_t i = 0; i < num_ids; ++i) {
202 auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
203 if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
204 // |ids[i]| itself is valid but handle creation failed. In that case, mark
205 // deserialization as failed but continue to deserialize the rest of
206 // handles.
207 result = false;
208 }
209
210 associated_endpoint_handles_.push_back(std::move(handle));
211 ids[i] = kInvalidInterfaceId;
212 }
213 return result;
214 }
215
PassThroughFilter()216 PassThroughFilter::PassThroughFilter() {}
217
~PassThroughFilter()218 PassThroughFilter::~PassThroughFilter() {}
219
Accept(Message * message)220 bool PassThroughFilter::Accept(Message* message) { return true; }
221
SyncMessageResponseContext()222 SyncMessageResponseContext::SyncMessageResponseContext()
223 : outer_context_(current()) {
224 g_tls_sync_response_context.Get().Set(this);
225 }
226
~SyncMessageResponseContext()227 SyncMessageResponseContext::~SyncMessageResponseContext() {
228 DCHECK_EQ(current(), this);
229 g_tls_sync_response_context.Get().Set(outer_context_);
230 }
231
232 // static
current()233 SyncMessageResponseContext* SyncMessageResponseContext::current() {
234 return g_tls_sync_response_context.Get().Get();
235 }
236
ReportBadMessage(const std::string & error)237 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
238 GetBadMessageCallback().Run(error);
239 }
240
241 const ReportBadMessageCallback&
GetBadMessageCallback()242 SyncMessageResponseContext::GetBadMessageCallback() {
243 if (bad_message_callback_.is_null()) {
244 bad_message_callback_ =
245 base::Bind(&DoNotifyBadMessage, base::Passed(&response_));
246 }
247 return bad_message_callback_;
248 }
249
ReadMessage(MessagePipeHandle handle,Message * message)250 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
251 MojoResult rv;
252
253 std::vector<Handle> handles;
254 ScopedMessageHandle mojo_message;
255 uint32_t num_bytes = 0, num_handles = 0;
256 rv = ReadMessageNew(handle,
257 &mojo_message,
258 &num_bytes,
259 nullptr,
260 &num_handles,
261 MOJO_READ_MESSAGE_FLAG_NONE);
262 if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
263 DCHECK_GT(num_handles, 0u);
264 handles.resize(num_handles);
265 rv = ReadMessageNew(handle,
266 &mojo_message,
267 &num_bytes,
268 reinterpret_cast<MojoHandle*>(handles.data()),
269 &num_handles,
270 MOJO_READ_MESSAGE_FLAG_NONE);
271 }
272
273 if (rv != MOJO_RESULT_OK)
274 return rv;
275
276 message->InitializeFromMojoMessage(
277 std::move(mojo_message), num_bytes, &handles);
278 return MOJO_RESULT_OK;
279 }
280
ReportBadMessage(const std::string & error)281 void ReportBadMessage(const std::string& error) {
282 internal::MessageDispatchContext* context =
283 internal::MessageDispatchContext::current();
284 DCHECK(context);
285 context->GetBadMessageCallback().Run(error);
286 }
287
GetBadMessageCallback()288 ReportBadMessageCallback GetBadMessageCallback() {
289 internal::MessageDispatchContext* context =
290 internal::MessageDispatchContext::current();
291 DCHECK(context);
292 return context->GetBadMessageCallback();
293 }
294
295 namespace internal {
296
297 MessageHeaderV2::MessageHeaderV2() = default;
298
MessageDispatchContext(Message * message)299 MessageDispatchContext::MessageDispatchContext(Message* message)
300 : outer_context_(current()), message_(message) {
301 g_tls_message_dispatch_context.Get().Set(this);
302 }
303
~MessageDispatchContext()304 MessageDispatchContext::~MessageDispatchContext() {
305 DCHECK_EQ(current(), this);
306 g_tls_message_dispatch_context.Get().Set(outer_context_);
307 }
308
309 // static
current()310 MessageDispatchContext* MessageDispatchContext::current() {
311 return g_tls_message_dispatch_context.Get().Get();
312 }
313
314 const ReportBadMessageCallback&
GetBadMessageCallback()315 MessageDispatchContext::GetBadMessageCallback() {
316 if (bad_message_callback_.is_null()) {
317 bad_message_callback_ =
318 base::Bind(&DoNotifyBadMessage, base::Passed(message_));
319 }
320 return bad_message_callback_;
321 }
322
323 // static
SetCurrentSyncResponseMessage(Message * message)324 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
325 SyncMessageResponseContext* context = SyncMessageResponseContext::current();
326 if (context)
327 context->response_ = std::move(*message);
328 }
329
330 } // namespace internal
331
332 } // namespace mojo
333