• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "dns_tls_test"
18 #define LOG_NDEBUG 1  // Set to 0 to enable verbose debug logging
19 
20 #include <gtest/gtest.h>
21 
22 #include "DnsTlsDispatcher.h"
23 #include "DnsTlsQueryMap.h"
24 #include "DnsTlsServer.h"
25 #include "DnsTlsSessionCache.h"
26 #include "DnsTlsSocket.h"
27 #include "DnsTlsTransport.h"
28 #include "IDnsTlsSocket.h"
29 #include "IDnsTlsSocketFactory.h"
30 #include "IDnsTlsSocketObserver.h"
31 
32 #include "dns_responder/dns_tls_frontend.h"
33 
34 #include <chrono>
35 #include <arpa/inet.h>
36 #include <android-base/macros.h>
37 #include <netdutils/Slice.h>
38 
39 #include "log/log.h"
40 
41 namespace android {
42 namespace net {
43 
44 using netdutils::Slice;
45 using netdutils::makeSlice;
46 
47 typedef std::vector<uint8_t> bytevec;
48 
parseServer(const char * server,in_port_t port,sockaddr_storage * parsed)49 static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
50     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
51     if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
52         // IPv4 parse succeeded, so it's IPv4
53         sin->sin_family = AF_INET;
54         sin->sin_port = htons(port);
55         return;
56     }
57     sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
58     if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
59         // IPv6 parse succeeded, so it's IPv6.
60         sin6->sin6_family = AF_INET6;
61         sin6->sin6_port = htons(port);
62         return;
63     }
64     ALOGE("Failed to parse server address: %s", server);
65 }
66 
67 bytevec FINGERPRINT1 = { 1 };
68 bytevec FINGERPRINT2 = { 2 };
69 
70 std::string SERVERNAME1 = "dns.example.com";
71 std::string SERVERNAME2 = "dns.example.org";
72 
73 // BaseTest just provides constants that are useful for the tests.
74 class BaseTest : public ::testing::Test {
75   protected:
BaseTest()76     BaseTest() {
77         parseServer("192.0.2.1", 853, &V4ADDR1);
78         parseServer("192.0.2.2", 853, &V4ADDR2);
79         parseServer("2001:db8::1", 853, &V6ADDR1);
80         parseServer("2001:db8::2", 853, &V6ADDR2);
81 
82         SERVER1 = DnsTlsServer(V4ADDR1);
83         SERVER1.fingerprints.insert(FINGERPRINT1);
84         SERVER1.name = SERVERNAME1;
85     }
86 
87     sockaddr_storage V4ADDR1;
88     sockaddr_storage V4ADDR2;
89     sockaddr_storage V6ADDR1;
90     sockaddr_storage V6ADDR2;
91 
92     DnsTlsServer SERVER1;
93 };
94 
make_query(uint16_t id,size_t size)95 bytevec make_query(uint16_t id, size_t size) {
96     bytevec vec(size);
97     vec[0] = id >> 8;
98     vec[1] = id;
99     // Arbitrarily fill the query body with unique data.
100     for (size_t i = 2; i < size; ++i) {
101         vec[i] = id + i;
102     }
103     return vec;
104 }
105 
106 // Query constants
107 const unsigned MARK = 123;
108 const uint16_t ID = 52;
109 const uint16_t SIZE = 22;
110 const bytevec QUERY = make_query(ID, SIZE);
111 
112 template <class T>
113 class FakeSocketFactory : public IDnsTlsSocketFactory {
114   public:
FakeSocketFactory()115     FakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)116     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
117             const DnsTlsServer& server ATTRIBUTE_UNUSED,
118             unsigned mark ATTRIBUTE_UNUSED,
119             IDnsTlsSocketObserver* observer,
120             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
121         return std::make_unique<T>(observer);
122     }
123 };
124 
make_echo(uint16_t id,const Slice query)125 bytevec make_echo(uint16_t id, const Slice query) {
126     bytevec response(query.size() + 2);
127     response[0] = id >> 8;
128     response[1] = id;
129     // Echo the query as the fake response.
130     memcpy(response.data() + 2, query.base(), query.size());
131     return response;
132 }
133 
134 // Simplest possible fake server.  This just echoes the query as the response.
135 class FakeSocketEcho : public IDnsTlsSocket {
136   public:
FakeSocketEcho(IDnsTlsSocketObserver * observer)137     explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query)138     bool query(uint16_t id, const Slice query) override {
139         // Return the response immediately (asynchronously).
140         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
141         return true;
142     }
143 
144   private:
145     IDnsTlsSocketObserver* const mObserver;
146 };
147 
148 class TransportTest : public BaseTest {};
149 
TEST_F(TransportTest,Query)150 TEST_F(TransportTest, Query) {
151     FakeSocketFactory<FakeSocketEcho> factory;
152     DnsTlsTransport transport(SERVER1, MARK, &factory);
153     auto r = transport.query(makeSlice(QUERY)).get();
154 
155     EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
156     EXPECT_EQ(QUERY, r.response);
157 }
158 
159 // Fake Socket that echoes the observed query ID as the response body.
160 class FakeSocketId : public IDnsTlsSocket {
161   public:
FakeSocketId(IDnsTlsSocketObserver * observer)162     explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query ATTRIBUTE_UNUSED)163     bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
164         // Return the response immediately (asynchronously).
165         bytevec response(4);
166         // Echo the ID in the header to match the response to the query.
167         // This will be overwritten by DnsTlsQueryMap.
168         response[0] = id >> 8;
169         response[1] = id;
170         // Echo the ID in the body, so that the test can verify which ID was used by
171         // DnsTlsQueryMap.
172         response[2] = id >> 8;
173         response[3] = id;
174         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
175         return true;
176     }
177 
178   private:
179     IDnsTlsSocketObserver* const mObserver;
180 };
181 
182 // Test that IDs are properly reused
TEST_F(TransportTest,IdReuse)183 TEST_F(TransportTest, IdReuse) {
184     FakeSocketFactory<FakeSocketId> factory;
185     DnsTlsTransport transport(SERVER1, MARK, &factory);
186     for (int i = 0; i < 100; ++i) {
187         // Send a query.
188         std::future<DnsTlsServer::Result> f = transport.query(makeSlice(QUERY));
189         // Wait for the response.
190         DnsTlsServer::Result r = f.get();
191         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
192 
193         // All queries should have an observed ID of zero, because it is returned to the ID pool
194         // after each use.
195         EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
196     }
197 }
198 
199 // These queries might be handled in serial or parallel as they race the
200 // responses.
TEST_F(TransportTest,RacingQueries_10000)201 TEST_F(TransportTest, RacingQueries_10000) {
202     FakeSocketFactory<FakeSocketEcho> factory;
203     DnsTlsTransport transport(SERVER1, MARK, &factory);
204     std::vector<std::future<DnsTlsTransport::Result>> results;
205     // Fewer than 65536 queries to avoid ID exhaustion.
206     const int num_queries = 10000;
207     results.reserve(num_queries);
208     for (int i = 0; i < num_queries; ++i) {
209         results.push_back(transport.query(makeSlice(QUERY)));
210     }
211     for (auto& result : results) {
212         auto r = result.get();
213         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
214         EXPECT_EQ(QUERY, r.response);
215     }
216 }
217 
218 // A server that waits until sDelay queries are queued before responding.
219 class FakeSocketDelay : public IDnsTlsSocket {
220   public:
FakeSocketDelay(IDnsTlsSocketObserver * observer)221     explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
~FakeSocketDelay()222     ~FakeSocketDelay() { std::lock_guard guard(mLock); }
223     static size_t sDelay;
224     static bool sReverse;
225 
query(uint16_t id,const Slice query)226     bool query(uint16_t id, const Slice query) override {
227         ALOGV("FakeSocketDelay got query with ID %d", int(id));
228         std::lock_guard guard(mLock);
229         // Check for duplicate IDs.
230         EXPECT_EQ(0U, mIds.count(id));
231         mIds.insert(id);
232 
233         // Store response.
234         mResponses.push_back(make_echo(id, query));
235 
236         ALOGV("Up to %zu out of %zu queries", mResponses.size(), sDelay);
237         if (mResponses.size() == sDelay) {
238             std::thread(&FakeSocketDelay::sendResponses, this).detach();
239         }
240         return true;
241     }
242 
243   private:
sendResponses()244     void sendResponses() {
245         std::lock_guard guard(mLock);
246         if (sReverse) {
247             std::reverse(std::begin(mResponses), std::end(mResponses));
248         }
249         for (auto& response : mResponses) {
250             mObserver->onResponse(response);
251         }
252         mIds.clear();
253         mResponses.clear();
254     }
255 
256     std::mutex mLock;
257     IDnsTlsSocketObserver* const mObserver;
258     std::set<uint16_t> mIds GUARDED_BY(mLock);
259     std::vector<bytevec> mResponses GUARDED_BY(mLock);
260 };
261 
262 size_t FakeSocketDelay::sDelay;
263 bool FakeSocketDelay::sReverse;
264 
TEST_F(TransportTest,ParallelColliding)265 TEST_F(TransportTest, ParallelColliding) {
266     FakeSocketDelay::sDelay = 10;
267     FakeSocketDelay::sReverse = false;
268     FakeSocketFactory<FakeSocketDelay> factory;
269     DnsTlsTransport transport(SERVER1, MARK, &factory);
270     std::vector<std::future<DnsTlsTransport::Result>> results;
271     // Fewer than 65536 queries to avoid ID exhaustion.
272     results.reserve(FakeSocketDelay::sDelay);
273     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
274         results.push_back(transport.query(makeSlice(QUERY)));
275     }
276     for (auto& result : results) {
277         auto r = result.get();
278         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
279         EXPECT_EQ(QUERY, r.response);
280     }
281 }
282 
TEST_F(TransportTest,ParallelColliding_Max)283 TEST_F(TransportTest, ParallelColliding_Max) {
284     FakeSocketDelay::sDelay = 65536;
285     FakeSocketDelay::sReverse = false;
286     FakeSocketFactory<FakeSocketDelay> factory;
287     DnsTlsTransport transport(SERVER1, MARK, &factory);
288     std::vector<std::future<DnsTlsTransport::Result>> results;
289     // Exactly 65536 queries should still be possible in parallel,
290     // even if they all have the same original ID.
291     results.reserve(FakeSocketDelay::sDelay);
292     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
293         results.push_back(transport.query(makeSlice(QUERY)));
294     }
295     for (auto& result : results) {
296         auto r = result.get();
297         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
298         EXPECT_EQ(QUERY, r.response);
299     }
300 }
301 
TEST_F(TransportTest,ParallelUnique)302 TEST_F(TransportTest, ParallelUnique) {
303     FakeSocketDelay::sDelay = 10;
304     FakeSocketDelay::sReverse = false;
305     FakeSocketFactory<FakeSocketDelay> factory;
306     DnsTlsTransport transport(SERVER1, MARK, &factory);
307     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
308     std::vector<std::future<DnsTlsTransport::Result>> results;
309     results.reserve(FakeSocketDelay::sDelay);
310     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
311         queries[i] = make_query(i, SIZE);
312         results.push_back(transport.query(makeSlice(queries[i])));
313     }
314     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
315         auto r = results[i].get();
316         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
317         EXPECT_EQ(queries[i], r.response);
318     }
319 }
320 
TEST_F(TransportTest,ParallelUnique_Max)321 TEST_F(TransportTest, ParallelUnique_Max) {
322     FakeSocketDelay::sDelay = 65536;
323     FakeSocketDelay::sReverse = false;
324     FakeSocketFactory<FakeSocketDelay> factory;
325     DnsTlsTransport transport(SERVER1, MARK, &factory);
326     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
327     std::vector<std::future<DnsTlsTransport::Result>> results;
328     // Exactly 65536 queries should still be possible in parallel,
329     // and they should all be mapped correctly back to the original ID.
330     results.reserve(FakeSocketDelay::sDelay);
331     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
332         queries[i] = make_query(i, SIZE);
333         results.push_back(transport.query(makeSlice(queries[i])));
334     }
335     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
336         auto r = results[i].get();
337         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
338         EXPECT_EQ(queries[i], r.response);
339     }
340 }
341 
TEST_F(TransportTest,IdExhaustion)342 TEST_F(TransportTest, IdExhaustion) {
343     const int num_queries = 65536;
344     // A delay of 65537 is unreachable, because the maximum number
345     // of outstanding queries is 65536.
346     FakeSocketDelay::sDelay = num_queries + 1;
347     FakeSocketDelay::sReverse = false;
348     FakeSocketFactory<FakeSocketDelay> factory;
349     DnsTlsTransport transport(SERVER1, MARK, &factory);
350     std::vector<std::future<DnsTlsTransport::Result>> results;
351     // Issue the maximum number of queries.
352     results.reserve(num_queries);
353     for (int i = 0; i < num_queries; ++i) {
354         results.push_back(transport.query(makeSlice(QUERY)));
355     }
356 
357     // The ID space is now full, so subsequent queries should fail immediately.
358     auto r = transport.query(makeSlice(QUERY)).get();
359     EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
360     EXPECT_TRUE(r.response.empty());
361 
362     for (auto& result : results) {
363         // All other queries should remain outstanding.
364         EXPECT_EQ(std::future_status::timeout,
365                 result.wait_for(std::chrono::duration<int>::zero()));
366     }
367 }
368 
369 // Responses can come back from the server in any order.  This should have no
370 // effect on Transport's observed behavior.
TEST_F(TransportTest,ReverseOrder)371 TEST_F(TransportTest, ReverseOrder) {
372     FakeSocketDelay::sDelay = 10;
373     FakeSocketDelay::sReverse = true;
374     FakeSocketFactory<FakeSocketDelay> factory;
375     DnsTlsTransport transport(SERVER1, MARK, &factory);
376     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
377     std::vector<std::future<DnsTlsTransport::Result>> results;
378     results.reserve(FakeSocketDelay::sDelay);
379     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
380         queries[i] = make_query(i, SIZE);
381         results.push_back(transport.query(makeSlice(queries[i])));
382     }
383     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
384         auto r = results[i].get();
385         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
386         EXPECT_EQ(queries[i], r.response);
387     }
388 }
389 
TEST_F(TransportTest,ReverseOrder_Max)390 TEST_F(TransportTest, ReverseOrder_Max) {
391     FakeSocketDelay::sDelay = 65536;
392     FakeSocketDelay::sReverse = true;
393     FakeSocketFactory<FakeSocketDelay> factory;
394     DnsTlsTransport transport(SERVER1, MARK, &factory);
395     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
396     std::vector<std::future<DnsTlsTransport::Result>> results;
397     results.reserve(FakeSocketDelay::sDelay);
398     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
399         queries[i] = make_query(i, SIZE);
400         results.push_back(transport.query(makeSlice(queries[i])));
401     }
402     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
403         auto r = results[i].get();
404         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
405         EXPECT_EQ(queries[i], r.response);
406     }
407 }
408 
409 // Returning null from the factory indicates a connection failure.
410 class NullSocketFactory : public IDnsTlsSocketFactory {
411   public:
NullSocketFactory()412     NullSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer ATTRIBUTE_UNUSED,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)413     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
414             const DnsTlsServer& server ATTRIBUTE_UNUSED,
415             unsigned mark ATTRIBUTE_UNUSED,
416             IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
417             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
418         return nullptr;
419     }
420 };
421 
TEST_F(TransportTest,ConnectFail)422 TEST_F(TransportTest, ConnectFail) {
423     NullSocketFactory factory;
424     DnsTlsTransport transport(SERVER1, MARK, &factory);
425     auto r = transport.query(makeSlice(QUERY)).get();
426 
427     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
428     EXPECT_TRUE(r.response.empty());
429 }
430 
431 // Simulate a socket that connects but then immediately receives a server
432 // close notification.
433 class FakeSocketClose : public IDnsTlsSocket {
434   public:
FakeSocketClose(IDnsTlsSocketObserver * observer)435     explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
436         : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
~FakeSocketClose()437     ~FakeSocketClose() { mCloser.join(); }
query(uint16_t id ATTRIBUTE_UNUSED,const Slice query ATTRIBUTE_UNUSED)438     bool query(uint16_t id ATTRIBUTE_UNUSED,
439                const Slice query ATTRIBUTE_UNUSED) override {
440         return true;
441     }
442 
443   private:
444     std::thread mCloser;
445 };
446 
TEST_F(TransportTest,CloseRetryFail)447 TEST_F(TransportTest, CloseRetryFail) {
448     FakeSocketFactory<FakeSocketClose> factory;
449     DnsTlsTransport transport(SERVER1, MARK, &factory);
450     auto r = transport.query(makeSlice(QUERY)).get();
451 
452     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
453     EXPECT_TRUE(r.response.empty());
454 }
455 
456 // Simulate a server that occasionally closes the connection and silently
457 // drops some queries.
458 class FakeSocketLimited : public IDnsTlsSocket {
459   public:
460     static int sLimit;  // Number of queries to answer per socket.
461     static size_t sMaxSize;  // Silently discard queries greater than this size.
FakeSocketLimited(IDnsTlsSocketObserver * observer)462     explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
463         : mObserver(observer), mQueries(0) {}
~FakeSocketLimited()464     ~FakeSocketLimited() {
465         {
466             ALOGV("~FakeSocketLimited acquiring mLock");
467             std::lock_guard guard(mLock);
468             ALOGV("~FakeSocketLimited acquired mLock");
469             for (auto& thread : mThreads) {
470                 ALOGV("~FakeSocketLimited joining response thread");
471                 thread.join();
472                 ALOGV("~FakeSocketLimited joined response thread");
473             }
474             mThreads.clear();
475         }
476 
477         if (mCloser) {
478             ALOGV("~FakeSocketLimited joining closer thread");
479             mCloser->join();
480             ALOGV("~FakeSocketLimited joined closer thread");
481         }
482     }
query(uint16_t id,const Slice query)483     bool query(uint16_t id, const Slice query) override {
484         ALOGV("FakeSocketLimited::query acquiring mLock");
485         std::lock_guard guard(mLock);
486         ALOGV("FakeSocketLimited::query acquired mLock");
487         ++mQueries;
488 
489         if (mQueries <= sLimit) {
490             ALOGV("size %zu vs. limit of %zu", query.size(), sMaxSize);
491             if (query.size() <= sMaxSize) {
492                 // Return the response immediately (asynchronously).
493                 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
494             }
495         }
496         if (mQueries == sLimit) {
497             mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
498         }
499         return mQueries <= sLimit;
500     }
501 
502   private:
sendClose()503     void sendClose() {
504         {
505             ALOGV("FakeSocketLimited::sendClose acquiring mLock");
506             std::lock_guard guard(mLock);
507             ALOGV("FakeSocketLimited::sendClose acquired mLock");
508             for (auto& thread : mThreads) {
509                 ALOGV("FakeSocketLimited::sendClose joining response thread");
510                 thread.join();
511                 ALOGV("FakeSocketLimited::sendClose joined response thread");
512             }
513             mThreads.clear();
514         }
515         mObserver->onClosed();
516     }
517     std::mutex mLock;
518     IDnsTlsSocketObserver* const mObserver;
519     int mQueries GUARDED_BY(mLock);
520     std::vector<std::thread> mThreads GUARDED_BY(mLock);
521     std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
522 };
523 
524 int FakeSocketLimited::sLimit;
525 size_t FakeSocketLimited::sMaxSize;
526 
TEST_F(TransportTest,SilentDrop)527 TEST_F(TransportTest, SilentDrop) {
528     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
529     FakeSocketLimited::sMaxSize = 0;  // Silently drop all queries
530     FakeSocketFactory<FakeSocketLimited> factory;
531     DnsTlsTransport transport(SERVER1, MARK, &factory);
532 
533     // Queue up 10 queries.  They will all be ignored, and after the 10th,
534     // the socket will close.  Transport will retry them all, until they
535     // all hit the retry limit and expire.
536     std::vector<std::future<DnsTlsTransport::Result>> results;
537     results.reserve(FakeSocketLimited::sLimit);
538     for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
539         results.push_back(transport.query(makeSlice(QUERY)));
540     }
541     for (auto& result : results) {
542         auto r = result.get();
543         EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
544         EXPECT_TRUE(r.response.empty());
545     }
546 }
547 
TEST_F(TransportTest,PartialDrop)548 TEST_F(TransportTest, PartialDrop) {
549     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
550     FakeSocketLimited::sMaxSize = SIZE - 2;  // Silently drop "long" queries
551     FakeSocketFactory<FakeSocketLimited> factory;
552     DnsTlsTransport transport(SERVER1, MARK, &factory);
553 
554     // Queue up 100 queries, alternating "short" which will be served and "long"
555     // which will be dropped.
556     const int num_queries = 10 * FakeSocketLimited::sLimit;
557     std::vector<bytevec> queries(num_queries);
558     std::vector<std::future<DnsTlsTransport::Result>> results;
559     results.reserve(num_queries);
560     for (int i = 0; i < num_queries; ++i) {
561         queries[i] = make_query(i, SIZE + (i % 2));
562         results.push_back(transport.query(makeSlice(queries[i])));
563     }
564     // Just check the short queries, which are at the even indices.
565     for (int i = 0; i < num_queries; i += 2) {
566         auto r = results[i].get();
567         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
568         EXPECT_EQ(queries[i], r.response);
569     }
570 }
571 
572 // Simulate a malfunctioning server that injects extra miscellaneous
573 // responses to queries that were not asked.  This will cause wrong answers but
574 // must not crash the Transport.
575 class FakeSocketGarbage : public IDnsTlsSocket {
576   public:
FakeSocketGarbage(IDnsTlsSocketObserver * observer)577     explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
578         // Inject a garbage event.
579         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
580     }
~FakeSocketGarbage()581     ~FakeSocketGarbage() {
582         std::lock_guard guard(mLock);
583         for (auto& thread : mThreads) {
584             thread.join();
585         }
586     }
query(uint16_t id,const Slice query)587     bool query(uint16_t id, const Slice query) override {
588         std::lock_guard guard(mLock);
589         // Return the response twice.
590         auto echo = make_echo(id, query);
591         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
592         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
593         // Also return some other garbage
594         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
595         return true;
596     }
597 
598   private:
599     std::mutex mLock;
600     std::vector<std::thread> mThreads GUARDED_BY(mLock);
601     IDnsTlsSocketObserver* const mObserver;
602 };
603 
TEST_F(TransportTest,IgnoringGarbage)604 TEST_F(TransportTest, IgnoringGarbage) {
605     FakeSocketFactory<FakeSocketGarbage> factory;
606     DnsTlsTransport transport(SERVER1, MARK, &factory);
607     for (int i = 0; i < 10; ++i) {
608         auto r = transport.query(makeSlice(QUERY)).get();
609 
610         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
611         // Don't check the response because this server is malfunctioning.
612     }
613 }
614 
615 // Dispatcher tests
616 class DispatcherTest : public BaseTest {};
617 
TEST_F(DispatcherTest,Query)618 TEST_F(DispatcherTest, Query) {
619     bytevec ans(4096);
620     int resplen = 0;
621 
622     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
623     DnsTlsDispatcher dispatcher(std::move(factory));
624     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
625                               makeSlice(ans), &resplen);
626 
627     EXPECT_EQ(DnsTlsTransport::Response::success, r);
628     EXPECT_EQ(int(QUERY.size()), resplen);
629     ans.resize(resplen);
630     EXPECT_EQ(QUERY, ans);
631 }
632 
TEST_F(DispatcherTest,AnswerTooLarge)633 TEST_F(DispatcherTest, AnswerTooLarge) {
634     bytevec ans(SIZE - 1);  // Too small to hold the answer
635     int resplen = 0;
636 
637     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
638     DnsTlsDispatcher dispatcher(std::move(factory));
639     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
640                               makeSlice(ans), &resplen);
641 
642     EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
643 }
644 
645 template<class T>
646 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
647   public:
TrackingFakeSocketFactory()648     TrackingFakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)649     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
650             const DnsTlsServer& server,
651             unsigned mark,
652             IDnsTlsSocketObserver* observer,
653             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
654         std::lock_guard guard(mLock);
655         keys.emplace(mark, server);
656         return std::make_unique<T>(observer);
657     }
658     std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
659 
660   private:
661     std::mutex mLock;
662 };
663 
TEST_F(DispatcherTest,Dispatching)664 TEST_F(DispatcherTest, Dispatching) {
665     FakeSocketDelay::sDelay = 5;
666     FakeSocketDelay::sReverse = true;
667     auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
668     auto* weak_factory = factory.get();  // Valid as long as dispatcher is in scope.
669     DnsTlsDispatcher dispatcher(std::move(factory));
670 
671     // Populate a vector of two servers and two socket marks, four combinations
672     // in total.
673     std::vector<std::pair<unsigned, DnsTlsServer>> keys;
674     keys.emplace_back(MARK, SERVER1);
675     keys.emplace_back(MARK + 1, SERVER1);
676     keys.emplace_back(MARK, V4ADDR2);
677     keys.emplace_back(MARK + 1, V4ADDR2);
678 
679     // Do several queries on each server.  They should all succeed.
680     std::vector<std::thread> threads;
681     for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
682         auto key = keys[i % keys.size()];
683         threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
684             auto q = make_query(i, SIZE);
685             bytevec ans(4096);
686             int resplen = 0;
687             unsigned mark = key.first;
688             const DnsTlsServer& server = key.second;
689             auto r = dispatcher->query(server, mark, makeSlice(q),
690                                        makeSlice(ans), &resplen);
691             EXPECT_EQ(DnsTlsTransport::Response::success, r);
692             EXPECT_EQ(int(q.size()), resplen);
693             ans.resize(resplen);
694             EXPECT_EQ(q, ans);
695         }, &dispatcher);
696     }
697     for (auto& thread : threads) {
698         thread.join();
699     }
700     // We expect that the factory created one socket for each key.
701     EXPECT_EQ(keys.size(), weak_factory->keys.size());
702     for (auto& key : keys) {
703         EXPECT_EQ(1U, weak_factory->keys.count(key));
704     }
705 }
706 
707 // Check DnsTlsServer's comparison logic.
708 AddressComparator ADDRESS_COMPARATOR;
isAddressEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)709 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
710     bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
711     bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
712     EXPECT_FALSE(cmp1 && cmp2);
713     return !cmp1 && !cmp2;
714 }
715 
checkUnequal(const DnsTlsServer & s1,const DnsTlsServer & s2)716 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
717     EXPECT_TRUE(s1 == s1);
718     EXPECT_TRUE(s2 == s2);
719     EXPECT_TRUE(isAddressEqual(s1, s1));
720     EXPECT_TRUE(isAddressEqual(s2, s2));
721 
722     EXPECT_TRUE(s1 < s2 ^ s2 < s1);
723     EXPECT_FALSE(s1 == s2);
724     EXPECT_FALSE(s2 == s1);
725 }
726 
727 class ServerTest : public BaseTest {};
728 
TEST_F(ServerTest,IPv4)729 TEST_F(ServerTest, IPv4) {
730     checkUnequal(V4ADDR1, V4ADDR2);
731     EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
732 }
733 
TEST_F(ServerTest,IPv6)734 TEST_F(ServerTest, IPv6) {
735     checkUnequal(V6ADDR1, V6ADDR2);
736     EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
737 }
738 
TEST_F(ServerTest,MixedAddressFamily)739 TEST_F(ServerTest, MixedAddressFamily) {
740     checkUnequal(V6ADDR1, V4ADDR1);
741     EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
742 }
743 
TEST_F(ServerTest,IPv6ScopeId)744 TEST_F(ServerTest, IPv6ScopeId) {
745     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
746     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
747     addr1->sin6_scope_id = 1;
748     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
749     addr2->sin6_scope_id = 2;
750     checkUnequal(s1, s2);
751     EXPECT_FALSE(isAddressEqual(s1, s2));
752 
753     EXPECT_FALSE(s1.wasExplicitlyConfigured());
754     EXPECT_FALSE(s2.wasExplicitlyConfigured());
755 }
756 
TEST_F(ServerTest,IPv6FlowInfo)757 TEST_F(ServerTest, IPv6FlowInfo) {
758     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
759     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
760     addr1->sin6_flowinfo = 1;
761     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
762     addr2->sin6_flowinfo = 2;
763     // All comparisons ignore flowinfo.
764     EXPECT_EQ(s1, s2);
765     EXPECT_TRUE(isAddressEqual(s1, s2));
766 
767     EXPECT_FALSE(s1.wasExplicitlyConfigured());
768     EXPECT_FALSE(s2.wasExplicitlyConfigured());
769 }
770 
TEST_F(ServerTest,Port)771 TEST_F(ServerTest, Port) {
772     DnsTlsServer s1, s2;
773     parseServer("192.0.2.1", 853, &s1.ss);
774     parseServer("192.0.2.1", 854, &s2.ss);
775     checkUnequal(s1, s2);
776     EXPECT_TRUE(isAddressEqual(s1, s2));
777 
778     DnsTlsServer s3, s4;
779     parseServer("2001:db8::1", 853, &s3.ss);
780     parseServer("2001:db8::1", 852, &s4.ss);
781     checkUnequal(s3, s4);
782     EXPECT_TRUE(isAddressEqual(s3, s4));
783 
784     EXPECT_FALSE(s1.wasExplicitlyConfigured());
785     EXPECT_FALSE(s2.wasExplicitlyConfigured());
786 }
787 
TEST_F(ServerTest,Name)788 TEST_F(ServerTest, Name) {
789     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
790     s1.name = SERVERNAME1;
791     checkUnequal(s1, s2);
792     s2.name = SERVERNAME2;
793     checkUnequal(s1, s2);
794     EXPECT_TRUE(isAddressEqual(s1, s2));
795 
796     EXPECT_TRUE(s1.wasExplicitlyConfigured());
797     EXPECT_TRUE(s2.wasExplicitlyConfigured());
798 }
799 
TEST_F(ServerTest,Fingerprint)800 TEST_F(ServerTest, Fingerprint) {
801     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
802 
803     s1.fingerprints.insert(FINGERPRINT1);
804     checkUnequal(s1, s2);
805     EXPECT_TRUE(isAddressEqual(s1, s2));
806 
807     s2.fingerprints.insert(FINGERPRINT2);
808     checkUnequal(s1, s2);
809     EXPECT_TRUE(isAddressEqual(s1, s2));
810 
811     s2.fingerprints.insert(FINGERPRINT1);
812     checkUnequal(s1, s2);
813     EXPECT_TRUE(isAddressEqual(s1, s2));
814 
815     s1.fingerprints.insert(FINGERPRINT2);
816     EXPECT_EQ(s1, s2);
817     EXPECT_TRUE(isAddressEqual(s1, s2));
818 
819     EXPECT_TRUE(s1.wasExplicitlyConfigured());
820     EXPECT_TRUE(s2.wasExplicitlyConfigured());
821 }
822 
TEST(QueryMapTest,Basic)823 TEST(QueryMapTest, Basic) {
824     DnsTlsQueryMap map;
825 
826     EXPECT_TRUE(map.empty());
827 
828     bytevec q0 = make_query(999, SIZE);
829     bytevec q1 = make_query(888, SIZE);
830     bytevec q2 = make_query(777, SIZE);
831 
832     auto f0 = map.recordQuery(makeSlice(q0));
833     auto f1 = map.recordQuery(makeSlice(q1));
834     auto f2 = map.recordQuery(makeSlice(q2));
835 
836     // Check return values of recordQuery
837     EXPECT_EQ(0, f0->query.newId);
838     EXPECT_EQ(1, f1->query.newId);
839     EXPECT_EQ(2, f2->query.newId);
840 
841     // Check side effects of recordQuery
842     EXPECT_FALSE(map.empty());
843 
844     auto all = map.getAll();
845     EXPECT_EQ(3U, all.size());
846 
847     EXPECT_EQ(0, all[0].newId);
848     EXPECT_EQ(1, all[1].newId);
849     EXPECT_EQ(2, all[2].newId);
850 
851     EXPECT_EQ(makeSlice(q0), all[0].query);
852     EXPECT_EQ(makeSlice(q1), all[1].query);
853     EXPECT_EQ(makeSlice(q2), all[2].query);
854 
855     bytevec a0 = make_query(0, SIZE);
856     bytevec a1 = make_query(1, SIZE);
857     bytevec a2 = make_query(2, SIZE);
858 
859     // Return responses out of order
860     map.onResponse(a2);
861     map.onResponse(a0);
862     map.onResponse(a1);
863 
864     EXPECT_TRUE(map.empty());
865 
866     auto r0 = f0->result.get();
867     auto r1 = f1->result.get();
868     auto r2 = f2->result.get();
869 
870     EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
871     EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
872     EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
873 
874     const bytevec& d0 = r0.response;
875     const bytevec& d1 = r1.response;
876     const bytevec& d2 = r2.response;
877 
878     // The ID should match the query
879     EXPECT_EQ(999, d0[0] << 8 | d0[1]);
880     EXPECT_EQ(888, d1[0] << 8 | d1[1]);
881     EXPECT_EQ(777, d2[0] << 8 | d2[1]);
882     // The body should match the answer
883     EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
884     EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
885     EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
886 }
887 
TEST(QueryMapTest,FillHole)888 TEST(QueryMapTest, FillHole) {
889     DnsTlsQueryMap map;
890     std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
891     for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
892         futures[i] = map.recordQuery(makeSlice(QUERY));
893         ASSERT_TRUE(futures[i]);  // answers[i] should be nonnull.
894         EXPECT_EQ(i, futures[i]->query.newId);
895     }
896 
897     // The map should now be full.
898     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
899 
900     // Trying to add another query should fail because the map is full.
901     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
902 
903     // Send an answer to query 40000
904     auto answer = make_query(40000, SIZE);
905     map.onResponse(answer);
906     auto result = futures[40000]->result.get();
907     EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
908     EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
909     EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
910               bytevec(result.response.begin() + 2, result.response.end()));
911 
912     // There should now be room in the map.
913     EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
914     auto f = map.recordQuery(makeSlice(QUERY));
915     ASSERT_TRUE(f);
916     EXPECT_EQ(40000, f->query.newId);
917 
918     // The map should now be full again.
919     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
920     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
921 }
922 
923 class StubObserver : public IDnsTlsSocketObserver {
924   public:
925     bool closed = false;
onResponse(std::vector<uint8_t>)926     void onResponse(std::vector<uint8_t>) override {}
927 
onClosed()928     void onClosed() override { closed = true; }
929 };
930 
TEST(DnsTlsSocketTest,SlowDestructor)931 TEST(DnsTlsSocketTest, SlowDestructor) {
932     constexpr char tls_addr[] = "127.0.0.3";
933     constexpr char tls_port[] = "8530";  // High-numbered port so root isn't required.
934     // This test doesn't perform any queries, so the backend address can be invalid.
935     constexpr char backend_addr[] = "192.0.2.1";
936     constexpr char backend_port[] = "1";
937 
938     test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
939     ASSERT_TRUE(tls.startServer());
940 
941     DnsTlsServer server;
942     parseServer(tls_addr, 8530, &server.ss);
943 
944     StubObserver observer;
945     ASSERT_FALSE(observer.closed);
946     DnsTlsSessionCache cache;
947     auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
948     ASSERT_TRUE(socket->initialize());
949 
950     // Test: Time the socket destructor.  This should be fast.
951     auto before = std::chrono::steady_clock::now();
952     socket.reset();
953     auto after = std::chrono::steady_clock::now();
954     auto delay = after - before;
955     ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
956     EXPECT_TRUE(observer.closed);
957     // Shutdown should complete in milliseconds, but if the shutdown signal is lost
958     // it will wait for the timeout, which is expected to take 20seconds.
959     EXPECT_LT(delay, std::chrono::seconds{5});
960 }
961 
962 } // end of namespace net
963 } // end of namespace android
964