• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/system/raw_channel.h"
6 
7 #include <stdint.h>
8 
9 #include <vector>
10 
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/scoped_ptr.h"
16 #include "base/memory/scoped_vector.h"
17 #include "base/rand_util.h"
18 #include "base/synchronization/lock.h"
19 #include "base/synchronization/waitable_event.h"
20 #include "base/test/test_io_thread.h"
21 #include "base/threading/platform_thread.h"  // For |Sleep()|.
22 #include "base/threading/simple_thread.h"
23 #include "base/time/time.h"
24 #include "build/build_config.h"
25 #include "mojo/common/test/test_utils.h"
26 #include "mojo/embedder/platform_channel_pair.h"
27 #include "mojo/embedder/platform_handle.h"
28 #include "mojo/embedder/scoped_platform_handle.h"
29 #include "mojo/system/message_in_transit.h"
30 #include "mojo/system/test_utils.h"
31 #include "testing/gtest/include/gtest/gtest.h"
32 
33 namespace mojo {
34 namespace system {
35 namespace {
36 
MakeTestMessage(uint32_t num_bytes)37 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
38   std::vector<unsigned char> bytes(num_bytes, 0);
39   for (size_t i = 0; i < num_bytes; i++)
40     bytes[i] = static_cast<unsigned char>(i + num_bytes);
41   return make_scoped_ptr(
42       new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint,
43                            MessageInTransit::kSubtypeMessagePipeEndpointData,
44                            num_bytes,
45                            bytes.empty() ? nullptr : &bytes[0]));
46 }
47 
CheckMessageData(const void * bytes,uint32_t num_bytes)48 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
49   const unsigned char* b = static_cast<const unsigned char*>(bytes);
50   for (uint32_t i = 0; i < num_bytes; i++) {
51     if (b[i] != static_cast<unsigned char>(i + num_bytes))
52       return false;
53   }
54   return true;
55 }
56 
InitOnIOThread(RawChannel * raw_channel,RawChannel::Delegate * delegate)57 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) {
58   CHECK(raw_channel->Init(delegate));
59 }
60 
WriteTestMessageToHandle(const embedder::PlatformHandle & handle,uint32_t num_bytes)61 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
62                               uint32_t num_bytes) {
63   scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
64 
65   size_t write_size = 0;
66   mojo::test::BlockingWrite(
67       handle, message->main_buffer(), message->main_buffer_size(), &write_size);
68   return write_size == message->main_buffer_size();
69 }
70 
71 // -----------------------------------------------------------------------------
72 
73 class RawChannelTest : public testing::Test {
74  public:
RawChannelTest()75   RawChannelTest() : io_thread_(base::TestIOThread::kManualStart) {}
~RawChannelTest()76   virtual ~RawChannelTest() {}
77 
SetUp()78   virtual void SetUp() OVERRIDE {
79     embedder::PlatformChannelPair channel_pair;
80     handles[0] = channel_pair.PassServerHandle();
81     handles[1] = channel_pair.PassClientHandle();
82     io_thread_.Start();
83   }
84 
TearDown()85   virtual void TearDown() OVERRIDE {
86     io_thread_.Stop();
87     handles[0].reset();
88     handles[1].reset();
89   }
90 
91  protected:
io_thread()92   base::TestIOThread* io_thread() { return &io_thread_; }
93 
94   embedder::ScopedPlatformHandle handles[2];
95 
96  private:
97   base::TestIOThread io_thread_;
98 
99   DISALLOW_COPY_AND_ASSIGN(RawChannelTest);
100 };
101 
102 // RawChannelTest.WriteMessage -------------------------------------------------
103 
104 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
105  public:
WriteOnlyRawChannelDelegate()106   WriteOnlyRawChannelDelegate() {}
~WriteOnlyRawChannelDelegate()107   virtual ~WriteOnlyRawChannelDelegate() {}
108 
109   // |RawChannel::Delegate| implementation:
OnReadMessage(const MessageInTransit::View &,embedder::ScopedPlatformHandleVectorPtr)110   virtual void OnReadMessage(
111       const MessageInTransit::View& /*message_view*/,
112       embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
113     CHECK(false);  // Should not get called.
114   }
OnError(Error error)115   virtual void OnError(Error error) OVERRIDE {
116     // We'll get a read (shutdown) error when the connection is closed.
117     CHECK_EQ(error, ERROR_READ_SHUTDOWN);
118   }
119 
120  private:
121   DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
122 };
123 
124 static const int64_t kMessageReaderSleepMs = 1;
125 static const size_t kMessageReaderMaxPollIterations = 3000;
126 
127 class TestMessageReaderAndChecker {
128  public:
TestMessageReaderAndChecker(embedder::PlatformHandle handle)129   explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
130       : handle_(handle) {}
~TestMessageReaderAndChecker()131   ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
132 
ReadAndCheckNextMessage(uint32_t expected_size)133   bool ReadAndCheckNextMessage(uint32_t expected_size) {
134     unsigned char buffer[4096];
135 
136     for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
137       size_t read_size = 0;
138       CHECK(mojo::test::NonBlockingRead(
139           handle_, buffer, sizeof(buffer), &read_size));
140 
141       // Append newly-read data to |bytes_|.
142       bytes_.insert(bytes_.end(), buffer, buffer + read_size);
143 
144       // If we have the header....
145       size_t message_size;
146       if (MessageInTransit::GetNextMessageSize(
147               bytes_.empty() ? nullptr : &bytes_[0],
148               bytes_.size(),
149               &message_size)) {
150         // If we've read the whole message....
151         if (bytes_.size() >= message_size) {
152           bool rv = true;
153           MessageInTransit::View message_view(message_size, &bytes_[0]);
154           CHECK_EQ(message_view.main_buffer_size(), message_size);
155 
156           if (message_view.num_bytes() != expected_size) {
157             LOG(ERROR) << "Wrong size: " << message_size << " instead of "
158                        << expected_size << " bytes.";
159             rv = false;
160           } else if (!CheckMessageData(message_view.bytes(),
161                                        message_view.num_bytes())) {
162             LOG(ERROR) << "Incorrect message bytes.";
163             rv = false;
164           }
165 
166           // Erase message data.
167           bytes_.erase(bytes_.begin(),
168                        bytes_.begin() + message_view.main_buffer_size());
169           return rv;
170         }
171       }
172 
173       if (static_cast<size_t>(read_size) < sizeof(buffer)) {
174         i++;
175         base::PlatformThread::Sleep(
176             base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
177       }
178     }
179 
180     LOG(ERROR) << "Too many iterations.";
181     return false;
182   }
183 
184  private:
185   const embedder::PlatformHandle handle_;
186 
187   // The start of the received data should always be on a message boundary.
188   std::vector<unsigned char> bytes_;
189 
190   DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
191 };
192 
193 // Tests writing (and verifies reading using our own custom reader).
TEST_F(RawChannelTest,WriteMessage)194 TEST_F(RawChannelTest, WriteMessage) {
195   WriteOnlyRawChannelDelegate delegate;
196   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
197   TestMessageReaderAndChecker checker(handles[1].get());
198   io_thread()->PostTaskAndWait(
199       FROM_HERE,
200       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
201 
202   // Write and read, for a variety of sizes.
203   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
204     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
205     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
206   }
207 
208   // Write/queue and read afterwards, for a variety of sizes.
209   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
210     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
211   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
212     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
213 
214   io_thread()->PostTaskAndWait(
215       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
216 }
217 
218 // RawChannelTest.OnReadMessage ------------------------------------------------
219 
220 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
221  public:
ReadCheckerRawChannelDelegate()222   ReadCheckerRawChannelDelegate() : done_event_(false, false), position_(0) {}
~ReadCheckerRawChannelDelegate()223   virtual ~ReadCheckerRawChannelDelegate() {}
224 
225   // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View & message_view,embedder::ScopedPlatformHandleVectorPtr platform_handles)226   virtual void OnReadMessage(
227       const MessageInTransit::View& message_view,
228       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
229     EXPECT_FALSE(platform_handles);
230 
231     size_t position;
232     size_t expected_size;
233     bool should_signal = false;
234     {
235       base::AutoLock locker(lock_);
236       CHECK_LT(position_, expected_sizes_.size());
237       position = position_;
238       expected_size = expected_sizes_[position];
239       position_++;
240       if (position_ >= expected_sizes_.size())
241         should_signal = true;
242     }
243 
244     EXPECT_EQ(expected_size, message_view.num_bytes()) << position;
245     if (message_view.num_bytes() == expected_size) {
246       EXPECT_TRUE(
247           CheckMessageData(message_view.bytes(), message_view.num_bytes()))
248           << position;
249     }
250 
251     if (should_signal)
252       done_event_.Signal();
253   }
OnError(Error error)254   virtual void OnError(Error error) OVERRIDE {
255     // We'll get a read (shutdown) error when the connection is closed.
256     CHECK_EQ(error, ERROR_READ_SHUTDOWN);
257   }
258 
259   // Waits for all the messages (of sizes |expected_sizes_|) to be seen.
Wait()260   void Wait() { done_event_.Wait(); }
261 
SetExpectedSizes(const std::vector<uint32_t> & expected_sizes)262   void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
263     base::AutoLock locker(lock_);
264     CHECK_EQ(position_, expected_sizes_.size());
265     expected_sizes_ = expected_sizes;
266     position_ = 0;
267   }
268 
269  private:
270   base::WaitableEvent done_event_;
271 
272   base::Lock lock_;  // Protects the following members.
273   std::vector<uint32_t> expected_sizes_;
274   size_t position_;
275 
276   DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
277 };
278 
279 // Tests reading (writing using our own custom writer).
TEST_F(RawChannelTest,OnReadMessage)280 TEST_F(RawChannelTest, OnReadMessage) {
281   ReadCheckerRawChannelDelegate delegate;
282   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
283   io_thread()->PostTaskAndWait(
284       FROM_HERE,
285       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
286 
287   // Write and read, for a variety of sizes.
288   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
289     delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
290 
291     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
292 
293     delegate.Wait();
294   }
295 
296   // Set up reader and write as fast as we can.
297   // Write/queue and read afterwards, for a variety of sizes.
298   std::vector<uint32_t> expected_sizes;
299   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
300     expected_sizes.push_back(size);
301   delegate.SetExpectedSizes(expected_sizes);
302   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
303     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
304   delegate.Wait();
305 
306   io_thread()->PostTaskAndWait(
307       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
308 }
309 
310 // RawChannelTest.WriteMessageAndOnReadMessage ---------------------------------
311 
312 class RawChannelWriterThread : public base::SimpleThread {
313  public:
RawChannelWriterThread(RawChannel * raw_channel,size_t write_count)314   RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
315       : base::SimpleThread("raw_channel_writer_thread"),
316         raw_channel_(raw_channel),
317         left_to_write_(write_count) {}
318 
~RawChannelWriterThread()319   virtual ~RawChannelWriterThread() { Join(); }
320 
321  private:
Run()322   virtual void Run() OVERRIDE {
323     static const int kMaxRandomMessageSize = 25000;
324 
325     while (left_to_write_-- > 0) {
326       EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
327           static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
328     }
329   }
330 
331   RawChannel* const raw_channel_;
332   size_t left_to_write_;
333 
334   DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
335 };
336 
337 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
338  public:
ReadCountdownRawChannelDelegate(size_t expected_count)339   explicit ReadCountdownRawChannelDelegate(size_t expected_count)
340       : done_event_(false, false), expected_count_(expected_count), count_(0) {}
~ReadCountdownRawChannelDelegate()341   virtual ~ReadCountdownRawChannelDelegate() {}
342 
343   // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View & message_view,embedder::ScopedPlatformHandleVectorPtr platform_handles)344   virtual void OnReadMessage(
345       const MessageInTransit::View& message_view,
346       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
347     EXPECT_FALSE(platform_handles);
348 
349     EXPECT_LT(count_, expected_count_);
350     count_++;
351 
352     EXPECT_TRUE(
353         CheckMessageData(message_view.bytes(), message_view.num_bytes()));
354 
355     if (count_ >= expected_count_)
356       done_event_.Signal();
357   }
OnError(Error error)358   virtual void OnError(Error error) OVERRIDE {
359     // We'll get a read (shutdown) error when the connection is closed.
360     CHECK_EQ(error, ERROR_READ_SHUTDOWN);
361   }
362 
363   // Waits for all the messages to have been seen.
Wait()364   void Wait() { done_event_.Wait(); }
365 
366  private:
367   base::WaitableEvent done_event_;
368   size_t expected_count_;
369   size_t count_;
370 
371   DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
372 };
373 
TEST_F(RawChannelTest,WriteMessageAndOnReadMessage)374 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) {
375   static const size_t kNumWriterThreads = 10;
376   static const size_t kNumWriteMessagesPerThread = 4000;
377 
378   WriteOnlyRawChannelDelegate writer_delegate;
379   scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass()));
380   io_thread()->PostTaskAndWait(FROM_HERE,
381                                base::Bind(&InitOnIOThread,
382                                           writer_rc.get(),
383                                           base::Unretained(&writer_delegate)));
384 
385   ReadCountdownRawChannelDelegate reader_delegate(kNumWriterThreads *
386                                                   kNumWriteMessagesPerThread);
387   scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass()));
388   io_thread()->PostTaskAndWait(FROM_HERE,
389                                base::Bind(&InitOnIOThread,
390                                           reader_rc.get(),
391                                           base::Unretained(&reader_delegate)));
392 
393   {
394     ScopedVector<RawChannelWriterThread> writer_threads;
395     for (size_t i = 0; i < kNumWriterThreads; i++) {
396       writer_threads.push_back(new RawChannelWriterThread(
397           writer_rc.get(), kNumWriteMessagesPerThread));
398     }
399     for (size_t i = 0; i < writer_threads.size(); i++)
400       writer_threads[i]->Start();
401   }  // Joins all the writer threads.
402 
403   // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
404   // any, but we want to know about them.)
405   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
406 
407   // Wait for reading to finish.
408   reader_delegate.Wait();
409 
410   io_thread()->PostTaskAndWait(
411       FROM_HERE,
412       base::Bind(&RawChannel::Shutdown, base::Unretained(reader_rc.get())));
413 
414   io_thread()->PostTaskAndWait(
415       FROM_HERE,
416       base::Bind(&RawChannel::Shutdown, base::Unretained(writer_rc.get())));
417 }
418 
419 // RawChannelTest.OnError ------------------------------------------------------
420 
421 class ErrorRecordingRawChannelDelegate
422     : public ReadCountdownRawChannelDelegate {
423  public:
ErrorRecordingRawChannelDelegate(size_t expected_read_count,bool expect_read_error,bool expect_write_error)424   ErrorRecordingRawChannelDelegate(size_t expected_read_count,
425                                    bool expect_read_error,
426                                    bool expect_write_error)
427       : ReadCountdownRawChannelDelegate(expected_read_count),
428         got_read_error_event_(false, false),
429         got_write_error_event_(false, false),
430         expecting_read_error_(expect_read_error),
431         expecting_write_error_(expect_write_error) {}
432 
~ErrorRecordingRawChannelDelegate()433   virtual ~ErrorRecordingRawChannelDelegate() {}
434 
OnError(Error error)435   virtual void OnError(Error error) OVERRIDE {
436     switch (error) {
437       case ERROR_READ_SHUTDOWN:
438         ASSERT_TRUE(expecting_read_error_);
439         expecting_read_error_ = false;
440         got_read_error_event_.Signal();
441         break;
442       case ERROR_READ_BROKEN:
443         // TODO(vtl): Test broken connections.
444         CHECK(false);
445         break;
446       case ERROR_READ_BAD_MESSAGE:
447         // TODO(vtl): Test reception/detection of bad messages.
448         CHECK(false);
449         break;
450       case ERROR_READ_UNKNOWN:
451         // TODO(vtl): Test however it is we might get here.
452         CHECK(false);
453         break;
454       case ERROR_WRITE:
455         ASSERT_TRUE(expecting_write_error_);
456         expecting_write_error_ = false;
457         got_write_error_event_.Signal();
458         break;
459     }
460   }
461 
WaitForReadError()462   void WaitForReadError() { got_read_error_event_.Wait(); }
WaitForWriteError()463   void WaitForWriteError() { got_write_error_event_.Wait(); }
464 
465  private:
466   base::WaitableEvent got_read_error_event_;
467   base::WaitableEvent got_write_error_event_;
468 
469   bool expecting_read_error_;
470   bool expecting_write_error_;
471 
472   DISALLOW_COPY_AND_ASSIGN(ErrorRecordingRawChannelDelegate);
473 };
474 
475 // Tests (fatal) errors.
TEST_F(RawChannelTest,OnError)476 TEST_F(RawChannelTest, OnError) {
477   ErrorRecordingRawChannelDelegate delegate(0, true, true);
478   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
479   io_thread()->PostTaskAndWait(
480       FROM_HERE,
481       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
482 
483   // Close the handle of the other end, which should make writing fail.
484   handles[1].reset();
485 
486   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
487 
488   // We should get a write error.
489   delegate.WaitForWriteError();
490 
491   // We should also get a read error.
492   delegate.WaitForReadError();
493 
494   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
495 
496   // Sleep a bit, to make sure we don't get another |OnError()|
497   // notification. (If we actually get another one, |OnError()| crashes.)
498   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20));
499 
500   io_thread()->PostTaskAndWait(
501       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
502 }
503 
504 // RawChannelTest.ReadUnaffectedByWriteError -----------------------------------
505 
TEST_F(RawChannelTest,ReadUnaffectedByWriteError)506 TEST_F(RawChannelTest, ReadUnaffectedByWriteError) {
507   const size_t kMessageCount = 5;
508 
509   // Write a few messages into the other end.
510   uint32_t message_size = 1;
511   for (size_t i = 0; i < kMessageCount;
512        i++, message_size += message_size / 2 + 1)
513     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
514 
515   // Close the other end, which should make writing fail.
516   handles[1].reset();
517 
518   // Only start up reading here. The system buffer should still contain the
519   // messages that were written.
520   ErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true);
521   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
522   io_thread()->PostTaskAndWait(
523       FROM_HERE,
524       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
525 
526   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
527 
528   // We should definitely get a write error.
529   delegate.WaitForWriteError();
530 
531   // Wait for reading to finish. A writing failure shouldn't affect reading.
532   delegate.Wait();
533 
534   // And then we should get a read error.
535   delegate.WaitForReadError();
536 
537   io_thread()->PostTaskAndWait(
538       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
539 }
540 
541 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------
542 
543 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
544 // correctly.
TEST_F(RawChannelTest,WriteMessageAfterShutdown)545 TEST_F(RawChannelTest, WriteMessageAfterShutdown) {
546   WriteOnlyRawChannelDelegate delegate;
547   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
548   io_thread()->PostTaskAndWait(
549       FROM_HERE,
550       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
551   io_thread()->PostTaskAndWait(
552       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
553 
554   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
555 }
556 
557 // RawChannelTest.ShutdownOnReadMessage ----------------------------------------
558 
559 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate {
560  public:
ShutdownOnReadMessageRawChannelDelegate(RawChannel * raw_channel)561   explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel)
562       : raw_channel_(raw_channel),
563         done_event_(false, false),
564         did_shutdown_(false) {}
~ShutdownOnReadMessageRawChannelDelegate()565   virtual ~ShutdownOnReadMessageRawChannelDelegate() {}
566 
567   // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View & message_view,embedder::ScopedPlatformHandleVectorPtr platform_handles)568   virtual void OnReadMessage(
569       const MessageInTransit::View& message_view,
570       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
571     EXPECT_FALSE(platform_handles);
572     EXPECT_FALSE(did_shutdown_);
573     EXPECT_TRUE(
574         CheckMessageData(message_view.bytes(), message_view.num_bytes()));
575     raw_channel_->Shutdown();
576     did_shutdown_ = true;
577     done_event_.Signal();
578   }
OnError(Error)579   virtual void OnError(Error /*error*/) OVERRIDE {
580     CHECK(false);  // Should not get called.
581   }
582 
583   // Waits for shutdown.
Wait()584   void Wait() {
585     done_event_.Wait();
586     EXPECT_TRUE(did_shutdown_);
587   }
588 
589  private:
590   RawChannel* const raw_channel_;
591   base::WaitableEvent done_event_;
592   bool did_shutdown_;
593 
594   DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate);
595 };
596 
TEST_F(RawChannelTest,ShutdownOnReadMessage)597 TEST_F(RawChannelTest, ShutdownOnReadMessage) {
598   // Write a few messages into the other end.
599   for (size_t count = 0; count < 5; count++)
600     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10));
601 
602   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
603   ShutdownOnReadMessageRawChannelDelegate delegate(rc.get());
604   io_thread()->PostTaskAndWait(
605       FROM_HERE,
606       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
607 
608   // Wait for the delegate, which will shut the |RawChannel| down.
609   delegate.Wait();
610 }
611 
612 // RawChannelTest.ShutdownOnError{Read, Write} ---------------------------------
613 
614 class ShutdownOnErrorRawChannelDelegate : public RawChannel::Delegate {
615  public:
ShutdownOnErrorRawChannelDelegate(RawChannel * raw_channel,Error shutdown_on_error_type)616   ShutdownOnErrorRawChannelDelegate(RawChannel* raw_channel,
617                                     Error shutdown_on_error_type)
618       : raw_channel_(raw_channel),
619         shutdown_on_error_type_(shutdown_on_error_type),
620         done_event_(false, false),
621         did_shutdown_(false) {}
~ShutdownOnErrorRawChannelDelegate()622   virtual ~ShutdownOnErrorRawChannelDelegate() {}
623 
624   // |RawChannel::Delegate| implementation (called on the I/O thread):
OnReadMessage(const MessageInTransit::View &,embedder::ScopedPlatformHandleVectorPtr)625   virtual void OnReadMessage(
626       const MessageInTransit::View& /*message_view*/,
627       embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
628     CHECK(false);  // Should not get called.
629   }
OnError(Error error)630   virtual void OnError(Error error) OVERRIDE {
631     EXPECT_FALSE(did_shutdown_);
632     if (error != shutdown_on_error_type_)
633       return;
634     raw_channel_->Shutdown();
635     did_shutdown_ = true;
636     done_event_.Signal();
637   }
638 
639   // Waits for shutdown.
Wait()640   void Wait() {
641     done_event_.Wait();
642     EXPECT_TRUE(did_shutdown_);
643   }
644 
645  private:
646   RawChannel* const raw_channel_;
647   const Error shutdown_on_error_type_;
648   base::WaitableEvent done_event_;
649   bool did_shutdown_;
650 
651   DISALLOW_COPY_AND_ASSIGN(ShutdownOnErrorRawChannelDelegate);
652 };
653 
TEST_F(RawChannelTest,ShutdownOnErrorRead)654 TEST_F(RawChannelTest, ShutdownOnErrorRead) {
655   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
656   ShutdownOnErrorRawChannelDelegate delegate(
657       rc.get(), RawChannel::Delegate::ERROR_READ_SHUTDOWN);
658   io_thread()->PostTaskAndWait(
659       FROM_HERE,
660       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
661 
662   // Close the handle of the other end, which should stuff fail.
663   handles[1].reset();
664 
665   // Wait for the delegate, which will shut the |RawChannel| down.
666   delegate.Wait();
667 }
668 
TEST_F(RawChannelTest,ShutdownOnErrorWrite)669 TEST_F(RawChannelTest, ShutdownOnErrorWrite) {
670   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
671   ShutdownOnErrorRawChannelDelegate delegate(rc.get(),
672                                              RawChannel::Delegate::ERROR_WRITE);
673   io_thread()->PostTaskAndWait(
674       FROM_HERE,
675       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
676 
677   // Close the handle of the other end, which should stuff fail.
678   handles[1].reset();
679 
680   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
681 
682   // Wait for the delegate, which will shut the |RawChannel| down.
683   delegate.Wait();
684 }
685 
686 }  // namespace
687 }  // namespace system
688 }  // namespace mojo
689