• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/system/raw_channel.h"
6 
7 #include <string.h>
8 
9 #include <algorithm>
10 
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/message_loop/message_loop.h"
15 #include "base/stl_util.h"
16 #include "mojo/system/message_in_transit.h"
17 #include "mojo/system/transport_data.h"
18 
19 namespace mojo {
20 namespace system {
21 
22 const size_t kReadSize = 4096;
23 
24 // RawChannel::ReadBuffer ------------------------------------------------------
25 
ReadBuffer()26 RawChannel::ReadBuffer::ReadBuffer()
27     : buffer_(kReadSize),
28       num_valid_bytes_(0) {
29 }
30 
~ReadBuffer()31 RawChannel::ReadBuffer::~ReadBuffer() {
32 }
33 
GetBuffer(char ** addr,size_t * size)34 void RawChannel::ReadBuffer::GetBuffer(char** addr, size_t* size) {
35   DCHECK_GE(buffer_.size(), num_valid_bytes_ + kReadSize);
36   *addr = &buffer_[0] + num_valid_bytes_;
37   *size = kReadSize;
38 }
39 
40 // RawChannel::WriteBuffer -----------------------------------------------------
41 
WriteBuffer(size_t serialized_platform_handle_size)42 RawChannel::WriteBuffer::WriteBuffer(size_t serialized_platform_handle_size)
43     : serialized_platform_handle_size_(serialized_platform_handle_size),
44       platform_handles_offset_(0),
45       data_offset_(0) {
46 }
47 
~WriteBuffer()48 RawChannel::WriteBuffer::~WriteBuffer() {
49   STLDeleteElements(&message_queue_);
50 }
51 
HavePlatformHandlesToSend() const52 bool RawChannel::WriteBuffer::HavePlatformHandlesToSend() const {
53   if (message_queue_.empty())
54     return false;
55 
56   const TransportData* transport_data =
57       message_queue_.front()->transport_data();
58   if (!transport_data)
59     return false;
60 
61   const embedder::PlatformHandleVector* all_platform_handles =
62       transport_data->platform_handles();
63   if (!all_platform_handles) {
64     DCHECK_EQ(platform_handles_offset_, 0u);
65     return false;
66   }
67   if (platform_handles_offset_ >= all_platform_handles->size()) {
68     DCHECK_EQ(platform_handles_offset_, all_platform_handles->size());
69     return false;
70   }
71 
72   return true;
73 }
74 
GetPlatformHandlesToSend(size_t * num_platform_handles,embedder::PlatformHandle ** platform_handles,void ** serialization_data)75 void RawChannel::WriteBuffer::GetPlatformHandlesToSend(
76     size_t* num_platform_handles,
77     embedder::PlatformHandle** platform_handles,
78     void** serialization_data) {
79   DCHECK(HavePlatformHandlesToSend());
80 
81   TransportData* transport_data = message_queue_.front()->transport_data();
82   embedder::PlatformHandleVector* all_platform_handles =
83       transport_data->platform_handles();
84   *num_platform_handles =
85       all_platform_handles->size() - platform_handles_offset_;
86   *platform_handles = &(*all_platform_handles)[platform_handles_offset_];
87   size_t serialization_data_offset =
88       transport_data->platform_handle_table_offset();
89   DCHECK_GT(serialization_data_offset, 0u);
90   serialization_data_offset +=
91       platform_handles_offset_ * serialized_platform_handle_size_;
92   *serialization_data =
93       static_cast<char*>(transport_data->buffer()) + serialization_data_offset;
94 }
95 
GetBuffers(std::vector<Buffer> * buffers) const96 void RawChannel::WriteBuffer::GetBuffers(std::vector<Buffer>* buffers) const {
97   buffers->clear();
98 
99   if (message_queue_.empty())
100     return;
101 
102   MessageInTransit* message = message_queue_.front();
103   DCHECK_LT(data_offset_, message->total_size());
104   size_t bytes_to_write = message->total_size() - data_offset_;
105 
106   size_t transport_data_buffer_size = message->transport_data() ?
107       message->transport_data()->buffer_size() : 0;
108 
109   if (!transport_data_buffer_size) {
110     // Only write from the main buffer.
111     DCHECK_LT(data_offset_, message->main_buffer_size());
112     DCHECK_LE(bytes_to_write, message->main_buffer_size());
113     Buffer buffer = {
114         static_cast<const char*>(message->main_buffer()) + data_offset_,
115         bytes_to_write};
116     buffers->push_back(buffer);
117     return;
118   }
119 
120   if (data_offset_ >= message->main_buffer_size()) {
121     // Only write from the transport data buffer.
122     DCHECK_LT(data_offset_ - message->main_buffer_size(),
123               transport_data_buffer_size);
124     DCHECK_LE(bytes_to_write, transport_data_buffer_size);
125     Buffer buffer = {
126         static_cast<const char*>(message->transport_data()->buffer()) +
127             (data_offset_ - message->main_buffer_size()),
128         bytes_to_write};
129     buffers->push_back(buffer);
130     return;
131   }
132 
133   // TODO(vtl): We could actually send out buffers from multiple messages, with
134   // the "stopping" condition being reaching a message with platform handles
135   // attached.
136 
137   // Write from both buffers.
138   DCHECK_EQ(bytes_to_write, message->main_buffer_size() - data_offset_ +
139                                 transport_data_buffer_size);
140   Buffer buffer1 = {
141     static_cast<const char*>(message->main_buffer()) + data_offset_,
142     message->main_buffer_size() - data_offset_
143   };
144   buffers->push_back(buffer1);
145   Buffer buffer2 = {
146     static_cast<const char*>(message->transport_data()->buffer()),
147     transport_data_buffer_size
148   };
149   buffers->push_back(buffer2);
150 }
151 
152 // RawChannel ------------------------------------------------------------------
153 
RawChannel()154 RawChannel::RawChannel()
155     : message_loop_for_io_(NULL),
156       delegate_(NULL),
157       read_stopped_(false),
158       write_stopped_(false),
159       weak_ptr_factory_(this) {
160 }
161 
~RawChannel()162 RawChannel::~RawChannel() {
163   DCHECK(!read_buffer_);
164   DCHECK(!write_buffer_);
165 
166   // No need to take the |write_lock_| here -- if there are still weak pointers
167   // outstanding, then we're hosed anyway (since we wouldn't be able to
168   // invalidate them cleanly, since we might not be on the I/O thread).
169   DCHECK(!weak_ptr_factory_.HasWeakPtrs());
170 }
171 
Init(Delegate * delegate)172 bool RawChannel::Init(Delegate* delegate) {
173   DCHECK(delegate);
174 
175   DCHECK(!delegate_);
176   delegate_ = delegate;
177 
178   CHECK_EQ(base::MessageLoop::current()->type(), base::MessageLoop::TYPE_IO);
179   DCHECK(!message_loop_for_io_);
180   message_loop_for_io_ =
181       static_cast<base::MessageLoopForIO*>(base::MessageLoop::current());
182 
183   // No need to take the lock. No one should be using us yet.
184   DCHECK(!read_buffer_);
185   read_buffer_.reset(new ReadBuffer);
186   DCHECK(!write_buffer_);
187   write_buffer_.reset(new WriteBuffer(GetSerializedPlatformHandleSize()));
188 
189   if (!OnInit()) {
190     delegate_ = NULL;
191     message_loop_for_io_ = NULL;
192     read_buffer_.reset();
193     write_buffer_.reset();
194     return false;
195   }
196 
197   if (ScheduleRead() != IO_PENDING) {
198     // This will notify the delegate about the read failure. Although we're on
199     // the I/O thread, don't call it in the nested context.
200     message_loop_for_io_->PostTask(
201         FROM_HERE,
202         base::Bind(&RawChannel::OnReadCompleted, weak_ptr_factory_.GetWeakPtr(),
203                    false, 0));
204   }
205 
206   // ScheduleRead() failure is treated as a read failure (by notifying the
207   // delegate), not as an init failure.
208   return true;
209 }
210 
Shutdown()211 void RawChannel::Shutdown() {
212   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
213 
214   base::AutoLock locker(write_lock_);
215 
216   LOG_IF(WARNING, !write_buffer_->message_queue_.empty())
217       << "Shutting down RawChannel with write buffer nonempty";
218 
219   // Reset the delegate so that it won't receive further calls.
220   delegate_ = NULL;
221   read_stopped_ = true;
222   write_stopped_ = true;
223   weak_ptr_factory_.InvalidateWeakPtrs();
224 
225   OnShutdownNoLock(read_buffer_.Pass(), write_buffer_.Pass());
226 }
227 
228 // Reminder: This must be thread-safe.
WriteMessage(scoped_ptr<MessageInTransit> message)229 bool RawChannel::WriteMessage(scoped_ptr<MessageInTransit> message) {
230   DCHECK(message);
231 
232   base::AutoLock locker(write_lock_);
233   if (write_stopped_)
234     return false;
235 
236   if (!write_buffer_->message_queue_.empty()) {
237     EnqueueMessageNoLock(message.Pass());
238     return true;
239   }
240 
241   EnqueueMessageNoLock(message.Pass());
242   DCHECK_EQ(write_buffer_->data_offset_, 0u);
243 
244   size_t platform_handles_written = 0;
245   size_t bytes_written = 0;
246   IOResult io_result = WriteNoLock(&platform_handles_written, &bytes_written);
247   if (io_result == IO_PENDING)
248     return true;
249 
250   bool result = OnWriteCompletedNoLock(io_result == IO_SUCCEEDED,
251                                        platform_handles_written,
252                                        bytes_written);
253   if (!result) {
254     // Even if we're on the I/O thread, don't call |OnFatalError()| in the
255     // nested context.
256     message_loop_for_io_->PostTask(
257         FROM_HERE,
258         base::Bind(&RawChannel::CallOnFatalError,
259                    weak_ptr_factory_.GetWeakPtr(),
260                    Delegate::FATAL_ERROR_WRITE));
261   }
262 
263   return result;
264 }
265 
266 // Reminder: This must be thread-safe.
IsWriteBufferEmpty()267 bool RawChannel::IsWriteBufferEmpty() {
268   base::AutoLock locker(write_lock_);
269   return write_buffer_->message_queue_.empty();
270 }
271 
OnReadCompleted(bool result,size_t bytes_read)272 void RawChannel::OnReadCompleted(bool result, size_t bytes_read) {
273   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
274 
275   if (read_stopped_) {
276     NOTREACHED();
277     return;
278   }
279 
280   IOResult io_result = result ? IO_SUCCEEDED : IO_FAILED;
281 
282   // Keep reading data in a loop, and dispatch messages if enough data is
283   // received. Exit the loop if any of the following happens:
284   //   - one or more messages were dispatched;
285   //   - the last read failed, was a partial read or would block;
286   //   - |Shutdown()| was called.
287   do {
288     if (io_result != IO_SUCCEEDED) {
289       read_stopped_ = true;
290       CallOnFatalError(Delegate::FATAL_ERROR_READ);
291       return;
292     }
293 
294     read_buffer_->num_valid_bytes_ += bytes_read;
295 
296     // Dispatch all the messages that we can.
297     bool did_dispatch_message = false;
298     // Tracks the offset of the first undispatched message in |read_buffer_|.
299     // Currently, we copy data to ensure that this is zero at the beginning.
300     size_t read_buffer_start = 0;
301     size_t remaining_bytes = read_buffer_->num_valid_bytes_;
302     size_t message_size;
303     // Note that we rely on short-circuit evaluation here:
304     //   - |read_buffer_start| may be an invalid index into
305     //     |read_buffer_->buffer_| if |remaining_bytes| is zero.
306     //   - |message_size| is only valid if |GetNextMessageSize()| returns true.
307     // TODO(vtl): Use |message_size| more intelligently (e.g., to request the
308     // next read).
309     // TODO(vtl): Validate that |message_size| is sane.
310     while (remaining_bytes > 0 &&
311            MessageInTransit::GetNextMessageSize(
312                &read_buffer_->buffer_[read_buffer_start], remaining_bytes,
313                &message_size) &&
314            remaining_bytes >= message_size) {
315       MessageInTransit::View
316           message_view(message_size, &read_buffer_->buffer_[read_buffer_start]);
317       DCHECK_EQ(message_view.total_size(), message_size);
318 
319       const char* error_message = NULL;
320       if (!message_view.IsValid(GetSerializedPlatformHandleSize(),
321                                 &error_message)) {
322         DCHECK(error_message);
323         LOG(WARNING) << "Received invalid message: " << error_message;
324         read_stopped_ = true;
325         CallOnFatalError(Delegate::FATAL_ERROR_READ);
326         return;
327       }
328 
329       if (message_view.type() == MessageInTransit::kTypeRawChannel) {
330         if (!OnReadMessageForRawChannel(message_view)) {
331           read_stopped_ = true;
332           CallOnFatalError(Delegate::FATAL_ERROR_READ);
333           return;
334         }
335       } else {
336         embedder::ScopedPlatformHandleVectorPtr platform_handles;
337         if (message_view.transport_data_buffer()) {
338           size_t num_platform_handles;
339           const void* platform_handle_table;
340           TransportData::GetPlatformHandleTable(
341               message_view.transport_data_buffer(),
342               &num_platform_handles,
343               &platform_handle_table);
344 
345           if (num_platform_handles > 0) {
346             platform_handles =
347                 GetReadPlatformHandles(num_platform_handles,
348                                        platform_handle_table).Pass();
349             if (!platform_handles) {
350               LOG(WARNING) << "Invalid number of platform handles received";
351               read_stopped_ = true;
352               CallOnFatalError(Delegate::FATAL_ERROR_READ);
353               return;
354             }
355           }
356         }
357 
358         // TODO(vtl): In the case that we aren't expecting any platform handles,
359         // for the POSIX implementation, we should confirm that none are stored.
360 
361         // Dispatch the message.
362         DCHECK(delegate_);
363         delegate_->OnReadMessage(message_view, platform_handles.Pass());
364         if (read_stopped_) {
365           // |Shutdown()| was called in |OnReadMessage()|.
366           // TODO(vtl): Add test for this case.
367           return;
368         }
369       }
370 
371       did_dispatch_message = true;
372 
373       // Update our state.
374       read_buffer_start += message_size;
375       remaining_bytes -= message_size;
376     }
377 
378     if (read_buffer_start > 0) {
379       // Move data back to start.
380       read_buffer_->num_valid_bytes_ = remaining_bytes;
381       if (read_buffer_->num_valid_bytes_ > 0) {
382         memmove(&read_buffer_->buffer_[0],
383                 &read_buffer_->buffer_[read_buffer_start], remaining_bytes);
384       }
385       read_buffer_start = 0;
386     }
387 
388     if (read_buffer_->buffer_.size() - read_buffer_->num_valid_bytes_ <
389             kReadSize) {
390       // Use power-of-2 buffer sizes.
391       // TODO(vtl): Make sure the buffer doesn't get too large (and enforce the
392       // maximum message size to whatever extent necessary).
393       // TODO(vtl): We may often be able to peek at the header and get the real
394       // required extra space (which may be much bigger than |kReadSize|).
395       size_t new_size = std::max(read_buffer_->buffer_.size(), kReadSize);
396       while (new_size < read_buffer_->num_valid_bytes_ + kReadSize)
397         new_size *= 2;
398 
399       // TODO(vtl): It's suboptimal to zero out the fresh memory.
400       read_buffer_->buffer_.resize(new_size, 0);
401     }
402 
403     // (1) If we dispatched any messages, stop reading for now (and let the
404     // message loop do its thing for another round).
405     // TODO(vtl): Is this the behavior we want? (Alternatives: i. Dispatch only
406     // a single message. Risks: slower, more complex if we want to avoid lots of
407     // copying. ii. Keep reading until there's no more data and dispatch all the
408     // messages we can. Risks: starvation of other users of the message loop.)
409     // (2) If we didn't max out |kReadSize|, stop reading for now.
410     bool schedule_for_later = did_dispatch_message || bytes_read < kReadSize;
411     bytes_read = 0;
412     io_result = schedule_for_later ? ScheduleRead() : Read(&bytes_read);
413   } while (io_result != IO_PENDING);
414 }
415 
OnWriteCompleted(bool result,size_t platform_handles_written,size_t bytes_written)416 void RawChannel::OnWriteCompleted(bool result,
417                                   size_t platform_handles_written,
418                                   size_t bytes_written) {
419   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
420 
421   bool did_fail = false;
422   {
423     base::AutoLock locker(write_lock_);
424     DCHECK_EQ(write_stopped_, write_buffer_->message_queue_.empty());
425 
426     if (write_stopped_) {
427       NOTREACHED();
428       return;
429     }
430 
431     did_fail = !OnWriteCompletedNoLock(result,
432                                        platform_handles_written,
433                                        bytes_written);
434   }
435 
436   if (did_fail)
437     CallOnFatalError(Delegate::FATAL_ERROR_WRITE);
438 }
439 
EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message)440 void RawChannel::EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message) {
441   write_lock_.AssertAcquired();
442   write_buffer_->message_queue_.push_back(message.release());
443 }
444 
OnReadMessageForRawChannel(const MessageInTransit::View & message_view)445 bool RawChannel::OnReadMessageForRawChannel(
446     const MessageInTransit::View& message_view) {
447   // No non-implementation specific |RawChannel| control messages.
448   LOG(ERROR) << "Invalid control message (subtype " << message_view.subtype()
449              << ")";
450   return false;
451 }
452 
CallOnFatalError(Delegate::FatalError fatal_error)453 void RawChannel::CallOnFatalError(Delegate::FatalError fatal_error) {
454   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
455   // TODO(vtl): Add a "write_lock_.AssertNotAcquired()"?
456   if (delegate_)
457     delegate_->OnFatalError(fatal_error);
458 }
459 
OnWriteCompletedNoLock(bool result,size_t platform_handles_written,size_t bytes_written)460 bool RawChannel::OnWriteCompletedNoLock(bool result,
461                                         size_t platform_handles_written,
462                                         size_t bytes_written) {
463   write_lock_.AssertAcquired();
464 
465   DCHECK(!write_stopped_);
466   DCHECK(!write_buffer_->message_queue_.empty());
467 
468   if (result) {
469     write_buffer_->platform_handles_offset_ += platform_handles_written;
470     write_buffer_->data_offset_ += bytes_written;
471 
472     MessageInTransit* message = write_buffer_->message_queue_.front();
473     if (write_buffer_->data_offset_ >= message->total_size()) {
474       // Complete write.
475       DCHECK_EQ(write_buffer_->data_offset_, message->total_size());
476       write_buffer_->message_queue_.pop_front();
477       delete message;
478       write_buffer_->platform_handles_offset_ = 0;
479       write_buffer_->data_offset_ = 0;
480 
481       if (write_buffer_->message_queue_.empty())
482         return true;
483     }
484 
485     // Schedule the next write.
486     IOResult io_result = ScheduleWriteNoLock();
487     if (io_result == IO_PENDING)
488       return true;
489     DCHECK_EQ(io_result, IO_FAILED);
490   }
491 
492   write_stopped_ = true;
493   STLDeleteElements(&write_buffer_->message_queue_);
494   write_buffer_->platform_handles_offset_ = 0;
495   write_buffer_->data_offset_ = 0;
496   return false;
497 }
498 
499 }  // namespace system
500 }  // namespace mojo
501