1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_tls_frontend.h"
18
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <openssl/err.h>
22 #include <openssl/evp.h>
23 #include <openssl/ssl.h>
24 #include <sys/eventfd.h>
25 #include <sys/poll.h>
26 #include <sys/socket.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29
30 #define LOG_TAG "DnsTlsFrontend"
31 #include <log/log.h>
32 #include <netdutils/SocketOption.h>
33
34 #include "NetdConstants.h" // SHA256_SIZE
35
36 using android::netdutils::enableSockopt;
37
38 namespace {
39
40 // Copied from DnsTlsTransport.
getSPKIDigest(const X509 * cert,std::vector<uint8_t> * out)41 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
42 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr);
43 unsigned char spki[spki_len];
44 unsigned char* temp = spki;
45 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
46 ALOGE("SPKI length mismatch");
47 return false;
48 }
49 out->resize(SHA256_SIZE);
50 unsigned int digest_len = 0;
51 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr);
52 if (ret != 1) {
53 ALOGE("Server cert digest extraction failed");
54 return false;
55 }
56 if (digest_len != out->size()) {
57 ALOGE("Wrong digest length: %d", digest_len);
58 return false;
59 }
60 return true;
61 }
62
errno2str()63 std::string errno2str() {
64 char error_msg[512] = { 0 };
65 return strerror_r(errno, error_msg, sizeof(error_msg));
66 }
67
68 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
69
addr2str(const sockaddr * sa,socklen_t sa_len)70 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
71 char host_str[NI_MAXHOST] = { 0 };
72 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
73 NI_NUMERICHOST);
74 if (rv == 0) return std::string(host_str);
75 return std::string();
76 }
77
make_private_key()78 bssl::UniquePtr<EVP_PKEY> make_private_key() {
79 bssl::UniquePtr<BIGNUM> e(BN_new());
80 if (!e) {
81 ALOGE("BN_new failed");
82 return nullptr;
83 }
84 if (!BN_set_word(e.get(), RSA_F4)) {
85 ALOGE("BN_set_word failed");
86 return nullptr;
87 }
88
89 bssl::UniquePtr<RSA> rsa(RSA_new());
90 if (!rsa) {
91 ALOGE("RSA_new failed");
92 return nullptr;
93 }
94 if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), nullptr)) {
95 ALOGE("RSA_generate_key_ex failed");
96 return nullptr;
97 }
98
99 bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
100 if (!privkey) {
101 ALOGE("EVP_PKEY_new failed");
102 return nullptr;
103 }
104 if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
105 ALOGE("EVP_PKEY_assign_RSA failed");
106 return nullptr;
107 }
108
109 // |rsa| is now owned by |privkey|, so no need to free it.
110 rsa.release();
111 return privkey;
112 }
113
make_cert(EVP_PKEY * privkey,EVP_PKEY * parent_key)114 bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey, EVP_PKEY* parent_key) {
115 bssl::UniquePtr<X509> cert(X509_new());
116 if (!cert) {
117 ALOGE("X509_new failed");
118 return nullptr;
119 }
120
121 ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
122
123 // Set one hour expiration.
124 X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
125 X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
126
127 X509_set_pubkey(cert.get(), privkey);
128
129 if (!X509_sign(cert.get(), parent_key, EVP_sha256())) {
130 ALOGE("X509_sign failed");
131 return nullptr;
132 }
133
134 return cert;
135 }
136
137 }
138
139 namespace test {
140
startServer()141 bool DnsTlsFrontend::startServer() {
142 SSL_load_error_strings();
143 OpenSSL_add_ssl_algorithms();
144
145 // reset queries_ to 0 every time startServer called
146 // which would help us easy to check queries_ via calling waitForQueries
147 queries_ = 0;
148
149 ctx_.reset(SSL_CTX_new(TLS_server_method()));
150 if (!ctx_) {
151 ALOGE("SSL context creation failed");
152 return false;
153 }
154
155 SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
156
157 // Make certificate chain
158 std::vector<bssl::UniquePtr<EVP_PKEY>> keys(chain_length_);
159 for (int i = 0; i < chain_length_; ++i) {
160 keys[i] = make_private_key();
161 }
162 std::vector<bssl::UniquePtr<X509>> certs(chain_length_);
163 for (int i = 0; i < chain_length_; ++i) {
164 int next = std::min(i + 1, chain_length_ - 1);
165 certs[i] = make_cert(keys[i].get(), keys[next].get());
166 }
167
168 // Install certificate chain.
169 if (SSL_CTX_use_certificate(ctx_.get(), certs[0].get()) <= 0) {
170 ALOGE("SSL_CTX_use_certificate failed");
171 return false;
172 }
173 if (SSL_CTX_use_PrivateKey(ctx_.get(), keys[0].get()) <= 0 ) {
174 ALOGE("SSL_CTX_use_PrivateKey failed");
175 return false;
176 }
177 for (int i = 1; i < chain_length_; ++i) {
178 if (SSL_CTX_add1_chain_cert(ctx_.get(), certs[i].get()) != 1) {
179 ALOGE("SSL_CTX_add1_chain_cert failed");
180 return false;
181 }
182 }
183
184 // Report the fingerprint of the "middle" cert. For N = 2, this is the root.
185 int fp_index = chain_length_ / 2;
186 if (!getSPKIDigest(certs[fp_index].get(), &fingerprint_)) {
187 ALOGE("getSPKIDigest failed");
188 return false;
189 }
190
191 // Set up TCP server socket for clients.
192 addrinfo frontend_ai_hints{
193 .ai_family = AF_UNSPEC,
194 .ai_socktype = SOCK_STREAM,
195 .ai_flags = AI_PASSIVE
196 };
197 addrinfo* frontend_ai_res = nullptr;
198 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
199 &frontend_ai_hints, &frontend_ai_res);
200 ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res);
201 if (rv) {
202 ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
203 listen_service_.c_str(), gai_strerror(rv));
204 return false;
205 }
206
207 for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
208 android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
209 if (s.get() < 0) {
210 APLOGI("ignore creating socket failed %d", s.get());
211 continue;
212 }
213 enableSockopt(s.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
214 enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
215 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
216 if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
217 APLOGI("failed to bind TCP %s:%s", host_str.c_str(), listen_service_.c_str());
218 continue;
219 }
220 ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
221 socket_ = std::move(s);
222 break;
223 }
224
225 if (listen(socket_.get(), 1) < 0) {
226 APLOGI("failed to listen socket %d", socket_.get());
227 return false;
228 }
229
230 // Set up UDP client socket to backend.
231 addrinfo backend_ai_hints{
232 .ai_family = AF_UNSPEC,
233 .ai_socktype = SOCK_DGRAM
234 };
235 addrinfo* backend_ai_res = nullptr;
236 rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
237 &backend_ai_hints, &backend_ai_res);
238 ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res);
239 if (rv) {
240 ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
241 listen_service_.c_str(), gai_strerror(rv));
242 return false;
243 }
244 backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
245 backend_ai_res->ai_protocol));
246 if (backend_socket_.get() < 0) {
247 APLOGI("backend socket %d creation failed", backend_socket_.get());
248 return false;
249 }
250
251 // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of
252 // no backend server. Don't check it.
253 connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
254
255 // Set up eventfd socket.
256 event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
257 if (event_fd_.get() == -1) {
258 APLOGI("failed to create eventfd %d", event_fd_.get());
259 return false;
260 }
261
262 {
263 std::lock_guard lock(update_mutex_);
264 handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
265 }
266 ALOGI("server started successfully");
267 return true;
268 }
269
requestHandler()270 void DnsTlsFrontend::requestHandler() {
271 ALOGD("Request handler started");
272 enum { EVENT_FD = 0, LISTEN_FD = 1 };
273 pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
274 {.fd = socket_.get(), .events = POLLIN}};
275
276 while (true) {
277 int poll_code = poll(fds, std::size(fds), -1);
278 if (poll_code <= 0) {
279 APLOGI("Poll failed with error %d", poll_code);
280 break;
281 }
282
283 if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) {
284 handleEventFd();
285 break;
286 }
287 if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) {
288 sockaddr_storage addr;
289 socklen_t len = sizeof(addr);
290
291 ALOGD("Trying to accept a client");
292 android::base::unique_fd client(
293 accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC));
294 if (client.get() < 0) {
295 // Stop
296 APLOGI("failed to accept client socket %d", client.get());
297 break;
298 }
299
300 bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
301 SSL_set_fd(ssl.get(), client.get());
302
303 ALOGD("Doing SSL handshake");
304 bool success = false;
305 if (SSL_accept(ssl.get()) <= 0) {
306 ALOGI("SSL negotiation failure");
307 } else {
308 ALOGD("SSL handshake complete");
309 success = handleOneRequest(ssl.get());
310 }
311
312 if (success) {
313 // Increment queries_ as late as possible, because it represents
314 // a query that is fully processed, and the response returned to the
315 // client, including cleanup actions.
316 ++queries_;
317 }
318 }
319 }
320 ALOGD("Ending loop");
321 }
322
handleOneRequest(SSL * ssl)323 bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
324 uint8_t queryHeader[2];
325 if (SSL_read(ssl, &queryHeader, 2) != 2) {
326 ALOGI("Not enough header bytes");
327 return false;
328 }
329 const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
330 uint8_t query[qlen];
331 size_t qbytes = 0;
332 while (qbytes < qlen) {
333 int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
334 if (ret <= 0) {
335 ALOGI("Error while reading query");
336 return false;
337 }
338 qbytes += ret;
339 }
340 int sent = send(backend_socket_.get(), query, qlen, 0);
341 if (sent != qlen) {
342 ALOGI("Failed to send query");
343 return false;
344 }
345 const int max_size = 4096;
346 uint8_t recv_buffer[max_size];
347 int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
348 if (rlen <= 0) {
349 ALOGI("Failed to receive response");
350 return false;
351 }
352 uint8_t responseHeader[2];
353 responseHeader[0] = rlen >> 8;
354 responseHeader[1] = rlen;
355 if (SSL_write(ssl, responseHeader, 2) != 2) {
356 ALOGI("Failed to write response header");
357 return false;
358 }
359 if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
360 ALOGI("Failed to write response body");
361 return false;
362 }
363 return true;
364 }
365
stopServer()366 bool DnsTlsFrontend::stopServer() {
367 std::lock_guard lock(update_mutex_);
368 if (!running()) {
369 ALOGI("server not running");
370 return false;
371 }
372
373 ALOGI("stopping frontend");
374 if (!sendToEventFd()) {
375 return false;
376 }
377 handler_thread_.join();
378 socket_.reset();
379 backend_socket_.reset();
380 event_fd_.reset();
381 ctx_.reset();
382 fingerprint_.clear();
383 ALOGI("frontend stopped successfully");
384 return true;
385 }
386
waitForQueries(int number,int timeoutMs) const387 bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
388 constexpr int intervalMs = 20;
389 int limit = timeoutMs / intervalMs;
390 for (int count = 0; count <= limit; ++count) {
391 bool done = queries_ >= number;
392 // Always sleep at least one more interval after we are done, to wait for
393 // any immediate post-query actions that the client may take (such as
394 // marking this server as reachable during validation).
395 usleep(intervalMs * 1000);
396 if (done) {
397 // For ensuring that calls have sufficient headroom for slow machines
398 ALOGD("Query arrived in %d/%d of allotted time", count, limit);
399 return true;
400 }
401 }
402 return false;
403 }
404
sendToEventFd()405 bool DnsTlsFrontend::sendToEventFd() {
406 const uint64_t data = 1;
407 if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
408 APLOGI("failed to write eventfd, rt=%zd", rt);
409 return false;
410 }
411 return true;
412 }
413
handleEventFd()414 void DnsTlsFrontend::handleEventFd() {
415 int64_t data;
416 if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
417 APLOGI("ignore reading eventfd failed, rt=%zd", rt);
418 }
419 }
420
421 } // namespace test
422