1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.io.File; 19 import java.nio.ByteBuffer; 20 import java.nio.MappedByteBuffer; 21 import java.util.ArrayList; 22 import java.util.Arrays; 23 import java.util.HashMap; 24 import java.util.List; 25 import java.util.Map; 26 // Use Android annotation instead 27 import android.support.annotation.NonNull; 28 // import org.checkerframework.checker.nullness.qual.NonNull; 29 30 /** 31 * Driver class to drive model inference with TensorFlow Lite. 32 * 33 * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations 34 * are executed for model inference. 35 * 36 * <p>For example, if a model takes only one input and returns only one output: 37 * 38 * <pre>{@code 39 * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { 40 * interpreter.run(input, output); 41 * } 42 * }</pre> 43 * 44 * <p>If a model takes multiple inputs or outputs: 45 * 46 * <pre>{@code 47 * Object[] inputs = {input0, input1, ...}; 48 * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>(); 49 * FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4. 50 * ith_output.order(ByteOrder.nativeOrder()); 51 * map_of_indices_to_outputs.put(i, ith_output); 52 * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { 53 * interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); 54 * } 55 * }</pre> 56 * 57 * <p>If a model takes or produces string tensors: 58 * 59 * <pre>{@code 60 * String[] input = {"foo", "bar"}; // Input tensor shape is [2]. 61 * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. 62 * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { 63 * interpreter.runForMultipleInputsOutputs(input, output); 64 * } 65 * }</pre> 66 * 67 * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite 68 * model with Toco, as are the default shapes of the inputs. 69 * 70 * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will 71 * be implicitly resized according to that array's shape. When inputs are provided as {@link 72 * java.nio.Buffer} types, no implicit resizing is done; the caller must ensure that the {@link 73 * java.nio.Buffer} byte size either matches that of the corresponding tensor, or that they first 74 * resize the tensor via {@link #resizeInput(int, int[])}. Tensor shape and type information can be 75 * obtained via the {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link 76 * #getOutputTensor(int)}. 77 * 78 * <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code 79 * Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()} 80 * 81 * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19, 82 * but is not guaranteed. 83 * 84 * <p>Note: This class is not thread safe. 85 */ 86 public final class Interpreter implements AutoCloseable { 87 88 /** An options class for controlling runtime interpreter behavior. */ 89 public static class Options { Options()90 public Options() {} 91 92 /** 93 * Sets the number of threads to be used for ops that support multi-threading. Defaults to a 94 * platform-dependent value. 95 */ setNumThreads(int numThreads)96 public Options setNumThreads(int numThreads) { 97 this.numThreads = numThreads; 98 return this; 99 } 100 101 /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */ setUseNNAPI(boolean useNNAPI)102 public Options setUseNNAPI(boolean useNNAPI) { 103 this.useNNAPI = useNNAPI; 104 return this; 105 } 106 107 /** 108 * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false 109 * (disallow). 110 * 111 * @deprecated Prefer using {@link 112 * org.tensorflow.lite.nnapi.NnApiDelegate.Options#setAllowFp16(boolean enable)}. 113 */ 114 @Deprecated setAllowFp16PrecisionForFp32(boolean allow)115 public Options setAllowFp16PrecisionForFp32(boolean allow) { 116 this.allowFp16PrecisionForFp32 = allow; 117 return this; 118 } 119 120 /** 121 * Adds a {@link Delegate} to be applied during interpreter creation. 122 * 123 * <p>WARNING: This is an experimental interface that is subject to change. 124 */ addDelegate(Delegate delegate)125 public Options addDelegate(Delegate delegate) { 126 delegates.add(delegate); 127 return this; 128 } 129 130 /** 131 * Advanced: Set if buffer handle output is allowed. 132 * 133 * <p>When a {@link Delegate} supports hardware acceleration, the interpreter will make the data 134 * of output tensors available in the CPU-allocated tensor buffers by default. If the client can 135 * consume the buffer handle directly (e.g. reading output from OpenGL texture), it can set this 136 * flag to false, avoiding the copy of data to the CPU buffer. The delegate documentation should 137 * indicate whether this is supported and how it can be used. 138 * 139 * <p>WARNING: This is an experimental interface that is subject to change. 140 */ setAllowBufferHandleOutput(boolean allow)141 public Options setAllowBufferHandleOutput(boolean allow) { 142 this.allowBufferHandleOutput = allow; 143 return this; 144 } 145 146 /** 147 * Advanced: Set if the interpreter is able to be cancelled. 148 * 149 * @see {@link Interpreter#setCancelled(boolean)}. 150 */ setCancellable(boolean allow)151 public Options setCancellable(boolean allow) { 152 this.allowCancellation = allow; 153 return this; 154 } 155 156 /** 157 * Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK). 158 * 159 * <p>Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided 160 * via the XNNPACK delegate. Currently, this is restricted to a subset of floating point 161 * operations. Eventually, we plan to enable this by default, as it can provide significant 162 * peformance benefits for many classes of floating point models. See 163 * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md 164 * for more details. 165 * 166 * <p>Things to keep in mind when enabling this flag: 167 * 168 * <ul> 169 * <li>Startup time and resize time may increase. 170 * <li>Baseline memory consumption may increase. 171 * <li>May be ignored if another delegate (eg NNAPI) have been applied. 172 * <li>Quantized models will not see any benefit. 173 * </ul> 174 * 175 * <p>WARNING: This is an experimental interface that is subject to change. 176 */ setUseXNNPACK(boolean useXNNPACK)177 public Options setUseXNNPACK(boolean useXNNPACK) { 178 this.useXNNPACK = useXNNPACK; 179 return this; 180 } 181 182 int numThreads = -1; 183 Boolean useNNAPI; 184 Boolean allowFp16PrecisionForFp32; 185 Boolean allowBufferHandleOutput; 186 Boolean allowCancellation; 187 188 // TODO(b/171856982): update the comment when applying XNNPACK delegate by default is 189 // enabled for C++ TfLite library on Android platform. 190 // Note: the initial "null" value indicates default behavior which may mean XNNPACK 191 // delegate will be applied by default. 192 Boolean useXNNPACK; 193 final List<Delegate> delegates = new ArrayList<>(); 194 } 195 196 /** 197 * Initializes a {@code Interpreter} 198 * 199 * @param modelFile: a File of a pre-trained TF Lite model. 200 * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite 201 * model. 202 */ Interpreter(@onNull File modelFile)203 public Interpreter(@NonNull File modelFile) { 204 this(modelFile, /*options = */ null); 205 } 206 207 /** 208 * Initializes a {@code Interpreter} and specifies the number of threads used for inference. 209 * 210 * @param modelFile: a file of a pre-trained TF Lite model 211 * @param numThreads: number of threads to use for inference 212 * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will 213 * be removed in a future release. 214 */ 215 @Deprecated Interpreter(@onNull File modelFile, int numThreads)216 public Interpreter(@NonNull File modelFile, int numThreads) { 217 this(modelFile, new Options().setNumThreads(numThreads)); 218 } 219 220 /** 221 * Initializes a {@code Interpreter} and specifies the number of threads used for inference. 222 * 223 * @param modelFile: a file of a pre-trained TF Lite model 224 * @param options: a set of options for customizing interpreter behavior 225 * @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite 226 * model. 227 */ Interpreter(@onNull File modelFile, Options options)228 public Interpreter(@NonNull File modelFile, Options options) { 229 wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options); 230 signatureNameList = getSignatureDefNames(); 231 } 232 233 /** 234 * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file. 235 * 236 * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The 237 * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a 238 * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. 239 * 240 * @throws IllegalArgumentException if {@code byteBuffer} is not a {@link MappedByteBuffer} nor a 241 * direct {@link Bytebuffer} of nativeOrder. 242 */ Interpreter(@onNull ByteBuffer byteBuffer)243 public Interpreter(@NonNull ByteBuffer byteBuffer) { 244 this(byteBuffer, /* options= */ null); 245 } 246 247 /** 248 * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and specifies the 249 * number of threads used for inference. 250 * 251 * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The 252 * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a 253 * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. 254 * 255 * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method 256 * will be removed in a future release. 257 */ 258 @Deprecated Interpreter(@onNull ByteBuffer byteBuffer, int numThreads)259 public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { 260 this(byteBuffer, new Options().setNumThreads(numThreads)); 261 } 262 263 /** 264 * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. 265 * 266 * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code 267 * Interpreter}. 268 * 269 * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method 270 * will be removed in a future release. 271 */ 272 @Deprecated Interpreter(@onNull MappedByteBuffer mappedByteBuffer)273 public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) { 274 this(mappedByteBuffer, /* options= */ null); 275 } 276 277 /** 278 * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom 279 * {@link Interpreter.Options}. 280 * 281 * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The 282 * {@code ByteBuffer} can be either a {@link MappedByteBuffer} that memory-maps a model file, or a 283 * direct {@link ByteBuffer} of nativeOrder() that contains the bytes content of a model. 284 * 285 * @throws IllegalArgumentException if {@code byteBuffer} is not a {@link MappedByteBuffer} nor a 286 * direct {@link Bytebuffer} of nativeOrder. 287 */ Interpreter(@onNull ByteBuffer byteBuffer, Options options)288 public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) { 289 wrapper = new NativeInterpreterWrapper(byteBuffer, options); 290 signatureNameList = getSignatureDefNames(); 291 } 292 293 /** 294 * Runs model inference if the model takes only one input, and provides only one output. 295 * 296 * <p>Warning: The API is more efficient if a {@link java.nio.Buffer} (preferably direct, but not 297 * required) is used as the input/output data type. Please consider using {@link java.nio.Buffer} 298 * to feed and fetch primitive data for better performance. The following concrete {@link 299 * java.nio.Buffer} types are supported: 300 * 301 * <ul> 302 * <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type. 303 * <li>{@link java.nio.FloatBuffer} - compatible with float Tensors. 304 * <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors. 305 * <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors. 306 * </ul> 307 * 308 * Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as 309 * scalar inputs. 310 * 311 * @param input an array or multidimensional array, or a {@link java.nio.Buffer} of primitive 312 * types including int, float, long, and byte. {@link java.nio.Buffer} is the preferred way to 313 * pass large input data for primitive types, whereas string types require using the 314 * (multi-dimensional) array input path. When a {@link java.nio.Buffer} is used, its content 315 * should remain unchanged until model inference is done, and the caller must ensure that the 316 * {@link java.nio.Buffer} is at the appropriate read position. A {@code null} value is 317 * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, 318 * and such a buffer has been bound to the input {@link Tensor}. 319 * @param output a multidimensional array of output data, or a {@link java.nio.Buffer} of 320 * primitive types including int, float, long, and byte. When a {@link java.nio.Buffer} is 321 * used, the caller must ensure that it is set the appropriate write position. A null value is 322 * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, 323 * and such a buffer has been bound to the output {@link Tensor}. See {@link 324 * Interpreter.Options#setAllowBufferHandleOutput(boolean)}. 325 * @throws IllegalArgumentException if {@code input} or {@code output} is null or empty, or if 326 * error occurs when running the inference. 327 * @throws IllegalArgumentException (EXPERIMENTAL, subject to change) if the inference is 328 * interrupted by {@code setCancelled(true)}. 329 */ run(Object input, Object output)330 public void run(Object input, Object output) { 331 Object[] inputs = {input}; 332 Map<Integer, Object> outputs = new HashMap<>(); 333 outputs.put(0, output); 334 runForMultipleInputsOutputs(inputs, outputs); 335 } 336 337 /** 338 * Runs model inference if the model takes multiple inputs, or returns multiple outputs. 339 * 340 * <p>Warning: The API is more efficient if {@link java.nio.Buffer}s (preferably direct, but not 341 * required) are used as the input/output data types. Please consider using {@link 342 * java.nio.Buffer} to feed and fetch primitive data for better performance. The following 343 * concrete {@link java.nio.Buffer} types are supported: 344 * 345 * <ul> 346 * <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type. 347 * <li>{@link java.nio.FloatBuffer} - compatible with float Tensors. 348 * <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors. 349 * <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors. 350 * </ul> 351 * 352 * Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as 353 * scalar inputs. 354 * 355 * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is 356 * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and 357 * such a buffer has been bound to the corresponding input or output {@link Tensor}(s). 358 * 359 * @param inputs an array of input data. The inputs should be in the same order as inputs of the 360 * model. Each input can be an array or multidimensional array, or a {@link java.nio.Buffer} 361 * of primitive types including int, float, long, and byte. {@link java.nio.Buffer} is the 362 * preferred way to pass large input data, whereas string types require using the 363 * (multi-dimensional) array input path. When {@link java.nio.Buffer} is used, its content 364 * should remain unchanged until model inference is done, and the caller must ensure that the 365 * {@link java.nio.Buffer} is at the appropriate read position. 366 * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link 367 * java.nio.Buffer}s of primitive types including int, float, long, and byte. It only needs to 368 * keep entries for the outputs to be used. When a {@link java.nio.Buffer} is used, the caller 369 * must ensure that it is set the appropriate write position. 370 * @throws IllegalArgumentException if {@code inputs} or {@code outputs} is null or empty, or if 371 * error occurs when running the inference. 372 */ runForMultipleInputsOutputs( @onNull Object[] inputs, @NonNull Map<Integer, Object> outputs)373 public void runForMultipleInputsOutputs( 374 @NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) { 375 checkNotClosed(); 376 wrapper.run(inputs, outputs); 377 } 378 379 /** 380 * Runs model inference based on SignatureDef provided through @code methodName. 381 * 382 * <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output 383 * data types. 384 * 385 * @param inputs A Map of inputs from input name in the signatureDef to an input object. 386 * @param outputs a map mapping from output name in SignatureDef to output data. 387 * @param methodName The exported method name identifying the SignatureDef. 388 * @throws IllegalArgumentException if {@code inputs} or {@code outputs} or {@code methodName} is 389 * null or empty, or if error occurs when running the inference. 390 * 391 * <p>WARNING: This is an experimental API and subject to change. 392 */ runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs, String methodName)393 public void runSignature( 394 @NonNull Map<String, Object> inputs, 395 @NonNull Map<String, Object> outputs, 396 String methodName) { 397 checkNotClosed(); 398 if (methodName == null && signatureNameList.length == 1) { 399 methodName = signatureNameList[0]; 400 } 401 if (methodName == null) { 402 throw new IllegalArgumentException( 403 "Input error: SignatureDef methodName should not be null. null is only allowed if the" 404 + " model has a single Signature. Available Signatures: " 405 + Arrays.toString(signatureNameList)); 406 } 407 wrapper.runSignature(inputs, outputs, methodName); 408 } 409 410 /* Same as {@link Interpreter#runSignature(Object, Object, String)} but doesn't require 411 * passing a methodName, assuming the model has one SignatureDef. If the model has more than 412 * one SignatureDef it will throw an exception. 413 * 414 * * <p>WARNING: This is an experimental API and subject to change. 415 * */ runSignature( @onNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs)416 public void runSignature( 417 @NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) { 418 checkNotClosed(); 419 runSignature(inputs, outputs, null); 420 } 421 422 /** 423 * Expicitly updates allocations for all tensors, if necessary. 424 * 425 * <p>This will propagate shapes and memory allocations for all dependent tensors using the input 426 * tensor shape(s) as given. 427 * 428 * <p>Note: This call is *purely optional*. Tensor allocation will occur automatically during 429 * execution if any input tensors have been resized. This call is most useful in determining the 430 * shapes for any output tensors before executing the graph, e.g., 431 * 432 * <pre>{@code 433 * interpreter.resizeInput(0, new int[]{1, 4, 4, 3})); 434 * interpreter.allocateTensors(); 435 * FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0),numElements()); 436 * // Populate inputs... 437 * FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements()); 438 * interpreter.run(input, output) 439 * // Process outputs... 440 * }</pre> 441 * 442 * @throws IllegalStateException if the graph's tensors could not be successfully allocated. 443 */ allocateTensors()444 public void allocateTensors() { 445 checkNotClosed(); 446 wrapper.allocateTensors(); 447 } 448 449 /** 450 * Resizes idx-th input of the native model to the given dims. 451 * 452 * @throws IllegalArgumentException if {@code idx} is negtive or is not smaller than the number of 453 * model inputs; or if error occurs when resizing the idx-th input. 454 */ resizeInput(int idx, @NonNull int[] dims)455 public void resizeInput(int idx, @NonNull int[] dims) { 456 checkNotClosed(); 457 wrapper.resizeInput(idx, dims, false); 458 } 459 460 /** 461 * Resizes idx-th input of the native model to the given dims. 462 * 463 * <p>When `strict` is True, only unknown dimensions can be resized. Unknown dimensions are 464 * indicated as `-1` in the array returned by `Tensor.shapeSignature()`. 465 * 466 * @throws IllegalArgumentException if {@code idx} is negtive or is not smaller than the number of 467 * model inputs; or if error occurs when resizing the idx-th input. Additionally, the error 468 * occurs when attempting to resize a tensor with fixed dimensions when `struct` is True. 469 */ resizeInput(int idx, @NonNull int[] dims, boolean strict)470 public void resizeInput(int idx, @NonNull int[] dims, boolean strict) { 471 checkNotClosed(); 472 wrapper.resizeInput(idx, dims, strict); 473 } 474 475 /** Gets the number of input tensors. */ getInputTensorCount()476 public int getInputTensorCount() { 477 checkNotClosed(); 478 return wrapper.getInputTensorCount(); 479 } 480 481 /** 482 * Gets index of an input given the op name of the input. 483 * 484 * @throws IllegalArgumentException if {@code opName} does not match any input in the model used 485 * to initialize the {@link Interpreter}. 486 */ getInputIndex(String opName)487 public int getInputIndex(String opName) { 488 checkNotClosed(); 489 return wrapper.getInputIndex(opName); 490 } 491 492 /** 493 * Gets the Tensor associated with the provdied input index. 494 * 495 * @throws IllegalArgumentException if {@code inputIndex} is negtive or is not smaller than the 496 * number of model inputs. 497 */ getInputTensor(int inputIndex)498 public Tensor getInputTensor(int inputIndex) { 499 checkNotClosed(); 500 return wrapper.getInputTensor(inputIndex); 501 } 502 503 /** 504 * Gets the Tensor associated with the provdied input name and signature method name. 505 * 506 * @param inputName Input name in the signature. 507 * @param methodName The exported method name identifying the SignatureDef, can be null if the 508 * model has one signature. 509 * @throws IllegalArgumentException if {@code inputName} or {@code methodName} is null or empty, 510 * or invalid name provided. 511 * 512 * <p>WARNING: This is an experimental API and subject to change. 513 */ getInputTensorFromSignature(String inputName, String methodName)514 public Tensor getInputTensorFromSignature(String inputName, String methodName) { 515 checkNotClosed(); 516 if (methodName == null && signatureNameList.length == 1) { 517 methodName = signatureNameList[0]; 518 } 519 if (methodName == null) { 520 throw new IllegalArgumentException( 521 "Input error: SignatureDef methodName should not be null. null is only allowed if the" 522 + " model has a single Signature. Available Signatures: " 523 + Arrays.toString(signatureNameList)); 524 } 525 return wrapper.getInputTensor(inputName, methodName); 526 } 527 528 /** 529 * Gets the list of SignatureDef exported method names available in the model. 530 * 531 * <p>WARNING: This is an experimental API and subject to change. 532 */ getSignatureDefNames()533 public String[] getSignatureDefNames() { 534 checkNotClosed(); 535 return wrapper.getSignatureDefNames(); 536 } 537 538 /** 539 * Gets the list of SignatureDefs inputs for method {@code methodName} 540 * 541 * <p>WARNING: This is an experimental API and subject to change. 542 */ getSignatureInputs(String methodName)543 public String[] getSignatureInputs(String methodName) { 544 checkNotClosed(); 545 return wrapper.getSignatureInputs(methodName); 546 } 547 548 /** 549 * Gets the list of SignatureDefs outputs for method {@code methodName} 550 * 551 * <p>WARNING: This is an experimental API and subject to change. 552 */ getSignatureOutputs(String methodName)553 public String[] getSignatureOutputs(String methodName) { 554 checkNotClosed(); 555 return wrapper.getSignatureOutputs(methodName); 556 } 557 558 /** Gets the number of output Tensors. */ getOutputTensorCount()559 public int getOutputTensorCount() { 560 checkNotClosed(); 561 return wrapper.getOutputTensorCount(); 562 } 563 564 /** 565 * Gets index of an output given the op name of the output. 566 * 567 * @throws IllegalArgumentException if {@code opName} does not match any output in the model used 568 * to initialize the {@link Interpreter}. 569 */ getOutputIndex(String opName)570 public int getOutputIndex(String opName) { 571 checkNotClosed(); 572 return wrapper.getOutputIndex(opName); 573 } 574 575 /** 576 * Gets the Tensor associated with the provdied output index. 577 * 578 * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference 579 * is executed. If you need updated details *before* running inference (e.g., after resizing an 580 * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to 581 * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes 582 * that are dependent on input *values*, the output shape may not be fully determined until 583 * running inference. 584 * 585 * @throws IllegalArgumentException if {@code outputIndex} is negtive or is not smaller than the 586 * number of model outputs. 587 */ getOutputTensor(int outputIndex)588 public Tensor getOutputTensor(int outputIndex) { 589 checkNotClosed(); 590 return wrapper.getOutputTensor(outputIndex); 591 } 592 593 /** 594 * Gets the Tensor associated with the provdied output name in specifc signature method. 595 * 596 * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference 597 * is executed. If you need updated details *before* running inference (e.g., after resizing an 598 * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to 599 * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes 600 * that are dependent on input *values*, the output shape may not be fully determined until 601 * running inference. 602 * 603 * @param outputName Output name in the signature. 604 * @param methodName The exported method name identifying the SignatureDef, can be null if the 605 * model has one signature. 606 * @throws IllegalArgumentException if {@code outputName} or {@code methodName} is null or empty, 607 * or invalid name provided. 608 * 609 * <p>WARNING: This is an experimental API and subject to change. 610 */ getOutputTensorFromSignature(String outputName, String methodName)611 public Tensor getOutputTensorFromSignature(String outputName, String methodName) { 612 checkNotClosed(); 613 if (methodName == null && signatureNameList.length == 1) { 614 methodName = signatureNameList[0]; 615 } 616 if (methodName == null) { 617 throw new IllegalArgumentException( 618 "Input error: SignatureDef methodName should not be null. null is only allowed if the" 619 + " model has a single Signature. Available Signatures: " 620 + Arrays.toString(signatureNameList)); 621 } 622 return wrapper.getOutputTensor(outputName, methodName); 623 } 624 625 /** 626 * Returns native inference timing. 627 * 628 * @throws IllegalArgumentException if the model is not initialized by the {@link Interpreter}. 629 */ getLastNativeInferenceDurationNanoseconds()630 public Long getLastNativeInferenceDurationNanoseconds() { 631 checkNotClosed(); 632 return wrapper.getLastNativeInferenceDurationNanoseconds(); 633 } 634 635 /** 636 * Sets the number of threads to be used for ops that support multi-threading. 637 * 638 * @deprecated Prefer using {@link Interpreter.Options#setNumThreads(int)} directly for 639 * controlling thread multi-threading. This method will be removed in a future release. 640 */ 641 @Deprecated setNumThreads(int numThreads)642 public void setNumThreads(int numThreads) { 643 checkNotClosed(); 644 wrapper.setNumThreads(numThreads); 645 } 646 647 /** 648 * Advanced: Modifies the graph with the provided {@link Delegate}. 649 * 650 * @throws IllegalArgumentException if error occurs when modifying graph with {@code delegate}. 651 * @deprecated Prefer using {@link Interpreter.Options#addDelegate} to provide delegates at 652 * creation time. This method will be removed in a future release. 653 */ 654 @Deprecated modifyGraphWithDelegate(Delegate delegate)655 public void modifyGraphWithDelegate(Delegate delegate) { 656 checkNotClosed(); 657 wrapper.modifyGraphWithDelegate(delegate); 658 } 659 660 /** 661 * Advanced: Resets all variable tensors to the default value. 662 * 663 * <p>If a variable tensor doesn't have an associated buffer, it will be reset to zero. 664 * 665 * <p>WARNING: This is an experimental API and subject to change. 666 */ resetVariableTensors()667 public void resetVariableTensors() { 668 checkNotClosed(); 669 wrapper.resetVariableTensors(); 670 } 671 672 /** 673 * Advanced: Interrupts inference in the middle of a call to {@link Interpreter#run}. 674 * 675 * <p>A cancellation flag will be set to true when this function gets called. The interpreter will 676 * check the flag between Op invocations, and if it's {@code true}, the interpreter will stop 677 * execution. The interpreter will remain a cancelled state until explicitly "uncancelled" by 678 * {@code setCancelled(false)}. 679 * 680 * <p>WARNING: This is an experimental API and subject to change. 681 * 682 * @param cancelled {@code true} to cancel inference in a best-effort way; {@code false} to 683 * resume. 684 * @throws IllegalStateException if the interpreter is not initialized with the cancellable 685 * option, which is by default off. 686 * @see {@link Interpreter.Options#setCancellable(boolean)}. 687 */ setCancelled(boolean cancelled)688 public void setCancelled(boolean cancelled) { 689 wrapper.setCancelled(cancelled); 690 } 691 getExecutionPlanLength()692 int getExecutionPlanLength() { 693 checkNotClosed(); 694 return wrapper.getExecutionPlanLength(); 695 } 696 697 /** Release resources associated with the {@code Interpreter}. */ 698 @Override close()699 public void close() { 700 if (wrapper != null) { 701 wrapper.close(); 702 wrapper = null; 703 } 704 } 705 706 // for Object.finalize, see https://bugs.openjdk.java.net/browse/JDK-8165641 707 @SuppressWarnings("deprecation") 708 @Override finalize()709 protected void finalize() throws Throwable { 710 try { 711 close(); 712 } finally { 713 super.finalize(); 714 } 715 } 716 checkNotClosed()717 private void checkNotClosed() { 718 if (wrapper == null) { 719 throw new IllegalStateException("Internal error: The Interpreter has already been closed."); 720 } 721 } 722 723 NativeInterpreterWrapper wrapper; 724 String[] signatureNameList; 725 } 726