1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.android.iwlan.epdg; 18 19 import android.annotation.CallbackExecutor; 20 import android.annotation.NonNull; 21 import android.annotation.Nullable; 22 import android.net.DnsResolver; 23 import android.net.DnsResolver.DnsException; 24 import android.net.Network; 25 import android.net.ParseException; 26 import android.os.CancellationSignal; 27 import android.util.Log; 28 29 import com.android.net.module.util.DnsPacket; 30 import com.android.net.module.util.DnsPacketUtils.DnsRecordParser; 31 32 import java.net.InetAddress; 33 import java.net.UnknownHostException; 34 import java.nio.BufferUnderflowException; 35 import java.nio.ByteBuffer; 36 import java.util.ArrayList; 37 import java.util.HashMap; 38 import java.util.Iterator; 39 import java.util.List; 40 import java.util.Map; 41 import java.util.concurrent.CompletableFuture; 42 import java.util.concurrent.ExecutionException; 43 import java.util.concurrent.Executor; 44 import java.util.concurrent.Executors; 45 46 /** 47 * A utility wrapper around android.net.DnsResolver that queries for SRV DNS Resource Records, and 48 * returns in the user callback a list of server (IP addresses, port number) combinations pertaining 49 * to the service requested. 50 * 51 * <p>The returned {@link List<SrvRecordInetAddress>} is currently not sorted according to priority 52 * and weight, in the mechanism described in RFC2782. 53 */ 54 final class SrvDnsResolver { 55 private static final String TAG = "SrvDnsResolver"; 56 57 /** 58 * An SRV Resource Record is queried to obtain the specific port number at which a service is 59 * offered. So the client is returned a combination of (INetAddress, port). 60 */ 61 static class SrvRecordInetAddress { 62 // Holds an IPv4/v6 address, obtained by querying getHostAddress(). 63 public final InetAddress mInetAddress; 64 // A 16-bit unsigned port number. 65 public final int mPort; 66 SrvRecordInetAddress(InetAddress inetAddress, int port)67 public SrvRecordInetAddress(InetAddress inetAddress, int port) { 68 mInetAddress = inetAddress; 69 mPort = port; 70 } 71 } 72 73 // Since the query type for SRV records is not defined in DnsResolver, it is defined here. 74 static final int QUERY_TYPE_SRV = 33; 75 76 /* 77 * Parses and stores an SRV record as described in RFC2782. 78 * 79 * Expects records of type QUERY_TYPE_SRV in the Queries and Answer records section, and records 80 * of type TYPE_A and TYPE_AAAA in the Additional Records section of the DnsPacket. 81 */ 82 static class SrvResponse extends DnsPacket { 83 static class SrvRecord { 84 // A 16-bit unsigned integer that determines the priority of the target host. Clients 85 // must attempt to contact the target host with the lowest-numbered priority first. 86 public final int priority; 87 88 // A 16-bit unsigned integer that specifies a relative weight for entries with the same 89 // priority. Larger weights should we given a proportionately higher probability of 90 // being selected. 91 public final int weight; 92 93 // A 16-bit unsigned integer that specifies the port on this target for this service. 94 public final int port; 95 96 // The domain name of the target host. A target of "." means that the service is 97 // decidedly not available at this domain. 98 public final String target; 99 100 private static final int MAXNAMESIZE = 255; 101 SrvRecord(byte[] srvRecordData)102 SrvRecord(byte[] srvRecordData) throws ParseException { 103 final ByteBuffer buf = ByteBuffer.wrap(srvRecordData); 104 105 try { 106 priority = Short.toUnsignedInt(buf.getShort()); 107 weight = Short.toUnsignedInt(buf.getShort()); 108 port = Short.toUnsignedInt(buf.getShort()); 109 // Although unexpected, some DNS servers do use name compression on portions of 110 // the 'target' field that overlap with the query section of the DNS packet. 111 target = 112 DnsRecordParser.parseName( 113 buf, 0, /* isNameCompressionSupported */ true); 114 if (target.length() > MAXNAMESIZE) { 115 throw new ParseException( 116 "Parse name failed, name size is too long: " + target.length()); 117 } 118 if (buf.hasRemaining()) { 119 throw new ParseException( 120 "Parsing SRV record data failed: more bytes than expected!"); 121 } 122 } catch (BufferUnderflowException e) { 123 throw new ParseException("Parsing SRV Record data failed with cause", e); 124 } 125 } 126 } 127 128 private final int mQueryType; 129 SrvResponse(@onNull byte[] data)130 SrvResponse(@NonNull byte[] data) throws ParseException { 131 super(data); 132 if (!mHeader.isResponse()) { 133 throw new ParseException("Not an answer packet"); 134 } 135 int numQueries = mHeader.getRecordCount(QDSECTION); 136 // Expects exactly one query in query section. 137 if (numQueries != 1) { 138 throw new ParseException("Unexpected query count: " + numQueries); 139 } 140 mQueryType = mRecords[QDSECTION].get(0).nsType; 141 if (mQueryType != QUERY_TYPE_SRV) { 142 throw new ParseException("Unexpected query type: " + mQueryType); 143 } 144 } 145 146 // Parses the Answers section of a DnsPacket to construct and return a mapping 147 // of Domain Name strings to their corresponding SRV record. parseSrvRecords()148 public @NonNull Map<String, SrvRecord> parseSrvRecords() throws ParseException { 149 final HashMap<String, SrvRecord> targetNameToSrvRecord = new HashMap<>(); 150 if (mHeader.getRecordCount(ANSECTION) == 0) return targetNameToSrvRecord; 151 152 for (final DnsRecord ansSec : mRecords[ANSECTION]) { 153 final int nsType = ansSec.nsType; 154 if (nsType != QUERY_TYPE_SRV) { 155 throw new ParseException("Unexpected DNS record type in ANSECTION: " + nsType); 156 } 157 final SrvRecord record = new SrvRecord(ansSec.getRR()); 158 if (targetNameToSrvRecord.containsKey(record.target)) { 159 throw new ParseException( 160 "Domain name " 161 + record.target 162 + " already encountered in DNS response!"); 163 } 164 targetNameToSrvRecord.put(record.target, record); 165 Log.d(TAG, "SrvRecord name: " + ansSec.dName + " target name: " + record.target); 166 } 167 return targetNameToSrvRecord; 168 } 169 170 /* 171 * Parses the 'Additional Records' section of a DnsPacket and expects 'Address Records' 172 * (TYPE_A and TYPE_AAAA records) to construct and return a mapping of Domain Name strings 173 * to their corresponding IP address(es). 174 */ parseIpAddresses()175 public @NonNull Map<String, List<InetAddress>> parseIpAddresses() throws ParseException { 176 final HashMap<String, List<InetAddress>> domainNameToIpAddress = new HashMap<>(); 177 if (mHeader.getRecordCount(ARSECTION) == 0) return domainNameToIpAddress; 178 179 for (final DnsRecord ansSec : mRecords[ARSECTION]) { 180 int nsType = ansSec.nsType; 181 if (nsType != DnsResolver.TYPE_A && nsType != DnsResolver.TYPE_AAAA) { 182 throw new ParseException("Unexpected DNS record type in ARSECTION: " + nsType); 183 } 184 domainNameToIpAddress.computeIfAbsent(ansSec.dName, k -> new ArrayList<>()); 185 try { 186 final InetAddress ipAddress = InetAddress.getByAddress(ansSec.getRR()); 187 Log.d( 188 TAG, 189 "Additional record name: " 190 + ansSec.dName 191 + " IP addr: " 192 + ipAddress.getHostAddress()); 193 domainNameToIpAddress.get(ansSec.dName).add(ipAddress); 194 } catch (UnknownHostException e) { 195 throw new ParseException( 196 "RR to IP address translation failed for domain: " + ansSec.dName); 197 } 198 } 199 return domainNameToIpAddress; 200 } 201 } 202 203 /** 204 * A decorator for {@link DnsResolver.Callback} that accumulates IPv4/v6 responses for SRV DNS 205 * queries and passes it up to the user callback. 206 */ 207 private static class SrvRecordAnswerAccumulator implements DnsResolver.Callback<byte[]> { 208 private static final String TAG = "SrvRecordAnswerAccum"; 209 210 private final Network mNetwork; 211 private final DnsResolver.Callback<List<SrvRecordInetAddress>> mUserCallback; 212 private final Executor mUserExecutor; 213 214 private static class LazyExecutor { 215 public static final Executor INSTANCE = Executors.newSingleThreadExecutor(); 216 } 217 getInternalExecutor()218 static Executor getInternalExecutor() { 219 return LazyExecutor.INSTANCE; 220 } 221 SrvRecordAnswerAccumulator( @onNull Network network, @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback, @NonNull @CallbackExecutor Executor executor)222 SrvRecordAnswerAccumulator( 223 @NonNull Network network, 224 @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback, 225 @NonNull @CallbackExecutor Executor executor) { 226 mNetwork = network; 227 mUserCallback = callback; 228 mUserExecutor = executor; 229 } 230 231 /** 232 * Some DNS servers, when queried for an SRV record, do not return the IPv4/v6 records along 233 * with the SRV record. For those, we perform an additional blocking IPv4/v6 DNS query for 234 * each outstanding SRV record. 235 */ queryDns(String domainName)236 private List<InetAddress> queryDns(String domainName) throws DnsException { 237 final CompletableFuture<List<InetAddress>> result = new CompletableFuture(); 238 final DnsResolver.Callback<List<InetAddress>> cb = 239 new DnsResolver.Callback<List<InetAddress>>() { 240 @Override 241 public void onAnswer( 242 @NonNull final List<InetAddress> answer, final int rcode) { 243 if (rcode != 0) { 244 Log.e(TAG, "queryDNS Response Code = " + rcode); 245 } 246 result.complete(answer); 247 } 248 249 @Override 250 public void onError(@Nullable final DnsException error) { 251 Log.e(TAG, "queryDNS response with error : " + error); 252 result.completeExceptionally(error); 253 } 254 }; 255 DnsResolver.getInstance() 256 .query(mNetwork, domainName, DnsResolver.FLAG_EMPTY, Runnable::run, null, cb); 257 258 try { 259 return result.get(); 260 } catch (ExecutionException e) { 261 throw (DnsException) e.getCause(); 262 } catch (InterruptedException e) { 263 Thread.currentThread().interrupt(); // Restore the interrupted status 264 throw new DnsException(DnsResolver.ERROR_SYSTEM, e); 265 } 266 } 267 268 /** 269 * Composes the final (IP address, Port) combination for the client's SRV request. Performs 270 * additional DNS queries if necessary. The SRV records are presently not sorted according 271 * to priority and weight, as described in RFC2782- this is simply 'good enough'. 272 */ composeSrvRecordResult(SrvResponse response)273 private List<SrvRecordInetAddress> composeSrvRecordResult(SrvResponse response) 274 throws DnsPacket.ParseException, DnsException { 275 final List<SrvRecordInetAddress> srvRecordInetAddresses = new ArrayList<>(); 276 final Map<String, List<InetAddress>> domainNameToIpAddresses = 277 response.parseIpAddresses(); 278 final Map<String, SrvResponse.SrvRecord> targetNameToSrvRecords = 279 response.parseSrvRecords(); 280 281 Iterator<Map.Entry<String, SrvResponse.SrvRecord>> itr = 282 targetNameToSrvRecords.entrySet().iterator(); 283 284 // Checks if the received SRV RRs have a corresponding match in IP addresses. For the 285 // ones that do, adds the (IP address, port number) to the output field list. 286 while (itr.hasNext()) { 287 Map.Entry<String, SrvResponse.SrvRecord> targetNameToSrvRecord = itr.next(); 288 String domainName = targetNameToSrvRecord.getKey(); 289 int port = targetNameToSrvRecord.getValue().port; 290 List<InetAddress> addresses = domainNameToIpAddresses.get(domainName); 291 if (addresses != null) { 292 // Found a match- add to output list and remove entry from SrvRecord collection. 293 for (InetAddress address : addresses) { 294 srvRecordInetAddresses.add(new SrvRecordInetAddress(address, port)); 295 } 296 itr.remove(); 297 } 298 } 299 300 // For the SRV RRs that don't, spawns a separate DnsResolver query for each, and 301 // collects results using a blocking call. 302 itr = targetNameToSrvRecords.entrySet().iterator(); 303 while (itr.hasNext()) { 304 Map.Entry<String, SrvResponse.SrvRecord> targetNameToSrvRecord = itr.next(); 305 String domainName = targetNameToSrvRecord.getKey(); 306 int port = targetNameToSrvRecord.getValue().port; 307 List<InetAddress> addresses = queryDns(domainName); 308 for (InetAddress address : addresses) { 309 srvRecordInetAddresses.add(new SrvRecordInetAddress(address, port)); 310 } 311 } 312 return srvRecordInetAddresses; 313 } 314 315 @Override onAnswer(@onNull byte[] answer, int rcode)316 public void onAnswer(@NonNull byte[] answer, int rcode) { 317 try { 318 final SrvResponse response = new SrvResponse(answer); 319 final List<SrvRecordInetAddress> result = composeSrvRecordResult(response); 320 mUserExecutor.execute(() -> mUserCallback.onAnswer(result, rcode)); 321 } catch (DnsPacket.ParseException e) { 322 // Convert the com.android.net.module.util.DnsPacket.ParseException to an 323 // android.net.ParseException. This is the type that was used in Q and is implied 324 // by the public documentation of ERROR_PARSE. 325 // 326 // DnsPacket cannot throw android.net.ParseException directly because it's @hide. 327 final ParseException pe = new ParseException(e.reason, e.getCause()); 328 pe.setStackTrace(e.getStackTrace()); 329 Log.e(TAG, "ParseException", pe); 330 mUserExecutor.execute( 331 () -> mUserCallback.onError(new DnsException(DnsResolver.ERROR_PARSE, pe))); 332 } catch (DnsException e) { 333 mUserExecutor.execute(() -> mUserCallback.onError(e)); 334 } 335 } 336 337 @Override onError(@onNull DnsException error)338 public void onError(@NonNull DnsException error) { 339 Log.e(TAG, "onError: " + error); 340 mUserExecutor.execute(() -> mUserCallback.onError(error)); 341 } 342 } 343 344 /** 345 * Send an SRV DNS query with the specified name, class and query type. The answer will be 346 * provided asynchronously on the passed executor, through the provided {@link 347 * DnsResolver.Callback}. 348 * 349 * @param network {@link Network} specifying which network to query on. {@code null} for query 350 * on default network. 351 * @param domain SRV domain name to query ( in format _Service._Protocol.Name) 352 * @param cancellationSignal used by the caller to signal if the query should be cancelled. May 353 * be {@code null}. 354 * @param callback a {@link DnsResolver.Callback} which will be called on a separate thread to 355 * notify the caller of the result of the DNS query. 356 */ query( @ullable Network network, @NonNull String domain, @NonNull @CallbackExecutor Executor executor, @Nullable CancellationSignal cancellationSignal, @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback)357 public static void query( 358 @Nullable Network network, 359 @NonNull String domain, 360 @NonNull @CallbackExecutor Executor executor, 361 @Nullable CancellationSignal cancellationSignal, 362 @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback) { 363 final SrvRecordAnswerAccumulator srvDnsCb = 364 new SrvRecordAnswerAccumulator(network, callback, executor); 365 DnsResolver.getInstance() 366 .rawQuery( 367 network, 368 domain, 369 DnsResolver.CLASS_IN, 370 QUERY_TYPE_SRV, 371 DnsResolver.FLAG_EMPTY, 372 SrvRecordAnswerAccumulator.getInternalExecutor(), 373 cancellationSignal, 374 srvDnsCb); 375 } 376 SrvDnsResolver()377 private SrvDnsResolver() {} 378 } 379