1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/system/raw_channel.h"
6
7 #include <stdint.h>
8
9 #include <vector>
10
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/scoped_ptr.h"
16 #include "base/memory/scoped_vector.h"
17 #include "base/rand_util.h"
18 #include "base/synchronization/lock.h"
19 #include "base/synchronization/waitable_event.h"
20 #include "base/test/test_io_thread.h"
21 #include "base/threading/platform_thread.h" // For |Sleep()|.
22 #include "base/threading/simple_thread.h"
23 #include "base/time/time.h"
24 #include "build/build_config.h"
25 #include "mojo/common/test/test_utils.h"
26 #include "mojo/embedder/platform_channel_pair.h"
27 #include "mojo/embedder/platform_handle.h"
28 #include "mojo/embedder/scoped_platform_handle.h"
29 #include "mojo/system/message_in_transit.h"
30 #include "mojo/system/test_utils.h"
31 #include "testing/gtest/include/gtest/gtest.h"
32
33 namespace mojo {
34 namespace system {
35 namespace {
36
MakeTestMessage(uint32_t num_bytes)37 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
38 std::vector<unsigned char> bytes(num_bytes, 0);
39 for (size_t i = 0; i < num_bytes; i++)
40 bytes[i] = static_cast<unsigned char>(i + num_bytes);
41 return make_scoped_ptr(
42 new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint,
43 MessageInTransit::kSubtypeMessagePipeEndpointData,
44 num_bytes,
45 bytes.empty() ? nullptr : &bytes[0]));
46 }
47
CheckMessageData(const void * bytes,uint32_t num_bytes)48 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
49 const unsigned char* b = static_cast<const unsigned char*>(bytes);
50 for (uint32_t i = 0; i < num_bytes; i++) {
51 if (b[i] != static_cast<unsigned char>(i + num_bytes))
52 return false;
53 }
54 return true;
55 }
56
InitOnIOThread(RawChannel * raw_channel,RawChannel::Delegate * delegate)57 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) {
58 CHECK(raw_channel->Init(delegate));
59 }
60
WriteTestMessageToHandle(const embedder::PlatformHandle & handle,uint32_t num_bytes)61 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
62 uint32_t num_bytes) {
63 scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
64
65 size_t write_size = 0;
66 mojo::test::BlockingWrite(
67 handle, message->main_buffer(), message->main_buffer_size(), &write_size);
68 return write_size == message->main_buffer_size();
69 }
70
71 // -----------------------------------------------------------------------------
72
73 class RawChannelTest : public testing::Test {
74 public:
RawChannelTest()75 RawChannelTest() : io_thread_(base::TestIOThread::kManualStart) {}
~RawChannelTest()76 virtual ~RawChannelTest() {}
77
SetUp()78 virtual void SetUp() OVERRIDE {
79 embedder::PlatformChannelPair channel_pair;
80 handles[0] = channel_pair.PassServerHandle();
81 handles[1] = channel_pair.PassClientHandle();
82 io_thread_.Start();
83 }
84
TearDown()85 virtual void TearDown() OVERRIDE {
86 io_thread_.Stop();
87 handles[0].reset();
88 handles[1].reset();
89 }
90
91 protected:
io_thread()92 base::TestIOThread* io_thread() { return &io_thread_; }
93
94 embedder::ScopedPlatformHandle handles[2];
95
96 private:
97 base::TestIOThread io_thread_;
98
99 DISALLOW_COPY_AND_ASSIGN(RawChannelTest);
100 };
101
102 // RawChannelTest.WriteMessage -------------------------------------------------
103
104 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
105 public:
WriteOnlyRawChannelDelegate()106 WriteOnlyRawChannelDelegate() {}
~WriteOnlyRawChannelDelegate()107 virtual ~WriteOnlyRawChannelDelegate() {}
108
109 // |RawChannel::Delegate| implementation:
OnReadMessage(const MessageInTransit::View &,embedder::ScopedPlatformHandleVectorPtr)110 virtual void OnReadMessage(
111 const MessageInTransit::View& /*message_view*/,
112 embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
113 CHECK(false); // Should not get called.
114 }
OnError(Error error)115 virtual void OnError(Error error) OVERRIDE {
116 // We'll get a read (shutdown) error when the connection is closed.
117 CHECK_EQ(error, ERROR_READ_SHUTDOWN);
118 }
119
120 private:
121 DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
122 };
123
124 static const int64_t kMessageReaderSleepMs = 1;
125 static const size_t kMessageReaderMaxPollIterations = 3000;
126
127 class TestMessageReaderAndChecker {
128 public:
TestMessageReaderAndChecker(embedder::PlatformHandle handle)129 explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
130 : handle_(handle) {}
~TestMessageReaderAndChecker()131 ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
132
ReadAndCheckNextMessage(uint32_t expected_size)133 bool ReadAndCheckNextMessage(uint32_t expected_size) {
134 unsigned char buffer[4096];
135
136 for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
137 size_t read_size = 0;
138 CHECK(mojo::test::NonBlockingRead(
139 handle_, buffer, sizeof(buffer), &read_size));
140
141 // Append newly-read data to |bytes_|.
142 bytes_.insert(bytes_.end(), buffer, buffer + read_size);
143
144 // If we have the header....
145 size_t message_size;
146 if (MessageInTransit::GetNextMessageSize(
147 bytes_.empty() ? nullptr : &bytes_[0],
148 bytes_.size(),
149 &message_size)) {
150 // If we've read the whole message....
151 if (bytes_.size() >= message_size) {
152 bool rv = true;
153 MessageInTransit::View message_view(message_size, &bytes_[0]);
154 CHECK_EQ(message_view.main_buffer_size(), message_size);
155
156 if (message_view.num_bytes() != expected_size) {
157 LOG(ERROR) << "Wrong size: " << message_size << " instead of "
158 << expected_size << " bytes.";
159 rv = false;
160 } else if (!CheckMessageData(message_view.bytes(),
161 message_view.num_bytes())) {
162 LOG(ERROR) << "Incorrect message bytes.";
163 rv = false;
164 }
165
166 // Erase message data.
167 bytes_.erase(bytes_.begin(),
168 bytes_.begin() + message_view.main_buffer_size());
169 return rv;
170 }
171 }
172
173 if (static_cast<size_t>(read_size) < sizeof(buffer)) {
174 i++;
175 base::PlatformThread::Sleep(
176 base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
177 }
178 }
179
180 LOG(ERROR) << "Too many iterations.";
181 return false;
182 }
183
184 private:
185 const embedder::PlatformHandle handle_;
186
187 // The start of the received data should always be on a message boundary.
188 std::vector<unsigned char> bytes_;
189
190 DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
191 };
192
193 // Tests writing (and verifies reading using our own custom reader).
TEST_F(RawChannelTest,WriteMessage)194 TEST_F(RawChannelTest, WriteMessage) {
195 WriteOnlyRawChannelDelegate delegate;
196 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
197 TestMessageReaderAndChecker checker(handles[1].get());
198 io_thread()->PostTaskAndWait(
199 FROM_HERE,
200 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
201
202 // Write and read, for a variety of sizes.
203 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
204 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
205 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
206 }
207
208 // Write/queue and read afterwards, for a variety of sizes.
209 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
210 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
211 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
212 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
213
214 io_thread()->PostTaskAndWait(
215 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
216 }
217
218 // RawChannelTest.OnReadMessage ------------------------------------------------
219
220 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
221 public:
ReadCheckerRawChannelDelegate()222 ReadCheckerRawChannelDelegate() : done_event_(false, false), position_(0) {}
~ReadCheckerRawChannelDelegate()223 virtual ~ReadCheckerRawChannelDelegate() {}
224
225 // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View & message_view,embedder::ScopedPlatformHandleVectorPtr platform_handles)226 virtual void OnReadMessage(
227 const MessageInTransit::View& message_view,
228 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
229 EXPECT_FALSE(platform_handles);
230
231 size_t position;
232 size_t expected_size;
233 bool should_signal = false;
234 {
235 base::AutoLock locker(lock_);
236 CHECK_LT(position_, expected_sizes_.size());
237 position = position_;
238 expected_size = expected_sizes_[position];
239 position_++;
240 if (position_ >= expected_sizes_.size())
241 should_signal = true;
242 }
243
244 EXPECT_EQ(expected_size, message_view.num_bytes()) << position;
245 if (message_view.num_bytes() == expected_size) {
246 EXPECT_TRUE(
247 CheckMessageData(message_view.bytes(), message_view.num_bytes()))
248 << position;
249 }
250
251 if (should_signal)
252 done_event_.Signal();
253 }
OnError(Error error)254 virtual void OnError(Error error) OVERRIDE {
255 // We'll get a read (shutdown) error when the connection is closed.
256 CHECK_EQ(error, ERROR_READ_SHUTDOWN);
257 }
258
259 // Waits for all the messages (of sizes |expected_sizes_|) to be seen.
Wait()260 void Wait() { done_event_.Wait(); }
261
SetExpectedSizes(const std::vector<uint32_t> & expected_sizes)262 void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
263 base::AutoLock locker(lock_);
264 CHECK_EQ(position_, expected_sizes_.size());
265 expected_sizes_ = expected_sizes;
266 position_ = 0;
267 }
268
269 private:
270 base::WaitableEvent done_event_;
271
272 base::Lock lock_; // Protects the following members.
273 std::vector<uint32_t> expected_sizes_;
274 size_t position_;
275
276 DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
277 };
278
279 // Tests reading (writing using our own custom writer).
TEST_F(RawChannelTest,OnReadMessage)280 TEST_F(RawChannelTest, OnReadMessage) {
281 ReadCheckerRawChannelDelegate delegate;
282 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
283 io_thread()->PostTaskAndWait(
284 FROM_HERE,
285 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
286
287 // Write and read, for a variety of sizes.
288 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
289 delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
290
291 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
292
293 delegate.Wait();
294 }
295
296 // Set up reader and write as fast as we can.
297 // Write/queue and read afterwards, for a variety of sizes.
298 std::vector<uint32_t> expected_sizes;
299 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
300 expected_sizes.push_back(size);
301 delegate.SetExpectedSizes(expected_sizes);
302 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
303 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
304 delegate.Wait();
305
306 io_thread()->PostTaskAndWait(
307 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
308 }
309
310 // RawChannelTest.WriteMessageAndOnReadMessage ---------------------------------
311
312 class RawChannelWriterThread : public base::SimpleThread {
313 public:
RawChannelWriterThread(RawChannel * raw_channel,size_t write_count)314 RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
315 : base::SimpleThread("raw_channel_writer_thread"),
316 raw_channel_(raw_channel),
317 left_to_write_(write_count) {}
318
~RawChannelWriterThread()319 virtual ~RawChannelWriterThread() { Join(); }
320
321 private:
Run()322 virtual void Run() OVERRIDE {
323 static const int kMaxRandomMessageSize = 25000;
324
325 while (left_to_write_-- > 0) {
326 EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
327 static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
328 }
329 }
330
331 RawChannel* const raw_channel_;
332 size_t left_to_write_;
333
334 DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
335 };
336
337 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
338 public:
ReadCountdownRawChannelDelegate(size_t expected_count)339 explicit ReadCountdownRawChannelDelegate(size_t expected_count)
340 : done_event_(false, false), expected_count_(expected_count), count_(0) {}
~ReadCountdownRawChannelDelegate()341 virtual ~ReadCountdownRawChannelDelegate() {}
342
343 // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View & message_view,embedder::ScopedPlatformHandleVectorPtr platform_handles)344 virtual void OnReadMessage(
345 const MessageInTransit::View& message_view,
346 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
347 EXPECT_FALSE(platform_handles);
348
349 EXPECT_LT(count_, expected_count_);
350 count_++;
351
352 EXPECT_TRUE(
353 CheckMessageData(message_view.bytes(), message_view.num_bytes()));
354
355 if (count_ >= expected_count_)
356 done_event_.Signal();
357 }
OnError(Error error)358 virtual void OnError(Error error) OVERRIDE {
359 // We'll get a read (shutdown) error when the connection is closed.
360 CHECK_EQ(error, ERROR_READ_SHUTDOWN);
361 }
362
363 // Waits for all the messages to have been seen.
Wait()364 void Wait() { done_event_.Wait(); }
365
366 private:
367 base::WaitableEvent done_event_;
368 size_t expected_count_;
369 size_t count_;
370
371 DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
372 };
373
TEST_F(RawChannelTest,WriteMessageAndOnReadMessage)374 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) {
375 static const size_t kNumWriterThreads = 10;
376 static const size_t kNumWriteMessagesPerThread = 4000;
377
378 WriteOnlyRawChannelDelegate writer_delegate;
379 scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass()));
380 io_thread()->PostTaskAndWait(FROM_HERE,
381 base::Bind(&InitOnIOThread,
382 writer_rc.get(),
383 base::Unretained(&writer_delegate)));
384
385 ReadCountdownRawChannelDelegate reader_delegate(kNumWriterThreads *
386 kNumWriteMessagesPerThread);
387 scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass()));
388 io_thread()->PostTaskAndWait(FROM_HERE,
389 base::Bind(&InitOnIOThread,
390 reader_rc.get(),
391 base::Unretained(&reader_delegate)));
392
393 {
394 ScopedVector<RawChannelWriterThread> writer_threads;
395 for (size_t i = 0; i < kNumWriterThreads; i++) {
396 writer_threads.push_back(new RawChannelWriterThread(
397 writer_rc.get(), kNumWriteMessagesPerThread));
398 }
399 for (size_t i = 0; i < writer_threads.size(); i++)
400 writer_threads[i]->Start();
401 } // Joins all the writer threads.
402
403 // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
404 // any, but we want to know about them.)
405 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
406
407 // Wait for reading to finish.
408 reader_delegate.Wait();
409
410 io_thread()->PostTaskAndWait(
411 FROM_HERE,
412 base::Bind(&RawChannel::Shutdown, base::Unretained(reader_rc.get())));
413
414 io_thread()->PostTaskAndWait(
415 FROM_HERE,
416 base::Bind(&RawChannel::Shutdown, base::Unretained(writer_rc.get())));
417 }
418
419 // RawChannelTest.OnError ------------------------------------------------------
420
421 class ErrorRecordingRawChannelDelegate
422 : public ReadCountdownRawChannelDelegate {
423 public:
ErrorRecordingRawChannelDelegate(size_t expected_read_count,bool expect_read_error,bool expect_write_error)424 ErrorRecordingRawChannelDelegate(size_t expected_read_count,
425 bool expect_read_error,
426 bool expect_write_error)
427 : ReadCountdownRawChannelDelegate(expected_read_count),
428 got_read_error_event_(false, false),
429 got_write_error_event_(false, false),
430 expecting_read_error_(expect_read_error),
431 expecting_write_error_(expect_write_error) {}
432
~ErrorRecordingRawChannelDelegate()433 virtual ~ErrorRecordingRawChannelDelegate() {}
434
OnError(Error error)435 virtual void OnError(Error error) OVERRIDE {
436 switch (error) {
437 case ERROR_READ_SHUTDOWN:
438 ASSERT_TRUE(expecting_read_error_);
439 expecting_read_error_ = false;
440 got_read_error_event_.Signal();
441 break;
442 case ERROR_READ_BROKEN:
443 // TODO(vtl): Test broken connections.
444 CHECK(false);
445 break;
446 case ERROR_READ_BAD_MESSAGE:
447 // TODO(vtl): Test reception/detection of bad messages.
448 CHECK(false);
449 break;
450 case ERROR_READ_UNKNOWN:
451 // TODO(vtl): Test however it is we might get here.
452 CHECK(false);
453 break;
454 case ERROR_WRITE:
455 ASSERT_TRUE(expecting_write_error_);
456 expecting_write_error_ = false;
457 got_write_error_event_.Signal();
458 break;
459 }
460 }
461
WaitForReadError()462 void WaitForReadError() { got_read_error_event_.Wait(); }
WaitForWriteError()463 void WaitForWriteError() { got_write_error_event_.Wait(); }
464
465 private:
466 base::WaitableEvent got_read_error_event_;
467 base::WaitableEvent got_write_error_event_;
468
469 bool expecting_read_error_;
470 bool expecting_write_error_;
471
472 DISALLOW_COPY_AND_ASSIGN(ErrorRecordingRawChannelDelegate);
473 };
474
475 // Tests (fatal) errors.
TEST_F(RawChannelTest,OnError)476 TEST_F(RawChannelTest, OnError) {
477 ErrorRecordingRawChannelDelegate delegate(0, true, true);
478 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
479 io_thread()->PostTaskAndWait(
480 FROM_HERE,
481 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
482
483 // Close the handle of the other end, which should make writing fail.
484 handles[1].reset();
485
486 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
487
488 // We should get a write error.
489 delegate.WaitForWriteError();
490
491 // We should also get a read error.
492 delegate.WaitForReadError();
493
494 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
495
496 // Sleep a bit, to make sure we don't get another |OnError()|
497 // notification. (If we actually get another one, |OnError()| crashes.)
498 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20));
499
500 io_thread()->PostTaskAndWait(
501 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
502 }
503
504 // RawChannelTest.ReadUnaffectedByWriteError -----------------------------------
505
TEST_F(RawChannelTest,ReadUnaffectedByWriteError)506 TEST_F(RawChannelTest, ReadUnaffectedByWriteError) {
507 const size_t kMessageCount = 5;
508
509 // Write a few messages into the other end.
510 uint32_t message_size = 1;
511 for (size_t i = 0; i < kMessageCount;
512 i++, message_size += message_size / 2 + 1)
513 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
514
515 // Close the other end, which should make writing fail.
516 handles[1].reset();
517
518 // Only start up reading here. The system buffer should still contain the
519 // messages that were written.
520 ErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true);
521 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
522 io_thread()->PostTaskAndWait(
523 FROM_HERE,
524 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
525
526 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
527
528 // We should definitely get a write error.
529 delegate.WaitForWriteError();
530
531 // Wait for reading to finish. A writing failure shouldn't affect reading.
532 delegate.Wait();
533
534 // And then we should get a read error.
535 delegate.WaitForReadError();
536
537 io_thread()->PostTaskAndWait(
538 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
539 }
540
541 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------
542
543 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
544 // correctly.
TEST_F(RawChannelTest,WriteMessageAfterShutdown)545 TEST_F(RawChannelTest, WriteMessageAfterShutdown) {
546 WriteOnlyRawChannelDelegate delegate;
547 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
548 io_thread()->PostTaskAndWait(
549 FROM_HERE,
550 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
551 io_thread()->PostTaskAndWait(
552 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
553
554 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
555 }
556
557 // RawChannelTest.ShutdownOnReadMessage ----------------------------------------
558
559 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate {
560 public:
ShutdownOnReadMessageRawChannelDelegate(RawChannel * raw_channel)561 explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel)
562 : raw_channel_(raw_channel),
563 done_event_(false, false),
564 did_shutdown_(false) {}
~ShutdownOnReadMessageRawChannelDelegate()565 virtual ~ShutdownOnReadMessageRawChannelDelegate() {}
566
567 // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View & message_view,embedder::ScopedPlatformHandleVectorPtr platform_handles)568 virtual void OnReadMessage(
569 const MessageInTransit::View& message_view,
570 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
571 EXPECT_FALSE(platform_handles);
572 EXPECT_FALSE(did_shutdown_);
573 EXPECT_TRUE(
574 CheckMessageData(message_view.bytes(), message_view.num_bytes()));
575 raw_channel_->Shutdown();
576 did_shutdown_ = true;
577 done_event_.Signal();
578 }
OnError(Error)579 virtual void OnError(Error /*error*/) OVERRIDE {
580 CHECK(false); // Should not get called.
581 }
582
583 // Waits for shutdown.
Wait()584 void Wait() {
585 done_event_.Wait();
586 EXPECT_TRUE(did_shutdown_);
587 }
588
589 private:
590 RawChannel* const raw_channel_;
591 base::WaitableEvent done_event_;
592 bool did_shutdown_;
593
594 DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate);
595 };
596
TEST_F(RawChannelTest,ShutdownOnReadMessage)597 TEST_F(RawChannelTest, ShutdownOnReadMessage) {
598 // Write a few messages into the other end.
599 for (size_t count = 0; count < 5; count++)
600 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10));
601
602 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
603 ShutdownOnReadMessageRawChannelDelegate delegate(rc.get());
604 io_thread()->PostTaskAndWait(
605 FROM_HERE,
606 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
607
608 // Wait for the delegate, which will shut the |RawChannel| down.
609 delegate.Wait();
610 }
611
612 // RawChannelTest.ShutdownOnError{Read, Write} ---------------------------------
613
614 class ShutdownOnErrorRawChannelDelegate : public RawChannel::Delegate {
615 public:
ShutdownOnErrorRawChannelDelegate(RawChannel * raw_channel,Error shutdown_on_error_type)616 ShutdownOnErrorRawChannelDelegate(RawChannel* raw_channel,
617 Error shutdown_on_error_type)
618 : raw_channel_(raw_channel),
619 shutdown_on_error_type_(shutdown_on_error_type),
620 done_event_(false, false),
621 did_shutdown_(false) {}
~ShutdownOnErrorRawChannelDelegate()622 virtual ~ShutdownOnErrorRawChannelDelegate() {}
623
624 // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View &,embedder::ScopedPlatformHandleVectorPtr)625 virtual void OnReadMessage(
626 const MessageInTransit::View& /*message_view*/,
627 embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
628 CHECK(false); // Should not get called.
629 }
OnError(Error error)630 virtual void OnError(Error error) OVERRIDE {
631 EXPECT_FALSE(did_shutdown_);
632 if (error != shutdown_on_error_type_)
633 return;
634 raw_channel_->Shutdown();
635 did_shutdown_ = true;
636 done_event_.Signal();
637 }
638
639 // Waits for shutdown.
Wait()640 void Wait() {
641 done_event_.Wait();
642 EXPECT_TRUE(did_shutdown_);
643 }
644
645 private:
646 RawChannel* const raw_channel_;
647 const Error shutdown_on_error_type_;
648 base::WaitableEvent done_event_;
649 bool did_shutdown_;
650
651 DISALLOW_COPY_AND_ASSIGN(ShutdownOnErrorRawChannelDelegate);
652 };
653
TEST_F(RawChannelTest,ShutdownOnErrorRead)654 TEST_F(RawChannelTest, ShutdownOnErrorRead) {
655 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
656 ShutdownOnErrorRawChannelDelegate delegate(
657 rc.get(), RawChannel::Delegate::ERROR_READ_SHUTDOWN);
658 io_thread()->PostTaskAndWait(
659 FROM_HERE,
660 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
661
662 // Close the handle of the other end, which should stuff fail.
663 handles[1].reset();
664
665 // Wait for the delegate, which will shut the |RawChannel| down.
666 delegate.Wait();
667 }
668
TEST_F(RawChannelTest,ShutdownOnErrorWrite)669 TEST_F(RawChannelTest, ShutdownOnErrorWrite) {
670 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
671 ShutdownOnErrorRawChannelDelegate delegate(rc.get(),
672 RawChannel::Delegate::ERROR_WRITE);
673 io_thread()->PostTaskAndWait(
674 FROM_HERE,
675 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
676
677 // Close the handle of the other end, which should stuff fail.
678 handles[1].reset();
679
680 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
681
682 // Wait for the delegate, which will shut the |RawChannel| down.
683 delegate.Wait();
684 }
685
686 } // namespace
687 } // namespace system
688 } // namespace mojo
689