• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/public/dns_sd_service_watcher.h"
6 
7 #include <algorithm>
8 
9 #include "gmock/gmock.h"
10 #include "gtest/gtest.h"
11 
12 using testing::_;
13 using testing::ContainerEq;
14 using testing::IsSubsetOf;
15 using testing::IsSupersetOf;
16 using testing::StrictMock;
17 
18 namespace openscreen {
19 namespace discovery {
20 namespace {
21 
ConvertRefs(const std::vector<std::reference_wrapper<const std::string>> & value)22 std::vector<std::string> ConvertRefs(
23     const std::vector<std::reference_wrapper<const std::string>>& value) {
24   std::vector<std::string> strings;
25 
26   // This loop is required to unwrap reference_wrapper objects.
27   for (const std::string& val : value) {
28     strings.push_back(val);
29   }
30   return strings;
31 }
32 
33 static const IPAddress kAddressV4(192, 168, 0, 0);
34 static const IPEndpoint kEndpointV4{kAddressV4, 0};
35 constexpr char kCastServiceId[] = "_googlecast._tcp";
36 constexpr char kCastDomainId[] = "local";
37 constexpr NetworkInterfaceIndex kNetworkInterface = 0;
38 
39 class MockDnsSdService : public DnsSdService {
40  public:
MockDnsSdService()41   MockDnsSdService() : querier_(this) {}
42 
GetQuerier()43   DnsSdQuerier* GetQuerier() override { return &querier_; }
GetPublisher()44   DnsSdPublisher* GetPublisher() override { return nullptr; }
45 
46   MOCK_METHOD2(StartQuery,
47                void(const std::string& service, DnsSdQuerier::Callback* cb));
48   MOCK_METHOD2(StopQuery,
49                void(const std::string& service, DnsSdQuerier::Callback* cb));
50   MOCK_METHOD1(ReinitializeQueries, void(const std::string& service));
51 
52  private:
53   class MockQuerier : public DnsSdQuerier {
54    public:
MockQuerier(MockDnsSdService * service)55     explicit MockQuerier(MockDnsSdService* service) : mock_service_(service) {
56       OSP_DCHECK(service);
57     }
58 
StartQuery(const std::string & service,DnsSdQuerier::Callback * cb)59     void StartQuery(const std::string& service,
60                     DnsSdQuerier::Callback* cb) override {
61       mock_service_->StartQuery(service, cb);
62     }
63 
StopQuery(const std::string & service,DnsSdQuerier::Callback * cb)64     void StopQuery(const std::string& service,
65                    DnsSdQuerier::Callback* cb) override {
66       mock_service_->StopQuery(service, cb);
67     }
68 
ReinitializeQueries(const std::string & service)69     void ReinitializeQueries(const std::string& service) override {
70       mock_service_->ReinitializeQueries(service);
71     }
72 
73    private:
74     MockDnsSdService* const mock_service_;
75   };
76 
77   MockQuerier querier_;
78 };
79 
80 }  // namespace
81 
82 class TestServiceWatcher : public DnsSdServiceWatcher<std::string> {
83  public:
84   using DnsSdServiceWatcher<std::string>::ConstRefT;
85 
TestServiceWatcher(MockDnsSdService * service)86   explicit TestServiceWatcher(MockDnsSdService* service)
87       : DnsSdServiceWatcher<std::string>(
88             service,
89             kCastServiceId,
90             [this](const DnsSdInstance& instance) { return Convert(instance); },
__anon1a1e3fb70302(std::vector<ConstRefT> ref, ConstRefT service, ServicesUpdatedState state) 91             [this](std::vector<ConstRefT> ref, ConstRefT service, ServicesUpdatedState state) {
92                     Callback(std::move(ref));
93             }) {}
94 
95   MOCK_METHOD1(Callback, void(std::vector<ConstRefT>));
96 
97   using DnsSdServiceWatcher<std::string>::OnEndpointCreated;
98   using DnsSdServiceWatcher<std::string>::OnEndpointUpdated;
99   using DnsSdServiceWatcher<std::string>::OnEndpointDeleted;
100 
101  private:
Convert(const DnsSdInstance & instance)102   std::string Convert(const DnsSdInstance& instance) {
103     return instance.instance_id();
104   }
105 };
106 
107 class DnsSdServiceWatcherTests : public testing::Test {
108  public:
DnsSdServiceWatcherTests()109   DnsSdServiceWatcherTests() : watcher_(&service_) {
110     // Start service discovery, since all other tests need it
111     EXPECT_FALSE(watcher_.is_running());
112     EXPECT_CALL(service_, StartQuery(kCastServiceId, _));
113     watcher_.StartDiscovery();
114     testing::Mock::VerifyAndClearExpectations(&service_);
115   }
116 
117  protected:
CreateNewInstance(const DnsSdInstanceEndpoint & record)118   void CreateNewInstance(const DnsSdInstanceEndpoint& record) {
119     const std::vector<std::string> services_before =
120         ConvertRefs(watcher_.GetServices());
121     const size_t count = services_before.size();
122 
123     std::vector<std::string> callbacked_services;
124     EXPECT_CALL(watcher_, Callback(_))
125         .WillOnce([services = &callbacked_services](
126                       std::vector<TestServiceWatcher::ConstRefT> value) {
127           *services = ConvertRefs(value);
128         });
129     watcher_.OnEndpointCreated(record);
130     testing::Mock::VerifyAndClearExpectations(&watcher_);
131 
132     std::vector<std::string> fetched_services =
133         ConvertRefs(watcher_.GetServices());
134     EXPECT_EQ(fetched_services.size(), count + 1);
135 
136     EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
137     EXPECT_THAT(fetched_services, IsSupersetOf(services_before));
138   }
139 
CreateExistingInstance(const DnsSdInstanceEndpoint & record)140   void CreateExistingInstance(const DnsSdInstanceEndpoint& record) {
141     const std::vector<std::string> services_before =
142         ConvertRefs(watcher_.GetServices());
143     const size_t count = services_before.size();
144 
145     std::vector<std::string> callbacked_services;
146     EXPECT_CALL(watcher_, Callback(_))
147         .WillOnce([services = &callbacked_services](
148                       std::vector<TestServiceWatcher::ConstRefT> value) {
149           *services = ConvertRefs(value);
150         });
151     watcher_.OnEndpointCreated(record);
152     testing::Mock::VerifyAndClearExpectations(&watcher_);
153 
154     const std::vector<std::string> fetched_services =
155         ConvertRefs(watcher_.GetServices());
156     EXPECT_EQ(fetched_services.size(), count);
157 
158     EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
159     EXPECT_THAT(fetched_services, ContainerEq(services_before));
160   }
161 
UpdateExistingInstance(const DnsSdInstanceEndpoint & record)162   void UpdateExistingInstance(const DnsSdInstanceEndpoint& record) {
163     const std::vector<std::string> services_before =
164         ConvertRefs(watcher_.GetServices());
165     const size_t count = services_before.size();
166 
167     std::vector<std::string> callbacked_services;
168     EXPECT_CALL(watcher_, Callback(_))
169         .WillOnce([services = &callbacked_services](
170                       std::vector<TestServiceWatcher::ConstRefT> value) {
171           *services = ConvertRefs(value);
172         });
173     watcher_.OnEndpointUpdated(record);
174     testing::Mock::VerifyAndClearExpectations(&watcher_);
175 
176     const std::vector<std::string> fetched_services =
177         ConvertRefs(watcher_.GetServices());
178     EXPECT_EQ(fetched_services.size(), count);
179 
180     EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
181     EXPECT_THAT(fetched_services, ContainerEq(services_before));
182   }
183 
DeleteExistingInstance(const DnsSdInstanceEndpoint & record)184   void DeleteExistingInstance(const DnsSdInstanceEndpoint& record) {
185     const std::vector<std::string> services_before =
186         ConvertRefs(watcher_.GetServices());
187     const size_t count = services_before.size();
188 
189     std::vector<std::string> callbacked_services;
190     EXPECT_CALL(watcher_, Callback(_))
191         .WillOnce([services = &callbacked_services](
192                       std::vector<TestServiceWatcher::ConstRefT> value) {
193           *services = ConvertRefs(value);
194         });
195     watcher_.OnEndpointDeleted(record);
196     testing::Mock::VerifyAndClearExpectations(&watcher_);
197 
198     const std::vector<std::string> fetched_services =
199         ConvertRefs(watcher_.GetServices());
200     EXPECT_EQ(fetched_services.size(), count - 1);
201   }
202 
UpdateNonExistingInstance(const DnsSdInstanceEndpoint & record)203   void UpdateNonExistingInstance(const DnsSdInstanceEndpoint& record) {
204     const std::vector<std::string> services_before =
205         ConvertRefs(watcher_.GetServices());
206     const size_t count = services_before.size();
207 
208     EXPECT_CALL(watcher_, Callback(_)).Times(0);
209     watcher_.OnEndpointUpdated(record);
210     testing::Mock::VerifyAndClearExpectations(&watcher_);
211 
212     const std::vector<std::string> fetched_services =
213         ConvertRefs(watcher_.GetServices());
214     EXPECT_EQ(fetched_services.size(), count);
215 
216     EXPECT_THAT(services_before, ContainerEq(fetched_services));
217   }
218 
DeleteNonExistingInstance(const DnsSdInstanceEndpoint & record)219   void DeleteNonExistingInstance(const DnsSdInstanceEndpoint& record) {
220     const std::vector<std::string> services_before =
221         ConvertRefs(watcher_.GetServices());
222     const size_t count = services_before.size();
223 
224     EXPECT_CALL(watcher_, Callback(_)).Times(0);
225     watcher_.OnEndpointDeleted(record);
226     testing::Mock::VerifyAndClearExpectations(&watcher_);
227 
228     const std::vector<std::string> fetched_services =
229         ConvertRefs(watcher_.GetServices());
230     EXPECT_EQ(fetched_services.size(), count);
231 
232     EXPECT_THAT(services_before, ContainerEq(fetched_services));
233   }
234 
ContainsService(const DnsSdInstanceEndpoint & record)235   bool ContainsService(const DnsSdInstanceEndpoint& record) {
236     const std::string& service = record.instance_id();
237     const std::vector<TestServiceWatcher::ConstRefT> services =
238         watcher_.GetServices();
239     return std::find_if(services.begin(), services.end(),
240                         [&service](const std::string& ref) {
241                           return service == ref;
242                         }) != services.end();
243   }
244 
245   StrictMock<MockDnsSdService> service_;
246   StrictMock<TestServiceWatcher> watcher_;
247   std::vector<std::string> fetched_services;
248 };
249 
TEST_F(DnsSdServiceWatcherTests,StartStopDiscoveryWorks)250 TEST_F(DnsSdServiceWatcherTests, StartStopDiscoveryWorks) {
251   EXPECT_TRUE(watcher_.is_running());
252   EXPECT_CALL(service_, StopQuery(kCastServiceId, _));
253   watcher_.StopDiscovery();
254   EXPECT_FALSE(watcher_.is_running());
255 }
256 
TEST(DnsSdServiceWatcherTest,RefreshFailsBeforeDiscoveryStarts)257 TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) {
258   StrictMock<MockDnsSdService> service;
259   StrictMock<TestServiceWatcher> watcher(&service);
260   EXPECT_FALSE(watcher.DiscoverNow().ok());
261   EXPECT_FALSE(watcher.ForceRefresh().ok());
262 }
263 
TEST_F(DnsSdServiceWatcherTests,RefreshDiscoveryWorks)264 TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) {
265   const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId,
266                                      DnsSdTxtRecord{}, kNetworkInterface,
267                                      kEndpointV4);
268   CreateNewInstance(record);
269 
270   // Refresh services.
271   EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
272   EXPECT_TRUE(watcher_.DiscoverNow().ok());
273   EXPECT_EQ(watcher_.GetServices().size(), size_t{1});
274   testing::Mock::VerifyAndClearExpectations(&service_);
275 
276   EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
277   EXPECT_TRUE(watcher_.ForceRefresh().ok());
278   EXPECT_EQ(watcher_.GetServices().size(), size_t{0});
279   testing::Mock::VerifyAndClearExpectations(&service_);
280 }
281 
TEST_F(DnsSdServiceWatcherTests,CreatingUpdatingDeletingInstancesWork)282 TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) {
283   const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId,
284                                      DnsSdTxtRecord{}, kNetworkInterface,
285                                      kEndpointV4);
286   const DnsSdInstanceEndpoint record2("Instance2", kCastServiceId,
287                                       kCastDomainId, DnsSdTxtRecord{},
288                                       kNetworkInterface, kEndpointV4);
289 
290   EXPECT_FALSE(ContainsService(record));
291   EXPECT_FALSE(ContainsService(record2));
292 
293   CreateNewInstance(record);
294   EXPECT_TRUE(ContainsService(record));
295   EXPECT_FALSE(ContainsService(record2));
296 
297   CreateExistingInstance(record);
298   EXPECT_TRUE(ContainsService(record));
299   EXPECT_FALSE(ContainsService(record2));
300 
301   UpdateNonExistingInstance(record2);
302   EXPECT_TRUE(ContainsService(record));
303   EXPECT_FALSE(ContainsService(record2));
304 
305   DeleteNonExistingInstance(record2);
306   EXPECT_TRUE(ContainsService(record));
307   EXPECT_FALSE(ContainsService(record2));
308 
309   CreateNewInstance(record2);
310   EXPECT_TRUE(ContainsService(record));
311   EXPECT_TRUE(ContainsService(record2));
312 
313   UpdateExistingInstance(record2);
314   EXPECT_TRUE(ContainsService(record));
315   EXPECT_TRUE(ContainsService(record2));
316 
317   UpdateExistingInstance(record);
318   EXPECT_TRUE(ContainsService(record));
319   EXPECT_TRUE(ContainsService(record2));
320 
321   DeleteExistingInstance(record);
322   EXPECT_FALSE(ContainsService(record));
323   EXPECT_TRUE(ContainsService(record2));
324 
325   UpdateNonExistingInstance(record);
326   EXPECT_FALSE(ContainsService(record));
327   EXPECT_TRUE(ContainsService(record2));
328 
329   DeleteNonExistingInstance(record);
330   EXPECT_FALSE(ContainsService(record));
331   EXPECT_TRUE(ContainsService(record2));
332 
333   DeleteExistingInstance(record2);
334   EXPECT_FALSE(ContainsService(record));
335   EXPECT_FALSE(ContainsService(record2));
336 }
337 
338 }  // namespace discovery
339 }  // namespace openscreen
340