1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "resolv"
18
19 #include "DnsTlsTransport.h"
20
21 #include <span>
22
23 #include <android-base/format.h>
24 #include <android-base/logging.h>
25 #include <android-base/result.h>
26 #include <android-base/stringprintf.h>
27 #include <arpa/inet.h>
28 #include <arpa/nameser.h>
29 #include <netdutils/Stopwatch.h>
30 #include <netdutils/ThreadUtil.h>
31 #include <private/android_filesystem_config.h> // AID_DNS
32 #include <sys/poll.h>
33
34 #include "DnsTlsSocketFactory.h"
35 #include "Experiments.h"
36 #include "IDnsTlsSocketFactory.h"
37 #include "resolv_private.h"
38 #include "util.h"
39
40 using android::base::StringPrintf;
41 using android::netdutils::setThreadName;
42
43 namespace android {
44 namespace net {
45
46 namespace {
47
48 // Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com".
makeDnsQuery()49 std::vector<uint8_t> makeDnsQuery() {
50 static const char kDnsSafeChars[] =
51 "abcdefhijklmnopqrstuvwxyz"
52 "ABCDEFHIJKLMNOPQRSTUVWXYZ"
53 "0123456789";
54 const auto c = [](uint8_t rnd) -> uint8_t {
55 return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
56 };
57 uint8_t rnd[8];
58 arc4random_buf(rnd, std::size(rnd));
59
60 return std::vector<uint8_t>{
61 rnd[6], rnd[7], // [0-1] query ID
62 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
63 0, 1, // [4-5] QDCOUNT (number of queries)
64 0, 0, // [6-7] ANCOUNT (number of answers)
65 0, 0, // [8-9] NSCOUNT (number of name server records)
66 0, 0, // [10-11] ARCOUNT (number of additional records)
67 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n',
68 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm',
69 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a',
70 't', 'i', 'c', 3, 'c', 'o', 'm',
71 0, // null terminator of FQDN (root TLD)
72 0, ns_t_aaaa, // QTYPE
73 0, ns_c_in // QCLASS
74 };
75 }
76
checkDnsResponse(const std::span<const uint8_t> answer)77 base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) {
78 if (answer.size() < NS_HFIXEDSZ) {
79 return Errorf("short response: {}", answer.size());
80 }
81
82 const int qdcount = (answer[4] << 8) | answer[5];
83 if (qdcount != 1) {
84 return Errorf("reply query count != 1: {}", qdcount);
85 }
86
87 const int ancount = (answer[6] << 8) | answer[7];
88 LOG(DEBUG) << "answer count: " << ancount;
89
90 // TODO: Further validate the response contents (check for valid AAAA record, ...).
91 // Note that currently, integration tests rely on this function accepting a
92 // response with zero records.
93
94 return {};
95 }
96
97 // Sends |query| to the given server, and returns the DNS response.
sendUdpQuery(netdutils::IPAddress ip,uint32_t mark,std::span<const uint8_t> query)98 base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark,
99 std::span<const uint8_t> query) {
100 const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53);
101 const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
102 const int nsaplen = sockaddrSize(nsap);
103 const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC;
104 android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)};
105 if (fd < 0) {
106 return ErrnoErrorf("socket failed");
107 }
108
109 resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID);
110 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) {
111 return ErrnoErrorf("setsockopt failed");
112 }
113
114 if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) {
115 return ErrnoErrorf("connect failed");
116 }
117
118 if (send(fd, query.data(), query.size(), 0) != query.size()) {
119 return ErrnoErrorf("send failed");
120 }
121
122 const int timeoutMs = 3000;
123 while (true) {
124 pollfd fds = {.fd = fd, .events = POLLIN};
125
126 const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
127 if (n == 0) {
128 return Errorf("poll timed out");
129 }
130 if (n < 0) {
131 return ErrnoErrorf("poll failed");
132 }
133 if (fds.revents & (POLLIN | POLLERR)) {
134 std::vector<uint8_t> buf(MAXPACKET);
135 const int resplen = recv(fd, buf.data(), buf.size(), 0);
136
137 if (resplen < 0) {
138 return ErrnoErrorf("recvfrom failed");
139 }
140
141 buf.resize(resplen);
142 if (auto result = checkDnsResponse(buf); !result.ok()) {
143 return Errorf("checkDnsResponse failed: {}", result.error().message());
144 }
145
146 return {};
147 }
148 }
149 }
150
151 } // namespace
152
query(const netdutils::Slice query)153 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
154 std::lock_guard guard(mLock);
155
156 auto record = mQueries.recordQuery(query);
157 if (!record) {
158 return std::async(std::launch::deferred, []{
159 return (Result) { .code = Response::internal_error };
160 });
161 }
162
163 if (!mSocket) {
164 LOG(DEBUG) << "No socket for query. Opening socket and sending.";
165 doConnect();
166 } else {
167 sendQuery(record->query);
168 }
169
170 return std::move(record->result);
171 }
172
getConnectCounter() const173 int DnsTlsTransport::getConnectCounter() const {
174 std::lock_guard guard(mLock);
175 return mConnectCounter;
176 }
177
sendQuery(const DnsTlsQueryMap::Query & q)178 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) {
179 // Strip off the ID number and send the new ID instead.
180 const bool sent = mSocket->query(q.newId, netdutils::drop(netdutils::makeSlice(q.query), 2));
181 if (sent) {
182 mQueries.markTried(q.newId);
183 }
184 return sent;
185 }
186
doConnect()187 void DnsTlsTransport::doConnect() {
188 LOG(DEBUG) << "Constructing new socket";
189 mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
190
191 bool success = true;
192 if (mSocket.get() == nullptr || !mSocket->startHandshake()) {
193 success = false;
194 }
195 mConnectCounter++;
196
197 if (success) {
198 auto queries = mQueries.getAll();
199 LOG(DEBUG) << "Initialization succeeded. Reissuing " << queries.size() << " queries.";
200 for(auto& q : queries) {
201 if (!sendQuery(q)) {
202 break;
203 }
204 }
205 } else {
206 LOG(DEBUG) << "Initialization failed.";
207 mSocket.reset();
208 LOG(DEBUG) << "Failing all pending queries.";
209 mQueries.clear();
210 }
211 }
212
onResponse(std::vector<uint8_t> response)213 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
214 mQueries.onResponse(std::move(response));
215 }
216
onClosed()217 void DnsTlsTransport::onClosed() {
218 std::lock_guard guard(mLock);
219 if (mClosing) {
220 return;
221 }
222 // Move remaining operations to a new thread.
223 // This is necessary because
224 // 1. onClosed is currently running on a thread that blocks mSocket's destructor
225 // 2. doReconnect will call that destructor
226 if (mReconnectThread) {
227 // Complete cleanup of a previous reconnect thread, if present.
228 mReconnectThread->join();
229 // Joining a thread that is trying to acquire mLock, while holding mLock,
230 // looks like it risks a deadlock. However, a deadlock will not occur because
231 // once onClosed is called, it cannot be called again until after doReconnect
232 // acquires mLock.
233 }
234 mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
235 }
236
doReconnect()237 void DnsTlsTransport::doReconnect() {
238 std::lock_guard guard(mLock);
239 setThreadName(StringPrintf("TlsReconn_%u", mMark & 0xffff).c_str());
240 if (mClosing) {
241 return;
242 }
243 mQueries.cleanup();
244 if (!mQueries.empty()) {
245 LOG(DEBUG) << "Fast reconnect to retry remaining queries";
246 doConnect();
247 } else {
248 LOG(DEBUG) << "No pending queries. Going idle.";
249 mSocket.reset();
250 }
251 }
252
~DnsTlsTransport()253 DnsTlsTransport::~DnsTlsTransport() {
254 LOG(DEBUG) << "Destructor";
255 {
256 std::lock_guard guard(mLock);
257 LOG(DEBUG) << "Locked destruction procedure";
258 mQueries.clear();
259 mClosing = true;
260 }
261 // It's possible that a reconnect thread was spawned and waiting for mLock.
262 // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
263 // but we need to wait for it to finish before allowing destruction to proceed.
264 if (mReconnectThread) {
265 LOG(DEBUG) << "Waiting for reconnect thread to terminate";
266 mReconnectThread->join();
267 mReconnectThread.reset();
268 }
269 // Ensure that the socket is destroyed, and can clean up its callback threads,
270 // before any of this object's fields become invalid.
271 mSocket.reset();
272 LOG(DEBUG) << "Destructor completed";
273 }
274
275 // static
276 // TODO: Use this function to preheat the session cache.
277 // That may require moving it to DnsTlsDispatcher.
validate(const DnsTlsServer & server,uint32_t mark)278 bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) {
279 LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark;
280
281 const std::vector<uint8_t> query = makeDnsQuery();
282 DnsTlsSocketFactory factory;
283 DnsTlsTransport transport(server, mark, &factory);
284
285 // Send the initial query to warm up the connection.
286 auto r = transport.query(netdutils::makeSlice(query)).get();
287 if (r.code != Response::success) {
288 LOG(WARNING) << "query failed";
289 return false;
290 }
291
292 if (auto result = checkDnsResponse(r.response); !result.ok()) {
293 LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
294 return false;
295 }
296
297 // If this validation is not for opportunistic mode, or the flags are not properly set,
298 // the validation is done. If not, the validation will compare DoT probe latency and
299 // UDP probe latency, and it will pass if:
300 // dot_probe_latency < latencyFactor * udp_probe_latency + latencyOffsetMs
301 //
302 // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms,
303 // DoT probe latency must less than 25 ms.
304 const bool versionHigherThanAndroidR = getApiLevel() >= 30;
305 int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor",
306 (versionHigherThanAndroidR ? 3 : -1));
307 int latencyOffsetMs = Experiments::getInstance()->getFlag(
308 "dot_validation_latency_offset_ms", (versionHigherThanAndroidR ? 100 : -1));
309 const bool shouldCompareUdpLatency =
310 server.name.empty() &&
311 (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0);
312 if (!shouldCompareUdpLatency) {
313 return true;
314 }
315
316 LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor,
317 latencyOffsetMs);
318
319 int64_t udpProbeTimeUs = 0;
320 bool udpProbeGotAnswer = false;
321 std::thread udpProbeThread([&] {
322 // Can issue another probe if the first one fails or is lost.
323 for (int i = 1; i < 3; i++) {
324 netdutils::Stopwatch stopwatch;
325 auto result = sendUdpQuery(server.addr().ip(), mark, query);
326 udpProbeTimeUs = stopwatch.timeTakenUs();
327 udpProbeGotAnswer = result.ok();
328 LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(),
329 (udpProbeGotAnswer ? "succeeded" : "failed"),
330 udpProbeTimeUs / 1000.0);
331
332 if (udpProbeGotAnswer) {
333 break;
334 }
335 LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message();
336 }
337 });
338
339 int64_t dotProbeTimeUs = 0;
340 bool dotProbeGotAnswer = false;
341 std::thread dotProbeThread([&] {
342 netdutils::Stopwatch stopwatch;
343 auto r = transport.query(netdutils::makeSlice(query)).get();
344 dotProbeTimeUs = stopwatch.timeTakenUs();
345
346 if (r.code != Response::success) {
347 LOG(WARNING) << "query failed";
348 } else {
349 if (auto result = checkDnsResponse(r.response); !result.ok()) {
350 LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
351 } else {
352 dotProbeGotAnswer = true;
353 }
354 }
355
356 LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(),
357 (dotProbeGotAnswer ? "succeeded" : "failed"),
358 dotProbeTimeUs / 1000.0);
359 });
360
361 // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false,
362 // actively cancel UDP probe thread.
363 dotProbeThread.join();
364 udpProbeThread.join();
365
366 if (!dotProbeGotAnswer) return false;
367 if (!udpProbeGotAnswer) return true;
368 return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000);
369 }
370
371 } // end of namespace net
372 } // end of namespace android
373