• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.ondevicepersonalization;
18 
19 import static com.android.ondevicepersonalization.internal.util.ByteArrayUtil.deserializeObject;
20 import static com.android.ondevicepersonalization.internal.util.ByteArrayUtil.serializeObject;
21 
22 import android.annotation.FlaggedApi;
23 import android.annotation.IntDef;
24 import android.annotation.IntRange;
25 import android.annotation.NonNull;
26 import android.annotation.SuppressLint;
27 
28 import com.android.adservices.ondevicepersonalization.flags.Flags;
29 import com.android.internal.util.Preconditions;
30 
31 import java.lang.annotation.Retention;
32 import java.lang.annotation.RetentionPolicy;
33 import java.util.Objects;
34 
35 /**
36  * Contains all the information needed for a run of model inference. The input of {@link
37  * ModelManager#run}.
38  */
39 public final class InferenceInput {
40     /** The configuration that controls runtime interpreter behavior. */
41     @NonNull private Params mParams;
42 
43     /**
44      * A byte array that holds input data. The inputs should be in the same order as inputs of the
45      * model.
46      *
47      * <p>For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs:
48      * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9
49      *
50      * <pre>{@code
51      * String[] input0 = {"foo", "bar"}; // string tensor shape is [2].
52      * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3].
53      * Object[] inputData = {input0, input1, ...};
54      * byte[] data = serializeObject(inputData);
55      * }</pre>
56      *
57      * <p>For Executorch model, this field is a serialized EValue array.
58      *
59      * @hide
60      */
61     @NonNull private byte[] mData;
62 
63     /**
64      * The number of input examples. Adopter can set this field to run batching inference. The batch
65      * size is 1 by default. The batch size should match the input data size.
66      */
67     private int mBatchSize = 1;
68 
69     /**
70      * The empty InferenceOutput representing the expected output structure. For LiteRT, the
71      * inference code will verify whether this expected output structure matches model output
72      * signature.
73      *
74      * <p>If a model produce string tensors:
75      *
76      * <pre>{@code
77      * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
78      * HashMap<Integer, Object> outputs = new HashMap<>();
79      * outputs.put(0, output);
80      * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build();
81      * }</pre>
82      */
83     @NonNull private InferenceOutput mExpectedOutputStructure;
84 
85     public static class Params {
86         /** A {@link KeyValueStore} where pre-trained model is stored. */
87         @NonNull private KeyValueStore mKeyValueStore;
88 
89         /** The key of the table where the corresponding value stores a pre-trained model. */
90         @NonNull private String mModelKey;
91 
92         /** The model inference will run on CPU. */
93         public static final int DELEGATE_CPU = 1;
94 
95         /**
96          * The delegate to run model inference.
97          *
98          * @hide
99          */
100         @IntDef(
101                 prefix = "DELEGATE_",
102                 value = {DELEGATE_CPU})
103         @Retention(RetentionPolicy.SOURCE)
104         public @interface Delegate {}
105 
106         /**
107          * The delegate to run model inference. If not set, the default value is {@link
108          * #DELEGATE_CPU}.
109          */
110         private @Delegate int mDelegateType = DELEGATE_CPU;
111 
112         /** The model is a tensorflow lite model. */
113         public static final int MODEL_TYPE_TENSORFLOW_LITE = 1;
114 
115         /** The model is an executorch model. */
116         @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED)
117         public static final int MODEL_TYPE_EXECUTORCH = 2;
118 
119         /**
120          * The type of the model.
121          *
122          * @hide
123          */
124         @IntDef(
125                 prefix = "MODEL_TYPE",
126                 value = {MODEL_TYPE_TENSORFLOW_LITE, MODEL_TYPE_EXECUTORCH})
127         @Retention(RetentionPolicy.SOURCE)
128         public @interface ModelType {}
129 
130         /**
131          * The type of the pre-trained model. If not set, the default value is {@link
132          * #MODEL_TYPE_TENSORFLOW_LITE} .
133          */
134         private @ModelType int mModelType = MODEL_TYPE_TENSORFLOW_LITE;
135 
136         /**
137          * The number of threads used for intraop parallelism on CPU, must be positive number.
138          * Adopters can set this field based on model architecture. The actual thread number depends
139          * on system resources and other constraints.
140          */
141         private @IntRange(from = 1) int mRecommendedNumThreads = 1;
142 
143         /**
144          * Creates a new Params.
145          *
146          * @param keyValueStore A {@link KeyValueStore} where pre-trained model is stored.
147          * @param modelKey The key of the table where the corresponding value stores a pre-trained
148          *     model.
149          * @param delegateType The delegate to run model inference. If not set, the default value is
150          *     {@link #DELEGATE_CPU}.
151          * @param modelType The type of the pre-trained model. If not set, the default value is
152          *     {@link #MODEL_TYPE_TENSORFLOW_LITE} .
153          * @param recommendedNumThreads The number of threads used for intraop parallelism on CPU,
154          *     must be positive number. Adopters can set this field based on model architecture. The
155          *     actual thread number depends on system resources and other constraints.
156          * @hide
157          */
Params( @onNull KeyValueStore keyValueStore, @NonNull String modelKey, @Delegate int delegateType, @ModelType int modelType, @IntRange(from = 1) int recommendedNumThreads)158         public Params(
159                 @NonNull KeyValueStore keyValueStore,
160                 @NonNull String modelKey,
161                 @Delegate int delegateType,
162                 @ModelType int modelType,
163                 @IntRange(from = 1) int recommendedNumThreads) {
164             this.mKeyValueStore = Objects.requireNonNull(keyValueStore);
165             this.mModelKey = Objects.requireNonNull(modelKey);
166             this.mDelegateType = delegateType;
167             this.mModelType = modelType;
168 
169             if (!(mModelType == MODEL_TYPE_TENSORFLOW_LITE)
170                     && !(mModelType == MODEL_TYPE_EXECUTORCH)) {
171                 throw new java.lang.IllegalArgumentException(
172                         "modelType was "
173                                 + mModelType
174                                 + " but must be one of: "
175                                 + "MODEL_TYPE_TENSORFLOW_LITE("
176                                 + MODEL_TYPE_TENSORFLOW_LITE
177                                 + "), "
178                                 + "MODEL_TYPE_EXECUTORCH("
179                                 + MODEL_TYPE_EXECUTORCH
180                                 + ")");
181             }
182 
183             this.mRecommendedNumThreads = recommendedNumThreads;
184             Preconditions.checkState(
185                     recommendedNumThreads >= 1,
186                     "recommend thread number should be large or equal to 1");
187         }
188 
189         /** A {@link KeyValueStore} where pre-trained model is stored. */
getKeyValueStore()190         public @NonNull KeyValueStore getKeyValueStore() {
191             return mKeyValueStore;
192         }
193 
194         /** The key of the table where the corresponding value stores a pre-trained model. */
getModelKey()195         public @NonNull String getModelKey() {
196             return mModelKey;
197         }
198 
199         /**
200          * The delegate to run model inference. If not set, the default value is {@link
201          * #DELEGATE_CPU}.
202          */
getDelegateType()203         public @Delegate int getDelegateType() {
204             return mDelegateType;
205         }
206 
207         /**
208          * The type of the pre-trained model. If not set, the default value is {@link
209          * #MODEL_TYPE_TENSORFLOW_LITE} .
210          */
getModelType()211         public @ModelType int getModelType() {
212             return mModelType;
213         }
214 
215         /**
216          * The number of threads used for intraop parallelism on CPU, must be positive number.
217          * Adopters can set this field based on model architecture. The actual thread number depends
218          * on system resources and other constraints.
219          */
getRecommendedNumThreads()220         public @IntRange(from = 1) int getRecommendedNumThreads() {
221             return mRecommendedNumThreads;
222         }
223 
224         @Override
equals(@ndroid.annotation.Nullable Object o)225         public boolean equals(@android.annotation.Nullable Object o) {
226             // You can override field equality logic by defining either of the methods like:
227             // boolean fieldNameEquals(Params other) { ... }
228             // boolean fieldNameEquals(FieldType otherValue) { ... }
229 
230             if (this == o) return true;
231             if (o == null || getClass() != o.getClass()) return false;
232             @SuppressWarnings("unchecked")
233             Params that = (Params) o;
234             //noinspection PointlessBooleanExpression
235             return true
236                     && java.util.Objects.equals(mKeyValueStore, that.mKeyValueStore)
237                     && java.util.Objects.equals(mModelKey, that.mModelKey)
238                     && mDelegateType == that.mDelegateType
239                     && mModelType == that.mModelType
240                     && mRecommendedNumThreads == that.mRecommendedNumThreads;
241         }
242 
243         @Override
hashCode()244         public int hashCode() {
245             // You can override field hashCode logic by defining methods like:
246             // int fieldNameHashCode() { ... }
247 
248             int _hash = 1;
249             _hash = 31 * _hash + java.util.Objects.hashCode(mKeyValueStore);
250             _hash = 31 * _hash + java.util.Objects.hashCode(mModelKey);
251             _hash = 31 * _hash + mDelegateType;
252             _hash = 31 * _hash + mModelType;
253             _hash = 31 * _hash + mRecommendedNumThreads;
254             return _hash;
255         }
256 
257         /** A builder for {@link Params} */
258         @SuppressWarnings("WeakerAccess")
259         public static final class Builder {
260 
261             private @NonNull KeyValueStore mKeyValueStore;
262             private @NonNull String mModelKey;
263             private @Delegate int mDelegateType;
264             private @ModelType int mModelType;
265             private @IntRange(from = 1) int mRecommendedNumThreads;
266 
267             private long mBuilderFieldsSet = 0L;
268 
269             /**
270              * Creates a new Builder.
271              *
272              * @param keyValueStore a {@link KeyValueStore} where pre-trained model is stored.
273              * @param modelKey key of the table where the corresponding value stores a pre-trained
274              *     model.
275              */
Builder(@onNull KeyValueStore keyValueStore, @NonNull String modelKey)276             public Builder(@NonNull KeyValueStore keyValueStore, @NonNull String modelKey) {
277                 mKeyValueStore = Objects.requireNonNull(keyValueStore);
278                 mModelKey = Objects.requireNonNull(modelKey);
279             }
280 
281             /** A {@link KeyValueStore} where pre-trained model is stored. */
setKeyValueStore(@onNull KeyValueStore value)282             public @NonNull Builder setKeyValueStore(@NonNull KeyValueStore value) {
283                 mBuilderFieldsSet |= 0x1;
284                 mKeyValueStore = value;
285                 return this;
286             }
287 
288             /** The key of the table where the corresponding value stores a pre-trained model. */
setModelKey(@onNull String value)289             public @NonNull Builder setModelKey(@NonNull String value) {
290                 mBuilderFieldsSet |= 0x2;
291                 mModelKey = value;
292                 return this;
293             }
294 
295             /**
296              * The delegate to run model inference. If not set, the default value is {@link
297              * #DELEGATE_CPU}.
298              */
setDelegateType(@elegate int value)299             public @NonNull Builder setDelegateType(@Delegate int value) {
300                 mBuilderFieldsSet |= 0x4;
301                 mDelegateType = value;
302                 return this;
303             }
304 
305             /**
306              * The type of the pre-trained model. If not set, the default value is {@link
307              * #MODEL_TYPE_TENSORFLOW_LITE} .
308              */
setModelType(@odelType int value)309             public @NonNull Builder setModelType(@ModelType int value) {
310                 mBuilderFieldsSet |= 0x8;
311                 mModelType = value;
312                 return this;
313             }
314 
315             /**
316              * The number of threads used for intraop parallelism on CPU, must be positive number.
317              * Adopters can set this field based on model architecture. The actual thread number
318              * depends on system resources and other constraints.
319              */
setRecommendedNumThreads(@ntRangefrom = 1) int value)320             public @NonNull Builder setRecommendedNumThreads(@IntRange(from = 1) int value) {
321                 mBuilderFieldsSet |= 0x10;
322                 mRecommendedNumThreads = value;
323                 return this;
324             }
325 
326             /** Builds the instance. This builder should not be touched after calling this! */
build()327             public @NonNull Params build() {
328                 mBuilderFieldsSet |= 0x20; // Mark builder used
329 
330                 if ((mBuilderFieldsSet & 0x4) == 0) {
331                     mDelegateType = DELEGATE_CPU;
332                 }
333                 if ((mBuilderFieldsSet & 0x8) == 0) {
334                     mModelType = MODEL_TYPE_TENSORFLOW_LITE;
335                 }
336                 if ((mBuilderFieldsSet & 0x10) == 0) {
337                     mRecommendedNumThreads = 1;
338                 }
339                 Params o =
340                         new Params(
341                                 mKeyValueStore,
342                                 mModelKey,
343                                 mDelegateType,
344                                 mModelType,
345                                 mRecommendedNumThreads);
346                 return o;
347             }
348         }
349     }
350 
InferenceInput( @onNull Params params, @NonNull byte[] data, int batchSize, @NonNull InferenceOutput expectedOutputStructure)351     /* package-private */ InferenceInput(
352             @NonNull Params params,
353             @NonNull byte[] data,
354             int batchSize,
355             @NonNull InferenceOutput expectedOutputStructure) {
356         this.mParams = Objects.requireNonNull(params);
357         this.mData = Objects.requireNonNull(data);
358         this.mBatchSize = batchSize;
359         this.mExpectedOutputStructure = Objects.requireNonNull(expectedOutputStructure);
360     }
361 
362     /** The configuration that controls runtime interpreter behavior. */
getParams()363     public @NonNull Params getParams() {
364         return mParams;
365     }
366 
367     /**
368      * A byte array that holds input data. The inputs should be in the same order as inputs of the
369      * model.
370      *
371      * <p>For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs:
372      * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9
373      *
374      * <pre>{@code
375      * String[] input0 = {"foo", "bar"}; // string tensor shape is [2].
376      * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3].
377      * Object[] inputData = {input0, input1, ...};
378      * byte[] data = serializeObject(inputData);
379      * }</pre>
380      *
381      * <p>For Executorch model, this field is a serialized EValue array.
382      */
383     @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED)
getData()384     public @NonNull byte[] getData() {
385         return mData;
386     }
387 
388     /**
389      * Note: use {@link InferenceInput#getData()} instead.
390      *
391      * <p>An array of input data. The inputs should be in the same order as inputs of the model.
392      *
393      * <p>For example, if a model takes multiple inputs:
394      *
395      * <pre>{@code
396      * String[] input0 = {"foo", "bar"}; // string tensor shape is [2].
397      * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3].
398      * Object[] inputData = {input0, input1, ...};
399      * }</pre>
400      *
401      * For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs:
402      * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9
403      */
404     @SuppressLint("ArrayReturn")
getInputData()405     public @NonNull Object[] getInputData() {
406         return (Object[]) deserializeObject(mData);
407     }
408 
409     /**
410      * The number of input examples. Adopter can set this field to run batching inference. The batch
411      * size is 1 by default. The batch size should match the input data size.
412      */
getBatchSize()413     public int getBatchSize() {
414         return mBatchSize;
415     }
416 
417     /**
418      * The empty InferenceOutput representing the expected output structure. For LiteRT, the
419      * inference code will verify whether this expected output structure matches model output
420      * signature.
421      *
422      * <p>If a model produce string tensors:
423      *
424      * <pre>{@code
425      * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
426      * HashMap<Integer, Object> outputs = new HashMap<>();
427      * outputs.put(0, output);
428      * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build();
429      * }</pre>
430      */
getExpectedOutputStructure()431     public @NonNull InferenceOutput getExpectedOutputStructure() {
432         return mExpectedOutputStructure;
433     }
434 
435     @Override
equals(@ndroid.annotation.Nullable Object o)436     public boolean equals(@android.annotation.Nullable Object o) {
437         // You can override field equality logic by defining either of the methods like:
438         // boolean fieldNameEquals(InferenceInput other) { ... }
439         // boolean fieldNameEquals(FieldType otherValue) { ... }
440 
441         if (this == o) return true;
442         if (o == null || getClass() != o.getClass()) return false;
443         @SuppressWarnings("unchecked")
444         InferenceInput that = (InferenceInput) o;
445         //noinspection PointlessBooleanExpression
446         return true
447                 && java.util.Objects.equals(mParams, that.mParams)
448                 && java.util.Arrays.equals(mData, that.mData)
449                 && mBatchSize == that.mBatchSize
450                 && java.util.Objects.equals(
451                         mExpectedOutputStructure, that.mExpectedOutputStructure);
452     }
453 
454     @Override
hashCode()455     public int hashCode() {
456         // You can override field hashCode logic by defining methods like:
457         // int fieldNameHashCode() { ... }
458 
459         int _hash = 1;
460         _hash = 31 * _hash + java.util.Objects.hashCode(mParams);
461         _hash = 31 * _hash + java.util.Arrays.hashCode(mData);
462         _hash = 31 * _hash + mBatchSize;
463         _hash = 31 * _hash + java.util.Objects.hashCode(mExpectedOutputStructure);
464         return _hash;
465     }
466 
467     /** A builder for {@link InferenceInput} */
468     @SuppressWarnings("WeakerAccess")
469     public static final class Builder {
470 
471         private @NonNull Params mParams;
472         private @NonNull byte[] mData;
473         private int mBatchSize;
474         private @NonNull InferenceOutput mExpectedOutputStructure =
475                 new InferenceOutput.Builder().build();
476 
477         private long mBuilderFieldsSet = 0L;
478 
479         /**
480          * Note: use {@link InferenceInput.Builder#Builder(Params, byte[])} instead.
481          *
482          * <p>Creates a new Builder for LiteRT model inference input. For LiteRT, inputData field is
483          * mapped to inputs of runForMultipleInputsOutputs:
484          * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9
485          * The inputs should be in the same order as inputs * of the model. *
486          *
487          * <p>For example, if a model takes multiple inputs: *
488          *
489          * <pre>{@code
490          *  String[] input0 = {"foo", "bar"}; // string tensor shape is [2].
491          * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3].
492          * Object[] inputData = {input0, input1, ...};
493          * }</pre>
494          *
495          * For LiteRT, the inference code will verify whether the expected output structure matches
496          * model output signature.
497          *
498          * <p>If a model produce string tensors:
499          *
500          * <pre>{@code
501          * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
502          * HashMap<Integer, Object> outputs = new HashMap<>();
503          * outputs.put(0, output);
504          * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build();
505          *
506          * }</pre>
507          *
508          * @param params configuration that controls runtime interpreter behavior.
509          * @param inputData an array of input data.
510          * @param expectedOutputStructure an empty InferenceOutput representing the expected output
511          *     structure.
512          */
Builder( @onNull Params params, @SuppressLint("ArrayReturn") @NonNull Object[] inputData, @NonNull InferenceOutput expectedOutputStructure)513         public Builder(
514                 @NonNull Params params,
515                 @SuppressLint("ArrayReturn") @NonNull Object[] inputData,
516                 @NonNull InferenceOutput expectedOutputStructure) {
517             mParams = Objects.requireNonNull(params);
518             mData = serializeObject(Objects.requireNonNull(inputData));
519             mExpectedOutputStructure = Objects.requireNonNull(expectedOutputStructure);
520         }
521 
522         /**
523          * Creates a new Builder with provided runtime parameters and input data.
524          *
525          * <p>For LiteRT, inputData field is mapped to inputs of runForMultipleInputsOutputs:
526          * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9
527          * For example, if a model takes multiple inputs:
528          *
529          * <pre>{@code
530          * String[] input0 = {"foo", "bar"}; // string tensor shape is [2].
531          * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3].
532          * Object[] data = {input0, input1, ...};
533          * byte[] inputData = serializeObject(data);
534          * }</pre>
535          *
536          * <p>For Executorch, inputData field is mapped to a serialized EValue array.
537          *
538          * @param params configuration that controls runtime interpreter behavior.
539          * @param inputData byte array that holds serialized input data.
540          */
541         @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED)
Builder(@onNull Params params, @NonNull byte[] inputData)542         public Builder(@NonNull Params params, @NonNull byte[] inputData) {
543             mParams = Objects.requireNonNull(params);
544             mData = Objects.requireNonNull(inputData);
545         }
546 
547         /** The configuration that controls runtime interpreter behavior. */
setParams(@onNull Params value)548         public @NonNull Builder setParams(@NonNull Params value) {
549             mBuilderFieldsSet |= 0x1;
550             mParams = value;
551             return this;
552         }
553 
554         /**
555          * A byte array that holds input data. The inputs should be in the same order as inputs of
556          * the model.
557          *
558          * <p>For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs:
559          * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9
560          *
561          * <pre>{@code
562          * String[] input0 = {"foo", "bar"}; // string tensor shape is [2].
563          * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3].
564          * Object[] data = {input0, input1, ...};
565          * byte[] inputData = serializeObject(data);
566          * }</pre>
567          *
568          * <p>For Executorch model, this field is a serialized EValue array.
569          */
570         @FlaggedApi(Flags.FLAG_EXECUTORCH_INFERENCE_API_ENABLED)
setInputData(@onNull byte[] value)571         public @NonNull Builder setInputData(@NonNull byte[] value) {
572             mBuilderFieldsSet |= 0x2;
573             mData = value;
574             return this;
575         }
576 
577         /**
578          * Note: use {@link InferenceInput.Builder#setInputData(byte[])} instead.
579          *
580          * <p>An array of input data. The inputs should be in the same order as inputs of the model.
581          *
582          * <p>For example, if a model takes multiple inputs:
583          *
584          * <pre>{@code
585          * String[] input0 = {"foo", "bar"}; // string tensor shape is [2].
586          * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3].
587          * Object[] inputData = {input0, input1, ...};
588          * }</pre>
589          *
590          * For LiteRT, this field is mapped to inputs of runForMultipleInputsOutputs:
591          * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9
592          */
setInputData(@onNull Object... value)593         public @NonNull Builder setInputData(@NonNull Object... value) {
594             mBuilderFieldsSet |= 0x2;
595             mData = serializeObject(value);
596             return this;
597         }
598 
599         /**
600          * The number of input examples. Adopter can set this field to run batching inference. The
601          * batch size is 1 by default. The batch size should match the input data size.
602          */
setBatchSize(int value)603         public @NonNull Builder setBatchSize(int value) {
604             mBuilderFieldsSet |= 0x4;
605             mBatchSize = value;
606             return this;
607         }
608 
609         /**
610          * The empty InferenceOutput representing the expected output structure. It's only required
611          * by LiteRT model. For LiteRT, the inference code will verify whether this expected output
612          * structure matches model output signature.
613          *
614          * <p>If a model produce string tensors:
615          *
616          * <pre>{@code
617          * String[] output = new String[3][2];  // Output tensor shape is [3, 2].
618          * HashMap<Integer, Object> outputs = new HashMap<>();
619          * outputs.put(0, output);
620          * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build();
621          * }</pre>
622          */
setExpectedOutputStructure(@onNull InferenceOutput value)623         public @NonNull Builder setExpectedOutputStructure(@NonNull InferenceOutput value) {
624             mBuilderFieldsSet |= 0x8;
625             mExpectedOutputStructure = value;
626             return this;
627         }
628 
629         /** @hide */
validateInputData()630         private void validateInputData() {
631             Preconditions.checkArgument(
632                     mData.length > 0, "Input data should not be empty for InferenceInput.");
633         }
634 
635         /** @hide */
validateOutputStructure()636         private void validateOutputStructure() {
637             // ExecuTorch model doesn't require set output structure.
638             if (mParams.getModelType() != Params.MODEL_TYPE_TENSORFLOW_LITE) {
639                 return;
640             }
641             Preconditions.checkArgument(
642                     !mExpectedOutputStructure.getDataOutputs().isEmpty()
643                             || mExpectedOutputStructure.getData().length > 0,
644                     "ExpectedOutputStructure field is required for TensorflowLite model.");
645         }
646 
647         /** Builds the instance. This builder should not be touched after calling this! */
build()648         public @NonNull InferenceInput build() {
649 
650             mBuilderFieldsSet |= 0x10; // Mark builder used
651 
652             if ((mBuilderFieldsSet & 0x4) == 0) {
653                 mBatchSize = 1;
654             }
655             validateInputData();
656             validateOutputStructure();
657             InferenceInput o =
658                     new InferenceInput(mParams, mData, mBatchSize, mExpectedOutputStructure);
659             return o;
660         }
661     }
662 }
663