1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.lang.reflect.InvocationTargetException; 19 import java.nio.ByteBuffer; 20 import java.nio.ByteOrder; 21 import java.nio.MappedByteBuffer; 22 import java.util.ArrayList; 23 import java.util.HashMap; 24 import java.util.List; 25 import java.util.Map; 26 import java.util.TreeMap; 27 import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime; 28 import org.tensorflow.lite.InterpreterImpl.Options; 29 import org.tensorflow.lite.annotations.UsedByReflection; 30 import org.tensorflow.lite.nnapi.NnApiDelegate; 31 32 /** 33 * An internal wrapper that wraps native interpreter and controls model execution. 34 * 35 * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be 36 * explicitly freed by invoking the {@link #close()} method when the {@code 37 * NativeInterpreterWrapper} object is no longer needed. 38 * 39 * <p>Note: This class is not thread safe. 40 */ 41 class NativeInterpreterWrapper implements AutoCloseable { 42 43 // This is changed to RuntimeFlavor.SYSTEM for TF Lite in Google Play Services. 44 private static final RuntimeFlavor RUNTIME_FLAVOR = RuntimeFlavor.APPLICATION; 45 NativeInterpreterWrapper(String modelPath)46 NativeInterpreterWrapper(String modelPath) { 47 this(modelPath, /* options= */ null); 48 } 49 NativeInterpreterWrapper(ByteBuffer byteBuffer)50 NativeInterpreterWrapper(ByteBuffer byteBuffer) { 51 this(byteBuffer, /* options= */ null); 52 } 53 NativeInterpreterWrapper(String modelPath, InterpreterImpl.Options options)54 NativeInterpreterWrapper(String modelPath, InterpreterImpl.Options options) { 55 TensorFlowLite.init(); 56 long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); 57 long modelHandle = createModel(modelPath, errorHandle); 58 init(errorHandle, modelHandle, options); 59 } 60 NativeInterpreterWrapper(ByteBuffer buffer, InterpreterImpl.Options options)61 NativeInterpreterWrapper(ByteBuffer buffer, InterpreterImpl.Options options) { 62 TensorFlowLite.init(); 63 if (buffer == null 64 || (!(buffer instanceof MappedByteBuffer) 65 && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) { 66 throw new IllegalArgumentException( 67 "Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct " 68 + "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content."); 69 } 70 this.modelByteBuffer = buffer; 71 long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); 72 long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); 73 init(errorHandle, modelHandle, options); 74 } 75 init(long errorHandle, long modelHandle, InterpreterImpl.Options options)76 private void init(long errorHandle, long modelHandle, InterpreterImpl.Options options) { 77 if (options == null) { 78 options = new InterpreterImpl.Options(); 79 } 80 this.errorHandle = errorHandle; 81 this.modelHandle = modelHandle; 82 // First create the interpreter without delegates. We need an interpreter in order to figure 83 // out whether the model contains any unresolved flex ops, and creating the interpreter with 84 // delegates might fail if there are any unresolved flex ops. 85 // (Alternatively, we could determine this without needing to recreate the interpreter 86 // by passing the tflite::Model in to here, and then traversing that?) 87 ArrayList<Long> delegateHandles = new ArrayList<>(); 88 boolean useXnnpack = true; 89 if (options.useXNNPACK != null) { 90 useXnnpack = options.useXNNPACK; 91 } 92 this.interpreterHandle = 93 createInterpreter( 94 modelHandle, errorHandle, options.getNumThreads(), useXnnpack, delegateHandles); 95 this.originalGraphHasUnresolvedFlexOp = hasUnresolvedFlexOp(interpreterHandle); 96 addDelegates(options); 97 initDelegatesWithInterpreterFactory(); 98 delegateHandles.ensureCapacity(delegates.size()); 99 for (Delegate delegate : delegates) { 100 delegateHandles.add(delegate.getNativeHandle()); 101 } 102 if (!delegateHandles.isEmpty()) { 103 // If there are any delegates enabled, recreate the interpreter with those delegates. 104 delete(/* errorHandle= */ 0, /* modelHandle= */ 0, this.interpreterHandle); 105 this.interpreterHandle = 106 createInterpreter( 107 modelHandle, errorHandle, options.getNumThreads(), useXnnpack, delegateHandles); 108 } 109 if (options.allowFp16PrecisionForFp32 != null) { 110 allowFp16PrecisionForFp32(interpreterHandle, options.allowFp16PrecisionForFp32); 111 } 112 if (options.allowBufferHandleOutput != null) { 113 allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput); 114 } 115 if (options.isCancellable()) { 116 this.cancellationFlagHandle = createCancellationFlag(interpreterHandle); 117 } 118 this.inputTensors = new TensorImpl[getInputCount(interpreterHandle)]; 119 this.outputTensors = new TensorImpl[getOutputCount(interpreterHandle)]; 120 if (options.allowFp16PrecisionForFp32 != null) { 121 allowFp16PrecisionForFp32(interpreterHandle, options.allowFp16PrecisionForFp32); 122 } 123 if (options.allowBufferHandleOutput != null) { 124 allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput); 125 } 126 allocateTensors(interpreterHandle, errorHandle); 127 this.isMemoryAllocated = true; 128 } 129 130 /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ 131 @Override close()132 public void close() { 133 // Close the tensors first as they may reference the native interpreter. 134 for (int i = 0; i < inputTensors.length; ++i) { 135 if (inputTensors[i] != null) { 136 inputTensors[i].close(); 137 inputTensors[i] = null; 138 } 139 } 140 for (int i = 0; i < outputTensors.length; ++i) { 141 if (outputTensors[i] != null) { 142 outputTensors[i].close(); 143 outputTensors[i] = null; 144 } 145 } 146 delete(errorHandle, modelHandle, interpreterHandle); 147 deleteCancellationFlag(cancellationFlagHandle); 148 errorHandle = 0; 149 modelHandle = 0; 150 interpreterHandle = 0; 151 cancellationFlagHandle = 0; 152 modelByteBuffer = null; 153 inputsIndexes = null; 154 outputsIndexes = null; 155 isMemoryAllocated = false; 156 delegates.clear(); 157 for (Delegate ownedDelegate : ownedDelegates) { 158 ownedDelegate.close(); 159 } 160 ownedDelegates.clear(); 161 } 162 163 /** Runs model inference based on SignatureDef provided through {@code signatureKey}. */ runSignature( Map<String, Object> inputs, Map<String, Object> outputs, String signatureKey)164 public void runSignature( 165 Map<String, Object> inputs, Map<String, Object> outputs, String signatureKey) { 166 inferenceDurationNanoseconds = -1; 167 if (inputs == null || inputs.isEmpty()) { 168 throw new IllegalArgumentException("Input error: Inputs should not be null or empty."); 169 } 170 if (outputs == null) { 171 throw new IllegalArgumentException("Input error: Outputs should not be null."); 172 } 173 NativeSignatureRunnerWrapper signatureRunnerWrapper = getSignatureRunnerWrapper(signatureKey); 174 int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex(); 175 if (subgraphIndex == 0) { 176 // Map inputs/output to input indexes. 177 Object[] inputsList = new Object[inputs.size()]; 178 for (Map.Entry<String, Object> input : inputs.entrySet()) { 179 inputsList[signatureRunnerWrapper.getInputIndex(input.getKey())] = input.getValue(); 180 } 181 Map<Integer, Object> outputsWithOutputIndex = new TreeMap<>(); 182 for (Map.Entry<String, Object> output : outputs.entrySet()) { 183 outputsWithOutputIndex.put( 184 signatureRunnerWrapper.getOutputIndex(output.getKey()), output.getValue()); 185 } 186 run(inputsList, outputsWithOutputIndex); 187 return; 188 } 189 190 for (Map.Entry<String, Object> input : inputs.entrySet()) { 191 TensorImpl tensor = getInputTensor(input.getKey(), signatureKey); 192 int[] newShape = tensor.getInputShapeIfDifferent(input.getValue()); 193 if (newShape != null) { 194 signatureRunnerWrapper.resizeInput(input.getKey(), newShape); 195 } 196 } 197 198 signatureRunnerWrapper.allocateTensorsIfNeeded(); 199 200 for (Map.Entry<String, Object> input : inputs.entrySet()) { 201 signatureRunnerWrapper.getInputTensor(input.getKey()).setTo(input.getValue()); 202 } 203 204 long inferenceStartNanos = System.nanoTime(); 205 signatureRunnerWrapper.invoke(); 206 long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos; 207 208 for (Map.Entry<String, Object> output : outputs.entrySet()) { 209 // Null output placeholders are allowed and ignored. 210 if (output.getValue() != null) { 211 signatureRunnerWrapper.getOutputTensor(output.getKey()).copyTo(output.getValue()); 212 } 213 } 214 215 // Only set if the entire operation succeeds. 216 this.inferenceDurationNanoseconds = inferenceDurationNanoseconds; 217 } 218 219 /** Sets inputs, runs model inference and returns outputs. */ run(Object[] inputs, Map<Integer, Object> outputs)220 void run(Object[] inputs, Map<Integer, Object> outputs) { 221 inferenceDurationNanoseconds = -1; 222 if (inputs == null || inputs.length == 0) { 223 throw new IllegalArgumentException("Input error: Inputs should not be null or empty."); 224 } 225 if (outputs == null) { 226 throw new IllegalArgumentException("Input error: Outputs should not be null."); 227 } 228 229 // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs. 230 // Rather than forcing an immediate resize + allocation if an input's shape differs, we first 231 // flush all resizes, avoiding redundant allocations. 232 for (int i = 0; i < inputs.length; ++i) { 233 TensorImpl tensor = getInputTensor(i); 234 int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]); 235 if (newShape != null) { 236 resizeInput(i, newShape); 237 } 238 } 239 240 boolean allocatedTensors = allocateTensorsIfNeeded(); 241 242 for (int i = 0; i < inputs.length; ++i) { 243 getInputTensor(i).setTo(inputs[i]); 244 } 245 246 long inferenceStartNanos = System.nanoTime(); 247 run(interpreterHandle, errorHandle); 248 long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos; 249 250 // Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes. 251 if (allocatedTensors) { 252 for (TensorImpl outputTensor : outputTensors) { 253 if (outputTensor != null) { 254 outputTensor.refreshShape(); 255 } 256 } 257 } 258 for (Map.Entry<Integer, Object> output : outputs.entrySet()) { 259 // Null output placeholders are allowed and ignored. 260 if (output.getValue() != null) { 261 getOutputTensor(output.getKey()).copyTo(output.getValue()); 262 } 263 } 264 265 // Only set if the entire operation succeeds. 266 this.inferenceDurationNanoseconds = inferenceDurationNanoseconds; 267 } 268 269 /** Resizes dimensions of a specific input. */ resizeInput(int idx, int[] dims)270 void resizeInput(int idx, int[] dims) { 271 resizeInput(idx, dims, false); 272 } 273 274 /** Resizes dimensions of a specific input. */ resizeInput(int idx, int[] dims, boolean strict)275 void resizeInput(int idx, int[] dims, boolean strict) { 276 if (resizeInput(interpreterHandle, errorHandle, idx, dims, strict)) { 277 // Tensor allocation is deferred until either an explicit `allocateTensors()` call or 278 // `invoke()` avoiding redundant allocations if multiple tensors are simultaneosly resized. 279 isMemoryAllocated = false; 280 if (inputTensors[idx] != null) { 281 inputTensors[idx].refreshShape(); 282 } 283 } 284 } 285 286 /** Triggers explicit allocation of tensors. */ allocateTensors()287 void allocateTensors() { 288 allocateTensorsIfNeeded(); 289 } 290 291 /** 292 * Allocates tensor memory space in the given subgraph and returns true when allocation happens 293 */ allocateTensorsIfNeeded()294 private boolean allocateTensorsIfNeeded() { 295 if (isMemoryAllocated) { 296 return false; 297 } 298 299 isMemoryAllocated = true; 300 allocateTensors(interpreterHandle, errorHandle); 301 for (TensorImpl outputTensor : outputTensors) { 302 if (outputTensor != null) { 303 outputTensor.refreshShape(); 304 } 305 } 306 return true; 307 } 308 309 /** Gets index of an input given its name. */ getInputIndex(String name)310 int getInputIndex(String name) { 311 if (inputsIndexes == null) { 312 String[] names = getInputNames(interpreterHandle); 313 inputsIndexes = new HashMap<>(); 314 if (names != null) { 315 for (int i = 0; i < names.length; ++i) { 316 inputsIndexes.put(names[i], i); 317 } 318 } 319 } 320 if (inputsIndexes.containsKey(name)) { 321 return inputsIndexes.get(name); 322 } else { 323 throw new IllegalArgumentException( 324 String.format( 325 "Input error: '%s' is not a valid name for any input. Names of inputs and their " 326 + "indexes are %s", 327 name, inputsIndexes)); 328 } 329 } 330 331 /** Gets index of an output given its name. */ getOutputIndex(String name)332 int getOutputIndex(String name) { 333 if (outputsIndexes == null) { 334 String[] names = getOutputNames(interpreterHandle); 335 outputsIndexes = new HashMap<>(); 336 if (names != null) { 337 for (int i = 0; i < names.length; ++i) { 338 outputsIndexes.put(names[i], i); 339 } 340 } 341 } 342 if (outputsIndexes.containsKey(name)) { 343 return outputsIndexes.get(name); 344 } else { 345 throw new IllegalArgumentException( 346 String.format( 347 "Input error: '%s' is not a valid name for any output. Names of outputs and their " 348 + "indexes are %s", 349 name, outputsIndexes)); 350 } 351 } 352 353 /** 354 * Gets the last inference duration in nanoseconds. It returns null if there is no previous 355 * inference run or the last inference run failed. 356 */ getLastNativeInferenceDurationNanoseconds()357 Long getLastNativeInferenceDurationNanoseconds() { 358 return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds; 359 } 360 361 /** Gets the number of input tensors. */ getInputTensorCount()362 int getInputTensorCount() { 363 return inputTensors.length; 364 } 365 366 /** 367 * Gets the input {@link TensorImpl} for the provided input index. 368 * 369 * @throws IllegalArgumentException if the input index is invalid. 370 */ getInputTensor(int index)371 TensorImpl getInputTensor(int index) { 372 if (index < 0 || index >= inputTensors.length) { 373 throw new IllegalArgumentException("Invalid input Tensor index: " + index); 374 } 375 TensorImpl inputTensor = inputTensors[index]; 376 if (inputTensor == null) { 377 inputTensor = 378 inputTensors[index] = 379 TensorImpl.fromIndex( 380 interpreterHandle, getInputTensorIndex(interpreterHandle, index)); 381 } 382 return inputTensor; 383 } 384 385 /** 386 * Gets the input {@link TensorImpl} given the tensor name and method in the signature. 387 * 388 * @throws IllegalArgumentException if the input name is invalid. 389 */ getInputTensor(String inputName, String signatureKey)390 TensorImpl getInputTensor(String inputName, String signatureKey) { 391 if (inputName == null) { 392 throw new IllegalArgumentException("Invalid input tensor name provided (null)"); 393 } 394 NativeSignatureRunnerWrapper signatureRunnerWrapper = getSignatureRunnerWrapper(signatureKey); 395 int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex(); 396 if (subgraphIndex > 0) { 397 return signatureRunnerWrapper.getInputTensor(inputName); 398 } 399 400 int inputIndex = signatureRunnerWrapper.getInputIndex(inputName); 401 return getInputTensor(inputIndex); 402 } 403 404 /** Gets the keys of SignatureDefs available in the model, if any. */ getSignatureKeys()405 public String[] getSignatureKeys() { 406 return getSignatureKeys(interpreterHandle); 407 } 408 409 /** Gets the list of SignatureDefs inputs for method {@code signatureKey} */ getSignatureInputs(String signatureKey)410 String[] getSignatureInputs(String signatureKey) { 411 return getSignatureRunnerWrapper(signatureKey).inputNames(); 412 } 413 414 /** Gets the list of SignatureDefs outputs for method {@code signatureKey} */ getSignatureOutputs(String signatureKey)415 String[] getSignatureOutputs(String signatureKey) { 416 return getSignatureRunnerWrapper(signatureKey).outputNames(); 417 } 418 419 /** Gets the number of output tensors. */ getOutputTensorCount()420 int getOutputTensorCount() { 421 return outputTensors.length; 422 } 423 424 /** 425 * Gets the output {@link TensorImpl} for the provided output index. 426 * 427 * @throws IllegalArgumentException if the output index is invalid. 428 */ getOutputTensor(int index)429 TensorImpl getOutputTensor(int index) { 430 if (index < 0 || index >= outputTensors.length) { 431 throw new IllegalArgumentException("Invalid output Tensor index: " + index); 432 } 433 TensorImpl outputTensor = outputTensors[index]; 434 if (outputTensor == null) { 435 outputTensor = 436 outputTensors[index] = 437 TensorImpl.fromIndex( 438 interpreterHandle, getOutputTensorIndex(interpreterHandle, index)); 439 } 440 return outputTensor; 441 } 442 443 /** 444 * Gets the output {@link TensorImpl} given the tensor name and method in the signature. 445 * 446 * @throws IllegalArgumentException if the output name is invalid. 447 */ getOutputTensor(String outputName, String signatureKey)448 TensorImpl getOutputTensor(String outputName, String signatureKey) { 449 if (outputName == null) { 450 throw new IllegalArgumentException("Invalid output tensor name provided (null)"); 451 } 452 NativeSignatureRunnerWrapper signatureRunnerWrapper = getSignatureRunnerWrapper(signatureKey); 453 int subgraphIndex = signatureRunnerWrapper.getSubgraphIndex(); 454 if (subgraphIndex > 0) { 455 return signatureRunnerWrapper.getOutputTensor(outputName); 456 } 457 458 int outputIndex = signatureRunnerWrapper.getOutputIndex(outputName); 459 return getOutputTensor(outputIndex); 460 } 461 462 /** Gets the number of ops in the execution plan. */ getExecutionPlanLength()463 int getExecutionPlanLength() { 464 return getExecutionPlanLength(interpreterHandle); 465 } 466 467 /** 468 * Sets internal cancellation flag. If it's true, the interpreter will try to interrupt any 469 * invocation between ops. 470 */ setCancelled(boolean value)471 void setCancelled(boolean value) { 472 if (cancellationFlagHandle == 0) { 473 throw new IllegalStateException( 474 "Cannot cancel the inference. Have you called InterpreterApi.Options.setCancellable?"); 475 } 476 setCancelled(interpreterHandle, cancellationFlagHandle, value); 477 } 478 479 // Add all the delegates specified in the options (other than XNNPACK) to this.delegates. addDelegates(InterpreterImpl.Options options)480 private void addDelegates(InterpreterImpl.Options options) { 481 // First add the flex delegate if necessary. This ensures the graph is fully resolved before 482 // applying other delegates. 483 if (originalGraphHasUnresolvedFlexOp) { 484 Delegate optionalFlexDelegate = maybeCreateFlexDelegate(options.getDelegates()); 485 if (optionalFlexDelegate != null) { 486 ownedDelegates.add(optionalFlexDelegate); 487 delegates.add(optionalFlexDelegate); 488 } 489 } 490 // Now add the user-supplied delegates. 491 addUserProvidedDelegates(options); 492 for (DelegateFactory delegateFactory : options.getDelegateFactories()) { 493 Delegate delegate = delegateFactory.create(RUNTIME_FLAVOR); 494 ownedDelegates.add(delegate); 495 delegates.add(delegate); 496 } 497 if (options.getUseNNAPI()) { 498 NnApiDelegate optionalNnApiDelegate = new NnApiDelegate(); 499 ownedDelegates.add(optionalNnApiDelegate); 500 delegates.add(optionalNnApiDelegate); 501 } 502 } 503 addUserProvidedDelegates(Options options)504 private void addUserProvidedDelegates(Options options) { 505 for (Delegate delegate : options.getDelegates()) { 506 // NnApiDelegate is compatible with both the system and built-in runtimes and therefore can be 507 // added directly even when using TF Lite from the system. 508 if (options.getRuntime() != TfLiteRuntime.FROM_APPLICATION_ONLY 509 && !(delegate instanceof NnApiDelegate)) { 510 throw new IllegalArgumentException( 511 "Instantiated delegates (other than NnApiDelegate) are not allowed when using TF Lite" 512 + " from Google Play Services. Please use" 513 + " InterpreterApi.Options.addDelegateFactory() with an appropriate DelegateFactory" 514 + " instead."); 515 } 516 delegates.add(delegate); 517 } 518 } 519 520 // Complete the initialization of any delegates that require an InterpreterFactoryApi instance. initDelegatesWithInterpreterFactory()521 private void initDelegatesWithInterpreterFactory() { 522 InterpreterFactoryApi interpreterFactoryApi = new InterpreterFactoryImpl(); 523 for (Delegate delegate : delegates) { 524 if (delegate instanceof NnApiDelegate) { 525 ((NnApiDelegate) delegate).initWithInterpreterFactoryApi(interpreterFactoryApi); 526 } 527 } 528 } 529 getSignatureRunnerWrapper(String signatureKey)530 private NativeSignatureRunnerWrapper getSignatureRunnerWrapper(String signatureKey) { 531 if (signatureRunnerMap == null) { 532 signatureRunnerMap = new HashMap<>(); 533 } 534 if (!signatureRunnerMap.containsKey(signatureKey)) { 535 signatureRunnerMap.put( 536 signatureKey, 537 new NativeSignatureRunnerWrapper(interpreterHandle, errorHandle, signatureKey)); 538 } 539 return signatureRunnerMap.get(signatureKey); 540 } 541 maybeCreateFlexDelegate(List<Delegate> delegates)542 private static Delegate maybeCreateFlexDelegate(List<Delegate> delegates) { 543 try { 544 Class<?> clazz = Class.forName("org.tensorflow.lite.flex.FlexDelegate"); 545 // No need to create the Flex delegate if one has already been provided. 546 for (Delegate delegate : delegates) { 547 if (clazz.isInstance(delegate)) { 548 return null; 549 } 550 } 551 return (Delegate) clazz.getConstructor().newInstance(); 552 } catch (ClassNotFoundException 553 | IllegalAccessException 554 | IllegalArgumentException 555 | InstantiationException 556 | InvocationTargetException 557 | NoSuchMethodException 558 | SecurityException e) { 559 // The error will propagate when tensors are allocated. 560 return null; 561 } 562 } 563 564 private static final int ERROR_BUFFER_SIZE = 512; 565 566 long errorHandle; 567 568 long interpreterHandle; 569 570 private long modelHandle; 571 572 private long cancellationFlagHandle = 0; 573 574 @UsedByReflection("nativeinterpreterwrapper_jni.cc") 575 private long inferenceDurationNanoseconds = -1; 576 577 private ByteBuffer modelByteBuffer; 578 579 // Lazily constructed maps of input and output names to input and output Tensor indexes. 580 private Map<String, Integer> inputsIndexes; 581 private Map<String, Integer> outputsIndexes; 582 583 // A map from signature key to its native wrapper object. 584 private Map<String, NativeSignatureRunnerWrapper> signatureRunnerMap; 585 586 // Lazily constructed and populated arrays of input and output Tensor wrappers. 587 private TensorImpl[] inputTensors; 588 private TensorImpl[] outputTensors; 589 590 // Whether subgraph's tensor memory space is allocated. 591 private boolean isMemoryAllocated = false; 592 593 // Whether the model has any Flex custom ops that can't be resolved by the OpResolver. 594 private boolean originalGraphHasUnresolvedFlexOp = false; 595 596 // As the Java Delegate owns the native delegate instance, we keep a strong ref to any injected 597 // delegates for safety. 598 private final List<Delegate> delegates = new ArrayList<>(); 599 600 // List of owned delegates that must be closed when the interpreter is closed. 601 private final List<Delegate> ownedDelegates = new ArrayList<>(); 602 run(long interpreterHandle, long errorHandle)603 private static native void run(long interpreterHandle, long errorHandle); 604 resizeInput( long interpreterHandle, long errorHandle, int inputIdx, int[] dims, boolean strict)605 private static native boolean resizeInput( 606 long interpreterHandle, long errorHandle, int inputIdx, int[] dims, boolean strict); 607 allocateTensors(long interpreterHandle, long errorHandle)608 private static native long allocateTensors(long interpreterHandle, long errorHandle); 609 getSignatureKeys(long interpreterHandle)610 private static native String[] getSignatureKeys(long interpreterHandle); 611 setCancelled( long interpreterHandle, long cancellationFlagHandle, boolean value)612 private static native void setCancelled( 613 long interpreterHandle, long cancellationFlagHandle, boolean value); 614 hasUnresolvedFlexOp(long interpreterHandle)615 private static native boolean hasUnresolvedFlexOp(long interpreterHandle); 616 getInputTensorIndex(long interpreterHandle, int inputIdx)617 private static native int getInputTensorIndex(long interpreterHandle, int inputIdx); 618 getOutputTensorIndex(long interpreterHandle, int outputIdx)619 private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx); 620 getInputCount(long interpreterHandle)621 private static native int getInputCount(long interpreterHandle); 622 getOutputCount(long interpreterHandle)623 private static native int getOutputCount(long interpreterHandle); 624 getExecutionPlanLength(long interpreterHandle)625 private static native int getExecutionPlanLength(long interpreterHandle); 626 getInputNames(long interpreterHandle)627 private static native String[] getInputNames(long interpreterHandle); 628 getOutputNames(long interpreterHandle)629 private static native String[] getOutputNames(long interpreterHandle); 630 allowFp16PrecisionForFp32(long interpreterHandle, boolean allow)631 private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow); 632 allowBufferHandleOutput(long interpreterHandle, boolean allow)633 private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow); 634 createErrorReporter(int size)635 private static native long createErrorReporter(int size); 636 createModel(String modelPathOrBuffer, long errorHandle)637 private static native long createModel(String modelPathOrBuffer, long errorHandle); 638 createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle)639 private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle); 640 createInterpreter( long modelHandle, long errorHandle, int numThreads, boolean useXnnpack, List<Long> delegateHandles)641 private static native long createInterpreter( 642 long modelHandle, 643 long errorHandle, 644 int numThreads, 645 boolean useXnnpack, 646 List<Long> delegateHandles); 647 createCancellationFlag(long interpreterHandle)648 private static native long createCancellationFlag(long interpreterHandle); 649 deleteCancellationFlag(long cancellationFlagHandle)650 private static native long deleteCancellationFlag(long cancellationFlagHandle); 651 delete(long errorHandle, long modelHandle, long interpreterHandle)652 private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); 653 } 654