• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2011 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ResolverController"
18 #define DBG 0
19 
20 #include <algorithm>
21 #include <cstdlib>
22 #include <map>
23 #include <mutex>
24 #include <set>
25 #include <string>
26 #include <thread>
27 #include <utility>
28 #include <vector>
29 #include <cutils/log.h>
30 #include <net/if.h>
31 #include <sys/socket.h>
32 #include <netdb.h>
33 
34 #include <arpa/inet.h>
35 // NOTE: <resolv_netid.h> is a private C library header that provides
36 //       declarations for _resolv_set_nameservers_for_net and
37 //       _resolv_flush_cache_for_net
38 #include <resolv_netid.h>
39 #include <resolv_params.h>
40 #include <resolv_stats.h>
41 
42 #include <android-base/strings.h>
43 #include <android/net/INetd.h>
44 
45 #include "DumpWriter.h"
46 #include "NetdConstants.h"
47 #include "ResolverController.h"
48 #include "ResolverStats.h"
49 #include "dns/DnsTlsTransport.h"
50 
51 namespace android {
52 namespace net {
53 
54 namespace {
55 
56 struct PrivateDnsServer {
PrivateDnsServerandroid::net::__anon0ca64dd90111::PrivateDnsServer57     PrivateDnsServer(const sockaddr_storage& ss) : ss(ss) {}
58     const sockaddr_storage ss;
59     // For now, the fingerprints are always SHA-256.  This is the only digest algorithm
60     // that is mandatory to support (https://tools.ietf.org/html/rfc7858#section-4.2).
61     std::set<std::vector<uint8_t>> fingerprints;
62 };
63 
64 // This comparison ignores ports and fingerprints.
operator <(const PrivateDnsServer & x,const PrivateDnsServer & y)65 bool operator<(const PrivateDnsServer& x, const PrivateDnsServer& y) {
66     if (x.ss.ss_family != y.ss.ss_family) {
67         return x.ss.ss_family < y.ss.ss_family;
68     }
69     // Same address family.  Compare IP addresses.
70     if (x.ss.ss_family == AF_INET) {
71         const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
72         const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
73         return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
74     } else if (x.ss.ss_family == AF_INET6) {
75         const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
76         const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
77         return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
78     }
79     return false;  // Unknown address type.  This is an error.
80 }
81 
parseServer(const char * server,in_port_t port,sockaddr_storage * parsed)82 bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
83     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
84     if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
85         // IPv4 parse succeeded, so it's IPv4
86         sin->sin_family = AF_INET;
87         sin->sin_port = htons(port);
88         return true;
89     }
90     sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
91     if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
92         // IPv6 parse succeeded, so it's IPv6.
93         sin6->sin6_family = AF_INET6;
94         sin6->sin6_port = htons(port);
95         return true;
96     }
97     if (DBG) {
98         ALOGW("Failed to parse server address: %s", server);
99     }
100     return false;
101 }
102 
103 // Structure for tracking the entire set of known Private DNS servers.
104 std::mutex privateDnsLock;
105 typedef std::set<PrivateDnsServer> PrivateDnsSet;
106 PrivateDnsSet privateDnsServers;
107 
108 // Structure for tracking the validation status of servers on a specific netid.
109 // Servers that fail validation are removed from the tracker, and can be retried.
110 enum class Validation : bool { in_process, success };
111 typedef std::map<PrivateDnsServer, Validation> PrivateDnsTracker;
112 std::map<unsigned, PrivateDnsTracker> privateDnsTransports;
113 
parseServers(const char ** servers,int numservers,in_port_t port)114 PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
115     PrivateDnsSet set;
116     for (int i = 0; i < numservers; ++i) {
117         sockaddr_storage parsed;
118         if (parseServer(servers[i], port, &parsed)) {
119             set.insert(parsed);
120         }
121     }
122     return set;
123 }
124 
checkPrivateDnsProviders(const unsigned netId,const char ** servers,int numservers)125 void checkPrivateDnsProviders(const unsigned netId, const char** servers, int numservers) {
126     if (DBG) {
127         ALOGD("checkPrivateDnsProviders(%u)", netId);
128     }
129 
130     std::lock_guard<std::mutex> guard(privateDnsLock);
131     if (privateDnsServers.empty()) {
132         return;
133     }
134 
135     // First compute the intersection of the servers to check with the
136     // servers that are permitted to use DNS over TLS.  The intersection
137     // will contain the port number to be used for Private DNS.
138     PrivateDnsSet serversToCheck = parseServers(servers, numservers, 53);
139     PrivateDnsSet intersection;
140     std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
141         serversToCheck.begin(), serversToCheck.end(),
142         std::inserter(intersection, intersection.begin()));
143     if (intersection.empty()) {
144         return;
145     }
146 
147     auto netPair = privateDnsTransports.find(netId);
148     if (netPair == privateDnsTransports.end()) {
149         // New netId
150         bool added;
151         std::tie(netPair, added) = privateDnsTransports.emplace(netId, PrivateDnsTracker());
152         if (!added) {
153             ALOGE("Memory error while checking private DNS for netId %d", netId);
154             return;
155         }
156     }
157 
158     auto& tracker = netPair->second;
159     for (const auto& privateServer : intersection) {
160         if (tracker.count(privateServer) != 0) {
161             continue;
162         }
163         tracker[privateServer] = Validation::in_process;
164         std::thread validate_thread([privateServer, netId] {
165             // validateDnsTlsServer() is a blocking call that performs network operations.
166             // It can take milliseconds to minutes, up to the SYN retry limit.
167             bool success = validateDnsTlsServer(netId,
168                     privateServer.ss, privateServer.fingerprints);
169             std::lock_guard<std::mutex> guard(privateDnsLock);
170             auto netPair = privateDnsTransports.find(netId);
171             if (netPair == privateDnsTransports.end()) {
172                 ALOGW("netId %u was erased during private DNS validation", netId);
173                 return;
174             }
175             auto& tracker = netPair->second;
176             if (privateDnsServers.count(privateServer) == 0) {
177                 ALOGW("Server was removed during private DNS validation");
178                 success = false;
179             }
180             if (success) {
181                 tracker[privateServer] = Validation::success;
182             } else {
183                 // Validation failure is expected if a user is on a captive portal.
184                 // TODO: Trigger a second validation attempt after captive portal login
185                 // succeeds.
186                 tracker.erase(privateServer);
187             }
188         });
189         validate_thread.detach();
190     }
191 }
192 
clearPrivateDnsProviders(unsigned netId)193 void clearPrivateDnsProviders(unsigned netId) {
194     if (DBG) {
195         ALOGD("clearPrivateDnsProviders(%u)", netId);
196     }
197     std::lock_guard<std::mutex> guard(privateDnsLock);
198     privateDnsTransports.erase(netId);
199 }
200 
201 }  // namespace
202 
setDnsServers(unsigned netId,const char * searchDomains,const char ** servers,int numservers,const __res_params * params)203 int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
204         const char** servers, int numservers, const __res_params* params) {
205     if (DBG) {
206         ALOGD("setDnsServers netId = %u\n", netId);
207     }
208     checkPrivateDnsProviders(netId, servers, numservers);
209     return -_resolv_set_nameservers_for_net(netId, servers, numservers, searchDomains, params);
210 }
211 
shouldUseTls(unsigned netId,const sockaddr_storage & insecureServer,sockaddr_storage * secureServer,std::set<std::vector<uint8_t>> * fingerprints)212 bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
213         sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints) {
214     // This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
215     // If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
216     // an atomic_int32_t counter of validated connections, and returning early if it's zero.
217     std::lock_guard<std::mutex> guard(privateDnsLock);
218     const auto netPair = privateDnsTransports.find(netId);
219     if (netPair == privateDnsTransports.end()) {
220         return false;
221     }
222     const auto& tracker = netPair->second;
223     const auto serverPair = tracker.find(insecureServer);
224     if (serverPair == tracker.end() || serverPair->second != Validation::success) {
225         return false;
226     }
227     const auto& validatedServer = serverPair->first;
228     *secureServer = validatedServer.ss;
229     *fingerprints = validatedServer.fingerprints;
230     return true;
231 }
232 
clearDnsServers(unsigned netId)233 int ResolverController::clearDnsServers(unsigned netId) {
234     _resolv_set_nameservers_for_net(netId, NULL, 0, "", NULL);
235     if (DBG) {
236         ALOGD("clearDnsServers netId = %u\n", netId);
237     }
238     clearPrivateDnsProviders(netId);
239     return 0;
240 }
241 
flushDnsCache(unsigned netId)242 int ResolverController::flushDnsCache(unsigned netId) {
243     if (DBG) {
244         ALOGD("flushDnsCache netId = %u\n", netId);
245     }
246 
247     _resolv_flush_cache_for_net(netId);
248 
249     return 0;
250 }
251 
getDnsInfo(unsigned netId,std::vector<std::string> * servers,std::vector<std::string> * domains,__res_params * params,std::vector<android::net::ResolverStats> * stats)252 int ResolverController::getDnsInfo(unsigned netId, std::vector<std::string>* servers,
253         std::vector<std::string>* domains, __res_params* params,
254         std::vector<android::net::ResolverStats>* stats) {
255     using android::net::ResolverStats;
256     using android::net::INetd;
257     static_assert(ResolverStats::STATS_SUCCESSES == INetd::RESOLVER_STATS_SUCCESSES &&
258             ResolverStats::STATS_ERRORS == INetd::RESOLVER_STATS_ERRORS &&
259             ResolverStats::STATS_TIMEOUTS == INetd::RESOLVER_STATS_TIMEOUTS &&
260             ResolverStats::STATS_INTERNAL_ERRORS == INetd::RESOLVER_STATS_INTERNAL_ERRORS &&
261             ResolverStats::STATS_RTT_AVG == INetd::RESOLVER_STATS_RTT_AVG &&
262             ResolverStats::STATS_LAST_SAMPLE_TIME == INetd::RESOLVER_STATS_LAST_SAMPLE_TIME &&
263             ResolverStats::STATS_USABLE == INetd::RESOLVER_STATS_USABLE &&
264             ResolverStats::STATS_COUNT == INetd::RESOLVER_STATS_COUNT,
265             "AIDL and ResolverStats.h out of sync");
266     int nscount = -1;
267     sockaddr_storage res_servers[MAXNS];
268     int dcount = -1;
269     char res_domains[MAXDNSRCH][MAXDNSRCHPATH];
270     __res_stats res_stats[MAXNS];
271     servers->clear();
272     domains->clear();
273     *params = __res_params{};
274     stats->clear();
275     int revision_id = android_net_res_stats_get_info_for_net(netId, &nscount, res_servers, &dcount,
276             res_domains, params, res_stats);
277 
278     // If the netId is unknown (which can happen for valid net IDs for which no DNS servers have
279     // yet been configured), there is no revision ID. In this case there is no data to return.
280     if (revision_id < 0) {
281         return 0;
282     }
283 
284     // Verify that the returned data is sane.
285     if (nscount < 0 || nscount > MAXNS || dcount < 0 || dcount > MAXDNSRCH) {
286         ALOGE("%s: nscount=%d, dcount=%d", __FUNCTION__, nscount, dcount);
287         return -ENOTRECOVERABLE;
288     }
289 
290     // Determine which servers are considered usable by the resolver.
291     bool valid_servers[MAXNS];
292     std::fill_n(valid_servers, MAXNS, false);
293     android_net_res_stats_get_usable_servers(params, res_stats, nscount, valid_servers);
294 
295     // Convert the server sockaddr structures to std::string.
296     stats->resize(nscount);
297     for (int i = 0 ; i < nscount ; ++i) {
298         char hbuf[NI_MAXHOST];
299         int rv = getnameinfo(reinterpret_cast<const sockaddr*>(&res_servers[i]),
300                 sizeof(res_servers[i]), hbuf, sizeof(hbuf), nullptr, 0, NI_NUMERICHOST);
301         std::string server_str;
302         if (rv == 0) {
303             server_str.assign(hbuf);
304         } else {
305             ALOGE("getnameinfo() failed for server #%d: %s", i, gai_strerror(rv));
306             server_str.assign("<invalid>");
307         }
308         servers->push_back(std::move(server_str));
309         android::net::ResolverStats& cur_stats = (*stats)[i];
310         android_net_res_stats_aggregate(&res_stats[i], &cur_stats.successes, &cur_stats.errors,
311                 &cur_stats.timeouts, &cur_stats.internal_errors, &cur_stats.rtt_avg,
312                 &cur_stats.last_sample_time);
313         cur_stats.usable = valid_servers[i];
314     }
315 
316     // Convert the stack-allocated search domain strings to std::string.
317     for (int i = 0 ; i < dcount ; ++i) {
318         domains->push_back(res_domains[i]);
319     }
320     return 0;
321 }
322 
setResolverConfiguration(int32_t netId,const std::vector<std::string> & servers,const std::vector<std::string> & domains,const std::vector<int32_t> & params)323 int ResolverController::setResolverConfiguration(int32_t netId,
324         const std::vector<std::string>& servers, const std::vector<std::string>& domains,
325         const std::vector<int32_t>& params) {
326     using android::net::INetd;
327     if (params.size() != INetd::RESOLVER_PARAMS_COUNT) {
328         ALOGE("%s: params.size()=%zu", __FUNCTION__, params.size());
329         return -EINVAL;
330     }
331 
332     auto server_count = std::min<size_t>(MAXNS, servers.size());
333     std::vector<const char*> server_ptrs;
334     for (size_t i = 0 ; i < server_count ; ++i) {
335         server_ptrs.push_back(servers[i].c_str());
336     }
337 
338     std::string domains_str;
339     if (!domains.empty()) {
340         domains_str = domains[0];
341         for (size_t i = 1 ; i < domains.size() ; ++i) {
342             domains_str += " " + domains[i];
343         }
344     }
345 
346     __res_params res_params;
347     res_params.sample_validity = params[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY];
348     res_params.success_threshold = params[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD];
349     res_params.min_samples = params[INetd::RESOLVER_PARAMS_MIN_SAMPLES];
350     res_params.max_samples = params[INetd::RESOLVER_PARAMS_MAX_SAMPLES];
351 
352     return setDnsServers(netId, domains_str.c_str(), server_ptrs.data(), server_ptrs.size(),
353             &res_params);
354 }
355 
getResolverInfo(int32_t netId,std::vector<std::string> * servers,std::vector<std::string> * domains,std::vector<int32_t> * params,std::vector<int32_t> * stats)356 int ResolverController::getResolverInfo(int32_t netId, std::vector<std::string>* servers,
357         std::vector<std::string>* domains, std::vector<int32_t>* params,
358         std::vector<int32_t>* stats) {
359     using android::net::ResolverStats;
360     using android::net::INetd;
361     __res_params res_params;
362     std::vector<ResolverStats> res_stats;
363     int ret = getDnsInfo(netId, servers, domains, &res_params, &res_stats);
364     if (ret != 0) {
365         return ret;
366     }
367 
368     // Serialize the information for binder.
369     ResolverStats::encodeAll(res_stats, stats);
370 
371     params->resize(INetd::RESOLVER_PARAMS_COUNT);
372     (*params)[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY] = res_params.sample_validity;
373     (*params)[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD] = res_params.success_threshold;
374     (*params)[INetd::RESOLVER_PARAMS_MIN_SAMPLES] = res_params.min_samples;
375     (*params)[INetd::RESOLVER_PARAMS_MAX_SAMPLES] = res_params.max_samples;
376     return 0;
377 }
378 
dump(DumpWriter & dw,unsigned netId)379 void ResolverController::dump(DumpWriter& dw, unsigned netId) {
380     // No lock needed since Bionic's resolver locks all accessed data structures internally.
381     using android::net::ResolverStats;
382     std::vector<std::string> servers;
383     std::vector<std::string> domains;
384     __res_params params;
385     std::vector<ResolverStats> stats;
386     time_t now = time(nullptr);
387     int rv = getDnsInfo(netId, &servers, &domains, &params, &stats);
388     dw.incIndent();
389     if (rv != 0) {
390         dw.println("getDnsInfo() failed for netid %u", netId);
391     } else {
392         if (servers.empty()) {
393             dw.println("No DNS servers defined");
394         } else {
395             dw.println("DNS servers: # IP (total, successes, errors, timeouts, internal errors, "
396                     "RTT avg, last sample)");
397             dw.incIndent();
398             for (size_t i = 0 ; i < servers.size() ; ++i) {
399                 if (i < stats.size()) {
400                     const ResolverStats& s = stats[i];
401                     int total = s.successes + s.errors + s.timeouts + s.internal_errors;
402                     if (total > 0) {
403                         int time_delta = (s.last_sample_time > 0) ? now - s.last_sample_time : -1;
404                         dw.println("%s (%d, %d, %d, %d, %d, %dms, %ds)%s", servers[i].c_str(),
405                                 total, s.successes, s.errors, s.timeouts, s.internal_errors,
406                                 s.rtt_avg, time_delta, s.usable ? "" : " BROKEN");
407                     } else {
408                         dw.println("%s <no data>", servers[i].c_str());
409                     }
410                 } else {
411                     dw.println("%s <no stats>", servers[i].c_str());
412                 }
413             }
414             dw.decIndent();
415         }
416         if (domains.empty()) {
417             dw.println("No search domains defined");
418         } else {
419             std::string domains_str = android::base::Join(domains, ", ");
420             dw.println("search domains: %s", domains_str.c_str());
421         }
422         if (params.sample_validity != 0) {
423             dw.println("DNS parameters: sample validity = %us, success threshold = %u%%, "
424                     "samples (min, max) = (%u, %u)", params.sample_validity,
425                     static_cast<unsigned>(params.success_threshold),
426                     static_cast<unsigned>(params.min_samples),
427                     static_cast<unsigned>(params.max_samples));
428         }
429     }
430     dw.decIndent();
431 }
432 
addPrivateDnsServer(const std::string & server,int32_t port,const std::string & fingerprintAlgorithm,const std::set<std::vector<uint8_t>> & fingerprints)433 int ResolverController::addPrivateDnsServer(const std::string& server, int32_t port,
434         const std::string& fingerprintAlgorithm,
435         const std::set<std::vector<uint8_t>>& fingerprints) {
436     using android::net::INetd;
437     if (fingerprintAlgorithm.empty()) {
438         if (!fingerprints.empty()) {
439             return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
440         }
441     } else if (fingerprintAlgorithm.compare("SHA-256") == 0) {
442         if (fingerprints.empty()) {
443             return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
444         }
445         for (const auto& fingerprint : fingerprints) {
446             if (fingerprint.size() != SHA256_SIZE) {
447                 return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
448             }
449         }
450     } else {
451         return INetd::PRIVATE_DNS_UNKNOWN_ALGORITHM;
452     }
453     if (port <= 0 || port > 0xFFFF) {
454         return INetd::PRIVATE_DNS_BAD_PORT;
455     }
456     sockaddr_storage parsed;
457     if (!parseServer(server.c_str(), port, &parsed)) {
458         return INetd::PRIVATE_DNS_BAD_ADDRESS;
459     }
460     PrivateDnsServer privateServer(parsed);
461     privateServer.fingerprints = fingerprints;
462     std::lock_guard<std::mutex> guard(privateDnsLock);
463     // Ensure we overwrite any previous matching server.  This is necessary because equality is
464     // based only on the IP address, not the port or fingerprints.
465     privateDnsServers.erase(privateServer);
466     privateDnsServers.insert(privateServer);
467     return INetd::PRIVATE_DNS_SUCCESS;
468 }
469 
removePrivateDnsServer(const std::string & server)470 int ResolverController::removePrivateDnsServer(const std::string& server) {
471     using android::net::INetd;
472     sockaddr_storage parsed;
473     if (!parseServer(server.c_str(), 0, &parsed)) {
474         return INetd::PRIVATE_DNS_BAD_ADDRESS;
475     }
476     std::lock_guard<std::mutex> guard(privateDnsLock);
477     privateDnsServers.erase(parsed);
478     for (auto& pair : privateDnsTransports) {
479         pair.second.erase(parsed);
480     }
481     return INetd::PRIVATE_DNS_SUCCESS;
482 }
483 
484 }  // namespace net
485 }  // namespace android
486