1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Simple JNI wrapper for the SmartSelection library.
18
19 #include "textclassifier_jni.h"
20
21 #include <jni.h>
22 #include <vector>
23
24 #include "lang_id/lang-id.h"
25 #include "smartselect/text-classification-model.h"
26
27 using libtextclassifier::TextClassificationModel;
28 using libtextclassifier::ModelOptions;
29 using libtextclassifier::nlp_core::lang_id::LangId;
30
31 namespace {
32
JStringToUtf8String(JNIEnv * env,const jstring & jstr,std::string * result)33 bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
34 std::string* result) {
35 if (jstr == nullptr) {
36 *result = std::string();
37 return false;
38 }
39
40 jclass string_class = env->FindClass("java/lang/String");
41 jmethodID get_bytes_id =
42 env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
43
44 jstring encoding = env->NewStringUTF("UTF-8");
45 jbyteArray array = reinterpret_cast<jbyteArray>(
46 env->CallObjectMethod(jstr, get_bytes_id, encoding));
47
48 jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
49 int length = env->GetArrayLength(array);
50
51 *result = std::string(reinterpret_cast<char*>(array_bytes), length);
52
53 // Release the array.
54 env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
55 env->DeleteLocalRef(array);
56 env->DeleteLocalRef(string_class);
57 env->DeleteLocalRef(encoding);
58
59 return true;
60 }
61
ToStlString(JNIEnv * env,const jstring & str)62 std::string ToStlString(JNIEnv* env, const jstring& str) {
63 std::string result;
64 JStringToUtf8String(env, str, &result);
65 return result;
66 }
67
ScoredStringsToJObjectArray(JNIEnv * env,const std::string & result_class_name,const std::vector<std::pair<std::string,float>> & classification_result)68 jobjectArray ScoredStringsToJObjectArray(
69 JNIEnv* env, const std::string& result_class_name,
70 const std::vector<std::pair<std::string, float>>& classification_result) {
71 jclass result_class = env->FindClass(result_class_name.c_str());
72 jmethodID result_class_constructor =
73 env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V");
74
75 jobjectArray results =
76 env->NewObjectArray(classification_result.size(), result_class, nullptr);
77
78 for (int i = 0; i < classification_result.size(); i++) {
79 jstring row_string =
80 env->NewStringUTF(classification_result[i].first.c_str());
81 jobject result =
82 env->NewObject(result_class, result_class_constructor, row_string,
83 static_cast<jfloat>(classification_result[i].second));
84 env->SetObjectArrayElement(results, i, result);
85 env->DeleteLocalRef(result);
86 }
87 env->DeleteLocalRef(result_class);
88 return results;
89 }
90
91 } // namespace
92
93 namespace libtextclassifier {
94
95 using libtextclassifier::CodepointSpan;
96
97 namespace {
98
ConvertIndicesBMPUTF8(const std::string & utf8_str,CodepointSpan orig_indices,bool from_utf8)99 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
100 CodepointSpan orig_indices,
101 bool from_utf8) {
102 const libtextclassifier::UnicodeText unicode_str =
103 libtextclassifier::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
104
105 int unicode_index = 0;
106 int bmp_index = 0;
107
108 const int* source_index;
109 const int* target_index;
110 if (from_utf8) {
111 source_index = &unicode_index;
112 target_index = &bmp_index;
113 } else {
114 source_index = &bmp_index;
115 target_index = &unicode_index;
116 }
117
118 CodepointSpan result{-1, -1};
119 std::function<void()> assign_indices_fn = [&result, &orig_indices,
120 &source_index, &target_index]() {
121 if (orig_indices.first == *source_index) {
122 result.first = *target_index;
123 }
124
125 if (orig_indices.second == *source_index) {
126 result.second = *target_index;
127 }
128 };
129
130 for (auto it = unicode_str.begin(); it != unicode_str.end();
131 ++it, ++unicode_index, ++bmp_index) {
132 assign_indices_fn();
133
134 // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
135 if (*it > 0xFFFF) {
136 ++bmp_index;
137 }
138 }
139 assign_indices_fn();
140
141 return result;
142 }
143
144 } // namespace
145
ConvertIndicesBMPToUTF8(const std::string & utf8_str,CodepointSpan orig_indices)146 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
147 CodepointSpan orig_indices) {
148 return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/false);
149 }
150
ConvertIndicesUTF8ToBMP(const std::string & utf8_str,CodepointSpan orig_indices)151 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
152 CodepointSpan orig_indices) {
153 return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/true);
154 }
155
156 } // namespace libtextclassifier
157
158 using libtextclassifier::ConvertIndicesUTF8ToBMP;
159 using libtextclassifier::ConvertIndicesBMPToUTF8;
160 using libtextclassifier::CodepointSpan;
161
162 JNIEXPORT jlong JNICALL
Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv * env,jobject thiz,jint fd)163 Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env,
164 jobject thiz,
165 jint fd) {
166 TextClassificationModel* model = new TextClassificationModel(fd);
167 return reinterpret_cast<jlong>(model);
168 }
169
170 JNIEXPORT jintArray JNICALL
Java_android_view_textclassifier_SmartSelection_nativeSuggest(JNIEnv * env,jobject thiz,jlong ptr,jstring context,jint selection_begin,jint selection_end)171 Java_android_view_textclassifier_SmartSelection_nativeSuggest(
172 JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
173 jint selection_end) {
174 TextClassificationModel* model =
175 reinterpret_cast<TextClassificationModel*>(ptr);
176
177 const std::string context_utf8 = ToStlString(env, context);
178 CodepointSpan input_indices =
179 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
180 CodepointSpan selection =
181 model->SuggestSelection(context_utf8, input_indices);
182 selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
183
184 jintArray result = env->NewIntArray(2);
185 env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
186 env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
187 return result;
188 }
189
190 JNIEXPORT jobjectArray JNICALL
Java_android_view_textclassifier_SmartSelection_nativeClassifyText(JNIEnv * env,jobject thiz,jlong ptr,jstring context,jint selection_begin,jint selection_end,jint input_flags)191 Java_android_view_textclassifier_SmartSelection_nativeClassifyText(
192 JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
193 jint selection_end, jint input_flags) {
194 TextClassificationModel* ff_model =
195 reinterpret_cast<TextClassificationModel*>(ptr);
196 const std::vector<std::pair<std::string, float>> classification_result =
197 ff_model->ClassifyText(ToStlString(env, context),
198 {selection_begin, selection_end}, input_flags);
199
200 return ScoredStringsToJObjectArray(
201 env, "android/view/textclassifier/SmartSelection$ClassificationResult",
202 classification_result);
203 }
204
205 JNIEXPORT void JNICALL
Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv * env,jobject thiz,jlong ptr)206 Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env,
207 jobject thiz,
208 jlong ptr) {
209 TextClassificationModel* model =
210 reinterpret_cast<TextClassificationModel*>(ptr);
211 delete model;
212 }
213
Java_android_view_textclassifier_LangId_nativeNew(JNIEnv * env,jobject thiz,jint fd)214 JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
215 JNIEnv* env, jobject thiz, jint fd) {
216 return reinterpret_cast<jlong>(new LangId(fd));
217 }
218
219 JNIEXPORT jstring JNICALL
Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv * env,jobject clazz,jint fd)220 Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
221 jobject clazz,
222 jint fd) {
223 ModelOptions model_options;
224 if (ReadSelectionModelOptions(fd, &model_options)) {
225 return env->NewStringUTF(model_options.language().c_str());
226 } else {
227 return env->NewStringUTF("UNK");
228 }
229 }
230
231 JNIEXPORT jint JNICALL
Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv * env,jobject clazz,jint fd)232 Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
233 jobject clazz,
234 jint fd) {
235 ModelOptions model_options;
236 if (ReadSelectionModelOptions(fd, &model_options)) {
237 return model_options.version();
238 } else {
239 return -1;
240 }
241 }
242
243 JNIEXPORT jobjectArray JNICALL
Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv * env,jobject thiz,jlong ptr,jstring text)244 Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
245 jobject thiz,
246 jlong ptr,
247 jstring text) {
248 LangId* lang_id = reinterpret_cast<LangId*>(ptr);
249 const std::vector<std::pair<std::string, float>> scored_languages =
250 lang_id->FindLanguages(ToStlString(env, text));
251
252 return ScoredStringsToJObjectArray(
253 env, "android/view/textclassifier/LangId$ClassificationResult",
254 scored_languages);
255 }
256
Java_android_view_textclassifier_LangId_nativeClose(JNIEnv * env,jobject thiz,jlong ptr)257 JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose(
258 JNIEnv* env, jobject thiz, jlong ptr) {
259 LangId* lang_id = reinterpret_cast<LangId*>(ptr);
260 delete lang_id;
261 }
262
Java_android_view_textclassifier_LangId_nativeGetVersion(JNIEnv * env,jobject clazz,jint fd)263 JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion(
264 JNIEnv* env, jobject clazz, jint fd) {
265 std::unique_ptr<LangId> lang_id(new LangId(fd));
266 return lang_id->version();
267 }
268