• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_tls_frontend.h"
18 
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <openssl/err.h>
22 #include <openssl/evp.h>
23 #include <openssl/ssl.h>
24 #include <sys/eventfd.h>
25 #include <sys/poll.h>
26 #include <sys/socket.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29 
30 #define LOG_TAG "DnsTlsFrontend"
31 #include <log/log.h>
32 #include <netdutils/SocketOption.h>
33 
34 #include "NetdConstants.h"  // SHA256_SIZE
35 
36 using android::netdutils::enableSockopt;
37 
38 namespace {
39 
40 // Copied from DnsTlsTransport.
getSPKIDigest(const X509 * cert,std::vector<uint8_t> * out)41 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
42     int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr);
43     unsigned char spki[spki_len];
44     unsigned char* temp = spki;
45     if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
46         ALOGE("SPKI length mismatch");
47         return false;
48     }
49     out->resize(SHA256_SIZE);
50     unsigned int digest_len = 0;
51     int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr);
52     if (ret != 1) {
53         ALOGE("Server cert digest extraction failed");
54         return false;
55     }
56     if (digest_len != out->size()) {
57         ALOGE("Wrong digest length: %d", digest_len);
58         return false;
59     }
60     return true;
61 }
62 
errno2str()63 std::string errno2str() {
64     char error_msg[512] = { 0 };
65     return strerror_r(errno, error_msg, sizeof(error_msg));
66 }
67 
68 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
69 
addr2str(const sockaddr * sa,socklen_t sa_len)70 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
71     char host_str[NI_MAXHOST] = { 0 };
72     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
73                          NI_NUMERICHOST);
74     if (rv == 0) return std::string(host_str);
75     return std::string();
76 }
77 
make_private_key()78 bssl::UniquePtr<EVP_PKEY> make_private_key() {
79     bssl::UniquePtr<BIGNUM> e(BN_new());
80     if (!e) {
81         ALOGE("BN_new failed");
82         return nullptr;
83     }
84     if (!BN_set_word(e.get(), RSA_F4)) {
85         ALOGE("BN_set_word failed");
86         return nullptr;
87     }
88 
89     bssl::UniquePtr<RSA> rsa(RSA_new());
90     if (!rsa) {
91         ALOGE("RSA_new failed");
92         return nullptr;
93     }
94     if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), nullptr)) {
95         ALOGE("RSA_generate_key_ex failed");
96         return nullptr;
97     }
98 
99     bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
100     if (!privkey) {
101         ALOGE("EVP_PKEY_new failed");
102         return nullptr;
103     }
104     if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
105         ALOGE("EVP_PKEY_assign_RSA failed");
106         return nullptr;
107     }
108 
109     // |rsa| is now owned by |privkey|, so no need to free it.
110     rsa.release();
111     return privkey;
112 }
113 
make_cert(EVP_PKEY * privkey,EVP_PKEY * parent_key)114 bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey, EVP_PKEY* parent_key) {
115     bssl::UniquePtr<X509> cert(X509_new());
116     if (!cert) {
117         ALOGE("X509_new failed");
118         return nullptr;
119     }
120 
121     ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
122 
123     // Set one hour expiration.
124     X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
125     X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
126 
127     X509_set_pubkey(cert.get(), privkey);
128 
129     if (!X509_sign(cert.get(), parent_key, EVP_sha256())) {
130         ALOGE("X509_sign failed");
131         return nullptr;
132     }
133 
134     return cert;
135 }
136 
137 }
138 
139 namespace test {
140 
startServer()141 bool DnsTlsFrontend::startServer() {
142     SSL_load_error_strings();
143     OpenSSL_add_ssl_algorithms();
144 
145     // reset queries_ to 0 every time startServer called
146     // which would help us easy to check queries_ via calling waitForQueries
147     queries_ = 0;
148 
149     ctx_.reset(SSL_CTX_new(TLS_server_method()));
150     if (!ctx_) {
151         ALOGE("SSL context creation failed");
152         return false;
153     }
154 
155     SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
156 
157     // Make certificate chain
158     std::vector<bssl::UniquePtr<EVP_PKEY>> keys(chain_length_);
159     for (int i = 0; i < chain_length_; ++i) {
160         keys[i] = make_private_key();
161     }
162     std::vector<bssl::UniquePtr<X509>> certs(chain_length_);
163     for (int i = 0; i < chain_length_; ++i) {
164         int next = std::min(i + 1, chain_length_ - 1);
165         certs[i] = make_cert(keys[i].get(), keys[next].get());
166     }
167 
168     // Install certificate chain.
169     if (SSL_CTX_use_certificate(ctx_.get(), certs[0].get()) <= 0) {
170         ALOGE("SSL_CTX_use_certificate failed");
171         return false;
172     }
173     if (SSL_CTX_use_PrivateKey(ctx_.get(), keys[0].get()) <= 0 ) {
174         ALOGE("SSL_CTX_use_PrivateKey failed");
175         return false;
176     }
177     for (int i = 1; i < chain_length_; ++i) {
178         if (SSL_CTX_add1_chain_cert(ctx_.get(), certs[i].get()) != 1) {
179             ALOGE("SSL_CTX_add1_chain_cert failed");
180             return false;
181         }
182     }
183 
184     // Report the fingerprint of the "middle" cert.  For N = 2, this is the root.
185     int fp_index = chain_length_ / 2;
186     if (!getSPKIDigest(certs[fp_index].get(), &fingerprint_)) {
187         ALOGE("getSPKIDigest failed");
188         return false;
189     }
190 
191     // Set up TCP server socket for clients.
192     addrinfo frontend_ai_hints{
193         .ai_family = AF_UNSPEC,
194         .ai_socktype = SOCK_STREAM,
195         .ai_flags = AI_PASSIVE
196     };
197     addrinfo* frontend_ai_res = nullptr;
198     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
199                          &frontend_ai_hints, &frontend_ai_res);
200     ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res);
201     if (rv) {
202         ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
203             listen_service_.c_str(), gai_strerror(rv));
204         return false;
205     }
206 
207     for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
208         android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
209         if (s.get() < 0) {
210             APLOGI("ignore creating socket failed %d", s.get());
211             continue;
212         }
213         enableSockopt(s.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
214         enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
215         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
216         if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
217             APLOGI("failed to bind TCP %s:%s", host_str.c_str(), listen_service_.c_str());
218             continue;
219         }
220         ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
221         socket_ = std::move(s);
222         break;
223     }
224 
225     if (listen(socket_.get(), 1) < 0) {
226         APLOGI("failed to listen socket %d", socket_.get());
227         return false;
228     }
229 
230     // Set up UDP client socket to backend.
231     addrinfo backend_ai_hints{
232         .ai_family = AF_UNSPEC,
233         .ai_socktype = SOCK_DGRAM
234     };
235     addrinfo* backend_ai_res = nullptr;
236     rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
237                          &backend_ai_hints, &backend_ai_res);
238     ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res);
239     if (rv) {
240         ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
241             listen_service_.c_str(), gai_strerror(rv));
242         return false;
243     }
244     backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
245                                  backend_ai_res->ai_protocol));
246     if (backend_socket_.get() < 0) {
247         APLOGI("backend socket %d creation failed", backend_socket_.get());
248         return false;
249     }
250 
251     // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of
252     // no backend server. Don't check it.
253     connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
254 
255     // Set up eventfd socket.
256     event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
257     if (event_fd_.get() == -1) {
258         APLOGI("failed to create eventfd %d", event_fd_.get());
259         return false;
260     }
261 
262     {
263         std::lock_guard lock(update_mutex_);
264         handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
265     }
266     ALOGI("server started successfully");
267     return true;
268 }
269 
requestHandler()270 void DnsTlsFrontend::requestHandler() {
271     ALOGD("Request handler started");
272     enum { EVENT_FD = 0, LISTEN_FD = 1 };
273     pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
274                      {.fd = socket_.get(), .events = POLLIN}};
275 
276     while (true) {
277         int poll_code = poll(fds, std::size(fds), -1);
278         if (poll_code <= 0) {
279             APLOGI("Poll failed with error %d", poll_code);
280             break;
281         }
282 
283         if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) {
284             handleEventFd();
285             break;
286         }
287         if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) {
288             sockaddr_storage addr;
289             socklen_t len = sizeof(addr);
290 
291             ALOGD("Trying to accept a client");
292             android::base::unique_fd client(
293                     accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC));
294             if (client.get() < 0) {
295                 // Stop
296                 APLOGI("failed to accept client socket %d", client.get());
297                 break;
298             }
299 
300             bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
301             SSL_set_fd(ssl.get(), client.get());
302 
303             ALOGD("Doing SSL handshake");
304             bool success = false;
305             if (SSL_accept(ssl.get()) <= 0) {
306                 ALOGI("SSL negotiation failure");
307             } else {
308                 ALOGD("SSL handshake complete");
309                 success = handleOneRequest(ssl.get());
310             }
311 
312             if (success) {
313                 // Increment queries_ as late as possible, because it represents
314                 // a query that is fully processed, and the response returned to the
315                 // client, including cleanup actions.
316                 ++queries_;
317             }
318         }
319     }
320     ALOGD("Ending loop");
321 }
322 
handleOneRequest(SSL * ssl)323 bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
324     uint8_t queryHeader[2];
325     if (SSL_read(ssl, &queryHeader, 2) != 2) {
326         ALOGI("Not enough header bytes");
327         return false;
328     }
329     const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
330     uint8_t query[qlen];
331     size_t qbytes = 0;
332     while (qbytes < qlen) {
333         int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
334         if (ret <= 0) {
335             ALOGI("Error while reading query");
336             return false;
337         }
338         qbytes += ret;
339     }
340     int sent = send(backend_socket_.get(), query, qlen, 0);
341     if (sent != qlen) {
342         ALOGI("Failed to send query");
343         return false;
344     }
345     const int max_size = 4096;
346     uint8_t recv_buffer[max_size];
347     int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
348     if (rlen <= 0) {
349         ALOGI("Failed to receive response");
350         return false;
351     }
352     uint8_t responseHeader[2];
353     responseHeader[0] = rlen >> 8;
354     responseHeader[1] = rlen;
355     if (SSL_write(ssl, responseHeader, 2) != 2) {
356         ALOGI("Failed to write response header");
357         return false;
358     }
359     if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
360         ALOGI("Failed to write response body");
361         return false;
362     }
363     return true;
364 }
365 
stopServer()366 bool DnsTlsFrontend::stopServer() {
367     std::lock_guard lock(update_mutex_);
368     if (!running()) {
369         ALOGI("server not running");
370         return false;
371     }
372 
373     ALOGI("stopping frontend");
374     if (!sendToEventFd()) {
375         return false;
376     }
377     handler_thread_.join();
378     socket_.reset();
379     backend_socket_.reset();
380     event_fd_.reset();
381     ctx_.reset();
382     fingerprint_.clear();
383     ALOGI("frontend stopped successfully");
384     return true;
385 }
386 
waitForQueries(int number,int timeoutMs) const387 bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
388     constexpr int intervalMs = 20;
389     int limit = timeoutMs / intervalMs;
390     for (int count = 0; count <= limit; ++count) {
391         bool done = queries_ >= number;
392         // Always sleep at least one more interval after we are done, to wait for
393         // any immediate post-query actions that the client may take (such as
394         // marking this server as reachable during validation).
395         usleep(intervalMs * 1000);
396         if (done) {
397             // For ensuring that calls have sufficient headroom for slow machines
398             ALOGD("Query arrived in %d/%d of allotted time", count, limit);
399             return true;
400         }
401     }
402     return false;
403 }
404 
sendToEventFd()405 bool DnsTlsFrontend::sendToEventFd() {
406     const uint64_t data = 1;
407     if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
408         APLOGI("failed to write eventfd, rt=%zd", rt);
409         return false;
410     }
411     return true;
412 }
413 
handleEventFd()414 void DnsTlsFrontend::handleEventFd() {
415     int64_t data;
416     if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
417         APLOGI("ignore reading eventfd failed, rt=%zd", rt);
418     }
419 }
420 
421 }  // namespace test
422