• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <jni.h>
17 
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/op_resolver.h"
20 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
21 #include "tensorflow_lite_support/cc/utils/jni_utils.h"
22 #include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h"
23 
24 namespace tflite {
25 namespace task {
26 // To be provided by a link-time library
27 extern std::unique_ptr<OpResolver> CreateOpResolver();
28 
29 }  // namespace task
30 }  // namespace tflite
31 
32 namespace {
33 
34 using ::tflite::support::utils::kAssertionError;
35 using ::tflite::support::utils::kInvalidPointer;
36 using ::tflite::support::utils::GetMappedFileBuffer;
37 using ::tflite::support::utils::JStringToString;
38 using ::tflite::support::utils::ThrowException;
39 using ::tflite::task::text::nlclassifier::NLClassifier;
40 using ::tflite::task::text::nlclassifier::NLClassifierOptions;
41 using ::tflite::task::text::nlclassifier::RunClassifier;
42 
43 
ConvertJavaNLClassifierOptions(JNIEnv * env,jobject java_nl_classifier_options)44 NLClassifierOptions ConvertJavaNLClassifierOptions(
45     JNIEnv* env, jobject java_nl_classifier_options) {
46   jclass nl_classifier_options_class = env->FindClass(
47       "org/tensorflow/lite/task/text/nlclassifier/"
48       "NLClassifier$NLClassifierOptions");
49   jmethodID input_tensor_index_method_id =
50       env->GetMethodID(nl_classifier_options_class, "inputTensorIndex", "()I");
51   jmethodID output_score_tensor_index_method_id = env->GetMethodID(
52       nl_classifier_options_class, "outputScoreTensorIndex", "()I");
53   jmethodID output_label_tensor_index_method_id = env->GetMethodID(
54       nl_classifier_options_class, "outputLabelTensorIndex", "()I");
55   jmethodID input_tensor_name_method_id = env->GetMethodID(
56       nl_classifier_options_class, "inputTensorName", "()Ljava/lang/String;");
57   jmethodID output_score_tensor_name_method_id =
58       env->GetMethodID(nl_classifier_options_class, "outputScoreTensorName",
59                        "()Ljava/lang/String;");
60   jmethodID output_label_tensor_name_method_id =
61       env->GetMethodID(nl_classifier_options_class, "outputLabelTensorName",
62                        "()Ljava/lang/String;");
63 
64   return {
65       .input_tensor_index = env->CallIntMethod(java_nl_classifier_options,
66                                                input_tensor_index_method_id),
67       .output_score_tensor_index = env->CallIntMethod(
68           java_nl_classifier_options, output_score_tensor_index_method_id),
69       .output_label_tensor_index = env->CallIntMethod(
70           java_nl_classifier_options, output_label_tensor_index_method_id),
71       .input_tensor_name = JStringToString(
72           env, (jstring)env->CallObjectMethod(java_nl_classifier_options,
73                                               input_tensor_name_method_id)),
74       .output_score_tensor_name = JStringToString(
75           env,
76           (jstring)env->CallObjectMethod(java_nl_classifier_options,
77                                          output_score_tensor_name_method_id)),
78       .output_label_tensor_name = JStringToString(
79           env,
80           (jstring)env->CallObjectMethod(java_nl_classifier_options,
81                                          output_label_tensor_name_method_id)),
82   };
83 }
84 
85 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(JNIEnv * env,jobject thiz,jlong native_handle)86 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(
87     JNIEnv* env, jobject thiz, jlong native_handle) {
88   delete reinterpret_cast<NLClassifier*>(native_handle);
89 }
90 
91 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(JNIEnv * env,jclass thiz,jobject nl_classifier_options,jobject model_buffer)92 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(
93     JNIEnv* env, jclass thiz, jobject nl_classifier_options,
94     jobject model_buffer) {
95   auto model = GetMappedFileBuffer(env, model_buffer);
96   tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status =
97       NLClassifier::CreateFromBufferAndOptions(
98           model.data(), model.size(),
99           ConvertJavaNLClassifierOptions(env, nl_classifier_options),
100           tflite::task::CreateOpResolver());
101 
102   if (status.ok()) {
103     return reinterpret_cast<jlong>(status->release());
104   } else {
105     ThrowException(env, kAssertionError,
106                    "Error occurred when initializing NLClassifier: %s",
107                    status.status().message().data());
108     return kInvalidPointer;
109   }
110 }
111 
112 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(JNIEnv * env,jclass thiz,jobject nl_classifier_options,jint fd)113 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(
114     JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd) {
115   tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status =
116       NLClassifier::CreateFromFdAndOptions(
117           fd, ConvertJavaNLClassifierOptions(env, nl_classifier_options),
118           tflite::task::CreateOpResolver());
119   if (status.ok()) {
120     return reinterpret_cast<jlong>(status->release());
121   } else {
122     ThrowException(env, kAssertionError,
123                    "Error occurred when initializing NLClassifier: %s",
124                    status.status().message().data());
125     return kInvalidPointer;
126   }
127 }
128 
129 extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(JNIEnv * env,jclass thiz,jlong native_handle,jstring text)130 Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(
131     JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
132   return RunClassifier(env, native_handle, text);
133 }
134 
135 }  // namespace
136