1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/protocol/channel_multiplexer.h"
6
7 #include "base/bind.h"
8 #include "base/message_loop/message_loop.h"
9 #include "base/run_loop.h"
10 #include "net/base/net_errors.h"
11 #include "net/socket/socket.h"
12 #include "net/socket/stream_socket.h"
13 #include "remoting/base/constants.h"
14 #include "remoting/protocol/connection_tester.h"
15 #include "remoting/protocol/fake_session.h"
16 #include "testing/gmock/include/gmock/gmock.h"
17 #include "testing/gtest/include/gtest/gtest.h"
18
19 using testing::_;
20 using testing::AtMost;
21 using testing::InvokeWithoutArgs;
22
23 namespace remoting {
24 namespace protocol {
25
26 namespace {
27
28 const int kMessageSize = 1024;
29 const int kMessages = 100;
30 const char kMuxChannelName[] = "mux";
31
32 const char kTestChannelName[] = "test";
33 const char kTestChannelName2[] = "test2";
34
35
QuitCurrentThread()36 void QuitCurrentThread() {
37 base::MessageLoop::current()->PostTask(FROM_HERE,
38 base::MessageLoop::QuitClosure());
39 }
40
41 class MockSocketCallback {
42 public:
43 MOCK_METHOD1(OnDone, void(int result));
44 };
45
46 class MockConnectCallback {
47 public:
48 MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket));
OnConnected(scoped_ptr<net::StreamSocket> socket)49 void OnConnected(scoped_ptr<net::StreamSocket> socket) {
50 OnConnectedPtr(socket.release());
51 }
52 };
53
54 } // namespace
55
56 class ChannelMultiplexerTest : public testing::Test {
57 public:
DeleteAll()58 void DeleteAll() {
59 host_socket1_.reset();
60 host_socket2_.reset();
61 client_socket1_.reset();
62 client_socket2_.reset();
63 host_mux_.reset();
64 client_mux_.reset();
65 }
66
DeleteAfterSessionFail()67 void DeleteAfterSessionFail() {
68 host_mux_->CancelChannelCreation(kTestChannelName2);
69 DeleteAll();
70 }
71
72 protected:
SetUp()73 virtual void SetUp() OVERRIDE {
74 // Create pair of multiplexers and connect them to each other.
75 host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName));
76 client_mux_.reset(new ChannelMultiplexer(&client_session_,
77 kMuxChannelName));
78 }
79
80 // Connect sockets to each other. Must be called after we've created at least
81 // one channel with each multiplexer.
ConnectSockets()82 void ConnectSockets() {
83 FakeSocket* host_socket =
84 host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
85 FakeSocket* client_socket =
86 client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
87 host_socket->PairWith(client_socket);
88
89 // Make writes asynchronous in one direction.
90 host_socket->set_async_write(true);
91 }
92
CreateChannel(const std::string & name,scoped_ptr<net::StreamSocket> * host_socket,scoped_ptr<net::StreamSocket> * client_socket)93 void CreateChannel(const std::string& name,
94 scoped_ptr<net::StreamSocket>* host_socket,
95 scoped_ptr<net::StreamSocket>* client_socket) {
96 int counter = 2;
97 host_mux_->CreateStreamChannel(name, base::Bind(
98 &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
99 host_socket, &counter));
100 client_mux_->CreateStreamChannel(name, base::Bind(
101 &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
102 client_socket, &counter));
103
104 message_loop_.Run();
105
106 EXPECT_TRUE(host_socket->get());
107 EXPECT_TRUE(client_socket->get());
108 }
109
OnChannelConnected(scoped_ptr<net::StreamSocket> * storage,int * counter,scoped_ptr<net::StreamSocket> socket)110 void OnChannelConnected(
111 scoped_ptr<net::StreamSocket>* storage,
112 int* counter,
113 scoped_ptr<net::StreamSocket> socket) {
114 *storage = socket.Pass();
115 --(*counter);
116 EXPECT_GE(*counter, 0);
117 if (*counter == 0)
118 QuitCurrentThread();
119 }
120
CreateTestBuffer(int size)121 scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) {
122 scoped_refptr<net::IOBufferWithSize> result =
123 new net::IOBufferWithSize(size);
124 for (int i = 0; i< size; ++i) {
125 result->data()[i] = rand() % 256;
126 }
127 return result;
128 }
129
130 base::MessageLoop message_loop_;
131
132 FakeSession host_session_;
133 FakeSession client_session_;
134
135 scoped_ptr<ChannelMultiplexer> host_mux_;
136 scoped_ptr<ChannelMultiplexer> client_mux_;
137
138 scoped_ptr<net::StreamSocket> host_socket1_;
139 scoped_ptr<net::StreamSocket> client_socket1_;
140 scoped_ptr<net::StreamSocket> host_socket2_;
141 scoped_ptr<net::StreamSocket> client_socket2_;
142 };
143
144
TEST_F(ChannelMultiplexerTest,OneChannel)145 TEST_F(ChannelMultiplexerTest, OneChannel) {
146 scoped_ptr<net::StreamSocket> host_socket;
147 scoped_ptr<net::StreamSocket> client_socket;
148 ASSERT_NO_FATAL_FAILURE(
149 CreateChannel(kTestChannelName, &host_socket, &client_socket));
150
151 ConnectSockets();
152
153 StreamConnectionTester tester(host_socket.get(), client_socket.get(),
154 kMessageSize, kMessages);
155 tester.Start();
156 message_loop_.Run();
157 tester.CheckResults();
158 }
159
TEST_F(ChannelMultiplexerTest,TwoChannels)160 TEST_F(ChannelMultiplexerTest, TwoChannels) {
161 scoped_ptr<net::StreamSocket> host_socket1_;
162 scoped_ptr<net::StreamSocket> client_socket1_;
163 ASSERT_NO_FATAL_FAILURE(
164 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
165
166 scoped_ptr<net::StreamSocket> host_socket2_;
167 scoped_ptr<net::StreamSocket> client_socket2_;
168 ASSERT_NO_FATAL_FAILURE(
169 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
170
171 ConnectSockets();
172
173 StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
174 kMessageSize, kMessages);
175 StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
176 kMessageSize, kMessages);
177 tester1.Start();
178 tester2.Start();
179 while (!tester1.done() || !tester2.done()) {
180 message_loop_.Run();
181 }
182 tester1.CheckResults();
183 tester2.CheckResults();
184 }
185
186 // Four channels, two in each direction
TEST_F(ChannelMultiplexerTest,FourChannels)187 TEST_F(ChannelMultiplexerTest, FourChannels) {
188 scoped_ptr<net::StreamSocket> host_socket1_;
189 scoped_ptr<net::StreamSocket> client_socket1_;
190 ASSERT_NO_FATAL_FAILURE(
191 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
192
193 scoped_ptr<net::StreamSocket> host_socket2_;
194 scoped_ptr<net::StreamSocket> client_socket2_;
195 ASSERT_NO_FATAL_FAILURE(
196 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
197
198 scoped_ptr<net::StreamSocket> host_socket3;
199 scoped_ptr<net::StreamSocket> client_socket3;
200 ASSERT_NO_FATAL_FAILURE(
201 CreateChannel("test3", &host_socket3, &client_socket3));
202
203 scoped_ptr<net::StreamSocket> host_socket4;
204 scoped_ptr<net::StreamSocket> client_socket4;
205 ASSERT_NO_FATAL_FAILURE(
206 CreateChannel("ch4", &host_socket4, &client_socket4));
207
208 ConnectSockets();
209
210 StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
211 kMessageSize, kMessages);
212 StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
213 kMessageSize, kMessages);
214 StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(),
215 kMessageSize, kMessages);
216 StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(),
217 kMessageSize, kMessages);
218 tester1.Start();
219 tester2.Start();
220 tester3.Start();
221 tester4.Start();
222 while (!tester1.done() || !tester2.done() ||
223 !tester3.done() || !tester4.done()) {
224 message_loop_.Run();
225 }
226 tester1.CheckResults();
227 tester2.CheckResults();
228 tester3.CheckResults();
229 tester4.CheckResults();
230 }
231
TEST_F(ChannelMultiplexerTest,WriteFailSync)232 TEST_F(ChannelMultiplexerTest, WriteFailSync) {
233 scoped_ptr<net::StreamSocket> host_socket1_;
234 scoped_ptr<net::StreamSocket> client_socket1_;
235 ASSERT_NO_FATAL_FAILURE(
236 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
237
238 scoped_ptr<net::StreamSocket> host_socket2_;
239 scoped_ptr<net::StreamSocket> client_socket2_;
240 ASSERT_NO_FATAL_FAILURE(
241 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
242
243 ConnectSockets();
244
245 host_session_.GetStreamChannel(kMuxChannelName)->
246 set_next_write_error(net::ERR_FAILED);
247 host_session_.GetStreamChannel(kMuxChannelName)->
248 set_async_write(false);
249
250 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
251
252 MockSocketCallback cb1;
253 MockSocketCallback cb2;
254
255 EXPECT_CALL(cb1, OnDone(_))
256 .Times(0);
257 EXPECT_CALL(cb2, OnDone(_))
258 .Times(0);
259
260 EXPECT_EQ(net::ERR_FAILED,
261 host_socket1_->Write(buf.get(),
262 buf->size(),
263 base::Bind(&MockSocketCallback::OnDone,
264 base::Unretained(&cb1))));
265 EXPECT_EQ(net::ERR_FAILED,
266 host_socket2_->Write(buf.get(),
267 buf->size(),
268 base::Bind(&MockSocketCallback::OnDone,
269 base::Unretained(&cb2))));
270
271 base::RunLoop().RunUntilIdle();
272 }
273
TEST_F(ChannelMultiplexerTest,WriteFailAsync)274 TEST_F(ChannelMultiplexerTest, WriteFailAsync) {
275 ASSERT_NO_FATAL_FAILURE(
276 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
277
278 ASSERT_NO_FATAL_FAILURE(
279 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
280
281 ConnectSockets();
282
283 host_session_.GetStreamChannel(kMuxChannelName)->
284 set_next_write_error(net::ERR_FAILED);
285 host_session_.GetStreamChannel(kMuxChannelName)->
286 set_async_write(true);
287
288 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
289
290 MockSocketCallback cb1;
291 MockSocketCallback cb2;
292 EXPECT_CALL(cb1, OnDone(net::ERR_FAILED));
293 EXPECT_CALL(cb2, OnDone(net::ERR_FAILED));
294
295 EXPECT_EQ(net::ERR_IO_PENDING,
296 host_socket1_->Write(buf.get(),
297 buf->size(),
298 base::Bind(&MockSocketCallback::OnDone,
299 base::Unretained(&cb1))));
300 EXPECT_EQ(net::ERR_IO_PENDING,
301 host_socket2_->Write(buf.get(),
302 buf->size(),
303 base::Bind(&MockSocketCallback::OnDone,
304 base::Unretained(&cb2))));
305
306 base::RunLoop().RunUntilIdle();
307 }
308
TEST_F(ChannelMultiplexerTest,DeleteWhenFailed)309 TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
310 ASSERT_NO_FATAL_FAILURE(
311 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
312 ASSERT_NO_FATAL_FAILURE(
313 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
314
315 ConnectSockets();
316
317 host_session_.GetStreamChannel(kMuxChannelName)->
318 set_next_write_error(net::ERR_FAILED);
319 host_session_.GetStreamChannel(kMuxChannelName)->
320 set_async_write(true);
321
322 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
323
324 MockSocketCallback cb1;
325 MockSocketCallback cb2;
326
327 EXPECT_CALL(cb1, OnDone(net::ERR_FAILED))
328 .Times(AtMost(1))
329 .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
330 EXPECT_CALL(cb2, OnDone(net::ERR_FAILED))
331 .Times(AtMost(1))
332 .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
333
334 EXPECT_EQ(net::ERR_IO_PENDING,
335 host_socket1_->Write(buf.get(),
336 buf->size(),
337 base::Bind(&MockSocketCallback::OnDone,
338 base::Unretained(&cb1))));
339 EXPECT_EQ(net::ERR_IO_PENDING,
340 host_socket2_->Write(buf.get(),
341 buf->size(),
342 base::Bind(&MockSocketCallback::OnDone,
343 base::Unretained(&cb2))));
344
345 base::RunLoop().RunUntilIdle();
346
347 // Check that the sockets were destroyed.
348 EXPECT_FALSE(host_mux_.get());
349 }
350
TEST_F(ChannelMultiplexerTest,SessionFail)351 TEST_F(ChannelMultiplexerTest, SessionFail) {
352 host_session_.set_async_creation(true);
353 host_session_.set_error(AUTHENTICATION_FAILED);
354
355 MockConnectCallback cb1;
356 MockConnectCallback cb2;
357
358 host_mux_->CreateStreamChannel(kTestChannelName, base::Bind(
359 &MockConnectCallback::OnConnected, base::Unretained(&cb1)));
360 host_mux_->CreateStreamChannel(kTestChannelName2, base::Bind(
361 &MockConnectCallback::OnConnected, base::Unretained(&cb2)));
362
363 EXPECT_CALL(cb1, OnConnectedPtr(NULL))
364 .Times(AtMost(1))
365 .WillOnce(InvokeWithoutArgs(
366 this, &ChannelMultiplexerTest::DeleteAfterSessionFail));
367 EXPECT_CALL(cb2, OnConnectedPtr(_))
368 .Times(0);
369
370 base::RunLoop().RunUntilIdle();
371 }
372
373 } // namespace protocol
374 } // namespace remoting
375