• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/memory/weak_ptr.h"
6 #include "base/run_loop.h"
7 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
8 #include "net/base/net_errors.h"
9 #include "net/dns/dns_protocol.h"
10 #include "net/dns/mdns_client_impl.h"
11 #include "net/dns/mock_mdns_socket_factory.h"
12 #include "testing/gmock/include/gmock/gmock.h"
13 #include "testing/gtest/include/gtest/gtest.h"
14 
15 using ::testing::_;
16 using ::testing::Invoke;
17 using ::testing::StrictMock;
18 using ::testing::NiceMock;
19 using ::testing::Mock;
20 using ::testing::SaveArg;
21 using ::testing::SetArgPointee;
22 using ::testing::Return;
23 using ::testing::Exactly;
24 
25 namespace local_discovery {
26 
27 namespace {
28 
29 const uint8 kSamplePacketPTR[] = {
30   // Header
31   0x00, 0x00,               // ID is zeroed out
32   0x81, 0x80,               // Standard query response, RA, no error
33   0x00, 0x00,               // No questions (for simplicity)
34   0x00, 0x01,               // 1 RR (answers)
35   0x00, 0x00,               // 0 authority RRs
36   0x00, 0x00,               // 0 additional RRs
37 
38   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
39   0x04, '_', 't', 'c', 'p',
40   0x05, 'l', 'o', 'c', 'a', 'l',
41   0x00,
42   0x00, 0x0c,        // TYPE is PTR.
43   0x00, 0x01,        // CLASS is IN.
44   0x00, 0x00,        // TTL (4 bytes) is 1 second.
45   0x00, 0x01,
46   0x00, 0x08,        // RDLENGTH is 8 bytes.
47   0x05, 'h', 'e', 'l', 'l', 'o',
48   0xc0, 0x0c
49 };
50 
51 const uint8 kSamplePacketSRV[] = {
52   // Header
53   0x00, 0x00,               // ID is zeroed out
54   0x81, 0x80,               // Standard query response, RA, no error
55   0x00, 0x00,               // No questions (for simplicity)
56   0x00, 0x01,               // 1 RR (answers)
57   0x00, 0x00,               // 0 authority RRs
58   0x00, 0x00,               // 0 additional RRs
59 
60   0x05, 'h', 'e', 'l', 'l', 'o',
61   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
62   0x04, '_', 't', 'c', 'p',
63   0x05, 'l', 'o', 'c', 'a', 'l',
64   0x00,
65   0x00, 0x21,        // TYPE is SRV.
66   0x00, 0x01,        // CLASS is IN.
67   0x00, 0x00,        // TTL (4 bytes) is 1 second.
68   0x00, 0x01,
69   0x00, 0x15,        // RDLENGTH is 21 bytes.
70   0x00, 0x00,
71   0x00, 0x00,
72   0x22, 0xb8,  // port 8888
73   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
74   0x05, 'l', 'o', 'c', 'a', 'l',
75   0x00,
76 };
77 
78 const uint8 kSamplePacketTXT[] = {
79   // Header
80   0x00, 0x00,               // ID is zeroed out
81   0x81, 0x80,               // Standard query response, RA, no error
82   0x00, 0x00,               // No questions (for simplicity)
83   0x00, 0x01,               // 1 RR (answers)
84   0x00, 0x00,               // 0 authority RRs
85   0x00, 0x00,               // 0 additional RRs
86 
87   0x05, 'h', 'e', 'l', 'l', 'o',
88   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
89   0x04, '_', 't', 'c', 'p',
90   0x05, 'l', 'o', 'c', 'a', 'l',
91   0x00,
92   0x00, 0x10,        // TYPE is PTR.
93   0x00, 0x01,        // CLASS is IN.
94   0x00, 0x00,        // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
95   0x00, 0x01,
96   0x00, 0x06,        // RDLENGTH is 21 bytes.
97   0x05, 'h', 'e', 'l', 'l', 'o'
98 };
99 
100 const uint8 kSamplePacketSRVA[] = {
101   // Header
102   0x00, 0x00,               // ID is zeroed out
103   0x81, 0x80,               // Standard query response, RA, no error
104   0x00, 0x00,               // No questions (for simplicity)
105   0x00, 0x02,               // 2 RR (answers)
106   0x00, 0x00,               // 0 authority RRs
107   0x00, 0x00,               // 0 additional RRs
108 
109   0x05, 'h', 'e', 'l', 'l', 'o',
110   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
111   0x04, '_', 't', 'c', 'p',
112   0x05, 'l', 'o', 'c', 'a', 'l',
113   0x00,
114   0x00, 0x21,        // TYPE is SRV.
115   0x00, 0x01,        // CLASS is IN.
116   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
117   0x00, 0x10,
118   0x00, 0x15,        // RDLENGTH is 21 bytes.
119   0x00, 0x00,
120   0x00, 0x00,
121   0x22, 0xb8,  // port 8888
122   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
123   0x05, 'l', 'o', 'c', 'a', 'l',
124   0x00,
125 
126   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
127   0x05, 'l', 'o', 'c', 'a', 'l',
128   0x00,
129   0x00, 0x01,        // TYPE is A.
130   0x00, 0x01,        // CLASS is IN.
131   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
132   0x00, 0x10,
133   0x00, 0x04,        // RDLENGTH is 4 bytes.
134   0x01, 0x02,
135   0x03, 0x04,
136 };
137 
138 const uint8 kSamplePacketPTR2[] = {
139   // Header
140   0x00, 0x00,               // ID is zeroed out
141   0x81, 0x80,               // Standard query response, RA, no error
142   0x00, 0x00,               // No questions (for simplicity)
143   0x00, 0x02,               // 2 RR (answers)
144   0x00, 0x00,               // 0 authority RRs
145   0x00, 0x00,               // 0 additional RRs
146 
147   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
148   0x04, '_', 't', 'c', 'p',
149   0x05, 'l', 'o', 'c', 'a', 'l',
150   0x00,
151   0x00, 0x0c,        // TYPE is PTR.
152   0x00, 0x01,        // CLASS is IN.
153   0x02, 0x00,        // TTL (4 bytes) is 1 second.
154   0x00, 0x01,
155   0x00, 0x08,        // RDLENGTH is 8 bytes.
156   0x05, 'g', 'd', 'b', 'y', 'e',
157   0xc0, 0x0c,
158 
159   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
160   0x04, '_', 't', 'c', 'p',
161   0x05, 'l', 'o', 'c', 'a', 'l',
162   0x00,
163   0x00, 0x0c,        // TYPE is PTR.
164   0x00, 0x01,        // CLASS is IN.
165   0x02, 0x00,        // TTL (4 bytes) is 1 second.
166   0x00, 0x01,
167   0x00, 0x08,        // RDLENGTH is 8 bytes.
168   0x05, 'h', 'e', 'l', 'l', 'o',
169   0xc0, 0x0c
170 };
171 
172 const uint8 kSamplePacketQuerySRV[] = {
173   // Header
174   0x00, 0x00,               // ID is zeroed out
175   0x00, 0x00,               // No flags.
176   0x00, 0x01,               // One question.
177   0x00, 0x00,               // 0 RRs (answers)
178   0x00, 0x00,               // 0 authority RRs
179   0x00, 0x00,               // 0 additional RRs
180 
181   // Question
182   0x05, 'h', 'e', 'l', 'l', 'o',
183   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
184   0x04, '_', 't', 'c', 'p',
185   0x05, 'l', 'o', 'c', 'a', 'l',
186   0x00,
187   0x00, 0x21,        // TYPE is SRV.
188   0x00, 0x01,        // CLASS is IN.
189 };
190 
191 
192 class MockServiceWatcherClient {
193  public:
194   MOCK_METHOD2(OnServiceUpdated,
195                void(ServiceWatcher::UpdateType, const std::string&));
196 
GetCallback()197   ServiceWatcher::UpdatedCallback GetCallback() {
198     return base::Bind(&MockServiceWatcherClient::OnServiceUpdated,
199                       base::Unretained(this));
200   }
201 };
202 
203 class ServiceDiscoveryTest : public ::testing::Test {
204  public:
ServiceDiscoveryTest()205   ServiceDiscoveryTest()
206       : service_discovery_client_(&mdns_client_) {
207     mdns_client_.StartListening(&socket_factory_);
208   }
209 
~ServiceDiscoveryTest()210   virtual ~ServiceDiscoveryTest() {
211   }
212 
213  protected:
RunFor(base::TimeDelta time_period)214   void RunFor(base::TimeDelta time_period) {
215     base::CancelableCallback<void()> callback(base::Bind(
216         &ServiceDiscoveryTest::Stop, base::Unretained(this)));
217     base::MessageLoop::current()->PostDelayedTask(
218         FROM_HERE, callback.callback(), time_period);
219 
220     base::MessageLoop::current()->Run();
221     callback.Cancel();
222   }
223 
Stop()224   void Stop() {
225     base::MessageLoop::current()->Quit();
226   }
227 
228   net::MockMDnsSocketFactory socket_factory_;
229   net::MDnsClientImpl mdns_client_;
230   ServiceDiscoveryClientImpl service_discovery_client_;
231   base::MessageLoop loop_;
232 };
233 
TEST_F(ServiceDiscoveryTest,AddRemoveService)234 TEST_F(ServiceDiscoveryTest, AddRemoveService) {
235   StrictMock<MockServiceWatcherClient> delegate;
236 
237   scoped_ptr<ServiceWatcher> watcher(
238       service_discovery_client_.CreateServiceWatcher(
239           "_privet._tcp.local", delegate.GetCallback()));
240 
241   watcher->Start();
242 
243   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
244                                          "hello._privet._tcp.local"))
245       .Times(Exactly(1));
246 
247   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
248 
249   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
250                                          "hello._privet._tcp.local"))
251       .Times(Exactly(1));
252 
253   RunFor(base::TimeDelta::FromSeconds(2));
254 };
255 
TEST_F(ServiceDiscoveryTest,DiscoverNewServices)256 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) {
257   StrictMock<MockServiceWatcherClient> delegate;
258 
259   scoped_ptr<ServiceWatcher> watcher(
260       service_discovery_client_.CreateServiceWatcher(
261           "_privet._tcp.local", delegate.GetCallback()));
262 
263   watcher->Start();
264 
265   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
266 
267   watcher->DiscoverNewServices(false);
268 
269   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
270 
271   RunFor(base::TimeDelta::FromSeconds(2));
272 };
273 
TEST_F(ServiceDiscoveryTest,ReadCachedServices)274 TEST_F(ServiceDiscoveryTest, ReadCachedServices) {
275   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
276 
277   StrictMock<MockServiceWatcherClient> delegate;
278 
279   scoped_ptr<ServiceWatcher> watcher(
280       service_discovery_client_.CreateServiceWatcher(
281           "_privet._tcp.local", delegate.GetCallback()));
282 
283   watcher->Start();
284 
285   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
286                                          "hello._privet._tcp.local"))
287       .Times(Exactly(1));
288 
289   base::MessageLoop::current()->RunUntilIdle();
290 };
291 
292 
TEST_F(ServiceDiscoveryTest,ReadCachedServicesMultiple)293 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) {
294   socket_factory_.SimulateReceive(kSamplePacketPTR2, sizeof(kSamplePacketPTR2));
295 
296   StrictMock<MockServiceWatcherClient> delegate;
297   scoped_ptr<ServiceWatcher> watcher =
298       service_discovery_client_.CreateServiceWatcher(
299           "_privet._tcp.local", delegate.GetCallback());
300 
301   watcher->Start();
302 
303   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
304                                          "hello._privet._tcp.local"))
305       .Times(Exactly(1));
306 
307   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
308                                          "gdbye._privet._tcp.local"))
309       .Times(Exactly(1));
310 
311   base::MessageLoop::current()->RunUntilIdle();
312 };
313 
314 
TEST_F(ServiceDiscoveryTest,OnServiceChanged)315 TEST_F(ServiceDiscoveryTest, OnServiceChanged) {
316   StrictMock<MockServiceWatcherClient> delegate;
317   scoped_ptr<ServiceWatcher> watcher(
318       service_discovery_client_.CreateServiceWatcher(
319           "_privet._tcp.local", delegate.GetCallback()));
320 
321   watcher->Start();
322 
323   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
324                                          "hello._privet._tcp.local"))
325       .Times(Exactly(1));
326 
327   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
328 
329   base::MessageLoop::current()->RunUntilIdle();
330 
331   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
332                                          "hello._privet._tcp.local"))
333       .Times(Exactly(1));
334 
335   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
336 
337   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
338 
339   base::MessageLoop::current()->RunUntilIdle();
340 };
341 
TEST_F(ServiceDiscoveryTest,SinglePacket)342 TEST_F(ServiceDiscoveryTest, SinglePacket) {
343   StrictMock<MockServiceWatcherClient> delegate;
344   scoped_ptr<ServiceWatcher> watcher(
345       service_discovery_client_.CreateServiceWatcher(
346           "_privet._tcp.local", delegate.GetCallback()));
347 
348   watcher->Start();
349 
350   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
351                                          "hello._privet._tcp.local"))
352       .Times(Exactly(1));
353 
354   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
355 
356   // Reset the "already updated" flag.
357   base::MessageLoop::current()->RunUntilIdle();
358 
359   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
360                                          "hello._privet._tcp.local"))
361       .Times(Exactly(1));
362 
363   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
364 
365   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
366 
367   base::MessageLoop::current()->RunUntilIdle();
368 };
369 
TEST_F(ServiceDiscoveryTest,ActivelyRefreshServices)370 TEST_F(ServiceDiscoveryTest, ActivelyRefreshServices) {
371   StrictMock<MockServiceWatcherClient> delegate;
372   scoped_ptr<ServiceWatcher> watcher(
373       service_discovery_client_.CreateServiceWatcher(
374           "_privet._tcp.local", delegate.GetCallback()));
375 
376   watcher->Start();
377   watcher->SetActivelyRefreshServices(true);
378 
379   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
380                                          "hello._privet._tcp.local"))
381       .Times(Exactly(1));
382 
383   std::string query_packet = std::string((const char*)(kSamplePacketQuerySRV),
384                                          sizeof(kSamplePacketQuerySRV));
385 
386   EXPECT_CALL(socket_factory_, OnSendTo(query_packet))
387       .Times(2);
388 
389   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
390 
391   base::MessageLoop::current()->RunUntilIdle();
392 
393   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
394 
395   EXPECT_CALL(socket_factory_, OnSendTo(query_packet))
396       .Times(4);  // IPv4 and IPv6 at 85% and 95%
397 
398   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
399                                          "hello._privet._tcp.local"))
400       .Times(Exactly(1));
401 
402   RunFor(base::TimeDelta::FromSeconds(2));
403 
404   base::MessageLoop::current()->RunUntilIdle();
405 };
406 
407 
408 class ServiceResolverTest : public ServiceDiscoveryTest {
409  public:
ServiceResolverTest()410   ServiceResolverTest() {
411     metadata_expected_.push_back("hello");
412     address_expected_ = net::HostPortPair("myhello.local", 8888);
413     ip_address_expected_.push_back(1);
414     ip_address_expected_.push_back(2);
415     ip_address_expected_.push_back(3);
416     ip_address_expected_.push_back(4);
417   }
418 
~ServiceResolverTest()419   ~ServiceResolverTest() {
420   }
421 
SetUp()422   void SetUp()  {
423     resolver_ = service_discovery_client_.CreateServiceResolver(
424                     "hello._privet._tcp.local",
425                      base::Bind(&ServiceResolverTest::OnFinishedResolving,
426                                 base::Unretained(this)));
427   }
428 
OnFinishedResolving(ServiceResolver::RequestStatus request_status,const ServiceDescription & service_description)429   void OnFinishedResolving(ServiceResolver::RequestStatus request_status,
430                            const ServiceDescription& service_description) {
431     OnFinishedResolvingInternal(request_status,
432                                 service_description.address.ToString(),
433                                 service_description.metadata,
434                                 service_description.ip_address);
435   }
436 
437   MOCK_METHOD4(OnFinishedResolvingInternal,
438                void(ServiceResolver::RequestStatus,
439                     const std::string&,
440                     const std::vector<std::string>&,
441                     const net::IPAddressNumber&));
442 
443  protected:
444   scoped_ptr<ServiceResolver> resolver_;
445   net::IPAddressNumber ip_address_;
446   net::HostPortPair address_expected_;
447   std::vector<std::string> metadata_expected_;
448   net::IPAddressNumber ip_address_expected_;
449 };
450 
TEST_F(ServiceResolverTest,TxtAndSrvButNoA)451 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) {
452   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
453 
454   resolver_->StartResolving();
455 
456   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
457 
458   base::MessageLoop::current()->RunUntilIdle();
459 
460   EXPECT_CALL(*this,
461               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
462                                           address_expected_.ToString(),
463                                           metadata_expected_,
464                                           net::IPAddressNumber()));
465 
466   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
467 };
468 
TEST_F(ServiceResolverTest,TxtSrvAndA)469 TEST_F(ServiceResolverTest, TxtSrvAndA) {
470   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
471 
472   resolver_->StartResolving();
473 
474   EXPECT_CALL(*this,
475               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
476                                           address_expected_.ToString(),
477                                           metadata_expected_,
478                                           ip_address_expected_));
479 
480   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
481 
482   socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
483 };
484 
TEST_F(ServiceResolverTest,JustSrv)485 TEST_F(ServiceResolverTest, JustSrv) {
486   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
487 
488   resolver_->StartResolving();
489 
490   EXPECT_CALL(*this,
491               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
492                                           address_expected_.ToString(),
493                                           std::vector<std::string>(),
494                                           ip_address_expected_));
495 
496   socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
497 
498   // TODO(noamsml): When NSEC record support is added, change this to use an
499   // NSEC record.
500   RunFor(base::TimeDelta::FromSeconds(4));
501 };
502 
TEST_F(ServiceResolverTest,WithNothing)503 TEST_F(ServiceResolverTest, WithNothing) {
504   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
505 
506   resolver_->StartResolving();
507 
508   EXPECT_CALL(*this, OnFinishedResolvingInternal(
509                          ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _));
510 
511   // TODO(noamsml): When NSEC record support is added, change this to use an
512   // NSEC record.
513   RunFor(base::TimeDelta::FromSeconds(4));
514 };
515 
516 }  // namespace
517 
518 }  // namespace local_discovery
519