1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/tools/quic/quic_dispatcher.h"
6
7 #include <string>
8
9 #include "base/strings/string_piece.h"
10 #include "net/quic/crypto/crypto_handshake.h"
11 #include "net/quic/crypto/quic_crypto_server_config.h"
12 #include "net/quic/crypto/quic_random.h"
13 #include "net/quic/quic_crypto_stream.h"
14 #include "net/quic/quic_flags.h"
15 #include "net/quic/quic_utils.h"
16 #include "net/quic/test_tools/quic_test_utils.h"
17 #include "net/tools/epoll_server/epoll_server.h"
18 #include "net/tools/quic/quic_packet_writer_wrapper.h"
19 #include "net/tools/quic/quic_time_wait_list_manager.h"
20 #include "net/tools/quic/test_tools/quic_dispatcher_peer.h"
21 #include "net/tools/quic/test_tools/quic_test_utils.h"
22 #include "testing/gmock/include/gmock/gmock.h"
23 #include "testing/gtest/include/gtest/gtest.h"
24
25 using base::StringPiece;
26 using net::EpollServer;
27 using net::test::ConstructEncryptedPacket;
28 using net::test::MockSession;
29 using net::test::ValueRestore;
30 using net::tools::test::MockConnection;
31 using std::make_pair;
32 using testing::DoAll;
33 using testing::InSequence;
34 using testing::Invoke;
35 using testing::WithoutArgs;
36 using testing::_;
37
38 namespace net {
39 namespace tools {
40 namespace test {
41 namespace {
42
43 class TestDispatcher : public QuicDispatcher {
44 public:
TestDispatcher(const QuicConfig & config,const QuicCryptoServerConfig & crypto_config,EpollServer * eps)45 explicit TestDispatcher(const QuicConfig& config,
46 const QuicCryptoServerConfig& crypto_config,
47 EpollServer* eps)
48 : QuicDispatcher(config,
49 crypto_config,
50 QuicSupportedVersions(),
51 new QuicDispatcher::DefaultPacketWriterFactory(),
52 eps) {
53 }
54
55 MOCK_METHOD3(CreateQuicSession, QuicSession*(
56 QuicConnectionId connection_id,
57 const IPEndPoint& server_address,
58 const IPEndPoint& client_address));
59
60 using QuicDispatcher::current_server_address;
61 using QuicDispatcher::current_client_address;
62 };
63
64 // A Connection class which unregisters the session from the dispatcher
65 // when sending connection close.
66 // It'd be slightly more realistic to do this from the Session but it would
67 // involve a lot more mocking.
68 class MockServerConnection : public MockConnection {
69 public:
MockServerConnection(QuicConnectionId connection_id,QuicDispatcher * dispatcher)70 MockServerConnection(QuicConnectionId connection_id,
71 QuicDispatcher* dispatcher)
72 : MockConnection(connection_id, true),
73 dispatcher_(dispatcher) {}
74
UnregisterOnConnectionClosed()75 void UnregisterOnConnectionClosed() {
76 LOG(ERROR) << "Unregistering " << connection_id();
77 dispatcher_->OnConnectionClosed(connection_id(), QUIC_NO_ERROR);
78 }
79 private:
80 QuicDispatcher* dispatcher_;
81 };
82
CreateSession(QuicDispatcher * dispatcher,QuicConnectionId connection_id,const IPEndPoint & client_address,MockSession ** session)83 QuicSession* CreateSession(QuicDispatcher* dispatcher,
84 QuicConnectionId connection_id,
85 const IPEndPoint& client_address,
86 MockSession** session) {
87 MockServerConnection* connection =
88 new MockServerConnection(connection_id, dispatcher);
89 *session = new MockSession(connection);
90 ON_CALL(*connection, SendConnectionClose(_)).WillByDefault(
91 WithoutArgs(Invoke(
92 connection, &MockServerConnection::UnregisterOnConnectionClosed)));
93 EXPECT_CALL(*reinterpret_cast<MockConnection*>((*session)->connection()),
94 ProcessUdpPacket(_, client_address, _));
95
96 return *session;
97 }
98
99 class QuicDispatcherTest : public ::testing::Test {
100 public:
QuicDispatcherTest()101 QuicDispatcherTest()
102 : crypto_config_(QuicCryptoServerConfig::TESTING,
103 QuicRandom::GetInstance()),
104 dispatcher_(config_, crypto_config_, &eps_),
105 session1_(NULL),
106 session2_(NULL) {
107 dispatcher_.Initialize(1);
108 }
109
~QuicDispatcherTest()110 virtual ~QuicDispatcherTest() {}
111
connection1()112 MockConnection* connection1() {
113 return reinterpret_cast<MockConnection*>(session1_->connection());
114 }
115
connection2()116 MockConnection* connection2() {
117 return reinterpret_cast<MockConnection*>(session2_->connection());
118 }
119
ProcessPacket(IPEndPoint client_address,QuicConnectionId connection_id,bool has_version_flag,const string & data)120 void ProcessPacket(IPEndPoint client_address,
121 QuicConnectionId connection_id,
122 bool has_version_flag,
123 const string& data) {
124 scoped_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket(
125 connection_id, has_version_flag, false, 1, data));
126 data_ = string(packet->data(), packet->length());
127 dispatcher_.ProcessPacket(server_address_, client_address, *packet);
128 }
129
ValidatePacket(const QuicEncryptedPacket & packet)130 void ValidatePacket(const QuicEncryptedPacket& packet) {
131 EXPECT_EQ(data_.length(), packet.AsStringPiece().length());
132 EXPECT_EQ(data_, packet.AsStringPiece());
133 }
134
135 EpollServer eps_;
136 QuicConfig config_;
137 QuicCryptoServerConfig crypto_config_;
138 IPEndPoint server_address_;
139 TestDispatcher dispatcher_;
140 MockSession* session1_;
141 MockSession* session2_;
142 string data_;
143 };
144
TEST_F(QuicDispatcherTest,ProcessPackets)145 TEST_F(QuicDispatcherTest, ProcessPackets) {
146 IPEndPoint client_address(net::test::Loopback4(), 1);
147 IPAddressNumber any4;
148 CHECK(net::ParseIPLiteralToNumber("0.0.0.0", &any4));
149 server_address_ = IPEndPoint(any4, 5);
150
151 EXPECT_CALL(dispatcher_, CreateQuicSession(1, _, client_address))
152 .WillOnce(testing::Return(CreateSession(
153 &dispatcher_, 1, client_address, &session1_)));
154 ProcessPacket(client_address, 1, true, "foo");
155 EXPECT_EQ(client_address, dispatcher_.current_client_address());
156 EXPECT_EQ(server_address_, dispatcher_.current_server_address());
157
158
159 EXPECT_CALL(dispatcher_, CreateQuicSession(2, _, client_address))
160 .WillOnce(testing::Return(CreateSession(
161 &dispatcher_, 2, client_address, &session2_)));
162 ProcessPacket(client_address, 2, true, "bar");
163
164 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
165 ProcessUdpPacket(_, _, _)).Times(1).
166 WillOnce(testing::WithArgs<2>(Invoke(
167 this, &QuicDispatcherTest::ValidatePacket)));
168 ProcessPacket(client_address, 1, false, "eep");
169 }
170
TEST_F(QuicDispatcherTest,Shutdown)171 TEST_F(QuicDispatcherTest, Shutdown) {
172 IPEndPoint client_address(net::test::Loopback4(), 1);
173
174 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
175 .WillOnce(testing::Return(CreateSession(
176 &dispatcher_, 1, client_address, &session1_)));
177
178 ProcessPacket(client_address, 1, true, "foo");
179
180 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
181 SendConnectionClose(QUIC_PEER_GOING_AWAY));
182
183 dispatcher_.Shutdown();
184 }
185
186 class MockTimeWaitListManager : public QuicTimeWaitListManager {
187 public:
MockTimeWaitListManager(QuicPacketWriter * writer,QuicServerSessionVisitor * visitor,EpollServer * eps)188 MockTimeWaitListManager(QuicPacketWriter* writer,
189 QuicServerSessionVisitor* visitor,
190 EpollServer* eps)
191 : QuicTimeWaitListManager(writer, visitor, eps, QuicSupportedVersions()) {
192 }
193
194 MOCK_METHOD5(ProcessPacket, void(const IPEndPoint& server_address,
195 const IPEndPoint& client_address,
196 QuicConnectionId connection_id,
197 QuicPacketSequenceNumber sequence_number,
198 const QuicEncryptedPacket& packet));
199 };
200
TEST_F(QuicDispatcherTest,TimeWaitListManager)201 TEST_F(QuicDispatcherTest, TimeWaitListManager) {
202 MockTimeWaitListManager* time_wait_list_manager =
203 new MockTimeWaitListManager(
204 QuicDispatcherPeer::GetWriter(&dispatcher_), &dispatcher_, &eps_);
205 // dispatcher takes the ownership of time_wait_list_manager.
206 QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_,
207 time_wait_list_manager);
208 // Create a new session.
209 IPEndPoint client_address(net::test::Loopback4(), 1);
210 QuicConnectionId connection_id = 1;
211 EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, _, client_address))
212 .WillOnce(testing::Return(CreateSession(
213 &dispatcher_, connection_id, client_address, &session1_)));
214 ProcessPacket(client_address, connection_id, true, "foo");
215
216 // Close the connection by sending public reset packet.
217 QuicPublicResetPacket packet;
218 packet.public_header.connection_id = connection_id;
219 packet.public_header.reset_flag = true;
220 packet.public_header.version_flag = false;
221 packet.rejected_sequence_number = 19191;
222 packet.nonce_proof = 132232;
223 scoped_ptr<QuicEncryptedPacket> encrypted(
224 QuicFramer::BuildPublicResetPacket(packet));
225 EXPECT_CALL(*session1_, OnConnectionClosed(QUIC_PUBLIC_RESET, true)).Times(1)
226 .WillOnce(WithoutArgs(Invoke(
227 reinterpret_cast<MockServerConnection*>(session1_->connection()),
228 &MockServerConnection::UnregisterOnConnectionClosed)));
229 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
230 ProcessUdpPacket(_, _, _))
231 .WillOnce(Invoke(
232 reinterpret_cast<MockConnection*>(session1_->connection()),
233 &MockConnection::ReallyProcessUdpPacket));
234 dispatcher_.ProcessPacket(IPEndPoint(), client_address, *encrypted);
235 EXPECT_TRUE(time_wait_list_manager->IsConnectionIdInTimeWait(connection_id));
236
237 // Dispatcher forwards subsequent packets for this connection_id to the time
238 // wait list manager.
239 EXPECT_CALL(*time_wait_list_manager,
240 ProcessPacket(_, _, connection_id, _, _)).Times(1);
241 ProcessPacket(client_address, connection_id, true, "foo");
242 }
243
TEST_F(QuicDispatcherTest,StrayPacketToTimeWaitListManager)244 TEST_F(QuicDispatcherTest, StrayPacketToTimeWaitListManager) {
245 MockTimeWaitListManager* time_wait_list_manager =
246 new MockTimeWaitListManager(
247 QuicDispatcherPeer::GetWriter(&dispatcher_), &dispatcher_, &eps_);
248 // dispatcher takes the ownership of time_wait_list_manager.
249 QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_,
250 time_wait_list_manager);
251
252 IPEndPoint client_address(net::test::Loopback4(), 1);
253 QuicConnectionId connection_id = 1;
254 // Dispatcher forwards all packets for this connection_id to the time wait
255 // list manager.
256 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, _)).Times(0);
257 EXPECT_CALL(*time_wait_list_manager,
258 ProcessPacket(_, _, connection_id, _, _)).Times(1);
259 string data = "foo";
260 ProcessPacket(client_address, connection_id, false, "foo");
261 }
262
263 class BlockingWriter : public QuicPacketWriterWrapper {
264 public:
BlockingWriter()265 BlockingWriter() : write_blocked_(false) {}
266
IsWriteBlocked() const267 virtual bool IsWriteBlocked() const OVERRIDE { return write_blocked_; }
SetWritable()268 virtual void SetWritable() OVERRIDE { write_blocked_ = false; }
269
WritePacket(const char * buffer,size_t buf_len,const IPAddressNumber & self_client_address,const IPEndPoint & peer_client_address)270 virtual WriteResult WritePacket(
271 const char* buffer,
272 size_t buf_len,
273 const IPAddressNumber& self_client_address,
274 const IPEndPoint& peer_client_address) OVERRIDE {
275 // It would be quite possible to actually implement this method here with
276 // the fake blocked status, but it would be significantly more work in
277 // Chromium, and since it's not called anyway, don't bother.
278 LOG(DFATAL) << "Not supported";
279 return WriteResult();
280 }
281
282 bool write_blocked_;
283 };
284
285 class QuicDispatcherWriteBlockedListTest : public QuicDispatcherTest {
286 public:
SetUp()287 virtual void SetUp() {
288 writer_ = new BlockingWriter;
289 QuicDispatcherPeer::SetPacketWriterFactory(&dispatcher_,
290 new TestWriterFactory());
291 QuicDispatcherPeer::UseWriter(&dispatcher_, writer_);
292
293 IPEndPoint client_address(net::test::Loopback4(), 1);
294
295 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
296 .WillOnce(testing::Return(CreateSession(
297 &dispatcher_, 1, client_address, &session1_)));
298 ProcessPacket(client_address, 1, true, "foo");
299
300 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
301 .WillOnce(testing::Return(CreateSession(
302 &dispatcher_, 2, client_address, &session2_)));
303 ProcessPacket(client_address, 2, true, "bar");
304
305 blocked_list_ = QuicDispatcherPeer::GetWriteBlockedList(&dispatcher_);
306 }
307
TearDown()308 virtual void TearDown() {
309 EXPECT_CALL(*connection1(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
310 EXPECT_CALL(*connection2(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
311 dispatcher_.Shutdown();
312 }
313
SetBlocked()314 void SetBlocked() {
315 writer_->write_blocked_ = true;
316 }
317
BlockConnection2()318 void BlockConnection2() {
319 writer_->write_blocked_ = true;
320 dispatcher_.OnWriteBlocked(connection2());
321 }
322
323 protected:
324 BlockingWriter* writer_;
325 QuicDispatcher::WriteBlockedList* blocked_list_;
326 };
327
TEST_F(QuicDispatcherWriteBlockedListTest,BasicOnCanWrite)328 TEST_F(QuicDispatcherWriteBlockedListTest, BasicOnCanWrite) {
329 // No OnCanWrite calls because no connections are blocked.
330 dispatcher_.OnCanWrite();
331
332 // Register connection 1 for events, and make sure it's notified.
333 SetBlocked();
334 dispatcher_.OnWriteBlocked(connection1());
335 EXPECT_CALL(*connection1(), OnCanWrite());
336 dispatcher_.OnCanWrite();
337
338 // It should get only one notification.
339 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
340 dispatcher_.OnCanWrite();
341 EXPECT_FALSE(dispatcher_.HasPendingWrites());
342 }
343
TEST_F(QuicDispatcherWriteBlockedListTest,OnCanWriteOrder)344 TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteOrder) {
345 // Make sure we handle events in order.
346 InSequence s;
347 SetBlocked();
348 dispatcher_.OnWriteBlocked(connection1());
349 dispatcher_.OnWriteBlocked(connection2());
350 EXPECT_CALL(*connection1(), OnCanWrite());
351 EXPECT_CALL(*connection2(), OnCanWrite());
352 dispatcher_.OnCanWrite();
353
354 // Check the other ordering.
355 SetBlocked();
356 dispatcher_.OnWriteBlocked(connection2());
357 dispatcher_.OnWriteBlocked(connection1());
358 EXPECT_CALL(*connection2(), OnCanWrite());
359 EXPECT_CALL(*connection1(), OnCanWrite());
360 dispatcher_.OnCanWrite();
361 }
362
TEST_F(QuicDispatcherWriteBlockedListTest,OnCanWriteRemove)363 TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteRemove) {
364 // Add and remove one connction.
365 SetBlocked();
366 dispatcher_.OnWriteBlocked(connection1());
367 blocked_list_->erase(connection1());
368 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
369 dispatcher_.OnCanWrite();
370
371 // Add and remove one connction and make sure it doesn't affect others.
372 SetBlocked();
373 dispatcher_.OnWriteBlocked(connection1());
374 dispatcher_.OnWriteBlocked(connection2());
375 blocked_list_->erase(connection1());
376 EXPECT_CALL(*connection2(), OnCanWrite());
377 dispatcher_.OnCanWrite();
378
379 // Add it, remove it, and add it back and make sure things are OK.
380 SetBlocked();
381 dispatcher_.OnWriteBlocked(connection1());
382 blocked_list_->erase(connection1());
383 dispatcher_.OnWriteBlocked(connection1());
384 EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
385 dispatcher_.OnCanWrite();
386 }
387
TEST_F(QuicDispatcherWriteBlockedListTest,DoubleAdd)388 TEST_F(QuicDispatcherWriteBlockedListTest, DoubleAdd) {
389 // Make sure a double add does not necessitate a double remove.
390 SetBlocked();
391 dispatcher_.OnWriteBlocked(connection1());
392 dispatcher_.OnWriteBlocked(connection1());
393 blocked_list_->erase(connection1());
394 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
395 dispatcher_.OnCanWrite();
396
397 // Make sure a double add does not result in two OnCanWrite calls.
398 SetBlocked();
399 dispatcher_.OnWriteBlocked(connection1());
400 dispatcher_.OnWriteBlocked(connection1());
401 EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
402 dispatcher_.OnCanWrite();
403 }
404
TEST_F(QuicDispatcherWriteBlockedListTest,OnCanWriteHandleBlock)405 TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteHandleBlock) {
406 // Finally make sure if we write block on a write call, we stop calling.
407 InSequence s;
408 SetBlocked();
409 dispatcher_.OnWriteBlocked(connection1());
410 dispatcher_.OnWriteBlocked(connection2());
411 EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
412 Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
413 EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
414 dispatcher_.OnCanWrite();
415
416 // And we'll resume where we left off when we get another call.
417 EXPECT_CALL(*connection2(), OnCanWrite());
418 dispatcher_.OnCanWrite();
419 }
420
TEST_F(QuicDispatcherWriteBlockedListTest,LimitedWrites)421 TEST_F(QuicDispatcherWriteBlockedListTest, LimitedWrites) {
422 // Make sure we call both writers. The first will register for more writing
423 // but should not be immediately called due to limits.
424 InSequence s;
425 SetBlocked();
426 dispatcher_.OnWriteBlocked(connection1());
427 dispatcher_.OnWriteBlocked(connection2());
428 EXPECT_CALL(*connection1(), OnCanWrite());
429 EXPECT_CALL(*connection2(), OnCanWrite()).WillOnce(
430 Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection2));
431 dispatcher_.OnCanWrite();
432 EXPECT_TRUE(dispatcher_.HasPendingWrites());
433
434 // Now call OnCanWrite again, and connection1 should get its second chance
435 EXPECT_CALL(*connection2(), OnCanWrite());
436 dispatcher_.OnCanWrite();
437 EXPECT_FALSE(dispatcher_.HasPendingWrites());
438 }
439
TEST_F(QuicDispatcherWriteBlockedListTest,TestWriteLimits)440 TEST_F(QuicDispatcherWriteBlockedListTest, TestWriteLimits) {
441 // Finally make sure if we write block on a write call, we stop calling.
442 InSequence s;
443 SetBlocked();
444 dispatcher_.OnWriteBlocked(connection1());
445 dispatcher_.OnWriteBlocked(connection2());
446 EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
447 Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
448 EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
449 dispatcher_.OnCanWrite();
450 EXPECT_TRUE(dispatcher_.HasPendingWrites());
451
452 // And we'll resume where we left off when we get another call.
453 EXPECT_CALL(*connection2(), OnCanWrite());
454 dispatcher_.OnCanWrite();
455 EXPECT_FALSE(dispatcher_.HasPendingWrites());
456 }
457
458 } // namespace
459 } // namespace test
460 } // namespace tools
461 } // namespace net
462