• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.lang.reflect.InvocationTargetException;
19 import java.nio.ByteBuffer;
20 import java.nio.ByteOrder;
21 import java.nio.MappedByteBuffer;
22 import java.util.ArrayList;
23 import java.util.HashMap;
24 import java.util.List;
25 import java.util.Map;
26 import java.util.TreeMap;
27 import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;
28 import org.tensorflow.lite.InterpreterImpl.Options;
29 import org.tensorflow.lite.annotations.UsedByReflection;
30 import org.tensorflow.lite.nnapi.NnApiDelegate;
31 
32 /**
33  * An internal wrapper that wraps native interpreter and controls model execution.
34  *
35  * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
36  * explicitly freed by invoking the {@link #close()} method when the {@code
37  * NativeInterpreterWrapper} object is no longer needed.
38  *
39  * <p>Note: This class is not thread safe.
40  */
41 class NativeInterpreterWrapper implements AutoCloseable {
42 
43   // This is changed to RuntimeFlavor.SYSTEM for TF Lite in Google Play Services.
44   private static final RuntimeFlavor RUNTIME_FLAVOR = RuntimeFlavor.APPLICATION;
45 
NativeInterpreterWrapper(String modelPath)46   NativeInterpreterWrapper(String modelPath) {
47     this(modelPath, /* options= */ null);
48   }
49 
NativeInterpreterWrapper(ByteBuffer byteBuffer)50   NativeInterpreterWrapper(ByteBuffer byteBuffer) {
51     this(byteBuffer, /* options= */ null);
52   }
53 
NativeInterpreterWrapper(String modelPath, InterpreterImpl.Options options)54   NativeInterpreterWrapper(String modelPath, InterpreterImpl.Options options) {
55     TensorFlowLite.init();
56     long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
57     long modelHandle = createModel(modelPath, errorHandle);
58     init(errorHandle, modelHandle, options);
59   }
60 
NativeInterpreterWrapper(ByteBuffer buffer, InterpreterImpl.Options options)61   NativeInterpreterWrapper(ByteBuffer buffer, InterpreterImpl.Options options) {
62     TensorFlowLite.init();
63     if (buffer == null
64         || (!(buffer instanceof MappedByteBuffer)
65             && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) {
66       throw new IllegalArgumentException(
67           "Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct "
68               + "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content.");
69     }
70     this.modelByteBuffer = buffer;
71     long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
72     long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
73     init(errorHandle, modelHandle, options);
74   }
75 
init(long errorHandle, long modelHandle, InterpreterImpl.Options options)76   private void init(long errorHandle, long modelHandle, InterpreterImpl.Options options) {
77     if (options == null) {
78       options = new InterpreterImpl.Options();
79     }
80     this.errorHandle = errorHandle;
81     this.modelHandle = modelHandle;
82     // First create the interpreter without delegates.  We need an interpreter in order to figure
83     // out whether the model contains any unresolved flex ops, and creating the interpreter with
84     // delegates might fail if there are any unresolved flex ops.
85     // (Alternatively, we could determine this without needing to recreate the interpreter
86     // by passing the tflite::Model in to here, and then traversing that?)
87     ArrayList<Long> delegateHandles = new ArrayList<>();
88     boolean useXnnpack = true;
89     if (options.useXNNPACK != null) {
90       useXnnpack = options.useXNNPACK;
91     }
92     this.interpreterHandle =
93         createInterpreter(
94             modelHandle, errorHandle, options.getNumThreads(), useXnnpack, delegateHandles);
95     this.originalGraphHasUnresolvedFlexOp = hasUnresolvedFlexOp(interpreterHandle);
96     addDelegates(options);
97     initDelegatesWithInterpreterFactory();
98     delegateHandles.ensureCapacity(delegates.size());
99     for (Delegate delegate : delegates) {
100       delegateHandles.add(delegate.getNativeHandle());
101     }
102     if (!delegateHandles.isEmpty()) {
103       // If there are any delegates enabled, recreate the interpreter with those delegates.
104       delete(/* errorHandle= */ 0, /* modelHandle= */ 0, this.interpreterHandle);
105       this.interpreterHandle =
106           createInterpreter(
107               modelHandle, errorHandle, options.getNumThreads(), useXnnpack, delegateHandles);
108     }
109     if (options.allowFp16PrecisionForFp32 != null) {
110       allowFp16PrecisionForFp32(interpreterHandle, options.allowFp16PrecisionForFp32);
111     }
112     if (options.allowBufferHandleOutput != null) {
113       allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput);
114     }
115     if (options.isCancellable()) {
116       this.cancellationFlagHandle = createCancellationFlag(interpreterHandle);
117     }
118     this.inputTensors = new TensorImpl[getInputCount(interpreterHandle)];
119     this.outputTensors = new TensorImpl[getOutputCount(interpreterHandle)];
120     if (options.allowFp16PrecisionForFp32 != null) {
121       allowFp16PrecisionForFp32(interpreterHandle, options.allowFp16PrecisionForFp32);
122     }
123     if (options.allowBufferHandleOutput != null) {
124       allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput);
125     }
126     allocateTensors(interpreterHandle, errorHandle);
127     this.isMemoryAllocated = true;
128   }
129 
130   /** Releases resources associated with this {@code NativeInterpreterWrapper}. */
131   @Override
close()132   public void close() {
133     // Close the tensors first as they may reference the native interpreter.
134     for (int i = 0; i < inputTensors.length; ++i) {
135       if (inputTensors[i] != null) {
136         inputTensors[i].close();
137         inputTensors[i] = null;
138       }
139     }
140     for (int i = 0; i < outputTensors.length; ++i) {
141       if (outputTensors[i] != null) {
142         outputTensors[i].close();
143         outputTensors[i] = null;
144       }
145     }
146     delete(errorHandle, modelHandle, interpreterHandle);
147     deleteCancellationFlag(cancellationFlagHandle);
148     errorHandle = 0;
149     modelHandle = 0;
150     interpreterHandle = 0;
151     cancellationFlagHandle = 0;
152     modelByteBuffer = null;
153     inputsIndexes = null;
154     outputsIndexes = null;
155     isMemoryAllocated = false;
156     delegates.clear();
157     for (Delegate ownedDelegate : ownedDelegates) {
158       ownedDelegate.close();
159     }
160     ownedDelegates.clear();
161   }
162 
163   /** Runs model inference based on SignatureDef provided through {@code signatureKey}. */
runSignature( Map<String, Object> inputs, Map<String, Object> outputs, String signatureKey)164   public void runSignature(
165       Map<String, Object> inputs, Map<String, Object> outputs, String signatureKey) {
166     inferenceDurationNanoseconds = -1;
167     if (inputs == null || inputs.isEmpty()) {
168       throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
169     }
170     if (outputs == null) {
171       throw new IllegalArgumentException("Input error: Outputs should not be null.");
172     }
173     NativeSignatureRunnerWrapper signatureRunnerWrapper = getSignatureRunnerWrapper(signatureKey);
174     int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex();
175     if (subgraphIndex == 0) {
176       // Map inputs/output to input indexes.
177       Object[] inputsList = new Object[inputs.size()];
178       for (Map.Entry<String, Object> input : inputs.entrySet()) {
179         inputsList[signatureRunnerWrapper.getInputIndex(input.getKey())] = input.getValue();
180       }
181       Map<Integer, Object> outputsWithOutputIndex = new TreeMap<>();
182       for (Map.Entry<String, Object> output : outputs.entrySet()) {
183         outputsWithOutputIndex.put(
184             signatureRunnerWrapper.getOutputIndex(output.getKey()), output.getValue());
185       }
186       run(inputsList, outputsWithOutputIndex);
187       return;
188     }
189 
190     for (Map.Entry<String, Object> input : inputs.entrySet()) {
191       TensorImpl tensor = getInputTensor(input.getKey(), signatureKey);
192       int[] newShape = tensor.getInputShapeIfDifferent(input.getValue());
193       if (newShape != null) {
194         signatureRunnerWrapper.resizeInput(input.getKey(), newShape);
195       }
196     }
197 
198     signatureRunnerWrapper.allocateTensorsIfNeeded();
199 
200     for (Map.Entry<String, Object> input : inputs.entrySet()) {
201       signatureRunnerWrapper.getInputTensor(input.getKey()).setTo(input.getValue());
202     }
203 
204     long inferenceStartNanos = System.nanoTime();
205     signatureRunnerWrapper.invoke();
206     long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
207 
208     for (Map.Entry<String, Object> output : outputs.entrySet()) {
209       // Null output placeholders are allowed and ignored.
210       if (output.getValue() != null) {
211         signatureRunnerWrapper.getOutputTensor(output.getKey()).copyTo(output.getValue());
212       }
213     }
214 
215     // Only set if the entire operation succeeds.
216     this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
217   }
218 
219   /** Sets inputs, runs model inference and returns outputs. */
run(Object[] inputs, Map<Integer, Object> outputs)220   void run(Object[] inputs, Map<Integer, Object> outputs) {
221     inferenceDurationNanoseconds = -1;
222     if (inputs == null || inputs.length == 0) {
223       throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
224     }
225     if (outputs == null) {
226       throw new IllegalArgumentException("Input error: Outputs should not be null.");
227     }
228 
229     // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs.
230     // Rather than forcing an immediate resize + allocation if an input's shape differs, we first
231     // flush all resizes, avoiding redundant allocations.
232     for (int i = 0; i < inputs.length; ++i) {
233       TensorImpl tensor = getInputTensor(i);
234       int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
235       if (newShape != null) {
236         resizeInput(i, newShape);
237       }
238     }
239 
240     boolean allocatedTensors = allocateTensorsIfNeeded();
241 
242     for (int i = 0; i < inputs.length; ++i) {
243       getInputTensor(i).setTo(inputs[i]);
244     }
245 
246     long inferenceStartNanos = System.nanoTime();
247     run(interpreterHandle, errorHandle);
248     long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
249 
250     // Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes.
251     if (allocatedTensors) {
252       for (TensorImpl outputTensor : outputTensors) {
253         if (outputTensor != null) {
254           outputTensor.refreshShape();
255         }
256       }
257     }
258     for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
259       // Null output placeholders are allowed and ignored.
260       if (output.getValue() != null) {
261         getOutputTensor(output.getKey()).copyTo(output.getValue());
262       }
263     }
264 
265     // Only set if the entire operation succeeds.
266     this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
267   }
268 
269   /** Resizes dimensions of a specific input. */
resizeInput(int idx, int[] dims)270   void resizeInput(int idx, int[] dims) {
271     resizeInput(idx, dims, false);
272   }
273 
274   /** Resizes dimensions of a specific input. */
resizeInput(int idx, int[] dims, boolean strict)275   void resizeInput(int idx, int[] dims, boolean strict) {
276     if (resizeInput(interpreterHandle, errorHandle, idx, dims, strict)) {
277       // Tensor allocation is deferred until either an explicit `allocateTensors()` call or
278       // `invoke()` avoiding redundant allocations if multiple tensors are simultaneosly resized.
279       isMemoryAllocated = false;
280       if (inputTensors[idx] != null) {
281         inputTensors[idx].refreshShape();
282       }
283     }
284   }
285 
286   /** Triggers explicit allocation of tensors. */
allocateTensors()287   void allocateTensors() {
288     allocateTensorsIfNeeded();
289   }
290 
291   /**
292    * Allocates tensor memory space in the given subgraph and returns true when allocation happens
293    */
allocateTensorsIfNeeded()294   private boolean allocateTensorsIfNeeded() {
295     if (isMemoryAllocated) {
296       return false;
297     }
298 
299     isMemoryAllocated = true;
300     allocateTensors(interpreterHandle, errorHandle);
301     for (TensorImpl outputTensor : outputTensors) {
302       if (outputTensor != null) {
303         outputTensor.refreshShape();
304       }
305     }
306     return true;
307   }
308 
309   /** Gets index of an input given its name. */
getInputIndex(String name)310   int getInputIndex(String name) {
311     if (inputsIndexes == null) {
312       String[] names = getInputNames(interpreterHandle);
313       inputsIndexes = new HashMap<>();
314       if (names != null) {
315         for (int i = 0; i < names.length; ++i) {
316           inputsIndexes.put(names[i], i);
317         }
318       }
319     }
320     if (inputsIndexes.containsKey(name)) {
321       return inputsIndexes.get(name);
322     } else {
323       throw new IllegalArgumentException(
324           String.format(
325               "Input error: '%s' is not a valid name for any input. Names of inputs and their "
326                   + "indexes are %s",
327               name, inputsIndexes));
328     }
329   }
330 
331   /** Gets index of an output given its name. */
getOutputIndex(String name)332   int getOutputIndex(String name) {
333     if (outputsIndexes == null) {
334       String[] names = getOutputNames(interpreterHandle);
335       outputsIndexes = new HashMap<>();
336       if (names != null) {
337         for (int i = 0; i < names.length; ++i) {
338           outputsIndexes.put(names[i], i);
339         }
340       }
341     }
342     if (outputsIndexes.containsKey(name)) {
343       return outputsIndexes.get(name);
344     } else {
345       throw new IllegalArgumentException(
346           String.format(
347               "Input error: '%s' is not a valid name for any output. Names of outputs and their "
348                   + "indexes are %s",
349               name, outputsIndexes));
350     }
351   }
352 
353   /**
354    * Gets the last inference duration in nanoseconds. It returns null if there is no previous
355    * inference run or the last inference run failed.
356    */
getLastNativeInferenceDurationNanoseconds()357   Long getLastNativeInferenceDurationNanoseconds() {
358     return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds;
359   }
360 
361   /** Gets the number of input tensors. */
getInputTensorCount()362   int getInputTensorCount() {
363     return inputTensors.length;
364   }
365 
366   /**
367    * Gets the input {@link TensorImpl} for the provided input index.
368    *
369    * @throws IllegalArgumentException if the input index is invalid.
370    */
getInputTensor(int index)371   TensorImpl getInputTensor(int index) {
372     if (index < 0 || index >= inputTensors.length) {
373       throw new IllegalArgumentException("Invalid input Tensor index: " + index);
374     }
375     TensorImpl inputTensor = inputTensors[index];
376     if (inputTensor == null) {
377       inputTensor =
378           inputTensors[index] =
379               TensorImpl.fromIndex(
380                   interpreterHandle, getInputTensorIndex(interpreterHandle, index));
381     }
382     return inputTensor;
383   }
384 
385   /**
386    * Gets the input {@link TensorImpl} given the tensor name and method in the signature.
387    *
388    * @throws IllegalArgumentException if the input name is invalid.
389    */
getInputTensor(String inputName, String signatureKey)390   TensorImpl getInputTensor(String inputName, String signatureKey) {
391     if (inputName == null) {
392       throw new IllegalArgumentException("Invalid input tensor name provided (null)");
393     }
394     NativeSignatureRunnerWrapper signatureRunnerWrapper = getSignatureRunnerWrapper(signatureKey);
395     int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex();
396     if (subgraphIndex > 0) {
397       return signatureRunnerWrapper.getInputTensor(inputName);
398     }
399 
400     int inputIndex = signatureRunnerWrapper.getInputIndex(inputName);
401     return getInputTensor(inputIndex);
402   }
403 
404   /** Gets the keys of SignatureDefs available in the model, if any. */
getSignatureKeys()405   public String[] getSignatureKeys() {
406     return getSignatureKeys(interpreterHandle);
407   }
408 
409   /** Gets the list of SignatureDefs inputs for method {@code signatureKey} */
getSignatureInputs(String signatureKey)410   String[] getSignatureInputs(String signatureKey) {
411     return getSignatureRunnerWrapper(signatureKey).inputNames();
412   }
413 
414   /** Gets the list of SignatureDefs outputs for method {@code signatureKey} */
getSignatureOutputs(String signatureKey)415   String[] getSignatureOutputs(String signatureKey) {
416     return getSignatureRunnerWrapper(signatureKey).outputNames();
417   }
418 
419   /** Gets the number of output tensors. */
getOutputTensorCount()420   int getOutputTensorCount() {
421     return outputTensors.length;
422   }
423 
424   /**
425    * Gets the output {@link TensorImpl} for the provided output index.
426    *
427    * @throws IllegalArgumentException if the output index is invalid.
428    */
getOutputTensor(int index)429   TensorImpl getOutputTensor(int index) {
430     if (index < 0 || index >= outputTensors.length) {
431       throw new IllegalArgumentException("Invalid output Tensor index: " + index);
432     }
433     TensorImpl outputTensor = outputTensors[index];
434     if (outputTensor == null) {
435       outputTensor =
436           outputTensors[index] =
437               TensorImpl.fromIndex(
438                   interpreterHandle, getOutputTensorIndex(interpreterHandle, index));
439     }
440     return outputTensor;
441   }
442 
443   /**
444    * Gets the output {@link TensorImpl} given the tensor name and method in the signature.
445    *
446    * @throws IllegalArgumentException if the output name is invalid.
447    */
getOutputTensor(String outputName, String signatureKey)448   TensorImpl getOutputTensor(String outputName, String signatureKey) {
449     if (outputName == null) {
450       throw new IllegalArgumentException("Invalid output tensor name provided (null)");
451     }
452     NativeSignatureRunnerWrapper signatureRunnerWrapper = getSignatureRunnerWrapper(signatureKey);
453     int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex();
454     if (subgraphIndex > 0) {
455       return signatureRunnerWrapper.getOutputTensor(outputName);
456     }
457 
458     int outputIndex = signatureRunnerWrapper.getOutputIndex(outputName);
459     return getOutputTensor(outputIndex);
460   }
461 
462   /** Gets the number of ops in the execution plan. */
getExecutionPlanLength()463   int getExecutionPlanLength() {
464     return getExecutionPlanLength(interpreterHandle);
465   }
466 
467   /**
468    * Sets internal cancellation flag. If it's true, the interpreter will try to interrupt any
469    * invocation between ops.
470    */
setCancelled(boolean value)471   void setCancelled(boolean value) {
472     if (cancellationFlagHandle == 0) {
473       throw new IllegalStateException(
474           "Cannot cancel the inference. Have you called InterpreterApi.Options.setCancellable?");
475     }
476     setCancelled(interpreterHandle, cancellationFlagHandle, value);
477   }
478 
479   // Add all the delegates specified in the options (other than XNNPACK) to this.delegates.
addDelegates(InterpreterImpl.Options options)480   private void addDelegates(InterpreterImpl.Options options) {
481     // First add the flex delegate if necessary. This ensures the graph is fully resolved before
482     // applying other delegates.
483     if (originalGraphHasUnresolvedFlexOp) {
484       Delegate optionalFlexDelegate = maybeCreateFlexDelegate(options.getDelegates());
485       if (optionalFlexDelegate != null) {
486         ownedDelegates.add(optionalFlexDelegate);
487         delegates.add(optionalFlexDelegate);
488       }
489     }
490     // Now add the user-supplied delegates.
491     addUserProvidedDelegates(options);
492     for (DelegateFactory delegateFactory : options.getDelegateFactories()) {
493       Delegate delegate = delegateFactory.create(RUNTIME_FLAVOR);
494       ownedDelegates.add(delegate);
495       delegates.add(delegate);
496     }
497     if (options.getUseNNAPI()) {
498       NnApiDelegate optionalNnApiDelegate = new NnApiDelegate();
499       ownedDelegates.add(optionalNnApiDelegate);
500       delegates.add(optionalNnApiDelegate);
501     }
502   }
503 
addUserProvidedDelegates(Options options)504   private void addUserProvidedDelegates(Options options) {
505     for (Delegate delegate : options.getDelegates()) {
506       // NnApiDelegate is compatible with both the system and built-in runtimes and therefore can be
507       // added directly even when using TF Lite from the system.
508       if (options.getRuntime() != TfLiteRuntime.FROM_APPLICATION_ONLY
509           && !(delegate instanceof NnApiDelegate)) {
510         throw new IllegalArgumentException(
511             "Instantiated delegates (other than NnApiDelegate) are not allowed when using TF Lite"
512                 + " from Google Play Services. Please use"
513                 + " InterpreterApi.Options.addDelegateFactory() with an appropriate DelegateFactory"
514                 + " instead.");
515       }
516       delegates.add(delegate);
517     }
518   }
519 
520   // Complete the initialization of any delegates that require an InterpreterFactoryApi instance.
initDelegatesWithInterpreterFactory()521   private void initDelegatesWithInterpreterFactory() {
522     InterpreterFactoryApi interpreterFactoryApi = new InterpreterFactoryImpl();
523     for (Delegate delegate : delegates) {
524       if (delegate instanceof NnApiDelegate) {
525         ((NnApiDelegate) delegate).initWithInterpreterFactoryApi(interpreterFactoryApi);
526       }
527     }
528   }
529 
getSignatureRunnerWrapper(String signatureKey)530   private NativeSignatureRunnerWrapper getSignatureRunnerWrapper(String signatureKey) {
531     if (signatureRunnerMap == null) {
532       signatureRunnerMap = new HashMap<>();
533     }
534     if (!signatureRunnerMap.containsKey(signatureKey)) {
535       signatureRunnerMap.put(
536           signatureKey,
537           new NativeSignatureRunnerWrapper(interpreterHandle, errorHandle, signatureKey));
538     }
539     return signatureRunnerMap.get(signatureKey);
540   }
541 
maybeCreateFlexDelegate(List<Delegate> delegates)542   private static Delegate maybeCreateFlexDelegate(List<Delegate> delegates) {
543     try {
544       Class<?> clazz = Class.forName("org.tensorflow.lite.flex.FlexDelegate");
545       // No need to create the Flex delegate if one has already been provided.
546       for (Delegate delegate : delegates) {
547         if (clazz.isInstance(delegate)) {
548           return null;
549         }
550       }
551       return (Delegate) clazz.getConstructor().newInstance();
552     } catch (ClassNotFoundException
553         | IllegalAccessException
554         | IllegalArgumentException
555         | InstantiationException
556         | InvocationTargetException
557         | NoSuchMethodException
558         | SecurityException e) {
559       // The error will propagate when tensors are allocated.
560       return null;
561     }
562   }
563 
564   private static final int ERROR_BUFFER_SIZE = 512;
565 
566   long errorHandle;
567 
568   long interpreterHandle;
569 
570   private long modelHandle;
571 
572   private long cancellationFlagHandle = 0;
573 
574   @UsedByReflection("nativeinterpreterwrapper_jni.cc")
575   private long inferenceDurationNanoseconds = -1;
576 
577   private ByteBuffer modelByteBuffer;
578 
579   // Lazily constructed maps of input and output names to input and output Tensor indexes.
580   private Map<String, Integer> inputsIndexes;
581   private Map<String, Integer> outputsIndexes;
582 
583   // A map from signature key to its native wrapper object.
584   private Map<String, NativeSignatureRunnerWrapper> signatureRunnerMap;
585 
586   // Lazily constructed and populated arrays of input and output Tensor wrappers.
587   private TensorImpl[] inputTensors;
588   private TensorImpl[] outputTensors;
589 
590   // Whether subgraph's tensor memory space is allocated.
591   private boolean isMemoryAllocated = false;
592 
593   // Whether the model has any Flex custom ops that can't be resolved by the OpResolver.
594   private boolean originalGraphHasUnresolvedFlexOp = false;
595 
596   // As the Java Delegate owns the native delegate instance, we keep a strong ref to any injected
597   // delegates for safety.
598   private final List<Delegate> delegates = new ArrayList<>();
599 
600   // List of owned delegates that must be closed when the interpreter is closed.
601   private final List<Delegate> ownedDelegates = new ArrayList<>();
602 
run(long interpreterHandle, long errorHandle)603   private static native void run(long interpreterHandle, long errorHandle);
604 
resizeInput( long interpreterHandle, long errorHandle, int inputIdx, int[] dims, boolean strict)605   private static native boolean resizeInput(
606       long interpreterHandle, long errorHandle, int inputIdx, int[] dims, boolean strict);
607 
allocateTensors(long interpreterHandle, long errorHandle)608   private static native long allocateTensors(long interpreterHandle, long errorHandle);
609 
getSignatureKeys(long interpreterHandle)610   private static native String[] getSignatureKeys(long interpreterHandle);
611 
setCancelled( long interpreterHandle, long cancellationFlagHandle, boolean value)612   private static native void setCancelled(
613       long interpreterHandle, long cancellationFlagHandle, boolean value);
614 
hasUnresolvedFlexOp(long interpreterHandle)615   private static native boolean hasUnresolvedFlexOp(long interpreterHandle);
616 
getInputTensorIndex(long interpreterHandle, int inputIdx)617   private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
618 
getOutputTensorIndex(long interpreterHandle, int outputIdx)619   private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
620 
getInputCount(long interpreterHandle)621   private static native int getInputCount(long interpreterHandle);
622 
getOutputCount(long interpreterHandle)623   private static native int getOutputCount(long interpreterHandle);
624 
getExecutionPlanLength(long interpreterHandle)625   private static native int getExecutionPlanLength(long interpreterHandle);
626 
getInputNames(long interpreterHandle)627   private static native String[] getInputNames(long interpreterHandle);
628 
getOutputNames(long interpreterHandle)629   private static native String[] getOutputNames(long interpreterHandle);
630 
allowFp16PrecisionForFp32(long interpreterHandle, boolean allow)631   private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow);
632 
allowBufferHandleOutput(long interpreterHandle, boolean allow)633   private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
634 
createErrorReporter(int size)635   private static native long createErrorReporter(int size);
636 
createModel(String modelPathOrBuffer, long errorHandle)637   private static native long createModel(String modelPathOrBuffer, long errorHandle);
638 
createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle)639   private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle);
640 
createInterpreter( long modelHandle, long errorHandle, int numThreads, boolean useXnnpack, List<Long> delegateHandles)641   private static native long createInterpreter(
642       long modelHandle,
643       long errorHandle,
644       int numThreads,
645       boolean useXnnpack,
646       List<Long> delegateHandles);
647 
createCancellationFlag(long interpreterHandle)648   private static native long createCancellationFlag(long interpreterHandle);
649 
deleteCancellationFlag(long cancellationFlagHandle)650   private static native long deleteCancellationFlag(long cancellationFlagHandle);
651 
delete(long errorHandle, long modelHandle, long interpreterHandle)652   private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
653 }
654