1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/system/raw_channel.h"
6
7 #include <string.h>
8
9 #include <algorithm>
10
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/message_loop/message_loop.h"
15 #include "base/stl_util.h"
16 #include "mojo/system/message_in_transit.h"
17 #include "mojo/system/transport_data.h"
18
19 namespace mojo {
20 namespace system {
21
22 const size_t kReadSize = 4096;
23
24 // RawChannel::ReadBuffer ------------------------------------------------------
25
ReadBuffer()26 RawChannel::ReadBuffer::ReadBuffer() : buffer_(kReadSize), num_valid_bytes_(0) {
27 }
28
~ReadBuffer()29 RawChannel::ReadBuffer::~ReadBuffer() {
30 }
31
GetBuffer(char ** addr,size_t * size)32 void RawChannel::ReadBuffer::GetBuffer(char** addr, size_t* size) {
33 DCHECK_GE(buffer_.size(), num_valid_bytes_ + kReadSize);
34 *addr = &buffer_[0] + num_valid_bytes_;
35 *size = kReadSize;
36 }
37
38 // RawChannel::WriteBuffer -----------------------------------------------------
39
WriteBuffer(size_t serialized_platform_handle_size)40 RawChannel::WriteBuffer::WriteBuffer(size_t serialized_platform_handle_size)
41 : serialized_platform_handle_size_(serialized_platform_handle_size),
42 platform_handles_offset_(0),
43 data_offset_(0) {
44 }
45
~WriteBuffer()46 RawChannel::WriteBuffer::~WriteBuffer() {
47 STLDeleteElements(&message_queue_);
48 }
49
HavePlatformHandlesToSend() const50 bool RawChannel::WriteBuffer::HavePlatformHandlesToSend() const {
51 if (message_queue_.empty())
52 return false;
53
54 const TransportData* transport_data =
55 message_queue_.front()->transport_data();
56 if (!transport_data)
57 return false;
58
59 const embedder::PlatformHandleVector* all_platform_handles =
60 transport_data->platform_handles();
61 if (!all_platform_handles) {
62 DCHECK_EQ(platform_handles_offset_, 0u);
63 return false;
64 }
65 if (platform_handles_offset_ >= all_platform_handles->size()) {
66 DCHECK_EQ(platform_handles_offset_, all_platform_handles->size());
67 return false;
68 }
69
70 return true;
71 }
72
GetPlatformHandlesToSend(size_t * num_platform_handles,embedder::PlatformHandle ** platform_handles,void ** serialization_data)73 void RawChannel::WriteBuffer::GetPlatformHandlesToSend(
74 size_t* num_platform_handles,
75 embedder::PlatformHandle** platform_handles,
76 void** serialization_data) {
77 DCHECK(HavePlatformHandlesToSend());
78
79 TransportData* transport_data = message_queue_.front()->transport_data();
80 embedder::PlatformHandleVector* all_platform_handles =
81 transport_data->platform_handles();
82 *num_platform_handles =
83 all_platform_handles->size() - platform_handles_offset_;
84 *platform_handles = &(*all_platform_handles)[platform_handles_offset_];
85 size_t serialization_data_offset =
86 transport_data->platform_handle_table_offset();
87 DCHECK_GT(serialization_data_offset, 0u);
88 serialization_data_offset +=
89 platform_handles_offset_ * serialized_platform_handle_size_;
90 *serialization_data =
91 static_cast<char*>(transport_data->buffer()) + serialization_data_offset;
92 }
93
GetBuffers(std::vector<Buffer> * buffers) const94 void RawChannel::WriteBuffer::GetBuffers(std::vector<Buffer>* buffers) const {
95 buffers->clear();
96
97 if (message_queue_.empty())
98 return;
99
100 MessageInTransit* message = message_queue_.front();
101 DCHECK_LT(data_offset_, message->total_size());
102 size_t bytes_to_write = message->total_size() - data_offset_;
103
104 size_t transport_data_buffer_size =
105 message->transport_data() ? message->transport_data()->buffer_size() : 0;
106
107 if (!transport_data_buffer_size) {
108 // Only write from the main buffer.
109 DCHECK_LT(data_offset_, message->main_buffer_size());
110 DCHECK_LE(bytes_to_write, message->main_buffer_size());
111 Buffer buffer = {
112 static_cast<const char*>(message->main_buffer()) + data_offset_,
113 bytes_to_write};
114 buffers->push_back(buffer);
115 return;
116 }
117
118 if (data_offset_ >= message->main_buffer_size()) {
119 // Only write from the transport data buffer.
120 DCHECK_LT(data_offset_ - message->main_buffer_size(),
121 transport_data_buffer_size);
122 DCHECK_LE(bytes_to_write, transport_data_buffer_size);
123 Buffer buffer = {
124 static_cast<const char*>(message->transport_data()->buffer()) +
125 (data_offset_ - message->main_buffer_size()),
126 bytes_to_write};
127 buffers->push_back(buffer);
128 return;
129 }
130
131 // TODO(vtl): We could actually send out buffers from multiple messages, with
132 // the "stopping" condition being reaching a message with platform handles
133 // attached.
134
135 // Write from both buffers.
136 DCHECK_EQ(
137 bytes_to_write,
138 message->main_buffer_size() - data_offset_ + transport_data_buffer_size);
139 Buffer buffer1 = {
140 static_cast<const char*>(message->main_buffer()) + data_offset_,
141 message->main_buffer_size() - data_offset_};
142 buffers->push_back(buffer1);
143 Buffer buffer2 = {
144 static_cast<const char*>(message->transport_data()->buffer()),
145 transport_data_buffer_size};
146 buffers->push_back(buffer2);
147 }
148
149 // RawChannel ------------------------------------------------------------------
150
RawChannel()151 RawChannel::RawChannel()
152 : message_loop_for_io_(nullptr),
153 delegate_(nullptr),
154 read_stopped_(false),
155 write_stopped_(false),
156 weak_ptr_factory_(this) {
157 }
158
~RawChannel()159 RawChannel::~RawChannel() {
160 DCHECK(!read_buffer_);
161 DCHECK(!write_buffer_);
162
163 // No need to take the |write_lock_| here -- if there are still weak pointers
164 // outstanding, then we're hosed anyway (since we wouldn't be able to
165 // invalidate them cleanly, since we might not be on the I/O thread).
166 DCHECK(!weak_ptr_factory_.HasWeakPtrs());
167 }
168
Init(Delegate * delegate)169 bool RawChannel::Init(Delegate* delegate) {
170 DCHECK(delegate);
171
172 DCHECK(!delegate_);
173 delegate_ = delegate;
174
175 CHECK_EQ(base::MessageLoop::current()->type(), base::MessageLoop::TYPE_IO);
176 DCHECK(!message_loop_for_io_);
177 message_loop_for_io_ =
178 static_cast<base::MessageLoopForIO*>(base::MessageLoop::current());
179
180 // No need to take the lock. No one should be using us yet.
181 DCHECK(!read_buffer_);
182 read_buffer_.reset(new ReadBuffer);
183 DCHECK(!write_buffer_);
184 write_buffer_.reset(new WriteBuffer(GetSerializedPlatformHandleSize()));
185
186 if (!OnInit()) {
187 delegate_ = nullptr;
188 message_loop_for_io_ = nullptr;
189 read_buffer_.reset();
190 write_buffer_.reset();
191 return false;
192 }
193
194 IOResult io_result = ScheduleRead();
195 if (io_result != IO_PENDING) {
196 // This will notify the delegate about the read failure. Although we're on
197 // the I/O thread, don't call it in the nested context.
198 message_loop_for_io_->PostTask(FROM_HERE,
199 base::Bind(&RawChannel::OnReadCompleted,
200 weak_ptr_factory_.GetWeakPtr(),
201 io_result,
202 0));
203 }
204
205 // ScheduleRead() failure is treated as a read failure (by notifying the
206 // delegate), not as an init failure.
207 return true;
208 }
209
Shutdown()210 void RawChannel::Shutdown() {
211 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
212
213 base::AutoLock locker(write_lock_);
214
215 LOG_IF(WARNING, !write_buffer_->message_queue_.empty())
216 << "Shutting down RawChannel with write buffer nonempty";
217
218 // Reset the delegate so that it won't receive further calls.
219 delegate_ = nullptr;
220 read_stopped_ = true;
221 write_stopped_ = true;
222 weak_ptr_factory_.InvalidateWeakPtrs();
223
224 OnShutdownNoLock(read_buffer_.Pass(), write_buffer_.Pass());
225 }
226
227 // Reminder: This must be thread-safe.
WriteMessage(scoped_ptr<MessageInTransit> message)228 bool RawChannel::WriteMessage(scoped_ptr<MessageInTransit> message) {
229 DCHECK(message);
230
231 base::AutoLock locker(write_lock_);
232 if (write_stopped_)
233 return false;
234
235 if (!write_buffer_->message_queue_.empty()) {
236 EnqueueMessageNoLock(message.Pass());
237 return true;
238 }
239
240 EnqueueMessageNoLock(message.Pass());
241 DCHECK_EQ(write_buffer_->data_offset_, 0u);
242
243 size_t platform_handles_written = 0;
244 size_t bytes_written = 0;
245 IOResult io_result = WriteNoLock(&platform_handles_written, &bytes_written);
246 if (io_result == IO_PENDING)
247 return true;
248
249 bool result = OnWriteCompletedNoLock(
250 io_result, platform_handles_written, bytes_written);
251 if (!result) {
252 // Even if we're on the I/O thread, don't call |OnError()| in the nested
253 // context.
254 message_loop_for_io_->PostTask(FROM_HERE,
255 base::Bind(&RawChannel::CallOnError,
256 weak_ptr_factory_.GetWeakPtr(),
257 Delegate::ERROR_WRITE));
258 }
259
260 return result;
261 }
262
263 // Reminder: This must be thread-safe.
IsWriteBufferEmpty()264 bool RawChannel::IsWriteBufferEmpty() {
265 base::AutoLock locker(write_lock_);
266 return write_buffer_->message_queue_.empty();
267 }
268
OnReadCompleted(IOResult io_result,size_t bytes_read)269 void RawChannel::OnReadCompleted(IOResult io_result, size_t bytes_read) {
270 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
271
272 if (read_stopped_) {
273 NOTREACHED();
274 return;
275 }
276
277 // Keep reading data in a loop, and dispatch messages if enough data is
278 // received. Exit the loop if any of the following happens:
279 // - one or more messages were dispatched;
280 // - the last read failed, was a partial read or would block;
281 // - |Shutdown()| was called.
282 do {
283 switch (io_result) {
284 case IO_SUCCEEDED:
285 break;
286 case IO_FAILED_SHUTDOWN:
287 case IO_FAILED_BROKEN:
288 case IO_FAILED_UNKNOWN:
289 read_stopped_ = true;
290 CallOnError(ReadIOResultToError(io_result));
291 return;
292 case IO_PENDING:
293 NOTREACHED();
294 return;
295 }
296
297 read_buffer_->num_valid_bytes_ += bytes_read;
298
299 // Dispatch all the messages that we can.
300 bool did_dispatch_message = false;
301 // Tracks the offset of the first undispatched message in |read_buffer_|.
302 // Currently, we copy data to ensure that this is zero at the beginning.
303 size_t read_buffer_start = 0;
304 size_t remaining_bytes = read_buffer_->num_valid_bytes_;
305 size_t message_size;
306 // Note that we rely on short-circuit evaluation here:
307 // - |read_buffer_start| may be an invalid index into
308 // |read_buffer_->buffer_| if |remaining_bytes| is zero.
309 // - |message_size| is only valid if |GetNextMessageSize()| returns true.
310 // TODO(vtl): Use |message_size| more intelligently (e.g., to request the
311 // next read).
312 // TODO(vtl): Validate that |message_size| is sane.
313 while (remaining_bytes > 0 && MessageInTransit::GetNextMessageSize(
314 &read_buffer_->buffer_[read_buffer_start],
315 remaining_bytes,
316 &message_size) &&
317 remaining_bytes >= message_size) {
318 MessageInTransit::View message_view(
319 message_size, &read_buffer_->buffer_[read_buffer_start]);
320 DCHECK_EQ(message_view.total_size(), message_size);
321
322 const char* error_message = nullptr;
323 if (!message_view.IsValid(GetSerializedPlatformHandleSize(),
324 &error_message)) {
325 DCHECK(error_message);
326 LOG(ERROR) << "Received invalid message: " << error_message;
327 read_stopped_ = true;
328 CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
329 return;
330 }
331
332 if (message_view.type() == MessageInTransit::kTypeRawChannel) {
333 if (!OnReadMessageForRawChannel(message_view)) {
334 read_stopped_ = true;
335 CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
336 return;
337 }
338 } else {
339 embedder::ScopedPlatformHandleVectorPtr platform_handles;
340 if (message_view.transport_data_buffer()) {
341 size_t num_platform_handles;
342 const void* platform_handle_table;
343 TransportData::GetPlatformHandleTable(
344 message_view.transport_data_buffer(),
345 &num_platform_handles,
346 &platform_handle_table);
347
348 if (num_platform_handles > 0) {
349 platform_handles =
350 GetReadPlatformHandles(num_platform_handles,
351 platform_handle_table).Pass();
352 if (!platform_handles) {
353 LOG(ERROR) << "Invalid number of platform handles received";
354 read_stopped_ = true;
355 CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
356 return;
357 }
358 }
359 }
360
361 // TODO(vtl): In the case that we aren't expecting any platform handles,
362 // for the POSIX implementation, we should confirm that none are stored.
363
364 // Dispatch the message.
365 DCHECK(delegate_);
366 delegate_->OnReadMessage(message_view, platform_handles.Pass());
367 if (read_stopped_) {
368 // |Shutdown()| was called in |OnReadMessage()|.
369 // TODO(vtl): Add test for this case.
370 return;
371 }
372 }
373
374 did_dispatch_message = true;
375
376 // Update our state.
377 read_buffer_start += message_size;
378 remaining_bytes -= message_size;
379 }
380
381 if (read_buffer_start > 0) {
382 // Move data back to start.
383 read_buffer_->num_valid_bytes_ = remaining_bytes;
384 if (read_buffer_->num_valid_bytes_ > 0) {
385 memmove(&read_buffer_->buffer_[0],
386 &read_buffer_->buffer_[read_buffer_start],
387 remaining_bytes);
388 }
389 read_buffer_start = 0;
390 }
391
392 if (read_buffer_->buffer_.size() - read_buffer_->num_valid_bytes_ <
393 kReadSize) {
394 // Use power-of-2 buffer sizes.
395 // TODO(vtl): Make sure the buffer doesn't get too large (and enforce the
396 // maximum message size to whatever extent necessary).
397 // TODO(vtl): We may often be able to peek at the header and get the real
398 // required extra space (which may be much bigger than |kReadSize|).
399 size_t new_size = std::max(read_buffer_->buffer_.size(), kReadSize);
400 while (new_size < read_buffer_->num_valid_bytes_ + kReadSize)
401 new_size *= 2;
402
403 // TODO(vtl): It's suboptimal to zero out the fresh memory.
404 read_buffer_->buffer_.resize(new_size, 0);
405 }
406
407 // (1) If we dispatched any messages, stop reading for now (and let the
408 // message loop do its thing for another round).
409 // TODO(vtl): Is this the behavior we want? (Alternatives: i. Dispatch only
410 // a single message. Risks: slower, more complex if we want to avoid lots of
411 // copying. ii. Keep reading until there's no more data and dispatch all the
412 // messages we can. Risks: starvation of other users of the message loop.)
413 // (2) If we didn't max out |kReadSize|, stop reading for now.
414 bool schedule_for_later = did_dispatch_message || bytes_read < kReadSize;
415 bytes_read = 0;
416 io_result = schedule_for_later ? ScheduleRead() : Read(&bytes_read);
417 } while (io_result != IO_PENDING);
418 }
419
OnWriteCompleted(IOResult io_result,size_t platform_handles_written,size_t bytes_written)420 void RawChannel::OnWriteCompleted(IOResult io_result,
421 size_t platform_handles_written,
422 size_t bytes_written) {
423 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
424 DCHECK_NE(io_result, IO_PENDING);
425
426 bool did_fail = false;
427 {
428 base::AutoLock locker(write_lock_);
429 DCHECK_EQ(write_stopped_, write_buffer_->message_queue_.empty());
430
431 if (write_stopped_) {
432 NOTREACHED();
433 return;
434 }
435
436 did_fail = !OnWriteCompletedNoLock(
437 io_result, platform_handles_written, bytes_written);
438 }
439
440 if (did_fail)
441 CallOnError(Delegate::ERROR_WRITE);
442 }
443
EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message)444 void RawChannel::EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message) {
445 write_lock_.AssertAcquired();
446 write_buffer_->message_queue_.push_back(message.release());
447 }
448
OnReadMessageForRawChannel(const MessageInTransit::View & message_view)449 bool RawChannel::OnReadMessageForRawChannel(
450 const MessageInTransit::View& message_view) {
451 // No non-implementation specific |RawChannel| control messages.
452 LOG(ERROR) << "Invalid control message (subtype " << message_view.subtype()
453 << ")";
454 return false;
455 }
456
457 // static
ReadIOResultToError(IOResult io_result)458 RawChannel::Delegate::Error RawChannel::ReadIOResultToError(
459 IOResult io_result) {
460 switch (io_result) {
461 case IO_FAILED_SHUTDOWN:
462 return Delegate::ERROR_READ_SHUTDOWN;
463 case IO_FAILED_BROKEN:
464 return Delegate::ERROR_READ_BROKEN;
465 case IO_FAILED_UNKNOWN:
466 return Delegate::ERROR_READ_UNKNOWN;
467 case IO_SUCCEEDED:
468 case IO_PENDING:
469 NOTREACHED();
470 break;
471 }
472 return Delegate::ERROR_READ_UNKNOWN;
473 }
474
CallOnError(Delegate::Error error)475 void RawChannel::CallOnError(Delegate::Error error) {
476 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
477 // TODO(vtl): Add a "write_lock_.AssertNotAcquired()"?
478 if (delegate_)
479 delegate_->OnError(error);
480 }
481
OnWriteCompletedNoLock(IOResult io_result,size_t platform_handles_written,size_t bytes_written)482 bool RawChannel::OnWriteCompletedNoLock(IOResult io_result,
483 size_t platform_handles_written,
484 size_t bytes_written) {
485 write_lock_.AssertAcquired();
486
487 DCHECK(!write_stopped_);
488 DCHECK(!write_buffer_->message_queue_.empty());
489
490 if (io_result == IO_SUCCEEDED) {
491 write_buffer_->platform_handles_offset_ += platform_handles_written;
492 write_buffer_->data_offset_ += bytes_written;
493
494 MessageInTransit* message = write_buffer_->message_queue_.front();
495 if (write_buffer_->data_offset_ >= message->total_size()) {
496 // Complete write.
497 CHECK_EQ(write_buffer_->data_offset_, message->total_size());
498 write_buffer_->message_queue_.pop_front();
499 delete message;
500 write_buffer_->platform_handles_offset_ = 0;
501 write_buffer_->data_offset_ = 0;
502
503 if (write_buffer_->message_queue_.empty())
504 return true;
505 }
506
507 // Schedule the next write.
508 io_result = ScheduleWriteNoLock();
509 if (io_result == IO_PENDING)
510 return true;
511 DCHECK_NE(io_result, IO_SUCCEEDED);
512 }
513
514 write_stopped_ = true;
515 STLDeleteElements(&write_buffer_->message_queue_);
516 write_buffer_->platform_handles_offset_ = 0;
517 write_buffer_->data_offset_ = 0;
518 return false;
519 }
520
521 } // namespace system
522 } // namespace mojo
523