• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import static java.lang.Math.min;
20 
21 import android.content.Context;
22 import android.content.res.AssetManager;
23 import android.os.LocaleList;
24 import androidx.annotation.GuardedBy;
25 import androidx.collection.ArrayMap;
26 import com.android.textclassifier.common.ModelFile;
27 import com.android.textclassifier.common.ModelType;
28 import com.android.textclassifier.common.ModelType.ModelTypeDef;
29 import com.android.textclassifier.common.TextClassifierSettings;
30 import com.android.textclassifier.common.base.TcLog;
31 import com.android.textclassifier.downloader.ModelDownloadManager;
32 import com.android.textclassifier.utils.IndentingPrintWriter;
33 import com.google.common.annotations.VisibleForTesting;
34 import com.google.common.base.Preconditions;
35 import com.google.common.base.Splitter;
36 import com.google.common.base.Supplier;
37 import com.google.common.collect.ImmutableList;
38 import java.io.File;
39 import java.io.IOException;
40 import java.util.List;
41 import java.util.Locale;
42 import java.util.Map;
43 import java.util.regex.Matcher;
44 import java.util.regex.Pattern;
45 import javax.annotation.Nullable;
46 
47 // TODO(licha): Consider making this a singleton class
48 // TODO(licha): Check whether this is thread-safe
49 /**
50  * Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
51  * model files to load.
52  */
53 final class ModelFileManagerImpl implements ModelFileManager {
54 
55   private static final String TAG = "ModelFileManagerImpl";
56 
57   private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
58   private static final String ASSETS_DIR = "textclassifier";
59 
60   private ImmutableList<ModelFileLister> modelFileListers;
61 
62   private final TextClassifierSettings settings;
63 
ModelFileManagerImpl( Context context, ModelDownloadManager modelDownloadManager, TextClassifierSettings settings)64   public ModelFileManagerImpl(
65       Context context, ModelDownloadManager modelDownloadManager, TextClassifierSettings settings) {
66 
67     Preconditions.checkNotNull(context);
68     Preconditions.checkNotNull(modelDownloadManager);
69 
70     this.settings = Preconditions.checkNotNull(settings);
71 
72     AssetManager assetManager = context.getAssets();
73     modelFileListers =
74         ImmutableList.of(
75             // Annotator models.
76             new RegularFileFullMatchLister(
77                 ModelType.ANNOTATOR,
78                 new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
79                 /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
80             new AssetFilePatternMatchLister(
81                 assetManager,
82                 ModelType.ANNOTATOR,
83                 ASSETS_DIR,
84                 "annotator\\.(.*)\\.model",
85                 /* isEnabled= */ () -> true),
86             // Actions models.
87             new RegularFileFullMatchLister(
88                 ModelType.ACTIONS_SUGGESTIONS,
89                 new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
90                 /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
91             new AssetFilePatternMatchLister(
92                 assetManager,
93                 ModelType.ACTIONS_SUGGESTIONS,
94                 ASSETS_DIR,
95                 "actions_suggestions\\.(.*)\\.model",
96                 /* isEnabled= */ () -> true),
97             // LangID models.
98             new RegularFileFullMatchLister(
99                 ModelType.LANG_ID,
100                 new File(CONFIG_UPDATER_DIR, "lang_id.model"),
101                 /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
102             new AssetFilePatternMatchLister(
103                 assetManager,
104                 ModelType.LANG_ID,
105                 ASSETS_DIR,
106                 "lang_id.model",
107                 /* isEnabled= */ () -> true),
108             new DownloaderModelsLister(modelDownloadManager, settings));
109   }
110 
111   @VisibleForTesting
ModelFileManagerImpl( Context context, List<ModelFileLister> modelFileListers, TextClassifierSettings settings)112   public ModelFileManagerImpl(
113       Context context, List<ModelFileLister> modelFileListers, TextClassifierSettings settings) {
114     this.modelFileListers = ImmutableList.copyOf(modelFileListers);
115     this.settings = settings;
116   }
117 
listModelFiles(@odelTypeDef String modelType)118   public ImmutableList<ModelFile> listModelFiles(@ModelTypeDef String modelType) {
119     Preconditions.checkNotNull(modelType);
120 
121     ImmutableList.Builder<ModelFile> modelFiles = new ImmutableList.Builder<>();
122     for (ModelFileLister modelFileLister : modelFileListers) {
123       modelFiles.addAll(modelFileLister.list(modelType));
124     }
125     return modelFiles.build();
126   }
127 
128   /** Lists model files. */
129   @FunctionalInterface
130   public interface ModelFileLister {
list(@odelTypeDef String modelType)131     List<ModelFile> list(@ModelTypeDef String modelType);
132   }
133 
134   /** Lists Downloader models */
135   public static class DownloaderModelsLister implements ModelFileLister {
136 
137     private final ModelDownloadManager modelDownloadManager;
138     private final TextClassifierSettings settings;
139 
140     /**
141      * @param modelDownloadManager manager of downloaded models
142      * @param settings current settings
143      */
DownloaderModelsLister( ModelDownloadManager modelDownloadManager, TextClassifierSettings settings)144     public DownloaderModelsLister(
145         ModelDownloadManager modelDownloadManager, TextClassifierSettings settings) {
146       this.modelDownloadManager = Preconditions.checkNotNull(modelDownloadManager);
147       this.settings = Preconditions.checkNotNull(settings);
148     }
149 
150     @Override
list(@odelTypeDef String modelType)151     public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
152       ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
153       if (settings.isModelDownloadManagerEnabled()) {
154         for (File modelFile : modelDownloadManager.listDownloadedModels(modelType)) {
155           try {
156             // TODO(licha): Construct downloader model files with locale tag in our internal
157             // database
158             modelFilesBuilder.add(ModelFile.createFromRegularFile(modelFile, modelType));
159           } catch (IOException e) {
160             TcLog.e(TAG, "Failed to create ModelFile: " + modelFile.getAbsolutePath(), e);
161           }
162         }
163       }
164       return modelFilesBuilder.build();
165     }
166   }
167 
168   /** Lists model files by performing full match on file path. */
169   public static class RegularFileFullMatchLister implements ModelFileLister {
170     private final String modelType;
171     private final File targetFile;
172     private final Supplier<Boolean> isEnabled;
173 
174     /**
175      * @param modelType the type of the model
176      * @param targetFile the expected model file
177      * @param isEnabled whether this lister is enabled
178      */
RegularFileFullMatchLister( @odelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled)179     public RegularFileFullMatchLister(
180         @ModelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled) {
181       this.modelType = Preconditions.checkNotNull(modelType);
182       this.targetFile = Preconditions.checkNotNull(targetFile);
183       this.isEnabled = Preconditions.checkNotNull(isEnabled);
184     }
185 
186     @Override
list(@odelTypeDef String modelType)187     public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
188       if (!this.modelType.equals(modelType)) {
189         return ImmutableList.of();
190       }
191       if (!isEnabled.get()) {
192         return ImmutableList.of();
193       }
194       if (!targetFile.exists()) {
195         return ImmutableList.of();
196       }
197       try {
198         return ImmutableList.of(ModelFile.createFromRegularFile(targetFile, modelType));
199       } catch (IOException e) {
200         TcLog.e(
201             TAG, "Failed to call createFromRegularFile with: " + targetFile.getAbsolutePath(), e);
202       }
203       return ImmutableList.of();
204     }
205   }
206 
207   /** Lists model file in a specified folder by doing pattern matching on file names. */
208   public static class RegularFilePatternMatchLister implements ModelFileLister {
209     private final String modelType;
210     private final File folder;
211     private final Pattern fileNamePattern;
212     private final Supplier<Boolean> isEnabled;
213 
214     /**
215      * @param modelType the type of the model
216      * @param folder the folder to list files
217      * @param fileNameRegex the regex to match the file name in the specified folder
218      * @param isEnabled whether the lister is enabled
219      */
RegularFilePatternMatchLister( @odelTypeDef String modelType, File folder, String fileNameRegex, Supplier<Boolean> isEnabled)220     public RegularFilePatternMatchLister(
221         @ModelTypeDef String modelType,
222         File folder,
223         String fileNameRegex,
224         Supplier<Boolean> isEnabled) {
225       this.modelType = Preconditions.checkNotNull(modelType);
226       this.folder = Preconditions.checkNotNull(folder);
227       this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
228       this.isEnabled = Preconditions.checkNotNull(isEnabled);
229     }
230 
231     @Override
list(@odelTypeDef String modelType)232     public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
233       if (!this.modelType.equals(modelType)) {
234         return ImmutableList.of();
235       }
236       if (!isEnabled.get()) {
237         return ImmutableList.of();
238       }
239       if (!folder.isDirectory()) {
240         return ImmutableList.of();
241       }
242       File[] files = folder.listFiles();
243       if (files == null) {
244         return ImmutableList.of();
245       }
246       ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
247       for (File file : files) {
248         final Matcher matcher = fileNamePattern.matcher(file.getName());
249         if (!matcher.matches() || !file.isFile()) {
250           continue;
251         }
252         try {
253           modelFilesBuilder.add(ModelFile.createFromRegularFile(file, modelType));
254         } catch (IOException e) {
255           TcLog.w(TAG, "Failed to call createFromRegularFile with: " + file.getAbsolutePath());
256         }
257       }
258       return modelFilesBuilder.build();
259     }
260   }
261 
262   /** Lists the model files preloaded in the APK file. */
263   public static class AssetFilePatternMatchLister implements ModelFileLister {
264     private final AssetManager assetManager;
265     private final String modelType;
266     private final String pathToList;
267     private final Pattern fileNamePattern;
268     private final Supplier<Boolean> isEnabled;
269     private final Object lock = new Object();
270     // Assets won't change without updating the app, so cache the result for performance reason.
271     @GuardedBy("lock")
272     private final Map<String, ImmutableList<ModelFile>> resultCache;
273 
274     /**
275      * @param modelType the type of the model.
276      * @param pathToList the folder to list files
277      * @param fileNameRegex the regex to match the file name in the specified folder
278      * @param isEnabled whether this lister is enabled
279      */
AssetFilePatternMatchLister( AssetManager assetManager, @ModelTypeDef String modelType, String pathToList, String fileNameRegex, Supplier<Boolean> isEnabled)280     public AssetFilePatternMatchLister(
281         AssetManager assetManager,
282         @ModelTypeDef String modelType,
283         String pathToList,
284         String fileNameRegex,
285         Supplier<Boolean> isEnabled) {
286       this.assetManager = Preconditions.checkNotNull(assetManager);
287       this.modelType = Preconditions.checkNotNull(modelType);
288       this.pathToList = Preconditions.checkNotNull(pathToList);
289       this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
290       this.isEnabled = Preconditions.checkNotNull(isEnabled);
291       resultCache = new ArrayMap<>();
292     }
293 
294     @Override
list(@odelTypeDef String modelType)295     public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
296       if (!this.modelType.equals(modelType)) {
297         return ImmutableList.of();
298       }
299       if (!isEnabled.get()) {
300         return ImmutableList.of();
301       }
302       synchronized (lock) {
303         if (resultCache.get(modelType) != null) {
304           return resultCache.get(modelType);
305         }
306         String[] fileNames = null;
307         try {
308           fileNames = assetManager.list(pathToList);
309         } catch (IOException e) {
310           TcLog.e(TAG, "Failed to list assets", e);
311         }
312         if (fileNames == null) {
313           return ImmutableList.of();
314         }
315         ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
316         for (String fileName : fileNames) {
317           final Matcher matcher = fileNamePattern.matcher(fileName);
318           if (!matcher.matches()) {
319             continue;
320           }
321           String absolutePath =
322               new StringBuilder(pathToList).append('/').append(fileName).toString();
323           try {
324             modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
325           } catch (IOException e) {
326             TcLog.e(TAG, "Failed to call createFromAsset with: " + absolutePath, e);
327           }
328         }
329         ImmutableList<ModelFile> result = modelFilesBuilder.build();
330         resultCache.put(modelType, result);
331         return result;
332       }
333     }
334   }
335 
336   /**
337    * Returns the best locale matching the given detected locales and the default device localelist.
338    * Default locale returned if no matching locale is found.
339    *
340    * @param localePreferences list of optional locale preferences. Used if request contains
341    *     preference and multi_language_support is disabled.
342    * @param detectedLocales ordered list of locales detected from Tcs request text, use {@code null}
343    *     if no detected locales provided.
344    */
findBestModelLocale( @ullable LocaleList localePreferences, @Nullable LocaleList detectedLocales)345   public Locale findBestModelLocale(
346       @Nullable LocaleList localePreferences, @Nullable LocaleList detectedLocales) {
347     if (!settings.isMultiLanguageSupportEnabled() || isEmptyLocaleList(detectedLocales)) {
348       return isEmptyLocaleList(localePreferences) ? Locale.getDefault() : localePreferences.get(0);
349     }
350     Locale bestLocale = Locale.getDefault();
351     LocaleList adjustedLocales = LocaleList.getAdjustedDefault();
352     // we only intersect detected locales with locales for which we have predownloaded models.
353     // Number of downlaoded locale models is determined by flag in tcs settings
354     int numberOfActiveModels = min(adjustedLocales.size(), settings.getMultiLanguageModelsLimit());
355     List<String> filteredDeviceLocales =
356         Splitter.on(",")
357             .splitToList(adjustedLocales.toLanguageTags())
358             .subList(0, numberOfActiveModels);
359     LocaleList filteredDeviceLocaleList =
360         LocaleList.forLanguageTags(String.join(",", filteredDeviceLocales));
361     List<Locale.LanguageRange> deviceLanguageRange =
362         Locale.LanguageRange.parse(filteredDeviceLocaleList.toLanguageTags());
363     for (int i = 0; i < detectedLocales.size(); i++) {
364       if (Locale.lookupTag(
365               deviceLanguageRange, ImmutableList.of(detectedLocales.get(i).getLanguage()))
366           != null) {
367         bestLocale = detectedLocales.get(i);
368         break;
369       }
370     }
371     return bestLocale;
372   }
373 
374   @Nullable
375   @Override
findBestModelFile( @odelTypeDef String modelType, @Nullable LocaleList localePreferences, @Nullable LocaleList detectedLocales)376   public ModelFile findBestModelFile(
377       @ModelTypeDef String modelType,
378       @Nullable LocaleList localePreferences,
379       @Nullable LocaleList detectedLocales) {
380     Locale targetLocale = findBestModelLocale(localePreferences, detectedLocales);
381     // detectedLocales usually only contains 2-char language (e.g. en), while locale in
382     // localePreferences is usually complete (e.g. en_US). Log only if targetLocale is not a prefix.
383     if (!isEmptyLocaleList(localePreferences)
384         && !localePreferences.get(0).toString().startsWith(targetLocale.toString())) {
385       TcLog.d(
386           TAG,
387           String.format(
388               Locale.US,
389               "localePreference and targetLocale mismatch: preference: %s, target: %s",
390               localePreferences.get(0),
391               targetLocale));
392     }
393     return findBestModelFile(modelType, targetLocale);
394   }
395 
396   /**
397    * Returns the best model file for the given locale, {@code null} if nothing is found.
398    *
399    * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
400    * @param targetLocale the preferred locale from preferences or detected locales default locales
401    *     if non given or detected.
402    */
403   @Nullable
findBestModelFile(@odelTypeDef String modelType, Locale targetLocale)404   private ModelFile findBestModelFile(@ModelTypeDef String modelType, Locale targetLocale) {
405     List<Locale.LanguageRange> deviceLanguageRanges =
406         Locale.LanguageRange.parse(LocaleList.getDefault().toLanguageTags());
407     boolean languageIndependentModelOnly = false;
408     if (Locale.lookupTag(deviceLanguageRanges, ImmutableList.of(targetLocale.getLanguage()))
409         == null) {
410       // If the targetLocale's language is not in device locale list, we don't match it to avoid
411       // leaking user language profile to the callers.
412       languageIndependentModelOnly = true;
413     }
414     List<Locale.LanguageRange> targetLanguageRanges =
415         Locale.LanguageRange.parse(targetLocale.toLanguageTag());
416     ModelFile bestModel = null;
417     for (ModelFile model : listModelFiles(modelType)) {
418       if (languageIndependentModelOnly && !model.languageIndependent) {
419         continue;
420       }
421       if (model.isAnyLanguageSupported(targetLanguageRanges)) {
422         if (model.isPreferredTo(bestModel)) {
423           bestModel = model;
424         }
425       }
426     }
427     return bestModel;
428   }
429 
430   /**
431    * Helpter function to check if LocaleList is null or empty
432    *
433    * @param localeList locale list to be checked
434    */
isEmptyLocaleList(@ullable LocaleList localeList)435   private static boolean isEmptyLocaleList(@Nullable LocaleList localeList) {
436     return localeList == null || localeList.isEmpty();
437   }
438 
439   @Override
dump(IndentingPrintWriter printWriter)440   public void dump(IndentingPrintWriter printWriter) {
441     printWriter.println("ModelFileManagerImpl:");
442     printWriter.increaseIndent();
443     for (@ModelTypeDef String modelType : ModelType.values()) {
444       printWriter.println(modelType + " model file(s):");
445       printWriter.increaseIndent();
446       for (ModelFile modelFile : listModelFiles(modelType)) {
447         printWriter.println(modelFile.toString());
448       }
449       printWriter.decreaseIndent();
450     }
451     printWriter.decreaseIndent();
452   }
453 }
454