1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <stddef.h>
6 #include <stdint.h>
7
8 #include <memory>
9
10 #include "base/memory/raw_ptr.h"
11 #include "base/message_loop/message_pump_type.h"
12 #include "base/notreached.h"
13 #include "base/pickle.h"
14 #include "base/run_loop.h"
15 #include "base/task/single_thread_task_runner.h"
16 #include "base/threading/thread.h"
17 #include "build/build_config.h"
18 #include "ipc/ipc_message.h"
19 #include "ipc/ipc_test_base.h"
20 #include "ipc/message_filter.h"
21
22 // Get basic type definitions.
23 #define IPC_MESSAGE_IMPL
24 #include "ipc/ipc_channel_proxy_unittest_messages.h"
25
26 // Generate constructors.
27 #include "ipc/struct_constructor_macros.h"
28 #include "ipc/ipc_channel_proxy_unittest_messages.h"
29
30 // Generate param traits write methods.
31 #include "ipc/param_traits_write_macros.h"
32 namespace IPC {
33 #include "ipc/ipc_channel_proxy_unittest_messages.h"
34 } // namespace IPC
35
36 // Generate param traits read methods.
37 #include "ipc/param_traits_read_macros.h"
38 namespace IPC {
39 #include "ipc/ipc_channel_proxy_unittest_messages.h"
40 } // namespace IPC
41
42 // Generate param traits log methods.
43 #include "ipc/param_traits_log_macros.h"
44 namespace IPC {
45 #include "ipc/ipc_channel_proxy_unittest_messages.h"
46 } // namespace IPC
47
48
49 namespace {
50
CreateRunLoopAndRun(raw_ptr<base::RunLoop> * run_loop_ptr)51 void CreateRunLoopAndRun(raw_ptr<base::RunLoop>* run_loop_ptr) {
52 base::RunLoop run_loop;
53 *run_loop_ptr = &run_loop;
54 run_loop.Run();
55 *run_loop_ptr = nullptr;
56 }
57
58 class QuitListener : public IPC::Listener {
59 public:
60 QuitListener() = default;
61
OnMessageReceived(const IPC::Message & message)62 bool OnMessageReceived(const IPC::Message& message) override {
63 IPC_BEGIN_MESSAGE_MAP(QuitListener, message)
64 IPC_MESSAGE_HANDLER(WorkerMsg_Quit, OnQuit)
65 IPC_MESSAGE_HANDLER(TestMsg_BadMessage, OnBadMessage)
66 IPC_END_MESSAGE_MAP()
67 return true;
68 }
69
OnBadMessageReceived(const IPC::Message & message)70 void OnBadMessageReceived(const IPC::Message& message) override {
71 bad_message_received_ = true;
72 }
73
OnChannelError()74 void OnChannelError() override { CHECK(quit_message_received_); }
75
OnQuit()76 void OnQuit() {
77 quit_message_received_ = true;
78 run_loop_->QuitWhenIdle();
79 }
80
OnBadMessage(const BadType & bad_type)81 void OnBadMessage(const BadType& bad_type) {
82 // Should never be called since IPC wouldn't be deserialized correctly.
83 NOTREACHED();
84 }
85
86 bool bad_message_received_ = false;
87 bool quit_message_received_ = false;
88 raw_ptr<base::RunLoop> run_loop_ = nullptr;
89 };
90
91 class ChannelReflectorListener : public IPC::Listener {
92 public:
93 ChannelReflectorListener() = default;
94
Init(IPC::Channel * channel)95 void Init(IPC::Channel* channel) {
96 DCHECK(!channel_);
97 channel_ = channel;
98 }
99
OnMessageReceived(const IPC::Message & message)100 bool OnMessageReceived(const IPC::Message& message) override {
101 IPC_BEGIN_MESSAGE_MAP(ChannelReflectorListener, message)
102 IPC_MESSAGE_HANDLER(TestMsg_Bounce, OnTestBounce)
103 IPC_MESSAGE_HANDLER(TestMsg_SendBadMessage, OnSendBadMessage)
104 IPC_MESSAGE_HANDLER(AutomationMsg_Bounce, OnAutomationBounce)
105 IPC_MESSAGE_HANDLER(WorkerMsg_Bounce, OnBounce)
106 IPC_MESSAGE_HANDLER(WorkerMsg_Quit, OnQuit)
107 IPC_END_MESSAGE_MAP()
108 return true;
109 }
110
OnTestBounce()111 void OnTestBounce() {
112 channel_->Send(new TestMsg_Bounce());
113 }
114
OnSendBadMessage()115 void OnSendBadMessage() {
116 channel_->Send(new TestMsg_BadMessage(BadType()));
117 }
118
OnAutomationBounce()119 void OnAutomationBounce() { channel_->Send(new AutomationMsg_Bounce()); }
120
OnBounce()121 void OnBounce() {
122 channel_->Send(new WorkerMsg_Bounce());
123 }
124
OnQuit()125 void OnQuit() {
126 channel_->Send(new WorkerMsg_Quit());
127 run_loop_->QuitWhenIdle();
128 }
129
130 raw_ptr<base::RunLoop> run_loop_ = nullptr;
131
132 private:
133 raw_ptr<IPC::Channel> channel_ = nullptr;
134 };
135
136 class MessageCountFilter : public IPC::MessageFilter {
137 public:
138 enum FilterEvent {
139 NONE,
140 FILTER_ADDED,
141 CHANNEL_CONNECTED,
142 CHANNEL_ERROR,
143 CHANNEL_CLOSING,
144 FILTER_REMOVED
145 };
146
147 MessageCountFilter() = default;
MessageCountFilter(uint32_t supported_message_class)148 MessageCountFilter(uint32_t supported_message_class)
149 : supported_message_class_(supported_message_class),
150 is_global_filter_(false) {}
151
OnFilterAdded(IPC::Channel * channel)152 void OnFilterAdded(IPC::Channel* channel) override {
153 EXPECT_TRUE(channel);
154 EXPECT_EQ(NONE, last_filter_event_);
155 last_filter_event_ = FILTER_ADDED;
156 }
157
OnChannelConnected(int32_t peer_pid)158 void OnChannelConnected(int32_t peer_pid) override {
159 EXPECT_EQ(FILTER_ADDED, last_filter_event_);
160 EXPECT_NE(static_cast<int32_t>(base::kNullProcessId), peer_pid);
161 last_filter_event_ = CHANNEL_CONNECTED;
162 }
163
OnChannelError()164 void OnChannelError() override {
165 EXPECT_EQ(CHANNEL_CONNECTED, last_filter_event_);
166 last_filter_event_ = CHANNEL_ERROR;
167 }
168
OnChannelClosing()169 void OnChannelClosing() override {
170 // We may or may not have gotten OnChannelError; if not, the last event has
171 // to be OnChannelConnected.
172 EXPECT_NE(FILTER_REMOVED, last_filter_event_);
173 if (last_filter_event_ != CHANNEL_ERROR)
174 EXPECT_EQ(CHANNEL_CONNECTED, last_filter_event_);
175 last_filter_event_ = CHANNEL_CLOSING;
176 }
177
OnFilterRemoved()178 void OnFilterRemoved() override {
179 // A filter may be removed at any time, even before the channel is connected
180 // (and thus before OnFilterAdded is ever able to dispatch.) The only time
181 // we won't see OnFilterRemoved is immediately after OnFilterAdded, because
182 // OnChannelConnected is always the next event to fire after that.
183 EXPECT_NE(FILTER_ADDED, last_filter_event_);
184 last_filter_event_ = FILTER_REMOVED;
185 }
186
OnMessageReceived(const IPC::Message & message)187 bool OnMessageReceived(const IPC::Message& message) override {
188 // We should always get the OnFilterAdded and OnChannelConnected events
189 // prior to any messages.
190 EXPECT_EQ(CHANNEL_CONNECTED, last_filter_event_);
191
192 if (!is_global_filter_) {
193 EXPECT_EQ(supported_message_class_, IPC_MESSAGE_CLASS(message));
194 }
195 ++messages_received_;
196
197 if (!message_filtering_enabled_)
198 return false;
199
200 bool handled = true;
201 IPC_BEGIN_MESSAGE_MAP(MessageCountFilter, message)
202 IPC_MESSAGE_HANDLER(TestMsg_BadMessage, OnBadMessage)
203 IPC_MESSAGE_UNHANDLED(handled = false)
204 IPC_END_MESSAGE_MAP()
205 return handled;
206 }
207
OnBadMessage(const BadType & bad_type)208 void OnBadMessage(const BadType& bad_type) {
209 // Should never be called since IPC wouldn't be deserialized correctly.
210 NOTREACHED();
211 }
212
GetSupportedMessageClasses(std::vector<uint32_t> * supported_message_classes) const213 bool GetSupportedMessageClasses(
214 std::vector<uint32_t>* supported_message_classes) const override {
215 if (is_global_filter_)
216 return false;
217 supported_message_classes->push_back(supported_message_class_);
218 return true;
219 }
220
set_message_filtering_enabled(bool enabled)221 void set_message_filtering_enabled(bool enabled) {
222 message_filtering_enabled_ = enabled;
223 }
224
messages_received() const225 size_t messages_received() const { return messages_received_; }
last_filter_event() const226 FilterEvent last_filter_event() const { return last_filter_event_; }
227
228 private:
229 ~MessageCountFilter() override = default;
230
231 size_t messages_received_ = 0;
232 uint32_t supported_message_class_ = 0;
233 bool is_global_filter_ = true;
234
235 FilterEvent last_filter_event_ = NONE;
236 bool message_filtering_enabled_ = false;
237 };
238
239 class IPCChannelProxyTest : public IPCChannelMojoTestBase {
240 public:
241 IPCChannelProxyTest() = default;
242 ~IPCChannelProxyTest() override = default;
243
SetUp()244 void SetUp() override {
245 IPCChannelMojoTestBase::SetUp();
246
247 Init("ChannelProxyClient");
248
249 thread_ = std::make_unique<base::Thread>("ChannelProxyTestServerThread");
250 base::Thread::Options options;
251 options.message_pump_type = base::MessagePumpType::IO;
252 thread_->StartWithOptions(std::move(options));
253
254 listener_ = std::make_unique<QuitListener>();
255 channel_proxy_ = IPC::ChannelProxy::Create(
256 TakeHandle().release(), IPC::Channel::MODE_SERVER, listener_.get(),
257 thread_->task_runner(),
258 base::SingleThreadTaskRunner::GetCurrentDefault());
259 }
260
TearDown()261 void TearDown() override {
262 channel_proxy_.reset();
263 thread_.reset();
264 listener_.reset();
265 IPCChannelMojoTestBase::TearDown();
266 }
267
SendQuitMessageAndWaitForIdle()268 void SendQuitMessageAndWaitForIdle() {
269 sender()->Send(new WorkerMsg_Quit);
270 CreateRunLoopAndRun(&listener_->run_loop_);
271 EXPECT_TRUE(WaitForClientShutdown());
272 }
273
DidListenerGetBadMessage()274 bool DidListenerGetBadMessage() {
275 return listener_->bad_message_received_;
276 }
277
channel_proxy()278 IPC::ChannelProxy* channel_proxy() { return channel_proxy_.get(); }
sender()279 IPC::Sender* sender() { return channel_proxy_.get(); }
280
281 private:
282 std::unique_ptr<base::Thread> thread_;
283 std::unique_ptr<QuitListener> listener_;
284 std::unique_ptr<IPC::ChannelProxy> channel_proxy_;
285 };
286
TEST_F(IPCChannelProxyTest,MessageClassFilters)287 TEST_F(IPCChannelProxyTest, MessageClassFilters) {
288 // Construct a filter per message class.
289 std::vector<scoped_refptr<MessageCountFilter>> class_filters;
290 class_filters.push_back(
291 base::MakeRefCounted<MessageCountFilter>(TestMsgStart));
292 class_filters.push_back(
293 base::MakeRefCounted<MessageCountFilter>(AutomationMsgStart));
294 for (size_t i = 0; i < class_filters.size(); ++i)
295 channel_proxy()->AddFilter(class_filters[i].get());
296
297 // Send a message for each class; each filter should receive just one message.
298 sender()->Send(new TestMsg_Bounce);
299 sender()->Send(new AutomationMsg_Bounce);
300
301 // Send some messages not assigned to a specific or valid message class.
302 sender()->Send(new WorkerMsg_Bounce);
303
304 // Each filter should have received just the one sent message of the
305 // corresponding class.
306 SendQuitMessageAndWaitForIdle();
307 for (size_t i = 0; i < class_filters.size(); ++i)
308 EXPECT_EQ(1U, class_filters[i]->messages_received());
309 }
310
TEST_F(IPCChannelProxyTest,GlobalAndMessageClassFilters)311 TEST_F(IPCChannelProxyTest, GlobalAndMessageClassFilters) {
312 // Add a class and global filter.
313 scoped_refptr<MessageCountFilter> class_filter(
314 new MessageCountFilter(TestMsgStart));
315 class_filter->set_message_filtering_enabled(false);
316 channel_proxy()->AddFilter(class_filter.get());
317
318 scoped_refptr<MessageCountFilter> global_filter(new MessageCountFilter());
319 global_filter->set_message_filtering_enabled(false);
320 channel_proxy()->AddFilter(global_filter.get());
321
322 // A message of class Test should be seen by both the global filter and
323 // Test-specific filter.
324 sender()->Send(new TestMsg_Bounce);
325
326 // A message of a different class should be seen only by the global filter.
327 sender()->Send(new AutomationMsg_Bounce);
328
329 // Flush all messages.
330 SendQuitMessageAndWaitForIdle();
331
332 // The class filter should have received only the class-specific message.
333 EXPECT_EQ(1U, class_filter->messages_received());
334
335 // The global filter should have received both messages, as well as the final
336 // QUIT message.
337 EXPECT_EQ(3U, global_filter->messages_received());
338 }
339
TEST_F(IPCChannelProxyTest,FilterRemoval)340 TEST_F(IPCChannelProxyTest, FilterRemoval) {
341 // Add a class and global filter.
342 scoped_refptr<MessageCountFilter> class_filter(
343 new MessageCountFilter(TestMsgStart));
344 scoped_refptr<MessageCountFilter> global_filter(new MessageCountFilter());
345
346 // Add and remove both types of filters.
347 channel_proxy()->AddFilter(class_filter.get());
348 channel_proxy()->AddFilter(global_filter.get());
349 channel_proxy()->RemoveFilter(global_filter.get());
350 channel_proxy()->RemoveFilter(class_filter.get());
351
352 // Send some messages; they should not be seen by either filter.
353 sender()->Send(new TestMsg_Bounce);
354 sender()->Send(new AutomationMsg_Bounce);
355
356 // Ensure that the filters were removed and did not receive any messages.
357 SendQuitMessageAndWaitForIdle();
358 EXPECT_EQ(MessageCountFilter::FILTER_REMOVED,
359 global_filter->last_filter_event());
360 EXPECT_EQ(MessageCountFilter::FILTER_REMOVED,
361 class_filter->last_filter_event());
362 EXPECT_EQ(0U, class_filter->messages_received());
363 EXPECT_EQ(0U, global_filter->messages_received());
364 }
365
TEST_F(IPCChannelProxyTest,BadMessageOnListenerThread)366 TEST_F(IPCChannelProxyTest, BadMessageOnListenerThread) {
367 scoped_refptr<MessageCountFilter> class_filter(
368 new MessageCountFilter(TestMsgStart));
369 class_filter->set_message_filtering_enabled(false);
370 channel_proxy()->AddFilter(class_filter.get());
371
372 sender()->Send(new TestMsg_SendBadMessage());
373
374 SendQuitMessageAndWaitForIdle();
375 EXPECT_TRUE(DidListenerGetBadMessage());
376 }
377
TEST_F(IPCChannelProxyTest,BadMessageOnIPCThread)378 TEST_F(IPCChannelProxyTest, BadMessageOnIPCThread) {
379 scoped_refptr<MessageCountFilter> class_filter(
380 new MessageCountFilter(TestMsgStart));
381 class_filter->set_message_filtering_enabled(true);
382 channel_proxy()->AddFilter(class_filter.get());
383
384 sender()->Send(new TestMsg_SendBadMessage());
385
386 SendQuitMessageAndWaitForIdle();
387 EXPECT_TRUE(DidListenerGetBadMessage());
388 }
389
390 class IPCChannelBadMessageTest : public IPCChannelMojoTestBase {
391 public:
SetUp()392 void SetUp() override {
393 IPCChannelMojoTestBase::SetUp();
394
395 Init("ChannelProxyClient");
396
397 listener_ = std::make_unique<QuitListener>();
398 CreateChannel(listener_.get());
399 ASSERT_TRUE(ConnectChannel());
400 }
401
TearDown()402 void TearDown() override {
403 IPCChannelMojoTestBase::TearDown();
404 listener_.reset();
405 }
406
SendQuitMessageAndWaitForIdle()407 void SendQuitMessageAndWaitForIdle() {
408 sender()->Send(new WorkerMsg_Quit);
409 CreateRunLoopAndRun(&listener_->run_loop_);
410 EXPECT_TRUE(WaitForClientShutdown());
411 }
412
DidListenerGetBadMessage()413 bool DidListenerGetBadMessage() {
414 return listener_->bad_message_received_;
415 }
416
417 private:
418 std::unique_ptr<QuitListener> listener_;
419 };
420
TEST_F(IPCChannelBadMessageTest,BadMessage)421 TEST_F(IPCChannelBadMessageTest, BadMessage) {
422 sender()->Send(new TestMsg_SendBadMessage());
423 SendQuitMessageAndWaitForIdle();
424 EXPECT_TRUE(DidListenerGetBadMessage());
425 }
426
DEFINE_IPC_CHANNEL_MOJO_TEST_CLIENT(ChannelProxyClient)427 DEFINE_IPC_CHANNEL_MOJO_TEST_CLIENT(ChannelProxyClient) {
428 ChannelReflectorListener listener;
429 Connect(&listener);
430 listener.Init(channel());
431
432 CreateRunLoopAndRun(&listener.run_loop_);
433
434 Close();
435 }
436
437 } // namespace
438