• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.service.topics.classifier;
18 
19 import static java.util.stream.Collectors.toMap;
20 import static java.util.stream.Collectors.toSet;
21 
22 import android.annotation.NonNull;
23 import android.content.Context;
24 import android.os.Build;
25 
26 import androidx.annotation.RequiresApi;
27 
28 import com.android.adservices.LoggerFactory;
29 import com.android.adservices.data.topics.Topic;
30 import com.android.adservices.service.Flags;
31 import com.android.adservices.service.Flags.ClassifierType;
32 import com.android.adservices.service.FlagsFactory;
33 import com.android.adservices.service.stats.AdServicesLoggerImpl;
34 import com.android.adservices.service.topics.CacheManager;
35 import com.android.internal.annotations.VisibleForTesting;
36 
37 import com.google.common.base.Supplier;
38 import com.google.common.base.Suppliers;
39 
40 import java.util.List;
41 import java.util.Map;
42 import java.util.Map.Entry;
43 import java.util.Random;
44 import java.util.Set;
45 import java.util.stream.Stream;
46 
47 /**
48  * Manager class to control the classifier behaviour between available types of classifier based on
49  * classifier flags.
50  */
51 // TODO(b/269798827): Enable for R.
52 @RequiresApi(Build.VERSION_CODES.S)
53 public class ClassifierManager implements Classifier {
54     private static final LoggerFactory.Logger sLogger = LoggerFactory.getTopicsLogger();
55     private static ClassifierManager sSingleton;
56 
57     private Supplier<OnDeviceClassifier> mOnDeviceClassifier;
58     private Supplier<PrecomputedClassifier> mPrecomputedClassifier;
59 
60     @VisibleForTesting
ClassifierManager( @onNull Supplier<OnDeviceClassifier> onDeviceClassifier, @NonNull Supplier<PrecomputedClassifier> precomputedClassifier)61     ClassifierManager(
62             @NonNull Supplier<OnDeviceClassifier> onDeviceClassifier,
63             @NonNull Supplier<PrecomputedClassifier> precomputedClassifier) {
64         mOnDeviceClassifier = onDeviceClassifier;
65         mPrecomputedClassifier = precomputedClassifier;
66     }
67 
68     /** Returns the singleton instance of the {@link ClassifierManager} given a context. */
69     @NonNull
getInstance(@onNull Context context)70     public static ClassifierManager getInstance(@NonNull Context context) {
71         synchronized (ClassifierManager.class) {
72             if (sSingleton == null) {
73                 // Note: we need to have a singleton ModelManager shared by both Classifiers.
74                 sSingleton =
75                         new ClassifierManager(
76                                 Suppliers.memoize(
77                                         () ->
78                                                 new OnDeviceClassifier(
79                                                         new Random(),
80                                                         ModelManager.getInstance(context),
81                                                         CacheManager.getInstance(context),
82                                                         ClassifierInputManager.getInstance(context),
83                                                         AdServicesLoggerImpl.getInstance())),
84                                 Suppliers.memoize(
85                                         () ->
86                                                 new PrecomputedClassifier(
87                                                         ModelManager.getInstance(context),
88                                                         CacheManager.getInstance(context),
89                                                         AdServicesLoggerImpl.getInstance())));
90             }
91         }
92         return sSingleton;
93     }
94 
95     /**
96      * {@inheritDoc}
97      *
98      * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values.
99      */
100     @Override
classify(Set<String> apps)101     public Map<String, List<Topic>> classify(Set<String> apps) {
102         Flags flags = FlagsFactory.getFlags();
103 
104         if (flags.getTopicsOnDeviceClassifierKillSwitch()) {
105             sLogger.v(
106                     "On-device classifier disabled via topics on device classifier kill switch - "
107                             + "falling back to precomputed classifier");
108             return mPrecomputedClassifier.get().classify(apps);
109         }
110 
111         @ClassifierType int classifierTypeFlag = flags.getClassifierType();
112         if (classifierTypeFlag == Flags.PRECOMPUTED_CLASSIFIER) {
113             sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_CLASSIFIER");
114             return mPrecomputedClassifier.get().classify(apps);
115         } else if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) {
116             sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = ON_DEVICE_CLASSIFIER");
117             return mOnDeviceClassifier.get().classify(apps);
118         } else {
119             sLogger.v(
120                     "ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_THEN_ON_DEVICE");
121             // PRECOMPUTED_THEN_ON_DEVICE
122             // Default if classifierTypeFlag value is not set/invalid.
123             // precomputedClassifications expects non-empty values.
124             Map<String, List<Topic>> precomputedClassifications =
125                     mPrecomputedClassifier.get().classify(apps);
126             // Collect package names that do not have any topics in the precomputed list.
127             Set<String> remainingApps =
128                     apps.stream()
129                             .filter(
130                                     packageName ->
131                                             !isValidValue(packageName, precomputedClassifications))
132                             .collect(toSet());
133             Map<String, List<Topic>> onDeviceClassifications =
134                     mOnDeviceClassifier.get().classify(remainingApps);
135 
136             // Combine classification values. On device classifications are used for values that
137             // do not have valid precomputed classifications.
138             Map<String, List<Topic>> combinedClassifications =
139                     Stream.concat(
140                                     onDeviceClassifications.entrySet().stream(),
141                                     precomputedClassifications.entrySet().stream())
142                             .collect(
143                                     toMap(
144                                             Entry::getKey,
145                                             Entry::getValue,
146                                             ClassifierManager::combineTopics));
147             return combinedClassifications;
148         }
149     }
150 
151     /**
152      * {@inheritDoc}
153      *
154      * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values.
155      */
156     @Override
getTopTopics( Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics)157     public List<Topic> getTopTopics(
158             Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics) {
159         Flags flags = FlagsFactory.getFlags();
160 
161         if (flags.getTopicsOnDeviceClassifierKillSwitch()) {
162             sLogger.v(
163                     "On-device classifier disabled via topics on device classifier kill switch - "
164                             + "falling back to precomputed classifier");
165             return mPrecomputedClassifier
166                     .get()
167                     .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics);
168         }
169 
170         @ClassifierType int classifierTypeFlag = flags.getClassifierType();
171         // getTopTopics has the same implementation.
172         // If the loaded assets are same, the output will be same.
173         if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) {
174             return mOnDeviceClassifier
175                     .get()
176                     .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics);
177         } else {
178             // Use getTopics from PrecomputedClassifier as default.
179             return mPrecomputedClassifier
180                     .get()
181                     .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics);
182         }
183     }
184 
185     // Prefer precomputed values for topics if the list is not empty.
combineTopics( List<Topic> onDeviceValue, List<Topic> precomputedValue)186     private static List<Topic> combineTopics(
187             List<Topic> onDeviceValue, List<Topic> precomputedValue) {
188         if (!precomputedValue.isEmpty()) {
189             return precomputedValue;
190         }
191         return onDeviceValue;
192     }
193 
194     // Return true if package name has non-empty list of topics in the classifications.
isValidValue(String packageName, Map<String, List<Topic>> classifications)195     private boolean isValidValue(String packageName, Map<String, List<Topic>> classifications) {
196         if (classifications.containsKey(packageName)
197                 && !classifications.get(packageName).isEmpty()) {
198             return true;
199         }
200         return false;
201     }
202 }
203