1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/public/dns_sd_service_watcher.h"
6
7 #include <algorithm>
8
9 #include "gmock/gmock.h"
10 #include "gtest/gtest.h"
11
12 using testing::_;
13 using testing::ContainerEq;
14 using testing::IsSubsetOf;
15 using testing::IsSupersetOf;
16 using testing::StrictMock;
17
18 namespace openscreen {
19 namespace discovery {
20 namespace {
21
ConvertRefs(const std::vector<std::reference_wrapper<const std::string>> & value)22 std::vector<std::string> ConvertRefs(
23 const std::vector<std::reference_wrapper<const std::string>>& value) {
24 std::vector<std::string> strings;
25
26 // This loop is required to unwrap reference_wrapper objects.
27 for (const std::string& val : value) {
28 strings.push_back(val);
29 }
30 return strings;
31 }
32
33 static const IPAddress kAddressV4(192, 168, 0, 0);
34 static const IPEndpoint kEndpointV4{kAddressV4, 0};
35 constexpr char kCastServiceId[] = "_googlecast._tcp";
36 constexpr char kCastDomainId[] = "local";
37 constexpr NetworkInterfaceIndex kNetworkInterface = 0;
38
39 class MockDnsSdService : public DnsSdService {
40 public:
MockDnsSdService()41 MockDnsSdService() : querier_(this) {}
42
GetQuerier()43 DnsSdQuerier* GetQuerier() override { return &querier_; }
GetPublisher()44 DnsSdPublisher* GetPublisher() override { return nullptr; }
45
46 MOCK_METHOD2(StartQuery,
47 void(const std::string& service, DnsSdQuerier::Callback* cb));
48 MOCK_METHOD2(StopQuery,
49 void(const std::string& service, DnsSdQuerier::Callback* cb));
50 MOCK_METHOD1(ReinitializeQueries, void(const std::string& service));
51
52 private:
53 class MockQuerier : public DnsSdQuerier {
54 public:
MockQuerier(MockDnsSdService * service)55 explicit MockQuerier(MockDnsSdService* service) : mock_service_(service) {
56 OSP_DCHECK(service);
57 }
58
StartQuery(const std::string & service,DnsSdQuerier::Callback * cb)59 void StartQuery(const std::string& service,
60 DnsSdQuerier::Callback* cb) override {
61 mock_service_->StartQuery(service, cb);
62 }
63
StopQuery(const std::string & service,DnsSdQuerier::Callback * cb)64 void StopQuery(const std::string& service,
65 DnsSdQuerier::Callback* cb) override {
66 mock_service_->StopQuery(service, cb);
67 }
68
ReinitializeQueries(const std::string & service)69 void ReinitializeQueries(const std::string& service) override {
70 mock_service_->ReinitializeQueries(service);
71 }
72
73 private:
74 MockDnsSdService* const mock_service_;
75 };
76
77 MockQuerier querier_;
78 };
79
80 } // namespace
81
82 class TestServiceWatcher : public DnsSdServiceWatcher<std::string> {
83 public:
84 using DnsSdServiceWatcher<std::string>::ConstRefT;
85
TestServiceWatcher(MockDnsSdService * service)86 explicit TestServiceWatcher(MockDnsSdService* service)
87 : DnsSdServiceWatcher<std::string>(
88 service,
89 kCastServiceId,
90 [this](const DnsSdInstance& instance) { return Convert(instance); },
__anon1a1e3fb70302(std::vector<ConstRefT> ref, ConstRefT service, ServicesUpdatedState state) 91 [this](std::vector<ConstRefT> ref, ConstRefT service, ServicesUpdatedState state) {
92 Callback(std::move(ref));
93 }) {}
94
95 MOCK_METHOD1(Callback, void(std::vector<ConstRefT>));
96
97 using DnsSdServiceWatcher<std::string>::OnEndpointCreated;
98 using DnsSdServiceWatcher<std::string>::OnEndpointUpdated;
99 using DnsSdServiceWatcher<std::string>::OnEndpointDeleted;
100
101 private:
Convert(const DnsSdInstance & instance)102 std::string Convert(const DnsSdInstance& instance) {
103 return instance.instance_id();
104 }
105 };
106
107 class DnsSdServiceWatcherTests : public testing::Test {
108 public:
DnsSdServiceWatcherTests()109 DnsSdServiceWatcherTests() : watcher_(&service_) {
110 // Start service discovery, since all other tests need it
111 EXPECT_FALSE(watcher_.is_running());
112 EXPECT_CALL(service_, StartQuery(kCastServiceId, _));
113 watcher_.StartDiscovery();
114 testing::Mock::VerifyAndClearExpectations(&service_);
115 }
116
117 protected:
CreateNewInstance(const DnsSdInstanceEndpoint & record)118 void CreateNewInstance(const DnsSdInstanceEndpoint& record) {
119 const std::vector<std::string> services_before =
120 ConvertRefs(watcher_.GetServices());
121 const size_t count = services_before.size();
122
123 std::vector<std::string> callbacked_services;
124 EXPECT_CALL(watcher_, Callback(_))
125 .WillOnce([services = &callbacked_services](
126 std::vector<TestServiceWatcher::ConstRefT> value) {
127 *services = ConvertRefs(value);
128 });
129 watcher_.OnEndpointCreated(record);
130 testing::Mock::VerifyAndClearExpectations(&watcher_);
131
132 std::vector<std::string> fetched_services =
133 ConvertRefs(watcher_.GetServices());
134 EXPECT_EQ(fetched_services.size(), count + 1);
135
136 EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
137 EXPECT_THAT(fetched_services, IsSupersetOf(services_before));
138 }
139
CreateExistingInstance(const DnsSdInstanceEndpoint & record)140 void CreateExistingInstance(const DnsSdInstanceEndpoint& record) {
141 const std::vector<std::string> services_before =
142 ConvertRefs(watcher_.GetServices());
143 const size_t count = services_before.size();
144
145 std::vector<std::string> callbacked_services;
146 EXPECT_CALL(watcher_, Callback(_))
147 .WillOnce([services = &callbacked_services](
148 std::vector<TestServiceWatcher::ConstRefT> value) {
149 *services = ConvertRefs(value);
150 });
151 watcher_.OnEndpointCreated(record);
152 testing::Mock::VerifyAndClearExpectations(&watcher_);
153
154 const std::vector<std::string> fetched_services =
155 ConvertRefs(watcher_.GetServices());
156 EXPECT_EQ(fetched_services.size(), count);
157
158 EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
159 EXPECT_THAT(fetched_services, ContainerEq(services_before));
160 }
161
UpdateExistingInstance(const DnsSdInstanceEndpoint & record)162 void UpdateExistingInstance(const DnsSdInstanceEndpoint& record) {
163 const std::vector<std::string> services_before =
164 ConvertRefs(watcher_.GetServices());
165 const size_t count = services_before.size();
166
167 std::vector<std::string> callbacked_services;
168 EXPECT_CALL(watcher_, Callback(_))
169 .WillOnce([services = &callbacked_services](
170 std::vector<TestServiceWatcher::ConstRefT> value) {
171 *services = ConvertRefs(value);
172 });
173 watcher_.OnEndpointUpdated(record);
174 testing::Mock::VerifyAndClearExpectations(&watcher_);
175
176 const std::vector<std::string> fetched_services =
177 ConvertRefs(watcher_.GetServices());
178 EXPECT_EQ(fetched_services.size(), count);
179
180 EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
181 EXPECT_THAT(fetched_services, ContainerEq(services_before));
182 }
183
DeleteExistingInstance(const DnsSdInstanceEndpoint & record)184 void DeleteExistingInstance(const DnsSdInstanceEndpoint& record) {
185 const std::vector<std::string> services_before =
186 ConvertRefs(watcher_.GetServices());
187 const size_t count = services_before.size();
188
189 std::vector<std::string> callbacked_services;
190 EXPECT_CALL(watcher_, Callback(_))
191 .WillOnce([services = &callbacked_services](
192 std::vector<TestServiceWatcher::ConstRefT> value) {
193 *services = ConvertRefs(value);
194 });
195 watcher_.OnEndpointDeleted(record);
196 testing::Mock::VerifyAndClearExpectations(&watcher_);
197
198 const std::vector<std::string> fetched_services =
199 ConvertRefs(watcher_.GetServices());
200 EXPECT_EQ(fetched_services.size(), count - 1);
201 }
202
UpdateNonExistingInstance(const DnsSdInstanceEndpoint & record)203 void UpdateNonExistingInstance(const DnsSdInstanceEndpoint& record) {
204 const std::vector<std::string> services_before =
205 ConvertRefs(watcher_.GetServices());
206 const size_t count = services_before.size();
207
208 EXPECT_CALL(watcher_, Callback(_)).Times(0);
209 watcher_.OnEndpointUpdated(record);
210 testing::Mock::VerifyAndClearExpectations(&watcher_);
211
212 const std::vector<std::string> fetched_services =
213 ConvertRefs(watcher_.GetServices());
214 EXPECT_EQ(fetched_services.size(), count);
215
216 EXPECT_THAT(services_before, ContainerEq(fetched_services));
217 }
218
DeleteNonExistingInstance(const DnsSdInstanceEndpoint & record)219 void DeleteNonExistingInstance(const DnsSdInstanceEndpoint& record) {
220 const std::vector<std::string> services_before =
221 ConvertRefs(watcher_.GetServices());
222 const size_t count = services_before.size();
223
224 EXPECT_CALL(watcher_, Callback(_)).Times(0);
225 watcher_.OnEndpointDeleted(record);
226 testing::Mock::VerifyAndClearExpectations(&watcher_);
227
228 const std::vector<std::string> fetched_services =
229 ConvertRefs(watcher_.GetServices());
230 EXPECT_EQ(fetched_services.size(), count);
231
232 EXPECT_THAT(services_before, ContainerEq(fetched_services));
233 }
234
ContainsService(const DnsSdInstanceEndpoint & record)235 bool ContainsService(const DnsSdInstanceEndpoint& record) {
236 const std::string& service = record.instance_id();
237 const std::vector<TestServiceWatcher::ConstRefT> services =
238 watcher_.GetServices();
239 return std::find_if(services.begin(), services.end(),
240 [&service](const std::string& ref) {
241 return service == ref;
242 }) != services.end();
243 }
244
245 StrictMock<MockDnsSdService> service_;
246 StrictMock<TestServiceWatcher> watcher_;
247 std::vector<std::string> fetched_services;
248 };
249
TEST_F(DnsSdServiceWatcherTests,StartStopDiscoveryWorks)250 TEST_F(DnsSdServiceWatcherTests, StartStopDiscoveryWorks) {
251 EXPECT_TRUE(watcher_.is_running());
252 EXPECT_CALL(service_, StopQuery(kCastServiceId, _));
253 watcher_.StopDiscovery();
254 EXPECT_FALSE(watcher_.is_running());
255 }
256
TEST(DnsSdServiceWatcherTest,RefreshFailsBeforeDiscoveryStarts)257 TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) {
258 StrictMock<MockDnsSdService> service;
259 StrictMock<TestServiceWatcher> watcher(&service);
260 EXPECT_FALSE(watcher.DiscoverNow().ok());
261 EXPECT_FALSE(watcher.ForceRefresh().ok());
262 }
263
TEST_F(DnsSdServiceWatcherTests,RefreshDiscoveryWorks)264 TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) {
265 const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId,
266 DnsSdTxtRecord{}, kNetworkInterface,
267 kEndpointV4);
268 CreateNewInstance(record);
269
270 // Refresh services.
271 EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
272 EXPECT_TRUE(watcher_.DiscoverNow().ok());
273 EXPECT_EQ(watcher_.GetServices().size(), size_t{1});
274 testing::Mock::VerifyAndClearExpectations(&service_);
275
276 EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
277 EXPECT_TRUE(watcher_.ForceRefresh().ok());
278 EXPECT_EQ(watcher_.GetServices().size(), size_t{0});
279 testing::Mock::VerifyAndClearExpectations(&service_);
280 }
281
TEST_F(DnsSdServiceWatcherTests,CreatingUpdatingDeletingInstancesWork)282 TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) {
283 const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId,
284 DnsSdTxtRecord{}, kNetworkInterface,
285 kEndpointV4);
286 const DnsSdInstanceEndpoint record2("Instance2", kCastServiceId,
287 kCastDomainId, DnsSdTxtRecord{},
288 kNetworkInterface, kEndpointV4);
289
290 EXPECT_FALSE(ContainsService(record));
291 EXPECT_FALSE(ContainsService(record2));
292
293 CreateNewInstance(record);
294 EXPECT_TRUE(ContainsService(record));
295 EXPECT_FALSE(ContainsService(record2));
296
297 CreateExistingInstance(record);
298 EXPECT_TRUE(ContainsService(record));
299 EXPECT_FALSE(ContainsService(record2));
300
301 UpdateNonExistingInstance(record2);
302 EXPECT_TRUE(ContainsService(record));
303 EXPECT_FALSE(ContainsService(record2));
304
305 DeleteNonExistingInstance(record2);
306 EXPECT_TRUE(ContainsService(record));
307 EXPECT_FALSE(ContainsService(record2));
308
309 CreateNewInstance(record2);
310 EXPECT_TRUE(ContainsService(record));
311 EXPECT_TRUE(ContainsService(record2));
312
313 UpdateExistingInstance(record2);
314 EXPECT_TRUE(ContainsService(record));
315 EXPECT_TRUE(ContainsService(record2));
316
317 UpdateExistingInstance(record);
318 EXPECT_TRUE(ContainsService(record));
319 EXPECT_TRUE(ContainsService(record2));
320
321 DeleteExistingInstance(record);
322 EXPECT_FALSE(ContainsService(record));
323 EXPECT_TRUE(ContainsService(record2));
324
325 UpdateNonExistingInstance(record);
326 EXPECT_FALSE(ContainsService(record));
327 EXPECT_TRUE(ContainsService(record2));
328
329 DeleteNonExistingInstance(record);
330 EXPECT_FALSE(ContainsService(record));
331 EXPECT_TRUE(ContainsService(record2));
332
333 DeleteExistingInstance(record2);
334 EXPECT_FALSE(ContainsService(record));
335 EXPECT_FALSE(ContainsService(record2));
336 }
337
338 } // namespace discovery
339 } // namespace openscreen
340