• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include <netdutils/NetNativeTestBase.h>
20 #include <resolv_stats_test_utils.h>
21 
22 #include "PrivateDnsConfiguration.h"
23 #include "resolv_cache.h"
24 #include "tests/dns_responder/dns_responder.h"
25 #include "tests/dns_responder/dns_tls_frontend.h"
26 #include "tests/resolv_test_utils.h"
27 
28 namespace android::net {
29 
30 using namespace std::chrono_literals;
31 
32 class PrivateDnsConfigurationTest : public NetNativeTestBase {
33   public:
34     using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
35 
SetUpTestSuite()36     static void SetUpTestSuite() {
37         // stopServer() will be called in their destructor.
38         ASSERT_TRUE(tls1.startServer());
39         ASSERT_TRUE(tls2.startServer());
40         ASSERT_TRUE(backend.startServer());
41         ASSERT_TRUE(backend1ForUdpProbe.startServer());
42         ASSERT_TRUE(backend2ForUdpProbe.startServer());
43     }
44 
SetUp()45     void SetUp() {
46         mPdc.setObserver(&mObserver);
47         mPdc.mBackoffBuilder.withInitialRetransmissionTime(std::chrono::seconds(1))
48                 .withMaximumRetransmissionTime(std::chrono::seconds(1));
49 
50         // The default and sole action when the observer is notified of onValidationStateUpdate.
51         // Don't override the action. In other words, don't use WillOnce() or WillRepeatedly()
52         // when mObserver.onValidationStateUpdate is expected to be called, like:
53         //
54         //   EXPECT_CALL(mObserver, onValidationStateUpdate).WillOnce(Return());
55         //
56         // This is to ensure that tests can monitor how many validation threads are running. Tests
57         // must wait until every validation thread finishes.
58         ON_CALL(mObserver, onValidationStateUpdate)
59                 .WillByDefault([&](const std::string& server, Validation validation, uint32_t) {
60                     std::lock_guard guard(mObserver.lock);
61                     if (validation == Validation::in_process) {
62                         auto it = mObserver.serverStateMap.find(server);
63                         if (it == mObserver.serverStateMap.end() ||
64                             it->second != Validation::in_process) {
65                             // Increment runningThreads only when receive the first in_process
66                             // notification. The rest of the continuous in_process notifications
67                             // are due to probe retry which runs on the same thread.
68                             // TODO: consider adding onValidationThreadStart() and
69                             // onValidationThreadEnd() callbacks.
70                             mObserver.runningThreads++;
71                         }
72                     } else if (validation == Validation::success ||
73                                validation == Validation::fail) {
74                         mObserver.runningThreads--;
75                     }
76                     mObserver.serverStateMap[server] = validation;
77                 });
78 
79         // Create a NetConfig for stats.
80         EXPECT_EQ(0, resolv_create_cache_for_net(kNetId));
81     }
82 
TearDown()83     void TearDown() {
84         // Reset the state for the next test.
85         resolv_delete_cache_for_net(kNetId);
86         mPdc.set(kNetId, kMark, {}, {}, {}, {});
87     }
88 
89   protected:
90     class MockObserver : public PrivateDnsValidationObserver {
91       public:
92         MOCK_METHOD(void, onValidationStateUpdate,
93                     (const std::string& serverIp, Validation validation, uint32_t netId),
94                     (override));
95 
getServerStateMap() const96         std::map<std::string, Validation> getServerStateMap() const {
97             std::lock_guard guard(lock);
98             return serverStateMap;
99         }
100 
removeFromServerStateMap(const std::string & server)101         void removeFromServerStateMap(const std::string& server) {
102             std::lock_guard guard(lock);
103             if (const auto it = serverStateMap.find(server); it != serverStateMap.end())
104                 serverStateMap.erase(it);
105         }
106 
107         // The current number of validation threads running.
108         std::atomic<int> runningThreads = 0;
109 
110         mutable std::mutex lock;
111         std::map<std::string, Validation> serverStateMap GUARDED_BY(lock);
112     };
113 
expectPrivateDnsStatus(PrivateDnsMode mode)114     void expectPrivateDnsStatus(PrivateDnsMode mode) {
115         // Use PollForCondition because mObserver is notified asynchronously.
116         EXPECT_TRUE(PollForCondition([&]() { return checkPrivateDnsStatus(mode); }));
117     }
118 
checkPrivateDnsStatus(PrivateDnsMode mode)119     bool checkPrivateDnsStatus(PrivateDnsMode mode) {
120         const PrivateDnsStatus status = mPdc.getStatus(kNetId);
121         if (status.mode != mode) return false;
122 
123         std::map<std::string, Validation> serverStateMap;
124         for (const auto& [server, validation] : status.dotServersMap) {
125             serverStateMap[ToString(&server.ss)] = validation;
126         }
127         return (serverStateMap == mObserver.getServerStateMap());
128     }
129 
hasPrivateDnsServer(const ServerIdentity & identity,unsigned netId)130     bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
131         return mPdc.getDotServer(identity, netId).ok();
132     }
133 
134     static constexpr uint32_t kNetId = 30;
135     static constexpr uint32_t kMark = 30;
136     static constexpr char kBackend[] = "127.0.2.1";
137     static constexpr char kServer1[] = "127.0.2.2";
138     static constexpr char kServer2[] = "127.0.2.3";
139 
140     MockObserver mObserver;
141     inline static PrivateDnsConfiguration mPdc;
142 
143     // TODO: Because incorrect CAs result in validation failed in strict mode, have
144     // PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate().
145     inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
146     inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
147     inline static test::DNSResponder backend{kBackend, "53"};
148     inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
149     inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
150 };
151 
TEST_F(PrivateDnsConfigurationTest,ValidationSuccess)152 TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
153     testing::InSequence seq;
154     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
155     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
156 
157     EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
158     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
159 
160     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
161 }
162 
TEST_F(PrivateDnsConfigurationTest,ValidationFail_Opportunistic)163 TEST_F(PrivateDnsConfigurationTest, ValidationFail_Opportunistic) {
164     ASSERT_TRUE(backend.stopServer());
165 
166     testing::InSequence seq;
167     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
168     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
169 
170     EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
171     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
172 
173     // Strictly wait for all of the validation finish; otherwise, the test can crash somehow.
174     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
175     ASSERT_TRUE(backend.startServer());
176 }
177 
TEST_F(PrivateDnsConfigurationTest,Revalidation_Opportunistic)178 TEST_F(PrivateDnsConfigurationTest, Revalidation_Opportunistic) {
179     const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
180 
181     // Step 1: Set up and wait for validation complete.
182     testing::InSequence seq;
183     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
184     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
185 
186     EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
187     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
188     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
189 
190     // Step 2: Simulate the DNS is temporarily broken, and then request a validation.
191     // Expect the validation to run as follows:
192     //   1. DnsResolver notifies of Validation::in_process when the validation is about to run.
193     //   2. The first probing fails. DnsResolver notifies of Validation::in_process.
194     //   3. One second later, the second probing begins and succeeds. DnsResolver notifies of
195     //      Validation::success.
196     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId))
197             .Times(2);
198     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
199 
200     std::thread t([] {
201         std::this_thread::sleep_for(1000ms);
202         backend.startServer();
203     });
204     backend.stopServer();
205     EXPECT_TRUE(mPdc.requestDotValidation(kNetId, ServerIdentity(server), kMark).ok());
206 
207     t.join();
208     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
209     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
210 }
211 
TEST_F(PrivateDnsConfigurationTest,ValidationBlock)212 TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
213     backend.setDeferredResp(true);
214 
215     // onValidationStateUpdate() is called in sequence.
216     {
217         testing::InSequence seq;
218         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
219         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
220         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
221         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
222 
223         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::in_process, kNetId));
224         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}, {}, {}), 0);
225         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 2; }));
226         mObserver.removeFromServerStateMap(kServer1);
227         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
228 
229         // No duplicate validation as long as not in OFF mode; otherwise, an unexpected
230         // onValidationStateUpdate() will be caught.
231         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
232         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1, kServer2}, {}, {}), 0);
233         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}, {}, {}), 0);
234         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
235 
236         // The status keeps unchanged if pass invalid arguments.
237         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}, {}, {}), -EINVAL);
238         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
239     }
240 
241     // The update for |kServer1| will be Validation::fail because |kServer1| is not an expected
242     // server for the network.
243     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
244     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::success, kNetId));
245     backend.setDeferredResp(false);
246 
247     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
248 
249     // kServer1 is not a present server and thus should not be available from
250     // PrivateDnsConfiguration::getStatus().
251     mObserver.removeFromServerStateMap(kServer1);
252 
253     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
254 }
255 
TEST_F(PrivateDnsConfigurationTest,Validation_NetworkDestroyedOrOffMode)256 TEST_F(PrivateDnsConfigurationTest, Validation_NetworkDestroyedOrOffMode) {
257     for (const std::string_view config : {"OFF", "NETWORK_DESTROYED"}) {
258         SCOPED_TRACE(config);
259         backend.setDeferredResp(true);
260 
261         testing::InSequence seq;
262         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
263         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
264         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
265         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
266 
267         if (config == "OFF") {
268             EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}, {}), 0);
269         } else if (config == "NETWORK_DESTROYED") {
270             mPdc.clear(kNetId);
271         }
272 
273         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
274         backend.setDeferredResp(false);
275 
276         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
277         mObserver.removeFromServerStateMap(kServer1);
278         expectPrivateDnsStatus(PrivateDnsMode::OFF);
279     }
280 }
281 
TEST_F(PrivateDnsConfigurationTest,NoValidation)282 TEST_F(PrivateDnsConfigurationTest, NoValidation) {
283     // If onValidationStateUpdate() is called, the test will fail with uninteresting mock
284     // function calls in the end of the test.
285 
286     const auto expectStatus = [&]() {
287         const PrivateDnsStatus status = mPdc.getStatus(kNetId);
288         EXPECT_EQ(status.mode, PrivateDnsMode::OFF);
289         EXPECT_THAT(status.dotServersMap, testing::IsEmpty());
290     };
291 
292     EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}, {}, {}), -EINVAL);
293     expectStatus();
294 
295     EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}, {}), 0);
296     expectStatus();
297 }
298 
TEST_F(PrivateDnsConfigurationTest,ServerIdentity_Comparison)299 TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
300     DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
301     server.name = "dns.example.com";
302 
303     // Different socket address.
304     DnsTlsServer other = server;
305     EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
306     other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
307     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
308     other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
309     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
310 
311     // Different provider hostname.
312     other = server;
313     EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
314     other.name = "other.example.com";
315     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
316     other.name = "";
317     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
318 }
319 
TEST_F(PrivateDnsConfigurationTest,RequestValidation)320 TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
321     const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
322     const ServerIdentity identity(server);
323 
324     testing::InSequence seq;
325 
326     for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
327         SCOPED_TRACE(config);
328 
329         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
330         if (config == "SUCCESS") {
331             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
332         } else if (config == "IN_PROGRESS") {
333             backend.setDeferredResp(true);
334         } else {
335             // config = "FAIL"
336             ASSERT_TRUE(backend.stopServer());
337             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
338         }
339         EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
340         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
341 
342         // Wait until the validation state is transitioned.
343         const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
344         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));
345 
346         if (config == "SUCCESS") {
347             EXPECT_CALL(mObserver,
348                         onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
349             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
350             EXPECT_TRUE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
351         } else if (config == "IN_PROGRESS") {
352             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
353             EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
354         } else if (config == "FAIL") {
355             EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
356         }
357 
358         // Resending the same request or requesting nonexistent servers are denied.
359         EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
360         EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark + 1).ok());
361         EXPECT_FALSE(mPdc.requestDotValidation(kNetId + 1, identity, kMark).ok());
362 
363         // Reset the test state.
364         backend.setDeferredResp(false);
365         backend.startServer();
366 
367         // Ensure the status of mObserver is synced.
368         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
369 
370         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
371         mPdc.clear(kNetId);
372     }
373 }
374 
TEST_F(PrivateDnsConfigurationTest,GetPrivateDns)375 TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
376     const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
377     const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
378 
379     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
380     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
381 
382     // Suppress the warning.
383     EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);
384 
385     EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
386     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
387 
388     EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
389     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
390     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));
391 
392     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
393 }
394 
395 // Tests that getStatusForMetrics() returns the correct data.
TEST_F(PrivateDnsConfigurationTest,GetStatusForMetrics)396 TEST_F(PrivateDnsConfigurationTest, GetStatusForMetrics) {
397     tls2.stopServer();
398     const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
399     const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
400 
401     // Suppress the warning.
402     EXPECT_CALL(mObserver, onValidationStateUpdate).Times(4);
403 
404     // Set 1 unencrypted server and 2 encrypted servers (one will pass DoT validation; the other
405     // will fail. Both of them don't support DoH).
406     EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {kServer1, kServer2}, {}, {}), 0);
407     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
408 
409     // Get the metric before call clear().
410     NetworkDnsServerSupportReported event = mPdc.getStatusForMetrics(kNetId);
411     NetworkDnsServerSupportReported expectedEvent;
412     // It's NT_UNKNOWN because this test didn't call resolv_set_nameservers() to set
413     // the network type.
414     expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
415     expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_OPPORTUNISTIC);
416     Server* server = expectedEvent.mutable_servers()->add_server();
417     server->set_protocol(PROTO_UDP);  // kServer2
418     server->set_index(0);
419     server->set_validated(false);
420     server = expectedEvent.mutable_servers()->add_server();
421     server->set_protocol(PROTO_DOT);  // kServer1
422     server->set_index(0);
423     server->set_validated(true);
424     server = expectedEvent.mutable_servers()->add_server();
425     server->set_protocol(PROTO_DOT);  // kServer2
426     server->set_index(1);
427     server->set_validated(false);
428     server = expectedEvent.mutable_servers()->add_server();
429     server->set_protocol(PROTO_DOH);  // kServer1
430     server->set_index(0);
431     server->set_validated(false);
432     server = expectedEvent.mutable_servers()->add_server();
433     server->set_protocol(PROTO_DOH);  // kServer2
434     server->set_index(1);
435     server->set_validated(false);
436     EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
437 
438     // Get the metric after call clear().
439     mPdc.clear(kNetId);
440     event = mPdc.getStatusForMetrics(kNetId);
441     expectedEvent.Clear();
442     expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
443     expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_UNKNOWN);
444     EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
445 
446     tls2.startServer();
447 }
448 
449 // TODO: add ValidationFail_Strict test.
450 
451 }  // namespace android::net
452