1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/core/channel.h"
6
7 #include "base/bind.h"
8 #include "base/memory/ptr_util.h"
9 #include "base/message_loop/message_loop.h"
10 #include "base/threading/thread.h"
11 #include "mojo/core/platform_handle_utils.h"
12 #include "mojo/public/cpp/platform/platform_channel.h"
13 #include "testing/gmock/include/gmock/gmock.h"
14 #include "testing/gtest/include/gtest/gtest.h"
15
16 namespace mojo {
17 namespace core {
18 namespace {
19
20 class TestChannel : public Channel {
21 public:
TestChannel(Channel::Delegate * delegate)22 TestChannel(Channel::Delegate* delegate) : Channel(delegate) {}
23
GetReadBufferTest(size_t * buffer_capacity)24 char* GetReadBufferTest(size_t* buffer_capacity) {
25 return GetReadBuffer(buffer_capacity);
26 }
27
OnReadCompleteTest(size_t bytes_read,size_t * next_read_size_hint)28 bool OnReadCompleteTest(size_t bytes_read, size_t* next_read_size_hint) {
29 return OnReadComplete(bytes_read, next_read_size_hint);
30 }
31
32 MOCK_METHOD7(GetReadPlatformHandles,
33 bool(const void* payload,
34 size_t payload_size,
35 size_t num_handles,
36 const void* extra_header,
37 size_t extra_header_size,
38 std::vector<PlatformHandle>* handles,
39 bool* deferred));
40 MOCK_METHOD0(Start, void());
41 MOCK_METHOD0(ShutDownImpl, void());
42 MOCK_METHOD0(LeakHandle, void());
43
Write(MessagePtr message)44 void Write(MessagePtr message) override {}
45
46 protected:
~TestChannel()47 ~TestChannel() override {}
48 };
49
50 // Not using GMock as I don't think it supports movable types.
51 class MockChannelDelegate : public Channel::Delegate {
52 public:
MockChannelDelegate()53 MockChannelDelegate() {}
54
GetReceivedPayloadSize() const55 size_t GetReceivedPayloadSize() const { return payload_size_; }
56
GetReceivedPayload() const57 const void* GetReceivedPayload() const { return payload_.get(); }
58
59 protected:
OnChannelMessage(const void * payload,size_t payload_size,std::vector<PlatformHandle> handles)60 void OnChannelMessage(const void* payload,
61 size_t payload_size,
62 std::vector<PlatformHandle> handles) override {
63 payload_.reset(new char[payload_size]);
64 memcpy(payload_.get(), payload, payload_size);
65 payload_size_ = payload_size;
66 }
67
68 // Notify that an error has occured and the Channel will cease operation.
OnChannelError(Channel::Error error)69 void OnChannelError(Channel::Error error) override {}
70
71 private:
72 size_t payload_size_ = 0;
73 std::unique_ptr<char[]> payload_;
74 };
75
CreateDefaultMessage(bool legacy_message)76 Channel::MessagePtr CreateDefaultMessage(bool legacy_message) {
77 const size_t payload_size = 100;
78 Channel::MessagePtr message = std::make_unique<Channel::Message>(
79 payload_size, 0,
80 legacy_message ? Channel::Message::MessageType::NORMAL_LEGACY
81 : Channel::Message::MessageType::NORMAL);
82 char* payload = static_cast<char*>(message->mutable_payload());
83 for (size_t i = 0; i < payload_size; i++) {
84 payload[i] = static_cast<char>(i);
85 }
86 return message;
87 }
88
TestMemoryEqual(const void * data1,size_t data1_size,const void * data2,size_t data2_size)89 void TestMemoryEqual(const void* data1,
90 size_t data1_size,
91 const void* data2,
92 size_t data2_size) {
93 ASSERT_EQ(data1_size, data2_size);
94 const unsigned char* data1_char = static_cast<const unsigned char*>(data1);
95 const unsigned char* data2_char = static_cast<const unsigned char*>(data2);
96 for (size_t i = 0; i < data1_size; i++) {
97 // ASSERT so we don't log tons of errors if the data is different.
98 ASSERT_EQ(data1_char[i], data2_char[i]);
99 }
100 }
101
TestMessagesAreEqual(Channel::Message * message1,Channel::Message * message2,bool legacy_messages)102 void TestMessagesAreEqual(Channel::Message* message1,
103 Channel::Message* message2,
104 bool legacy_messages) {
105 // If any of the message is null, this is probably not what you wanted to
106 // test.
107 ASSERT_NE(nullptr, message1);
108 ASSERT_NE(nullptr, message2);
109
110 ASSERT_EQ(message1->payload_size(), message2->payload_size());
111 EXPECT_EQ(message1->has_handles(), message2->has_handles());
112
113 TestMemoryEqual(message1->payload(), message1->payload_size(),
114 message2->payload(), message2->payload_size());
115
116 if (legacy_messages)
117 return;
118
119 ASSERT_EQ(message1->extra_header_size(), message2->extra_header_size());
120 TestMemoryEqual(message1->extra_header(), message1->extra_header_size(),
121 message2->extra_header(), message2->extra_header_size());
122 }
123
TEST(ChannelTest,LegacyMessageDeserialization)124 TEST(ChannelTest, LegacyMessageDeserialization) {
125 Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */);
126 Channel::MessagePtr deserialized_message =
127 Channel::Message::Deserialize(message->data(), message->data_num_bytes());
128 TestMessagesAreEqual(message.get(), deserialized_message.get(),
129 true /* legacy_message */);
130 }
131
TEST(ChannelTest,NonLegacyMessageDeserialization)132 TEST(ChannelTest, NonLegacyMessageDeserialization) {
133 Channel::MessagePtr message =
134 CreateDefaultMessage(false /* legacy_message */);
135 Channel::MessagePtr deserialized_message =
136 Channel::Message::Deserialize(message->data(), message->data_num_bytes());
137 TestMessagesAreEqual(message.get(), deserialized_message.get(),
138 false /* legacy_message */);
139 }
140
TEST(ChannelTest,OnReadLegacyMessage)141 TEST(ChannelTest, OnReadLegacyMessage) {
142 size_t buffer_size = 100 * 1024;
143 Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */);
144
145 MockChannelDelegate channel_delegate;
146 scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate);
147 char* read_buffer = channel->GetReadBufferTest(&buffer_size);
148 ASSERT_LT(message->data_num_bytes(),
149 buffer_size); // Bad test. Increase buffer
150 // size.
151 memcpy(read_buffer, message->data(), message->data_num_bytes());
152
153 size_t next_read_size_hint = 0;
154 EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(),
155 &next_read_size_hint));
156
157 TestMemoryEqual(message->payload(), message->payload_size(),
158 channel_delegate.GetReceivedPayload(),
159 channel_delegate.GetReceivedPayloadSize());
160 }
161
TEST(ChannelTest,OnReadNonLegacyMessage)162 TEST(ChannelTest, OnReadNonLegacyMessage) {
163 size_t buffer_size = 100 * 1024;
164 Channel::MessagePtr message =
165 CreateDefaultMessage(false /* legacy_message */);
166
167 MockChannelDelegate channel_delegate;
168 scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate);
169 char* read_buffer = channel->GetReadBufferTest(&buffer_size);
170 ASSERT_LT(message->data_num_bytes(),
171 buffer_size); // Bad test. Increase buffer
172 // size.
173 memcpy(read_buffer, message->data(), message->data_num_bytes());
174
175 size_t next_read_size_hint = 0;
176 EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(),
177 &next_read_size_hint));
178
179 TestMemoryEqual(message->payload(), message->payload_size(),
180 channel_delegate.GetReceivedPayload(),
181 channel_delegate.GetReceivedPayloadSize());
182 }
183
184 class ChannelTestShutdownAndWriteDelegate : public Channel::Delegate {
185 public:
ChannelTestShutdownAndWriteDelegate(PlatformChannelEndpoint endpoint,scoped_refptr<base::TaskRunner> task_runner,scoped_refptr<Channel> client_channel,std::unique_ptr<base::Thread> client_thread,base::RepeatingClosure quit_closure)186 ChannelTestShutdownAndWriteDelegate(
187 PlatformChannelEndpoint endpoint,
188 scoped_refptr<base::TaskRunner> task_runner,
189 scoped_refptr<Channel> client_channel,
190 std::unique_ptr<base::Thread> client_thread,
191 base::RepeatingClosure quit_closure)
192 : quit_closure_(std::move(quit_closure)),
193 client_channel_(std::move(client_channel)),
194 client_thread_(std::move(client_thread)) {
195 channel_ = Channel::Create(this, ConnectionParams(std::move(endpoint)),
196 std::move(task_runner));
197 channel_->Start();
198 }
~ChannelTestShutdownAndWriteDelegate()199 ~ChannelTestShutdownAndWriteDelegate() override { channel_->ShutDown(); }
200
201 // Channel::Delegate implementation
OnChannelMessage(const void * payload,size_t payload_size,std::vector<PlatformHandle> handles)202 void OnChannelMessage(const void* payload,
203 size_t payload_size,
204 std::vector<PlatformHandle> handles) override {
205 ++message_count_;
206
207 // If |client_channel_| exists then close it and its thread.
208 if (client_channel_) {
209 // Write a fresh message, making our channel readable again.
210 Channel::MessagePtr message = CreateDefaultMessage(false);
211 client_thread_->task_runner()->PostTask(
212 FROM_HERE, base::BindOnce(&Channel::Write, client_channel_,
213 base::Passed(&message)));
214
215 // Close the channel and wait for it to shutdown.
216 client_channel_->ShutDown();
217 client_channel_ = nullptr;
218
219 client_thread_->Stop();
220 client_thread_ = nullptr;
221 }
222
223 // Write a message to the channel, to verify whether this triggers an
224 // OnChannelError callback before all messages were read.
225 Channel::MessagePtr message = CreateDefaultMessage(false);
226 channel_->Write(std::move(message));
227 }
228
OnChannelError(Channel::Error error)229 void OnChannelError(Channel::Error error) override {
230 EXPECT_EQ(2, message_count_);
231 quit_closure_.Run();
232 }
233
234 base::RepeatingClosure quit_closure_;
235 int message_count_ = 0;
236 scoped_refptr<Channel> channel_;
237
238 scoped_refptr<Channel> client_channel_;
239 std::unique_ptr<base::Thread> client_thread_;
240 };
241
TEST(ChannelTest,PeerShutdownDuringRead)242 TEST(ChannelTest, PeerShutdownDuringRead) {
243 base::MessageLoop message_loop(base::MessageLoop::TYPE_IO);
244 PlatformChannel channel;
245
246 // Create a "client" Channel with one end of the pipe, and Start() it.
247 std::unique_ptr<base::Thread> client_thread =
248 std::make_unique<base::Thread>("clientio_thread");
249 client_thread->StartWithOptions(
250 base::Thread::Options(base::MessageLoop::TYPE_IO, 0));
251
252 scoped_refptr<Channel> client_channel =
253 Channel::Create(nullptr, ConnectionParams(channel.TakeRemoteEndpoint()),
254 client_thread->task_runner());
255 client_channel->Start();
256
257 // On the "client" IO thread, create and write a message.
258 Channel::MessagePtr message = CreateDefaultMessage(false);
259 client_thread->task_runner()->PostTask(
260 FROM_HERE,
261 base::BindOnce(&Channel::Write, client_channel, base::Passed(&message)));
262
263 // Create a "server" Channel with the other end of the pipe, and process the
264 // messages from it. The |server_delegate| will ShutDown the client end of
265 // the pipe after the first message, and quit the RunLoop when OnChannelError
266 // is received.
267 base::RunLoop run_loop;
268 ChannelTestShutdownAndWriteDelegate server_delegate(
269 channel.TakeLocalEndpoint(), message_loop.task_runner(),
270 std::move(client_channel), std::move(client_thread),
271 run_loop.QuitClosure());
272
273 run_loop.Run();
274 }
275
276 } // namespace
277 } // namespace core
278 } // namespace mojo
279