1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.connectivity.mdns; 18 19 import static com.android.net.module.util.DnsUtils.equalsIgnoreDnsCase; 20 import static com.android.net.module.util.DnsUtils.toDnsUpperCase; 21 import static com.android.net.module.util.HandlerUtils.ensureRunningOnHandlerThread; 22 import static com.android.server.connectivity.mdns.MdnsResponse.EXPIRATION_NEVER; 23 24 import static java.lang.Math.min; 25 26 import android.annotation.NonNull; 27 import android.annotation.Nullable; 28 import android.os.Handler; 29 import android.os.Looper; 30 import android.util.ArrayMap; 31 32 import com.android.internal.annotations.VisibleForTesting; 33 import com.android.server.connectivity.mdns.util.MdnsUtils; 34 35 import java.io.PrintWriter; 36 import java.util.ArrayList; 37 import java.util.Collections; 38 import java.util.Iterator; 39 import java.util.List; 40 import java.util.Objects; 41 42 /** 43 * The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these 44 * services to reduce duplicated queries. 45 * 46 * <p>This class is not thread safe, it is intended to be used only from the looper thread. 47 * However, the constructor is an exception, as it is called on another thread; 48 * therefore for thread safety all members of this class MUST either be final or initialized 49 * to their default value (0, false or null). 50 */ 51 public class MdnsServiceCache { 52 public static class CacheKey { 53 @NonNull final String mUpperCaseServiceType; 54 @NonNull final SocketKey mSocketKey; 55 CacheKey(@onNull String serviceType, @NonNull SocketKey socketKey)56 CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) { 57 mUpperCaseServiceType = toDnsUpperCase(serviceType); 58 mSocketKey = socketKey; 59 } 60 hashCode()61 @Override public int hashCode() { 62 return Objects.hash(mUpperCaseServiceType, mSocketKey); 63 } 64 equals(Object other)65 @Override public boolean equals(Object other) { 66 if (this == other) { 67 return true; 68 } 69 if (!(other instanceof CacheKey)) { 70 return false; 71 } 72 return Objects.equals(mUpperCaseServiceType, ((CacheKey) other).mUpperCaseServiceType) 73 && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey); 74 } 75 76 @Override toString()77 public String toString() { 78 return "CacheKey{ ServiceType=" + mUpperCaseServiceType + ", " + mSocketKey + " }"; 79 } 80 } 81 82 public static class CachedService { 83 @NonNull final MdnsResponse mService; 84 boolean mServiceExpired; 85 CachedService(MdnsResponse service)86 CachedService(MdnsResponse service) { 87 mService = service; 88 mServiceExpired = false; 89 } 90 } 91 92 /** 93 * A map of cached services. Key is composed of service type and socket. Value is the list of 94 * services which are discovered from the given CacheKey. 95 * When the MdnsFeatureFlags#NSD_EXPIRED_SERVICES_REMOVAL flag is enabled, the lists are sorted 96 * by expiration time, with the earliest entries appearing first. This sorting allows the 97 * removal process to progress through the expiration check efficiently. 98 */ 99 @NonNull 100 private final ArrayMap<CacheKey, List<CachedService>> mCachedServices = new ArrayMap<>(); 101 /** 102 * A map of service expire callbacks. Key is composed of service type and socket and value is 103 * the callback listener. 104 */ 105 @NonNull 106 private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>(); 107 @NonNull 108 private final Handler mHandler; 109 @NonNull 110 private final MdnsFeatureFlags mMdnsFeatureFlags; 111 @NonNull 112 private final MdnsUtils.Clock mClock; 113 private long mNextExpirationTime = EXPIRATION_NEVER; 114 MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags)115 public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) { 116 this(looper, mdnsFeatureFlags, new MdnsUtils.Clock()); 117 } 118 119 @VisibleForTesting MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags, @NonNull MdnsUtils.Clock clock)120 MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags, 121 @NonNull MdnsUtils.Clock clock) { 122 mHandler = new Handler(looper); 123 mMdnsFeatureFlags = mdnsFeatureFlags; 124 mClock = clock; 125 } 126 cachedServicesToResponses(List<CachedService> cachedServices)127 private List<MdnsResponse> cachedServicesToResponses(List<CachedService> cachedServices) { 128 final List<MdnsResponse> responses = new ArrayList<>(); 129 for (CachedService cachedService : cachedServices) { 130 responses.add(cachedService.mService); 131 } 132 return responses; 133 } 134 135 /** 136 * Get the cache services which are queried from given service type and socket. 137 * 138 * @param cacheKey the target CacheKey. 139 * @return the set of services which matches the given service type. 140 */ 141 @NonNull getCachedServices(@onNull CacheKey cacheKey)142 public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) { 143 ensureRunningOnHandlerThread(mHandler); 144 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 145 maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime()); 146 } 147 return mCachedServices.containsKey(cacheKey) 148 ? Collections.unmodifiableList( 149 cachedServicesToResponses(mCachedServices.get(cacheKey))) 150 : Collections.emptyList(); 151 } 152 153 /** 154 * Find a matched response for given service name 155 * 156 * @param responses the responses to be searched. 157 * @param serviceName the target service name 158 * @return the response which matches the given service name or null if not found. 159 */ findMatchedResponse(@onNull List<MdnsResponse> responses, @NonNull String serviceName)160 public static MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses, 161 @NonNull String serviceName) { 162 for (MdnsResponse response : responses) { 163 if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { 164 return response; 165 } 166 } 167 return null; 168 } 169 findMatchedCachedService( @onNull List<CachedService> cachedServices, @NonNull String serviceName)170 private static CachedService findMatchedCachedService( 171 @NonNull List<CachedService> cachedServices, @NonNull String serviceName) { 172 for (CachedService cachedService : cachedServices) { 173 if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) { 174 return cachedService; 175 } 176 } 177 return null; 178 } 179 180 /** 181 * Get the cache service. 182 * 183 * @param serviceName the target service name. 184 * @param cacheKey the target CacheKey. 185 * @return the service which matches given conditions. 186 */ 187 @Nullable getCachedService(@onNull String serviceName, @NonNull CacheKey cacheKey)188 public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { 189 ensureRunningOnHandlerThread(mHandler); 190 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 191 maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime()); 192 } 193 final List<CachedService> cachedServices = mCachedServices.get(cacheKey); 194 if (cachedServices == null) { 195 return null; 196 } 197 final CachedService cachedService = findMatchedCachedService(cachedServices, serviceName); 198 return cachedService != null ? new MdnsResponse(cachedService.mService) : null; 199 } 200 insertServiceAndSortList( List<CachedService> cachedServices, CachedService cachedService, long now)201 static void insertServiceAndSortList( 202 List<CachedService> cachedServices, CachedService cachedService, long now) { 203 // binarySearch returns "the index of the search key, if it is contained in the list; 204 // otherwise, (-(insertion point) - 1)" 205 final int searchRes = Collections.binarySearch(cachedServices, cachedService, 206 // Sort the list by ttl. 207 (o1, o2) -> Long.compare(o1.mService.getMinRemainingTtl(now), 208 o2.mService.getMinRemainingTtl(now))); 209 cachedServices.add(searchRes >= 0 ? searchRes : (-searchRes - 1), cachedService); 210 } 211 212 /** 213 * Add or update a service. 214 * 215 * @param cacheKey the target CacheKey. 216 * @param response the response of the discovered service. 217 */ addOrUpdateService(@onNull CacheKey cacheKey, @NonNull MdnsResponse response)218 public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) { 219 ensureRunningOnHandlerThread(mHandler); 220 final List<CachedService> cachedServices = mCachedServices.computeIfAbsent( 221 cacheKey, key -> new ArrayList<>()); 222 // Remove existing service if present. 223 final CachedService existing = findMatchedCachedService(cachedServices, 224 response.getServiceInstanceName()); 225 cachedServices.remove(existing); 226 227 final CachedService cachedService = new CachedService(response); 228 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 229 final long now = mClock.elapsedRealtime(); 230 // Insert and sort service 231 insertServiceAndSortList(cachedServices, cachedService, now); 232 // Update the next expiration check time when a new service is added. 233 mNextExpirationTime = getNextExpirationTime(now); 234 } else { 235 cachedServices.add(cachedService); 236 } 237 } 238 239 /** 240 * Remove a service which matches the given service name, type and socket. 241 * 242 * @param serviceName the target service name. 243 * @param cacheKey the target CacheKey. 244 */ 245 @Nullable removeService(@onNull String serviceName, @NonNull CacheKey cacheKey)246 public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { 247 ensureRunningOnHandlerThread(mHandler); 248 final List<CachedService> cachedServices = mCachedServices.get(cacheKey); 249 if (cachedServices == null) { 250 return null; 251 } 252 final Iterator<CachedService> iterator = cachedServices.iterator(); 253 CachedService removedService = null; 254 while (iterator.hasNext()) { 255 final CachedService cachedService = iterator.next(); 256 if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) { 257 iterator.remove(); 258 removedService = cachedService; 259 break; 260 } 261 } 262 263 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 264 // Remove the serviceType if no response. 265 if (cachedServices.isEmpty()) { 266 mCachedServices.remove(cacheKey); 267 } 268 // Update the next expiration check time when a service is removed. 269 mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime()); 270 } 271 return removedService == null ? null : removedService.mService; 272 } 273 274 /** 275 * Remove services which matches the given type and socket. 276 * 277 * @param cacheKey the target CacheKey. 278 */ removeServices(@onNull CacheKey cacheKey)279 public void removeServices(@NonNull CacheKey cacheKey) { 280 ensureRunningOnHandlerThread(mHandler); 281 // Remove all services 282 if (mCachedServices.remove(cacheKey) == null) { 283 return; 284 } 285 // Update the next expiration check time if services are removed. 286 mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime()); 287 } 288 289 /** 290 * Register a callback to listen to service expiration. 291 * 292 * <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient 293 * relies on this. 294 * 295 * @param cacheKey the target CacheKey. 296 * @param callback the callback that notify the service is expired. 297 */ registerServiceExpiredCallback(@onNull CacheKey cacheKey, @NonNull ServiceExpiredCallback callback)298 public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey, 299 @NonNull ServiceExpiredCallback callback) { 300 ensureRunningOnHandlerThread(mHandler); 301 mCallbacks.put(cacheKey, callback); 302 } 303 304 /** 305 * Unregister the service expired callback. 306 * 307 * @param cacheKey the CacheKey that is registered to listen service expiration before. 308 */ unregisterServiceExpiredCallback(@onNull CacheKey cacheKey)309 public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) { 310 ensureRunningOnHandlerThread(mHandler); 311 mCallbacks.remove(cacheKey); 312 } 313 notifyServiceExpired(@onNull CacheKey cacheKey, @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)314 private void notifyServiceExpired(@NonNull CacheKey cacheKey, 315 @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse) { 316 final ServiceExpiredCallback callback = mCallbacks.get(cacheKey); 317 if (callback == null) { 318 // The cached service is no listener. 319 return; 320 } 321 mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse)); 322 } 323 removeExpiredServices(@onNull List<CachedService> cachedServices, long now)324 static List<CachedService> removeExpiredServices(@NonNull List<CachedService> cachedServices, 325 long now) { 326 final List<CachedService> removedServices = new ArrayList<>(); 327 final Iterator<CachedService> iterator = cachedServices.iterator(); 328 while (iterator.hasNext()) { 329 final CachedService cachedService = iterator.next(); 330 // TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's 331 // expired. Then send service update notification. 332 if (!cachedService.mService.hasServiceRecord() 333 || cachedService.mService.getMinRemainingTtl(now) > 0) { 334 // The responses are sorted by the service record ttl time. Break out of loop 335 // early if service is not expired or no service record. 336 break; 337 } 338 // Remove the ttl expired service. 339 iterator.remove(); 340 removedServices.add(cachedService); 341 } 342 return removedServices; 343 } 344 getNextExpirationTime(long now)345 private long getNextExpirationTime(long now) { 346 if (mCachedServices.isEmpty()) { 347 return EXPIRATION_NEVER; 348 } 349 350 long minRemainingTtl = EXPIRATION_NEVER; 351 for (int i = 0; i < mCachedServices.size(); i++) { 352 minRemainingTtl = min(minRemainingTtl, 353 // The empty lists are not kept in the map, so there's always at least one 354 // element in the list. Therefore, it's fine to get the first element without a 355 // null check. 356 mCachedServices.valueAt(i).get(0).mService.getMinRemainingTtl(now)); 357 } 358 return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl; 359 } 360 361 /** 362 * Check whether the ttl time is expired on each service and notify to the listeners 363 */ maybeRemoveExpiredServices(CacheKey cacheKey, long now)364 private void maybeRemoveExpiredServices(CacheKey cacheKey, long now) { 365 ensureRunningOnHandlerThread(mHandler); 366 if (now < mNextExpirationTime) { 367 // Skip the check if ttl time is not expired. 368 return; 369 } 370 371 final List<CachedService> cachedServices = mCachedServices.get(cacheKey); 372 if (cachedServices == null) { 373 // No such services. 374 return; 375 } 376 377 final List<CachedService> removedServices = removeExpiredServices(cachedServices, now); 378 if (removedServices.isEmpty()) { 379 // No expired services. 380 return; 381 } 382 383 for (CachedService previousService : removedServices) { 384 notifyServiceExpired(cacheKey, previousService.mService, null /* newResponse */); 385 } 386 387 // Remove the serviceType if no response. 388 if (cachedServices.isEmpty()) { 389 mCachedServices.remove(cacheKey); 390 } 391 392 // Update next expiration time. 393 mNextExpirationTime = getNextExpirationTime(now); 394 } 395 396 /** 397 * Dump ServiceCache state. 398 */ dump(PrintWriter pw, String indent)399 public void dump(PrintWriter pw, String indent) { 400 ensureRunningOnHandlerThread(mHandler); 401 // IndentingPrintWriter cannot be used on the mDNS stack build. So, manually add an indent. 402 for (int i = 0; i < mCachedServices.size(); i++) { 403 final CacheKey key = mCachedServices.keyAt(i); 404 pw.println(indent + key); 405 for (CachedService cachedService : mCachedServices.valueAt(i)) { 406 pw.println(indent + " Response{ " + cachedService.mService 407 + " } Expired=" + cachedService.mServiceExpired); 408 } 409 pw.println(); 410 } 411 } 412 413 /*** Callbacks for listening service expiration */ 414 public interface ServiceExpiredCallback { 415 /*** Notify the service is expired */ onServiceRecordExpired(@onNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)416 void onServiceRecordExpired(@NonNull MdnsResponse previousResponse, 417 @Nullable MdnsResponse newResponse); 418 } 419 420 // TODO: Schedule a job to check ttl expiration for all services and notify to the clients. 421 } 422