1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
18 #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
19
20 #include <vector>
21
22 #include "actions/types.h"
23 #include "annotator/types.h"
24 #include "utils/flatbuffers/mutable.h"
25 #include "utils/strings/stringpiece.h"
26 #include "utils/variant.h"
27 #include "flatbuffers/reflection_generated.h"
28
29 #ifdef __cplusplus
30 extern "C" {
31 #endif
32 #include "lauxlib.h"
33 #include "lua.h"
34 #include "lualib.h"
35 #ifdef __cplusplus
36 }
37 #endif
38
39 namespace libtextclassifier3 {
40
41 static constexpr const char kLengthKey[] = "__len";
42 static constexpr const char kPairsKey[] = "__pairs";
43 static constexpr const char kIndexKey[] = "__index";
44 static constexpr const char kGcKey[] = "__gc";
45 static constexpr const char kNextKey[] = "__next";
46
47 static constexpr const int kIndexStackTop = -1;
48
49 // Casts to the lua user data type.
50 template <typename T>
AsUserData(const T * value)51 void* AsUserData(const T* value) {
52 return static_cast<void*>(const_cast<T*>(value));
53 }
54 template <typename T>
AsUserData(const T value)55 void* AsUserData(const T value) {
56 return reinterpret_cast<void*>(value);
57 }
58
59 // Retrieves up-values.
60 template <typename T>
FromUpValue(const int index,lua_State * state)61 T FromUpValue(const int index, lua_State* state) {
62 return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
63 }
64
65 class LuaEnvironment {
66 public:
67 virtual ~LuaEnvironment();
68 explicit LuaEnvironment();
69
70 // Compile a lua snippet into binary bytecode.
71 // NOTE: The compiled bytecode might not be compatible across Lua versions
72 // and platforms.
73 bool Compile(StringPiece snippet, std::string* bytecode) const;
74
75 // Loads default libraries.
76 void LoadDefaultLibraries();
77
78 // Provides a callback to Lua.
79 template <typename T>
PushFunction(int (T::* handler)())80 void PushFunction(int (T::*handler)()) {
81 PushFunction(std::bind(handler, static_cast<T*>(this)));
82 }
83
84 template <typename F>
PushFunction(const F & func)85 void PushFunction(const F& func) const {
86 // Copy closure to the lua stack.
87 new (lua_newuserdata(state_, sizeof(func))) F(func);
88
89 // Register garbage collection callback.
90 lua_newtable(state_);
91 lua_pushcfunction(state_, &ReleaseFunction<F>);
92 lua_setfield(state_, -2, kGcKey);
93 lua_setmetatable(state_, -2);
94
95 // Push dispatch.
96 lua_pushcclosure(state_, &CallFunction<F>, 1);
97 }
98
99 // Sets up a named table that calls back whenever a member is accessed.
100 // This allows to lazily provide required information to the script.
101 template <typename T>
PushLazyObject(int (T::* handler)())102 void PushLazyObject(int (T::*handler)()) {
103 PushLazyObject(std::bind(handler, static_cast<T*>(this)));
104 }
105
106 template <typename F>
PushLazyObject(const F & func)107 void PushLazyObject(const F& func) const {
108 lua_newtable(state_);
109 lua_newtable(state_);
110 PushFunction(func);
111 lua_setfield(state_, -2, kIndexKey);
112 lua_setmetatable(state_, -2);
113 }
114
Push(const int64 value)115 void Push(const int64 value) const { lua_pushinteger(state_, value); }
Push(const uint64 value)116 void Push(const uint64 value) const { lua_pushinteger(state_, value); }
Push(const int32 value)117 void Push(const int32 value) const { lua_pushinteger(state_, value); }
Push(const uint32 value)118 void Push(const uint32 value) const { lua_pushinteger(state_, value); }
Push(const int16 value)119 void Push(const int16 value) const { lua_pushinteger(state_, value); }
Push(const uint16 value)120 void Push(const uint16 value) const { lua_pushinteger(state_, value); }
Push(const int8 value)121 void Push(const int8 value) const { lua_pushinteger(state_, value); }
Push(const uint8 value)122 void Push(const uint8 value) const { lua_pushinteger(state_, value); }
Push(const float value)123 void Push(const float value) const { lua_pushnumber(state_, value); }
Push(const double value)124 void Push(const double value) const { lua_pushnumber(state_, value); }
Push(const bool value)125 void Push(const bool value) const { lua_pushboolean(state_, value); }
Push(const StringPiece value)126 void Push(const StringPiece value) const { PushString(value); }
Push(const flatbuffers::String * value)127 void Push(const flatbuffers::String* value) const {
128 if (value == nullptr) {
129 PushString("");
130 } else {
131 PushString(StringPiece(value->c_str(), value->size()));
132 }
133 }
134
135 template <typename T>
136 T Read(const int index = -1) const;
137
138 template <>
139 int64 Read<int64>(const int index) const {
140 return static_cast<int64>(lua_tointeger(state_, /*idx=*/index));
141 }
142
143 template <>
144 uint64 Read<uint64>(const int index) const {
145 return static_cast<uint64>(lua_tointeger(state_, /*idx=*/index));
146 }
147
148 template <>
149 int32 Read<int32>(const int index) const {
150 return static_cast<int32>(lua_tointeger(state_, /*idx=*/index));
151 }
152
153 template <>
154 uint32 Read<uint32>(const int index) const {
155 return static_cast<uint32>(lua_tointeger(state_, /*idx=*/index));
156 }
157
158 template <>
159 int16 Read<int16>(const int index) const {
160 return static_cast<int16>(lua_tointeger(state_, /*idx=*/index));
161 }
162
163 template <>
164 uint16 Read<uint16>(const int index) const {
165 return static_cast<uint16>(lua_tointeger(state_, /*idx=*/index));
166 }
167
168 template <>
169 int8 Read<int8>(const int index) const {
170 return static_cast<int8>(lua_tointeger(state_, /*idx=*/index));
171 }
172
173 template <>
174 uint8 Read<uint8>(const int index) const {
175 return static_cast<uint8>(lua_tointeger(state_, /*idx=*/index));
176 }
177
178 template <>
179 float Read<float>(const int index) const {
180 return static_cast<float>(lua_tonumber(state_, /*idx=*/index));
181 }
182
183 template <>
184 double Read<double>(const int index) const {
185 return static_cast<double>(lua_tonumber(state_, /*idx=*/index));
186 }
187
188 template <>
189 bool Read<bool>(const int index) const {
190 return lua_toboolean(state_, /*idx=*/index);
191 }
192
193 template <>
194 StringPiece Read<StringPiece>(const int index) const {
195 return ReadString(index);
196 }
197
198 template <>
199 std::string Read<std::string>(const int index) const {
200 return ReadString(index).ToString();
201 }
202
203 // Reads a string from the stack.
204 StringPiece ReadString(int index) const;
205
206 // Pushes a string to the stack.
207 void PushString(const StringPiece str) const;
208
209 // Pushes a flatbuffer to the stack.
PushFlatbuffer(const reflection::Schema * schema,const flatbuffers::Table * table)210 void PushFlatbuffer(const reflection::Schema* schema,
211 const flatbuffers::Table* table) const {
212 PushFlatbuffer(schema, schema->root_table(), table);
213 }
214
215 // Reads a flatbuffer from the stack.
216 int ReadFlatbuffer(int index, MutableFlatbuffer* buffer) const;
217
218 // Pushes an iterator.
219 template <typename ItemCallback, typename KeyCallback>
PushIterator(const int length,const ItemCallback & item_callback,const KeyCallback & key_callback)220 void PushIterator(const int length, const ItemCallback& item_callback,
221 const KeyCallback& key_callback) const {
222 lua_newtable(state_);
223 CreateIteratorMetatable(length, item_callback);
224 PushFunction([this, length, item_callback, key_callback]() {
225 return Iterator::Dispatch(this, length, item_callback, key_callback);
226 });
227 lua_setfield(state_, -2, kIndexKey);
228 lua_setmetatable(state_, -2);
229 }
230
231 template <typename ItemCallback>
PushIterator(const int length,const ItemCallback & item_callback)232 void PushIterator(const int length, const ItemCallback& item_callback) const {
233 lua_newtable(state_);
234 CreateIteratorMetatable(length, item_callback);
235 PushFunction([this, length, item_callback]() {
236 return Iterator::Dispatch(this, length, item_callback);
237 });
238 lua_setfield(state_, -2, kIndexKey);
239 lua_setmetatable(state_, -2);
240 }
241
242 template <typename ItemCallback>
CreateIteratorMetatable(const int length,const ItemCallback & item_callback)243 void CreateIteratorMetatable(const int length,
244 const ItemCallback& item_callback) const {
245 lua_newtable(state_);
246 PushFunction([this, length]() { return Iterator::Length(this, length); });
247 lua_setfield(state_, -2, kLengthKey);
248 PushFunction([this, length, item_callback]() {
249 return Iterator::IterItems(this, length, item_callback);
250 });
251 lua_setfield(state_, -2, kPairsKey);
252 PushFunction([this, length, item_callback]() {
253 return Iterator::Next(this, length, item_callback);
254 });
255 lua_setfield(state_, -2, kNextKey);
256 }
257
258 template <typename T>
PushVectorIterator(const std::vector<T> * items)259 void PushVectorIterator(const std::vector<T>* items) const {
260 PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
261 this->Push(items->at(pos));
262 return 1;
263 });
264 }
265
266 template <typename T>
PushVector(const std::vector<T> & items)267 void PushVector(const std::vector<T>& items) const {
268 lua_newtable(state_);
269 for (int i = 0; i < items.size(); i++) {
270 // Key: index, 1-based.
271 Push(i + 1);
272
273 // Value.
274 Push(items[i]);
275 lua_settable(state_, /*idx=*/-3);
276 }
277 }
278
PushEmptyVector()279 void PushEmptyVector() const { lua_newtable(state_); }
280
281 template <typename T>
282 std::vector<T> ReadVector(const int index = -1) const {
283 std::vector<T> result;
284 if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
285 TC3_LOG(ERROR) << "Expected a table, got: "
286 << lua_type(state_, /*idx=*/kIndexStackTop);
287 lua_pop(state_, 1);
288 return {};
289 }
290 lua_pushnil(state_);
291 while (Next(index - 1)) {
292 result.push_back(Read<T>(/*index=*/kIndexStackTop));
293 lua_pop(state_, 1);
294 }
295 return result;
296 }
297
298 // Runs a closure in protected mode.
299 // `func`: closure to run in protected mode.
300 // `num_lua_args`: number of arguments from the lua stack to process.
301 // `num_results`: number of result values pushed on the stack.
302 template <typename F>
303 int RunProtected(const F& func, const int num_args = 0,
304 const int num_results = 0) const {
305 PushFunction(func);
306 // Put the closure before the arguments on the stack.
307 if (num_args > 0) {
308 lua_insert(state_, -(1 + num_args));
309 }
310 return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
311 }
312
313 // Auxiliary methods to handle model results.
314 // Provides an annotation to lua.
315 void PushAnnotation(const ClassificationResult& classification,
316 const reflection::Schema* entity_data_schema) const;
317 void PushAnnotation(const ClassificationResult& classification,
318 StringPiece text,
319 const reflection::Schema* entity_data_schema) const;
320 void PushAnnotation(const ActionSuggestionAnnotation& annotation,
321 const reflection::Schema* entity_data_schema) const;
322
323 template <typename Annotation>
PushAnnotations(const std::vector<Annotation> * annotations,const reflection::Schema * entity_data_schema)324 void PushAnnotations(const std::vector<Annotation>* annotations,
325 const reflection::Schema* entity_data_schema) const {
326 PushIterator(
327 annotations ? annotations->size() : 0,
328 [this, annotations, entity_data_schema](const int64 index) {
329 PushAnnotation(annotations->at(index), entity_data_schema);
330 return 1;
331 },
332 [this, annotations, entity_data_schema](StringPiece name) {
333 if (const Annotation* annotation =
334 GetAnnotationByName(*annotations, name)) {
335 PushAnnotation(*annotation, entity_data_schema);
336 return 1;
337 } else {
338 return 0;
339 }
340 });
341 }
342
343 // Pushes a span to the lua stack.
344 void PushAnnotatedSpan(const AnnotatedSpan& annotated_span,
345 const reflection::Schema* entity_data_schema) const;
346 void PushAnnotatedSpans(const std::vector<AnnotatedSpan>* annotated_spans,
347 const reflection::Schema* entity_data_schema) const;
348
349 // Reads a message text span from lua.
350 MessageTextSpan ReadSpan() const;
351
352 ActionSuggestionAnnotation ReadAnnotation(
353 const reflection::Schema* entity_data_schema) const;
354 int ReadAnnotations(
355 const reflection::Schema* entity_data_schema,
356 std::vector<ActionSuggestionAnnotation>* annotations) const;
357 ClassificationResult ReadClassificationResult(
358 const reflection::Schema* entity_data_schema) const;
359
360 // Provides an action to lua.
361 void PushAction(
362 const ActionSuggestion& action,
363 const reflection::Schema* actions_entity_data_schema,
364 const reflection::Schema* annotations_entity_data_schema) const;
365
366 void PushActions(
367 const std::vector<ActionSuggestion>* actions,
368 const reflection::Schema* actions_entity_data_schema,
369 const reflection::Schema* annotations_entity_data_schema) const;
370
371 ActionSuggestion ReadAction(
372 const reflection::Schema* actions_entity_data_schema,
373 const reflection::Schema* annotations_entity_data_schema) const;
374
375 int ReadActions(const reflection::Schema* actions_entity_data_schema,
376 const reflection::Schema* annotations_entity_data_schema,
377 std::vector<ActionSuggestion>* actions) const;
378
379 // Conversation message iterator.
380 void PushConversation(
381 const std::vector<ConversationMessage>* conversation,
382 const reflection::Schema* annotations_entity_data_schema) const;
383
state()384 lua_State* state() const { return state_; }
385
386 protected:
387 // Wrapper for handling iteration over containers.
388 class Iterator {
389 public:
390 // Starts a new key-value pair iterator.
391 template <typename ItemCallback>
IterItems(const LuaEnvironment * env,const int length,const ItemCallback & callback)392 static int IterItems(const LuaEnvironment* env, const int length,
393 const ItemCallback& callback) {
394 env->PushFunction([env, callback, length, pos = 0]() mutable {
395 if (pos >= length) {
396 lua_pushnil(env->state());
397 return 1;
398 }
399
400 // Push key.
401 lua_pushinteger(env->state(), pos + 1);
402
403 // Push item.
404 return 1 + callback(pos++);
405 });
406 return 1; // Num. results.
407 }
408
409 // Gets the next element.
410 template <typename ItemCallback>
Next(const LuaEnvironment * env,const int length,const ItemCallback & item_callback)411 static int Next(const LuaEnvironment* env, const int length,
412 const ItemCallback& item_callback) {
413 int64 pos = lua_isnil(env->state(), /*idx=*/kIndexStackTop)
414 ? 0
415 : env->Read<int64>(/*index=*/kIndexStackTop);
416 if (pos < length) {
417 // Push next key.
418 lua_pushinteger(env->state(), pos + 1);
419
420 // Push item.
421 return 1 + item_callback(pos);
422 } else {
423 lua_pushnil(env->state());
424 return 1;
425 }
426 }
427
428 // Returns the length of the container the iterator processes.
Length(const LuaEnvironment * env,const int length)429 static int Length(const LuaEnvironment* env, const int length) {
430 lua_pushinteger(env->state(), length);
431 return 1; // Num. results.
432 }
433
434 // Handles item queries to the iterator.
435 // Elements of the container can either be queried by name or index.
436 // Dispatch will check how an element is accessed and
437 // calls `key_callback` for access by name and `item_callback` for access by
438 // index.
439 template <typename ItemCallback, typename KeyCallback>
Dispatch(const LuaEnvironment * env,const int length,const ItemCallback & item_callback,const KeyCallback & key_callback)440 static int Dispatch(const LuaEnvironment* env, const int length,
441 const ItemCallback& item_callback,
442 const KeyCallback& key_callback) {
443 switch (lua_type(env->state(), kIndexStackTop)) {
444 case LUA_TNUMBER: {
445 // Lua is one based, so adjust the index here.
446 const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
447 if (index < 0 || index >= length) {
448 TC3_LOG(ERROR) << "Invalid index: " << index;
449 lua_error(env->state());
450 return 0;
451 }
452 return item_callback(index);
453 }
454 case LUA_TSTRING: {
455 return key_callback(env->ReadString(kIndexStackTop));
456 }
457 default:
458 TC3_LOG(ERROR) << "Unexpected access type: "
459 << lua_type(env->state(), kIndexStackTop);
460 lua_error(env->state());
461 return 0;
462 }
463 }
464
465 template <typename ItemCallback>
Dispatch(const LuaEnvironment * env,const int length,const ItemCallback & item_callback)466 static int Dispatch(const LuaEnvironment* env, const int length,
467 const ItemCallback& item_callback) {
468 switch (lua_type(env->state(), kIndexStackTop)) {
469 case LUA_TNUMBER: {
470 // Lua is one based, so adjust the index here.
471 const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
472 if (index < 0 || index >= length) {
473 TC3_LOG(ERROR) << "Invalid index: " << index;
474 lua_error(env->state());
475 return 0;
476 }
477 return item_callback(index);
478 }
479 default:
480 TC3_LOG(ERROR) << "Unexpected access type: "
481 << lua_type(env->state(), kIndexStackTop);
482 lua_error(env->state());
483 return 0;
484 }
485 }
486 };
487
488 // Calls the deconstructor from a previously pushed function.
489 template <typename T>
ReleaseFunction(lua_State * state)490 static int ReleaseFunction(lua_State* state) {
491 static_cast<T*>(lua_touserdata(state, 1))->~T();
492 return 0;
493 }
494
495 template <typename T>
CallFunction(lua_State * state)496 static int CallFunction(lua_State* state) {
497 return (*static_cast<T*>(lua_touserdata(state, lua_upvalueindex(1))))();
498 }
499
500 // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
501 void PushFlatbuffer(const reflection::Schema* schema,
502 const reflection::Object* type,
503 const flatbuffers::Table* table) const;
504 int GetField(const reflection::Schema* schema, const reflection::Object* type,
505 const flatbuffers::Table* table) const;
506
507 // Reads a repeated field from lua.
508 template <typename T>
ReadRepeatedField(const int index,RepeatedField * result)509 void ReadRepeatedField(const int index, RepeatedField* result) const {
510 for (const T& element : ReadVector<T>(index)) {
511 result->Add(element);
512 }
513 }
514
515 template <>
516 void ReadRepeatedField<MutableFlatbuffer>(const int index,
517 RepeatedField* result) const {
518 lua_pushnil(state_);
519 while (Next(index - 1)) {
520 ReadFlatbuffer(index, result->Add());
521 lua_pop(state_, 1);
522 }
523 }
524
525 // Pushes a repeated field to the lua stack.
526 template <typename T>
PushRepeatedField(const flatbuffers::Vector<T> * items)527 void PushRepeatedField(const flatbuffers::Vector<T>* items) const {
528 PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
529 Push(items->Get(pos));
530 return 1; // Num. results.
531 });
532 }
533
PushRepeatedFlatbufferField(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>> * items)534 void PushRepeatedFlatbufferField(
535 const reflection::Schema* schema, const reflection::Object* type,
536 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>* items)
537 const {
538 PushIterator(items ? items->size() : 0,
539 [this, schema, type, items](const int64 pos) {
540 PushFlatbuffer(schema, type, items->Get(pos));
541 return 1; // Num. results.
542 });
543 }
544
545 // Overloads Lua next function to use __next key on the metatable.
546 // This allows us to treat lua objects and lazy objects provided by our
547 // callbacks uniformly.
Next(int index)548 int Next(int index) const {
549 // Check whether the (meta)table of this object has an associated "__next"
550 // entry. This means, we registered our own callback. So we explicitly call
551 // that.
552 if (luaL_getmetafield(state_, index, kNextKey)) {
553 // Callback is now on top of the stack, so adjust relative indices by 1.
554 if (index < 0) {
555 index--;
556 }
557
558 // Copy the reference to the table.
559 lua_pushvalue(state_, index);
560
561 // Move the key to top to have it as second argument for the callback.
562 // Copy the key to the top.
563 lua_pushvalue(state_, -3);
564
565 // Remove the copy of the key.
566 lua_remove(state_, -4);
567
568 // Call the callback with (key and table as arguments).
569 lua_pcall(state_, /*nargs=*/2 /* table, key */,
570 /*nresults=*/2 /* key, item */, 0);
571
572 // Next returned nil, it's the end.
573 if (lua_isnil(state_, kIndexStackTop)) {
574 // Remove nil value.
575 // Results will be padded to `nresults` specified above, so we need
576 // to remove two elements here.
577 lua_pop(state_, 2);
578 return 0;
579 }
580
581 return 2; // Num. results.
582 } else if (lua_istable(state_, index)) {
583 return lua_next(state_, index);
584 }
585
586 // Remove the key.
587 lua_pop(state_, 1);
588 return 0;
589 }
590
GetAnnotationByName(const std::vector<ClassificationResult> & annotations,StringPiece name)591 static const ClassificationResult* GetAnnotationByName(
592 const std::vector<ClassificationResult>& annotations, StringPiece name) {
593 // Lookup annotation by collection.
594 for (const ClassificationResult& annotation : annotations) {
595 if (name.Equals(annotation.collection)) {
596 return &annotation;
597 }
598 }
599 TC3_LOG(ERROR) << "No annotation with collection: " << name << " found.";
600 return nullptr;
601 }
602
GetAnnotationByName(const std::vector<ActionSuggestionAnnotation> & annotations,StringPiece name)603 static const ActionSuggestionAnnotation* GetAnnotationByName(
604 const std::vector<ActionSuggestionAnnotation>& annotations,
605 StringPiece name) {
606 // Lookup annotation by name.
607 for (const ActionSuggestionAnnotation& annotation : annotations) {
608 if (name.Equals(annotation.name)) {
609 return &annotation;
610 }
611 }
612 TC3_LOG(ERROR) << "No annotation with name: " << name << " found.";
613 return nullptr;
614 }
615
616 lua_State* state_;
617 }; // namespace libtextclassifier3
618
619 bool Compile(StringPiece snippet, std::string* bytecode);
620
621 } // namespace libtextclassifier3
622
623 #endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
624