1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite.task.text.nlclassifier; 17 18 import android.content.Context; 19 import android.os.ParcelFileDescriptor; 20 import com.google.auto.value.AutoValue; 21 import java.io.File; 22 import java.io.IOException; 23 import java.nio.ByteBuffer; 24 import java.nio.MappedByteBuffer; 25 import java.util.List; 26 import org.tensorflow.lite.annotations.UsedByReflection; 27 import org.tensorflow.lite.support.label.Category; 28 import org.tensorflow.lite.task.core.BaseTaskApi; 29 import org.tensorflow.lite.task.core.TaskJniUtils; 30 import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; 31 32 /** 33 * Classifier API for natural language classification tasks, categorizes string into different 34 * classes. 35 * 36 * <p>The API expects a TFLite model with the following input/output tensor: 37 * 38 * <ul> 39 * <li>Input tensor (kTfLiteString) 40 * <ul> 41 * <li>input of the model, accepts a string. 42 * </ul> 43 * <li>Output score tensor 44 * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool) 45 * <ul> 46 * <li>output scores for each class, if type is one of the Int types, dequantize it, if it 47 * is Bool type, convert the values to 0.0 and 1.0 respectively. 48 * <li>can have an optional associated file in metadata for labels, the file should be a 49 * plain text file with one label per line, the number of labels should match the number 50 * of categories the model outputs. Output label tensor: optional (kTfLiteString) - 51 * output classname for each class, should be of the same length with scores. If this 52 * tensor is not present, the API uses score indices as classnames. - will be ignored if 53 * output score tensor already has an associated label file. 54 * </ul> 55 * <li>Optional Output label tensor (kTfLiteString/kTfLiteInt32) 56 * <ul> 57 * <li>output classname for each class, should be of the same length with scores. If this 58 * tensor is not present, the API uses score indices as classnames. 59 * <li>will be ignored if output score tensor already has an associated labe file. 60 * </ul> 61 * </ul> 62 * 63 * <p>By default the API tries to find the input/output tensors with default configurations in 64 * {@link NLClassifierOptions}, with tensor name prioritized over tensor index. The option is 65 * configurable for different TFLite models. 66 */ 67 public class NLClassifier extends BaseTaskApi { 68 69 /** Options to identify input and output tensors of the model. */ 70 @AutoValue 71 @UsedByReflection("nl_classifier_jni.cc") 72 public abstract static class NLClassifierOptions { 73 private static final int DEFAULT_INPUT_TENSOR_INDEX = 0; 74 private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0; 75 // By default there is no output label tensor. The label file can be attached 76 // to the output score tensor metadata. 77 private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1; 78 private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT"; 79 private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE"; 80 private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL"; 81 82 @UsedByReflection("nl_classifier_jni.cc") inputTensorIndex()83 abstract int inputTensorIndex(); 84 85 @UsedByReflection("nl_classifier_jni.cc") outputScoreTensorIndex()86 abstract int outputScoreTensorIndex(); 87 88 @UsedByReflection("nl_classifier_jni.cc") outputLabelTensorIndex()89 abstract int outputLabelTensorIndex(); 90 91 @UsedByReflection("nl_classifier_jni.cc") inputTensorName()92 abstract String inputTensorName(); 93 94 @UsedByReflection("nl_classifier_jni.cc") outputScoreTensorName()95 abstract String outputScoreTensorName(); 96 97 @UsedByReflection("nl_classifier_jni.cc") outputLabelTensorName()98 abstract String outputLabelTensorName(); 99 builder()100 public static Builder builder() { 101 return new AutoValue_NLClassifier_NLClassifierOptions.Builder() 102 .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX) 103 .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX) 104 .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX) 105 .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME) 106 .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME) 107 .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME); 108 } 109 110 /** Builder for {@link NLClassifierOptions}. */ 111 @AutoValue.Builder 112 public abstract static class Builder { setInputTensorIndex(int value)113 public abstract Builder setInputTensorIndex(int value); 114 setOutputScoreTensorIndex(int value)115 public abstract Builder setOutputScoreTensorIndex(int value); 116 setOutputLabelTensorIndex(int value)117 public abstract Builder setOutputLabelTensorIndex(int value); 118 setInputTensorName(String value)119 public abstract Builder setInputTensorName(String value); 120 setOutputScoreTensorName(String value)121 public abstract Builder setOutputScoreTensorName(String value); 122 setOutputLabelTensorName(String value)123 public abstract Builder setOutputLabelTensorName(String value); 124 build()125 public abstract NLClassifierOptions build(); 126 } 127 } 128 129 private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "tflite_support_classifiers_native"; 130 131 /** 132 * Constructor to initialize the JNI with a pointer from C++. 133 * 134 * @param nativeHandle a pointer referencing memory allocated in C++. 135 */ NLClassifier(long nativeHandle)136 protected NLClassifier(long nativeHandle) { 137 super(nativeHandle); 138 } 139 140 /** 141 * Create {@link NLClassifier} from default {@link NLClassifierOptions}. 142 * 143 * @param context Android context. 144 * @param pathToModel Path to the classification model relative to asset dir. 145 * @return {@link NLClassifier} instance. 146 * @throws IOException If model file fails to load. 147 */ createFromFile(Context context, String pathToModel)148 public static NLClassifier createFromFile(Context context, String pathToModel) 149 throws IOException { 150 return createFromFileAndOptions(context, pathToModel, NLClassifierOptions.builder().build()); 151 } 152 153 /** 154 * Create {@link NLClassifier} from default {@link NLClassifierOptions}. 155 * 156 * @param modelFile The classification model {@link File} instance. 157 * @return {@link NLClassifier} instance. 158 * @throws IOException If model file fails to load. 159 */ createFromFile(File modelFile)160 public static NLClassifier createFromFile(File modelFile) throws IOException { 161 return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build()); 162 } 163 164 /** 165 * Create {@link NLClassifier} from {@link NLClassifierOptions}. 166 * 167 * @param context Android context 168 * @param pathToModel Path to the classification model relative to asset dir. 169 * @param options Configurations for the model. 170 * @return {@link NLClassifier} instance. 171 * @throws IOException If model file fails to load. 172 */ createFromFileAndOptions( Context context, String pathToModel, NLClassifierOptions options)173 public static NLClassifier createFromFileAndOptions( 174 Context context, String pathToModel, NLClassifierOptions options) throws IOException { 175 return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, pathToModel), options); 176 } 177 178 /** 179 * Create {@link NLClassifier} from {@link NLClassifierOptions}. 180 * 181 * @param modelFile The classification model {@link File} instance. 182 * @param options Configurations for the model. 183 * @return {@link NLClassifier} instance. 184 * @throws IOException If model file fails to load. 185 */ createFromFileAndOptions( File modelFile, final NLClassifierOptions options)186 public static NLClassifier createFromFileAndOptions( 187 File modelFile, final NLClassifierOptions options) throws IOException { 188 try (ParcelFileDescriptor descriptor = 189 ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { 190 return new NLClassifier( 191 TaskJniUtils.createHandleFromLibrary( 192 new EmptyHandleProvider() { 193 @Override 194 public long createHandle() { 195 return initJniWithFileDescriptor(options, descriptor.getFd()); 196 } 197 }, 198 NL_CLASSIFIER_NATIVE_LIBNAME)); 199 } 200 } 201 202 /** 203 * Create {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}. 204 * 205 * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the 206 * classification model 207 * @param options Configurations for the model 208 * @return {@link NLClassifier} instance 209 * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a 210 * {@link MappedByteBuffer} 211 */ 212 public static NLClassifier createFromBufferAndOptions( 213 final ByteBuffer modelBuffer, final NLClassifierOptions options) { 214 if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { 215 throw new IllegalArgumentException( 216 "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); 217 } 218 return new NLClassifier( 219 TaskJniUtils.createHandleFromLibrary( 220 new EmptyHandleProvider() { 221 @Override 222 public long createHandle() { 223 return initJniWithByteBuffer(options, modelBuffer); 224 } 225 }, 226 NL_CLASSIFIER_NATIVE_LIBNAME)); 227 } 228 229 /** 230 * Perform classification on a string input, returns classified {@link Category}s. 231 * 232 * @param text input text to the model. 233 * @return A list of Category results. 234 */ 235 public List<Category> classify(String text) { 236 return classifyNative(getNativeHandle(), text); 237 } 238 239 private static native long initJniWithByteBuffer( 240 NLClassifierOptions options, ByteBuffer modelBuffer); 241 242 private static native long initJniWithFileDescriptor(NLClassifierOptions options, int fd); 243 244 private static native List<Category> classifyNative(long nativeHandle, String text); 245 246 @Override 247 protected void deinit(long nativeHandle) { 248 deinitJni(nativeHandle); 249 } 250 251 /** 252 * Native implementation to release memory pointed by the pointer. 253 * 254 * @param nativeHandle pointer to memory allocated 255 */ 256 private native void deinitJni(long nativeHandle); 257 } 258