• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "jingle/glue/pseudotcp_adapter.h"
6 
7 #include <vector>
8 
9 #include "base/bind.h"
10 #include "base/bind_helpers.h"
11 #include "base/compiler_specific.h"
12 #include "jingle/glue/thread_wrapper.h"
13 #include "net/base/io_buffer.h"
14 #include "net/base/net_errors.h"
15 #include "net/base/test_completion_callback.h"
16 #include "net/udp/udp_socket.h"
17 #include "testing/gmock/include/gmock/gmock.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 
20 
21 namespace jingle_glue {
22 namespace {
23 class FakeSocket;
24 }  // namespace
25 }  // namespace jingle_glue
26 
27 namespace jingle_glue {
28 
29 namespace {
30 
31 const int kMessageSize = 1024;
32 const int kMessages = 100;
33 const int kTestDataSize = kMessages * kMessageSize;
34 
35 class RateLimiter {
36  public:
~RateLimiter()37   virtual ~RateLimiter() { };
38   // Returns true if the new packet needs to be dropped, false otherwise.
39   virtual bool DropNextPacket() = 0;
40 };
41 
42 class LeakyBucket : public RateLimiter {
43  public:
44   // |rate| is in drops per second.
LeakyBucket(double volume,double rate)45   LeakyBucket(double volume, double rate)
46       : volume_(volume),
47         rate_(rate),
48         level_(0.0),
49         last_update_(base::TimeTicks::HighResNow()) {
50   }
51 
~LeakyBucket()52   virtual ~LeakyBucket() { }
53 
DropNextPacket()54   virtual bool DropNextPacket() OVERRIDE {
55     base::TimeTicks now = base::TimeTicks::HighResNow();
56     double interval = (now - last_update_).InSecondsF();
57     last_update_ = now;
58     level_ = level_ + 1.0 - interval * rate_;
59     if (level_ > volume_) {
60       level_ = volume_;
61       return true;
62     } else if (level_ < 0.0) {
63       level_ = 0.0;
64     }
65     return false;
66   }
67 
68  private:
69   double volume_;
70   double rate_;
71   double level_;
72   base::TimeTicks last_update_;
73 };
74 
75 class FakeSocket : public net::Socket {
76  public:
FakeSocket()77   FakeSocket()
78       : rate_limiter_(NULL),
79         latency_ms_(0) {
80   }
~FakeSocket()81   virtual ~FakeSocket() { }
82 
AppendInputPacket(const std::vector<char> & data)83   void AppendInputPacket(const std::vector<char>& data) {
84     if (rate_limiter_ && rate_limiter_->DropNextPacket())
85       return;  // Lose the packet.
86 
87     if (!read_callback_.is_null()) {
88       int size = std::min(read_buffer_size_, static_cast<int>(data.size()));
89       memcpy(read_buffer_->data(), &data[0], data.size());
90       net::CompletionCallback cb = read_callback_;
91       read_callback_.Reset();
92       read_buffer_ = NULL;
93       cb.Run(size);
94     } else {
95       incoming_packets_.push_back(data);
96     }
97   }
98 
Connect(FakeSocket * peer_socket)99   void Connect(FakeSocket* peer_socket) {
100     peer_socket_ = peer_socket;
101   }
102 
set_rate_limiter(RateLimiter * rate_limiter)103   void set_rate_limiter(RateLimiter* rate_limiter) {
104     rate_limiter_ = rate_limiter;
105   };
106 
set_latency(int latency_ms)107   void set_latency(int latency_ms) { latency_ms_ = latency_ms; };
108 
109   // net::Socket interface.
Read(net::IOBuffer * buf,int buf_len,const net::CompletionCallback & callback)110   virtual int Read(net::IOBuffer* buf, int buf_len,
111                    const net::CompletionCallback& callback) OVERRIDE {
112     CHECK(read_callback_.is_null());
113     CHECK(buf);
114 
115     if (incoming_packets_.size() > 0) {
116       scoped_refptr<net::IOBuffer> buffer(buf);
117       int size = std::min(
118           static_cast<int>(incoming_packets_.front().size()), buf_len);
119       memcpy(buffer->data(), &*incoming_packets_.front().begin(), size);
120       incoming_packets_.pop_front();
121       return size;
122     } else {
123       read_callback_ = callback;
124       read_buffer_ = buf;
125       read_buffer_size_ = buf_len;
126       return net::ERR_IO_PENDING;
127     }
128   }
129 
Write(net::IOBuffer * buf,int buf_len,const net::CompletionCallback & callback)130   virtual int Write(net::IOBuffer* buf, int buf_len,
131                     const net::CompletionCallback& callback) OVERRIDE {
132     DCHECK(buf);
133     if (peer_socket_) {
134       base::MessageLoop::current()->PostDelayedTask(
135           FROM_HERE,
136           base::Bind(&FakeSocket::AppendInputPacket,
137                      base::Unretained(peer_socket_),
138                      std::vector<char>(buf->data(), buf->data() + buf_len)),
139           base::TimeDelta::FromMilliseconds(latency_ms_));
140     }
141 
142     return buf_len;
143   }
144 
SetReceiveBufferSize(int32 size)145   virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
146     NOTIMPLEMENTED();
147     return false;
148   }
SetSendBufferSize(int32 size)149   virtual bool SetSendBufferSize(int32 size) OVERRIDE {
150     NOTIMPLEMENTED();
151     return false;
152   }
153 
154  private:
155   scoped_refptr<net::IOBuffer> read_buffer_;
156   int read_buffer_size_;
157   net::CompletionCallback read_callback_;
158 
159   std::deque<std::vector<char> > incoming_packets_;
160 
161   FakeSocket* peer_socket_;
162   RateLimiter* rate_limiter_;
163   int latency_ms_;
164 };
165 
166 class TCPChannelTester : public base::RefCountedThreadSafe<TCPChannelTester> {
167  public:
TCPChannelTester(base::MessageLoop * message_loop,net::Socket * client_socket,net::Socket * host_socket)168   TCPChannelTester(base::MessageLoop* message_loop,
169                    net::Socket* client_socket,
170                    net::Socket* host_socket)
171       : message_loop_(message_loop),
172         host_socket_(host_socket),
173         client_socket_(client_socket),
174         done_(false),
175         write_errors_(0),
176         read_errors_(0) {}
177 
Start()178   void Start() {
179     message_loop_->PostTask(
180         FROM_HERE, base::Bind(&TCPChannelTester::DoStart, this));
181   }
182 
CheckResults()183   void CheckResults() {
184     EXPECT_EQ(0, write_errors_);
185     EXPECT_EQ(0, read_errors_);
186 
187     ASSERT_EQ(kTestDataSize + kMessageSize, input_buffer_->capacity());
188 
189     output_buffer_->SetOffset(0);
190     ASSERT_EQ(kTestDataSize, output_buffer_->size());
191 
192     EXPECT_EQ(0, memcmp(output_buffer_->data(),
193                         input_buffer_->StartOfBuffer(), kTestDataSize));
194   }
195 
196  protected:
~TCPChannelTester()197   virtual ~TCPChannelTester() {}
198 
Done()199   void Done() {
200     done_ = true;
201     message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
202   }
203 
DoStart()204   void DoStart() {
205     InitBuffers();
206     DoRead();
207     DoWrite();
208   }
209 
InitBuffers()210   void InitBuffers() {
211     output_buffer_ = new net::DrainableIOBuffer(
212         new net::IOBuffer(kTestDataSize), kTestDataSize);
213     memset(output_buffer_->data(), 123, kTestDataSize);
214 
215     input_buffer_ = new net::GrowableIOBuffer();
216     // Always keep kMessageSize bytes available at the end of the input buffer.
217     input_buffer_->SetCapacity(kMessageSize);
218   }
219 
DoWrite()220   void DoWrite() {
221     int result = 1;
222     while (result > 0) {
223       if (output_buffer_->BytesRemaining() == 0)
224         break;
225 
226       int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
227                                     kMessageSize);
228       result = client_socket_->Write(
229           output_buffer_.get(),
230           bytes_to_write,
231           base::Bind(&TCPChannelTester::OnWritten, base::Unretained(this)));
232       HandleWriteResult(result);
233     }
234   }
235 
OnWritten(int result)236   void OnWritten(int result) {
237     HandleWriteResult(result);
238     DoWrite();
239   }
240 
HandleWriteResult(int result)241   void HandleWriteResult(int result) {
242     if (result <= 0 && result != net::ERR_IO_PENDING) {
243       LOG(ERROR) << "Received error " << result << " when trying to write";
244       write_errors_++;
245       Done();
246     } else if (result > 0) {
247       output_buffer_->DidConsume(result);
248     }
249   }
250 
DoRead()251   void DoRead() {
252     int result = 1;
253     while (result > 0) {
254       input_buffer_->set_offset(input_buffer_->capacity() - kMessageSize);
255 
256       result = host_socket_->Read(
257           input_buffer_.get(),
258           kMessageSize,
259           base::Bind(&TCPChannelTester::OnRead, base::Unretained(this)));
260       HandleReadResult(result);
261     };
262   }
263 
OnRead(int result)264   void OnRead(int result) {
265     HandleReadResult(result);
266     DoRead();
267   }
268 
HandleReadResult(int result)269   void HandleReadResult(int result) {
270     if (result <= 0 && result != net::ERR_IO_PENDING) {
271       if (!done_) {
272         LOG(ERROR) << "Received error " << result << " when trying to read";
273         read_errors_++;
274         Done();
275       }
276     } else if (result > 0) {
277       // Allocate memory for the next read.
278       input_buffer_->SetCapacity(input_buffer_->capacity() + result);
279       if (input_buffer_->capacity() == kTestDataSize + kMessageSize)
280         Done();
281     }
282   }
283 
284  private:
285   friend class base::RefCountedThreadSafe<TCPChannelTester>;
286 
287   base::MessageLoop* message_loop_;
288   net::Socket* host_socket_;
289   net::Socket* client_socket_;
290   bool done_;
291 
292   scoped_refptr<net::DrainableIOBuffer> output_buffer_;
293   scoped_refptr<net::GrowableIOBuffer> input_buffer_;
294 
295   int write_errors_;
296   int read_errors_;
297 };
298 
299 class PseudoTcpAdapterTest : public testing::Test {
300  protected:
SetUp()301   virtual void SetUp() OVERRIDE {
302     JingleThreadWrapper::EnsureForCurrentMessageLoop();
303 
304     host_socket_ = new FakeSocket();
305     client_socket_ = new FakeSocket();
306 
307     host_socket_->Connect(client_socket_);
308     client_socket_->Connect(host_socket_);
309 
310     host_pseudotcp_.reset(new PseudoTcpAdapter(host_socket_));
311     client_pseudotcp_.reset(new PseudoTcpAdapter(client_socket_));
312   }
313 
314   FakeSocket* host_socket_;
315   FakeSocket* client_socket_;
316 
317   scoped_ptr<PseudoTcpAdapter> host_pseudotcp_;
318   scoped_ptr<PseudoTcpAdapter> client_pseudotcp_;
319   base::MessageLoop message_loop_;
320 };
321 
TEST_F(PseudoTcpAdapterTest,DataTransfer)322 TEST_F(PseudoTcpAdapterTest, DataTransfer) {
323   net::TestCompletionCallback host_connect_cb;
324   net::TestCompletionCallback client_connect_cb;
325 
326   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
327   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
328 
329   if (rv1 == net::ERR_IO_PENDING)
330     rv1 = host_connect_cb.WaitForResult();
331   if (rv2 == net::ERR_IO_PENDING)
332     rv2 = client_connect_cb.WaitForResult();
333   ASSERT_EQ(net::OK, rv1);
334   ASSERT_EQ(net::OK, rv2);
335 
336   scoped_refptr<TCPChannelTester> tester =
337       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
338                            client_pseudotcp_.get());
339 
340   tester->Start();
341   message_loop_.Run();
342   tester->CheckResults();
343 }
344 
TEST_F(PseudoTcpAdapterTest,LimitedChannel)345 TEST_F(PseudoTcpAdapterTest, LimitedChannel) {
346   const int kLatencyMs = 20;
347   const int kPacketsPerSecond = 400;
348   const int kBurstPackets = 10;
349 
350   LeakyBucket host_limiter(kBurstPackets, kPacketsPerSecond);
351   host_socket_->set_latency(kLatencyMs);
352   host_socket_->set_rate_limiter(&host_limiter);
353 
354   LeakyBucket client_limiter(kBurstPackets, kPacketsPerSecond);
355   host_socket_->set_latency(kLatencyMs);
356   client_socket_->set_rate_limiter(&client_limiter);
357 
358   net::TestCompletionCallback host_connect_cb;
359   net::TestCompletionCallback client_connect_cb;
360 
361   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
362   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
363 
364   if (rv1 == net::ERR_IO_PENDING)
365     rv1 = host_connect_cb.WaitForResult();
366   if (rv2 == net::ERR_IO_PENDING)
367     rv2 = client_connect_cb.WaitForResult();
368   ASSERT_EQ(net::OK, rv1);
369   ASSERT_EQ(net::OK, rv2);
370 
371   scoped_refptr<TCPChannelTester> tester =
372       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
373                            client_pseudotcp_.get());
374 
375   tester->Start();
376   message_loop_.Run();
377   tester->CheckResults();
378 }
379 
380 class DeleteOnConnected {
381  public:
DeleteOnConnected(base::MessageLoop * message_loop,scoped_ptr<PseudoTcpAdapter> * adapter)382   DeleteOnConnected(base::MessageLoop* message_loop,
383                     scoped_ptr<PseudoTcpAdapter>* adapter)
384       : message_loop_(message_loop), adapter_(adapter) {}
OnConnected(int error)385   void OnConnected(int error) {
386     adapter_->reset();
387     message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
388   }
389   base::MessageLoop* message_loop_;
390   scoped_ptr<PseudoTcpAdapter>* adapter_;
391 };
392 
TEST_F(PseudoTcpAdapterTest,DeleteOnConnected)393 TEST_F(PseudoTcpAdapterTest, DeleteOnConnected) {
394   // This test verifies that deleting the adapter mid-callback doesn't lead
395   // to deleted structures being touched as the stack unrolls, so the failure
396   // mode is a crash rather than a normal test failure.
397   net::TestCompletionCallback client_connect_cb;
398   DeleteOnConnected host_delete(&message_loop_, &host_pseudotcp_);
399 
400   host_pseudotcp_->Connect(base::Bind(&DeleteOnConnected::OnConnected,
401                                       base::Unretained(&host_delete)));
402   client_pseudotcp_->Connect(client_connect_cb.callback());
403   message_loop_.Run();
404 
405   ASSERT_EQ(NULL, host_pseudotcp_.get());
406 }
407 
408 // Verify that we can send/receive data with the write-waits-for-send
409 // flag set.
TEST_F(PseudoTcpAdapterTest,WriteWaitsForSendLetsDataThrough)410 TEST_F(PseudoTcpAdapterTest, WriteWaitsForSendLetsDataThrough) {
411   net::TestCompletionCallback host_connect_cb;
412   net::TestCompletionCallback client_connect_cb;
413 
414   host_pseudotcp_->SetWriteWaitsForSend(true);
415   client_pseudotcp_->SetWriteWaitsForSend(true);
416 
417   // Disable Nagle's algorithm because the test is slow when it is
418   // enabled.
419   host_pseudotcp_->SetNoDelay(true);
420 
421   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
422   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
423 
424   if (rv1 == net::ERR_IO_PENDING)
425     rv1 = host_connect_cb.WaitForResult();
426   if (rv2 == net::ERR_IO_PENDING)
427     rv2 = client_connect_cb.WaitForResult();
428   ASSERT_EQ(net::OK, rv1);
429   ASSERT_EQ(net::OK, rv2);
430 
431   scoped_refptr<TCPChannelTester> tester =
432       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
433                            client_pseudotcp_.get());
434 
435   tester->Start();
436   message_loop_.Run();
437   tester->CheckResults();
438 }
439 
440 }  // namespace
441 
442 }  // namespace jingle_glue
443