• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless requied by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #include <arpa/inet.h>
19 #include <errno.h>
20 #include <netdb.h>
21 #include <stdarg.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <unistd.h>
25 
26 #include <cutils/sockets.h>
27 #include <android-base/stringprintf.h>
28 #include <private/android_filesystem_config.h>
29 
30 #include <openssl/base64.h>
31 
32 #include <algorithm>
33 #include <chrono>
34 #include <iterator>
35 #include <numeric>
36 #include <thread>
37 
38 #define LOG_TAG "netd_test"
39 // TODO: make this dynamic and stop depending on implementation details.
40 #define TEST_OEM_NETWORK "oem29"
41 #define TEST_NETID 30
42 
43 #include "NetdClient.h"
44 
45 #include <gtest/gtest.h>
46 
47 #include <utils/Log.h>
48 
49 #include "dns_responder.h"
50 #include "dns_responder_client.h"
51 #include "dns_tls_frontend.h"
52 #include "resolv_params.h"
53 #include "ResolverStats.h"
54 
55 #include "android/net/INetd.h"
56 #include "android/net/metrics/INetdEventListener.h"
57 #include "binder/IServiceManager.h"
58 
59 using android::base::StringPrintf;
60 using android::base::StringAppendF;
61 using android::net::ResolverStats;
62 using android::net::metrics::INetdEventListener;
63 
64 // Emulates the behavior of UnorderedElementsAreArray, which currently cannot be used.
65 // TODO: Use UnorderedElementsAreArray, which depends on being able to compile libgmock_host,
66 // if that is not possible, improve this hacky algorithm, which is O(n**2)
67 template <class A, class B>
UnorderedCompareArray(const A & a,const B & b)68 bool UnorderedCompareArray(const A& a, const B& b) {
69     if (a.size() != b.size()) return false;
70     for (const auto& a_elem : a) {
71         size_t a_count = 0;
72         for (const auto& a_elem2 : a) {
73             if (a_elem == a_elem2) {
74                 ++a_count;
75             }
76         }
77         size_t b_count = 0;
78         for (const auto& b_elem : b) {
79             if (a_elem == b_elem) ++b_count;
80         }
81         if (a_count != b_count) return false;
82     }
83     return true;
84 }
85 
86 class AddrInfo {
87   public:
AddrInfo()88     AddrInfo() : ai_(nullptr), error_(0) {}
89 
AddrInfo(const char * node,const char * service,const addrinfo & hints)90     AddrInfo(const char* node, const char* service, const addrinfo& hints) : ai_(nullptr) {
91         init(node, service, hints);
92     }
93 
AddrInfo(const char * node,const char * service)94     AddrInfo(const char* node, const char* service) : ai_(nullptr) {
95         init(node, service);
96     }
97 
~AddrInfo()98     ~AddrInfo() { clear(); }
99 
init(const char * node,const char * service,const addrinfo & hints)100     int init(const char* node, const char* service, const addrinfo& hints) {
101         clear();
102         error_ = getaddrinfo(node, service, &hints, &ai_);
103         return error_;
104     }
105 
init(const char * node,const char * service)106     int init(const char* node, const char* service) {
107         clear();
108         error_ = getaddrinfo(node, service, nullptr, &ai_);
109         return error_;
110     }
111 
clear()112     void clear() {
113         if (ai_ != nullptr) {
114             freeaddrinfo(ai_);
115             ai_ = nullptr;
116             error_ = 0;
117         }
118     }
119 
operator *() const120     const addrinfo& operator*() const { return *ai_; }
get() const121     const addrinfo* get() const { return ai_; }
operator &() const122     const addrinfo* operator&() const { return ai_; }
error() const123     int error() const { return error_; }
124 
125   private:
126     addrinfo* ai_;
127     int error_;
128 };
129 
130 class ResolverTest : public ::testing::Test, public DnsResponderClient {
131 private:
132     int mOriginalMetricsLevel;
133 
134 protected:
SetUp()135     virtual void SetUp() {
136         // Ensure resolutions go via proxy.
137         DnsResponderClient::SetUp();
138 
139         // If DNS reporting is off: turn it on so we run through everything.
140         auto rv = mNetdSrv->getMetricsReportingLevel(&mOriginalMetricsLevel);
141         ASSERT_TRUE(rv.isOk());
142         if (mOriginalMetricsLevel != INetdEventListener::REPORTING_LEVEL_FULL) {
143             rv = mNetdSrv->setMetricsReportingLevel(INetdEventListener::REPORTING_LEVEL_FULL);
144             ASSERT_TRUE(rv.isOk());
145         }
146     }
147 
TearDown()148     virtual void TearDown() {
149         if (mOriginalMetricsLevel != INetdEventListener::REPORTING_LEVEL_FULL) {
150             auto rv = mNetdSrv->setMetricsReportingLevel(mOriginalMetricsLevel);
151             ASSERT_TRUE(rv.isOk());
152         }
153 
154         DnsResponderClient::TearDown();
155     }
156 
GetResolverInfo(std::vector<std::string> * servers,std::vector<std::string> * domains,__res_params * params,std::vector<ResolverStats> * stats)157     bool GetResolverInfo(std::vector<std::string>* servers, std::vector<std::string>* domains,
158             __res_params* params, std::vector<ResolverStats>* stats) {
159         using android::net::INetd;
160         std::vector<int32_t> params32;
161         std::vector<int32_t> stats32;
162         auto rv = mNetdSrv->getResolverInfo(TEST_NETID, servers, domains, &params32, &stats32);
163         if (!rv.isOk() || params32.size() != INetd::RESOLVER_PARAMS_COUNT) {
164             return false;
165         }
166         *params = __res_params {
167             .sample_validity = static_cast<uint16_t>(
168                     params32[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
169             .success_threshold = static_cast<uint8_t>(
170                     params32[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
171             .min_samples = static_cast<uint8_t>(
172                     params32[INetd::RESOLVER_PARAMS_MIN_SAMPLES]),
173             .max_samples = static_cast<uint8_t>(
174                     params32[INetd::RESOLVER_PARAMS_MAX_SAMPLES])
175         };
176         return ResolverStats::decodeAll(stats32, stats);
177     }
178 
ToString(const hostent * he) const179     std::string ToString(const hostent* he) const {
180         if (he == nullptr) return "<null>";
181         char buffer[INET6_ADDRSTRLEN];
182         if (!inet_ntop(he->h_addrtype, he->h_addr_list[0], buffer, sizeof(buffer))) {
183             return "<invalid>";
184         }
185         return buffer;
186     }
187 
ToString(const addrinfo * ai) const188     std::string ToString(const addrinfo* ai) const {
189         if (!ai)
190             return "<null>";
191         for (const auto* aip = ai ; aip != nullptr ; aip = aip->ai_next) {
192             char host[NI_MAXHOST];
193             int rv = getnameinfo(aip->ai_addr, aip->ai_addrlen, host, sizeof(host), nullptr, 0,
194                     NI_NUMERICHOST);
195             if (rv != 0)
196                 return gai_strerror(rv);
197             return host;
198         }
199         return "<invalid>";
200     }
201 
GetNumQueries(const test::DNSResponder & dns,const char * name) const202     size_t GetNumQueries(const test::DNSResponder& dns, const char* name) const {
203         auto queries = dns.queries();
204         size_t found = 0;
205         for (const auto& p : queries) {
206             if (p.first == name) {
207                 ++found;
208             }
209         }
210         return found;
211     }
212 
GetNumQueriesForType(const test::DNSResponder & dns,ns_type type,const char * name) const213     size_t GetNumQueriesForType(const test::DNSResponder& dns, ns_type type,
214             const char* name) const {
215         auto queries = dns.queries();
216         size_t found = 0;
217         for (const auto& p : queries) {
218             if (p.second == type && p.first == name) {
219                 ++found;
220             }
221         }
222         return found;
223     }
224 
RunGetAddrInfoStressTest_Binder(unsigned num_hosts,unsigned num_threads,unsigned num_queries)225     void RunGetAddrInfoStressTest_Binder(unsigned num_hosts, unsigned num_threads,
226             unsigned num_queries) {
227         std::vector<std::string> domains = { "example.com" };
228         std::vector<std::unique_ptr<test::DNSResponder>> dns;
229         std::vector<std::string> servers;
230         std::vector<DnsResponderClient::Mapping> mappings;
231         ASSERT_NO_FATAL_FAILURE(SetupMappings(num_hosts, domains, &mappings));
232         ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS, mappings, &dns, &servers));
233 
234         ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
235 
236         auto t0 = std::chrono::steady_clock::now();
237         std::vector<std::thread> threads(num_threads);
238         for (std::thread& thread : threads) {
239            thread = std::thread([this, &mappings, num_queries]() {
240                 for (unsigned i = 0 ; i < num_queries ; ++i) {
241                     uint32_t ofs = arc4random_uniform(mappings.size());
242                     auto& mapping = mappings[ofs];
243                     addrinfo* result = nullptr;
244                     int rv = getaddrinfo(mapping.host.c_str(), nullptr, nullptr, &result);
245                     EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
246                     if (rv == 0) {
247                         std::string result_str = ToString(result);
248                         EXPECT_TRUE(result_str == mapping.ip4 || result_str == mapping.ip6)
249                             << "result='" << result_str << "', ip4='" << mapping.ip4
250                             << "', ip6='" << mapping.ip6;
251                     }
252                     if (result) {
253                         freeaddrinfo(result);
254                         result = nullptr;
255                     }
256                 }
257             });
258         }
259 
260         for (std::thread& thread : threads) {
261             thread.join();
262         }
263         auto t1 = std::chrono::steady_clock::now();
264         ALOGI("%u hosts, %u threads, %u queries, %Es", num_hosts, num_threads, num_queries,
265                 std::chrono::duration<double>(t1 - t0).count());
266         ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
267     }
268 
269     const std::vector<std::string> mDefaultSearchDomains = { "example.com" };
270     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
271     const std::string mDefaultParams = "300 25 8 8";
272     const std::vector<int> mDefaultParams_Binder = { 300, 25, 8, 8 };
273 };
274 
TEST_F(ResolverTest,GetHostByName)275 TEST_F(ResolverTest, GetHostByName) {
276     const char* listen_addr = "127.0.0.3";
277     const char* listen_srv = "53";
278     const char* host_name = "hello.example.com.";
279     const char *nonexistent_host_name = "nonexistent.example.com.";
280     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
281     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
282     ASSERT_TRUE(dns.startServer());
283     std::vector<std::string> servers = { listen_addr };
284     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
285 
286     const hostent* result;
287 
288     dns.clearQueries();
289     result = gethostbyname("nonexistent");
290     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, nonexistent_host_name));
291     ASSERT_TRUE(result == nullptr);
292     ASSERT_EQ(HOST_NOT_FOUND, h_errno);
293 
294     dns.clearQueries();
295     result = gethostbyname("hello");
296     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
297     ASSERT_FALSE(result == nullptr);
298     ASSERT_EQ(4, result->h_length);
299     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
300     EXPECT_EQ("1.2.3.3", ToString(result));
301     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
302 
303     dns.stopServer();
304 }
305 
TEST_F(ResolverTest,TestBinderSerialization)306 TEST_F(ResolverTest, TestBinderSerialization) {
307     using android::net::INetd;
308     std::vector<int> params_offsets = {
309         INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY,
310         INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD,
311         INetd::RESOLVER_PARAMS_MIN_SAMPLES,
312         INetd::RESOLVER_PARAMS_MAX_SAMPLES
313     };
314     int size = static_cast<int>(params_offsets.size());
315     EXPECT_EQ(size, INetd::RESOLVER_PARAMS_COUNT);
316     std::sort(params_offsets.begin(), params_offsets.end());
317     for (int i = 0 ; i < size ; ++i) {
318         EXPECT_EQ(params_offsets[i], i);
319     }
320 }
321 
TEST_F(ResolverTest,GetHostByName_Binder)322 TEST_F(ResolverTest, GetHostByName_Binder) {
323     using android::net::INetd;
324 
325     std::vector<std::string> domains = { "example.com" };
326     std::vector<std::unique_ptr<test::DNSResponder>> dns;
327     std::vector<std::string> servers;
328     std::vector<Mapping> mappings;
329     ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
330     ASSERT_NO_FATAL_FAILURE(SetupDNSServers(4, mappings, &dns, &servers));
331     ASSERT_EQ(1U, mappings.size());
332     const Mapping& mapping = mappings[0];
333 
334     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
335 
336     const hostent* result = gethostbyname(mapping.host.c_str());
337     size_t total_queries = std::accumulate(dns.begin(), dns.end(), 0,
338             [this, &mapping](size_t total, auto& d) {
339                 return total + GetNumQueriesForType(*d, ns_type::ns_t_a, mapping.entry.c_str());
340             });
341 
342     EXPECT_LE(1U, total_queries);
343     ASSERT_FALSE(result == nullptr);
344     ASSERT_EQ(4, result->h_length);
345     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
346     EXPECT_EQ(mapping.ip4, ToString(result));
347     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
348 
349     std::vector<std::string> res_servers;
350     std::vector<std::string> res_domains;
351     __res_params res_params;
352     std::vector<ResolverStats> res_stats;
353     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
354     EXPECT_EQ(servers.size(), res_servers.size());
355     EXPECT_EQ(domains.size(), res_domains.size());
356     ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
357     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
358             res_params.sample_validity);
359     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
360             res_params.success_threshold);
361     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
362     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
363     EXPECT_EQ(servers.size(), res_stats.size());
364 
365     EXPECT_TRUE(UnorderedCompareArray(res_servers, servers));
366     EXPECT_TRUE(UnorderedCompareArray(res_domains, domains));
367 
368     ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
369 }
370 
TEST_F(ResolverTest,GetAddrInfo)371 TEST_F(ResolverTest, GetAddrInfo) {
372     addrinfo* result = nullptr;
373 
374     const char* listen_addr = "127.0.0.4";
375     const char* listen_addr2 = "127.0.0.5";
376     const char* listen_srv = "53";
377     const char* host_name = "howdy.example.com.";
378     test::DNSResponder dns(listen_addr, listen_srv, 250,
379                            ns_rcode::ns_r_servfail, 1.0);
380     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
381     dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
382     ASSERT_TRUE(dns.startServer());
383 
384     test::DNSResponder dns2(listen_addr2, listen_srv, 250,
385                             ns_rcode::ns_r_servfail, 1.0);
386     dns2.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
387     dns2.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
388     ASSERT_TRUE(dns2.startServer());
389 
390 
391     std::vector<std::string> servers = { listen_addr };
392     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
393     dns.clearQueries();
394     dns2.clearQueries();
395 
396     EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
397     size_t found = GetNumQueries(dns, host_name);
398     EXPECT_LE(1U, found);
399     // Could be A or AAAA
400     std::string result_str = ToString(result);
401     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
402         << ", result_str='" << result_str << "'";
403     // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
404     if (result) {
405         freeaddrinfo(result);
406         result = nullptr;
407     }
408 
409     // Verify that the name is cached.
410     size_t old_found = found;
411     EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
412     found = GetNumQueries(dns, host_name);
413     EXPECT_LE(1U, found);
414     EXPECT_EQ(old_found, found);
415     result_str = ToString(result);
416     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
417         << result_str;
418     if (result) {
419         freeaddrinfo(result);
420         result = nullptr;
421     }
422 
423     // Change the DNS resolver, ensure that queries are still cached.
424     servers = { listen_addr2 };
425     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
426     dns.clearQueries();
427     dns2.clearQueries();
428 
429     EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
430     found = GetNumQueries(dns, host_name);
431     size_t found2 = GetNumQueries(dns2, host_name);
432     EXPECT_EQ(0U, found);
433     EXPECT_LE(0U, found2);
434 
435     // Could be A or AAAA
436     result_str = ToString(result);
437     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
438         << ", result_str='" << result_str << "'";
439     if (result) {
440         freeaddrinfo(result);
441         result = nullptr;
442     }
443 
444     dns.stopServer();
445     dns2.stopServer();
446 }
447 
TEST_F(ResolverTest,GetAddrInfoV4)448 TEST_F(ResolverTest, GetAddrInfoV4) {
449     addrinfo* result = nullptr;
450 
451     const char* listen_addr = "127.0.0.5";
452     const char* listen_srv = "53";
453     const char* host_name = "hola.example.com.";
454     test::DNSResponder dns(listen_addr, listen_srv, 250,
455                            ns_rcode::ns_r_servfail, 1.0);
456     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.5");
457     ASSERT_TRUE(dns.startServer());
458     std::vector<std::string> servers = { listen_addr };
459     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
460 
461     addrinfo hints;
462     memset(&hints, 0, sizeof(hints));
463     hints.ai_family = AF_INET;
464     EXPECT_EQ(0, getaddrinfo("hola", nullptr, &hints, &result));
465     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
466     EXPECT_EQ("1.2.3.5", ToString(result));
467     if (result) {
468         freeaddrinfo(result);
469         result = nullptr;
470     }
471 }
472 
TEST_F(ResolverTest,MultidomainResolution)473 TEST_F(ResolverTest, MultidomainResolution) {
474     std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };
475     const char* listen_addr = "127.0.0.6";
476     const char* listen_srv = "53";
477     const char* host_name = "nihao.example2.com.";
478     test::DNSResponder dns(listen_addr, listen_srv, 250,
479                            ns_rcode::ns_r_servfail, 1.0);
480     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
481     ASSERT_TRUE(dns.startServer());
482     std::vector<std::string> servers = { listen_addr };
483     ASSERT_TRUE(SetResolversForNetwork(searchDomains, servers, mDefaultParams));
484 
485     dns.clearQueries();
486     const hostent* result = gethostbyname("nihao");
487     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
488     ASSERT_FALSE(result == nullptr);
489     ASSERT_EQ(4, result->h_length);
490     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
491     EXPECT_EQ("1.2.3.3", ToString(result));
492     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
493     dns.stopServer();
494 }
495 
TEST_F(ResolverTest,GetAddrInfoV6_failing)496 TEST_F(ResolverTest, GetAddrInfoV6_failing) {
497     addrinfo* result = nullptr;
498 
499     const char* listen_addr0 = "127.0.0.7";
500     const char* listen_addr1 = "127.0.0.8";
501     const char* listen_srv = "53";
502     const char* host_name = "ohayou.example.com.";
503     test::DNSResponder dns0(listen_addr0, listen_srv, 250,
504                             ns_rcode::ns_r_servfail, 0.0);
505     test::DNSResponder dns1(listen_addr1, listen_srv, 250,
506                             ns_rcode::ns_r_servfail, 1.0);
507     dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
508     dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
509     ASSERT_TRUE(dns0.startServer());
510     ASSERT_TRUE(dns1.startServer());
511     std::vector<std::string> servers = { listen_addr0, listen_addr1 };
512     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
513     unsigned sample_validity = 300;
514     int success_threshold = 25;
515     int sample_count = 8;
516     std::string params = StringPrintf("%u %d %d %d", sample_validity, success_threshold,
517             sample_count, sample_count);
518     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, params));
519 
520     // Repeatedly perform resolutions for non-existing domains until MAXNSSAMPLES resolutions have
521     // reached the dns0, which is set to fail. No more requests should then arrive at that server
522     // for the next sample_lifetime seconds.
523     // TODO: This approach is implementation-dependent, change once metrics reporting is available.
524     addrinfo hints;
525     memset(&hints, 0, sizeof(hints));
526     hints.ai_family = AF_INET6;
527     for (int i = 0 ; i < sample_count ; ++i) {
528         std::string domain = StringPrintf("nonexistent%d", i);
529         getaddrinfo(domain.c_str(), nullptr, &hints, &result);
530         if (result) {
531             freeaddrinfo(result);
532             result = nullptr;
533         }
534     }
535     // Due to 100% errors for all possible samples, the server should be ignored from now on and
536     // only the second one used for all following queries, until NSSAMPLE_VALIDITY is reached.
537     dns0.clearQueries();
538     dns1.clearQueries();
539     EXPECT_EQ(0, getaddrinfo("ohayou", nullptr, &hints, &result));
540     EXPECT_EQ(0U, GetNumQueries(dns0, host_name));
541     EXPECT_EQ(1U, GetNumQueries(dns1, host_name));
542     if (result) {
543         freeaddrinfo(result);
544         result = nullptr;
545     }
546 }
547 
TEST_F(ResolverTest,GetAddrInfoV6_concurrent)548 TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
549     const char* listen_addr0 = "127.0.0.9";
550     const char* listen_addr1 = "127.0.0.10";
551     const char* listen_addr2 = "127.0.0.11";
552     const char* listen_srv = "53";
553     const char* host_name = "konbanha.example.com.";
554     test::DNSResponder dns0(listen_addr0, listen_srv, 250,
555                             ns_rcode::ns_r_servfail, 1.0);
556     test::DNSResponder dns1(listen_addr1, listen_srv, 250,
557                             ns_rcode::ns_r_servfail, 1.0);
558     test::DNSResponder dns2(listen_addr2, listen_srv, 250,
559                             ns_rcode::ns_r_servfail, 1.0);
560     dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
561     dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
562     dns2.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::7");
563     ASSERT_TRUE(dns0.startServer());
564     ASSERT_TRUE(dns1.startServer());
565     ASSERT_TRUE(dns2.startServer());
566     const std::vector<std::string> servers = { listen_addr0, listen_addr1, listen_addr2 };
567     std::vector<std::thread> threads(10);
568     for (std::thread& thread : threads) {
569        thread = std::thread([this, &servers]() {
570             unsigned delay = arc4random_uniform(1*1000*1000); // <= 1s
571             usleep(delay);
572             std::vector<std::string> serverSubset;
573             for (const auto& server : servers) {
574                 if (arc4random_uniform(2)) {
575                     serverSubset.push_back(server);
576                 }
577             }
578             if (serverSubset.empty()) serverSubset = servers;
579             ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, serverSubset,
580                     mDefaultParams));
581             addrinfo hints;
582             memset(&hints, 0, sizeof(hints));
583             hints.ai_family = AF_INET6;
584             addrinfo* result = nullptr;
585             int rv = getaddrinfo("konbanha", nullptr, &hints, &result);
586             EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
587             if (result) {
588                 freeaddrinfo(result);
589                 result = nullptr;
590             }
591         });
592     }
593     for (std::thread& thread : threads) {
594         thread.join();
595     }
596 }
597 
TEST_F(ResolverTest,GetAddrInfoStressTest_Binder_100)598 TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100) {
599     const unsigned num_hosts = 100;
600     const unsigned num_threads = 100;
601     const unsigned num_queries = 100;
602     ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
603 }
604 
TEST_F(ResolverTest,GetAddrInfoStressTest_Binder_100000)605 TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100000) {
606     const unsigned num_hosts = 100000;
607     const unsigned num_threads = 100;
608     const unsigned num_queries = 100;
609     ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
610 }
611 
TEST_F(ResolverTest,EmptySetup)612 TEST_F(ResolverTest, EmptySetup) {
613     using android::net::INetd;
614     std::vector<std::string> servers;
615     std::vector<std::string> domains;
616     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
617     std::vector<std::string> res_servers;
618     std::vector<std::string> res_domains;
619     __res_params res_params;
620     std::vector<ResolverStats> res_stats;
621     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
622     EXPECT_EQ(0U, res_servers.size());
623     EXPECT_EQ(0U, res_domains.size());
624     ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
625     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
626             res_params.sample_validity);
627     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
628             res_params.success_threshold);
629     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
630     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
631 }
632 
TEST_F(ResolverTest,SearchPathChange)633 TEST_F(ResolverTest, SearchPathChange) {
634     addrinfo* result = nullptr;
635 
636     const char* listen_addr = "127.0.0.13";
637     const char* listen_srv = "53";
638     const char* host_name1 = "test13.domain1.org.";
639     const char* host_name2 = "test13.domain2.org.";
640     test::DNSResponder dns(listen_addr, listen_srv, 250,
641                            ns_rcode::ns_r_servfail, 1.0);
642     dns.addMapping(host_name1, ns_type::ns_t_aaaa, "2001:db8::13");
643     dns.addMapping(host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13");
644     ASSERT_TRUE(dns.startServer());
645     std::vector<std::string> servers = { listen_addr };
646     std::vector<std::string> domains = { "domain1.org" };
647     ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
648 
649     addrinfo hints;
650     memset(&hints, 0, sizeof(hints));
651     hints.ai_family = AF_INET6;
652     EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
653     EXPECT_EQ(1U, dns.queries().size());
654     EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
655     EXPECT_EQ("2001:db8::13", ToString(result));
656     if (result) freeaddrinfo(result);
657 
658     // Test that changing the domain search path on its own works.
659     domains = { "domain2.org" };
660     ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
661     dns.clearQueries();
662 
663     EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
664     EXPECT_EQ(1U, dns.queries().size());
665     EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
666     EXPECT_EQ("2001:db8::1:13", ToString(result));
667     if (result) freeaddrinfo(result);
668 }
669 
TEST_F(ResolverTest,MaxServerPrune_Binder)670 TEST_F(ResolverTest, MaxServerPrune_Binder) {
671     using android::net::INetd;
672 
673     std::vector<std::string> domains = { "example.com" };
674     std::vector<std::unique_ptr<test::DNSResponder>> dns;
675     std::vector<std::string> servers;
676     std::vector<Mapping> mappings;
677     ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
678     ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS + 1, mappings, &dns, &servers));
679 
680     ASSERT_TRUE(SetResolversForNetwork(servers, domains,  mDefaultParams_Binder));
681 
682     std::vector<std::string> res_servers;
683     std::vector<std::string> res_domains;
684     __res_params res_params;
685     std::vector<ResolverStats> res_stats;
686     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
687     EXPECT_EQ(static_cast<size_t>(MAXNS), res_servers.size());
688 
689     ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
690 }
691 
base64Encode(const std::vector<uint8_t> & input)692 static std::string base64Encode(const std::vector<uint8_t>& input) {
693     size_t out_len;
694     EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size()));
695     // out_len includes the trailing NULL.
696     uint8_t output_bytes[out_len];
697     EXPECT_EQ(out_len - 1, EVP_EncodeBlock(output_bytes, input.data(), input.size()));
698     return std::string(reinterpret_cast<char*>(output_bytes));
699 }
700 
701 // Test what happens if the specified TLS server is nonexistent.
TEST_F(ResolverTest,GetHostByName_TlsMissing)702 TEST_F(ResolverTest, GetHostByName_TlsMissing) {
703     const char* listen_addr = "127.0.0.3";
704     const char* listen_srv = "53";
705     const char* host_name = "tlsmissing.example.com.";
706     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
707     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
708     ASSERT_TRUE(dns.startServer());
709     std::vector<std::string> servers = { listen_addr };
710 
711     // There's nothing listening on this address, so validation will either fail or
712     /// hang.  Either way, queries will continue to flow to the DNSResponder.
713     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
714     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
715 
716     const hostent* result;
717 
718     result = gethostbyname("tlsmissing");
719     ASSERT_FALSE(result == nullptr);
720     EXPECT_EQ("1.2.3.3", ToString(result));
721 
722     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
723     dns.stopServer();
724 }
725 
726 // Test what happens if the specified TLS server replies with garbage.
TEST_F(ResolverTest,GetHostByName_TlsBroken)727 TEST_F(ResolverTest, GetHostByName_TlsBroken) {
728     const char* listen_addr = "127.0.0.3";
729     const char* listen_srv = "53";
730     const char* host_name1 = "tlsbroken1.example.com.";
731     const char* host_name2 = "tlsbroken2.example.com.";
732     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
733     dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
734     dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
735     ASSERT_TRUE(dns.startServer());
736     std::vector<std::string> servers = { listen_addr };
737 
738     // Bind the specified private DNS socket but don't respond to any client sockets yet.
739     int s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
740     ASSERT_TRUE(s >= 0);
741     struct sockaddr_in tlsServer = {
742         .sin_family = AF_INET,
743         .sin_port = htons(853),
744     };
745     ASSERT_TRUE(inet_pton(AF_INET, listen_addr, &tlsServer.sin_addr));
746     ASSERT_FALSE(bind(s, reinterpret_cast<struct sockaddr*>(&tlsServer), sizeof(tlsServer)));
747     ASSERT_FALSE(listen(s, 1));
748 
749     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
750     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
751 
752     // SetResolversForNetwork should have triggered a validation connection to this address.
753     struct sockaddr_storage cliaddr;
754     socklen_t sin_size = sizeof(cliaddr);
755     int new_fd = accept(s, reinterpret_cast<struct sockaddr *>(&cliaddr), &sin_size);
756     ASSERT_TRUE(new_fd > 0);
757 
758     // We've received the new file descriptor but not written to it or closed, so the
759     // validation is still pending.  Queries should still flow correctly because the
760     // server is not used until validation succeeds.
761     const hostent* result;
762     result = gethostbyname("tlsbroken1");
763     ASSERT_FALSE(result == nullptr);
764     EXPECT_EQ("1.2.3.1", ToString(result));
765 
766     // Now we cause the validation to fail.
767     std::string garbage = "definitely not a valid TLS ServerHello";
768     write(new_fd, garbage.data(), garbage.size());
769     close(new_fd);
770 
771     // Validation failure shouldn't interfere with lookups, because lookups won't be sent
772     // to the TLS server unless validation succeeds.
773     result = gethostbyname("tlsbroken2");
774     ASSERT_FALSE(result == nullptr);
775     EXPECT_EQ("1.2.3.2", ToString(result));
776 
777     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
778     dns.stopServer();
779     close(s);
780 }
781 
TEST_F(ResolverTest,GetHostByName_Tls)782 TEST_F(ResolverTest, GetHostByName_Tls) {
783     const char* listen_addr = "127.0.0.3";
784     const char* listen_udp = "53";
785     const char* listen_tls = "853";
786     const char* host_name1 = "tls1.example.com.";
787     const char* host_name2 = "tls2.example.com.";
788     const char* host_name3 = "tls3.example.com.";
789     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
790     dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
791     dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
792     dns.addMapping(host_name3, ns_type::ns_t_a, "1.2.3.3");
793     ASSERT_TRUE(dns.startServer());
794     std::vector<std::string> servers = { listen_addr };
795 
796     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
797     ASSERT_TRUE(tls.startServer());
798     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
799     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
800 
801     const hostent* result;
802 
803     // Wait for validation to complete.
804     EXPECT_TRUE(tls.waitForQueries(1, 5000));
805 
806     result = gethostbyname("tls1");
807     ASSERT_FALSE(result == nullptr);
808     EXPECT_EQ("1.2.3.1", ToString(result));
809 
810     // Wait for query to get counted.
811     EXPECT_TRUE(tls.waitForQueries(2, 5000));
812 
813     // Stop the TLS server.  Since it's already been validated, queries will
814     // continue to be routed to it.
815     tls.stopServer();
816 
817     result = gethostbyname("tls2");
818     EXPECT_TRUE(result == nullptr);
819     EXPECT_EQ(HOST_NOT_FOUND, h_errno);
820 
821     // Remove the TLS server setting.  Queries should now be routed to the
822     // UDP endpoint.
823     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
824 
825     result = gethostbyname("tls3");
826     ASSERT_FALSE(result == nullptr);
827     EXPECT_EQ("1.2.3.3", ToString(result));
828 
829     dns.stopServer();
830 }
831 
TEST_F(ResolverTest,GetHostByName_TlsFingerprint)832 TEST_F(ResolverTest, GetHostByName_TlsFingerprint) {
833     const char* listen_addr = "127.0.0.3";
834     const char* listen_udp = "53";
835     const char* listen_tls = "853";
836     const char* host_name = "tlsfingerprint.example.com.";
837     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
838     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
839     ASSERT_TRUE(dns.startServer());
840     std::vector<std::string> servers = { listen_addr };
841 
842     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
843     ASSERT_TRUE(tls.startServer());
844     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
845             { base64Encode(tls.fingerprint()) });
846     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
847 
848     const hostent* result;
849 
850     // Wait for validation to complete.
851     EXPECT_TRUE(tls.waitForQueries(1, 5000));
852 
853     result = gethostbyname("tlsfingerprint");
854     ASSERT_FALSE(result == nullptr);
855     EXPECT_EQ("1.2.3.1", ToString(result));
856 
857     // Wait for query to get counted.
858     EXPECT_TRUE(tls.waitForQueries(2, 5000));
859 
860     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
861     tls.stopServer();
862     dns.stopServer();
863 }
864 
TEST_F(ResolverTest,GetHostByName_BadTlsFingerprint)865 TEST_F(ResolverTest, GetHostByName_BadTlsFingerprint) {
866     const char* listen_addr = "127.0.0.3";
867     const char* listen_udp = "53";
868     const char* listen_tls = "853";
869     const char* host_name = "badtlsfingerprint.example.com.";
870     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
871     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
872     ASSERT_TRUE(dns.startServer());
873     std::vector<std::string> servers = { listen_addr };
874 
875     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
876     ASSERT_TRUE(tls.startServer());
877     std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
878     bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
879     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
880             { base64Encode(bad_fingerprint) });
881     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
882 
883     const hostent* result;
884 
885     // The initial validation should fail at the fingerprint check before
886     // issuing a query.
887     EXPECT_FALSE(tls.waitForQueries(1, 500));
888 
889     result = gethostbyname("badtlsfingerprint");
890     ASSERT_FALSE(result == nullptr);
891     EXPECT_EQ("1.2.3.1", ToString(result));
892 
893     // The query should have bypassed the TLS frontend, because validation
894     // failed.
895     EXPECT_FALSE(tls.waitForQueries(1, 500));
896 
897     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
898     tls.stopServer();
899     dns.stopServer();
900 }
901 
902 // Test that we can pass two different fingerprints, and connection succeeds as long as
903 // at least one of them matches the server.
TEST_F(ResolverTest,GetHostByName_TwoTlsFingerprints)904 TEST_F(ResolverTest, GetHostByName_TwoTlsFingerprints) {
905     const char* listen_addr = "127.0.0.3";
906     const char* listen_udp = "53";
907     const char* listen_tls = "853";
908     const char* host_name = "twotlsfingerprints.example.com.";
909     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
910     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
911     ASSERT_TRUE(dns.startServer());
912     std::vector<std::string> servers = { listen_addr };
913 
914     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
915     ASSERT_TRUE(tls.startServer());
916     std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
917     bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
918     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
919             { base64Encode(bad_fingerprint), base64Encode(tls.fingerprint()) });
920     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
921 
922     const hostent* result;
923 
924     // Wait for validation to complete.
925     EXPECT_TRUE(tls.waitForQueries(1, 5000));
926 
927     result = gethostbyname("twotlsfingerprints");
928     ASSERT_FALSE(result == nullptr);
929     EXPECT_EQ("1.2.3.1", ToString(result));
930 
931     // Wait for query to get counted.
932     EXPECT_TRUE(tls.waitForQueries(2, 5000));
933 
934     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
935     tls.stopServer();
936     dns.stopServer();
937 }
938 
TEST_F(ResolverTest,GetHostByName_TlsFingerprintGoesBad)939 TEST_F(ResolverTest, GetHostByName_TlsFingerprintGoesBad) {
940     const char* listen_addr = "127.0.0.3";
941     const char* listen_udp = "53";
942     const char* listen_tls = "853";
943     const char* host_name1 = "tlsfingerprintgoesbad1.example.com.";
944     const char* host_name2 = "tlsfingerprintgoesbad2.example.com.";
945     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
946     dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
947     dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
948     ASSERT_TRUE(dns.startServer());
949     std::vector<std::string> servers = { listen_addr };
950 
951     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
952     ASSERT_TRUE(tls.startServer());
953     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
954             { base64Encode(tls.fingerprint()) });
955     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
956 
957     const hostent* result;
958 
959     // Wait for validation to complete.
960     EXPECT_TRUE(tls.waitForQueries(1, 5000));
961 
962     result = gethostbyname("tlsfingerprintgoesbad1");
963     ASSERT_FALSE(result == nullptr);
964     EXPECT_EQ("1.2.3.1", ToString(result));
965 
966     // Wait for query to get counted.
967     EXPECT_TRUE(tls.waitForQueries(2, 5000));
968 
969     // Restart the TLS server.  This will generate a new certificate whose fingerprint
970     // no longer matches the stored fingerprint.
971     tls.stopServer();
972     tls.startServer();
973 
974     result = gethostbyname("tlsfingerprintgoesbad2");
975     ASSERT_TRUE(result == nullptr);
976     EXPECT_EQ(HOST_NOT_FOUND, h_errno);
977 
978     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
979     tls.stopServer();
980     dns.stopServer();
981 }
982 
TEST_F(ResolverTest,GetHostByName_TlsFailover)983 TEST_F(ResolverTest, GetHostByName_TlsFailover) {
984     const char* listen_addr1 = "127.0.0.3";
985     const char* listen_addr2 = "127.0.0.4";
986     const char* listen_udp = "53";
987     const char* listen_tls = "853";
988     const char* host_name1 = "tlsfailover1.example.com.";
989     const char* host_name2 = "tlsfailover2.example.com.";
990     test::DNSResponder dns1(listen_addr1, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
991     test::DNSResponder dns2(listen_addr2, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
992     dns1.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
993     dns1.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
994     dns2.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.3");
995     dns2.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.4");
996     ASSERT_TRUE(dns1.startServer());
997     ASSERT_TRUE(dns2.startServer());
998     std::vector<std::string> servers = { listen_addr1, listen_addr2 };
999 
1000     test::DnsTlsFrontend tls1(listen_addr1, listen_tls, listen_addr1, listen_udp);
1001     test::DnsTlsFrontend tls2(listen_addr2, listen_tls, listen_addr2, listen_udp);
1002     ASSERT_TRUE(tls1.startServer());
1003     ASSERT_TRUE(tls2.startServer());
1004     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr1, 853, "SHA-256",
1005             { base64Encode(tls1.fingerprint()) });
1006     rv = mNetdSrv->addPrivateDnsServer(listen_addr2, 853, "SHA-256",
1007             { base64Encode(tls2.fingerprint()) });
1008     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
1009 
1010     const hostent* result;
1011 
1012     // Wait for validation to complete.
1013     EXPECT_TRUE(tls1.waitForQueries(1, 5000));
1014     EXPECT_TRUE(tls2.waitForQueries(1, 5000));
1015 
1016     result = gethostbyname("tlsfailover1");
1017     ASSERT_FALSE(result == nullptr);
1018     EXPECT_EQ("1.2.3.1", ToString(result));
1019 
1020     // Wait for query to get counted.
1021     EXPECT_TRUE(tls1.waitForQueries(2, 5000));
1022     // No new queries should have reached tls2.
1023     EXPECT_EQ(1, tls2.queries());
1024 
1025     // Stop tls1.  Subsequent queries should attempt to reach tls1, fail, and retry to tls2.
1026     tls1.stopServer();
1027 
1028     result = gethostbyname("tlsfailover2");
1029     EXPECT_EQ("1.2.3.4", ToString(result));
1030 
1031     // Wait for query to get counted.
1032     EXPECT_TRUE(tls2.waitForQueries(2, 5000));
1033 
1034     // No additional queries should have reached the insecure servers.
1035     EXPECT_EQ(2U, dns1.queries().size());
1036     EXPECT_EQ(2U, dns2.queries().size());
1037 
1038     rv = mNetdSrv->removePrivateDnsServer(listen_addr1);
1039     rv = mNetdSrv->removePrivateDnsServer(listen_addr2);
1040     tls2.stopServer();
1041     dns1.stopServer();
1042     dns2.stopServer();
1043 }
1044 
TEST_F(ResolverTest,GetAddrInfo_Tls)1045 TEST_F(ResolverTest, GetAddrInfo_Tls) {
1046     const char* listen_addr = "127.0.0.3";
1047     const char* listen_udp = "53";
1048     const char* listen_tls = "853";
1049     const char* host_name = "addrinfotls.example.com.";
1050     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
1051     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
1052     dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
1053     ASSERT_TRUE(dns.startServer());
1054     std::vector<std::string> servers = { listen_addr };
1055 
1056     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
1057     ASSERT_TRUE(tls.startServer());
1058     auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
1059             { base64Encode(tls.fingerprint()) });
1060     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
1061 
1062     // Wait for validation to complete.
1063     EXPECT_TRUE(tls.waitForQueries(1, 5000));
1064 
1065     dns.clearQueries();
1066     addrinfo* result = nullptr;
1067     EXPECT_EQ(0, getaddrinfo("addrinfotls", nullptr, nullptr, &result));
1068     size_t found = GetNumQueries(dns, host_name);
1069     EXPECT_LE(1U, found);
1070     // Could be A or AAAA
1071     std::string result_str = ToString(result);
1072     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
1073         << ", result_str='" << result_str << "'";
1074     // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
1075     if (result) {
1076         freeaddrinfo(result);
1077         result = nullptr;
1078     }
1079     // Wait for both A and AAAA queries to get counted.
1080     EXPECT_TRUE(tls.waitForQueries(3, 5000));
1081 
1082     rv = mNetdSrv->removePrivateDnsServer(listen_addr);
1083     tls.stopServer();
1084     dns.stopServer();
1085 }
1086