/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ #include #include "actions/types.h" #include "annotator/types.h" #include "utils/flatbuffers/mutable.h" #include "utils/strings/stringpiece.h" #include "utils/variant.h" #include "flatbuffers/reflection_generated.h" #ifdef __cplusplus extern "C" { #endif #include "lauxlib.h" #include "lua.h" #include "lualib.h" #ifdef __cplusplus } #endif namespace libtextclassifier3 { static constexpr const char kLengthKey[] = "__len"; static constexpr const char kPairsKey[] = "__pairs"; static constexpr const char kIndexKey[] = "__index"; static constexpr const char kGcKey[] = "__gc"; static constexpr const char kNextKey[] = "__next"; static constexpr const int kIndexStackTop = -1; // Casts to the lua user data type. template void* AsUserData(const T* value) { return static_cast(const_cast(value)); } template void* AsUserData(const T value) { return reinterpret_cast(value); } // Retrieves up-values. template T FromUpValue(const int index, lua_State* state) { return static_cast(lua_touserdata(state, lua_upvalueindex(index))); } class LuaEnvironment { public: virtual ~LuaEnvironment(); explicit LuaEnvironment(); // Compile a lua snippet into binary bytecode. // NOTE: The compiled bytecode might not be compatible across Lua versions // and platforms. bool Compile(StringPiece snippet, std::string* bytecode) const; // Loads default libraries. void LoadDefaultLibraries(); // Provides a callback to Lua. template void PushFunction(int (T::*handler)()) { PushFunction(std::bind(handler, static_cast(this))); } template void PushFunction(const F& func) const { // Copy closure to the lua stack. new (lua_newuserdata(state_, sizeof(func))) F(func); // Register garbage collection callback. lua_newtable(state_); lua_pushcfunction(state_, &ReleaseFunction); lua_setfield(state_, -2, kGcKey); lua_setmetatable(state_, -2); // Push dispatch. lua_pushcclosure(state_, &CallFunction, 1); } // Sets up a named table that calls back whenever a member is accessed. // This allows to lazily provide required information to the script. template void PushLazyObject(int (T::*handler)()) { PushLazyObject(std::bind(handler, static_cast(this))); } template void PushLazyObject(const F& func) const { lua_newtable(state_); lua_newtable(state_); PushFunction(func); lua_setfield(state_, -2, kIndexKey); lua_setmetatable(state_, -2); } void Push(const int64 value) const { lua_pushinteger(state_, value); } void Push(const uint64 value) const { lua_pushinteger(state_, value); } void Push(const int32 value) const { lua_pushinteger(state_, value); } void Push(const uint32 value) const { lua_pushinteger(state_, value); } void Push(const int16 value) const { lua_pushinteger(state_, value); } void Push(const uint16 value) const { lua_pushinteger(state_, value); } void Push(const int8 value) const { lua_pushinteger(state_, value); } void Push(const uint8 value) const { lua_pushinteger(state_, value); } void Push(const float value) const { lua_pushnumber(state_, value); } void Push(const double value) const { lua_pushnumber(state_, value); } void Push(const bool value) const { lua_pushboolean(state_, value); } void Push(const StringPiece value) const { PushString(value); } void Push(const flatbuffers::String* value) const { if (value == nullptr) { PushString(""); } else { PushString(StringPiece(value->c_str(), value->size())); } } template T Read(const int index = -1) const; template <> int64 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> uint64 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> int32 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> uint32 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> int16 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> uint16 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> int8 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> uint8 Read(const int index) const { return static_cast(lua_tointeger(state_, /*idx=*/index)); } template <> float Read(const int index) const { return static_cast(lua_tonumber(state_, /*idx=*/index)); } template <> double Read(const int index) const { return static_cast(lua_tonumber(state_, /*idx=*/index)); } template <> bool Read(const int index) const { return lua_toboolean(state_, /*idx=*/index); } template <> StringPiece Read(const int index) const { return ReadString(index); } template <> std::string Read(const int index) const { return ReadString(index).ToString(); } // Reads a string from the stack. StringPiece ReadString(int index) const; // Pushes a string to the stack. void PushString(const StringPiece str) const; // Pushes a flatbuffer to the stack. void PushFlatbuffer(const reflection::Schema* schema, const flatbuffers::Table* table) const { PushFlatbuffer(schema, schema->root_table(), table); } // Reads a flatbuffer from the stack. int ReadFlatbuffer(int index, MutableFlatbuffer* buffer) const; // Pushes an iterator. template void PushIterator(const int length, const ItemCallback& item_callback, const KeyCallback& key_callback) const { lua_newtable(state_); CreateIteratorMetatable(length, item_callback); PushFunction([this, length, item_callback, key_callback]() { return Iterator::Dispatch(this, length, item_callback, key_callback); }); lua_setfield(state_, -2, kIndexKey); lua_setmetatable(state_, -2); } template void PushIterator(const int length, const ItemCallback& item_callback) const { lua_newtable(state_); CreateIteratorMetatable(length, item_callback); PushFunction([this, length, item_callback]() { return Iterator::Dispatch(this, length, item_callback); }); lua_setfield(state_, -2, kIndexKey); lua_setmetatable(state_, -2); } template void CreateIteratorMetatable(const int length, const ItemCallback& item_callback) const { lua_newtable(state_); PushFunction([this, length]() { return Iterator::Length(this, length); }); lua_setfield(state_, -2, kLengthKey); PushFunction([this, length, item_callback]() { return Iterator::IterItems(this, length, item_callback); }); lua_setfield(state_, -2, kPairsKey); PushFunction([this, length, item_callback]() { return Iterator::Next(this, length, item_callback); }); lua_setfield(state_, -2, kNextKey); } template void PushVectorIterator(const std::vector* items) const { PushIterator(items ? items->size() : 0, [this, items](const int64 pos) { this->Push(items->at(pos)); return 1; }); } template void PushVector(const std::vector& items) const { lua_newtable(state_); for (int i = 0; i < items.size(); i++) { // Key: index, 1-based. Push(i + 1); // Value. Push(items[i]); lua_settable(state_, /*idx=*/-3); } } void PushEmptyVector() const { lua_newtable(state_); } template std::vector ReadVector(const int index = -1) const { std::vector result; if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected a table, got: " << lua_type(state_, /*idx=*/kIndexStackTop); lua_pop(state_, 1); return {}; } lua_pushnil(state_); while (Next(index - 1)) { result.push_back(Read(/*index=*/kIndexStackTop)); lua_pop(state_, 1); } return result; } // Runs a closure in protected mode. // `func`: closure to run in protected mode. // `num_lua_args`: number of arguments from the lua stack to process. // `num_results`: number of result values pushed on the stack. template int RunProtected(const F& func, const int num_args = 0, const int num_results = 0) const { PushFunction(func); // Put the closure before the arguments on the stack. if (num_args > 0) { lua_insert(state_, -(1 + num_args)); } return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0); } // Auxiliary methods to handle model results. // Provides an annotation to lua. void PushAnnotation(const ClassificationResult& classification, const reflection::Schema* entity_data_schema) const; void PushAnnotation(const ClassificationResult& classification, StringPiece text, const reflection::Schema* entity_data_schema) const; void PushAnnotation(const ActionSuggestionAnnotation& annotation, const reflection::Schema* entity_data_schema) const; template void PushAnnotations(const std::vector* annotations, const reflection::Schema* entity_data_schema) const { PushIterator( annotations ? annotations->size() : 0, [this, annotations, entity_data_schema](const int64 index) { PushAnnotation(annotations->at(index), entity_data_schema); return 1; }, [this, annotations, entity_data_schema](StringPiece name) { if (const Annotation* annotation = GetAnnotationByName(*annotations, name)) { PushAnnotation(*annotation, entity_data_schema); return 1; } else { return 0; } }); } // Pushes a span to the lua stack. void PushAnnotatedSpan(const AnnotatedSpan& annotated_span, const reflection::Schema* entity_data_schema) const; void PushAnnotatedSpans(const std::vector* annotated_spans, const reflection::Schema* entity_data_schema) const; // Reads a message text span from lua. MessageTextSpan ReadSpan() const; ActionSuggestionAnnotation ReadAnnotation( const reflection::Schema* entity_data_schema) const; int ReadAnnotations( const reflection::Schema* entity_data_schema, std::vector* annotations) const; ClassificationResult ReadClassificationResult( const reflection::Schema* entity_data_schema) const; // Provides an action to lua. void PushAction( const ActionSuggestion& action, const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema) const; void PushActions( const std::vector* actions, const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema) const; ActionSuggestion ReadAction( const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema) const; int ReadActions(const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema, std::vector* actions) const; // Conversation message iterator. void PushConversation( const std::vector* conversation, const reflection::Schema* annotations_entity_data_schema) const; lua_State* state() const { return state_; } protected: // Wrapper for handling iteration over containers. class Iterator { public: // Starts a new key-value pair iterator. template static int IterItems(const LuaEnvironment* env, const int length, const ItemCallback& callback) { env->PushFunction([env, callback, length, pos = 0]() mutable { if (pos >= length) { lua_pushnil(env->state()); return 1; } // Push key. lua_pushinteger(env->state(), pos + 1); // Push item. return 1 + callback(pos++); }); return 1; // Num. results. } // Gets the next element. template static int Next(const LuaEnvironment* env, const int length, const ItemCallback& item_callback) { int64 pos = lua_isnil(env->state(), /*idx=*/kIndexStackTop) ? 0 : env->Read(/*index=*/kIndexStackTop); if (pos < length) { // Push next key. lua_pushinteger(env->state(), pos + 1); // Push item. return 1 + item_callback(pos); } else { lua_pushnil(env->state()); return 1; } } // Returns the length of the container the iterator processes. static int Length(const LuaEnvironment* env, const int length) { lua_pushinteger(env->state(), length); return 1; // Num. results. } // Handles item queries to the iterator. // Elements of the container can either be queried by name or index. // Dispatch will check how an element is accessed and // calls `key_callback` for access by name and `item_callback` for access by // index. template static int Dispatch(const LuaEnvironment* env, const int length, const ItemCallback& item_callback, const KeyCallback& key_callback) { switch (lua_type(env->state(), kIndexStackTop)) { case LUA_TNUMBER: { // Lua is one based, so adjust the index here. const int64 index = env->Read(/*index=*/kIndexStackTop) - 1; if (index < 0 || index >= length) { TC3_LOG(ERROR) << "Invalid index: " << index; lua_error(env->state()); return 0; } return item_callback(index); } case LUA_TSTRING: { return key_callback(env->ReadString(kIndexStackTop)); } default: TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(env->state(), kIndexStackTop); lua_error(env->state()); return 0; } } template static int Dispatch(const LuaEnvironment* env, const int length, const ItemCallback& item_callback) { switch (lua_type(env->state(), kIndexStackTop)) { case LUA_TNUMBER: { // Lua is one based, so adjust the index here. const int64 index = env->Read(/*index=*/kIndexStackTop) - 1; if (index < 0 || index >= length) { TC3_LOG(ERROR) << "Invalid index: " << index; lua_error(env->state()); return 0; } return item_callback(index); } default: TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(env->state(), kIndexStackTop); lua_error(env->state()); return 0; } } }; // Calls the deconstructor from a previously pushed function. template static int ReleaseFunction(lua_State* state) { static_cast(lua_touserdata(state, 1))->~T(); return 0; } template static int CallFunction(lua_State* state) { return (*static_cast(lua_touserdata(state, lua_upvalueindex(1))))(); } // Auxiliary methods to expose (reflective) flatbuffer based data to Lua. void PushFlatbuffer(const reflection::Schema* schema, const reflection::Object* type, const flatbuffers::Table* table) const; int GetField(const reflection::Schema* schema, const reflection::Object* type, const flatbuffers::Table* table) const; // Reads a repeated field from lua. template void ReadRepeatedField(const int index, RepeatedField* result) const { for (const T& element : ReadVector(index)) { result->Add(element); } } template <> void ReadRepeatedField(const int index, RepeatedField* result) const { lua_pushnil(state_); while (Next(index - 1)) { ReadFlatbuffer(index, result->Add()); lua_pop(state_, 1); } } // Pushes a repeated field to the lua stack. template void PushRepeatedField(const flatbuffers::Vector* items) const { PushIterator(items ? items->size() : 0, [this, items](const int64 pos) { Push(items->Get(pos)); return 1; // Num. results. }); } void PushRepeatedFlatbufferField( const reflection::Schema* schema, const reflection::Object* type, const flatbuffers::Vector>* items) const { PushIterator(items ? items->size() : 0, [this, schema, type, items](const int64 pos) { PushFlatbuffer(schema, type, items->Get(pos)); return 1; // Num. results. }); } // Overloads Lua next function to use __next key on the metatable. // This allows us to treat lua objects and lazy objects provided by our // callbacks uniformly. int Next(int index) const { // Check whether the (meta)table of this object has an associated "__next" // entry. This means, we registered our own callback. So we explicitly call // that. if (luaL_getmetafield(state_, index, kNextKey)) { // Callback is now on top of the stack, so adjust relative indices by 1. if (index < 0) { index--; } // Copy the reference to the table. lua_pushvalue(state_, index); // Move the key to top to have it as second argument for the callback. // Copy the key to the top. lua_pushvalue(state_, -3); // Remove the copy of the key. lua_remove(state_, -4); // Call the callback with (key and table as arguments). lua_pcall(state_, /*nargs=*/2 /* table, key */, /*nresults=*/2 /* key, item */, 0); // Next returned nil, it's the end. if (lua_isnil(state_, kIndexStackTop)) { // Remove nil value. // Results will be padded to `nresults` specified above, so we need // to remove two elements here. lua_pop(state_, 2); return 0; } return 2; // Num. results. } else if (lua_istable(state_, index)) { return lua_next(state_, index); } // Remove the key. lua_pop(state_, 1); return 0; } static const ClassificationResult* GetAnnotationByName( const std::vector& annotations, StringPiece name) { // Lookup annotation by collection. for (const ClassificationResult& annotation : annotations) { if (name.Equals(annotation.collection)) { return &annotation; } } TC3_LOG(ERROR) << "No annotation with collection: " << name << " found."; return nullptr; } static const ActionSuggestionAnnotation* GetAnnotationByName( const std::vector& annotations, StringPiece name) { // Lookup annotation by name. for (const ActionSuggestionAnnotation& annotation : annotations) { if (name.Equals(annotation.name)) { return &annotation; } } TC3_LOG(ERROR) << "No annotation with name: " << name << " found."; return nullptr; } lua_State* state_; }; // namespace libtextclassifier3 bool Compile(StringPiece snippet, std::string* bytecode); } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_