1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include <netdutils/NetNativeTestBase.h>
20 #include <resolv_stats_test_utils.h>
21
22 #include "PrivateDnsConfiguration.h"
23 #include "resolv_cache.h"
24 #include "tests/dns_responder/dns_responder.h"
25 #include "tests/dns_responder/dns_tls_frontend.h"
26 #include "tests/resolv_test_utils.h"
27
28 namespace android::net {
29
30 using namespace std::chrono_literals;
31
32 class PrivateDnsConfigurationTest : public NetNativeTestBase {
33 public:
34 using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
35
SetUpTestSuite()36 static void SetUpTestSuite() {
37 // stopServer() will be called in their destructor.
38 ASSERT_TRUE(tls1.startServer());
39 ASSERT_TRUE(tls2.startServer());
40 ASSERT_TRUE(backend.startServer());
41 ASSERT_TRUE(backend1ForUdpProbe.startServer());
42 ASSERT_TRUE(backend2ForUdpProbe.startServer());
43 }
44
SetUp()45 void SetUp() {
46 mPdc.setObserver(&mObserver);
47 mPdc.mBackoffBuilder.withInitialRetransmissionTime(std::chrono::seconds(1))
48 .withMaximumRetransmissionTime(std::chrono::seconds(1));
49
50 // The default and sole action when the observer is notified of onValidationStateUpdate.
51 // Don't override the action. In other words, don't use WillOnce() or WillRepeatedly()
52 // when mObserver.onValidationStateUpdate is expected to be called, like:
53 //
54 // EXPECT_CALL(mObserver, onValidationStateUpdate).WillOnce(Return());
55 //
56 // This is to ensure that tests can monitor how many validation threads are running. Tests
57 // must wait until every validation thread finishes.
58 ON_CALL(mObserver, onValidationStateUpdate)
59 .WillByDefault([&](const std::string& server, Validation validation, uint32_t) {
60 std::lock_guard guard(mObserver.lock);
61 if (validation == Validation::in_process) {
62 auto it = mObserver.serverStateMap.find(server);
63 if (it == mObserver.serverStateMap.end() ||
64 it->second != Validation::in_process) {
65 // Increment runningThreads only when receive the first in_process
66 // notification. The rest of the continuous in_process notifications
67 // are due to probe retry which runs on the same thread.
68 // TODO: consider adding onValidationThreadStart() and
69 // onValidationThreadEnd() callbacks.
70 mObserver.runningThreads++;
71 }
72 } else if (validation == Validation::success ||
73 validation == Validation::fail) {
74 mObserver.runningThreads--;
75 }
76 mObserver.serverStateMap[server] = validation;
77 });
78
79 // Create a NetConfig for stats.
80 EXPECT_EQ(0, resolv_create_cache_for_net(kNetId));
81 }
82
TearDown()83 void TearDown() {
84 // Reset the state for the next test.
85 resolv_delete_cache_for_net(kNetId);
86 mPdc.set(kNetId, kMark, {}, {}, {}, {});
87 }
88
89 protected:
90 class MockObserver : public PrivateDnsValidationObserver {
91 public:
92 MOCK_METHOD(void, onValidationStateUpdate,
93 (const std::string& serverIp, Validation validation, uint32_t netId),
94 (override));
95
getServerStateMap() const96 std::map<std::string, Validation> getServerStateMap() const {
97 std::lock_guard guard(lock);
98 return serverStateMap;
99 }
100
removeFromServerStateMap(const std::string & server)101 void removeFromServerStateMap(const std::string& server) {
102 std::lock_guard guard(lock);
103 if (const auto it = serverStateMap.find(server); it != serverStateMap.end())
104 serverStateMap.erase(it);
105 }
106
107 // The current number of validation threads running.
108 std::atomic<int> runningThreads = 0;
109
110 mutable std::mutex lock;
111 std::map<std::string, Validation> serverStateMap GUARDED_BY(lock);
112 };
113
expectPrivateDnsStatus(PrivateDnsMode mode)114 void expectPrivateDnsStatus(PrivateDnsMode mode) {
115 // Use PollForCondition because mObserver is notified asynchronously.
116 EXPECT_TRUE(PollForCondition([&]() { return checkPrivateDnsStatus(mode); }));
117 }
118
checkPrivateDnsStatus(PrivateDnsMode mode)119 bool checkPrivateDnsStatus(PrivateDnsMode mode) {
120 const PrivateDnsStatus status = mPdc.getStatus(kNetId);
121 if (status.mode != mode) return false;
122
123 std::map<std::string, Validation> serverStateMap;
124 for (const auto& [server, validation] : status.dotServersMap) {
125 serverStateMap[ToString(&server.ss)] = validation;
126 }
127 return (serverStateMap == mObserver.getServerStateMap());
128 }
129
hasPrivateDnsServer(const ServerIdentity & identity,unsigned netId)130 bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
131 return mPdc.getDotServer(identity, netId).ok();
132 }
133
134 static constexpr uint32_t kNetId = 30;
135 static constexpr uint32_t kMark = 30;
136 static constexpr char kBackend[] = "127.0.2.1";
137 static constexpr char kServer1[] = "127.0.2.2";
138 static constexpr char kServer2[] = "127.0.2.3";
139
140 MockObserver mObserver;
141 inline static PrivateDnsConfiguration mPdc;
142
143 // TODO: Because incorrect CAs result in validation failed in strict mode, have
144 // PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate().
145 inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
146 inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
147 inline static test::DNSResponder backend{kBackend, "53"};
148 inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
149 inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
150 };
151
TEST_F(PrivateDnsConfigurationTest,ValidationSuccess)152 TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
153 testing::InSequence seq;
154 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
155 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
156
157 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
158 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
159
160 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
161 }
162
TEST_F(PrivateDnsConfigurationTest,ValidationFail_Opportunistic)163 TEST_F(PrivateDnsConfigurationTest, ValidationFail_Opportunistic) {
164 ASSERT_TRUE(backend.stopServer());
165
166 testing::InSequence seq;
167 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
168 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
169
170 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
171 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
172
173 // Strictly wait for all of the validation finish; otherwise, the test can crash somehow.
174 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
175 ASSERT_TRUE(backend.startServer());
176 }
177
TEST_F(PrivateDnsConfigurationTest,Revalidation_Opportunistic)178 TEST_F(PrivateDnsConfigurationTest, Revalidation_Opportunistic) {
179 const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
180
181 // Step 1: Set up and wait for validation complete.
182 testing::InSequence seq;
183 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
184 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
185
186 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
187 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
188 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
189
190 // Step 2: Simulate the DNS is temporarily broken, and then request a validation.
191 // Expect the validation to run as follows:
192 // 1. DnsResolver notifies of Validation::in_process when the validation is about to run.
193 // 2. The first probing fails. DnsResolver notifies of Validation::in_process.
194 // 3. One second later, the second probing begins and succeeds. DnsResolver notifies of
195 // Validation::success.
196 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId))
197 .Times(2);
198 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
199
200 std::thread t([] {
201 std::this_thread::sleep_for(1000ms);
202 backend.startServer();
203 });
204 backend.stopServer();
205 EXPECT_TRUE(mPdc.requestDotValidation(kNetId, ServerIdentity(server), kMark).ok());
206
207 t.join();
208 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
209 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
210 }
211
TEST_F(PrivateDnsConfigurationTest,ValidationBlock)212 TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
213 backend.setDeferredResp(true);
214
215 // onValidationStateUpdate() is called in sequence.
216 {
217 testing::InSequence seq;
218 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
219 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
220 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
221 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
222
223 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::in_process, kNetId));
224 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}, {}, {}), 0);
225 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 2; }));
226 mObserver.removeFromServerStateMap(kServer1);
227 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
228
229 // No duplicate validation as long as not in OFF mode; otherwise, an unexpected
230 // onValidationStateUpdate() will be caught.
231 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
232 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1, kServer2}, {}, {}), 0);
233 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}, {}, {}), 0);
234 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
235
236 // The status keeps unchanged if pass invalid arguments.
237 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}, {}, {}), -EINVAL);
238 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
239 }
240
241 // The update for |kServer1| will be Validation::fail because |kServer1| is not an expected
242 // server for the network.
243 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
244 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::success, kNetId));
245 backend.setDeferredResp(false);
246
247 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
248
249 // kServer1 is not a present server and thus should not be available from
250 // PrivateDnsConfiguration::getStatus().
251 mObserver.removeFromServerStateMap(kServer1);
252
253 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
254 }
255
TEST_F(PrivateDnsConfigurationTest,Validation_NetworkDestroyedOrOffMode)256 TEST_F(PrivateDnsConfigurationTest, Validation_NetworkDestroyedOrOffMode) {
257 for (const std::string_view config : {"OFF", "NETWORK_DESTROYED"}) {
258 SCOPED_TRACE(config);
259 backend.setDeferredResp(true);
260
261 testing::InSequence seq;
262 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
263 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
264 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
265 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
266
267 if (config == "OFF") {
268 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}, {}), 0);
269 } else if (config == "NETWORK_DESTROYED") {
270 mPdc.clear(kNetId);
271 }
272
273 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
274 backend.setDeferredResp(false);
275
276 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
277 mObserver.removeFromServerStateMap(kServer1);
278 expectPrivateDnsStatus(PrivateDnsMode::OFF);
279 }
280 }
281
TEST_F(PrivateDnsConfigurationTest,NoValidation)282 TEST_F(PrivateDnsConfigurationTest, NoValidation) {
283 // If onValidationStateUpdate() is called, the test will fail with uninteresting mock
284 // function calls in the end of the test.
285
286 const auto expectStatus = [&]() {
287 const PrivateDnsStatus status = mPdc.getStatus(kNetId);
288 EXPECT_EQ(status.mode, PrivateDnsMode::OFF);
289 EXPECT_THAT(status.dotServersMap, testing::IsEmpty());
290 };
291
292 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}, {}, {}), -EINVAL);
293 expectStatus();
294
295 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}, {}), 0);
296 expectStatus();
297 }
298
TEST_F(PrivateDnsConfigurationTest,ServerIdentity_Comparison)299 TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
300 DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
301 server.name = "dns.example.com";
302
303 // Different socket address.
304 DnsTlsServer other = server;
305 EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
306 other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
307 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
308 other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
309 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
310
311 // Different provider hostname.
312 other = server;
313 EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
314 other.name = "other.example.com";
315 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
316 other.name = "";
317 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
318 }
319
TEST_F(PrivateDnsConfigurationTest,RequestValidation)320 TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
321 const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
322 const ServerIdentity identity(server);
323
324 testing::InSequence seq;
325
326 for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
327 SCOPED_TRACE(config);
328
329 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
330 if (config == "SUCCESS") {
331 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
332 } else if (config == "IN_PROGRESS") {
333 backend.setDeferredResp(true);
334 } else {
335 // config = "FAIL"
336 ASSERT_TRUE(backend.stopServer());
337 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
338 }
339 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
340 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
341
342 // Wait until the validation state is transitioned.
343 const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
344 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));
345
346 if (config == "SUCCESS") {
347 EXPECT_CALL(mObserver,
348 onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
349 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
350 EXPECT_TRUE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
351 } else if (config == "IN_PROGRESS") {
352 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
353 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
354 } else if (config == "FAIL") {
355 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
356 }
357
358 // Resending the same request or requesting nonexistent servers are denied.
359 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
360 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark + 1).ok());
361 EXPECT_FALSE(mPdc.requestDotValidation(kNetId + 1, identity, kMark).ok());
362
363 // Reset the test state.
364 backend.setDeferredResp(false);
365 backend.startServer();
366
367 // Ensure the status of mObserver is synced.
368 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
369
370 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
371 mPdc.clear(kNetId);
372 }
373 }
374
TEST_F(PrivateDnsConfigurationTest,GetPrivateDns)375 TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
376 const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
377 const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
378
379 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
380 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
381
382 // Suppress the warning.
383 EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);
384
385 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
386 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
387
388 EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
389 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
390 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));
391
392 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
393 }
394
395 // Tests that getStatusForMetrics() returns the correct data.
TEST_F(PrivateDnsConfigurationTest,GetStatusForMetrics)396 TEST_F(PrivateDnsConfigurationTest, GetStatusForMetrics) {
397 tls2.stopServer();
398 const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
399 const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
400
401 // Suppress the warning.
402 EXPECT_CALL(mObserver, onValidationStateUpdate).Times(4);
403
404 // Set 1 unencrypted server and 2 encrypted servers (one will pass DoT validation; the other
405 // will fail. Both of them don't support DoH).
406 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {kServer1, kServer2}, {}, {}), 0);
407 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
408
409 // Get the metric before call clear().
410 NetworkDnsServerSupportReported event = mPdc.getStatusForMetrics(kNetId);
411 NetworkDnsServerSupportReported expectedEvent;
412 // It's NT_UNKNOWN because this test didn't call resolv_set_nameservers() to set
413 // the network type.
414 expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
415 expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_OPPORTUNISTIC);
416 Server* server = expectedEvent.mutable_servers()->add_server();
417 server->set_protocol(PROTO_UDP); // kServer2
418 server->set_index(0);
419 server->set_validated(false);
420 server = expectedEvent.mutable_servers()->add_server();
421 server->set_protocol(PROTO_DOT); // kServer1
422 server->set_index(0);
423 server->set_validated(true);
424 server = expectedEvent.mutable_servers()->add_server();
425 server->set_protocol(PROTO_DOT); // kServer2
426 server->set_index(1);
427 server->set_validated(false);
428 server = expectedEvent.mutable_servers()->add_server();
429 server->set_protocol(PROTO_DOH); // kServer1
430 server->set_index(0);
431 server->set_validated(false);
432 server = expectedEvent.mutable_servers()->add_server();
433 server->set_protocol(PROTO_DOH); // kServer2
434 server->set_index(1);
435 server->set_validated(false);
436 EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
437
438 // Get the metric after call clear().
439 mPdc.clear(kNetId);
440 event = mPdc.getStatusForMetrics(kNetId);
441 expectedEvent.Clear();
442 expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
443 expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_UNKNOWN);
444 EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
445
446 tls2.startServer();
447 }
448
449 // TODO: add ValidationFail_Strict test.
450
451 } // namespace android::net
452