1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 /** 19 * An internal wrapper that wraps native SignatureRunner. 20 * 21 * <p>Note: This class is not thread safe. 22 */ 23 final class NativeSignatureRunnerWrapper { NativeSignatureRunnerWrapper(long interpreterHandle, long errorHandle, String signatureKey)24 NativeSignatureRunnerWrapper(long interpreterHandle, long errorHandle, String signatureKey) { 25 this.errorHandle = errorHandle; 26 signatureRunnerHandle = nativeGetSignatureRunner(interpreterHandle, signatureKey); 27 if (signatureRunnerHandle == -1) { 28 throw new IllegalArgumentException("Input error: Signature " + signatureKey + " not found."); 29 } 30 } 31 32 /** Gets the subgraph index associated with this Signature. */ getSubgraphIndex()33 public int getSubgraphIndex() { 34 return nativeGetSubgraphIndex(signatureRunnerHandle); 35 } 36 37 /** Gets the inputs of this Signature. */ inputNames()38 public String[] inputNames() { 39 return nativeInputNames(signatureRunnerHandle); 40 } 41 42 /** Gets the outputs of this Signature. */ outputNames()43 public String[] outputNames() { 44 return nativeOutputNames(signatureRunnerHandle); 45 } 46 47 /** Gets the input tensor specified by {@code inputName}. */ getInputTensor(String inputName)48 public TensorImpl getInputTensor(String inputName) { 49 return TensorImpl.fromSignatureInput(signatureRunnerHandle, inputName); 50 } 51 52 /** Gets the output tensor specified by {@code outputName}. */ getOutputTensor(String outputName)53 public TensorImpl getOutputTensor(String outputName) { 54 return TensorImpl.fromSignatureOutput(signatureRunnerHandle, outputName); 55 } 56 57 /** Gets the index of the input specified by {@code inputName}. */ getInputIndex(String inputName)58 public int getInputIndex(String inputName) { 59 int inputIndex = nativeGetInputIndex(signatureRunnerHandle, inputName); 60 if (inputIndex == -1) { 61 throw new IllegalArgumentException("Input error: input " + inputName + " not found."); 62 } 63 return inputIndex; 64 } 65 66 /** Gets the index of the output specified by {@code outputName}. */ getOutputIndex(String outputName)67 public int getOutputIndex(String outputName) { 68 int outputIndex = nativeGetOutputIndex(signatureRunnerHandle, outputName); 69 if (outputIndex == -1) { 70 throw new IllegalArgumentException("Input error: output " + outputName + " not found."); 71 } 72 return outputIndex; 73 } 74 75 /** Resizes dimensions of a specific input. */ resizeInput(String inputName, int[] dims)76 public boolean resizeInput(String inputName, int[] dims) { 77 isMemoryAllocated = false; 78 return nativeResizeInput(signatureRunnerHandle, errorHandle, inputName, dims); 79 } 80 81 /** Allocates tensor memory space. */ allocateTensorsIfNeeded()82 public void allocateTensorsIfNeeded() { 83 if (isMemoryAllocated) { 84 return; 85 } 86 87 nativeAllocateTensors(signatureRunnerHandle, errorHandle); 88 isMemoryAllocated = true; 89 } 90 91 /** Runs inference for this Signature. */ invoke()92 public void invoke() { 93 nativeInvoke(signatureRunnerHandle, errorHandle); 94 } 95 96 private final long signatureRunnerHandle; 97 98 private final long errorHandle; 99 100 private boolean isMemoryAllocated = false; 101 nativeGetSignatureRunner(long interpreterHandle, String signatureKey)102 private static native long nativeGetSignatureRunner(long interpreterHandle, String signatureKey); 103 nativeGetSubgraphIndex(long signatureRunnerHandle)104 private static native int nativeGetSubgraphIndex(long signatureRunnerHandle); 105 nativeInputNames(long signatureRunnerHandle)106 private static native String[] nativeInputNames(long signatureRunnerHandle); 107 nativeOutputNames(long signatureRunnerHandle)108 private static native String[] nativeOutputNames(long signatureRunnerHandle); 109 nativeGetInputIndex(long signatureRunnerHandle, String inputName)110 private static native int nativeGetInputIndex(long signatureRunnerHandle, String inputName); 111 nativeGetOutputIndex(long signatureRunnerHandle, String outputName)112 private static native int nativeGetOutputIndex(long signatureRunnerHandle, String outputName); 113 nativeResizeInput( long signatureRunnerHandle, long errorHandle, String inputName, int[] dims)114 private static native boolean nativeResizeInput( 115 long signatureRunnerHandle, long errorHandle, String inputName, int[] dims); 116 nativeAllocateTensors(long signatureRunnerHandle, long errorHandle)117 private static native void nativeAllocateTensors(long signatureRunnerHandle, long errorHandle); 118 nativeInvoke(long signatureRunnerHandle, long errorHandle)119 private static native void nativeInvoke(long signatureRunnerHandle, long errorHandle); 120 } 121