1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_tls_frontend.h"
18
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <openssl/err.h>
22 #include <openssl/evp.h>
23 #include <openssl/ssl.h>
24 #include <openssl/x509.h>
25 #include <sys/eventfd.h>
26 #include <sys/poll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30
31 #define LOG_TAG "DnsTlsFrontend"
32 #include <android-base/logging.h>
33 #include <netdutils/InternetAddresses.h>
34 #include <netdutils/SocketOption.h>
35 #include "dns_responder.h"
36 #include "dns_tls_certificate.h"
37
38 using android::netdutils::enableSockopt;
39 using android::netdutils::ScopedAddrinfo;
40
41 namespace {
stringToX509Certs(const char * certs)42 static bssl::UniquePtr<X509> stringToX509Certs(const char* certs) {
43 bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(certs, strlen(certs)));
44 return bssl::UniquePtr<X509>(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
45 }
46
47 // Convert a string buffer containing an RSA Private Key into an OpenSSL RSA struct.
stringToRSAPrivateKey(const char * key)48 static bssl::UniquePtr<RSA> stringToRSAPrivateKey(const char* key) {
49 bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(key, strlen(key)));
50 return bssl::UniquePtr<RSA>(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr));
51 }
52
addr2str(const sockaddr * sa,socklen_t sa_len)53 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
54 char host_str[NI_MAXHOST] = {0};
55 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
56 if (rv == 0) return std::string(host_str);
57 return std::string();
58 }
59
60 } // namespace
61
62 namespace test {
63
startServer()64 bool DnsTlsFrontend::startServer() {
65 OpenSSL_add_ssl_algorithms();
66
67 // reset queries_ to 0 every time startServer called
68 // which would help us easy to check queries_ via calling waitForQueries
69 queries_ = 0;
70
71 ctx_.reset(SSL_CTX_new(TLS_server_method()));
72 if (!ctx_) {
73 LOG(ERROR) << "SSL context creation failed";
74 return false;
75 }
76
77 SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
78
79 bssl::UniquePtr<X509> ca_certs(stringToX509Certs(kCertificate));
80 if (!ca_certs) {
81 LOG(ERROR) << "StringToX509Certs failed";
82 return false;
83 }
84
85 if (SSL_CTX_use_certificate(ctx_.get(), ca_certs.get()) <= 0) {
86 LOG(ERROR) << "SSL_CTX_use_certificate failed";
87 return false;
88 }
89
90 bssl::UniquePtr<RSA> private_key(stringToRSAPrivateKey(kPrivatekey));
91 if (SSL_CTX_use_RSAPrivateKey(ctx_.get(), private_key.get()) <= 0) {
92 LOG(ERROR) << "Error loading client RSA Private Key data.";
93 return false;
94 }
95
96 // Set up TCP server socket for clients.
97 addrinfo frontend_ai_hints{
98 .ai_flags = AI_PASSIVE,
99 .ai_family = AF_UNSPEC,
100 .ai_socktype = SOCK_STREAM,
101 };
102 addrinfo* frontend_ai_res = nullptr;
103 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &frontend_ai_hints,
104 &frontend_ai_res);
105 ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res);
106 if (rv) {
107 LOG(ERROR) << "frontend getaddrinfo(" << listen_address_.c_str() << ", "
108 << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
109 return false;
110 }
111
112 for (const addrinfo* ai = frontend_ai_res; ai; ai = ai->ai_next) {
113 android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
114 if (s.get() < 0) {
115 PLOG(INFO) << "ignore creating socket failed " << s.get();
116 continue;
117 }
118 enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
119 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
120 if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
121 PLOG(INFO) << "failed to bind TCP " << host_str.c_str() << ":"
122 << listen_service_.c_str();
123 continue;
124 }
125 LOG(INFO) << "bound to TCP " << host_str.c_str() << ":" << listen_service_.c_str();
126 socket_ = std::move(s);
127 break;
128 }
129
130 if (listen(socket_.get(), 1) < 0) {
131 PLOG(INFO) << "failed to listen socket " << socket_.get();
132 return false;
133 }
134
135 // Set up UDP client socket to backend.
136 addrinfo backend_ai_hints{.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
137 addrinfo* backend_ai_res = nullptr;
138 rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(), &backend_ai_hints,
139 &backend_ai_res);
140 ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res);
141 if (rv) {
142 LOG(ERROR) << "backend getaddrinfo(" << listen_address_.c_str() << ", "
143 << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
144 return false;
145 }
146 backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
147 backend_ai_res->ai_protocol));
148 if (backend_socket_.get() < 0) {
149 PLOG(INFO) << "backend socket " << backend_socket_.get() << " creation failed";
150 return false;
151 }
152
153 // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of
154 // no backend server. Don't check it.
155 static_cast<void>(
156 connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen));
157
158 // Set up eventfd socket.
159 event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
160 if (event_fd_.get() == -1) {
161 PLOG(INFO) << "failed to create eventfd " << event_fd_.get();
162 return false;
163 }
164
165 {
166 std::lock_guard lock(update_mutex_);
167 handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
168 }
169 LOG(INFO) << "server started successfully";
170 return true;
171 }
172
requestHandler()173 void DnsTlsFrontend::requestHandler() {
174 LOG(DEBUG) << "Request handler started";
175 enum { EVENT_FD = 0, LISTEN_FD = 1 };
176 pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
177 {.fd = socket_.get(), .events = POLLIN}};
178 android::base::unique_fd clientFd;
179
180 while (true) {
181 int poll_code = poll(fds, std::size(fds), -1);
182 if (poll_code <= 0) {
183 PLOG(WARNING) << "Poll failed with error " << poll_code;
184 break;
185 }
186
187 if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) {
188 handleEventFd();
189 break;
190 }
191 if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) {
192 sockaddr_storage addr;
193 socklen_t len = sizeof(addr);
194
195 LOG(DEBUG) << "Trying to accept a client";
196 android::base::unique_fd client(
197 accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC));
198 if (client.get() < 0) {
199 // Stop
200 PLOG(INFO) << "failed to accept client socket " << client.get();
201 break;
202 }
203
204 accept_connection_count_++;
205 if (hangOnHandshake_) {
206 LOG(DEBUG) << "TEST ONLY: unresponsive to SSL handshake";
207
208 // The previous fd already stored in clientFd will be closed automatically.
209 clientFd = std::move(client);
210 continue;
211 }
212
213 bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
214 SSL_set_fd(ssl.get(), client.get());
215
216 LOG(DEBUG) << "Doing SSL handshake";
217 if (SSL_accept(ssl.get()) <= 0) {
218 LOG(INFO) << "SSL negotiation failure";
219 } else {
220 LOG(DEBUG) << "SSL handshake complete";
221 // Increment queries_ as late as possible, because it represents
222 // a query that is fully processed, and the response returned to the
223 // client, including cleanup actions.
224 queries_ += handleRequests(ssl.get(), client.get());
225 }
226
227 if (passiveClose_) {
228 LOG(DEBUG) << "hold the current connection until next connection request";
229 clientFd = std::move(client);
230 }
231 }
232 }
233 LOG(DEBUG) << "Ending loop";
234 }
235
handleRequests(SSL * ssl,int clientFd)236 int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
237 int queryCounts = 0;
238 std::vector<uint8_t> reply;
239 bool isDotProbe = false;
240 pollfd fds = {.fd = clientFd, .events = POLLIN};
241 again:
242 do {
243 uint8_t queryHeader[2];
244 if (SSL_read(ssl, &queryHeader, 2) != 2) {
245 LOG(INFO) << "Not enough header bytes";
246 return queryCounts;
247 }
248 const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
249 uint8_t query[qlen];
250 size_t qbytes = 0;
251 while (qbytes < qlen) {
252 int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
253 if (ret <= 0) {
254 LOG(INFO) << "Error while reading query";
255 return queryCounts;
256 }
257 qbytes += ret;
258 }
259 int sent = send(backend_socket_.get(), query, qlen, 0);
260 if (sent != qlen) {
261 LOG(INFO) << "Failed to send query";
262 return queryCounts;
263 }
264
265 if (!isDotProbe) {
266 DNSHeader dnsHdr;
267 dnsHdr.read((char*)query, (char*)query + qlen);
268 for (const auto& question : dnsHdr.questions) {
269 if (question.qname.name.find("dnsotls-ds.metric.gstatic.com") !=
270 std::string::npos) {
271 isDotProbe = true;
272 break;
273 }
274 }
275 }
276
277 const int max_size = 4096;
278 uint8_t recv_buffer[max_size];
279 int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
280 if (rlen <= 0) {
281 LOG(INFO) << "Failed to receive response";
282 return queryCounts;
283 }
284 uint8_t responseHeader[2];
285 responseHeader[0] = rlen >> 8;
286 responseHeader[1] = rlen;
287 reply.insert(reply.end(), responseHeader, responseHeader + 2);
288 reply.insert(reply.end(), recv_buffer, recv_buffer + rlen);
289
290 ++queryCounts;
291 if (queryCounts >= delayQueries_) {
292 break;
293 }
294 } while (poll(&fds, 1, delayQueriesTimeout_) > 0);
295
296 if (queryCounts < delayQueries_) {
297 LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received "
298 << queryCounts << " queries";
299 }
300
301 const int replyLen = reply.size();
302 LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen;
303 if (SSL_write(ssl, reply.data(), replyLen) != replyLen) {
304 LOG(WARNING) << "Failed to write response body";
305 }
306
307 // Poll again because the same DoT probe might be sent again.
308 if (isDotProbe && queryCounts == 1) {
309 int n = poll(&fds, 1, 50);
310 if (n > 0 && fds.revents & POLLIN) {
311 goto again;
312 }
313 }
314
315 LOG(DEBUG) << __func__ << " return: " << queryCounts;
316 return queryCounts;
317 }
318
stopServer()319 bool DnsTlsFrontend::stopServer() {
320 std::lock_guard lock(update_mutex_);
321 if (!running()) {
322 LOG(INFO) << "server not running";
323 return false;
324 }
325
326 LOG(INFO) << "stopping frontend";
327 if (!sendToEventFd()) {
328 return false;
329 }
330 handler_thread_.join();
331 socket_.reset();
332 backend_socket_.reset();
333 event_fd_.reset();
334 ctx_.reset();
335 LOG(INFO) << "frontend stopped successfully";
336 return true;
337 }
338
339 // TODO: use a condition variable instead of polling
340 // TODO: also clear queries_ to eliminate potential race conditions
waitForQueries(int expected_count) const341 bool DnsTlsFrontend::waitForQueries(int expected_count) const {
342 constexpr int intervalMs = 20;
343 constexpr int timeoutMs = 5000;
344 int limit = timeoutMs / intervalMs;
345 for (int count = 0; count <= limit; ++count) {
346 bool done = queries_ >= expected_count;
347 // Always sleep at least one more interval after we are done, to wait for
348 // any immediate post-query actions that the client may take (such as
349 // marking this server as reachable during validation).
350 usleep(intervalMs * 1000);
351 if (done) {
352 // For ensuring that calls have sufficient headroom for slow machines
353 LOG(DEBUG) << "Query arrived in " << count << "/" << limit << " of allotted time";
354 return true;
355 }
356 }
357 return false;
358 }
359
sendToEventFd()360 bool DnsTlsFrontend::sendToEventFd() {
361 const uint64_t data = 1;
362 if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
363 PLOG(INFO) << "failed to write eventfd, rt=" << rt;
364 return false;
365 }
366 return true;
367 }
368
handleEventFd()369 void DnsTlsFrontend::handleEventFd() {
370 int64_t data;
371 if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
372 PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
373 }
374 }
375
376 } // namespace test
377