1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // JNI wrapper for actions.
18
19 #include "actions/actions_jni.h"
20
21 #include <jni.h>
22 #include <map>
23 #include <type_traits>
24 #include <vector>
25
26 #include "actions/actions-suggestions.h"
27 #include "annotator/annotator.h"
28 #include "annotator/annotator_jni_common.h"
29 #include "utils/base/integral_types.h"
30 #include "utils/intents/intent-generator.h"
31 #include "utils/intents/jni.h"
32 #include "utils/java/jni-cache.h"
33 #include "utils/java/scoped_local_ref.h"
34 #include "utils/java/string_utils.h"
35 #include "utils/memory/mmap.h"
36
37 using libtextclassifier3::ActionsSuggestions;
38 using libtextclassifier3::ActionsSuggestionsResponse;
39 using libtextclassifier3::ActionSuggestion;
40 using libtextclassifier3::ActionSuggestionOptions;
41 using libtextclassifier3::Annotator;
42 using libtextclassifier3::Conversation;
43 using libtextclassifier3::IntentGenerator;
44 using libtextclassifier3::ScopedLocalRef;
45 using libtextclassifier3::ToStlString;
46
47 // When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
48 // pointer from JNI. When using a standard ICU the pointer is not needed and the
49 // objects are instantiated implicitly.
50 #ifdef TC3_UNILIB_JAVAICU
51 using libtextclassifier3::UniLib;
52 #endif
53
54 namespace libtextclassifier3 {
55
56 namespace {
57
58 // Cached state for model inference.
59 // Keeps a jni cache, intent generator and model instance so that they don't
60 // have to be recreated for each call.
61 class ActionsSuggestionsJniContext {
62 public:
Create(const std::shared_ptr<libtextclassifier3::JniCache> & jni_cache,std::unique_ptr<ActionsSuggestions> model)63 static ActionsSuggestionsJniContext* Create(
64 const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
65 std::unique_ptr<ActionsSuggestions> model) {
66 if (jni_cache == nullptr || model == nullptr) {
67 return nullptr;
68 }
69 std::unique_ptr<IntentGenerator> intent_generator =
70 IntentGenerator::Create(model->model()->android_intent_options(),
71 model->model()->resources(), jni_cache);
72 std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
73 libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
74
75 if (intent_generator == nullptr || template_handler == nullptr) {
76 return nullptr;
77 }
78
79 return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
80 std::move(intent_generator),
81 std::move(template_handler));
82 }
83
jni_cache() const84 std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
85 return jni_cache_;
86 }
87
model() const88 ActionsSuggestions* model() const { return model_.get(); }
89
intent_generator() const90 IntentGenerator* intent_generator() const { return intent_generator_.get(); }
91
template_handler() const92 RemoteActionTemplatesHandler* template_handler() const {
93 return template_handler_.get();
94 }
95
96 private:
ActionsSuggestionsJniContext(const std::shared_ptr<libtextclassifier3::JniCache> & jni_cache,std::unique_ptr<ActionsSuggestions> model,std::unique_ptr<IntentGenerator> intent_generator,std::unique_ptr<RemoteActionTemplatesHandler> template_handler)97 ActionsSuggestionsJniContext(
98 const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
99 std::unique_ptr<ActionsSuggestions> model,
100 std::unique_ptr<IntentGenerator> intent_generator,
101 std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
102 : jni_cache_(jni_cache),
103 model_(std::move(model)),
104 intent_generator_(std::move(intent_generator)),
105 template_handler_(std::move(template_handler)) {}
106
107 std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
108 std::unique_ptr<ActionsSuggestions> model_;
109 std::unique_ptr<IntentGenerator> intent_generator_;
110 std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
111 };
112
FromJavaActionSuggestionOptions(JNIEnv * env,jobject joptions)113 ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
114 jobject joptions) {
115 ActionSuggestionOptions options = ActionSuggestionOptions::Default();
116 return options;
117 }
118
ActionSuggestionsToJObjectArray(JNIEnv * env,const ActionsSuggestionsJniContext * context,jobject app_context,const reflection::Schema * annotations_entity_data_schema,const std::vector<ActionSuggestion> & action_result,const Conversation & conversation,const jstring device_locales,const bool generate_intents)119 jobjectArray ActionSuggestionsToJObjectArray(
120 JNIEnv* env, const ActionsSuggestionsJniContext* context,
121 jobject app_context,
122 const reflection::Schema* annotations_entity_data_schema,
123 const std::vector<ActionSuggestion>& action_result,
124 const Conversation& conversation, const jstring device_locales,
125 const bool generate_intents) {
126 const ScopedLocalRef<jclass> result_class(
127 env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
128 "$ActionSuggestion"),
129 env);
130 if (!result_class) {
131 TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
132 return nullptr;
133 }
134
135 const jmethodID result_class_constructor = env->GetMethodID(
136 result_class.get(), "<init>",
137 "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
138 TC3_NAMED_VARIANT_CLASS_NAME_STR
139 ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
140 const jobjectArray results =
141 env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
142 for (int i = 0; i < action_result.size(); i++) {
143 jobject extras = nullptr;
144
145 const reflection::Schema* actions_entity_data_schema =
146 context->model()->entity_data_schema();
147 if (actions_entity_data_schema != nullptr &&
148 !action_result[i].serialized_entity_data.empty()) {
149 extras = context->template_handler()->EntityDataAsNamedVariantArray(
150 actions_entity_data_schema, action_result[i].serialized_entity_data);
151 }
152
153 jbyteArray serialized_entity_data = nullptr;
154 if (!action_result[i].serialized_entity_data.empty()) {
155 serialized_entity_data =
156 env->NewByteArray(action_result[i].serialized_entity_data.size());
157 env->SetByteArrayRegion(
158 serialized_entity_data, 0,
159 action_result[i].serialized_entity_data.size(),
160 reinterpret_cast<const jbyte*>(
161 action_result[i].serialized_entity_data.data()));
162 }
163
164 jobject remote_action_templates_result = nullptr;
165 if (generate_intents) {
166 std::vector<RemoteActionTemplate> remote_action_templates;
167 if (context->intent_generator()->GenerateIntents(
168 device_locales, action_result[i], conversation, app_context,
169 actions_entity_data_schema, annotations_entity_data_schema,
170 &remote_action_templates)) {
171 remote_action_templates_result =
172 context->template_handler()->RemoteActionTemplatesToJObjectArray(
173 remote_action_templates);
174 }
175 }
176
177 ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
178 action_result[i].response_text);
179
180 ScopedLocalRef<jobject> result(env->NewObject(
181 result_class.get(), result_class_constructor, reply.get(),
182 env->NewStringUTF(action_result[i].type.c_str()),
183 static_cast<jfloat>(action_result[i].score), extras,
184 serialized_entity_data, remote_action_templates_result));
185 env->SetObjectArrayElement(results, i, result.get());
186 }
187 return results;
188 }
189
FromJavaConversationMessage(JNIEnv * env,jobject jmessage)190 ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
191 if (!jmessage) {
192 return {};
193 }
194
195 const ScopedLocalRef<jclass> message_class(
196 env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
197 "$ConversationMessage"),
198 env);
199 const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
200 env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
201 "Ljava/lang/String;");
202 const std::pair<bool, int32> status_or_user_id =
203 CallJniMethod0<int32>(env, jmessage, message_class.get(),
204 &JNIEnv::CallIntMethod, "getUserId", "I");
205 const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
206 env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
207 "getReferenceTimeMsUtc", "J");
208 const std::pair<bool, jobject> status_or_reference_timezone =
209 CallJniMethod0<jobject>(env, jmessage, message_class.get(),
210 &JNIEnv::CallObjectMethod, "getReferenceTimezone",
211 "Ljava/lang/String;");
212 const std::pair<bool, jobject> status_or_detected_text_language_tags =
213 CallJniMethod0<jobject>(
214 env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
215 "getDetectedTextLanguageTags", "Ljava/lang/String;");
216 if (!status_or_text.first || !status_or_user_id.first ||
217 !status_or_detected_text_language_tags.first ||
218 !status_or_reference_time.first || !status_or_reference_timezone.first) {
219 return {};
220 }
221
222 ConversationMessage message;
223 message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
224 message.user_id = status_or_user_id.second;
225 message.reference_time_ms_utc = status_or_reference_time.second;
226 message.reference_timezone = ToStlString(
227 env, static_cast<jstring>(status_or_reference_timezone.second));
228 message.detected_text_language_tags = ToStlString(
229 env, static_cast<jstring>(status_or_detected_text_language_tags.second));
230 return message;
231 }
232
FromJavaConversation(JNIEnv * env,jobject jconversation)233 Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
234 if (!jconversation) {
235 return {};
236 }
237
238 const ScopedLocalRef<jclass> conversation_class(
239 env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
240 "$Conversation"),
241 env);
242
243 const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
244 env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
245 "getConversationMessages",
246 "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
247
248 if (!status_or_messages.first) {
249 return {};
250 }
251
252 const jobjectArray jmessages =
253 reinterpret_cast<jobjectArray>(status_or_messages.second);
254
255 const int size = env->GetArrayLength(jmessages);
256
257 std::vector<ConversationMessage> messages;
258 for (int i = 0; i < size; i++) {
259 jobject jmessage = env->GetObjectArrayElement(jmessages, i);
260 ConversationMessage message = FromJavaConversationMessage(env, jmessage);
261 messages.push_back(message);
262 }
263 Conversation conversation;
264 conversation.messages = messages;
265 return conversation;
266 }
267
GetLocalesFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)268 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
269 if (!mmap->handle().ok()) {
270 return env->NewStringUTF("");
271 }
272 const ActionsModel* model = libtextclassifier3::ViewActionsModel(
273 mmap->handle().start(), mmap->handle().num_bytes());
274 if (!model || !model->locales()) {
275 return env->NewStringUTF("");
276 }
277 return env->NewStringUTF(model->locales()->c_str());
278 }
279
GetVersionFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)280 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
281 if (!mmap->handle().ok()) {
282 return 0;
283 }
284 const ActionsModel* model = libtextclassifier3::ViewActionsModel(
285 mmap->handle().start(), mmap->handle().num_bytes());
286 if (!model) {
287 return 0;
288 }
289 return model->version();
290 }
291
GetNameFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)292 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
293 if (!mmap->handle().ok()) {
294 return env->NewStringUTF("");
295 }
296 const ActionsModel* model = libtextclassifier3::ViewActionsModel(
297 mmap->handle().start(), mmap->handle().num_bytes());
298 if (!model || !model->name()) {
299 return env->NewStringUTF("");
300 }
301 return env->NewStringUTF(model->name()->c_str());
302 }
303 } // namespace
304 } // namespace libtextclassifier3
305
306 using libtextclassifier3::ActionsSuggestionsJniContext;
307 using libtextclassifier3::ActionSuggestionsToJObjectArray;
308 using libtextclassifier3::FromJavaActionSuggestionOptions;
309 using libtextclassifier3::FromJavaConversation;
310
TC3_JNI_METHOD(jlong,TC3_ACTIONS_CLASS_NAME,nativeNewActionsModel)311 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
312 (JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
313 std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
314 libtextclassifier3::JniCache::Create(env);
315 std::string preconditions;
316 if (serialized_preconditions != nullptr &&
317 !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
318 &preconditions)) {
319 TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
320 return 0;
321 }
322 #ifdef TC3_UNILIB_JAVAICU
323 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
324 jni_cache,
325 ActionsSuggestions::FromFileDescriptor(
326 fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
327 #else
328 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
329 jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
330 preconditions)));
331 #endif // TC3_UNILIB_JAVAICU
332 }
333
TC3_JNI_METHOD(jlong,TC3_ACTIONS_CLASS_NAME,nativeNewActionsModelFromPath)334 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
335 (JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
336 std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
337 libtextclassifier3::JniCache::Create(env);
338 const std::string path_str = ToStlString(env, path);
339 std::string preconditions;
340 if (serialized_preconditions != nullptr &&
341 !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
342 &preconditions)) {
343 TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
344 return 0;
345 }
346 #ifdef TC3_UNILIB_JAVAICU
347 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
348 jni_cache, ActionsSuggestions::FromPath(
349 path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
350 preconditions)));
351 #else
352 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
353 jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
354 preconditions)));
355 #endif // TC3_UNILIB_JAVAICU
356 }
357
TC3_JNI_METHOD(jobjectArray,TC3_ACTIONS_CLASS_NAME,nativeSuggestActions)358 TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
359 (JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
360 jlong annotatorPtr, jobject app_context, jstring device_locales,
361 jboolean generate_intents) {
362 if (!ptr) {
363 return nullptr;
364 }
365 const Conversation conversation = FromJavaConversation(env, jconversation);
366 const ActionSuggestionOptions options =
367 FromJavaActionSuggestionOptions(env, joptions);
368 const ActionsSuggestionsJniContext* context =
369 reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
370 const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
371
372 const ActionsSuggestionsResponse response =
373 context->model()->SuggestActions(conversation, annotator, options);
374
375 const reflection::Schema* anntotations_entity_data_schema =
376 annotator ? annotator->entity_data_schema() : nullptr;
377 return ActionSuggestionsToJObjectArray(
378 env, context, app_context, anntotations_entity_data_schema,
379 response.actions, conversation, device_locales, generate_intents);
380 }
381
TC3_JNI_METHOD(void,TC3_ACTIONS_CLASS_NAME,nativeCloseActionsModel)382 TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
383 (JNIEnv* env, jobject clazz, jlong model_ptr) {
384 const ActionsSuggestionsJniContext* context =
385 reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
386 delete context;
387 }
388
TC3_JNI_METHOD(jstring,TC3_ACTIONS_CLASS_NAME,nativeGetLocales)389 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
390 (JNIEnv* env, jobject clazz, jint fd) {
391 const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
392 new libtextclassifier3::ScopedMmap(fd));
393 return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
394 }
395
TC3_JNI_METHOD(jstring,TC3_ACTIONS_CLASS_NAME,nativeGetName)396 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
397 (JNIEnv* env, jobject clazz, jint fd) {
398 const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
399 new libtextclassifier3::ScopedMmap(fd));
400 return libtextclassifier3::GetNameFromMmap(env, mmap.get());
401 }
402
TC3_JNI_METHOD(jint,TC3_ACTIONS_CLASS_NAME,nativeGetVersion)403 TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
404 (JNIEnv* env, jobject clazz, jint fd) {
405 const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
406 new libtextclassifier3::ScopedMmap(fd));
407 return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
408 }
409