• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 /**
19  * An internal wrapper that wraps native SignatureRunner.
20  *
21  * <p>Note: This class is not thread safe.
22  */
23 final class NativeSignatureRunnerWrapper {
NativeSignatureRunnerWrapper(long interpreterHandle, long errorHandle, String signatureKey)24   NativeSignatureRunnerWrapper(long interpreterHandle, long errorHandle, String signatureKey) {
25     this.errorHandle = errorHandle;
26     signatureRunnerHandle = nativeGetSignatureRunner(interpreterHandle, signatureKey);
27     if (signatureRunnerHandle == -1) {
28       throw new IllegalArgumentException("Input error: Signature " + signatureKey + " not found.");
29     }
30   }
31 
32   /** Gets the subgraph index associated with this Signature. */
getSubgraphIndex()33   public int getSubgraphIndex() {
34     return nativeGetSubgraphIndex(signatureRunnerHandle);
35   }
36 
37   /** Gets the inputs of this Signature. */
inputNames()38   public String[] inputNames() {
39     return nativeInputNames(signatureRunnerHandle);
40   }
41 
42   /** Gets the outputs of this Signature. */
outputNames()43   public String[] outputNames() {
44     return nativeOutputNames(signatureRunnerHandle);
45   }
46 
47   /** Gets the input tensor specified by {@code inputName}. */
getInputTensor(String inputName)48   public TensorImpl getInputTensor(String inputName) {
49     return TensorImpl.fromSignatureInput(signatureRunnerHandle, inputName);
50   }
51 
52   /** Gets the output tensor specified by {@code outputName}. */
getOutputTensor(String outputName)53   public TensorImpl getOutputTensor(String outputName) {
54     return TensorImpl.fromSignatureOutput(signatureRunnerHandle, outputName);
55   }
56 
57   /** Gets the index of the input specified by {@code inputName}. */
getInputIndex(String inputName)58   public int getInputIndex(String inputName) {
59     int inputIndex = nativeGetInputIndex(signatureRunnerHandle, inputName);
60     if (inputIndex == -1) {
61       throw new IllegalArgumentException("Input error: input " + inputName + " not found.");
62     }
63     return inputIndex;
64   }
65 
66   /** Gets the index of the output specified by {@code outputName}. */
getOutputIndex(String outputName)67   public int getOutputIndex(String outputName) {
68     int outputIndex = nativeGetOutputIndex(signatureRunnerHandle, outputName);
69     if (outputIndex == -1) {
70       throw new IllegalArgumentException("Input error: output " + outputName + " not found.");
71     }
72     return outputIndex;
73   }
74 
75   /** Resizes dimensions of a specific input. */
resizeInput(String inputName, int[] dims)76   public boolean resizeInput(String inputName, int[] dims) {
77     isMemoryAllocated = false;
78     return nativeResizeInput(signatureRunnerHandle, errorHandle, inputName, dims);
79   }
80 
81   /** Allocates tensor memory space. */
allocateTensorsIfNeeded()82   public void allocateTensorsIfNeeded() {
83     if (isMemoryAllocated) {
84       return;
85     }
86 
87     nativeAllocateTensors(signatureRunnerHandle, errorHandle);
88     isMemoryAllocated = true;
89   }
90 
91   /** Runs inference for this Signature. */
invoke()92   public void invoke() {
93     nativeInvoke(signatureRunnerHandle, errorHandle);
94   }
95 
96   private final long signatureRunnerHandle;
97 
98   private final long errorHandle;
99 
100   private boolean isMemoryAllocated = false;
101 
nativeGetSignatureRunner(long interpreterHandle, String signatureKey)102   private static native long nativeGetSignatureRunner(long interpreterHandle, String signatureKey);
103 
nativeGetSubgraphIndex(long signatureRunnerHandle)104   private static native int nativeGetSubgraphIndex(long signatureRunnerHandle);
105 
nativeInputNames(long signatureRunnerHandle)106   private static native String[] nativeInputNames(long signatureRunnerHandle);
107 
nativeOutputNames(long signatureRunnerHandle)108   private static native String[] nativeOutputNames(long signatureRunnerHandle);
109 
nativeGetInputIndex(long signatureRunnerHandle, String inputName)110   private static native int nativeGetInputIndex(long signatureRunnerHandle, String inputName);
111 
nativeGetOutputIndex(long signatureRunnerHandle, String outputName)112   private static native int nativeGetOutputIndex(long signatureRunnerHandle, String outputName);
113 
nativeResizeInput( long signatureRunnerHandle, long errorHandle, String inputName, int[] dims)114   private static native boolean nativeResizeInput(
115       long signatureRunnerHandle, long errorHandle, String inputName, int[] dims);
116 
nativeAllocateTensors(long signatureRunnerHandle, long errorHandle)117   private static native void nativeAllocateTensors(long signatureRunnerHandle, long errorHandle);
118 
nativeInvoke(long signatureRunnerHandle, long errorHandle)119   private static native void nativeInvoke(long signatureRunnerHandle, long errorHandle);
120 }
121