• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/bindings/message.h"
6 
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10 
11 #include <algorithm>
12 #include <utility>
13 
14 #include "base/bind.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/strings/stringprintf.h"
18 #include "base/threading/thread_local.h"
19 #include "mojo/public/cpp/bindings/associated_group_controller.h"
20 #include "mojo/public/cpp/bindings/lib/array_internal.h"
21 
22 namespace mojo {
23 
24 namespace {
25 
26 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
27     DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
28 
29 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::
30     DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
31 
DoNotifyBadMessage(Message message,const std::string & error)32 void DoNotifyBadMessage(Message message, const std::string& error) {
33   message.NotifyBadMessage(error);
34 }
35 
36 }  // namespace
37 
Message()38 Message::Message() {
39 }
40 
Message(Message && other)41 Message::Message(Message&& other)
42     : buffer_(std::move(other.buffer_)),
43       handles_(std::move(other.handles_)),
44       associated_endpoint_handles_(
45           std::move(other.associated_endpoint_handles_)) {}
46 
~Message()47 Message::~Message() {
48   CloseHandles();
49 }
50 
operator =(Message && other)51 Message& Message::operator=(Message&& other) {
52   Reset();
53   std::swap(other.buffer_, buffer_);
54   std::swap(other.handles_, handles_);
55   std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_);
56   return *this;
57 }
58 
Reset()59 void Message::Reset() {
60   CloseHandles();
61   handles_.clear();
62   associated_endpoint_handles_.clear();
63   buffer_.reset();
64 }
65 
Initialize(size_t capacity,bool zero_initialized)66 void Message::Initialize(size_t capacity, bool zero_initialized) {
67   DCHECK(!buffer_);
68   buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized));
69 }
70 
InitializeFromMojoMessage(ScopedMessageHandle message,uint32_t num_bytes,std::vector<Handle> * handles)71 void Message::InitializeFromMojoMessage(ScopedMessageHandle message,
72                                         uint32_t num_bytes,
73                                         std::vector<Handle>* handles) {
74   DCHECK(!buffer_);
75   buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes));
76   handles_.swap(*handles);
77 }
78 
payload() const79 const uint8_t* Message::payload() const {
80   if (version() < 2)
81     return data() + header()->num_bytes;
82 
83   return static_cast<const uint8_t*>(header_v2()->payload.Get());
84 }
85 
payload_num_bytes() const86 uint32_t Message::payload_num_bytes() const {
87   DCHECK_GE(data_num_bytes(), header()->num_bytes);
88   size_t num_bytes;
89   if (version() < 2) {
90     num_bytes = data_num_bytes() - header()->num_bytes;
91   } else {
92     auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
93     if (!payload) {
94       num_bytes = 0;
95     } else {
96       auto payload_end =
97           reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
98       if (!payload_end)
99         payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
100       DCHECK_GE(payload_end, payload);
101       num_bytes = payload_end - payload;
102     }
103   }
104   DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max());
105   return static_cast<uint32_t>(num_bytes);
106 }
107 
payload_num_interface_ids() const108 uint32_t Message::payload_num_interface_ids() const {
109   auto* array_pointer =
110       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
111   return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
112 }
113 
payload_interface_ids() const114 const uint32_t* Message::payload_interface_ids() const {
115   auto* array_pointer =
116       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
117   return array_pointer ? array_pointer->storage() : nullptr;
118 }
119 
TakeMojoMessage()120 ScopedMessageHandle Message::TakeMojoMessage() {
121   // If there are associated endpoints transferred,
122   // SerializeAssociatedEndpointHandles() must be called before this method.
123   DCHECK(associated_endpoint_handles_.empty());
124 
125   if (handles_.empty())  // Fast path for the common case: No handles.
126     return buffer_->TakeMessage();
127 
128   // Allocate a new message with space for the handles, then copy the buffer
129   // contents into it.
130   //
131   // TODO(rockot): We could avoid this copy by extending GetSerializedSize()
132   // behavior to collect handles. It's unoptimized for now because it's much
133   // more common to have messages with no handles.
134   ScopedMessageHandle new_message;
135   MojoResult rv = AllocMessage(
136       data_num_bytes(),
137       handles_.empty() ? nullptr
138                        : reinterpret_cast<const MojoHandle*>(handles_.data()),
139       handles_.size(),
140       MOJO_ALLOC_MESSAGE_FLAG_NONE,
141       &new_message);
142   CHECK_EQ(rv, MOJO_RESULT_OK);
143   handles_.clear();
144 
145   void* new_buffer = nullptr;
146   rv = GetMessageBuffer(new_message.get(), &new_buffer);
147   CHECK_EQ(rv, MOJO_RESULT_OK);
148 
149   memcpy(new_buffer, data(), data_num_bytes());
150   buffer_.reset();
151 
152   return new_message;
153 }
154 
NotifyBadMessage(const std::string & error)155 void Message::NotifyBadMessage(const std::string& error) {
156   DCHECK(buffer_);
157   buffer_->NotifyBadMessage(error);
158 }
159 
CloseHandles()160 void Message::CloseHandles() {
161   for (std::vector<Handle>::iterator it = handles_.begin();
162        it != handles_.end(); ++it) {
163     if (it->is_valid())
164       CloseRaw(*it);
165   }
166 }
167 
SerializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)168 void Message::SerializeAssociatedEndpointHandles(
169     AssociatedGroupController* group_controller) {
170   if (associated_endpoint_handles_.empty())
171     return;
172 
173   DCHECK_GE(version(), 2u);
174   DCHECK(header_v2()->payload_interface_ids.is_null());
175 
176   size_t size = associated_endpoint_handles_.size();
177   auto* data = internal::Array_Data<uint32_t>::New(size, buffer());
178   header_v2()->payload_interface_ids.Set(data);
179 
180   for (size_t i = 0; i < size; ++i) {
181     ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
182 
183     DCHECK(handle.pending_association());
184     data->storage()[i] =
185         group_controller->AssociateInterface(std::move(handle));
186   }
187   associated_endpoint_handles_.clear();
188 }
189 
DeserializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)190 bool Message::DeserializeAssociatedEndpointHandles(
191     AssociatedGroupController* group_controller) {
192   associated_endpoint_handles_.clear();
193 
194   uint32_t num_ids = payload_num_interface_ids();
195   if (num_ids == 0)
196     return true;
197 
198   associated_endpoint_handles_.reserve(num_ids);
199   uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
200   bool result = true;
201   for (uint32_t i = 0; i < num_ids; ++i) {
202     auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
203     if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
204       // |ids[i]| itself is valid but handle creation failed. In that case, mark
205       // deserialization as failed but continue to deserialize the rest of
206       // handles.
207       result = false;
208     }
209 
210     associated_endpoint_handles_.push_back(std::move(handle));
211     ids[i] = kInvalidInterfaceId;
212   }
213   return result;
214 }
215 
PassThroughFilter()216 PassThroughFilter::PassThroughFilter() {}
217 
~PassThroughFilter()218 PassThroughFilter::~PassThroughFilter() {}
219 
Accept(Message * message)220 bool PassThroughFilter::Accept(Message* message) { return true; }
221 
SyncMessageResponseContext()222 SyncMessageResponseContext::SyncMessageResponseContext()
223     : outer_context_(current()) {
224   g_tls_sync_response_context.Get().Set(this);
225 }
226 
~SyncMessageResponseContext()227 SyncMessageResponseContext::~SyncMessageResponseContext() {
228   DCHECK_EQ(current(), this);
229   g_tls_sync_response_context.Get().Set(outer_context_);
230 }
231 
232 // static
current()233 SyncMessageResponseContext* SyncMessageResponseContext::current() {
234   return g_tls_sync_response_context.Get().Get();
235 }
236 
ReportBadMessage(const std::string & error)237 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
238   GetBadMessageCallback().Run(error);
239 }
240 
241 const ReportBadMessageCallback&
GetBadMessageCallback()242 SyncMessageResponseContext::GetBadMessageCallback() {
243   if (bad_message_callback_.is_null()) {
244     bad_message_callback_ =
245         base::Bind(&DoNotifyBadMessage, base::Passed(&response_));
246   }
247   return bad_message_callback_;
248 }
249 
ReadMessage(MessagePipeHandle handle,Message * message)250 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
251   MojoResult rv;
252 
253   std::vector<Handle> handles;
254   ScopedMessageHandle mojo_message;
255   uint32_t num_bytes = 0, num_handles = 0;
256   rv = ReadMessageNew(handle,
257                       &mojo_message,
258                       &num_bytes,
259                       nullptr,
260                       &num_handles,
261                       MOJO_READ_MESSAGE_FLAG_NONE);
262   if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
263     DCHECK_GT(num_handles, 0u);
264     handles.resize(num_handles);
265     rv = ReadMessageNew(handle,
266                         &mojo_message,
267                         &num_bytes,
268                         reinterpret_cast<MojoHandle*>(handles.data()),
269                         &num_handles,
270                         MOJO_READ_MESSAGE_FLAG_NONE);
271   }
272 
273   if (rv != MOJO_RESULT_OK)
274     return rv;
275 
276   message->InitializeFromMojoMessage(
277       std::move(mojo_message), num_bytes, &handles);
278   return MOJO_RESULT_OK;
279 }
280 
ReportBadMessage(const std::string & error)281 void ReportBadMessage(const std::string& error) {
282   internal::MessageDispatchContext* context =
283       internal::MessageDispatchContext::current();
284   DCHECK(context);
285   context->GetBadMessageCallback().Run(error);
286 }
287 
GetBadMessageCallback()288 ReportBadMessageCallback GetBadMessageCallback() {
289   internal::MessageDispatchContext* context =
290       internal::MessageDispatchContext::current();
291   DCHECK(context);
292   return context->GetBadMessageCallback();
293 }
294 
295 namespace internal {
296 
297 MessageHeaderV2::MessageHeaderV2() = default;
298 
MessageDispatchContext(Message * message)299 MessageDispatchContext::MessageDispatchContext(Message* message)
300     : outer_context_(current()), message_(message) {
301   g_tls_message_dispatch_context.Get().Set(this);
302 }
303 
~MessageDispatchContext()304 MessageDispatchContext::~MessageDispatchContext() {
305   DCHECK_EQ(current(), this);
306   g_tls_message_dispatch_context.Get().Set(outer_context_);
307 }
308 
309 // static
current()310 MessageDispatchContext* MessageDispatchContext::current() {
311   return g_tls_message_dispatch_context.Get().Get();
312 }
313 
314 const ReportBadMessageCallback&
GetBadMessageCallback()315 MessageDispatchContext::GetBadMessageCallback() {
316   if (bad_message_callback_.is_null()) {
317     bad_message_callback_ =
318         base::Bind(&DoNotifyBadMessage, base::Passed(message_));
319   }
320   return bad_message_callback_;
321 }
322 
323 // static
SetCurrentSyncResponseMessage(Message * message)324 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
325   SyncMessageResponseContext* context = SyncMessageResponseContext::current();
326   if (context)
327     context->response_ = std::move(*message);
328 }
329 
330 }  // namespace internal
331 
332 }  // namespace mojo
333