1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.service.topics.classifier; 18 19 import static java.util.stream.Collectors.toMap; 20 import static java.util.stream.Collectors.toSet; 21 22 import android.annotation.NonNull; 23 import android.content.Context; 24 import android.os.Build; 25 26 import androidx.annotation.RequiresApi; 27 28 import com.android.adservices.LoggerFactory; 29 import com.android.adservices.data.topics.Topic; 30 import com.android.adservices.service.Flags; 31 import com.android.adservices.service.Flags.ClassifierType; 32 import com.android.adservices.service.FlagsFactory; 33 import com.android.adservices.service.stats.AdServicesLoggerImpl; 34 import com.android.adservices.service.topics.CacheManager; 35 import com.android.internal.annotations.VisibleForTesting; 36 37 import com.google.common.base.Supplier; 38 import com.google.common.base.Suppliers; 39 40 import java.util.List; 41 import java.util.Map; 42 import java.util.Map.Entry; 43 import java.util.Random; 44 import java.util.Set; 45 import java.util.stream.Stream; 46 47 /** 48 * Manager class to control the classifier behaviour between available types of classifier based on 49 * classifier flags. 50 */ 51 // TODO(b/269798827): Enable for R. 52 @RequiresApi(Build.VERSION_CODES.S) 53 public class ClassifierManager implements Classifier { 54 private static final LoggerFactory.Logger sLogger = LoggerFactory.getTopicsLogger(); 55 private static ClassifierManager sSingleton; 56 57 private Supplier<OnDeviceClassifier> mOnDeviceClassifier; 58 private Supplier<PrecomputedClassifier> mPrecomputedClassifier; 59 60 @VisibleForTesting ClassifierManager( @onNull Supplier<OnDeviceClassifier> onDeviceClassifier, @NonNull Supplier<PrecomputedClassifier> precomputedClassifier)61 ClassifierManager( 62 @NonNull Supplier<OnDeviceClassifier> onDeviceClassifier, 63 @NonNull Supplier<PrecomputedClassifier> precomputedClassifier) { 64 mOnDeviceClassifier = onDeviceClassifier; 65 mPrecomputedClassifier = precomputedClassifier; 66 } 67 68 /** Returns the singleton instance of the {@link ClassifierManager} given a context. */ 69 @NonNull getInstance(@onNull Context context)70 public static ClassifierManager getInstance(@NonNull Context context) { 71 synchronized (ClassifierManager.class) { 72 if (sSingleton == null) { 73 // Note: we need to have a singleton ModelManager shared by both Classifiers. 74 sSingleton = 75 new ClassifierManager( 76 Suppliers.memoize( 77 () -> 78 new OnDeviceClassifier( 79 new Random(), 80 ModelManager.getInstance(context), 81 CacheManager.getInstance(context), 82 ClassifierInputManager.getInstance(context), 83 AdServicesLoggerImpl.getInstance())), 84 Suppliers.memoize( 85 () -> 86 new PrecomputedClassifier( 87 ModelManager.getInstance(context), 88 CacheManager.getInstance(context), 89 AdServicesLoggerImpl.getInstance()))); 90 } 91 } 92 return sSingleton; 93 } 94 95 /** 96 * {@inheritDoc} 97 * 98 * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values. 99 */ 100 @Override classify(Set<String> apps)101 public Map<String, List<Topic>> classify(Set<String> apps) { 102 Flags flags = FlagsFactory.getFlags(); 103 104 if (flags.getTopicsOnDeviceClassifierKillSwitch()) { 105 sLogger.v( 106 "On-device classifier disabled via topics on device classifier kill switch - " 107 + "falling back to precomputed classifier"); 108 return mPrecomputedClassifier.get().classify(apps); 109 } 110 111 @ClassifierType int classifierTypeFlag = flags.getClassifierType(); 112 if (classifierTypeFlag == Flags.PRECOMPUTED_CLASSIFIER) { 113 sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_CLASSIFIER"); 114 return mPrecomputedClassifier.get().classify(apps); 115 } else if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) { 116 sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = ON_DEVICE_CLASSIFIER"); 117 return mOnDeviceClassifier.get().classify(apps); 118 } else { 119 sLogger.v( 120 "ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_THEN_ON_DEVICE"); 121 // PRECOMPUTED_THEN_ON_DEVICE 122 // Default if classifierTypeFlag value is not set/invalid. 123 // precomputedClassifications expects non-empty values. 124 Map<String, List<Topic>> precomputedClassifications = 125 mPrecomputedClassifier.get().classify(apps); 126 // Collect package names that do not have any topics in the precomputed list. 127 Set<String> remainingApps = 128 apps.stream() 129 .filter( 130 packageName -> 131 !isValidValue(packageName, precomputedClassifications)) 132 .collect(toSet()); 133 Map<String, List<Topic>> onDeviceClassifications = 134 mOnDeviceClassifier.get().classify(remainingApps); 135 136 // Combine classification values. On device classifications are used for values that 137 // do not have valid precomputed classifications. 138 Map<String, List<Topic>> combinedClassifications = 139 Stream.concat( 140 onDeviceClassifications.entrySet().stream(), 141 precomputedClassifications.entrySet().stream()) 142 .collect( 143 toMap( 144 Entry::getKey, 145 Entry::getValue, 146 ClassifierManager::combineTopics)); 147 return combinedClassifications; 148 } 149 } 150 151 /** 152 * {@inheritDoc} 153 * 154 * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values. 155 */ 156 @Override getTopTopics( Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics)157 public List<Topic> getTopTopics( 158 Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics) { 159 Flags flags = FlagsFactory.getFlags(); 160 161 if (flags.getTopicsOnDeviceClassifierKillSwitch()) { 162 sLogger.v( 163 "On-device classifier disabled via topics on device classifier kill switch - " 164 + "falling back to precomputed classifier"); 165 return mPrecomputedClassifier 166 .get() 167 .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics); 168 } 169 170 @ClassifierType int classifierTypeFlag = flags.getClassifierType(); 171 // getTopTopics has the same implementation. 172 // If the loaded assets are same, the output will be same. 173 if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) { 174 return mOnDeviceClassifier 175 .get() 176 .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics); 177 } else { 178 // Use getTopics from PrecomputedClassifier as default. 179 return mPrecomputedClassifier 180 .get() 181 .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics); 182 } 183 } 184 185 // Prefer precomputed values for topics if the list is not empty. combineTopics( List<Topic> onDeviceValue, List<Topic> precomputedValue)186 private static List<Topic> combineTopics( 187 List<Topic> onDeviceValue, List<Topic> precomputedValue) { 188 if (!precomputedValue.isEmpty()) { 189 return precomputedValue; 190 } 191 return onDeviceValue; 192 } 193 194 // Return true if package name has non-empty list of topics in the classifications. isValidValue(String packageName, Map<String, List<Topic>> classifications)195 private boolean isValidValue(String packageName, Map<String, List<Topic>> classifications) { 196 if (classifications.containsKey(packageName) 197 && !classifications.get(packageName).isEmpty()) { 198 return true; 199 } 200 return false; 201 } 202 } 203