1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/lua-utils.h"
18
19 namespace libtextclassifier3 {
20 namespace {
21 static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
22 {LUA_TABLIBNAME, luaopen_table},
23 {LUA_STRLIBNAME, luaopen_string},
24 {LUA_MATHLIBNAME, luaopen_math},
25 {nullptr, nullptr}};
26
27 static constexpr const char kTextKey[] = "text";
28 static constexpr const char kTimeUsecKey[] = "parsed_time_ms_utc";
29 static constexpr const char kGranularityKey[] = "granularity";
30 static constexpr const char kCollectionKey[] = "collection";
31 static constexpr const char kNameKey[] = "name";
32 static constexpr const char kScoreKey[] = "score";
33 static constexpr const char kPriorityScoreKey[] = "priority_score";
34 static constexpr const char kTypeKey[] = "type";
35 static constexpr const char kResponseTextKey[] = "response_text";
36 static constexpr const char kAnnotationKey[] = "annotation";
37 static constexpr const char kSpanKey[] = "span";
38 static constexpr const char kMessageKey[] = "message";
39 static constexpr const char kBeginKey[] = "begin";
40 static constexpr const char kEndKey[] = "end";
41 static constexpr const char kClassificationKey[] = "classification";
42 static constexpr const char kSerializedEntity[] = "serialized_entity";
43 static constexpr const char kEntityKey[] = "entity";
44
45 // Implementation of a lua_Writer that appends the data to a string.
LuaStringWriter(lua_State * state,const void * data,size_t size,void * result)46 int LuaStringWriter(lua_State* state, const void* data, size_t size,
47 void* result) {
48 std::string* const result_string = static_cast<std::string*>(result);
49 result_string->insert(result_string->size(), static_cast<const char*>(data),
50 size);
51 return LUA_OK;
52 }
53
54 } // namespace
55
LuaEnvironment()56 LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
57
~LuaEnvironment()58 LuaEnvironment::~LuaEnvironment() {
59 if (state_ != nullptr) {
60 lua_close(state_);
61 }
62 }
63
PushFlatbuffer(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table) const64 void LuaEnvironment::PushFlatbuffer(const reflection::Schema* schema,
65 const reflection::Object* type,
66 const flatbuffers::Table* table) const {
67 PushLazyObject(
68 std::bind(&LuaEnvironment::GetField, this, schema, type, table));
69 }
70
GetField(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table) const71 int LuaEnvironment::GetField(const reflection::Schema* schema,
72 const reflection::Object* type,
73 const flatbuffers::Table* table) const {
74 const char* field_name = lua_tostring(state_, /*idx=*/kIndexStackTop);
75 const reflection::Field* field = type->fields()->LookupByKey(field_name);
76 if (field == nullptr) {
77 lua_error(state_);
78 return 0;
79 }
80 // Provide primitive fields directly.
81 const reflection::BaseType field_type = field->type()->base_type();
82 switch (field_type) {
83 case reflection::Bool:
84 Push(table->GetField<bool>(field->offset(), field->default_integer()));
85 break;
86 case reflection::UByte:
87 Push(table->GetField<uint8>(field->offset(), field->default_integer()));
88 break;
89 case reflection::Byte:
90 Push(table->GetField<int8>(field->offset(), field->default_integer()));
91 break;
92 case reflection::Int:
93 Push(table->GetField<int32>(field->offset(), field->default_integer()));
94 break;
95 case reflection::UInt:
96 Push(table->GetField<uint32>(field->offset(), field->default_integer()));
97 break;
98 case reflection::Long:
99 Push(table->GetField<int64>(field->offset(), field->default_integer()));
100 break;
101 case reflection::ULong:
102 Push(table->GetField<uint64>(field->offset(), field->default_integer()));
103 break;
104 case reflection::Float:
105 Push(table->GetField<float>(field->offset(), field->default_real()));
106 break;
107 case reflection::Double:
108 Push(table->GetField<double>(field->offset(), field->default_real()));
109 break;
110 case reflection::String: {
111 Push(table->GetPointer<const flatbuffers::String*>(field->offset()));
112 break;
113 }
114 case reflection::Obj: {
115 const flatbuffers::Table* field_table =
116 table->GetPointer<const flatbuffers::Table*>(field->offset());
117 if (field_table == nullptr) {
118 // Field was not set in entity data.
119 return 0;
120 }
121 const reflection::Object* field_type =
122 schema->objects()->Get(field->type()->index());
123 PushFlatbuffer(schema, field_type, field_table);
124 break;
125 }
126 case reflection::Vector: {
127 const flatbuffers::Vector<flatbuffers::Offset<void>>* field_vector =
128 table->GetPointer<
129 const flatbuffers::Vector<flatbuffers::Offset<void>>*>(
130 field->offset());
131 if (field_vector == nullptr) {
132 // Repeated field was not set in flatbuffer.
133 PushEmptyVector();
134 break;
135 }
136 switch (field->type()->element()) {
137 case reflection::Bool:
138 PushRepeatedField(table->GetPointer<const flatbuffers::Vector<bool>*>(
139 field->offset()));
140 break;
141 case reflection::UByte:
142 PushRepeatedField(
143 table->GetPointer<const flatbuffers::Vector<uint8>*>(
144 field->offset()));
145 break;
146 case reflection::Byte:
147 PushRepeatedField(table->GetPointer<const flatbuffers::Vector<int8>*>(
148 field->offset()));
149 break;
150 case reflection::Int:
151 PushRepeatedField(
152 table->GetPointer<const flatbuffers::Vector<int32>*>(
153 field->offset()));
154 break;
155 case reflection::UInt:
156 PushRepeatedField(
157 table->GetPointer<const flatbuffers::Vector<uint32>*>(
158 field->offset()));
159 break;
160 case reflection::Long:
161 PushRepeatedField(
162 table->GetPointer<const flatbuffers::Vector<int64>*>(
163 field->offset()));
164 break;
165 case reflection::ULong:
166 PushRepeatedField(
167 table->GetPointer<const flatbuffers::Vector<uint64>*>(
168 field->offset()));
169 break;
170 case reflection::Float:
171 PushRepeatedField(
172 table->GetPointer<const flatbuffers::Vector<float>*>(
173 field->offset()));
174 break;
175 case reflection::Double:
176 PushRepeatedField(
177 table->GetPointer<const flatbuffers::Vector<double>*>(
178 field->offset()));
179 break;
180 case reflection::String:
181 PushRepeatedField(
182 table->GetPointer<const flatbuffers::Vector<
183 flatbuffers::Offset<flatbuffers::String>>*>(field->offset()));
184 break;
185 case reflection::Obj:
186 PushRepeatedFlatbufferField(
187 schema, schema->objects()->Get(field->type()->index()),
188 table->GetPointer<const flatbuffers::Vector<
189 flatbuffers::Offset<flatbuffers::Table>>*>(field->offset()));
190 break;
191 default:
192 TC3_LOG(ERROR) << "Unsupported repeated type: "
193 << field->type()->element();
194 lua_error(state_);
195 return 0;
196 }
197 break;
198 }
199 default:
200 TC3_LOG(ERROR) << "Unsupported type: " << field_type;
201 lua_error(state_);
202 return 0;
203 }
204 return 1;
205 }
206
ReadFlatbuffer(const int index,MutableFlatbuffer * buffer) const207 int LuaEnvironment::ReadFlatbuffer(const int index,
208 MutableFlatbuffer* buffer) const {
209 if (buffer == nullptr) {
210 TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
211 lua_error(state_);
212 return LUA_ERRRUN;
213 }
214 if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
215 TC3_LOG(ERROR) << "Expected table, got: "
216 << lua_type(state_, /*idx=*/kIndexStackTop);
217 lua_error(state_);
218 return LUA_ERRRUN;
219 }
220
221 lua_pushnil(state_);
222 while (Next(index - 1)) {
223 const StringPiece key = ReadString(/*index=*/index - 1);
224 const reflection::Field* field = buffer->GetFieldOrNull(key);
225 if (field == nullptr) {
226 TC3_LOG(ERROR) << "Unknown field: " << key;
227 lua_error(state_);
228 return LUA_ERRRUN;
229 }
230 switch (field->type()->base_type()) {
231 case reflection::Obj:
232 ReadFlatbuffer(/*index=*/kIndexStackTop, buffer->Mutable(field));
233 break;
234 case reflection::Bool:
235 buffer->Set(field, Read<bool>(/*index=*/kIndexStackTop));
236 break;
237 case reflection::Byte:
238 buffer->Set(field, Read<int8>(/*index=*/kIndexStackTop));
239 break;
240 case reflection::UByte:
241 buffer->Set(field, Read<uint8>(/*index=*/kIndexStackTop));
242 break;
243 case reflection::Int:
244 buffer->Set(field, Read<int32>(/*index=*/kIndexStackTop));
245 break;
246 case reflection::UInt:
247 buffer->Set(field, Read<uint32>(/*index=*/kIndexStackTop));
248 break;
249 case reflection::Long:
250 buffer->Set(field, Read<int64>(/*index=*/kIndexStackTop));
251 break;
252 case reflection::ULong:
253 buffer->Set(field, Read<uint64>(/*index=*/kIndexStackTop));
254 break;
255 case reflection::Float:
256 buffer->Set(field, Read<float>(/*index=*/kIndexStackTop));
257 break;
258 case reflection::Double:
259 buffer->Set(field, Read<double>(/*index=*/kIndexStackTop));
260 break;
261 case reflection::String: {
262 buffer->Set(field, ReadString(/*index=*/kIndexStackTop));
263 break;
264 }
265 case reflection::Vector: {
266 // Read repeated field.
267 switch (field->type()->element()) {
268 case reflection::Bool:
269 ReadRepeatedField<bool>(/*index=*/kIndexStackTop,
270 buffer->Repeated(field));
271 break;
272 case reflection::Byte:
273 ReadRepeatedField<int8>(/*index=*/kIndexStackTop,
274 buffer->Repeated(field));
275 break;
276 case reflection::UByte:
277 ReadRepeatedField<uint8>(/*index=*/kIndexStackTop,
278 buffer->Repeated(field));
279 break;
280 case reflection::Int:
281 ReadRepeatedField<int32>(/*index=*/kIndexStackTop,
282 buffer->Repeated(field));
283 break;
284 case reflection::UInt:
285 ReadRepeatedField<uint32>(/*index=*/kIndexStackTop,
286 buffer->Repeated(field));
287 break;
288 case reflection::Long:
289 ReadRepeatedField<int64>(/*index=*/kIndexStackTop,
290 buffer->Repeated(field));
291 break;
292 case reflection::ULong:
293 ReadRepeatedField<uint64>(/*index=*/kIndexStackTop,
294 buffer->Repeated(field));
295 break;
296 case reflection::Float:
297 ReadRepeatedField<float>(/*index=*/kIndexStackTop,
298 buffer->Repeated(field));
299 break;
300 case reflection::Double:
301 ReadRepeatedField<double>(/*index=*/kIndexStackTop,
302 buffer->Repeated(field));
303 break;
304 case reflection::String:
305 ReadRepeatedField<std::string>(/*index=*/kIndexStackTop,
306 buffer->Repeated(field));
307 break;
308 case reflection::Obj:
309 ReadRepeatedField<MutableFlatbuffer>(/*index=*/kIndexStackTop,
310 buffer->Repeated(field));
311 break;
312 default:
313 TC3_LOG(ERROR) << "Unsupported repeated field type: "
314 << field->type()->element();
315 lua_error(state_);
316 return LUA_ERRRUN;
317 }
318 break;
319 }
320 default:
321 TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
322 lua_error(state_);
323 return LUA_ERRRUN;
324 }
325 lua_pop(state_, 1);
326 }
327 return LUA_OK;
328 }
329
LoadDefaultLibraries()330 void LuaEnvironment::LoadDefaultLibraries() {
331 for (const luaL_Reg* lib = defaultlibs; lib->func; lib++) {
332 luaL_requiref(state_, lib->name, lib->func, 1);
333 lua_pop(state_, 1); // Remove lib.
334 }
335 }
336
ReadString(const int index) const337 StringPiece LuaEnvironment::ReadString(const int index) const {
338 size_t length = 0;
339 const char* data = lua_tolstring(state_, index, &length);
340 return StringPiece(data, length);
341 }
342
PushString(const StringPiece str) const343 void LuaEnvironment::PushString(const StringPiece str) const {
344 lua_pushlstring(state_, str.data(), str.size());
345 }
346
Compile(StringPiece snippet,std::string * bytecode) const347 bool LuaEnvironment::Compile(StringPiece snippet, std::string* bytecode) const {
348 if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
349 /*name=*/nullptr) != LUA_OK) {
350 TC3_LOG(ERROR) << "Could not compile lua snippet: "
351 << ReadString(/*index=*/kIndexStackTop);
352 lua_pop(state_, 1);
353 return false;
354 }
355 if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
356 TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
357 lua_pop(state_, 1);
358 return false;
359 }
360 lua_pop(state_, 1);
361 return true;
362 }
363
PushAnnotation(const ClassificationResult & classification,const reflection::Schema * entity_data_schema) const364 void LuaEnvironment::PushAnnotation(
365 const ClassificationResult& classification,
366 const reflection::Schema* entity_data_schema) const {
367 if (entity_data_schema == nullptr ||
368 classification.serialized_entity_data.empty()) {
369 // Empty table.
370 lua_newtable(state_);
371 } else {
372 PushFlatbuffer(entity_data_schema,
373 flatbuffers::GetRoot<flatbuffers::Table>(
374 classification.serialized_entity_data.data()));
375 }
376 Push(classification.datetime_parse_result.time_ms_utc);
377 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTimeUsecKey);
378 Push(classification.datetime_parse_result.granularity);
379 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kGranularityKey);
380 Push(classification.collection);
381 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kCollectionKey);
382 Push(classification.score);
383 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
384 Push(classification.serialized_entity_data);
385 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSerializedEntity);
386 }
387
PushAnnotation(const ClassificationResult & classification,StringPiece text,const reflection::Schema * entity_data_schema) const388 void LuaEnvironment::PushAnnotation(
389 const ClassificationResult& classification, StringPiece text,
390 const reflection::Schema* entity_data_schema) const {
391 PushAnnotation(classification, entity_data_schema);
392 Push(text);
393 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTextKey);
394 }
395
PushAnnotation(const ActionSuggestionAnnotation & annotation,const reflection::Schema * entity_data_schema) const396 void LuaEnvironment::PushAnnotation(
397 const ActionSuggestionAnnotation& annotation,
398 const reflection::Schema* entity_data_schema) const {
399 PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema);
400 PushString(annotation.name);
401 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kNameKey);
402 {
403 lua_newtable(state_);
404 Push(annotation.span.message_index);
405 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kMessageKey);
406 Push(annotation.span.span.first);
407 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
408 Push(annotation.span.span.second);
409 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
410 }
411 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
412 }
413
PushAnnotatedSpan(const AnnotatedSpan & annotated_span,const reflection::Schema * entity_data_schema) const414 void LuaEnvironment::PushAnnotatedSpan(
415 const AnnotatedSpan& annotated_span,
416 const reflection::Schema* entity_data_schema) const {
417 lua_newtable(state_);
418 {
419 lua_newtable(state_);
420 Push(annotated_span.span.first);
421 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
422 Push(annotated_span.span.second);
423 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
424 }
425 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
426 PushAnnotations(&annotated_span.classification, entity_data_schema);
427 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kClassificationKey);
428 }
429
PushAnnotatedSpans(const std::vector<AnnotatedSpan> * annotated_spans,const reflection::Schema * entity_data_schema) const430 void LuaEnvironment::PushAnnotatedSpans(
431 const std::vector<AnnotatedSpan>* annotated_spans,
432 const reflection::Schema* entity_data_schema) const {
433 PushIterator(annotated_spans ? annotated_spans->size() : 0,
434 [this, annotated_spans, entity_data_schema](const int64 index) {
435 PushAnnotatedSpan(annotated_spans->at(index),
436 entity_data_schema);
437 return 1;
438 });
439 }
440
ReadSpan() const441 MessageTextSpan LuaEnvironment::ReadSpan() const {
442 MessageTextSpan span;
443 lua_pushnil(state_);
444 while (Next(/*index=*/kIndexStackTop - 1)) {
445 const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
446 if (key.Equals(kMessageKey)) {
447 span.message_index = Read<int>(/*index=*/kIndexStackTop);
448 } else if (key.Equals(kBeginKey)) {
449 span.span.first = Read<int>(/*index=*/kIndexStackTop);
450 } else if (key.Equals(kEndKey)) {
451 span.span.second = Read<int>(/*index=*/kIndexStackTop);
452 } else if (key.Equals(kTextKey)) {
453 span.text = Read<std::string>(/*index=*/kIndexStackTop);
454 } else {
455 TC3_LOG(INFO) << "Unknown span field: " << key;
456 }
457 lua_pop(state_, 1);
458 }
459 return span;
460 }
461
ReadAnnotations(const reflection::Schema * entity_data_schema,std::vector<ActionSuggestionAnnotation> * annotations) const462 int LuaEnvironment::ReadAnnotations(
463 const reflection::Schema* entity_data_schema,
464 std::vector<ActionSuggestionAnnotation>* annotations) const {
465 if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
466 TC3_LOG(ERROR) << "Expected annotations table, got: "
467 << lua_type(state_, /*idx=*/kIndexStackTop);
468 lua_pop(state_, 1);
469 lua_error(state_);
470 return LUA_ERRRUN;
471 }
472
473 // Read actions.
474 lua_pushnil(state_);
475 while (Next(/*index=*/kIndexStackTop - 1)) {
476 if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
477 TC3_LOG(ERROR) << "Expected annotation table, got: "
478 << lua_type(state_, /*idx=*/kIndexStackTop);
479 lua_pop(state_, 1);
480 continue;
481 }
482 annotations->push_back(ReadAnnotation(entity_data_schema));
483 lua_pop(state_, 1);
484 }
485 return LUA_OK;
486 }
487
ReadAnnotation(const reflection::Schema * entity_data_schema) const488 ActionSuggestionAnnotation LuaEnvironment::ReadAnnotation(
489 const reflection::Schema* entity_data_schema) const {
490 ActionSuggestionAnnotation annotation;
491 lua_pushnil(state_);
492 while (Next(/*index=*/kIndexStackTop - 1)) {
493 const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
494 if (key.Equals(kNameKey)) {
495 annotation.name = Read<std::string>(/*index=*/kIndexStackTop);
496 } else if (key.Equals(kSpanKey)) {
497 annotation.span = ReadSpan();
498 } else if (key.Equals(kEntityKey)) {
499 annotation.entity = ReadClassificationResult(entity_data_schema);
500 } else {
501 TC3_LOG(ERROR) << "Unknown annotation field: " << key;
502 }
503 lua_pop(state_, 1);
504 }
505 return annotation;
506 }
507
ReadClassificationResult(const reflection::Schema * entity_data_schema) const508 ClassificationResult LuaEnvironment::ReadClassificationResult(
509 const reflection::Schema* entity_data_schema) const {
510 ClassificationResult classification;
511 lua_pushnil(state_);
512 while (Next(/*index=*/kIndexStackTop - 1)) {
513 const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
514 if (key.Equals(kCollectionKey)) {
515 classification.collection = Read<std::string>(/*index=*/kIndexStackTop);
516 } else if (key.Equals(kScoreKey)) {
517 classification.score = Read<float>(/*index=*/kIndexStackTop);
518 } else if (key.Equals(kTimeUsecKey)) {
519 classification.datetime_parse_result.time_ms_utc =
520 Read<int64>(/*index=*/kIndexStackTop);
521 } else if (key.Equals(kGranularityKey)) {
522 classification.datetime_parse_result.granularity =
523 static_cast<DatetimeGranularity>(
524 lua_tonumber(state_, /*idx=*/kIndexStackTop));
525 } else if (key.Equals(kSerializedEntity)) {
526 classification.serialized_entity_data =
527 Read<std::string>(/*index=*/kIndexStackTop);
528 } else if (key.Equals(kEntityKey)) {
529 auto buffer = MutableFlatbufferBuilder(entity_data_schema).NewRoot();
530 ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
531 classification.serialized_entity_data = buffer->Serialize();
532 } else {
533 TC3_LOG(INFO) << "Unknown classification result field: " << key;
534 }
535 lua_pop(state_, 1);
536 }
537 return classification;
538 }
539
PushAction(const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const540 void LuaEnvironment::PushAction(
541 const ActionSuggestion& action,
542 const reflection::Schema* actions_entity_data_schema,
543 const reflection::Schema* annotations_entity_data_schema) const {
544 if (actions_entity_data_schema == nullptr ||
545 action.serialized_entity_data.empty()) {
546 // Empty table.
547 lua_newtable(state_);
548 } else {
549 PushFlatbuffer(actions_entity_data_schema,
550 flatbuffers::GetRoot<flatbuffers::Table>(
551 action.serialized_entity_data.data()));
552 }
553 PushString(action.type);
554 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTypeKey);
555 PushString(action.response_text);
556 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kResponseTextKey);
557 Push(action.score);
558 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
559 Push(action.priority_score);
560 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kPriorityScoreKey);
561 PushAnnotations(&action.annotations, annotations_entity_data_schema);
562 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kAnnotationKey);
563 }
564
PushActions(const std::vector<ActionSuggestion> * actions,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const565 void LuaEnvironment::PushActions(
566 const std::vector<ActionSuggestion>* actions,
567 const reflection::Schema* actions_entity_data_schema,
568 const reflection::Schema* annotations_entity_data_schema) const {
569 PushIterator(actions ? actions->size() : 0,
570 [this, actions, actions_entity_data_schema,
571 annotations_entity_data_schema](const int64 index) {
572 PushAction(actions->at(index), actions_entity_data_schema,
573 annotations_entity_data_schema);
574 return 1;
575 });
576 }
577
ReadAction(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const578 ActionSuggestion LuaEnvironment::ReadAction(
579 const reflection::Schema* actions_entity_data_schema,
580 const reflection::Schema* annotations_entity_data_schema) const {
581 ActionSuggestion action;
582 lua_pushnil(state_);
583 while (Next(/*index=*/kIndexStackTop - 1)) {
584 const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
585 if (key.Equals(kResponseTextKey)) {
586 action.response_text = Read<std::string>(/*index=*/kIndexStackTop);
587 } else if (key.Equals(kTypeKey)) {
588 action.type = Read<std::string>(/*index=*/kIndexStackTop);
589 } else if (key.Equals(kScoreKey)) {
590 action.score = Read<float>(/*index=*/kIndexStackTop);
591 } else if (key.Equals(kPriorityScoreKey)) {
592 action.priority_score = Read<float>(/*index=*/kIndexStackTop);
593 } else if (key.Equals(kAnnotationKey)) {
594 ReadAnnotations(actions_entity_data_schema, &action.annotations);
595 } else if (key.Equals(kEntityKey)) {
596 auto buffer =
597 MutableFlatbufferBuilder(actions_entity_data_schema).NewRoot();
598 ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
599 action.serialized_entity_data = buffer->Serialize();
600 } else {
601 TC3_LOG(INFO) << "Unknown action field: " << key;
602 }
603 lua_pop(state_, 1);
604 }
605 return action;
606 }
607
ReadActions(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema,std::vector<ActionSuggestion> * actions) const608 int LuaEnvironment::ReadActions(
609 const reflection::Schema* actions_entity_data_schema,
610 const reflection::Schema* annotations_entity_data_schema,
611 std::vector<ActionSuggestion>* actions) const {
612 // Read actions.
613 lua_pushnil(state_);
614 while (Next(/*index=*/kIndexStackTop - 1)) {
615 if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
616 TC3_LOG(ERROR) << "Expected action table, got: "
617 << lua_type(state_, /*idx=*/kIndexStackTop);
618 lua_pop(state_, 1);
619 continue;
620 }
621 actions->push_back(
622 ReadAction(actions_entity_data_schema, annotations_entity_data_schema));
623 lua_pop(state_, /*n=*/1);
624 }
625 lua_pop(state_, /*n=*/1);
626
627 return LUA_OK;
628 }
629
PushConversation(const std::vector<ConversationMessage> * conversation,const reflection::Schema * annotations_entity_data_schema) const630 void LuaEnvironment::PushConversation(
631 const std::vector<ConversationMessage>* conversation,
632 const reflection::Schema* annotations_entity_data_schema) const {
633 PushIterator(
634 conversation ? conversation->size() : 0,
635 [this, conversation, annotations_entity_data_schema](const int64 index) {
636 const ConversationMessage& message = conversation->at(index);
637 lua_newtable(state_);
638 Push(message.user_id);
639 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "user_id");
640 Push(message.text);
641 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "text");
642 Push(message.reference_time_ms_utc);
643 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "time_ms_utc");
644 Push(message.reference_timezone);
645 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "timezone");
646 PushAnnotatedSpans(&message.annotations,
647 annotations_entity_data_schema);
648 lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "annotation");
649 return 1;
650 });
651 }
652
Compile(StringPiece snippet,std::string * bytecode)653 bool Compile(StringPiece snippet, std::string* bytecode) {
654 return LuaEnvironment().Compile(snippet, bytecode);
655 }
656
657 } // namespace libtextclassifier3
658