• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns/DnsTlsTransport.h"
18 
19 #include <arpa/inet.h>
20 #include <arpa/nameser.h>
21 #include <errno.h>
22 #include <openssl/err.h>
23 #include <openssl/ssl.h>
24 #include <stdlib.h>
25 
26 #define LOG_TAG "DnsTlsTransport"
27 #define DBG 0
28 
29 #include "log/log.h"
30 #include "Fwmark.h"
31 #undef ADD  // already defined in nameser.h
32 #include "NetdConstants.h"
33 #include "Permission.h"
34 
35 
36 namespace android {
37 namespace net {
38 
39 namespace {
40 
setNonBlocking(int fd,bool enabled)41 bool setNonBlocking(int fd, bool enabled) {
42     int flags = fcntl(fd, F_GETFL);
43     if (flags < 0) return false;
44 
45     if (enabled) {
46         flags |= O_NONBLOCK;
47     } else {
48         flags &= ~O_NONBLOCK;
49     }
50     return (fcntl(fd, F_SETFL, flags) == 0);
51 }
52 
waitForReading(int fd)53 int waitForReading(int fd) {
54     fd_set fds;
55     FD_ZERO(&fds);
56     FD_SET(fd, &fds);
57     const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
58     if (DBG && ret <= 0) {
59         ALOGD("select");
60     }
61     return ret;
62 }
63 
waitForWriting(int fd)64 int waitForWriting(int fd) {
65     fd_set fds;
66     FD_ZERO(&fds);
67     FD_SET(fd, &fds);
68     const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
69     if (DBG && ret <= 0) {
70         ALOGD("select");
71     }
72     return ret;
73 }
74 
75 }  // namespace
76 
makeConnectedSocket() const77 android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
78     android::base::unique_fd fd;
79     int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
80     switch (mProtocol) {
81         case IPPROTO_TCP:
82             type |= SOCK_STREAM;
83             break;
84         default:
85             errno = EPROTONOSUPPORT;
86             return fd;
87     }
88 
89     fd.reset(socket(mAddr.ss_family, type, mProtocol));
90     if (fd.get() == -1) {
91         return fd;
92     }
93 
94     const socklen_t len = sizeof(mMark);
95     if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
96         fd.reset();
97     } else if (connect(fd.get(),
98             reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
99         && errno != EINPROGRESS) {
100         fd.reset();
101     }
102 
103     return fd;
104 }
105 
getSPKIDigest(const X509 * cert,std::vector<uint8_t> * out)106 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
107     int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
108     unsigned char spki[spki_len];
109     unsigned char* temp = spki;
110     if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
111         ALOGW("SPKI length mismatch");
112         return false;
113     }
114     out->resize(SHA256_SIZE);
115     unsigned int digest_len = 0;
116     int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
117     if (ret != 1) {
118         ALOGW("Server cert digest extraction failed");
119         return false;
120     }
121     if (digest_len != out->size()) {
122         ALOGW("Wrong digest length: %d", digest_len);
123         return false;
124     }
125     return true;
126 }
127 
sslConnect(int fd)128 SSL* DnsTlsTransport::sslConnect(int fd) {
129     if (fd < 0) {
130         ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
131         return nullptr;
132     }
133 
134     // Set up TLS context.
135     bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
136     if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
137         !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
138         ALOGD("failed to min/max TLS versions");
139         return nullptr;
140     }
141 
142     bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
143     bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
144     SSL_set_bio(ssl.get(), bio.get(), bio.get());
145     bio.release();
146 
147     if (!setNonBlocking(fd, false)) {
148         ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
149         return nullptr;
150     }
151 
152     for (;;) {
153         if (DBG) {
154             ALOGD("%u Calling SSL_connect", mMark);
155         }
156         int ret = SSL_connect(ssl.get());
157         if (DBG) {
158             ALOGD("%u SSL_connect returned %d", mMark, ret);
159         }
160         if (ret == 1) break;  // SSL handshake complete;
161 
162         const int ssl_err = SSL_get_error(ssl.get(), ret);
163         switch (ssl_err) {
164             case SSL_ERROR_WANT_READ:
165                 if (waitForReading(fd) != 1) {
166                     ALOGW("SSL_connect read error");
167                     return nullptr;
168                 }
169                 break;
170             case SSL_ERROR_WANT_WRITE:
171                 if (waitForWriting(fd) != 1) {
172                     ALOGW("SSL_connect write error");
173                     return nullptr;
174                 }
175                 break;
176             default:
177                 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
178                 return nullptr;
179         }
180     }
181 
182     if (!mFingerprints.empty()) {
183         if (DBG) {
184             ALOGD("Checking DNS over TLS fingerprint");
185         }
186         // TODO: Follow the cert chain and check all the way up.
187         bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get()));
188         if (!cert) {
189             ALOGW("Server has null certificate");
190             return nullptr;
191         }
192         std::vector<uint8_t> digest;
193         if (!getSPKIDigest(cert.get(), &digest)) {
194             ALOGE("Digest computation failed");
195             return nullptr;
196         }
197 
198         if (mFingerprints.count(digest) == 0) {
199             ALOGW("No matching fingerprint");
200             return nullptr;
201         }
202         if (DBG) {
203             ALOGD("DNS over TLS fingerprint is correct");
204         }
205     }
206 
207     if (DBG) {
208         ALOGD("%u handshake complete", mMark);
209     }
210     return ssl.release();
211 }
212 
sslWrite(int fd,SSL * ssl,const uint8_t * buffer,int len)213 bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
214     if (DBG) {
215         ALOGD("%u Writing %d bytes", mMark, len);
216     }
217     for (;;) {
218         int ret = SSL_write(ssl, buffer, len);
219         if (ret == len) break;  // SSL write complete;
220 
221         if (ret < 1) {
222             const int ssl_err = SSL_get_error(ssl, ret);
223             switch (ssl_err) {
224                 case SSL_ERROR_WANT_WRITE:
225                     if (waitForWriting(fd) != 1) {
226                         if (DBG) {
227                             ALOGW("SSL_write error");
228                         }
229                         return false;
230                     }
231                     continue;
232                 case 0:
233                     break;  // SSL write complete;
234                 default:
235                     if (DBG) {
236                         ALOGW("SSL_write error %d", ssl_err);
237                     }
238                     return false;
239             }
240         }
241     }
242     if (DBG) {
243         ALOGD("%u Wrote %d bytes", mMark, len);
244     }
245     return true;
246 }
247 
248 // Read exactly len bytes into buffer or fail
sslRead(int fd,SSL * ssl,uint8_t * buffer,int len)249 bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
250     int remaining = len;
251     while (remaining > 0) {
252         int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
253         if (ret == 0) {
254             ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
255             return false;
256         }
257 
258         if (ret < 0) {
259             const int ssl_err = SSL_get_error(ssl, ret);
260             if (ssl_err == SSL_ERROR_WANT_READ) {
261                 if (waitForReading(fd) != 1) {
262                     if (DBG) {
263                         ALOGW("SSL_read error");
264                     }
265                     return false;
266                 }
267                 continue;
268             } else {
269                 if (DBG) {
270                     ALOGW("SSL_read error %d", ssl_err);
271                 }
272                 return false;
273             }
274         }
275 
276         remaining -= ret;
277     }
278     return true;
279 }
280 
doQuery(const uint8_t * query,size_t qlen,uint8_t * response,size_t limit,int * resplen)281 DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
282         uint8_t *response, size_t limit, int *resplen) {
283     *resplen = 0;  // Zero indicates an error.
284 
285     if (DBG) {
286         ALOGD("%u connecting TCP socket", mMark);
287     }
288     android::base::unique_fd fd(makeConnectedSocket());
289     if (DBG) {
290         ALOGD("%u connecting SSL", mMark);
291     }
292     bssl::UniquePtr<SSL> ssl(sslConnect(fd));
293     if (ssl == nullptr) {
294         if (DBG) {
295             ALOGW("%u SSL connection failed", mMark);
296         }
297         return Response::network_error;
298     }
299 
300     uint8_t queryHeader[2];
301     queryHeader[0] = qlen >> 8;
302     queryHeader[1] = qlen;
303     if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
304         return Response::network_error;
305     }
306     if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
307         return Response::network_error;
308     }
309     if (DBG) {
310         ALOGD("%u SSL_write complete", mMark);
311     }
312 
313     uint8_t responseHeader[2];
314     if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
315         if (DBG) {
316             ALOGW("%u Failed to read 2-byte length header", mMark);
317         }
318         return Response::network_error;
319     }
320     const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
321     if (DBG) {
322         ALOGD("%u Expecting response of size %i", mMark, responseSize);
323     }
324     if (responseSize > limit) {
325         ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
326         return Response::limit_error;
327     }
328     if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
329         if (DBG) {
330             ALOGW("%u Failed to read %i bytes", mMark, responseSize);
331         }
332         return Response::network_error;
333     }
334     if (DBG) {
335         ALOGD("%u SSL_read complete", mMark);
336     }
337 
338     if (response[0] != query[0] || response[1] != query[1]) {
339         ALOGE("reply query ID != query ID");
340         return Response::internal_error;
341     }
342 
343     SSL_shutdown(ssl.get());
344 
345     *resplen = responseSize;
346     return Response::success;
347 }
348 
validateDnsTlsServer(unsigned netid,const struct sockaddr_storage & ss,const std::set<std::vector<uint8_t>> & fingerprints)349 bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
350         const std::set<std::vector<uint8_t>>& fingerprints) {
351     if (DBG) {
352         ALOGD("Beginning validation on %u", netid);
353     }
354     // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
355     // order to prove that it is actually a working DNS over TLS server.
356     static const char kDnsSafeChars[] =
357             "abcdefhijklmnopqrstuvwxyz"
358             "ABCDEFHIJKLMNOPQRSTUVWXYZ"
359             "0123456789";
360     const auto c = [](uint8_t rnd) -> uint8_t {
361         return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
362     };
363     uint8_t rnd[8];
364     arc4random_buf(rnd, ARRAY_SIZE(rnd));
365     // We could try to use res_mkquery() here, but it's basically the same.
366     uint8_t query[] = {
367         rnd[6], rnd[7],  // [0-1]   query ID
368         1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
369         0, 1,  // [4-5]   QDCOUNT (number of queries)
370         0, 0,  // [6-7]   ANCOUNT (number of answers)
371         0, 0,  // [8-9]   NSCOUNT (number of name server records)
372         0, 0,  // [10-11] ARCOUNT (number of additional records)
373         17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
374             '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
375         6, 'm', 'e', 't', 'r', 'i', 'c',
376         7, 'g', 's', 't', 'a', 't', 'i', 'c',
377         3, 'c', 'o', 'm',
378         0,  // null terminator of FQDN (root TLD)
379         0, ns_t_aaaa,  // QTYPE
380         0, ns_c_in     // QCLASS
381     };
382     const int qlen = ARRAY_SIZE(query);
383 
384     const int kRecvBufSize = 4 * 1024;
385     uint8_t recvbuf[kRecvBufSize];
386 
387     // At validation time, we only know the netId, so we have to guess/compute the
388     // corresponding socket mark.
389     Fwmark fwmark;
390     fwmark.permission = PERMISSION_SYSTEM;
391     fwmark.explicitlySelected = true;
392     fwmark.protectedFromVpn = true;
393     fwmark.netId = netid;
394     unsigned mark = fwmark.intValue;
395     DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
396     int replylen = 0;
397     xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
398     if (replylen == 0) {
399         if (DBG) {
400             ALOGD("doQuery failed");
401         }
402         return false;
403     }
404 
405     if (replylen < NS_HFIXEDSZ) {
406         if (DBG) {
407             ALOGW("short response: %d", replylen);
408         }
409         return false;
410     }
411 
412     const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
413     if (qdcount != 1) {
414         ALOGW("reply query count != 1: %d", qdcount);
415         return false;
416     }
417 
418     const int ancount = (recvbuf[6] << 8) | recvbuf[7];
419     if (DBG) {
420         ALOGD("%u answer count: %d", netid, ancount);
421     }
422 
423     // TODO: Further validate the response contents (check for valid AAAA record, ...).
424     // Note that currently, integration tests rely on this function accepting a
425     // response with zero records.
426 #if 0
427     for (int i = 0; i < resplen; i++) {
428         ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
429     }
430 #endif
431     return true;
432 }
433 
434 }  // namespace net
435 }  // namespace android
436