1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/dnssd/impl/publisher_impl.h"
6
7 #include <utility>
8 #include <vector>
9
10 #include "discovery/common/testing/mock_reporting_client.h"
11 #include "discovery/dnssd/testing/fake_network_interface_config.h"
12 #include "gmock/gmock.h"
13 #include "gtest/gtest.h"
14 #include "platform/test/fake_clock.h"
15 #include "platform/test/fake_task_runner.h"
16
17 namespace openscreen {
18 namespace discovery {
19 namespace {
20
21 using testing::_;
22 using testing::Return;
23 using testing::StrictMock;
24
25 class MockClient : public DnsSdPublisher::Client {
26 public:
27 MOCK_METHOD2(OnEndpointClaimed,
28 void(const DnsSdInstance&, const DnsSdInstanceEndpoint&));
29 };
30
31 class MockMdnsService : public MdnsService {
32 public:
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)33 void StartQuery(const DomainName& name,
34 DnsType dns_type,
35 DnsClass dns_class,
36 MdnsRecordChangedCallback* callback) override {
37 FAIL();
38 }
39
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)40 void StopQuery(const DomainName& name,
41 DnsType dns_type,
42 DnsClass dns_class,
43 MdnsRecordChangedCallback* callback) override {
44 FAIL();
45 }
46
ReinitializeQueries(const DomainName & name)47 void ReinitializeQueries(const DomainName& name) override { FAIL(); }
48
49 MOCK_METHOD3(StartProbe,
50 Error(MdnsDomainConfirmedProvider*, DomainName, IPAddress));
51 MOCK_METHOD2(UpdateRegisteredRecord,
52 Error(const MdnsRecord&, const MdnsRecord&));
53 MOCK_METHOD1(RegisterRecord, Error(const MdnsRecord& record));
54 MOCK_METHOD1(UnregisterRecord, Error(const MdnsRecord& record));
55 };
56
57 class PublisherImplTest : public testing::Test {
58 public:
PublisherImplTest()59 PublisherImplTest()
60 : clock_(Clock::now()),
61 task_runner_(&clock_),
62 publisher_(&mock_service_,
63 &reporting_client_,
64 &task_runner_,
65 &network_config_) {}
66
mdns_service()67 MockMdnsService* mdns_service() { return &mock_service_; }
task_runner()68 TaskRunner* task_runner() { return &task_runner_; }
publisher()69 PublisherImpl* publisher() { return &publisher_; }
70
71 // Calls PublisherImpl::OnDomainFound() through the public interface it
72 // implements.
CallOnDomainFound(const DomainName & domain,const DomainName & domain2)73 void CallOnDomainFound(const DomainName& domain, const DomainName& domain2) {
74 static_cast<MdnsDomainConfirmedProvider&>(publisher_)
75 .OnDomainFound(domain, domain2);
76 }
77
78 protected:
79 FakeNetworkInterfaceConfig network_config_;
80 FakeClock clock_;
81 FakeTaskRunner task_runner_;
82 StrictMock<MockMdnsService> mock_service_;
83 StrictMock<MockReportingClient> reporting_client_;
84 PublisherImpl publisher_;
85 };
86
TEST_F(PublisherImplTest,TestRegistrationAndDegrestration)87 TEST_F(PublisherImplTest, TestRegistrationAndDegrestration) {
88 IPAddress address = IPAddress(192, 168, 0, 0);
89 network_config_.set_address_v4(address);
90 const DomainName domain{"instance", "_service", "_udp", "domain"};
91 const DomainName domain2{"instance2", "_service", "_udp", "domain"};
92 const DnsSdInstance instance("instance", "_service._udp", "domain", {}, 80);
93 const DnsSdInstance instance2("instance2", "_service._udp", "domain", {}, 80);
94 MockClient client;
95
96 EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1);
97 publisher()->Register(instance, &client);
98 testing::Mock::VerifyAndClearExpectations(mdns_service());
99
100 int seen = 0;
101 EXPECT_CALL(*mdns_service(), RegisterRecord(_))
102 .Times(4)
103 .WillRepeatedly([&seen, &address,
104 &domain2](const MdnsRecord& record) mutable -> Error {
105 if (record.dns_type() == DnsType::kA) {
106 const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata());
107 if (data.ipv4_address() == address) {
108 seen++;
109 }
110 } else if (record.dns_type() == DnsType::kSRV) {
111 const SrvRecordRdata& data =
112 absl::get<SrvRecordRdata>(record.rdata());
113 if (data.port() == 80) {
114 seen++;
115 }
116 }
117
118 if (record.dns_type() != DnsType::kPTR) {
119 EXPECT_EQ(record.name(), domain2);
120 }
121 return Error::None();
122 });
123 EXPECT_CALL(client, OnEndpointClaimed(instance, _))
124 .WillOnce([instance2](const DnsSdInstance& requested,
125 const DnsSdInstanceEndpoint& claimed) {
126 EXPECT_EQ(instance2, claimed);
127 });
128 CallOnDomainFound(domain, domain2);
129 EXPECT_EQ(seen, 2);
130 testing::Mock::VerifyAndClearExpectations(mdns_service());
131 testing::Mock::VerifyAndClearExpectations(&client);
132
133 seen = 0;
134 EXPECT_CALL(*mdns_service(), UnregisterRecord(_))
135 .Times(4)
136 .WillRepeatedly([&seen,
137 &address](const MdnsRecord& record) mutable -> Error {
138 if (record.dns_type() == DnsType::kA) {
139 const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata());
140 if (data.ipv4_address() == address) {
141 seen++;
142 }
143 } else if (record.dns_type() == DnsType::kSRV) {
144 const SrvRecordRdata& data =
145 absl::get<SrvRecordRdata>(record.rdata());
146 if (data.port() == 80) {
147 seen++;
148 }
149 }
150 return Error::None();
151 });
152 publisher()->DeregisterAll("_service._udp");
153 EXPECT_EQ(seen, 2);
154 }
155
TEST_F(PublisherImplTest,TestUpdate)156 TEST_F(PublisherImplTest, TestUpdate) {
157 IPAddress address = IPAddress(192, 168, 0, 0);
158 network_config_.set_address_v4(address);
159 DomainName domain{"instance", "_service", "_udp", "domain"};
160 DnsSdTxtRecord txt;
161 txt.SetFlag("id", true);
162 DnsSdInstance instance("instance", "_service._udp", "domain", std::move(txt),
163 80);
164 MockClient client;
165
166 // Update a non-existent instance
167 EXPECT_FALSE(publisher()->UpdateRegistration(instance).ok());
168
169 // Update an instance during the probing phase
170 EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1);
171 EXPECT_EQ(publisher()->Register(instance, &client), Error::None());
172 testing::Mock::VerifyAndClearExpectations(mdns_service());
173
174 IPAddress address2 = IPAddress(1, 2, 3, 4, 5, 6, 7, 8);
175 network_config_.set_address_v4(IPAddress{});
176 network_config_.set_address_v6(address2);
177 DnsSdTxtRecord txt2;
178 txt2.SetFlag("id2", true);
179 DnsSdInstance instance2("instance", "_service._udp", "domain",
180 std::move(txt2), 80);
181 EXPECT_EQ(publisher()->UpdateRegistration(instance2), Error::None());
182
183 bool seen_v6 = false;
184 EXPECT_CALL(*mdns_service(), RegisterRecord(_))
185 .Times(4)
186 .WillRepeatedly([&seen_v6](const MdnsRecord& record) mutable -> Error {
187 EXPECT_NE(record.dns_type(), DnsType::kA);
188 if (record.dns_type() == DnsType::kAAAA) {
189 seen_v6 = true;
190 }
191 return Error::None();
192 });
193 EXPECT_CALL(client, OnEndpointClaimed(instance2, _))
194 .WillOnce([instance2](const DnsSdInstance& requested,
195 const DnsSdInstanceEndpoint& claimed) {
196 EXPECT_EQ(instance2, claimed);
197 });
198 CallOnDomainFound(domain, domain);
199 EXPECT_TRUE(seen_v6);
200 testing::Mock::VerifyAndClearExpectations(mdns_service());
201 testing::Mock::VerifyAndClearExpectations(&client);
202
203 // Update an instance once it has been published.
204 network_config_.set_address_v4(address);
205 network_config_.set_address_v6(IPAddress{});
206 EXPECT_CALL(*mdns_service(), RegisterRecord(_))
207 .WillOnce([](const MdnsRecord& record) -> Error {
208 EXPECT_EQ(record.dns_type(), DnsType::kA);
209 return Error::None();
210 });
211 EXPECT_CALL(*mdns_service(), UnregisterRecord(_))
212 .WillOnce([](const MdnsRecord& record) -> Error {
213 EXPECT_EQ(record.dns_type(), DnsType::kAAAA);
214 return Error::None();
215 });
216 EXPECT_CALL(*mdns_service(), UpdateRegisteredRecord(_, _))
217 .WillOnce(
218 [](const MdnsRecord& record, const MdnsRecord& record2) -> Error {
219 EXPECT_EQ(record.dns_type(), DnsType::kTXT);
220 EXPECT_EQ(record2.dns_type(), DnsType::kTXT);
221 return Error::None();
222 });
223 EXPECT_EQ(publisher()->UpdateRegistration(instance), Error::None());
224 testing::Mock::VerifyAndClearExpectations(mdns_service());
225 }
226
227 } // namespace
228 } // namespace discovery
229 } // namespace openscreen
230