1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "PrivateDnsConfiguration"
18 #define DBG 0
19
20 #include "PrivateDnsConfiguration.h"
21
22 #include <log/log.h>
23 #include <netdb.h>
24 #include <sys/socket.h>
25
26 #include "DnsTlsTransport.h"
27 #include "ResolverEventReporter.h"
28 #include "netd_resolv/resolv.h"
29 #include "netdutils/BackoffSequence.h"
30
31 namespace android {
32 namespace net {
33
addrToString(const sockaddr_storage * addr)34 std::string addrToString(const sockaddr_storage* addr) {
35 char out[INET6_ADDRSTRLEN] = {0};
36 getnameinfo((const sockaddr*) addr, sizeof(sockaddr_storage), out, INET6_ADDRSTRLEN, nullptr, 0,
37 NI_NUMERICHOST);
38 return std::string(out);
39 }
40
parseServer(const char * server,sockaddr_storage * parsed)41 bool parseServer(const char* server, sockaddr_storage* parsed) {
42 addrinfo hints = {.ai_family = AF_UNSPEC, .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV};
43 addrinfo* res;
44
45 int err = getaddrinfo(server, "853", &hints, &res);
46 if (err != 0) {
47 ALOGW("Failed to parse server address (%s): %s", server, gai_strerror(err));
48 return false;
49 }
50
51 memcpy(parsed, res->ai_addr, res->ai_addrlen);
52 freeaddrinfo(res);
53 return true;
54 }
55
set(int32_t netId,uint32_t mark,const std::vector<std::string> & servers,const std::string & name,const std::set<std::vector<uint8_t>> & fingerprints)56 int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
57 const std::vector<std::string>& servers, const std::string& name,
58 const std::set<std::vector<uint8_t>>& fingerprints) {
59 if (DBG) {
60 ALOGD("PrivateDnsConfiguration::set(%u, 0x%x, %zu, %s, %zu)", netId, mark, servers.size(),
61 name.c_str(), fingerprints.size());
62 }
63
64 const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
65
66 // Parse the list of servers that has been passed in
67 std::set<DnsTlsServer> tlsServers;
68 for (size_t i = 0; i < servers.size(); ++i) {
69 sockaddr_storage parsed;
70 if (!parseServer(servers[i].c_str(), &parsed)) {
71 return -EINVAL;
72 }
73 DnsTlsServer server(parsed);
74 server.name = name;
75 server.fingerprints = fingerprints;
76 tlsServers.insert(server);
77 }
78
79 std::lock_guard guard(mPrivateDnsLock);
80 if (explicitlyConfigured) {
81 mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
82 } else if (!tlsServers.empty()) {
83 mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
84 } else {
85 mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
86 mPrivateDnsTransports.erase(netId);
87 return 0;
88 }
89
90 // Create the tracker if it was not present
91 auto netPair = mPrivateDnsTransports.find(netId);
92 if (netPair == mPrivateDnsTransports.end()) {
93 // No TLS tracker yet for this netId.
94 bool added;
95 std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
96 if (!added) {
97 ALOGE("Memory error while recording private DNS for netId %d", netId);
98 return -ENOMEM;
99 }
100 }
101 auto& tracker = netPair->second;
102
103 // Remove any servers from the tracker that are not in |servers| exactly.
104 for (auto it = tracker.begin(); it != tracker.end();) {
105 if (tlsServers.count(it->first) == 0) {
106 it = tracker.erase(it);
107 } else {
108 ++it;
109 }
110 }
111
112 // Add any new or changed servers to the tracker, and initiate async checks for them.
113 for (const auto& server : tlsServers) {
114 if (needsValidation(tracker, server)) {
115 validatePrivateDnsProvider(server, tracker, netId, mark);
116 }
117 }
118 return 0;
119 }
120
getStatus(unsigned netId)121 PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
122 PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
123 std::lock_guard guard(mPrivateDnsLock);
124
125 const auto mode = mPrivateDnsModes.find(netId);
126 if (mode == mPrivateDnsModes.end()) return status;
127 status.mode = mode->second;
128
129 const auto netPair = mPrivateDnsTransports.find(netId);
130 if (netPair != mPrivateDnsTransports.end()) {
131 for (const auto& serverPair : netPair->second) {
132 if (serverPair.second == Validation::success) {
133 status.validatedServers.push_back(serverPair.first);
134 }
135 }
136 }
137
138 return status;
139 }
140
getStatus(unsigned netId,ExternalPrivateDnsStatus * status)141 void PrivateDnsConfiguration::getStatus(unsigned netId, ExternalPrivateDnsStatus* status) {
142 std::lock_guard guard(mPrivateDnsLock);
143
144 const auto mode = mPrivateDnsModes.find(netId);
145 if (mode == mPrivateDnsModes.end()) return;
146 status->mode = mode->second;
147
148 const auto netPair = mPrivateDnsTransports.find(netId);
149 if (netPair != mPrivateDnsTransports.end()) {
150 int count = 0;
151 for (const auto& serverPair : netPair->second) {
152 status->serverStatus[count].ss = serverPair.first.ss;
153 status->serverStatus[count].hostname =
154 serverPair.first.name.empty() ? "" : serverPair.first.name.c_str();
155 status->serverStatus[count].validation = serverPair.second;
156 count++;
157 if (count >= MAXNS) break; // Lose the rest
158 }
159 status->numServers = count;
160 }
161 }
162
clear(unsigned netId)163 void PrivateDnsConfiguration::clear(unsigned netId) {
164 if (DBG) {
165 ALOGD("PrivateDnsConfiguration::clear(%u)", netId);
166 }
167 std::lock_guard guard(mPrivateDnsLock);
168 mPrivateDnsModes.erase(netId);
169 mPrivateDnsTransports.erase(netId);
170 }
171
validatePrivateDnsProvider(const DnsTlsServer & server,PrivateDnsTracker & tracker,unsigned netId,uint32_t mark)172 void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& server,
173 PrivateDnsTracker& tracker, unsigned netId,
174 uint32_t mark) REQUIRES(mPrivateDnsLock) {
175 if (DBG) {
176 ALOGD("validatePrivateDnsProvider(%s, %u)", addrToString(&server.ss).c_str(), netId);
177 }
178
179 tracker[server] = Validation::in_process;
180 if (DBG) {
181 ALOGD("Server %s marked as in_process. Tracker now has size %zu",
182 addrToString(&server.ss).c_str(), tracker.size());
183 }
184 // Note that capturing |server| and |netId| in this lambda create copies.
185 std::thread validate_thread([this, server, netId, mark] {
186 // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
187 //
188 // Start with a 1 minute delay and backoff to once per hour.
189 //
190 // Assumptions:
191 // [1] Each TLS validation is ~10KB of certs+handshake+payload.
192 // [2] Network typically provision clients with <=4 nameservers.
193 // [3] Average month has 30 days.
194 //
195 // Each validation pass in a given hour is ~1.2MB of data. And 24
196 // such validation passes per day is about ~30MB per month, in the
197 // worst case. Otherwise, this will cost ~600 SYNs per month
198 // (6 SYNs per ip, 4 ips per validation pass, 24 passes per day).
199 auto backoff = netdutils::BackoffSequence<>::Builder()
200 .withInitialRetransmissionTime(std::chrono::seconds(60))
201 .withMaximumRetransmissionTime(std::chrono::seconds(3600))
202 .build();
203
204 while (true) {
205 // ::validate() is a blocking call that performs network operations.
206 // It can take milliseconds to minutes, up to the SYN retry limit.
207 const bool success = DnsTlsTransport::validate(server, netId, mark);
208 if (DBG) {
209 ALOGD("validateDnsTlsServer returned %d for %s", success,
210 addrToString(&server.ss).c_str());
211 }
212
213 const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success);
214 if (!needs_reeval) {
215 break;
216 }
217
218 if (backoff.hasNextTimeout()) {
219 std::this_thread::sleep_for(backoff.getNextTimeout());
220 } else {
221 break;
222 }
223 }
224 });
225 validate_thread.detach();
226 }
227
recordPrivateDnsValidation(const DnsTlsServer & server,unsigned netId,bool success)228 bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
229 bool success) {
230 constexpr bool NEEDS_REEVALUATION = true;
231 constexpr bool DONT_REEVALUATE = false;
232
233 std::lock_guard guard(mPrivateDnsLock);
234
235 auto netPair = mPrivateDnsTransports.find(netId);
236 if (netPair == mPrivateDnsTransports.end()) {
237 ALOGW("netId %u was erased during private DNS validation", netId);
238 return DONT_REEVALUATE;
239 }
240
241 const auto mode = mPrivateDnsModes.find(netId);
242 if (mode == mPrivateDnsModes.end()) {
243 ALOGW("netId %u has no private DNS validation mode", netId);
244 return DONT_REEVALUATE;
245 }
246 const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
247
248 bool reevaluationStatus =
249 (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;
250
251 auto& tracker = netPair->second;
252 auto serverPair = tracker.find(server);
253 if (serverPair == tracker.end()) {
254 ALOGW("Server %s was removed during private DNS validation",
255 addrToString(&server.ss).c_str());
256 success = false;
257 reevaluationStatus = DONT_REEVALUATE;
258 } else if (!(serverPair->first == server)) {
259 // TODO: It doesn't seem correct to overwrite the tracker entry for
260 // |server| down below in this circumstance... Fix this.
261 ALOGW("Server %s was changed during private DNS validation",
262 addrToString(&server.ss).c_str());
263 success = false;
264 reevaluationStatus = DONT_REEVALUATE;
265 }
266
267 // Send a validation event to NetdEventListenerService.
268 const auto& listeners = ResolverEventReporter::getInstance().getListeners();
269 if (listeners.size() != 0) {
270 for (const auto& it : listeners) {
271 it->onPrivateDnsValidationEvent(netId, addrToString(&server.ss), server.name, success);
272 }
273 if (DBG) {
274 ALOGD("Sent validation %s event on netId %u for %s with hostname %s",
275 success ? "success" : "failure", netId, addrToString(&server.ss).c_str(),
276 server.name.c_str());
277 }
278 } else {
279 ALOGE("Validation event not sent since no INetdEventListener receiver is available.");
280 }
281
282 if (success) {
283 tracker[server] = Validation::success;
284 if (DBG) {
285 ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
286 addrToString(&server.ss).c_str(), tracker.size());
287 }
288 } else {
289 // Validation failure is expected if a user is on a captive portal.
290 // TODO: Trigger a second validation attempt after captive portal login
291 // succeeds.
292 tracker[server] = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
293 : Validation::fail;
294 if (DBG) {
295 ALOGD("Validation failed for %s!", addrToString(&server.ss).c_str());
296 }
297 }
298
299 return reevaluationStatus;
300 }
301
302 // Start validation for newly added servers as well as any servers that have
303 // landed in Validation::fail state. Note that servers that have failed
304 // multiple validation attempts but for which there is still a validating
305 // thread running are marked as being Validation::in_process.
needsValidation(const PrivateDnsTracker & tracker,const DnsTlsServer & server)306 bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
307 const DnsTlsServer& server) {
308 const auto& iter = tracker.find(server);
309 return (iter == tracker.end()) || (iter->second == Validation::fail);
310 }
311
312 PrivateDnsConfiguration gPrivateDnsConfiguration;
313
314 } // namespace net
315 } // namespace android
316