1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.ondevicepersonalization; 18 19 import static com.android.ondevicepersonalization.internal.util.ByteArrayUtil.deserializeObject; 20 import static com.android.ondevicepersonalization.internal.util.ByteArrayUtil.serializeObject; 21 22 import android.annotation.FlaggedApi; 23 import android.annotation.IntDef; 24 import android.annotation.IntRange; 25 import android.annotation.NonNull; 26 import android.annotation.SuppressLint; 27 28 import com.android.adservices.ondevicepersonalization.flags.Flags; 29 import com.android.internal.util.Preconditions; 30 31 import java.lang.annotation.Retention; 32 import java.lang.annotation.RetentionPolicy; 33 import java.util.Objects; 34 35 /** 36 * Contains all the information needed for a run of model inference. The input of {@link 37 * ModelManager#run}. 38 */ 39 public final class InferenceInput { 40 /** The configuration that controls runtime interpreter behavior. */ 41 @NonNull private Params mParams; 42 43 /** 44 * A byte array that holds input data. The inputs should be in the same order as inputs of the 45 * model. 46 * 47 * <p>For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs: 48 * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 49 * 50 * <pre>{@code 51 * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. 52 * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. 53 * Object[] inputData = {input0, input1, ...}; 54 * byte[] data = serializeObject(inputData); 55 * }</pre> 56 * 57 * <p>For Executorch model, this field is a serialized EValue array. 58 * 59 * @hide 60 */ 61 @NonNull private byte[] mData; 62 63 /** 64 * The number of input examples. Adopter can set this field to run batching inference. The batch 65 * size is 1 by default. The batch size should match the input data size. 66 */ 67 private int mBatchSize = 1; 68 69 /** 70 * The empty InferenceOutput representing the expected output structure. For LiteRT, the 71 * inference code will verify whether this expected output structure matches model output 72 * signature. 73 * 74 * <p>If a model produce string tensors: 75 * 76 * <pre>{@code 77 * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. 78 * HashMap<Integer, Object> outputs = new HashMap<>(); 79 * outputs.put(0, output); 80 * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); 81 * }</pre> 82 */ 83 @NonNull private InferenceOutput mExpectedOutputStructure; 84 85 public static class Params { 86 /** A {@link KeyValueStore} where pre-trained model is stored. */ 87 @NonNull private KeyValueStore mKeyValueStore; 88 89 /** The key of the table where the corresponding value stores a pre-trained model. */ 90 @NonNull private String mModelKey; 91 92 /** The model inference will run on CPU. */ 93 public static final int DELEGATE_CPU = 1; 94 95 /** 96 * The delegate to run model inference. 97 * 98 * @hide 99 */ 100 @IntDef( 101 prefix = "DELEGATE_", 102 value = {DELEGATE_CPU}) 103 @Retention(RetentionPolicy.SOURCE) 104 public @interface Delegate {} 105 106 /** 107 * The delegate to run model inference. If not set, the default value is {@link 108 * #DELEGATE_CPU}. 109 */ 110 private @Delegate int mDelegateType = DELEGATE_CPU; 111 112 /** The model is a tensorflow lite model. */ 113 public static final int MODEL_TYPE_TENSORFLOW_LITE = 1; 114 115 /** The model is an executorch model. */ 116 @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED) 117 public static final int MODEL_TYPE_EXECUTORCH = 2; 118 119 /** 120 * The type of the model. 121 * 122 * @hide 123 */ 124 @IntDef( 125 prefix = "MODEL_TYPE", 126 value = {MODEL_TYPE_TENSORFLOW_LITE, MODEL_TYPE_EXECUTORCH}) 127 @Retention(RetentionPolicy.SOURCE) 128 public @interface ModelType {} 129 130 /** 131 * The type of the pre-trained model. If not set, the default value is {@link 132 * #MODEL_TYPE_TENSORFLOW_LITE} . 133 */ 134 private @ModelType int mModelType = MODEL_TYPE_TENSORFLOW_LITE; 135 136 /** 137 * The number of threads used for intraop parallelism on CPU, must be positive number. 138 * Adopters can set this field based on model architecture. The actual thread number depends 139 * on system resources and other constraints. 140 */ 141 private @IntRange(from = 1) int mRecommendedNumThreads = 1; 142 143 /** 144 * Creates a new Params. 145 * 146 * @param keyValueStore A {@link KeyValueStore} where pre-trained model is stored. 147 * @param modelKey The key of the table where the corresponding value stores a pre-trained 148 * model. 149 * @param delegateType The delegate to run model inference. If not set, the default value is 150 * {@link #DELEGATE_CPU}. 151 * @param modelType The type of the pre-trained model. If not set, the default value is 152 * {@link #MODEL_TYPE_TENSORFLOW_LITE} . 153 * @param recommendedNumThreads The number of threads used for intraop parallelism on CPU, 154 * must be positive number. Adopters can set this field based on model architecture. The 155 * actual thread number depends on system resources and other constraints. 156 * @hide 157 */ Params( @onNull KeyValueStore keyValueStore, @NonNull String modelKey, @Delegate int delegateType, @ModelType int modelType, @IntRange(from = 1) int recommendedNumThreads)158 public Params( 159 @NonNull KeyValueStore keyValueStore, 160 @NonNull String modelKey, 161 @Delegate int delegateType, 162 @ModelType int modelType, 163 @IntRange(from = 1) int recommendedNumThreads) { 164 this.mKeyValueStore = Objects.requireNonNull(keyValueStore); 165 this.mModelKey = Objects.requireNonNull(modelKey); 166 this.mDelegateType = delegateType; 167 this.mModelType = modelType; 168 169 if (!(mModelType == MODEL_TYPE_TENSORFLOW_LITE) 170 && !(mModelType == MODEL_TYPE_EXECUTORCH)) { 171 throw new java.lang.IllegalArgumentException( 172 "modelType was " 173 + mModelType 174 + " but must be one of: " 175 + "MODEL_TYPE_TENSORFLOW_LITE(" 176 + MODEL_TYPE_TENSORFLOW_LITE 177 + "), " 178 + "MODEL_TYPE_EXECUTORCH(" 179 + MODEL_TYPE_EXECUTORCH 180 + ")"); 181 } 182 183 this.mRecommendedNumThreads = recommendedNumThreads; 184 Preconditions.checkState( 185 recommendedNumThreads >= 1, 186 "recommend thread number should be large or equal to 1"); 187 } 188 189 /** A {@link KeyValueStore} where pre-trained model is stored. */ getKeyValueStore()190 public @NonNull KeyValueStore getKeyValueStore() { 191 return mKeyValueStore; 192 } 193 194 /** The key of the table where the corresponding value stores a pre-trained model. */ getModelKey()195 public @NonNull String getModelKey() { 196 return mModelKey; 197 } 198 199 /** 200 * The delegate to run model inference. If not set, the default value is {@link 201 * #DELEGATE_CPU}. 202 */ getDelegateType()203 public @Delegate int getDelegateType() { 204 return mDelegateType; 205 } 206 207 /** 208 * The type of the pre-trained model. If not set, the default value is {@link 209 * #MODEL_TYPE_TENSORFLOW_LITE} . 210 */ getModelType()211 public @ModelType int getModelType() { 212 return mModelType; 213 } 214 215 /** 216 * The number of threads used for intraop parallelism on CPU, must be positive number. 217 * Adopters can set this field based on model architecture. The actual thread number depends 218 * on system resources and other constraints. 219 */ getRecommendedNumThreads()220 public @IntRange(from = 1) int getRecommendedNumThreads() { 221 return mRecommendedNumThreads; 222 } 223 224 @Override equals(@ndroid.annotation.Nullable Object o)225 public boolean equals(@android.annotation.Nullable Object o) { 226 // You can override field equality logic by defining either of the methods like: 227 // boolean fieldNameEquals(Params other) { ... } 228 // boolean fieldNameEquals(FieldType otherValue) { ... } 229 230 if (this == o) return true; 231 if (o == null || getClass() != o.getClass()) return false; 232 @SuppressWarnings("unchecked") 233 Params that = (Params) o; 234 //noinspection PointlessBooleanExpression 235 return true 236 && java.util.Objects.equals(mKeyValueStore, that.mKeyValueStore) 237 && java.util.Objects.equals(mModelKey, that.mModelKey) 238 && mDelegateType == that.mDelegateType 239 && mModelType == that.mModelType 240 && mRecommendedNumThreads == that.mRecommendedNumThreads; 241 } 242 243 @Override hashCode()244 public int hashCode() { 245 // You can override field hashCode logic by defining methods like: 246 // int fieldNameHashCode() { ... } 247 248 int _hash = 1; 249 _hash = 31 * _hash + java.util.Objects.hashCode(mKeyValueStore); 250 _hash = 31 * _hash + java.util.Objects.hashCode(mModelKey); 251 _hash = 31 * _hash + mDelegateType; 252 _hash = 31 * _hash + mModelType; 253 _hash = 31 * _hash + mRecommendedNumThreads; 254 return _hash; 255 } 256 257 /** A builder for {@link Params} */ 258 @SuppressWarnings("WeakerAccess") 259 public static final class Builder { 260 261 private @NonNull KeyValueStore mKeyValueStore; 262 private @NonNull String mModelKey; 263 private @Delegate int mDelegateType; 264 private @ModelType int mModelType; 265 private @IntRange(from = 1) int mRecommendedNumThreads; 266 267 private long mBuilderFieldsSet = 0L; 268 269 /** 270 * Creates a new Builder. 271 * 272 * @param keyValueStore a {@link KeyValueStore} where pre-trained model is stored. 273 * @param modelKey key of the table where the corresponding value stores a pre-trained 274 * model. 275 */ Builder(@onNull KeyValueStore keyValueStore, @NonNull String modelKey)276 public Builder(@NonNull KeyValueStore keyValueStore, @NonNull String modelKey) { 277 mKeyValueStore = Objects.requireNonNull(keyValueStore); 278 mModelKey = Objects.requireNonNull(modelKey); 279 } 280 281 /** A {@link KeyValueStore} where pre-trained model is stored. */ setKeyValueStore(@onNull KeyValueStore value)282 public @NonNull Builder setKeyValueStore(@NonNull KeyValueStore value) { 283 mBuilderFieldsSet |= 0x1; 284 mKeyValueStore = value; 285 return this; 286 } 287 288 /** The key of the table where the corresponding value stores a pre-trained model. */ setModelKey(@onNull String value)289 public @NonNull Builder setModelKey(@NonNull String value) { 290 mBuilderFieldsSet |= 0x2; 291 mModelKey = value; 292 return this; 293 } 294 295 /** 296 * The delegate to run model inference. If not set, the default value is {@link 297 * #DELEGATE_CPU}. 298 */ setDelegateType(@elegate int value)299 public @NonNull Builder setDelegateType(@Delegate int value) { 300 mBuilderFieldsSet |= 0x4; 301 mDelegateType = value; 302 return this; 303 } 304 305 /** 306 * The type of the pre-trained model. If not set, the default value is {@link 307 * #MODEL_TYPE_TENSORFLOW_LITE} . 308 */ setModelType(@odelType int value)309 public @NonNull Builder setModelType(@ModelType int value) { 310 mBuilderFieldsSet |= 0x8; 311 mModelType = value; 312 return this; 313 } 314 315 /** 316 * The number of threads used for intraop parallelism on CPU, must be positive number. 317 * Adopters can set this field based on model architecture. The actual thread number 318 * depends on system resources and other constraints. 319 */ setRecommendedNumThreads(@ntRangefrom = 1) int value)320 public @NonNull Builder setRecommendedNumThreads(@IntRange(from = 1) int value) { 321 mBuilderFieldsSet |= 0x10; 322 mRecommendedNumThreads = value; 323 return this; 324 } 325 326 /** Builds the instance. This builder should not be touched after calling this! */ build()327 public @NonNull Params build() { 328 mBuilderFieldsSet |= 0x20; // Mark builder used 329 330 if ((mBuilderFieldsSet & 0x4) == 0) { 331 mDelegateType = DELEGATE_CPU; 332 } 333 if ((mBuilderFieldsSet & 0x8) == 0) { 334 mModelType = MODEL_TYPE_TENSORFLOW_LITE; 335 } 336 if ((mBuilderFieldsSet & 0x10) == 0) { 337 mRecommendedNumThreads = 1; 338 } 339 Params o = 340 new Params( 341 mKeyValueStore, 342 mModelKey, 343 mDelegateType, 344 mModelType, 345 mRecommendedNumThreads); 346 return o; 347 } 348 } 349 } 350 InferenceInput( @onNull Params params, @NonNull byte[] data, int batchSize, @NonNull InferenceOutput expectedOutputStructure)351 /* package-private */ InferenceInput( 352 @NonNull Params params, 353 @NonNull byte[] data, 354 int batchSize, 355 @NonNull InferenceOutput expectedOutputStructure) { 356 this.mParams = Objects.requireNonNull(params); 357 this.mData = Objects.requireNonNull(data); 358 this.mBatchSize = batchSize; 359 this.mExpectedOutputStructure = Objects.requireNonNull(expectedOutputStructure); 360 } 361 362 /** The configuration that controls runtime interpreter behavior. */ getParams()363 public @NonNull Params getParams() { 364 return mParams; 365 } 366 367 /** 368 * A byte array that holds input data. The inputs should be in the same order as inputs of the 369 * model. 370 * 371 * <p>For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs: 372 * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 373 * 374 * <pre>{@code 375 * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. 376 * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. 377 * Object[] inputData = {input0, input1, ...}; 378 * byte[] data = serializeObject(inputData); 379 * }</pre> 380 * 381 * <p>For Executorch model, this field is a serialized EValue array. 382 */ 383 @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED) getData()384 public @NonNull byte[] getData() { 385 return mData; 386 } 387 388 /** 389 * Note: use {@link InferenceInput#getData()} instead. 390 * 391 * <p>An array of input data. The inputs should be in the same order as inputs of the model. 392 * 393 * <p>For example, if a model takes multiple inputs: 394 * 395 * <pre>{@code 396 * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. 397 * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. 398 * Object[] inputData = {input0, input1, ...}; 399 * }</pre> 400 * 401 * For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs: 402 * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 403 */ 404 @SuppressLint("ArrayReturn") getInputData()405 public @NonNull Object[] getInputData() { 406 return (Object[]) deserializeObject(mData); 407 } 408 409 /** 410 * The number of input examples. Adopter can set this field to run batching inference. The batch 411 * size is 1 by default. The batch size should match the input data size. 412 */ getBatchSize()413 public int getBatchSize() { 414 return mBatchSize; 415 } 416 417 /** 418 * The empty InferenceOutput representing the expected output structure. For LiteRT, the 419 * inference code will verify whether this expected output structure matches model output 420 * signature. 421 * 422 * <p>If a model produce string tensors: 423 * 424 * <pre>{@code 425 * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. 426 * HashMap<Integer, Object> outputs = new HashMap<>(); 427 * outputs.put(0, output); 428 * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); 429 * }</pre> 430 */ getExpectedOutputStructure()431 public @NonNull InferenceOutput getExpectedOutputStructure() { 432 return mExpectedOutputStructure; 433 } 434 435 @Override equals(@ndroid.annotation.Nullable Object o)436 public boolean equals(@android.annotation.Nullable Object o) { 437 // You can override field equality logic by defining either of the methods like: 438 // boolean fieldNameEquals(InferenceInput other) { ... } 439 // boolean fieldNameEquals(FieldType otherValue) { ... } 440 441 if (this == o) return true; 442 if (o == null || getClass() != o.getClass()) return false; 443 @SuppressWarnings("unchecked") 444 InferenceInput that = (InferenceInput) o; 445 //noinspection PointlessBooleanExpression 446 return true 447 && java.util.Objects.equals(mParams, that.mParams) 448 && java.util.Arrays.equals(mData, that.mData) 449 && mBatchSize == that.mBatchSize 450 && java.util.Objects.equals( 451 mExpectedOutputStructure, that.mExpectedOutputStructure); 452 } 453 454 @Override hashCode()455 public int hashCode() { 456 // You can override field hashCode logic by defining methods like: 457 // int fieldNameHashCode() { ... } 458 459 int _hash = 1; 460 _hash = 31 * _hash + java.util.Objects.hashCode(mParams); 461 _hash = 31 * _hash + java.util.Arrays.hashCode(mData); 462 _hash = 31 * _hash + mBatchSize; 463 _hash = 31 * _hash + java.util.Objects.hashCode(mExpectedOutputStructure); 464 return _hash; 465 } 466 467 /** A builder for {@link InferenceInput} */ 468 @SuppressWarnings("WeakerAccess") 469 public static final class Builder { 470 471 private @NonNull Params mParams; 472 private @NonNull byte[] mData; 473 private int mBatchSize; 474 private @NonNull InferenceOutput mExpectedOutputStructure = 475 new InferenceOutput.Builder().build(); 476 477 private long mBuilderFieldsSet = 0L; 478 479 /** 480 * Note: use {@link InferenceInput.Builder#Builder(Params, byte[])} instead. 481 * 482 * <p>Creates a new Builder for LiteRT model inference input. For LiteRT, inputData field is 483 * mapped to inputs of runForMultipleInputsOutputs: 484 * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 485 * The inputs should be in the same order as inputs * of the model. * 486 * 487 * <p>For example, if a model takes multiple inputs: * 488 * 489 * <pre>{@code 490 * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. 491 * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. 492 * Object[] inputData = {input0, input1, ...}; 493 * }</pre> 494 * 495 * For LiteRT, the inference code will verify whether the expected output structure matches 496 * model output signature. 497 * 498 * <p>If a model produce string tensors: 499 * 500 * <pre>{@code 501 * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. 502 * HashMap<Integer, Object> outputs = new HashMap<>(); 503 * outputs.put(0, output); 504 * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); 505 * 506 * }</pre> 507 * 508 * @param params configuration that controls runtime interpreter behavior. 509 * @param inputData an array of input data. 510 * @param expectedOutputStructure an empty InferenceOutput representing the expected output 511 * structure. 512 */ Builder( @onNull Params params, @SuppressLint("ArrayReturn") @NonNull Object[] inputData, @NonNull InferenceOutput expectedOutputStructure)513 public Builder( 514 @NonNull Params params, 515 @SuppressLint("ArrayReturn") @NonNull Object[] inputData, 516 @NonNull InferenceOutput expectedOutputStructure) { 517 mParams = Objects.requireNonNull(params); 518 mData = serializeObject(Objects.requireNonNull(inputData)); 519 mExpectedOutputStructure = Objects.requireNonNull(expectedOutputStructure); 520 } 521 522 /** 523 * Creates a new Builder with provided runtime parameters and input data. 524 * 525 * <p>For LiteRT, inputData field is mapped to inputs of runForMultipleInputsOutputs: 526 * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 527 * For example, if a model takes multiple inputs: 528 * 529 * <pre>{@code 530 * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. 531 * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. 532 * Object[] data = {input0, input1, ...}; 533 * byte[] inputData = serializeObject(data); 534 * }</pre> 535 * 536 * <p>For Executorch, inputData field is mapped to a serialized EValue array. 537 * 538 * @param params configuration that controls runtime interpreter behavior. 539 * @param inputData byte array that holds serialized input data. 540 */ 541 @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED) Builder(@onNull Params params, @NonNull byte[] inputData)542 public Builder(@NonNull Params params, @NonNull byte[] inputData) { 543 mParams = Objects.requireNonNull(params); 544 mData = Objects.requireNonNull(inputData); 545 } 546 547 /** The configuration that controls runtime interpreter behavior. */ setParams(@onNull Params value)548 public @NonNull Builder setParams(@NonNull Params value) { 549 mBuilderFieldsSet |= 0x1; 550 mParams = value; 551 return this; 552 } 553 554 /** 555 * A byte array that holds input data. The inputs should be in the same order as inputs of 556 * the model. 557 * 558 * <p>For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs: 559 * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 560 * 561 * <pre>{@code 562 * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. 563 * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. 564 * Object[] data = {input0, input1, ...}; 565 * byte[] inputData = serializeObject(data); 566 * }</pre> 567 * 568 * <p>For Executorch model, this field is a serialized EValue array. 569 */ 570 @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED) setInputData(@onNull byte[] value)571 public @NonNull Builder setInputData(@NonNull byte[] value) { 572 mBuilderFieldsSet |= 0x2; 573 mData = value; 574 return this; 575 } 576 577 /** 578 * Note: use {@link InferenceInput.Builder#setInputData(byte[])} instead. 579 * 580 * <p>An array of input data. The inputs should be in the same order as inputs of the model. 581 * 582 * <p>For example, if a model takes multiple inputs: 583 * 584 * <pre>{@code 585 * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. 586 * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. 587 * Object[] inputData = {input0, input1, ...}; 588 * }</pre> 589 * 590 * For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs: 591 * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 592 */ setInputData(@onNull Object... value)593 public @NonNull Builder setInputData(@NonNull Object... value) { 594 mBuilderFieldsSet |= 0x2; 595 mData = serializeObject(value); 596 return this; 597 } 598 599 /** 600 * The number of input examples. Adopter can set this field to run batching inference. The 601 * batch size is 1 by default. The batch size should match the input data size. 602 */ setBatchSize(int value)603 public @NonNull Builder setBatchSize(int value) { 604 mBuilderFieldsSet |= 0x4; 605 mBatchSize = value; 606 return this; 607 } 608 609 /** 610 * The empty InferenceOutput representing the expected output structure. It's only required 611 * by LiteRT model. For LiteRT, the inference code will verify whether this expected output 612 * structure matches model output signature. 613 * 614 * <p>If a model produce string tensors: 615 * 616 * <pre>{@code 617 * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. 618 * HashMap<Integer, Object> outputs = new HashMap<>(); 619 * outputs.put(0, output); 620 * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); 621 * }</pre> 622 */ setExpectedOutputStructure(@onNull InferenceOutput value)623 public @NonNull Builder setExpectedOutputStructure(@NonNull InferenceOutput value) { 624 mBuilderFieldsSet |= 0x8; 625 mExpectedOutputStructure = value; 626 return this; 627 } 628 629 /** @hide */ validateInputData()630 private void validateInputData() { 631 Preconditions.checkArgument( 632 mData.length > 0, "Input data should not be empty for InferenceInput."); 633 } 634 635 /** @hide */ validateOutputStructure()636 private void validateOutputStructure() { 637 // ExecuTorch model doesn't require set output structure. 638 if (mParams.getModelType() != Params.MODEL_TYPE_TENSORFLOW_LITE) { 639 return; 640 } 641 Preconditions.checkArgument( 642 !mExpectedOutputStructure.getDataOutputs().isEmpty() 643 || mExpectedOutputStructure.getData().length > 0, 644 "ExpectedOutputStructure field is required for TensorflowLite model."); 645 } 646 647 /** Builds the instance. This builder should not be touched after calling this! */ build()648 public @NonNull InferenceInput build() { 649 650 mBuilderFieldsSet |= 0x10; // Mark builder used 651 652 if ((mBuilderFieldsSet & 0x4) == 0) { 653 mBatchSize = 1; 654 } 655 validateInputData(); 656 validateOutputStructure(); 657 InferenceInput o = 658 new InferenceInput(mParams, mData, mBatchSize, mExpectedOutputStructure); 659 return o; 660 } 661 } 662 } 663