1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/intents/intent-generator.h"
18
19 #include <vector>
20
21 #include "utils/base/logging.h"
22 #include "utils/intents/jni-lua.h"
23 #include "utils/java/jni-helper.h"
24 #include "utils/utf8/unicodetext.h"
25 #include "utils/zlib/zlib.h"
26
27 #ifdef __cplusplus
28 extern "C" {
29 #endif
30 #include "lauxlib.h"
31 #include "lua.h"
32 #ifdef __cplusplus
33 }
34 #endif
35
36 namespace libtextclassifier3 {
37 namespace {
38
39 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
40
41 // Lua environment for classfication result intent generation.
42 class AnnotatorJniEnvironment : public JniLuaEnvironment {
43 public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema)44 AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
45 const jobject context,
46 const std::vector<Locale>& device_locales,
47 const std::string& entity_text,
48 const ClassificationResult& classification,
49 const int64 reference_time_ms_utc,
50 const reflection::Schema* entity_data_schema)
51 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
52 entity_text_(entity_text),
53 classification_(classification),
54 reference_time_ms_utc_(reference_time_ms_utc),
55 entity_data_schema_(entity_data_schema) {}
56
57 protected:
SetupExternalHook()58 void SetupExternalHook() override {
59 JniLuaEnvironment::SetupExternalHook();
60 lua_pushinteger(state_, reference_time_ms_utc_);
61 lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
62
63 PushAnnotation(classification_, entity_text_, entity_data_schema_);
64 lua_setfield(state_, /*idx=*/-2, "entity");
65 }
66
67 const std::string& entity_text_;
68 const ClassificationResult& classification_;
69 const int64 reference_time_ms_utc_;
70
71 // Reflection schema data.
72 const reflection::Schema* const entity_data_schema_;
73 };
74
75 // Lua environment for actions intent generation.
76 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
77 public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)78 ActionsJniLuaEnvironment(
79 const Resources& resources, const JniCache* jni_cache,
80 const jobject context, const std::vector<Locale>& device_locales,
81 const Conversation& conversation, const ActionSuggestion& action,
82 const reflection::Schema* actions_entity_data_schema,
83 const reflection::Schema* annotations_entity_data_schema)
84 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
85 conversation_(conversation),
86 action_(action),
87 actions_entity_data_schema_(actions_entity_data_schema),
88 annotations_entity_data_schema_(annotations_entity_data_schema) {}
89
90 protected:
SetupExternalHook()91 void SetupExternalHook() override {
92 JniLuaEnvironment::SetupExternalHook();
93 PushConversation(&conversation_.messages, annotations_entity_data_schema_);
94 lua_setfield(state_, /*idx=*/-2, "conversation");
95
96 PushAction(action_, actions_entity_data_schema_,
97 annotations_entity_data_schema_);
98 lua_setfield(state_, /*idx=*/-2, "entity");
99 }
100
101 const Conversation& conversation_;
102 const ActionSuggestion& action_;
103 const reflection::Schema* actions_entity_data_schema_;
104 const reflection::Schema* annotations_entity_data_schema_;
105 };
106
107 } // namespace
108
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)109 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
110 const IntentFactoryModel* options, const ResourcePool* resources,
111 const std::shared_ptr<JniCache>& jni_cache) {
112 std::unique_ptr<IntentGenerator> intent_generator(
113 new IntentGenerator(options, resources, jni_cache));
114
115 if (options == nullptr || options->generator() == nullptr) {
116 TC3_LOG(ERROR) << "No intent generator options.";
117 return nullptr;
118 }
119
120 std::unique_ptr<ZlibDecompressor> zlib_decompressor =
121 ZlibDecompressor::Instance();
122 if (!zlib_decompressor) {
123 TC3_LOG(ERROR) << "Cannot initialize decompressor.";
124 return nullptr;
125 }
126
127 for (const IntentFactoryModel_::IntentGenerator* generator :
128 *options->generator()) {
129 std::string lua_template_generator;
130 if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
131 generator->lua_template_generator(),
132 generator->compressed_lua_template_generator(),
133 &lua_template_generator)) {
134 TC3_LOG(ERROR) << "Could not decompress generator template.";
135 return nullptr;
136 }
137
138 std::string lua_code = lua_template_generator;
139 if (options->precompile_generators()) {
140 if (!Compile(lua_template_generator, &lua_code)) {
141 TC3_LOG(ERROR) << "Could not precompile generator template.";
142 return nullptr;
143 }
144 }
145
146 intent_generator->generators_[generator->type()->str()] = lua_code;
147 }
148
149 return intent_generator;
150 }
151
ParseDeviceLocales(const jstring device_locales) const152 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
153 const jstring device_locales) const {
154 if (device_locales == nullptr) {
155 TC3_LOG(ERROR) << "No locales provided.";
156 return {};
157 }
158 StatusOr<std::string> status_or_locales_str =
159 JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
160 if (!status_or_locales_str.ok()) {
161 TC3_LOG(ERROR)
162 << "JStringToUtf8String failed, cannot retrieve provided locales.";
163 return {};
164 }
165 std::vector<Locale> locales;
166 if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
167 TC3_LOG(ERROR) << "Cannot parse locales.";
168 return {};
169 }
170 return locales;
171 }
172
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const173 bool IntentGenerator::GenerateIntents(
174 const jstring device_locales, const ClassificationResult& classification,
175 const int64 reference_time_ms_utc, const std::string& text,
176 const CodepointSpan selection_indices, const jobject context,
177 const reflection::Schema* annotations_entity_data_schema,
178 std::vector<RemoteActionTemplate>* remote_actions) const {
179 if (options_ == nullptr) {
180 return false;
181 }
182
183 // Retrieve generator for specified entity.
184 auto it = generators_.find(classification.collection);
185 if (it == generators_.end()) {
186 TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
187 return true;
188 }
189
190 const std::string entity_text =
191 UTF8ToUnicodeText(text, /*do_copy=*/false)
192 .UTF8Substring(selection_indices.first, selection_indices.second);
193
194 std::unique_ptr<AnnotatorJniEnvironment> interpreter(
195 new AnnotatorJniEnvironment(
196 resources_, jni_cache_.get(), context,
197 ParseDeviceLocales(device_locales), entity_text, classification,
198 reference_time_ms_utc, annotations_entity_data_schema));
199
200 if (!interpreter->Initialize()) {
201 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
202 return false;
203 }
204
205 return interpreter->RunIntentGenerator(it->second, remote_actions);
206 }
207
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const208 bool IntentGenerator::GenerateIntents(
209 const jstring device_locales, const ActionSuggestion& action,
210 const Conversation& conversation, const jobject context,
211 const reflection::Schema* annotations_entity_data_schema,
212 const reflection::Schema* actions_entity_data_schema,
213 std::vector<RemoteActionTemplate>* remote_actions) const {
214 if (options_ == nullptr) {
215 return false;
216 }
217
218 // Retrieve generator for specified action.
219 auto it = generators_.find(action.type);
220 if (it == generators_.end()) {
221 return true;
222 }
223
224 std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
225 new ActionsJniLuaEnvironment(
226 resources_, jni_cache_.get(), context,
227 ParseDeviceLocales(device_locales), conversation, action,
228 actions_entity_data_schema, annotations_entity_data_schema));
229
230 if (!interpreter->Initialize()) {
231 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
232 return false;
233 }
234
235 return interpreter->RunIntentGenerator(it->second, remote_actions);
236 }
237
238 } // namespace libtextclassifier3
239