• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #define LOG_TAG "resolv_integration_test"
19 
20 #include <android-base/logging.h>
21 #include <android-base/parseint.h>
22 #include <android-base/properties.h>
23 #include <android-base/result.h>
24 #include <android-base/stringprintf.h>
25 #include <android-base/unique_fd.h>
26 #include <android/multinetwork.h>  // ResNsendFlags
27 #include <arpa/inet.h>
28 #include <arpa/nameser.h>
29 #include <binder/ProcessState.h>
30 #include <cutils/sockets.h>
31 #include <gmock/gmock-matchers.h>
32 #include <gtest/gtest.h>
33 #include <netdb.h>
34 #include <netdutils/InternetAddresses.h>
35 #include <netdutils/NetworkConstants.h>  // SHA256_SIZE
36 #include <netdutils/ResponseCode.h>
37 #include <netdutils/Slice.h>
38 #include <netdutils/SocketOption.h>
39 #include <netdutils/Stopwatch.h>
40 #include <netinet/in.h>
41 #include <poll.h> /* poll */
42 #include <private/android_filesystem_config.h>
43 #include <resolv.h>
44 #include <stdarg.h>
45 #include <stdlib.h>
46 #include <sys/socket.h>
47 #include <sys/un.h>
48 #include <unistd.h>
49 
50 #include <algorithm>
51 #include <chrono>
52 #include <iterator>
53 #include <numeric>
54 #include <thread>
55 #include <unordered_set>
56 
57 #include <DnsProxydProtocol.h>  // NETID_USE_LOCAL_NAMESERVERS
58 #include <aidl/android/net/IDnsResolver.h>
59 #include <android/binder_manager.h>
60 #include <android/binder_process.h>
61 #include <bpf/BpfUtils.h>
62 #include <util.h>  // getApiLevel
63 #include "NetdClient.h"
64 #include "ResolverStats.h"
65 #include "netid_client.h"  // NETID_UNSET
66 #include "params.h"        // MAXNS
67 #include "stats.h"         // RCODE_TIMEOUT
68 #include "tests/dns_metrics_listener/dns_metrics_listener.h"
69 #include "tests/dns_responder/dns_responder.h"
70 #include "tests/dns_responder/dns_responder_client_ndk.h"
71 #include "tests/dns_responder/dns_tls_certificate.h"
72 #include "tests/dns_responder/dns_tls_frontend.h"
73 #include "tests/resolv_test_utils.h"
74 #include "tests/tun_forwarder.h"
75 #include "tests/unsolicited_listener/unsolicited_event_listener.h"
76 
77 // Valid VPN netId range is 100 ~ 65535
78 constexpr int TEST_VPN_NETID = 65502;
79 constexpr int MAXPACKET = (8 * 1024);
80 
81 const std::string kSortNameserversFlag("persist.device_config.netd_native.sort_nameservers");
82 const std::string kDotConnectTimeoutMsFlag(
83         "persist.device_config.netd_native.dot_connect_timeout_ms");
84 const std::string kDotAsyncHandshakeFlag("persist.device_config.netd_native.dot_async_handshake");
85 const std::string kDotMaxretriesFlag("persist.device_config.netd_native.dot_maxtries");
86 const std::string kDotRevalidationThresholdFlag(
87         "persist.device_config.netd_native.dot_revalidation_threshold");
88 const std::string kDotXportUnusableThresholdFlag(
89         "persist.device_config.netd_native.dot_xport_unusable_threshold");
90 const std::string kDotQueryTimeoutMsFlag("persist.device_config.netd_native.dot_query_timeout_ms");
91 const std::string kDotValidationLatencyFactorFlag(
92         "persist.device_config.netd_native.dot_validation_latency_factor");
93 const std::string kDotValidationLatencyOffsetMsFlag(
94         "persist.device_config.netd_native.dot_validation_latency_offset_ms");
95 // Semi-public Bionic hook used by the NDK (frameworks/base/native/android/net.c)
96 // Tested here for convenience.
97 extern "C" int android_getaddrinfofornet(const char* hostname, const char* servname,
98                                          const addrinfo* hints, unsigned netid, unsigned mark,
99                                          struct addrinfo** result);
100 
101 using namespace std::chrono_literals;
102 
103 using aidl::android::net::IDnsResolver;
104 using aidl::android::net::INetd;
105 using aidl::android::net::ResolverOptionsParcel;
106 using aidl::android::net::ResolverParamsParcel;
107 using aidl::android::net::metrics::INetdEventListener;
108 using aidl::android::net::resolv::aidl::DnsHealthEventParcel;
109 using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
110 using aidl::android::net::resolv::aidl::Nat64PrefixEventParcel;
111 using aidl::android::net::resolv::aidl::PrivateDnsValidationEventParcel;
112 using android::base::Error;
113 using android::base::GetProperty;
114 using android::base::ParseInt;
115 using android::base::Result;
116 using android::base::StringPrintf;
117 using android::base::unique_fd;
118 using android::net::ResolverStats;
119 using android::net::TunForwarder;
120 using android::net::metrics::DnsMetricsListener;
121 using android::net::resolv::aidl::UnsolicitedEventListener;
122 using android::netdutils::enableSockopt;
123 using android::netdutils::makeSlice;
124 using android::netdutils::ResponseCode;
125 using android::netdutils::ScopedAddrinfo;
126 using android::netdutils::Stopwatch;
127 using android::netdutils::toHex;
128 
129 // TODO: move into libnetdutils?
130 namespace {
131 
safe_getaddrinfo(const char * node,const char * service,const struct addrinfo * hints)132 ScopedAddrinfo safe_getaddrinfo(const char* node, const char* service,
133                                 const struct addrinfo* hints) {
134     addrinfo* result = nullptr;
135     if (getaddrinfo(node, service, hints, &result) != 0) {
136         result = nullptr;  // Should already be the case, but...
137     }
138     return ScopedAddrinfo(result);
139 }
140 
safe_getaddrinfo_time_taken(const char * node,const char * service,const addrinfo & hints)141 std::pair<ScopedAddrinfo, int> safe_getaddrinfo_time_taken(const char* node, const char* service,
142                                                            const addrinfo& hints) {
143     Stopwatch s;
144     ScopedAddrinfo result = safe_getaddrinfo(node, service, &hints);
145     return {std::move(result), s.timeTakenUs() / 1000};
146 }
147 
148 struct NameserverStats {
149     NameserverStats() = delete;
NameserverStats__anonccf6c6f20111::NameserverStats150     NameserverStats(const std::string server) : server(server) {}
setSuccesses__anonccf6c6f20111::NameserverStats151     NameserverStats& setSuccesses(int val) {
152         successes = val;
153         return *this;
154     }
setErrors__anonccf6c6f20111::NameserverStats155     NameserverStats& setErrors(int val) {
156         errors = val;
157         return *this;
158     }
setTimeouts__anonccf6c6f20111::NameserverStats159     NameserverStats& setTimeouts(int val) {
160         timeouts = val;
161         return *this;
162     }
setInternalErrors__anonccf6c6f20111::NameserverStats163     NameserverStats& setInternalErrors(int val) {
164         internal_errors = val;
165         return *this;
166     }
167 
168     const std::string server;
169     int successes = 0;
170     int errors = 0;
171     int timeouts = 0;
172     int internal_errors = 0;
173 };
174 
175 class ScopedSystemProperties {
176   public:
ScopedSystemProperties(const std::string & key,const std::string & value)177     ScopedSystemProperties(const std::string& key, const std::string& value) : mStoredKey(key) {
178         mStoredValue = android::base::GetProperty(key, "");
179         android::base::SetProperty(key, value);
180     }
~ScopedSystemProperties()181     ~ScopedSystemProperties() { android::base::SetProperty(mStoredKey, mStoredValue); }
182 
183   private:
184     std::string mStoredKey;
185     std::string mStoredValue;
186 };
187 
188 const bool isAtLeastR = (getApiLevel() >= 30);
189 
190 }  // namespace
191 
192 class ResolverTest : public ::testing::Test {
193   public:
SetUpTestSuite()194     static void SetUpTestSuite() {
195         // Get binder service.
196         // Note that |mDnsClient| is not used for getting binder service in this static function.
197         // The reason is that wants to keep |mDnsClient| as a non-static data member. |mDnsClient|
198         // which sets up device network configuration could be independent from every test.
199         // TODO: Perhaps add a static function in resolv_test_binder_utils.{cpp,h} to get binder
200         // service.
201 
202         AIBinder* binder = AServiceManager_getService("dnsresolver");
203         sResolvBinder = ndk::SpAIBinder(binder);
204         auto resolvService = aidl::android::net::IDnsResolver::fromBinder(sResolvBinder);
205         ASSERT_NE(nullptr, resolvService.get());
206 
207         // Subscribe the death recipient to the service IDnsResolver for detecting Netd death.
208         // GTEST assertion macros are not invoked for generating a test failure in the death
209         // recipient because the macros can't indicate failed test if Netd died between tests.
210         // Moreover, continuing testing may have no meaningful after Netd death. Therefore, the
211         // death recipient aborts process by GTEST_LOG_(FATAL) once Netd died.
212         sResolvDeathRecipient = AIBinder_DeathRecipient_new([](void*) {
213             constexpr char errorMessage[] = "Netd died";
214             LOG(ERROR) << errorMessage;
215             GTEST_LOG_(FATAL) << errorMessage;
216         });
217         ASSERT_EQ(STATUS_OK, AIBinder_linkToDeath(binder, sResolvDeathRecipient, nullptr));
218 
219         // Subscribe the DNS listener for verifying DNS metrics event contents.
220         sDnsMetricsListener = ndk::SharedRefBase::make<DnsMetricsListener>(
221                 TEST_NETID /*monitor specific network*/);
222         ASSERT_TRUE(resolvService->registerEventListener(sDnsMetricsListener).isOk());
223 
224         // Subscribe the unsolicited event listener for verifying unsolicited event contents.
225         sUnsolicitedEventListener = ndk::SharedRefBase::make<UnsolicitedEventListener>(
226                 TEST_NETID /*monitor specific network*/);
227         ASSERT_TRUE(
228                 resolvService->registerUnsolicitedEventListener(sUnsolicitedEventListener).isOk());
229 
230         // Start the binder thread pool for listening DNS metrics events and receiving death
231         // recipient.
232         ABinderProcess_startThreadPool();
233     }
TearDownTestSuite()234     static void TearDownTestSuite() { AIBinder_DeathRecipient_delete(sResolvDeathRecipient); }
235 
236   protected:
SetUp()237     void SetUp() {
238         mDnsClient.SetUp();
239         sDnsMetricsListener->reset();
240         sUnsolicitedEventListener->reset();
241         mIsResolverOptionIPCSupported =
242                 DnsResponderClient::isRemoteVersionSupported(mDnsClient.resolvService(), 9);
243     }
244 
TearDown()245     void TearDown() {
246         // Ensure the dump works at the end of each test.
247         DumpResolverService();
248 
249         mDnsClient.TearDown();
250     }
251 
resetNetwork()252     void resetNetwork() {
253         mDnsClient.TearDown();
254         mDnsClient.SetupOemNetwork();
255     }
256 
StartDns(test::DNSResponder & dns,const std::vector<DnsRecord> & records)257     void StartDns(test::DNSResponder& dns, const std::vector<DnsRecord>& records) {
258         for (const auto& r : records) {
259             dns.addMapping(r.host_name, r.type, r.addr);
260         }
261 
262         ASSERT_TRUE(dns.startServer());
263         dns.clearQueries();
264     }
265 
DumpResolverService()266     void DumpResolverService() {
267         unique_fd fd(open("/dev/null", O_WRONLY));
268         EXPECT_EQ(mDnsClient.resolvService()->dump(fd, nullptr, 0), 0);
269 
270         const char* querylogCmd[] = {"querylog"};  // Keep it sync with DnsQueryLog::DUMP_KEYWORD.
271         EXPECT_EQ(mDnsClient.resolvService()->dump(fd, querylogCmd, std::size(querylogCmd)), 0);
272     }
273 
WaitForNat64Prefix(ExpectNat64PrefixStatus status,std::chrono::milliseconds timeout=std::chrono::milliseconds (1000))274     bool WaitForNat64Prefix(ExpectNat64PrefixStatus status,
275                             std::chrono::milliseconds timeout = std::chrono::milliseconds(1000)) {
276         return sDnsMetricsListener->waitForNat64Prefix(status, timeout) &&
277                sUnsolicitedEventListener->waitForNat64Prefix(
278                        status == EXPECT_FOUND
279                                ? IDnsResolverUnsolicitedEventListener::PREFIX_OPERATION_ADDED
280                                : IDnsResolverUnsolicitedEventListener::PREFIX_OPERATION_REMOVED,
281                        timeout);
282     }
283 
WaitForPrivateDnsValidation(std::string serverAddr,bool validated)284     bool WaitForPrivateDnsValidation(std::string serverAddr, bool validated) {
285         return sDnsMetricsListener->waitForPrivateDnsValidation(serverAddr, validated) &&
286                sUnsolicitedEventListener->waitForPrivateDnsValidation(
287                        serverAddr,
288                        validated ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
289                                  : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE);
290     }
291 
hasUncaughtPrivateDnsValidation(const std::string & serverAddr)292     bool hasUncaughtPrivateDnsValidation(const std::string& serverAddr) {
293         return sDnsMetricsListener->findValidationRecord(serverAddr) &&
294                sUnsolicitedEventListener->findValidationRecord(serverAddr);
295     }
296 
ExpectDnsEvent(int32_t eventType,int32_t returnCode,const std::string & hostname,const std::vector<std::string> & ipAddresses)297     void ExpectDnsEvent(int32_t eventType, int32_t returnCode, const std::string& hostname,
298                         const std::vector<std::string>& ipAddresses) {
299         const DnsMetricsListener::DnsEvent expect = {
300                 TEST_NETID, eventType,   returnCode,
301                 hostname,   ipAddresses, static_cast<int32_t>(ipAddresses.size())};
302         do {
303             // Blocking call until timeout.
304             const auto dnsEvent = sDnsMetricsListener->popDnsEvent();
305             ASSERT_TRUE(dnsEvent.has_value()) << "Expected DnsEvent " << expect;
306             if (dnsEvent.value() == expect) break;
307             LOG(INFO) << "Skip unexpected DnsEvent: " << dnsEvent.value();
308         } while (true);
309 
310         while (returnCode == 0 || returnCode == RCODE_TIMEOUT) {
311             // Blocking call until timeout.
312             Result<int> result = sUnsolicitedEventListener->popDnsHealthResult();
313             ASSERT_TRUE(result.ok()) << "Expected dns health result is " << returnCode;
314             if ((returnCode == 0 &&
315                  result.value() == IDnsResolverUnsolicitedEventListener::DNS_HEALTH_RESULT_OK) ||
316                 (returnCode == RCODE_TIMEOUT &&
317                  result.value() ==
318                          IDnsResolverUnsolicitedEventListener::DNS_HEALTH_RESULT_TIMEOUT)) {
319                 break;
320             }
321             LOG(INFO) << "Skip unexpected dns health result:" << result.value();
322         }
323     }
324 
325     enum class StatsCmp { LE, EQ };
326 
expectStatsNotGreaterThan(const std::vector<NameserverStats> & nameserversStats)327     bool expectStatsNotGreaterThan(const std::vector<NameserverStats>& nameserversStats) {
328         return expectStatsFromGetResolverInfo(nameserversStats, StatsCmp::LE);
329     }
330 
expectStatsEqualTo(const std::vector<NameserverStats> & nameserversStats)331     bool expectStatsEqualTo(const std::vector<NameserverStats>& nameserversStats) {
332         return expectStatsFromGetResolverInfo(nameserversStats, StatsCmp::EQ);
333     }
334 
expectStatsFromGetResolverInfo(const std::vector<NameserverStats> & nameserversStats,const StatsCmp cmp)335     bool expectStatsFromGetResolverInfo(const std::vector<NameserverStats>& nameserversStats,
336                                         const StatsCmp cmp) {
337         std::vector<std::string> res_servers;
338         std::vector<std::string> res_domains;
339         std::vector<std::string> res_tls_servers;
340         res_params res_params;
341         std::vector<ResolverStats> res_stats;
342         int wait_for_pending_req_timeout_count;
343 
344         if (!DnsResponderClient::GetResolverInfo(mDnsClient.resolvService(), TEST_NETID,
345                                                  &res_servers, &res_domains, &res_tls_servers,
346                                                  &res_params, &res_stats,
347                                                  &wait_for_pending_req_timeout_count)) {
348             ADD_FAILURE() << "GetResolverInfo failed";
349             return false;
350         }
351 
352         if (res_servers.size() != res_stats.size()) {
353             ADD_FAILURE() << fmt::format("res_servers.size() != res_stats.size(): {} != {}",
354                                          res_servers.size(), res_stats.size());
355             return false;
356         }
357         if (res_servers.size() != nameserversStats.size()) {
358             ADD_FAILURE() << fmt::format("res_servers.size() != nameserversStats.size(): {} != {}",
359                                          res_servers.size(), nameserversStats.size());
360             return false;
361         }
362 
363         for (const auto& stats : nameserversStats) {
364             SCOPED_TRACE(stats.server);
365             const auto it = std::find(res_servers.begin(), res_servers.end(), stats.server);
366             if (it == res_servers.end()) {
367                 ADD_FAILURE() << fmt::format("nameserver {} not found in the list {{{}}}",
368                                              stats.server, fmt::join(res_servers, ", "));
369                 return false;
370             }
371             const int index = std::distance(res_servers.begin(), it);
372 
373             // The check excludes rtt_avg, last_sample_time, and usable since they will be obsolete
374             // after |res_stats| is retrieved from NetConfig.dnsStats rather than NetConfig.nsstats.
375             switch (cmp) {
376                 case StatsCmp::EQ:
377                     EXPECT_EQ(res_stats[index].successes, stats.successes);
378                     EXPECT_EQ(res_stats[index].errors, stats.errors);
379                     EXPECT_EQ(res_stats[index].timeouts, stats.timeouts);
380                     EXPECT_EQ(res_stats[index].internal_errors, stats.internal_errors);
381                     break;
382                 case StatsCmp::LE:
383                     EXPECT_LE(res_stats[index].successes, stats.successes);
384                     EXPECT_LE(res_stats[index].errors, stats.errors);
385                     EXPECT_LE(res_stats[index].timeouts, stats.timeouts);
386                     EXPECT_LE(res_stats[index].internal_errors, stats.internal_errors);
387                     break;
388                 default:
389                     ADD_FAILURE() << "Unknown comparator " << static_cast<int>(cmp);
390                     return false;
391             }
392         }
393 
394         return true;
395     }
396 
397     // Since there's no way to terminate private DNS validation threads at any time. Tests that
398     // focus on the results of private DNS validation can interfere with each other if they use the
399     // same IP address for test servers. getUniqueIPv4Address() is a workaround to reduce the
400     // possibility of tests being flaky. A feasible solution is to forbid the validation threads,
401     // which are considered as outdated (e.g. switch the resolver to private DNS OFF mode), updating
402     // the result to the PrivateDnsConfiguration instance.
getUniqueIPv4Address()403     static std::string getUniqueIPv4Address() {
404         static int counter = 0;
405         return fmt::format("127.0.100.{}", (++counter & 0xff));
406     }
407 
408     DnsResponderClient mDnsClient;
409 
410     bool mIsResolverOptionIPCSupported = false;
411 
412     // Use a shared static DNS listener for all tests to avoid registering lots of listeners
413     // which may be released late until process terminated. Currently, registered DNS listener
414     // is removed by binder death notification which is fired when the process hosting an
415     // IBinder has gone away. If every test in ResolverTest registers its DNS listener, Netd
416     // may temporarily hold lots of dead listeners until the unit test process terminates.
417     // TODO: Perhaps add an unregistering listener binder call or fork a listener process which
418     // could be terminated earlier.
419     static std::shared_ptr<DnsMetricsListener>
420             sDnsMetricsListener;  // Initialized in SetUpTestSuite.
421 
422     inline static std::shared_ptr<UnsolicitedEventListener>
423             sUnsolicitedEventListener;  // Initialized in SetUpTestSuite.
424 
425     // Use a shared static death recipient to monitor the service death. The static death
426     // recipient could monitor the death not only during the test but also between tests.
427     static AIBinder_DeathRecipient* sResolvDeathRecipient;  // Initialized in SetUpTestSuite.
428 
429     // The linked AIBinder_DeathRecipient will be automatically unlinked if the binder is deleted.
430     // The binder needs to be retained throughout tests.
431     static ndk::SpAIBinder sResolvBinder;
432 };
433 
434 // Initialize static member of class.
435 std::shared_ptr<DnsMetricsListener> ResolverTest::sDnsMetricsListener;
436 AIBinder_DeathRecipient* ResolverTest::sResolvDeathRecipient;
437 ndk::SpAIBinder ResolverTest::sResolvBinder;
438 
TEST_F(ResolverTest,GetHostByName)439 TEST_F(ResolverTest, GetHostByName) {
440     constexpr char nonexistent_host_name[] = "nonexistent.example.com.";
441 
442     test::DNSResponder dns;
443     StartDns(dns, {{kHelloExampleCom, ns_type::ns_t_a, "1.2.3.3"}});
444     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
445 
446     const hostent* result;
447     result = gethostbyname("nonexistent");
448     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, nonexistent_host_name));
449     ASSERT_TRUE(result == nullptr);
450     EXPECT_EQ(HOST_NOT_FOUND, h_errno);
451 
452     dns.clearQueries();
453     result = gethostbyname("hello");
454     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, kHelloExampleCom));
455     ASSERT_FALSE(result == nullptr);
456     ASSERT_EQ(4, result->h_length);
457     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
458     EXPECT_EQ("1.2.3.3", ToString(result));
459     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
460 }
461 
TEST_F(ResolverTest,GetHostByName_NULL)462 TEST_F(ResolverTest, GetHostByName_NULL) {
463     // Most libc implementations would just crash on gethostbyname(NULL). Instead, Bionic
464     // serializes the null argument over dnsproxyd, causing the server-side to crash!
465     // This is a regression test.
466     const char* const testcases[] = {nullptr, "", "^"};
467     for (const char* name : testcases) {
468         SCOPED_TRACE(fmt::format("gethostbyname({})", name ? name : "NULL"));
469         const hostent* result = gethostbyname(name);
470         EXPECT_TRUE(result == nullptr);
471         EXPECT_EQ(HOST_NOT_FOUND, h_errno);
472     }
473 }
474 
TEST_F(ResolverTest,GetHostByName_cnames)475 TEST_F(ResolverTest, GetHostByName_cnames) {
476     constexpr char host_name[] = "host.example.com.";
477     size_t cnamecount = 0;
478     test::DNSResponder dns;
479 
480     const std::vector<DnsRecord> records = {
481             {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
482             {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
483             {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
484             {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
485             {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
486             {"e.example.com.", ns_type::ns_t_cname, host_name},
487             {host_name, ns_type::ns_t_a, "1.2.3.3"},
488             {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
489     };
490     StartDns(dns, records);
491     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
492 
493     // using gethostbyname2() to resolve ipv4 hello.example.com. to 1.2.3.3
494     // Ensure the v4 address and cnames are correct
495     const hostent* result;
496     result = gethostbyname2("hello", AF_INET);
497     ASSERT_FALSE(result == nullptr);
498 
499     for (int i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
500         std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
501         EXPECT_EQ(result->h_aliases[i], domain_name);
502         cnamecount++;
503     }
504     // The size of "Non-cname type" record in DNS records is 2
505     ASSERT_EQ(cnamecount, records.size() - 2);
506     ASSERT_EQ(4, result->h_length);
507     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
508     EXPECT_EQ("1.2.3.3", ToString(result));
509     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
510     EXPECT_EQ(1U, dns.queries().size()) << dns.dumpQueries();
511 
512     // using gethostbyname2() to resolve ipv6 hello.example.com. to 2001:db8::42
513     // Ensure the v6 address and cnames are correct
514     cnamecount = 0;
515     dns.clearQueries();
516     result = gethostbyname2("hello", AF_INET6);
517     for (unsigned i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
518         std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
519         EXPECT_EQ(result->h_aliases[i], domain_name);
520         cnamecount++;
521     }
522     // The size of "Non-cname type" DNS record in records is 2
523     ASSERT_EQ(cnamecount, records.size() - 2);
524     ASSERT_FALSE(result == nullptr);
525     ASSERT_EQ(16, result->h_length);
526     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
527     EXPECT_EQ("2001:db8::42", ToString(result));
528     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
529 }
530 
TEST_F(ResolverTest,GetHostByName_cnamesInfiniteLoop)531 TEST_F(ResolverTest, GetHostByName_cnamesInfiniteLoop) {
532     test::DNSResponder dns;
533     const std::vector<DnsRecord> records = {
534             {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
535             {"a.example.com.", ns_type::ns_t_cname, kHelloExampleCom},
536     };
537     StartDns(dns, records);
538     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
539 
540     const hostent* result;
541     result = gethostbyname2("hello", AF_INET);
542     ASSERT_TRUE(result == nullptr);
543 
544     dns.clearQueries();
545     result = gethostbyname2("hello", AF_INET6);
546     ASSERT_TRUE(result == nullptr);
547 }
548 
TEST_F(ResolverTest,GetHostByName_localhost)549 TEST_F(ResolverTest, GetHostByName_localhost) {
550     constexpr char name_camelcase[] = "LocalHost";
551     constexpr char name_ip6_dot[] = "ip6-localhost.";
552     constexpr char name_ip6_fqdn[] = "ip6-localhost.example.com.";
553 
554     // Add a no-op nameserver which shouldn't receive any queries
555     test::DNSResponder dns;
556     StartDns(dns, {});
557     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
558 
559     // Expect no DNS queries; localhost is resolved via /etc/hosts
560     const hostent* result = gethostbyname(kLocalHost);
561     EXPECT_TRUE(dns.queries().empty()) << dns.dumpQueries();
562     ASSERT_FALSE(result == nullptr);
563     ASSERT_EQ(4, result->h_length);
564     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
565     EXPECT_EQ(kLocalHostAddr, ToString(result));
566     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
567 
568     // Ensure the hosts file resolver ignores case of hostnames
569     result = gethostbyname(name_camelcase);
570     EXPECT_TRUE(dns.queries().empty()) << dns.dumpQueries();
571     ASSERT_FALSE(result == nullptr);
572     ASSERT_EQ(4, result->h_length);
573     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
574     EXPECT_EQ(kLocalHostAddr, ToString(result));
575     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
576 
577     // The hosts file also contains ip6-localhost, but gethostbyname() won't
578     // return it. This would be easy to
579     // change, but there's no point in changing the legacy behavior; new code
580     // should be calling getaddrinfo() anyway.
581     // So we check the legacy behavior, which results in amusing A-record
582     // lookups for ip6-localhost, with and without search domains appended.
583     dns.clearQueries();
584     result = gethostbyname(kIp6LocalHost);
585     EXPECT_EQ(2U, dns.queries().size()) << dns.dumpQueries();
586     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, name_ip6_dot)) << dns.dumpQueries();
587     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, name_ip6_fqdn)) << dns.dumpQueries();
588     ASSERT_TRUE(result == nullptr);
589 
590     // Finally, use gethostbyname2() to resolve ip6-localhost to ::1 from
591     // the hosts file.
592     dns.clearQueries();
593     result = gethostbyname2(kIp6LocalHost, AF_INET6);
594     EXPECT_TRUE(dns.queries().empty()) << dns.dumpQueries();
595     ASSERT_FALSE(result == nullptr);
596     ASSERT_EQ(16, result->h_length);
597     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
598     EXPECT_EQ(kIp6LocalHostAddr, ToString(result));
599     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
600 }
601 
TEST_F(ResolverTest,GetHostByName_numeric)602 TEST_F(ResolverTest, GetHostByName_numeric) {
603     // Add a no-op nameserver which shouldn't receive any queries
604     test::DNSResponder dns;
605     StartDns(dns, {});
606     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
607 
608     // Numeric v4 address: expect no DNS queries
609     constexpr char numeric_v4[] = "192.168.0.1";
610     const hostent* result = gethostbyname(numeric_v4);
611     EXPECT_EQ(0U, dns.queries().size());
612     ASSERT_FALSE(result == nullptr);
613     ASSERT_EQ(4, result->h_length);  // v4
614     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
615     EXPECT_EQ(numeric_v4, ToString(result));
616     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
617 
618     // gethostbyname() recognizes a v6 address, and fails with no DNS queries
619     constexpr char numeric_v6[] = "2001:db8::42";
620     dns.clearQueries();
621     result = gethostbyname(numeric_v6);
622     EXPECT_EQ(0U, dns.queries().size());
623     EXPECT_TRUE(result == nullptr);
624 
625     // Numeric v6 address with gethostbyname2(): succeeds with no DNS queries
626     dns.clearQueries();
627     result = gethostbyname2(numeric_v6, AF_INET6);
628     EXPECT_EQ(0U, dns.queries().size());
629     ASSERT_FALSE(result == nullptr);
630     ASSERT_EQ(16, result->h_length);  // v6
631     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
632     EXPECT_EQ(numeric_v6, ToString(result));
633     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
634 
635     // Numeric v6 address with scope work with getaddrinfo(),
636     // but gethostbyname2() does not understand them; it issues two dns
637     // queries, then fails. This hardly ever happens, there's no point
638     // in fixing this. This test simply verifies the current (bogus)
639     // behavior to avoid further regressions (like crashes, or leaks).
640     constexpr char numeric_v6_scope[] = "fe80::1%lo";
641     dns.clearQueries();
642     result = gethostbyname2(numeric_v6_scope, AF_INET6);
643     EXPECT_EQ(2U, dns.queries().size());  // OUCH!
644     ASSERT_TRUE(result == nullptr);
645 }
646 
TEST_F(ResolverTest,BinderSerialization)647 TEST_F(ResolverTest, BinderSerialization) {
648     std::vector<int> params_offsets = {
649             IDnsResolver::RESOLVER_PARAMS_SAMPLE_VALIDITY,
650             IDnsResolver::RESOLVER_PARAMS_SUCCESS_THRESHOLD,
651             IDnsResolver::RESOLVER_PARAMS_MIN_SAMPLES,
652             IDnsResolver::RESOLVER_PARAMS_MAX_SAMPLES,
653             IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC,
654             IDnsResolver::RESOLVER_PARAMS_RETRY_COUNT,
655     };
656     const int size = static_cast<int>(params_offsets.size());
657     EXPECT_EQ(size, IDnsResolver::RESOLVER_PARAMS_COUNT);
658     std::sort(params_offsets.begin(), params_offsets.end());
659     for (int i = 0; i < size; ++i) {
660         EXPECT_EQ(params_offsets[i], i);
661     }
662 }
663 
TEST_F(ResolverTest,GetHostByName_Binder)664 TEST_F(ResolverTest, GetHostByName_Binder) {
665     std::vector<std::string> domains = {"example.com"};
666     std::vector<std::unique_ptr<test::DNSResponder>> dns;
667     std::vector<std::string> servers;
668     std::vector<DnsResponderClient::Mapping> mappings;
669     ASSERT_NO_FATAL_FAILURE(mDnsClient.SetupMappings(1, domains, &mappings));
670     ASSERT_NO_FATAL_FAILURE(mDnsClient.SetupDNSServers(4, mappings, &dns, &servers));
671     ASSERT_EQ(1U, mappings.size());
672     const DnsResponderClient::Mapping& mapping = mappings[0];
673 
674     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, domains, kDefaultParams));
675 
676     const hostent* result = gethostbyname(mapping.host.c_str());
677     const size_t total_queries =
678             std::accumulate(dns.begin(), dns.end(), 0, [&mapping](size_t total, auto& d) {
679                 return total + GetNumQueriesForType(*d, ns_type::ns_t_a, mapping.entry.c_str());
680             });
681 
682     EXPECT_LE(1U, total_queries);
683     ASSERT_FALSE(result == nullptr);
684     ASSERT_EQ(4, result->h_length);
685     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
686     EXPECT_EQ(mapping.ip4, ToString(result));
687     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
688 
689     std::vector<std::string> res_servers;
690     std::vector<std::string> res_domains;
691     std::vector<std::string> res_tls_servers;
692     res_params res_params;
693     std::vector<ResolverStats> res_stats;
694     int wait_for_pending_req_timeout_count;
695     ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
696             mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
697             &res_params, &res_stats, &wait_for_pending_req_timeout_count));
698     EXPECT_EQ(servers.size(), res_servers.size());
699     EXPECT_EQ(domains.size(), res_domains.size());
700     EXPECT_EQ(0U, res_tls_servers.size());
701     ASSERT_EQ(static_cast<size_t>(IDnsResolver::RESOLVER_PARAMS_COUNT), kDefaultParams.size());
702     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_SAMPLE_VALIDITY],
703               res_params.sample_validity);
704     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
705               res_params.success_threshold);
706     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
707     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
708     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC],
709               res_params.base_timeout_msec);
710     EXPECT_EQ(servers.size(), res_stats.size());
711 
712     EXPECT_THAT(res_servers, testing::UnorderedElementsAreArray(servers));
713     EXPECT_THAT(res_domains, testing::UnorderedElementsAreArray(domains));
714 }
715 
TEST_F(ResolverTest,GetAddrInfo)716 TEST_F(ResolverTest, GetAddrInfo) {
717     constexpr char listen_addr[] = "127.0.0.4";
718     constexpr char listen_addr2[] = "127.0.0.5";
719     constexpr char host_name[] = "howdy.example.com.";
720 
721     const std::vector<DnsRecord> records = {
722             {host_name, ns_type::ns_t_a, "1.2.3.4"},
723             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
724     };
725     test::DNSResponder dns(listen_addr);
726     test::DNSResponder dns2(listen_addr2);
727     StartDns(dns, records);
728     StartDns(dns2, records);
729 
730     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr}));
731     dns.clearQueries();
732     dns2.clearQueries();
733 
734     ScopedAddrinfo result = safe_getaddrinfo("howdy", nullptr, nullptr);
735     EXPECT_TRUE(result != nullptr);
736     size_t found = GetNumQueries(dns, host_name);
737     EXPECT_LE(1U, found);
738     // Could be A or AAAA
739     std::string result_str = ToString(result);
740     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
741             << ", result_str='" << result_str << "'";
742 
743     // Verify that the name is cached.
744     size_t old_found = found;
745     result = safe_getaddrinfo("howdy", nullptr, nullptr);
746     EXPECT_TRUE(result != nullptr);
747     found = GetNumQueries(dns, host_name);
748     EXPECT_LE(1U, found);
749     EXPECT_EQ(old_found, found);
750     result_str = ToString(result);
751     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4") << result_str;
752 
753     // Change the DNS resolver, ensure that queries are still cached.
754     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr2}));
755     dns.clearQueries();
756     dns2.clearQueries();
757 
758     result = safe_getaddrinfo("howdy", nullptr, nullptr);
759     EXPECT_TRUE(result != nullptr);
760     found = GetNumQueries(dns, host_name);
761     size_t found2 = GetNumQueries(dns2, host_name);
762     EXPECT_EQ(0U, found);
763     EXPECT_LE(0U, found2);
764 
765     // Could be A or AAAA
766     result_str = ToString(result);
767     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
768             << ", result_str='" << result_str << "'";
769 }
770 
TEST_F(ResolverTest,GetAddrInfoV4)771 TEST_F(ResolverTest, GetAddrInfoV4) {
772     test::DNSResponder dns;
773     StartDns(dns, {{kHelloExampleCom, ns_type::ns_t_a, "1.2.3.5"}});
774     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
775 
776     const addrinfo hints = {.ai_family = AF_INET};
777     ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
778     EXPECT_TRUE(result != nullptr);
779     EXPECT_EQ(1U, GetNumQueries(dns, kHelloExampleCom));
780     EXPECT_EQ("1.2.3.5", ToString(result));
781 }
782 
TEST_F(ResolverTest,GetAddrInfo_localhost)783 TEST_F(ResolverTest, GetAddrInfo_localhost) {
784     // Add a no-op nameserver which shouldn't receive any queries
785     test::DNSResponder dns;
786     StartDns(dns, {});
787     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
788 
789     ScopedAddrinfo result = safe_getaddrinfo(kLocalHost, nullptr, nullptr);
790     EXPECT_TRUE(result != nullptr);
791     // Expect no DNS queries; localhost is resolved via /etc/hosts
792     EXPECT_TRUE(dns.queries().empty()) << dns.dumpQueries();
793     EXPECT_EQ(kLocalHostAddr, ToString(result));
794 
795     result = safe_getaddrinfo(kIp6LocalHost, nullptr, nullptr);
796     EXPECT_TRUE(result != nullptr);
797     // Expect no DNS queries; ip6-localhost is resolved via /etc/hosts
798     EXPECT_TRUE(dns.queries().empty()) << dns.dumpQueries();
799     EXPECT_EQ(kIp6LocalHostAddr, ToString(result));
800 }
801 
TEST_F(ResolverTest,GetAddrInfo_InvalidSocketType)802 TEST_F(ResolverTest, GetAddrInfo_InvalidSocketType) {
803     test::DNSResponder dns;
804     StartDns(dns, {{kHelloExampleCom, ns_type::ns_t_a, "1.2.3.5"}});
805     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
806 
807     // TODO: Test other invalid socket types.
808     const addrinfo hints = {
809             .ai_family = AF_UNSPEC,
810             .ai_socktype = SOCK_PACKET,
811     };
812     addrinfo* result = nullptr;
813     // This is a valid hint, but the query won't be sent because the socket type is
814     // not supported.
815     EXPECT_EQ(EAI_NODATA, getaddrinfo("hello", nullptr, &hints, &result));
816     ScopedAddrinfo result_cleanup(result);
817     EXPECT_EQ(nullptr, result);
818 }
819 
820 // Verify if the resolver correctly handle multiple queries simultaneously
821 // step 1: set dns server#1 into deferred responding mode.
822 // step 2: thread#1 query "hello.example.com." --> resolver send query to server#1.
823 // step 3: thread#2 query "hello.example.com." --> resolver hold the request and wait for
824 //           response of previous pending query sent by thread#1.
825 // step 4: thread#3 query "konbanha.example.com." --> resolver send query to server#3. Server
826 //           respond to resolver immediately.
827 // step 5: check if server#1 get 1 query by thread#1, server#2 get 0 query, server#3 get 1 query.
828 // step 6: resume dns server#1 to respond dns query in step#2.
829 // step 7: thread#1 and #2 should get returned from DNS query after step#6. Also, check the
830 //           number of queries in server#2 is 0 to ensure thread#2 does not wake up unexpectedly
831 //           before signaled by thread#1.
TEST_F(ResolverTest,GetAddrInfoV4_deferred_resp)832 TEST_F(ResolverTest, GetAddrInfoV4_deferred_resp) {
833     const char* listen_addr1 = "127.0.0.9";
834     const char* listen_addr2 = "127.0.0.10";
835     const char* listen_addr3 = "127.0.0.11";
836     const char* listen_srv = "53";
837     const char* host_name_deferred = "hello.example.com.";
838     const char* host_name_normal = "konbanha.example.com.";
839     test::DNSResponder dns1(listen_addr1, listen_srv, ns_rcode::ns_r_servfail);
840     test::DNSResponder dns2(listen_addr2, listen_srv, ns_rcode::ns_r_servfail);
841     test::DNSResponder dns3(listen_addr3, listen_srv, ns_rcode::ns_r_servfail);
842     dns1.addMapping(host_name_deferred, ns_type::ns_t_a, "1.2.3.4");
843     dns2.addMapping(host_name_deferred, ns_type::ns_t_a, "1.2.3.4");
844     dns3.addMapping(host_name_normal, ns_type::ns_t_a, "1.2.3.5");
845     ASSERT_TRUE(dns1.startServer());
846     ASSERT_TRUE(dns2.startServer());
847     ASSERT_TRUE(dns3.startServer());
848     const std::vector<std::string> servers_for_t1 = {listen_addr1};
849     const std::vector<std::string> servers_for_t2 = {listen_addr2};
850     const std::vector<std::string> servers_for_t3 = {listen_addr3};
851     addrinfo hints = {.ai_family = AF_INET};
852     const std::vector<int> params = {300, 25, 8, 8, 5000};
853     bool t3_task_done = false;
854 
855     dns1.setDeferredResp(true);
856     std::thread t1([&, this]() {
857         ASSERT_TRUE(
858                 mDnsClient.SetResolversForNetwork(servers_for_t1, kDefaultSearchDomains, params));
859         ScopedAddrinfo result = safe_getaddrinfo(host_name_deferred, nullptr, &hints);
860         // t3's dns query should got returned first
861         EXPECT_TRUE(t3_task_done);
862         EXPECT_EQ(1U, GetNumQueries(dns1, host_name_deferred));
863         EXPECT_TRUE(result != nullptr);
864         EXPECT_EQ("1.2.3.4", ToString(result));
865     });
866 
867     // ensuring t1 and t2 handler functions are processed in order
868     usleep(100 * 1000);
869     std::thread t2([&, this]() {
870         ASSERT_TRUE(
871                 mDnsClient.SetResolversForNetwork(servers_for_t2, kDefaultSearchDomains, params));
872         ScopedAddrinfo result = safe_getaddrinfo(host_name_deferred, nullptr, &hints);
873         EXPECT_TRUE(t3_task_done);
874         EXPECT_EQ(0U, GetNumQueries(dns2, host_name_deferred));
875         EXPECT_TRUE(result != nullptr);
876         EXPECT_EQ("1.2.3.4", ToString(result));
877 
878         std::vector<std::string> res_servers;
879         std::vector<std::string> res_domains;
880         std::vector<std::string> res_tls_servers;
881         res_params res_params;
882         std::vector<ResolverStats> res_stats;
883         int wait_for_pending_req_timeout_count;
884         ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
885                 mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains,
886                 &res_tls_servers, &res_params, &res_stats, &wait_for_pending_req_timeout_count));
887         EXPECT_EQ(0, wait_for_pending_req_timeout_count);
888     });
889 
890     // ensuring t2 and t3 handler functions are processed in order
891     usleep(100 * 1000);
892     std::thread t3([&, this]() {
893         ASSERT_TRUE(
894                 mDnsClient.SetResolversForNetwork(servers_for_t3, kDefaultSearchDomains, params));
895         ScopedAddrinfo result = safe_getaddrinfo(host_name_normal, nullptr, &hints);
896         EXPECT_EQ(1U, GetNumQueries(dns1, host_name_deferred));
897         EXPECT_EQ(0U, GetNumQueries(dns2, host_name_deferred));
898         EXPECT_EQ(1U, GetNumQueries(dns3, host_name_normal));
899         EXPECT_TRUE(result != nullptr);
900         EXPECT_EQ("1.2.3.5", ToString(result));
901 
902         t3_task_done = true;
903         dns1.setDeferredResp(false);
904     });
905     t3.join();
906     t1.join();
907     t2.join();
908 }
909 
TEST_F(ResolverTest,GetAddrInfo_cnames)910 TEST_F(ResolverTest, GetAddrInfo_cnames) {
911     constexpr char host_name[] = "host.example.com.";
912     test::DNSResponder dns;
913     const std::vector<DnsRecord> records = {
914             {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
915             {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
916             {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
917             {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
918             {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
919             {"e.example.com.", ns_type::ns_t_cname, host_name},
920             {host_name, ns_type::ns_t_a, "1.2.3.3"},
921             {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
922     };
923     StartDns(dns, records);
924     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
925 
926     addrinfo hints = {.ai_family = AF_INET};
927     ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
928     EXPECT_TRUE(result != nullptr);
929     EXPECT_EQ("1.2.3.3", ToString(result));
930 
931     dns.clearQueries();
932     hints = {.ai_family = AF_INET6};
933     result = safe_getaddrinfo("hello", nullptr, &hints);
934     EXPECT_TRUE(result != nullptr);
935     EXPECT_EQ("2001:db8::42", ToString(result));
936 }
937 
TEST_F(ResolverTest,GetAddrInfo_cnamesNoIpAddress)938 TEST_F(ResolverTest, GetAddrInfo_cnamesNoIpAddress) {
939     test::DNSResponder dns;
940     const std::vector<DnsRecord> records = {
941             {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
942     };
943     StartDns(dns, records);
944     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
945 
946     addrinfo hints = {.ai_family = AF_INET};
947     ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
948     EXPECT_TRUE(result == nullptr);
949 
950     dns.clearQueries();
951     hints = {.ai_family = AF_INET6};
952     result = safe_getaddrinfo("hello", nullptr, &hints);
953     EXPECT_TRUE(result == nullptr);
954 }
955 
TEST_F(ResolverTest,GetAddrInfo_cnamesIllegalRdata)956 TEST_F(ResolverTest, GetAddrInfo_cnamesIllegalRdata) {
957     test::DNSResponder dns;
958     const std::vector<DnsRecord> records = {
959             {kHelloExampleCom, ns_type::ns_t_cname, ".!#?"},
960     };
961     StartDns(dns, records);
962     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
963 
964     addrinfo hints = {.ai_family = AF_INET};
965     ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
966     EXPECT_TRUE(result == nullptr);
967 
968     dns.clearQueries();
969     hints = {.ai_family = AF_INET6};
970     result = safe_getaddrinfo("hello", nullptr, &hints);
971     EXPECT_TRUE(result == nullptr);
972 }
973 
TEST_F(ResolverTest,GetAddrInfoForCaseInSensitiveDomains)974 TEST_F(ResolverTest, GetAddrInfoForCaseInSensitiveDomains) {
975     test::DNSResponder dns;
976     const char* host_name = "howdy.example.com.";
977     const char* host_name2 = "HOWDY.example.com.";
978     const std::vector<DnsRecord> records = {
979             {host_name, ns_type::ns_t_a, "1.2.3.4"},
980             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
981             {host_name2, ns_type::ns_t_a, "1.2.3.5"},
982             {host_name2, ns_type::ns_t_aaaa, "::1.2.3.5"},
983     };
984     StartDns(dns, records);
985     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
986 
987     ScopedAddrinfo hostname_result = safe_getaddrinfo("howdy", nullptr, nullptr);
988     EXPECT_TRUE(hostname_result != nullptr);
989     const size_t hostname1_count_after_first_query = GetNumQueries(dns, host_name);
990     EXPECT_LE(1U, hostname1_count_after_first_query);
991     // Could be A or AAAA
992     std::string hostname_result_str = ToString(hostname_result);
993     EXPECT_TRUE(hostname_result_str == "1.2.3.4" || hostname_result_str == "::1.2.3.4");
994 
995     // Verify that the name is cached.
996     ScopedAddrinfo hostname2_result = safe_getaddrinfo("HOWDY", nullptr, nullptr);
997     EXPECT_TRUE(hostname2_result != nullptr);
998     const size_t hostname1_count_after_second_query = GetNumQueries(dns, host_name);
999     EXPECT_LE(1U, hostname1_count_after_second_query);
1000 
1001     // verify that there is no change in num of queries for howdy.example.com
1002     EXPECT_EQ(hostname1_count_after_first_query, hostname1_count_after_second_query);
1003 
1004     // Number of queries for HOWDY.example.com would be >= 1 if domain names
1005     // are considered case-sensitive, else number of queries should be 0.
1006     const size_t hostname2_count = GetNumQueries(dns, host_name2);
1007     EXPECT_EQ(0U,hostname2_count);
1008     std::string hostname2_result_str = ToString(hostname2_result);
1009     EXPECT_TRUE(hostname2_result_str == "1.2.3.4" || hostname2_result_str == "::1.2.3.4");
1010 
1011     // verify that the result is still the same address even though
1012     // mixed-case string is not in the DNS
1013     ScopedAddrinfo result = safe_getaddrinfo("HowDY", nullptr, nullptr);
1014     EXPECT_TRUE(result != nullptr);
1015     std::string result_str = ToString(result);
1016     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4");
1017 }
1018 
TEST_F(ResolverTest,MultidomainResolution)1019 TEST_F(ResolverTest, MultidomainResolution) {
1020     constexpr char host_name[] = "nihao.example2.com.";
1021     std::vector<std::string> searchDomains = {"example1.com", "example2.com", "example3.com"};
1022 
1023     test::DNSResponder dns("127.0.0.6");
1024     StartDns(dns, {{host_name, ns_type::ns_t_a, "1.2.3.3"}});
1025     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({"127.0.0.6"}, searchDomains));
1026 
1027     const hostent* result = gethostbyname("nihao");
1028 
1029     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
1030     ASSERT_FALSE(result == nullptr);
1031     ASSERT_EQ(4, result->h_length);
1032     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
1033     EXPECT_EQ("1.2.3.3", ToString(result));
1034     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
1035 }
1036 
TEST_F(ResolverTest,GetAddrInfoV6_numeric)1037 TEST_F(ResolverTest, GetAddrInfoV6_numeric) {
1038     constexpr char host_name[] = "ohayou.example.com.";
1039     constexpr char numeric_addr[] = "fe80::1%lo";
1040 
1041     test::DNSResponder dns;
1042     dns.setResponseProbability(0.0);
1043     StartDns(dns, {{host_name, ns_type::ns_t_aaaa, "2001:db8::5"}});
1044     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
1045 
1046     addrinfo hints = {.ai_family = AF_INET6};
1047     ScopedAddrinfo result = safe_getaddrinfo(numeric_addr, nullptr, &hints);
1048     EXPECT_TRUE(result != nullptr);
1049     EXPECT_EQ(numeric_addr, ToString(result));
1050     EXPECT_TRUE(dns.queries().empty());  // Ensure no DNS queries were sent out
1051 
1052     // Now try a non-numeric hostname query with the AI_NUMERICHOST flag set.
1053     // We should fail without sending out a DNS query.
1054     hints.ai_flags |= AI_NUMERICHOST;
1055     result = safe_getaddrinfo(host_name, nullptr, &hints);
1056     EXPECT_TRUE(result == nullptr);
1057     EXPECT_TRUE(dns.queries().empty());  // Ensure no DNS queries were sent out
1058 }
1059 
TEST_F(ResolverTest,GetAddrInfoV6_failing)1060 TEST_F(ResolverTest, GetAddrInfoV6_failing) {
1061     constexpr char listen_addr0[] = "127.0.0.7";
1062     constexpr char listen_addr1[] = "127.0.0.8";
1063     const char* host_name = "ohayou.example.com.";
1064 
1065     test::DNSResponder dns0(listen_addr0);
1066     test::DNSResponder dns1(listen_addr1);
1067     dns0.setResponseProbability(0.0);
1068     StartDns(dns0, {{host_name, ns_type::ns_t_aaaa, "2001:db8::5"}});
1069     StartDns(dns1, {{host_name, ns_type::ns_t_aaaa, "2001:db8::6"}});
1070 
1071     std::vector<std::string> servers = {listen_addr0, listen_addr1};
1072     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
1073     int sample_count = 8;
1074     const std::vector<int> params = {300, 25, sample_count, sample_count};
1075     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, kDefaultSearchDomains, params));
1076 
1077     // Repeatedly perform resolutions for non-existing domains until MAXNSSAMPLES resolutions have
1078     // reached the dns0, which is set to fail. No more requests should then arrive at that server
1079     // for the next sample_lifetime seconds.
1080     // TODO: This approach is implementation-dependent, change once metrics reporting is available.
1081     const addrinfo hints = {.ai_family = AF_INET6};
1082     for (int i = 0; i < sample_count; ++i) {
1083         std::string domain = StringPrintf("nonexistent%d", i);
1084         ScopedAddrinfo result = safe_getaddrinfo(domain.c_str(), nullptr, &hints);
1085     }
1086     // Due to 100% errors for all possible samples, the server should be ignored from now on and
1087     // only the second one used for all following queries, until NSSAMPLE_VALIDITY is reached.
1088     dns0.clearQueries();
1089     dns1.clearQueries();
1090     ScopedAddrinfo result = safe_getaddrinfo("ohayou", nullptr, &hints);
1091     EXPECT_TRUE(result != nullptr);
1092     EXPECT_EQ(0U, GetNumQueries(dns0, host_name));
1093     EXPECT_EQ(1U, GetNumQueries(dns1, host_name));
1094 }
1095 
TEST_F(ResolverTest,GetAddrInfoV6_nonresponsive)1096 TEST_F(ResolverTest, GetAddrInfoV6_nonresponsive) {
1097     constexpr char listen_addr0[] = "127.0.0.7";
1098     constexpr char listen_addr1[] = "127.0.0.8";
1099     constexpr char listen_srv[] = "53";
1100     constexpr char host_name1[] = "ohayou.example.com.";
1101     constexpr char host_name2[] = "ciao.example.com.";
1102     const std::vector<std::string> defaultSearchDomain = {"example.com"};
1103     // The minimal timeout is 1000ms, so we can't decrease timeout
1104     // So reduce retry count.
1105     const std::vector<int> reduceRetryParams = {
1106             300,      // sample validity in seconds
1107             25,       // success threshod in percent
1108             8,    8,  // {MIN,MAX}_SAMPLES
1109             1000,     // BASE_TIMEOUT_MSEC
1110             1,        // retry count
1111     };
1112     const std::vector<DnsRecord> records0 = {
1113             {host_name1, ns_type::ns_t_aaaa, "2001:db8::5"},
1114             {host_name2, ns_type::ns_t_aaaa, "2001:db8::5"},
1115     };
1116     const std::vector<DnsRecord> records1 = {
1117             {host_name1, ns_type::ns_t_aaaa, "2001:db8::6"},
1118             {host_name2, ns_type::ns_t_aaaa, "2001:db8::6"},
1119     };
1120 
1121     // dns0 does not respond with 100% probability, while
1122     // dns1 responds normally, at least initially.
1123     test::DNSResponder dns0(listen_addr0, listen_srv, static_cast<ns_rcode>(-1));
1124     test::DNSResponder dns1(listen_addr1, listen_srv, static_cast<ns_rcode>(-1));
1125     dns0.setResponseProbability(0.0);
1126     StartDns(dns0, records0);
1127     StartDns(dns1, records1);
1128     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr0, listen_addr1}, defaultSearchDomain,
1129                                                   reduceRetryParams));
1130 
1131     // Specify ai_socktype to make getaddrinfo will only query 1 time
1132     const addrinfo hints = {.ai_family = AF_INET6, .ai_socktype = SOCK_STREAM};
1133 
1134     // dns0 will ignore the request, and we'll fallback to dns1 after the first
1135     // retry.
1136     ScopedAddrinfo result = safe_getaddrinfo(host_name1, nullptr, &hints);
1137     EXPECT_TRUE(result != nullptr);
1138     EXPECT_EQ(1U, GetNumQueries(dns0, host_name1));
1139     EXPECT_EQ(1U, GetNumQueries(dns1, host_name1));
1140     ExpectDnsEvent(INetdEventListener::EVENT_GETADDRINFO, 0, host_name1, {"2001:db8::6"});
1141 
1142     // Now make dns1 also ignore 100% requests... The resolve should alternate
1143     // queries between the nameservers and fail
1144     dns1.setResponseProbability(0.0);
1145     addrinfo* result2 = nullptr;
1146     EXPECT_EQ(EAI_NODATA, getaddrinfo(host_name2, nullptr, &hints, &result2));
1147     EXPECT_EQ(nullptr, result2);
1148     EXPECT_EQ(1U, GetNumQueries(dns0, host_name2));
1149     EXPECT_EQ(1U, GetNumQueries(dns1, host_name2));
1150     ExpectDnsEvent(INetdEventListener::EVENT_GETADDRINFO, RCODE_TIMEOUT, host_name2, {});
1151 }
1152 
TEST_F(ResolverTest,GetAddrInfoV6_concurrent)1153 TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
1154     constexpr char listen_addr0[] = "127.0.0.9";
1155     constexpr char listen_addr1[] = "127.0.0.10";
1156     constexpr char listen_addr2[] = "127.0.0.11";
1157     constexpr char host_name[] = "konbanha.example.com.";
1158 
1159     test::DNSResponder dns0(listen_addr0);
1160     test::DNSResponder dns1(listen_addr1);
1161     test::DNSResponder dns2(listen_addr2);
1162     StartDns(dns0, {{host_name, ns_type::ns_t_aaaa, "2001:db8::5"}});
1163     StartDns(dns1, {{host_name, ns_type::ns_t_aaaa, "2001:db8::6"}});
1164     StartDns(dns2, {{host_name, ns_type::ns_t_aaaa, "2001:db8::7"}});
1165 
1166     const std::vector<std::string> servers = {listen_addr0, listen_addr1, listen_addr2};
1167     std::vector<std::thread> threads(10);
1168     for (std::thread& thread : threads) {
1169         thread = std::thread([this, &servers]() {
1170             unsigned delay = arc4random_uniform(1 * 1000 * 1000);  // <= 1s
1171             usleep(delay);
1172             std::vector<std::string> serverSubset;
1173             for (const auto& server : servers) {
1174                 if (arc4random_uniform(2)) {
1175                     serverSubset.push_back(server);
1176                 }
1177             }
1178             if (serverSubset.empty()) serverSubset = servers;
1179             ASSERT_TRUE(mDnsClient.SetResolversForNetwork(serverSubset));
1180             const addrinfo hints = {.ai_family = AF_INET6};
1181             addrinfo* result = nullptr;
1182             int rv = getaddrinfo("konbanha", nullptr, &hints, &result);
1183             EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
1184             if (result) {
1185                 freeaddrinfo(result);
1186                 result = nullptr;
1187             }
1188         });
1189     }
1190     for (std::thread& thread : threads) {
1191         thread.join();
1192     }
1193 
1194     std::vector<std::string> res_servers;
1195     std::vector<std::string> res_domains;
1196     std::vector<std::string> res_tls_servers;
1197     res_params res_params;
1198     std::vector<ResolverStats> res_stats;
1199     int wait_for_pending_req_timeout_count;
1200     ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
1201             mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
1202             &res_params, &res_stats, &wait_for_pending_req_timeout_count));
1203     EXPECT_EQ(0, wait_for_pending_req_timeout_count);
1204 }
1205 
TEST_F(ResolverTest,SkipBadServersDueToInternalError)1206 TEST_F(ResolverTest, SkipBadServersDueToInternalError) {
1207     constexpr char listen_addr1[] = "fe80::1";
1208     constexpr char listen_addr2[] = "255.255.255.255";
1209     constexpr char listen_addr3[] = "127.0.0.3";
1210     int counter = 0;  // To generate unique hostnames.
1211     test::DNSResponder dns(listen_addr3);
1212     ASSERT_TRUE(dns.startServer());
1213 
1214     ResolverParamsParcel setupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
1215     setupParams.servers = {listen_addr1, listen_addr2, listen_addr3};
1216     setupParams.minSamples = 2;  // Recognize bad servers in two attempts when sorting not enabled.
1217 
1218     ResolverParamsParcel cleanupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
1219     cleanupParams.servers.clear();
1220     cleanupParams.tlsServers.clear();
1221 
1222     for (const auto& sortNameserversFlag : {"" /* unset */, "0" /* off */, "1" /* on */}) {
1223         SCOPED_TRACE(fmt::format("sortNameversFlag_{}", sortNameserversFlag));
1224         ScopedSystemProperties scopedSystemProperties(kSortNameserversFlag, sortNameserversFlag);
1225 
1226         // Re-setup test network to make experiment flag take effect.
1227         resetNetwork();
1228 
1229         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(setupParams));
1230 
1231         // Start sending synchronized querying.
1232         for (int i = 0; i < 100; i++) {
1233             std::string hostName = StringPrintf("hello%d.com.", counter++);
1234             dns.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
1235             const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
1236             EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
1237         }
1238 
1239         const std::vector<NameserverStats> targetStats = {
1240                 NameserverStats(listen_addr1).setInternalErrors(5),
1241                 NameserverStats(listen_addr2).setInternalErrors(5),
1242                 NameserverStats(listen_addr3).setSuccesses(setupParams.maxSamples),
1243         };
1244         EXPECT_TRUE(expectStatsNotGreaterThan(targetStats));
1245 
1246         // Also verify the number of queries received in the server because res_stats.successes has
1247         // a maximum.
1248         EXPECT_EQ(dns.queries().size(), 100U);
1249 
1250         // Reset the state.
1251         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(cleanupParams));
1252         dns.clearQueries();
1253     }
1254 }
1255 
TEST_F(ResolverTest,SkipBadServersDueToTimeout)1256 TEST_F(ResolverTest, SkipBadServersDueToTimeout) {
1257     constexpr char listen_addr1[] = "127.0.0.3";
1258     constexpr char listen_addr2[] = "127.0.0.4";
1259     int counter = 0;  // To generate unique hostnames.
1260 
1261     ResolverParamsParcel setupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
1262     setupParams.servers = {listen_addr1, listen_addr2};
1263     setupParams.minSamples = 2;  // Recognize bad servers in two attempts when sorting not enabled.
1264 
1265     ResolverParamsParcel cleanupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
1266     cleanupParams.servers.clear();
1267     cleanupParams.tlsServers.clear();
1268 
1269     // Set dns1 non-responsive and dns2 workable.
1270     test::DNSResponder dns1(listen_addr1, test::kDefaultListenService, static_cast<ns_rcode>(-1));
1271     test::DNSResponder dns2(listen_addr2);
1272     dns1.setResponseProbability(0.0);
1273     ASSERT_TRUE(dns1.startServer());
1274     ASSERT_TRUE(dns2.startServer());
1275 
1276     for (const auto& sortNameserversFlag : {"" /* unset */, "0" /* off */, "1" /* on */}) {
1277         SCOPED_TRACE(fmt::format("sortNameversFlag_{}", sortNameserversFlag));
1278         ScopedSystemProperties scopedSystemProperties(kSortNameserversFlag, sortNameserversFlag);
1279 
1280         // Re-setup test network to make experiment flag take effect.
1281         resetNetwork();
1282 
1283         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(setupParams));
1284 
1285         // Start sending synchronized querying.
1286         for (int i = 0; i < 100; i++) {
1287             std::string hostName = StringPrintf("hello%d.com.", counter++);
1288             dns1.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
1289             dns2.addMapping(hostName, ns_type::ns_t_a, "1.2.3.5");
1290             const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
1291             EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
1292         }
1293 
1294         const std::vector<NameserverStats> targetStats = {
1295                 NameserverStats(listen_addr1).setTimeouts(5),
1296                 NameserverStats(listen_addr2).setSuccesses(setupParams.maxSamples),
1297         };
1298         EXPECT_TRUE(expectStatsNotGreaterThan(targetStats));
1299 
1300         // Also verify the number of queries received in the server because res_stats.successes has
1301         // an upper bound.
1302         EXPECT_GT(dns1.queries().size(), 0U);
1303         EXPECT_LT(dns1.queries().size(), 5U);
1304         EXPECT_EQ(dns2.queries().size(), 100U);
1305 
1306         // Reset the state.
1307         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(cleanupParams));
1308         dns1.clearQueries();
1309         dns2.clearQueries();
1310     }
1311 }
1312 
TEST_F(ResolverTest,GetAddrInfoFromCustTable_InvalidInput)1313 TEST_F(ResolverTest, GetAddrInfoFromCustTable_InvalidInput) {
1314     constexpr char hostnameNoip[] = "noip.example.com.";
1315     constexpr char hostnameInvalidip[] = "invalidip.example.com.";
1316     const std::vector<aidl::android::net::ResolverHostsParcel> invalidCustHosts = {
1317             {"", hostnameNoip},
1318             {"wrong IP", hostnameInvalidip},
1319     };
1320     test::DNSResponder dns;
1321     StartDns(dns, {});
1322     auto resolverParams = DnsResponderClient::GetDefaultResolverParamsParcel();
1323 
1324     ResolverOptionsParcel resolverOptions;
1325     resolverOptions.hosts = invalidCustHosts;
1326     if (!mIsResolverOptionIPCSupported) {
1327         resolverParams.resolverOptions = resolverOptions;
1328     }
1329     ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(resolverParams).isOk());
1330     if (mIsResolverOptionIPCSupported) {
1331         ASSERT_TRUE(mDnsClient.resolvService()
1332                             ->setResolverOptions(resolverParams.netId, resolverOptions)
1333                             .isOk());
1334     }
1335     for (const auto& hostname : {hostnameNoip, hostnameInvalidip}) {
1336         // The query won't get data from customized table because of invalid customized table
1337         // and DNSResponder also has no records. hostnameNoip has never registered and
1338         // hostnameInvalidip has registered but wrong IP.
1339         const addrinfo hints = {.ai_family = AF_UNSPEC};
1340         ScopedAddrinfo result = safe_getaddrinfo(hostname, nullptr, &hints);
1341         ASSERT_TRUE(result == nullptr);
1342         EXPECT_EQ(4U, GetNumQueries(dns, hostname));
1343     }
1344 }
1345 
TEST_F(ResolverTest,GetAddrInfoFromCustTable)1346 TEST_F(ResolverTest, GetAddrInfoFromCustTable) {
1347     constexpr char hostnameV4[] = "v4only.example.com.";
1348     constexpr char hostnameV6[] = "v6only.example.com.";
1349     constexpr char hostnameV4V6[] = "v4v6.example.com.";
1350     constexpr char custAddrV4[] = "1.2.3.4";
1351     constexpr char custAddrV6[] = "::1.2.3.4";
1352     constexpr char dnsSvAddrV4[] = "1.2.3.5";
1353     constexpr char dnsSvAddrV6[] = "::1.2.3.5";
1354     const std::vector<aidl::android::net::ResolverHostsParcel> custHostV4 = {
1355             {custAddrV4, hostnameV4},
1356     };
1357     const std::vector<aidl::android::net::ResolverHostsParcel> custHostV6 = {
1358             {custAddrV6, hostnameV6},
1359     };
1360     const std::vector<aidl::android::net::ResolverHostsParcel> custHostV4V6 = {
1361             {custAddrV4, hostnameV4V6},
1362             {custAddrV6, hostnameV4V6},
1363     };
1364     const std::vector<DnsRecord> dnsSvHostV4 = {
1365             {hostnameV4, ns_type::ns_t_a, dnsSvAddrV4},
1366     };
1367     const std::vector<DnsRecord> dnsSvHostV6 = {
1368             {hostnameV6, ns_type::ns_t_aaaa, dnsSvAddrV6},
1369     };
1370     const std::vector<DnsRecord> dnsSvHostV4V6 = {
1371             {hostnameV4V6, ns_type::ns_t_a, dnsSvAddrV4},
1372             {hostnameV4V6, ns_type::ns_t_aaaa, dnsSvAddrV6},
1373     };
1374     struct TestConfig {
1375         const std::string name;
1376         const std::vector<aidl::android::net::ResolverHostsParcel> customizedHosts;
1377         const std::vector<DnsRecord> dnsserverHosts;
1378         const std::vector<std::string> queryResult;
1379         std::string asParameters() const {
1380             return StringPrintf("name: %s, customizedHosts: %s, dnsserverHosts: %s", name.c_str(),
1381                                 customizedHosts.empty() ? "No" : "Yes",
1382                                 dnsserverHosts.empty() ? "No" : "Yes");
1383         }
1384     } testConfigs[]{
1385             // clang-format off
1386             {hostnameV4,    {},            {},             {}},
1387             {hostnameV4,    {},            dnsSvHostV4,    {dnsSvAddrV4}},
1388             {hostnameV4,    custHostV4,    {},             {custAddrV4}},
1389             {hostnameV4,    custHostV4,    dnsSvHostV4,    {custAddrV4}},
1390             {hostnameV6,    {},            {},             {}},
1391             {hostnameV6,    {},            dnsSvHostV6,    {dnsSvAddrV6}},
1392             {hostnameV6,    custHostV6,    {},             {custAddrV6}},
1393             {hostnameV6,    custHostV6,    dnsSvHostV6,    {custAddrV6}},
1394             {hostnameV4V6,  {},            {},             {}},
1395             {hostnameV4V6,  {},            dnsSvHostV4V6,  {dnsSvAddrV4, dnsSvAddrV6}},
1396             {hostnameV4V6,  custHostV4V6,  {},             {custAddrV4, custAddrV6}},
1397             {hostnameV4V6,  custHostV4V6,  dnsSvHostV4V6,  {custAddrV4, custAddrV6}},
1398             // clang-format on
1399     };
1400 
1401     for (const auto& config : testConfigs) {
1402         SCOPED_TRACE(config.asParameters());
1403 
1404         test::DNSResponder dns;
1405         StartDns(dns, config.dnsserverHosts);
1406 
1407         auto resolverParams = DnsResponderClient::GetDefaultResolverParamsParcel();
1408         ResolverOptionsParcel resolverOptions;
1409         resolverOptions.hosts = config.customizedHosts;
1410         if (!mIsResolverOptionIPCSupported) {
1411             resolverParams.resolverOptions = resolverOptions;
1412         }
1413         ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(resolverParams).isOk());
1414 
1415         if (mIsResolverOptionIPCSupported) {
1416             ASSERT_TRUE(mDnsClient.resolvService()
1417                                 ->setResolverOptions(resolverParams.netId, resolverOptions)
1418                                 .isOk());
1419         }
1420         const addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM};
1421         ScopedAddrinfo result = safe_getaddrinfo(config.name.c_str(), nullptr, &hints);
1422         if (config.customizedHosts.empty() && config.dnsserverHosts.empty()) {
1423             ASSERT_TRUE(result == nullptr);
1424             EXPECT_EQ(2U, GetNumQueries(dns, config.name.c_str()));
1425         } else {
1426             ASSERT_TRUE(result != nullptr);
1427             EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(config.queryResult));
1428             EXPECT_EQ(config.customizedHosts.empty() ? 2U : 0U,
1429                       GetNumQueries(dns, config.name.c_str()));
1430         }
1431 
1432         EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(TEST_NETID).isOk());
1433     }
1434 }
1435 
TEST_F(ResolverTest,GetAddrInfoFromCustTable_Modify)1436 TEST_F(ResolverTest, GetAddrInfoFromCustTable_Modify) {
1437     constexpr char hostnameV4V6[] = "v4v6.example.com.";
1438     constexpr char custAddrV4[] = "1.2.3.4";
1439     constexpr char custAddrV6[] = "::1.2.3.4";
1440     constexpr char dnsSvAddrV4[] = "1.2.3.5";
1441     constexpr char dnsSvAddrV6[] = "::1.2.3.5";
1442     const std::vector<DnsRecord> dnsSvHostV4V6 = {
1443             {hostnameV4V6, ns_type::ns_t_a, dnsSvAddrV4},
1444             {hostnameV4V6, ns_type::ns_t_aaaa, dnsSvAddrV6},
1445     };
1446     const std::vector<aidl::android::net::ResolverHostsParcel> custHostV4V6 = {
1447             {custAddrV4, hostnameV4V6},
1448             {custAddrV6, hostnameV4V6},
1449     };
1450     test::DNSResponder dns;
1451     StartDns(dns, dnsSvHostV4V6);
1452     auto resolverParams = DnsResponderClient::GetDefaultResolverParamsParcel();
1453 
1454     ResolverOptionsParcel resolverOptions;
1455     resolverOptions.hosts = custHostV4V6;
1456     if (!mIsResolverOptionIPCSupported) {
1457         resolverParams.resolverOptions = resolverOptions;
1458     }
1459     ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(resolverParams).isOk());
1460 
1461     if (mIsResolverOptionIPCSupported) {
1462         ASSERT_TRUE(mDnsClient.resolvService()
1463                             ->setResolverOptions(resolverParams.netId, resolverOptions)
1464                             .isOk());
1465     }
1466 
1467     const addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM};
1468     ScopedAddrinfo result = safe_getaddrinfo(hostnameV4V6, nullptr, &hints);
1469     ASSERT_TRUE(result != nullptr);
1470     EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray({custAddrV4, custAddrV6}));
1471     EXPECT_EQ(0U, GetNumQueries(dns, hostnameV4V6));
1472 
1473     resolverOptions.hosts = {};
1474     if (!mIsResolverOptionIPCSupported) {
1475         resolverParams.resolverOptions = resolverOptions;
1476         ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(resolverParams).isOk());
1477     } else {
1478         ASSERT_TRUE(mDnsClient.resolvService()
1479                             ->setResolverOptions(resolverParams.netId, resolverOptions)
1480                             .isOk());
1481     }
1482     result = safe_getaddrinfo(hostnameV4V6, nullptr, &hints);
1483     ASSERT_TRUE(result != nullptr);
1484     EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray({dnsSvAddrV4, dnsSvAddrV6}));
1485     EXPECT_EQ(2U, GetNumQueries(dns, hostnameV4V6));
1486 }
1487 
TEST_F(ResolverTest,EmptySetup)1488 TEST_F(ResolverTest, EmptySetup) {
1489     std::vector<std::string> servers;
1490     std::vector<std::string> domains;
1491     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, domains));
1492     std::vector<std::string> res_servers;
1493     std::vector<std::string> res_domains;
1494     std::vector<std::string> res_tls_servers;
1495     res_params res_params;
1496     std::vector<ResolverStats> res_stats;
1497     int wait_for_pending_req_timeout_count;
1498     ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
1499             mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
1500             &res_params, &res_stats, &wait_for_pending_req_timeout_count));
1501     EXPECT_EQ(0U, res_servers.size());
1502     EXPECT_EQ(0U, res_domains.size());
1503     EXPECT_EQ(0U, res_tls_servers.size());
1504     ASSERT_EQ(static_cast<size_t>(IDnsResolver::RESOLVER_PARAMS_COUNT), kDefaultParams.size());
1505     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_SAMPLE_VALIDITY],
1506               res_params.sample_validity);
1507     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
1508               res_params.success_threshold);
1509     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
1510     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
1511     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC],
1512               res_params.base_timeout_msec);
1513     EXPECT_EQ(kDefaultParams[IDnsResolver::RESOLVER_PARAMS_RETRY_COUNT], res_params.retry_count);
1514 }
1515 
TEST_F(ResolverTest,SearchPathChange)1516 TEST_F(ResolverTest, SearchPathChange) {
1517     constexpr char listen_addr[] = "127.0.0.13";
1518     constexpr char host_name1[] = "test13.domain1.org.";
1519     constexpr char host_name2[] = "test13.domain2.org.";
1520     std::vector<std::string> servers = {listen_addr};
1521     std::vector<std::string> domains = {"domain1.org"};
1522 
1523     const std::vector<DnsRecord> records = {
1524             {host_name1, ns_type::ns_t_aaaa, "2001:db8::13"},
1525             {host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13"},
1526     };
1527     test::DNSResponder dns(listen_addr);
1528     StartDns(dns, records);
1529     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, domains));
1530 
1531     const addrinfo hints = {.ai_family = AF_INET6};
1532     ScopedAddrinfo result = safe_getaddrinfo("test13", nullptr, &hints);
1533     EXPECT_TRUE(result != nullptr);
1534     EXPECT_EQ(1U, dns.queries().size());
1535     EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
1536     EXPECT_EQ("2001:db8::13", ToString(result));
1537 
1538     // Test that changing the domain search path on its own works.
1539     domains = {"domain2.org"};
1540     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, domains));
1541     dns.clearQueries();
1542 
1543     result = safe_getaddrinfo("test13", nullptr, &hints);
1544     EXPECT_TRUE(result != nullptr);
1545     EXPECT_EQ(1U, dns.queries().size());
1546     EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
1547     EXPECT_EQ("2001:db8::1:13", ToString(result));
1548 }
1549 
1550 namespace {
1551 
getResolverDomains(aidl::android::net::IDnsResolver * dnsResolverService,unsigned netId)1552 std::vector<std::string> getResolverDomains(aidl::android::net::IDnsResolver* dnsResolverService,
1553                                             unsigned netId) {
1554     std::vector<std::string> res_servers;
1555     std::vector<std::string> res_domains;
1556     std::vector<std::string> res_tls_servers;
1557     res_params res_params;
1558     std::vector<ResolverStats> res_stats;
1559     int wait_for_pending_req_timeout_count;
1560     DnsResponderClient::GetResolverInfo(dnsResolverService, netId, &res_servers, &res_domains,
1561                                         &res_tls_servers, &res_params, &res_stats,
1562                                         &wait_for_pending_req_timeout_count);
1563     return res_domains;
1564 }
1565 
1566 }  // namespace
1567 
TEST_F(ResolverTest,SearchPathPrune)1568 TEST_F(ResolverTest, SearchPathPrune) {
1569     constexpr size_t DUPLICATED_DOMAIN_NUM = 3;
1570     constexpr char listen_addr[] = "127.0.0.13";
1571     constexpr char domian_name1[] = "domain13.org.";
1572     constexpr char domian_name2[] = "domain14.org.";
1573     constexpr char host_name1[] = "test13.domain13.org.";
1574     constexpr char host_name2[] = "test14.domain14.org.";
1575     std::vector<std::string> servers = {listen_addr};
1576 
1577     std::vector<std::string> testDomains1;
1578     std::vector<std::string> testDomains2;
1579     // Domain length should be <= 255
1580     // Max number of domains in search path is 6
1581     for (size_t i = 0; i < MAXDNSRCH + 1; i++) {
1582         // Fill up with invalid domain
1583         testDomains1.push_back(std::string(300, i + '0'));
1584         // Fill up with valid but duplicated domain
1585         testDomains2.push_back(StringPrintf("domain%zu.org", i % DUPLICATED_DOMAIN_NUM));
1586     }
1587 
1588     // Add valid domain used for query.
1589     testDomains1.push_back(domian_name1);
1590 
1591     // Add valid domain twice used for query.
1592     testDomains2.push_back(domian_name2);
1593     testDomains2.push_back(domian_name2);
1594 
1595     const std::vector<DnsRecord> records = {
1596             {host_name1, ns_type::ns_t_aaaa, "2001:db8::13"},
1597             {host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13"},
1598     };
1599     test::DNSResponder dns(listen_addr);
1600     StartDns(dns, records);
1601     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains1));
1602 
1603     const addrinfo hints = {.ai_family = AF_INET6};
1604     ScopedAddrinfo result = safe_getaddrinfo("test13", nullptr, &hints);
1605 
1606     EXPECT_TRUE(result != nullptr);
1607 
1608     EXPECT_EQ(1U, dns.queries().size());
1609     EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
1610     EXPECT_EQ("2001:db8::13", ToString(result));
1611 
1612     const auto& res_domains1 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID);
1613     // Expect 1 valid domain, invalid domains are removed.
1614     ASSERT_EQ(1U, res_domains1.size());
1615     EXPECT_EQ(domian_name1, res_domains1[0]);
1616 
1617     dns.clearQueries();
1618 
1619     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains2));
1620 
1621     result = safe_getaddrinfo("test14", nullptr, &hints);
1622     EXPECT_TRUE(result != nullptr);
1623 
1624     // (3 domains * 2 retries) + 1 success query = 7
1625     EXPECT_EQ(7U, dns.queries().size());
1626     EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
1627     EXPECT_EQ("2001:db8::1:13", ToString(result));
1628 
1629     const auto& res_domains2 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID);
1630     // Expect 4 valid domain, duplicate domains are removed.
1631     EXPECT_EQ(DUPLICATED_DOMAIN_NUM + 1U, res_domains2.size());
1632     EXPECT_THAT(
1633             std::vector<std::string>({"domain0.org", "domain1.org", "domain2.org", domian_name2}),
1634             testing::ElementsAreArray(res_domains2));
1635 }
1636 
1637 // If we move this function to dns_responder_client, it will complicate the dependency need of
1638 // dns_tls_frontend.h.
setupTlsServers(const std::vector<std::string> & servers,std::vector<std::unique_ptr<test::DnsTlsFrontend>> * tls)1639 static void setupTlsServers(const std::vector<std::string>& servers,
1640                             std::vector<std::unique_ptr<test::DnsTlsFrontend>>* tls) {
1641     constexpr char listen_udp[] = "53";
1642     constexpr char listen_tls[] = "853";
1643 
1644     for (const auto& server : servers) {
1645         auto t = std::make_unique<test::DnsTlsFrontend>(server, listen_tls, server, listen_udp);
1646         t = std::make_unique<test::DnsTlsFrontend>(server, listen_tls, server, listen_udp);
1647         t->startServer();
1648         tls->push_back(std::move(t));
1649     }
1650 }
1651 
TEST_F(ResolverTest,MaxServerPrune_Binder)1652 TEST_F(ResolverTest, MaxServerPrune_Binder) {
1653     std::vector<std::string> domains;
1654     std::vector<std::unique_ptr<test::DNSResponder>> dns;
1655     std::vector<std::unique_ptr<test::DnsTlsFrontend>> tls;
1656     std::vector<std::string> servers;
1657     std::vector<DnsResponderClient::Mapping> mappings;
1658 
1659     for (unsigned i = 0; i < MAXDNSRCH + 1; i++) {
1660         domains.push_back(StringPrintf("example%u.com", i));
1661     }
1662     ASSERT_NO_FATAL_FAILURE(mDnsClient.SetupMappings(1, domains, &mappings));
1663     ASSERT_NO_FATAL_FAILURE(mDnsClient.SetupDNSServers(MAXNS + 1, mappings, &dns, &servers));
1664     ASSERT_NO_FATAL_FAILURE(setupTlsServers(servers, &tls));
1665 
1666     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, domains, kDefaultParams,
1667                                                kDefaultPrivateDnsHostName));
1668 
1669     // If the private DNS validation hasn't completed yet before backend DNS servers stop,
1670     // TLS servers will get stuck in handleOneRequest(), which causes this test stuck in
1671     // ~DnsTlsFrontend() because the TLS server loop threads can't be terminated.
1672     // So, wait for private DNS validation done before stopping backend DNS servers.
1673     for (int i = 0; i < MAXNS; i++) {
1674         LOG(INFO) << "Waiting for private DNS validation on " << tls[i]->listen_address() << ".";
1675         EXPECT_TRUE(WaitForPrivateDnsValidation(tls[i]->listen_address(), true));
1676         LOG(INFO) << "private DNS validation on " << tls[i]->listen_address() << " done.";
1677     }
1678 
1679     std::vector<std::string> res_servers;
1680     std::vector<std::string> res_domains;
1681     std::vector<std::string> res_tls_servers;
1682     res_params res_params;
1683     std::vector<ResolverStats> res_stats;
1684     int wait_for_pending_req_timeout_count;
1685     ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
1686             mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
1687             &res_params, &res_stats, &wait_for_pending_req_timeout_count));
1688 
1689     // Check the size of the stats and its contents.
1690     EXPECT_EQ(static_cast<size_t>(MAXNS), res_servers.size());
1691     EXPECT_EQ(static_cast<size_t>(MAXNS), res_tls_servers.size());
1692     EXPECT_EQ(static_cast<size_t>(MAXDNSRCH), res_domains.size());
1693     EXPECT_TRUE(std::equal(servers.begin(), servers.begin() + MAXNS, res_servers.begin()));
1694     EXPECT_TRUE(std::equal(servers.begin(), servers.begin() + MAXNS, res_tls_servers.begin()));
1695     EXPECT_TRUE(std::equal(domains.begin(), domains.begin() + MAXDNSRCH, res_domains.begin()));
1696 }
1697 
TEST_F(ResolverTest,ResolverStats)1698 TEST_F(ResolverTest, ResolverStats) {
1699     constexpr char listen_addr1[] = "127.0.0.4";
1700     constexpr char listen_addr2[] = "127.0.0.5";
1701     constexpr char listen_addr3[] = "127.0.0.6";
1702 
1703     // Set server 1 timeout.
1704     test::DNSResponder dns1(listen_addr1, "53", static_cast<ns_rcode>(-1));
1705     dns1.setResponseProbability(0.0);
1706     ASSERT_TRUE(dns1.startServer());
1707 
1708     // Set server 2 responding server failure.
1709     test::DNSResponder dns2(listen_addr2);
1710     dns2.setResponseProbability(0.0);
1711     ASSERT_TRUE(dns2.startServer());
1712 
1713     // Set server 3 workable.
1714     test::DNSResponder dns3(listen_addr3);
1715     dns3.addMapping(kHelloExampleCom, ns_type::ns_t_a, "1.2.3.4");
1716     ASSERT_TRUE(dns3.startServer());
1717 
1718     std::vector<std::string> servers = {listen_addr1, listen_addr2, listen_addr3};
1719     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
1720 
1721     dns3.clearQueries();
1722     const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
1723     ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
1724     size_t found = GetNumQueries(dns3, kHelloExampleCom);
1725     EXPECT_LE(1U, found);
1726     std::string result_str = ToString(result);
1727     EXPECT_TRUE(result_str == "1.2.3.4") << ", result_str='" << result_str << "'";
1728 
1729     const std::vector<NameserverStats> expectedCleartextDnsStats = {
1730             NameserverStats(listen_addr1).setTimeouts(1),
1731             NameserverStats(listen_addr2).setErrors(1),
1732             NameserverStats(listen_addr3).setSuccesses(1),
1733     };
1734     EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
1735 }
1736 
TEST_F(ResolverTest,AlwaysUseLatestSetupParamsInLookups)1737 TEST_F(ResolverTest, AlwaysUseLatestSetupParamsInLookups) {
1738     constexpr char listen_addr1[] = "127.0.0.3";
1739     constexpr char listen_addr2[] = "255.255.255.255";
1740     constexpr char listen_addr3[] = "127.0.0.4";
1741     constexpr char hostname[] = "hello";
1742     constexpr char fqdn_with_search_domain[] = "hello.domain2.com.";
1743 
1744     test::DNSResponder dns1(listen_addr1, test::kDefaultListenService, static_cast<ns_rcode>(-1));
1745     dns1.setResponseProbability(0.0);
1746     ASSERT_TRUE(dns1.startServer());
1747 
1748     test::DNSResponder dns3(listen_addr3);
1749     StartDns(dns3, {{fqdn_with_search_domain, ns_type::ns_t_a, "1.2.3.4"}});
1750 
1751     ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
1752     parcel.tlsServers.clear();
1753     parcel.servers = {listen_addr1, listen_addr2};
1754     parcel.domains = {"domain1.com", "domain2.com"};
1755     ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
1756 
1757     // Expect the things happening in t1:
1758     //   1. The lookup starts using the first domain for query. It sends queries to the populated
1759     //      nameserver list {listen_addr1, listen_addr2} for the hostname "hello.domain1.com".
1760     //   2. A different list of nameservers is updated to the resolver. Revision ID is incremented.
1761     //   3. The query for the hostname times out. The lookup fails to add the timeout record to the
1762     //      the stats because of the unmatched revision ID.
1763     //   4. The lookup starts using the second domain for query. It sends queries to the populated
1764     //      nameserver list {listen_addr3, listen_addr1, listen_addr2} for another hostname
1765     //      "hello.domain2.com".
1766     //   5. The lookup gets the answer and updates a success record to the stats.
1767     std::thread t1([&hostname]() {
1768         const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
1769         ScopedAddrinfo result = safe_getaddrinfo(hostname, nullptr, &hints);
1770         EXPECT_NE(result.get(), nullptr);
1771         EXPECT_EQ(ToString(result), "1.2.3.4");
1772     });
1773 
1774     // Wait for t1 to start the step 1.
1775     while (dns1.queries().size() == 0) {
1776         usleep(1000);
1777     }
1778 
1779     // Update the resolver with three nameservers. This will increment the revision ID.
1780     parcel.servers = {listen_addr3, listen_addr1, listen_addr2};
1781     ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
1782 
1783     t1.join();
1784     EXPECT_EQ(0U, GetNumQueriesForType(dns3, ns_type::ns_t_aaaa, fqdn_with_search_domain));
1785     EXPECT_EQ(1U, GetNumQueriesForType(dns3, ns_type::ns_t_a, fqdn_with_search_domain));
1786 
1787     const std::vector<NameserverStats> expectedCleartextDnsStats = {
1788             NameserverStats(listen_addr1),
1789             NameserverStats(listen_addr2),
1790             NameserverStats(listen_addr3).setSuccesses(1),
1791     };
1792     EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
1793 }
1794 
1795 // Test what happens if the specified TLS server is nonexistent.
TEST_F(ResolverTest,GetHostByName_TlsMissing)1796 TEST_F(ResolverTest, GetHostByName_TlsMissing) {
1797     constexpr char listen_addr[] = "127.0.0.3";
1798     constexpr char host_name[] = "tlsmissing.example.com.";
1799 
1800     test::DNSResponder dns;
1801     StartDns(dns, {{host_name, ns_type::ns_t_a, "1.2.3.3"}});
1802     std::vector<std::string> servers = {listen_addr};
1803 
1804     // There's nothing listening on this address, so validation will either fail or
1805     /// hang.  Either way, queries will continue to flow to the DNSResponder.
1806     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
1807 
1808     const hostent* result;
1809 
1810     result = gethostbyname("tlsmissing");
1811     ASSERT_FALSE(result == nullptr);
1812     EXPECT_EQ("1.2.3.3", ToString(result));
1813 
1814     // Clear TLS bit.
1815     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
1816 }
1817 
1818 // Test what happens if the specified TLS server replies with garbage.
TEST_F(ResolverTest,GetHostByName_TlsBroken)1819 TEST_F(ResolverTest, GetHostByName_TlsBroken) {
1820     constexpr char listen_addr[] = "127.0.0.3";
1821     constexpr char host_name1[] = "tlsbroken1.example.com.";
1822     constexpr char host_name2[] = "tlsbroken2.example.com.";
1823     const std::vector<DnsRecord> records = {
1824             {host_name1, ns_type::ns_t_a, "1.2.3.1"},
1825             {host_name2, ns_type::ns_t_a, "1.2.3.2"},
1826     };
1827 
1828     test::DNSResponder dns;
1829     StartDns(dns, records);
1830     std::vector<std::string> servers = {listen_addr};
1831 
1832     // Bind the specified private DNS socket but don't respond to any client sockets yet.
1833     int s = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
1834     ASSERT_TRUE(s >= 0);
1835     struct sockaddr_in tlsServer = {
1836             .sin_family = AF_INET,
1837             .sin_port = htons(853),
1838     };
1839     ASSERT_TRUE(inet_pton(AF_INET, listen_addr, &tlsServer.sin_addr));
1840     ASSERT_TRUE(enableSockopt(s, SOL_SOCKET, SO_REUSEPORT).ok());
1841     ASSERT_TRUE(enableSockopt(s, SOL_SOCKET, SO_REUSEADDR).ok());
1842     ASSERT_FALSE(bind(s, reinterpret_cast<struct sockaddr*>(&tlsServer), sizeof(tlsServer)));
1843     ASSERT_FALSE(listen(s, 1));
1844 
1845     // Trigger TLS validation.
1846     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
1847 
1848     struct sockaddr_storage cliaddr;
1849     socklen_t sin_size = sizeof(cliaddr);
1850     int new_fd = accept4(s, reinterpret_cast<struct sockaddr*>(&cliaddr), &sin_size, SOCK_CLOEXEC);
1851     ASSERT_TRUE(new_fd > 0);
1852 
1853     // We've received the new file descriptor but not written to it or closed, so the
1854     // validation is still pending.  Queries should still flow correctly because the
1855     // server is not used until validation succeeds.
1856     const hostent* result;
1857     result = gethostbyname("tlsbroken1");
1858     ASSERT_FALSE(result == nullptr);
1859     EXPECT_EQ("1.2.3.1", ToString(result));
1860 
1861     // Now we cause the validation to fail.
1862     std::string garbage = "definitely not a valid TLS ServerHello";
1863     write(new_fd, garbage.data(), garbage.size());
1864     close(new_fd);
1865 
1866     // Validation failure shouldn't interfere with lookups, because lookups won't be sent
1867     // to the TLS server unless validation succeeds.
1868     result = gethostbyname("tlsbroken2");
1869     ASSERT_FALSE(result == nullptr);
1870     EXPECT_EQ("1.2.3.2", ToString(result));
1871 
1872     // Clear TLS bit.
1873     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
1874     close(s);
1875 }
1876 
TEST_F(ResolverTest,GetHostByName_Tls)1877 TEST_F(ResolverTest, GetHostByName_Tls) {
1878     constexpr char listen_addr[] = "127.0.0.3";
1879     constexpr char listen_udp[] = "53";
1880     constexpr char listen_tls[] = "853";
1881     constexpr char host_name1[] = "tls1.example.com.";
1882     constexpr char host_name2[] = "tls2.example.com.";
1883     constexpr char host_name3[] = "tls3.example.com.";
1884     const std::vector<DnsRecord> records = {
1885             {host_name1, ns_type::ns_t_a, "1.2.3.1"},
1886             {host_name2, ns_type::ns_t_a, "1.2.3.2"},
1887             {host_name3, ns_type::ns_t_a, "1.2.3.3"},
1888     };
1889 
1890     test::DNSResponder dns;
1891     StartDns(dns, records);
1892     std::vector<std::string> servers = {listen_addr};
1893 
1894     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
1895     ASSERT_TRUE(tls.startServer());
1896     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
1897     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
1898 
1899     const hostent* result = gethostbyname("tls1");
1900     ASSERT_FALSE(result == nullptr);
1901     EXPECT_EQ("1.2.3.1", ToString(result));
1902 
1903     // Wait for query to get counted.
1904     EXPECT_TRUE(tls.waitForQueries(2));
1905 
1906     // Stop the TLS server.  Since we're in opportunistic mode, queries will
1907     // fall back to the locally-assigned (clear text) nameservers.
1908     tls.stopServer();
1909 
1910     dns.clearQueries();
1911     result = gethostbyname("tls2");
1912     EXPECT_FALSE(result == nullptr);
1913     EXPECT_EQ("1.2.3.2", ToString(result));
1914     const auto queries = dns.queries();
1915     EXPECT_EQ(1U, queries.size());
1916     EXPECT_EQ("tls2.example.com.", queries[0].name);
1917     EXPECT_EQ(ns_t_a, queries[0].type);
1918 
1919     // Reset the resolvers without enabling TLS.  Queries should still be routed
1920     // to the UDP endpoint.
1921     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
1922 
1923     result = gethostbyname("tls3");
1924     ASSERT_FALSE(result == nullptr);
1925     EXPECT_EQ("1.2.3.3", ToString(result));
1926 }
1927 
TEST_F(ResolverTest,GetHostByName_TlsFailover)1928 TEST_F(ResolverTest, GetHostByName_TlsFailover) {
1929     constexpr char listen_addr1[] = "127.0.0.3";
1930     constexpr char listen_addr2[] = "127.0.0.4";
1931     constexpr char listen_udp[] = "53";
1932     constexpr char listen_tls[] = "853";
1933     constexpr char host_name1[] = "tlsfailover1.example.com.";
1934     constexpr char host_name2[] = "tlsfailover2.example.com.";
1935     const std::vector<DnsRecord> records1 = {
1936             {host_name1, ns_type::ns_t_a, "1.2.3.1"},
1937             {host_name2, ns_type::ns_t_a, "1.2.3.2"},
1938     };
1939     const std::vector<DnsRecord> records2 = {
1940             {host_name1, ns_type::ns_t_a, "1.2.3.3"},
1941             {host_name2, ns_type::ns_t_a, "1.2.3.4"},
1942     };
1943 
1944     test::DNSResponder dns1(listen_addr1);
1945     test::DNSResponder dns2(listen_addr2);
1946     StartDns(dns1, records1);
1947     StartDns(dns2, records2);
1948 
1949     std::vector<std::string> servers = {listen_addr1, listen_addr2};
1950 
1951     test::DnsTlsFrontend tls1(listen_addr1, listen_tls, listen_addr1, listen_udp);
1952     test::DnsTlsFrontend tls2(listen_addr2, listen_tls, listen_addr2, listen_udp);
1953     ASSERT_TRUE(tls1.startServer());
1954     ASSERT_TRUE(tls2.startServer());
1955     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
1956                                                kDefaultPrivateDnsHostName));
1957     EXPECT_TRUE(WaitForPrivateDnsValidation(tls1.listen_address(), true));
1958     EXPECT_TRUE(WaitForPrivateDnsValidation(tls2.listen_address(), true));
1959 
1960     const hostent* result = gethostbyname("tlsfailover1");
1961     ASSERT_FALSE(result == nullptr);
1962     EXPECT_EQ("1.2.3.1", ToString(result));
1963 
1964     // Wait for query to get counted.
1965     EXPECT_TRUE(tls1.waitForQueries(2));
1966     // No new queries should have reached tls2.
1967     EXPECT_TRUE(tls2.waitForQueries(1));
1968 
1969     // Stop tls1.  Subsequent queries should attempt to reach tls1, fail, and retry to tls2.
1970     tls1.stopServer();
1971 
1972     result = gethostbyname("tlsfailover2");
1973     EXPECT_EQ("1.2.3.4", ToString(result));
1974 
1975     // Wait for query to get counted.
1976     EXPECT_TRUE(tls2.waitForQueries(2));
1977 
1978     // No additional queries should have reached the insecure servers.
1979     EXPECT_EQ(2U, dns1.queries().size());
1980     EXPECT_EQ(2U, dns2.queries().size());
1981 
1982     // Clear TLS bit.
1983     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
1984 }
1985 
TEST_F(ResolverTest,GetHostByName_BadTlsName)1986 TEST_F(ResolverTest, GetHostByName_BadTlsName) {
1987     constexpr char listen_addr[] = "127.0.0.3";
1988     constexpr char listen_udp[] = "53";
1989     constexpr char listen_tls[] = "853";
1990     constexpr char host_name[] = "badtlsname.example.com.";
1991 
1992     test::DNSResponder dns;
1993     StartDns(dns, {{host_name, ns_type::ns_t_a, "1.2.3.1"}});
1994     std::vector<std::string> servers = {listen_addr};
1995 
1996     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
1997     ASSERT_TRUE(tls.startServer());
1998     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
1999                                                kDefaultIncorrectPrivateDnsHostName));
2000 
2001     // The TLS handshake would fail because the name of TLS server doesn't
2002     // match with TLS server's certificate.
2003     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
2004 
2005     // The query should fail hard, because a name was specified.
2006     EXPECT_EQ(nullptr, gethostbyname("badtlsname"));
2007 
2008     // Clear TLS bit.
2009     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
2010 }
2011 
TEST_F(ResolverTest,GetAddrInfo_Tls)2012 TEST_F(ResolverTest, GetAddrInfo_Tls) {
2013     constexpr char listen_addr[] = "127.0.0.3";
2014     constexpr char listen_udp[] = "53";
2015     constexpr char listen_tls[] = "853";
2016     constexpr char host_name[] = "addrinfotls.example.com.";
2017     const std::vector<DnsRecord> records = {
2018             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2019             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2020     };
2021 
2022     test::DNSResponder dns;
2023     StartDns(dns, records);
2024     std::vector<std::string> servers = {listen_addr};
2025 
2026     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
2027     ASSERT_TRUE(tls.startServer());
2028     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
2029                                                kDefaultPrivateDnsHostName));
2030     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
2031 
2032     dns.clearQueries();
2033     ScopedAddrinfo result = safe_getaddrinfo("addrinfotls", nullptr, nullptr);
2034     EXPECT_TRUE(result != nullptr);
2035     size_t found = GetNumQueries(dns, host_name);
2036     EXPECT_LE(1U, found);
2037     // Could be A or AAAA
2038     std::string result_str = ToString(result);
2039     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
2040             << ", result_str='" << result_str << "'";
2041     // Wait for both A and AAAA queries to get counted.
2042     EXPECT_TRUE(tls.waitForQueries(3));
2043 
2044     // Clear TLS bit.
2045     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
2046 }
2047 
TEST_F(ResolverTest,TlsBypass)2048 TEST_F(ResolverTest, TlsBypass) {
2049     const char OFF[] = "off";
2050     const char OPPORTUNISTIC[] = "opportunistic";
2051     const char STRICT[] = "strict";
2052 
2053     const char GETHOSTBYNAME[] = "gethostbyname";
2054     const char GETADDRINFO[] = "getaddrinfo";
2055     const char GETADDRINFOFORNET[] = "getaddrinfofornet";
2056 
2057     const unsigned BYPASS_NETID = NETID_USE_LOCAL_NAMESERVERS | TEST_NETID;
2058 
2059     const char ADDR4[] = "192.0.2.1";
2060     const char ADDR6[] = "2001:db8::1";
2061 
2062     const char cleartext_addr[] = "127.0.0.53";
2063     const char cleartext_port[] = "53";
2064     const char tls_port[] = "853";
2065     const std::vector<std::string> servers = {cleartext_addr};
2066 
2067     test::DNSResponder dns(cleartext_addr);
2068     ASSERT_TRUE(dns.startServer());
2069 
2070     test::DnsTlsFrontend tls(cleartext_addr, tls_port, cleartext_addr, cleartext_port);
2071     ASSERT_TRUE(tls.startServer());
2072 
2073     // clang-format off
2074     struct TestConfig {
2075         const std::string mode;
2076         const bool withWorkingTLS;
2077         const std::string method;
2078 
2079         std::string asHostName() const {
2080             return StringPrintf("%s.%s.%s.", mode.c_str(), withWorkingTLS ? "tlsOn" : "tlsOff",
2081                                 method.c_str());
2082         }
2083     } testConfigs[]{
2084         {OFF,           true,  GETHOSTBYNAME},
2085         {OPPORTUNISTIC, true,  GETHOSTBYNAME},
2086         {STRICT,        true,  GETHOSTBYNAME},
2087         {OFF,           true,  GETADDRINFO},
2088         {OPPORTUNISTIC, true,  GETADDRINFO},
2089         {STRICT,        true,  GETADDRINFO},
2090         {OFF,           true,  GETADDRINFOFORNET},
2091         {OPPORTUNISTIC, true,  GETADDRINFOFORNET},
2092         {STRICT,        true,  GETADDRINFOFORNET},
2093         {OFF,           false, GETHOSTBYNAME},
2094         {OPPORTUNISTIC, false, GETHOSTBYNAME},
2095         {STRICT,        false, GETHOSTBYNAME},
2096         {OFF,           false, GETADDRINFO},
2097         {OPPORTUNISTIC, false, GETADDRINFO},
2098         {STRICT,        false, GETADDRINFO},
2099         {OFF,           false, GETADDRINFOFORNET},
2100         {OPPORTUNISTIC, false, GETADDRINFOFORNET},
2101         {STRICT,        false, GETADDRINFOFORNET},
2102     };
2103     // clang-format on
2104 
2105     for (const auto& config : testConfigs) {
2106         const std::string testHostName = config.asHostName();
2107         SCOPED_TRACE(testHostName);
2108 
2109         // Don't tempt test bugs due to caching.
2110         const char* host_name = testHostName.c_str();
2111         dns.addMapping(host_name, ns_type::ns_t_a, ADDR4);
2112         dns.addMapping(host_name, ns_type::ns_t_aaaa, ADDR6);
2113 
2114         if (config.withWorkingTLS) {
2115             if (!tls.running()) {
2116                 ASSERT_TRUE(tls.startServer());
2117             }
2118         } else {
2119             if (tls.running()) {
2120                 ASSERT_TRUE(tls.stopServer());
2121             }
2122         }
2123 
2124         if (config.mode == OFF) {
2125             ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, kDefaultSearchDomains,
2126                                                           kDefaultParams));
2127         } else /* OPPORTUNISTIC or STRICT */ {
2128             const char* tls_hostname = (config.mode == STRICT) ? kDefaultPrivateDnsHostName : "";
2129             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
2130                                                        kDefaultParams, tls_hostname));
2131 
2132             // Wait for the validation event. If the server is running, the validation should
2133             // succeed; otherwise, the validation should fail.
2134             EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), config.withWorkingTLS));
2135             if (config.withWorkingTLS) {
2136                 EXPECT_TRUE(tls.waitForQueries(1));
2137                 tls.clearQueries();
2138             }
2139         }
2140 
2141         const hostent* h_result = nullptr;
2142         ScopedAddrinfo ai_result;
2143 
2144         if (config.method == GETHOSTBYNAME) {
2145             ASSERT_EQ(0, setNetworkForResolv(BYPASS_NETID));
2146             h_result = gethostbyname(host_name);
2147 
2148             EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
2149             ASSERT_FALSE(h_result == nullptr);
2150             ASSERT_EQ(4, h_result->h_length);
2151             ASSERT_FALSE(h_result->h_addr_list[0] == nullptr);
2152             EXPECT_EQ(ADDR4, ToString(h_result));
2153             EXPECT_TRUE(h_result->h_addr_list[1] == nullptr);
2154         } else if (config.method == GETADDRINFO) {
2155             ASSERT_EQ(0, setNetworkForResolv(BYPASS_NETID));
2156             ai_result = safe_getaddrinfo(host_name, nullptr, nullptr);
2157             EXPECT_TRUE(ai_result != nullptr);
2158 
2159             EXPECT_LE(1U, GetNumQueries(dns, host_name));
2160             // Could be A or AAAA
2161             const std::string result_str = ToString(ai_result);
2162             EXPECT_TRUE(result_str == ADDR4 || result_str == ADDR6)
2163                     << ", result_str='" << result_str << "'";
2164         } else if (config.method == GETADDRINFOFORNET) {
2165             addrinfo* raw_ai_result = nullptr;
2166             EXPECT_EQ(0, android_getaddrinfofornet(host_name, /*servname=*/nullptr,
2167                                                    /*hints=*/nullptr, BYPASS_NETID, MARK_UNSET,
2168                                                    &raw_ai_result));
2169             ai_result.reset(raw_ai_result);
2170 
2171             EXPECT_LE(1U, GetNumQueries(dns, host_name));
2172             // Could be A or AAAA
2173             const std::string result_str = ToString(ai_result);
2174             EXPECT_TRUE(result_str == ADDR4 || result_str == ADDR6)
2175                     << ", result_str='" << result_str << "'";
2176         }
2177 
2178         EXPECT_EQ(0, tls.queries());
2179 
2180         // Clear per-process resolv netid.
2181         ASSERT_EQ(0, setNetworkForResolv(NETID_UNSET));
2182         dns.clearQueries();
2183     }
2184 }
2185 
TEST_F(ResolverTest,StrictMode_NoTlsServers)2186 TEST_F(ResolverTest, StrictMode_NoTlsServers) {
2187     constexpr char cleartext_addr[] = "127.0.0.53";
2188     const std::vector<std::string> servers = {cleartext_addr};
2189     constexpr char host_name[] = "strictmode.notlsips.example.com.";
2190     const std::vector<DnsRecord> records = {
2191             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2192             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2193     };
2194 
2195     test::DNSResponder dns(cleartext_addr);
2196     StartDns(dns, records);
2197 
2198     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
2199                                                kDefaultIncorrectPrivateDnsHostName));
2200 
2201     addrinfo* ai_result = nullptr;
2202     EXPECT_NE(0, getaddrinfo(host_name, nullptr, nullptr, &ai_result));
2203     EXPECT_EQ(0U, GetNumQueries(dns, host_name));
2204 }
2205 
2206 namespace {
2207 
getAsyncResponse(int fd,int * rcode,uint8_t * buf,int bufLen)2208 int getAsyncResponse(int fd, int* rcode, uint8_t* buf, int bufLen) {
2209     struct pollfd wait_fd[1];
2210     wait_fd[0].fd = fd;
2211     wait_fd[0].events = POLLIN;
2212     short revents;
2213     int ret;
2214 
2215     ret = poll(wait_fd, 1, -1);
2216     revents = wait_fd[0].revents;
2217     if (revents & POLLIN) {
2218         return resNetworkResult(fd, rcode, buf, bufLen);
2219     }
2220     return -1;
2221 }
2222 
toString(uint8_t * buf,int bufLen,int ipType)2223 std::string toString(uint8_t* buf, int bufLen, int ipType) {
2224     ns_msg handle;
2225     int ancount, n = 0;
2226     ns_rr rr;
2227 
2228     if (ns_initparse((const uint8_t*)buf, bufLen, &handle) >= 0) {
2229         ancount = ns_msg_count(handle, ns_s_an);
2230         if (ns_parserr(&handle, ns_s_an, n, &rr) == 0) {
2231             const uint8_t* rdata = ns_rr_rdata(rr);
2232             char buffer[INET6_ADDRSTRLEN];
2233             if (inet_ntop(ipType, (const char*)rdata, buffer, sizeof(buffer))) {
2234                 return buffer;
2235             }
2236         }
2237     }
2238     return "";
2239 }
2240 
dns_open_proxy()2241 int dns_open_proxy() {
2242     int s = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
2243     if (s == -1) {
2244         return -1;
2245     }
2246     const int one = 1;
2247     setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
2248 
2249     static const struct sockaddr_un proxy_addr = {
2250             .sun_family = AF_UNIX,
2251             .sun_path = "/dev/socket/dnsproxyd",
2252     };
2253 
2254     if (TEMP_FAILURE_RETRY(connect(s, (const struct sockaddr*)&proxy_addr, sizeof(proxy_addr))) !=
2255         0) {
2256         close(s);
2257         return -1;
2258     }
2259 
2260     return s;
2261 }
2262 
expectAnswersValid(int fd,int ipType,const std::string & expectedAnswer)2263 void expectAnswersValid(int fd, int ipType, const std::string& expectedAnswer) {
2264     int rcode = -1;
2265     uint8_t buf[MAXPACKET] = {};
2266 
2267     int res = getAsyncResponse(fd, &rcode, buf, MAXPACKET);
2268     EXPECT_GT(res, 0);
2269     EXPECT_EQ(expectedAnswer, toString(buf, res, ipType));
2270 }
2271 
expectAnswersNotValid(int fd,int expectedErrno)2272 void expectAnswersNotValid(int fd, int expectedErrno) {
2273     int rcode = -1;
2274     uint8_t buf[MAXPACKET] = {};
2275 
2276     int res = getAsyncResponse(fd, &rcode, buf, MAXPACKET);
2277     EXPECT_EQ(expectedErrno, res);
2278 }
2279 
2280 }  // namespace
2281 
TEST_F(ResolverTest,Async_NormalQueryV4V6)2282 TEST_F(ResolverTest, Async_NormalQueryV4V6) {
2283     constexpr char listen_addr[] = "127.0.0.4";
2284     constexpr char host_name[] = "howdy.example.com.";
2285     const std::vector<DnsRecord> records = {
2286             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2287             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2288     };
2289 
2290     test::DNSResponder dns(listen_addr);
2291     StartDns(dns, records);
2292     std::vector<std::string> servers = {listen_addr};
2293     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2294 
2295     int fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2296     int fd2 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_aaaa, 0);
2297     EXPECT_TRUE(fd1 != -1);
2298     EXPECT_TRUE(fd2 != -1);
2299 
2300     uint8_t buf[MAXPACKET] = {};
2301     int rcode;
2302     int res = getAsyncResponse(fd2, &rcode, buf, MAXPACKET);
2303     EXPECT_GT(res, 0);
2304     EXPECT_EQ("::1.2.3.4", toString(buf, res, AF_INET6));
2305 
2306     res = getAsyncResponse(fd1, &rcode, buf, MAXPACKET);
2307     EXPECT_GT(res, 0);
2308     EXPECT_EQ("1.2.3.4", toString(buf, res, AF_INET));
2309 
2310     EXPECT_EQ(2U, GetNumQueries(dns, host_name));
2311 
2312     // Re-query verify cache works
2313     fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2314     fd2 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_aaaa, 0);
2315 
2316     EXPECT_TRUE(fd1 != -1);
2317     EXPECT_TRUE(fd2 != -1);
2318 
2319     res = getAsyncResponse(fd2, &rcode, buf, MAXPACKET);
2320     EXPECT_GT(res, 0);
2321     EXPECT_EQ("::1.2.3.4", toString(buf, res, AF_INET6));
2322 
2323     res = getAsyncResponse(fd1, &rcode, buf, MAXPACKET);
2324     EXPECT_GT(res, 0);
2325     EXPECT_EQ("1.2.3.4", toString(buf, res, AF_INET));
2326 
2327     EXPECT_EQ(2U, GetNumQueries(dns, host_name));
2328 }
2329 
TEST_F(ResolverTest,Async_BadQuery)2330 TEST_F(ResolverTest, Async_BadQuery) {
2331     constexpr char listen_addr[] = "127.0.0.4";
2332     constexpr char host_name[] = "howdy.example.com.";
2333     const std::vector<DnsRecord> records = {
2334             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2335             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2336     };
2337 
2338     test::DNSResponder dns(listen_addr);
2339     StartDns(dns, records);
2340     std::vector<std::string> servers = {listen_addr};
2341     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2342 
2343     static struct {
2344         int fd;
2345         const char* dname;
2346         const int queryType;
2347         const int expectRcode;
2348     } kTestData[] = {
2349             {-1, "", ns_t_aaaa, 0},
2350             {-1, "as65ass46", ns_t_aaaa, 0},
2351             {-1, "454564564564", ns_t_aaaa, 0},
2352             {-1, "h645235", ns_t_a, 0},
2353             {-1, "www.google.com", ns_t_a, 0},
2354     };
2355 
2356     for (auto& td : kTestData) {
2357         SCOPED_TRACE(td.dname);
2358         td.fd = resNetworkQuery(TEST_NETID, td.dname, ns_c_in, td.queryType, 0);
2359         EXPECT_TRUE(td.fd != -1);
2360     }
2361 
2362     // dns_responder return empty resp(packet only contains query part) with no error currently
2363     for (const auto& td : kTestData) {
2364         uint8_t buf[MAXPACKET] = {};
2365         int rcode;
2366         SCOPED_TRACE(td.dname);
2367         int res = getAsyncResponse(td.fd, &rcode, buf, MAXPACKET);
2368         EXPECT_GT(res, 0);
2369         EXPECT_EQ(rcode, td.expectRcode);
2370     }
2371 }
2372 
TEST_F(ResolverTest,Async_EmptyAnswer)2373 TEST_F(ResolverTest, Async_EmptyAnswer) {
2374     constexpr char listen_addr[] = "127.0.0.4";
2375     constexpr char host_name[] = "howdy.example.com.";
2376     const std::vector<DnsRecord> records = {
2377             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2378             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2379     };
2380 
2381     test::DNSResponder dns(listen_addr);
2382     StartDns(dns, records);
2383     std::vector<std::string> servers = {listen_addr};
2384     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2385 
2386     // TODO: Disable retry to make this test explicit.
2387     auto& cv = dns.getCv();
2388     auto& cvMutex = dns.getCvMutex();
2389     int fd1;
2390     // Wait on the condition variable to ensure that the DNS server has handled our first query.
2391     {
2392         std::unique_lock lk(cvMutex);
2393         fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_aaaa, 0);
2394         EXPECT_TRUE(fd1 != -1);
2395         EXPECT_EQ(std::cv_status::no_timeout, cv.wait_for(lk, std::chrono::seconds(1)));
2396     }
2397 
2398     dns.setResponseProbability(0.0);
2399 
2400     int fd2 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2401     EXPECT_TRUE(fd2 != -1);
2402 
2403     int fd3 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2404     EXPECT_TRUE(fd3 != -1);
2405 
2406     uint8_t buf[MAXPACKET] = {};
2407     int rcode;
2408 
2409     // expect no response
2410     int res = getAsyncResponse(fd3, &rcode, buf, MAXPACKET);
2411     EXPECT_EQ(-ETIMEDOUT, res);
2412 
2413     // expect no response
2414     memset(buf, 0, MAXPACKET);
2415     res = getAsyncResponse(fd2, &rcode, buf, MAXPACKET);
2416     EXPECT_EQ(-ETIMEDOUT, res);
2417 
2418     dns.setResponseProbability(1.0);
2419 
2420     int fd4 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2421     EXPECT_TRUE(fd4 != -1);
2422 
2423     memset(buf, 0, MAXPACKET);
2424     res = getAsyncResponse(fd4, &rcode, buf, MAXPACKET);
2425     EXPECT_GT(res, 0);
2426     EXPECT_EQ("1.2.3.4", toString(buf, res, AF_INET));
2427 
2428     memset(buf, 0, MAXPACKET);
2429     res = getAsyncResponse(fd1, &rcode, buf, MAXPACKET);
2430     EXPECT_GT(res, 0);
2431     EXPECT_EQ("::1.2.3.4", toString(buf, res, AF_INET6));
2432 
2433     // Trailing dot is removed. Is it intended?
2434     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, 0, "howdy.example.com", {"::1.2.3.4"});
2435     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, RCODE_TIMEOUT, "howdy.example.com", {});
2436     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, RCODE_TIMEOUT, "howdy.example.com", {});
2437     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, 0, "howdy.example.com", {"1.2.3.4"});
2438 }
2439 
TEST_F(ResolverTest,Async_MalformedQuery)2440 TEST_F(ResolverTest, Async_MalformedQuery) {
2441     constexpr char listen_addr[] = "127.0.0.4";
2442     constexpr char host_name[] = "howdy.example.com.";
2443     const std::vector<DnsRecord> records = {
2444             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2445             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2446     };
2447 
2448     test::DNSResponder dns(listen_addr);
2449     StartDns(dns, records);
2450     std::vector<std::string> servers = {listen_addr};
2451     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2452 
2453     int fd = dns_open_proxy();
2454     EXPECT_TRUE(fd > 0);
2455 
2456     const std::string badMsg = "16-52512#";
2457     static const struct {
2458         const std::string cmd;
2459         const int expectErr;
2460     } kTestData[] = {
2461             // Too few arguments
2462             {"resnsend " + badMsg + '\0', -EINVAL},
2463             // Bad netId
2464             {"resnsend badnetId 0 " + badMsg + '\0', -EINVAL},
2465             // Bad raw data
2466             {"resnsend " + std::to_string(TEST_NETID) + " 0 " + badMsg + '\0', -EILSEQ},
2467     };
2468 
2469     for (unsigned int i = 0; i < std::size(kTestData); i++) {
2470         auto& td = kTestData[i];
2471         SCOPED_TRACE(td.cmd);
2472         ssize_t rc = TEMP_FAILURE_RETRY(write(fd, td.cmd.c_str(), td.cmd.size()));
2473         EXPECT_EQ(rc, static_cast<ssize_t>(td.cmd.size()));
2474 
2475         int32_t tmp;
2476         rc = TEMP_FAILURE_RETRY(read(fd, &tmp, sizeof(tmp)));
2477         EXPECT_TRUE(rc > 0);
2478         EXPECT_EQ(static_cast<int>(ntohl(tmp)), td.expectErr);
2479     }
2480     // Normal query with answer buffer
2481     // This is raw data of query "howdy.example.com" type 1 class 1
2482     std::string query = "81sBAAABAAAAAAAABWhvd2R5B2V4YW1wbGUDY29tAAABAAE=";
2483     std::string cmd = "resnsend " + std::to_string(TEST_NETID) + " 0 " + query + '\0';
2484     ssize_t rc = TEMP_FAILURE_RETRY(write(fd, cmd.c_str(), cmd.size()));
2485     EXPECT_EQ(rc, static_cast<ssize_t>(cmd.size()));
2486 
2487     uint8_t smallBuf[1] = {};
2488     int rcode;
2489     rc = getAsyncResponse(fd, &rcode, smallBuf, 1);
2490     EXPECT_EQ(-EMSGSIZE, rc);
2491 
2492     // Do the normal test with large buffer again
2493     fd = dns_open_proxy();
2494     EXPECT_TRUE(fd > 0);
2495     rc = TEMP_FAILURE_RETRY(write(fd, cmd.c_str(), cmd.size()));
2496     EXPECT_EQ(rc, static_cast<ssize_t>(cmd.size()));
2497     uint8_t buf[MAXPACKET] = {};
2498     rc = getAsyncResponse(fd, &rcode, buf, MAXPACKET);
2499     EXPECT_EQ("1.2.3.4", toString(buf, rc, AF_INET));
2500 }
2501 
TEST_F(ResolverTest,Async_CacheFlags)2502 TEST_F(ResolverTest, Async_CacheFlags) {
2503     constexpr char listen_addr[] = "127.0.0.4";
2504     constexpr char host_name1[] = "howdy.example.com.";
2505     constexpr char host_name2[] = "howdy.example2.com.";
2506     constexpr char host_name3[] = "howdy.example3.com.";
2507     const std::vector<DnsRecord> records = {
2508             {host_name1, ns_type::ns_t_a, "1.2.3.4"}, {host_name1, ns_type::ns_t_aaaa, "::1.2.3.4"},
2509             {host_name2, ns_type::ns_t_a, "1.2.3.5"}, {host_name2, ns_type::ns_t_aaaa, "::1.2.3.5"},
2510             {host_name3, ns_type::ns_t_a, "1.2.3.6"}, {host_name3, ns_type::ns_t_aaaa, "::1.2.3.6"},
2511     };
2512 
2513     test::DNSResponder dns(listen_addr);
2514     StartDns(dns, records);
2515     std::vector<std::string> servers = {listen_addr};
2516     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2517 
2518     // ANDROID_RESOLV_NO_CACHE_STORE
2519     int fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2520                               ANDROID_RESOLV_NO_CACHE_STORE);
2521     EXPECT_TRUE(fd1 != -1);
2522     int fd2 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2523                               ANDROID_RESOLV_NO_CACHE_STORE);
2524     EXPECT_TRUE(fd2 != -1);
2525     int fd3 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2526                               ANDROID_RESOLV_NO_CACHE_STORE);
2527     EXPECT_TRUE(fd3 != -1);
2528 
2529     expectAnswersValid(fd3, AF_INET, "1.2.3.4");
2530     expectAnswersValid(fd2, AF_INET, "1.2.3.4");
2531     expectAnswersValid(fd1, AF_INET, "1.2.3.4");
2532 
2533     // No cache exists, expect 3 queries
2534     EXPECT_EQ(3U, GetNumQueries(dns, host_name1));
2535 
2536     // Raise a query with no flags to ensure no cache exists. Also make an cache entry for the
2537     // query.
2538     fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2539 
2540     EXPECT_TRUE(fd1 != -1);
2541 
2542     expectAnswersValid(fd1, AF_INET, "1.2.3.4");
2543 
2544     // Expect 4 queries because there should be no cache before this query.
2545     EXPECT_EQ(4U, GetNumQueries(dns, host_name1));
2546 
2547     // Now we have the cache entry, re-query with ANDROID_RESOLV_NO_CACHE_STORE to ensure
2548     // that ANDROID_RESOLV_NO_CACHE_STORE implied ANDROID_RESOLV_NO_CACHE_LOOKUP.
2549     fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2550                           ANDROID_RESOLV_NO_CACHE_STORE);
2551     EXPECT_TRUE(fd1 != -1);
2552     expectAnswersValid(fd1, AF_INET, "1.2.3.4");
2553     // Expect 5 queries because we shouldn't do cache lookup for the query which has
2554     // ANDROID_RESOLV_NO_CACHE_STORE.
2555     EXPECT_EQ(5U, GetNumQueries(dns, host_name1));
2556 
2557     // ANDROID_RESOLV_NO_CACHE_LOOKUP
2558     fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2559                           ANDROID_RESOLV_NO_CACHE_LOOKUP);
2560     fd2 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2561                           ANDROID_RESOLV_NO_CACHE_LOOKUP);
2562 
2563     EXPECT_TRUE(fd1 != -1);
2564     EXPECT_TRUE(fd2 != -1);
2565 
2566     expectAnswersValid(fd2, AF_INET, "1.2.3.4");
2567     expectAnswersValid(fd1, AF_INET, "1.2.3.4");
2568 
2569     // Cache was skipped, expect 2 more queries.
2570     EXPECT_EQ(7U, GetNumQueries(dns, host_name1));
2571 
2572     // Re-query verify cache works
2573     fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2574     EXPECT_TRUE(fd1 != -1);
2575     expectAnswersValid(fd1, AF_INET, "1.2.3.4");
2576 
2577     // Cache hits,  expect still 7 queries
2578     EXPECT_EQ(7U, GetNumQueries(dns, host_name1));
2579 
2580     // Start to verify if ANDROID_RESOLV_NO_CACHE_LOOKUP does write response into cache
2581     dns.clearQueries();
2582 
2583     fd1 = resNetworkQuery(TEST_NETID, "howdy.example2.com", ns_c_in, ns_t_aaaa,
2584                           ANDROID_RESOLV_NO_CACHE_LOOKUP);
2585     fd2 = resNetworkQuery(TEST_NETID, "howdy.example2.com", ns_c_in, ns_t_aaaa,
2586                           ANDROID_RESOLV_NO_CACHE_LOOKUP);
2587 
2588     EXPECT_TRUE(fd1 != -1);
2589     EXPECT_TRUE(fd2 != -1);
2590 
2591     expectAnswersValid(fd2, AF_INET6, "::1.2.3.5");
2592     expectAnswersValid(fd1, AF_INET6, "::1.2.3.5");
2593 
2594     // Skip cache, expect 2 queries
2595     EXPECT_EQ(2U, GetNumQueries(dns, host_name2));
2596 
2597     // Re-query without flags
2598     fd1 = resNetworkQuery(TEST_NETID, "howdy.example2.com", ns_c_in, ns_t_aaaa, 0);
2599     fd2 = resNetworkQuery(TEST_NETID, "howdy.example2.com", ns_c_in, ns_t_aaaa, 0);
2600 
2601     EXPECT_TRUE(fd1 != -1);
2602     EXPECT_TRUE(fd2 != -1);
2603 
2604     expectAnswersValid(fd2, AF_INET6, "::1.2.3.5");
2605     expectAnswersValid(fd1, AF_INET6, "::1.2.3.5");
2606 
2607     // Cache hits, expect still 2 queries
2608     EXPECT_EQ(2U, GetNumQueries(dns, host_name2));
2609 
2610     // Test both ANDROID_RESOLV_NO_CACHE_STORE and ANDROID_RESOLV_NO_CACHE_LOOKUP are set
2611     dns.clearQueries();
2612 
2613     // Make sure that the cache of "howdy.example3.com" exists.
2614     fd1 = resNetworkQuery(TEST_NETID, "howdy.example3.com", ns_c_in, ns_t_aaaa, 0);
2615     EXPECT_TRUE(fd1 != -1);
2616     expectAnswersValid(fd1, AF_INET6, "::1.2.3.6");
2617     EXPECT_EQ(1U, GetNumQueries(dns, host_name3));
2618 
2619     // Re-query with testFlags
2620     const int testFlag = ANDROID_RESOLV_NO_CACHE_STORE | ANDROID_RESOLV_NO_CACHE_LOOKUP;
2621     fd1 = resNetworkQuery(TEST_NETID, "howdy.example3.com", ns_c_in, ns_t_aaaa, testFlag);
2622     EXPECT_TRUE(fd1 != -1);
2623     expectAnswersValid(fd1, AF_INET6, "::1.2.3.6");
2624     // Expect cache lookup is skipped.
2625     EXPECT_EQ(2U, GetNumQueries(dns, host_name3));
2626 
2627     // Do another query with testFlags
2628     fd1 = resNetworkQuery(TEST_NETID, "howdy.example3.com", ns_c_in, ns_t_a, testFlag);
2629     EXPECT_TRUE(fd1 != -1);
2630     expectAnswersValid(fd1, AF_INET, "1.2.3.6");
2631     // Expect cache lookup is skipped.
2632     EXPECT_EQ(3U, GetNumQueries(dns, host_name3));
2633 
2634     // Re-query with no flags
2635     fd1 = resNetworkQuery(TEST_NETID, "howdy.example3.com", ns_c_in, ns_t_a, 0);
2636     EXPECT_TRUE(fd1 != -1);
2637     expectAnswersValid(fd1, AF_INET, "1.2.3.6");
2638     // Expect no cache hit because cache storing is also skipped in previous query.
2639     EXPECT_EQ(4U, GetNumQueries(dns, host_name3));
2640 }
2641 
TEST_F(ResolverTest,Async_NoCacheStoreFlagDoesNotRefreshStaleCacheEntry)2642 TEST_F(ResolverTest, Async_NoCacheStoreFlagDoesNotRefreshStaleCacheEntry) {
2643     constexpr char listen_addr[] = "127.0.0.4";
2644     constexpr char host_name[] = "howdy.example.com.";
2645     const std::vector<DnsRecord> records = {
2646             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2647     };
2648 
2649     test::DNSResponder dns(listen_addr);
2650     StartDns(dns, records);
2651     std::vector<std::string> servers = {listen_addr};
2652     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2653 
2654     const unsigned SHORT_TTL_SEC = 1;
2655     dns.setTtl(SHORT_TTL_SEC);
2656 
2657     // Refer to b/148842821 for the purpose of below test steps.
2658     // Basically, this test is used to ensure stale cache case is handled
2659     // correctly with ANDROID_RESOLV_NO_CACHE_STORE.
2660     int fd = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2661     EXPECT_TRUE(fd != -1);
2662     expectAnswersValid(fd, AF_INET, "1.2.3.4");
2663 
2664     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
2665     dns.clearQueries();
2666 
2667     // Wait until cache expired
2668     sleep(SHORT_TTL_SEC + 0.5);
2669 
2670     // Now request the same hostname again.
2671     // We should see a new DNS query because the entry in cache has become stale.
2672     // Due to ANDROID_RESOLV_NO_CACHE_STORE, this query must *not* refresh that stale entry.
2673     fd = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2674                          ANDROID_RESOLV_NO_CACHE_STORE);
2675     EXPECT_TRUE(fd != -1);
2676     expectAnswersValid(fd, AF_INET, "1.2.3.4");
2677     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
2678     dns.clearQueries();
2679 
2680     // If the cache is still stale, we expect to see one more DNS query
2681     // (this time the cache will be refreshed, but we're not checking for it).
2682     fd = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2683     EXPECT_TRUE(fd != -1);
2684     expectAnswersValid(fd, AF_INET, "1.2.3.4");
2685     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
2686 }
2687 
TEST_F(ResolverTest,Async_NoRetryFlag)2688 TEST_F(ResolverTest, Async_NoRetryFlag) {
2689     constexpr char listen_addr0[] = "127.0.0.4";
2690     constexpr char listen_addr1[] = "127.0.0.6";
2691     constexpr char host_name[] = "howdy.example.com.";
2692     const std::vector<DnsRecord> records = {
2693             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2694             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2695     };
2696 
2697     test::DNSResponder dns0(listen_addr0);
2698     test::DNSResponder dns1(listen_addr1);
2699     StartDns(dns0, records);
2700     StartDns(dns1, records);
2701     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr0, listen_addr1}));
2702 
2703     dns0.clearQueries();
2704     dns1.clearQueries();
2705 
2706     dns0.setResponseProbability(0.0);
2707     dns1.setResponseProbability(0.0);
2708 
2709     int fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a,
2710                               ANDROID_RESOLV_NO_RETRY);
2711     EXPECT_TRUE(fd1 != -1);
2712 
2713     int fd2 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_aaaa,
2714                               ANDROID_RESOLV_NO_RETRY);
2715     EXPECT_TRUE(fd2 != -1);
2716 
2717     // expect no response
2718     expectAnswersNotValid(fd1, -ETIMEDOUT);
2719     expectAnswersNotValid(fd2, -ETIMEDOUT);
2720     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, RCODE_TIMEOUT, "howdy.example.com", {});
2721     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, RCODE_TIMEOUT, "howdy.example.com", {});
2722 
2723     // No retry case, expect total 2 queries. The server is selected randomly.
2724     EXPECT_EQ(2U, GetNumQueries(dns0, host_name) + GetNumQueries(dns1, host_name));
2725 
2726     dns0.clearQueries();
2727     dns1.clearQueries();
2728 
2729     fd1 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_a, 0);
2730     EXPECT_TRUE(fd1 != -1);
2731 
2732     fd2 = resNetworkQuery(TEST_NETID, "howdy.example.com", ns_c_in, ns_t_aaaa, 0);
2733     EXPECT_TRUE(fd2 != -1);
2734 
2735     // expect no response
2736     expectAnswersNotValid(fd1, -ETIMEDOUT);
2737     expectAnswersNotValid(fd2, -ETIMEDOUT);
2738     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, RCODE_TIMEOUT, "howdy.example.com", {});
2739     ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, RCODE_TIMEOUT, "howdy.example.com", {});
2740 
2741     // Retry case, expect 4 queries
2742     EXPECT_EQ(4U, GetNumQueries(dns0, host_name));
2743     EXPECT_EQ(4U, GetNumQueries(dns1, host_name));
2744 }
2745 
TEST_F(ResolverTest,Async_VerifyQueryID)2746 TEST_F(ResolverTest, Async_VerifyQueryID) {
2747     constexpr char listen_addr[] = "127.0.0.4";
2748     constexpr char host_name[] = "howdy.example.com.";
2749     const std::vector<DnsRecord> records = {
2750             {host_name, ns_type::ns_t_a, "1.2.3.4"},
2751             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
2752     };
2753 
2754     test::DNSResponder dns(listen_addr);
2755     StartDns(dns, records);
2756     std::vector<std::string> servers = {listen_addr};
2757     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2758 
2759     const uint8_t queryBuf1[] = {
2760             /* Header */
2761             0x55, 0x66, /* Transaction ID */
2762             0x01, 0x00, /* Flags */
2763             0x00, 0x01, /* Questions */
2764             0x00, 0x00, /* Answer RRs */
2765             0x00, 0x00, /* Authority RRs */
2766             0x00, 0x00, /* Additional RRs */
2767             /* Queries */
2768             0x05, 0x68, 0x6f, 0x77, 0x64, 0x79, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65,
2769             0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */
2770             0x00, 0x01,                   /* Type */
2771             0x00, 0x01                    /* Class */
2772     };
2773 
2774     int fd = resNetworkSend(TEST_NETID, queryBuf1, sizeof(queryBuf1), 0);
2775     EXPECT_TRUE(fd != -1);
2776 
2777     uint8_t buf[MAXPACKET] = {};
2778     int rcode;
2779 
2780     int res = getAsyncResponse(fd, &rcode, buf, MAXPACKET);
2781     EXPECT_GT(res, 0);
2782     EXPECT_EQ("1.2.3.4", toString(buf, res, AF_INET));
2783 
2784     auto hp = reinterpret_cast<HEADER*>(buf);
2785     EXPECT_EQ(21862U, htons(hp->id));
2786 
2787     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
2788 
2789     const uint8_t queryBuf2[] = {
2790             /* Header */
2791             0x00, 0x53, /* Transaction ID */
2792             0x01, 0x00, /* Flags */
2793             0x00, 0x01, /* Questions */
2794             0x00, 0x00, /* Answer RRs */
2795             0x00, 0x00, /* Authority RRs */
2796             0x00, 0x00, /* Additional RRs */
2797             /* Queries */
2798             0x05, 0x68, 0x6f, 0x77, 0x64, 0x79, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65,
2799             0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */
2800             0x00, 0x01,                   /* Type */
2801             0x00, 0x01                    /* Class */
2802     };
2803 
2804     // Re-query verify cache works and query id is correct
2805     fd = resNetworkSend(TEST_NETID, queryBuf2, sizeof(queryBuf2), 0);
2806 
2807     EXPECT_TRUE(fd != -1);
2808 
2809     res = getAsyncResponse(fd, &rcode, buf, MAXPACKET);
2810     EXPECT_GT(res, 0);
2811     EXPECT_EQ("1.2.3.4", toString(buf, res, AF_INET));
2812 
2813     EXPECT_EQ(0x0053U, htons(hp->id));
2814 
2815     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
2816 }
2817 
2818 // This test checks that the resolver should not generate the request containing OPT RR when using
2819 // cleartext DNS. If we query the DNS server not supporting EDNS0 and it reponds with
2820 // FORMERR_ON_EDNS, we will fallback to no EDNS0 and try again. If the server does no response, we
2821 // won't retry so that we get no answer.
TEST_F(ResolverTest,BrokenEdns)2822 TEST_F(ResolverTest, BrokenEdns) {
2823     typedef test::DNSResponder::Edns Edns;
2824     enum ExpectResult { EXPECT_FAILURE, EXPECT_SUCCESS };
2825 
2826     // Perform cleartext query in off mode.
2827     const char OFF[] = "off";
2828 
2829     // Perform cleartext query when there's no private DNS server validated in opportunistic mode.
2830     const char OPPORTUNISTIC_UDP[] = "opportunistic_udp";
2831 
2832     // Perform cleartext query when there is a private DNS server validated in opportunistic mode.
2833     const char OPPORTUNISTIC_FALLBACK_UDP[] = "opportunistic_fallback_udp";
2834 
2835     // Perform cyphertext query in opportunistic mode.
2836     const char OPPORTUNISTIC_TLS[] = "opportunistic_tls";
2837 
2838     // Perform cyphertext query in strict mode.
2839     const char STRICT[] = "strict";
2840 
2841     const char GETHOSTBYNAME[] = "gethostbyname";
2842     const char GETADDRINFO[] = "getaddrinfo";
2843     const char ADDR4[] = "192.0.2.1";
2844     const char CLEARTEXT_ADDR[] = "127.0.0.53";
2845     const char CLEARTEXT_PORT[] = "53";
2846     const char TLS_PORT[] = "853";
2847     const std::vector<std::string> servers = {CLEARTEXT_ADDR};
2848     ResolverParamsParcel paramsForCleanup = DnsResponderClient::GetDefaultResolverParamsParcel();
2849     paramsForCleanup.servers.clear();
2850     paramsForCleanup.tlsServers.clear();
2851 
2852     test::DNSResponder dns(CLEARTEXT_ADDR, CLEARTEXT_PORT, ns_rcode::ns_r_servfail);
2853     ASSERT_TRUE(dns.startServer());
2854 
2855     test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
2856 
2857     // clang-format off
2858     static const struct TestConfig {
2859         std::string mode;
2860         std::string method;
2861         Edns edns;
2862         ExpectResult expectResult;
2863 
2864         std::string asHostName() const {
2865             const char* ednsString;
2866             switch (edns) {
2867                 case Edns::ON:
2868                     ednsString = "ednsOn";
2869                     break;
2870                 case Edns::FORMERR_ON_EDNS:
2871                     ednsString = "ednsFormerr";
2872                     break;
2873                 case Edns::DROP:
2874                     ednsString = "ednsDrop";
2875                     break;
2876                 default:
2877                     ednsString = "";
2878                     break;
2879             }
2880             return StringPrintf("%s.%s.%s.", mode.c_str(), method.c_str(), ednsString);
2881         }
2882     } testConfigs[] = {
2883             // In OPPORTUNISTIC_TLS, if the DNS server doesn't support EDNS0 but TLS, the lookup
2884             // fails. Could such server exist? if so, we might need to fix it to fallback to
2885             // cleartext query. If the server still make no response for the queries with EDNS0, we
2886             // might also need to fix it to retry without EDNS0.
2887             // Another thing is that {OPPORTUNISTIC_TLS, Edns::DROP} and {STRICT, Edns::DROP} are
2888             // commented out since TLS timeout is not configurable.
2889             // TODO: Uncomment them after TLS timeout is configurable.
2890             {OFF,                        GETHOSTBYNAME, Edns::ON,      EXPECT_SUCCESS},
2891             {OPPORTUNISTIC_UDP,          GETHOSTBYNAME, Edns::ON,      EXPECT_SUCCESS},
2892             {OPPORTUNISTIC_FALLBACK_UDP, GETHOSTBYNAME, Edns::ON,      EXPECT_SUCCESS},
2893             {OPPORTUNISTIC_TLS,          GETHOSTBYNAME, Edns::ON,      EXPECT_SUCCESS},
2894             {STRICT,                     GETHOSTBYNAME, Edns::ON,      EXPECT_SUCCESS},
2895             {OFF,                        GETHOSTBYNAME, Edns::FORMERR_ON_EDNS, EXPECT_SUCCESS},
2896             {OPPORTUNISTIC_UDP,          GETHOSTBYNAME, Edns::FORMERR_ON_EDNS, EXPECT_SUCCESS},
2897             {OPPORTUNISTIC_FALLBACK_UDP, GETHOSTBYNAME, Edns::FORMERR_ON_EDNS, EXPECT_SUCCESS},
2898             {OPPORTUNISTIC_TLS,          GETHOSTBYNAME, Edns::FORMERR_ON_EDNS, EXPECT_FAILURE},
2899             {STRICT,                     GETHOSTBYNAME, Edns::FORMERR_ON_EDNS, EXPECT_FAILURE},
2900             {OFF,                        GETHOSTBYNAME, Edns::DROP,    EXPECT_SUCCESS},
2901             {OPPORTUNISTIC_UDP,          GETHOSTBYNAME, Edns::DROP,    EXPECT_SUCCESS},
2902 
2903             // The failure is due to no retry on timeout. Maybe fix it?
2904             {OPPORTUNISTIC_FALLBACK_UDP, GETHOSTBYNAME, Edns::DROP,    EXPECT_FAILURE},
2905 
2906             //{OPPORTUNISTIC_TLS,        GETHOSTBYNAME, Edns::DROP,    EXPECT_FAILURE},
2907             //{STRICT,                   GETHOSTBYNAME, Edns::DROP,    EXPECT_FAILURE},
2908             {OFF,                        GETADDRINFO,   Edns::ON,      EXPECT_SUCCESS},
2909             {OPPORTUNISTIC_UDP,          GETADDRINFO,   Edns::ON,      EXPECT_SUCCESS},
2910             {OPPORTUNISTIC_FALLBACK_UDP, GETADDRINFO,   Edns::ON,      EXPECT_SUCCESS},
2911             {OPPORTUNISTIC_TLS,          GETADDRINFO,   Edns::ON,      EXPECT_SUCCESS},
2912             {STRICT,                     GETADDRINFO,   Edns::ON,      EXPECT_SUCCESS},
2913             {OFF,                        GETADDRINFO,   Edns::FORMERR_ON_EDNS, EXPECT_SUCCESS},
2914             {OPPORTUNISTIC_UDP,          GETADDRINFO,   Edns::FORMERR_ON_EDNS, EXPECT_SUCCESS},
2915             {OPPORTUNISTIC_FALLBACK_UDP, GETADDRINFO,   Edns::FORMERR_ON_EDNS, EXPECT_SUCCESS},
2916             {OPPORTUNISTIC_TLS,          GETADDRINFO,   Edns::FORMERR_ON_EDNS, EXPECT_FAILURE},
2917             {STRICT,                     GETADDRINFO,   Edns::FORMERR_ON_EDNS, EXPECT_FAILURE},
2918             {OFF,                        GETADDRINFO,   Edns::DROP,    EXPECT_SUCCESS},
2919             {OPPORTUNISTIC_UDP,          GETADDRINFO,   Edns::DROP,    EXPECT_SUCCESS},
2920 
2921             // The failure is due to no retry on timeout. Maybe fix it?
2922             {OPPORTUNISTIC_FALLBACK_UDP, GETADDRINFO,   Edns::DROP,    EXPECT_FAILURE},
2923 
2924             //{OPPORTUNISTIC_TLS, GETADDRINFO,   Edns::DROP,   EXPECT_FAILURE},
2925             //{STRICT,            GETADDRINFO,   Edns::DROP,   EXPECT_FAILURE},
2926     };
2927     // clang-format on
2928 
2929     for (const auto& config : testConfigs) {
2930         const std::string testHostName = config.asHostName();
2931         SCOPED_TRACE(testHostName);
2932 
2933         const char* host_name = testHostName.c_str();
2934         dns.addMapping(host_name, ns_type::ns_t_a, ADDR4);
2935         dns.setEdns(config.edns);
2936 
2937         if (config.mode == OFF) {
2938             if (tls.running()) {
2939                 ASSERT_TRUE(tls.stopServer());
2940             }
2941             ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
2942         } else if (config.mode == OPPORTUNISTIC_UDP) {
2943             if (tls.running()) {
2944                 ASSERT_TRUE(tls.stopServer());
2945             }
2946             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
2947                                                        kDefaultParams, ""));
2948             EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
2949         } else if (config.mode == OPPORTUNISTIC_TLS || config.mode == OPPORTUNISTIC_FALLBACK_UDP) {
2950             if (!tls.running()) {
2951                 ASSERT_TRUE(tls.startServer());
2952             }
2953             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
2954                                                        kDefaultParams, ""));
2955             EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
2956 
2957             if (config.mode == OPPORTUNISTIC_FALLBACK_UDP) {
2958                 // Force the resolver to fallback to cleartext queries.
2959                 ASSERT_TRUE(tls.stopServer());
2960             }
2961         } else if (config.mode == STRICT) {
2962             if (!tls.running()) {
2963                 ASSERT_TRUE(tls.startServer());
2964             }
2965             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
2966                                                        kDefaultParams, kDefaultPrivateDnsHostName));
2967             EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
2968         }
2969 
2970         if (config.method == GETHOSTBYNAME) {
2971             const hostent* h_result = gethostbyname(host_name);
2972             if (config.expectResult == EXPECT_SUCCESS) {
2973                 EXPECT_LE(1U, GetNumQueries(dns, host_name));
2974                 ASSERT_TRUE(h_result != nullptr);
2975                 ASSERT_EQ(4, h_result->h_length);
2976                 ASSERT_FALSE(h_result->h_addr_list[0] == nullptr);
2977                 EXPECT_EQ(ADDR4, ToString(h_result));
2978                 EXPECT_TRUE(h_result->h_addr_list[1] == nullptr);
2979                 ExpectDnsEvent(INetdEventListener::EVENT_GETHOSTBYNAME, 0, host_name, {ADDR4});
2980             } else {
2981                 EXPECT_EQ(0U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
2982                 ASSERT_TRUE(h_result == nullptr);
2983                 ASSERT_EQ(HOST_NOT_FOUND, h_errno);
2984                 int returnCode = (config.edns == Edns::DROP) ? RCODE_TIMEOUT : EAI_FAIL;
2985                 ExpectDnsEvent(INetdEventListener::EVENT_GETHOSTBYNAME, returnCode, host_name, {});
2986             }
2987         } else if (config.method == GETADDRINFO) {
2988             ScopedAddrinfo ai_result;
2989             addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
2990             ai_result = safe_getaddrinfo(host_name, nullptr, &hints);
2991             if (config.expectResult == EXPECT_SUCCESS) {
2992                 EXPECT_TRUE(ai_result != nullptr);
2993                 EXPECT_EQ(1U, GetNumQueries(dns, host_name));
2994                 const std::string result_str = ToString(ai_result);
2995                 EXPECT_EQ(ADDR4, result_str);
2996                 ExpectDnsEvent(INetdEventListener::EVENT_GETADDRINFO, 0, host_name, {ADDR4});
2997             } else {
2998                 EXPECT_TRUE(ai_result == nullptr);
2999                 EXPECT_EQ(0U, GetNumQueries(dns, host_name));
3000                 int returnCode = (config.edns == Edns::DROP) ? RCODE_TIMEOUT : EAI_FAIL;
3001                 ExpectDnsEvent(INetdEventListener::EVENT_GETADDRINFO, returnCode, host_name, {});
3002             }
3003         } else {
3004             FAIL() << "Unsupported query method: " << config.method;
3005         }
3006 
3007         tls.clearQueries();
3008         dns.clearQueries();
3009 
3010         // Clear the setup to force the resolver to validate private DNS servers in every test.
3011         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(paramsForCleanup));
3012     }
3013 }
3014 
3015 // DNS-over-TLS validation success, but server does not respond to TLS query after a while.
3016 // Resolver should have a reasonable number of retries instead of spinning forever. We don't have
3017 // an efficient way to know if resolver is stuck in an infinite loop. However, test case will be
3018 // failed due to timeout.
TEST_F(ResolverTest,UnstableTls)3019 TEST_F(ResolverTest, UnstableTls) {
3020     const char CLEARTEXT_ADDR[] = "127.0.0.53";
3021     const char CLEARTEXT_PORT[] = "53";
3022     const char TLS_PORT[] = "853";
3023     const char* host_name1 = "nonexistent1.example.com.";
3024     const char* host_name2 = "nonexistent2.example.com.";
3025     const std::vector<std::string> servers = {CLEARTEXT_ADDR};
3026 
3027     test::DNSResponder dns(CLEARTEXT_ADDR, CLEARTEXT_PORT, ns_rcode::ns_r_servfail);
3028     ASSERT_TRUE(dns.startServer());
3029     dns.setEdns(test::DNSResponder::Edns::FORMERR_ON_EDNS);
3030     test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
3031     ASSERT_TRUE(tls.startServer());
3032     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
3033     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
3034 
3035     // Shutdown TLS server to get an error. It's similar to no response case but without waiting.
3036     tls.stopServer();
3037 
3038     const hostent* h_result = gethostbyname(host_name1);
3039     EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
3040     ASSERT_TRUE(h_result == nullptr);
3041     ASSERT_EQ(HOST_NOT_FOUND, h_errno);
3042 
3043     addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
3044     ScopedAddrinfo ai_result = safe_getaddrinfo(host_name2, nullptr, &hints);
3045     EXPECT_TRUE(ai_result == nullptr);
3046     EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
3047 }
3048 
3049 // DNS-over-TLS validation success, but server does not respond to TLS query after a while.
3050 // Moreover, server responds RCODE=FORMERR even on non-EDNS query.
TEST_F(ResolverTest,BogusDnsServer)3051 TEST_F(ResolverTest, BogusDnsServer) {
3052     const char CLEARTEXT_ADDR[] = "127.0.0.53";
3053     const char CLEARTEXT_PORT[] = "53";
3054     const char TLS_PORT[] = "853";
3055     const char* host_name1 = "nonexistent1.example.com.";
3056     const char* host_name2 = "nonexistent2.example.com.";
3057     const std::vector<std::string> servers = {CLEARTEXT_ADDR};
3058 
3059     test::DNSResponder dns(CLEARTEXT_ADDR, CLEARTEXT_PORT, ns_rcode::ns_r_servfail);
3060     ASSERT_TRUE(dns.startServer());
3061     test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
3062     ASSERT_TRUE(tls.startServer());
3063     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
3064     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
3065 
3066     // Shutdown TLS server to get an error. It's similar to no response case but without waiting.
3067     tls.stopServer();
3068     dns.setEdns(test::DNSResponder::Edns::FORMERR_UNCOND);
3069 
3070     const hostent* h_result = gethostbyname(host_name1);
3071     EXPECT_EQ(0U, GetNumQueries(dns, host_name1));
3072     ASSERT_TRUE(h_result == nullptr);
3073     ASSERT_EQ(HOST_NOT_FOUND, h_errno);
3074 
3075     addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
3076     ScopedAddrinfo ai_result = safe_getaddrinfo(host_name2, nullptr, &hints);
3077     EXPECT_TRUE(ai_result == nullptr);
3078     EXPECT_EQ(0U, GetNumQueries(dns, host_name2));
3079 }
3080 
TEST_F(ResolverTest,GetAddrInfo_Dns64Synthesize)3081 TEST_F(ResolverTest, GetAddrInfo_Dns64Synthesize) {
3082     constexpr char listen_addr[] = "::1";
3083     constexpr char dns64_name[] = "ipv4only.arpa.";
3084     constexpr char host_name[] = "v4only.example.com.";
3085     const std::vector<DnsRecord> records = {
3086             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3087             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3088     };
3089 
3090     test::DNSResponder dns(listen_addr);
3091     StartDns(dns, records);
3092 
3093     std::vector<std::string> servers = {listen_addr};
3094     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3095 
3096     // Start NAT64 prefix discovery and wait for it to complete.
3097     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3098     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3099 
3100     // hints are necessary in order to let netd know which type of addresses the caller is
3101     // interested in.
3102     const addrinfo hints = {.ai_family = AF_UNSPEC};
3103     ScopedAddrinfo result = safe_getaddrinfo("v4only", nullptr, &hints);
3104     EXPECT_TRUE(result != nullptr);
3105     // TODO: BUG: there should only be two queries, one AAAA (which returns no records) and one A
3106     // (which returns 1.2.3.4). But there is an extra AAAA.
3107     EXPECT_EQ(3U, GetNumQueries(dns, host_name));
3108 
3109     std::string result_str = ToString(result);
3110     EXPECT_EQ(result_str, "64:ff9b::102:304");
3111 
3112     // Stopping NAT64 prefix discovery disables synthesis.
3113     EXPECT_TRUE(mDnsClient.resolvService()->stopPrefix64Discovery(TEST_NETID).isOk());
3114     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_NOT_FOUND));
3115 
3116     dns.clearQueries();
3117 
3118     result = safe_getaddrinfo("v4only", nullptr, &hints);
3119     EXPECT_TRUE(result != nullptr);
3120     // TODO: BUG: there should only be one query, an AAAA (which returns no records), because the
3121     // A is already cached. But there is an extra AAAA.
3122     EXPECT_EQ(2U, GetNumQueries(dns, host_name));
3123 
3124     result_str = ToString(result);
3125     EXPECT_EQ(result_str, "1.2.3.4");
3126 }
3127 
TEST_F(ResolverTest,GetAddrInfo_Dns64QuerySpecified)3128 TEST_F(ResolverTest, GetAddrInfo_Dns64QuerySpecified) {
3129     constexpr char listen_addr[] = "::1";
3130     constexpr char dns64_name[] = "ipv4only.arpa.";
3131     constexpr char host_name[] = "v4only.example.com.";
3132     const std::vector<DnsRecord> records = {
3133             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3134             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3135     };
3136 
3137     test::DNSResponder dns(listen_addr);
3138     StartDns(dns, records);
3139     const std::vector<std::string> servers = {listen_addr};
3140     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3141 
3142     // Start NAT64 prefix discovery and wait for it to complete.
3143     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3144     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3145 
3146     // Ensure to synthesize AAAA if AF_INET6 is specified, and not to synthesize AAAA
3147     // in AF_INET case.
3148     addrinfo hints;
3149     memset(&hints, 0, sizeof(hints));
3150     hints.ai_family = AF_INET6;
3151     ScopedAddrinfo result = safe_getaddrinfo("v4only", nullptr, &hints);
3152     EXPECT_TRUE(result != nullptr);
3153     std::string result_str = ToString(result);
3154     EXPECT_EQ(result_str, "64:ff9b::102:304");
3155 
3156     hints.ai_family = AF_INET;
3157     result = safe_getaddrinfo("v4only", nullptr, &hints);
3158     EXPECT_TRUE(result != nullptr);
3159     EXPECT_LE(2U, GetNumQueries(dns, host_name));
3160     result_str = ToString(result);
3161     EXPECT_EQ(result_str, "1.2.3.4");
3162 }
3163 
TEST_F(ResolverTest,GetAddrInfo_Dns64QueryUnspecifiedV6)3164 TEST_F(ResolverTest, GetAddrInfo_Dns64QueryUnspecifiedV6) {
3165     constexpr char listen_addr[] = "::1";
3166     constexpr char dns64_name[] = "ipv4only.arpa.";
3167     constexpr char host_name[] = "v4v6.example.com.";
3168     const std::vector<DnsRecord> records = {
3169             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3170             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3171             {host_name, ns_type::ns_t_aaaa, "2001:db8::1.2.3.4"},
3172     };
3173 
3174     test::DNSResponder dns(listen_addr);
3175     StartDns(dns, records);
3176     const std::vector<std::string> servers = {listen_addr};
3177     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3178 
3179     // Start NAT64 prefix discovery and wait for it to complete.
3180     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3181     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3182 
3183     const addrinfo hints = {.ai_family = AF_UNSPEC};
3184     ScopedAddrinfo result = safe_getaddrinfo("v4v6", nullptr, &hints);
3185     EXPECT_TRUE(result != nullptr);
3186     EXPECT_LE(2U, GetNumQueries(dns, host_name));
3187 
3188     // In AF_UNSPEC case, do not synthesize AAAA if there's at least one AAAA answer.
3189     const std::vector<std::string> result_strs = ToStrings(result);
3190     for (const auto& str : result_strs) {
3191         EXPECT_TRUE(str == "1.2.3.4" || str == "2001:db8::102:304")
3192                 << ", result_str='" << str << "'";
3193     }
3194 }
3195 
TEST_F(ResolverTest,GetAddrInfo_Dns64QueryUnspecifiedNoV6)3196 TEST_F(ResolverTest, GetAddrInfo_Dns64QueryUnspecifiedNoV6) {
3197     constexpr char listen_addr[] = "::1";
3198     constexpr char dns64_name[] = "ipv4only.arpa.";
3199     constexpr char host_name[] = "v4v6.example.com.";
3200     const std::vector<DnsRecord> records = {
3201             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3202             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3203     };
3204 
3205     test::DNSResponder dns(listen_addr);
3206     StartDns(dns, records);
3207     const std::vector<std::string> servers = {listen_addr};
3208     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3209 
3210     // Start NAT64 prefix discovery and wait for it to complete.
3211     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3212     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3213 
3214     const addrinfo hints = {.ai_family = AF_UNSPEC};
3215     ScopedAddrinfo result = safe_getaddrinfo("v4v6", nullptr, &hints);
3216     EXPECT_TRUE(result != nullptr);
3217     EXPECT_LE(2U, GetNumQueries(dns, host_name));
3218 
3219     // In AF_UNSPEC case, synthesize AAAA if there's no AAAA answer.
3220     std::string result_str = ToString(result);
3221     EXPECT_EQ(result_str, "64:ff9b::102:304");
3222 }
3223 
TEST_F(ResolverTest,GetAddrInfo_Dns64QuerySpecialUseIPv4Addresses)3224 TEST_F(ResolverTest, GetAddrInfo_Dns64QuerySpecialUseIPv4Addresses) {
3225     constexpr char THIS_NETWORK[] = "this_network";
3226     constexpr char LOOPBACK[] = "loopback";
3227     constexpr char LINK_LOCAL[] = "link_local";
3228     constexpr char MULTICAST[] = "multicast";
3229     constexpr char LIMITED_BROADCAST[] = "limited_broadcast";
3230 
3231     constexpr char ADDR_THIS_NETWORK[] = "0.0.0.1";
3232     constexpr char ADDR_LOOPBACK[] = "127.0.0.1";
3233     constexpr char ADDR_LINK_LOCAL[] = "169.254.0.1";
3234     constexpr char ADDR_MULTICAST[] = "224.0.0.1";
3235     constexpr char ADDR_LIMITED_BROADCAST[] = "255.255.255.255";
3236 
3237     constexpr char listen_addr[] = "::1";
3238     constexpr char dns64_name[] = "ipv4only.arpa.";
3239 
3240     test::DNSResponder dns(listen_addr);
3241     StartDns(dns, {{dns64_name, ns_type::ns_t_aaaa, "64:ff9b::"}});
3242     const std::vector<std::string> servers = {listen_addr};
3243     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3244 
3245     // Start NAT64 prefix discovery and wait for it to complete.
3246     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3247     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3248 
3249     // clang-format off
3250     static const struct TestConfig {
3251         std::string name;
3252         std::string addr;
3253 
3254         std::string asHostName() const { return StringPrintf("%s.example.com.", name.c_str()); }
3255     } testConfigs[]{
3256         {THIS_NETWORK,      ADDR_THIS_NETWORK},
3257         {LOOPBACK,          ADDR_LOOPBACK},
3258         {LINK_LOCAL,        ADDR_LINK_LOCAL},
3259         {MULTICAST,         ADDR_MULTICAST},
3260         {LIMITED_BROADCAST, ADDR_LIMITED_BROADCAST}
3261     };
3262     // clang-format on
3263 
3264     for (const auto& config : testConfigs) {
3265         const std::string testHostName = config.asHostName();
3266         SCOPED_TRACE(testHostName);
3267 
3268         const char* host_name = testHostName.c_str();
3269         dns.addMapping(host_name, ns_type::ns_t_a, config.addr.c_str());
3270 
3271         addrinfo hints;
3272         memset(&hints, 0, sizeof(hints));
3273         hints.ai_family = AF_INET6;
3274         ScopedAddrinfo result = safe_getaddrinfo(config.name.c_str(), nullptr, &hints);
3275         // In AF_INET6 case, don't return IPv4 answers
3276         EXPECT_TRUE(result == nullptr);
3277         EXPECT_LE(2U, GetNumQueries(dns, host_name));
3278         dns.clearQueries();
3279 
3280         memset(&hints, 0, sizeof(hints));
3281         hints.ai_family = AF_UNSPEC;
3282         result = safe_getaddrinfo(config.name.c_str(), nullptr, &hints);
3283         EXPECT_TRUE(result != nullptr);
3284         // Expect IPv6 query only. IPv4 answer has been cached in previous query.
3285         EXPECT_LE(1U, GetNumQueries(dns, host_name));
3286         // In AF_UNSPEC case, don't synthesize special use IPv4 address.
3287         std::string result_str = ToString(result);
3288         EXPECT_EQ(result_str, config.addr.c_str());
3289         dns.clearQueries();
3290     }
3291 }
3292 
TEST_F(ResolverTest,GetAddrInfo_Dns64QueryWithNullArgumentHints)3293 TEST_F(ResolverTest, GetAddrInfo_Dns64QueryWithNullArgumentHints) {
3294     constexpr char listen_addr[] = "::1";
3295     constexpr char dns64_name[] = "ipv4only.arpa.";
3296     constexpr char host_name[] = "v4only.example.com.";
3297     constexpr char host_name2[] = "v4v6.example.com.";
3298     const std::vector<DnsRecord> records = {
3299             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3300             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3301             {host_name2, ns_type::ns_t_a, "1.2.3.4"},
3302             {host_name2, ns_type::ns_t_aaaa, "2001:db8::1.2.3.4"},
3303     };
3304 
3305     test::DNSResponder dns(listen_addr);
3306     StartDns(dns, records);
3307     const std::vector<std::string> servers = {listen_addr};
3308     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3309 
3310     // Start NAT64 prefix discovery and wait for it to complete.
3311     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3312     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3313 
3314     // Assign argument hints of getaddrinfo() as null is equivalent to set ai_family AF_UNSPEC.
3315     // In AF_UNSPEC case, synthesize AAAA if there has A answer only.
3316     ScopedAddrinfo result = safe_getaddrinfo("v4only", nullptr, nullptr);
3317     EXPECT_TRUE(result != nullptr);
3318     EXPECT_LE(2U, GetNumQueries(dns, host_name));
3319     std::string result_str = ToString(result);
3320     EXPECT_EQ(result_str, "64:ff9b::102:304");
3321     dns.clearQueries();
3322 
3323     // In AF_UNSPEC case, do not synthesize AAAA if there's at least one AAAA answer.
3324     result = safe_getaddrinfo("v4v6", nullptr, nullptr);
3325     EXPECT_TRUE(result != nullptr);
3326     EXPECT_LE(2U, GetNumQueries(dns, host_name2));
3327     std::vector<std::string> result_strs = ToStrings(result);
3328     for (const auto& str : result_strs) {
3329         EXPECT_TRUE(str == "1.2.3.4" || str == "2001:db8::102:304")
3330                 << ", result_str='" << str << "'";
3331     }
3332 }
3333 
TEST_F(ResolverTest,GetAddrInfo_Dns64QueryNullArgumentNode)3334 TEST_F(ResolverTest, GetAddrInfo_Dns64QueryNullArgumentNode) {
3335     constexpr char ADDR_ANYADDR_V4[] = "0.0.0.0";
3336     constexpr char ADDR_ANYADDR_V6[] = "::";
3337     constexpr char ADDR_LOCALHOST_V4[] = "127.0.0.1";
3338     constexpr char ADDR_LOCALHOST_V6[] = "::1";
3339 
3340     constexpr char PORT_NAME_HTTP[] = "http";
3341     constexpr char PORT_NUMBER_HTTP[] = "80";
3342 
3343     constexpr char listen_addr[] = "::1";
3344     constexpr char dns64_name[] = "ipv4only.arpa.";
3345 
3346     test::DNSResponder dns(listen_addr);
3347     StartDns(dns, {{dns64_name, ns_type::ns_t_aaaa, "64:ff9b::"}});
3348     const std::vector<std::string> servers = {listen_addr};
3349     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3350 
3351     // Start NAT64 prefix discovery and wait for it to complete.
3352     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3353     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3354 
3355     // clang-format off
3356     // If node is null, return address is listed by libc/getaddrinfo.c as follows.
3357     // - passive socket -> anyaddr (0.0.0.0 or ::)
3358     // - non-passive socket -> localhost (127.0.0.1 or ::1)
3359     static const struct TestConfig {
3360         int flag;
3361         std::string addr_v4;
3362         std::string addr_v6;
3363 
3364         std::string asParameters() const {
3365             return StringPrintf("flag=%d, addr_v4=%s, addr_v6=%s", flag, addr_v4.c_str(),
3366                                 addr_v6.c_str());
3367         }
3368     } testConfigs[]{
3369         {0 /* non-passive */, ADDR_LOCALHOST_V4, ADDR_LOCALHOST_V6},
3370         {AI_PASSIVE,          ADDR_ANYADDR_V4,   ADDR_ANYADDR_V6}
3371     };
3372     // clang-format on
3373 
3374     for (const auto& config : testConfigs) {
3375         SCOPED_TRACE(config.asParameters());
3376 
3377         addrinfo hints = {
3378                 .ai_flags = config.flag,
3379                 .ai_family = AF_UNSPEC,  // any address family
3380                 .ai_socktype = 0,        // any type
3381                 .ai_protocol = 0,        // any protocol
3382         };
3383 
3384         // Assign hostname as null and service as port name.
3385         ScopedAddrinfo result = safe_getaddrinfo(nullptr, PORT_NAME_HTTP, &hints);
3386         ASSERT_TRUE(result != nullptr);
3387 
3388         // Can't be synthesized because it should not get into Netd.
3389         std::vector<std::string> result_strs = ToStrings(result);
3390         for (const auto& str : result_strs) {
3391             EXPECT_TRUE(str == config.addr_v4 || str == config.addr_v6)
3392                     << ", result_str='" << str << "'";
3393         }
3394 
3395         // Assign hostname as null and service as numeric port number.
3396         hints.ai_flags = config.flag | AI_NUMERICSERV;
3397         result = safe_getaddrinfo(nullptr, PORT_NUMBER_HTTP, &hints);
3398         ASSERT_TRUE(result != nullptr);
3399 
3400         // Can't be synthesized because it should not get into Netd.
3401         result_strs = ToStrings(result);
3402         for (const auto& str : result_strs) {
3403             EXPECT_TRUE(str == config.addr_v4 || str == config.addr_v6)
3404                     << ", result_str='" << str << "'";
3405         }
3406     }
3407 }
3408 
TEST_F(ResolverTest,GetHostByAddr_ReverseDnsQueryWithHavingNat64Prefix)3409 TEST_F(ResolverTest, GetHostByAddr_ReverseDnsQueryWithHavingNat64Prefix) {
3410     struct hostent* result = nullptr;
3411     struct in_addr v4addr;
3412     struct in6_addr v6addr;
3413 
3414     constexpr char listen_addr[] = "::1";
3415     constexpr char dns64_name[] = "ipv4only.arpa.";
3416     constexpr char ptr_name[] = "v4v6.example.com.";
3417     // PTR record for IPv4 address 1.2.3.4
3418     constexpr char ptr_addr_v4[] = "4.3.2.1.in-addr.arpa.";
3419     // PTR record for IPv6 address 2001:db8::102:304
3420     constexpr char ptr_addr_v6[] =
3421             "4.0.3.0.2.0.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.";
3422     const std::vector<DnsRecord> records = {
3423             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3424             {ptr_addr_v4, ns_type::ns_t_ptr, ptr_name},
3425             {ptr_addr_v6, ns_type::ns_t_ptr, ptr_name},
3426     };
3427 
3428     test::DNSResponder dns(listen_addr);
3429     StartDns(dns, records);
3430     const std::vector<std::string> servers = {listen_addr};
3431     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3432 
3433     // Start NAT64 prefix discovery and wait for it to complete.
3434     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3435     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3436 
3437     // Reverse IPv4 DNS query. Prefix should have no effect on it.
3438     inet_pton(AF_INET, "1.2.3.4", &v4addr);
3439     result = gethostbyaddr(&v4addr, sizeof(v4addr), AF_INET);
3440     ASSERT_TRUE(result != nullptr);
3441     std::string result_str = result->h_name ? result->h_name : "null";
3442     EXPECT_EQ(result_str, "v4v6.example.com");
3443 
3444     // Reverse IPv6 DNS query. Prefix should have no effect on it.
3445     inet_pton(AF_INET6, "2001:db8::102:304", &v6addr);
3446     result = gethostbyaddr(&v6addr, sizeof(v6addr), AF_INET6);
3447     ASSERT_TRUE(result != nullptr);
3448     result_str = result->h_name ? result->h_name : "null";
3449     EXPECT_EQ(result_str, "v4v6.example.com");
3450 }
3451 
TEST_F(ResolverTest,GetHostByAddr_ReverseDns64Query)3452 TEST_F(ResolverTest, GetHostByAddr_ReverseDns64Query) {
3453     constexpr char listen_addr[] = "::1";
3454     constexpr char dns64_name[] = "ipv4only.arpa.";
3455     constexpr char ptr_name[] = "v4only.example.com.";
3456     // PTR record for IPv4 address 1.2.3.4
3457     constexpr char ptr_addr_v4[] = "4.3.2.1.in-addr.arpa.";
3458     // PTR record for IPv6 address 64:ff9b::1.2.3.4
3459     constexpr char ptr_addr_v6_nomapping[] =
3460             "4.0.3.0.2.0.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.b.9.f.f.4.6.0.0.ip6.arpa.";
3461     constexpr char ptr_name_v6_synthesis[] = "v6synthesis.example.com.";
3462     // PTR record for IPv6 address 64:ff9b::5.6.7.8
3463     constexpr char ptr_addr_v6_synthesis[] =
3464             "8.0.7.0.6.0.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.b.9.f.f.4.6.0.0.ip6.arpa.";
3465     const std::vector<DnsRecord> records = {
3466             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3467             {ptr_addr_v4, ns_type::ns_t_ptr, ptr_name},
3468             {ptr_addr_v6_synthesis, ns_type::ns_t_ptr, ptr_name_v6_synthesis},
3469     };
3470 
3471     test::DNSResponder dns(listen_addr);
3472     StartDns(dns, records);
3473     // "ptr_addr_v6_nomapping" is not mapped in DNS server
3474     const std::vector<std::string> servers = {listen_addr};
3475     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3476 
3477     // Start NAT64 prefix discovery and wait for it to complete.
3478     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3479     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3480 
3481     // Synthesized PTR record doesn't exist on DNS server
3482     // Reverse IPv6 DNS64 query while DNS server doesn't have an answer for synthesized address.
3483     // After querying synthesized address failed, expect that prefix is removed from IPv6
3484     // synthesized address and do reverse IPv4 query instead.
3485     struct in6_addr v6addr;
3486     inet_pton(AF_INET6, "64:ff9b::1.2.3.4", &v6addr);
3487     struct hostent* result = gethostbyaddr(&v6addr, sizeof(v6addr), AF_INET6);
3488     ASSERT_TRUE(result != nullptr);
3489     EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v6_nomapping));  // PTR record not exist
3490     EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v4));            // PTR record exist
3491     std::string result_str = result->h_name ? result->h_name : "null";
3492     EXPECT_EQ(result_str, "v4only.example.com");
3493     // Check that return address has been mapped from IPv4 to IPv6 address because Netd
3494     // removes NAT64 prefix and does IPv4 DNS reverse lookup in this case. Then, Netd
3495     // fakes the return IPv4 address as original queried IPv6 address.
3496     result_str = ToString(result);
3497     EXPECT_EQ(result_str, "64:ff9b::102:304");
3498     dns.clearQueries();
3499 
3500     // Synthesized PTR record exists on DNS server
3501     // Reverse IPv6 DNS64 query while DNS server has an answer for synthesized address.
3502     // Expect to Netd pass through synthesized address for DNS queries.
3503     inet_pton(AF_INET6, "64:ff9b::5.6.7.8", &v6addr);
3504     result = gethostbyaddr(&v6addr, sizeof(v6addr), AF_INET6);
3505     ASSERT_TRUE(result != nullptr);
3506     EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v6_synthesis));
3507     result_str = result->h_name ? result->h_name : "null";
3508     EXPECT_EQ(result_str, "v6synthesis.example.com");
3509 }
3510 
TEST_F(ResolverTest,GetHostByAddr_ReverseDns64QueryFromHostFile)3511 TEST_F(ResolverTest, GetHostByAddr_ReverseDns64QueryFromHostFile) {
3512     constexpr char dns64_name[] = "ipv4only.arpa.";
3513     constexpr char host_name[] = "localhost";
3514     // The address is synthesized by prefix64:localhost.
3515     constexpr char host_addr[] = "64:ff9b::7f00:1";
3516     constexpr char listen_addr[] = "::1";
3517 
3518     test::DNSResponder dns(listen_addr);
3519     StartDns(dns, {{dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"}});
3520     const std::vector<std::string> servers = {listen_addr};
3521     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3522 
3523     // Start NAT64 prefix discovery and wait for it to complete.
3524     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3525     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3526 
3527     // Using synthesized "localhost" address to be a trick for resolving host name
3528     // from host file /etc/hosts and "localhost" is the only name in /etc/hosts. Note that this is
3529     // not realistic: the code never synthesizes AAAA records for addresses in 127.0.0.0/8.
3530     struct in6_addr v6addr;
3531     inet_pton(AF_INET6, host_addr, &v6addr);
3532     struct hostent* result = gethostbyaddr(&v6addr, sizeof(v6addr), AF_INET6);
3533     ASSERT_TRUE(result != nullptr);
3534     // Expect no DNS queries; localhost is resolved via /etc/hosts.
3535     EXPECT_EQ(0U, GetNumQueries(dns, host_name));
3536 
3537     ASSERT_EQ(sizeof(in6_addr), (unsigned)result->h_length);
3538     ASSERT_EQ(AF_INET6, result->h_addrtype);
3539     std::string result_str = ToString(result);
3540     EXPECT_EQ(result_str, host_addr);
3541     result_str = result->h_name ? result->h_name : "null";
3542     EXPECT_EQ(result_str, host_name);
3543 }
3544 
TEST_F(ResolverTest,GetHostByAddr_cnamesClasslessReverseDelegation)3545 TEST_F(ResolverTest, GetHostByAddr_cnamesClasslessReverseDelegation) {
3546     // IPv4 addresses in the subnet with notation '/' or '-'.
3547     constexpr char addr_slash[] = "192.0.2.1";
3548     constexpr char addr_hyphen[] = "192.0.3.1";
3549 
3550     // Used to verify DNS reverse query for classless reverse lookup zone. See detail in RFC 2317
3551     // section 4.
3552     const static std::vector<DnsRecord> records = {
3553             // The records for reverse querying "192.0.2.1" in the subnet with notation '/'.
3554             {"1.2.0.192.in-addr.arpa.", ns_type::ns_t_cname, "1.0/25.2.0.192.in-addr.arpa."},
3555             {"1.0/25.2.0.192.in-addr.arpa.", ns_type::ns_t_ptr, kHelloExampleCom},
3556 
3557             // The records for reverse querying "192.0.3.1" in the subnet with notation '-'.
3558             {"1.3.0.192.in-addr.arpa.", ns_type::ns_t_cname, "1.0-127.3.0.192.in-addr.arpa."},
3559             {"1.0-127.3.0.192.in-addr.arpa.", ns_type::ns_t_ptr, kHelloExampleCom},
3560     };
3561 
3562     test::DNSResponder dns;
3563     StartDns(dns, records);
3564     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
3565 
3566     for (const auto& address : {addr_slash, addr_hyphen}) {
3567         SCOPED_TRACE(address);
3568 
3569         in_addr v4addr;
3570         ASSERT_TRUE(inet_pton(AF_INET, address, &v4addr));
3571         hostent* result = gethostbyaddr(&v4addr, sizeof(v4addr), AF_INET);
3572         ASSERT_TRUE(result != nullptr);
3573         EXPECT_STREQ("hello.example.com", result->h_name);
3574     }
3575 }
3576 
TEST_F(ResolverTest,GetNameInfo_ReverseDnsQueryWithHavingNat64Prefix)3577 TEST_F(ResolverTest, GetNameInfo_ReverseDnsQueryWithHavingNat64Prefix) {
3578     constexpr char listen_addr[] = "::1";
3579     constexpr char dns64_name[] = "ipv4only.arpa.";
3580     constexpr char ptr_name[] = "v4v6.example.com.";
3581     // PTR record for IPv4 address 1.2.3.4
3582     constexpr char ptr_addr_v4[] = "4.3.2.1.in-addr.arpa.";
3583     // PTR record for IPv6 address 2001:db8::102:304
3584     constexpr char ptr_addr_v6[] =
3585             "4.0.3.0.2.0.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.";
3586     const std::vector<DnsRecord> records = {
3587             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3588             {ptr_addr_v4, ns_type::ns_t_ptr, ptr_name},
3589             {ptr_addr_v6, ns_type::ns_t_ptr, ptr_name},
3590     };
3591 
3592     test::DNSResponder dns(listen_addr);
3593     StartDns(dns, records);
3594     const std::vector<std::string> servers = {listen_addr};
3595     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3596 
3597     // Start NAT64 prefix discovery and wait for it to complete.
3598     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3599     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3600 
3601     // clang-format off
3602     static const struct TestConfig {
3603         int flag;
3604         int family;
3605         std::string addr;
3606         std::string host;
3607 
3608         std::string asParameters() const {
3609             return StringPrintf("flag=%d, family=%d, addr=%s, host=%s", flag, family, addr.c_str(),
3610                                 host.c_str());
3611         }
3612     } testConfigs[]{
3613         {NI_NAMEREQD,    AF_INET,  "1.2.3.4",           "v4v6.example.com"},
3614         {NI_NUMERICHOST, AF_INET,  "1.2.3.4",           "1.2.3.4"},
3615         {0,              AF_INET,  "1.2.3.4",           "v4v6.example.com"},
3616         {0,              AF_INET,  "5.6.7.8",           "5.6.7.8"},           // unmapped
3617         {NI_NAMEREQD,    AF_INET6, "2001:db8::102:304", "v4v6.example.com"},
3618         {NI_NUMERICHOST, AF_INET6, "2001:db8::102:304", "2001:db8::102:304"},
3619         {0,              AF_INET6, "2001:db8::102:304", "v4v6.example.com"},
3620         {0,              AF_INET6, "2001:db8::506:708", "2001:db8::506:708"}, // unmapped
3621     };
3622     // clang-format on
3623 
3624     // Reverse IPv4/IPv6 DNS query. Prefix should have no effect on it.
3625     for (const auto& config : testConfigs) {
3626         SCOPED_TRACE(config.asParameters());
3627 
3628         int rv;
3629         char host[NI_MAXHOST];
3630         struct sockaddr_in sin;
3631         struct sockaddr_in6 sin6;
3632         if (config.family == AF_INET) {
3633             memset(&sin, 0, sizeof(sin));
3634             sin.sin_family = AF_INET;
3635             inet_pton(AF_INET, config.addr.c_str(), &sin.sin_addr);
3636             rv = getnameinfo((const struct sockaddr*)&sin, sizeof(sin), host, sizeof(host), nullptr,
3637                              0, config.flag);
3638             if (config.flag == NI_NAMEREQD) EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v4));
3639         } else if (config.family == AF_INET6) {
3640             memset(&sin6, 0, sizeof(sin6));
3641             sin6.sin6_family = AF_INET6;
3642             inet_pton(AF_INET6, config.addr.c_str(), &sin6.sin6_addr);
3643             rv = getnameinfo((const struct sockaddr*)&sin6, sizeof(sin6), host, sizeof(host),
3644                              nullptr, 0, config.flag);
3645             if (config.flag == NI_NAMEREQD) EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v6));
3646         }
3647         ASSERT_EQ(0, rv);
3648         std::string result_str = host;
3649         EXPECT_EQ(result_str, config.host);
3650         dns.clearQueries();
3651     }
3652 }
3653 
TEST_F(ResolverTest,GetNameInfo_ReverseDns64Query)3654 TEST_F(ResolverTest, GetNameInfo_ReverseDns64Query) {
3655     constexpr char listen_addr[] = "::1";
3656     constexpr char dns64_name[] = "ipv4only.arpa.";
3657     constexpr char ptr_name[] = "v4only.example.com.";
3658     // PTR record for IPv4 address 1.2.3.4
3659     constexpr char ptr_addr_v4[] = "4.3.2.1.in-addr.arpa.";
3660     // PTR record for IPv6 address 64:ff9b::1.2.3.4
3661     constexpr char ptr_addr_v6_nomapping[] =
3662             "4.0.3.0.2.0.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.b.9.f.f.4.6.0.0.ip6.arpa.";
3663     constexpr char ptr_name_v6_synthesis[] = "v6synthesis.example.com.";
3664     // PTR record for IPv6 address 64:ff9b::5.6.7.8
3665     constexpr char ptr_addr_v6_synthesis[] =
3666             "8.0.7.0.6.0.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.b.9.f.f.4.6.0.0.ip6.arpa.";
3667     const std::vector<DnsRecord> records = {
3668             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3669             {ptr_addr_v4, ns_type::ns_t_ptr, ptr_name},
3670             {ptr_addr_v6_synthesis, ns_type::ns_t_ptr, ptr_name_v6_synthesis},
3671     };
3672 
3673     test::DNSResponder dns(listen_addr);
3674     StartDns(dns, records);
3675     const std::vector<std::string> servers = {listen_addr};
3676     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3677 
3678     // Start NAT64 prefix discovery and wait for it to complete.
3679     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3680     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3681 
3682     // clang-format off
3683     static const struct TestConfig {
3684         bool hasSynthesizedPtrRecord;
3685         int flag;
3686         std::string addr;
3687         std::string host;
3688 
3689         std::string asParameters() const {
3690             return StringPrintf("hasSynthesizedPtrRecord=%d, flag=%d, addr=%s, host=%s",
3691                                 hasSynthesizedPtrRecord, flag, addr.c_str(), host.c_str());
3692         }
3693     } testConfigs[]{
3694         {false, NI_NAMEREQD,    "64:ff9b::102:304", "v4only.example.com"},
3695         {false, NI_NUMERICHOST, "64:ff9b::102:304", "64:ff9b::102:304"},
3696         {false, 0,              "64:ff9b::102:304", "v4only.example.com"},
3697         {true,  NI_NAMEREQD,    "64:ff9b::506:708", "v6synthesis.example.com"},
3698         {true,  NI_NUMERICHOST, "64:ff9b::506:708", "64:ff9b::506:708"},
3699         {true,  0,              "64:ff9b::506:708", "v6synthesis.example.com"}
3700     };
3701     // clang-format on
3702 
3703     // hasSynthesizedPtrRecord = false
3704     //   Synthesized PTR record doesn't exist on DNS server
3705     //   Reverse IPv6 DNS64 query while DNS server doesn't have an answer for synthesized address.
3706     //   After querying synthesized address failed, expect that prefix is removed from IPv6
3707     //   synthesized address and do reverse IPv4 query instead.
3708     //
3709     // hasSynthesizedPtrRecord = true
3710     //   Synthesized PTR record exists on DNS server
3711     //   Reverse IPv6 DNS64 query while DNS server has an answer for synthesized address.
3712     //   Expect to just pass through synthesized address for DNS queries.
3713     for (const auto& config : testConfigs) {
3714         SCOPED_TRACE(config.asParameters());
3715 
3716         char host[NI_MAXHOST];
3717         struct sockaddr_in6 sin6;
3718         memset(&sin6, 0, sizeof(sin6));
3719         sin6.sin6_family = AF_INET6;
3720         inet_pton(AF_INET6, config.addr.c_str(), &sin6.sin6_addr);
3721         int rv = getnameinfo((const struct sockaddr*)&sin6, sizeof(sin6), host, sizeof(host),
3722                              nullptr, 0, config.flag);
3723         ASSERT_EQ(0, rv);
3724         if (config.flag == NI_NAMEREQD) {
3725             if (config.hasSynthesizedPtrRecord) {
3726                 EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v6_synthesis));
3727             } else {
3728                 EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v6_nomapping));  // PTR record not exist.
3729                 EXPECT_LE(1U, GetNumQueries(dns, ptr_addr_v4));            // PTR record exist.
3730             }
3731         }
3732         std::string result_str = host;
3733         EXPECT_EQ(result_str, config.host);
3734         dns.clearQueries();
3735     }
3736 }
3737 
TEST_F(ResolverTest,GetNameInfo_ReverseDns64QueryFromHostFile)3738 TEST_F(ResolverTest, GetNameInfo_ReverseDns64QueryFromHostFile) {
3739     constexpr char dns64_name[] = "ipv4only.arpa.";
3740     constexpr char host_name[] = "localhost";
3741     // The address is synthesized by prefix64:localhost.
3742     constexpr char host_addr[] = "64:ff9b::7f00:1";
3743     constexpr char listen_addr[] = "::1";
3744 
3745     test::DNSResponder dns(listen_addr);
3746 
3747     StartDns(dns, {{dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"}});
3748     const std::vector<std::string> servers = {listen_addr};
3749     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3750 
3751     // Start NAT64 prefix discovery and wait for it to complete.
3752     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3753     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3754 
3755     // Using synthesized "localhost" address to be a trick for resolving host name
3756     // from host file /etc/hosts and "localhost" is the only name in /etc/hosts. Note that this is
3757     // not realistic: the code never synthesizes AAAA records for addresses in 127.0.0.0/8.
3758     char host[NI_MAXHOST];
3759     struct sockaddr_in6 sin6 = {.sin6_family = AF_INET6};
3760     inet_pton(AF_INET6, host_addr, &sin6.sin6_addr);
3761     int rv = getnameinfo((const struct sockaddr*)&sin6, sizeof(sin6), host, sizeof(host), nullptr,
3762                          0, NI_NAMEREQD);
3763     ASSERT_EQ(0, rv);
3764     // Expect no DNS queries; localhost is resolved via /etc/hosts.
3765     EXPECT_EQ(0U, GetNumQueries(dns, host_name));
3766 
3767     std::string result_str = host;
3768     EXPECT_EQ(result_str, host_name);
3769 }
3770 
TEST_F(ResolverTest,GetNameInfo_cnamesClasslessReverseDelegation)3771 TEST_F(ResolverTest, GetNameInfo_cnamesClasslessReverseDelegation) {
3772     // IPv4 addresses in the subnet with notation '/' or '-'.
3773     constexpr char addr_slash[] = "192.0.2.1";
3774     constexpr char addr_hyphen[] = "192.0.3.1";
3775 
3776     // Used to verify DNS reverse query for classless reverse lookup zone. See detail in RFC 2317
3777     // section 4.
3778     const static std::vector<DnsRecord> records = {
3779             // The records for reverse querying "192.0.2.1" in the subnet with notation '/'.
3780             {"1.2.0.192.in-addr.arpa.", ns_type::ns_t_cname, "1.0/25.2.0.192.in-addr.arpa."},
3781             {"1.0/25.2.0.192.in-addr.arpa.", ns_type::ns_t_ptr, kHelloExampleCom},
3782 
3783             // The records for reverse querying "192.0.3.1" in the subnet with notation '-'.
3784             {"1.3.0.192.in-addr.arpa.", ns_type::ns_t_cname, "1.0-127.3.0.192.in-addr.arpa."},
3785             {"1.0-127.3.0.192.in-addr.arpa.", ns_type::ns_t_ptr, kHelloExampleCom},
3786     };
3787 
3788     test::DNSResponder dns;
3789     StartDns(dns, records);
3790     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
3791 
3792     for (const auto& address : {addr_slash, addr_hyphen}) {
3793         SCOPED_TRACE(address);
3794 
3795         char host[NI_MAXHOST];
3796         sockaddr_in sin = {.sin_family = AF_INET};
3797         ASSERT_TRUE(inet_pton(AF_INET, address, &sin.sin_addr));
3798         int rv = getnameinfo((const sockaddr*)&sin, sizeof(sin), host, sizeof(host), nullptr, 0,
3799                              NI_NAMEREQD);
3800         ASSERT_EQ(0, rv);
3801         EXPECT_STREQ("hello.example.com", host);
3802     }
3803 }
3804 
TEST_F(ResolverTest,GetHostByName2_Dns64Synthesize)3805 TEST_F(ResolverTest, GetHostByName2_Dns64Synthesize) {
3806     constexpr char listen_addr[] = "::1";
3807     constexpr char dns64_name[] = "ipv4only.arpa.";
3808     constexpr char host_name[] = "ipv4only.example.com.";
3809     const std::vector<DnsRecord> records = {
3810             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3811             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3812     };
3813 
3814     test::DNSResponder dns(listen_addr);
3815     StartDns(dns, records);
3816     const std::vector<std::string> servers = {listen_addr};
3817     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3818 
3819     // Start NAT64 prefix discovery and wait for it to complete.
3820     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3821     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3822 
3823     // Query an IPv4-only hostname. Expect that gets a synthesized address.
3824     struct hostent* result = gethostbyname2("ipv4only", AF_INET6);
3825     ASSERT_TRUE(result != nullptr);
3826     EXPECT_LE(1U, GetNumQueries(dns, host_name));
3827     std::string result_str = ToString(result);
3828     EXPECT_EQ(result_str, "64:ff9b::102:304");
3829 }
3830 
TEST_F(ResolverTest,GetHostByName2_DnsQueryWithHavingNat64Prefix)3831 TEST_F(ResolverTest, GetHostByName2_DnsQueryWithHavingNat64Prefix) {
3832     constexpr char dns64_name[] = "ipv4only.arpa.";
3833     constexpr char host_name[] = "v4v6.example.com.";
3834     constexpr char listen_addr[] = "::1";
3835     const std::vector<DnsRecord> records = {
3836             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
3837             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3838             {host_name, ns_type::ns_t_aaaa, "2001:db8::1.2.3.4"},
3839     };
3840 
3841     test::DNSResponder dns(listen_addr);
3842     StartDns(dns, records);
3843     const std::vector<std::string> servers = {listen_addr};
3844     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3845 
3846     // Start NAT64 prefix discovery and wait for it to complete.
3847     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3848     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3849 
3850     // IPv4 DNS query. Prefix should have no effect on it.
3851     struct hostent* result = gethostbyname2("v4v6", AF_INET);
3852     ASSERT_TRUE(result != nullptr);
3853     EXPECT_LE(1U, GetNumQueries(dns, host_name));
3854     std::string result_str = ToString(result);
3855     EXPECT_EQ(result_str, "1.2.3.4");
3856     dns.clearQueries();
3857 
3858     // IPv6 DNS query. Prefix should have no effect on it.
3859     result = gethostbyname2("v4v6", AF_INET6);
3860     ASSERT_TRUE(result != nullptr);
3861     EXPECT_LE(1U, GetNumQueries(dns, host_name));
3862     result_str = ToString(result);
3863     EXPECT_EQ(result_str, "2001:db8::102:304");
3864 }
3865 
TEST_F(ResolverTest,GetHostByName2_Dns64QuerySpecialUseIPv4Addresses)3866 TEST_F(ResolverTest, GetHostByName2_Dns64QuerySpecialUseIPv4Addresses) {
3867     constexpr char THIS_NETWORK[] = "this_network";
3868     constexpr char LOOPBACK[] = "loopback";
3869     constexpr char LINK_LOCAL[] = "link_local";
3870     constexpr char MULTICAST[] = "multicast";
3871     constexpr char LIMITED_BROADCAST[] = "limited_broadcast";
3872 
3873     constexpr char ADDR_THIS_NETWORK[] = "0.0.0.1";
3874     constexpr char ADDR_LOOPBACK[] = "127.0.0.1";
3875     constexpr char ADDR_LINK_LOCAL[] = "169.254.0.1";
3876     constexpr char ADDR_MULTICAST[] = "224.0.0.1";
3877     constexpr char ADDR_LIMITED_BROADCAST[] = "255.255.255.255";
3878 
3879     constexpr char listen_addr[] = "::1";
3880     constexpr char dns64_name[] = "ipv4only.arpa.";
3881 
3882     test::DNSResponder dns(listen_addr);
3883     StartDns(dns, {{dns64_name, ns_type::ns_t_aaaa, "64:ff9b::"}});
3884     const std::vector<std::string> servers = {listen_addr};
3885     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3886 
3887     // Start NAT64 prefix discovery and wait for it to complete.
3888     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3889     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3890 
3891     // clang-format off
3892     static const struct TestConfig {
3893         std::string name;
3894         std::string addr;
3895 
3896         std::string asHostName() const {
3897             return StringPrintf("%s.example.com.", name.c_str());
3898         }
3899     } testConfigs[]{
3900         {THIS_NETWORK,      ADDR_THIS_NETWORK},
3901         {LOOPBACK,          ADDR_LOOPBACK},
3902         {LINK_LOCAL,        ADDR_LINK_LOCAL},
3903         {MULTICAST,         ADDR_MULTICAST},
3904         {LIMITED_BROADCAST, ADDR_LIMITED_BROADCAST}
3905     };
3906     // clang-format on
3907 
3908     for (const auto& config : testConfigs) {
3909         const std::string testHostName = config.asHostName();
3910         SCOPED_TRACE(testHostName);
3911 
3912         const char* host_name = testHostName.c_str();
3913         dns.addMapping(host_name, ns_type::ns_t_a, config.addr.c_str());
3914 
3915         struct hostent* result = gethostbyname2(config.name.c_str(), AF_INET6);
3916         EXPECT_LE(1U, GetNumQueries(dns, host_name));
3917 
3918         // In AF_INET6 case, don't synthesize special use IPv4 address.
3919         // Expect to have no answer
3920         EXPECT_EQ(nullptr, result);
3921 
3922         dns.clearQueries();
3923     }
3924 }
3925 
TEST_F(ResolverTest,PrefixDiscoveryBypassTls)3926 TEST_F(ResolverTest, PrefixDiscoveryBypassTls) {
3927     constexpr char listen_addr[] = "::1";
3928     constexpr char cleartext_port[] = "53";
3929     constexpr char tls_port[] = "853";
3930     constexpr char dns64_name[] = "ipv4only.arpa.";
3931     const std::vector<std::string> servers = {listen_addr};
3932 
3933     test::DNSResponder dns(listen_addr);
3934     StartDns(dns, {{dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"}});
3935     test::DnsTlsFrontend tls(listen_addr, tls_port, listen_addr, cleartext_port);
3936     ASSERT_TRUE(tls.startServer());
3937 
3938     // Setup OPPORTUNISTIC mode and wait for the validation complete.
3939     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
3940     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
3941     EXPECT_TRUE(tls.waitForQueries(1));
3942     tls.clearQueries();
3943 
3944     // Start NAT64 prefix discovery and wait for it complete.
3945     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3946     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3947 
3948     // Verify it bypassed TLS even though there's a TLS server available.
3949     EXPECT_EQ(0, tls.queries()) << dns.dumpQueries();
3950     EXPECT_EQ(1U, GetNumQueries(dns, dns64_name)) << dns.dumpQueries();
3951 
3952     // Restart the testing network to reset the cache.
3953     mDnsClient.TearDown();
3954     mDnsClient.SetUp();
3955     dns.clearQueries();
3956 
3957     // Setup STRICT mode and wait for the validation complete.
3958     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
3959                                                kDefaultPrivateDnsHostName));
3960     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
3961     EXPECT_TRUE(tls.waitForQueries(1));
3962     tls.clearQueries();
3963 
3964     // Start NAT64 prefix discovery and wait for it to complete.
3965     EXPECT_TRUE(mDnsClient.resolvService()->startPrefix64Discovery(TEST_NETID).isOk());
3966     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
3967 
3968     // Verify it bypassed TLS despite STRICT mode.
3969     EXPECT_EQ(0, tls.queries()) << dns.dumpQueries();
3970     EXPECT_EQ(1U, GetNumQueries(dns, dns64_name)) << dns.dumpQueries();
3971 }
3972 
TEST_F(ResolverTest,SetAndClearNat64Prefix)3973 TEST_F(ResolverTest, SetAndClearNat64Prefix) {
3974     constexpr char host_name[] = "v4.example.com.";
3975     constexpr char listen_addr[] = "::1";
3976     const std::vector<DnsRecord> records = {
3977             {host_name, ns_type::ns_t_a, "1.2.3.4"},
3978     };
3979     const std::string kNat64Prefix1 = "64:ff9b::/96";
3980     const std::string kNat64Prefix2 = "2001:db8:6464::/96";
3981 
3982     test::DNSResponder dns(listen_addr);
3983     StartDns(dns, records);
3984     const std::vector<std::string> servers = {listen_addr};
3985     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
3986 
3987     auto resolvService = mDnsClient.resolvService();
3988     addrinfo hints = {.ai_family = AF_INET6};
3989 
3990     // No NAT64 prefix, no AAAA record.
3991     ScopedAddrinfo result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
3992     ASSERT_TRUE(result == nullptr);
3993 
3994     // Set the prefix, and expect to get a synthesized AAAA record.
3995     EXPECT_TRUE(resolvService->setPrefix64(TEST_NETID, kNat64Prefix2).isOk());
3996     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
3997     ASSERT_FALSE(result == nullptr);
3998     EXPECT_EQ("2001:db8:6464::102:304", ToString(result));
3999 
4000     // Update the prefix, expect to see AAAA records from the new prefix.
4001     EXPECT_TRUE(resolvService->setPrefix64(TEST_NETID, kNat64Prefix1).isOk());
4002     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4003     ASSERT_FALSE(result == nullptr);
4004     EXPECT_EQ("64:ff9b::102:304", ToString(result));
4005 
4006     // Non-/96 prefixes are ignored.
4007     auto status = resolvService->setPrefix64(TEST_NETID, "64:ff9b::/64");
4008     EXPECT_FALSE(status.isOk());
4009     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4010     EXPECT_EQ(EINVAL, status.getServiceSpecificError());
4011 
4012     // Invalid prefixes are ignored.
4013     status = resolvService->setPrefix64(TEST_NETID, "192.0.2.0/24");
4014     EXPECT_FALSE(status.isOk());
4015     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4016     EXPECT_EQ(EINVAL, status.getServiceSpecificError());
4017 
4018     status = resolvService->setPrefix64(TEST_NETID, "192.0.2.1");
4019     EXPECT_FALSE(status.isOk());
4020     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4021     EXPECT_EQ(EINVAL, status.getServiceSpecificError());
4022 
4023     status = resolvService->setPrefix64(TEST_NETID, "hello");
4024     EXPECT_FALSE(status.isOk());
4025     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4026     EXPECT_EQ(EINVAL, status.getServiceSpecificError());
4027 
4028     // DNS64 synthesis is still working.
4029     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4030     ASSERT_FALSE(result == nullptr);
4031     EXPECT_EQ("64:ff9b::102:304", ToString(result));
4032 
4033     // Clear the prefix. No AAAA records any more.
4034     EXPECT_TRUE(resolvService->setPrefix64(TEST_NETID, "").isOk());
4035     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4036     EXPECT_TRUE(result == nullptr);
4037 
4038     // Calling startPrefix64Discovery clears the prefix.
4039     EXPECT_TRUE(resolvService->setPrefix64(TEST_NETID, kNat64Prefix1).isOk());
4040     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4041     ASSERT_FALSE(result == nullptr);
4042     EXPECT_EQ("64:ff9b::102:304", ToString(result));
4043 
4044     EXPECT_TRUE(resolvService->startPrefix64Discovery(TEST_NETID).isOk());
4045     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4046     ASSERT_TRUE(result == nullptr);
4047 
4048     // setPrefix64 fails if prefix discovery is started, even if no prefix is yet discovered...
4049     status = resolvService->setPrefix64(TEST_NETID, kNat64Prefix1);
4050     EXPECT_FALSE(status.isOk());
4051     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4052     EXPECT_EQ(EEXIST, status.getServiceSpecificError());
4053 
4054     // .. and clearing the prefix also has no effect.
4055     status = resolvService->setPrefix64(TEST_NETID, "");
4056     EXPECT_FALSE(status.isOk());
4057     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4058     EXPECT_EQ(ENOENT, status.getServiceSpecificError());
4059 
4060     // setPrefix64 succeeds again when prefix discovery is stopped.
4061     EXPECT_TRUE(resolvService->stopPrefix64Discovery(TEST_NETID).isOk());
4062     EXPECT_TRUE(resolvService->setPrefix64(TEST_NETID, kNat64Prefix1).isOk());
4063     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4064     ASSERT_FALSE(result == nullptr);
4065     EXPECT_EQ("64:ff9b::102:304", ToString(result));
4066 
4067     // Calling stopPrefix64Discovery clears the prefix.
4068     EXPECT_TRUE(resolvService->stopPrefix64Discovery(TEST_NETID).isOk());
4069     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4070     ASSERT_TRUE(result == nullptr);
4071 
4072     // Set up NAT64 prefix discovery.
4073     constexpr char dns64_name[] = "ipv4only.arpa.";
4074     const std::vector<DnsRecord> newRecords = {
4075             {host_name, ns_type::ns_t_a, "1.2.3.4"},
4076             {dns64_name, ns_type::ns_t_aaaa, "64:ff9b::192.0.0.170"},
4077     };
4078     dns.stopServer();
4079     StartDns(dns, newRecords);
4080 
4081     EXPECT_TRUE(resolvService->startPrefix64Discovery(TEST_NETID).isOk());
4082     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_FOUND));
4083     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4084     ASSERT_FALSE(result == nullptr);
4085     EXPECT_EQ("64:ff9b::102:304", ToString(result));
4086 
4087     // setPrefix64 fails if NAT64 prefix discovery has succeeded, and the discovered prefix
4088     // continues to be used.
4089     status = resolvService->setPrefix64(TEST_NETID, kNat64Prefix2);
4090     EXPECT_FALSE(status.isOk());
4091     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4092     EXPECT_EQ(EEXIST, status.getServiceSpecificError());
4093 
4094     // Clearing the prefix also has no effect if discovery is started.
4095     status = resolvService->setPrefix64(TEST_NETID, "");
4096     EXPECT_FALSE(status.isOk());
4097     EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
4098     EXPECT_EQ(ENOENT, status.getServiceSpecificError());
4099 
4100     result = safe_getaddrinfo("v4.example.com", nullptr, &hints);
4101     ASSERT_FALSE(result == nullptr);
4102     EXPECT_EQ("64:ff9b::102:304", ToString(result));
4103 
4104     EXPECT_TRUE(resolvService->stopPrefix64Discovery(TEST_NETID).isOk());
4105     EXPECT_TRUE(WaitForNat64Prefix(EXPECT_NOT_FOUND));
4106 
4107     EXPECT_EQ(0, sDnsMetricsListener->getUnexpectedNat64PrefixUpdates());
4108     EXPECT_EQ(0, sUnsolicitedEventListener->getUnexpectedNat64PrefixUpdates());
4109 }
4110 
4111 namespace {
4112 
4113 class ScopedSetNetworkForProcess {
4114   public:
ScopedSetNetworkForProcess(unsigned netId)4115     explicit ScopedSetNetworkForProcess(unsigned netId) {
4116         mStoredNetId = getNetworkForProcess();
4117         if (netId == mStoredNetId) return;
4118         EXPECT_EQ(0, setNetworkForProcess(netId));
4119     }
~ScopedSetNetworkForProcess()4120     ~ScopedSetNetworkForProcess() { EXPECT_EQ(0, setNetworkForProcess(mStoredNetId)); }
4121 
4122   private:
4123     unsigned mStoredNetId;
4124 };
4125 
4126 class ScopedSetNetworkForResolv {
4127   public:
ScopedSetNetworkForResolv(unsigned netId)4128     explicit ScopedSetNetworkForResolv(unsigned netId) { EXPECT_EQ(0, setNetworkForResolv(netId)); }
~ScopedSetNetworkForResolv()4129     ~ScopedSetNetworkForResolv() { EXPECT_EQ(0, setNetworkForResolv(NETID_UNSET)); }
4130 };
4131 
sendCommand(int fd,const std::string & cmd)4132 void sendCommand(int fd, const std::string& cmd) {
4133     ssize_t rc = TEMP_FAILURE_RETRY(write(fd, cmd.c_str(), cmd.size() + 1));
4134     EXPECT_EQ(rc, static_cast<ssize_t>(cmd.size() + 1));
4135 }
4136 
readBE32(int fd)4137 int32_t readBE32(int fd) {
4138     int32_t tmp;
4139     int n = TEMP_FAILURE_RETRY(read(fd, &tmp, sizeof(tmp)));
4140     EXPECT_TRUE(n > 0);
4141     return ntohl(tmp);
4142 }
4143 
readResponseCode(int fd)4144 int readResponseCode(int fd) {
4145     char buf[4];
4146     int n = TEMP_FAILURE_RETRY(read(fd, &buf, sizeof(buf)));
4147     EXPECT_TRUE(n > 0);
4148     // The format of response code is that 4 bytes for the code & null.
4149     buf[3] = '\0';
4150     int result;
4151     EXPECT_TRUE(ParseInt(buf, &result));
4152     return result;
4153 }
4154 
checkAndClearUseLocalNameserversFlag(unsigned * netid)4155 bool checkAndClearUseLocalNameserversFlag(unsigned* netid) {
4156     if (netid == nullptr || ((*netid) & NETID_USE_LOCAL_NAMESERVERS) == 0) {
4157         return false;
4158     }
4159     *netid = (*netid) & ~NETID_USE_LOCAL_NAMESERVERS;
4160     return true;
4161 }
4162 
makeUidRangeParcel(int start,int stop)4163 aidl::android::net::UidRangeParcel makeUidRangeParcel(int start, int stop) {
4164     aidl::android::net::UidRangeParcel res;
4165     res.start = start;
4166     res.stop = stop;
4167 
4168     return res;
4169 }
4170 
expectNetIdWithLocalNameserversFlag(unsigned netId)4171 void expectNetIdWithLocalNameserversFlag(unsigned netId) {
4172     unsigned dnsNetId = 0;
4173     EXPECT_EQ(0, getNetworkForDns(&dnsNetId));
4174     EXPECT_TRUE(checkAndClearUseLocalNameserversFlag(&dnsNetId));
4175     EXPECT_EQ(netId, static_cast<unsigned>(dnsNetId));
4176 }
4177 
expectDnsNetIdEquals(unsigned netId)4178 void expectDnsNetIdEquals(unsigned netId) {
4179     unsigned dnsNetId = 0;
4180     EXPECT_EQ(0, getNetworkForDns(&dnsNetId));
4181     EXPECT_EQ(netId, static_cast<unsigned>(dnsNetId));
4182 }
4183 
expectDnsNetIdIsDefaultNetwork(INetd * netdService)4184 void expectDnsNetIdIsDefaultNetwork(INetd* netdService) {
4185     int currentNetid;
4186     EXPECT_TRUE(netdService->networkGetDefault(&currentNetid).isOk());
4187     expectDnsNetIdEquals(currentNetid);
4188 }
4189 
expectDnsNetIdWithVpn(INetd * netdService,unsigned vpnNetId,unsigned expectedNetId)4190 void expectDnsNetIdWithVpn(INetd* netdService, unsigned vpnNetId, unsigned expectedNetId) {
4191     if (DnsResponderClient::isRemoteVersionSupported(netdService, 6)) {
4192         const auto& config = DnsResponderClient::makeNativeNetworkConfig(
4193                 vpnNetId, NativeNetworkType::VIRTUAL, INetd::PERMISSION_NONE, /*secure=*/false);
4194         EXPECT_TRUE(netdService->networkCreate(config).isOk());
4195     } else {
4196 #pragma clang diagnostic push
4197 #pragma clang diagnostic ignored "-Wdeprecated-declarations"
4198         EXPECT_TRUE(netdService->networkCreateVpn(vpnNetId, false /* secure */).isOk());
4199 #pragma clang diagnostic pop
4200     }
4201 
4202     uid_t uid = getuid();
4203     // Add uid to VPN
4204     EXPECT_TRUE(netdService->networkAddUidRanges(vpnNetId, {makeUidRangeParcel(uid, uid)}).isOk());
4205     expectDnsNetIdEquals(expectedNetId);
4206     EXPECT_TRUE(netdService->networkDestroy(vpnNetId).isOk());
4207 }
4208 
4209 }  // namespace
4210 
TEST_F(ResolverTest,getDnsNetId)4211 TEST_F(ResolverTest, getDnsNetId) {
4212     // We've called setNetworkForProcess in SetupOemNetwork, so reset to default first.
4213     setNetworkForProcess(NETID_UNSET);
4214 
4215     expectDnsNetIdIsDefaultNetwork(mDnsClient.netdService());
4216     expectDnsNetIdWithVpn(mDnsClient.netdService(), TEST_VPN_NETID, TEST_VPN_NETID);
4217 
4218     // Test with setNetworkForProcess
4219     {
4220         ScopedSetNetworkForProcess scopedSetNetworkForProcess(TEST_NETID);
4221         expectDnsNetIdEquals(TEST_NETID);
4222     }
4223 
4224     // Test with setNetworkForProcess with NETID_USE_LOCAL_NAMESERVERS
4225     {
4226         ScopedSetNetworkForProcess scopedSetNetworkForProcess(TEST_NETID |
4227                                                               NETID_USE_LOCAL_NAMESERVERS);
4228         expectNetIdWithLocalNameserversFlag(TEST_NETID);
4229     }
4230 
4231     // Test with setNetworkForResolv
4232     {
4233         ScopedSetNetworkForResolv scopedSetNetworkForResolv(TEST_NETID);
4234         expectDnsNetIdEquals(TEST_NETID);
4235     }
4236 
4237     // Test with setNetworkForResolv with NETID_USE_LOCAL_NAMESERVERS
4238     {
4239         ScopedSetNetworkForResolv scopedSetNetworkForResolv(TEST_NETID |
4240                                                             NETID_USE_LOCAL_NAMESERVERS);
4241         expectNetIdWithLocalNameserversFlag(TEST_NETID);
4242     }
4243 
4244     // Test with setNetworkForResolv under bypassable vpn
4245     {
4246         ScopedSetNetworkForResolv scopedSetNetworkForResolv(TEST_NETID);
4247         expectDnsNetIdWithVpn(mDnsClient.netdService(), TEST_VPN_NETID, TEST_NETID);
4248     }
4249 
4250     // Create socket connected to DnsProxyListener
4251     int fd = dns_open_proxy();
4252     EXPECT_TRUE(fd > 0);
4253     unique_fd ufd(fd);
4254 
4255     // Test command with wrong netId
4256     sendCommand(fd, "getdnsnetid abc");
4257     EXPECT_EQ(ResponseCode::DnsProxyQueryResult, readResponseCode(fd));
4258     EXPECT_EQ(-EINVAL, readBE32(fd));
4259 
4260     // Test unsupported command
4261     sendCommand(fd, "getdnsnetidNotSupported");
4262     // Keep in sync with FrameworkListener.cpp (500, "Command not recognized")
4263     EXPECT_EQ(500, readResponseCode(fd));
4264 }
4265 
TEST_F(ResolverTest,BlockDnsQueryWithUidRule)4266 TEST_F(ResolverTest, BlockDnsQueryWithUidRule) {
4267     SKIP_IF_BPF_NOT_SUPPORTED;
4268     constexpr char listen_addr1[] = "127.0.0.4";
4269     constexpr char listen_addr2[] = "::1";
4270     constexpr char host_name[] = "howdy.example.com.";
4271     const std::vector<DnsRecord> records = {
4272             {host_name, ns_type::ns_t_a, "1.2.3.4"},
4273             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
4274     };
4275     INetd* netdService = mDnsClient.netdService();
4276 
4277     test::DNSResponder dns1(listen_addr1);
4278     test::DNSResponder dns2(listen_addr2);
4279     StartDns(dns1, records);
4280     StartDns(dns2, records);
4281 
4282     std::vector<std::string> servers = {listen_addr1, listen_addr2};
4283     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
4284     dns1.clearQueries();
4285     dns2.clearQueries();
4286 
4287     ScopeBlockedUIDRule scopeBlockUidRule(netdService, TEST_UID);
4288     // Dns Query
4289     int fd1 = resNetworkQuery(TEST_NETID, host_name, ns_c_in, ns_t_a, 0);
4290     int fd2 = resNetworkQuery(TEST_NETID, host_name, ns_c_in, ns_t_aaaa, 0);
4291     EXPECT_TRUE(fd1 != -1);
4292     EXPECT_TRUE(fd2 != -1);
4293 
4294     uint8_t buf1[MAXPACKET] = {};
4295     uint8_t buf2[MAXPACKET] = {};
4296     int rcode;
4297     int res2 = getAsyncResponse(fd2, &rcode, buf2, MAXPACKET);
4298     int res1 = getAsyncResponse(fd1, &rcode, buf1, MAXPACKET);
4299     // If API level >= 30 (R+), these queries should be blocked.
4300     if (isAtLeastR) {
4301         EXPECT_EQ(res2, -ECONNREFUSED);
4302         EXPECT_EQ(res1, -ECONNREFUSED);
4303         ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, EAI_SYSTEM, "howdy.example.com", {});
4304         ExpectDnsEvent(INetdEventListener::EVENT_RES_NSEND, EAI_SYSTEM, "howdy.example.com", {});
4305     } else {
4306         EXPECT_GT(res2, 0);
4307         EXPECT_EQ("::1.2.3.4", toString(buf2, res2, AF_INET6));
4308         EXPECT_GT(res1, 0);
4309         EXPECT_EQ("1.2.3.4", toString(buf1, res1, AF_INET));
4310         // To avoid flaky test, do not evaluate DnsEvent since event order is not guaranteed.
4311     }
4312 }
4313 
TEST_F(ResolverTest,GetAddrinfo_BlockDnsQueryWithUidRule)4314 TEST_F(ResolverTest, GetAddrinfo_BlockDnsQueryWithUidRule) {
4315     SKIP_IF_BPF_NOT_SUPPORTED;
4316     constexpr char listen_addr1[] = "127.0.0.4";
4317     constexpr char listen_addr2[] = "::1";
4318     constexpr char host_name[] = "howdy.example.com.";
4319     const std::vector<DnsRecord> records = {
4320             {host_name, ns_type::ns_t_a, "1.2.3.4"},
4321             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
4322     };
4323     test::DNSResponder dns1(listen_addr1);
4324     test::DNSResponder dns2(listen_addr2);
4325     StartDns(dns1, records);
4326     StartDns(dns2, records);
4327 
4328     std::vector<std::string> servers = {listen_addr1, listen_addr2};
4329     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, kDefaultSearchDomains, kDefaultParams));
4330 
4331     const addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
4332 
4333     static struct {
4334         const char* hname;
4335         const int expectedErrorCode;
4336     } kTestData[] = {
4337             {host_name, EAI_NODATA},
4338             // To test the query with search domain.
4339             {"howdy", EAI_AGAIN},
4340     };
4341 
4342     INetd* netdService = mDnsClient.netdService();
4343     for (auto& td : kTestData) {
4344         SCOPED_TRACE(td.hname);
4345         ScopeBlockedUIDRule scopeBlockUidRule(netdService, TEST_UID);
4346         // If API level >= 30 (R+), these queries should be blocked.
4347         if (isAtLeastR) {
4348             addrinfo* result = nullptr;
4349             // getaddrinfo() in bionic would convert all errors to EAI_NODATA
4350             // except EAI_SYSTEM.
4351             EXPECT_EQ(EAI_NODATA, getaddrinfo(td.hname, nullptr, &hints, &result));
4352             ExpectDnsEvent(INetdEventListener::EVENT_GETADDRINFO, td.expectedErrorCode, td.hname,
4353                            {});
4354         } else {
4355             ScopedAddrinfo result = safe_getaddrinfo(td.hname, nullptr, &hints);
4356             EXPECT_NE(nullptr, result);
4357             EXPECT_THAT(ToStrings(result),
4358                         testing::UnorderedElementsAreArray({"1.2.3.4", "::1.2.3.4"}));
4359             // To avoid flaky test, do not evaluate DnsEvent since event order is not guaranteed.
4360         }
4361     }
4362 }
4363 
TEST_F(ResolverTest,EnforceDnsUid)4364 TEST_F(ResolverTest, EnforceDnsUid) {
4365     SKIP_IF_BPF_NOT_SUPPORTED;
4366     constexpr char listen_addr1[] = "127.0.0.4";
4367     constexpr char listen_addr2[] = "::1";
4368     constexpr char host_name[] = "howdy.example.com.";
4369     const std::vector<DnsRecord> records = {
4370             {host_name, ns_type::ns_t_a, "1.2.3.4"},
4371             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
4372     };
4373     INetd* netdService = mDnsClient.netdService();
4374 
4375     test::DNSResponder dns1(listen_addr1);
4376     test::DNSResponder dns2(listen_addr2);
4377     StartDns(dns1, records);
4378     StartDns(dns2, records);
4379 
4380     // switch uid of DNS queries from applications to AID_DNS
4381     ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
4382     parcel.servers = {listen_addr1, listen_addr2};
4383     ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk());
4384 
4385     uint8_t buf[MAXPACKET] = {};
4386     uint8_t buf2[MAXPACKET] = {};
4387     int rcode;
4388     {
4389         ScopeBlockedUIDRule scopeBlockUidRule(netdService, TEST_UID);
4390         // Dns Queries should be blocked
4391         const int fd1 = resNetworkQuery(TEST_NETID, host_name, ns_c_in, ns_t_a, 0);
4392         const int fd2 = resNetworkQuery(TEST_NETID, host_name, ns_c_in, ns_t_aaaa, 0);
4393         EXPECT_TRUE(fd1 != -1);
4394         EXPECT_TRUE(fd2 != -1);
4395 
4396         const int res2 = getAsyncResponse(fd2, &rcode, buf2, MAXPACKET);
4397         const int res1 = getAsyncResponse(fd1, &rcode, buf, MAXPACKET);
4398         // If API level >= 30 (R+), the query should be blocked.
4399         if (isAtLeastR) {
4400             EXPECT_EQ(res2, -ECONNREFUSED);
4401             EXPECT_EQ(res1, -ECONNREFUSED);
4402         } else {
4403             EXPECT_GT(res2, 0);
4404             EXPECT_EQ("::1.2.3.4", toString(buf2, res2, AF_INET6));
4405             EXPECT_GT(res1, 0);
4406             EXPECT_EQ("1.2.3.4", toString(buf, res1, AF_INET));
4407         }
4408     }
4409 
4410     memset(buf, 0, MAXPACKET);
4411     ResolverOptionsParcel resolverOptions;
4412     resolverOptions.enforceDnsUid = true;
4413     if (!mIsResolverOptionIPCSupported) {
4414         parcel.resolverOptions = resolverOptions;
4415         ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk());
4416     } else {
4417         ASSERT_TRUE(mDnsClient.resolvService()
4418                             ->setResolverOptions(parcel.netId, resolverOptions)
4419                             .isOk());
4420     }
4421 
4422     {
4423         ScopeBlockedUIDRule scopeBlockUidRule(netdService, TEST_UID);
4424         // Dns Queries should NOT be blocked
4425         int fd1 = resNetworkQuery(TEST_NETID, host_name, ns_c_in, ns_t_a, 0);
4426         int fd2 = resNetworkQuery(TEST_NETID, host_name, ns_c_in, ns_t_aaaa, 0);
4427         EXPECT_TRUE(fd1 != -1);
4428         EXPECT_TRUE(fd2 != -1);
4429 
4430         int res = getAsyncResponse(fd2, &rcode, buf, MAXPACKET);
4431         EXPECT_EQ("::1.2.3.4", toString(buf, res, AF_INET6));
4432 
4433         memset(buf, 0, MAXPACKET);
4434         res = getAsyncResponse(fd1, &rcode, buf, MAXPACKET);
4435         EXPECT_EQ("1.2.3.4", toString(buf, res, AF_INET));
4436 
4437         // @TODO: So far we know that uid of DNS queries are no more set to DNS requester. But we
4438         // don't check if they are actually being set to AID_DNS, because system uids are always
4439         // allowed in bpf_owner_match(). Audit by firewallSetUidRule(AID_DNS) + sending queries is
4440         // infeasible. Fix it if the behavior of bpf_owner_match() is changed in the future, or if
4441         // we have better idea to deal with this.
4442     }
4443 }
4444 
TEST_F(ResolverTest,ConnectTlsServerTimeout)4445 TEST_F(ResolverTest, ConnectTlsServerTimeout) {
4446     constexpr char hostname1[] = "query1.example.com.";
4447     constexpr char hostname2[] = "query2.example.com.";
4448     const std::vector<DnsRecord> records = {
4449             {hostname1, ns_type::ns_t_a, "1.2.3.4"},
4450             {hostname2, ns_type::ns_t_a, "1.2.3.5"},
4451     };
4452 
4453     static const struct TestConfig {
4454         bool asyncHandshake;
4455         int maxRetries;
4456 
4457         // if asyncHandshake:
4458         //   expectedTimeout = Min(DotQueryTimeoutMs, dotConnectTimeoutMs * maxRetries)
4459         // otherwise:
4460         //   expectedTimeout = dotConnectTimeoutMs
4461         int expectedTimeout;
4462     } testConfigs[] = {
4463             // Test mis-configured dot_maxtries flag.
4464             {false, 0, 1000}, {true, 0, 1000},
4465 
4466             {false, 1, 1000}, {false, 3, 1000}, {true, 1, 1000}, {true, 3, 3000},
4467     };
4468 
4469     for (const auto& config : testConfigs) {
4470         SCOPED_TRACE(fmt::format("testConfig: [{}, {}]", config.asyncHandshake, config.maxRetries));
4471 
4472         // Because a DnsTlsTransport lasts at least 5 minutes in spite of network
4473         // destroyed, let the resolver creates an unique DnsTlsTransport every time
4474         // so that the DnsTlsTransport won't interfere the other tests.
4475         const std::string addr = getUniqueIPv4Address();
4476         test::DNSResponder dns(addr);
4477         StartDns(dns, records);
4478         test::DnsTlsFrontend tls(addr, "853", addr, "53");
4479         ASSERT_TRUE(tls.startServer());
4480 
4481         // The resolver will adjust the timeout value to 1000ms since the value is too small.
4482         ScopedSystemProperties sp1(kDotConnectTimeoutMsFlag, "100");
4483 
4484         // Infinite timeout.
4485         ScopedSystemProperties sp2(kDotQueryTimeoutMsFlag, "-1");
4486 
4487         ScopedSystemProperties sp3(kDotAsyncHandshakeFlag, config.asyncHandshake ? "1" : "0");
4488         ScopedSystemProperties sp4(kDotMaxretriesFlag, std::to_string(config.maxRetries));
4489         resetNetwork();
4490 
4491         // Set up resolver to opportunistic mode.
4492         auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
4493         parcel.servers = {addr};
4494         parcel.tlsServers = {addr};
4495         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
4496         EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
4497         EXPECT_TRUE(tls.waitForQueries(1));
4498         tls.clearQueries();
4499         dns.clearQueries();
4500 
4501         // The server becomes unresponsive to the handshake request.
4502         tls.setHangOnHandshakeForTesting(true);
4503 
4504         // Expect the things happening in getaddrinfo():
4505         //   1. Connect to the private DNS server.
4506         //   2. SSL handshake times out.
4507         //   3. Fallback to UDP transport, and then get the answer.
4508         const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
4509         auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname1, nullptr, hints);
4510 
4511         EXPECT_NE(nullptr, result);
4512         EXPECT_EQ(0, tls.queries());
4513         EXPECT_EQ(1U, GetNumQueries(dns, hostname1));
4514         EXPECT_EQ(records.at(0).addr, ToString(result));
4515 
4516         // A loose upper bound is set by adding 1000ms buffer time. Theoretically, getaddrinfo()
4517         // should just take a bit more than expetTimeout milliseconds.
4518         EXPECT_GE(timeTakenMs, config.expectedTimeout);
4519         EXPECT_LE(timeTakenMs, config.expectedTimeout + 1000);
4520 
4521         // Set the server to be responsive. Verify that the resolver will attempt to reconnect
4522         // to the server and then get the result within the timeout.
4523         tls.setHangOnHandshakeForTesting(false);
4524         std::tie(result, timeTakenMs) = safe_getaddrinfo_time_taken(hostname2, nullptr, hints);
4525 
4526         EXPECT_NE(nullptr, result);
4527         EXPECT_TRUE(tls.waitForQueries(1));
4528         EXPECT_EQ(1U, GetNumQueries(dns, hostname2));
4529         EXPECT_EQ(records.at(1).addr, ToString(result));
4530 
4531         EXPECT_LE(timeTakenMs, 200);
4532     }
4533 }
4534 
TEST_F(ResolverTest,ConnectTlsServerTimeout_ConcurrentQueries)4535 TEST_F(ResolverTest, ConnectTlsServerTimeout_ConcurrentQueries) {
4536     constexpr uint32_t cacheFlag = ANDROID_RESOLV_NO_CACHE_LOOKUP;
4537     constexpr char hostname[] = "hello.example.com.";
4538     const std::vector<DnsRecord> records = {
4539             {hostname, ns_type::ns_t_a, "1.2.3.4"},
4540     };
4541     int testConfigCount = 0;
4542 
4543     static const struct TestConfig {
4544         bool asyncHandshake;
4545         int dotConnectTimeoutMs;
4546         int dotQueryTimeoutMs;
4547         int maxRetries;
4548         int concurrency;
4549 
4550         // if asyncHandshake:
4551         //   expectedTimeout = Min(DotQueryTimeoutMs, dotConnectTimeoutMs * maxRetries)
4552         // otherwise:
4553         //   expectedTimeout = dotConnectTimeoutMs * concurrency
4554         int expectedTimeout;
4555     } testConfigs[] = {
4556             // clang-format off
4557             {false, 1000, 3000, 1, 5,  5000},
4558             {false, 1000, 3000, 3, 5,  5000},
4559             {false, 2000, 1500, 3, 2,  4000},
4560             {true,  1000, 3000, 1, 5,  1000},
4561             {true,  2500, 1500, 1, 10, 1500},
4562             {true,  1000, 5000, 3, 5,  3000},
4563             // clang-format on
4564     };
4565 
4566     // Launch query threads. Expected behaviors are:
4567     // - when dot_async_handshake is disabled, one of the query threads triggers a
4568     //   handshake and then times out. Then same as another query thread, and so forth.
4569     // - when dot_async_handshake is enabled, only one handshake is triggered, and then
4570     //   all of the query threads time out at the same time.
4571     for (const auto& config : testConfigs) {
4572         testConfigCount++;
4573         ScopedSystemProperties sp1(kDotQueryTimeoutMsFlag,
4574                                    std::to_string(config.dotQueryTimeoutMs));
4575         ScopedSystemProperties sp2(kDotConnectTimeoutMsFlag,
4576                                    std::to_string(config.dotConnectTimeoutMs));
4577         ScopedSystemProperties sp3(kDotAsyncHandshakeFlag, config.asyncHandshake ? "1" : "0");
4578         ScopedSystemProperties sp4(kDotMaxretriesFlag, std::to_string(config.maxRetries));
4579         resetNetwork();
4580 
4581         for (const auto& dnsMode : {"OPPORTUNISTIC", "STRICT"}) {
4582             SCOPED_TRACE(fmt::format("testConfig: [{}, {}]", testConfigCount, dnsMode));
4583 
4584             // Because a DnsTlsTransport lasts at least 5 minutes in spite of network
4585             // destroyed, let the resolver creates an unique DnsTlsTransport every time
4586             // so that the DnsTlsTransport won't interfere the other tests.
4587             const std::string addr = getUniqueIPv4Address();
4588             test::DNSResponder dns(addr);
4589             StartDns(dns, records);
4590             test::DnsTlsFrontend tls(addr, "853", addr, "53");
4591             ASSERT_TRUE(tls.startServer());
4592 
4593             auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
4594             parcel.servers = {addr};
4595             parcel.tlsServers = {addr};
4596             if (dnsMode == "STRICT") parcel.tlsName = kDefaultPrivateDnsHostName;
4597             ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
4598             EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
4599             EXPECT_TRUE(tls.waitForQueries(1));
4600 
4601             // The server becomes unresponsive to the handshake request.
4602             tls.setHangOnHandshakeForTesting(true);
4603 
4604             Stopwatch s;
4605             std::vector<std::thread> threads(config.concurrency);
4606             for (std::thread& thread : threads) {
4607                 thread = std::thread([&]() {
4608                     int fd = resNetworkQuery(TEST_NETID, hostname, ns_c_in, ns_t_a, cacheFlag);
4609                     dnsMode == "STRICT" ? expectAnswersNotValid(fd, -ETIMEDOUT)
4610                                         : expectAnswersValid(fd, AF_INET, "1.2.3.4");
4611                 });
4612             }
4613             for (std::thread& thread : threads) {
4614                 thread.join();
4615             }
4616 
4617             const int timeTakenMs = s.timeTakenUs() / 1000;
4618             // A loose upper bound is set by adding 1000ms buffer time. Theoretically, it should
4619             // just take a bit more than expetTimeout milliseconds for the result.
4620             EXPECT_GE(timeTakenMs, config.expectedTimeout);
4621             EXPECT_LE(timeTakenMs, config.expectedTimeout + 1000);
4622 
4623             // Recover the server from being unresponsive and try again.
4624             tls.setHangOnHandshakeForTesting(false);
4625             int fd = resNetworkQuery(TEST_NETID, hostname, ns_c_in, ns_t_a, cacheFlag);
4626             if (dnsMode == "STRICT" && config.asyncHandshake &&
4627                 config.dotQueryTimeoutMs < (config.dotConnectTimeoutMs * config.maxRetries)) {
4628                 // In this case, the connection handshake is supposed to be in progress. Queries
4629                 // sent before the handshake finishes will time out (either due to connect timeout
4630                 // or query timeout).
4631                 expectAnswersNotValid(fd, -ETIMEDOUT);
4632             } else {
4633                 expectAnswersValid(fd, AF_INET, "1.2.3.4");
4634             }
4635         }
4636     }
4637 }
4638 
TEST_F(ResolverTest,QueryTlsServerTimeout)4639 TEST_F(ResolverTest, QueryTlsServerTimeout) {
4640     constexpr uint32_t cacheFlag = ANDROID_RESOLV_NO_CACHE_LOOKUP;
4641     constexpr int INFINITE_QUERY_TIMEOUT = -1;
4642     constexpr int DOT_SERVER_UNRESPONSIVE_TIME_MS = 5000;
4643     constexpr char hostname1[] = "query1.example.com.";
4644     constexpr char hostname2[] = "query2.example.com.";
4645     const std::vector<DnsRecord> records = {
4646             {hostname1, ns_type::ns_t_a, "1.2.3.4"},
4647             {hostname2, ns_type::ns_t_a, "1.2.3.5"},
4648     };
4649 
4650     for (const int queryTimeoutMs : {INFINITE_QUERY_TIMEOUT, 1000}) {
4651         for (const auto& dnsMode : {"OPPORTUNISTIC", "STRICT"}) {
4652             SCOPED_TRACE(fmt::format("testConfig: [{}] [{}]", dnsMode, queryTimeoutMs));
4653 
4654             const std::string addr = getUniqueIPv4Address();
4655             test::DNSResponder dns(addr);
4656             StartDns(dns, records);
4657             test::DnsTlsFrontend tls(addr, "853", addr, "53");
4658             ASSERT_TRUE(tls.startServer());
4659 
4660             ScopedSystemProperties sp(kDotQueryTimeoutMsFlag, std::to_string(queryTimeoutMs));
4661             resetNetwork();
4662 
4663             auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
4664             parcel.servers = {addr};
4665             parcel.tlsServers = {addr};
4666             if (dnsMode == "STRICT") parcel.tlsName = kDefaultPrivateDnsHostName;
4667 
4668             ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
4669             EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
4670             EXPECT_TRUE(tls.waitForQueries(1));
4671             tls.clearQueries();
4672 
4673             // Set the DoT server to be unresponsive to DNS queries until either it receives
4674             // 2 queries or 5s later.
4675             tls.setDelayQueries(2);
4676             tls.setDelayQueriesTimeout(DOT_SERVER_UNRESPONSIVE_TIME_MS);
4677 
4678             // First query.
4679             Stopwatch s;
4680             int fd = resNetworkQuery(TEST_NETID, hostname1, ns_c_in, ns_t_a, cacheFlag);
4681             if (dnsMode == "STRICT" && queryTimeoutMs != INFINITE_QUERY_TIMEOUT) {
4682                 expectAnswersNotValid(fd, -ETIMEDOUT);
4683             } else {
4684                 expectAnswersValid(fd, AF_INET, "1.2.3.4");
4685             }
4686 
4687             // Besides checking the result of the query, check how much time the
4688             // resolver processed the query.
4689             int timeTakenMs = s.getTimeAndResetUs() / 1000;
4690             const int expectedTimeTakenMs = (queryTimeoutMs == INFINITE_QUERY_TIMEOUT)
4691                                                     ? DOT_SERVER_UNRESPONSIVE_TIME_MS
4692                                                     : queryTimeoutMs;
4693             EXPECT_GE(timeTakenMs, expectedTimeTakenMs);
4694             EXPECT_LE(timeTakenMs, expectedTimeTakenMs + 1000);
4695 
4696             // Second query.
4697             tls.setDelayQueries(1);
4698             fd = resNetworkQuery(TEST_NETID, hostname2, ns_c_in, ns_t_a, cacheFlag);
4699             expectAnswersValid(fd, AF_INET, "1.2.3.5");
4700 
4701             // Also check how much time the resolver processed the query.
4702             timeTakenMs = s.timeTakenUs() / 1000;
4703             EXPECT_LE(timeTakenMs, 500);
4704             EXPECT_EQ(2, tls.queries());
4705         }
4706     }
4707 }
4708 
4709 // Verifies that the DnsResolver re-validates the DoT server when several DNS queries to
4710 // the server fails in a row.
TEST_F(ResolverTest,TlsServerRevalidation)4711 TEST_F(ResolverTest, TlsServerRevalidation) {
4712     constexpr uint32_t cacheFlag = ANDROID_RESOLV_NO_CACHE_LOOKUP;
4713     constexpr int dotXportUnusableThreshold = 10;
4714     constexpr int dotQueryTimeoutMs = 1000;
4715     constexpr char hostname[] = "hello.example.com.";
4716     const std::vector<DnsRecord> records = {
4717             {hostname, ns_type::ns_t_a, "1.2.3.4"},
4718     };
4719 
4720     static const struct TestConfig {
4721         std::string dnsMode;
4722         int validationThreshold;
4723         int queries;
4724 
4725         // Expected behavior in the DnsResolver.
4726         bool expectRevalidationHappen;
4727         bool expectDotUnusable;
4728     } testConfigs[] = {
4729             // clang-format off
4730             {"OPPORTUNISTIC", -1,  5, false, false},
4731             {"OPPORTUNISTIC", -1, 10, false, false},
4732             {"OPPORTUNISTIC",  5,  5,  true, false},
4733             {"OPPORTUNISTIC",  5, 10,  true,  true},
4734             {"STRICT",        -1,  5, false, false},
4735             {"STRICT",        -1, 10, false, false},
4736             {"STRICT",         5,  5, false, false},
4737             {"STRICT",         5, 10, false, false},
4738             // clang-format on
4739     };
4740 
4741     for (const auto& config : testConfigs) {
4742         SCOPED_TRACE(fmt::format("testConfig: [{}, {}, {}]", config.dnsMode,
4743                                  config.validationThreshold, config.queries));
4744         const int queries = config.queries;
4745         const int delayQueriesTimeout = dotQueryTimeoutMs + 1000;
4746 
4747         ScopedSystemProperties sp1(kDotRevalidationThresholdFlag,
4748                                    std::to_string(config.validationThreshold));
4749         ScopedSystemProperties sp2(kDotXportUnusableThresholdFlag,
4750                                    std::to_string(dotXportUnusableThreshold));
4751         ScopedSystemProperties sp3(kDotQueryTimeoutMsFlag, std::to_string(dotQueryTimeoutMs));
4752         resetNetwork();
4753 
4754         // This test is sensitive to the number of queries sent in DoT validation.
4755         int latencyFactor;
4756         int latencyOffsetMs;
4757         if (isAtLeastR) {
4758             // The feature is enabled by default in R.
4759             latencyFactor = std::stoi(GetProperty(kDotValidationLatencyFactorFlag, "3"));
4760             latencyOffsetMs = std::stoi(GetProperty(kDotValidationLatencyOffsetMsFlag, "100"));
4761         } else {
4762             // The feature is disabled by default in Q.
4763             latencyFactor = std::stoi(GetProperty(kDotValidationLatencyFactorFlag, "-1"));
4764             latencyOffsetMs = std::stoi(GetProperty(kDotValidationLatencyOffsetMsFlag, "-1"));
4765         }
4766         const bool dotValidationExtraProbes = (config.dnsMode == "OPPORTUNISTIC") &&
4767                                               (latencyFactor >= 0 && latencyOffsetMs >= 0 &&
4768                                                latencyFactor + latencyOffsetMs != 0);
4769 
4770         const std::string addr = getUniqueIPv4Address();
4771         test::DNSResponder dns(addr);
4772         StartDns(dns, records);
4773         test::DnsTlsFrontend tls(addr, "853", addr, "53");
4774         ASSERT_TRUE(tls.startServer());
4775 
4776         auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
4777         parcel.servers = {addr};
4778         parcel.tlsServers = {addr};
4779         if (config.dnsMode == "STRICT") parcel.tlsName = kDefaultPrivateDnsHostName;
4780         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
4781         EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
4782         if (dotValidationExtraProbes) {
4783             EXPECT_TRUE(tls.waitForQueries(2));
4784         } else {
4785             EXPECT_TRUE(tls.waitForQueries(1));
4786         }
4787         tls.clearQueries();
4788         dns.clearQueries();
4789 
4790         // Expect the things happening in order:
4791         // 1. Configure the DoT server to postpone |queries + 1| DNS queries.
4792         // 2. Send |queries| DNS queries, they will time out in 1 second.
4793         // 3. 1 second later, the DoT server still waits for one more DNS query until
4794         //    |delayQueriesTimeout| times out.
4795         // 4. (opportunistic mode only) Meanwhile, DoT revalidation happens. The DnsResolver
4796         //    creates a new connection and sends a query to the DoT server.
4797         // 5. 1 second later, |delayQueriesTimeout| times out. The DoT server flushes all of the
4798         //    postponed DNS queries, and handles the query which comes from the revalidation.
4799         // 6. (opportunistic mode only) The revalidation succeeds.
4800         // 7. Send another DNS query, and expect it will succeed.
4801         // 8. (opportunistic mode only) If the DoT server has been deemed as unusable, the
4802         //    DnsResolver skips trying the DoT server.
4803 
4804         // Step 1.
4805         tls.setDelayQueries(queries + 1);
4806         tls.setDelayQueriesTimeout(delayQueriesTimeout);
4807 
4808         // Step 2.
4809         std::vector<std::thread> threads1(queries);
4810         for (std::thread& thread : threads1) {
4811             thread = std::thread([&]() {
4812                 int fd = resNetworkQuery(TEST_NETID, hostname, ns_c_in, ns_t_a, cacheFlag);
4813                 config.dnsMode == "STRICT" ? expectAnswersNotValid(fd, -ETIMEDOUT)
4814                                            : expectAnswersValid(fd, AF_INET, "1.2.3.4");
4815             });
4816         }
4817 
4818         // Step 3 and 4.
4819         for (std::thread& thread : threads1) {
4820             thread.join();
4821         }
4822 
4823         // Recover the config to make the revalidation can succeed.
4824         tls.setDelayQueries(1);
4825 
4826         // Step 5 and 6.
4827         int expectedDotQueries = queries;
4828         int extraDnsProbe = 0;
4829         if (config.expectRevalidationHappen) {
4830             EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
4831             expectedDotQueries++;
4832 
4833             if (dotValidationExtraProbes) {
4834                 expectedDotQueries++;
4835                 extraDnsProbe = 1;
4836             }
4837         }
4838 
4839         // Step 7 and 8.
4840         int fd = resNetworkQuery(TEST_NETID, hostname, ns_c_in, ns_t_a, cacheFlag);
4841         expectAnswersValid(fd, AF_INET, "1.2.3.4");
4842         expectedDotQueries++;
4843 
4844         const int expectedDo53Queries =
4845                 expectedDotQueries +
4846                 (config.dnsMode == "OPPORTUNISTIC" ? (queries + extraDnsProbe) : 0);
4847 
4848         if (config.expectDotUnusable) {
4849             // A DoT server can be deemed as unusable only in opportunistic mode. When it happens,
4850             // the DnsResolver doesn't use the DoT server for a certain period of time.
4851             expectedDotQueries--;
4852         }
4853 
4854         // This code makes the test more robust to race condition.
4855         EXPECT_TRUE(tls.waitForQueries(expectedDotQueries));
4856 
4857         EXPECT_EQ(dns.queries().size(), static_cast<unsigned>(expectedDo53Queries));
4858         EXPECT_EQ(tls.queries(), expectedDotQueries);
4859     }
4860 }
4861 
4862 // Verifies that private DNS validation fails if DoT server is much slower than cleartext server.
TEST_F(ResolverTest,TlsServerValidation_UdpProbe)4863 TEST_F(ResolverTest, TlsServerValidation_UdpProbe) {
4864     constexpr char backend_addr[] = "127.0.0.3";
4865     test::DNSResponder backend(backend_addr);
4866     backend.setResponseDelayMs(200);
4867     ASSERT_TRUE(backend.startServer());
4868 
4869     static const struct TestConfig {
4870         int latencyFactor;
4871         int latencyOffsetMs;
4872         bool udpProbeLost;
4873         size_t expectedUdpProbes;
4874         bool expectedValidationPass;
4875     } testConfigs[] = {
4876             // clang-format off
4877             {-1, -1,  false, 0, true},
4878             {0,  0,   false, 0, true},
4879             {1,  10,  false, 1, false},
4880             {1,  10,  true,  2, false},
4881             {5,  300, false, 1, true},
4882             {5,  300, true,  2, true},
4883             // clang-format on
4884     };
4885 
4886     for (const auto& config : testConfigs) {
4887         SCOPED_TRACE(fmt::format("testConfig: [{}, {}, {}]", config.latencyFactor,
4888                                  config.latencyOffsetMs, config.udpProbeLost));
4889 
4890         const std::string addr = getUniqueIPv4Address();
4891         test::DNSResponder dns(addr, "53", static_cast<ns_rcode>(-1));
4892         test::DnsTlsFrontend tls(addr, "853", backend_addr, "53");
4893         dns.setResponseDelayMs(10);
4894         ASSERT_TRUE(dns.startServer());
4895         ASSERT_TRUE(tls.startServer());
4896 
4897         ScopedSystemProperties sp1(kDotValidationLatencyFactorFlag,
4898                                    std::to_string(config.latencyFactor));
4899         ScopedSystemProperties sp2(kDotValidationLatencyOffsetMsFlag,
4900                                    std::to_string(config.latencyOffsetMs));
4901         resetNetwork();
4902 
4903         std::unique_ptr<std::thread> thread;
4904         if (config.udpProbeLost) {
4905             thread.reset(new std::thread([&dns]() {
4906                 // Simulate that the first UDP probe is lost and the second UDP probe succeeds.
4907                 dns.setResponseProbability(0.0);
4908                 std::this_thread::sleep_for(std::chrono::seconds(2));
4909                 dns.setResponseProbability(1.0);
4910             }));
4911         }
4912 
4913         // Set up opportunistic mode, and wait for the validation complete.
4914         auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
4915         parcel.servers = {addr};
4916         parcel.tlsServers = {addr};
4917         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
4918 
4919         // The timeout of WaitForPrivateDnsValidation is 5 seconds which is still enough for
4920         // the testcase of UDP probe lost because the retry of UDP probe happens after 3 seconds.
4921         EXPECT_TRUE(
4922                 WaitForPrivateDnsValidation(tls.listen_address(), config.expectedValidationPass));
4923         EXPECT_EQ(dns.queries().size(), config.expectedUdpProbes);
4924         dns.clearQueries();
4925 
4926         // Test that Private DNS validation always pass in strict mode.
4927         parcel.tlsName = kDefaultPrivateDnsHostName;
4928         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
4929         EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
4930         EXPECT_EQ(dns.queries().size(), 0U);
4931 
4932         if (thread) {
4933             thread->join();
4934             thread.reset();
4935         }
4936     }
4937 }
4938 
TEST_F(ResolverTest,FlushNetworkCache)4939 TEST_F(ResolverTest, FlushNetworkCache) {
4940     SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
4941     test::DNSResponder dns;
4942     StartDns(dns, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
4943     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
4944 
4945     const hostent* result = gethostbyname("hello");
4946     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, kHelloExampleCom));
4947 
4948     // get result from cache
4949     result = gethostbyname("hello");
4950     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, kHelloExampleCom));
4951 
4952     EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(TEST_NETID).isOk());
4953 
4954     result = gethostbyname("hello");
4955     EXPECT_EQ(2U, GetNumQueriesForType(dns, ns_type::ns_t_a, kHelloExampleCom));
4956 }
4957 
TEST_F(ResolverTest,FlushNetworkCache_random)4958 TEST_F(ResolverTest, FlushNetworkCache_random) {
4959     SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
4960     constexpr int num_flush = 10;
4961     constexpr int num_queries = 20;
4962     test::DNSResponder dns;
4963     StartDns(dns, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
4964     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
4965     const addrinfo hints = {.ai_family = AF_INET};
4966 
4967     std::thread t([this]() {
4968         for (int i = 0; i < num_flush; ++i) {
4969             unsigned delay = arc4random_uniform(10 * 1000);  // 10ms
4970             usleep(delay);
4971             EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(TEST_NETID).isOk());
4972         }
4973     });
4974 
4975     for (int i = 0; i < num_queries; ++i) {
4976         ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
4977         EXPECT_TRUE(result != nullptr);
4978         EXPECT_EQ(kHelloExampleComAddrV4, ToString(result));
4979     }
4980     t.join();
4981 }
4982 
4983 // flush cache while one query is wait-for-response, another is pending.
TEST_F(ResolverTest,FlushNetworkCache_concurrent)4984 TEST_F(ResolverTest, FlushNetworkCache_concurrent) {
4985     SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
4986     const char* listen_addr1 = "127.0.0.9";
4987     const char* listen_addr2 = "127.0.0.10";
4988     test::DNSResponder dns1(listen_addr1);
4989     test::DNSResponder dns2(listen_addr2);
4990     StartDns(dns1, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
4991     StartDns(dns2, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
4992     addrinfo hints = {.ai_family = AF_INET};
4993 
4994     // step 1: set server#1 into deferred responding mode
4995     dns1.setDeferredResp(true);
4996     std::thread t1([&listen_addr1, &hints, this]() {
4997         ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr1}));
4998         // step 3: query
4999         ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
5000         // step 9: check result
5001         EXPECT_TRUE(result != nullptr);
5002         EXPECT_EQ(kHelloExampleComAddrV4, ToString(result));
5003     });
5004 
5005     // step 2: wait for the query to reach the server
5006     while (GetNumQueries(dns1, kHelloExampleCom) == 0) {
5007         usleep(1000);  // 1ms
5008     }
5009 
5010     std::thread t2([&listen_addr2, &hints, &dns2, this]() {
5011         ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr2}));
5012         // step 5: query (should be blocked in resolver)
5013         ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
5014         // step 7: check result
5015         EXPECT_TRUE(result != nullptr);
5016         EXPECT_EQ(kHelloExampleComAddrV4, ToString(result));
5017         EXPECT_EQ(1U, GetNumQueriesForType(dns2, ns_type::ns_t_a, kHelloExampleCom));
5018     });
5019 
5020     // step 4: wait a bit for the 2nd query to enter pending state
5021     usleep(100 * 1000);  // 100ms
5022     // step 6: flush cache (will unblock pending queries)
5023     EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(TEST_NETID).isOk());
5024     t2.join();
5025 
5026     // step 8: resume server#1
5027     dns1.setDeferredResp(false);
5028     t1.join();
5029 
5030     // step 10: verify if result is correctly cached
5031     dns2.clearQueries();
5032     ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
5033     EXPECT_EQ(0U, GetNumQueries(dns2, kHelloExampleCom));
5034     EXPECT_EQ(kHelloExampleComAddrV4, ToString(result));
5035 }
5036 
5037 // TODO: Perhaps to have a boundary conditions test for TCP and UDP.
TEST_F(ResolverTest,TcpQueryWithOversizePayload)5038 TEST_F(ResolverTest, TcpQueryWithOversizePayload) {
5039     test::DNSResponder dns;
5040     StartDns(dns, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
5041     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
5042 
5043     int fd = dns_open_proxy();
5044     ASSERT_TRUE(fd > 0);
5045 
5046     // Sending DNS query over TCP once the packet sizes exceed 512 bytes.
5047     // The raw data is combined with Question section and Additional section
5048     // Question section : query "hello.example.com", type A, class IN
5049     // Additional section : type OPT (41), Option PADDING, Option Length 546
5050     // Padding option which allows DNS clients and servers to artificially
5051     // increase the size of a DNS message by a variable number of bytes.
5052     // See also RFC7830, section 3
5053     const std::string query =
5054             "+c0BAAABAAAAAAABBWhlbGxvB2V4YW1wbGUDY29tAAABAAEAACkgAAAAgAACJgAMAiIAAAAAAAAAAAAAAAAAA"
5055             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5056             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5057             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5058             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5059             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5060             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5061             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5062             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
5063             "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
5064     const std::string cmd =
5065             "resnsend " + std::to_string(TEST_NETID) + " 0 " /* ResNsendFlags */ + query + '\0';
5066     ssize_t rc = TEMP_FAILURE_RETRY(write(fd, cmd.c_str(), cmd.size()));
5067     EXPECT_EQ(rc, static_cast<ssize_t>(cmd.size()));
5068     expectAnswersValid(fd, AF_INET, kHelloExampleComAddrV4);
5069     EXPECT_EQ(1U, GetNumQueriesForProtocol(dns, IPPROTO_TCP, kHelloExampleCom));
5070     EXPECT_EQ(0U, GetNumQueriesForProtocol(dns, IPPROTO_UDP, kHelloExampleCom));
5071 }
5072 
TEST_F(ResolverTest,TruncatedRspMode)5073 TEST_F(ResolverTest, TruncatedRspMode) {
5074     constexpr char listen_addr[] = "127.0.0.4";
5075     constexpr char listen_addr2[] = "127.0.0.5";
5076     constexpr char listen_srv[] = "53";
5077 
5078     test::DNSResponder dns(listen_addr, listen_srv, static_cast<ns_rcode>(-1));
5079     test::DNSResponder dns2(listen_addr2, listen_srv, static_cast<ns_rcode>(-1));
5080     // dns supports UDP only, dns2 support UDP and TCP
5081     dns.setResponseProbability(0.0, IPPROTO_TCP);
5082     StartDns(dns, kLargeCnameChainRecords);
5083     StartDns(dns2, kLargeCnameChainRecords);
5084 
5085     const struct TestConfig {
5086         const std::optional<int32_t> tcMode;
5087         const bool ret;
5088         const unsigned numQueries;
5089         std::string asParameters() const {
5090             return StringPrintf("tcMode: %d, ret: %s, numQueries: %u", tcMode.value_or(-1),
5091                                 ret ? "true" : "false", numQueries);
5092         }
5093     } testConfigs[]{
5094             // clang-format off
5095             {std::nullopt,                                      true,  0}, /* mode unset */
5096             {aidl::android::net::IDnsResolver::TC_MODE_DEFAULT, true,  0}, /* default mode */
5097             {-666,                                              false, 0}, /* invalid input */
5098             {aidl::android::net::IDnsResolver::TC_MODE_UDP_TCP, true,  1}, /* alternative mode */
5099             // clang-format on
5100     };
5101 
5102     for (const auto& config : testConfigs) {
5103         SCOPED_TRACE(config.asParameters());
5104 
5105         ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
5106         parcel.servers = {listen_addr, listen_addr2};
5107         ResolverOptionsParcel resolverOptions;
5108         if (config.tcMode.has_value()) resolverOptions.tcMode = config.tcMode.value();
5109         if (!mIsResolverOptionIPCSupported) {
5110             parcel.resolverOptions = resolverOptions;
5111             ASSERT_EQ(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk(),
5112                       config.ret);
5113         } else {
5114             ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk());
5115         }
5116         if (mIsResolverOptionIPCSupported) {
5117             ASSERT_EQ(mDnsClient.resolvService()
5118                               ->setResolverOptions(parcel.netId, resolverOptions)
5119                               .isOk(),
5120                       config.ret);
5121         }
5122 
5123         const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
5124         ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
5125         ASSERT_TRUE(result != nullptr);
5126         EXPECT_EQ(ToString(result), kHelloExampleComAddrV4);
5127         // TC_MODE_DEFAULT: resolver retries on TCP-only on each name server.
5128         // TC_MODE_UDP_TCP: resolver retries on TCP on the same server, falls back to UDP from next.
5129         ASSERT_EQ(GetNumQueriesForProtocol(dns, IPPROTO_UDP, kHelloExampleCom), 1U);
5130         ASSERT_EQ(GetNumQueriesForProtocol(dns, IPPROTO_TCP, kHelloExampleCom), 1U);
5131         ASSERT_EQ(GetNumQueriesForProtocol(dns2, IPPROTO_UDP, kHelloExampleCom), config.numQueries);
5132         ASSERT_EQ(GetNumQueriesForProtocol(dns2, IPPROTO_TCP, kHelloExampleCom), 1U);
5133 
5134         dns.clearQueries();
5135         dns2.clearQueries();
5136         ASSERT_TRUE(mDnsClient.resolvService()->flushNetworkCache(TEST_NETID).isOk());
5137 
5138         // Clear the stats to make the resolver always choose the same server for the first query.
5139         parcel.servers.clear();
5140         parcel.tlsServers.clear();
5141         if (!mIsResolverOptionIPCSupported) {
5142             ASSERT_EQ(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk(),
5143                       config.ret);
5144         } else {
5145             ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk());
5146         }
5147     }
5148 }
5149 
TEST_F(ResolverTest,RepeatedSetup_ResolverStatusRemains)5150 TEST_F(ResolverTest, RepeatedSetup_ResolverStatusRemains) {
5151     constexpr char unusable_listen_addr[] = "127.0.0.3";
5152     constexpr char listen_addr[] = "127.0.0.4";
5153     constexpr char hostname[] = "a.hello.query.";
5154     const auto repeatedSetResolversFromParcel = [&](const ResolverParamsParcel& parcel) {
5155         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5156         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5157         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5158     };
5159 
5160     test::DNSResponder dns(listen_addr);
5161     StartDns(dns, {{hostname, ns_type::ns_t_a, "1.2.3.3"}});
5162     test::DnsTlsFrontend tls1(listen_addr, "853", listen_addr, "53");
5163     ASSERT_TRUE(tls1.startServer());
5164 
5165     // Private DNS off mode.
5166     ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
5167     parcel.servers = {unusable_listen_addr, listen_addr};
5168     parcel.tlsServers.clear();
5169     ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5170 
5171     // Send a query.
5172     const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
5173     EXPECT_NE(safe_getaddrinfo(hostname, nullptr, &hints), nullptr);
5174 
5175     // Check the stats as expected.
5176     const std::vector<NameserverStats> expectedCleartextDnsStats = {
5177             NameserverStats(unusable_listen_addr).setInternalErrors(1),
5178             NameserverStats(listen_addr).setSuccesses(1),
5179     };
5180     EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
5181     EXPECT_EQ(GetNumQueries(dns, hostname), 1U);
5182 
5183     // The stats is supposed to remain as long as the list of cleartext DNS servers is unchanged.
5184     static const struct TestConfig {
5185         std::vector<std::string> servers;
5186         std::vector<std::string> tlsServers;
5187         std::string tlsName;
5188     } testConfigs[] = {
5189             // Private DNS opportunistic mode.
5190             {{listen_addr, unusable_listen_addr}, {listen_addr, unusable_listen_addr}, ""},
5191             {{unusable_listen_addr, listen_addr}, {unusable_listen_addr, listen_addr}, ""},
5192 
5193             // Private DNS strict mode.
5194             {{listen_addr, unusable_listen_addr}, {"127.0.0.100"}, kDefaultPrivateDnsHostName},
5195             {{unusable_listen_addr, listen_addr}, {"127.0.0.100"}, kDefaultPrivateDnsHostName},
5196 
5197             // Private DNS off mode.
5198             {{unusable_listen_addr, listen_addr}, {}, ""},
5199             {{listen_addr, unusable_listen_addr}, {}, ""},
5200     };
5201 
5202     for (const auto& config : testConfigs) {
5203         SCOPED_TRACE(fmt::format("testConfig: [{}] [{}] [{}]", fmt::join(config.servers, ","),
5204                                  fmt::join(config.tlsServers, ","), config.tlsName));
5205         parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
5206         parcel.servers = config.servers;
5207         parcel.tlsServers = config.tlsServers;
5208         parcel.tlsName = config.tlsName;
5209         repeatedSetResolversFromParcel(parcel);
5210         EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
5211 
5212         // The stats remains when the list of search domains changes.
5213         parcel.domains.push_back("tmp.domains");
5214         repeatedSetResolversFromParcel(parcel);
5215         EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
5216 
5217         // The stats remains when the parameters change (except maxSamples).
5218         parcel.sampleValiditySeconds++;
5219         parcel.successThreshold++;
5220         parcel.minSamples++;
5221         parcel.baseTimeoutMsec++;
5222         parcel.retryCount++;
5223         repeatedSetResolversFromParcel(parcel);
5224         EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
5225     }
5226 
5227     // The cache remains.
5228     EXPECT_NE(safe_getaddrinfo(hostname, nullptr, &hints), nullptr);
5229     EXPECT_EQ(GetNumQueries(dns, hostname), 1U);
5230 }
5231 
TEST_F(ResolverTest,RepeatedSetup_NoRedundantPrivateDnsValidation)5232 TEST_F(ResolverTest, RepeatedSetup_NoRedundantPrivateDnsValidation) {
5233     const std::string addr1 = getUniqueIPv4Address();  // For a workable DNS server.
5234     const std::string addr2 = getUniqueIPv4Address();  // For an unresponsive DNS server.
5235     const std::string unusable_addr = getUniqueIPv4Address();
5236     const auto waitForPrivateDnsStateUpdated = []() {
5237         // A buffer time for the PrivateDnsConfiguration instance to update its map,
5238         // mPrivateDnsValidateThreads, which is used for tracking validation threads.
5239         // Since there is a time gap between when PrivateDnsConfiguration reports
5240         // onPrivateDnsValidationEvent and when PrivateDnsConfiguration updates the map, this is a
5241         // workaround to avoid the test starts a subsequent resolver setup during the time gap.
5242         // TODO: Report onPrivateDnsValidationEvent after all the relevant updates are complete.
5243         // Reference to b/152009023.
5244         std::this_thread::sleep_for(20ms);
5245     };
5246 
5247     test::DNSResponder dns1(addr1);
5248     test::DNSResponder dns2(addr2);
5249     StartDns(dns1, {});
5250     StartDns(dns2, {});
5251     test::DnsTlsFrontend workableTls(addr1, "853", addr1, "53");
5252     test::DnsTlsFrontend unresponsiveTls(addr2, "853", addr2, "53");
5253     int validationAttemptsToUnresponsiveTls = 1;
5254     unresponsiveTls.setHangOnHandshakeForTesting(true);
5255     ASSERT_TRUE(workableTls.startServer());
5256     ASSERT_TRUE(unresponsiveTls.startServer());
5257 
5258     // First setup.
5259     ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
5260     parcel.servers = {addr1, addr2, unusable_addr};
5261     parcel.tlsServers = {addr1, addr2, unusable_addr};
5262     ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5263 
5264     // Check the validation results.
5265     EXPECT_TRUE(WaitForPrivateDnsValidation(workableTls.listen_address(), true));
5266     EXPECT_TRUE(WaitForPrivateDnsValidation(unusable_addr, false));
5267 
5268     // The validation is still in progress.
5269     EXPECT_EQ(unresponsiveTls.acceptConnectionsCount(), validationAttemptsToUnresponsiveTls);
5270 
5271     static const struct TestConfig {
5272         std::vector<std::string> tlsServers;
5273         std::string tlsName;
5274     } testConfigs[] = {
5275             {{addr1, addr2, unusable_addr}, ""},
5276             {{unusable_addr, addr1, addr2}, ""},
5277             {{unusable_addr, addr1, addr2}, kDefaultPrivateDnsHostName},
5278             {{addr1, addr2, unusable_addr}, kDefaultPrivateDnsHostName},
5279     };
5280 
5281     std::string TlsNameLastTime;
5282     for (const auto& config : testConfigs) {
5283         SCOPED_TRACE(fmt::format("testConfig: [{}] [{}]", fmt::join(config.tlsServers, ","),
5284                                  config.tlsName));
5285         parcel.servers = config.tlsServers;
5286         parcel.tlsServers = config.tlsServers;
5287         parcel.tlsName = config.tlsName;
5288         parcel.caCertificate = config.tlsName.empty() ? "" : kCaCert;
5289 
5290         const bool dnsModeChanged = (TlsNameLastTime != config.tlsName);
5291 
5292         waitForPrivateDnsStateUpdated();
5293         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5294 
5295         for (const auto& serverAddr : parcel.tlsServers) {
5296             SCOPED_TRACE(serverAddr);
5297             if (serverAddr == workableTls.listen_address()) {
5298                 if (dnsModeChanged) {
5299                     // Despite the identical IP address, the server is regarded as a different
5300                     // server when DnsTlsServer.name is different. The resolver treats it as a
5301                     // different object and begins the validation process.
5302                     EXPECT_TRUE(WaitForPrivateDnsValidation(serverAddr, true));
5303                 }
5304             } else if (serverAddr == unresponsiveTls.listen_address()) {
5305                 if (dnsModeChanged) {
5306                     // Despite the identical IP address, the server is regarded as a different
5307                     // server when DnsTlsServer.name is different. The resolver treats it as a
5308                     // different object and begins the validation process.
5309                     validationAttemptsToUnresponsiveTls++;
5310 
5311                     // This is the limitation from DnsTlsFrontend. DnsTlsFrontend can't operate
5312                     // concurrently. As soon as there's another connection request,
5313                     // DnsTlsFrontend resets the unique_fd to the new connection.
5314                     EXPECT_TRUE(WaitForPrivateDnsValidation(serverAddr, false));
5315                 }
5316             } else {
5317                 // Must be unusable_addr.
5318                 // In opportunistic mode, when a validation for a private DNS server fails, the
5319                 // resolver just marks the server as failed and doesn't re-evaluate it, but the
5320                 // server can be re-evaluated when setResolverConfiguration() is called.
5321                 // However, in strict mode, the resolver automatically re-evaluates the server and
5322                 // marks the server as in_progress until the validation succeeds, so repeated setup
5323                 // makes no effect.
5324                 if (dnsModeChanged || config.tlsName.empty() /* not in strict mode */) {
5325                     EXPECT_TRUE(WaitForPrivateDnsValidation(serverAddr, false));
5326                 }
5327             }
5328         }
5329 
5330         // Repeated setups make no effect in strict mode.
5331         waitForPrivateDnsStateUpdated();
5332         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5333         if (config.tlsName.empty()) {
5334             EXPECT_TRUE(WaitForPrivateDnsValidation(unusable_addr, false));
5335         }
5336         waitForPrivateDnsStateUpdated();
5337         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5338         if (config.tlsName.empty()) {
5339             EXPECT_TRUE(WaitForPrivateDnsValidation(unusable_addr, false));
5340         }
5341 
5342         EXPECT_EQ(unresponsiveTls.acceptConnectionsCount(), validationAttemptsToUnresponsiveTls);
5343 
5344         TlsNameLastTime = config.tlsName;
5345     }
5346 
5347     // Check that all the validation results are caught.
5348     // Note: it doesn't mean no validation being in progress.
5349     EXPECT_FALSE(hasUncaughtPrivateDnsValidation(addr1));
5350     EXPECT_FALSE(hasUncaughtPrivateDnsValidation(addr2));
5351     EXPECT_FALSE(hasUncaughtPrivateDnsValidation(unusable_addr));
5352 }
5353 
TEST_F(ResolverTest,RepeatedSetup_KeepChangingPrivateDnsServers)5354 TEST_F(ResolverTest, RepeatedSetup_KeepChangingPrivateDnsServers) {
5355     enum TlsServerState { WORKING, UNSUPPORTED, UNRESPONSIVE };
5356     const std::string addr1 = getUniqueIPv4Address();
5357     const std::string addr2 = getUniqueIPv4Address();
5358     const auto waitForPrivateDnsStateUpdated = []() {
5359         // A buffer time for PrivateDnsConfiguration to update its state. It prevents this test
5360         // being flaky. See b/152009023 for the reason.
5361         std::this_thread::sleep_for(20ms);
5362     };
5363 
5364     test::DNSResponder dns1(addr1);
5365     test::DNSResponder dns2(addr2);
5366     StartDns(dns1, {});
5367     StartDns(dns2, {});
5368     test::DnsTlsFrontend tls1(addr1, "853", addr1, "53");
5369     test::DnsTlsFrontend tls2(addr2, "853", addr2, "53");
5370     ASSERT_TRUE(tls1.startServer());
5371     ASSERT_TRUE(tls2.startServer());
5372 
5373     static const struct TestConfig {
5374         std::string tlsServer;
5375         std::string tlsName;
5376         bool expectNothingHappenWhenServerUnsupported;
5377         bool expectNothingHappenWhenServerUnresponsive;
5378         std::string asTestName() const {
5379             return fmt::format("{}, {}, {}, {}", tlsServer, tlsName,
5380                                expectNothingHappenWhenServerUnsupported,
5381                                expectNothingHappenWhenServerUnresponsive);
5382         }
5383     } testConfigs[] = {
5384             {{addr1}, "", false, false},
5385             {{addr2}, "", false, false},
5386             {{addr1}, "", false, true},
5387             {{addr2}, "", false, true},
5388 
5389             // expectNothingHappenWhenServerUnresponsive is false in the two cases because of the
5390             // limitation from DnsTlsFrontend which can't operate concurrently.
5391             {{addr1}, kDefaultPrivateDnsHostName, false, false},
5392             {{addr2}, kDefaultPrivateDnsHostName, false, false},
5393             {{addr1}, kDefaultPrivateDnsHostName, true, true},
5394             {{addr2}, kDefaultPrivateDnsHostName, true, true},
5395 
5396             // expectNothingHappenWhenServerUnresponsive is true in the two cases because of the
5397             // limitation from DnsTlsFrontend which can't operate concurrently.
5398             {{addr1}, "", true, false},
5399             {{addr2}, "", true, false},
5400             {{addr1}, "", true, true},
5401             {{addr2}, "", true, true},
5402     };
5403 
5404     for (const auto& serverState : {WORKING, UNSUPPORTED, UNRESPONSIVE}) {
5405         int testIndex = 0;
5406         for (const auto& config : testConfigs) {
5407             SCOPED_TRACE(fmt::format("serverState:{} testIndex:{} testConfig:[{}]", serverState,
5408                                      testIndex++, config.asTestName()));
5409             auto& tls = (config.tlsServer == addr1) ? tls1 : tls2;
5410 
5411             if (serverState == UNSUPPORTED && tls.running()) ASSERT_TRUE(tls.stopServer());
5412             if (serverState != UNSUPPORTED && !tls.running()) ASSERT_TRUE(tls.startServer());
5413 
5414             tls.setHangOnHandshakeForTesting(serverState == UNRESPONSIVE);
5415             const int connectCountsBefore = tls.acceptConnectionsCount();
5416 
5417             waitForPrivateDnsStateUpdated();
5418             ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
5419             parcel.servers = {config.tlsServer};
5420             parcel.tlsServers = {config.tlsServer};
5421             parcel.tlsName = config.tlsName;
5422             parcel.caCertificate = config.tlsName.empty() ? "" : kCaCert;
5423             ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5424 
5425             if (serverState == WORKING) {
5426                 EXPECT_TRUE(WaitForPrivateDnsValidation(config.tlsServer, true));
5427             } else if (serverState == UNSUPPORTED) {
5428                 if (config.expectNothingHappenWhenServerUnsupported) {
5429                     // It's possible that the resolver hasn't yet started to
5430                     // connect. Wait a while.
5431                     // TODO: See if we can get rid of the hard waiting time, such as comparing
5432                     // the CountDiff across two tests.
5433                     std::this_thread::sleep_for(100ms);
5434                     EXPECT_EQ(tls.acceptConnectionsCount(), connectCountsBefore);
5435                 } else {
5436                     EXPECT_TRUE(WaitForPrivateDnsValidation(config.tlsServer, false));
5437                 }
5438             } else {
5439                 // Must be UNRESPONSIVE.
5440                 // DnsTlsFrontend is the only signal for checking whether or not the resolver starts
5441                 // another validation when the server is unresponsive.
5442                 const int expectCountDiff =
5443                         config.expectNothingHappenWhenServerUnresponsive ? 0 : 1;
5444                 if (expectCountDiff == 0) {
5445                     // It's possible that the resolver hasn't yet started to
5446                     // connect. Wait a while.
5447                     std::this_thread::sleep_for(100ms);
5448                 } else {
5449                     EXPECT_TRUE(WaitForPrivateDnsValidation(config.tlsServer, false));
5450                 }
5451                 const auto condition = [&]() {
5452                     return tls.acceptConnectionsCount() == connectCountsBefore + expectCountDiff;
5453                 };
5454                 EXPECT_TRUE(PollForCondition(condition));
5455             }
5456         }
5457 
5458         // Set to off mode to reset the PrivateDnsConfiguration state.
5459         ResolverParamsParcel setupOffmode = DnsResponderClient::GetDefaultResolverParamsParcel();
5460         setupOffmode.tlsServers.clear();
5461         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(setupOffmode));
5462     }
5463 
5464     // Check that all the validation results are caught.
5465     // Note: it doesn't mean no validation being in progress.
5466     EXPECT_FALSE(hasUncaughtPrivateDnsValidation(addr1));
5467     EXPECT_FALSE(hasUncaughtPrivateDnsValidation(addr2));
5468 }
5469 
TEST_F(ResolverTest,PermissionCheckOnCertificateInjection)5470 TEST_F(ResolverTest, PermissionCheckOnCertificateInjection) {
5471     ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
5472     parcel.caCertificate = kCaCert;
5473     ASSERT_TRUE(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk());
5474 
5475     for (const uid_t uid : {AID_SYSTEM, TEST_UID}) {
5476         ScopedChangeUID scopedChangeUID(uid);
5477         auto status = mDnsClient.resolvService()->setResolverConfiguration(parcel);
5478         EXPECT_EQ(status.getExceptionCode(), EX_SECURITY);
5479     }
5480 }
5481 
5482 // Parameterized tests.
5483 // TODO: Merge the existing tests as parameterized test if possible.
5484 // TODO: Perhaps move parameterized tests to an independent file.
5485 enum class CallType { GETADDRINFO, GETHOSTBYNAME };
5486 class ResolverParameterizedTest : public ResolverTest,
5487                                   public testing::WithParamInterface<CallType> {
5488   protected:
VerifyQueryHelloExampleComV4(const test::DNSResponder & dns,const CallType calltype,const bool verifyNumQueries=true)5489     void VerifyQueryHelloExampleComV4(const test::DNSResponder& dns, const CallType calltype,
5490                                       const bool verifyNumQueries = true) {
5491         if (calltype == CallType::GETADDRINFO) {
5492             const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
5493             ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
5494             ASSERT_TRUE(result != nullptr);
5495             EXPECT_EQ(kHelloExampleComAddrV4, ToString(result));
5496         } else if (calltype == CallType::GETHOSTBYNAME) {
5497             const hostent* result = gethostbyname("hello");
5498             ASSERT_TRUE(result != nullptr);
5499             ASSERT_EQ(4, result->h_length);
5500             ASSERT_FALSE(result->h_addr_list[0] == nullptr);
5501             EXPECT_EQ(kHelloExampleComAddrV4, ToString(result));
5502             EXPECT_TRUE(result->h_addr_list[1] == nullptr);
5503         } else {
5504             FAIL() << "Unsupported call type: " << static_cast<uint32_t>(calltype);
5505         }
5506         if (verifyNumQueries) EXPECT_EQ(1U, GetNumQueries(dns, kHelloExampleCom));
5507     }
5508 };
5509 
5510 INSTANTIATE_TEST_SUITE_P(QueryCallTest, ResolverParameterizedTest,
5511                          testing::Values(CallType::GETADDRINFO, CallType::GETHOSTBYNAME),
__anonccf6c6f21902(const testing::TestParamInfo<CallType>& info) 5512                          [](const testing::TestParamInfo<CallType>& info) {
5513                              switch (info.param) {
5514                                  case CallType::GETADDRINFO:
5515                                      return "GetAddrInfo";
5516                                  case CallType::GETHOSTBYNAME:
5517                                      return "GetHostByName";
5518                                  default:
5519                                      return "InvalidParameter";  // Should not happen.
5520                              }
5521                          });
5522 
TEST_P(ResolverParameterizedTest,AuthoritySectionAndAdditionalSection)5523 TEST_P(ResolverParameterizedTest, AuthoritySectionAndAdditionalSection) {
5524     // DNS response may have more information in authority section and additional section.
5525     // Currently, getanswer() of packages/modules/DnsResolver/getaddrinfo.cpp doesn't parse the
5526     // content of authority section and additional section. Test these sections if they crash
5527     // the resolver, just in case. See also RFC 1035 section 4.1.
5528     const auto& calltype = GetParam();
5529     test::DNSHeader header(kDefaultDnsHeader);
5530 
5531     // Create a DNS response which has a authoritative nameserver record in authority
5532     // section and its relevant address record in additional section.
5533     //
5534     // Question
5535     //   hello.example.com.     IN      A
5536     // Answer
5537     //   hello.example.com.     IN      A   1.2.3.4
5538     // Authority:
5539     //   hello.example.com.     IN      NS  ns1.example.com.
5540     // Additional:
5541     //   ns1.example.com.       IN      A   5.6.7.8
5542     //
5543     // A response may have only question, answer, and authority section. Current testing response
5544     // should be able to cover this condition.
5545 
5546     // Question section.
5547     test::DNSQuestion question{
5548             .qname = {.name = kHelloExampleCom},
5549             .qtype = ns_type::ns_t_a,
5550             .qclass = ns_c_in,
5551     };
5552     header.questions.push_back(std::move(question));
5553 
5554     // Answer section.
5555     test::DNSRecord recordAnswer{
5556             .name = {.name = kHelloExampleCom},
5557             .rtype = ns_type::ns_t_a,
5558             .rclass = ns_c_in,
5559             .ttl = 0,  // no cache
5560     };
5561     EXPECT_TRUE(test::DNSResponder::fillRdata(kHelloExampleComAddrV4, recordAnswer));
5562     header.answers.push_back(std::move(recordAnswer));
5563 
5564     // Authority section.
5565     test::DNSRecord recordAuthority{
5566             .name = {.name = kHelloExampleCom},
5567             .rtype = ns_type::ns_t_ns,
5568             .rclass = ns_c_in,
5569             .ttl = 0,  // no cache
5570     };
5571     EXPECT_TRUE(test::DNSResponder::fillRdata("ns1.example.com.", recordAuthority));
5572     header.authorities.push_back(std::move(recordAuthority));
5573 
5574     // Additional section.
5575     test::DNSRecord recordAdditional{
5576             .name = {.name = "ns1.example.com."},
5577             .rtype = ns_type::ns_t_a,
5578             .rclass = ns_c_in,
5579             .ttl = 0,  // no cache
5580     };
5581     EXPECT_TRUE(test::DNSResponder::fillRdata("5.6.7.8", recordAdditional));
5582     header.additionals.push_back(std::move(recordAdditional));
5583 
5584     // Start DNS server.
5585     test::DNSResponder dns(test::DNSResponder::MappingType::DNS_HEADER);
5586     dns.addMappingDnsHeader(kHelloExampleCom, ns_type::ns_t_a, header);
5587     ASSERT_TRUE(dns.startServer());
5588     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
5589     dns.clearQueries();
5590 
5591     // Expect that get the address and the resolver doesn't crash.
5592     VerifyQueryHelloExampleComV4(dns, calltype);
5593 }
5594 
TEST_P(ResolverParameterizedTest,MessageCompression)5595 TEST_P(ResolverParameterizedTest, MessageCompression) {
5596     const auto& calltype = GetParam();
5597 
5598     // The response with compressed domain name by a pointer. See RFC 1035 section 4.1.4.
5599     //
5600     // Ignoring the other fields of the message, the domain name of question section and answer
5601     // section are presented as:
5602     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5603     // 12 |           5           |           h           |
5604     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5605     // 14 |           e           |           l           |
5606     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5607     // 16 |           l           |           o           |
5608     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5609     // 18 |           7           |           e           |
5610     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5611     // 20 |           x           |           a           |
5612     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5613     // 22 |           m           |           p           |
5614     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5615     // 24 |           l           |           e           |
5616     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5617     // 26 |           3           |           c           |
5618     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5619     // 28 |           o           |           m           |
5620     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5621     // 30 |           0           |          ...          |
5622     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5623     //
5624     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5625     // 35 | 1  1|                12                       |
5626     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5627     const std::vector<uint8_t> kResponseAPointer = {
5628             /* Header */
5629             0x00, 0x00, /* Transaction ID: 0x0000 */
5630             0x81, 0x80, /* Flags: qr rd ra */
5631             0x00, 0x01, /* Questions: 1 */
5632             0x00, 0x01, /* Answer RRs: 1 */
5633             0x00, 0x00, /* Authority RRs: 0 */
5634             0x00, 0x00, /* Additional RRs: 0 */
5635             /* Queries */
5636             0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65,
5637             0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name: hello.example.com */
5638             0x00, 0x01,                   /* Type: A */
5639             0x00, 0x01,                   /* Class: IN */
5640             /* Answers */
5641             0xc0, 0x0c,             /* Name: hello.example.com (a pointer) */
5642             0x00, 0x01,             /* Type: A */
5643             0x00, 0x01,             /* Class: IN */
5644             0x00, 0x00, 0x00, 0x00, /* Time to live: 0 */
5645             0x00, 0x04,             /* Data length: 4 */
5646             0x01, 0x02, 0x03, 0x04  /* Address: 1.2.3.4 */
5647     };
5648 
5649     // The response with compressed domain name by a sequence of labels ending with a pointer. See
5650     // RFC 1035 section 4.1.4.
5651     //
5652     // Ignoring the other fields of the message, the domain name of question section and answer
5653     // section are presented as:
5654     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5655     // 12 |           5           |           h           |
5656     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5657     // 14 |           e           |           l           |
5658     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5659     // 16 |           l           |           o           |
5660     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5661     // 18 |           7           |           e           |
5662     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5663     // 20 |           x           |           a           |
5664     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5665     // 22 |           m           |           p           |
5666     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5667     // 24 |           l           |           e           |
5668     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5669     // 26 |           3           |           c           |
5670     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5671     // 28 |           o           |           m           |
5672     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5673     // 30 |           0           |          ...          |
5674     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5675     //
5676     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5677     // 35 |           5           |           h           |
5678     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5679     // 37 |           e           |           l           |
5680     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5681     // 39 |           l           |           o           |
5682     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5683     // 41 | 1  1|                18                       |
5684     //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
5685     const std::vector<uint8_t> kResponseLabelEndingWithAPointer = {
5686             /* Header */
5687             0x00, 0x00, /* Transaction ID: 0x0000 */
5688             0x81, 0x80, /* Flags: qr rd ra */
5689             0x00, 0x01, /* Questions: 1 */
5690             0x00, 0x01, /* Answer RRs: 1 */
5691             0x00, 0x00, /* Authority RRs: 0 */
5692             0x00, 0x00, /* Additional RRs: 0 */
5693             /* Queries */
5694             0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65,
5695             0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name: hello.example.com */
5696             0x00, 0x01,                   /* Type: A */
5697             0x00, 0x01,                   /* Class: IN */
5698             /* Answers */
5699             0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0xc0,
5700             0x12,                   /* Name: hello.example.com (a label ending with a pointer) */
5701             0x00, 0x01,             /* Type: A */
5702             0x00, 0x01,             /* Class: IN */
5703             0x00, 0x00, 0x00, 0x00, /* Time to live: 0 */
5704             0x00, 0x04,             /* Data length: 4 */
5705             0x01, 0x02, 0x03, 0x04  /* Address: 1.2.3.4 */
5706     };
5707 
5708     for (const auto& response : {kResponseAPointer, kResponseLabelEndingWithAPointer}) {
5709         SCOPED_TRACE(StringPrintf("Hex dump: %s", toHex(makeSlice(response)).c_str()));
5710 
5711         test::DNSResponder dns(test::DNSResponder::MappingType::BINARY_PACKET);
5712         dns.addMappingBinaryPacket(kHelloExampleComQueryV4, response);
5713         StartDns(dns, {});
5714         ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
5715 
5716         // Expect no cache because the TTL of testing responses are 0.
5717         VerifyQueryHelloExampleComV4(dns, calltype);
5718     }
5719 }
5720 
TEST_P(ResolverParameterizedTest,TruncatedResponse)5721 TEST_P(ResolverParameterizedTest, TruncatedResponse) {
5722     const auto& calltype = GetParam();
5723 
5724     test::DNSResponder dns;
5725     StartDns(dns, kLargeCnameChainRecords);
5726     ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
5727 
5728     // Expect UDP response is truncated. The resolver retries over TCP. See RFC 1035 section 4.2.1.
5729     VerifyQueryHelloExampleComV4(dns, calltype, false);
5730     EXPECT_EQ(1U, GetNumQueriesForProtocol(dns, IPPROTO_UDP, kHelloExampleCom));
5731     EXPECT_EQ(1U, GetNumQueriesForProtocol(dns, IPPROTO_TCP, kHelloExampleCom));
5732 }
5733 
TEST_F(ResolverTest,KeepListeningUDP)5734 TEST_F(ResolverTest, KeepListeningUDP) {
5735     constexpr char listen_addr1[] = "127.0.0.4";
5736     constexpr char listen_addr2[] = "127.0.0.5";
5737     constexpr char host_name[] = "howdy.example.com.";
5738     const std::vector<DnsRecord> records = {
5739             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
5740     };
5741     const std::vector<int> params = {300, 25, 8, 8, 1000 /* BASE_TIMEOUT_MSEC */,
5742                                      1 /* retry count */};
5743     const int delayTimeMs = 1500;
5744 
5745     test::DNSResponder neverRespondDns(listen_addr2, "53", static_cast<ns_rcode>(-1));
5746     neverRespondDns.setResponseProbability(0.0);
5747     StartDns(neverRespondDns, records);
5748     ScopedSystemProperties scopedSystemProperties(
5749             "persist.device_config.netd_native.keep_listening_udp", "1");
5750     // Re-setup test network to make experiment flag take effect.
5751     resetNetwork();
5752 
5753     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr1, listen_addr2},
5754                                                   kDefaultSearchDomains, params));
5755     // There are 2 DNS servers for this test.
5756     // |delayedDns| will be blocked for |delayTimeMs|, then start to respond to requests.
5757     // |neverRespondDns| will never respond.
5758     // In the first try, resolver will send query to |delayedDns| but get timeout error
5759     // because |delayTimeMs| > DNS timeout.
5760     // Then it's the second try, resolver will send query to |neverRespondDns| and
5761     // listen on both servers. Resolver will receive the answer coming from |delayedDns|.
5762 
5763     test::DNSResponder delayedDns(listen_addr1);
5764     delayedDns.setResponseDelayMs(delayTimeMs);
5765     StartDns(delayedDns, records);
5766 
5767     // Specify hints to ensure resolver doing query only 1 round.
5768     const addrinfo hints = {.ai_family = AF_INET6, .ai_socktype = SOCK_DGRAM};
5769     ScopedAddrinfo result = safe_getaddrinfo(host_name, nullptr, &hints);
5770     EXPECT_TRUE(result != nullptr);
5771 
5772     std::string result_str = ToString(result);
5773     EXPECT_TRUE(result_str == "::1.2.3.4") << ", result_str='" << result_str << "'";
5774 }
5775 
TEST_F(ResolverTest,GetAddrInfoParallelLookupTimeout)5776 TEST_F(ResolverTest, GetAddrInfoParallelLookupTimeout) {
5777     constexpr char listen_addr[] = "127.0.0.4";
5778     constexpr char host_name[] = "howdy.example.com.";
5779     constexpr int TIMING_TOLERANCE_MS = 200;
5780     constexpr int DNS_TIMEOUT_MS = 1000;
5781     const std::vector<DnsRecord> records = {
5782             {host_name, ns_type::ns_t_a, "1.2.3.4"},
5783             {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
5784     };
5785     const std::vector<int> params = {300, 25, 8, 8, DNS_TIMEOUT_MS /* BASE_TIMEOUT_MSEC */,
5786                                      1 /* retry count */};
5787     test::DNSResponder neverRespondDns(listen_addr, "53", static_cast<ns_rcode>(-1));
5788     neverRespondDns.setResponseProbability(0.0);
5789     StartDns(neverRespondDns, records);
5790     ScopedSystemProperties scopedSystemProperties(
5791             "persist.device_config.netd_native.parallel_lookup_release", "1");
5792     // The default value of parallel_lookup_sleep_time should be very small
5793     // that we can ignore in this test case.
5794     // Re-setup test network to make experiment flag take effect.
5795     resetNetwork();
5796 
5797     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr}, kDefaultSearchDomains, params));
5798     neverRespondDns.clearQueries();
5799 
5800     // Use a never respond DNS server to verify if the A/AAAA queries are sent in parallel.
5801     // The resolver parameters are set to timeout 1s and retry 1 times.
5802     // So we expect the safe_getaddrinfo_time_taken() might take ~1s to
5803     // return when parallel lookup is enabled. And the DNS server should receive 2 queries.
5804     const addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
5805     auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(host_name, nullptr, hints);
5806 
5807     EXPECT_TRUE(result == nullptr);
5808     EXPECT_NEAR(DNS_TIMEOUT_MS, timeTakenMs, TIMING_TOLERANCE_MS)
5809             << "took time should approximate equal timeout";
5810     EXPECT_EQ(2U, GetNumQueries(neverRespondDns, host_name));
5811     ExpectDnsEvent(INetdEventListener::EVENT_GETADDRINFO, RCODE_TIMEOUT, host_name, {});
5812 }
5813 
TEST_F(ResolverTest,GetAddrInfoParallelLookupSleepTime)5814 TEST_F(ResolverTest, GetAddrInfoParallelLookupSleepTime) {
5815     constexpr char listen_addr[] = "127.0.0.4";
5816     constexpr int TIMING_TOLERANCE_MS = 200;
5817     const std::vector<DnsRecord> records = {
5818             {kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4},
5819             {kHelloExampleCom, ns_type::ns_t_aaaa, kHelloExampleComAddrV6},
5820     };
5821     const std::vector<int> params = {300, 25, 8, 8, 1000 /* BASE_TIMEOUT_MSEC */,
5822                                      1 /* retry count */};
5823     test::DNSResponder dns(listen_addr);
5824     StartDns(dns, records);
5825     ScopedSystemProperties scopedSystemProperties1(
5826             "persist.device_config.netd_native.parallel_lookup_release", "1");
5827     constexpr int PARALLEL_LOOKUP_SLEEP_TIME_MS = 500;
5828     ScopedSystemProperties scopedSystemProperties2(
5829             "persist.device_config.netd_native.parallel_lookup_sleep_time",
5830             std::to_string(PARALLEL_LOOKUP_SLEEP_TIME_MS));
5831     // Re-setup test network to make experiment flag take effect.
5832     resetNetwork();
5833 
5834     ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr}, kDefaultSearchDomains, params));
5835     dns.clearQueries();
5836 
5837     // Expect the safe_getaddrinfo_time_taken() might take ~500ms to return because we set
5838     // parallel_lookup_sleep_time to 500ms.
5839     const addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
5840     auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(kHelloExampleCom, nullptr, hints);
5841 
5842     EXPECT_NE(nullptr, result);
5843     EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(
5844                                            {kHelloExampleComAddrV4, kHelloExampleComAddrV6}));
5845     EXPECT_NEAR(PARALLEL_LOOKUP_SLEEP_TIME_MS, timeTakenMs, TIMING_TOLERANCE_MS)
5846             << "took time should approximate equal timeout";
5847     EXPECT_EQ(2U, GetNumQueries(dns, kHelloExampleCom));
5848 
5849     // Expect the PARALLEL_LOOKUP_SLEEP_TIME_MS won't affect the query under cache hit case.
5850     dns.clearQueries();
5851     std::tie(result, timeTakenMs) = safe_getaddrinfo_time_taken(kHelloExampleCom, nullptr, hints);
5852     EXPECT_NE(nullptr, result);
5853     EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(
5854                                            {kHelloExampleComAddrV4, kHelloExampleComAddrV6}));
5855     EXPECT_GT(PARALLEL_LOOKUP_SLEEP_TIME_MS, timeTakenMs);
5856     EXPECT_EQ(0U, GetNumQueries(dns, kHelloExampleCom));
5857 }
5858 
TEST_F(ResolverTest,BlockDnsQueryUidDoesNotLeadToBadServer)5859 TEST_F(ResolverTest, BlockDnsQueryUidDoesNotLeadToBadServer) {
5860     SKIP_IF_BPF_NOT_SUPPORTED;
5861     constexpr char listen_addr1[] = "127.0.0.4";
5862     constexpr char listen_addr2[] = "::1";
5863     test::DNSResponder dns1(listen_addr1);
5864     test::DNSResponder dns2(listen_addr2);
5865     StartDns(dns1, {});
5866     StartDns(dns2, {});
5867 
5868     std::vector<std::string> servers = {listen_addr1, listen_addr2};
5869     ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers));
5870     dns1.clearQueries();
5871     dns2.clearQueries();
5872     {
5873         ScopeBlockedUIDRule scopeBlockUidRule(mDnsClient.netdService(), TEST_UID);
5874         // Start querying ten times.
5875         for (int i = 0; i < 10; i++) {
5876             std::string hostName = fmt::format("blocked{}.com", i);
5877             const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
5878             // The query result between R+ and Q would be different, but we don't really care
5879             // about the result here because this test is only used to ensure blocked uid rule
5880             // won't cause bad servers.
5881             safe_getaddrinfo(hostName.c_str(), nullptr, &hints);
5882         }
5883     }
5884     ResolverParamsParcel setupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
5885     // If api level >= 30 (R+), expect all query packets to be blocked, hence we should not see any
5886     // of their stats show up. Otherwise, all queries should succeed.
5887     const std::vector<NameserverStats> expectedDnsStats = {
5888             NameserverStats(listen_addr1).setSuccesses(isAtLeastR ? 0 : setupParams.maxSamples),
5889             NameserverStats(listen_addr2),
5890     };
5891     expectStatsEqualTo(expectedDnsStats);
5892     // If api level >= 30 (R+), expect server won't receive any queries,
5893     // otherwise expect 20 == 10 * (setupParams.domains.size() + 1) queries.
5894     EXPECT_EQ(dns1.queries().size(), isAtLeastR ? 0U : 10 * (setupParams.domains.size() + 1));
5895     EXPECT_EQ(dns2.queries().size(), 0U);
5896 }
5897 
TEST_F(ResolverTest,DnsServerSelection)5898 TEST_F(ResolverTest, DnsServerSelection) {
5899     test::DNSResponder dns1("127.0.0.3");
5900     test::DNSResponder dns2("127.0.0.4");
5901     test::DNSResponder dns3("127.0.0.5");
5902 
5903     dns1.setResponseDelayMs(10);
5904     dns2.setResponseDelayMs(25);
5905     dns3.setResponseDelayMs(50);
5906     StartDns(dns1, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
5907     StartDns(dns2, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
5908     StartDns(dns3, {{kHelloExampleCom, ns_type::ns_t_a, kHelloExampleComAddrV4}});
5909 
5910     // NOTE: the servers must be sorted alphabetically.
5911     std::vector<std::string> serverList = {
5912             dns1.listen_address(),
5913             dns2.listen_address(),
5914             dns3.listen_address(),
5915     };
5916 
5917     do {
5918         SCOPED_TRACE(fmt::format("testConfig: [{}]", fmt::join(serverList, ", ")));
5919         const int queryNum = 50;
5920         int64_t accumulatedTime = 0;
5921 
5922         // The flag can be reset any time. It's better to re-setup the flag in each iteration.
5923         ScopedSystemProperties scopedSystemProperties(kSortNameserversFlag, "1");
5924 
5925         // Restart the testing network to 1) make the flag take effect and 2) reset the statistics.
5926         resetNetwork();
5927 
5928         // DnsServerSelection doesn't apply to private DNS.
5929         ResolverParamsParcel setupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
5930         setupParams.servers = serverList;
5931         setupParams.tlsServers.clear();
5932         ASSERT_TRUE(mDnsClient.SetResolversFromParcel(setupParams));
5933 
5934         // DNSResponder doesn't handle queries concurrently, so don't allow more than
5935         // one in-flight query.
5936         for (int i = 0; i < queryNum; i++) {
5937             Stopwatch s;
5938             int fd = resNetworkQuery(TEST_NETID, kHelloExampleCom, ns_c_in, ns_t_a,
5939                                      ANDROID_RESOLV_NO_CACHE_LOOKUP);
5940             expectAnswersValid(fd, AF_INET, kHelloExampleComAddrV4);
5941             accumulatedTime += s.timeTakenUs();
5942         }
5943 
5944         const int dns1Count = dns1.queries().size();
5945         const int dns2Count = dns2.queries().size();
5946         const int dns3Count = dns3.queries().size();
5947 
5948         // All of the servers have ever been selected. In addition, the less latency server
5949         // is selected more frequently.
5950         EXPECT_GT(dns1Count, 0);
5951         EXPECT_GT(dns2Count, 0);
5952         EXPECT_GT(dns3Count, 0);
5953         EXPECT_GE(dns1Count, dns2Count);
5954         EXPECT_GE(dns2Count, dns3Count);
5955 
5956         const int averageTime = accumulatedTime / queryNum;
5957         LOG(INFO) << "ResolverTest#DnsServerSelection: averageTime " << averageTime << "us";
5958 
5959         dns1.clearQueries();
5960         dns2.clearQueries();
5961         dns3.clearQueries();
5962     } while (std::next_permutation(serverList.begin(), serverList.end()));
5963 }
5964 
TEST_F(ResolverTest,MultipleDotQueriesInOnePacket)5965 TEST_F(ResolverTest, MultipleDotQueriesInOnePacket) {
5966     constexpr char hostname1[] = "query1.example.com.";
5967     constexpr char hostname2[] = "query2.example.com.";
5968     const std::vector<DnsRecord> records = {
5969             {hostname1, ns_type::ns_t_a, "1.2.3.4"},
5970             {hostname2, ns_type::ns_t_a, "1.2.3.5"},
5971     };
5972 
5973     const std::string addr = getUniqueIPv4Address();
5974     test::DNSResponder dns(addr);
5975     StartDns(dns, records);
5976     test::DnsTlsFrontend tls(addr, "853", addr, "53");
5977     ASSERT_TRUE(tls.startServer());
5978 
5979     // Set up resolver to strict mode.
5980     auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
5981     parcel.servers = {addr};
5982     parcel.tlsServers = {addr};
5983     parcel.tlsName = kDefaultPrivateDnsHostName;
5984     parcel.caCertificate = kCaCert;
5985     ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
5986     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
5987     EXPECT_TRUE(tls.waitForQueries(1));
5988     tls.clearQueries();
5989     dns.clearQueries();
5990 
5991     const auto queryAndCheck = [&](const std::string& hostname,
5992                                    const std::vector<DnsRecord>& records) {
5993         SCOPED_TRACE(hostname);
5994 
5995         const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
5996         auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname.c_str(), nullptr, hints);
5997 
5998         std::vector<std::string> expectedAnswers;
5999         for (const auto& r : records) {
6000             if (r.host_name == hostname) expectedAnswers.push_back(r.addr);
6001         }
6002 
6003         EXPECT_LE(timeTakenMs, 200);
6004         ASSERT_NE(result, nullptr);
6005         EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(expectedAnswers));
6006     };
6007 
6008     // Set tls to reply DNS responses in one TCP packet and not to close the connection from its
6009     // side.
6010     tls.setDelayQueries(2);
6011     tls.setDelayQueriesTimeout(500);
6012     tls.setPassiveClose(true);
6013 
6014     // Start sending DNS requests at the same time.
6015     std::array<std::thread, 2> threads;
6016     threads[0] = std::thread(queryAndCheck, hostname1, records);
6017     threads[1] = std::thread(queryAndCheck, hostname2, records);
6018 
6019     threads[0].join();
6020     threads[1].join();
6021 
6022     // Also check no additional queries due to DoT reconnection.
6023     EXPECT_EQ(tls.queries(), 2);
6024 }
6025 
6026 // ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
6027 // The resolver sends queries to address A, and then there will be a TunForwarder helping forward
6028 // the packets to address B, which is the address on which the testing server is listening. The
6029 // answer packets responded from the testing server go through the reverse path back to the
6030 // resolver.
6031 //
6032 // To achieve the that, it needs to set up a interface with routing rules. Tests are not
6033 // supposed to initiate DNS servers on their own; instead, some utilities are added to the class to
6034 // help the setup.
6035 //
6036 // An example of how to use it:
6037 // TEST_F() {
6038 //     ScopedPhysicalNetwork network = CreateScopedPhysicalNetwork(V4);
6039 //     network.init();
6040 //
6041 //     auto dns = network.addIpv4Dns();
6042 //     StartDns(dns.dnsServer, {});
6043 //
6044 //     network.setDnsConfiguration();
6045 //     network.startTunForwarder();
6046 //
6047 //     // Send queries here
6048 // }
6049 
6050 class ResolverMultinetworkTest : public ResolverTest {
6051   protected:
6052     enum class ConnectivityType { V4, V6, V4V6 };
6053     static constexpr int TEST_NETID_BASE = 10000;
6054 
6055     struct DnsServerPair {
DnsServerPairResolverMultinetworkTest::DnsServerPair6056         DnsServerPair(std::shared_ptr<test::DNSResponder> server, std::string addr)
6057             : dnsServer(server), dnsAddr(addr) {}
6058         std::shared_ptr<test::DNSResponder> dnsServer;
6059         std::string dnsAddr;  // The DNS server address used for setResolverConfiguration().
6060         // TODO: Add test::DnsTlsFrontend* and std::string for DoT.
6061     };
6062 
6063     class ScopedNetwork {
6064       public:
ScopedNetwork(unsigned netId,ConnectivityType type,INetd * netdSrv,IDnsResolver * dnsResolvSrv,const char * networkName)6065         ScopedNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
6066                       IDnsResolver* dnsResolvSrv, const char* networkName)
6067             : mNetId(netId),
6068               mConnectivityType(type),
6069               mNetdSrv(netdSrv),
6070               mDnsResolvSrv(dnsResolvSrv),
6071               mNetworkName(networkName) {
6072             mIfname = fmt::format("testtun{}", netId);
6073         }
~ScopedNetwork()6074         virtual ~ScopedNetwork() {
6075             if (mNetdSrv != nullptr) mNetdSrv->networkDestroy(mNetId);
6076             if (mDnsResolvSrv != nullptr) mDnsResolvSrv->destroyNetworkCache(mNetId);
6077         }
6078 
6079         Result<void> init();
addIpv4Dns()6080         Result<DnsServerPair> addIpv4Dns() { return addDns(ConnectivityType::V4); }
addIpv6Dns()6081         Result<DnsServerPair> addIpv6Dns() { return addDns(ConnectivityType::V6); }
startTunForwarder()6082         bool startTunForwarder() { return mTunForwarder->startForwarding(); }
6083         bool setDnsConfiguration() const;
6084         bool clearDnsConfiguration() const;
netId() const6085         unsigned netId() const { return mNetId; }
name() const6086         std::string name() const { return mNetworkName; }
6087 
6088       protected:
6089         // Subclasses should implement it to decide which network should be create.
6090         virtual Result<void> createNetwork() const = 0;
6091 
6092         const unsigned mNetId;
6093         const ConnectivityType mConnectivityType;
6094         INetd* mNetdSrv;
6095         IDnsResolver* mDnsResolvSrv;
6096         const std::string mNetworkName;
6097         std::string mIfname;
6098         std::unique_ptr<TunForwarder> mTunForwarder;
6099         std::vector<DnsServerPair> mDnsServerPairs;
6100 
6101       private:
6102         Result<DnsServerPair> addDns(ConnectivityType connectivity);
6103         // Assuming mNetId is unique during ResolverMultinetworkTest, make the
6104         // address based on it to avoid conflicts.
makeIpv4AddrString(uint8_t n) const6105         std::string makeIpv4AddrString(uint8_t n) const {
6106             return StringPrintf("192.168.%u.%u", (mNetId - TEST_NETID_BASE), n);
6107         }
makeIpv6AddrString(uint8_t n) const6108         std::string makeIpv6AddrString(uint8_t n) const {
6109             return StringPrintf("2001:db8:%u::%u", (mNetId - TEST_NETID_BASE), n);
6110         }
6111     };
6112 
6113     class ScopedPhysicalNetwork : public ScopedNetwork {
6114       public:
ScopedPhysicalNetwork(unsigned netId,const char * networkName)6115         ScopedPhysicalNetwork(unsigned netId, const char* networkName)
6116             : ScopedNetwork(netId, ConnectivityType::V4V6, nullptr, nullptr, networkName) {}
ScopedPhysicalNetwork(unsigned netId,ConnectivityType type,INetd * netdSrv,IDnsResolver * dnsResolvSrv,const char * name="Physical")6117         ScopedPhysicalNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
6118                               IDnsResolver* dnsResolvSrv, const char* name = "Physical")
6119             : ScopedNetwork(netId, type, netdSrv, dnsResolvSrv, name) {}
6120 
6121       protected:
createNetwork() const6122         Result<void> createNetwork() const override {
6123             ::ndk::ScopedAStatus r;
6124             if (DnsResponderClient::isRemoteVersionSupported(mNetdSrv, 6)) {
6125                 const auto& config = DnsResponderClient::makeNativeNetworkConfig(
6126                         mNetId, NativeNetworkType::PHYSICAL, INetd::PERMISSION_NONE,
6127                         /*secure=*/false);
6128                 r = mNetdSrv->networkCreate(config);
6129             } else {
6130 #pragma clang diagnostic push
6131 #pragma clang diagnostic ignored "-Wdeprecated-declarations"
6132                 r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_NONE);
6133 #pragma clang diagnostic pop
6134             }
6135 
6136             if (!r.isOk()) {
6137                 return Error() << r.getMessage();
6138             }
6139             return {};
6140         }
6141     };
6142 
6143     class ScopedVirtualNetwork : public ScopedNetwork {
6144       public:
ScopedVirtualNetwork(unsigned netId,ConnectivityType type,INetd * netdSrv,IDnsResolver * dnsResolvSrv,const char * name,bool isSecure)6145         ScopedVirtualNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
6146                              IDnsResolver* dnsResolvSrv, const char* name, bool isSecure)
6147             : ScopedNetwork(netId, type, netdSrv, dnsResolvSrv, name), mIsSecure(isSecure) {}
~ScopedVirtualNetwork()6148         ~ScopedVirtualNetwork() {
6149             if (!mVpnIsolationUids.empty()) {
6150                 const std::vector<int> tmpUids(mVpnIsolationUids.begin(), mVpnIsolationUids.end());
6151                 mNetdSrv->firewallRemoveUidInterfaceRules(tmpUids);
6152             }
6153         }
6154         // Enable VPN isolation. Ensures that uid can only receive packets on mIfname.
enableVpnIsolation(int uid)6155         Result<void> enableVpnIsolation(int uid) {
6156             if (auto r = mNetdSrv->firewallAddUidInterfaceRules(mIfname, {uid}); !r.isOk()) {
6157                 return Error() << r.getMessage();
6158             }
6159             mVpnIsolationUids.insert(uid);
6160             return {};
6161         }
disableVpnIsolation(int uid)6162         Result<void> disableVpnIsolation(int uid) {
6163             if (auto r = mNetdSrv->firewallRemoveUidInterfaceRules({static_cast<int>(uid)});
6164                 !r.isOk()) {
6165                 return Error() << r.getMessage();
6166             }
6167             mVpnIsolationUids.erase(uid);
6168             return {};
6169         }
addUser(uid_t uid) const6170         Result<void> addUser(uid_t uid) const { return addUidRange(uid, uid); }
addUidRange(uid_t from,uid_t to) const6171         Result<void> addUidRange(uid_t from, uid_t to) const {
6172             if (auto r = mNetdSrv->networkAddUidRanges(mNetId, {makeUidRangeParcel(from, to)});
6173                 !r.isOk()) {
6174                 return Error() << r.getMessage();
6175             }
6176             return {};
6177         }
6178 
6179       protected:
createNetwork() const6180         Result<void> createNetwork() const override {
6181             ::ndk::ScopedAStatus r;
6182             if (DnsResponderClient::isRemoteVersionSupported(mNetdSrv, 6)) {
6183                 const auto& config = DnsResponderClient::makeNativeNetworkConfig(
6184                         mNetId, NativeNetworkType::VIRTUAL, INetd::PERMISSION_NONE, mIsSecure);
6185                 r = mNetdSrv->networkCreate(config);
6186             } else {
6187 #pragma clang diagnostic push
6188 #pragma clang diagnostic ignored "-Wdeprecated-declarations"
6189                 r = mNetdSrv->networkCreateVpn(mNetId, mIsSecure);
6190 #pragma clang diagnostic pop
6191             }
6192 
6193             if (!r.isOk()) {
6194                 return Error() << r.getMessage();
6195             }
6196             return {};
6197         }
6198 
6199         bool mIsSecure = false;
6200         std::unordered_set<int> mVpnIsolationUids;
6201     };
6202 
SetUp()6203     void SetUp() override {
6204         ResolverTest::SetUp();
6205         ASSERT_NE(mDnsClient.netdService(), nullptr);
6206         ASSERT_NE(mDnsClient.resolvService(), nullptr);
6207     }
6208 
TearDown()6209     void TearDown() override {
6210         ResolverTest::TearDown();
6211         // Restore default network
6212         if (mStoredDefaultNetwork >= 0) {
6213             mDnsClient.netdService()->networkSetDefault(mStoredDefaultNetwork);
6214         }
6215     }
6216 
CreateScopedPhysicalNetwork(ConnectivityType type,const char * name="Physical")6217     ScopedPhysicalNetwork CreateScopedPhysicalNetwork(ConnectivityType type,
6218                                                       const char* name = "Physical") {
6219         return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService(), name};
6220     }
CreateScopedVirtualNetwork(ConnectivityType type,bool isSecure,const char * name="Virtual")6221     ScopedVirtualNetwork CreateScopedVirtualNetwork(ConnectivityType type, bool isSecure,
6222                                                     const char* name = "Virtual") {
6223         return {getFreeNetId(), type,    mDnsClient.netdService(), mDnsClient.resolvService(),
6224                 name,           isSecure};
6225     }
6226     void StartDns(test::DNSResponder& dns, const std::vector<DnsRecord>& records);
setDefaultNetwork(int netId)6227     void setDefaultNetwork(int netId) {
6228         // Save current default network at the first call.
6229         std::call_once(defaultNetworkFlag, [&]() {
6230             ASSERT_TRUE(mDnsClient.netdService()->networkGetDefault(&mStoredDefaultNetwork).isOk());
6231         });
6232         ASSERT_TRUE(mDnsClient.netdService()->networkSetDefault(netId).isOk());
6233     }
getFreeNetId()6234     unsigned getFreeNetId() {
6235         if (mNextNetId == TEST_NETID_BASE + 256) mNextNetId = TEST_NETID_BASE;
6236         return mNextNetId++;
6237     }
6238 
6239   private:
6240     // Use a different netId because this class inherits from the class ResolverTest which
6241     // always creates TEST_NETID in setup. It's incremented when CreateScoped{Physical,
6242     // Virtual}Network() is called.
6243     // Note: 255 is the maximum number of (mNextNetId - TEST_NETID_BASE) here as mNextNetId
6244     // is used to create address.
6245     unsigned mNextNetId = TEST_NETID_BASE;
6246     // Use -1 to represent that default network was not modified because
6247     // real netId must be an unsigned value.
6248     int mStoredDefaultNetwork = -1;
6249     std::once_flag defaultNetworkFlag;
6250 };
6251 
init()6252 Result<void> ResolverMultinetworkTest::ScopedNetwork::init() {
6253     if (mNetdSrv == nullptr || mDnsResolvSrv == nullptr) return Error() << "srv not available";
6254     unique_fd ufd = TunForwarder::createTun(mIfname);
6255     if (!ufd.ok()) {
6256         return Errorf("createTun for {} failed", mIfname);
6257     }
6258     mTunForwarder = std::make_unique<TunForwarder>(std::move(ufd));
6259 
6260     if (auto r = createNetwork(); !r.ok()) {
6261         return r;
6262     }
6263     if (auto r = mDnsResolvSrv->createNetworkCache(mNetId); !r.isOk()) {
6264         return Error() << r.getMessage();
6265     }
6266     if (auto r = mNetdSrv->networkAddInterface(mNetId, mIfname); !r.isOk()) {
6267         return Error() << r.getMessage();
6268     }
6269 
6270     if (mConnectivityType == ConnectivityType::V4 || mConnectivityType == ConnectivityType::V4V6) {
6271         const std::string v4Addr = makeIpv4AddrString(1);
6272         if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v4Addr, 32); !r.isOk()) {
6273             return Error() << r.getMessage();
6274         }
6275         if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "0.0.0.0/0", ""); !r.isOk()) {
6276             return Error() << r.getMessage();
6277         }
6278     }
6279     if (mConnectivityType == ConnectivityType::V6 || mConnectivityType == ConnectivityType::V4V6) {
6280         const std::string v6Addr = makeIpv6AddrString(1);
6281         if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v6Addr, 128); !r.isOk()) {
6282             return Error() << r.getMessage();
6283         }
6284         if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "::/0", ""); !r.isOk()) {
6285             return Error() << r.getMessage();
6286         }
6287     }
6288 
6289     return {};
6290 }
6291 
StartDns(test::DNSResponder & dns,const std::vector<DnsRecord> & records)6292 void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
6293                                         const std::vector<DnsRecord>& records) {
6294     ResolverTest::StartDns(dns, records);
6295 
6296     // Bind the DNSResponder's sockets to the network if specified.
6297     if (std::optional<unsigned> netId = dns.getNetwork(); netId.has_value()) {
6298         setNetworkForSocket(netId.value(), dns.getUdpSocket());
6299         setNetworkForSocket(netId.value(), dns.getTcpSocket());
6300     }
6301 }
6302 
addDns(ConnectivityType type)6303 Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
6304         ConnectivityType type) {
6305     const int index = mDnsServerPairs.size();
6306     const int prefixLen = (type == ConnectivityType::V4) ? 32 : 128;
6307 
6308     const std::function<std::string(unsigned)> makeIpString =
6309             std::bind((type == ConnectivityType::V4) ? &ScopedNetwork::makeIpv4AddrString
6310                                                      : &ScopedNetwork::makeIpv6AddrString,
6311                       this, std::placeholders::_1);
6312 
6313     std::string src1 = makeIpString(1);            // The address from which the resolver will send.
6314     std::string dst1 = makeIpString(
6315             index + 100 +
6316             (mNetId - TEST_NETID_BASE));           // The address to which the resolver will send.
6317     std::string src2 = dst1;                       // The address translated from src1.
6318     std::string dst2 = makeIpString(
6319             index + 200 + (mNetId - TEST_NETID_BASE));  // The address translated from dst2.
6320 
6321     if (!mTunForwarder->addForwardingRule({src1, dst1}, {src2, dst2}) ||
6322         !mTunForwarder->addForwardingRule({dst2, src2}, {dst1, src1})) {
6323         return Errorf("Failed to add the rules ({}, {}, {}, {})", src1, dst1, src2, dst2);
6324     }
6325 
6326     if (!mNetdSrv->interfaceAddAddress(mIfname, dst2, prefixLen).isOk()) {
6327         return Errorf("interfaceAddAddress({}, {}, {}) failed", mIfname, dst2, prefixLen);
6328     }
6329 
6330     return mDnsServerPairs.emplace_back(std::make_shared<test::DNSResponder>(mNetId, dst2), dst1);
6331 }
6332 
setDnsConfiguration() const6333 bool ResolverMultinetworkTest::ScopedNetwork::setDnsConfiguration() const {
6334     if (mDnsResolvSrv == nullptr) return false;
6335     ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
6336     parcel.tlsServers.clear();
6337     parcel.netId = mNetId;
6338     parcel.servers.clear();
6339     for (const auto& pair : mDnsServerPairs) {
6340         parcel.servers.push_back(pair.dnsAddr);
6341     }
6342     return mDnsResolvSrv->setResolverConfiguration(parcel).isOk();
6343 }
6344 
clearDnsConfiguration() const6345 bool ResolverMultinetworkTest::ScopedNetwork::clearDnsConfiguration() const {
6346     if (mDnsResolvSrv == nullptr) return false;
6347     return mDnsResolvSrv->destroyNetworkCache(mNetId).isOk() &&
6348            mDnsResolvSrv->createNetworkCache(mNetId).isOk();
6349 }
6350 
6351 namespace {
6352 
6353 // Convenient wrapper for making getaddrinfo call like framework.
android_getaddrinfofornet_wrapper(const char * name,int netId)6354 Result<ScopedAddrinfo> android_getaddrinfofornet_wrapper(const char* name, int netId) {
6355     // Use the same parameter as libcore/ojluni/src/main/java/java/net/Inet6AddressImpl.java.
6356     static const addrinfo hints = {
6357             .ai_flags = AI_ADDRCONFIG,
6358             .ai_family = AF_UNSPEC,
6359             .ai_socktype = SOCK_STREAM,
6360     };
6361     addrinfo* result = nullptr;
6362     if (int r = android_getaddrinfofornet(name, nullptr, &hints, netId, MARK_UNSET, &result)) {
6363         return Error() << r;
6364     }
6365     return ScopedAddrinfo(result);
6366 }
6367 
expectDnsWorksForUid(const char * name,unsigned netId,uid_t uid,const std::vector<std::string> & expectedResult)6368 void expectDnsWorksForUid(const char* name, unsigned netId, uid_t uid,
6369                           const std::vector<std::string>& expectedResult) {
6370     ScopedChangeUID scopedChangeUID(uid);
6371     auto result = android_getaddrinfofornet_wrapper(name, netId);
6372     ASSERT_RESULT_OK(result);
6373     ScopedAddrinfo ai_result(std::move(result.value()));
6374     std::vector<std::string> result_strs = ToStrings(ai_result);
6375     EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
6376 }
6377 
6378 }  // namespace
6379 
TEST_F(ResolverMultinetworkTest,GetAddrInfo_AI_ADDRCONFIG)6380 TEST_F(ResolverMultinetworkTest, GetAddrInfo_AI_ADDRCONFIG) {
6381     constexpr char host_name[] = "ohayou.example.com.";
6382 
6383     const std::array<ConnectivityType, 3> allTypes = {
6384             ConnectivityType::V4,
6385             ConnectivityType::V6,
6386             ConnectivityType::V4V6,
6387     };
6388     for (const auto& type : allTypes) {
6389         SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
6390 
6391         // Create a network.
6392         ScopedPhysicalNetwork network = CreateScopedPhysicalNetwork(type);
6393         ASSERT_RESULT_OK(network.init());
6394 
6395         // Add a testing DNS server.
6396         const Result<DnsServerPair> dnsPair =
6397                 (type == ConnectivityType::V4) ? network.addIpv4Dns() : network.addIpv6Dns();
6398         ASSERT_RESULT_OK(dnsPair);
6399         StartDns(*dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.0"},
6400                                        {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});
6401 
6402         // Set up resolver and start forwarding.
6403         ASSERT_TRUE(network.setDnsConfiguration());
6404         ASSERT_TRUE(network.startTunForwarder());
6405 
6406         auto result = android_getaddrinfofornet_wrapper(host_name, network.netId());
6407         ASSERT_RESULT_OK(result);
6408         ScopedAddrinfo ai_result(std::move(result.value()));
6409         std::vector<std::string> result_strs = ToStrings(ai_result);
6410         std::vector<std::string> expectedResult;
6411         size_t expectedQueries = 0;
6412 
6413         if (type == ConnectivityType::V6 || type == ConnectivityType::V4V6) {
6414             expectedResult.emplace_back("2001:db8:cafe:d00d::31");
6415             expectedQueries++;
6416         }
6417         if (type == ConnectivityType::V4 || type == ConnectivityType::V4V6) {
6418             expectedResult.emplace_back("192.0.2.0");
6419             expectedQueries++;
6420         }
6421         EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
6422         EXPECT_EQ(GetNumQueries(*dnsPair->dnsServer, host_name), expectedQueries);
6423     }
6424 }
6425 
TEST_F(ResolverMultinetworkTest,NetworkDestroyedDuringQueryInFlight)6426 TEST_F(ResolverMultinetworkTest, NetworkDestroyedDuringQueryInFlight) {
6427     constexpr char host_name[] = "ohayou.example.com.";
6428 
6429     // Create a network and add an ipv4 DNS server.
6430     auto network = std::make_unique<ScopedPhysicalNetwork>(getFreeNetId(), ConnectivityType::V4V6,
6431                                                            mDnsClient.netdService(),
6432                                                            mDnsClient.resolvService());
6433     ASSERT_RESULT_OK(network->init());
6434     const Result<DnsServerPair> dnsPair = network->addIpv4Dns();
6435     ASSERT_RESULT_OK(dnsPair);
6436 
6437     // Set the DNS server unresponsive.
6438     dnsPair->dnsServer->setResponseProbability(0.0);
6439     dnsPair->dnsServer->setErrorRcode(static_cast<ns_rcode>(-1));
6440     StartDns(*dnsPair->dnsServer, {});
6441 
6442     // Set up resolver and start forwarding.
6443     ASSERT_TRUE(network->setDnsConfiguration());
6444     ASSERT_TRUE(network->startTunForwarder());
6445 
6446     // Expect the things happening in order:
6447     // 1. The thread sends the query to the dns server which is unresponsive.
6448     // 2. The network is destroyed while the thread is waiting for the response from the dns server.
6449     // 3. After the dns server timeout, the thread retries but fails to connect.
6450     std::thread lookup([&]() {
6451         int fd = resNetworkQuery(network->netId(), host_name, ns_c_in, ns_t_a, 0);
6452         EXPECT_TRUE(fd != -1);
6453         expectAnswersNotValid(fd, -ETIMEDOUT);
6454     });
6455 
6456     // Tear down the network as soon as the dns server receives the query.
6457     const auto condition = [&]() { return GetNumQueries(*dnsPair->dnsServer, host_name) == 1U; };
6458     EXPECT_TRUE(PollForCondition(condition));
6459     network.reset();
6460 
6461     lookup.join();
6462 }
6463 
TEST_F(ResolverMultinetworkTest,OneCachePerNetwork)6464 TEST_F(ResolverMultinetworkTest, OneCachePerNetwork) {
6465     SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
6466     constexpr char host_name[] = "ohayou.example.com.";
6467 
6468     ScopedPhysicalNetwork network1 = CreateScopedPhysicalNetwork(ConnectivityType::V4V6);
6469     ScopedPhysicalNetwork network2 = CreateScopedPhysicalNetwork(ConnectivityType::V4V6);
6470     ASSERT_RESULT_OK(network1.init());
6471     ASSERT_RESULT_OK(network2.init());
6472 
6473     const Result<DnsServerPair> dnsPair1 = network1.addIpv4Dns();
6474     const Result<DnsServerPair> dnsPair2 = network2.addIpv4Dns();
6475     ASSERT_RESULT_OK(dnsPair1);
6476     ASSERT_RESULT_OK(dnsPair2);
6477     StartDns(*dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.0"}});
6478     StartDns(*dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.1"}});
6479 
6480     // Set up resolver for network 1 and start forwarding.
6481     ASSERT_TRUE(network1.setDnsConfiguration());
6482     ASSERT_TRUE(network1.startTunForwarder());
6483 
6484     // Set up resolver for network 2 and start forwarding.
6485     ASSERT_TRUE(network2.setDnsConfiguration());
6486     ASSERT_TRUE(network2.startTunForwarder());
6487 
6488     // Send the same queries to both networks.
6489     int fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
6490     int fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
6491 
6492     expectAnswersValid(fd1, AF_INET, "192.0.2.0");
6493     expectAnswersValid(fd2, AF_INET, "192.0.2.1");
6494     EXPECT_EQ(GetNumQueries(*dnsPair1->dnsServer, host_name), 1U);
6495     EXPECT_EQ(GetNumQueries(*dnsPair2->dnsServer, host_name), 1U);
6496 
6497     // Flush the cache of network 1, and send the queries again.
6498     EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(network1.netId()).isOk());
6499     fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
6500     fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
6501 
6502     expectAnswersValid(fd1, AF_INET, "192.0.2.0");
6503     expectAnswersValid(fd2, AF_INET, "192.0.2.1");
6504     EXPECT_EQ(GetNumQueries(*dnsPair1->dnsServer, host_name), 2U);
6505     EXPECT_EQ(GetNumQueries(*dnsPair2->dnsServer, host_name), 1U);
6506 }
6507 
TEST_F(ResolverMultinetworkTest,DnsWithVpn)6508 TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
6509     SKIP_IF_BPF_NOT_SUPPORTED;
6510     SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
6511     constexpr char host_name[] = "ohayou.example.com.";
6512     constexpr char ipv4_addr[] = "192.0.2.0";
6513     constexpr char ipv6_addr[] = "2001:db8:cafe:d00d::31";
6514 
6515     const std::pair<ConnectivityType, std::vector<std::string>> testPairs[] = {
6516             {ConnectivityType::V4, {ipv4_addr}},
6517             {ConnectivityType::V6, {ipv6_addr}},
6518             {ConnectivityType::V4V6, {ipv6_addr, ipv4_addr}},
6519     };
6520     for (const auto& [type, result] : testPairs) {
6521         SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
6522 
6523         // Create a network.
6524         ScopedPhysicalNetwork underlyingNetwork = CreateScopedPhysicalNetwork(type, "Underlying");
6525         ScopedVirtualNetwork bypassableVpnNetwork =
6526                 CreateScopedVirtualNetwork(type, false, "BypassableVpn");
6527         ScopedVirtualNetwork secureVpnNetwork = CreateScopedVirtualNetwork(type, true, "SecureVpn");
6528 
6529         ASSERT_RESULT_OK(underlyingNetwork.init());
6530         ASSERT_RESULT_OK(bypassableVpnNetwork.init());
6531         ASSERT_RESULT_OK(secureVpnNetwork.init());
6532         ASSERT_RESULT_OK(bypassableVpnNetwork.addUser(TEST_UID));
6533         ASSERT_RESULT_OK(secureVpnNetwork.addUser(TEST_UID2));
6534 
6535         auto setupDnsFn = [&](std::shared_ptr<test::DNSResponder> dnsServer,
6536                               ScopedNetwork* nw) -> void {
6537             StartDns(*dnsServer, {{host_name, ns_type::ns_t_a, ipv4_addr},
6538                                   {host_name, ns_type::ns_t_aaaa, ipv6_addr}});
6539             ASSERT_TRUE(nw->setDnsConfiguration());
6540             ASSERT_TRUE(nw->startTunForwarder());
6541         };
6542         // Add a testing DNS server to networks.
6543         const Result<DnsServerPair> underlyingPair = (type == ConnectivityType::V4)
6544                                                              ? underlyingNetwork.addIpv4Dns()
6545                                                              : underlyingNetwork.addIpv6Dns();
6546         ASSERT_RESULT_OK(underlyingPair);
6547         const Result<DnsServerPair> bypassableVpnPair = (type == ConnectivityType::V4)
6548                                                                 ? bypassableVpnNetwork.addIpv4Dns()
6549                                                                 : bypassableVpnNetwork.addIpv6Dns();
6550         ASSERT_RESULT_OK(bypassableVpnPair);
6551         const Result<DnsServerPair> secureVpnPair = (type == ConnectivityType::V4)
6552                                                             ? secureVpnNetwork.addIpv4Dns()
6553                                                             : secureVpnNetwork.addIpv6Dns();
6554         ASSERT_RESULT_OK(secureVpnPair);
6555         // Set up resolver and start forwarding for networks.
6556         setupDnsFn(underlyingPair->dnsServer, &underlyingNetwork);
6557         setupDnsFn(bypassableVpnPair->dnsServer, &bypassableVpnNetwork);
6558         setupDnsFn(secureVpnPair->dnsServer, &secureVpnNetwork);
6559 
6560         setDefaultNetwork(underlyingNetwork.netId());
6561         const unsigned underlyingNetId = underlyingNetwork.netId();
6562         const unsigned bypassableVpnNetId = bypassableVpnNetwork.netId();
6563         const unsigned secureVpnNetId = secureVpnNetwork.netId();
6564         // We've called setNetworkForProcess in SetupOemNetwork, so reset to default first.
6565         ScopedSetNetworkForProcess scopedSetNetworkForProcess(NETID_UNSET);
6566         auto expectDnsQueryCountsFn = [&](size_t count,
6567                                           std::shared_ptr<test::DNSResponder> dnsServer,
6568                                           unsigned expectedDnsNetId) -> void {
6569             EXPECT_EQ(GetNumQueries(*dnsServer, host_name), count);
6570             EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(expectedDnsNetId).isOk());
6571             dnsServer->clearQueries();
6572             // Give DnsResolver some time to clear cache to avoid race.
6573             usleep(5 * 1000);
6574         };
6575 
6576         // Create a object to represent default network, do not init it.
6577         ScopedPhysicalNetwork defaultNetwork{NETID_UNSET, "Default"};
6578 
6579         // Test VPN with DNS server under 4 different network selection scenarios.
6580         // See the test config for the expectation.
6581         const struct TestConfig {
6582             ScopedNetwork* selectedNetwork;
6583             unsigned expectedDnsNetId;
6584             std::shared_ptr<test::DNSResponder> expectedDnsServer;
6585         } vpnWithDnsServerConfigs[]{
6586                 // clang-format off
6587                 // Queries use the bypassable VPN by default.
6588                 {&defaultNetwork,       bypassableVpnNetId, bypassableVpnPair->dnsServer},
6589                 // Choosing the underlying network works because the VPN is bypassable.
6590                 {&underlyingNetwork,    underlyingNetId,    underlyingPair->dnsServer},
6591                 // Selecting the VPN sends the query on the VPN.
6592                 {&bypassableVpnNetwork, bypassableVpnNetId, bypassableVpnPair->dnsServer},
6593                 // TEST_UID does not have access to the secure VPN.
6594                 {&secureVpnNetwork,     bypassableVpnNetId, bypassableVpnPair->dnsServer},
6595                 // clang-format on
6596         };
6597         for (const auto& config : vpnWithDnsServerConfigs) {
6598             SCOPED_TRACE(fmt::format("Bypassble VPN with DnsServer, selectedNetwork = {}",
6599                                      config.selectedNetwork->name()));
6600             expectDnsWorksForUid(host_name, config.selectedNetwork->netId(), TEST_UID, result);
6601             expectDnsQueryCountsFn(result.size(), config.expectedDnsServer,
6602                                    config.expectedDnsNetId);
6603         }
6604 
6605         std::vector<ScopedNetwork*> nwVec{&defaultNetwork, &underlyingNetwork,
6606                                           &bypassableVpnNetwork, &secureVpnNetwork};
6607         // Test the VPN without DNS server with the same combination as before.
6608         ASSERT_TRUE(bypassableVpnNetwork.clearDnsConfiguration());
6609         // Test bypassable VPN, TEST_UID
6610         for (const auto* selectedNetwork : nwVec) {
6611             SCOPED_TRACE(fmt::format("Bypassble VPN without DnsServer, selectedNetwork = {}",
6612                                      selectedNetwork->name()));
6613             expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID, result);
6614             expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
6615         }
6616 
6617         // The same test scenario as before plus enableVpnIsolation for secure VPN, TEST_UID2.
6618         for (bool enableVpnIsolation : {false, true}) {
6619             SCOPED_TRACE(fmt::format("enableVpnIsolation = {}", enableVpnIsolation));
6620             if (enableVpnIsolation) {
6621                 EXPECT_RESULT_OK(secureVpnNetwork.enableVpnIsolation(TEST_UID2));
6622             }
6623 
6624             // Test secure VPN without DNS server.
6625             ASSERT_TRUE(secureVpnNetwork.clearDnsConfiguration());
6626             for (const auto* selectedNetwork : nwVec) {
6627                 SCOPED_TRACE(fmt::format("Secure VPN without DnsServer, selectedNetwork = {}",
6628                                          selectedNetwork->name()));
6629                 expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
6630                 expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
6631             }
6632 
6633             // Test secure VPN with DNS server.
6634             ASSERT_TRUE(secureVpnNetwork.setDnsConfiguration());
6635             for (const auto* selectedNetwork : nwVec) {
6636                 SCOPED_TRACE(fmt::format("Secure VPN with DnsServer, selectedNetwork = {}",
6637                                          selectedNetwork->name()));
6638                 expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
6639                 expectDnsQueryCountsFn(result.size(), secureVpnPair->dnsServer, secureVpnNetId);
6640             }
6641 
6642             if (enableVpnIsolation) {
6643                 EXPECT_RESULT_OK(secureVpnNetwork.disableVpnIsolation(TEST_UID2));
6644             }
6645         }
6646     }
6647 }
6648