1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "dns_tls_test"
18
19 #include <gtest/gtest.h>
20
21 #include "dns/DnsTlsDispatcher.h"
22 #include "dns/DnsTlsQueryMap.h"
23 #include "dns/DnsTlsServer.h"
24 #include "dns/DnsTlsSessionCache.h"
25 #include "dns/DnsTlsSocket.h"
26 #include "dns/DnsTlsTransport.h"
27 #include "dns/IDnsTlsSocket.h"
28 #include "dns/IDnsTlsSocketFactory.h"
29 #include "dns/IDnsTlsSocketObserver.h"
30
31 #include "dns_responder/dns_tls_frontend.h"
32
33 #include <chrono>
34 #include <arpa/inet.h>
35 #include <android-base/macros.h>
36 #include <netdutils/Slice.h>
37
38 #include "log/log.h"
39
40 namespace android {
41 namespace net {
42
43 using netdutils::Slice;
44 using netdutils::makeSlice;
45
46 typedef std::vector<uint8_t> bytevec;
47
parseServer(const char * server,in_port_t port,sockaddr_storage * parsed)48 static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
49 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
50 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
51 // IPv4 parse succeeded, so it's IPv4
52 sin->sin_family = AF_INET;
53 sin->sin_port = htons(port);
54 return;
55 }
56 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
57 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
58 // IPv6 parse succeeded, so it's IPv6.
59 sin6->sin6_family = AF_INET6;
60 sin6->sin6_port = htons(port);
61 return;
62 }
63 ALOGE("Failed to parse server address: %s", server);
64 }
65
66 bytevec FINGERPRINT1 = { 1 };
67 bytevec FINGERPRINT2 = { 2 };
68
69 std::string SERVERNAME1 = "dns.example.com";
70 std::string SERVERNAME2 = "dns.example.org";
71
72 // BaseTest just provides constants that are useful for the tests.
73 class BaseTest : public ::testing::Test {
74 protected:
BaseTest()75 BaseTest() {
76 parseServer("192.0.2.1", 853, &V4ADDR1);
77 parseServer("192.0.2.2", 853, &V4ADDR2);
78 parseServer("2001:db8::1", 853, &V6ADDR1);
79 parseServer("2001:db8::2", 853, &V6ADDR2);
80
81 SERVER1 = DnsTlsServer(V4ADDR1);
82 SERVER1.fingerprints.insert(FINGERPRINT1);
83 SERVER1.name = SERVERNAME1;
84 }
85
86 sockaddr_storage V4ADDR1;
87 sockaddr_storage V4ADDR2;
88 sockaddr_storage V6ADDR1;
89 sockaddr_storage V6ADDR2;
90
91 DnsTlsServer SERVER1;
92 };
93
make_query(uint16_t id,size_t size)94 bytevec make_query(uint16_t id, size_t size) {
95 bytevec vec(size);
96 vec[0] = id >> 8;
97 vec[1] = id;
98 // Arbitrarily fill the query body with unique data.
99 for (size_t i = 2; i < size; ++i) {
100 vec[i] = id + i;
101 }
102 return vec;
103 }
104
105 // Query constants
106 const unsigned MARK = 123;
107 const uint16_t ID = 52;
108 const uint16_t SIZE = 22;
109 const bytevec QUERY = make_query(ID, SIZE);
110
111 template <class T>
112 class FakeSocketFactory : public IDnsTlsSocketFactory {
113 public:
FakeSocketFactory()114 FakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)115 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
116 const DnsTlsServer& server ATTRIBUTE_UNUSED,
117 unsigned mark ATTRIBUTE_UNUSED,
118 IDnsTlsSocketObserver* observer,
119 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
120 return std::make_unique<T>(observer);
121 }
122 };
123
make_echo(uint16_t id,const Slice query)124 bytevec make_echo(uint16_t id, const Slice query) {
125 bytevec response(query.size() + 2);
126 response[0] = id >> 8;
127 response[1] = id;
128 // Echo the query as the fake response.
129 memcpy(response.data() + 2, query.base(), query.size());
130 return response;
131 }
132
133 // Simplest possible fake server. This just echoes the query as the response.
134 class FakeSocketEcho : public IDnsTlsSocket {
135 public:
FakeSocketEcho(IDnsTlsSocketObserver * observer)136 FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query)137 bool query(uint16_t id, const Slice query) override {
138 // Return the response immediately (asynchronously).
139 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
140 return true;
141 }
142 private:
143 IDnsTlsSocketObserver* const mObserver;
144 };
145
146 class TransportTest : public BaseTest {};
147
TEST_F(TransportTest,Query)148 TEST_F(TransportTest, Query) {
149 FakeSocketFactory<FakeSocketEcho> factory;
150 DnsTlsTransport transport(SERVER1, MARK, &factory);
151 auto r = transport.query(makeSlice(QUERY)).get();
152
153 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
154 EXPECT_EQ(QUERY, r.response);
155 }
156
TEST_F(TransportTest,SerialQueries_100000)157 TEST_F(TransportTest, SerialQueries_100000) {
158 FakeSocketFactory<FakeSocketEcho> factory;
159 DnsTlsTransport transport(SERVER1, MARK, &factory);
160 // Send more than 65536 queries serially.
161 for (int i = 0; i < 100000; ++i) {
162 auto r = transport.query(makeSlice(QUERY)).get();
163
164 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
165 EXPECT_EQ(QUERY, r.response);
166 }
167 }
168
169 // These queries might be handled in serial or parallel as they race the
170 // responses.
TEST_F(TransportTest,RacingQueries_10000)171 TEST_F(TransportTest, RacingQueries_10000) {
172 FakeSocketFactory<FakeSocketEcho> factory;
173 DnsTlsTransport transport(SERVER1, MARK, &factory);
174 std::vector<std::future<DnsTlsTransport::Result>> results;
175 // Fewer than 65536 queries to avoid ID exhaustion.
176 for (int i = 0; i < 10000; ++i) {
177 results.push_back(transport.query(makeSlice(QUERY)));
178 }
179 for (auto& result : results) {
180 auto r = result.get();
181 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
182 EXPECT_EQ(QUERY, r.response);
183 }
184 }
185
186 // A server that waits until sDelay queries are queued before responding.
187 class FakeSocketDelay : public IDnsTlsSocket {
188 public:
FakeSocketDelay(IDnsTlsSocketObserver * observer)189 FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
~FakeSocketDelay()190 ~FakeSocketDelay() { std::lock_guard<std::mutex> guard(mLock); }
191 static size_t sDelay;
192 static bool sReverse;
193
query(uint16_t id,const Slice query)194 bool query(uint16_t id, const Slice query) override {
195 ALOGD("FakeSocketDelay got query with ID %d", int(id));
196 std::lock_guard<std::mutex> guard(mLock);
197 // Check for duplicate IDs.
198 EXPECT_EQ(0U, mIds.count(id));
199 mIds.insert(id);
200
201 // Store response.
202 mResponses.push_back(make_echo(id, query));
203
204 ALOGD("Up to %zu out of %zu queries", mResponses.size(), sDelay);
205 if (mResponses.size() == sDelay) {
206 std::thread(&FakeSocketDelay::sendResponses, this).detach();
207 }
208 return true;
209 }
210 private:
sendResponses()211 void sendResponses() {
212 std::lock_guard<std::mutex> guard(mLock);
213 if (sReverse) {
214 std::reverse(std::begin(mResponses), std::end(mResponses));
215 }
216 for (auto& response : mResponses) {
217 mObserver->onResponse(response);
218 }
219 mIds.clear();
220 mResponses.clear();
221 }
222
223 std::mutex mLock;
224 IDnsTlsSocketObserver* const mObserver;
225 std::set<uint16_t> mIds GUARDED_BY(mLock);
226 std::vector<bytevec> mResponses GUARDED_BY(mLock);
227 };
228
229 size_t FakeSocketDelay::sDelay;
230 bool FakeSocketDelay::sReverse;
231
TEST_F(TransportTest,ParallelColliding)232 TEST_F(TransportTest, ParallelColliding) {
233 FakeSocketDelay::sDelay = 10;
234 FakeSocketDelay::sReverse = false;
235 FakeSocketFactory<FakeSocketDelay> factory;
236 DnsTlsTransport transport(SERVER1, MARK, &factory);
237 std::vector<std::future<DnsTlsTransport::Result>> results;
238 // Fewer than 65536 queries to avoid ID exhaustion.
239 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
240 results.push_back(transport.query(makeSlice(QUERY)));
241 }
242 for (auto& result : results) {
243 auto r = result.get();
244 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
245 EXPECT_EQ(QUERY, r.response);
246 }
247 }
248
TEST_F(TransportTest,ParallelColliding_Max)249 TEST_F(TransportTest, ParallelColliding_Max) {
250 FakeSocketDelay::sDelay = 65536;
251 FakeSocketDelay::sReverse = false;
252 FakeSocketFactory<FakeSocketDelay> factory;
253 DnsTlsTransport transport(SERVER1, MARK, &factory);
254 std::vector<std::future<DnsTlsTransport::Result>> results;
255 // Exactly 65536 queries should still be possible in parallel,
256 // even if they all have the same original ID.
257 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
258 results.push_back(transport.query(makeSlice(QUERY)));
259 }
260 for (auto& result : results) {
261 auto r = result.get();
262 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
263 EXPECT_EQ(QUERY, r.response);
264 }
265 }
266
TEST_F(TransportTest,ParallelUnique)267 TEST_F(TransportTest, ParallelUnique) {
268 FakeSocketDelay::sDelay = 10;
269 FakeSocketDelay::sReverse = false;
270 FakeSocketFactory<FakeSocketDelay> factory;
271 DnsTlsTransport transport(SERVER1, MARK, &factory);
272 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
273 std::vector<std::future<DnsTlsTransport::Result>> results;
274 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
275 queries[i] = make_query(i, SIZE);
276 results.push_back(transport.query(makeSlice(queries[i])));
277 }
278 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
279 auto r = results[i].get();
280 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
281 EXPECT_EQ(queries[i], r.response);
282 }
283 }
284
TEST_F(TransportTest,ParallelUnique_Max)285 TEST_F(TransportTest, ParallelUnique_Max) {
286 FakeSocketDelay::sDelay = 65536;
287 FakeSocketDelay::sReverse = false;
288 FakeSocketFactory<FakeSocketDelay> factory;
289 DnsTlsTransport transport(SERVER1, MARK, &factory);
290 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
291 std::vector<std::future<DnsTlsTransport::Result>> results;
292 // Exactly 65536 queries should still be possible in parallel,
293 // and they should all be mapped correctly back to the original ID.
294 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
295 queries[i] = make_query(i, SIZE);
296 results.push_back(transport.query(makeSlice(queries[i])));
297 }
298 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
299 auto r = results[i].get();
300 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
301 EXPECT_EQ(queries[i], r.response);
302 }
303 }
304
TEST_F(TransportTest,IdExhaustion)305 TEST_F(TransportTest, IdExhaustion) {
306 // A delay of 65537 is unreachable, because the maximum number
307 // of outstanding queries is 65536.
308 FakeSocketDelay::sDelay = 65537;
309 FakeSocketDelay::sReverse = false;
310 FakeSocketFactory<FakeSocketDelay> factory;
311 DnsTlsTransport transport(SERVER1, MARK, &factory);
312 std::vector<std::future<DnsTlsTransport::Result>> results;
313 // Issue the maximum number of queries.
314 for (int i = 0; i < 65536; ++i) {
315 results.push_back(transport.query(makeSlice(QUERY)));
316 }
317
318 // The ID space is now full, so subsequent queries should fail immediately.
319 auto r = transport.query(makeSlice(QUERY)).get();
320 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
321 EXPECT_TRUE(r.response.empty());
322
323 for (auto& result : results) {
324 // All other queries should remain outstanding.
325 EXPECT_EQ(std::future_status::timeout,
326 result.wait_for(std::chrono::duration<int>::zero()));
327 }
328 }
329
330 // Responses can come back from the server in any order. This should have no
331 // effect on Transport's observed behavior.
TEST_F(TransportTest,ReverseOrder)332 TEST_F(TransportTest, ReverseOrder) {
333 FakeSocketDelay::sDelay = 10;
334 FakeSocketDelay::sReverse = true;
335 FakeSocketFactory<FakeSocketDelay> factory;
336 DnsTlsTransport transport(SERVER1, MARK, &factory);
337 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
338 std::vector<std::future<DnsTlsTransport::Result>> results;
339 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
340 queries[i] = make_query(i, SIZE);
341 results.push_back(transport.query(makeSlice(queries[i])));
342 }
343 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
344 auto r = results[i].get();
345 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
346 EXPECT_EQ(queries[i], r.response);
347 }
348 }
349
TEST_F(TransportTest,ReverseOrder_Max)350 TEST_F(TransportTest, ReverseOrder_Max) {
351 FakeSocketDelay::sDelay = 65536;
352 FakeSocketDelay::sReverse = true;
353 FakeSocketFactory<FakeSocketDelay> factory;
354 DnsTlsTransport transport(SERVER1, MARK, &factory);
355 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
356 std::vector<std::future<DnsTlsTransport::Result>> results;
357 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
358 queries[i] = make_query(i, SIZE);
359 results.push_back(transport.query(makeSlice(queries[i])));
360 }
361 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
362 auto r = results[i].get();
363 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
364 EXPECT_EQ(queries[i], r.response);
365 }
366 }
367
368 // Returning null from the factory indicates a connection failure.
369 class NullSocketFactory : public IDnsTlsSocketFactory {
370 public:
NullSocketFactory()371 NullSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer ATTRIBUTE_UNUSED,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)372 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
373 const DnsTlsServer& server ATTRIBUTE_UNUSED,
374 unsigned mark ATTRIBUTE_UNUSED,
375 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
376 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
377 return nullptr;
378 }
379 };
380
TEST_F(TransportTest,ConnectFail)381 TEST_F(TransportTest, ConnectFail) {
382 NullSocketFactory factory;
383 DnsTlsTransport transport(SERVER1, MARK, &factory);
384 auto r = transport.query(makeSlice(QUERY)).get();
385
386 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
387 EXPECT_TRUE(r.response.empty());
388 }
389
390 // Simulate a socket that connects but then immediately receives a server
391 // close notification.
392 class FakeSocketClose : public IDnsTlsSocket {
393 public:
FakeSocketClose(IDnsTlsSocketObserver * observer)394 FakeSocketClose(IDnsTlsSocketObserver* observer) :
395 mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
~FakeSocketClose()396 ~FakeSocketClose() { mCloser.join(); }
query(uint16_t id ATTRIBUTE_UNUSED,const Slice query ATTRIBUTE_UNUSED)397 bool query(uint16_t id ATTRIBUTE_UNUSED,
398 const Slice query ATTRIBUTE_UNUSED) override {
399 return true;
400 }
401 private:
402 std::thread mCloser;
403 };
404
TEST_F(TransportTest,CloseRetryFail)405 TEST_F(TransportTest, CloseRetryFail) {
406 FakeSocketFactory<FakeSocketClose> factory;
407 DnsTlsTransport transport(SERVER1, MARK, &factory);
408 auto r = transport.query(makeSlice(QUERY)).get();
409
410 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
411 EXPECT_TRUE(r.response.empty());
412 }
413
414 // Simulate a server that occasionally closes the connection and silently
415 // drops some queries.
416 class FakeSocketLimited : public IDnsTlsSocket {
417 public:
418 static int sLimit; // Number of queries to answer per socket.
419 static size_t sMaxSize; // Silently discard queries greater than this size.
FakeSocketLimited(IDnsTlsSocketObserver * observer)420 FakeSocketLimited(IDnsTlsSocketObserver* observer) :
421 mObserver(observer), mQueries(0) {}
~FakeSocketLimited()422 ~FakeSocketLimited() {
423 {
424 ALOGD("~FakeSocketLimited acquiring mLock");
425 std::lock_guard<std::mutex> guard(mLock);
426 ALOGD("~FakeSocketLimited acquired mLock");
427 for (auto& thread : mThreads) {
428 ALOGD("~FakeSocketLimited joining response thread");
429 thread.join();
430 ALOGD("~FakeSocketLimited joined response thread");
431 }
432 mThreads.clear();
433 }
434
435 if (mCloser) {
436 ALOGD("~FakeSocketLimited joining closer thread");
437 mCloser->join();
438 ALOGD("~FakeSocketLimited joined closer thread");
439 }
440 }
query(uint16_t id,const Slice query)441 bool query(uint16_t id, const Slice query) override {
442 ALOGD("FakeSocketLimited::query acquiring mLock");
443 std::lock_guard<std::mutex> guard(mLock);
444 ALOGD("FakeSocketLimited::query acquired mLock");
445 ++mQueries;
446
447 if (mQueries <= sLimit) {
448 ALOGD("size %zu vs. limit of %zu", query.size(), sMaxSize);
449 if (query.size() <= sMaxSize) {
450 // Return the response immediately (asynchronously).
451 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
452 }
453 }
454 if (mQueries == sLimit) {
455 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
456 }
457 return mQueries <= sLimit;
458 }
459 private:
sendClose()460 void sendClose() {
461 {
462 ALOGD("FakeSocketLimited::sendClose acquiring mLock");
463 std::lock_guard<std::mutex> guard(mLock);
464 ALOGD("FakeSocketLimited::sendClose acquired mLock");
465 for (auto& thread : mThreads) {
466 ALOGD("FakeSocketLimited::sendClose joining response thread");
467 thread.join();
468 ALOGD("FakeSocketLimited::sendClose joined response thread");
469 }
470 mThreads.clear();
471 }
472 mObserver->onClosed();
473 }
474 std::mutex mLock;
475 IDnsTlsSocketObserver* const mObserver;
476 int mQueries GUARDED_BY(mLock);
477 std::vector<std::thread> mThreads GUARDED_BY(mLock);
478 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
479 };
480
481 int FakeSocketLimited::sLimit;
482 size_t FakeSocketLimited::sMaxSize;
483
TEST_F(TransportTest,SilentDrop)484 TEST_F(TransportTest, SilentDrop) {
485 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
486 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
487 FakeSocketFactory<FakeSocketLimited> factory;
488 DnsTlsTransport transport(SERVER1, MARK, &factory);
489
490 // Queue up 10 queries. They will all be ignored, and after the 10th,
491 // the socket will close. Transport will retry them all, until they
492 // all hit the retry limit and expire.
493 std::vector<std::future<DnsTlsTransport::Result>> results;
494 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
495 results.push_back(transport.query(makeSlice(QUERY)));
496 }
497 for (auto& result : results) {
498 auto r = result.get();
499 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
500 EXPECT_TRUE(r.response.empty());
501 }
502 }
503
TEST_F(TransportTest,PartialDrop)504 TEST_F(TransportTest, PartialDrop) {
505 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
506 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
507 FakeSocketFactory<FakeSocketLimited> factory;
508 DnsTlsTransport transport(SERVER1, MARK, &factory);
509
510 // Queue up 100 queries, alternating "short" which will be served and "long"
511 // which will be dropped.
512 int num_queries = 10 * FakeSocketLimited::sLimit;
513 std::vector<bytevec> queries(num_queries);
514 std::vector<std::future<DnsTlsTransport::Result>> results;
515 for (int i = 0; i < num_queries; ++i) {
516 queries[i] = make_query(i, SIZE + (i % 2));
517 results.push_back(transport.query(makeSlice(queries[i])));
518 }
519 // Just check the short queries, which are at the even indices.
520 for (int i = 0; i < num_queries; i += 2) {
521 auto r = results[i].get();
522 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
523 EXPECT_EQ(queries[i], r.response);
524 }
525 }
526
527 // Simulate a malfunctioning server that injects extra miscellaneous
528 // responses to queries that were not asked. This will cause wrong answers but
529 // must not crash the Transport.
530 class FakeSocketGarbage : public IDnsTlsSocket {
531 public:
FakeSocketGarbage(IDnsTlsSocketObserver * observer)532 FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
533 // Inject a garbage event.
534 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
535 }
~FakeSocketGarbage()536 ~FakeSocketGarbage() {
537 std::lock_guard<std::mutex> guard(mLock);
538 for (auto& thread : mThreads) {
539 thread.join();
540 }
541 }
query(uint16_t id,const Slice query)542 bool query(uint16_t id, const Slice query) override {
543 std::lock_guard<std::mutex> guard(mLock);
544 // Return the response twice.
545 auto echo = make_echo(id, query);
546 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
547 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
548 // Also return some other garbage
549 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
550 return true;
551 }
552 private:
553 std::mutex mLock;
554 std::vector<std::thread> mThreads GUARDED_BY(mLock);
555 IDnsTlsSocketObserver* const mObserver;
556 };
557
TEST_F(TransportTest,IgnoringGarbage)558 TEST_F(TransportTest, IgnoringGarbage) {
559 FakeSocketFactory<FakeSocketGarbage> factory;
560 DnsTlsTransport transport(SERVER1, MARK, &factory);
561 for (int i = 0; i < 10; ++i) {
562 auto r = transport.query(makeSlice(QUERY)).get();
563
564 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
565 // Don't check the response because this server is malfunctioning.
566 }
567 }
568
569 // Dispatcher tests
570 class DispatcherTest : public BaseTest {};
571
TEST_F(DispatcherTest,Query)572 TEST_F(DispatcherTest, Query) {
573 bytevec ans(4096);
574 int resplen = 0;
575
576 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
577 DnsTlsDispatcher dispatcher(std::move(factory));
578 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
579 makeSlice(ans), &resplen);
580
581 EXPECT_EQ(DnsTlsTransport::Response::success, r);
582 EXPECT_EQ(int(QUERY.size()), resplen);
583 ans.resize(resplen);
584 EXPECT_EQ(QUERY, ans);
585 }
586
TEST_F(DispatcherTest,AnswerTooLarge)587 TEST_F(DispatcherTest, AnswerTooLarge) {
588 bytevec ans(SIZE - 1); // Too small to hold the answer
589 int resplen = 0;
590
591 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
592 DnsTlsDispatcher dispatcher(std::move(factory));
593 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
594 makeSlice(ans), &resplen);
595
596 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
597 }
598
599 template<class T>
600 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
601 public:
TrackingFakeSocketFactory()602 TrackingFakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)603 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
604 const DnsTlsServer& server,
605 unsigned mark,
606 IDnsTlsSocketObserver* observer,
607 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
608 std::lock_guard<std::mutex> guard(mLock);
609 keys.emplace(mark, server);
610 return std::make_unique<T>(observer);
611 }
612 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
613 private:
614 std::mutex mLock;
615 };
616
TEST_F(DispatcherTest,Dispatching)617 TEST_F(DispatcherTest, Dispatching) {
618 FakeSocketDelay::sDelay = 5;
619 FakeSocketDelay::sReverse = true;
620 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
621 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
622 DnsTlsDispatcher dispatcher(std::move(factory));
623
624 // Populate a vector of two servers and two socket marks, four combinations
625 // in total.
626 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
627 keys.emplace_back(MARK, SERVER1);
628 keys.emplace_back(MARK + 1, SERVER1);
629 keys.emplace_back(MARK, V4ADDR2);
630 keys.emplace_back(MARK + 1, V4ADDR2);
631
632 // Do several queries on each server. They should all succeed.
633 std::vector<std::thread> threads;
634 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
635 auto key = keys[i % keys.size()];
636 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
637 auto q = make_query(i, SIZE);
638 bytevec ans(4096);
639 int resplen = 0;
640 unsigned mark = key.first;
641 const DnsTlsServer& server = key.second;
642 auto r = dispatcher->query(server, mark, makeSlice(q),
643 makeSlice(ans), &resplen);
644 EXPECT_EQ(DnsTlsTransport::Response::success, r);
645 EXPECT_EQ(int(q.size()), resplen);
646 ans.resize(resplen);
647 EXPECT_EQ(q, ans);
648 }, &dispatcher);
649 }
650 for (auto& thread : threads) {
651 thread.join();
652 }
653 // We expect that the factory created one socket for each key.
654 EXPECT_EQ(keys.size(), weak_factory->keys.size());
655 for (auto& key : keys) {
656 EXPECT_EQ(1U, weak_factory->keys.count(key));
657 }
658 }
659
660 // Check DnsTlsServer's comparison logic.
661 AddressComparator ADDRESS_COMPARATOR;
isAddressEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)662 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
663 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
664 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
665 EXPECT_FALSE(cmp1 && cmp2);
666 return !cmp1 && !cmp2;
667 }
668
checkUnequal(const DnsTlsServer & s1,const DnsTlsServer & s2)669 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
670 EXPECT_TRUE(s1 == s1);
671 EXPECT_TRUE(s2 == s2);
672 EXPECT_TRUE(isAddressEqual(s1, s1));
673 EXPECT_TRUE(isAddressEqual(s2, s2));
674
675 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
676 EXPECT_FALSE(s1 == s2);
677 EXPECT_FALSE(s2 == s1);
678 }
679
680 class ServerTest : public BaseTest {};
681
TEST_F(ServerTest,IPv4)682 TEST_F(ServerTest, IPv4) {
683 checkUnequal(V4ADDR1, V4ADDR2);
684 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
685 }
686
TEST_F(ServerTest,IPv6)687 TEST_F(ServerTest, IPv6) {
688 checkUnequal(V6ADDR1, V6ADDR2);
689 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
690 }
691
TEST_F(ServerTest,MixedAddressFamily)692 TEST_F(ServerTest, MixedAddressFamily) {
693 checkUnequal(V6ADDR1, V4ADDR1);
694 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
695 }
696
TEST_F(ServerTest,IPv6ScopeId)697 TEST_F(ServerTest, IPv6ScopeId) {
698 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
699 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
700 addr1->sin6_scope_id = 1;
701 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
702 addr2->sin6_scope_id = 2;
703 checkUnequal(s1, s2);
704 EXPECT_FALSE(isAddressEqual(s1, s2));
705
706 EXPECT_FALSE(s1.wasExplicitlyConfigured());
707 EXPECT_FALSE(s2.wasExplicitlyConfigured());
708 }
709
TEST_F(ServerTest,IPv6FlowInfo)710 TEST_F(ServerTest, IPv6FlowInfo) {
711 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
712 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
713 addr1->sin6_flowinfo = 1;
714 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
715 addr2->sin6_flowinfo = 2;
716 // All comparisons ignore flowinfo.
717 EXPECT_EQ(s1, s2);
718 EXPECT_TRUE(isAddressEqual(s1, s2));
719
720 EXPECT_FALSE(s1.wasExplicitlyConfigured());
721 EXPECT_FALSE(s2.wasExplicitlyConfigured());
722 }
723
TEST_F(ServerTest,Port)724 TEST_F(ServerTest, Port) {
725 DnsTlsServer s1, s2;
726 parseServer("192.0.2.1", 853, &s1.ss);
727 parseServer("192.0.2.1", 854, &s2.ss);
728 checkUnequal(s1, s2);
729 EXPECT_TRUE(isAddressEqual(s1, s2));
730
731 DnsTlsServer s3, s4;
732 parseServer("2001:db8::1", 853, &s3.ss);
733 parseServer("2001:db8::1", 852, &s4.ss);
734 checkUnequal(s3, s4);
735 EXPECT_TRUE(isAddressEqual(s3, s4));
736
737 EXPECT_FALSE(s1.wasExplicitlyConfigured());
738 EXPECT_FALSE(s2.wasExplicitlyConfigured());
739 }
740
TEST_F(ServerTest,Name)741 TEST_F(ServerTest, Name) {
742 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
743 s1.name = SERVERNAME1;
744 checkUnequal(s1, s2);
745 s2.name = SERVERNAME2;
746 checkUnequal(s1, s2);
747 EXPECT_TRUE(isAddressEqual(s1, s2));
748
749 EXPECT_TRUE(s1.wasExplicitlyConfigured());
750 EXPECT_TRUE(s2.wasExplicitlyConfigured());
751 }
752
TEST_F(ServerTest,Fingerprint)753 TEST_F(ServerTest, Fingerprint) {
754 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
755
756 s1.fingerprints.insert(FINGERPRINT1);
757 checkUnequal(s1, s2);
758 EXPECT_TRUE(isAddressEqual(s1, s2));
759
760 s2.fingerprints.insert(FINGERPRINT2);
761 checkUnequal(s1, s2);
762 EXPECT_TRUE(isAddressEqual(s1, s2));
763
764 s2.fingerprints.insert(FINGERPRINT1);
765 checkUnequal(s1, s2);
766 EXPECT_TRUE(isAddressEqual(s1, s2));
767
768 s1.fingerprints.insert(FINGERPRINT2);
769 EXPECT_EQ(s1, s2);
770 EXPECT_TRUE(isAddressEqual(s1, s2));
771
772 EXPECT_TRUE(s1.wasExplicitlyConfigured());
773 EXPECT_TRUE(s2.wasExplicitlyConfigured());
774 }
775
TEST(QueryMapTest,Basic)776 TEST(QueryMapTest, Basic) {
777 DnsTlsQueryMap map;
778
779 EXPECT_TRUE(map.empty());
780
781 bytevec q0 = make_query(999, SIZE);
782 bytevec q1 = make_query(888, SIZE);
783 bytevec q2 = make_query(777, SIZE);
784
785 auto f0 = map.recordQuery(makeSlice(q0));
786 auto f1 = map.recordQuery(makeSlice(q1));
787 auto f2 = map.recordQuery(makeSlice(q2));
788
789 // Check return values of recordQuery
790 EXPECT_EQ(0, f0->query.newId);
791 EXPECT_EQ(1, f1->query.newId);
792 EXPECT_EQ(2, f2->query.newId);
793
794 // Check side effects of recordQuery
795 EXPECT_FALSE(map.empty());
796
797 auto all = map.getAll();
798 EXPECT_EQ(3U, all.size());
799
800 EXPECT_EQ(0, all[0].newId);
801 EXPECT_EQ(1, all[1].newId);
802 EXPECT_EQ(2, all[2].newId);
803
804 EXPECT_EQ(makeSlice(q0), all[0].query);
805 EXPECT_EQ(makeSlice(q1), all[1].query);
806 EXPECT_EQ(makeSlice(q2), all[2].query);
807
808 bytevec a0 = make_query(0, SIZE);
809 bytevec a1 = make_query(1, SIZE);
810 bytevec a2 = make_query(2, SIZE);
811
812 // Return responses out of order
813 map.onResponse(a2);
814 map.onResponse(a0);
815 map.onResponse(a1);
816
817 EXPECT_TRUE(map.empty());
818
819 auto r0 = f0->result.get();
820 auto r1 = f1->result.get();
821 auto r2 = f2->result.get();
822
823 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
824 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
825 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
826
827 const bytevec& d0 = r0.response;
828 const bytevec& d1 = r1.response;
829 const bytevec& d2 = r2.response;
830
831 // The ID should match the query
832 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
833 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
834 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
835 // The body should match the answer
836 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
837 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
838 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
839 }
840
TEST(QueryMapTest,FillHole)841 TEST(QueryMapTest, FillHole) {
842 DnsTlsQueryMap map;
843 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
844 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
845 futures[i] = map.recordQuery(makeSlice(QUERY));
846 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
847 EXPECT_EQ(i, futures[i]->query.newId);
848 }
849
850 // The map should now be full.
851 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
852
853 // Trying to add another query should fail because the map is full.
854 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
855
856 // Send an answer to query 40000
857 auto answer = make_query(40000, SIZE);
858 map.onResponse(answer);
859 auto result = futures[40000]->result.get();
860 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
861 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
862 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
863 bytevec(result.response.begin() + 2, result.response.end()));
864
865 // There should now be room in the map.
866 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
867 auto f = map.recordQuery(makeSlice(QUERY));
868 ASSERT_TRUE(f);
869 EXPECT_EQ(40000, f->query.newId);
870
871 // The map should now be full again.
872 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
873 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
874 }
875
876 class StubObserver : public IDnsTlsSocketObserver {
877 public:
878 bool closed = false;
onResponse(std::vector<uint8_t>)879 void onResponse(std::vector<uint8_t>) override {}
880
onClosed()881 void onClosed() override { closed = true; }
882 };
883
TEST(DnsTlsSocketTest,SlowDestructor)884 TEST(DnsTlsSocketTest, SlowDestructor) {
885 constexpr char tls_addr[] = "127.0.0.3";
886 constexpr char tls_port[] = "8530"; // High-numbered port so root isn't required.
887 // This test doesn't perform any queries, so the backend address can be invalid.
888 constexpr char backend_addr[] = "192.0.2.1";
889 constexpr char backend_port[] = "1";
890
891 test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
892 ASSERT_TRUE(tls.startServer());
893
894 DnsTlsServer server;
895 parseServer(tls_addr, 8530, &server.ss);
896
897 StubObserver observer;
898 ASSERT_FALSE(observer.closed);
899 DnsTlsSessionCache cache;
900 auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
901 ASSERT_TRUE(socket->initialize());
902
903 // Test: Time the socket destructor. This should be fast.
904 auto before = std::chrono::steady_clock::now();
905 socket.reset();
906 auto after = std::chrono::steady_clock::now();
907 auto delay = after - before;
908 ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
909 EXPECT_TRUE(observer.closed);
910 // Shutdown should complete in milliseconds, but if the shutdown signal is lost
911 // it will wait for the timeout, which is expected to take 20seconds.
912 EXPECT_LT(delay, std::chrono::seconds{5});
913 }
914
915 } // end of namespace net
916 } // end of namespace android
917