• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "PrivateDnsConfiguration"
18 #define DBG 0
19 
20 #include "PrivateDnsConfiguration.h"
21 
22 #include <log/log.h>
23 #include <netdb.h>
24 #include <sys/socket.h>
25 
26 #include "DnsTlsTransport.h"
27 #include "ResolverEventReporter.h"
28 #include "netd_resolv/resolv.h"
29 #include "netdutils/BackoffSequence.h"
30 
31 namespace android {
32 namespace net {
33 
addrToString(const sockaddr_storage * addr)34 std::string addrToString(const sockaddr_storage* addr) {
35     char out[INET6_ADDRSTRLEN] = {0};
36     getnameinfo((const sockaddr*) addr, sizeof(sockaddr_storage), out, INET6_ADDRSTRLEN, nullptr, 0,
37                 NI_NUMERICHOST);
38     return std::string(out);
39 }
40 
parseServer(const char * server,sockaddr_storage * parsed)41 bool parseServer(const char* server, sockaddr_storage* parsed) {
42     addrinfo hints = {.ai_family = AF_UNSPEC, .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV};
43     addrinfo* res;
44 
45     int err = getaddrinfo(server, "853", &hints, &res);
46     if (err != 0) {
47         ALOGW("Failed to parse server address (%s): %s", server, gai_strerror(err));
48         return false;
49     }
50 
51     memcpy(parsed, res->ai_addr, res->ai_addrlen);
52     freeaddrinfo(res);
53     return true;
54 }
55 
set(int32_t netId,uint32_t mark,const std::vector<std::string> & servers,const std::string & name,const std::set<std::vector<uint8_t>> & fingerprints)56 int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
57                                  const std::vector<std::string>& servers, const std::string& name,
58                                  const std::set<std::vector<uint8_t>>& fingerprints) {
59     if (DBG) {
60         ALOGD("PrivateDnsConfiguration::set(%u, 0x%x, %zu, %s, %zu)", netId, mark, servers.size(),
61               name.c_str(), fingerprints.size());
62     }
63 
64     const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
65 
66     // Parse the list of servers that has been passed in
67     std::set<DnsTlsServer> tlsServers;
68     for (size_t i = 0; i < servers.size(); ++i) {
69         sockaddr_storage parsed;
70         if (!parseServer(servers[i].c_str(), &parsed)) {
71             return -EINVAL;
72         }
73         DnsTlsServer server(parsed);
74         server.name = name;
75         server.fingerprints = fingerprints;
76         tlsServers.insert(server);
77     }
78 
79     std::lock_guard guard(mPrivateDnsLock);
80     if (explicitlyConfigured) {
81         mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
82     } else if (!tlsServers.empty()) {
83         mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
84     } else {
85         mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
86         mPrivateDnsTransports.erase(netId);
87         return 0;
88     }
89 
90     // Create the tracker if it was not present
91     auto netPair = mPrivateDnsTransports.find(netId);
92     if (netPair == mPrivateDnsTransports.end()) {
93         // No TLS tracker yet for this netId.
94         bool added;
95         std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
96         if (!added) {
97             ALOGE("Memory error while recording private DNS for netId %d", netId);
98             return -ENOMEM;
99         }
100     }
101     auto& tracker = netPair->second;
102 
103     // Remove any servers from the tracker that are not in |servers| exactly.
104     for (auto it = tracker.begin(); it != tracker.end();) {
105         if (tlsServers.count(it->first) == 0) {
106             it = tracker.erase(it);
107         } else {
108             ++it;
109         }
110     }
111 
112     // Add any new or changed servers to the tracker, and initiate async checks for them.
113     for (const auto& server : tlsServers) {
114         if (needsValidation(tracker, server)) {
115             validatePrivateDnsProvider(server, tracker, netId, mark);
116         }
117     }
118     return 0;
119 }
120 
getStatus(unsigned netId)121 PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
122     PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
123     std::lock_guard guard(mPrivateDnsLock);
124 
125     const auto mode = mPrivateDnsModes.find(netId);
126     if (mode == mPrivateDnsModes.end()) return status;
127     status.mode = mode->second;
128 
129     const auto netPair = mPrivateDnsTransports.find(netId);
130     if (netPair != mPrivateDnsTransports.end()) {
131         for (const auto& serverPair : netPair->second) {
132             if (serverPair.second == Validation::success) {
133                 status.validatedServers.push_back(serverPair.first);
134             }
135         }
136     }
137 
138     return status;
139 }
140 
getStatus(unsigned netId,ExternalPrivateDnsStatus * status)141 void PrivateDnsConfiguration::getStatus(unsigned netId, ExternalPrivateDnsStatus* status) {
142     std::lock_guard guard(mPrivateDnsLock);
143 
144     const auto mode = mPrivateDnsModes.find(netId);
145     if (mode == mPrivateDnsModes.end()) return;
146     status->mode = mode->second;
147 
148     const auto netPair = mPrivateDnsTransports.find(netId);
149     if (netPair != mPrivateDnsTransports.end()) {
150         int count = 0;
151         for (const auto& serverPair : netPair->second) {
152             status->serverStatus[count].ss = serverPair.first.ss;
153             status->serverStatus[count].hostname =
154                     serverPair.first.name.empty() ? "" : serverPair.first.name.c_str();
155             status->serverStatus[count].validation = serverPair.second;
156             count++;
157             if (count >= MAXNS) break;  // Lose the rest
158         }
159         status->numServers = count;
160     }
161 }
162 
clear(unsigned netId)163 void PrivateDnsConfiguration::clear(unsigned netId) {
164     if (DBG) {
165         ALOGD("PrivateDnsConfiguration::clear(%u)", netId);
166     }
167     std::lock_guard guard(mPrivateDnsLock);
168     mPrivateDnsModes.erase(netId);
169     mPrivateDnsTransports.erase(netId);
170 }
171 
validatePrivateDnsProvider(const DnsTlsServer & server,PrivateDnsTracker & tracker,unsigned netId,uint32_t mark)172 void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& server,
173                                                          PrivateDnsTracker& tracker, unsigned netId,
174                                                          uint32_t mark) REQUIRES(mPrivateDnsLock) {
175     if (DBG) {
176         ALOGD("validatePrivateDnsProvider(%s, %u)", addrToString(&server.ss).c_str(), netId);
177     }
178 
179     tracker[server] = Validation::in_process;
180     if (DBG) {
181         ALOGD("Server %s marked as in_process.  Tracker now has size %zu",
182               addrToString(&server.ss).c_str(), tracker.size());
183     }
184     // Note that capturing |server| and |netId| in this lambda create copies.
185     std::thread validate_thread([this, server, netId, mark] {
186         // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
187         //
188         // Start with a 1 minute delay and backoff to once per hour.
189         //
190         // Assumptions:
191         //     [1] Each TLS validation is ~10KB of certs+handshake+payload.
192         //     [2] Network typically provision clients with <=4 nameservers.
193         //     [3] Average month has 30 days.
194         //
195         // Each validation pass in a given hour is ~1.2MB of data. And 24
196         // such validation passes per day is about ~30MB per month, in the
197         // worst case. Otherwise, this will cost ~600 SYNs per month
198         // (6 SYNs per ip, 4 ips per validation pass, 24 passes per day).
199         auto backoff = netdutils::BackoffSequence<>::Builder()
200                                .withInitialRetransmissionTime(std::chrono::seconds(60))
201                                .withMaximumRetransmissionTime(std::chrono::seconds(3600))
202                                .build();
203 
204         while (true) {
205             // ::validate() is a blocking call that performs network operations.
206             // It can take milliseconds to minutes, up to the SYN retry limit.
207             const bool success = DnsTlsTransport::validate(server, netId, mark);
208             if (DBG) {
209                 ALOGD("validateDnsTlsServer returned %d for %s", success,
210                       addrToString(&server.ss).c_str());
211             }
212 
213             const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success);
214             if (!needs_reeval) {
215                 break;
216             }
217 
218             if (backoff.hasNextTimeout()) {
219                 std::this_thread::sleep_for(backoff.getNextTimeout());
220             } else {
221                 break;
222             }
223         }
224     });
225     validate_thread.detach();
226 }
227 
recordPrivateDnsValidation(const DnsTlsServer & server,unsigned netId,bool success)228 bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
229                                                          bool success) {
230     constexpr bool NEEDS_REEVALUATION = true;
231     constexpr bool DONT_REEVALUATE = false;
232 
233     std::lock_guard guard(mPrivateDnsLock);
234 
235     auto netPair = mPrivateDnsTransports.find(netId);
236     if (netPair == mPrivateDnsTransports.end()) {
237         ALOGW("netId %u was erased during private DNS validation", netId);
238         return DONT_REEVALUATE;
239     }
240 
241     const auto mode = mPrivateDnsModes.find(netId);
242     if (mode == mPrivateDnsModes.end()) {
243         ALOGW("netId %u has no private DNS validation mode", netId);
244         return DONT_REEVALUATE;
245     }
246     const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
247 
248     bool reevaluationStatus =
249             (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;
250 
251     auto& tracker = netPair->second;
252     auto serverPair = tracker.find(server);
253     if (serverPair == tracker.end()) {
254         ALOGW("Server %s was removed during private DNS validation",
255               addrToString(&server.ss).c_str());
256         success = false;
257         reevaluationStatus = DONT_REEVALUATE;
258     } else if (!(serverPair->first == server)) {
259         // TODO: It doesn't seem correct to overwrite the tracker entry for
260         // |server| down below in this circumstance... Fix this.
261         ALOGW("Server %s was changed during private DNS validation",
262               addrToString(&server.ss).c_str());
263         success = false;
264         reevaluationStatus = DONT_REEVALUATE;
265     }
266 
267     // Send a validation event to NetdEventListenerService.
268     const auto& listeners = ResolverEventReporter::getInstance().getListeners();
269     if (listeners.size() != 0) {
270         for (const auto& it : listeners) {
271             it->onPrivateDnsValidationEvent(netId, addrToString(&server.ss), server.name, success);
272         }
273         if (DBG) {
274             ALOGD("Sent validation %s event on netId %u for %s with hostname %s",
275                   success ? "success" : "failure", netId, addrToString(&server.ss).c_str(),
276                   server.name.c_str());
277         }
278     } else {
279         ALOGE("Validation event not sent since no INetdEventListener receiver is available.");
280     }
281 
282     if (success) {
283         tracker[server] = Validation::success;
284         if (DBG) {
285             ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
286                   addrToString(&server.ss).c_str(), tracker.size());
287         }
288     } else {
289         // Validation failure is expected if a user is on a captive portal.
290         // TODO: Trigger a second validation attempt after captive portal login
291         // succeeds.
292         tracker[server] = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
293                                                                      : Validation::fail;
294         if (DBG) {
295             ALOGD("Validation failed for %s!", addrToString(&server.ss).c_str());
296         }
297     }
298 
299     return reevaluationStatus;
300 }
301 
302 // Start validation for newly added servers as well as any servers that have
303 // landed in Validation::fail state. Note that servers that have failed
304 // multiple validation attempts but for which there is still a validating
305 // thread running are marked as being Validation::in_process.
needsValidation(const PrivateDnsTracker & tracker,const DnsTlsServer & server)306 bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
307                                               const DnsTlsServer& server) {
308     const auto& iter = tracker.find(server);
309     return (iter == tracker.end()) || (iter->second == Validation::fail);
310 }
311 
312 PrivateDnsConfiguration gPrivateDnsConfiguration;
313 
314 }  // namespace net
315 }  // namespace android
316