• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/intents/intent-generator.h"
18 
19 #include <vector>
20 
21 #include "actions/lua-utils.h"
22 #include "actions/types.h"
23 #include "annotator/types.h"
24 #include "utils/base/logging.h"
25 #include "utils/hash/farmhash.h"
26 #include "utils/java/jni-base.h"
27 #include "utils/java/string_utils.h"
28 #include "utils/lua-utils.h"
29 #include "utils/strings/stringpiece.h"
30 #include "utils/strings/substitute.h"
31 #include "utils/utf8/unicodetext.h"
32 #include "utils/variant.h"
33 #include "utils/zlib/zlib.h"
34 #include "flatbuffers/reflection_generated.h"
35 
36 #ifdef __cplusplus
37 extern "C" {
38 #endif
39 #include "lauxlib.h"
40 #include "lua.h"
41 #ifdef __cplusplus
42 }
43 #endif
44 
45 namespace libtextclassifier3 {
46 namespace {
47 
48 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
49 static constexpr const char* kHashKey = "hash";
50 static constexpr const char* kUrlSchemaKey = "url_schema";
51 static constexpr const char* kUrlHostKey = "url_host";
52 static constexpr const char* kUrlEncodeKey = "urlencode";
53 static constexpr const char* kPackageNameKey = "package_name";
54 static constexpr const char* kDeviceLocaleKey = "device_locales";
55 static constexpr const char* kFormatKey = "format";
56 
57 // An Android specific Lua environment with JNI backed callbacks.
58 class JniLuaEnvironment : public LuaEnvironment {
59  public:
60   JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
61                     const jobject context,
62                     const std::vector<Locale>& device_locales);
63   // Environment setup.
64   bool Initialize();
65 
66   // Runs an intent generator snippet.
67   bool RunIntentGenerator(const std::string& generator_snippet,
68                           std::vector<RemoteActionTemplate>* remote_actions);
69 
70  protected:
71   virtual void SetupExternalHook();
72 
73   int HandleExternalCallback();
74   int HandleAndroidCallback();
75   int HandleUserRestrictionsCallback();
76   int HandleUrlEncode();
77   int HandleUrlSchema();
78   int HandleHash();
79   int HandleFormat();
80   int HandleAndroidStringResources();
81   int HandleUrlHost();
82 
83   // Checks and retrieves string resources from the model.
84   bool LookupModelStringResource();
85 
86   // Reads and create a RemoteAction result from Lua.
87   RemoteActionTemplate ReadRemoteActionTemplateResult();
88 
89   // Reads the extras from the Lua result.
90   void ReadExtras(std::map<std::string, Variant>* extra);
91 
92   // Reads the intent categories array from a Lua result.
93   void ReadCategories(std::vector<std::string>* category);
94 
95   // Retrieves user manager if not previously done.
96   bool RetrieveUserManager();
97 
98   // Retrieves system resources if not previously done.
99   bool RetrieveSystemResources();
100 
101   // Parse the url string by using Uri.parse from Java.
102   ScopedLocalRef<jobject> ParseUri(StringPiece url) const;
103 
104   // Read remote action templates from lua generator.
105   int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
106 
107   const Resources& resources_;
108   JNIEnv* jenv_;
109   const JniCache* jni_cache_;
110   const jobject context_;
111   std::vector<Locale> device_locales_;
112 
113   ScopedGlobalRef<jobject> usermanager_;
114   // Whether we previously attempted to retrieve the UserManager before.
115   bool usermanager_retrieved_;
116 
117   ScopedGlobalRef<jobject> system_resources_;
118   // Whether we previously attempted to retrieve the system resources.
119   bool system_resources_resources_retrieved_;
120 
121   // Cached JNI references for Java strings `string` and `android`.
122   ScopedGlobalRef<jstring> string_;
123   ScopedGlobalRef<jstring> android_;
124 };
125 
JniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales)126 JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
127                                      const JniCache* jni_cache,
128                                      const jobject context,
129                                      const std::vector<Locale>& device_locales)
130     : resources_(resources),
131       jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
132       jni_cache_(jni_cache),
133       context_(context),
134       device_locales_(device_locales),
135       usermanager_(/*object=*/nullptr,
136                    /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
137       usermanager_retrieved_(false),
138       system_resources_(/*object=*/nullptr,
139                         /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
140       system_resources_resources_retrieved_(false),
141       string_(/*object=*/nullptr,
142               /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
143       android_(/*object=*/nullptr,
144                /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
145 
Initialize()146 bool JniLuaEnvironment::Initialize() {
147   string_ =
148       MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
149   android_ =
150       MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
151   if (string_ == nullptr || android_ == nullptr) {
152     TC3_LOG(ERROR) << "Could not allocate constant strings references.";
153     return false;
154   }
155   return (RunProtected([this] {
156             LoadDefaultLibraries();
157             SetupExternalHook();
158             lua_setglobal(state_, "external");
159             return LUA_OK;
160           }) == LUA_OK);
161 }
162 
SetupExternalHook()163 void JniLuaEnvironment::SetupExternalHook() {
164   // This exposes an `external` object with the following fields:
165   //   * entity: the bundle with all information about a classification.
166   //   * android: callbacks into specific android provided methods.
167   //   * android.user_restrictions: callbacks to check user permissions.
168   //   * android.R: callbacks to retrieve string resources.
169   BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>(
170       "external");
171 
172   // android
173   BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>(
174       "android");
175   {
176     // android.user_restrictions
177     BindTable<JniLuaEnvironment,
178               &JniLuaEnvironment::HandleUserRestrictionsCallback>(
179         "user_restrictions");
180     lua_setfield(state_, /*idx=*/-2, "user_restrictions");
181 
182     // android.R
183     // Callback to access android string resources.
184     BindTable<JniLuaEnvironment,
185               &JniLuaEnvironment::HandleAndroidStringResources>("R");
186     lua_setfield(state_, /*idx=*/-2, "R");
187   }
188   lua_setfield(state_, /*idx=*/-2, "android");
189 }
190 
HandleExternalCallback()191 int JniLuaEnvironment::HandleExternalCallback() {
192   const StringPiece key = ReadString(/*index=*/-1);
193   if (key.Equals(kHashKey)) {
194     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
195     return 1;
196   } else if (key.Equals(kFormatKey)) {
197     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>();
198     return 1;
199   } else {
200     TC3_LOG(ERROR) << "Undefined external access " << key.ToString();
201     lua_error(state_);
202     return 0;
203   }
204 }
205 
HandleAndroidCallback()206 int JniLuaEnvironment::HandleAndroidCallback() {
207   const StringPiece key = ReadString(/*index=*/-1);
208   if (key.Equals(kDeviceLocaleKey)) {
209     // Provide the locale as table with the individual fields set.
210     lua_newtable(state_);
211     for (int i = 0; i < device_locales_.size(); i++) {
212       // Adjust index to 1-based indexing for Lua.
213       lua_pushinteger(state_, i + 1);
214       lua_newtable(state_);
215       PushString(device_locales_[i].Language());
216       lua_setfield(state_, -2, "language");
217       PushString(device_locales_[i].Region());
218       lua_setfield(state_, -2, "region");
219       PushString(device_locales_[i].Script());
220       lua_setfield(state_, -2, "script");
221       lua_settable(state_, /*idx=*/-3);
222     }
223     return 1;
224   } else if (key.Equals(kPackageNameKey)) {
225     if (context_ == nullptr) {
226       TC3_LOG(ERROR) << "Context invalid.";
227       lua_error(state_);
228       return 0;
229     }
230     ScopedLocalRef<jstring> package_name_str(
231         static_cast<jstring>(jenv_->CallObjectMethod(
232             context_, jni_cache_->context_get_package_name)));
233     if (jni_cache_->ExceptionCheckAndClear()) {
234       TC3_LOG(ERROR) << "Error calling Context.getPackageName";
235       lua_error(state_);
236       return 0;
237     }
238     PushString(ToStlString(jenv_, package_name_str.get()));
239     return 1;
240   } else if (key.Equals(kUrlEncodeKey)) {
241     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
242     return 1;
243   } else if (key.Equals(kUrlHostKey)) {
244     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>();
245     return 1;
246   } else if (key.Equals(kUrlSchemaKey)) {
247     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>();
248     return 1;
249   } else {
250     TC3_LOG(ERROR) << "Undefined android reference " << key.ToString();
251     lua_error(state_);
252     return 0;
253   }
254 }
255 
HandleUserRestrictionsCallback()256 int JniLuaEnvironment::HandleUserRestrictionsCallback() {
257   if (jni_cache_->usermanager_class == nullptr ||
258       jni_cache_->usermanager_get_user_restrictions == nullptr) {
259     // UserManager is only available for API level >= 17 and
260     // getUserRestrictions only for API level >= 18, so we just return false
261     // normally here.
262     lua_pushboolean(state_, false);
263     return 1;
264   }
265 
266   // Get user manager if not previously retrieved.
267   if (!RetrieveUserManager()) {
268     TC3_LOG(ERROR) << "Error retrieving user manager.";
269     lua_error(state_);
270     return 0;
271   }
272 
273   ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
274       usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
275   if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
276     TC3_LOG(ERROR) << "Error calling getUserRestrictions";
277     lua_error(state_);
278     return 0;
279   }
280 
281   const StringPiece key_str = ReadString(/*index=*/-1);
282   if (key_str.empty()) {
283     TC3_LOG(ERROR) << "Expected string, got null.";
284     lua_error(state_);
285     return 0;
286   }
287 
288   ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
289   if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
290     TC3_LOG(ERROR) << "Expected string, got null.";
291     lua_error(state_);
292     return 0;
293   }
294   const bool permission = jenv_->CallBooleanMethod(
295       bundle.get(), jni_cache_->bundle_get_boolean, key.get());
296   if (jni_cache_->ExceptionCheckAndClear()) {
297     TC3_LOG(ERROR) << "Error getting bundle value";
298     lua_pushboolean(state_, false);
299   } else {
300     lua_pushboolean(state_, permission);
301   }
302   return 1;
303 }
304 
HandleUrlEncode()305 int JniLuaEnvironment::HandleUrlEncode() {
306   const StringPiece input = ReadString(/*index=*/1);
307   if (input.empty()) {
308     TC3_LOG(ERROR) << "Expected string, got null.";
309     lua_error(state_);
310     return 0;
311   }
312 
313   // Call Java URL encoder.
314   ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
315   if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
316     TC3_LOG(ERROR) << "Expected string, got null.";
317     lua_error(state_);
318     return 0;
319   }
320   ScopedLocalRef<jstring> encoded_str(
321       static_cast<jstring>(jenv_->CallStaticObjectMethod(
322           jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
323           input_str.get(), jni_cache_->string_utf8.get())));
324   if (jni_cache_->ExceptionCheckAndClear()) {
325     TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
326     lua_error(state_);
327     return 0;
328   }
329   PushString(ToStlString(jenv_, encoded_str.get()));
330   return 1;
331 }
332 
ParseUri(StringPiece url) const333 ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
334   if (url.empty()) {
335     return nullptr;
336   }
337 
338   // Call to Java URI parser.
339   ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
340   if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
341     TC3_LOG(ERROR) << "Expected string, got null";
342     return nullptr;
343   }
344 
345   // Try to parse uri and get scheme.
346   ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
347       jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
348   if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
349     TC3_LOG(ERROR) << "Error calling Uri.parse";
350   }
351   return uri;
352 }
353 
HandleUrlSchema()354 int JniLuaEnvironment::HandleUrlSchema() {
355   StringPiece url = ReadString(/*index=*/1);
356 
357   ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
358   if (parsed_uri == nullptr) {
359     lua_error(state_);
360     return 0;
361   }
362 
363   ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
364       jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
365   if (jni_cache_->ExceptionCheckAndClear()) {
366     TC3_LOG(ERROR) << "Error calling Uri.getScheme";
367     lua_error(state_);
368     return 0;
369   }
370   if (scheme_str == nullptr) {
371     lua_pushnil(state_);
372   } else {
373     PushString(ToStlString(jenv_, scheme_str.get()));
374   }
375   return 1;
376 }
377 
HandleUrlHost()378 int JniLuaEnvironment::HandleUrlHost() {
379   StringPiece url = ReadString(/*index=*/-1);
380 
381   ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
382   if (parsed_uri == nullptr) {
383     lua_error(state_);
384     return 0;
385   }
386 
387   ScopedLocalRef<jstring> host_str(static_cast<jstring>(
388       jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
389   if (jni_cache_->ExceptionCheckAndClear()) {
390     TC3_LOG(ERROR) << "Error calling Uri.getHost";
391     lua_error(state_);
392     return 0;
393   }
394   if (host_str == nullptr) {
395     lua_pushnil(state_);
396   } else {
397     PushString(ToStlString(jenv_, host_str.get()));
398   }
399   return 1;
400 }
401 
HandleHash()402 int JniLuaEnvironment::HandleHash() {
403   const StringPiece input = ReadString(/*index=*/-1);
404   lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
405   return 1;
406 }
407 
HandleFormat()408 int JniLuaEnvironment::HandleFormat() {
409   const int num_args = lua_gettop(state_);
410   std::vector<StringPiece> args(num_args - 1);
411   for (int i = 0; i < num_args - 1; i++) {
412     args[i] = ReadString(/*index=*/i + 2);
413   }
414   PushString(strings::Substitute(ReadString(/*index=*/1), args));
415   return 1;
416 }
417 
LookupModelStringResource()418 bool JniLuaEnvironment::LookupModelStringResource() {
419   // Handle only lookup by name.
420   if (lua_type(state_, 2) != LUA_TSTRING) {
421     return false;
422   }
423 
424   const StringPiece resource_name = ReadString(/*index=*/-1);
425   std::string resource_content;
426   if (!resources_.GetResourceContent(device_locales_, resource_name,
427                                      &resource_content)) {
428     // Resource cannot be provided by the model.
429     return false;
430   }
431 
432   PushString(resource_content);
433   return true;
434 }
435 
HandleAndroidStringResources()436 int JniLuaEnvironment::HandleAndroidStringResources() {
437   // Check whether the requested resource can be served from the model data.
438   if (LookupModelStringResource()) {
439     return 1;
440   }
441 
442   // Get system resources if not previously retrieved.
443   if (!RetrieveSystemResources()) {
444     TC3_LOG(ERROR) << "Error retrieving system resources.";
445     lua_error(state_);
446     return 0;
447   }
448 
449   int resource_id;
450   switch (lua_type(state_, -1)) {
451     case LUA_TNUMBER:
452       resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
453       break;
454     case LUA_TSTRING: {
455       const StringPiece resource_name_str = ReadString(/*index=*/-1);
456       if (resource_name_str.empty()) {
457         TC3_LOG(ERROR) << "No resource name provided.";
458         lua_error(state_);
459         return 0;
460       }
461       ScopedLocalRef<jstring> resource_name =
462           jni_cache_->ConvertToJavaString(resource_name_str);
463       if (resource_name == nullptr) {
464         TC3_LOG(ERROR) << "Invalid resource name.";
465         lua_error(state_);
466         return 0;
467       }
468       resource_id = jenv_->CallIntMethod(
469           system_resources_.get(), jni_cache_->resources_get_identifier,
470           resource_name.get(), string_.get(), android_.get());
471       if (jni_cache_->ExceptionCheckAndClear()) {
472         TC3_LOG(ERROR) << "Error calling getIdentifier.";
473         lua_error(state_);
474         return 0;
475       }
476       break;
477     }
478     default:
479       TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
480       lua_error(state_);
481       return 0;
482   }
483   if (resource_id == 0) {
484     TC3_LOG(ERROR) << "Resource not found.";
485     lua_pushnil(state_);
486     return 1;
487   }
488   ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
489       jenv_->CallObjectMethod(system_resources_.get(),
490                               jni_cache_->resources_get_string, resource_id)));
491   if (jni_cache_->ExceptionCheckAndClear()) {
492     TC3_LOG(ERROR) << "Error calling getString.";
493     lua_error(state_);
494     return 0;
495   }
496   if (resource_str == nullptr) {
497     lua_pushnil(state_);
498   } else {
499     PushString(ToStlString(jenv_, resource_str.get()));
500   }
501   return 1;
502 }
503 
RetrieveSystemResources()504 bool JniLuaEnvironment::RetrieveSystemResources() {
505   if (system_resources_resources_retrieved_) {
506     return (system_resources_ != nullptr);
507   }
508   system_resources_resources_retrieved_ = true;
509   jobject system_resources_ref = jenv_->CallStaticObjectMethod(
510       jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
511   if (jni_cache_->ExceptionCheckAndClear()) {
512     TC3_LOG(ERROR) << "Error calling getSystem.";
513     return false;
514   }
515   system_resources_ =
516       MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
517   return (system_resources_ != nullptr);
518 }
519 
RetrieveUserManager()520 bool JniLuaEnvironment::RetrieveUserManager() {
521   if (context_ == nullptr) {
522     return false;
523   }
524   if (usermanager_retrieved_) {
525     return (usermanager_ != nullptr);
526   }
527   usermanager_retrieved_ = true;
528   ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
529   jobject usermanager_ref = jenv_->CallObjectMethod(
530       context_, jni_cache_->context_get_system_service, service.get());
531   if (jni_cache_->ExceptionCheckAndClear()) {
532     TC3_LOG(ERROR) << "Error calling getSystemService.";
533     return false;
534   }
535   usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
536   return (usermanager_ != nullptr);
537 }
538 
ReadRemoteActionTemplateResult()539 RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
540   RemoteActionTemplate result;
541   // Read intent template.
542   lua_pushnil(state_);
543   while (lua_next(state_, /*idx=*/-2)) {
544     const StringPiece key = ReadString(/*index=*/-2);
545     if (key.Equals("title_without_entity")) {
546       result.title_without_entity = ReadString(/*index=*/-1).ToString();
547     } else if (key.Equals("title_with_entity")) {
548       result.title_with_entity = ReadString(/*index=*/-1).ToString();
549     } else if (key.Equals("description")) {
550       result.description = ReadString(/*index=*/-1).ToString();
551     } else if (key.Equals("description_with_app_name")) {
552       result.description_with_app_name = ReadString(/*index=*/-1).ToString();
553     } else if (key.Equals("action")) {
554       result.action = ReadString(/*index=*/-1).ToString();
555     } else if (key.Equals("data")) {
556       result.data = ReadString(/*index=*/-1).ToString();
557     } else if (key.Equals("type")) {
558       result.type = ReadString(/*index=*/-1).ToString();
559     } else if (key.Equals("flags")) {
560       result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
561     } else if (key.Equals("package_name")) {
562       result.package_name = ReadString(/*index=*/-1).ToString();
563     } else if (key.Equals("request_code")) {
564       result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
565     } else if (key.Equals("category")) {
566       ReadCategories(&result.category);
567     } else if (key.Equals("extra")) {
568       ReadExtras(&result.extra);
569     } else {
570       TC3_LOG(INFO) << "Unknown entry: " << key.ToString();
571     }
572     lua_pop(state_, 1);
573   }
574   lua_pop(state_, 1);
575   return result;
576 }
577 
ReadCategories(std::vector<std::string> * category)578 void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
579   // Read category array.
580   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
581     TC3_LOG(ERROR) << "Expected categories table, got: "
582                    << lua_type(state_, /*idx=*/-1);
583     lua_pop(state_, 1);
584     return;
585   }
586   lua_pushnil(state_);
587   while (lua_next(state_, /*idx=*/-2)) {
588     category->push_back(ReadString(/*index=*/-1).ToString());
589     lua_pop(state_, 1);
590   }
591 }
592 
ReadExtras(std::map<std::string,Variant> * extra)593 void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
594   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
595     TC3_LOG(ERROR) << "Expected extras table, got: "
596                    << lua_type(state_, /*idx=*/-1);
597     lua_pop(state_, 1);
598     return;
599   }
600   lua_pushnil(state_);
601   while (lua_next(state_, /*idx=*/-2)) {
602     // Each entry is a table specifying name and value.
603     // The value is specified via a type specific field as Lua doesn't allow
604     // to easily distinguish between different number types.
605     if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
606       TC3_LOG(ERROR) << "Expected a table for an extra, got: "
607                      << lua_type(state_, /*idx=*/-1);
608       lua_pop(state_, 1);
609       return;
610     }
611     std::string name;
612     Variant value;
613 
614     lua_pushnil(state_);
615     while (lua_next(state_, /*idx=*/-2)) {
616       const StringPiece key = ReadString(/*index=*/-2);
617       if (key.Equals("name")) {
618         name = ReadString(/*index=*/-1).ToString();
619       } else if (key.Equals("int_value")) {
620         value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
621       } else if (key.Equals("long_value")) {
622         value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
623       } else if (key.Equals("float_value")) {
624         value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
625       } else if (key.Equals("bool_value")) {
626         value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
627       } else if (key.Equals("string_value")) {
628         value = Variant(ReadString(/*index=*/-1).ToString());
629       } else {
630         TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
631       }
632       lua_pop(state_, 1);
633     }
634     if (!name.empty()) {
635       (*extra)[name] = value;
636     } else {
637       TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
638     }
639     lua_pop(state_, 1);
640   }
641 }
642 
ReadRemoteActionTemplates(std::vector<RemoteActionTemplate> * result)643 int JniLuaEnvironment::ReadRemoteActionTemplates(
644     std::vector<RemoteActionTemplate>* result) {
645   // Read result.
646   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
647     TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
648     lua_error(state_);
649     return LUA_ERRRUN;
650   }
651 
652   // Read remote action templates array.
653   lua_pushnil(state_);
654   while (lua_next(state_, /*idx=*/-2)) {
655     if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
656       TC3_LOG(ERROR) << "Expected intent table, got: "
657                      << lua_type(state_, /*idx=*/-1);
658       lua_pop(state_, 1);
659       continue;
660     }
661     result->push_back(ReadRemoteActionTemplateResult());
662   }
663   lua_pop(state_, /*n=*/1);
664   return LUA_OK;
665 }
666 
RunIntentGenerator(const std::string & generator_snippet,std::vector<RemoteActionTemplate> * remote_actions)667 bool JniLuaEnvironment::RunIntentGenerator(
668     const std::string& generator_snippet,
669     std::vector<RemoteActionTemplate>* remote_actions) {
670   int status;
671   status = luaL_loadbuffer(state_, generator_snippet.data(),
672                            generator_snippet.size(),
673                            /*name=*/nullptr);
674   if (status != LUA_OK) {
675     TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
676     return false;
677   }
678   status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
679   if (status != LUA_OK) {
680     TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
681     return false;
682   }
683   if (RunProtected(
684           [this, remote_actions] {
685             return ReadRemoteActionTemplates(remote_actions);
686           },
687           /*num_args=*/1) != LUA_OK) {
688     TC3_LOG(ERROR) << "Could not read results.";
689     return false;
690   }
691   // Check that we correctly cleaned-up the state.
692   const int stack_size = lua_gettop(state_);
693   if (stack_size > 0) {
694     TC3_LOG(ERROR) << "Unexpected stack size.";
695     lua_settop(state_, 0);
696     return false;
697   }
698   return true;
699 }
700 
701 // Lua environment for classfication result intent generation.
702 class AnnotatorJniEnvironment : public JniLuaEnvironment {
703  public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema)704   AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
705                           const jobject context,
706                           const std::vector<Locale>& device_locales,
707                           const std::string& entity_text,
708                           const ClassificationResult& classification,
709                           const int64 reference_time_ms_utc,
710                           const reflection::Schema* entity_data_schema)
711       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
712         entity_text_(entity_text),
713         classification_(classification),
714         reference_time_ms_utc_(reference_time_ms_utc),
715         entity_data_schema_(entity_data_schema) {}
716 
717  protected:
SetupExternalHook()718   void SetupExternalHook() override {
719     JniLuaEnvironment::SetupExternalHook();
720     lua_pushinteger(state_, reference_time_ms_utc_);
721     lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
722 
723     PushAnnotation(classification_, entity_text_, entity_data_schema_, this);
724     lua_setfield(state_, /*idx=*/-2, "entity");
725   }
726 
727   const std::string& entity_text_;
728   const ClassificationResult& classification_;
729   const int64 reference_time_ms_utc_;
730 
731   // Reflection schema data.
732   const reflection::Schema* const entity_data_schema_;
733 };
734 
735 // Lua environment for actions intent generation.
736 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
737  public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)738   ActionsJniLuaEnvironment(
739       const Resources& resources, const JniCache* jni_cache,
740       const jobject context, const std::vector<Locale>& device_locales,
741       const Conversation& conversation, const ActionSuggestion& action,
742       const reflection::Schema* actions_entity_data_schema,
743       const reflection::Schema* annotations_entity_data_schema)
744       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
745         conversation_(conversation),
746         action_(action),
747         annotation_iterator_(annotations_entity_data_schema, this),
748         conversation_iterator_(annotations_entity_data_schema, this),
749         entity_data_schema_(actions_entity_data_schema) {}
750 
751  protected:
SetupExternalHook()752   void SetupExternalHook() override {
753     JniLuaEnvironment::SetupExternalHook();
754     conversation_iterator_.NewIterator("conversation", &conversation_.messages,
755                                        state_);
756     lua_setfield(state_, /*idx=*/-2, "conversation");
757 
758     PushAction(action_, entity_data_schema_, annotation_iterator_, this);
759     lua_setfield(state_, /*idx=*/-2, "entity");
760   }
761 
762   const Conversation& conversation_;
763   const ActionSuggestion& action_;
764   const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
765   const ConversationIterator conversation_iterator_;
766   const reflection::Schema* entity_data_schema_;
767 };
768 
769 }  // namespace
770 
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)771 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
772     const IntentFactoryModel* options, const ResourcePool* resources,
773     const std::shared_ptr<JniCache>& jni_cache) {
774   std::unique_ptr<IntentGenerator> intent_generator(
775       new IntentGenerator(options, resources, jni_cache));
776 
777   if (options == nullptr || options->generator() == nullptr) {
778     TC3_LOG(ERROR) << "No intent generator options.";
779     return nullptr;
780   }
781 
782   std::unique_ptr<ZlibDecompressor> zlib_decompressor =
783       ZlibDecompressor::Instance();
784   if (!zlib_decompressor) {
785     TC3_LOG(ERROR) << "Cannot initialize decompressor.";
786     return nullptr;
787   }
788 
789   for (const IntentFactoryModel_::IntentGenerator* generator :
790        *options->generator()) {
791     std::string lua_template_generator;
792     if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
793             generator->lua_template_generator(),
794             generator->compressed_lua_template_generator(),
795             &lua_template_generator)) {
796       TC3_LOG(ERROR) << "Could not decompress generator template.";
797       return nullptr;
798     }
799 
800     std::string lua_code = lua_template_generator;
801     if (options->precompile_generators()) {
802       if (!Compile(lua_template_generator, &lua_code)) {
803         TC3_LOG(ERROR) << "Could not precompile generator template.";
804         return nullptr;
805       }
806     }
807 
808     intent_generator->generators_[generator->type()->str()] = lua_code;
809   }
810 
811   return intent_generator;
812 }
813 
ParseDeviceLocales(const jstring device_locales) const814 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
815     const jstring device_locales) const {
816   if (device_locales == nullptr) {
817     TC3_LOG(ERROR) << "No locales provided.";
818     return {};
819   }
820   ScopedStringChars locales_str =
821       GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
822   if (locales_str == nullptr) {
823     TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
824     return {};
825   }
826   std::vector<Locale> locales;
827   if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
828                     &locales)) {
829     TC3_LOG(ERROR) << "Cannot parse locales.";
830     return {};
831   }
832   return locales;
833 }
834 
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const835 bool IntentGenerator::GenerateIntents(
836     const jstring device_locales, const ClassificationResult& classification,
837     const int64 reference_time_ms_utc, const std::string& text,
838     const CodepointSpan selection_indices, const jobject context,
839     const reflection::Schema* annotations_entity_data_schema,
840     std::vector<RemoteActionTemplate>* remote_actions) const {
841   if (options_ == nullptr) {
842     return false;
843   }
844 
845   // Retrieve generator for specified entity.
846   auto it = generators_.find(classification.collection);
847   if (it == generators_.end()) {
848     return true;
849   }
850 
851   const std::string entity_text =
852       UTF8ToUnicodeText(text, /*do_copy=*/false)
853           .UTF8Substring(selection_indices.first, selection_indices.second);
854 
855   std::unique_ptr<AnnotatorJniEnvironment> interpreter(
856       new AnnotatorJniEnvironment(
857           resources_, jni_cache_.get(), context,
858           ParseDeviceLocales(device_locales), entity_text, classification,
859           reference_time_ms_utc, annotations_entity_data_schema));
860 
861   if (!interpreter->Initialize()) {
862     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
863     return false;
864   }
865 
866   return interpreter->RunIntentGenerator(it->second, remote_actions);
867 }
868 
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const869 bool IntentGenerator::GenerateIntents(
870     const jstring device_locales, const ActionSuggestion& action,
871     const Conversation& conversation, const jobject context,
872     const reflection::Schema* annotations_entity_data_schema,
873     const reflection::Schema* actions_entity_data_schema,
874     std::vector<RemoteActionTemplate>* remote_actions) const {
875   if (options_ == nullptr) {
876     return false;
877   }
878 
879   // Retrieve generator for specified action.
880   auto it = generators_.find(action.type);
881   if (it == generators_.end()) {
882     return true;
883   }
884 
885   std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
886       new ActionsJniLuaEnvironment(
887           resources_, jni_cache_.get(), context,
888           ParseDeviceLocales(device_locales), conversation, action,
889           actions_entity_data_schema, annotations_entity_data_schema));
890 
891   if (!interpreter->Initialize()) {
892     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
893     return false;
894   }
895 
896   return interpreter->RunIntentGenerator(it->second, remote_actions);
897 }
898 
899 }  // namespace libtextclassifier3
900