1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.service.topics.classifier; 18 19 import android.annotation.NonNull; 20 import android.content.Context; 21 import android.content.res.AssetFileDescriptor; 22 import android.content.res.AssetManager; 23 import android.net.Uri; 24 import android.os.Build; 25 import android.util.ArrayMap; 26 import android.util.JsonReader; 27 28 import androidx.annotation.Nullable; 29 import androidx.annotation.RequiresApi; 30 31 import com.android.adservices.LogUtil; 32 import com.android.adservices.LoggerFactory; 33 import com.android.adservices.download.MobileDataDownloadFactory; 34 import com.android.adservices.service.FlagsFactory; 35 import com.android.adservices.service.topics.classifier.ClassifierInputConfig.ClassifierInputField; 36 import com.android.internal.annotations.VisibleForTesting; 37 38 import com.google.android.libraries.mobiledatadownload.GetFileGroupRequest; 39 import com.google.android.libraries.mobiledatadownload.MobileDataDownload; 40 import com.google.android.libraries.mobiledatadownload.file.SynchronousFileStorage; 41 import com.google.android.libraries.mobiledatadownload.file.openers.MappedByteBufferOpener; 42 import com.google.android.libraries.mobiledatadownload.file.openers.ReadStreamOpener; 43 import com.google.common.collect.ImmutableList; 44 import com.google.common.collect.ImmutableMap; 45 import com.google.mobiledatadownload.ClientConfigProto.ClientFile; 46 import com.google.mobiledatadownload.ClientConfigProto.ClientFileGroup; 47 48 import java.io.BufferedReader; 49 import java.io.FileInputStream; 50 import java.io.IOException; 51 import java.io.InputStream; 52 import java.io.InputStreamReader; 53 import java.nio.ByteBuffer; 54 import java.nio.MappedByteBuffer; 55 import java.nio.channels.FileChannel; 56 import java.util.ArrayList; 57 import java.util.Arrays; 58 import java.util.HashSet; 59 import java.util.IllegalFormatException; 60 import java.util.List; 61 import java.util.Map; 62 import java.util.Set; 63 import java.util.concurrent.ExecutionException; 64 65 /** 66 * Model Manager. 67 * 68 * <p>Model Manager to manage models used the Classifier. Currently, there are 2 types of models: 1) 69 * Bundled Model in the APK. 2) Downloaded Model via MDD. 70 * 71 * <p>ModelManager will select the right model to serve Classifier. 72 */ 73 // TODO(b/269798827): Enable for R. 74 @RequiresApi(Build.VERSION_CODES.S) 75 public class ModelManager { 76 private static final LoggerFactory.Logger sLogger = LoggerFactory.getTopicsLogger(); 77 public static final String BUNDLED_LABELS_FILE_PATH = "classifier/labels_topics.txt"; 78 public static final String BUNDLED_TOP_APP_FILE_PATH = "classifier/precomputed_app_list.csv"; 79 public static final String BUNDLED_CLASSIFIER_ASSETS_METADATA_FILE_PATH = 80 "classifier/classifier_assets_metadata.json"; 81 private static final String BUNDLED_CLASSIFIER_INPUT_CONFIG_FILE_PATH = 82 "classifier/classifier_input_config.txt"; 83 public static final String BUNDLED_MODEL_FILE_PATH = "classifier/model.tflite"; 84 85 private static final String FILE_GROUP_NAME = "topics-classifier-model"; 86 87 // Use "\t" as a delimiter to read the precomputed app topics file 88 private static final String LIST_COLUMN_DELIMITER = "\t"; 89 // Use "," as a delimiter to read multi-topics of one app in precomputed app topics file 90 private static final String TOPICS_DELIMITER = ","; 91 // Arbitrary string representing contents of a classifier input field to validate input format. 92 private static final String CLASSIFIER_INPUT_FIELD = "CLASSIFIER_INPUT_FIELD"; 93 94 // The key name of asset metadata property in classifier_assets_metadata.json 95 private static final String ASSET_PROPERTY_NAME = "property"; 96 // The key name of asset element in classifier_assets_metadata.json 97 private static final String ASSET_ELEMENT_NAME = "asset_name"; 98 // The attributions of assets property in classifier_assets_metadata.json 99 private static final Set<String> ASSETS_PROPERTY_ATTRIBUTIONS = 100 new HashSet( 101 Arrays.asList("taxonomy_type", "taxonomy_version", "build_id", "updated_date")); 102 // The attributions of assets metadata in classifier_assets_metadata.json 103 private static final Set<String> ASSETS_NORMAL_ATTRIBUTIONS = 104 new HashSet(Arrays.asList("asset_version", "path", "checksum", "updated_date")); 105 106 private static final String DOWNLOADED_LABEL_FILE_ID = "labels_topics.txt"; 107 private static final String DOWNLOADED_TOP_APPS_FILE_ID = "precomputed_app_list.csv"; 108 private static final String DOWNLOADED_CLASSIFIER_ASSETS_METADATA_FILE_ID = 109 "classifier_assets_metadata.json"; 110 private static final String DOWNLOADED_CLASSIFIER_INPUT_CONFIG_FILE_ID = 111 "classifier_input_config.txt"; 112 private static final String DOWNLOADED_MODEL_FILE_ID = "model.tflite"; 113 114 private static ModelManager sSingleton; 115 private final Context mContext; 116 private final AssetManager mAssetManager; 117 private final String mLabelsFilePath; 118 private final String mTopAppsFilePath; 119 private final String mClassifierAssetsMetadataPath; 120 private final String mModelFilePath; 121 private final String mClassifierInputConfigPath; 122 private final SynchronousFileStorage mFileStorage; 123 private final Map<String, ClientFile> mDownloadedFiles; 124 125 @VisibleForTesting ModelManager( @onNull Context context, @NonNull String labelsFilePath, @NonNull String topAppsFilePath, @NonNull String classifierAssetsMetadataPath, @NonNull String classifierInputConfigPath, @NonNull String modelFilePath, @NonNull SynchronousFileStorage fileStorage, @Nullable Map<String, ClientFile> downloadedFiles)126 ModelManager( 127 @NonNull Context context, 128 @NonNull String labelsFilePath, 129 @NonNull String topAppsFilePath, 130 @NonNull String classifierAssetsMetadataPath, 131 @NonNull String classifierInputConfigPath, 132 @NonNull String modelFilePath, 133 @NonNull SynchronousFileStorage fileStorage, 134 @Nullable Map<String, ClientFile> downloadedFiles) { 135 mContext = context.getApplicationContext(); 136 mAssetManager = context.getAssets(); 137 mLabelsFilePath = labelsFilePath; 138 mTopAppsFilePath = topAppsFilePath; 139 mClassifierAssetsMetadataPath = classifierAssetsMetadataPath; 140 mClassifierInputConfigPath = classifierInputConfigPath; 141 mModelFilePath = modelFilePath; 142 mFileStorage = fileStorage; 143 mDownloadedFiles = downloadedFiles; 144 } 145 146 /** Returns the singleton instance of the {@link ModelManager} given a context. */ 147 @NonNull getInstance(@onNull Context context)148 public static ModelManager getInstance(@NonNull Context context) { 149 synchronized (ModelManager.class) { 150 if (sSingleton == null) { 151 sSingleton = 152 new ModelManager( 153 context, 154 BUNDLED_LABELS_FILE_PATH, 155 BUNDLED_TOP_APP_FILE_PATH, 156 BUNDLED_CLASSIFIER_ASSETS_METADATA_FILE_PATH, 157 BUNDLED_CLASSIFIER_INPUT_CONFIG_FILE_PATH, 158 BUNDLED_MODEL_FILE_PATH, 159 MobileDataDownloadFactory.getFileStorage(context), 160 getDownloadedFiles(context)); 161 } 162 } 163 return sSingleton; 164 } 165 166 /** 167 * This function populates metadata files to a map. 168 * 169 * @param context {@link Context} 170 * @return A map<FileId, ClientFile> contains downloaded fileId with ClientFile or null if no 171 * downloaded files found. 172 */ 173 @VisibleForTesting getDownloadedFiles(@onNull Context context)174 static @Nullable Map<String, ClientFile> getDownloadedFiles(@NonNull Context context) { 175 ClientFileGroup fileGroup = getClientFileGroup(context); 176 if (fileGroup == null) { 177 sLogger.d("ClientFileGroup is null."); 178 return null; 179 } 180 Map<String, ClientFile> downloadedFiles = new ArrayMap<>(); 181 sLogger.v("Populating downloadFiles map."); 182 fileGroup.getFileList().stream() 183 .forEach(file -> downloadedFiles.put(file.getFileId(), file)); 184 return downloadedFiles; 185 } 186 187 /** Returns topics-classifier-model ClientFileGroup */ 188 @VisibleForTesting 189 @Nullable getClientFileGroup(@onNull Context context)190 static ClientFileGroup getClientFileGroup(@NonNull Context context) { 191 MobileDataDownload mobileDataDownload = 192 MobileDataDownloadFactory.getMdd(context, FlagsFactory.getFlags()); 193 GetFileGroupRequest getFileGroupRequest = 194 GetFileGroupRequest.newBuilder().setGroupName(FILE_GROUP_NAME).build(); 195 ClientFileGroup fileGroup = null; 196 try { 197 // TODO(b/242908564). Remove get() 198 fileGroup = mobileDataDownload.getFileGroup(getFileGroupRequest).get(); 199 } catch (ExecutionException | InterruptedException e) { 200 sLogger.e(e, "Unable to load MDD file group."); 201 return null; 202 } 203 return fileGroup; 204 } 205 206 /** 207 * Returns the build id of model that will be used for classification. This function will 208 * compare the build id from bundled asset and the downloaded model and choose the newer build 209 * id. 210 */ getBuildId()211 public long getBuildId() { 212 return useDownloadedFiles() 213 ? getDownloadedModelBuildId() 214 : CommonClassifierHelper.getBundledModelBuildId( 215 mContext, mClassifierAssetsMetadataPath); 216 } 217 218 // Return true if Model Manager should use downloaded model. Otherwise, use bundled model. 219 @VisibleForTesting useDownloadedFiles()220 boolean useDownloadedFiles() { 221 if (FlagsFactory.getFlags().getClassifierForceUseBundledFiles()) { 222 sLogger.d( 223 "ModelManager uses bundled model because flag" 224 + " classifier_force_use_bundled_files is enabled"); 225 return false; 226 } else if (mDownloadedFiles == null || mDownloadedFiles.size() == 0) { 227 // Use bundled model if no downloaded files available. 228 sLogger.d( 229 "ModelManager uses bundled model because there is no downloaded files" 230 + " available"); 231 return false; 232 } 233 234 long downloadedModelBuildId = getDownloadedModelBuildId(); 235 long bundledModelBuildId = 236 CommonClassifierHelper.getBundledModelBuildId( 237 mContext, mClassifierAssetsMetadataPath); 238 if (downloadedModelBuildId <= bundledModelBuildId) { 239 // Mdd has not downloaded new version of model. Use bundled model. 240 sLogger.d( 241 "ModelManager uses bundled model build id = %d because downloaded model build" 242 + " id = %d is not the latest version", 243 bundledModelBuildId, downloadedModelBuildId); 244 return false; 245 } 246 sLogger.d("ModelManager uses downloaded model build id = %d", downloadedModelBuildId); 247 return true; 248 } 249 250 /** 251 * Load TFLite model as a ByteBuffer. 252 * 253 * @throws IOException if failed to read downloaded or bundled model file. 254 */ 255 @NonNull retrieveModel()256 public ByteBuffer retrieveModel() throws IOException { 257 if (useDownloadedFiles()) { 258 ClientFile downloadedFile = mDownloadedFiles.get(DOWNLOADED_MODEL_FILE_ID); 259 MappedByteBuffer buffer = null; 260 if (downloadedFile == null) { 261 sLogger.e("Failed to find downloaded model file"); 262 return ByteBuffer.allocate(0); 263 } else { 264 buffer = 265 mFileStorage.open( 266 Uri.parse(downloadedFile.getFileUri()), 267 MappedByteBufferOpener.createForRead()); 268 return buffer; 269 } 270 } else { 271 try { 272 // Use bundled files. 273 AssetFileDescriptor fileDescriptor = mAssetManager.openFd(mModelFilePath); 274 FileInputStream inputStream = 275 new FileInputStream(fileDescriptor.getFileDescriptor()); 276 FileChannel fileChannel = inputStream.getChannel(); 277 278 long startOffset = fileDescriptor.getStartOffset(); 279 long declaredLength = fileDescriptor.getDeclaredLength(); 280 return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); 281 } catch (IOException | NullPointerException e) { 282 sLogger.e(e, "Error loading the bundled classifier model"); 283 return ByteBuffer.allocate(0); 284 } 285 } 286 } 287 288 /** Returns true if the classifier model is available for classification. */ isModelAvailable()289 public boolean isModelAvailable() { 290 if (useDownloadedFiles()) { 291 // Downloaded model is always expected to be available. 292 return true; 293 } else { 294 // Check if the non-zero model file is present in the apk assets. 295 try { 296 return mAssetManager.openFd(mModelFilePath).getLength() > 0; 297 } catch (IOException e) { 298 sLogger.e(e, "[ML] No classifier model available."); 299 return false; 300 } 301 } 302 } 303 304 /** 305 * Retrieve a list of topicIDs from labels file. 306 * 307 * @return The list of topicIDs from downloaded or bundled labels file. Empty list will be 308 * returned for {@link IOException}. 309 */ 310 @NonNull retrieveLabels()311 public ImmutableList<Integer> retrieveLabels() { 312 ImmutableList.Builder<Integer> labels = new ImmutableList.Builder(); 313 InputStream inputStream = null; // InputStream.nullInputStream() is not available on S-. 314 if (useDownloadedFiles()) { 315 inputStream = readDownloadedFile(DOWNLOADED_LABEL_FILE_ID); 316 } else { 317 // Use bundled files. 318 try { 319 inputStream = mAssetManager.open(mLabelsFilePath); 320 } catch (IOException e) { 321 sLogger.e(e, "Failed to read labels file"); 322 } 323 } 324 return inputStream == null ? labels.build() : getLabelsList(labels, inputStream); 325 } 326 327 @NonNull getLabelsList( @onNull ImmutableList.Builder<Integer> labels, @NonNull InputStream inputStream)328 private ImmutableList<Integer> getLabelsList( 329 @NonNull ImmutableList.Builder<Integer> labels, @NonNull InputStream inputStream) { 330 String line; 331 try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream)) { 332 BufferedReader reader = new BufferedReader(inputStreamReader); 333 334 while ((line = reader.readLine()) != null) { 335 // If the line has at least 1 digit, this line will be added to the labels. 336 if (line.length() > 0) { 337 labels.add(Integer.parseInt(line)); 338 } 339 } 340 } catch (IOException e) { 341 sLogger.e(e, "Unable to read precomputed labels"); 342 // When catching IOException -> return empty immutable list 343 // TODO(b/226944089): A strategy to handle exceptions 344 // in Classifier and PrecomputedLoader 345 return ImmutableList.of(); 346 } 347 348 return labels.build(); 349 } 350 351 /** 352 * Retrieve the app classification topicIDs. 353 * 354 * @return The map from App to the list of its classification topicIDs. 355 */ 356 @NonNull retrieveAppClassificationTopics()357 public Map<String, List<Integer>> retrieveAppClassificationTopics() { 358 // appTopicsMap = Map<App, List<Topic>> 359 Map<String, List<Integer>> appTopicsMap = new ArrayMap<>(); 360 361 // The immutable set of the topics from labels file 362 ImmutableList<Integer> validTopics = retrieveLabels(); 363 InputStream inputStream = null; 364 if (useDownloadedFiles()) { 365 inputStream = readDownloadedFile(DOWNLOADED_TOP_APPS_FILE_ID); 366 } else { 367 // Use bundled files. 368 try { 369 inputStream = mAssetManager.open(mTopAppsFilePath); 370 } catch (IOException e) { 371 sLogger.e(e, "Failed to read top apps file"); 372 } 373 } 374 return inputStream == null 375 ? appTopicsMap 376 : getAppsTopicMap(appTopicsMap, validTopics, inputStream); 377 } 378 379 @NonNull getAppsTopicMap( @onNull Map<String, List<Integer>> appTopicsMap, @NonNull ImmutableList<Integer> validTopics, @NonNull InputStream inputStream)380 private Map<String, List<Integer>> getAppsTopicMap( 381 @NonNull Map<String, List<Integer>> appTopicsMap, 382 @NonNull ImmutableList<Integer> validTopics, 383 @NonNull InputStream inputStream) { 384 String line; 385 try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream)) { 386 BufferedReader reader = new BufferedReader(inputStreamReader); 387 388 // Skip first line (columns name) 389 reader.readLine(); 390 391 while ((line = reader.readLine()) != null) { 392 String[] columns = line.split(LIST_COLUMN_DELIMITER); 393 394 // If the line has less than 2 elements, this app contains empty topic 395 // and save an empty topic list of this app in appTopicsMap. 396 if (columns.length < 2) { 397 // columns[0] if the app's name 398 appTopicsMap.put(columns[0], ImmutableList.of()); 399 continue; 400 } 401 402 // The first column is app package name 403 String app = columns[0]; 404 405 // The second column is multi-topics of the app 406 String[] appTopics = columns[1].split(TOPICS_DELIMITER); 407 408 // This list is used to temporarily store the allowed topicIDs of one app. 409 List<Integer> allowedAppTopics = new ArrayList<>(); 410 411 for (String appTopic : appTopics) { 412 // The topic will not save to the app topics map 413 // if it is not a valid topic in labels file 414 if (!validTopics.contains(Integer.parseInt(appTopic))) { 415 sLogger.e( 416 "Unable to load topicID \"%s\" in app \"%s\", " 417 + "because it is not a valid topic in labels file.", 418 appTopic, app); 419 continue; 420 } 421 422 // Add the allowed topic to the list 423 allowedAppTopics.add(Integer.parseInt(appTopic)); 424 } 425 426 appTopicsMap.put(app, ImmutableList.copyOf(allowedAppTopics)); 427 } 428 } catch (IOException e) { 429 sLogger.e(e, "Unable to read precomputed app topics list"); 430 // When catching IOException -> return empty hash map 431 // TODO(b/226944089): A strategy to handle exceptions 432 // in Classifier and PrecomputedLoader 433 return ImmutableMap.of(); 434 } 435 436 return appTopicsMap; 437 } 438 439 /** 440 * Retrieve the assets names and their corresponding metadata. 441 * 442 * @return The immutable map of assets metadata from {@code mClassifierAssetsMetadataPath}. 443 * Empty map will be returned for {@link IOException}. 444 */ 445 @NonNull retrieveClassifierAssetsMetadata()446 public ImmutableMap<String, ImmutableMap<String, String>> retrieveClassifierAssetsMetadata() { 447 // Initialize a ImmutableMap.Builder to store the classifier assets metadata iteratively. 448 // classifierAssetsMetadata = ImmutableMap<AssetName, ImmutableMap<MetadataName, Value>> 449 ImmutableMap.Builder<String, ImmutableMap<String, String>> classifierAssetsMetadata = 450 new ImmutableMap.Builder<>(); 451 InputStream inputStream = null; 452 if (useDownloadedFiles()) { 453 inputStream = readDownloadedFile(DOWNLOADED_CLASSIFIER_ASSETS_METADATA_FILE_ID); 454 } else { 455 // Use bundled files. 456 try { 457 inputStream = mAssetManager.open(mClassifierAssetsMetadataPath); 458 } catch (IOException e) { 459 sLogger.e(e, "Failed to read bundled metadata file"); 460 } 461 } 462 return inputStream == null 463 ? classifierAssetsMetadata.build() 464 : getAssetsMetadataMap(classifierAssetsMetadata, inputStream); 465 } 466 467 @NonNull getAssetsMetadataMap( @onNull ImmutableMap.Builder<String, ImmutableMap<String, String>> classifierAssetsMetadata, @NonNull InputStream inputStream)468 private ImmutableMap<String, ImmutableMap<String, String>> getAssetsMetadataMap( 469 @NonNull 470 ImmutableMap.Builder<String, ImmutableMap<String, String>> 471 classifierAssetsMetadata, 472 @NonNull InputStream inputStream) { 473 try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream)) { 474 JsonReader reader = new JsonReader(inputStreamReader); 475 476 reader.beginArray(); 477 while (reader.hasNext()) { 478 // Use an immutable map to store the metadata of one asset. 479 // assetMetadata = ImmutableMap<MetadataName, Value> 480 ImmutableMap.Builder<String, String> assetMetadata = new ImmutableMap.Builder<>(); 481 482 // Use jsonElementKey to save the key name of each array element. 483 String jsonElementKey = null; 484 485 // Begin to read one json element in the array here. 486 reader.beginObject(); 487 if (reader.hasNext()) { 488 String elementKeyName = reader.nextName(); 489 490 if (elementKeyName.equals(ASSET_PROPERTY_NAME)) { 491 jsonElementKey = reader.nextString(); 492 493 while (reader.hasNext()) { 494 String attribution = reader.nextName(); 495 // Check if the attribution name can be found in the property's key set. 496 if (ASSETS_PROPERTY_ATTRIBUTIONS.contains(attribution)) { 497 assetMetadata.put(attribution, reader.nextString()); 498 } else { 499 // Skip the redundant metadata name if it can't be found 500 // in the ASSETS_PROPERTY_ATTRIBUTIONS. 501 reader.skipValue(); 502 sLogger.e( 503 attribution, 504 " is a redundant metadata attribution of " 505 + "metadata property."); 506 } 507 } 508 } else if (elementKeyName.equals(ASSET_ELEMENT_NAME)) { 509 jsonElementKey = reader.nextString(); 510 511 while (reader.hasNext()) { 512 String attribution = reader.nextName(); 513 // Check if the attribution name can be found in the asset's key set. 514 if (ASSETS_NORMAL_ATTRIBUTIONS.contains(attribution)) { 515 assetMetadata.put(attribution, reader.nextString()); 516 } else { 517 // Skip the redundant metadata name if it can't be found 518 // in the ASSET_NORMAL_ATTRIBUTIONS. 519 reader.skipValue(); 520 sLogger.e( 521 attribution, 522 " is a redundant metadata attribution of asset."); 523 } 524 } 525 } else { 526 // Skip the json element if it doesn't have key "property" or "asset_name". 527 while (reader.hasNext()) { 528 reader.skipValue(); 529 } 530 sLogger.e( 531 "Can't load this json element, " 532 + "because \"property\" or \"asset_name\" " 533 + "can't be found in the json element."); 534 } 535 } 536 reader.endObject(); 537 538 // Save the metadata of the asset if and only if the assetName can be retrieved 539 // correctly from the metadata json file. 540 if (jsonElementKey != null) { 541 classifierAssetsMetadata.put(jsonElementKey, assetMetadata.build()); 542 } 543 } 544 reader.endArray(); 545 } catch (IOException e) { 546 sLogger.e(e, "Unable to read classifier assets metadata file"); 547 // When catching IOException -> return empty immutable map 548 return ImmutableMap.of(); 549 } 550 return classifierAssetsMetadata.build(); 551 } 552 553 /** 554 * Retrieve classifier input configuration from config file. 555 * 556 * @return A ClassifierInputConfig containing the format string for the classifier input and a 557 * list of fields to populate it. Empty ClassifierInputConfig will be returned for {@link 558 * IOException}. 559 */ 560 @NonNull retrieveClassifierInputConfig()561 public ClassifierInputConfig retrieveClassifierInputConfig() { 562 InputStream inputStream = null; // InputStream.nullInputStream() is not available on S-. 563 if (useDownloadedFiles()) { 564 inputStream = readDownloadedFile(DOWNLOADED_CLASSIFIER_INPUT_CONFIG_FILE_ID); 565 } else { 566 // Use bundled files. 567 try { 568 inputStream = mAssetManager.open(mClassifierInputConfigPath); 569 } catch (IOException e) { 570 LogUtil.e(e, "Failed to read classifier input config file"); 571 } 572 } 573 return inputStream == null 574 ? ClassifierInputConfig.getEmptyConfig() 575 : getClassifierInputConfig(inputStream); 576 } 577 578 @NonNull getClassifierInputConfig(@onNull InputStream inputStream)579 private ClassifierInputConfig getClassifierInputConfig(@NonNull InputStream inputStream) { 580 String line; 581 String inputFormat; 582 ImmutableList.Builder<ClassifierInputField> inputFields = ImmutableList.builder(); 583 584 try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream)) { 585 BufferedReader reader = new BufferedReader(inputStreamReader); 586 587 if ((line = reader.readLine()) == null || line.length() == 0) { 588 return ClassifierInputConfig.getEmptyConfig(); 589 } 590 inputFormat = line; 591 592 while ((line = reader.readLine()) != null) { 593 // If the line has at least 1 character, this line will be added to the input 594 // fields. 595 if (line.length() > 0) { 596 try { 597 inputFields.add(ClassifierInputField.valueOf(line)); 598 } catch (IllegalArgumentException e) { 599 LogUtil.e("Invalid input field in classifier input config: {}", line); 600 return ClassifierInputConfig.getEmptyConfig(); 601 } 602 } 603 } 604 } catch (IOException e) { 605 LogUtil.e(e, "Unable to read classifier input config"); 606 // When catching IOException -> return empty ClassifierInputConfig 607 // TODO(b/226944089): A strategy to handle exceptions 608 // in Classifier and PrecomputedLoader 609 return ClassifierInputConfig.getEmptyConfig(); 610 } 611 612 ClassifierInputConfig classifierInputConfig = 613 new ClassifierInputConfig(inputFormat, inputFields.build()); 614 615 if (!validateClassifierInputConfig(classifierInputConfig)) { 616 return ClassifierInputConfig.getEmptyConfig(); 617 } 618 619 return classifierInputConfig; 620 } 621 622 @NonNull validateClassifierInputConfig( @onNull ClassifierInputConfig classifierInputConfig)623 private boolean validateClassifierInputConfig( 624 @NonNull ClassifierInputConfig classifierInputConfig) { 625 String[] inputFields = new String[classifierInputConfig.getInputFields().size()]; 626 Arrays.fill(inputFields, CLASSIFIER_INPUT_FIELD); 627 try { 628 String formattedInput = 629 String.format(classifierInputConfig.getInputFormat(), (Object[]) inputFields); 630 LogUtil.d("Validated classifier input format: {}", formattedInput); 631 } catch (IllegalFormatException e) { 632 LogUtil.e("Classifier input config is incorrectly formatted"); 633 return false; 634 } 635 return true; 636 } 637 638 // Return an InputStream if downloaded model file can be found by 639 // ClientFile.file_id. 640 @NonNull readDownloadedFile(String fileId)641 private InputStream readDownloadedFile(String fileId) { 642 InputStream inputStream = null; 643 ClientFile downloadedFile = mDownloadedFiles.get(fileId); 644 if (downloadedFile == null) { 645 sLogger.e("Failed to find downloaded %s file", fileId); 646 return inputStream; 647 } 648 try { 649 inputStream = 650 mFileStorage.open( 651 Uri.parse(downloadedFile.getFileUri()), ReadStreamOpener.create()); 652 } catch (IOException e) { 653 sLogger.e(e, "Failed to load fileId = %s", fileId); 654 } 655 return inputStream; 656 } 657 658 /** 659 * Gets downloaded model build id from topics-classifier-model ClientFileGroup. Returns 0 if 660 * there is no downloaded file. 661 * 662 * @return downloaded model build id. 663 */ getDownloadedModelBuildId()664 private long getDownloadedModelBuildId() { 665 ClientFileGroup clientFileGroup = getClientFileGroup(mContext); 666 if (clientFileGroup == null) { 667 return 0; 668 } 669 return clientFileGroup.getBuildId(); 670 } 671 } 672