1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <array> 20 #include <list> 21 #include <map> 22 #include <mutex> 23 #include <vector> 24 25 #include <android-base/format.h> 26 #include <android-base/logging.h> 27 #include <android-base/result.h> 28 #include <android-base/thread_annotations.h> 29 #include <netdutils/BackoffSequence.h> 30 #include <netdutils/DumpWriter.h> 31 #include <netdutils/InternetAddresses.h> 32 #include <netdutils/Slice.h> 33 34 #include "DnsTlsServer.h" 35 #include "LockedQueue.h" 36 #include "PrivateDnsValidationObserver.h" 37 #include "doh.h" 38 39 namespace android { 40 namespace net { 41 42 // TODO: decouple the dependency of DnsTlsServer. 43 struct PrivateDnsStatus { 44 PrivateDnsMode mode; 45 46 // TODO: change the type to std::vector<DnsTlsServer>. 47 std::map<DnsTlsServer, Validation, AddressComparator> dotServersMap; 48 49 std::map<netdutils::IPSockAddr, Validation> dohServersMap; 50 validatedServersPrivateDnsStatus51 std::list<DnsTlsServer> validatedServers() const { 52 std::list<DnsTlsServer> servers; 53 54 for (const auto& pair : dotServersMap) { 55 if (pair.second == Validation::success) { 56 servers.push_back(pair.first); 57 } 58 } 59 return servers; 60 } 61 hasValidatedDohServersPrivateDnsStatus62 bool hasValidatedDohServers() const { 63 for (const auto& [_, status] : dohServersMap) { 64 if (status == Validation::success) { 65 return true; 66 } 67 } 68 return false; 69 } 70 }; 71 72 class PrivateDnsConfiguration { 73 public: 74 static constexpr int kDohQueryDefaultTimeoutMs = 30000; 75 static constexpr int kDohProbeDefaultTimeoutMs = 60000; 76 77 // The default value for QUIC max_idle_timeout. 78 static constexpr int kDohIdleDefaultTimeoutMs = 55000; 79 80 struct ServerIdentity { 81 const netdutils::IPSockAddr sockaddr; 82 const std::string provider; 83 ServerIdentityServerIdentity84 explicit ServerIdentity(const IPrivateDnsServer& server) 85 : sockaddr(server.addr()), provider(server.provider()) {} ServerIdentityServerIdentity86 ServerIdentity(const netdutils::IPSockAddr& addr, const std::string& host) 87 : sockaddr(addr), provider(host) {} 88 89 bool operator<(const ServerIdentity& other) const { 90 return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider); 91 } 92 bool operator==(const ServerIdentity& other) const { 93 return std::tie(sockaddr, provider) == std::tie(other.sockaddr, other.provider); 94 } 95 }; 96 97 // The only instance of PrivateDnsConfiguration. getInstance()98 static PrivateDnsConfiguration& getInstance() { 99 static PrivateDnsConfiguration instance; 100 return instance; 101 } 102 103 int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, 104 const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock); 105 106 void initDoh() EXCLUDES(mPrivateDnsLock); 107 108 int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, 109 const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock); 110 111 PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock); 112 113 void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); 114 115 void clearDoh(unsigned netId) EXCLUDES(mPrivateDnsLock); 116 117 ssize_t dohQuery(unsigned netId, const netdutils::Slice query, const netdutils::Slice answer, 118 uint64_t timeoutMs) EXCLUDES(mPrivateDnsLock); 119 120 // Request the server to be revalidated on a connection tagged with |mark|. 121 // Returns a Result to indicate if the request is accepted. 122 base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity, 123 uint32_t mark) EXCLUDES(mPrivateDnsLock); 124 125 void setObserver(PrivateDnsValidationObserver* observer); 126 127 void dump(netdutils::DumpWriter& dw) const; 128 129 void onDohStatusUpdate(uint32_t netId, bool success, const char* ipAddr, const char* host) 130 EXCLUDES(mPrivateDnsLock); 131 132 base::Result<netdutils::IPSockAddr> getDohServer(unsigned netId) const 133 EXCLUDES(mPrivateDnsLock); 134 135 private: 136 typedef std::map<ServerIdentity, std::unique_ptr<IPrivateDnsServer>> PrivateDnsTracker; 137 138 PrivateDnsConfiguration() = default; 139 140 // Launchs a thread to run the validation for |server| on the network |netId|. 141 // |isRevalidation| is true if this call is due to a revalidation request. 142 void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation) 143 REQUIRES(mPrivateDnsLock); 144 145 bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success, 146 bool isRevalidation) EXCLUDES(mPrivateDnsLock); 147 148 void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, 149 bool success) const REQUIRES(mPrivateDnsLock); 150 151 // Decide if a validation for |server| is needed. Note that servers that have failed 152 // multiple validation attempts but for which there is still a validating 153 // thread running are marked as being Validation::in_process. 154 bool needsValidation(const IPrivateDnsServer& server) const REQUIRES(mPrivateDnsLock); 155 156 void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) 157 REQUIRES(mPrivateDnsLock); 158 159 // For testing. 160 base::Result<IPrivateDnsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId) 161 EXCLUDES(mPrivateDnsLock); 162 163 base::Result<IPrivateDnsServer*> getPrivateDnsLocked(const ServerIdentity& identity, 164 unsigned netId) REQUIRES(mPrivateDnsLock); 165 166 void initDohLocked() REQUIRES(mPrivateDnsLock); 167 void clearDohLocked(unsigned netId) REQUIRES(mPrivateDnsLock); 168 169 mutable std::mutex mPrivateDnsLock; 170 std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock); 171 172 // Contains all servers for a network, along with their current validation status. 173 // In case a server is removed due to a configuration change, it remains in this map, 174 // but is marked inactive. 175 // Any pending validation threads will continue running because we have no way to cancel them. 176 std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock); 177 178 void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation, 179 uint32_t netId) const REQUIRES(mPrivateDnsLock); 180 181 bool needReportEvent(uint32_t netId, ServerIdentity identity, bool success) const 182 REQUIRES(mPrivateDnsLock); 183 184 // TODO: fix the reentrancy problem. 185 PrivateDnsValidationObserver* mObserver GUARDED_BY(mPrivateDnsLock); 186 187 DohDispatcher* mDohDispatcher; 188 189 friend class PrivateDnsConfigurationTest; 190 191 // It's not const because PrivateDnsConfigurationTest needs to override it. 192 // TODO: make it const by dependency injection. 193 netdutils::BackoffSequence<>::Builder mBackoffBuilder = 194 netdutils::BackoffSequence<>::Builder() 195 .withInitialRetransmissionTime(std::chrono::seconds(60)) 196 .withMaximumRetransmissionTime(std::chrono::seconds(3600)); 197 198 struct DohIdentity { 199 std::string httpsTemplate; 200 std::string ipAddr; 201 std::string host; 202 Validation status; 203 bool operator<(const DohIdentity& other) const { 204 return std::tie(ipAddr, host) < std::tie(other.ipAddr, other.host); 205 } 206 bool operator==(const DohIdentity& other) const { 207 return std::tie(ipAddr, host) == std::tie(other.ipAddr, other.host); 208 } 209 bool operator<(const ServerIdentity& other) const { 210 std::string otherIp = other.sockaddr.ip().toString(); 211 return std::tie(ipAddr, host) < std::tie(otherIp, other.provider); 212 } 213 bool operator==(const ServerIdentity& other) const { 214 std::string otherIp = other.sockaddr.ip().toString(); 215 return std::tie(ipAddr, host) == std::tie(otherIp, other.provider); 216 } 217 }; 218 219 struct DohProviderEntry { 220 std::string provider; 221 std::set<std::string> ips; 222 std::string host; 223 std::string httpsTemplate; 224 bool requireRootPermission; getDohIdentityDohProviderEntry225 base::Result<DohIdentity> getDohIdentity(const std::vector<std::string>& ips, 226 const std::string& host) const { 227 if (!host.empty() && this->host != host) return Errorf("host {} not matched", host); 228 for (const auto& ip : ips) { 229 if (this->ips.find(ip) == this->ips.end()) continue; 230 LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host); 231 // Only pick the first one for now. 232 return DohIdentity{httpsTemplate, ip, host, Validation::in_process}; 233 } 234 return Errorf("server not matched"); 235 }; 236 }; 237 238 // TODO: Move below DoH relevant stuff into Rust implementation. 239 std::map<unsigned, DohIdentity> mDohTracker GUARDED_BY(mPrivateDnsLock); 240 std::array<DohProviderEntry, 4> mAvailableDoHProviders = {{ 241 {"Google", 242 {"2001:4860:4860::8888", "2001:4860:4860::8844", "8.8.8.8", "8.8.4.4"}, 243 "dns.google", 244 "https://dns.google/dns-query", 245 false}, 246 {"Cloudflare", 247 {"2606:4700::6810:f8f9", "2606:4700::6810:f9f9", "104.16.248.249", "104.16.249.249"}, 248 "cloudflare-dns.com", 249 "https://cloudflare-dns.com/dns-query", 250 false}, 251 252 // The DoH providers for testing only. 253 // Using ResolverTestProvider requires that the DnsResolver is configured by someone 254 // who has root permission, which should be run by tests only. 255 {"ResolverTestProvider", 256 {"127.0.0.3", "::1"}, 257 "example.com", 258 "https://example.com/dns-query", 259 true}, 260 {"AndroidTesting", 261 {"192.0.2.100"}, 262 "dns.androidtesting.org", 263 "https://dns.androidtesting.org/dns-query", 264 false}, 265 }}; 266 267 struct RecordEntry { RecordEntryRecordEntry268 RecordEntry(uint32_t netId, const ServerIdentity& identity, Validation state) 269 : netId(netId), serverIdentity(identity), state(state) {} 270 271 const uint32_t netId; 272 const ServerIdentity serverIdentity; 273 const Validation state; 274 const std::chrono::system_clock::time_point timestamp = std::chrono::system_clock::now(); 275 }; 276 277 LockedRingBuffer<RecordEntry> mPrivateDnsLog{100}; 278 }; 279 280 } // namespace net 281 } // namespace android 282