• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "dns_tls_test"
18 
19 #include <gtest/gtest.h>
20 
21 #include "dns/DnsTlsDispatcher.h"
22 #include "dns/DnsTlsQueryMap.h"
23 #include "dns/DnsTlsServer.h"
24 #include "dns/DnsTlsSessionCache.h"
25 #include "dns/DnsTlsSocket.h"
26 #include "dns/DnsTlsTransport.h"
27 #include "dns/IDnsTlsSocket.h"
28 #include "dns/IDnsTlsSocketFactory.h"
29 #include "dns/IDnsTlsSocketObserver.h"
30 
31 #include "dns_responder/dns_tls_frontend.h"
32 
33 #include <chrono>
34 #include <arpa/inet.h>
35 #include <android-base/macros.h>
36 #include <netdutils/Slice.h>
37 
38 #include "log/log.h"
39 
40 namespace android {
41 namespace net {
42 
43 using netdutils::Slice;
44 using netdutils::makeSlice;
45 
46 typedef std::vector<uint8_t> bytevec;
47 
parseServer(const char * server,in_port_t port,sockaddr_storage * parsed)48 static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
49     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
50     if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
51         // IPv4 parse succeeded, so it's IPv4
52         sin->sin_family = AF_INET;
53         sin->sin_port = htons(port);
54         return;
55     }
56     sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
57     if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
58         // IPv6 parse succeeded, so it's IPv6.
59         sin6->sin6_family = AF_INET6;
60         sin6->sin6_port = htons(port);
61         return;
62     }
63     ALOGE("Failed to parse server address: %s", server);
64 }
65 
66 bytevec FINGERPRINT1 = { 1 };
67 bytevec FINGERPRINT2 = { 2 };
68 
69 std::string SERVERNAME1 = "dns.example.com";
70 std::string SERVERNAME2 = "dns.example.org";
71 
72 // BaseTest just provides constants that are useful for the tests.
73 class BaseTest : public ::testing::Test {
74 protected:
BaseTest()75     BaseTest() {
76         parseServer("192.0.2.1", 853, &V4ADDR1);
77         parseServer("192.0.2.2", 853, &V4ADDR2);
78         parseServer("2001:db8::1", 853, &V6ADDR1);
79         parseServer("2001:db8::2", 853, &V6ADDR2);
80 
81         SERVER1 = DnsTlsServer(V4ADDR1);
82         SERVER1.fingerprints.insert(FINGERPRINT1);
83         SERVER1.name = SERVERNAME1;
84     }
85 
86     sockaddr_storage V4ADDR1;
87     sockaddr_storage V4ADDR2;
88     sockaddr_storage V6ADDR1;
89     sockaddr_storage V6ADDR2;
90 
91     DnsTlsServer SERVER1;
92 };
93 
make_query(uint16_t id,size_t size)94 bytevec make_query(uint16_t id, size_t size) {
95     bytevec vec(size);
96     vec[0] = id >> 8;
97     vec[1] = id;
98     // Arbitrarily fill the query body with unique data.
99     for (size_t i = 2; i < size; ++i) {
100         vec[i] = id + i;
101     }
102     return vec;
103 }
104 
105 // Query constants
106 const unsigned MARK = 123;
107 const uint16_t ID = 52;
108 const uint16_t SIZE = 22;
109 const bytevec QUERY = make_query(ID, SIZE);
110 
111 template <class T>
112 class FakeSocketFactory : public IDnsTlsSocketFactory {
113 public:
FakeSocketFactory()114     FakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)115     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
116             const DnsTlsServer& server ATTRIBUTE_UNUSED,
117             unsigned mark ATTRIBUTE_UNUSED,
118             IDnsTlsSocketObserver* observer,
119             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
120         return std::make_unique<T>(observer);
121     }
122 };
123 
make_echo(uint16_t id,const Slice query)124 bytevec make_echo(uint16_t id, const Slice query) {
125     bytevec response(query.size() + 2);
126     response[0] = id >> 8;
127     response[1] = id;
128     // Echo the query as the fake response.
129     memcpy(response.data() + 2, query.base(), query.size());
130     return response;
131 }
132 
133 // Simplest possible fake server.  This just echoes the query as the response.
134 class FakeSocketEcho : public IDnsTlsSocket {
135 public:
FakeSocketEcho(IDnsTlsSocketObserver * observer)136     FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query)137     bool query(uint16_t id, const Slice query) override {
138         // Return the response immediately (asynchronously).
139         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
140         return true;
141     }
142 private:
143     IDnsTlsSocketObserver* const mObserver;
144 };
145 
146 class TransportTest : public BaseTest {};
147 
TEST_F(TransportTest,Query)148 TEST_F(TransportTest, Query) {
149     FakeSocketFactory<FakeSocketEcho> factory;
150     DnsTlsTransport transport(SERVER1, MARK, &factory);
151     auto r = transport.query(makeSlice(QUERY)).get();
152 
153     EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
154     EXPECT_EQ(QUERY, r.response);
155 }
156 
TEST_F(TransportTest,SerialQueries_100000)157 TEST_F(TransportTest, SerialQueries_100000) {
158     FakeSocketFactory<FakeSocketEcho> factory;
159     DnsTlsTransport transport(SERVER1, MARK, &factory);
160     // Send more than 65536 queries serially.
161     for (int i = 0; i < 100000; ++i) {
162         auto r = transport.query(makeSlice(QUERY)).get();
163 
164         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
165         EXPECT_EQ(QUERY, r.response);
166     }
167 }
168 
169 // These queries might be handled in serial or parallel as they race the
170 // responses.
TEST_F(TransportTest,RacingQueries_10000)171 TEST_F(TransportTest, RacingQueries_10000) {
172     FakeSocketFactory<FakeSocketEcho> factory;
173     DnsTlsTransport transport(SERVER1, MARK, &factory);
174     std::vector<std::future<DnsTlsTransport::Result>> results;
175     // Fewer than 65536 queries to avoid ID exhaustion.
176     for (int i = 0; i < 10000; ++i) {
177         results.push_back(transport.query(makeSlice(QUERY)));
178     }
179     for (auto& result : results) {
180         auto r = result.get();
181         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
182         EXPECT_EQ(QUERY, r.response);
183     }
184 }
185 
186 // A server that waits until sDelay queries are queued before responding.
187 class FakeSocketDelay : public IDnsTlsSocket {
188 public:
FakeSocketDelay(IDnsTlsSocketObserver * observer)189     FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
~FakeSocketDelay()190     ~FakeSocketDelay() { std::lock_guard<std::mutex> guard(mLock); }
191     static size_t sDelay;
192     static bool sReverse;
193 
query(uint16_t id,const Slice query)194     bool query(uint16_t id, const Slice query) override {
195         ALOGD("FakeSocketDelay got query with ID %d", int(id));
196         std::lock_guard<std::mutex> guard(mLock);
197         // Check for duplicate IDs.
198         EXPECT_EQ(0U, mIds.count(id));
199         mIds.insert(id);
200 
201         // Store response.
202         mResponses.push_back(make_echo(id, query));
203 
204         ALOGD("Up to %zu out of %zu queries", mResponses.size(), sDelay);
205         if (mResponses.size() == sDelay) {
206             std::thread(&FakeSocketDelay::sendResponses, this).detach();
207         }
208         return true;
209     }
210 private:
sendResponses()211     void sendResponses() {
212         std::lock_guard<std::mutex> guard(mLock);
213         if (sReverse) {
214             std::reverse(std::begin(mResponses), std::end(mResponses));
215         }
216         for (auto& response : mResponses) {
217             mObserver->onResponse(response);
218         }
219         mIds.clear();
220         mResponses.clear();
221     }
222 
223     std::mutex mLock;
224     IDnsTlsSocketObserver* const mObserver;
225     std::set<uint16_t> mIds GUARDED_BY(mLock);
226     std::vector<bytevec> mResponses GUARDED_BY(mLock);
227 };
228 
229 size_t FakeSocketDelay::sDelay;
230 bool FakeSocketDelay::sReverse;
231 
TEST_F(TransportTest,ParallelColliding)232 TEST_F(TransportTest, ParallelColliding) {
233     FakeSocketDelay::sDelay = 10;
234     FakeSocketDelay::sReverse = false;
235     FakeSocketFactory<FakeSocketDelay> factory;
236     DnsTlsTransport transport(SERVER1, MARK, &factory);
237     std::vector<std::future<DnsTlsTransport::Result>> results;
238     // Fewer than 65536 queries to avoid ID exhaustion.
239     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
240         results.push_back(transport.query(makeSlice(QUERY)));
241     }
242     for (auto& result : results) {
243         auto r = result.get();
244         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
245         EXPECT_EQ(QUERY, r.response);
246     }
247 }
248 
TEST_F(TransportTest,ParallelColliding_Max)249 TEST_F(TransportTest, ParallelColliding_Max) {
250     FakeSocketDelay::sDelay = 65536;
251     FakeSocketDelay::sReverse = false;
252     FakeSocketFactory<FakeSocketDelay> factory;
253     DnsTlsTransport transport(SERVER1, MARK, &factory);
254     std::vector<std::future<DnsTlsTransport::Result>> results;
255     // Exactly 65536 queries should still be possible in parallel,
256     // even if they all have the same original ID.
257     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
258         results.push_back(transport.query(makeSlice(QUERY)));
259     }
260     for (auto& result : results) {
261         auto r = result.get();
262         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
263         EXPECT_EQ(QUERY, r.response);
264     }
265 }
266 
TEST_F(TransportTest,ParallelUnique)267 TEST_F(TransportTest, ParallelUnique) {
268     FakeSocketDelay::sDelay = 10;
269     FakeSocketDelay::sReverse = false;
270     FakeSocketFactory<FakeSocketDelay> factory;
271     DnsTlsTransport transport(SERVER1, MARK, &factory);
272     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
273     std::vector<std::future<DnsTlsTransport::Result>> results;
274     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
275         queries[i] = make_query(i, SIZE);
276         results.push_back(transport.query(makeSlice(queries[i])));
277     }
278     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
279         auto r = results[i].get();
280         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
281         EXPECT_EQ(queries[i], r.response);
282     }
283 }
284 
TEST_F(TransportTest,ParallelUnique_Max)285 TEST_F(TransportTest, ParallelUnique_Max) {
286     FakeSocketDelay::sDelay = 65536;
287     FakeSocketDelay::sReverse = false;
288     FakeSocketFactory<FakeSocketDelay> factory;
289     DnsTlsTransport transport(SERVER1, MARK, &factory);
290     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
291     std::vector<std::future<DnsTlsTransport::Result>> results;
292     // Exactly 65536 queries should still be possible in parallel,
293     // and they should all be mapped correctly back to the original ID.
294     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
295         queries[i] = make_query(i, SIZE);
296         results.push_back(transport.query(makeSlice(queries[i])));
297     }
298     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
299         auto r = results[i].get();
300         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
301         EXPECT_EQ(queries[i], r.response);
302     }
303 }
304 
TEST_F(TransportTest,IdExhaustion)305 TEST_F(TransportTest, IdExhaustion) {
306     // A delay of 65537 is unreachable, because the maximum number
307     // of outstanding queries is 65536.
308     FakeSocketDelay::sDelay = 65537;
309     FakeSocketDelay::sReverse = false;
310     FakeSocketFactory<FakeSocketDelay> factory;
311     DnsTlsTransport transport(SERVER1, MARK, &factory);
312     std::vector<std::future<DnsTlsTransport::Result>> results;
313     // Issue the maximum number of queries.
314     for (int i = 0; i < 65536; ++i) {
315         results.push_back(transport.query(makeSlice(QUERY)));
316     }
317 
318     // The ID space is now full, so subsequent queries should fail immediately.
319     auto r = transport.query(makeSlice(QUERY)).get();
320     EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
321     EXPECT_TRUE(r.response.empty());
322 
323     for (auto& result : results) {
324         // All other queries should remain outstanding.
325         EXPECT_EQ(std::future_status::timeout,
326                 result.wait_for(std::chrono::duration<int>::zero()));
327     }
328 }
329 
330 // Responses can come back from the server in any order.  This should have no
331 // effect on Transport's observed behavior.
TEST_F(TransportTest,ReverseOrder)332 TEST_F(TransportTest, ReverseOrder) {
333     FakeSocketDelay::sDelay = 10;
334     FakeSocketDelay::sReverse = true;
335     FakeSocketFactory<FakeSocketDelay> factory;
336     DnsTlsTransport transport(SERVER1, MARK, &factory);
337     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
338     std::vector<std::future<DnsTlsTransport::Result>> results;
339     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
340         queries[i] = make_query(i, SIZE);
341         results.push_back(transport.query(makeSlice(queries[i])));
342     }
343     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
344         auto r = results[i].get();
345         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
346         EXPECT_EQ(queries[i], r.response);
347     }
348 }
349 
TEST_F(TransportTest,ReverseOrder_Max)350 TEST_F(TransportTest, ReverseOrder_Max) {
351     FakeSocketDelay::sDelay = 65536;
352     FakeSocketDelay::sReverse = true;
353     FakeSocketFactory<FakeSocketDelay> factory;
354     DnsTlsTransport transport(SERVER1, MARK, &factory);
355     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
356     std::vector<std::future<DnsTlsTransport::Result>> results;
357     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
358         queries[i] = make_query(i, SIZE);
359         results.push_back(transport.query(makeSlice(queries[i])));
360     }
361     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
362         auto r = results[i].get();
363         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
364         EXPECT_EQ(queries[i], r.response);
365     }
366 }
367 
368 // Returning null from the factory indicates a connection failure.
369 class NullSocketFactory : public IDnsTlsSocketFactory {
370 public:
NullSocketFactory()371     NullSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer ATTRIBUTE_UNUSED,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)372     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
373             const DnsTlsServer& server ATTRIBUTE_UNUSED,
374             unsigned mark ATTRIBUTE_UNUSED,
375             IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
376             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
377         return nullptr;
378     }
379 };
380 
TEST_F(TransportTest,ConnectFail)381 TEST_F(TransportTest, ConnectFail) {
382     NullSocketFactory factory;
383     DnsTlsTransport transport(SERVER1, MARK, &factory);
384     auto r = transport.query(makeSlice(QUERY)).get();
385 
386     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
387     EXPECT_TRUE(r.response.empty());
388 }
389 
390 // Simulate a socket that connects but then immediately receives a server
391 // close notification.
392 class FakeSocketClose : public IDnsTlsSocket {
393 public:
FakeSocketClose(IDnsTlsSocketObserver * observer)394     FakeSocketClose(IDnsTlsSocketObserver* observer) :
395             mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
~FakeSocketClose()396     ~FakeSocketClose() { mCloser.join(); }
query(uint16_t id ATTRIBUTE_UNUSED,const Slice query ATTRIBUTE_UNUSED)397     bool query(uint16_t id ATTRIBUTE_UNUSED,
398                const Slice query ATTRIBUTE_UNUSED) override {
399         return true;
400     }
401 private:
402     std::thread mCloser;
403 };
404 
TEST_F(TransportTest,CloseRetryFail)405 TEST_F(TransportTest, CloseRetryFail) {
406     FakeSocketFactory<FakeSocketClose> factory;
407     DnsTlsTransport transport(SERVER1, MARK, &factory);
408     auto r = transport.query(makeSlice(QUERY)).get();
409 
410     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
411     EXPECT_TRUE(r.response.empty());
412 }
413 
414 // Simulate a server that occasionally closes the connection and silently
415 // drops some queries.
416 class FakeSocketLimited : public IDnsTlsSocket {
417 public:
418     static int sLimit;  // Number of queries to answer per socket.
419     static size_t sMaxSize;  // Silently discard queries greater than this size.
FakeSocketLimited(IDnsTlsSocketObserver * observer)420     FakeSocketLimited(IDnsTlsSocketObserver* observer) :
421             mObserver(observer), mQueries(0) {}
~FakeSocketLimited()422     ~FakeSocketLimited() {
423         {
424             ALOGD("~FakeSocketLimited acquiring mLock");
425             std::lock_guard<std::mutex> guard(mLock);
426             ALOGD("~FakeSocketLimited acquired mLock");
427             for (auto& thread : mThreads) {
428                 ALOGD("~FakeSocketLimited joining response thread");
429                 thread.join();
430                 ALOGD("~FakeSocketLimited joined response thread");
431             }
432             mThreads.clear();
433         }
434 
435         if (mCloser) {
436             ALOGD("~FakeSocketLimited joining closer thread");
437             mCloser->join();
438             ALOGD("~FakeSocketLimited joined closer thread");
439         }
440     }
query(uint16_t id,const Slice query)441     bool query(uint16_t id, const Slice query) override {
442         ALOGD("FakeSocketLimited::query acquiring mLock");
443         std::lock_guard<std::mutex> guard(mLock);
444         ALOGD("FakeSocketLimited::query acquired mLock");
445         ++mQueries;
446 
447         if (mQueries <= sLimit) {
448             ALOGD("size %zu vs. limit of %zu", query.size(), sMaxSize);
449             if (query.size() <= sMaxSize) {
450                 // Return the response immediately (asynchronously).
451                 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
452             }
453         }
454         if (mQueries == sLimit) {
455             mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
456         }
457         return mQueries <= sLimit;
458     }
459 private:
sendClose()460     void sendClose() {
461         {
462             ALOGD("FakeSocketLimited::sendClose acquiring mLock");
463             std::lock_guard<std::mutex> guard(mLock);
464             ALOGD("FakeSocketLimited::sendClose acquired mLock");
465             for (auto& thread : mThreads) {
466                 ALOGD("FakeSocketLimited::sendClose joining response thread");
467                 thread.join();
468                 ALOGD("FakeSocketLimited::sendClose joined response thread");
469             }
470             mThreads.clear();
471         }
472         mObserver->onClosed();
473     }
474     std::mutex mLock;
475     IDnsTlsSocketObserver* const mObserver;
476     int mQueries GUARDED_BY(mLock);
477     std::vector<std::thread> mThreads GUARDED_BY(mLock);
478     std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
479 };
480 
481 int FakeSocketLimited::sLimit;
482 size_t FakeSocketLimited::sMaxSize;
483 
TEST_F(TransportTest,SilentDrop)484 TEST_F(TransportTest, SilentDrop) {
485     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
486     FakeSocketLimited::sMaxSize = 0;  // Silently drop all queries
487     FakeSocketFactory<FakeSocketLimited> factory;
488     DnsTlsTransport transport(SERVER1, MARK, &factory);
489 
490     // Queue up 10 queries.  They will all be ignored, and after the 10th,
491     // the socket will close.  Transport will retry them all, until they
492     // all hit the retry limit and expire.
493     std::vector<std::future<DnsTlsTransport::Result>> results;
494     for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
495         results.push_back(transport.query(makeSlice(QUERY)));
496     }
497     for (auto& result : results) {
498         auto r = result.get();
499         EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
500         EXPECT_TRUE(r.response.empty());
501     }
502 }
503 
TEST_F(TransportTest,PartialDrop)504 TEST_F(TransportTest, PartialDrop) {
505     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
506     FakeSocketLimited::sMaxSize = SIZE - 2;  // Silently drop "long" queries
507     FakeSocketFactory<FakeSocketLimited> factory;
508     DnsTlsTransport transport(SERVER1, MARK, &factory);
509 
510     // Queue up 100 queries, alternating "short" which will be served and "long"
511     // which will be dropped.
512     int num_queries = 10 * FakeSocketLimited::sLimit;
513     std::vector<bytevec> queries(num_queries);
514     std::vector<std::future<DnsTlsTransport::Result>> results;
515     for (int i = 0; i < num_queries; ++i) {
516         queries[i] = make_query(i, SIZE + (i % 2));
517         results.push_back(transport.query(makeSlice(queries[i])));
518     }
519     // Just check the short queries, which are at the even indices.
520     for (int i = 0; i < num_queries; i += 2) {
521         auto r = results[i].get();
522         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
523         EXPECT_EQ(queries[i], r.response);
524     }
525 }
526 
527 // Simulate a malfunctioning server that injects extra miscellaneous
528 // responses to queries that were not asked.  This will cause wrong answers but
529 // must not crash the Transport.
530 class FakeSocketGarbage : public IDnsTlsSocket {
531 public:
FakeSocketGarbage(IDnsTlsSocketObserver * observer)532     FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
533         // Inject a garbage event.
534         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
535     }
~FakeSocketGarbage()536     ~FakeSocketGarbage() {
537         std::lock_guard<std::mutex> guard(mLock);
538         for (auto& thread : mThreads) {
539             thread.join();
540         }
541     }
query(uint16_t id,const Slice query)542     bool query(uint16_t id, const Slice query) override {
543         std::lock_guard<std::mutex> guard(mLock);
544         // Return the response twice.
545         auto echo = make_echo(id, query);
546         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
547         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
548         // Also return some other garbage
549         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
550         return true;
551     }
552 private:
553     std::mutex mLock;
554     std::vector<std::thread> mThreads GUARDED_BY(mLock);
555     IDnsTlsSocketObserver* const mObserver;
556 };
557 
TEST_F(TransportTest,IgnoringGarbage)558 TEST_F(TransportTest, IgnoringGarbage) {
559     FakeSocketFactory<FakeSocketGarbage> factory;
560     DnsTlsTransport transport(SERVER1, MARK, &factory);
561     for (int i = 0; i < 10; ++i) {
562         auto r = transport.query(makeSlice(QUERY)).get();
563 
564         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
565         // Don't check the response because this server is malfunctioning.
566     }
567 }
568 
569 // Dispatcher tests
570 class DispatcherTest : public BaseTest {};
571 
TEST_F(DispatcherTest,Query)572 TEST_F(DispatcherTest, Query) {
573     bytevec ans(4096);
574     int resplen = 0;
575 
576     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
577     DnsTlsDispatcher dispatcher(std::move(factory));
578     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
579                               makeSlice(ans), &resplen);
580 
581     EXPECT_EQ(DnsTlsTransport::Response::success, r);
582     EXPECT_EQ(int(QUERY.size()), resplen);
583     ans.resize(resplen);
584     EXPECT_EQ(QUERY, ans);
585 }
586 
TEST_F(DispatcherTest,AnswerTooLarge)587 TEST_F(DispatcherTest, AnswerTooLarge) {
588     bytevec ans(SIZE - 1);  // Too small to hold the answer
589     int resplen = 0;
590 
591     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
592     DnsTlsDispatcher dispatcher(std::move(factory));
593     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
594                               makeSlice(ans), &resplen);
595 
596     EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
597 }
598 
599 template<class T>
600 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
601 public:
TrackingFakeSocketFactory()602     TrackingFakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)603     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
604             const DnsTlsServer& server,
605             unsigned mark,
606             IDnsTlsSocketObserver* observer,
607             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
608         std::lock_guard<std::mutex> guard(mLock);
609         keys.emplace(mark, server);
610         return std::make_unique<T>(observer);
611     }
612     std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
613 private:
614     std::mutex mLock;
615 };
616 
TEST_F(DispatcherTest,Dispatching)617 TEST_F(DispatcherTest, Dispatching) {
618     FakeSocketDelay::sDelay = 5;
619     FakeSocketDelay::sReverse = true;
620     auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
621     auto* weak_factory = factory.get();  // Valid as long as dispatcher is in scope.
622     DnsTlsDispatcher dispatcher(std::move(factory));
623 
624     // Populate a vector of two servers and two socket marks, four combinations
625     // in total.
626     std::vector<std::pair<unsigned, DnsTlsServer>> keys;
627     keys.emplace_back(MARK, SERVER1);
628     keys.emplace_back(MARK + 1, SERVER1);
629     keys.emplace_back(MARK, V4ADDR2);
630     keys.emplace_back(MARK + 1, V4ADDR2);
631 
632     // Do several queries on each server.  They should all succeed.
633     std::vector<std::thread> threads;
634     for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
635         auto key = keys[i % keys.size()];
636         threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
637             auto q = make_query(i, SIZE);
638             bytevec ans(4096);
639             int resplen = 0;
640             unsigned mark = key.first;
641             const DnsTlsServer& server = key.second;
642             auto r = dispatcher->query(server, mark, makeSlice(q),
643                                        makeSlice(ans), &resplen);
644             EXPECT_EQ(DnsTlsTransport::Response::success, r);
645             EXPECT_EQ(int(q.size()), resplen);
646             ans.resize(resplen);
647             EXPECT_EQ(q, ans);
648         }, &dispatcher);
649     }
650     for (auto& thread : threads) {
651         thread.join();
652     }
653     // We expect that the factory created one socket for each key.
654     EXPECT_EQ(keys.size(), weak_factory->keys.size());
655     for (auto& key : keys) {
656         EXPECT_EQ(1U, weak_factory->keys.count(key));
657     }
658 }
659 
660 // Check DnsTlsServer's comparison logic.
661 AddressComparator ADDRESS_COMPARATOR;
isAddressEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)662 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
663     bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
664     bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
665     EXPECT_FALSE(cmp1 && cmp2);
666     return !cmp1 && !cmp2;
667 }
668 
checkUnequal(const DnsTlsServer & s1,const DnsTlsServer & s2)669 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
670     EXPECT_TRUE(s1 == s1);
671     EXPECT_TRUE(s2 == s2);
672     EXPECT_TRUE(isAddressEqual(s1, s1));
673     EXPECT_TRUE(isAddressEqual(s2, s2));
674 
675     EXPECT_TRUE(s1 < s2 ^ s2 < s1);
676     EXPECT_FALSE(s1 == s2);
677     EXPECT_FALSE(s2 == s1);
678 }
679 
680 class ServerTest : public BaseTest {};
681 
TEST_F(ServerTest,IPv4)682 TEST_F(ServerTest, IPv4) {
683     checkUnequal(V4ADDR1, V4ADDR2);
684     EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
685 }
686 
TEST_F(ServerTest,IPv6)687 TEST_F(ServerTest, IPv6) {
688     checkUnequal(V6ADDR1, V6ADDR2);
689     EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
690 }
691 
TEST_F(ServerTest,MixedAddressFamily)692 TEST_F(ServerTest, MixedAddressFamily) {
693     checkUnequal(V6ADDR1, V4ADDR1);
694     EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
695 }
696 
TEST_F(ServerTest,IPv6ScopeId)697 TEST_F(ServerTest, IPv6ScopeId) {
698     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
699     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
700     addr1->sin6_scope_id = 1;
701     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
702     addr2->sin6_scope_id = 2;
703     checkUnequal(s1, s2);
704     EXPECT_FALSE(isAddressEqual(s1, s2));
705 
706     EXPECT_FALSE(s1.wasExplicitlyConfigured());
707     EXPECT_FALSE(s2.wasExplicitlyConfigured());
708 }
709 
TEST_F(ServerTest,IPv6FlowInfo)710 TEST_F(ServerTest, IPv6FlowInfo) {
711     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
712     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
713     addr1->sin6_flowinfo = 1;
714     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
715     addr2->sin6_flowinfo = 2;
716     // All comparisons ignore flowinfo.
717     EXPECT_EQ(s1, s2);
718     EXPECT_TRUE(isAddressEqual(s1, s2));
719 
720     EXPECT_FALSE(s1.wasExplicitlyConfigured());
721     EXPECT_FALSE(s2.wasExplicitlyConfigured());
722 }
723 
TEST_F(ServerTest,Port)724 TEST_F(ServerTest, Port) {
725     DnsTlsServer s1, s2;
726     parseServer("192.0.2.1", 853, &s1.ss);
727     parseServer("192.0.2.1", 854, &s2.ss);
728     checkUnequal(s1, s2);
729     EXPECT_TRUE(isAddressEqual(s1, s2));
730 
731     DnsTlsServer s3, s4;
732     parseServer("2001:db8::1", 853, &s3.ss);
733     parseServer("2001:db8::1", 852, &s4.ss);
734     checkUnequal(s3, s4);
735     EXPECT_TRUE(isAddressEqual(s3, s4));
736 
737     EXPECT_FALSE(s1.wasExplicitlyConfigured());
738     EXPECT_FALSE(s2.wasExplicitlyConfigured());
739 }
740 
TEST_F(ServerTest,Name)741 TEST_F(ServerTest, Name) {
742     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
743     s1.name = SERVERNAME1;
744     checkUnequal(s1, s2);
745     s2.name = SERVERNAME2;
746     checkUnequal(s1, s2);
747     EXPECT_TRUE(isAddressEqual(s1, s2));
748 
749     EXPECT_TRUE(s1.wasExplicitlyConfigured());
750     EXPECT_TRUE(s2.wasExplicitlyConfigured());
751 }
752 
TEST_F(ServerTest,Fingerprint)753 TEST_F(ServerTest, Fingerprint) {
754     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
755 
756     s1.fingerprints.insert(FINGERPRINT1);
757     checkUnequal(s1, s2);
758     EXPECT_TRUE(isAddressEqual(s1, s2));
759 
760     s2.fingerprints.insert(FINGERPRINT2);
761     checkUnequal(s1, s2);
762     EXPECT_TRUE(isAddressEqual(s1, s2));
763 
764     s2.fingerprints.insert(FINGERPRINT1);
765     checkUnequal(s1, s2);
766     EXPECT_TRUE(isAddressEqual(s1, s2));
767 
768     s1.fingerprints.insert(FINGERPRINT2);
769     EXPECT_EQ(s1, s2);
770     EXPECT_TRUE(isAddressEqual(s1, s2));
771 
772     EXPECT_TRUE(s1.wasExplicitlyConfigured());
773     EXPECT_TRUE(s2.wasExplicitlyConfigured());
774 }
775 
TEST(QueryMapTest,Basic)776 TEST(QueryMapTest, Basic) {
777     DnsTlsQueryMap map;
778 
779     EXPECT_TRUE(map.empty());
780 
781     bytevec q0 = make_query(999, SIZE);
782     bytevec q1 = make_query(888, SIZE);
783     bytevec q2 = make_query(777, SIZE);
784 
785     auto f0 = map.recordQuery(makeSlice(q0));
786     auto f1 = map.recordQuery(makeSlice(q1));
787     auto f2 = map.recordQuery(makeSlice(q2));
788 
789     // Check return values of recordQuery
790     EXPECT_EQ(0, f0->query.newId);
791     EXPECT_EQ(1, f1->query.newId);
792     EXPECT_EQ(2, f2->query.newId);
793 
794     // Check side effects of recordQuery
795     EXPECT_FALSE(map.empty());
796 
797     auto all = map.getAll();
798     EXPECT_EQ(3U, all.size());
799 
800     EXPECT_EQ(0, all[0].newId);
801     EXPECT_EQ(1, all[1].newId);
802     EXPECT_EQ(2, all[2].newId);
803 
804     EXPECT_EQ(makeSlice(q0), all[0].query);
805     EXPECT_EQ(makeSlice(q1), all[1].query);
806     EXPECT_EQ(makeSlice(q2), all[2].query);
807 
808     bytevec a0 = make_query(0, SIZE);
809     bytevec a1 = make_query(1, SIZE);
810     bytevec a2 = make_query(2, SIZE);
811 
812     // Return responses out of order
813     map.onResponse(a2);
814     map.onResponse(a0);
815     map.onResponse(a1);
816 
817     EXPECT_TRUE(map.empty());
818 
819     auto r0 = f0->result.get();
820     auto r1 = f1->result.get();
821     auto r2 = f2->result.get();
822 
823     EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
824     EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
825     EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
826 
827     const bytevec& d0 = r0.response;
828     const bytevec& d1 = r1.response;
829     const bytevec& d2 = r2.response;
830 
831     // The ID should match the query
832     EXPECT_EQ(999, d0[0] << 8 | d0[1]);
833     EXPECT_EQ(888, d1[0] << 8 | d1[1]);
834     EXPECT_EQ(777, d2[0] << 8 | d2[1]);
835     // The body should match the answer
836     EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
837     EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
838     EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
839 }
840 
TEST(QueryMapTest,FillHole)841 TEST(QueryMapTest, FillHole) {
842     DnsTlsQueryMap map;
843     std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
844     for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
845         futures[i] = map.recordQuery(makeSlice(QUERY));
846         ASSERT_TRUE(futures[i]);  // answers[i] should be nonnull.
847         EXPECT_EQ(i, futures[i]->query.newId);
848     }
849 
850     // The map should now be full.
851     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
852 
853     // Trying to add another query should fail because the map is full.
854     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
855 
856     // Send an answer to query 40000
857     auto answer = make_query(40000, SIZE);
858     map.onResponse(answer);
859     auto result = futures[40000]->result.get();
860     EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
861     EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
862     EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
863               bytevec(result.response.begin() + 2, result.response.end()));
864 
865     // There should now be room in the map.
866     EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
867     auto f = map.recordQuery(makeSlice(QUERY));
868     ASSERT_TRUE(f);
869     EXPECT_EQ(40000, f->query.newId);
870 
871     // The map should now be full again.
872     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
873     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
874 }
875 
876 class StubObserver : public IDnsTlsSocketObserver {
877   public:
878     bool closed = false;
onResponse(std::vector<uint8_t>)879     void onResponse(std::vector<uint8_t>) override {}
880 
onClosed()881     void onClosed() override { closed = true; }
882 };
883 
TEST(DnsTlsSocketTest,SlowDestructor)884 TEST(DnsTlsSocketTest, SlowDestructor) {
885     constexpr char tls_addr[] = "127.0.0.3";
886     constexpr char tls_port[] = "8530";  // High-numbered port so root isn't required.
887     // This test doesn't perform any queries, so the backend address can be invalid.
888     constexpr char backend_addr[] = "192.0.2.1";
889     constexpr char backend_port[] = "1";
890 
891     test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
892     ASSERT_TRUE(tls.startServer());
893 
894     DnsTlsServer server;
895     parseServer(tls_addr, 8530, &server.ss);
896 
897     StubObserver observer;
898     ASSERT_FALSE(observer.closed);
899     DnsTlsSessionCache cache;
900     auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
901     ASSERT_TRUE(socket->initialize());
902 
903     // Test: Time the socket destructor.  This should be fast.
904     auto before = std::chrono::steady_clock::now();
905     socket.reset();
906     auto after = std::chrono::steady_clock::now();
907     auto delay = after - before;
908     ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
909     EXPECT_TRUE(observer.closed);
910     // Shutdown should complete in milliseconds, but if the shutdown signal is lost
911     // it will wait for the timeout, which is expected to take 20seconds.
912     EXPECT_LT(delay, std::chrono::seconds{5});
913 }
914 
915 } // end of namespace net
916 } // end of namespace android
917