1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/intents/intent-generator.h"
18
19 #include <vector>
20
21 #include "actions/types.h"
22 #include "annotator/types.h"
23 #include "utils/base/logging.h"
24 #include "utils/base/statusor.h"
25 #include "utils/hash/farmhash.h"
26 #include "utils/java/jni-base.h"
27 #include "utils/java/jni-helper.h"
28 #include "utils/java/string_utils.h"
29 #include "utils/lua-utils.h"
30 #include "utils/strings/stringpiece.h"
31 #include "utils/strings/substitute.h"
32 #include "utils/utf8/unicodetext.h"
33 #include "utils/variant.h"
34 #include "utils/zlib/zlib.h"
35 #include "flatbuffers/reflection_generated.h"
36
37 #ifdef __cplusplus
38 extern "C" {
39 #endif
40 #include "lauxlib.h"
41 #include "lua.h"
42 #ifdef __cplusplus
43 }
44 #endif
45
46 namespace libtextclassifier3 {
47 namespace {
48
49 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
50 static constexpr const char* kHashKey = "hash";
51 static constexpr const char* kUrlSchemaKey = "url_schema";
52 static constexpr const char* kUrlHostKey = "url_host";
53 static constexpr const char* kUrlEncodeKey = "urlencode";
54 static constexpr const char* kPackageNameKey = "package_name";
55 static constexpr const char* kDeviceLocaleKey = "device_locales";
56 static constexpr const char* kFormatKey = "format";
57
58 // An Android specific Lua environment with JNI backed callbacks.
59 class JniLuaEnvironment : public LuaEnvironment {
60 public:
61 JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
62 const jobject context,
63 const std::vector<Locale>& device_locales);
64 // Environment setup.
65 bool Initialize();
66
67 // Runs an intent generator snippet.
68 bool RunIntentGenerator(const std::string& generator_snippet,
69 std::vector<RemoteActionTemplate>* remote_actions);
70
71 protected:
72 virtual void SetupExternalHook();
73
74 int HandleExternalCallback();
75 int HandleAndroidCallback();
76 int HandleUserRestrictionsCallback();
77 int HandleUrlEncode();
78 int HandleUrlSchema();
79 int HandleHash();
80 int HandleFormat();
81 int HandleAndroidStringResources();
82 int HandleUrlHost();
83
84 // Checks and retrieves string resources from the model.
85 bool LookupModelStringResource() const;
86
87 // Reads and create a RemoteAction result from Lua.
88 RemoteActionTemplate ReadRemoteActionTemplateResult() const;
89
90 // Reads the extras from the Lua result.
91 std::map<std::string, Variant> ReadExtras() const;
92
93 // Retrieves user manager if not previously done.
94 bool RetrieveUserManager();
95
96 // Retrieves system resources if not previously done.
97 bool RetrieveSystemResources();
98
99 // Parse the url string by using Uri.parse from Java.
100 StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
101
102 // Read remote action templates from lua generator.
103 int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
104
105 const Resources& resources_;
106 JNIEnv* jenv_;
107 const JniCache* jni_cache_;
108 const jobject context_;
109 std::vector<Locale> device_locales_;
110
111 ScopedGlobalRef<jobject> usermanager_;
112 // Whether we previously attempted to retrieve the UserManager before.
113 bool usermanager_retrieved_;
114
115 ScopedGlobalRef<jobject> system_resources_;
116 // Whether we previously attempted to retrieve the system resources.
117 bool system_resources_resources_retrieved_;
118
119 // Cached JNI references for Java strings `string` and `android`.
120 ScopedGlobalRef<jstring> string_;
121 ScopedGlobalRef<jstring> android_;
122 };
123
JniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales)124 JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
125 const JniCache* jni_cache,
126 const jobject context,
127 const std::vector<Locale>& device_locales)
128 : resources_(resources),
129 jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
130 jni_cache_(jni_cache),
131 context_(context),
132 device_locales_(device_locales),
133 usermanager_(/*object=*/nullptr,
134 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
135 usermanager_retrieved_(false),
136 system_resources_(/*object=*/nullptr,
137 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
138 system_resources_resources_retrieved_(false),
139 string_(/*object=*/nullptr,
140 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
141 android_(/*object=*/nullptr,
142 /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
143
Initialize()144 bool JniLuaEnvironment::Initialize() {
145 TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
146 JniHelper::NewStringUTF(jenv_, "string"));
147 string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
148 TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
149 JniHelper::NewStringUTF(jenv_, "android"));
150 android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
151 if (string_ == nullptr || android_ == nullptr) {
152 TC3_LOG(ERROR) << "Could not allocate constant strings references.";
153 return false;
154 }
155 return (RunProtected([this] {
156 LoadDefaultLibraries();
157 SetupExternalHook();
158 lua_setglobal(state_, "external");
159 return LUA_OK;
160 }) == LUA_OK);
161 }
162
SetupExternalHook()163 void JniLuaEnvironment::SetupExternalHook() {
164 // This exposes an `external` object with the following fields:
165 // * entity: the bundle with all information about a classification.
166 // * android: callbacks into specific android provided methods.
167 // * android.user_restrictions: callbacks to check user permissions.
168 // * android.R: callbacks to retrieve string resources.
169 PushLazyObject(&JniLuaEnvironment::HandleExternalCallback);
170
171 // android
172 PushLazyObject(&JniLuaEnvironment::HandleAndroidCallback);
173 {
174 // android.user_restrictions
175 PushLazyObject(&JniLuaEnvironment::HandleUserRestrictionsCallback);
176 lua_setfield(state_, /*idx=*/-2, "user_restrictions");
177
178 // android.R
179 // Callback to access android string resources.
180 PushLazyObject(&JniLuaEnvironment::HandleAndroidStringResources);
181 lua_setfield(state_, /*idx=*/-2, "R");
182 }
183 lua_setfield(state_, /*idx=*/-2, "android");
184 }
185
HandleExternalCallback()186 int JniLuaEnvironment::HandleExternalCallback() {
187 const StringPiece key = ReadString(kIndexStackTop);
188 if (key.Equals(kHashKey)) {
189 PushFunction(&JniLuaEnvironment::HandleHash);
190 return 1;
191 } else if (key.Equals(kFormatKey)) {
192 PushFunction(&JniLuaEnvironment::HandleFormat);
193 return 1;
194 } else {
195 TC3_LOG(ERROR) << "Undefined external access " << key;
196 lua_error(state_);
197 return 0;
198 }
199 }
200
HandleAndroidCallback()201 int JniLuaEnvironment::HandleAndroidCallback() {
202 const StringPiece key = ReadString(kIndexStackTop);
203 if (key.Equals(kDeviceLocaleKey)) {
204 // Provide the locale as table with the individual fields set.
205 lua_newtable(state_);
206 for (int i = 0; i < device_locales_.size(); i++) {
207 // Adjust index to 1-based indexing for Lua.
208 lua_pushinteger(state_, i + 1);
209 lua_newtable(state_);
210 PushString(device_locales_[i].Language());
211 lua_setfield(state_, -2, "language");
212 PushString(device_locales_[i].Region());
213 lua_setfield(state_, -2, "region");
214 PushString(device_locales_[i].Script());
215 lua_setfield(state_, -2, "script");
216 lua_settable(state_, /*idx=*/-3);
217 }
218 return 1;
219 } else if (key.Equals(kPackageNameKey)) {
220 if (context_ == nullptr) {
221 TC3_LOG(ERROR) << "Context invalid.";
222 lua_error(state_);
223 return 0;
224 }
225
226 StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
227 JniHelper::CallObjectMethod<jstring>(
228 jenv_, context_, jni_cache_->context_get_package_name);
229
230 if (!status_or_package_name_str.ok()) {
231 TC3_LOG(ERROR) << "Error calling Context.getPackageName";
232 lua_error(state_);
233 return 0;
234 }
235 StatusOr<std::string> status_or_package_name_std_str =
236 ToStlString(jenv_, status_or_package_name_str.ValueOrDie().get());
237 if (!status_or_package_name_std_str.ok()) {
238 lua_error(state_);
239 return 0;
240 }
241 PushString(status_or_package_name_std_str.ValueOrDie());
242 return 1;
243 } else if (key.Equals(kUrlEncodeKey)) {
244 PushFunction(&JniLuaEnvironment::HandleUrlEncode);
245 return 1;
246 } else if (key.Equals(kUrlHostKey)) {
247 PushFunction(&JniLuaEnvironment::HandleUrlHost);
248 return 1;
249 } else if (key.Equals(kUrlSchemaKey)) {
250 PushFunction(&JniLuaEnvironment::HandleUrlSchema);
251 return 1;
252 } else {
253 TC3_LOG(ERROR) << "Undefined android reference " << key;
254 lua_error(state_);
255 return 0;
256 }
257 }
258
HandleUserRestrictionsCallback()259 int JniLuaEnvironment::HandleUserRestrictionsCallback() {
260 if (jni_cache_->usermanager_class == nullptr ||
261 jni_cache_->usermanager_get_user_restrictions == nullptr) {
262 // UserManager is only available for API level >= 17 and
263 // getUserRestrictions only for API level >= 18, so we just return false
264 // normally here.
265 lua_pushboolean(state_, false);
266 return 1;
267 }
268
269 // Get user manager if not previously retrieved.
270 if (!RetrieveUserManager()) {
271 TC3_LOG(ERROR) << "Error retrieving user manager.";
272 lua_error(state_);
273 return 0;
274 }
275
276 StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
277 JniHelper::CallObjectMethod(
278 jenv_, usermanager_.get(),
279 jni_cache_->usermanager_get_user_restrictions);
280 if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
281 TC3_LOG(ERROR) << "Error calling getUserRestrictions";
282 lua_error(state_);
283 return 0;
284 }
285
286 const StringPiece key_str = ReadString(kIndexStackTop);
287 if (key_str.empty()) {
288 TC3_LOG(ERROR) << "Expected string, got null.";
289 lua_error(state_);
290 return 0;
291 }
292
293 const StatusOr<ScopedLocalRef<jstring>> status_or_key =
294 jni_cache_->ConvertToJavaString(key_str);
295 if (!status_or_key.ok()) {
296 lua_error(state_);
297 return 0;
298 }
299 const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
300 jenv_, status_or_bundle.ValueOrDie().get(),
301 jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
302 if (!status_or_permission.ok()) {
303 TC3_LOG(ERROR) << "Error getting bundle value";
304 lua_pushboolean(state_, false);
305 } else {
306 lua_pushboolean(state_, status_or_permission.ValueOrDie());
307 }
308 return 1;
309 }
310
HandleUrlEncode()311 int JniLuaEnvironment::HandleUrlEncode() {
312 const StringPiece input = ReadString(/*index=*/1);
313 if (input.empty()) {
314 TC3_LOG(ERROR) << "Expected string, got null.";
315 lua_error(state_);
316 return 0;
317 }
318
319 // Call Java URL encoder.
320 const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
321 jni_cache_->ConvertToJavaString(input);
322 if (!status_or_input_str.ok()) {
323 lua_error(state_);
324 return 0;
325 }
326 StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
327 JniHelper::CallStaticObjectMethod<jstring>(
328 jenv_, jni_cache_->urlencoder_class.get(),
329 jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
330 jni_cache_->string_utf8.get());
331
332 if (!status_or_encoded_str.ok()) {
333 TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
334 lua_error(state_);
335 return 0;
336 }
337 const StatusOr<std::string> status_or_encoded_std_str =
338 ToStlString(jenv_, status_or_encoded_str.ValueOrDie().get());
339 if (!status_or_encoded_std_str.ok()) {
340 lua_error(state_);
341 return 0;
342 }
343 PushString(status_or_encoded_std_str.ValueOrDie());
344 return 1;
345 }
346
ParseUri(StringPiece url) const347 StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
348 StringPiece url) const {
349 if (url.empty()) {
350 return {Status::UNKNOWN};
351 }
352
353 // Call to Java URI parser.
354 TC3_ASSIGN_OR_RETURN(
355 const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
356 jni_cache_->ConvertToJavaString(url));
357
358 // Try to parse uri and get scheme.
359 TC3_ASSIGN_OR_RETURN(
360 ScopedLocalRef<jobject> uri,
361 JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
362 jni_cache_->uri_parse,
363 status_or_url_str.ValueOrDie().get()));
364 if (uri == nullptr) {
365 TC3_LOG(ERROR) << "Error calling Uri.parse";
366 return {Status::UNKNOWN};
367 }
368 return uri;
369 }
370
HandleUrlSchema()371 int JniLuaEnvironment::HandleUrlSchema() {
372 StringPiece url = ReadString(/*index=*/1);
373
374 const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
375 if (!status_or_parsed_uri.ok()) {
376 lua_error(state_);
377 return 0;
378 }
379
380 const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
381 JniHelper::CallObjectMethod<jstring>(
382 jenv_, status_or_parsed_uri.ValueOrDie().get(),
383 jni_cache_->uri_get_scheme);
384 if (!status_or_scheme_str.ok()) {
385 TC3_LOG(ERROR) << "Error calling Uri.getScheme";
386 lua_error(state_);
387 return 0;
388 }
389 if (status_or_scheme_str.ValueOrDie() == nullptr) {
390 lua_pushnil(state_);
391 } else {
392 const StatusOr<std::string> status_or_scheme_std_str =
393 ToStlString(jenv_, status_or_scheme_str.ValueOrDie().get());
394 if (!status_or_scheme_std_str.ok()) {
395 lua_error(state_);
396 return 0;
397 }
398 PushString(status_or_scheme_std_str.ValueOrDie());
399 }
400 return 1;
401 }
402
HandleUrlHost()403 int JniLuaEnvironment::HandleUrlHost() {
404 const StringPiece url = ReadString(kIndexStackTop);
405
406 const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
407 if (!status_or_parsed_uri.ok()) {
408 lua_error(state_);
409 return 0;
410 }
411
412 const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
413 JniHelper::CallObjectMethod<jstring>(
414 jenv_, status_or_parsed_uri.ValueOrDie().get(),
415 jni_cache_->uri_get_host);
416 if (!status_or_host_str.ok()) {
417 TC3_LOG(ERROR) << "Error calling Uri.getHost";
418 lua_error(state_);
419 return 0;
420 }
421
422 if (status_or_host_str.ValueOrDie() == nullptr) {
423 lua_pushnil(state_);
424 } else {
425 const StatusOr<std::string> status_or_host_std_str =
426 ToStlString(jenv_, status_or_host_str.ValueOrDie().get());
427 if (!status_or_host_std_str.ok()) {
428 lua_error(state_);
429 return 0;
430 }
431 PushString(status_or_host_std_str.ValueOrDie());
432 }
433 return 1;
434 }
435
HandleHash()436 int JniLuaEnvironment::HandleHash() {
437 const StringPiece input = ReadString(kIndexStackTop);
438 lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
439 return 1;
440 }
441
HandleFormat()442 int JniLuaEnvironment::HandleFormat() {
443 const int num_args = lua_gettop(state_);
444 std::vector<StringPiece> args(num_args - 1);
445 for (int i = 0; i < num_args - 1; i++) {
446 args[i] = ReadString(/*index=*/i + 2);
447 }
448 PushString(strings::Substitute(ReadString(/*index=*/1), args));
449 return 1;
450 }
451
LookupModelStringResource() const452 bool JniLuaEnvironment::LookupModelStringResource() const {
453 // Handle only lookup by name.
454 if (lua_type(state_, kIndexStackTop) != LUA_TSTRING) {
455 return false;
456 }
457
458 const StringPiece resource_name = ReadString(kIndexStackTop);
459 std::string resource_content;
460 if (!resources_.GetResourceContent(device_locales_, resource_name,
461 &resource_content)) {
462 // Resource cannot be provided by the model.
463 return false;
464 }
465
466 PushString(resource_content);
467 return true;
468 }
469
HandleAndroidStringResources()470 int JniLuaEnvironment::HandleAndroidStringResources() {
471 // Check whether the requested resource can be served from the model data.
472 if (LookupModelStringResource()) {
473 return 1;
474 }
475
476 // Get system resources if not previously retrieved.
477 if (!RetrieveSystemResources()) {
478 TC3_LOG(ERROR) << "Error retrieving system resources.";
479 lua_error(state_);
480 return 0;
481 }
482
483 int resource_id;
484 switch (lua_type(state_, kIndexStackTop)) {
485 case LUA_TNUMBER:
486 resource_id = Read<int>(/*index=*/kIndexStackTop);
487 break;
488 case LUA_TSTRING: {
489 const StringPiece resource_name_str = ReadString(kIndexStackTop);
490 if (resource_name_str.empty()) {
491 TC3_LOG(ERROR) << "No resource name provided.";
492 lua_error(state_);
493 return 0;
494 }
495 const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
496 jni_cache_->ConvertToJavaString(resource_name_str);
497 if (!status_or_resource_name.ok()) {
498 TC3_LOG(ERROR) << "Invalid resource name.";
499 lua_error(state_);
500 return 0;
501 }
502 StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
503 jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
504 status_or_resource_name.ValueOrDie().get(), string_.get(),
505 android_.get());
506 if (!status_or_resource_id.ok()) {
507 TC3_LOG(ERROR) << "Error calling getIdentifier.";
508 lua_error(state_);
509 return 0;
510 }
511 resource_id = status_or_resource_id.ValueOrDie();
512 break;
513 }
514 default:
515 TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
516 lua_error(state_);
517 return 0;
518 }
519 if (resource_id == 0) {
520 TC3_LOG(ERROR) << "Resource not found.";
521 lua_pushnil(state_);
522 return 1;
523 }
524 StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
525 JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
526 jni_cache_->resources_get_string,
527 resource_id);
528 if (!status_or_resource_str.ok()) {
529 TC3_LOG(ERROR) << "Error calling getString.";
530 lua_error(state_);
531 return 0;
532 }
533
534 if (status_or_resource_str.ValueOrDie() == nullptr) {
535 lua_pushnil(state_);
536 } else {
537 StatusOr<std::string> status_or_resource_std_str =
538 ToStlString(jenv_, status_or_resource_str.ValueOrDie().get());
539 if (!status_or_resource_std_str.ok()) {
540 lua_error(state_);
541 return 0;
542 }
543 PushString(status_or_resource_std_str.ValueOrDie());
544 }
545 return 1;
546 }
547
RetrieveSystemResources()548 bool JniLuaEnvironment::RetrieveSystemResources() {
549 if (system_resources_resources_retrieved_) {
550 return (system_resources_ != nullptr);
551 }
552 system_resources_resources_retrieved_ = true;
553 TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
554 JniHelper::CallStaticObjectMethod(
555 jenv_, jni_cache_->resources_class.get(),
556 jni_cache_->resources_get_system));
557 system_resources_ =
558 MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
559 return (system_resources_ != nullptr);
560 }
561
RetrieveUserManager()562 bool JniLuaEnvironment::RetrieveUserManager() {
563 if (context_ == nullptr) {
564 return false;
565 }
566 if (usermanager_retrieved_) {
567 return (usermanager_ != nullptr);
568 }
569 usermanager_retrieved_ = true;
570 TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
571 JniHelper::NewStringUTF(jenv_, "user"));
572 TC3_ASSIGN_OR_RETURN_FALSE(
573 const ScopedLocalRef<jobject> usermanager_ref,
574 JniHelper::CallObjectMethod(jenv_, context_,
575 jni_cache_->context_get_system_service,
576 service.get()));
577
578 usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
579 return (usermanager_ != nullptr);
580 }
581
ReadRemoteActionTemplateResult() const582 RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() const {
583 RemoteActionTemplate result;
584 // Read intent template.
585 lua_pushnil(state_);
586 while (Next(/*index=*/-2)) {
587 const StringPiece key = ReadString(/*index=*/-2);
588 if (key.Equals("title_without_entity")) {
589 result.title_without_entity = Read<std::string>(/*index=*/kIndexStackTop);
590 } else if (key.Equals("title_with_entity")) {
591 result.title_with_entity = Read<std::string>(/*index=*/kIndexStackTop);
592 } else if (key.Equals("description")) {
593 result.description = Read<std::string>(/*index=*/kIndexStackTop);
594 } else if (key.Equals("description_with_app_name")) {
595 result.description_with_app_name =
596 Read<std::string>(/*index=*/kIndexStackTop);
597 } else if (key.Equals("action")) {
598 result.action = Read<std::string>(/*index=*/kIndexStackTop);
599 } else if (key.Equals("data")) {
600 result.data = Read<std::string>(/*index=*/kIndexStackTop);
601 } else if (key.Equals("type")) {
602 result.type = Read<std::string>(/*index=*/kIndexStackTop);
603 } else if (key.Equals("flags")) {
604 result.flags = Read<int>(/*index=*/kIndexStackTop);
605 } else if (key.Equals("package_name")) {
606 result.package_name = Read<std::string>(/*index=*/kIndexStackTop);
607 } else if (key.Equals("request_code")) {
608 result.request_code = Read<int>(/*index=*/kIndexStackTop);
609 } else if (key.Equals("category")) {
610 result.category = ReadVector<std::string>(/*index=*/kIndexStackTop);
611 } else if (key.Equals("extra")) {
612 result.extra = ReadExtras();
613 } else {
614 TC3_LOG(INFO) << "Unknown entry: " << key;
615 }
616 lua_pop(state_, 1);
617 }
618 lua_pop(state_, 1);
619 return result;
620 }
621
ReadExtras() const622 std::map<std::string, Variant> JniLuaEnvironment::ReadExtras() const {
623 if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
624 TC3_LOG(ERROR) << "Expected extras table, got: "
625 << lua_type(state_, kIndexStackTop);
626 lua_pop(state_, 1);
627 return {};
628 }
629 std::map<std::string, Variant> extras;
630 lua_pushnil(state_);
631 while (Next(/*index=*/-2)) {
632 // Each entry is a table specifying name and value.
633 // The value is specified via a type specific field as Lua doesn't allow
634 // to easily distinguish between different number types.
635 if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
636 TC3_LOG(ERROR) << "Expected a table for an extra, got: "
637 << lua_type(state_, kIndexStackTop);
638 lua_pop(state_, 1);
639 return {};
640 }
641 std::string name;
642 Variant value;
643
644 lua_pushnil(state_);
645 while (Next(/*index=*/-2)) {
646 const StringPiece key = ReadString(/*index=*/-2);
647 if (key.Equals("name")) {
648 name = Read<std::string>(/*index=*/kIndexStackTop);
649 } else if (key.Equals("int_value")) {
650 value = Variant(Read<int>(/*index=*/kIndexStackTop));
651 } else if (key.Equals("long_value")) {
652 value = Variant(Read<int64>(/*index=*/kIndexStackTop));
653 } else if (key.Equals("float_value")) {
654 value = Variant(Read<float>(/*index=*/kIndexStackTop));
655 } else if (key.Equals("bool_value")) {
656 value = Variant(Read<bool>(/*index=*/kIndexStackTop));
657 } else if (key.Equals("string_value")) {
658 value = Variant(Read<std::string>(/*index=*/kIndexStackTop));
659 } else if (key.Equals("string_array_value")) {
660 value = Variant(ReadVector<std::string>(/*index=*/kIndexStackTop));
661 } else if (key.Equals("float_array_value")) {
662 value = Variant(ReadVector<float>(/*index=*/kIndexStackTop));
663 } else if (key.Equals("int_array_value")) {
664 value = Variant(ReadVector<int>(/*index=*/kIndexStackTop));
665 } else if (key.Equals("named_variant_array_value")) {
666 value = Variant(ReadExtras());
667 } else {
668 TC3_LOG(INFO) << "Unknown extra field: " << key;
669 }
670 lua_pop(state_, 1);
671 }
672 if (!name.empty()) {
673 extras[name] = value;
674 } else {
675 TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
676 }
677 lua_pop(state_, 1);
678 }
679 return extras;
680 }
681
ReadRemoteActionTemplates(std::vector<RemoteActionTemplate> * result)682 int JniLuaEnvironment::ReadRemoteActionTemplates(
683 std::vector<RemoteActionTemplate>* result) {
684 // Read result.
685 if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
686 TC3_LOG(ERROR) << "Unexpected result for snippet: "
687 << lua_type(state_, kIndexStackTop);
688 lua_error(state_);
689 return LUA_ERRRUN;
690 }
691
692 // Read remote action templates array.
693 lua_pushnil(state_);
694 while (Next(/*index=*/-2)) {
695 if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
696 TC3_LOG(ERROR) << "Expected intent table, got: "
697 << lua_type(state_, kIndexStackTop);
698 lua_pop(state_, 1);
699 continue;
700 }
701 result->push_back(ReadRemoteActionTemplateResult());
702 }
703 lua_pop(state_, /*n=*/1);
704 return LUA_OK;
705 }
706
RunIntentGenerator(const std::string & generator_snippet,std::vector<RemoteActionTemplate> * remote_actions)707 bool JniLuaEnvironment::RunIntentGenerator(
708 const std::string& generator_snippet,
709 std::vector<RemoteActionTemplate>* remote_actions) {
710 int status;
711 status = luaL_loadbuffer(state_, generator_snippet.data(),
712 generator_snippet.size(),
713 /*name=*/nullptr);
714 if (status != LUA_OK) {
715 TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
716 return false;
717 }
718 status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
719 if (status != LUA_OK) {
720 TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
721 return false;
722 }
723 if (RunProtected(
724 [this, remote_actions] {
725 return ReadRemoteActionTemplates(remote_actions);
726 },
727 /*num_args=*/1) != LUA_OK) {
728 TC3_LOG(ERROR) << "Could not read results.";
729 return false;
730 }
731 // Check that we correctly cleaned-up the state.
732 const int stack_size = lua_gettop(state_);
733 if (stack_size > 0) {
734 TC3_LOG(ERROR) << "Unexpected stack size.";
735 lua_settop(state_, 0);
736 return false;
737 }
738 return true;
739 }
740
741 // Lua environment for classfication result intent generation.
742 class AnnotatorJniEnvironment : public JniLuaEnvironment {
743 public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema)744 AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
745 const jobject context,
746 const std::vector<Locale>& device_locales,
747 const std::string& entity_text,
748 const ClassificationResult& classification,
749 const int64 reference_time_ms_utc,
750 const reflection::Schema* entity_data_schema)
751 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
752 entity_text_(entity_text),
753 classification_(classification),
754 reference_time_ms_utc_(reference_time_ms_utc),
755 entity_data_schema_(entity_data_schema) {}
756
757 protected:
SetupExternalHook()758 void SetupExternalHook() override {
759 JniLuaEnvironment::SetupExternalHook();
760 lua_pushinteger(state_, reference_time_ms_utc_);
761 lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
762
763 PushAnnotation(classification_, entity_text_, entity_data_schema_);
764 lua_setfield(state_, /*idx=*/-2, "entity");
765 }
766
767 const std::string& entity_text_;
768 const ClassificationResult& classification_;
769 const int64 reference_time_ms_utc_;
770
771 // Reflection schema data.
772 const reflection::Schema* const entity_data_schema_;
773 };
774
775 // Lua environment for actions intent generation.
776 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
777 public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)778 ActionsJniLuaEnvironment(
779 const Resources& resources, const JniCache* jni_cache,
780 const jobject context, const std::vector<Locale>& device_locales,
781 const Conversation& conversation, const ActionSuggestion& action,
782 const reflection::Schema* actions_entity_data_schema,
783 const reflection::Schema* annotations_entity_data_schema)
784 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
785 conversation_(conversation),
786 action_(action),
787 actions_entity_data_schema_(actions_entity_data_schema),
788 annotations_entity_data_schema_(annotations_entity_data_schema) {}
789
790 protected:
SetupExternalHook()791 void SetupExternalHook() override {
792 JniLuaEnvironment::SetupExternalHook();
793 PushConversation(&conversation_.messages, annotations_entity_data_schema_);
794 lua_setfield(state_, /*idx=*/-2, "conversation");
795
796 PushAction(action_, actions_entity_data_schema_,
797 annotations_entity_data_schema_);
798 lua_setfield(state_, /*idx=*/-2, "entity");
799 }
800
801 const Conversation& conversation_;
802 const ActionSuggestion& action_;
803 const reflection::Schema* actions_entity_data_schema_;
804 const reflection::Schema* annotations_entity_data_schema_;
805 };
806
807 } // namespace
808
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)809 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
810 const IntentFactoryModel* options, const ResourcePool* resources,
811 const std::shared_ptr<JniCache>& jni_cache) {
812 std::unique_ptr<IntentGenerator> intent_generator(
813 new IntentGenerator(options, resources, jni_cache));
814
815 if (options == nullptr || options->generator() == nullptr) {
816 TC3_LOG(ERROR) << "No intent generator options.";
817 return nullptr;
818 }
819
820 std::unique_ptr<ZlibDecompressor> zlib_decompressor =
821 ZlibDecompressor::Instance();
822 if (!zlib_decompressor) {
823 TC3_LOG(ERROR) << "Cannot initialize decompressor.";
824 return nullptr;
825 }
826
827 for (const IntentFactoryModel_::IntentGenerator* generator :
828 *options->generator()) {
829 std::string lua_template_generator;
830 if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
831 generator->lua_template_generator(),
832 generator->compressed_lua_template_generator(),
833 &lua_template_generator)) {
834 TC3_LOG(ERROR) << "Could not decompress generator template.";
835 return nullptr;
836 }
837
838 std::string lua_code = lua_template_generator;
839 if (options->precompile_generators()) {
840 if (!Compile(lua_template_generator, &lua_code)) {
841 TC3_LOG(ERROR) << "Could not precompile generator template.";
842 return nullptr;
843 }
844 }
845
846 intent_generator->generators_[generator->type()->str()] = lua_code;
847 }
848
849 return intent_generator;
850 }
851
ParseDeviceLocales(const jstring device_locales) const852 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
853 const jstring device_locales) const {
854 if (device_locales == nullptr) {
855 TC3_LOG(ERROR) << "No locales provided.";
856 return {};
857 }
858 ScopedStringChars locales_str =
859 GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
860 if (locales_str == nullptr) {
861 TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
862 return {};
863 }
864 std::vector<Locale> locales;
865 if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
866 &locales)) {
867 TC3_LOG(ERROR) << "Cannot parse locales.";
868 return {};
869 }
870 return locales;
871 }
872
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const873 bool IntentGenerator::GenerateIntents(
874 const jstring device_locales, const ClassificationResult& classification,
875 const int64 reference_time_ms_utc, const std::string& text,
876 const CodepointSpan selection_indices, const jobject context,
877 const reflection::Schema* annotations_entity_data_schema,
878 std::vector<RemoteActionTemplate>* remote_actions) const {
879 if (options_ == nullptr) {
880 return false;
881 }
882
883 // Retrieve generator for specified entity.
884 auto it = generators_.find(classification.collection);
885 if (it == generators_.end()) {
886 TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
887 return true;
888 }
889
890 const std::string entity_text =
891 UTF8ToUnicodeText(text, /*do_copy=*/false)
892 .UTF8Substring(selection_indices.first, selection_indices.second);
893
894 std::unique_ptr<AnnotatorJniEnvironment> interpreter(
895 new AnnotatorJniEnvironment(
896 resources_, jni_cache_.get(), context,
897 ParseDeviceLocales(device_locales), entity_text, classification,
898 reference_time_ms_utc, annotations_entity_data_schema));
899
900 if (!interpreter->Initialize()) {
901 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
902 return false;
903 }
904
905 return interpreter->RunIntentGenerator(it->second, remote_actions);
906 }
907
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const908 bool IntentGenerator::GenerateIntents(
909 const jstring device_locales, const ActionSuggestion& action,
910 const Conversation& conversation, const jobject context,
911 const reflection::Schema* annotations_entity_data_schema,
912 const reflection::Schema* actions_entity_data_schema,
913 std::vector<RemoteActionTemplate>* remote_actions) const {
914 if (options_ == nullptr) {
915 return false;
916 }
917
918 // Retrieve generator for specified action.
919 auto it = generators_.find(action.type);
920 if (it == generators_.end()) {
921 return true;
922 }
923
924 std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
925 new ActionsJniLuaEnvironment(
926 resources_, jni_cache_.get(), context,
927 ParseDeviceLocales(device_locales), conversation, action,
928 actions_entity_data_schema, annotations_entity_data_schema));
929
930 if (!interpreter->Initialize()) {
931 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
932 return false;
933 }
934
935 return interpreter->RunIntentGenerator(it->second, remote_actions);
936 }
937
938 } // namespace libtextclassifier3
939