• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless requied by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #include <arpa/inet.h>
19 #include <errno.h>
20 #include <netdb.h>
21 #include <stdarg.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <unistd.h>
25 
26 #include <cutils/sockets.h>
27 #include <android-base/stringprintf.h>
28 #include <private/android_filesystem_config.h>
29 
30 #include <openssl/base64.h>
31 
32 #include <algorithm>
33 #include <chrono>
34 #include <iterator>
35 #include <numeric>
36 #include <thread>
37 
38 #define LOG_TAG "netd_test"
39 // TODO: make this dynamic and stop depending on implementation details.
40 #define TEST_NETID 30
41 
42 #include "resolv_netid.h"
43 #include "NetdClient.h"
44 
45 #include <gtest/gtest.h>
46 
47 #include <utils/Log.h>
48 
49 #include "dns_responder.h"
50 #include "dns_responder_client.h"
51 #include "dns_tls_frontend.h"
52 #include "resolv_params.h"
53 #include "ResolverStats.h"
54 
55 #include "android/net/INetd.h"
56 #include "android/net/metrics/INetdEventListener.h"
57 #include "binder/IServiceManager.h"
58 #include "netdutils/SocketOption.h"
59 
60 using android::base::StringPrintf;
61 using android::base::StringAppendF;
62 using android::net::ResolverStats;
63 using android::net::metrics::INetdEventListener;
64 using android::netdutils::enableSockopt;
65 
66 // Emulates the behavior of UnorderedElementsAreArray, which currently cannot be used.
67 // TODO: Use UnorderedElementsAreArray, which depends on being able to compile libgmock_host,
68 // if that is not possible, improve this hacky algorithm, which is O(n**2)
69 template <class A, class B>
UnorderedCompareArray(const A & a,const B & b)70 bool UnorderedCompareArray(const A& a, const B& b) {
71     if (a.size() != b.size()) return false;
72     for (const auto& a_elem : a) {
73         size_t a_count = 0;
74         for (const auto& a_elem2 : a) {
75             if (a_elem == a_elem2) {
76                 ++a_count;
77             }
78         }
79         size_t b_count = 0;
80         for (const auto& b_elem : b) {
81             if (a_elem == b_elem) ++b_count;
82         }
83         if (a_count != b_count) return false;
84     }
85     return true;
86 }
87 
88 class AddrInfo {
89   public:
AddrInfo()90     AddrInfo() : ai_(nullptr), error_(0) {}
91 
AddrInfo(const char * node,const char * service,const addrinfo & hints)92     AddrInfo(const char* node, const char* service, const addrinfo& hints) : ai_(nullptr) {
93         init(node, service, hints);
94     }
95 
AddrInfo(const char * node,const char * service)96     AddrInfo(const char* node, const char* service) : ai_(nullptr) {
97         init(node, service);
98     }
99 
~AddrInfo()100     ~AddrInfo() { clear(); }
101 
init(const char * node,const char * service,const addrinfo & hints)102     int init(const char* node, const char* service, const addrinfo& hints) {
103         clear();
104         error_ = getaddrinfo(node, service, &hints, &ai_);
105         return error_;
106     }
107 
init(const char * node,const char * service)108     int init(const char* node, const char* service) {
109         clear();
110         error_ = getaddrinfo(node, service, nullptr, &ai_);
111         return error_;
112     }
113 
clear()114     void clear() {
115         if (ai_ != nullptr) {
116             freeaddrinfo(ai_);
117             ai_ = nullptr;
118             error_ = 0;
119         }
120     }
121 
operator *() const122     const addrinfo& operator*() const { return *ai_; }
get() const123     const addrinfo* get() const { return ai_; }
operator &() const124     const addrinfo* operator&() const { return ai_; }
error() const125     int error() const { return error_; }
126 
127   private:
128     addrinfo* ai_;
129     int error_;
130 };
131 
132 class ResolverTest : public ::testing::Test, public DnsResponderClient {
133 private:
134     int mOriginalMetricsLevel;
135 
136 protected:
SetUp()137     virtual void SetUp() {
138         // Ensure resolutions go via proxy.
139         DnsResponderClient::SetUp();
140 
141         // If DNS reporting is off: turn it on so we run through everything.
142         auto rv = mNetdSrv->getMetricsReportingLevel(&mOriginalMetricsLevel);
143         ASSERT_TRUE(rv.isOk());
144         if (mOriginalMetricsLevel != INetdEventListener::REPORTING_LEVEL_FULL) {
145             rv = mNetdSrv->setMetricsReportingLevel(INetdEventListener::REPORTING_LEVEL_FULL);
146             ASSERT_TRUE(rv.isOk());
147         }
148     }
149 
TearDown()150     virtual void TearDown() {
151         if (mOriginalMetricsLevel != INetdEventListener::REPORTING_LEVEL_FULL) {
152             auto rv = mNetdSrv->setMetricsReportingLevel(mOriginalMetricsLevel);
153             ASSERT_TRUE(rv.isOk());
154         }
155 
156         DnsResponderClient::TearDown();
157     }
158 
GetResolverInfo(std::vector<std::string> * servers,std::vector<std::string> * domains,__res_params * params,std::vector<ResolverStats> * stats)159     bool GetResolverInfo(std::vector<std::string>* servers, std::vector<std::string>* domains,
160             __res_params* params, std::vector<ResolverStats>* stats) {
161         using android::net::INetd;
162         std::vector<int32_t> params32;
163         std::vector<int32_t> stats32;
164         auto rv = mNetdSrv->getResolverInfo(TEST_NETID, servers, domains, &params32, &stats32);
165         if (!rv.isOk() || params32.size() != INetd::RESOLVER_PARAMS_COUNT) {
166             return false;
167         }
168         *params = __res_params {
169             .sample_validity = static_cast<uint16_t>(
170                     params32[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
171             .success_threshold = static_cast<uint8_t>(
172                     params32[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
173             .min_samples = static_cast<uint8_t>(
174                     params32[INetd::RESOLVER_PARAMS_MIN_SAMPLES]),
175             .max_samples = static_cast<uint8_t>(
176                     params32[INetd::RESOLVER_PARAMS_MAX_SAMPLES])
177         };
178         return ResolverStats::decodeAll(stats32, stats);
179     }
180 
ToString(const hostent * he) const181     std::string ToString(const hostent* he) const {
182         if (he == nullptr) return "<null>";
183         char buffer[INET6_ADDRSTRLEN];
184         if (!inet_ntop(he->h_addrtype, he->h_addr_list[0], buffer, sizeof(buffer))) {
185             return "<invalid>";
186         }
187         return buffer;
188     }
189 
ToString(const addrinfo * ai) const190     std::string ToString(const addrinfo* ai) const {
191         if (!ai)
192             return "<null>";
193         for (const auto* aip = ai ; aip != nullptr ; aip = aip->ai_next) {
194             char host[NI_MAXHOST];
195             int rv = getnameinfo(aip->ai_addr, aip->ai_addrlen, host, sizeof(host), nullptr, 0,
196                     NI_NUMERICHOST);
197             if (rv != 0)
198                 return gai_strerror(rv);
199             return host;
200         }
201         return "<invalid>";
202     }
203 
GetNumQueries(const test::DNSResponder & dns,const char * name) const204     size_t GetNumQueries(const test::DNSResponder& dns, const char* name) const {
205         auto queries = dns.queries();
206         size_t found = 0;
207         for (const auto& p : queries) {
208             if (p.first == name) {
209                 ++found;
210             }
211         }
212         return found;
213     }
214 
GetNumQueriesForType(const test::DNSResponder & dns,ns_type type,const char * name) const215     size_t GetNumQueriesForType(const test::DNSResponder& dns, ns_type type,
216             const char* name) const {
217         auto queries = dns.queries();
218         size_t found = 0;
219         for (const auto& p : queries) {
220             if (p.second == type && p.first == name) {
221                 ++found;
222             }
223         }
224         return found;
225     }
226 
RunGetAddrInfoStressTest_Binder(unsigned num_hosts,unsigned num_threads,unsigned num_queries)227     void RunGetAddrInfoStressTest_Binder(unsigned num_hosts, unsigned num_threads,
228             unsigned num_queries) {
229         std::vector<std::string> domains = { "example.com" };
230         std::vector<std::unique_ptr<test::DNSResponder>> dns;
231         std::vector<std::string> servers;
232         std::vector<DnsResponderClient::Mapping> mappings;
233         ASSERT_NO_FATAL_FAILURE(SetupMappings(num_hosts, domains, &mappings));
234         ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS, mappings, &dns, &servers));
235 
236         ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
237 
238         auto t0 = std::chrono::steady_clock::now();
239         std::vector<std::thread> threads(num_threads);
240         for (std::thread& thread : threads) {
241            thread = std::thread([this, &mappings, num_queries]() {
242                 for (unsigned i = 0 ; i < num_queries ; ++i) {
243                     uint32_t ofs = arc4random_uniform(mappings.size());
244                     auto& mapping = mappings[ofs];
245                     addrinfo* result = nullptr;
246                     int rv = getaddrinfo(mapping.host.c_str(), nullptr, nullptr, &result);
247                     EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
248                     if (rv == 0) {
249                         std::string result_str = ToString(result);
250                         EXPECT_TRUE(result_str == mapping.ip4 || result_str == mapping.ip6)
251                             << "result='" << result_str << "', ip4='" << mapping.ip4
252                             << "', ip6='" << mapping.ip6;
253                     }
254                     if (result) {
255                         freeaddrinfo(result);
256                         result = nullptr;
257                     }
258                 }
259             });
260         }
261 
262         for (std::thread& thread : threads) {
263             thread.join();
264         }
265         auto t1 = std::chrono::steady_clock::now();
266         ALOGI("%u hosts, %u threads, %u queries, %Es", num_hosts, num_threads, num_queries,
267                 std::chrono::duration<double>(t1 - t0).count());
268         ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
269     }
270 
271     const std::vector<std::string> mDefaultSearchDomains = { "example.com" };
272     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
273     const std::string mDefaultParams = "300 25 8 8";
274     const std::vector<int> mDefaultParams_Binder = { 300, 25, 8, 8 };
275 };
276 
TEST_F(ResolverTest,GetHostByName)277 TEST_F(ResolverTest, GetHostByName) {
278     const char* listen_addr = "127.0.0.3";
279     const char* listen_srv = "53";
280     const char* host_name = "hello.example.com.";
281     const char *nonexistent_host_name = "nonexistent.example.com.";
282     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
283     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
284     ASSERT_TRUE(dns.startServer());
285     std::vector<std::string> servers = { listen_addr };
286     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
287 
288     const hostent* result;
289 
290     dns.clearQueries();
291     result = gethostbyname("nonexistent");
292     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, nonexistent_host_name));
293     ASSERT_TRUE(result == nullptr);
294     ASSERT_EQ(HOST_NOT_FOUND, h_errno);
295 
296     dns.clearQueries();
297     result = gethostbyname("hello");
298     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
299     ASSERT_FALSE(result == nullptr);
300     ASSERT_EQ(4, result->h_length);
301     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
302     EXPECT_EQ("1.2.3.3", ToString(result));
303     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
304 
305     dns.stopServer();
306 }
307 
TEST_F(ResolverTest,TestBinderSerialization)308 TEST_F(ResolverTest, TestBinderSerialization) {
309     using android::net::INetd;
310     std::vector<int> params_offsets = {
311         INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY,
312         INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD,
313         INetd::RESOLVER_PARAMS_MIN_SAMPLES,
314         INetd::RESOLVER_PARAMS_MAX_SAMPLES
315     };
316     int size = static_cast<int>(params_offsets.size());
317     EXPECT_EQ(size, INetd::RESOLVER_PARAMS_COUNT);
318     std::sort(params_offsets.begin(), params_offsets.end());
319     for (int i = 0 ; i < size ; ++i) {
320         EXPECT_EQ(params_offsets[i], i);
321     }
322 }
323 
TEST_F(ResolverTest,GetHostByName_Binder)324 TEST_F(ResolverTest, GetHostByName_Binder) {
325     using android::net::INetd;
326 
327     std::vector<std::string> domains = { "example.com" };
328     std::vector<std::unique_ptr<test::DNSResponder>> dns;
329     std::vector<std::string> servers;
330     std::vector<Mapping> mappings;
331     ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
332     ASSERT_NO_FATAL_FAILURE(SetupDNSServers(4, mappings, &dns, &servers));
333     ASSERT_EQ(1U, mappings.size());
334     const Mapping& mapping = mappings[0];
335 
336     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
337 
338     const hostent* result = gethostbyname(mapping.host.c_str());
339     size_t total_queries = std::accumulate(dns.begin(), dns.end(), 0,
340             [this, &mapping](size_t total, auto& d) {
341                 return total + GetNumQueriesForType(*d, ns_type::ns_t_a, mapping.entry.c_str());
342             });
343 
344     EXPECT_LE(1U, total_queries);
345     ASSERT_FALSE(result == nullptr);
346     ASSERT_EQ(4, result->h_length);
347     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
348     EXPECT_EQ(mapping.ip4, ToString(result));
349     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
350 
351     std::vector<std::string> res_servers;
352     std::vector<std::string> res_domains;
353     __res_params res_params;
354     std::vector<ResolverStats> res_stats;
355     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
356     EXPECT_EQ(servers.size(), res_servers.size());
357     EXPECT_EQ(domains.size(), res_domains.size());
358     ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
359     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
360             res_params.sample_validity);
361     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
362             res_params.success_threshold);
363     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
364     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
365     EXPECT_EQ(servers.size(), res_stats.size());
366 
367     EXPECT_TRUE(UnorderedCompareArray(res_servers, servers));
368     EXPECT_TRUE(UnorderedCompareArray(res_domains, domains));
369 
370     ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
371 }
372 
TEST_F(ResolverTest,GetAddrInfo)373 TEST_F(ResolverTest, GetAddrInfo) {
374     addrinfo* result = nullptr;
375 
376     const char* listen_addr = "127.0.0.4";
377     const char* listen_addr2 = "127.0.0.5";
378     const char* listen_srv = "53";
379     const char* host_name = "howdy.example.com.";
380     test::DNSResponder dns(listen_addr, listen_srv, 250,
381                            ns_rcode::ns_r_servfail, 1.0);
382     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
383     dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
384     ASSERT_TRUE(dns.startServer());
385 
386     test::DNSResponder dns2(listen_addr2, listen_srv, 250,
387                             ns_rcode::ns_r_servfail, 1.0);
388     dns2.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
389     dns2.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
390     ASSERT_TRUE(dns2.startServer());
391 
392 
393     std::vector<std::string> servers = { listen_addr };
394     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
395     dns.clearQueries();
396     dns2.clearQueries();
397 
398     EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
399     size_t found = GetNumQueries(dns, host_name);
400     EXPECT_LE(1U, found);
401     // Could be A or AAAA
402     std::string result_str = ToString(result);
403     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
404         << ", result_str='" << result_str << "'";
405     // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
406     if (result) {
407         freeaddrinfo(result);
408         result = nullptr;
409     }
410 
411     // Verify that the name is cached.
412     size_t old_found = found;
413     EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
414     found = GetNumQueries(dns, host_name);
415     EXPECT_LE(1U, found);
416     EXPECT_EQ(old_found, found);
417     result_str = ToString(result);
418     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
419         << result_str;
420     if (result) {
421         freeaddrinfo(result);
422         result = nullptr;
423     }
424 
425     // Change the DNS resolver, ensure that queries are still cached.
426     servers = { listen_addr2 };
427     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
428     dns.clearQueries();
429     dns2.clearQueries();
430 
431     EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
432     found = GetNumQueries(dns, host_name);
433     size_t found2 = GetNumQueries(dns2, host_name);
434     EXPECT_EQ(0U, found);
435     EXPECT_LE(0U, found2);
436 
437     // Could be A or AAAA
438     result_str = ToString(result);
439     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
440         << ", result_str='" << result_str << "'";
441     if (result) {
442         freeaddrinfo(result);
443         result = nullptr;
444     }
445 
446     dns.stopServer();
447     dns2.stopServer();
448 }
449 
TEST_F(ResolverTest,GetAddrInfoV4)450 TEST_F(ResolverTest, GetAddrInfoV4) {
451     addrinfo* result = nullptr;
452 
453     const char* listen_addr = "127.0.0.5";
454     const char* listen_srv = "53";
455     const char* host_name = "hola.example.com.";
456     test::DNSResponder dns(listen_addr, listen_srv, 250,
457                            ns_rcode::ns_r_servfail, 1.0);
458     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.5");
459     ASSERT_TRUE(dns.startServer());
460     std::vector<std::string> servers = { listen_addr };
461     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
462 
463     addrinfo hints;
464     memset(&hints, 0, sizeof(hints));
465     hints.ai_family = AF_INET;
466     EXPECT_EQ(0, getaddrinfo("hola", nullptr, &hints, &result));
467     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
468     EXPECT_EQ("1.2.3.5", ToString(result));
469     if (result) {
470         freeaddrinfo(result);
471         result = nullptr;
472     }
473 }
474 
TEST_F(ResolverTest,GetHostByNameBrokenEdns)475 TEST_F(ResolverTest, GetHostByNameBrokenEdns) {
476     const char* listen_addr = "127.0.0.3";
477     const char* listen_srv = "53";
478     const char* host_name = "edns.example.com.";
479     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
480     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
481     dns.setFailOnEdns(true);  // This is the only change from the basic test.
482     ASSERT_TRUE(dns.startServer());
483     std::vector<std::string> servers = { listen_addr };
484     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
485 
486     const hostent* result;
487 
488     dns.clearQueries();
489     result = gethostbyname("edns");
490     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
491     ASSERT_FALSE(result == nullptr);
492     ASSERT_EQ(4, result->h_length);
493     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
494     EXPECT_EQ("1.2.3.3", ToString(result));
495     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
496 }
497 
TEST_F(ResolverTest,GetAddrInfoBrokenEdns)498 TEST_F(ResolverTest, GetAddrInfoBrokenEdns) {
499     addrinfo* result = nullptr;
500 
501     const char* listen_addr = "127.0.0.5";
502     const char* listen_srv = "53";
503     const char* host_name = "edns2.example.com.";
504     test::DNSResponder dns(listen_addr, listen_srv, 250,
505                            ns_rcode::ns_r_servfail, 1.0);
506     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.5");
507     dns.setFailOnEdns(true);  // This is the only change from the basic test.
508     ASSERT_TRUE(dns.startServer());
509     std::vector<std::string> servers = { listen_addr };
510     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
511 
512     addrinfo hints;
513     memset(&hints, 0, sizeof(hints));
514     hints.ai_family = AF_INET;
515     EXPECT_EQ(0, getaddrinfo("edns2", nullptr, &hints, &result));
516     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
517     EXPECT_EQ("1.2.3.5", ToString(result));
518     if (result) {
519         freeaddrinfo(result);
520         result = nullptr;
521     }
522 }
523 
TEST_F(ResolverTest,MultidomainResolution)524 TEST_F(ResolverTest, MultidomainResolution) {
525     std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };
526     const char* listen_addr = "127.0.0.6";
527     const char* listen_srv = "53";
528     const char* host_name = "nihao.example2.com.";
529     test::DNSResponder dns(listen_addr, listen_srv, 250,
530                            ns_rcode::ns_r_servfail, 1.0);
531     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
532     ASSERT_TRUE(dns.startServer());
533     std::vector<std::string> servers = { listen_addr };
534     ASSERT_TRUE(SetResolversForNetwork(servers, searchDomains, mDefaultParams_Binder));
535 
536     dns.clearQueries();
537     const hostent* result = gethostbyname("nihao");
538     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
539     ASSERT_FALSE(result == nullptr);
540     ASSERT_EQ(4, result->h_length);
541     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
542     EXPECT_EQ("1.2.3.3", ToString(result));
543     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
544     dns.stopServer();
545 }
546 
TEST_F(ResolverTest,GetAddrInfoV6_failing)547 TEST_F(ResolverTest, GetAddrInfoV6_failing) {
548     addrinfo* result = nullptr;
549 
550     const char* listen_addr0 = "127.0.0.7";
551     const char* listen_addr1 = "127.0.0.8";
552     const char* listen_srv = "53";
553     const char* host_name = "ohayou.example.com.";
554     test::DNSResponder dns0(listen_addr0, listen_srv, 250,
555                             ns_rcode::ns_r_servfail, 0.0);
556     test::DNSResponder dns1(listen_addr1, listen_srv, 250,
557                             ns_rcode::ns_r_servfail, 1.0);
558     dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
559     dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
560     ASSERT_TRUE(dns0.startServer());
561     ASSERT_TRUE(dns1.startServer());
562     std::vector<std::string> servers = { listen_addr0, listen_addr1 };
563     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
564     int sample_count = 8;
565     const std::vector<int> params = { 300, 25, sample_count, sample_count };
566     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, params));
567 
568     // Repeatedly perform resolutions for non-existing domains until MAXNSSAMPLES resolutions have
569     // reached the dns0, which is set to fail. No more requests should then arrive at that server
570     // for the next sample_lifetime seconds.
571     // TODO: This approach is implementation-dependent, change once metrics reporting is available.
572     addrinfo hints;
573     memset(&hints, 0, sizeof(hints));
574     hints.ai_family = AF_INET6;
575     for (int i = 0 ; i < sample_count ; ++i) {
576         std::string domain = StringPrintf("nonexistent%d", i);
577         getaddrinfo(domain.c_str(), nullptr, &hints, &result);
578         if (result) {
579             freeaddrinfo(result);
580             result = nullptr;
581         }
582     }
583     // Due to 100% errors for all possible samples, the server should be ignored from now on and
584     // only the second one used for all following queries, until NSSAMPLE_VALIDITY is reached.
585     dns0.clearQueries();
586     dns1.clearQueries();
587     EXPECT_EQ(0, getaddrinfo("ohayou", nullptr, &hints, &result));
588     EXPECT_EQ(0U, GetNumQueries(dns0, host_name));
589     EXPECT_EQ(1U, GetNumQueries(dns1, host_name));
590     if (result) {
591         freeaddrinfo(result);
592         result = nullptr;
593     }
594 }
595 
TEST_F(ResolverTest,GetAddrInfoV6_concurrent)596 TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
597     const char* listen_addr0 = "127.0.0.9";
598     const char* listen_addr1 = "127.0.0.10";
599     const char* listen_addr2 = "127.0.0.11";
600     const char* listen_srv = "53";
601     const char* host_name = "konbanha.example.com.";
602     test::DNSResponder dns0(listen_addr0, listen_srv, 250,
603                             ns_rcode::ns_r_servfail, 1.0);
604     test::DNSResponder dns1(listen_addr1, listen_srv, 250,
605                             ns_rcode::ns_r_servfail, 1.0);
606     test::DNSResponder dns2(listen_addr2, listen_srv, 250,
607                             ns_rcode::ns_r_servfail, 1.0);
608     dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
609     dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
610     dns2.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::7");
611     ASSERT_TRUE(dns0.startServer());
612     ASSERT_TRUE(dns1.startServer());
613     ASSERT_TRUE(dns2.startServer());
614     const std::vector<std::string> servers = { listen_addr0, listen_addr1, listen_addr2 };
615     std::vector<std::thread> threads(10);
616     for (std::thread& thread : threads) {
617        thread = std::thread([this, &servers]() {
618             unsigned delay = arc4random_uniform(1*1000*1000); // <= 1s
619             usleep(delay);
620             std::vector<std::string> serverSubset;
621             for (const auto& server : servers) {
622                 if (arc4random_uniform(2)) {
623                     serverSubset.push_back(server);
624                 }
625             }
626             if (serverSubset.empty()) serverSubset = servers;
627             ASSERT_TRUE(SetResolversForNetwork(serverSubset, mDefaultSearchDomains,
628                     mDefaultParams_Binder));
629             addrinfo hints;
630             memset(&hints, 0, sizeof(hints));
631             hints.ai_family = AF_INET6;
632             addrinfo* result = nullptr;
633             int rv = getaddrinfo("konbanha", nullptr, &hints, &result);
634             EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
635             if (result) {
636                 freeaddrinfo(result);
637                 result = nullptr;
638             }
639         });
640     }
641     for (std::thread& thread : threads) {
642         thread.join();
643     }
644 }
645 
TEST_F(ResolverTest,GetAddrInfoStressTest_Binder_100)646 TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100) {
647     const unsigned num_hosts = 100;
648     const unsigned num_threads = 100;
649     const unsigned num_queries = 100;
650     ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
651 }
652 
TEST_F(ResolverTest,GetAddrInfoStressTest_Binder_100000)653 TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100000) {
654     const unsigned num_hosts = 100000;
655     const unsigned num_threads = 100;
656     const unsigned num_queries = 100;
657     ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
658 }
659 
TEST_F(ResolverTest,EmptySetup)660 TEST_F(ResolverTest, EmptySetup) {
661     using android::net::INetd;
662     std::vector<std::string> servers;
663     std::vector<std::string> domains;
664     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
665     std::vector<std::string> res_servers;
666     std::vector<std::string> res_domains;
667     __res_params res_params;
668     std::vector<ResolverStats> res_stats;
669     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
670     EXPECT_EQ(0U, res_servers.size());
671     EXPECT_EQ(0U, res_domains.size());
672     ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
673     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
674             res_params.sample_validity);
675     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
676             res_params.success_threshold);
677     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
678     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
679 }
680 
TEST_F(ResolverTest,SearchPathChange)681 TEST_F(ResolverTest, SearchPathChange) {
682     addrinfo* result = nullptr;
683 
684     const char* listen_addr = "127.0.0.13";
685     const char* listen_srv = "53";
686     const char* host_name1 = "test13.domain1.org.";
687     const char* host_name2 = "test13.domain2.org.";
688     test::DNSResponder dns(listen_addr, listen_srv, 250,
689                            ns_rcode::ns_r_servfail, 1.0);
690     dns.addMapping(host_name1, ns_type::ns_t_aaaa, "2001:db8::13");
691     dns.addMapping(host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13");
692     ASSERT_TRUE(dns.startServer());
693     std::vector<std::string> servers = { listen_addr };
694     std::vector<std::string> domains = { "domain1.org" };
695     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
696 
697     addrinfo hints;
698     memset(&hints, 0, sizeof(hints));
699     hints.ai_family = AF_INET6;
700     EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
701     EXPECT_EQ(1U, dns.queries().size());
702     EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
703     EXPECT_EQ("2001:db8::13", ToString(result));
704     if (result) freeaddrinfo(result);
705 
706     // Test that changing the domain search path on its own works.
707     domains = { "domain2.org" };
708     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
709     dns.clearQueries();
710 
711     EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
712     EXPECT_EQ(1U, dns.queries().size());
713     EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
714     EXPECT_EQ("2001:db8::1:13", ToString(result));
715     if (result) freeaddrinfo(result);
716 }
717 
TEST_F(ResolverTest,MaxServerPrune_Binder)718 TEST_F(ResolverTest, MaxServerPrune_Binder) {
719     using android::net::INetd;
720 
721     std::vector<std::string> domains = { "example.com" };
722     std::vector<std::unique_ptr<test::DNSResponder>> dns;
723     std::vector<std::string> servers;
724     std::vector<Mapping> mappings;
725     ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
726     ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS + 1, mappings, &dns, &servers));
727 
728     ASSERT_TRUE(SetResolversForNetwork(servers, domains,  mDefaultParams_Binder));
729 
730     std::vector<std::string> res_servers;
731     std::vector<std::string> res_domains;
732     __res_params res_params;
733     std::vector<ResolverStats> res_stats;
734     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
735     EXPECT_EQ(static_cast<size_t>(MAXNS), res_servers.size());
736 
737     ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
738 }
739 
base64Encode(const std::vector<uint8_t> & input)740 static std::string base64Encode(const std::vector<uint8_t>& input) {
741     size_t out_len;
742     EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size()));
743     // out_len includes the trailing NULL.
744     uint8_t output_bytes[out_len];
745     EXPECT_EQ(out_len - 1, EVP_EncodeBlock(output_bytes, input.data(), input.size()));
746     return std::string(reinterpret_cast<char*>(output_bytes));
747 }
748 
749 // Test what happens if the specified TLS server is nonexistent.
TEST_F(ResolverTest,GetHostByName_TlsMissing)750 TEST_F(ResolverTest, GetHostByName_TlsMissing) {
751     const char* listen_addr = "127.0.0.3";
752     const char* listen_srv = "53";
753     const char* host_name = "tlsmissing.example.com.";
754     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
755     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
756     ASSERT_TRUE(dns.startServer());
757     std::vector<std::string> servers = { listen_addr };
758 
759     // There's nothing listening on this address, so validation will either fail or
760     /// hang.  Either way, queries will continue to flow to the DNSResponder.
761     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
762 
763     const hostent* result;
764 
765     result = gethostbyname("tlsmissing");
766     ASSERT_FALSE(result == nullptr);
767     EXPECT_EQ("1.2.3.3", ToString(result));
768 
769     // Clear TLS bit.
770     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
771     dns.stopServer();
772 }
773 
774 // Test what happens if the specified TLS server replies with garbage.
TEST_F(ResolverTest,GetHostByName_TlsBroken)775 TEST_F(ResolverTest, GetHostByName_TlsBroken) {
776     const char* listen_addr = "127.0.0.3";
777     const char* listen_srv = "53";
778     const char* host_name1 = "tlsbroken1.example.com.";
779     const char* host_name2 = "tlsbroken2.example.com.";
780     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
781     dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
782     dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
783     ASSERT_TRUE(dns.startServer());
784     std::vector<std::string> servers = { listen_addr };
785 
786     // Bind the specified private DNS socket but don't respond to any client sockets yet.
787     int s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
788     ASSERT_TRUE(s >= 0);
789     struct sockaddr_in tlsServer = {
790         .sin_family = AF_INET,
791         .sin_port = htons(853),
792     };
793     ASSERT_TRUE(inet_pton(AF_INET, listen_addr, &tlsServer.sin_addr));
794     enableSockopt(s, SOL_SOCKET, SO_REUSEPORT);
795     enableSockopt(s, SOL_SOCKET, SO_REUSEADDR);
796     ASSERT_FALSE(bind(s, reinterpret_cast<struct sockaddr*>(&tlsServer), sizeof(tlsServer)));
797     ASSERT_FALSE(listen(s, 1));
798 
799     // Trigger TLS validation.
800     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
801 
802     struct sockaddr_storage cliaddr;
803     socklen_t sin_size = sizeof(cliaddr);
804     int new_fd = accept(s, reinterpret_cast<struct sockaddr *>(&cliaddr), &sin_size);
805     ASSERT_TRUE(new_fd > 0);
806 
807     // We've received the new file descriptor but not written to it or closed, so the
808     // validation is still pending.  Queries should still flow correctly because the
809     // server is not used until validation succeeds.
810     const hostent* result;
811     result = gethostbyname("tlsbroken1");
812     ASSERT_FALSE(result == nullptr);
813     EXPECT_EQ("1.2.3.1", ToString(result));
814 
815     // Now we cause the validation to fail.
816     std::string garbage = "definitely not a valid TLS ServerHello";
817     write(new_fd, garbage.data(), garbage.size());
818     close(new_fd);
819 
820     // Validation failure shouldn't interfere with lookups, because lookups won't be sent
821     // to the TLS server unless validation succeeds.
822     result = gethostbyname("tlsbroken2");
823     ASSERT_FALSE(result == nullptr);
824     EXPECT_EQ("1.2.3.2", ToString(result));
825 
826     // Clear TLS bit.
827     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
828     dns.stopServer();
829     close(s);
830 }
831 
TEST_F(ResolverTest,GetHostByName_Tls)832 TEST_F(ResolverTest, GetHostByName_Tls) {
833     const char* listen_addr = "127.0.0.3";
834     const char* listen_udp = "53";
835     const char* listen_tls = "853";
836     const char* host_name1 = "tls1.example.com.";
837     const char* host_name2 = "tls2.example.com.";
838     const char* host_name3 = "tls3.example.com.";
839     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
840     dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
841     dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
842     dns.addMapping(host_name3, ns_type::ns_t_a, "1.2.3.3");
843     ASSERT_TRUE(dns.startServer());
844     std::vector<std::string> servers = { listen_addr };
845 
846     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
847     ASSERT_TRUE(tls.startServer());
848     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
849 
850     const hostent* result;
851 
852     // Wait for validation to complete.
853     EXPECT_TRUE(tls.waitForQueries(1, 5000));
854 
855     result = gethostbyname("tls1");
856     ASSERT_FALSE(result == nullptr);
857     EXPECT_EQ("1.2.3.1", ToString(result));
858 
859     // Wait for query to get counted.
860     EXPECT_TRUE(tls.waitForQueries(2, 5000));
861 
862     // Stop the TLS server.  Since we're in opportunistic mode, queries will
863     // fall back to the locally-assigned (clear text) nameservers.
864     tls.stopServer();
865 
866     dns.clearQueries();
867     result = gethostbyname("tls2");
868     EXPECT_FALSE(result == nullptr);
869     EXPECT_EQ("1.2.3.2", ToString(result));
870     const auto queries = dns.queries();
871     EXPECT_EQ(1U, queries.size());
872     EXPECT_EQ("tls2.example.com.", queries[0].first);
873     EXPECT_EQ(ns_t_a, queries[0].second);
874 
875     // Reset the resolvers without enabling TLS.  Queries should still be routed
876     // to the UDP endpoint.
877     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
878 
879     result = gethostbyname("tls3");
880     ASSERT_FALSE(result == nullptr);
881     EXPECT_EQ("1.2.3.3", ToString(result));
882 
883     dns.stopServer();
884 }
885 
TEST_F(ResolverTest,GetHostByName_TlsFingerprint)886 TEST_F(ResolverTest, GetHostByName_TlsFingerprint) {
887     const char* listen_addr = "127.0.0.3";
888     const char* listen_udp = "53";
889     const char* listen_tls = "853";
890     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
891     ASSERT_TRUE(dns.startServer());
892     for (int chain_length = 1; chain_length <= 3; ++chain_length) {
893         const char* host_name = StringPrintf("tlsfingerprint%d.example.com.", chain_length).c_str();
894         dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
895         std::vector<std::string> servers = { listen_addr };
896 
897         test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
898         tls.set_chain_length(chain_length);
899         ASSERT_TRUE(tls.startServer());
900         ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
901                 { base64Encode(tls.fingerprint()) }));
902 
903         const hostent* result;
904 
905         // Wait for validation to complete.
906         EXPECT_TRUE(tls.waitForQueries(1, 5000));
907 
908         result = gethostbyname(StringPrintf("tlsfingerprint%d", chain_length).c_str());
909         EXPECT_FALSE(result == nullptr);
910         if (result) {
911             EXPECT_EQ("1.2.3.1", ToString(result));
912 
913             // Wait for query to get counted.
914             EXPECT_TRUE(tls.waitForQueries(2, 5000));
915         }
916 
917         // Clear TLS bit to ensure revalidation.
918         ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
919         tls.stopServer();
920     }
921     dns.stopServer();
922 }
923 
TEST_F(ResolverTest,GetHostByName_BadTlsFingerprint)924 TEST_F(ResolverTest, GetHostByName_BadTlsFingerprint) {
925     const char* listen_addr = "127.0.0.3";
926     const char* listen_udp = "53";
927     const char* listen_tls = "853";
928     const char* host_name = "badtlsfingerprint.example.com.";
929     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
930     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
931     ASSERT_TRUE(dns.startServer());
932     std::vector<std::string> servers = { listen_addr };
933 
934     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
935     ASSERT_TRUE(tls.startServer());
936     std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
937     bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
938     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
939             { base64Encode(bad_fingerprint) }));
940 
941     // The initial validation should fail at the fingerprint check before
942     // issuing a query.
943     EXPECT_FALSE(tls.waitForQueries(1, 500));
944 
945     // A fingerprint was provided and failed to match, so the query should fail.
946     EXPECT_EQ(nullptr, gethostbyname("badtlsfingerprint"));
947 
948     // Clear TLS bit.
949     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
950     tls.stopServer();
951     dns.stopServer();
952 }
953 
954 // Test that we can pass two different fingerprints, and connection succeeds as long as
955 // at least one of them matches the server.
TEST_F(ResolverTest,GetHostByName_TwoTlsFingerprints)956 TEST_F(ResolverTest, GetHostByName_TwoTlsFingerprints) {
957     const char* listen_addr = "127.0.0.3";
958     const char* listen_udp = "53";
959     const char* listen_tls = "853";
960     const char* host_name = "twotlsfingerprints.example.com.";
961     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
962     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
963     ASSERT_TRUE(dns.startServer());
964     std::vector<std::string> servers = { listen_addr };
965 
966     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
967     ASSERT_TRUE(tls.startServer());
968     std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
969     bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
970     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
971             { base64Encode(bad_fingerprint), base64Encode(tls.fingerprint()) }));
972 
973     const hostent* result;
974 
975     // Wait for validation to complete.
976     EXPECT_TRUE(tls.waitForQueries(1, 5000));
977 
978     result = gethostbyname("twotlsfingerprints");
979     ASSERT_FALSE(result == nullptr);
980     EXPECT_EQ("1.2.3.1", ToString(result));
981 
982     // Wait for query to get counted.
983     EXPECT_TRUE(tls.waitForQueries(2, 5000));
984 
985     // Clear TLS bit.
986     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
987     tls.stopServer();
988     dns.stopServer();
989 }
990 
TEST_F(ResolverTest,GetHostByName_TlsFingerprintGoesBad)991 TEST_F(ResolverTest, GetHostByName_TlsFingerprintGoesBad) {
992     const char* listen_addr = "127.0.0.3";
993     const char* listen_udp = "53";
994     const char* listen_tls = "853";
995     const char* host_name1 = "tlsfingerprintgoesbad1.example.com.";
996     const char* host_name2 = "tlsfingerprintgoesbad2.example.com.";
997     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
998     dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
999     dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
1000     ASSERT_TRUE(dns.startServer());
1001     std::vector<std::string> servers = { listen_addr };
1002 
1003     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
1004     ASSERT_TRUE(tls.startServer());
1005     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
1006             { base64Encode(tls.fingerprint()) }));
1007 
1008     const hostent* result;
1009 
1010     // Wait for validation to complete.
1011     EXPECT_TRUE(tls.waitForQueries(1, 5000));
1012 
1013     result = gethostbyname("tlsfingerprintgoesbad1");
1014     ASSERT_FALSE(result == nullptr);
1015     EXPECT_EQ("1.2.3.1", ToString(result));
1016 
1017     // Wait for query to get counted.
1018     EXPECT_TRUE(tls.waitForQueries(2, 5000));
1019 
1020     // Restart the TLS server.  This will generate a new certificate whose fingerprint
1021     // no longer matches the stored fingerprint.
1022     tls.stopServer();
1023     tls.startServer();
1024 
1025     result = gethostbyname("tlsfingerprintgoesbad2");
1026     ASSERT_TRUE(result == nullptr);
1027     EXPECT_EQ(HOST_NOT_FOUND, h_errno);
1028 
1029     // Clear TLS bit.
1030     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
1031     tls.stopServer();
1032     dns.stopServer();
1033 }
1034 
TEST_F(ResolverTest,GetHostByName_TlsFailover)1035 TEST_F(ResolverTest, GetHostByName_TlsFailover) {
1036     const char* listen_addr1 = "127.0.0.3";
1037     const char* listen_addr2 = "127.0.0.4";
1038     const char* listen_udp = "53";
1039     const char* listen_tls = "853";
1040     const char* host_name1 = "tlsfailover1.example.com.";
1041     const char* host_name2 = "tlsfailover2.example.com.";
1042     test::DNSResponder dns1(listen_addr1, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
1043     test::DNSResponder dns2(listen_addr2, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
1044     dns1.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
1045     dns1.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
1046     dns2.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.3");
1047     dns2.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.4");
1048     ASSERT_TRUE(dns1.startServer());
1049     ASSERT_TRUE(dns2.startServer());
1050     std::vector<std::string> servers = { listen_addr1, listen_addr2 };
1051 
1052     test::DnsTlsFrontend tls1(listen_addr1, listen_tls, listen_addr1, listen_udp);
1053     test::DnsTlsFrontend tls2(listen_addr2, listen_tls, listen_addr2, listen_udp);
1054     ASSERT_TRUE(tls1.startServer());
1055     ASSERT_TRUE(tls2.startServer());
1056     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
1057             { base64Encode(tls1.fingerprint()), base64Encode(tls2.fingerprint()) }));
1058 
1059     const hostent* result;
1060 
1061     // Wait for validation to complete.
1062     EXPECT_TRUE(tls1.waitForQueries(1, 5000));
1063     EXPECT_TRUE(tls2.waitForQueries(1, 5000));
1064 
1065     result = gethostbyname("tlsfailover1");
1066     ASSERT_FALSE(result == nullptr);
1067     EXPECT_EQ("1.2.3.1", ToString(result));
1068 
1069     // Wait for query to get counted.
1070     EXPECT_TRUE(tls1.waitForQueries(2, 5000));
1071     // No new queries should have reached tls2.
1072     EXPECT_EQ(1, tls2.queries());
1073 
1074     // Stop tls1.  Subsequent queries should attempt to reach tls1, fail, and retry to tls2.
1075     tls1.stopServer();
1076 
1077     result = gethostbyname("tlsfailover2");
1078     EXPECT_EQ("1.2.3.4", ToString(result));
1079 
1080     // Wait for query to get counted.
1081     EXPECT_TRUE(tls2.waitForQueries(2, 5000));
1082 
1083     // No additional queries should have reached the insecure servers.
1084     EXPECT_EQ(2U, dns1.queries().size());
1085     EXPECT_EQ(2U, dns2.queries().size());
1086 
1087     // Clear TLS bit.
1088     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
1089     tls2.stopServer();
1090     dns1.stopServer();
1091     dns2.stopServer();
1092 }
1093 
TEST_F(ResolverTest,GetHostByName_BadTlsName)1094 TEST_F(ResolverTest, GetHostByName_BadTlsName) {
1095     const char* listen_addr = "127.0.0.3";
1096     const char* listen_udp = "53";
1097     const char* listen_tls = "853";
1098     const char* host_name = "badtlsname.example.com.";
1099     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
1100     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
1101     ASSERT_TRUE(dns.startServer());
1102     std::vector<std::string> servers = { listen_addr };
1103 
1104     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
1105     ASSERT_TRUE(tls.startServer());
1106     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder,
1107             "www.example.com", {}));
1108 
1109     // The TLS server's certificate doesn't chain to a known CA, and a nonempty name was specified,
1110     // so the client should fail the TLS handshake before ever issuing a query.
1111     EXPECT_FALSE(tls.waitForQueries(1, 500));
1112 
1113     // The query should fail hard, because a name was specified.
1114     EXPECT_EQ(nullptr, gethostbyname("badtlsname"));
1115 
1116     // Clear TLS bit.
1117     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
1118     tls.stopServer();
1119     dns.stopServer();
1120 }
1121 
TEST_F(ResolverTest,GetAddrInfo_Tls)1122 TEST_F(ResolverTest, GetAddrInfo_Tls) {
1123     const char* listen_addr = "127.0.0.3";
1124     const char* listen_udp = "53";
1125     const char* listen_tls = "853";
1126     const char* host_name = "addrinfotls.example.com.";
1127     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
1128     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
1129     dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
1130     ASSERT_TRUE(dns.startServer());
1131     std::vector<std::string> servers = { listen_addr };
1132 
1133     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
1134     ASSERT_TRUE(tls.startServer());
1135     ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
1136             { base64Encode(tls.fingerprint()) }));
1137 
1138     // Wait for validation to complete.
1139     EXPECT_TRUE(tls.waitForQueries(1, 5000));
1140 
1141     dns.clearQueries();
1142     addrinfo* result = nullptr;
1143     EXPECT_EQ(0, getaddrinfo("addrinfotls", nullptr, nullptr, &result));
1144     size_t found = GetNumQueries(dns, host_name);
1145     EXPECT_LE(1U, found);
1146     // Could be A or AAAA
1147     std::string result_str = ToString(result);
1148     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
1149         << ", result_str='" << result_str << "'";
1150     // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
1151     if (result) {
1152         freeaddrinfo(result);
1153         result = nullptr;
1154     }
1155     // Wait for both A and AAAA queries to get counted.
1156     EXPECT_TRUE(tls.waitForQueries(3, 5000));
1157 
1158     // Clear TLS bit.
1159     ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
1160     tls.stopServer();
1161     dns.stopServer();
1162 }
1163 
TEST_F(ResolverTest,TlsBypass)1164 TEST_F(ResolverTest, TlsBypass) {
1165     const char OFF[] = "off";
1166     const char OPPORTUNISTIC[] = "opportunistic";
1167     const char STRICT[] = "strict";
1168 
1169     const char GETHOSTBYNAME[] = "gethostbyname";
1170     const char GETADDRINFO[] = "getaddrinfo";
1171     const char GETADDRINFOFORNET[] = "getaddrinfofornet";
1172 
1173     const unsigned BYPASS_NETID = NETID_USE_LOCAL_NAMESERVERS | TEST_NETID;
1174 
1175     const std::vector<uint8_t> NOOP_FINGERPRINT(test::SHA256_SIZE, 0U);
1176 
1177     const char ADDR4[] = "192.0.2.1";
1178     const char ADDR6[] = "2001:db8::1";
1179 
1180     const char cleartext_addr[] = "127.0.0.53";
1181     const char cleartext_port[] = "53";
1182     const char tls_port[] = "853";
1183     const std::vector<std::string> servers = { cleartext_addr };
1184 
1185     test::DNSResponder dns(cleartext_addr, cleartext_port, 250, ns_rcode::ns_r_servfail, 1.0);
1186     ASSERT_TRUE(dns.startServer());
1187 
1188     test::DnsTlsFrontend tls(cleartext_addr, tls_port, cleartext_addr, cleartext_port);
1189 
1190     struct TestConfig {
1191         const std::string mode;
1192         const bool withWorkingTLS;
1193         const std::string method;
1194 
1195         std::string asHostName() const {
1196             return StringPrintf("%s.%s.%s.",
1197                                 mode.c_str(),
1198                                 withWorkingTLS ? "tlsOn" : "tlsOff",
1199                                 method.c_str());
1200         }
1201     } testConfigs[]{
1202         {OFF,           false, GETHOSTBYNAME},
1203         {OPPORTUNISTIC, false, GETHOSTBYNAME},
1204         {STRICT,        false, GETHOSTBYNAME},
1205         {OFF,           true,  GETHOSTBYNAME},
1206         {OPPORTUNISTIC, true,  GETHOSTBYNAME},
1207         {STRICT,        true,  GETHOSTBYNAME},
1208         {OFF,           false, GETADDRINFO},
1209         {OPPORTUNISTIC, false, GETADDRINFO},
1210         {STRICT,        false, GETADDRINFO},
1211         {OFF,           true,  GETADDRINFO},
1212         {OPPORTUNISTIC, true,  GETADDRINFO},
1213         {STRICT,        true,  GETADDRINFO},
1214         {OFF,           false, GETADDRINFOFORNET},
1215         {OPPORTUNISTIC, false, GETADDRINFOFORNET},
1216         {STRICT,        false, GETADDRINFOFORNET},
1217         {OFF,           true,  GETADDRINFOFORNET},
1218         {OPPORTUNISTIC, true,  GETADDRINFOFORNET},
1219         {STRICT,        true,  GETADDRINFOFORNET},
1220     };
1221 
1222     for (const auto& config : testConfigs) {
1223         const std::string testHostName = config.asHostName();
1224         SCOPED_TRACE(testHostName);
1225 
1226         // Don't tempt test bugs due to caching.
1227         const char* host_name = testHostName.c_str();
1228         dns.addMapping(host_name, ns_type::ns_t_a, ADDR4);
1229         dns.addMapping(host_name, ns_type::ns_t_aaaa, ADDR6);
1230 
1231         if (config.withWorkingTLS) ASSERT_TRUE(tls.startServer());
1232 
1233         if (config.mode == OFF) {
1234             ASSERT_TRUE(SetResolversForNetwork(
1235                     servers, mDefaultSearchDomains,  mDefaultParams_Binder));
1236         } else if (config.mode == OPPORTUNISTIC) {
1237             ASSERT_TRUE(SetResolversWithTls(
1238                     servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
1239             // Wait for validation to complete.
1240             if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
1241         } else if (config.mode == STRICT) {
1242             // We use the existence of fingerprints to trigger strict mode,
1243             // rather than hostname validation.
1244             const auto& fingerprint =
1245                     (config.withWorkingTLS) ? tls.fingerprint() : NOOP_FINGERPRINT;
1246             ASSERT_TRUE(SetResolversWithTls(
1247                     servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
1248                     { base64Encode(fingerprint) }));
1249             // Wait for validation to complete.
1250             if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
1251         } else {
1252             FAIL() << "Unsupported Private DNS mode: " << config.mode;
1253         }
1254 
1255         const int tlsQueriesBefore = tls.queries();
1256 
1257         const hostent* h_result = nullptr;
1258         addrinfo* ai_result = nullptr;
1259 
1260         if (config.method == GETHOSTBYNAME) {
1261             ASSERT_EQ(0, setNetworkForResolv(BYPASS_NETID));
1262             h_result = gethostbyname(host_name);
1263 
1264             EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
1265             ASSERT_FALSE(h_result == nullptr);
1266             ASSERT_EQ(4, h_result->h_length);
1267             ASSERT_FALSE(h_result->h_addr_list[0] == nullptr);
1268             EXPECT_EQ(ADDR4, ToString(h_result));
1269             EXPECT_TRUE(h_result->h_addr_list[1] == nullptr);
1270         } else if (config.method == GETADDRINFO) {
1271             ASSERT_EQ(0, setNetworkForResolv(BYPASS_NETID));
1272             EXPECT_EQ(0, getaddrinfo(host_name, nullptr, nullptr, &ai_result));
1273 
1274             EXPECT_LE(1U, GetNumQueries(dns, host_name));
1275             // Could be A or AAAA
1276             const std::string result_str = ToString(ai_result);
1277             EXPECT_TRUE(result_str == ADDR4 || result_str == ADDR6)
1278                 << ", result_str='" << result_str << "'";
1279         } else if (config.method == GETADDRINFOFORNET) {
1280             EXPECT_EQ(0, android_getaddrinfofornet(
1281                     host_name, nullptr, nullptr, BYPASS_NETID, MARK_UNSET, &ai_result));
1282 
1283             EXPECT_LE(1U, GetNumQueries(dns, host_name));
1284             // Could be A or AAAA
1285             const std::string result_str = ToString(ai_result);
1286             EXPECT_TRUE(result_str == ADDR4 || result_str == ADDR6)
1287                 << ", result_str='" << result_str << "'";
1288         } else {
1289             FAIL() << "Unsupported query method: " << config.method;
1290         }
1291 
1292         const int tlsQueriesAfter = tls.queries();
1293         EXPECT_EQ(0, tlsQueriesAfter - tlsQueriesBefore);
1294 
1295         // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
1296         if (ai_result != nullptr) freeaddrinfo(ai_result);
1297 
1298         // Clear per-process resolv netid.
1299         ASSERT_EQ(0, setNetworkForResolv(NETID_UNSET));
1300         tls.stopServer();
1301         dns.clearQueries();
1302     }
1303 
1304     dns.stopServer();
1305 }
1306 
TEST_F(ResolverTest,StrictMode_NoTlsServers)1307 TEST_F(ResolverTest, StrictMode_NoTlsServers) {
1308     const std::vector<uint8_t> NOOP_FINGERPRINT(test::SHA256_SIZE, 0U);
1309     const char cleartext_addr[] = "127.0.0.53";
1310     const char cleartext_port[] = "53";
1311     const std::vector<std::string> servers = { cleartext_addr };
1312 
1313     test::DNSResponder dns(cleartext_addr, cleartext_port, 250, ns_rcode::ns_r_servfail, 1.0);
1314     const char* host_name = "strictmode.notlsips.example.com.";
1315     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
1316     dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
1317     ASSERT_TRUE(dns.startServer());
1318 
1319     ASSERT_TRUE(SetResolversWithTls(
1320             servers, mDefaultSearchDomains, mDefaultParams_Binder,
1321             {}, "", { base64Encode(NOOP_FINGERPRINT) }));
1322 
1323     addrinfo* ai_result = nullptr;
1324     EXPECT_NE(0, getaddrinfo(host_name, nullptr, nullptr, &ai_result));
1325     EXPECT_EQ(0U, GetNumQueries(dns, host_name));
1326 }
1327