/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/lua-utils.h" // lua_dump takes an extra argument "strip" in 5.3, but not in 5.2. #ifndef TC3_AOSP #define lua_dump(L, w, d, s) lua_dump((L), (w), (d)) #endif namespace libtextclassifier3 { namespace { // Upvalue indices for the flatbuffer callback. static constexpr int kSchemaArgId = 1; static constexpr int kTypeArgId = 2; static constexpr int kTableArgId = 3; static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base}, {LUA_TABLIBNAME, luaopen_table}, {LUA_STRLIBNAME, luaopen_string}, {LUA_BITLIBNAME, luaopen_bit32}, {LUA_MATHLIBNAME, luaopen_math}, {nullptr, nullptr}}; // Implementation of a lua_Writer that appends the data to a string. int LuaStringWriter(lua_State *state, const void *data, size_t size, void *result) { std::string *const result_string = static_cast(result); result_string->insert(result_string->size(), static_cast(data), size); return LUA_OK; } } // namespace LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); } LuaEnvironment::~LuaEnvironment() { if (state_ != nullptr) { lua_close(state_); } } int LuaEnvironment::Iterator::NextCallback(lua_State *state) { return FromUpValue(kIteratorArgId, state)->Next(state); } int LuaEnvironment::Iterator::LengthCallback(lua_State *state) { return FromUpValue(kIteratorArgId, state)->Length(state); } int LuaEnvironment::Iterator::ItemCallback(lua_State *state) { return FromUpValue(kIteratorArgId, state)->Item(state); } int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) { return FromUpValue(kIteratorArgId, state)->Iteritems(state); } void LuaEnvironment::PushFlatbuffer(const char *name, const reflection::Schema *schema, const reflection::Object *type, const flatbuffers::Table *table, lua_State *state) { lua_newtable(state); luaL_newmetatable(state, name); lua_pushlightuserdata(state, AsUserData(schema)); lua_pushlightuserdata(state, AsUserData(type)); lua_pushlightuserdata(state, AsUserData(table)); lua_pushcclosure(state, &GetFieldCallback, 3); lua_setfield(state, -2, kIndexKey); lua_setmetatable(state, -2); } int LuaEnvironment::GetFieldCallback(lua_State *state) { // Fetch the arguments. const reflection::Schema *schema = FromUpValue(kSchemaArgId, state); const reflection::Object *type = FromUpValue(kTypeArgId, state); const flatbuffers::Table *table = FromUpValue(kTableArgId, state); return GetField(schema, type, table, state); } int LuaEnvironment::GetField(const reflection::Schema *schema, const reflection::Object *type, const flatbuffers::Table *table, lua_State *state) { const char *field_name = lua_tostring(state, -1); const reflection::Field *field = type->fields()->LookupByKey(field_name); if (field == nullptr) { lua_error(state); return 0; } // Provide primitive fields directly. const reflection::BaseType field_type = field->type()->base_type(); switch (field_type) { case reflection::Bool: lua_pushboolean(state, table->GetField( field->offset(), field->default_integer())); break; case reflection::Int: lua_pushinteger(state, table->GetField(field->offset(), field->default_integer())); break; case reflection::Long: lua_pushinteger(state, table->GetField(field->offset(), field->default_integer())); break; case reflection::Float: lua_pushnumber(state, table->GetField(field->offset(), field->default_real())); break; case reflection::Double: lua_pushnumber(state, table->GetField(field->offset(), field->default_real())); break; case reflection::String: { const flatbuffers::String *string_value = table->GetPointer(field->offset()); if (string_value != nullptr) { lua_pushlstring(state, string_value->data(), string_value->Length()); } else { lua_pushlstring(state, "", 0); } break; } case reflection::Obj: { const flatbuffers::Table *field_table = table->GetPointer(field->offset()); if (field_table == nullptr) { TC3_LOG(ERROR) << "Field was not set in entity data."; lua_error(state); return 0; } const reflection::Object *field_type = schema->objects()->Get(field->type()->index()); PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table, state); break; } default: TC3_LOG(ERROR) << "Unsupported type: " << field_type; lua_error(state); return 0; } return 1; } int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) { if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected actions table, got: " << lua_type(state_, /*idx=*/-1); lua_error(state_); return LUA_ERRRUN; } lua_pushnil(state_); while (lua_next(state_, /*idx=*/-2)) { const StringPiece key = ReadString(/*index=*/-2); const reflection::Field *field = buffer->GetFieldOrNull(key); if (field == nullptr) { TC3_LOG(ERROR) << "Unknown field: " << key.ToString(); lua_error(state_); return LUA_ERRRUN; } switch (field->type()->base_type()) { case reflection::Obj: return ReadFlatbuffer(buffer->Mutable(field)); case reflection::Bool: buffer->Set(field, static_cast(lua_toboolean(state_, /*idx=*/-1))); break; case reflection::Int: buffer->Set(field, static_cast(lua_tonumber(state_, /*idx=*/-1))); break; case reflection::Long: buffer->Set(field, static_cast(lua_tonumber(state_, /*idx=*/-1))); break; case reflection::Float: buffer->Set(field, static_cast(lua_tonumber(state_, /*idx=*/-1))); break; case reflection::Double: buffer->Set(field, static_cast(lua_tonumber(state_, /*idx=*/-1))); break; case reflection::String: { buffer->Set(field, ReadString(/*index=*/-1)); break; } default: TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type(); lua_error(state_); return LUA_ERRRUN; } lua_pop(state_, 1); } // lua_pop(state_, /*n=*/1); return LUA_OK; } void LuaEnvironment::LoadDefaultLibraries() { for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) { luaL_requiref(state_, lib->name, lib->func, 1); lua_pop(state_, 1); /* remove lib */ } } void LuaEnvironment::PushValue(const Variant &value) { if (value.HasInt()) { lua_pushnumber(state_, value.IntValue()); } else if (value.HasInt64()) { lua_pushnumber(state_, value.Int64Value()); } else if (value.HasBool()) { lua_pushboolean(state_, value.BoolValue()); } else if (value.HasFloat()) { lua_pushnumber(state_, value.FloatValue()); } else if (value.HasDouble()) { lua_pushnumber(state_, value.DoubleValue()); } else if (value.HasString()) { lua_pushlstring(state_, value.StringValue().data(), value.StringValue().size()); } else { TC3_LOG(FATAL) << "Unknown value type."; } } StringPiece LuaEnvironment::ReadString(const int index) const { size_t length = 0; const char *data = lua_tolstring(state_, index, &length); return StringPiece(data, length); } void LuaEnvironment::PushString(const StringPiece str) { lua_pushlstring(state_, str.data(), str.size()); } void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema, const flatbuffers::Table *table) { PushFlatbuffer(schema->root_table()->name()->c_str(), schema, schema->root_table(), table, state_); } int LuaEnvironment::RunProtected(const std::function &func, const int num_args, const int num_results) { struct ProtectedCall { std::function func; static int run(lua_State *state) { // Read the pointer to the ProtectedCall struct. ProtectedCall *p = static_cast( lua_touserdata(state, lua_upvalueindex(1))); return p->func(); } }; ProtectedCall protected_call = {func}; lua_pushlightuserdata(state_, &protected_call); lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1); // Put the closure before the arguments on the stack. if (num_args > 0) { lua_insert(state_, -(1 + num_args)); } return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0); } bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) { if (luaL_loadbuffer(state_, snippet.data(), snippet.size(), /*name=*/nullptr) != LUA_OK) { TC3_LOG(ERROR) << "Could not compile lua snippet: " << ReadString(/*index=*/-1).ToString(); lua_pop(state_, 1); return false; } if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) { TC3_LOG(ERROR) << "Could not dump compiled lua snippet."; lua_pop(state_, 1); return false; } lua_pop(state_, 1); return true; } bool Compile(StringPiece snippet, std::string *bytecode) { return LuaEnvironment().Compile(snippet, bytecode); } } // namespace libtextclassifier3