1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/lua-utils.h"
18
19 namespace libtextclassifier3 {
20 namespace {
21 static constexpr const char* kTextKey = "text";
22 static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
23 static constexpr const char* kGranularityKey = "granularity";
24 static constexpr const char* kCollectionKey = "collection";
25 static constexpr const char* kNameKey = "name";
26 static constexpr const char* kScoreKey = "score";
27 static constexpr const char* kPriorityScoreKey = "priority_score";
28 static constexpr const char* kTypeKey = "type";
29 static constexpr const char* kResponseTextKey = "response_text";
30 static constexpr const char* kAnnotationKey = "annotation";
31 static constexpr const char* kSpanKey = "span";
32 static constexpr const char* kMessageKey = "message";
33 static constexpr const char* kBeginKey = "begin";
34 static constexpr const char* kEndKey = "end";
35 static constexpr const char* kClassificationKey = "classification";
36 static constexpr const char* kSerializedEntity = "serialized_entity";
37 static constexpr const char* kEntityKey = "entity";
38 } // namespace
39
40 template <>
Item(const std::vector<ClassificationResult> * annotations,StringPiece key,lua_State * state) const41 int AnnotationIterator<ClassificationResult>::Item(
42 const std::vector<ClassificationResult>* annotations, StringPiece key,
43 lua_State* state) const {
44 // Lookup annotation by collection.
45 for (const ClassificationResult& annotation : *annotations) {
46 if (key.Equals(annotation.collection)) {
47 PushAnnotation(annotation, entity_data_schema_, env_);
48 return 1;
49 }
50 }
51 TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString()
52 << " found.";
53 lua_error(state);
54 return 0;
55 }
56
57 template <>
Item(const std::vector<ActionSuggestionAnnotation> * annotations,StringPiece key,lua_State * state) const58 int AnnotationIterator<ActionSuggestionAnnotation>::Item(
59 const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
60 lua_State* state) const {
61 // Lookup annotation by name.
62 for (const ActionSuggestionAnnotation& annotation : *annotations) {
63 if (key.Equals(annotation.name)) {
64 PushAnnotation(annotation, entity_data_schema_, env_);
65 return 1;
66 }
67 }
68 TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found.";
69 lua_error(state);
70 return 0;
71 }
72
PushAnnotation(const ClassificationResult & classification,const reflection::Schema * entity_data_schema,LuaEnvironment * env)73 void PushAnnotation(const ClassificationResult& classification,
74 const reflection::Schema* entity_data_schema,
75 LuaEnvironment* env) {
76 if (entity_data_schema == nullptr ||
77 classification.serialized_entity_data.empty()) {
78 // Empty table.
79 lua_newtable(env->state());
80 } else {
81 env->PushFlatbuffer(entity_data_schema,
82 flatbuffers::GetRoot<flatbuffers::Table>(
83 classification.serialized_entity_data.data()));
84 }
85 lua_pushinteger(env->state(),
86 classification.datetime_parse_result.time_ms_utc);
87 lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey);
88 lua_pushinteger(env->state(),
89 classification.datetime_parse_result.granularity);
90 lua_setfield(env->state(), /*idx=*/-2, kGranularityKey);
91 env->PushString(classification.collection);
92 lua_setfield(env->state(), /*idx=*/-2, kCollectionKey);
93 lua_pushnumber(env->state(), classification.score);
94 lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
95 env->PushString(classification.serialized_entity_data);
96 lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity);
97 }
98
PushAnnotation(const ClassificationResult & classification,StringPiece text,const reflection::Schema * entity_data_schema,LuaEnvironment * env)99 void PushAnnotation(const ClassificationResult& classification,
100 StringPiece text,
101 const reflection::Schema* entity_data_schema,
102 LuaEnvironment* env) {
103 PushAnnotation(classification, entity_data_schema, env);
104 env->PushString(text);
105 lua_setfield(env->state(), /*idx=*/-2, kTextKey);
106 }
107
PushAnnotatedSpan(const AnnotatedSpan & annotated_span,const AnnotationIterator<ClassificationResult> & annotation_iterator,LuaEnvironment * env)108 void PushAnnotatedSpan(
109 const AnnotatedSpan& annotated_span,
110 const AnnotationIterator<ClassificationResult>& annotation_iterator,
111 LuaEnvironment* env) {
112 lua_newtable(env->state());
113 {
114 lua_newtable(env->state());
115 lua_pushinteger(env->state(), annotated_span.span.first);
116 lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
117 lua_pushinteger(env->state(), annotated_span.span.second);
118 lua_setfield(env->state(), /*idx=*/-2, kEndKey);
119 }
120 lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
121 annotation_iterator.NewIterator(kClassificationKey,
122 &annotated_span.classification, env->state());
123 lua_setfield(env->state(), /*idx=*/-2, kClassificationKey);
124 }
125
ReadSpan(LuaEnvironment * env)126 MessageTextSpan ReadSpan(LuaEnvironment* env) {
127 MessageTextSpan span;
128 lua_pushnil(env->state());
129 while (lua_next(env->state(), /*idx=*/-2)) {
130 const StringPiece key = env->ReadString(/*index=*/-2);
131 if (key.Equals(kMessageKey)) {
132 span.message_index =
133 static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
134 } else if (key.Equals(kBeginKey)) {
135 span.span.first =
136 static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
137 } else if (key.Equals(kEndKey)) {
138 span.span.second =
139 static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
140 } else if (key.Equals(kTextKey)) {
141 span.text = env->ReadString(/*index=*/-1).ToString();
142 } else {
143 TC3_LOG(INFO) << "Unknown span field: " << key.ToString();
144 }
145 lua_pop(env->state(), 1);
146 }
147 return span;
148 }
149
ReadAnnotations(const reflection::Schema * entity_data_schema,LuaEnvironment * env,std::vector<ActionSuggestionAnnotation> * annotations)150 int ReadAnnotations(const reflection::Schema* entity_data_schema,
151 LuaEnvironment* env,
152 std::vector<ActionSuggestionAnnotation>* annotations) {
153 if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
154 TC3_LOG(ERROR) << "Expected annotations table, got: "
155 << lua_type(env->state(), /*idx=*/-1);
156 lua_pop(env->state(), 1);
157 lua_error(env->state());
158 return LUA_ERRRUN;
159 }
160
161 // Read actions.
162 lua_pushnil(env->state());
163 while (lua_next(env->state(), /*idx=*/-2)) {
164 if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
165 TC3_LOG(ERROR) << "Expected annotation table, got: "
166 << lua_type(env->state(), /*idx=*/-1);
167 lua_pop(env->state(), 1);
168 continue;
169 }
170 annotations->push_back(ReadAnnotation(entity_data_schema, env));
171 lua_pop(env->state(), 1);
172 }
173 return LUA_OK;
174 }
175
ReadAnnotation(const reflection::Schema * entity_data_schema,LuaEnvironment * env)176 ActionSuggestionAnnotation ReadAnnotation(
177 const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
178 ActionSuggestionAnnotation annotation;
179 lua_pushnil(env->state());
180 while (lua_next(env->state(), /*idx=*/-2)) {
181 const StringPiece key = env->ReadString(/*index=*/-2);
182 if (key.Equals(kNameKey)) {
183 annotation.name = env->ReadString(/*index=*/-1).ToString();
184 } else if (key.Equals(kSpanKey)) {
185 annotation.span = ReadSpan(env);
186 } else if (key.Equals(kEntityKey)) {
187 annotation.entity = ReadClassificationResult(entity_data_schema, env);
188 } else {
189 TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString();
190 }
191 lua_pop(env->state(), 1);
192 }
193 return annotation;
194 }
195
ReadClassificationResult(const reflection::Schema * entity_data_schema,LuaEnvironment * env)196 ClassificationResult ReadClassificationResult(
197 const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
198 ClassificationResult classification;
199 lua_pushnil(env->state());
200 while (lua_next(env->state(), /*idx=*/-2)) {
201 const StringPiece key = env->ReadString(/*index=*/-2);
202 if (key.Equals(kCollectionKey)) {
203 classification.collection = env->ReadString(/*index=*/-1).ToString();
204 } else if (key.Equals(kScoreKey)) {
205 classification.score =
206 static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
207 } else if (key.Equals(kTimeUsecKey)) {
208 classification.datetime_parse_result.time_ms_utc =
209 static_cast<int64>(lua_tonumber(env->state(), /*idx=*/-1));
210 } else if (key.Equals(kGranularityKey)) {
211 classification.datetime_parse_result.granularity =
212 static_cast<DatetimeGranularity>(
213 lua_tonumber(env->state(), /*idx=*/-1));
214 } else if (key.Equals(kSerializedEntity)) {
215 classification.serialized_entity_data =
216 env->ReadString(/*index=*/-1).ToString();
217 } else if (key.Equals(kEntityKey)) {
218 auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
219 env->ReadFlatbuffer(buffer.get());
220 classification.serialized_entity_data = buffer->Serialize();
221 } else {
222 TC3_LOG(INFO) << "Unknown classification result field: "
223 << key.ToString();
224 }
225 lua_pop(env->state(), 1);
226 }
227 return classification;
228 }
229
PushAnnotation(const ActionSuggestionAnnotation & annotation,const reflection::Schema * entity_data_schema,LuaEnvironment * env)230 void PushAnnotation(const ActionSuggestionAnnotation& annotation,
231 const reflection::Schema* entity_data_schema,
232 LuaEnvironment* env) {
233 PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema,
234 env);
235 env->PushString(annotation.name);
236 lua_setfield(env->state(), /*idx=*/-2, kNameKey);
237 {
238 lua_newtable(env->state());
239 lua_pushinteger(env->state(), annotation.span.message_index);
240 lua_setfield(env->state(), /*idx=*/-2, kMessageKey);
241 lua_pushinteger(env->state(), annotation.span.span.first);
242 lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
243 lua_pushinteger(env->state(), annotation.span.span.second);
244 lua_setfield(env->state(), /*idx=*/-2, kEndKey);
245 }
246 lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
247 }
248
PushAction(const ActionSuggestion & action,const reflection::Schema * entity_data_schema,const AnnotationIterator<ActionSuggestionAnnotation> & annotation_iterator,LuaEnvironment * env)249 void PushAction(
250 const ActionSuggestion& action,
251 const reflection::Schema* entity_data_schema,
252 const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
253 LuaEnvironment* env) {
254 if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) {
255 // Empty table.
256 lua_newtable(env->state());
257 } else {
258 env->PushFlatbuffer(entity_data_schema,
259 flatbuffers::GetRoot<flatbuffers::Table>(
260 action.serialized_entity_data.data()));
261 }
262 env->PushString(action.type);
263 lua_setfield(env->state(), /*idx=*/-2, kTypeKey);
264 env->PushString(action.response_text);
265 lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey);
266 lua_pushnumber(env->state(), action.score);
267 lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
268 lua_pushnumber(env->state(), action.priority_score);
269 lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey);
270 annotation_iterator.NewIterator(kAnnotationKey, &action.annotations,
271 env->state());
272 lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey);
273 }
274
ReadAction(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema,LuaEnvironment * env)275 ActionSuggestion ReadAction(
276 const reflection::Schema* actions_entity_data_schema,
277 const reflection::Schema* annotations_entity_data_schema,
278 LuaEnvironment* env) {
279 ActionSuggestion action;
280 lua_pushnil(env->state());
281 while (lua_next(env->state(), /*idx=*/-2)) {
282 const StringPiece key = env->ReadString(/*index=*/-2);
283 if (key.Equals(kResponseTextKey)) {
284 action.response_text = env->ReadString(/*index=*/-1).ToString();
285 } else if (key.Equals(kTypeKey)) {
286 action.type = env->ReadString(/*index=*/-1).ToString();
287 } else if (key.Equals(kScoreKey)) {
288 action.score = static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
289 } else if (key.Equals(kPriorityScoreKey)) {
290 action.priority_score =
291 static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
292 } else if (key.Equals(kAnnotationKey)) {
293 ReadAnnotations(actions_entity_data_schema, env, &action.annotations);
294 } else if (key.Equals(kEntityKey)) {
295 auto buffer =
296 ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
297 env->ReadFlatbuffer(buffer.get());
298 action.serialized_entity_data = buffer->Serialize();
299 } else {
300 TC3_LOG(INFO) << "Unknown action field: " << key.ToString();
301 }
302 lua_pop(env->state(), 1);
303 }
304 return action;
305 }
306
ReadActions(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema,LuaEnvironment * env,std::vector<ActionSuggestion> * actions)307 int ReadActions(const reflection::Schema* actions_entity_data_schema,
308 const reflection::Schema* annotations_entity_data_schema,
309 LuaEnvironment* env, std::vector<ActionSuggestion>* actions) {
310 if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
311 TC3_LOG(ERROR) << "Expected actions table, got: "
312 << lua_type(env->state(), /*idx=*/-1);
313 lua_pop(env->state(), 1);
314 lua_error(env->state());
315 return LUA_ERRRUN;
316 }
317
318 // Read actions.
319 lua_pushnil(env->state());
320 while (lua_next(env->state(), /*idx=*/-2)) {
321 if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
322 TC3_LOG(ERROR) << "Expected action table, got: "
323 << lua_type(env->state(), /*idx=*/-1);
324 lua_pop(env->state(), 1);
325 continue;
326 }
327 actions->push_back(ReadAction(actions_entity_data_schema,
328 annotations_entity_data_schema, env));
329 lua_pop(env->state(), /*n=1*/ 1);
330 }
331 lua_pop(env->state(), /*n=*/1);
332
333 return LUA_OK;
334 }
335
Item(const std::vector<ConversationMessage> * messages,const int64 pos,lua_State * state) const336 int ConversationIterator::Item(const std::vector<ConversationMessage>* messages,
337 const int64 pos, lua_State* state) const {
338 const ConversationMessage& message = (*messages)[pos];
339 lua_newtable(state);
340 lua_pushinteger(state, message.user_id);
341 lua_setfield(state, /*idx=*/-2, "user_id");
342 env_->PushString(message.text);
343 lua_setfield(state, /*idx=*/-2, "text");
344 lua_pushinteger(state, message.reference_time_ms_utc);
345 lua_setfield(state, /*idx=*/-2, "time_ms_utc");
346 env_->PushString(message.reference_timezone);
347 lua_setfield(state, /*idx=*/-2, "timezone");
348 annotated_span_iterator_.NewIterator("annotation", &message.annotations,
349 state);
350 lua_setfield(state, /*idx=*/-2, "annotation");
351 return 1;
352 }
353
354 } // namespace libtextclassifier3
355