1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/dnssd/impl/conversion_layer.h"
6
7 #include <utility>
8
9 #include "absl/strings/str_join.h"
10 #include "absl/strings/str_split.h"
11 #include "absl/types/optional.h"
12 #include "absl/types/span.h"
13 #include "discovery/dnssd/impl/constants.h"
14 #include "discovery/dnssd/impl/instance_key.h"
15 #include "discovery/dnssd/impl/service_key.h"
16 #include "discovery/dnssd/public/dns_sd_instance.h"
17 #include "discovery/mdns/mdns_records.h"
18 #include "discovery/mdns/public/mdns_constants.h"
19
20 namespace openscreen {
21 namespace discovery {
22 namespace {
23
AddServiceInfoToLabels(const std::string & service,const std::string & domain,std::vector<std::string> * labels)24 void AddServiceInfoToLabels(const std::string& service,
25 const std::string& domain,
26 std::vector<std::string>* labels) {
27 std::vector<std::string> service_labels = absl::StrSplit(service, '.');
28 labels->insert(labels->end(), service_labels.begin(), service_labels.end());
29
30 std::vector<std::string> domain_labels = absl::StrSplit(domain, '.');
31 labels->insert(labels->end(), domain_labels.begin(), domain_labels.end());
32 }
33
GetPtrDomainName(const std::string & service,const std::string & domain)34 DomainName GetPtrDomainName(const std::string& service,
35 const std::string& domain) {
36 std::vector<std::string> labels;
37 AddServiceInfoToLabels(service, domain, &labels);
38 return DomainName{std::move(labels)};
39 }
40
GetInstanceDomainName(const std::string & instance,const std::string & service,const std::string & domain)41 DomainName GetInstanceDomainName(const std::string& instance,
42 const std::string& service,
43 const std::string& domain) {
44 std::vector<std::string> labels;
45 labels.emplace_back(instance);
46 AddServiceInfoToLabels(service, domain, &labels);
47 return DomainName{std::move(labels)};
48 }
49
GetInstanceDomainName(const InstanceKey & key)50 inline DomainName GetInstanceDomainName(const InstanceKey& key) {
51 return GetInstanceDomainName(key.instance_id(), key.service_id(),
52 key.domain_id());
53 }
54
CreatePtrRecord(const DnsSdInstance & instance,const DomainName & domain)55 MdnsRecord CreatePtrRecord(const DnsSdInstance& instance,
56 const DomainName& domain) {
57 PtrRecordRdata data(domain);
58 auto outer_domain =
59 GetPtrDomainName(instance.service_id(), instance.domain_id());
60 return MdnsRecord(std::move(outer_domain), DnsType::kPTR, DnsClass::kIN,
61 RecordType::kShared, kPtrRecordTtl, std::move(data));
62 }
63
CreateSrvRecord(const DnsSdInstance & instance,const DomainName & domain)64 MdnsRecord CreateSrvRecord(const DnsSdInstance& instance,
65 const DomainName& domain) {
66 uint16_t port = instance.port();
67 SrvRecordRdata data(0, 0, port, domain);
68 return MdnsRecord(domain, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
69 kSrvRecordTtl, std::move(data));
70 }
71
CreateARecords(const DnsSdInstanceEndpoint & endpoint,const DomainName & domain)72 std::vector<MdnsRecord> CreateARecords(const DnsSdInstanceEndpoint& endpoint,
73 const DomainName& domain) {
74 std::vector<MdnsRecord> records;
75 for (const IPAddress& address : endpoint.addresses()) {
76 if (address.IsV4()) {
77 ARecordRdata data(address);
78 records.emplace_back(domain, DnsType::kA, DnsClass::kIN,
79 RecordType::kUnique, kARecordTtl, std::move(data));
80 }
81 }
82
83 return records;
84 }
85
CreateAAAARecords(const DnsSdInstanceEndpoint & endpoint,const DomainName & domain)86 std::vector<MdnsRecord> CreateAAAARecords(const DnsSdInstanceEndpoint& endpoint,
87 const DomainName& domain) {
88 std::vector<MdnsRecord> records;
89 for (const IPAddress& address : endpoint.addresses()) {
90 if (address.IsV6()) {
91 AAAARecordRdata data(address);
92 records.emplace_back(domain, DnsType::kAAAA, DnsClass::kIN,
93 RecordType::kUnique, kAAAARecordTtl,
94 std::move(data));
95 }
96 }
97
98 return records;
99 }
100
CreateTxtRecord(const DnsSdInstance & endpoint,const DomainName & domain)101 MdnsRecord CreateTxtRecord(const DnsSdInstance& endpoint,
102 const DomainName& domain) {
103 TxtRecordRdata data(endpoint.txt().GetData());
104 return MdnsRecord(domain, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
105 kTXTRecordTtl, std::move(data));
106 }
107
108 } // namespace
109
CreateFromDnsTxt(const TxtRecordRdata & txt_data)110 ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt_data) {
111 DnsSdTxtRecord txt;
112 if (txt_data.texts().size() == 1 && txt_data.texts()[0] == "") {
113 return txt;
114 }
115
116 // Iterate backwards so that the first key of each type is the one that is
117 // present at the end, as pet spec.
118 for (auto it = txt_data.texts().rbegin(); it != txt_data.texts().rend();
119 it++) {
120 const std::string& text = *it;
121 size_t index_of_eq = text.find_first_of('=');
122 if (index_of_eq != std::string::npos) {
123 if (index_of_eq == 0) {
124 return Error::Code::kParameterInvalid;
125 }
126 std::string key = text.substr(0, index_of_eq);
127 std::string value = text.substr(index_of_eq + 1);
128 absl::Span<const uint8_t> data(
129 reinterpret_cast<const uint8_t*>(value.data()), value.size());
130 const auto set_result =
131 txt.SetValue(key, std::vector<uint8_t>(data.begin(), data.end()));
132 if (!set_result.ok()) {
133 return set_result;
134 }
135 } else {
136 const auto set_result = txt.SetFlag(text, true);
137 if (!set_result.ok()) {
138 return set_result;
139 }
140 }
141 }
142
143 return txt;
144 }
145
GetDomainName(const InstanceKey & key)146 DomainName GetDomainName(const InstanceKey& key) {
147 return GetInstanceDomainName(key.instance_id(), key.service_id(),
148 key.domain_id());
149 }
150
GetDomainName(const MdnsRecord & record)151 DomainName GetDomainName(const MdnsRecord& record) {
152 return IsPtrRecord(record)
153 ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain()
154 : record.name();
155 }
156
GetInstanceQueryInfo(const InstanceKey & key)157 DnsQueryInfo GetInstanceQueryInfo(const InstanceKey& key) {
158 return {GetDomainName(key), DnsType::kANY, DnsClass::kANY};
159 }
160
GetPtrQueryInfo(const ServiceKey & key)161 DnsQueryInfo GetPtrQueryInfo(const ServiceKey& key) {
162 auto domain = GetPtrDomainName(key.service_id(), key.domain_id());
163 return {std::move(domain), DnsType::kPTR, DnsClass::kANY};
164 }
165
HasValidDnsRecordAddress(const MdnsRecord & record)166 bool HasValidDnsRecordAddress(const MdnsRecord& record) {
167 return HasValidDnsRecordAddress(GetDomainName(record));
168 }
169
HasValidDnsRecordAddress(const DomainName & domain)170 bool HasValidDnsRecordAddress(const DomainName& domain) {
171 return InstanceKey::TryCreate(domain).is_value() &&
172 IsInstanceValid(domain.labels()[0]);
173 }
174
IsPtrRecord(const MdnsRecord & record)175 bool IsPtrRecord(const MdnsRecord& record) {
176 return record.dns_type() == DnsType::kPTR;
177 }
178
GetDnsRecords(const DnsSdInstance & instance)179 std::vector<MdnsRecord> GetDnsRecords(const DnsSdInstance& instance) {
180 auto domain = GetInstanceDomainName(InstanceKey(instance));
181
182 return {CreatePtrRecord(instance, domain), CreateSrvRecord(instance, domain),
183 CreateTxtRecord(instance, domain)};
184 }
185
GetDnsRecords(const DnsSdInstanceEndpoint & endpoint)186 std::vector<MdnsRecord> GetDnsRecords(const DnsSdInstanceEndpoint& endpoint) {
187 auto domain = GetInstanceDomainName(InstanceKey(endpoint));
188
189 std::vector<MdnsRecord> records =
190 GetDnsRecords(static_cast<DnsSdInstance>(endpoint));
191
192 std::vector<MdnsRecord> v4 = CreateARecords(endpoint, domain);
193 std::vector<MdnsRecord> v6 = CreateAAAARecords(endpoint, domain);
194
195 records.insert(records.end(), v4.begin(), v4.end());
196 records.insert(records.end(), v6.begin(), v6.end());
197
198 return records;
199 }
200
201 } // namespace discovery
202 } // namespace openscreen
203