1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // JNI wrapper for the TextClassifier.
18
19 #include "textclassifier_jni.h"
20
21 #include <jni.h>
22 #include <type_traits>
23 #include <vector>
24
25 #include "text-classifier.h"
26 #include "util/base/integral_types.h"
27 #include "util/java/scoped_local_ref.h"
28 #include "util/java/string_utils.h"
29 #include "util/memory/mmap.h"
30 #include "util/utf8/unilib.h"
31
32 using libtextclassifier2::AnnotatedSpan;
33 using libtextclassifier2::AnnotationOptions;
34 using libtextclassifier2::ClassificationOptions;
35 using libtextclassifier2::ClassificationResult;
36 using libtextclassifier2::CodepointSpan;
37 using libtextclassifier2::JStringToUtf8String;
38 using libtextclassifier2::Model;
39 using libtextclassifier2::ScopedLocalRef;
40 using libtextclassifier2::SelectionOptions;
41 using libtextclassifier2::TextClassifier;
42 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
43 using libtextclassifier2::UniLib;
44 #endif
45
46 namespace libtextclassifier2 {
47
48 using libtextclassifier2::CodepointSpan;
49
50 namespace {
51
ToStlString(JNIEnv * env,const jstring & str)52 std::string ToStlString(JNIEnv* env, const jstring& str) {
53 std::string result;
54 JStringToUtf8String(env, str, &result);
55 return result;
56 }
57
ClassificationResultsToJObjectArray(JNIEnv * env,const std::vector<ClassificationResult> & classification_result)58 jobjectArray ClassificationResultsToJObjectArray(
59 JNIEnv* env,
60 const std::vector<ClassificationResult>& classification_result) {
61 const ScopedLocalRef<jclass> result_class(
62 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"),
63 env);
64 if (!result_class) {
65 TC_LOG(ERROR) << "Couldn't find ClassificationResult class.";
66 return nullptr;
67 }
68 const ScopedLocalRef<jclass> datetime_parse_class(
69 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env);
70 if (!datetime_parse_class) {
71 TC_LOG(ERROR) << "Couldn't find DatetimeResult class.";
72 return nullptr;
73 }
74
75 const jmethodID result_class_constructor =
76 env->GetMethodID(result_class.get(), "<init>",
77 "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR
78 "$DatetimeResult;)V");
79 const jmethodID datetime_parse_class_constructor =
80 env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
81
82 const jobjectArray results = env->NewObjectArray(classification_result.size(),
83 result_class.get(), nullptr);
84 for (int i = 0; i < classification_result.size(); i++) {
85 jstring row_string =
86 env->NewStringUTF(classification_result[i].collection.c_str());
87 jobject row_datetime_parse = nullptr;
88 if (classification_result[i].datetime_parse_result.IsSet()) {
89 row_datetime_parse = env->NewObject(
90 datetime_parse_class.get(), datetime_parse_class_constructor,
91 classification_result[i].datetime_parse_result.time_ms_utc,
92 classification_result[i].datetime_parse_result.granularity);
93 }
94 jobject result =
95 env->NewObject(result_class.get(), result_class_constructor, row_string,
96 static_cast<jfloat>(classification_result[i].score),
97 row_datetime_parse);
98 env->SetObjectArrayElement(results, i, result);
99 env->DeleteLocalRef(result);
100 }
101 return results;
102 }
103
104 template <typename T, typename F>
CallJniMethod0(JNIEnv * env,jobject object,jclass class_object,F function,const std::string & method_name,const std::string & return_java_type)105 std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
106 jclass class_object, F function,
107 const std::string& method_name,
108 const std::string& return_java_type) {
109 const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
110 ("()" + return_java_type).c_str());
111 if (!method) {
112 return std::make_pair(false, T());
113 }
114 return std::make_pair(true, (env->*function)(object, method));
115 }
116
FromJavaSelectionOptions(JNIEnv * env,jobject joptions)117 SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
118 if (!joptions) {
119 return {};
120 }
121
122 const ScopedLocalRef<jclass> options_class(
123 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"),
124 env);
125 const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
126 env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
127 "getLocales", "Ljava/lang/String;");
128 if (!status_or_locales.first) {
129 return {};
130 }
131
132 SelectionOptions options;
133 options.locales =
134 ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
135
136 return options;
137 }
138
139 template <typename T>
FromJavaOptionsInternal(JNIEnv * env,jobject joptions,const std::string & class_name)140 T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
141 const std::string& class_name) {
142 if (!joptions) {
143 return {};
144 }
145
146 const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
147 env);
148 if (!options_class) {
149 return {};
150 }
151
152 const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
153 env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
154 "getLocale", "Ljava/lang/String;");
155 const std::pair<bool, jobject> status_or_reference_timezone =
156 CallJniMethod0<jobject>(env, joptions, options_class.get(),
157 &JNIEnv::CallObjectMethod, "getReferenceTimezone",
158 "Ljava/lang/String;");
159 const std::pair<bool, int64> status_or_reference_time_ms_utc =
160 CallJniMethod0<int64>(env, joptions, options_class.get(),
161 &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
162 "J");
163
164 if (!status_or_locales.first || !status_or_reference_timezone.first ||
165 !status_or_reference_time_ms_utc.first) {
166 return {};
167 }
168
169 T options;
170 options.locales =
171 ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
172 options.reference_timezone = ToStlString(
173 env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
174 options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
175 return options;
176 }
177
FromJavaClassificationOptions(JNIEnv * env,jobject joptions)178 ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
179 jobject joptions) {
180 return FromJavaOptionsInternal<ClassificationOptions>(
181 env, joptions,
182 TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions");
183 }
184
FromJavaAnnotationOptions(JNIEnv * env,jobject joptions)185 AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
186 return FromJavaOptionsInternal<AnnotationOptions>(
187 env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions");
188 }
189
ConvertIndicesBMPUTF8(const std::string & utf8_str,CodepointSpan orig_indices,bool from_utf8)190 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
191 CodepointSpan orig_indices,
192 bool from_utf8) {
193 const libtextclassifier2::UnicodeText unicode_str =
194 libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
195
196 int unicode_index = 0;
197 int bmp_index = 0;
198
199 const int* source_index;
200 const int* target_index;
201 if (from_utf8) {
202 source_index = &unicode_index;
203 target_index = &bmp_index;
204 } else {
205 source_index = &bmp_index;
206 target_index = &unicode_index;
207 }
208
209 CodepointSpan result{-1, -1};
210 std::function<void()> assign_indices_fn = [&result, &orig_indices,
211 &source_index, &target_index]() {
212 if (orig_indices.first == *source_index) {
213 result.first = *target_index;
214 }
215
216 if (orig_indices.second == *source_index) {
217 result.second = *target_index;
218 }
219 };
220
221 for (auto it = unicode_str.begin(); it != unicode_str.end();
222 ++it, ++unicode_index, ++bmp_index) {
223 assign_indices_fn();
224
225 // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
226 if (*it > 0xFFFF) {
227 ++bmp_index;
228 }
229 }
230 assign_indices_fn();
231
232 return result;
233 }
234
235 } // namespace
236
ConvertIndicesBMPToUTF8(const std::string & utf8_str,CodepointSpan bmp_indices)237 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
238 CodepointSpan bmp_indices) {
239 return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
240 }
241
ConvertIndicesUTF8ToBMP(const std::string & utf8_str,CodepointSpan utf8_indices)242 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
243 CodepointSpan utf8_indices) {
244 return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
245 }
246
GetFdFromAssetFileDescriptor(JNIEnv * env,jobject afd)247 jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
248 // Get system-level file descriptor from AssetFileDescriptor.
249 ScopedLocalRef<jclass> afd_class(
250 env->FindClass("android/content/res/AssetFileDescriptor"), env);
251 if (afd_class == nullptr) {
252 TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
253 return reinterpret_cast<jlong>(nullptr);
254 }
255 jmethodID afd_class_getFileDescriptor = env->GetMethodID(
256 afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
257 if (afd_class_getFileDescriptor == nullptr) {
258 TC_LOG(ERROR) << "Couldn't find getFileDescriptor.";
259 return reinterpret_cast<jlong>(nullptr);
260 }
261
262 ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
263 env);
264 if (fd_class == nullptr) {
265 TC_LOG(ERROR) << "Couldn't find FileDescriptor.";
266 return reinterpret_cast<jlong>(nullptr);
267 }
268 jfieldID fd_class_descriptor =
269 env->GetFieldID(fd_class.get(), "descriptor", "I");
270 if (fd_class_descriptor == nullptr) {
271 TC_LOG(ERROR) << "Couldn't find descriptor.";
272 return reinterpret_cast<jlong>(nullptr);
273 }
274
275 jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
276 return env->GetIntField(bundle_jfd, fd_class_descriptor);
277 }
278
GetLocalesFromMmap(JNIEnv * env,libtextclassifier2::ScopedMmap * mmap)279 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
280 if (!mmap->handle().ok()) {
281 return env->NewStringUTF("");
282 }
283 const Model* model = libtextclassifier2::ViewModel(
284 mmap->handle().start(), mmap->handle().num_bytes());
285 if (!model || !model->locales()) {
286 return env->NewStringUTF("");
287 }
288 return env->NewStringUTF(model->locales()->c_str());
289 }
290
GetVersionFromMmap(JNIEnv * env,libtextclassifier2::ScopedMmap * mmap)291 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
292 if (!mmap->handle().ok()) {
293 return 0;
294 }
295 const Model* model = libtextclassifier2::ViewModel(
296 mmap->handle().start(), mmap->handle().num_bytes());
297 if (!model) {
298 return 0;
299 }
300 return model->version();
301 }
302
GetNameFromMmap(JNIEnv * env,libtextclassifier2::ScopedMmap * mmap)303 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
304 if (!mmap->handle().ok()) {
305 return env->NewStringUTF("");
306 }
307 const Model* model = libtextclassifier2::ViewModel(
308 mmap->handle().start(), mmap->handle().num_bytes());
309 if (!model || !model->name()) {
310 return env->NewStringUTF("");
311 }
312 return env->NewStringUTF(model->name()->c_str());
313 }
314
315 } // namespace libtextclassifier2
316
317 using libtextclassifier2::ClassificationResultsToJObjectArray;
318 using libtextclassifier2::ConvertIndicesBMPToUTF8;
319 using libtextclassifier2::ConvertIndicesUTF8ToBMP;
320 using libtextclassifier2::FromJavaAnnotationOptions;
321 using libtextclassifier2::FromJavaClassificationOptions;
322 using libtextclassifier2::FromJavaSelectionOptions;
323 using libtextclassifier2::ToStlString;
324
JNI_METHOD(jlong,TC_CLASS_NAME,nativeNew)325 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
326 (JNIEnv* env, jobject thiz, jint fd) {
327 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
328 return reinterpret_cast<jlong>(
329 TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env));
330 #else
331 return reinterpret_cast<jlong>(
332 TextClassifier::FromFileDescriptor(fd).release());
333 #endif
334 }
335
JNI_METHOD(jlong,TC_CLASS_NAME,nativeNewFromPath)336 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
337 (JNIEnv* env, jobject thiz, jstring path) {
338 const std::string path_str = ToStlString(env, path);
339 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
340 return reinterpret_cast<jlong>(
341 TextClassifier::FromPath(path_str, new UniLib(env)).release());
342 #else
343 return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release());
344 #endif
345 }
346
JNI_METHOD(jlong,TC_CLASS_NAME,nativeNewFromAssetFileDescriptor)347 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
348 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
349 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
350 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
351 return reinterpret_cast<jlong>(
352 TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env))
353 .release());
354 #else
355 return reinterpret_cast<jlong>(
356 TextClassifier::FromFileDescriptor(fd, offset, size).release());
357 #endif
358 }
359
JNI_METHOD(jintArray,TC_CLASS_NAME,nativeSuggestSelection)360 JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection)
361 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
362 jint selection_end, jobject options) {
363 if (!ptr) {
364 return nullptr;
365 }
366
367 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
368
369 const std::string context_utf8 = ToStlString(env, context);
370 CodepointSpan input_indices =
371 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
372 CodepointSpan selection = model->SuggestSelection(
373 context_utf8, input_indices, FromJavaSelectionOptions(env, options));
374 selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
375
376 jintArray result = env->NewIntArray(2);
377 env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
378 env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
379 return result;
380 }
381
JNI_METHOD(jobjectArray,TC_CLASS_NAME,nativeClassifyText)382 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
383 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
384 jint selection_end, jobject options) {
385 if (!ptr) {
386 return nullptr;
387 }
388 TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr);
389
390 const std::string context_utf8 = ToStlString(env, context);
391 const CodepointSpan input_indices =
392 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
393 const std::vector<ClassificationResult> classification_result =
394 ff_model->ClassifyText(context_utf8, input_indices,
395 FromJavaClassificationOptions(env, options));
396
397 return ClassificationResultsToJObjectArray(env, classification_result);
398 }
399
JNI_METHOD(jobjectArray,TC_CLASS_NAME,nativeAnnotate)400 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
401 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
402 if (!ptr) {
403 return nullptr;
404 }
405 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
406 std::string context_utf8 = ToStlString(env, context);
407 std::vector<AnnotatedSpan> annotations =
408 model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));
409
410 jclass result_class =
411 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan");
412 if (!result_class) {
413 TC_LOG(ERROR) << "Couldn't find result class: "
414 << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan";
415 return nullptr;
416 }
417
418 jmethodID result_class_constructor = env->GetMethodID(
419 result_class, "<init>",
420 "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V");
421
422 jobjectArray results =
423 env->NewObjectArray(annotations.size(), result_class, nullptr);
424
425 for (int i = 0; i < annotations.size(); ++i) {
426 CodepointSpan span_bmp =
427 ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
428 jobject result = env->NewObject(
429 result_class, result_class_constructor,
430 static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
431 ClassificationResultsToJObjectArray(env,
432
433 annotations[i].classification));
434 env->SetObjectArrayElement(results, i, result);
435 env->DeleteLocalRef(result);
436 }
437 env->DeleteLocalRef(result_class);
438 return results;
439 }
440
JNI_METHOD(void,TC_CLASS_NAME,nativeClose)441 JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
442 (JNIEnv* env, jobject thiz, jlong ptr) {
443 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
444 delete model;
445 }
446
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetLanguage)447 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
448 (JNIEnv* env, jobject clazz, jint fd) {
449 TC_LOG(WARNING) << "Using deprecated getLanguage().";
450 return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd);
451 }
452
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetLocales)453 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales)
454 (JNIEnv* env, jobject clazz, jint fd) {
455 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
456 new libtextclassifier2::ScopedMmap(fd));
457 return GetLocalesFromMmap(env, mmap.get());
458 }
459
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetLocalesFromAssetFileDescriptor)460 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor)
461 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
462 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
463 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
464 new libtextclassifier2::ScopedMmap(fd, offset, size));
465 return GetLocalesFromMmap(env, mmap.get());
466 }
467
JNI_METHOD(jint,TC_CLASS_NAME,nativeGetVersion)468 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
469 (JNIEnv* env, jobject clazz, jint fd) {
470 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
471 new libtextclassifier2::ScopedMmap(fd));
472 return GetVersionFromMmap(env, mmap.get());
473 }
474
JNI_METHOD(jint,TC_CLASS_NAME,nativeGetVersionFromAssetFileDescriptor)475 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor)
476 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
477 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
478 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
479 new libtextclassifier2::ScopedMmap(fd, offset, size));
480 return GetVersionFromMmap(env, mmap.get());
481 }
482
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetName)483 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName)
484 (JNIEnv* env, jobject clazz, jint fd) {
485 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
486 new libtextclassifier2::ScopedMmap(fd));
487 return GetNameFromMmap(env, mmap.get());
488 }
489
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetNameFromAssetFileDescriptor)490 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor)
491 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
492 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
493 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
494 new libtextclassifier2::ScopedMmap(fd, offset, size));
495 return GetNameFromMmap(env, mmap.get());
496 }
497