1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/intents/intent-generator.h"
18
19 #include <vector>
20
21 #include "actions/lua-utils.h"
22 #include "actions/types.h"
23 #include "annotator/types.h"
24 #include "utils/base/logging.h"
25 #include "utils/hash/farmhash.h"
26 #include "utils/java/jni-base.h"
27 #include "utils/java/string_utils.h"
28 #include "utils/lua-utils.h"
29 #include "utils/strings/stringpiece.h"
30 #include "utils/strings/substitute.h"
31 #include "utils/utf8/unicodetext.h"
32 #include "utils/variant.h"
33 #include "utils/zlib/zlib.h"
34 #include "flatbuffers/reflection_generated.h"
35
36 #ifdef __cplusplus
37 extern "C" {
38 #endif
39 #include "lauxlib.h"
40 #include "lua.h"
41 #ifdef __cplusplus
42 }
43 #endif
44
45 namespace libtextclassifier3 {
46 namespace {
47
48 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
49 static constexpr const char* kHashKey = "hash";
50 static constexpr const char* kUrlSchemaKey = "url_schema";
51 static constexpr const char* kUrlHostKey = "url_host";
52 static constexpr const char* kUrlEncodeKey = "urlencode";
53 static constexpr const char* kPackageNameKey = "package_name";
54 static constexpr const char* kDeviceLocaleKey = "device_locales";
55 static constexpr const char* kFormatKey = "format";
56
57 // An Android specific Lua environment with JNI backed callbacks.
58 class JniLuaEnvironment : public LuaEnvironment {
59 public:
60 JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
61 const jobject context,
62 const std::vector<Locale>& device_locales);
63 // Environment setup.
64 bool Initialize();
65
66 // Runs an intent generator snippet.
67 bool RunIntentGenerator(const std::string& generator_snippet,
68 std::vector<RemoteActionTemplate>* remote_actions);
69
70 protected:
71 virtual void SetupExternalHook();
72
73 int HandleExternalCallback();
74 int HandleAndroidCallback();
75 int HandleUserRestrictionsCallback();
76 int HandleUrlEncode();
77 int HandleUrlSchema();
78 int HandleHash();
79 int HandleFormat();
80 int HandleAndroidStringResources();
81 int HandleUrlHost();
82
83 // Checks and retrieves string resources from the model.
84 bool LookupModelStringResource();
85
86 // Reads and create a RemoteAction result from Lua.
87 RemoteActionTemplate ReadRemoteActionTemplateResult();
88
89 // Reads the extras from the Lua result.
90 void ReadExtras(std::map<std::string, Variant>* extra);
91
92 // Reads the intent categories array from a Lua result.
93 void ReadCategories(std::vector<std::string>* category);
94
95 // Retrieves user manager if not previously done.
96 bool RetrieveUserManager();
97
98 // Retrieves system resources if not previously done.
99 bool RetrieveSystemResources();
100
101 // Parse the url string by using Uri.parse from Java.
102 ScopedLocalRef<jobject> ParseUri(StringPiece url) const;
103
104 // Read remote action templates from lua generator.
105 int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
106
107 const Resources& resources_;
108 JNIEnv* jenv_;
109 const JniCache* jni_cache_;
110 const jobject context_;
111 std::vector<Locale> device_locales_;
112
113 ScopedGlobalRef<jobject> usermanager_;
114 // Whether we previously attempted to retrieve the UserManager before.
115 bool usermanager_retrieved_;
116
117 ScopedGlobalRef<jobject> system_resources_;
118 // Whether we previously attempted to retrieve the system resources.
119 bool system_resources_resources_retrieved_;
120
121 // Cached JNI references for Java strings `string` and `android`.
122 ScopedGlobalRef<jstring> string_;
123 ScopedGlobalRef<jstring> android_;
124 };
125
JniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales)126 JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
127 const JniCache* jni_cache,
128 const jobject context,
129 const std::vector<Locale>& device_locales)
130 : resources_(resources),
131 jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
132 jni_cache_(jni_cache),
133 context_(context),
134 device_locales_(device_locales),
135 usermanager_(/*object=*/nullptr,
136 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
137 usermanager_retrieved_(false),
138 system_resources_(/*object=*/nullptr,
139 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
140 system_resources_resources_retrieved_(false),
141 string_(/*object=*/nullptr,
142 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
143 android_(/*object=*/nullptr,
144 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
145
Initialize()146 bool JniLuaEnvironment::Initialize() {
147 string_ =
148 MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
149 android_ =
150 MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
151 if (string_ == nullptr || android_ == nullptr) {
152 TC3_LOG(ERROR) << "Could not allocate constant strings references.";
153 return false;
154 }
155 return (RunProtected([this] {
156 LoadDefaultLibraries();
157 SetupExternalHook();
158 lua_setglobal(state_, "external");
159 return LUA_OK;
160 }) == LUA_OK);
161 }
162
SetupExternalHook()163 void JniLuaEnvironment::SetupExternalHook() {
164 // This exposes an `external` object with the following fields:
165 // * entity: the bundle with all information about a classification.
166 // * android: callbacks into specific android provided methods.
167 // * android.user_restrictions: callbacks to check user permissions.
168 // * android.R: callbacks to retrieve string resources.
169 BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>(
170 "external");
171
172 // android
173 BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>(
174 "android");
175 {
176 // android.user_restrictions
177 BindTable<JniLuaEnvironment,
178 &JniLuaEnvironment::HandleUserRestrictionsCallback>(
179 "user_restrictions");
180 lua_setfield(state_, /*idx=*/-2, "user_restrictions");
181
182 // android.R
183 // Callback to access android string resources.
184 BindTable<JniLuaEnvironment,
185 &JniLuaEnvironment::HandleAndroidStringResources>("R");
186 lua_setfield(state_, /*idx=*/-2, "R");
187 }
188 lua_setfield(state_, /*idx=*/-2, "android");
189 }
190
HandleExternalCallback()191 int JniLuaEnvironment::HandleExternalCallback() {
192 const StringPiece key = ReadString(/*index=*/-1);
193 if (key.Equals(kHashKey)) {
194 Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
195 return 1;
196 } else if (key.Equals(kFormatKey)) {
197 Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>();
198 return 1;
199 } else {
200 TC3_LOG(ERROR) << "Undefined external access " << key.ToString();
201 lua_error(state_);
202 return 0;
203 }
204 }
205
HandleAndroidCallback()206 int JniLuaEnvironment::HandleAndroidCallback() {
207 const StringPiece key = ReadString(/*index=*/-1);
208 if (key.Equals(kDeviceLocaleKey)) {
209 // Provide the locale as table with the individual fields set.
210 lua_newtable(state_);
211 for (int i = 0; i < device_locales_.size(); i++) {
212 // Adjust index to 1-based indexing for Lua.
213 lua_pushinteger(state_, i + 1);
214 lua_newtable(state_);
215 PushString(device_locales_[i].Language());
216 lua_setfield(state_, -2, "language");
217 PushString(device_locales_[i].Region());
218 lua_setfield(state_, -2, "region");
219 PushString(device_locales_[i].Script());
220 lua_setfield(state_, -2, "script");
221 lua_settable(state_, /*idx=*/-3);
222 }
223 return 1;
224 } else if (key.Equals(kPackageNameKey)) {
225 if (context_ == nullptr) {
226 TC3_LOG(ERROR) << "Context invalid.";
227 lua_error(state_);
228 return 0;
229 }
230 ScopedLocalRef<jstring> package_name_str(
231 static_cast<jstring>(jenv_->CallObjectMethod(
232 context_, jni_cache_->context_get_package_name)));
233 if (jni_cache_->ExceptionCheckAndClear()) {
234 TC3_LOG(ERROR) << "Error calling Context.getPackageName";
235 lua_error(state_);
236 return 0;
237 }
238 PushString(ToStlString(jenv_, package_name_str.get()));
239 return 1;
240 } else if (key.Equals(kUrlEncodeKey)) {
241 Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
242 return 1;
243 } else if (key.Equals(kUrlHostKey)) {
244 Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>();
245 return 1;
246 } else if (key.Equals(kUrlSchemaKey)) {
247 Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>();
248 return 1;
249 } else {
250 TC3_LOG(ERROR) << "Undefined android reference " << key.ToString();
251 lua_error(state_);
252 return 0;
253 }
254 }
255
HandleUserRestrictionsCallback()256 int JniLuaEnvironment::HandleUserRestrictionsCallback() {
257 if (jni_cache_->usermanager_class == nullptr ||
258 jni_cache_->usermanager_get_user_restrictions == nullptr) {
259 // UserManager is only available for API level >= 17 and
260 // getUserRestrictions only for API level >= 18, so we just return false
261 // normally here.
262 lua_pushboolean(state_, false);
263 return 1;
264 }
265
266 // Get user manager if not previously retrieved.
267 if (!RetrieveUserManager()) {
268 TC3_LOG(ERROR) << "Error retrieving user manager.";
269 lua_error(state_);
270 return 0;
271 }
272
273 ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
274 usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
275 if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
276 TC3_LOG(ERROR) << "Error calling getUserRestrictions";
277 lua_error(state_);
278 return 0;
279 }
280
281 const StringPiece key_str = ReadString(/*index=*/-1);
282 if (key_str.empty()) {
283 TC3_LOG(ERROR) << "Expected string, got null.";
284 lua_error(state_);
285 return 0;
286 }
287
288 ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
289 if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
290 TC3_LOG(ERROR) << "Expected string, got null.";
291 lua_error(state_);
292 return 0;
293 }
294 const bool permission = jenv_->CallBooleanMethod(
295 bundle.get(), jni_cache_->bundle_get_boolean, key.get());
296 if (jni_cache_->ExceptionCheckAndClear()) {
297 TC3_LOG(ERROR) << "Error getting bundle value";
298 lua_pushboolean(state_, false);
299 } else {
300 lua_pushboolean(state_, permission);
301 }
302 return 1;
303 }
304
HandleUrlEncode()305 int JniLuaEnvironment::HandleUrlEncode() {
306 const StringPiece input = ReadString(/*index=*/1);
307 if (input.empty()) {
308 TC3_LOG(ERROR) << "Expected string, got null.";
309 lua_error(state_);
310 return 0;
311 }
312
313 // Call Java URL encoder.
314 ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
315 if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
316 TC3_LOG(ERROR) << "Expected string, got null.";
317 lua_error(state_);
318 return 0;
319 }
320 ScopedLocalRef<jstring> encoded_str(
321 static_cast<jstring>(jenv_->CallStaticObjectMethod(
322 jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
323 input_str.get(), jni_cache_->string_utf8.get())));
324 if (jni_cache_->ExceptionCheckAndClear()) {
325 TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
326 lua_error(state_);
327 return 0;
328 }
329 PushString(ToStlString(jenv_, encoded_str.get()));
330 return 1;
331 }
332
ParseUri(StringPiece url) const333 ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
334 if (url.empty()) {
335 return nullptr;
336 }
337
338 // Call to Java URI parser.
339 ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
340 if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
341 TC3_LOG(ERROR) << "Expected string, got null";
342 return nullptr;
343 }
344
345 // Try to parse uri and get scheme.
346 ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
347 jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
348 if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
349 TC3_LOG(ERROR) << "Error calling Uri.parse";
350 }
351 return uri;
352 }
353
HandleUrlSchema()354 int JniLuaEnvironment::HandleUrlSchema() {
355 StringPiece url = ReadString(/*index=*/1);
356
357 ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
358 if (parsed_uri == nullptr) {
359 lua_error(state_);
360 return 0;
361 }
362
363 ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
364 jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
365 if (jni_cache_->ExceptionCheckAndClear()) {
366 TC3_LOG(ERROR) << "Error calling Uri.getScheme";
367 lua_error(state_);
368 return 0;
369 }
370 if (scheme_str == nullptr) {
371 lua_pushnil(state_);
372 } else {
373 PushString(ToStlString(jenv_, scheme_str.get()));
374 }
375 return 1;
376 }
377
HandleUrlHost()378 int JniLuaEnvironment::HandleUrlHost() {
379 StringPiece url = ReadString(/*index=*/-1);
380
381 ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
382 if (parsed_uri == nullptr) {
383 lua_error(state_);
384 return 0;
385 }
386
387 ScopedLocalRef<jstring> host_str(static_cast<jstring>(
388 jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
389 if (jni_cache_->ExceptionCheckAndClear()) {
390 TC3_LOG(ERROR) << "Error calling Uri.getHost";
391 lua_error(state_);
392 return 0;
393 }
394 if (host_str == nullptr) {
395 lua_pushnil(state_);
396 } else {
397 PushString(ToStlString(jenv_, host_str.get()));
398 }
399 return 1;
400 }
401
HandleHash()402 int JniLuaEnvironment::HandleHash() {
403 const StringPiece input = ReadString(/*index=*/-1);
404 lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
405 return 1;
406 }
407
HandleFormat()408 int JniLuaEnvironment::HandleFormat() {
409 const int num_args = lua_gettop(state_);
410 std::vector<StringPiece> args(num_args - 1);
411 for (int i = 0; i < num_args - 1; i++) {
412 args[i] = ReadString(/*index=*/i + 2);
413 }
414 PushString(strings::Substitute(ReadString(/*index=*/1), args));
415 return 1;
416 }
417
LookupModelStringResource()418 bool JniLuaEnvironment::LookupModelStringResource() {
419 // Handle only lookup by name.
420 if (lua_type(state_, 2) != LUA_TSTRING) {
421 return false;
422 }
423
424 const StringPiece resource_name = ReadString(/*index=*/-1);
425 std::string resource_content;
426 if (!resources_.GetResourceContent(device_locales_, resource_name,
427 &resource_content)) {
428 // Resource cannot be provided by the model.
429 return false;
430 }
431
432 PushString(resource_content);
433 return true;
434 }
435
HandleAndroidStringResources()436 int JniLuaEnvironment::HandleAndroidStringResources() {
437 // Check whether the requested resource can be served from the model data.
438 if (LookupModelStringResource()) {
439 return 1;
440 }
441
442 // Get system resources if not previously retrieved.
443 if (!RetrieveSystemResources()) {
444 TC3_LOG(ERROR) << "Error retrieving system resources.";
445 lua_error(state_);
446 return 0;
447 }
448
449 int resource_id;
450 switch (lua_type(state_, -1)) {
451 case LUA_TNUMBER:
452 resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
453 break;
454 case LUA_TSTRING: {
455 const StringPiece resource_name_str = ReadString(/*index=*/-1);
456 if (resource_name_str.empty()) {
457 TC3_LOG(ERROR) << "No resource name provided.";
458 lua_error(state_);
459 return 0;
460 }
461 ScopedLocalRef<jstring> resource_name =
462 jni_cache_->ConvertToJavaString(resource_name_str);
463 if (resource_name == nullptr) {
464 TC3_LOG(ERROR) << "Invalid resource name.";
465 lua_error(state_);
466 return 0;
467 }
468 resource_id = jenv_->CallIntMethod(
469 system_resources_.get(), jni_cache_->resources_get_identifier,
470 resource_name.get(), string_.get(), android_.get());
471 if (jni_cache_->ExceptionCheckAndClear()) {
472 TC3_LOG(ERROR) << "Error calling getIdentifier.";
473 lua_error(state_);
474 return 0;
475 }
476 break;
477 }
478 default:
479 TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
480 lua_error(state_);
481 return 0;
482 }
483 if (resource_id == 0) {
484 TC3_LOG(ERROR) << "Resource not found.";
485 lua_pushnil(state_);
486 return 1;
487 }
488 ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
489 jenv_->CallObjectMethod(system_resources_.get(),
490 jni_cache_->resources_get_string, resource_id)));
491 if (jni_cache_->ExceptionCheckAndClear()) {
492 TC3_LOG(ERROR) << "Error calling getString.";
493 lua_error(state_);
494 return 0;
495 }
496 if (resource_str == nullptr) {
497 lua_pushnil(state_);
498 } else {
499 PushString(ToStlString(jenv_, resource_str.get()));
500 }
501 return 1;
502 }
503
RetrieveSystemResources()504 bool JniLuaEnvironment::RetrieveSystemResources() {
505 if (system_resources_resources_retrieved_) {
506 return (system_resources_ != nullptr);
507 }
508 system_resources_resources_retrieved_ = true;
509 jobject system_resources_ref = jenv_->CallStaticObjectMethod(
510 jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
511 if (jni_cache_->ExceptionCheckAndClear()) {
512 TC3_LOG(ERROR) << "Error calling getSystem.";
513 return false;
514 }
515 system_resources_ =
516 MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
517 return (system_resources_ != nullptr);
518 }
519
RetrieveUserManager()520 bool JniLuaEnvironment::RetrieveUserManager() {
521 if (context_ == nullptr) {
522 return false;
523 }
524 if (usermanager_retrieved_) {
525 return (usermanager_ != nullptr);
526 }
527 usermanager_retrieved_ = true;
528 ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
529 jobject usermanager_ref = jenv_->CallObjectMethod(
530 context_, jni_cache_->context_get_system_service, service.get());
531 if (jni_cache_->ExceptionCheckAndClear()) {
532 TC3_LOG(ERROR) << "Error calling getSystemService.";
533 return false;
534 }
535 usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
536 return (usermanager_ != nullptr);
537 }
538
ReadRemoteActionTemplateResult()539 RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
540 RemoteActionTemplate result;
541 // Read intent template.
542 lua_pushnil(state_);
543 while (lua_next(state_, /*idx=*/-2)) {
544 const StringPiece key = ReadString(/*index=*/-2);
545 if (key.Equals("title_without_entity")) {
546 result.title_without_entity = ReadString(/*index=*/-1).ToString();
547 } else if (key.Equals("title_with_entity")) {
548 result.title_with_entity = ReadString(/*index=*/-1).ToString();
549 } else if (key.Equals("description")) {
550 result.description = ReadString(/*index=*/-1).ToString();
551 } else if (key.Equals("description_with_app_name")) {
552 result.description_with_app_name = ReadString(/*index=*/-1).ToString();
553 } else if (key.Equals("action")) {
554 result.action = ReadString(/*index=*/-1).ToString();
555 } else if (key.Equals("data")) {
556 result.data = ReadString(/*index=*/-1).ToString();
557 } else if (key.Equals("type")) {
558 result.type = ReadString(/*index=*/-1).ToString();
559 } else if (key.Equals("flags")) {
560 result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
561 } else if (key.Equals("package_name")) {
562 result.package_name = ReadString(/*index=*/-1).ToString();
563 } else if (key.Equals("request_code")) {
564 result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
565 } else if (key.Equals("category")) {
566 ReadCategories(&result.category);
567 } else if (key.Equals("extra")) {
568 ReadExtras(&result.extra);
569 } else {
570 TC3_LOG(INFO) << "Unknown entry: " << key.ToString();
571 }
572 lua_pop(state_, 1);
573 }
574 lua_pop(state_, 1);
575 return result;
576 }
577
ReadCategories(std::vector<std::string> * category)578 void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
579 // Read category array.
580 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
581 TC3_LOG(ERROR) << "Expected categories table, got: "
582 << lua_type(state_, /*idx=*/-1);
583 lua_pop(state_, 1);
584 return;
585 }
586 lua_pushnil(state_);
587 while (lua_next(state_, /*idx=*/-2)) {
588 category->push_back(ReadString(/*index=*/-1).ToString());
589 lua_pop(state_, 1);
590 }
591 }
592
ReadExtras(std::map<std::string,Variant> * extra)593 void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
594 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
595 TC3_LOG(ERROR) << "Expected extras table, got: "
596 << lua_type(state_, /*idx=*/-1);
597 lua_pop(state_, 1);
598 return;
599 }
600 lua_pushnil(state_);
601 while (lua_next(state_, /*idx=*/-2)) {
602 // Each entry is a table specifying name and value.
603 // The value is specified via a type specific field as Lua doesn't allow
604 // to easily distinguish between different number types.
605 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
606 TC3_LOG(ERROR) << "Expected a table for an extra, got: "
607 << lua_type(state_, /*idx=*/-1);
608 lua_pop(state_, 1);
609 return;
610 }
611 std::string name;
612 Variant value;
613
614 lua_pushnil(state_);
615 while (lua_next(state_, /*idx=*/-2)) {
616 const StringPiece key = ReadString(/*index=*/-2);
617 if (key.Equals("name")) {
618 name = ReadString(/*index=*/-1).ToString();
619 } else if (key.Equals("int_value")) {
620 value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
621 } else if (key.Equals("long_value")) {
622 value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
623 } else if (key.Equals("float_value")) {
624 value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
625 } else if (key.Equals("bool_value")) {
626 value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
627 } else if (key.Equals("string_value")) {
628 value = Variant(ReadString(/*index=*/-1).ToString());
629 } else {
630 TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
631 }
632 lua_pop(state_, 1);
633 }
634 if (!name.empty()) {
635 (*extra)[name] = value;
636 } else {
637 TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
638 }
639 lua_pop(state_, 1);
640 }
641 }
642
ReadRemoteActionTemplates(std::vector<RemoteActionTemplate> * result)643 int JniLuaEnvironment::ReadRemoteActionTemplates(
644 std::vector<RemoteActionTemplate>* result) {
645 // Read result.
646 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
647 TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
648 lua_error(state_);
649 return LUA_ERRRUN;
650 }
651
652 // Read remote action templates array.
653 lua_pushnil(state_);
654 while (lua_next(state_, /*idx=*/-2)) {
655 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
656 TC3_LOG(ERROR) << "Expected intent table, got: "
657 << lua_type(state_, /*idx=*/-1);
658 lua_pop(state_, 1);
659 continue;
660 }
661 result->push_back(ReadRemoteActionTemplateResult());
662 }
663 lua_pop(state_, /*n=*/1);
664 return LUA_OK;
665 }
666
RunIntentGenerator(const std::string & generator_snippet,std::vector<RemoteActionTemplate> * remote_actions)667 bool JniLuaEnvironment::RunIntentGenerator(
668 const std::string& generator_snippet,
669 std::vector<RemoteActionTemplate>* remote_actions) {
670 int status;
671 status = luaL_loadbuffer(state_, generator_snippet.data(),
672 generator_snippet.size(),
673 /*name=*/nullptr);
674 if (status != LUA_OK) {
675 TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
676 return false;
677 }
678 status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
679 if (status != LUA_OK) {
680 TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
681 return false;
682 }
683 if (RunProtected(
684 [this, remote_actions] {
685 return ReadRemoteActionTemplates(remote_actions);
686 },
687 /*num_args=*/1) != LUA_OK) {
688 TC3_LOG(ERROR) << "Could not read results.";
689 return false;
690 }
691 // Check that we correctly cleaned-up the state.
692 const int stack_size = lua_gettop(state_);
693 if (stack_size > 0) {
694 TC3_LOG(ERROR) << "Unexpected stack size.";
695 lua_settop(state_, 0);
696 return false;
697 }
698 return true;
699 }
700
701 // Lua environment for classfication result intent generation.
702 class AnnotatorJniEnvironment : public JniLuaEnvironment {
703 public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema)704 AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
705 const jobject context,
706 const std::vector<Locale>& device_locales,
707 const std::string& entity_text,
708 const ClassificationResult& classification,
709 const int64 reference_time_ms_utc,
710 const reflection::Schema* entity_data_schema)
711 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
712 entity_text_(entity_text),
713 classification_(classification),
714 reference_time_ms_utc_(reference_time_ms_utc),
715 entity_data_schema_(entity_data_schema) {}
716
717 protected:
SetupExternalHook()718 void SetupExternalHook() override {
719 JniLuaEnvironment::SetupExternalHook();
720 lua_pushinteger(state_, reference_time_ms_utc_);
721 lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
722
723 PushAnnotation(classification_, entity_text_, entity_data_schema_, this);
724 lua_setfield(state_, /*idx=*/-2, "entity");
725 }
726
727 const std::string& entity_text_;
728 const ClassificationResult& classification_;
729 const int64 reference_time_ms_utc_;
730
731 // Reflection schema data.
732 const reflection::Schema* const entity_data_schema_;
733 };
734
735 // Lua environment for actions intent generation.
736 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
737 public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)738 ActionsJniLuaEnvironment(
739 const Resources& resources, const JniCache* jni_cache,
740 const jobject context, const std::vector<Locale>& device_locales,
741 const Conversation& conversation, const ActionSuggestion& action,
742 const reflection::Schema* actions_entity_data_schema,
743 const reflection::Schema* annotations_entity_data_schema)
744 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
745 conversation_(conversation),
746 action_(action),
747 annotation_iterator_(annotations_entity_data_schema, this),
748 conversation_iterator_(annotations_entity_data_schema, this),
749 entity_data_schema_(actions_entity_data_schema) {}
750
751 protected:
SetupExternalHook()752 void SetupExternalHook() override {
753 JniLuaEnvironment::SetupExternalHook();
754 conversation_iterator_.NewIterator("conversation", &conversation_.messages,
755 state_);
756 lua_setfield(state_, /*idx=*/-2, "conversation");
757
758 PushAction(action_, entity_data_schema_, annotation_iterator_, this);
759 lua_setfield(state_, /*idx=*/-2, "entity");
760 }
761
762 const Conversation& conversation_;
763 const ActionSuggestion& action_;
764 const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
765 const ConversationIterator conversation_iterator_;
766 const reflection::Schema* entity_data_schema_;
767 };
768
769 } // namespace
770
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)771 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
772 const IntentFactoryModel* options, const ResourcePool* resources,
773 const std::shared_ptr<JniCache>& jni_cache) {
774 std::unique_ptr<IntentGenerator> intent_generator(
775 new IntentGenerator(options, resources, jni_cache));
776
777 if (options == nullptr || options->generator() == nullptr) {
778 TC3_LOG(ERROR) << "No intent generator options.";
779 return nullptr;
780 }
781
782 std::unique_ptr<ZlibDecompressor> zlib_decompressor =
783 ZlibDecompressor::Instance();
784 if (!zlib_decompressor) {
785 TC3_LOG(ERROR) << "Cannot initialize decompressor.";
786 return nullptr;
787 }
788
789 for (const IntentFactoryModel_::IntentGenerator* generator :
790 *options->generator()) {
791 std::string lua_template_generator;
792 if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
793 generator->lua_template_generator(),
794 generator->compressed_lua_template_generator(),
795 &lua_template_generator)) {
796 TC3_LOG(ERROR) << "Could not decompress generator template.";
797 return nullptr;
798 }
799
800 std::string lua_code = lua_template_generator;
801 if (options->precompile_generators()) {
802 if (!Compile(lua_template_generator, &lua_code)) {
803 TC3_LOG(ERROR) << "Could not precompile generator template.";
804 return nullptr;
805 }
806 }
807
808 intent_generator->generators_[generator->type()->str()] = lua_code;
809 }
810
811 return intent_generator;
812 }
813
ParseDeviceLocales(const jstring device_locales) const814 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
815 const jstring device_locales) const {
816 if (device_locales == nullptr) {
817 TC3_LOG(ERROR) << "No locales provided.";
818 return {};
819 }
820 ScopedStringChars locales_str =
821 GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
822 if (locales_str == nullptr) {
823 TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
824 return {};
825 }
826 std::vector<Locale> locales;
827 if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
828 &locales)) {
829 TC3_LOG(ERROR) << "Cannot parse locales.";
830 return {};
831 }
832 return locales;
833 }
834
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const835 bool IntentGenerator::GenerateIntents(
836 const jstring device_locales, const ClassificationResult& classification,
837 const int64 reference_time_ms_utc, const std::string& text,
838 const CodepointSpan selection_indices, const jobject context,
839 const reflection::Schema* annotations_entity_data_schema,
840 std::vector<RemoteActionTemplate>* remote_actions) const {
841 if (options_ == nullptr) {
842 return false;
843 }
844
845 // Retrieve generator for specified entity.
846 auto it = generators_.find(classification.collection);
847 if (it == generators_.end()) {
848 return true;
849 }
850
851 const std::string entity_text =
852 UTF8ToUnicodeText(text, /*do_copy=*/false)
853 .UTF8Substring(selection_indices.first, selection_indices.second);
854
855 std::unique_ptr<AnnotatorJniEnvironment> interpreter(
856 new AnnotatorJniEnvironment(
857 resources_, jni_cache_.get(), context,
858 ParseDeviceLocales(device_locales), entity_text, classification,
859 reference_time_ms_utc, annotations_entity_data_schema));
860
861 if (!interpreter->Initialize()) {
862 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
863 return false;
864 }
865
866 return interpreter->RunIntentGenerator(it->second, remote_actions);
867 }
868
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const869 bool IntentGenerator::GenerateIntents(
870 const jstring device_locales, const ActionSuggestion& action,
871 const Conversation& conversation, const jobject context,
872 const reflection::Schema* annotations_entity_data_schema,
873 const reflection::Schema* actions_entity_data_schema,
874 std::vector<RemoteActionTemplate>* remote_actions) const {
875 if (options_ == nullptr) {
876 return false;
877 }
878
879 // Retrieve generator for specified action.
880 auto it = generators_.find(action.type);
881 if (it == generators_.end()) {
882 return true;
883 }
884
885 std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
886 new ActionsJniLuaEnvironment(
887 resources_, jni_cache_.get(), context,
888 ParseDeviceLocales(device_locales), conversation, action,
889 actions_entity_data_schema, annotations_entity_data_schema));
890
891 if (!interpreter->Initialize()) {
892 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
893 return false;
894 }
895
896 return interpreter->RunIntentGenerator(it->second, remote_actions);
897 }
898
899 } // namespace libtextclassifier3
900