• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/intents/intent-generator.h"
18 
19 #include <vector>
20 
21 #include "actions/types.h"
22 #include "annotator/types.h"
23 #include "utils/base/logging.h"
24 #include "utils/base/statusor.h"
25 #include "utils/hash/farmhash.h"
26 #include "utils/java/jni-base.h"
27 #include "utils/java/jni-helper.h"
28 #include "utils/java/string_utils.h"
29 #include "utils/lua-utils.h"
30 #include "utils/strings/stringpiece.h"
31 #include "utils/strings/substitute.h"
32 #include "utils/utf8/unicodetext.h"
33 #include "utils/variant.h"
34 #include "utils/zlib/zlib.h"
35 #include "flatbuffers/reflection_generated.h"
36 
37 #ifdef __cplusplus
38 extern "C" {
39 #endif
40 #include "lauxlib.h"
41 #include "lua.h"
42 #ifdef __cplusplus
43 }
44 #endif
45 
46 namespace libtextclassifier3 {
47 namespace {
48 
49 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
50 static constexpr const char* kHashKey = "hash";
51 static constexpr const char* kUrlSchemaKey = "url_schema";
52 static constexpr const char* kUrlHostKey = "url_host";
53 static constexpr const char* kUrlEncodeKey = "urlencode";
54 static constexpr const char* kPackageNameKey = "package_name";
55 static constexpr const char* kDeviceLocaleKey = "device_locales";
56 static constexpr const char* kFormatKey = "format";
57 
58 // An Android specific Lua environment with JNI backed callbacks.
59 class JniLuaEnvironment : public LuaEnvironment {
60  public:
61   JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
62                     const jobject context,
63                     const std::vector<Locale>& device_locales);
64   // Environment setup.
65   bool Initialize();
66 
67   // Runs an intent generator snippet.
68   bool RunIntentGenerator(const std::string& generator_snippet,
69                           std::vector<RemoteActionTemplate>* remote_actions);
70 
71  protected:
72   virtual void SetupExternalHook();
73 
74   int HandleExternalCallback();
75   int HandleAndroidCallback();
76   int HandleUserRestrictionsCallback();
77   int HandleUrlEncode();
78   int HandleUrlSchema();
79   int HandleHash();
80   int HandleFormat();
81   int HandleAndroidStringResources();
82   int HandleUrlHost();
83 
84   // Checks and retrieves string resources from the model.
85   bool LookupModelStringResource() const;
86 
87   // Reads and create a RemoteAction result from Lua.
88   RemoteActionTemplate ReadRemoteActionTemplateResult() const;
89 
90   // Reads the extras from the Lua result.
91   std::map<std::string, Variant> ReadExtras() const;
92 
93   // Retrieves user manager if not previously done.
94   bool RetrieveUserManager();
95 
96   // Retrieves system resources if not previously done.
97   bool RetrieveSystemResources();
98 
99   // Parse the url string by using Uri.parse from Java.
100   StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
101 
102   // Read remote action templates from lua generator.
103   int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
104 
105   const Resources& resources_;
106   JNIEnv* jenv_;
107   const JniCache* jni_cache_;
108   const jobject context_;
109   std::vector<Locale> device_locales_;
110 
111   ScopedGlobalRef<jobject> usermanager_;
112   // Whether we previously attempted to retrieve the UserManager before.
113   bool usermanager_retrieved_;
114 
115   ScopedGlobalRef<jobject> system_resources_;
116   // Whether we previously attempted to retrieve the system resources.
117   bool system_resources_resources_retrieved_;
118 
119   // Cached JNI references for Java strings `string` and `android`.
120   ScopedGlobalRef<jstring> string_;
121   ScopedGlobalRef<jstring> android_;
122 };
123 
JniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales)124 JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
125                                      const JniCache* jni_cache,
126                                      const jobject context,
127                                      const std::vector<Locale>& device_locales)
128     : resources_(resources),
129       jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
130       jni_cache_(jni_cache),
131       context_(context),
132       device_locales_(device_locales),
133       usermanager_(/*object=*/nullptr,
134                    /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
135       usermanager_retrieved_(false),
136       system_resources_(/*object=*/nullptr,
137                         /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
138       system_resources_resources_retrieved_(false),
139       string_(/*object=*/nullptr,
140               /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
141       android_(/*object=*/nullptr,
142                /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
143 
Initialize()144 bool JniLuaEnvironment::Initialize() {
145   TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
146                              JniHelper::NewStringUTF(jenv_, "string"));
147   string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
148   TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
149                              JniHelper::NewStringUTF(jenv_, "android"));
150   android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
151   if (string_ == nullptr || android_ == nullptr) {
152     TC3_LOG(ERROR) << "Could not allocate constant strings references.";
153     return false;
154   }
155   return (RunProtected([this] {
156             LoadDefaultLibraries();
157             SetupExternalHook();
158             lua_setglobal(state_, "external");
159             return LUA_OK;
160           }) == LUA_OK);
161 }
162 
SetupExternalHook()163 void JniLuaEnvironment::SetupExternalHook() {
164   // This exposes an `external` object with the following fields:
165   //   * entity: the bundle with all information about a classification.
166   //   * android: callbacks into specific android provided methods.
167   //   * android.user_restrictions: callbacks to check user permissions.
168   //   * android.R: callbacks to retrieve string resources.
169   PushLazyObject(&JniLuaEnvironment::HandleExternalCallback);
170 
171   // android
172   PushLazyObject(&JniLuaEnvironment::HandleAndroidCallback);
173   {
174     // android.user_restrictions
175     PushLazyObject(&JniLuaEnvironment::HandleUserRestrictionsCallback);
176     lua_setfield(state_, /*idx=*/-2, "user_restrictions");
177 
178     // android.R
179     // Callback to access android string resources.
180     PushLazyObject(&JniLuaEnvironment::HandleAndroidStringResources);
181     lua_setfield(state_, /*idx=*/-2, "R");
182   }
183   lua_setfield(state_, /*idx=*/-2, "android");
184 }
185 
HandleExternalCallback()186 int JniLuaEnvironment::HandleExternalCallback() {
187   const StringPiece key = ReadString(kIndexStackTop);
188   if (key.Equals(kHashKey)) {
189     PushFunction(&JniLuaEnvironment::HandleHash);
190     return 1;
191   } else if (key.Equals(kFormatKey)) {
192     PushFunction(&JniLuaEnvironment::HandleFormat);
193     return 1;
194   } else {
195     TC3_LOG(ERROR) << "Undefined external access " << key;
196     lua_error(state_);
197     return 0;
198   }
199 }
200 
HandleAndroidCallback()201 int JniLuaEnvironment::HandleAndroidCallback() {
202   const StringPiece key = ReadString(kIndexStackTop);
203   if (key.Equals(kDeviceLocaleKey)) {
204     // Provide the locale as table with the individual fields set.
205     lua_newtable(state_);
206     for (int i = 0; i < device_locales_.size(); i++) {
207       // Adjust index to 1-based indexing for Lua.
208       lua_pushinteger(state_, i + 1);
209       lua_newtable(state_);
210       PushString(device_locales_[i].Language());
211       lua_setfield(state_, -2, "language");
212       PushString(device_locales_[i].Region());
213       lua_setfield(state_, -2, "region");
214       PushString(device_locales_[i].Script());
215       lua_setfield(state_, -2, "script");
216       lua_settable(state_, /*idx=*/-3);
217     }
218     return 1;
219   } else if (key.Equals(kPackageNameKey)) {
220     if (context_ == nullptr) {
221       TC3_LOG(ERROR) << "Context invalid.";
222       lua_error(state_);
223       return 0;
224     }
225 
226     StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
227         JniHelper::CallObjectMethod<jstring>(
228             jenv_, context_, jni_cache_->context_get_package_name);
229 
230     if (!status_or_package_name_str.ok()) {
231       TC3_LOG(ERROR) << "Error calling Context.getPackageName";
232       lua_error(state_);
233       return 0;
234     }
235     StatusOr<std::string> status_or_package_name_std_str =
236         ToStlString(jenv_, status_or_package_name_str.ValueOrDie().get());
237     if (!status_or_package_name_std_str.ok()) {
238       lua_error(state_);
239       return 0;
240     }
241     PushString(status_or_package_name_std_str.ValueOrDie());
242     return 1;
243   } else if (key.Equals(kUrlEncodeKey)) {
244     PushFunction(&JniLuaEnvironment::HandleUrlEncode);
245     return 1;
246   } else if (key.Equals(kUrlHostKey)) {
247     PushFunction(&JniLuaEnvironment::HandleUrlHost);
248     return 1;
249   } else if (key.Equals(kUrlSchemaKey)) {
250     PushFunction(&JniLuaEnvironment::HandleUrlSchema);
251     return 1;
252   } else {
253     TC3_LOG(ERROR) << "Undefined android reference " << key;
254     lua_error(state_);
255     return 0;
256   }
257 }
258 
HandleUserRestrictionsCallback()259 int JniLuaEnvironment::HandleUserRestrictionsCallback() {
260   if (jni_cache_->usermanager_class == nullptr ||
261       jni_cache_->usermanager_get_user_restrictions == nullptr) {
262     // UserManager is only available for API level >= 17 and
263     // getUserRestrictions only for API level >= 18, so we just return false
264     // normally here.
265     lua_pushboolean(state_, false);
266     return 1;
267   }
268 
269   // Get user manager if not previously retrieved.
270   if (!RetrieveUserManager()) {
271     TC3_LOG(ERROR) << "Error retrieving user manager.";
272     lua_error(state_);
273     return 0;
274   }
275 
276   StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
277       JniHelper::CallObjectMethod(
278           jenv_, usermanager_.get(),
279           jni_cache_->usermanager_get_user_restrictions);
280   if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
281     TC3_LOG(ERROR) << "Error calling getUserRestrictions";
282     lua_error(state_);
283     return 0;
284   }
285 
286   const StringPiece key_str = ReadString(kIndexStackTop);
287   if (key_str.empty()) {
288     TC3_LOG(ERROR) << "Expected string, got null.";
289     lua_error(state_);
290     return 0;
291   }
292 
293   const StatusOr<ScopedLocalRef<jstring>> status_or_key =
294       jni_cache_->ConvertToJavaString(key_str);
295   if (!status_or_key.ok()) {
296     lua_error(state_);
297     return 0;
298   }
299   const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
300       jenv_, status_or_bundle.ValueOrDie().get(),
301       jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
302   if (!status_or_permission.ok()) {
303     TC3_LOG(ERROR) << "Error getting bundle value";
304     lua_pushboolean(state_, false);
305   } else {
306     lua_pushboolean(state_, status_or_permission.ValueOrDie());
307   }
308   return 1;
309 }
310 
HandleUrlEncode()311 int JniLuaEnvironment::HandleUrlEncode() {
312   const StringPiece input = ReadString(/*index=*/1);
313   if (input.empty()) {
314     TC3_LOG(ERROR) << "Expected string, got null.";
315     lua_error(state_);
316     return 0;
317   }
318 
319   // Call Java URL encoder.
320   const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
321       jni_cache_->ConvertToJavaString(input);
322   if (!status_or_input_str.ok()) {
323     lua_error(state_);
324     return 0;
325   }
326   StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
327       JniHelper::CallStaticObjectMethod<jstring>(
328           jenv_, jni_cache_->urlencoder_class.get(),
329           jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
330           jni_cache_->string_utf8.get());
331 
332   if (!status_or_encoded_str.ok()) {
333     TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
334     lua_error(state_);
335     return 0;
336   }
337   const StatusOr<std::string> status_or_encoded_std_str =
338       ToStlString(jenv_, status_or_encoded_str.ValueOrDie().get());
339   if (!status_or_encoded_std_str.ok()) {
340     lua_error(state_);
341     return 0;
342   }
343   PushString(status_or_encoded_std_str.ValueOrDie());
344   return 1;
345 }
346 
ParseUri(StringPiece url) const347 StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
348     StringPiece url) const {
349   if (url.empty()) {
350     return {Status::UNKNOWN};
351   }
352 
353   // Call to Java URI parser.
354   TC3_ASSIGN_OR_RETURN(
355       const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
356       jni_cache_->ConvertToJavaString(url));
357 
358   // Try to parse uri and get scheme.
359   TC3_ASSIGN_OR_RETURN(
360       ScopedLocalRef<jobject> uri,
361       JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
362                                         jni_cache_->uri_parse,
363                                         status_or_url_str.ValueOrDie().get()));
364   if (uri == nullptr) {
365     TC3_LOG(ERROR) << "Error calling Uri.parse";
366     return {Status::UNKNOWN};
367   }
368   return uri;
369 }
370 
HandleUrlSchema()371 int JniLuaEnvironment::HandleUrlSchema() {
372   StringPiece url = ReadString(/*index=*/1);
373 
374   const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
375   if (!status_or_parsed_uri.ok()) {
376     lua_error(state_);
377     return 0;
378   }
379 
380   const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
381       JniHelper::CallObjectMethod<jstring>(
382           jenv_, status_or_parsed_uri.ValueOrDie().get(),
383           jni_cache_->uri_get_scheme);
384   if (!status_or_scheme_str.ok()) {
385     TC3_LOG(ERROR) << "Error calling Uri.getScheme";
386     lua_error(state_);
387     return 0;
388   }
389   if (status_or_scheme_str.ValueOrDie() == nullptr) {
390     lua_pushnil(state_);
391   } else {
392     const StatusOr<std::string> status_or_scheme_std_str =
393         ToStlString(jenv_, status_or_scheme_str.ValueOrDie().get());
394     if (!status_or_scheme_std_str.ok()) {
395       lua_error(state_);
396       return 0;
397     }
398     PushString(status_or_scheme_std_str.ValueOrDie());
399   }
400   return 1;
401 }
402 
HandleUrlHost()403 int JniLuaEnvironment::HandleUrlHost() {
404   const StringPiece url = ReadString(kIndexStackTop);
405 
406   const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
407   if (!status_or_parsed_uri.ok()) {
408     lua_error(state_);
409     return 0;
410   }
411 
412   const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
413       JniHelper::CallObjectMethod<jstring>(
414           jenv_, status_or_parsed_uri.ValueOrDie().get(),
415           jni_cache_->uri_get_host);
416   if (!status_or_host_str.ok()) {
417     TC3_LOG(ERROR) << "Error calling Uri.getHost";
418     lua_error(state_);
419     return 0;
420   }
421 
422   if (status_or_host_str.ValueOrDie() == nullptr) {
423     lua_pushnil(state_);
424   } else {
425     const StatusOr<std::string> status_or_host_std_str =
426         ToStlString(jenv_, status_or_host_str.ValueOrDie().get());
427     if (!status_or_host_std_str.ok()) {
428       lua_error(state_);
429       return 0;
430     }
431     PushString(status_or_host_std_str.ValueOrDie());
432   }
433   return 1;
434 }
435 
HandleHash()436 int JniLuaEnvironment::HandleHash() {
437   const StringPiece input = ReadString(kIndexStackTop);
438   lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
439   return 1;
440 }
441 
HandleFormat()442 int JniLuaEnvironment::HandleFormat() {
443   const int num_args = lua_gettop(state_);
444   std::vector<StringPiece> args(num_args - 1);
445   for (int i = 0; i < num_args - 1; i++) {
446     args[i] = ReadString(/*index=*/i + 2);
447   }
448   PushString(strings::Substitute(ReadString(/*index=*/1), args));
449   return 1;
450 }
451 
LookupModelStringResource() const452 bool JniLuaEnvironment::LookupModelStringResource() const {
453   // Handle only lookup by name.
454   if (lua_type(state_, kIndexStackTop) != LUA_TSTRING) {
455     return false;
456   }
457 
458   const StringPiece resource_name = ReadString(kIndexStackTop);
459   std::string resource_content;
460   if (!resources_.GetResourceContent(device_locales_, resource_name,
461                                      &resource_content)) {
462     // Resource cannot be provided by the model.
463     return false;
464   }
465 
466   PushString(resource_content);
467   return true;
468 }
469 
HandleAndroidStringResources()470 int JniLuaEnvironment::HandleAndroidStringResources() {
471   // Check whether the requested resource can be served from the model data.
472   if (LookupModelStringResource()) {
473     return 1;
474   }
475 
476   // Get system resources if not previously retrieved.
477   if (!RetrieveSystemResources()) {
478     TC3_LOG(ERROR) << "Error retrieving system resources.";
479     lua_error(state_);
480     return 0;
481   }
482 
483   int resource_id;
484   switch (lua_type(state_, kIndexStackTop)) {
485     case LUA_TNUMBER:
486       resource_id = Read<int>(/*index=*/kIndexStackTop);
487       break;
488     case LUA_TSTRING: {
489       const StringPiece resource_name_str = ReadString(kIndexStackTop);
490       if (resource_name_str.empty()) {
491         TC3_LOG(ERROR) << "No resource name provided.";
492         lua_error(state_);
493         return 0;
494       }
495       const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
496           jni_cache_->ConvertToJavaString(resource_name_str);
497       if (!status_or_resource_name.ok()) {
498         TC3_LOG(ERROR) << "Invalid resource name.";
499         lua_error(state_);
500         return 0;
501       }
502       StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
503           jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
504           status_or_resource_name.ValueOrDie().get(), string_.get(),
505           android_.get());
506       if (!status_or_resource_id.ok()) {
507         TC3_LOG(ERROR) << "Error calling getIdentifier.";
508         lua_error(state_);
509         return 0;
510       }
511       resource_id = status_or_resource_id.ValueOrDie();
512       break;
513     }
514     default:
515       TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
516       lua_error(state_);
517       return 0;
518   }
519   if (resource_id == 0) {
520     TC3_LOG(ERROR) << "Resource not found.";
521     lua_pushnil(state_);
522     return 1;
523   }
524   StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
525       JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
526                                            jni_cache_->resources_get_string,
527                                            resource_id);
528   if (!status_or_resource_str.ok()) {
529     TC3_LOG(ERROR) << "Error calling getString.";
530     lua_error(state_);
531     return 0;
532   }
533 
534   if (status_or_resource_str.ValueOrDie() == nullptr) {
535     lua_pushnil(state_);
536   } else {
537     StatusOr<std::string> status_or_resource_std_str =
538         ToStlString(jenv_, status_or_resource_str.ValueOrDie().get());
539     if (!status_or_resource_std_str.ok()) {
540       lua_error(state_);
541       return 0;
542     }
543     PushString(status_or_resource_std_str.ValueOrDie());
544   }
545   return 1;
546 }
547 
RetrieveSystemResources()548 bool JniLuaEnvironment::RetrieveSystemResources() {
549   if (system_resources_resources_retrieved_) {
550     return (system_resources_ != nullptr);
551   }
552   system_resources_resources_retrieved_ = true;
553   TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
554                              JniHelper::CallStaticObjectMethod(
555                                  jenv_, jni_cache_->resources_class.get(),
556                                  jni_cache_->resources_get_system));
557   system_resources_ =
558       MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
559   return (system_resources_ != nullptr);
560 }
561 
RetrieveUserManager()562 bool JniLuaEnvironment::RetrieveUserManager() {
563   if (context_ == nullptr) {
564     return false;
565   }
566   if (usermanager_retrieved_) {
567     return (usermanager_ != nullptr);
568   }
569   usermanager_retrieved_ = true;
570   TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
571                              JniHelper::NewStringUTF(jenv_, "user"));
572   TC3_ASSIGN_OR_RETURN_FALSE(
573       const ScopedLocalRef<jobject> usermanager_ref,
574       JniHelper::CallObjectMethod(jenv_, context_,
575                                   jni_cache_->context_get_system_service,
576                                   service.get()));
577 
578   usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
579   return (usermanager_ != nullptr);
580 }
581 
ReadRemoteActionTemplateResult() const582 RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() const {
583   RemoteActionTemplate result;
584   // Read intent template.
585   lua_pushnil(state_);
586   while (Next(/*index=*/-2)) {
587     const StringPiece key = ReadString(/*index=*/-2);
588     if (key.Equals("title_without_entity")) {
589       result.title_without_entity = Read<std::string>(/*index=*/kIndexStackTop);
590     } else if (key.Equals("title_with_entity")) {
591       result.title_with_entity = Read<std::string>(/*index=*/kIndexStackTop);
592     } else if (key.Equals("description")) {
593       result.description = Read<std::string>(/*index=*/kIndexStackTop);
594     } else if (key.Equals("description_with_app_name")) {
595       result.description_with_app_name =
596           Read<std::string>(/*index=*/kIndexStackTop);
597     } else if (key.Equals("action")) {
598       result.action = Read<std::string>(/*index=*/kIndexStackTop);
599     } else if (key.Equals("data")) {
600       result.data = Read<std::string>(/*index=*/kIndexStackTop);
601     } else if (key.Equals("type")) {
602       result.type = Read<std::string>(/*index=*/kIndexStackTop);
603     } else if (key.Equals("flags")) {
604       result.flags = Read<int>(/*index=*/kIndexStackTop);
605     } else if (key.Equals("package_name")) {
606       result.package_name = Read<std::string>(/*index=*/kIndexStackTop);
607     } else if (key.Equals("request_code")) {
608       result.request_code = Read<int>(/*index=*/kIndexStackTop);
609     } else if (key.Equals("category")) {
610       result.category = ReadVector<std::string>(/*index=*/kIndexStackTop);
611     } else if (key.Equals("extra")) {
612       result.extra = ReadExtras();
613     } else {
614       TC3_LOG(INFO) << "Unknown entry: " << key;
615     }
616     lua_pop(state_, 1);
617   }
618   lua_pop(state_, 1);
619   return result;
620 }
621 
ReadExtras() const622 std::map<std::string, Variant> JniLuaEnvironment::ReadExtras() const {
623   if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
624     TC3_LOG(ERROR) << "Expected extras table, got: "
625                    << lua_type(state_, kIndexStackTop);
626     lua_pop(state_, 1);
627     return {};
628   }
629   std::map<std::string, Variant> extras;
630   lua_pushnil(state_);
631   while (Next(/*index=*/-2)) {
632     // Each entry is a table specifying name and value.
633     // The value is specified via a type specific field as Lua doesn't allow
634     // to easily distinguish between different number types.
635     if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
636       TC3_LOG(ERROR) << "Expected a table for an extra, got: "
637                      << lua_type(state_, kIndexStackTop);
638       lua_pop(state_, 1);
639       return {};
640     }
641     std::string name;
642     Variant value;
643 
644     lua_pushnil(state_);
645     while (Next(/*index=*/-2)) {
646       const StringPiece key = ReadString(/*index=*/-2);
647       if (key.Equals("name")) {
648         name = Read<std::string>(/*index=*/kIndexStackTop);
649       } else if (key.Equals("int_value")) {
650         value = Variant(Read<int>(/*index=*/kIndexStackTop));
651       } else if (key.Equals("long_value")) {
652         value = Variant(Read<int64>(/*index=*/kIndexStackTop));
653       } else if (key.Equals("float_value")) {
654         value = Variant(Read<float>(/*index=*/kIndexStackTop));
655       } else if (key.Equals("bool_value")) {
656         value = Variant(Read<bool>(/*index=*/kIndexStackTop));
657       } else if (key.Equals("string_value")) {
658         value = Variant(Read<std::string>(/*index=*/kIndexStackTop));
659       } else if (key.Equals("string_array_value")) {
660         value = Variant(ReadVector<std::string>(/*index=*/kIndexStackTop));
661       } else if (key.Equals("float_array_value")) {
662         value = Variant(ReadVector<float>(/*index=*/kIndexStackTop));
663       } else if (key.Equals("int_array_value")) {
664         value = Variant(ReadVector<int>(/*index=*/kIndexStackTop));
665       } else if (key.Equals("named_variant_array_value")) {
666         value = Variant(ReadExtras());
667       } else {
668         TC3_LOG(INFO) << "Unknown extra field: " << key;
669       }
670       lua_pop(state_, 1);
671     }
672     if (!name.empty()) {
673       extras[name] = value;
674     } else {
675       TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
676     }
677     lua_pop(state_, 1);
678   }
679   return extras;
680 }
681 
ReadRemoteActionTemplates(std::vector<RemoteActionTemplate> * result)682 int JniLuaEnvironment::ReadRemoteActionTemplates(
683     std::vector<RemoteActionTemplate>* result) {
684   // Read result.
685   if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
686     TC3_LOG(ERROR) << "Unexpected result for snippet: "
687                    << lua_type(state_, kIndexStackTop);
688     lua_error(state_);
689     return LUA_ERRRUN;
690   }
691 
692   // Read remote action templates array.
693   lua_pushnil(state_);
694   while (Next(/*index=*/-2)) {
695     if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
696       TC3_LOG(ERROR) << "Expected intent table, got: "
697                      << lua_type(state_, kIndexStackTop);
698       lua_pop(state_, 1);
699       continue;
700     }
701     result->push_back(ReadRemoteActionTemplateResult());
702   }
703   lua_pop(state_, /*n=*/1);
704   return LUA_OK;
705 }
706 
RunIntentGenerator(const std::string & generator_snippet,std::vector<RemoteActionTemplate> * remote_actions)707 bool JniLuaEnvironment::RunIntentGenerator(
708     const std::string& generator_snippet,
709     std::vector<RemoteActionTemplate>* remote_actions) {
710   int status;
711   status = luaL_loadbuffer(state_, generator_snippet.data(),
712                            generator_snippet.size(),
713                            /*name=*/nullptr);
714   if (status != LUA_OK) {
715     TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
716     return false;
717   }
718   status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
719   if (status != LUA_OK) {
720     TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
721     return false;
722   }
723   if (RunProtected(
724           [this, remote_actions] {
725             return ReadRemoteActionTemplates(remote_actions);
726           },
727           /*num_args=*/1) != LUA_OK) {
728     TC3_LOG(ERROR) << "Could not read results.";
729     return false;
730   }
731   // Check that we correctly cleaned-up the state.
732   const int stack_size = lua_gettop(state_);
733   if (stack_size > 0) {
734     TC3_LOG(ERROR) << "Unexpected stack size.";
735     lua_settop(state_, 0);
736     return false;
737   }
738   return true;
739 }
740 
741 // Lua environment for classfication result intent generation.
742 class AnnotatorJniEnvironment : public JniLuaEnvironment {
743  public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema)744   AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
745                           const jobject context,
746                           const std::vector<Locale>& device_locales,
747                           const std::string& entity_text,
748                           const ClassificationResult& classification,
749                           const int64 reference_time_ms_utc,
750                           const reflection::Schema* entity_data_schema)
751       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
752         entity_text_(entity_text),
753         classification_(classification),
754         reference_time_ms_utc_(reference_time_ms_utc),
755         entity_data_schema_(entity_data_schema) {}
756 
757  protected:
SetupExternalHook()758   void SetupExternalHook() override {
759     JniLuaEnvironment::SetupExternalHook();
760     lua_pushinteger(state_, reference_time_ms_utc_);
761     lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
762 
763     PushAnnotation(classification_, entity_text_, entity_data_schema_);
764     lua_setfield(state_, /*idx=*/-2, "entity");
765   }
766 
767   const std::string& entity_text_;
768   const ClassificationResult& classification_;
769   const int64 reference_time_ms_utc_;
770 
771   // Reflection schema data.
772   const reflection::Schema* const entity_data_schema_;
773 };
774 
775 // Lua environment for actions intent generation.
776 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
777  public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)778   ActionsJniLuaEnvironment(
779       const Resources& resources, const JniCache* jni_cache,
780       const jobject context, const std::vector<Locale>& device_locales,
781       const Conversation& conversation, const ActionSuggestion& action,
782       const reflection::Schema* actions_entity_data_schema,
783       const reflection::Schema* annotations_entity_data_schema)
784       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
785         conversation_(conversation),
786         action_(action),
787         actions_entity_data_schema_(actions_entity_data_schema),
788         annotations_entity_data_schema_(annotations_entity_data_schema) {}
789 
790  protected:
SetupExternalHook()791   void SetupExternalHook() override {
792     JniLuaEnvironment::SetupExternalHook();
793     PushConversation(&conversation_.messages, annotations_entity_data_schema_);
794     lua_setfield(state_, /*idx=*/-2, "conversation");
795 
796     PushAction(action_, actions_entity_data_schema_,
797                annotations_entity_data_schema_);
798     lua_setfield(state_, /*idx=*/-2, "entity");
799   }
800 
801   const Conversation& conversation_;
802   const ActionSuggestion& action_;
803   const reflection::Schema* actions_entity_data_schema_;
804   const reflection::Schema* annotations_entity_data_schema_;
805 };
806 
807 }  // namespace
808 
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)809 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
810     const IntentFactoryModel* options, const ResourcePool* resources,
811     const std::shared_ptr<JniCache>& jni_cache) {
812   std::unique_ptr<IntentGenerator> intent_generator(
813       new IntentGenerator(options, resources, jni_cache));
814 
815   if (options == nullptr || options->generator() == nullptr) {
816     TC3_LOG(ERROR) << "No intent generator options.";
817     return nullptr;
818   }
819 
820   std::unique_ptr<ZlibDecompressor> zlib_decompressor =
821       ZlibDecompressor::Instance();
822   if (!zlib_decompressor) {
823     TC3_LOG(ERROR) << "Cannot initialize decompressor.";
824     return nullptr;
825   }
826 
827   for (const IntentFactoryModel_::IntentGenerator* generator :
828        *options->generator()) {
829     std::string lua_template_generator;
830     if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
831             generator->lua_template_generator(),
832             generator->compressed_lua_template_generator(),
833             &lua_template_generator)) {
834       TC3_LOG(ERROR) << "Could not decompress generator template.";
835       return nullptr;
836     }
837 
838     std::string lua_code = lua_template_generator;
839     if (options->precompile_generators()) {
840       if (!Compile(lua_template_generator, &lua_code)) {
841         TC3_LOG(ERROR) << "Could not precompile generator template.";
842         return nullptr;
843       }
844     }
845 
846     intent_generator->generators_[generator->type()->str()] = lua_code;
847   }
848 
849   return intent_generator;
850 }
851 
ParseDeviceLocales(const jstring device_locales) const852 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
853     const jstring device_locales) const {
854   if (device_locales == nullptr) {
855     TC3_LOG(ERROR) << "No locales provided.";
856     return {};
857   }
858   ScopedStringChars locales_str =
859       GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
860   if (locales_str == nullptr) {
861     TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
862     return {};
863   }
864   std::vector<Locale> locales;
865   if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
866                     &locales)) {
867     TC3_LOG(ERROR) << "Cannot parse locales.";
868     return {};
869   }
870   return locales;
871 }
872 
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const873 bool IntentGenerator::GenerateIntents(
874     const jstring device_locales, const ClassificationResult& classification,
875     const int64 reference_time_ms_utc, const std::string& text,
876     const CodepointSpan selection_indices, const jobject context,
877     const reflection::Schema* annotations_entity_data_schema,
878     std::vector<RemoteActionTemplate>* remote_actions) const {
879   if (options_ == nullptr) {
880     return false;
881   }
882 
883   // Retrieve generator for specified entity.
884   auto it = generators_.find(classification.collection);
885   if (it == generators_.end()) {
886     TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
887     return true;
888   }
889 
890   const std::string entity_text =
891       UTF8ToUnicodeText(text, /*do_copy=*/false)
892           .UTF8Substring(selection_indices.first, selection_indices.second);
893 
894   std::unique_ptr<AnnotatorJniEnvironment> interpreter(
895       new AnnotatorJniEnvironment(
896           resources_, jni_cache_.get(), context,
897           ParseDeviceLocales(device_locales), entity_text, classification,
898           reference_time_ms_utc, annotations_entity_data_schema));
899 
900   if (!interpreter->Initialize()) {
901     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
902     return false;
903   }
904 
905   return interpreter->RunIntentGenerator(it->second, remote_actions);
906 }
907 
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const908 bool IntentGenerator::GenerateIntents(
909     const jstring device_locales, const ActionSuggestion& action,
910     const Conversation& conversation, const jobject context,
911     const reflection::Schema* annotations_entity_data_schema,
912     const reflection::Schema* actions_entity_data_schema,
913     std::vector<RemoteActionTemplate>* remote_actions) const {
914   if (options_ == nullptr) {
915     return false;
916   }
917 
918   // Retrieve generator for specified action.
919   auto it = generators_.find(action.type);
920   if (it == generators_.end()) {
921     return true;
922   }
923 
924   std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
925       new ActionsJniLuaEnvironment(
926           resources_, jni_cache_.get(), context,
927           ParseDeviceLocales(device_locales), conversation, action,
928           actions_entity_data_schema, annotations_entity_data_schema));
929 
930   if (!interpreter->Initialize()) {
931     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
932     return false;
933   }
934 
935   return interpreter->RunIntentGenerator(it->second, remote_actions);
936 }
937 
938 }  // namespace libtextclassifier3
939