• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/dnssd/impl/publisher_impl.h"
6 
7 #include <utility>
8 #include <vector>
9 
10 #include "discovery/common/testing/mock_reporting_client.h"
11 #include "discovery/dnssd/testing/fake_network_interface_config.h"
12 #include "gmock/gmock.h"
13 #include "gtest/gtest.h"
14 #include "platform/test/fake_clock.h"
15 #include "platform/test/fake_task_runner.h"
16 
17 namespace openscreen {
18 namespace discovery {
19 namespace {
20 
21 using testing::_;
22 using testing::Return;
23 using testing::StrictMock;
24 
25 class MockClient : public DnsSdPublisher::Client {
26  public:
27   MOCK_METHOD2(OnEndpointClaimed,
28                void(const DnsSdInstance&, const DnsSdInstanceEndpoint&));
29 };
30 
31 class MockMdnsService : public MdnsService {
32  public:
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)33   void StartQuery(const DomainName& name,
34                   DnsType dns_type,
35                   DnsClass dns_class,
36                   MdnsRecordChangedCallback* callback) override {
37     FAIL();
38   }
39 
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)40   void StopQuery(const DomainName& name,
41                  DnsType dns_type,
42                  DnsClass dns_class,
43                  MdnsRecordChangedCallback* callback) override {
44     FAIL();
45   }
46 
ReinitializeQueries(const DomainName & name)47   void ReinitializeQueries(const DomainName& name) override { FAIL(); }
48 
49   MOCK_METHOD3(StartProbe,
50                Error(MdnsDomainConfirmedProvider*, DomainName, IPAddress));
51   MOCK_METHOD2(UpdateRegisteredRecord,
52                Error(const MdnsRecord&, const MdnsRecord&));
53   MOCK_METHOD1(RegisterRecord, Error(const MdnsRecord& record));
54   MOCK_METHOD1(UnregisterRecord, Error(const MdnsRecord& record));
55 };
56 
57 class PublisherImplTest : public testing::Test {
58  public:
PublisherImplTest()59   PublisherImplTest()
60       : clock_(Clock::now()),
61         task_runner_(&clock_),
62         publisher_(&mock_service_,
63                    &reporting_client_,
64                    &task_runner_,
65                    &network_config_) {}
66 
mdns_service()67   MockMdnsService* mdns_service() { return &mock_service_; }
task_runner()68   TaskRunner* task_runner() { return &task_runner_; }
publisher()69   PublisherImpl* publisher() { return &publisher_; }
70 
71   // Calls PublisherImpl::OnDomainFound() through the public interface it
72   // implements.
CallOnDomainFound(const DomainName & domain,const DomainName & domain2)73   void CallOnDomainFound(const DomainName& domain, const DomainName& domain2) {
74     static_cast<MdnsDomainConfirmedProvider&>(publisher_)
75         .OnDomainFound(domain, domain2);
76   }
77 
78  protected:
79   FakeNetworkInterfaceConfig network_config_;
80   FakeClock clock_;
81   FakeTaskRunner task_runner_;
82   StrictMock<MockMdnsService> mock_service_;
83   StrictMock<MockReportingClient> reporting_client_;
84   PublisherImpl publisher_;
85 };
86 
TEST_F(PublisherImplTest,TestRegistrationAndDegrestration)87 TEST_F(PublisherImplTest, TestRegistrationAndDegrestration) {
88   IPAddress address = IPAddress(192, 168, 0, 0);
89   network_config_.set_address_v4(address);
90   const DomainName domain{"instance", "_service", "_udp", "domain"};
91   const DomainName domain2{"instance2", "_service", "_udp", "domain"};
92   const DnsSdInstance instance("instance", "_service._udp", "domain", {}, 80);
93   const DnsSdInstance instance2("instance2", "_service._udp", "domain", {}, 80);
94   MockClient client;
95 
96   EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1);
97   publisher()->Register(instance, &client);
98   testing::Mock::VerifyAndClearExpectations(mdns_service());
99 
100   int seen = 0;
101   EXPECT_CALL(*mdns_service(), RegisterRecord(_))
102       .Times(4)
103       .WillRepeatedly([&seen, &address,
104                        &domain2](const MdnsRecord& record) mutable -> Error {
105         if (record.dns_type() == DnsType::kA) {
106           const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata());
107           if (data.ipv4_address() == address) {
108             seen++;
109           }
110         } else if (record.dns_type() == DnsType::kSRV) {
111           const SrvRecordRdata& data =
112               absl::get<SrvRecordRdata>(record.rdata());
113           if (data.port() == 80) {
114             seen++;
115           }
116         }
117 
118         if (record.dns_type() != DnsType::kPTR) {
119           EXPECT_EQ(record.name(), domain2);
120         }
121         return Error::None();
122       });
123   EXPECT_CALL(client, OnEndpointClaimed(instance, _))
124       .WillOnce([instance2](const DnsSdInstance& requested,
125                             const DnsSdInstanceEndpoint& claimed) {
126         EXPECT_EQ(instance2, claimed);
127       });
128   CallOnDomainFound(domain, domain2);
129   EXPECT_EQ(seen, 2);
130   testing::Mock::VerifyAndClearExpectations(mdns_service());
131   testing::Mock::VerifyAndClearExpectations(&client);
132 
133   seen = 0;
134   EXPECT_CALL(*mdns_service(), UnregisterRecord(_))
135       .Times(4)
136       .WillRepeatedly([&seen,
137                        &address](const MdnsRecord& record) mutable -> Error {
138         if (record.dns_type() == DnsType::kA) {
139           const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata());
140           if (data.ipv4_address() == address) {
141             seen++;
142           }
143         } else if (record.dns_type() == DnsType::kSRV) {
144           const SrvRecordRdata& data =
145               absl::get<SrvRecordRdata>(record.rdata());
146           if (data.port() == 80) {
147             seen++;
148           }
149         }
150         return Error::None();
151       });
152   publisher()->DeregisterAll("_service._udp");
153   EXPECT_EQ(seen, 2);
154 }
155 
TEST_F(PublisherImplTest,TestUpdate)156 TEST_F(PublisherImplTest, TestUpdate) {
157   IPAddress address = IPAddress(192, 168, 0, 0);
158   network_config_.set_address_v4(address);
159   DomainName domain{"instance", "_service", "_udp", "domain"};
160   DnsSdTxtRecord txt;
161   txt.SetFlag("id", true);
162   DnsSdInstance instance("instance", "_service._udp", "domain", std::move(txt),
163                          80);
164   MockClient client;
165 
166   // Update a non-existent instance
167   EXPECT_FALSE(publisher()->UpdateRegistration(instance).ok());
168 
169   // Update an instance during the probing phase
170   EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1);
171   EXPECT_EQ(publisher()->Register(instance, &client), Error::None());
172   testing::Mock::VerifyAndClearExpectations(mdns_service());
173 
174   IPAddress address2 = IPAddress(1, 2, 3, 4, 5, 6, 7, 8);
175   network_config_.set_address_v4(IPAddress{});
176   network_config_.set_address_v6(address2);
177   DnsSdTxtRecord txt2;
178   txt2.SetFlag("id2", true);
179   DnsSdInstance instance2("instance", "_service._udp", "domain",
180                           std::move(txt2), 80);
181   EXPECT_EQ(publisher()->UpdateRegistration(instance2), Error::None());
182 
183   bool seen_v6 = false;
184   EXPECT_CALL(*mdns_service(), RegisterRecord(_))
185       .Times(4)
186       .WillRepeatedly([&seen_v6](const MdnsRecord& record) mutable -> Error {
187         EXPECT_NE(record.dns_type(), DnsType::kA);
188         if (record.dns_type() == DnsType::kAAAA) {
189           seen_v6 = true;
190         }
191         return Error::None();
192       });
193   EXPECT_CALL(client, OnEndpointClaimed(instance2, _))
194       .WillOnce([instance2](const DnsSdInstance& requested,
195                             const DnsSdInstanceEndpoint& claimed) {
196         EXPECT_EQ(instance2, claimed);
197       });
198   CallOnDomainFound(domain, domain);
199   EXPECT_TRUE(seen_v6);
200   testing::Mock::VerifyAndClearExpectations(mdns_service());
201   testing::Mock::VerifyAndClearExpectations(&client);
202 
203   // Update an instance once it has been published.
204   network_config_.set_address_v4(address);
205   network_config_.set_address_v6(IPAddress{});
206   EXPECT_CALL(*mdns_service(), RegisterRecord(_))
207       .WillOnce([](const MdnsRecord& record) -> Error {
208         EXPECT_EQ(record.dns_type(), DnsType::kA);
209         return Error::None();
210       });
211   EXPECT_CALL(*mdns_service(), UnregisterRecord(_))
212       .WillOnce([](const MdnsRecord& record) -> Error {
213         EXPECT_EQ(record.dns_type(), DnsType::kAAAA);
214         return Error::None();
215       });
216   EXPECT_CALL(*mdns_service(), UpdateRegisteredRecord(_, _))
217       .WillOnce(
218           [](const MdnsRecord& record, const MdnsRecord& record2) -> Error {
219             EXPECT_EQ(record.dns_type(), DnsType::kTXT);
220             EXPECT_EQ(record2.dns_type(), DnsType::kTXT);
221             return Error::None();
222           });
223   EXPECT_EQ(publisher()->UpdateRegistration(instance), Error::None());
224   testing::Mock::VerifyAndClearExpectations(mdns_service());
225 }
226 
227 }  // namespace
228 }  // namespace discovery
229 }  // namespace openscreen
230