• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/intents/intent-generator.h"
18 
19 #include <vector>
20 
21 #include "utils/base/logging.h"
22 #include "utils/intents/jni-lua.h"
23 #include "utils/java/jni-helper.h"
24 #include "utils/utf8/unicodetext.h"
25 #include "utils/zlib/zlib.h"
26 
27 #ifdef __cplusplus
28 extern "C" {
29 #endif
30 #include "lauxlib.h"
31 #include "lua.h"
32 #ifdef __cplusplus
33 }
34 #endif
35 
36 namespace libtextclassifier3 {
37 namespace {
38 
39 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
40 
41 // Lua environment for classfication result intent generation.
42 class AnnotatorJniEnvironment : public JniLuaEnvironment {
43  public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema)44   AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
45                           const jobject context,
46                           const std::vector<Locale>& device_locales,
47                           const std::string& entity_text,
48                           const ClassificationResult& classification,
49                           const int64 reference_time_ms_utc,
50                           const reflection::Schema* entity_data_schema)
51       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
52         entity_text_(entity_text),
53         classification_(classification),
54         reference_time_ms_utc_(reference_time_ms_utc),
55         entity_data_schema_(entity_data_schema) {}
56 
57  protected:
SetupExternalHook()58   void SetupExternalHook() override {
59     JniLuaEnvironment::SetupExternalHook();
60     lua_pushinteger(state_, reference_time_ms_utc_);
61     lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
62 
63     PushAnnotation(classification_, entity_text_, entity_data_schema_);
64     lua_setfield(state_, /*idx=*/-2, "entity");
65   }
66 
67   const std::string& entity_text_;
68   const ClassificationResult& classification_;
69   const int64 reference_time_ms_utc_;
70 
71   // Reflection schema data.
72   const reflection::Schema* const entity_data_schema_;
73 };
74 
75 // Lua environment for actions intent generation.
76 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
77  public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)78   ActionsJniLuaEnvironment(
79       const Resources& resources, const JniCache* jni_cache,
80       const jobject context, const std::vector<Locale>& device_locales,
81       const Conversation& conversation, const ActionSuggestion& action,
82       const reflection::Schema* actions_entity_data_schema,
83       const reflection::Schema* annotations_entity_data_schema)
84       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
85         conversation_(conversation),
86         action_(action),
87         actions_entity_data_schema_(actions_entity_data_schema),
88         annotations_entity_data_schema_(annotations_entity_data_schema) {}
89 
90  protected:
SetupExternalHook()91   void SetupExternalHook() override {
92     JniLuaEnvironment::SetupExternalHook();
93     PushConversation(&conversation_.messages, annotations_entity_data_schema_);
94     lua_setfield(state_, /*idx=*/-2, "conversation");
95 
96     PushAction(action_, actions_entity_data_schema_,
97                annotations_entity_data_schema_);
98     lua_setfield(state_, /*idx=*/-2, "entity");
99   }
100 
101   const Conversation& conversation_;
102   const ActionSuggestion& action_;
103   const reflection::Schema* actions_entity_data_schema_;
104   const reflection::Schema* annotations_entity_data_schema_;
105 };
106 
107 }  // namespace
108 
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)109 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
110     const IntentFactoryModel* options, const ResourcePool* resources,
111     const std::shared_ptr<JniCache>& jni_cache) {
112   std::unique_ptr<IntentGenerator> intent_generator(
113       new IntentGenerator(options, resources, jni_cache));
114 
115   if (options == nullptr || options->generator() == nullptr) {
116     TC3_LOG(ERROR) << "No intent generator options.";
117     return nullptr;
118   }
119 
120   std::unique_ptr<ZlibDecompressor> zlib_decompressor =
121       ZlibDecompressor::Instance();
122   if (!zlib_decompressor) {
123     TC3_LOG(ERROR) << "Cannot initialize decompressor.";
124     return nullptr;
125   }
126 
127   for (const IntentFactoryModel_::IntentGenerator* generator :
128        *options->generator()) {
129     std::string lua_template_generator;
130     if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
131             generator->lua_template_generator(),
132             generator->compressed_lua_template_generator(),
133             &lua_template_generator)) {
134       TC3_LOG(ERROR) << "Could not decompress generator template.";
135       return nullptr;
136     }
137 
138     std::string lua_code = lua_template_generator;
139     if (options->precompile_generators()) {
140       if (!Compile(lua_template_generator, &lua_code)) {
141         TC3_LOG(ERROR) << "Could not precompile generator template.";
142         return nullptr;
143       }
144     }
145 
146     intent_generator->generators_[generator->type()->str()] = lua_code;
147   }
148 
149   return intent_generator;
150 }
151 
ParseDeviceLocales(const jstring device_locales) const152 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
153     const jstring device_locales) const {
154   if (device_locales == nullptr) {
155     TC3_LOG(ERROR) << "No locales provided.";
156     return {};
157   }
158   StatusOr<std::string> status_or_locales_str =
159       JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
160   if (!status_or_locales_str.ok()) {
161     TC3_LOG(ERROR)
162         << "JStringToUtf8String failed, cannot retrieve provided locales.";
163     return {};
164   }
165   std::vector<Locale> locales;
166   if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
167     TC3_LOG(ERROR) << "Cannot parse locales.";
168     return {};
169   }
170   return locales;
171 }
172 
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const173 bool IntentGenerator::GenerateIntents(
174     const jstring device_locales, const ClassificationResult& classification,
175     const int64 reference_time_ms_utc, const std::string& text,
176     const CodepointSpan selection_indices, const jobject context,
177     const reflection::Schema* annotations_entity_data_schema,
178     std::vector<RemoteActionTemplate>* remote_actions) const {
179   if (options_ == nullptr) {
180     return false;
181   }
182 
183   // Retrieve generator for specified entity.
184   auto it = generators_.find(classification.collection);
185   if (it == generators_.end()) {
186     TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
187     return true;
188   }
189 
190   const std::string entity_text =
191       UTF8ToUnicodeText(text, /*do_copy=*/false)
192           .UTF8Substring(selection_indices.first, selection_indices.second);
193 
194   std::unique_ptr<AnnotatorJniEnvironment> interpreter(
195       new AnnotatorJniEnvironment(
196           resources_, jni_cache_.get(), context,
197           ParseDeviceLocales(device_locales), entity_text, classification,
198           reference_time_ms_utc, annotations_entity_data_schema));
199 
200   if (!interpreter->Initialize()) {
201     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
202     return false;
203   }
204 
205   return interpreter->RunIntentGenerator(it->second, remote_actions);
206 }
207 
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const208 bool IntentGenerator::GenerateIntents(
209     const jstring device_locales, const ActionSuggestion& action,
210     const Conversation& conversation, const jobject context,
211     const reflection::Schema* annotations_entity_data_schema,
212     const reflection::Schema* actions_entity_data_schema,
213     std::vector<RemoteActionTemplate>* remote_actions) const {
214   if (options_ == nullptr) {
215     return false;
216   }
217 
218   // Retrieve generator for specified action.
219   auto it = generators_.find(action.type);
220   if (it == generators_.end()) {
221     return true;
222   }
223 
224   std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
225       new ActionsJniLuaEnvironment(
226           resources_, jni_cache_.get(), context,
227           ParseDeviceLocales(device_locales), conversation, action,
228           actions_entity_data_schema, annotations_entity_data_schema));
229 
230   if (!interpreter->Initialize()) {
231     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
232     return false;
233   }
234 
235   return interpreter->RunIntentGenerator(it->second, remote_actions);
236 }
237 
238 }  // namespace libtextclassifier3
239