1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <jni.h>
17
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/op_resolver.h"
20 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
21 #include "tensorflow_lite_support/cc/utils/jni_utils.h"
22 #include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h"
23
24 namespace tflite {
25 namespace task {
26 // To be provided by a link-time library
27 extern std::unique_ptr<OpResolver> CreateOpResolver();
28
29 } // namespace task
30 } // namespace tflite
31
32 namespace {
33
34 using ::tflite::support::utils::kAssertionError;
35 using ::tflite::support::utils::kInvalidPointer;
36 using ::tflite::support::utils::GetMappedFileBuffer;
37 using ::tflite::support::utils::JStringToString;
38 using ::tflite::support::utils::ThrowException;
39 using ::tflite::task::text::nlclassifier::NLClassifier;
40 using ::tflite::task::text::nlclassifier::NLClassifierOptions;
41 using ::tflite::task::text::nlclassifier::RunClassifier;
42
43
ConvertJavaNLClassifierOptions(JNIEnv * env,jobject java_nl_classifier_options)44 NLClassifierOptions ConvertJavaNLClassifierOptions(
45 JNIEnv* env, jobject java_nl_classifier_options) {
46 jclass nl_classifier_options_class = env->FindClass(
47 "org/tensorflow/lite/task/text/nlclassifier/"
48 "NLClassifier$NLClassifierOptions");
49 jmethodID input_tensor_index_method_id =
50 env->GetMethodID(nl_classifier_options_class, "inputTensorIndex", "()I");
51 jmethodID output_score_tensor_index_method_id = env->GetMethodID(
52 nl_classifier_options_class, "outputScoreTensorIndex", "()I");
53 jmethodID output_label_tensor_index_method_id = env->GetMethodID(
54 nl_classifier_options_class, "outputLabelTensorIndex", "()I");
55 jmethodID input_tensor_name_method_id = env->GetMethodID(
56 nl_classifier_options_class, "inputTensorName", "()Ljava/lang/String;");
57 jmethodID output_score_tensor_name_method_id =
58 env->GetMethodID(nl_classifier_options_class, "outputScoreTensorName",
59 "()Ljava/lang/String;");
60 jmethodID output_label_tensor_name_method_id =
61 env->GetMethodID(nl_classifier_options_class, "outputLabelTensorName",
62 "()Ljava/lang/String;");
63
64 return {
65 .input_tensor_index = env->CallIntMethod(java_nl_classifier_options,
66 input_tensor_index_method_id),
67 .output_score_tensor_index = env->CallIntMethod(
68 java_nl_classifier_options, output_score_tensor_index_method_id),
69 .output_label_tensor_index = env->CallIntMethod(
70 java_nl_classifier_options, output_label_tensor_index_method_id),
71 .input_tensor_name = JStringToString(
72 env, (jstring)env->CallObjectMethod(java_nl_classifier_options,
73 input_tensor_name_method_id)),
74 .output_score_tensor_name = JStringToString(
75 env,
76 (jstring)env->CallObjectMethod(java_nl_classifier_options,
77 output_score_tensor_name_method_id)),
78 .output_label_tensor_name = JStringToString(
79 env,
80 (jstring)env->CallObjectMethod(java_nl_classifier_options,
81 output_label_tensor_name_method_id)),
82 };
83 }
84
85 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(JNIEnv * env,jobject thiz,jlong native_handle)86 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(
87 JNIEnv* env, jobject thiz, jlong native_handle) {
88 delete reinterpret_cast<NLClassifier*>(native_handle);
89 }
90
91 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(JNIEnv * env,jclass thiz,jobject nl_classifier_options,jobject model_buffer)92 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(
93 JNIEnv* env, jclass thiz, jobject nl_classifier_options,
94 jobject model_buffer) {
95 auto model = GetMappedFileBuffer(env, model_buffer);
96 tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status =
97 NLClassifier::CreateFromBufferAndOptions(
98 model.data(), model.size(),
99 ConvertJavaNLClassifierOptions(env, nl_classifier_options),
100 tflite::task::CreateOpResolver());
101
102 if (status.ok()) {
103 return reinterpret_cast<jlong>(status->release());
104 } else {
105 ThrowException(env, kAssertionError,
106 "Error occurred when initializing NLClassifier: %s",
107 status.status().message().data());
108 return kInvalidPointer;
109 }
110 }
111
112 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(JNIEnv * env,jclass thiz,jobject nl_classifier_options,jint fd)113 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(
114 JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd) {
115 tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status =
116 NLClassifier::CreateFromFdAndOptions(
117 fd, ConvertJavaNLClassifierOptions(env, nl_classifier_options),
118 tflite::task::CreateOpResolver());
119 if (status.ok()) {
120 return reinterpret_cast<jlong>(status->release());
121 } else {
122 ThrowException(env, kAssertionError,
123 "Error occurred when initializing NLClassifier: %s",
124 status.status().message().data());
125 return kInvalidPointer;
126 }
127 }
128
129 extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(JNIEnv * env,jclass thiz,jlong native_handle,jstring text)130 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(
131 JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
132 return RunClassifier(env, native_handle, text);
133 }
134
135 } // namespace
136