• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns;
18 
19 import static com.android.net.module.util.DnsUtils.equalsIgnoreDnsCase;
20 import static com.android.net.module.util.DnsUtils.toDnsUpperCase;
21 import static com.android.net.module.util.HandlerUtils.ensureRunningOnHandlerThread;
22 import static com.android.server.connectivity.mdns.MdnsResponse.EXPIRATION_NEVER;
23 
24 import static java.lang.Math.min;
25 
26 import android.annotation.NonNull;
27 import android.annotation.Nullable;
28 import android.os.Handler;
29 import android.os.Looper;
30 import android.util.ArrayMap;
31 
32 import com.android.internal.annotations.VisibleForTesting;
33 import com.android.server.connectivity.mdns.util.MdnsUtils;
34 
35 import java.io.PrintWriter;
36 import java.util.ArrayList;
37 import java.util.Collections;
38 import java.util.Iterator;
39 import java.util.List;
40 import java.util.Objects;
41 
42 /**
43  * The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these
44  * services to reduce duplicated queries.
45  *
46  * <p>This class is not thread safe, it is intended to be used only from the looper thread.
47  *  However, the constructor is an exception, as it is called on another thread;
48  *  therefore for thread safety all members of this class MUST either be final or initialized
49  *  to their default value (0, false or null).
50  */
51 public class MdnsServiceCache {
52     public static class CacheKey {
53         @NonNull final String mUpperCaseServiceType;
54         @NonNull final SocketKey mSocketKey;
55 
CacheKey(@onNull String serviceType, @NonNull SocketKey socketKey)56         CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) {
57             mUpperCaseServiceType = toDnsUpperCase(serviceType);
58             mSocketKey = socketKey;
59         }
60 
hashCode()61         @Override public int hashCode() {
62             return Objects.hash(mUpperCaseServiceType, mSocketKey);
63         }
64 
equals(Object other)65         @Override public boolean equals(Object other) {
66             if (this == other) {
67                 return true;
68             }
69             if (!(other instanceof CacheKey)) {
70                 return false;
71             }
72             return Objects.equals(mUpperCaseServiceType, ((CacheKey) other).mUpperCaseServiceType)
73                     && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey);
74         }
75 
76         @Override
toString()77         public String toString() {
78             return "CacheKey{ ServiceType=" + mUpperCaseServiceType + ", " + mSocketKey + " }";
79         }
80     }
81 
82     public static class CachedService {
83         @NonNull final MdnsResponse mService;
84         boolean mServiceExpired;
85 
CachedService(MdnsResponse service)86         CachedService(MdnsResponse service) {
87             mService = service;
88             mServiceExpired = false;
89         }
90     }
91 
92     /**
93      * A map of cached services. Key is composed of service type and socket. Value is the list of
94      * services which are discovered from the given CacheKey.
95      * When the MdnsFeatureFlags#NSD_EXPIRED_SERVICES_REMOVAL flag is enabled, the lists are sorted
96      * by expiration time, with the earliest entries appearing first. This sorting allows the
97      * removal process to progress through the expiration check efficiently.
98      */
99     @NonNull
100     private final ArrayMap<CacheKey, List<CachedService>> mCachedServices = new ArrayMap<>();
101     /**
102      * A map of service expire callbacks. Key is composed of service type and socket and value is
103      * the callback listener.
104      */
105     @NonNull
106     private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>();
107     @NonNull
108     private final Handler mHandler;
109     @NonNull
110     private final MdnsFeatureFlags mMdnsFeatureFlags;
111     @NonNull
112     private final MdnsUtils.Clock mClock;
113     private long mNextExpirationTime = EXPIRATION_NEVER;
114 
MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags)115     public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
116         this(looper, mdnsFeatureFlags, new MdnsUtils.Clock());
117     }
118 
119     @VisibleForTesting
MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags, @NonNull MdnsUtils.Clock clock)120     MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags,
121             @NonNull MdnsUtils.Clock clock) {
122         mHandler = new Handler(looper);
123         mMdnsFeatureFlags = mdnsFeatureFlags;
124         mClock = clock;
125     }
126 
cachedServicesToResponses(List<CachedService> cachedServices)127     private List<MdnsResponse> cachedServicesToResponses(List<CachedService> cachedServices) {
128         final List<MdnsResponse> responses = new ArrayList<>();
129         for (CachedService cachedService : cachedServices) {
130             responses.add(cachedService.mService);
131         }
132         return responses;
133     }
134 
135     /**
136      * Get the cache services which are queried from given service type and socket.
137      *
138      * @param cacheKey the target CacheKey.
139      * @return the set of services which matches the given service type.
140      */
141     @NonNull
getCachedServices(@onNull CacheKey cacheKey)142     public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) {
143         ensureRunningOnHandlerThread(mHandler);
144         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
145             maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
146         }
147         return mCachedServices.containsKey(cacheKey)
148                 ? Collections.unmodifiableList(
149                         cachedServicesToResponses(mCachedServices.get(cacheKey)))
150                 : Collections.emptyList();
151     }
152 
153     /**
154      * Find a matched response for given service name
155      *
156      * @param responses the responses to be searched.
157      * @param serviceName the target service name
158      * @return the response which matches the given service name or null if not found.
159      */
findMatchedResponse(@onNull List<MdnsResponse> responses, @NonNull String serviceName)160     public static MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses,
161             @NonNull String serviceName) {
162         for (MdnsResponse response : responses) {
163             if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
164                 return response;
165             }
166         }
167         return null;
168     }
169 
findMatchedCachedService( @onNull List<CachedService> cachedServices, @NonNull String serviceName)170     private static CachedService findMatchedCachedService(
171             @NonNull List<CachedService> cachedServices, @NonNull String serviceName) {
172         for (CachedService cachedService : cachedServices) {
173             if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) {
174                 return cachedService;
175             }
176         }
177         return null;
178     }
179 
180     /**
181      * Get the cache service.
182      *
183      * @param serviceName the target service name.
184      * @param cacheKey the target CacheKey.
185      * @return the service which matches given conditions.
186      */
187     @Nullable
getCachedService(@onNull String serviceName, @NonNull CacheKey cacheKey)188     public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
189         ensureRunningOnHandlerThread(mHandler);
190         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
191             maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
192         }
193         final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
194         if (cachedServices == null) {
195             return null;
196         }
197         final CachedService cachedService = findMatchedCachedService(cachedServices, serviceName);
198         return cachedService != null ? new MdnsResponse(cachedService.mService) : null;
199     }
200 
insertServiceAndSortList( List<CachedService> cachedServices, CachedService cachedService, long now)201     static void insertServiceAndSortList(
202             List<CachedService> cachedServices, CachedService cachedService, long now) {
203         // binarySearch returns "the index of the search key, if it is contained in the list;
204         // otherwise, (-(insertion point) - 1)"
205         final int searchRes = Collections.binarySearch(cachedServices, cachedService,
206                 // Sort the list by ttl.
207                 (o1, o2) -> Long.compare(o1.mService.getMinRemainingTtl(now),
208                         o2.mService.getMinRemainingTtl(now)));
209         cachedServices.add(searchRes >= 0 ? searchRes : (-searchRes - 1), cachedService);
210     }
211 
212     /**
213      * Add or update a service.
214      *
215      * @param cacheKey the target CacheKey.
216      * @param response the response of the discovered service.
217      */
addOrUpdateService(@onNull CacheKey cacheKey, @NonNull MdnsResponse response)218     public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) {
219         ensureRunningOnHandlerThread(mHandler);
220         final List<CachedService> cachedServices = mCachedServices.computeIfAbsent(
221                 cacheKey, key -> new ArrayList<>());
222         // Remove existing service if present.
223         final CachedService existing = findMatchedCachedService(cachedServices,
224                 response.getServiceInstanceName());
225         cachedServices.remove(existing);
226 
227         final CachedService cachedService = new CachedService(response);
228         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
229             final long now = mClock.elapsedRealtime();
230             // Insert and sort service
231             insertServiceAndSortList(cachedServices, cachedService, now);
232             // Update the next expiration check time when a new service is added.
233             mNextExpirationTime = getNextExpirationTime(now);
234         } else {
235             cachedServices.add(cachedService);
236         }
237     }
238 
239     /**
240      * Remove a service which matches the given service name, type and socket.
241      *
242      * @param serviceName the target service name.
243      * @param cacheKey the target CacheKey.
244      */
245     @Nullable
removeService(@onNull String serviceName, @NonNull CacheKey cacheKey)246     public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
247         ensureRunningOnHandlerThread(mHandler);
248         final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
249         if (cachedServices == null) {
250             return null;
251         }
252         final Iterator<CachedService> iterator = cachedServices.iterator();
253         CachedService removedService = null;
254         while (iterator.hasNext()) {
255             final CachedService cachedService = iterator.next();
256             if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) {
257                 iterator.remove();
258                 removedService = cachedService;
259                 break;
260             }
261         }
262 
263         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
264             // Remove the serviceType if no response.
265             if (cachedServices.isEmpty()) {
266                 mCachedServices.remove(cacheKey);
267             }
268             // Update the next expiration check time when a service is removed.
269             mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
270         }
271         return removedService == null ? null : removedService.mService;
272     }
273 
274     /**
275      * Remove services which matches the given type and socket.
276      *
277      * @param cacheKey the target CacheKey.
278      */
removeServices(@onNull CacheKey cacheKey)279     public void removeServices(@NonNull CacheKey cacheKey) {
280         ensureRunningOnHandlerThread(mHandler);
281         // Remove all services
282         if (mCachedServices.remove(cacheKey) == null) {
283             return;
284         }
285         // Update the next expiration check time if services are removed.
286         mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
287     }
288 
289     /**
290      * Register a callback to listen to service expiration.
291      *
292      * <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient
293      * relies on this.
294      *
295      * @param cacheKey the target CacheKey.
296      * @param callback the callback that notify the service is expired.
297      */
registerServiceExpiredCallback(@onNull CacheKey cacheKey, @NonNull ServiceExpiredCallback callback)298     public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey,
299             @NonNull ServiceExpiredCallback callback) {
300         ensureRunningOnHandlerThread(mHandler);
301         mCallbacks.put(cacheKey, callback);
302     }
303 
304     /**
305      * Unregister the service expired callback.
306      *
307      * @param cacheKey the CacheKey that is registered to listen service expiration before.
308      */
unregisterServiceExpiredCallback(@onNull CacheKey cacheKey)309     public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) {
310         ensureRunningOnHandlerThread(mHandler);
311         mCallbacks.remove(cacheKey);
312     }
313 
notifyServiceExpired(@onNull CacheKey cacheKey, @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)314     private void notifyServiceExpired(@NonNull CacheKey cacheKey,
315             @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse) {
316         final ServiceExpiredCallback callback = mCallbacks.get(cacheKey);
317         if (callback == null) {
318             // The cached service is no listener.
319             return;
320         }
321         mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse));
322     }
323 
removeExpiredServices(@onNull List<CachedService> cachedServices, long now)324     static List<CachedService> removeExpiredServices(@NonNull List<CachedService> cachedServices,
325             long now) {
326         final List<CachedService> removedServices = new ArrayList<>();
327         final Iterator<CachedService> iterator = cachedServices.iterator();
328         while (iterator.hasNext()) {
329             final CachedService cachedService = iterator.next();
330             // TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's
331             //  expired. Then send service update notification.
332             if (!cachedService.mService.hasServiceRecord()
333                     || cachedService.mService.getMinRemainingTtl(now) > 0) {
334                 // The responses are sorted by the service record ttl time. Break out of loop
335                 // early if service is not expired or no service record.
336                 break;
337             }
338             // Remove the ttl expired service.
339             iterator.remove();
340             removedServices.add(cachedService);
341         }
342         return removedServices;
343     }
344 
getNextExpirationTime(long now)345     private long getNextExpirationTime(long now) {
346         if (mCachedServices.isEmpty()) {
347             return EXPIRATION_NEVER;
348         }
349 
350         long minRemainingTtl = EXPIRATION_NEVER;
351         for (int i = 0; i < mCachedServices.size(); i++) {
352             minRemainingTtl = min(minRemainingTtl,
353                     // The empty lists are not kept in the map, so there's always at least one
354                     // element in the list. Therefore, it's fine to get the first element without a
355                     // null check.
356                     mCachedServices.valueAt(i).get(0).mService.getMinRemainingTtl(now));
357         }
358         return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl;
359     }
360 
361     /**
362      * Check whether the ttl time is expired on each service and notify to the listeners
363      */
maybeRemoveExpiredServices(CacheKey cacheKey, long now)364     private void maybeRemoveExpiredServices(CacheKey cacheKey, long now) {
365         ensureRunningOnHandlerThread(mHandler);
366         if (now < mNextExpirationTime) {
367             // Skip the check if ttl time is not expired.
368             return;
369         }
370 
371         final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
372         if (cachedServices == null) {
373             // No such services.
374             return;
375         }
376 
377         final List<CachedService> removedServices = removeExpiredServices(cachedServices, now);
378         if (removedServices.isEmpty()) {
379             // No expired services.
380             return;
381         }
382 
383         for (CachedService previousService : removedServices) {
384             notifyServiceExpired(cacheKey, previousService.mService, null /* newResponse */);
385         }
386 
387         // Remove the serviceType if no response.
388         if (cachedServices.isEmpty()) {
389             mCachedServices.remove(cacheKey);
390         }
391 
392         // Update next expiration time.
393         mNextExpirationTime = getNextExpirationTime(now);
394     }
395 
396     /**
397      * Dump ServiceCache state.
398      */
dump(PrintWriter pw, String indent)399     public void dump(PrintWriter pw, String indent) {
400         ensureRunningOnHandlerThread(mHandler);
401         // IndentingPrintWriter cannot be used on the mDNS stack build. So, manually add an indent.
402         for (int i = 0; i < mCachedServices.size(); i++) {
403             final CacheKey key = mCachedServices.keyAt(i);
404             pw.println(indent + key);
405             for (CachedService cachedService : mCachedServices.valueAt(i)) {
406                 pw.println(indent + "  Response{ " + cachedService.mService
407                         + " } Expired=" + cachedService.mServiceExpired);
408             }
409             pw.println();
410         }
411     }
412 
413     /*** Callbacks for listening service expiration */
414     public interface ServiceExpiredCallback {
415         /*** Notify the service is expired */
onServiceRecordExpired(@onNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)416         void onServiceRecordExpired(@NonNull MdnsResponse previousResponse,
417                 @Nullable MdnsResponse newResponse);
418     }
419 
420     // TODO: Schedule a job to check ttl expiration for all services and notify to the clients.
421 }
422