1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns/DnsTlsTransport.h"
18
19 #include <arpa/inet.h>
20 #include <arpa/nameser.h>
21 #include <errno.h>
22 #include <openssl/err.h>
23 #include <openssl/ssl.h>
24 #include <stdlib.h>
25
26 #define LOG_TAG "DnsTlsTransport"
27 #define DBG 0
28
29 #include "log/log.h"
30 #include "Fwmark.h"
31 #undef ADD // already defined in nameser.h
32 #include "NetdConstants.h"
33 #include "Permission.h"
34
35
36 namespace android {
37 namespace net {
38
39 namespace {
40
setNonBlocking(int fd,bool enabled)41 bool setNonBlocking(int fd, bool enabled) {
42 int flags = fcntl(fd, F_GETFL);
43 if (flags < 0) return false;
44
45 if (enabled) {
46 flags |= O_NONBLOCK;
47 } else {
48 flags &= ~O_NONBLOCK;
49 }
50 return (fcntl(fd, F_SETFL, flags) == 0);
51 }
52
waitForReading(int fd)53 int waitForReading(int fd) {
54 fd_set fds;
55 FD_ZERO(&fds);
56 FD_SET(fd, &fds);
57 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
58 if (DBG && ret <= 0) {
59 ALOGD("select");
60 }
61 return ret;
62 }
63
waitForWriting(int fd)64 int waitForWriting(int fd) {
65 fd_set fds;
66 FD_ZERO(&fds);
67 FD_SET(fd, &fds);
68 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
69 if (DBG && ret <= 0) {
70 ALOGD("select");
71 }
72 return ret;
73 }
74
75 } // namespace
76
makeConnectedSocket() const77 android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
78 android::base::unique_fd fd;
79 int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
80 switch (mProtocol) {
81 case IPPROTO_TCP:
82 type |= SOCK_STREAM;
83 break;
84 default:
85 errno = EPROTONOSUPPORT;
86 return fd;
87 }
88
89 fd.reset(socket(mAddr.ss_family, type, mProtocol));
90 if (fd.get() == -1) {
91 return fd;
92 }
93
94 const socklen_t len = sizeof(mMark);
95 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
96 fd.reset();
97 } else if (connect(fd.get(),
98 reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
99 && errno != EINPROGRESS) {
100 fd.reset();
101 }
102
103 return fd;
104 }
105
getSPKIDigest(const X509 * cert,std::vector<uint8_t> * out)106 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
107 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
108 unsigned char spki[spki_len];
109 unsigned char* temp = spki;
110 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
111 ALOGW("SPKI length mismatch");
112 return false;
113 }
114 out->resize(SHA256_SIZE);
115 unsigned int digest_len = 0;
116 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
117 if (ret != 1) {
118 ALOGW("Server cert digest extraction failed");
119 return false;
120 }
121 if (digest_len != out->size()) {
122 ALOGW("Wrong digest length: %d", digest_len);
123 return false;
124 }
125 return true;
126 }
127
sslConnect(int fd)128 SSL* DnsTlsTransport::sslConnect(int fd) {
129 if (fd < 0) {
130 ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
131 return nullptr;
132 }
133
134 // Set up TLS context.
135 bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
136 if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
137 !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
138 ALOGD("failed to min/max TLS versions");
139 return nullptr;
140 }
141
142 bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
143 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
144 SSL_set_bio(ssl.get(), bio.get(), bio.get());
145 bio.release();
146
147 if (!setNonBlocking(fd, false)) {
148 ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
149 return nullptr;
150 }
151
152 for (;;) {
153 if (DBG) {
154 ALOGD("%u Calling SSL_connect", mMark);
155 }
156 int ret = SSL_connect(ssl.get());
157 if (DBG) {
158 ALOGD("%u SSL_connect returned %d", mMark, ret);
159 }
160 if (ret == 1) break; // SSL handshake complete;
161
162 const int ssl_err = SSL_get_error(ssl.get(), ret);
163 switch (ssl_err) {
164 case SSL_ERROR_WANT_READ:
165 if (waitForReading(fd) != 1) {
166 ALOGW("SSL_connect read error");
167 return nullptr;
168 }
169 break;
170 case SSL_ERROR_WANT_WRITE:
171 if (waitForWriting(fd) != 1) {
172 ALOGW("SSL_connect write error");
173 return nullptr;
174 }
175 break;
176 default:
177 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
178 return nullptr;
179 }
180 }
181
182 if (!mFingerprints.empty()) {
183 if (DBG) {
184 ALOGD("Checking DNS over TLS fingerprint");
185 }
186 // TODO: Follow the cert chain and check all the way up.
187 bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get()));
188 if (!cert) {
189 ALOGW("Server has null certificate");
190 return nullptr;
191 }
192 std::vector<uint8_t> digest;
193 if (!getSPKIDigest(cert.get(), &digest)) {
194 ALOGE("Digest computation failed");
195 return nullptr;
196 }
197
198 if (mFingerprints.count(digest) == 0) {
199 ALOGW("No matching fingerprint");
200 return nullptr;
201 }
202 if (DBG) {
203 ALOGD("DNS over TLS fingerprint is correct");
204 }
205 }
206
207 if (DBG) {
208 ALOGD("%u handshake complete", mMark);
209 }
210 return ssl.release();
211 }
212
sslWrite(int fd,SSL * ssl,const uint8_t * buffer,int len)213 bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
214 if (DBG) {
215 ALOGD("%u Writing %d bytes", mMark, len);
216 }
217 for (;;) {
218 int ret = SSL_write(ssl, buffer, len);
219 if (ret == len) break; // SSL write complete;
220
221 if (ret < 1) {
222 const int ssl_err = SSL_get_error(ssl, ret);
223 switch (ssl_err) {
224 case SSL_ERROR_WANT_WRITE:
225 if (waitForWriting(fd) != 1) {
226 if (DBG) {
227 ALOGW("SSL_write error");
228 }
229 return false;
230 }
231 continue;
232 case 0:
233 break; // SSL write complete;
234 default:
235 if (DBG) {
236 ALOGW("SSL_write error %d", ssl_err);
237 }
238 return false;
239 }
240 }
241 }
242 if (DBG) {
243 ALOGD("%u Wrote %d bytes", mMark, len);
244 }
245 return true;
246 }
247
248 // Read exactly len bytes into buffer or fail
sslRead(int fd,SSL * ssl,uint8_t * buffer,int len)249 bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
250 int remaining = len;
251 while (remaining > 0) {
252 int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
253 if (ret == 0) {
254 ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
255 return false;
256 }
257
258 if (ret < 0) {
259 const int ssl_err = SSL_get_error(ssl, ret);
260 if (ssl_err == SSL_ERROR_WANT_READ) {
261 if (waitForReading(fd) != 1) {
262 if (DBG) {
263 ALOGW("SSL_read error");
264 }
265 return false;
266 }
267 continue;
268 } else {
269 if (DBG) {
270 ALOGW("SSL_read error %d", ssl_err);
271 }
272 return false;
273 }
274 }
275
276 remaining -= ret;
277 }
278 return true;
279 }
280
doQuery(const uint8_t * query,size_t qlen,uint8_t * response,size_t limit,int * resplen)281 DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
282 uint8_t *response, size_t limit, int *resplen) {
283 *resplen = 0; // Zero indicates an error.
284
285 if (DBG) {
286 ALOGD("%u connecting TCP socket", mMark);
287 }
288 android::base::unique_fd fd(makeConnectedSocket());
289 if (DBG) {
290 ALOGD("%u connecting SSL", mMark);
291 }
292 bssl::UniquePtr<SSL> ssl(sslConnect(fd));
293 if (ssl == nullptr) {
294 if (DBG) {
295 ALOGW("%u SSL connection failed", mMark);
296 }
297 return Response::network_error;
298 }
299
300 uint8_t queryHeader[2];
301 queryHeader[0] = qlen >> 8;
302 queryHeader[1] = qlen;
303 if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
304 return Response::network_error;
305 }
306 if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
307 return Response::network_error;
308 }
309 if (DBG) {
310 ALOGD("%u SSL_write complete", mMark);
311 }
312
313 uint8_t responseHeader[2];
314 if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
315 if (DBG) {
316 ALOGW("%u Failed to read 2-byte length header", mMark);
317 }
318 return Response::network_error;
319 }
320 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
321 if (DBG) {
322 ALOGD("%u Expecting response of size %i", mMark, responseSize);
323 }
324 if (responseSize > limit) {
325 ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
326 return Response::limit_error;
327 }
328 if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
329 if (DBG) {
330 ALOGW("%u Failed to read %i bytes", mMark, responseSize);
331 }
332 return Response::network_error;
333 }
334 if (DBG) {
335 ALOGD("%u SSL_read complete", mMark);
336 }
337
338 if (response[0] != query[0] || response[1] != query[1]) {
339 ALOGE("reply query ID != query ID");
340 return Response::internal_error;
341 }
342
343 SSL_shutdown(ssl.get());
344
345 *resplen = responseSize;
346 return Response::success;
347 }
348
validateDnsTlsServer(unsigned netid,const struct sockaddr_storage & ss,const std::set<std::vector<uint8_t>> & fingerprints)349 bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
350 const std::set<std::vector<uint8_t>>& fingerprints) {
351 if (DBG) {
352 ALOGD("Beginning validation on %u", netid);
353 }
354 // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
355 // order to prove that it is actually a working DNS over TLS server.
356 static const char kDnsSafeChars[] =
357 "abcdefhijklmnopqrstuvwxyz"
358 "ABCDEFHIJKLMNOPQRSTUVWXYZ"
359 "0123456789";
360 const auto c = [](uint8_t rnd) -> uint8_t {
361 return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
362 };
363 uint8_t rnd[8];
364 arc4random_buf(rnd, ARRAY_SIZE(rnd));
365 // We could try to use res_mkquery() here, but it's basically the same.
366 uint8_t query[] = {
367 rnd[6], rnd[7], // [0-1] query ID
368 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
369 0, 1, // [4-5] QDCOUNT (number of queries)
370 0, 0, // [6-7] ANCOUNT (number of answers)
371 0, 0, // [8-9] NSCOUNT (number of name server records)
372 0, 0, // [10-11] ARCOUNT (number of additional records)
373 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
374 '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
375 6, 'm', 'e', 't', 'r', 'i', 'c',
376 7, 'g', 's', 't', 'a', 't', 'i', 'c',
377 3, 'c', 'o', 'm',
378 0, // null terminator of FQDN (root TLD)
379 0, ns_t_aaaa, // QTYPE
380 0, ns_c_in // QCLASS
381 };
382 const int qlen = ARRAY_SIZE(query);
383
384 const int kRecvBufSize = 4 * 1024;
385 uint8_t recvbuf[kRecvBufSize];
386
387 // At validation time, we only know the netId, so we have to guess/compute the
388 // corresponding socket mark.
389 Fwmark fwmark;
390 fwmark.permission = PERMISSION_SYSTEM;
391 fwmark.explicitlySelected = true;
392 fwmark.protectedFromVpn = true;
393 fwmark.netId = netid;
394 unsigned mark = fwmark.intValue;
395 DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
396 int replylen = 0;
397 xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
398 if (replylen == 0) {
399 if (DBG) {
400 ALOGD("doQuery failed");
401 }
402 return false;
403 }
404
405 if (replylen < NS_HFIXEDSZ) {
406 if (DBG) {
407 ALOGW("short response: %d", replylen);
408 }
409 return false;
410 }
411
412 const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
413 if (qdcount != 1) {
414 ALOGW("reply query count != 1: %d", qdcount);
415 return false;
416 }
417
418 const int ancount = (recvbuf[6] << 8) | recvbuf[7];
419 if (DBG) {
420 ALOGD("%u answer count: %d", netid, ancount);
421 }
422
423 // TODO: Further validate the response contents (check for valid AAAA record, ...).
424 // Note that currently, integration tests rely on this function accepting a
425 // response with zero records.
426 #if 0
427 for (int i = 0; i < resplen; i++) {
428 ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
429 }
430 #endif
431 return true;
432 }
433
434 } // namespace net
435 } // namespace android
436