1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/lua-utils.h"
18
19 // lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
20 #ifndef TC3_AOSP
21 #define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
22 #endif
23
24 namespace libtextclassifier3 {
25 namespace {
26 // Upvalue indices for the flatbuffer callback.
27 static constexpr int kSchemaArgId = 1;
28 static constexpr int kTypeArgId = 2;
29 static constexpr int kTableArgId = 3;
30
31 static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
32 {LUA_TABLIBNAME, luaopen_table},
33 {LUA_STRLIBNAME, luaopen_string},
34 {LUA_BITLIBNAME, luaopen_bit32},
35 {LUA_MATHLIBNAME, luaopen_math},
36 {nullptr, nullptr}};
37
38 // Implementation of a lua_Writer that appends the data to a string.
LuaStringWriter(lua_State * state,const void * data,size_t size,void * result)39 int LuaStringWriter(lua_State *state, const void *data, size_t size,
40 void *result) {
41 std::string *const result_string = static_cast<std::string *>(result);
42 result_string->insert(result_string->size(), static_cast<const char *>(data),
43 size);
44 return LUA_OK;
45 }
46
47 } // namespace
48
LuaEnvironment()49 LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
50
~LuaEnvironment()51 LuaEnvironment::~LuaEnvironment() {
52 if (state_ != nullptr) {
53 lua_close(state_);
54 }
55 }
56
NextCallback(lua_State * state)57 int LuaEnvironment::Iterator::NextCallback(lua_State *state) {
58 return FromUpValue<Iterator *>(kIteratorArgId, state)->Next(state);
59 }
60
LengthCallback(lua_State * state)61 int LuaEnvironment::Iterator::LengthCallback(lua_State *state) {
62 return FromUpValue<Iterator *>(kIteratorArgId, state)->Length(state);
63 }
64
ItemCallback(lua_State * state)65 int LuaEnvironment::Iterator::ItemCallback(lua_State *state) {
66 return FromUpValue<Iterator *>(kIteratorArgId, state)->Item(state);
67 }
68
IteritemsCallback(lua_State * state)69 int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) {
70 return FromUpValue<Iterator *>(kIteratorArgId, state)->Iteritems(state);
71 }
72
PushFlatbuffer(const char * name,const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table,lua_State * state)73 void LuaEnvironment::PushFlatbuffer(const char *name,
74 const reflection::Schema *schema,
75 const reflection::Object *type,
76 const flatbuffers::Table *table,
77 lua_State *state) {
78 lua_newtable(state);
79 luaL_newmetatable(state, name);
80 lua_pushlightuserdata(state, AsUserData(schema));
81 lua_pushlightuserdata(state, AsUserData(type));
82 lua_pushlightuserdata(state, AsUserData(table));
83 lua_pushcclosure(state, &GetFieldCallback, 3);
84 lua_setfield(state, -2, kIndexKey);
85 lua_setmetatable(state, -2);
86 }
87
GetFieldCallback(lua_State * state)88 int LuaEnvironment::GetFieldCallback(lua_State *state) {
89 // Fetch the arguments.
90 const reflection::Schema *schema =
91 FromUpValue<reflection::Schema *>(kSchemaArgId, state);
92 const reflection::Object *type =
93 FromUpValue<reflection::Object *>(kTypeArgId, state);
94 const flatbuffers::Table *table =
95 FromUpValue<flatbuffers::Table *>(kTableArgId, state);
96 return GetField(schema, type, table, state);
97 }
98
GetField(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table,lua_State * state)99 int LuaEnvironment::GetField(const reflection::Schema *schema,
100 const reflection::Object *type,
101 const flatbuffers::Table *table,
102 lua_State *state) {
103 const char *field_name = lua_tostring(state, -1);
104 const reflection::Field *field = type->fields()->LookupByKey(field_name);
105 if (field == nullptr) {
106 lua_error(state);
107 return 0;
108 }
109 // Provide primitive fields directly.
110 const reflection::BaseType field_type = field->type()->base_type();
111 switch (field_type) {
112 case reflection::Bool:
113 lua_pushboolean(state, table->GetField<uint8_t>(
114 field->offset(), field->default_integer()));
115 break;
116 case reflection::Int:
117 lua_pushinteger(state, table->GetField<int32>(field->offset(),
118 field->default_integer()));
119 break;
120 case reflection::Long:
121 lua_pushinteger(state, table->GetField<int64>(field->offset(),
122 field->default_integer()));
123 break;
124 case reflection::Float:
125 lua_pushnumber(state, table->GetField<float>(field->offset(),
126 field->default_real()));
127 break;
128 case reflection::Double:
129 lua_pushnumber(state, table->GetField<double>(field->offset(),
130 field->default_real()));
131 break;
132 case reflection::String: {
133 const flatbuffers::String *string_value =
134 table->GetPointer<const flatbuffers::String *>(field->offset());
135 if (string_value != nullptr) {
136 lua_pushlstring(state, string_value->data(), string_value->Length());
137 } else {
138 lua_pushlstring(state, "", 0);
139 }
140 break;
141 }
142 case reflection::Obj: {
143 const flatbuffers::Table *field_table =
144 table->GetPointer<const flatbuffers::Table *>(field->offset());
145 if (field_table == nullptr) {
146 TC3_LOG(ERROR) << "Field was not set in entity data.";
147 lua_error(state);
148 return 0;
149 }
150 const reflection::Object *field_type =
151 schema->objects()->Get(field->type()->index());
152 PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table,
153 state);
154 break;
155 }
156 default:
157 TC3_LOG(ERROR) << "Unsupported type: " << field_type;
158 lua_error(state);
159 return 0;
160 }
161 return 1;
162 }
163
ReadFlatbuffer(ReflectiveFlatbuffer * buffer)164 int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) {
165 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
166 TC3_LOG(ERROR) << "Expected actions table, got: "
167 << lua_type(state_, /*idx=*/-1);
168 lua_error(state_);
169 return LUA_ERRRUN;
170 }
171
172 lua_pushnil(state_);
173 while (lua_next(state_, /*idx=*/-2)) {
174 const StringPiece key = ReadString(/*index=*/-2);
175 const reflection::Field *field = buffer->GetFieldOrNull(key);
176 if (field == nullptr) {
177 TC3_LOG(ERROR) << "Unknown field: " << key.ToString();
178 lua_error(state_);
179 return LUA_ERRRUN;
180 }
181 switch (field->type()->base_type()) {
182 case reflection::Obj:
183 return ReadFlatbuffer(buffer->Mutable(field));
184 case reflection::Bool:
185 buffer->Set(field,
186 static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
187 break;
188 case reflection::Int:
189 buffer->Set(field, static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
190 break;
191 case reflection::Long:
192 buffer->Set(field,
193 static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
194 break;
195 case reflection::Float:
196 buffer->Set(field,
197 static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
198 break;
199 case reflection::Double:
200 buffer->Set(field,
201 static_cast<double>(lua_tonumber(state_, /*idx=*/-1)));
202 break;
203 case reflection::String: {
204 buffer->Set(field, ReadString(/*index=*/-1));
205 break;
206 }
207 default:
208 TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
209 lua_error(state_);
210 return LUA_ERRRUN;
211 }
212 lua_pop(state_, 1);
213 }
214 // lua_pop(state_, /*n=*/1);
215 return LUA_OK;
216 }
217
LoadDefaultLibraries()218 void LuaEnvironment::LoadDefaultLibraries() {
219 for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
220 luaL_requiref(state_, lib->name, lib->func, 1);
221 lua_pop(state_, 1); /* remove lib */
222 }
223 }
224
PushValue(const Variant & value)225 void LuaEnvironment::PushValue(const Variant &value) {
226 if (value.HasInt()) {
227 lua_pushnumber(state_, value.IntValue());
228 } else if (value.HasInt64()) {
229 lua_pushnumber(state_, value.Int64Value());
230 } else if (value.HasBool()) {
231 lua_pushboolean(state_, value.BoolValue());
232 } else if (value.HasFloat()) {
233 lua_pushnumber(state_, value.FloatValue());
234 } else if (value.HasDouble()) {
235 lua_pushnumber(state_, value.DoubleValue());
236 } else if (value.HasString()) {
237 lua_pushlstring(state_, value.StringValue().data(),
238 value.StringValue().size());
239 } else {
240 TC3_LOG(FATAL) << "Unknown value type.";
241 }
242 }
243
ReadString(const int index) const244 StringPiece LuaEnvironment::ReadString(const int index) const {
245 size_t length = 0;
246 const char *data = lua_tolstring(state_, index, &length);
247 return StringPiece(data, length);
248 }
249
PushString(const StringPiece str)250 void LuaEnvironment::PushString(const StringPiece str) {
251 lua_pushlstring(state_, str.data(), str.size());
252 }
253
PushFlatbuffer(const reflection::Schema * schema,const flatbuffers::Table * table)254 void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema,
255 const flatbuffers::Table *table) {
256 PushFlatbuffer(schema->root_table()->name()->c_str(), schema,
257 schema->root_table(), table, state_);
258 }
259
RunProtected(const std::function<int ()> & func,const int num_args,const int num_results)260 int LuaEnvironment::RunProtected(const std::function<int()> &func,
261 const int num_args, const int num_results) {
262 struct ProtectedCall {
263 std::function<int()> func;
264
265 static int run(lua_State *state) {
266 // Read the pointer to the ProtectedCall struct.
267 ProtectedCall *p = static_cast<ProtectedCall *>(
268 lua_touserdata(state, lua_upvalueindex(1)));
269 return p->func();
270 }
271 };
272 ProtectedCall protected_call = {func};
273 lua_pushlightuserdata(state_, &protected_call);
274 lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1);
275 // Put the closure before the arguments on the stack.
276 if (num_args > 0) {
277 lua_insert(state_, -(1 + num_args));
278 }
279 return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
280 }
281
Compile(StringPiece snippet,std::string * bytecode)282 bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) {
283 if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
284 /*name=*/nullptr) != LUA_OK) {
285 TC3_LOG(ERROR) << "Could not compile lua snippet: "
286 << ReadString(/*index=*/-1).ToString();
287 lua_pop(state_, 1);
288 return false;
289 }
290 if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
291 TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
292 lua_pop(state_, 1);
293 return false;
294 }
295 lua_pop(state_, 1);
296 return true;
297 }
298
Compile(StringPiece snippet,std::string * bytecode)299 bool Compile(StringPiece snippet, std::string *bytecode) {
300 return LuaEnvironment().Compile(snippet, bytecode);
301 }
302
303 } // namespace libtextclassifier3
304