• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // JNI wrapper for the Annotator.
18 
19 #include "annotator/annotator_jni.h"
20 
21 #include <jni.h>
22 #include <type_traits>
23 #include <vector>
24 
25 #include "annotator/annotator.h"
26 #include "annotator/annotator_jni_common.h"
27 #include "annotator/types.h"
28 #include "utils/base/integral_types.h"
29 #include "utils/calendar/calendar.h"
30 #include "utils/intents/intent-generator.h"
31 #include "utils/intents/jni.h"
32 #include "utils/java/jni-cache.h"
33 #include "utils/java/scoped_local_ref.h"
34 #include "utils/java/string_utils.h"
35 #include "utils/memory/mmap.h"
36 #include "utils/strings/stringpiece.h"
37 #include "utils/utf8/unilib.h"
38 
39 #ifdef TC3_UNILIB_JAVAICU
40 #ifndef TC3_CALENDAR_JAVAICU
41 #error Inconsistent usage of Java ICU components
42 #else
43 #define TC3_USE_JAVAICU
44 #endif
45 #endif
46 
47 using libtextclassifier3::AnnotatedSpan;
48 using libtextclassifier3::Annotator;
49 using libtextclassifier3::ClassificationResult;
50 using libtextclassifier3::CodepointSpan;
51 using libtextclassifier3::Model;
52 using libtextclassifier3::ScopedLocalRef;
53 // When using the Java's ICU, CalendarLib and UniLib need to be instantiated
54 // with a JavaVM pointer from JNI. When using a standard ICU the pointer is
55 // not needed and the objects are instantiated implicitly.
56 #ifdef TC3_USE_JAVAICU
57 using libtextclassifier3::CalendarLib;
58 using libtextclassifier3::UniLib;
59 #endif
60 
61 namespace libtextclassifier3 {
62 
63 using libtextclassifier3::CodepointSpan;
64 
65 namespace {
66 class AnnotatorJniContext {
67  public:
Create(const std::shared_ptr<libtextclassifier3::JniCache> & jni_cache,std::unique_ptr<Annotator> model)68   static AnnotatorJniContext* Create(
69       const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
70       std::unique_ptr<Annotator> model) {
71     if (jni_cache == nullptr || model == nullptr) {
72       return nullptr;
73     }
74     std::unique_ptr<IntentGenerator> intent_generator =
75         IntentGenerator::Create(model->model()->intent_options(),
76                                 model->model()->resources(), jni_cache);
77     std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
78         libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
79     if (template_handler == nullptr) {
80       return nullptr;
81     }
82     return new AnnotatorJniContext(jni_cache, std::move(model),
83                                    std::move(intent_generator),
84                                    std::move(template_handler));
85   }
86 
jni_cache() const87   std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
88     return jni_cache_;
89   }
90 
model() const91   Annotator* model() const { return model_.get(); }
92 
intent_generator() const93   IntentGenerator* intent_generator() const { return intent_generator_.get(); }
94 
template_handler() const95   RemoteActionTemplatesHandler* template_handler() const {
96     return template_handler_.get();
97   }
98 
99  private:
AnnotatorJniContext(const std::shared_ptr<libtextclassifier3::JniCache> & jni_cache,std::unique_ptr<Annotator> model,std::unique_ptr<IntentGenerator> intent_generator,std::unique_ptr<RemoteActionTemplatesHandler> template_handler)100   AnnotatorJniContext(
101       const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
102       std::unique_ptr<Annotator> model,
103       std::unique_ptr<IntentGenerator> intent_generator,
104       std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
105       : jni_cache_(jni_cache),
106         model_(std::move(model)),
107         intent_generator_(std::move(intent_generator)),
108         template_handler_(std::move(template_handler)) {}
109 
110   std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
111   std::unique_ptr<Annotator> model_;
112   std::unique_ptr<IntentGenerator> intent_generator_;
113   std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
114 };
115 
ClassificationResultWithIntentsToJObject(JNIEnv * env,const AnnotatorJniContext * model_context,jobject app_context,jclass result_class,jmethodID result_class_constructor,jclass datetime_parse_class,jmethodID datetime_parse_class_constructor,const jstring device_locales,const ClassificationOptions * options,const std::string & context,const CodepointSpan & selection_indices,const ClassificationResult & classification_result,bool generate_intents)116 jobject ClassificationResultWithIntentsToJObject(
117     JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
118     jclass result_class, jmethodID result_class_constructor,
119     jclass datetime_parse_class, jmethodID datetime_parse_class_constructor,
120     const jstring device_locales, const ClassificationOptions* options,
121     const std::string& context, const CodepointSpan& selection_indices,
122     const ClassificationResult& classification_result, bool generate_intents) {
123   jstring row_string =
124       env->NewStringUTF(classification_result.collection.c_str());
125 
126   jobject row_datetime_parse = nullptr;
127   if (classification_result.datetime_parse_result.IsSet()) {
128     row_datetime_parse =
129         env->NewObject(datetime_parse_class, datetime_parse_class_constructor,
130                        classification_result.datetime_parse_result.time_ms_utc,
131                        classification_result.datetime_parse_result.granularity);
132   }
133 
134   jbyteArray serialized_knowledge_result = nullptr;
135   const std::string& serialized_knowledge_result_string =
136       classification_result.serialized_knowledge_result;
137   if (!serialized_knowledge_result_string.empty()) {
138     serialized_knowledge_result =
139         env->NewByteArray(serialized_knowledge_result_string.size());
140     env->SetByteArrayRegion(serialized_knowledge_result, 0,
141                             serialized_knowledge_result_string.size(),
142                             reinterpret_cast<const jbyte*>(
143                                 serialized_knowledge_result_string.data()));
144   }
145 
146   jstring contact_name = nullptr;
147   if (!classification_result.contact_name.empty()) {
148     contact_name =
149         env->NewStringUTF(classification_result.contact_name.c_str());
150   }
151 
152   jstring contact_given_name = nullptr;
153   if (!classification_result.contact_given_name.empty()) {
154     contact_given_name =
155         env->NewStringUTF(classification_result.contact_given_name.c_str());
156   }
157 
158   jstring contact_nickname = nullptr;
159   if (!classification_result.contact_nickname.empty()) {
160     contact_nickname =
161         env->NewStringUTF(classification_result.contact_nickname.c_str());
162   }
163 
164   jstring contact_email_address = nullptr;
165   if (!classification_result.contact_email_address.empty()) {
166     contact_email_address =
167         env->NewStringUTF(classification_result.contact_email_address.c_str());
168   }
169 
170   jstring contact_phone_number = nullptr;
171   if (!classification_result.contact_phone_number.empty()) {
172     contact_phone_number =
173         env->NewStringUTF(classification_result.contact_phone_number.c_str());
174   }
175 
176   jstring contact_id = nullptr;
177   if (!classification_result.contact_id.empty()) {
178     contact_id = env->NewStringUTF(classification_result.contact_id.c_str());
179   }
180 
181   jstring app_name = nullptr;
182   if (!classification_result.app_name.empty()) {
183     app_name = env->NewStringUTF(classification_result.app_name.c_str());
184   }
185 
186   jstring app_package_name = nullptr;
187   if (!classification_result.app_package_name.empty()) {
188     app_package_name =
189         env->NewStringUTF(classification_result.app_package_name.c_str());
190   }
191 
192   jobject extras = nullptr;
193   if (model_context->model()->entity_data_schema() != nullptr &&
194       !classification_result.serialized_entity_data.empty()) {
195     extras = model_context->template_handler()->EntityDataAsNamedVariantArray(
196         model_context->model()->entity_data_schema(),
197         classification_result.serialized_entity_data);
198   }
199 
200   jbyteArray serialized_entity_data = nullptr;
201   if (!classification_result.serialized_entity_data.empty()) {
202     serialized_entity_data =
203         env->NewByteArray(classification_result.serialized_entity_data.size());
204     env->SetByteArrayRegion(
205         serialized_entity_data, 0,
206         classification_result.serialized_entity_data.size(),
207         reinterpret_cast<const jbyte*>(
208             classification_result.serialized_entity_data.data()));
209   }
210 
211   jobject remote_action_templates_result = nullptr;
212   // Only generate RemoteActionTemplate for the top classification result
213   // as classifyText does not need RemoteAction from other results anyway.
214   if (generate_intents && model_context->intent_generator() != nullptr) {
215     std::vector<RemoteActionTemplate> remote_action_templates;
216     if (model_context->intent_generator()->GenerateIntents(
217             device_locales, classification_result,
218             options->reference_time_ms_utc, context, selection_indices,
219             app_context, model_context->model()->entity_data_schema(),
220             &remote_action_templates)) {
221       remote_action_templates_result =
222           model_context->template_handler()
223               ->RemoteActionTemplatesToJObjectArray(remote_action_templates);
224     }
225   }
226 
227   return env->NewObject(
228       result_class, result_class_constructor, row_string,
229       static_cast<jfloat>(classification_result.score), row_datetime_parse,
230       serialized_knowledge_result, contact_name, contact_given_name,
231       contact_nickname, contact_email_address, contact_phone_number, contact_id,
232       app_name, app_package_name, extras, serialized_entity_data,
233       remote_action_templates_result, classification_result.duration_ms,
234       classification_result.numeric_value);
235 }
236 
ClassificationResultsWithIntentsToJObjectArray(JNIEnv * env,const AnnotatorJniContext * model_context,jobject app_context,const jstring device_locales,const ClassificationOptions * options,const std::string & context,const CodepointSpan & selection_indices,const std::vector<ClassificationResult> & classification_result,bool generate_intents)237 jobjectArray ClassificationResultsWithIntentsToJObjectArray(
238     JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
239     const jstring device_locales, const ClassificationOptions* options,
240     const std::string& context, const CodepointSpan& selection_indices,
241     const std::vector<ClassificationResult>& classification_result,
242     bool generate_intents) {
243   const ScopedLocalRef<jclass> result_class(
244       env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
245                      "$ClassificationResult"),
246       env);
247   if (!result_class) {
248     TC3_LOG(ERROR) << "Couldn't find ClassificationResult class.";
249     return nullptr;
250   }
251   const ScopedLocalRef<jclass> datetime_parse_class(
252       env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
253                      "$DatetimeResult"),
254       env);
255   if (!datetime_parse_class) {
256     TC3_LOG(ERROR) << "Couldn't find DatetimeResult class.";
257     return nullptr;
258   }
259 
260   const jmethodID result_class_constructor = env->GetMethodID(
261       result_class.get(), "<init>",
262       "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
263       "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
264       "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;"
265       "Ljava/lang/String;[L" TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR
266       ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
267       ";JJ)V");
268   const jmethodID datetime_parse_class_constructor =
269       env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
270 
271   const jobjectArray results = env->NewObjectArray(classification_result.size(),
272                                                    result_class.get(), nullptr);
273   for (int i = 0; i < classification_result.size(); i++) {
274     jobject result = ClassificationResultWithIntentsToJObject(
275         env, model_context, app_context, result_class.get(),
276         result_class_constructor, datetime_parse_class.get(),
277         datetime_parse_class_constructor, device_locales, options, context,
278         selection_indices, classification_result[i],
279         generate_intents && (i == 0));
280     env->SetObjectArrayElement(results, i, result);
281     env->DeleteLocalRef(result);
282   }
283   return results;
284 }
285 
ClassificationResultsToJObjectArray(JNIEnv * env,const AnnotatorJniContext * model_context,const std::vector<ClassificationResult> & classification_result)286 jobjectArray ClassificationResultsToJObjectArray(
287     JNIEnv* env, const AnnotatorJniContext* model_context,
288     const std::vector<ClassificationResult>& classification_result) {
289   return ClassificationResultsWithIntentsToJObjectArray(
290       env, model_context,
291       /*(unused) app_context=*/nullptr,
292       /*(unused) devide_locale=*/nullptr,
293       /*(unusued) options=*/nullptr,
294       /*(unused) selection_text=*/"",
295       /*(unused) selection_indices=*/{kInvalidIndex, kInvalidIndex},
296       classification_result,
297       /*generate_intents=*/false);
298 }
299 
ConvertIndicesBMPUTF8(const std::string & utf8_str,CodepointSpan orig_indices,bool from_utf8)300 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
301                                     CodepointSpan orig_indices,
302                                     bool from_utf8) {
303   const libtextclassifier3::UnicodeText unicode_str =
304       libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
305 
306   int unicode_index = 0;
307   int bmp_index = 0;
308 
309   const int* source_index;
310   const int* target_index;
311   if (from_utf8) {
312     source_index = &unicode_index;
313     target_index = &bmp_index;
314   } else {
315     source_index = &bmp_index;
316     target_index = &unicode_index;
317   }
318 
319   CodepointSpan result{-1, -1};
320   std::function<void()> assign_indices_fn = [&result, &orig_indices,
321                                              &source_index, &target_index]() {
322     if (orig_indices.first == *source_index) {
323       result.first = *target_index;
324     }
325 
326     if (orig_indices.second == *source_index) {
327       result.second = *target_index;
328     }
329   };
330 
331   for (auto it = unicode_str.begin(); it != unicode_str.end();
332        ++it, ++unicode_index, ++bmp_index) {
333     assign_indices_fn();
334 
335     // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
336     if (*it > 0xFFFF) {
337       ++bmp_index;
338     }
339   }
340   assign_indices_fn();
341 
342   return result;
343 }
344 
345 }  // namespace
346 
ConvertIndicesBMPToUTF8(const std::string & utf8_str,CodepointSpan bmp_indices)347 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
348                                       CodepointSpan bmp_indices) {
349   return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
350 }
351 
ConvertIndicesUTF8ToBMP(const std::string & utf8_str,CodepointSpan utf8_indices)352 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
353                                       CodepointSpan utf8_indices) {
354   return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
355 }
356 
GetLocalesFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)357 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
358   if (!mmap->handle().ok()) {
359     return env->NewStringUTF("");
360   }
361   const Model* model = libtextclassifier3::ViewModel(
362       mmap->handle().start(), mmap->handle().num_bytes());
363   if (!model || !model->locales()) {
364     return env->NewStringUTF("");
365   }
366   return env->NewStringUTF(model->locales()->c_str());
367 }
368 
GetVersionFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)369 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
370   if (!mmap->handle().ok()) {
371     return 0;
372   }
373   const Model* model = libtextclassifier3::ViewModel(
374       mmap->handle().start(), mmap->handle().num_bytes());
375   if (!model) {
376     return 0;
377   }
378   return model->version();
379 }
380 
GetNameFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)381 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
382   if (!mmap->handle().ok()) {
383     return env->NewStringUTF("");
384   }
385   const Model* model = libtextclassifier3::ViewModel(
386       mmap->handle().start(), mmap->handle().num_bytes());
387   if (!model || !model->name()) {
388     return env->NewStringUTF("");
389   }
390   return env->NewStringUTF(model->name()->c_str());
391 }
392 
393 }  // namespace libtextclassifier3
394 
395 using libtextclassifier3::AnnotatorJniContext;
396 using libtextclassifier3::ClassificationResultsToJObjectArray;
397 using libtextclassifier3::ClassificationResultsWithIntentsToJObjectArray;
398 using libtextclassifier3::ConvertIndicesBMPToUTF8;
399 using libtextclassifier3::ConvertIndicesUTF8ToBMP;
400 using libtextclassifier3::FromJavaAnnotationOptions;
401 using libtextclassifier3::FromJavaClassificationOptions;
402 using libtextclassifier3::FromJavaSelectionOptions;
403 using libtextclassifier3::ToStlString;
404 
TC3_JNI_METHOD(jlong,TC3_ANNOTATOR_CLASS_NAME,nativeNewAnnotator)405 TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
406 (JNIEnv* env, jobject thiz, jint fd) {
407   std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
408       libtextclassifier3::JniCache::Create(env));
409 #ifdef TC3_USE_JAVAICU
410   return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
411       jni_cache,
412       Annotator::FromFileDescriptor(
413           fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
414           std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
415 #else
416   return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
417       jni_cache, Annotator::FromFileDescriptor(fd)));
418 #endif
419 }
420 
TC3_JNI_METHOD(jlong,TC3_ANNOTATOR_CLASS_NAME,nativeNewAnnotatorFromPath)421 TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
422 (JNIEnv* env, jobject thiz, jstring path) {
423   const std::string path_str = ToStlString(env, path);
424   std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
425       libtextclassifier3::JniCache::Create(env));
426 #ifdef TC3_USE_JAVAICU
427   return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
428       jni_cache,
429       Annotator::FromPath(
430           path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
431           std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
432 #else
433   return reinterpret_cast<jlong>(
434       AnnotatorJniContext::Create(jni_cache, Annotator::FromPath(path_str)));
435 #endif
436 }
437 
TC3_JNI_METHOD(jlong,TC3_ANNOTATOR_CLASS_NAME,nativeNewAnnotatorFromAssetFileDescriptor)438 TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
439                nativeNewAnnotatorFromAssetFileDescriptor)
440 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
441   std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
442       libtextclassifier3::JniCache::Create(env));
443   const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
444 #ifdef TC3_USE_JAVAICU
445   return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
446       jni_cache,
447       Annotator::FromFileDescriptor(
448           fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
449           std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
450 #else
451   return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
452       jni_cache, Annotator::FromFileDescriptor(fd, offset, size)));
453 #endif
454 }
455 
TC3_JNI_METHOD(jboolean,TC3_ANNOTATOR_CLASS_NAME,nativeInitializeKnowledgeEngine)456 TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
457                nativeInitializeKnowledgeEngine)
458 (JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
459   if (!ptr) {
460     return false;
461   }
462 
463   Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
464 
465   std::string serialized_config_string;
466   const int length = env->GetArrayLength(serialized_config);
467   serialized_config_string.resize(length);
468   env->GetByteArrayRegion(serialized_config, 0, length,
469                           reinterpret_cast<jbyte*>(const_cast<char*>(
470                               serialized_config_string.data())));
471 
472   return model->InitializeKnowledgeEngine(serialized_config_string);
473 }
474 
TC3_JNI_METHOD(jboolean,TC3_ANNOTATOR_CLASS_NAME,nativeInitializeContactEngine)475 TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
476                nativeInitializeContactEngine)
477 (JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
478   if (!ptr) {
479     return false;
480   }
481 
482   Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
483 
484   std::string serialized_config_string;
485   const int length = env->GetArrayLength(serialized_config);
486   serialized_config_string.resize(length);
487   env->GetByteArrayRegion(serialized_config, 0, length,
488                           reinterpret_cast<jbyte*>(const_cast<char*>(
489                               serialized_config_string.data())));
490 
491   return model->InitializeContactEngine(serialized_config_string);
492 }
493 
TC3_JNI_METHOD(jboolean,TC3_ANNOTATOR_CLASS_NAME,nativeInitializeInstalledAppEngine)494 TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
495                nativeInitializeInstalledAppEngine)
496 (JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
497   if (!ptr) {
498     return false;
499   }
500 
501   Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
502 
503   std::string serialized_config_string;
504   const int length = env->GetArrayLength(serialized_config);
505   serialized_config_string.resize(length);
506   env->GetByteArrayRegion(serialized_config, 0, length,
507                           reinterpret_cast<jbyte*>(const_cast<char*>(
508                               serialized_config_string.data())));
509 
510   return model->InitializeInstalledAppEngine(serialized_config_string);
511 }
512 
TC3_JNI_METHOD(jlong,TC3_ANNOTATOR_CLASS_NAME,nativeGetNativeModelPtr)513 TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeGetNativeModelPtr)
514 (JNIEnv* env, jobject thiz, jlong ptr) {
515   if (!ptr) {
516     return 0L;
517   }
518   return reinterpret_cast<jlong>(
519       reinterpret_cast<AnnotatorJniContext*>(ptr)->model());
520 }
521 
TC3_JNI_METHOD(jintArray,TC3_ANNOTATOR_CLASS_NAME,nativeSuggestSelection)522 TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
523 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
524  jint selection_end, jobject options) {
525   if (!ptr) {
526     return nullptr;
527   }
528   const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
529   const std::string context_utf8 = ToStlString(env, context);
530   CodepointSpan input_indices =
531       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
532   CodepointSpan selection = model->SuggestSelection(
533       context_utf8, input_indices, FromJavaSelectionOptions(env, options));
534   selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
535 
536   jintArray result = env->NewIntArray(2);
537   env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
538   env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
539   return result;
540 }
541 
TC3_JNI_METHOD(jobjectArray,TC3_ANNOTATOR_CLASS_NAME,nativeClassifyText)542 TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
543 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
544  jint selection_end, jobject options, jobject app_context,
545  jstring device_locales) {
546   if (!ptr) {
547     return nullptr;
548   }
549   const AnnotatorJniContext* model_context =
550       reinterpret_cast<AnnotatorJniContext*>(ptr);
551 
552   const std::string context_utf8 = ToStlString(env, context);
553   const CodepointSpan input_indices =
554       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
555   const libtextclassifier3::ClassificationOptions classification_options =
556       FromJavaClassificationOptions(env, options);
557   const std::vector<ClassificationResult> classification_result =
558       model_context->model()->ClassifyText(context_utf8, input_indices,
559                                            classification_options);
560   if (app_context != nullptr) {
561     return ClassificationResultsWithIntentsToJObjectArray(
562         env, model_context, app_context, device_locales,
563         &classification_options, context_utf8, input_indices,
564         classification_result,
565         /*generate_intents=*/true);
566   }
567   return ClassificationResultsToJObjectArray(env, model_context,
568                                              classification_result);
569 }
570 
TC3_JNI_METHOD(jobjectArray,TC3_ANNOTATOR_CLASS_NAME,nativeAnnotate)571 TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
572 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
573   if (!ptr) {
574     return nullptr;
575   }
576   const AnnotatorJniContext* model_context =
577       reinterpret_cast<AnnotatorJniContext*>(ptr);
578   const std::string context_utf8 = ToStlString(env, context);
579   const std::vector<AnnotatedSpan> annotations =
580       model_context->model()->Annotate(context_utf8,
581                                        FromJavaAnnotationOptions(env, options));
582 
583   jclass result_class = env->FindClass(
584       TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan");
585   if (!result_class) {
586     TC3_LOG(ERROR) << "Couldn't find result class: "
587                    << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
588         "$AnnotatedSpan";
589     return nullptr;
590   }
591 
592   jmethodID result_class_constructor =
593       env->GetMethodID(result_class, "<init>",
594                        "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
595                        "$ClassificationResult;)V");
596 
597   jobjectArray results =
598       env->NewObjectArray(annotations.size(), result_class, nullptr);
599 
600   for (int i = 0; i < annotations.size(); ++i) {
601     CodepointSpan span_bmp =
602         ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
603     jobject result = env->NewObject(
604         result_class, result_class_constructor,
605         static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
606         ClassificationResultsToJObjectArray(env, model_context,
607                                             annotations[i].classification));
608     env->SetObjectArrayElement(results, i, result);
609     env->DeleteLocalRef(result);
610   }
611   env->DeleteLocalRef(result_class);
612   return results;
613 }
614 
TC3_JNI_METHOD(jbyteArray,TC3_ANNOTATOR_CLASS_NAME,nativeLookUpKnowledgeEntity)615 TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
616                nativeLookUpKnowledgeEntity)
617 (JNIEnv* env, jobject thiz, jlong ptr, jstring id) {
618   if (!ptr) {
619     return nullptr;
620   }
621   const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
622   const std::string id_utf8 = ToStlString(env, id);
623   std::string serialized_knowledge_result;
624   if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
625     return nullptr;
626   }
627   jbyteArray result = env->NewByteArray(serialized_knowledge_result.size());
628   env->SetByteArrayRegion(
629       result, 0, serialized_knowledge_result.size(),
630       reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
631   return result;
632 }
633 
TC3_JNI_METHOD(void,TC3_ANNOTATOR_CLASS_NAME,nativeCloseAnnotator)634 TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
635 (JNIEnv* env, jobject thiz, jlong ptr) {
636   const AnnotatorJniContext* context =
637       reinterpret_cast<AnnotatorJniContext*>(ptr);
638   delete context;
639 }
640 
TC3_JNI_METHOD(jstring,TC3_ANNOTATOR_CLASS_NAME,nativeGetLanguage)641 TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
642 (JNIEnv* env, jobject clazz, jint fd) {
643   TC3_LOG(WARNING) << "Using deprecated getLanguage().";
644   return TC3_JNI_METHOD_NAME(TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)(
645       env, clazz, fd);
646 }
647 
TC3_JNI_METHOD(jstring,TC3_ANNOTATOR_CLASS_NAME,nativeGetLocales)648 TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
649 (JNIEnv* env, jobject clazz, jint fd) {
650   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
651       new libtextclassifier3::ScopedMmap(fd));
652   return GetLocalesFromMmap(env, mmap.get());
653 }
654 
TC3_JNI_METHOD(jstring,TC3_ANNOTATOR_CLASS_NAME,nativeGetLocalesFromAssetFileDescriptor)655 TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
656                nativeGetLocalesFromAssetFileDescriptor)
657 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
658   const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
659   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
660       new libtextclassifier3::ScopedMmap(fd, offset, size));
661   return GetLocalesFromMmap(env, mmap.get());
662 }
663 
TC3_JNI_METHOD(jint,TC3_ANNOTATOR_CLASS_NAME,nativeGetVersion)664 TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
665 (JNIEnv* env, jobject clazz, jint fd) {
666   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
667       new libtextclassifier3::ScopedMmap(fd));
668   return GetVersionFromMmap(env, mmap.get());
669 }
670 
TC3_JNI_METHOD(jint,TC3_ANNOTATOR_CLASS_NAME,nativeGetVersionFromAssetFileDescriptor)671 TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME,
672                nativeGetVersionFromAssetFileDescriptor)
673 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
674   const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
675   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
676       new libtextclassifier3::ScopedMmap(fd, offset, size));
677   return GetVersionFromMmap(env, mmap.get());
678 }
679 
TC3_JNI_METHOD(jstring,TC3_ANNOTATOR_CLASS_NAME,nativeGetName)680 TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
681 (JNIEnv* env, jobject clazz, jint fd) {
682   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
683       new libtextclassifier3::ScopedMmap(fd));
684   return GetNameFromMmap(env, mmap.get());
685 }
686 
TC3_JNI_METHOD(jstring,TC3_ANNOTATOR_CLASS_NAME,nativeGetNameFromAssetFileDescriptor)687 TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
688                nativeGetNameFromAssetFileDescriptor)
689 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
690   const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
691   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
692       new libtextclassifier3::ScopedMmap(fd, offset, size));
693   return GetNameFromMmap(env, mmap.get());
694 }
695