• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite.task.text.nlclassifier;
17 
18 import android.content.Context;
19 import android.os.ParcelFileDescriptor;
20 import com.google.auto.value.AutoValue;
21 import java.io.File;
22 import java.io.IOException;
23 import java.nio.ByteBuffer;
24 import java.nio.MappedByteBuffer;
25 import java.util.List;
26 import org.tensorflow.lite.annotations.UsedByReflection;
27 import org.tensorflow.lite.support.label.Category;
28 import org.tensorflow.lite.task.core.BaseTaskApi;
29 import org.tensorflow.lite.task.core.TaskJniUtils;
30 import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
31 
32 /**
33  * Classifier API for natural language classification tasks, categorizes string into different
34  * classes.
35  *
36  * <p>The API expects a TFLite model with the following input/output tensor:
37  *
38  * <ul>
39  *   <li>Input tensor (kTfLiteString)
40  *       <ul>
41  *         <li>input of the model, accepts a string.
42  *       </ul>
43  *   <li>Output score tensor
44  *       (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool)
45  *       <ul>
46  *         <li>output scores for each class, if type is one of the Int types, dequantize it, if it
47  *             is Bool type, convert the values to 0.0 and 1.0 respectively.
48  *         <li>can have an optional associated file in metadata for labels, the file should be a
49  *             plain text file with one label per line, the number of labels should match the number
50  *             of categories the model outputs. Output label tensor: optional (kTfLiteString) -
51  *             output classname for each class, should be of the same length with scores. If this
52  *             tensor is not present, the API uses score indices as classnames. - will be ignored if
53  *             output score tensor already has an associated label file.
54  *       </ul>
55  *   <li>Optional Output label tensor (kTfLiteString/kTfLiteInt32)
56  *       <ul>
57  *         <li>output classname for each class, should be of the same length with scores. If this
58  *             tensor is not present, the API uses score indices as classnames.
59  *         <li>will be ignored if output score tensor already has an associated labe file.
60  *       </ul>
61  * </ul>
62  *
63  * <p>By default the API tries to find the input/output tensors with default configurations in
64  * {@link NLClassifierOptions}, with tensor name prioritized over tensor index. The option is
65  * configurable for different TFLite models.
66  */
67 public class NLClassifier extends BaseTaskApi {
68 
69   /** Options to identify input and output tensors of the model. */
70   @AutoValue
71   @UsedByReflection("nl_classifier_jni.cc")
72   public abstract static class NLClassifierOptions {
73     private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
74     private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
75     // By default there is no output label tensor. The label file can be attached
76     // to the output score tensor metadata.
77     private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
78     private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
79     private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
80     private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
81 
82     @UsedByReflection("nl_classifier_jni.cc")
inputTensorIndex()83     abstract int inputTensorIndex();
84 
85     @UsedByReflection("nl_classifier_jni.cc")
outputScoreTensorIndex()86     abstract int outputScoreTensorIndex();
87 
88     @UsedByReflection("nl_classifier_jni.cc")
outputLabelTensorIndex()89     abstract int outputLabelTensorIndex();
90 
91     @UsedByReflection("nl_classifier_jni.cc")
inputTensorName()92     abstract String inputTensorName();
93 
94     @UsedByReflection("nl_classifier_jni.cc")
outputScoreTensorName()95     abstract String outputScoreTensorName();
96 
97     @UsedByReflection("nl_classifier_jni.cc")
outputLabelTensorName()98     abstract String outputLabelTensorName();
99 
builder()100     public static Builder builder() {
101       return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
102           .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
103           .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
104           .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
105           .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
106           .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
107           .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
108     }
109 
110     /** Builder for {@link NLClassifierOptions}. */
111     @AutoValue.Builder
112     public abstract static class Builder {
setInputTensorIndex(int value)113       public abstract Builder setInputTensorIndex(int value);
114 
setOutputScoreTensorIndex(int value)115       public abstract Builder setOutputScoreTensorIndex(int value);
116 
setOutputLabelTensorIndex(int value)117       public abstract Builder setOutputLabelTensorIndex(int value);
118 
setInputTensorName(String value)119       public abstract Builder setInputTensorName(String value);
120 
setOutputScoreTensorName(String value)121       public abstract Builder setOutputScoreTensorName(String value);
122 
setOutputLabelTensorName(String value)123       public abstract Builder setOutputLabelTensorName(String value);
124 
build()125       public abstract NLClassifierOptions build();
126     }
127   }
128 
129   private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "tflite_support_classifiers_native";
130 
131   /**
132    * Constructor to initialize the JNI with a pointer from C++.
133    *
134    * @param nativeHandle a pointer referencing memory allocated in C++.
135    */
NLClassifier(long nativeHandle)136   protected NLClassifier(long nativeHandle) {
137     super(nativeHandle);
138   }
139 
140   /**
141    * Create {@link NLClassifier} from default {@link NLClassifierOptions}.
142    *
143    * @param context Android context.
144    * @param pathToModel Path to the classification model relative to asset dir.
145    * @return {@link NLClassifier} instance.
146    * @throws IOException If model file fails to load.
147    */
createFromFile(Context context, String pathToModel)148   public static NLClassifier createFromFile(Context context, String pathToModel)
149       throws IOException {
150     return createFromFileAndOptions(context, pathToModel, NLClassifierOptions.builder().build());
151   }
152 
153   /**
154    * Create {@link NLClassifier} from default {@link NLClassifierOptions}.
155    *
156    * @param modelFile The classification model {@link File} instance.
157    * @return {@link NLClassifier} instance.
158    * @throws IOException If model file fails to load.
159    */
createFromFile(File modelFile)160   public static NLClassifier createFromFile(File modelFile) throws IOException {
161     return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
162   }
163 
164   /**
165    * Create {@link NLClassifier} from {@link NLClassifierOptions}.
166    *
167    * @param context Android context
168    * @param pathToModel Path to the classification model relative to asset dir.
169    * @param options Configurations for the model.
170    * @return {@link NLClassifier} instance.
171    * @throws IOException If model file fails to load.
172    */
createFromFileAndOptions( Context context, String pathToModel, NLClassifierOptions options)173   public static NLClassifier createFromFileAndOptions(
174       Context context, String pathToModel, NLClassifierOptions options) throws IOException {
175     return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, pathToModel), options);
176   }
177 
178   /**
179    * Create {@link NLClassifier} from {@link NLClassifierOptions}.
180    *
181    * @param modelFile The classification model {@link File} instance.
182    * @param options Configurations for the model.
183    * @return {@link NLClassifier} instance.
184    * @throws IOException If model file fails to load.
185    */
createFromFileAndOptions( File modelFile, final NLClassifierOptions options)186   public static NLClassifier createFromFileAndOptions(
187       File modelFile, final NLClassifierOptions options) throws IOException {
188     try (ParcelFileDescriptor descriptor =
189         ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
190       return new NLClassifier(
191           TaskJniUtils.createHandleFromLibrary(
192               new EmptyHandleProvider() {
193                 @Override
194                 public long createHandle() {
195                   return initJniWithFileDescriptor(options, descriptor.getFd());
196                 }
197               },
198               NL_CLASSIFIER_NATIVE_LIBNAME));
199     }
200   }
201 
202   /**
203    * Create {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
204    *
205    * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
206    *     classification model
207    * @param options Configurations for the model
208    * @return {@link NLClassifier} instance
209    * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
210    *     {@link MappedByteBuffer}
211    */
212   public static NLClassifier createFromBufferAndOptions(
213       final ByteBuffer modelBuffer, final NLClassifierOptions options) {
214     if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
215       throw new IllegalArgumentException(
216           "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
217     }
218     return new NLClassifier(
219         TaskJniUtils.createHandleFromLibrary(
220             new EmptyHandleProvider() {
221               @Override
222               public long createHandle() {
223                 return initJniWithByteBuffer(options, modelBuffer);
224               }
225             },
226             NL_CLASSIFIER_NATIVE_LIBNAME));
227   }
228 
229   /**
230    * Perform classification on a string input, returns classified {@link Category}s.
231    *
232    * @param text input text to the model.
233    * @return A list of Category results.
234    */
235   public List<Category> classify(String text) {
236     return classifyNative(getNativeHandle(), text);
237   }
238 
239   private static native long initJniWithByteBuffer(
240       NLClassifierOptions options, ByteBuffer modelBuffer);
241 
242   private static native long initJniWithFileDescriptor(NLClassifierOptions options, int fd);
243 
244   private static native List<Category> classifyNative(long nativeHandle, String text);
245 
246   @Override
247   protected void deinit(long nativeHandle) {
248     deinitJni(nativeHandle);
249   }
250 
251   /**
252    * Native implementation to release memory pointed by the pointer.
253    *
254    * @param nativeHandle pointer to memory allocated
255    */
256   private native void deinitJni(long nativeHandle);
257 }
258