• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/lua-utils.h"
18 
19 // lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
20 #ifndef TC3_AOSP
21 #define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
22 #endif
23 
24 namespace libtextclassifier3 {
25 namespace {
26 // Upvalue indices for the flatbuffer callback.
27 static constexpr int kSchemaArgId = 1;
28 static constexpr int kTypeArgId = 2;
29 static constexpr int kTableArgId = 3;
30 
31 static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
32                                            {LUA_TABLIBNAME, luaopen_table},
33                                            {LUA_STRLIBNAME, luaopen_string},
34                                            {LUA_BITLIBNAME, luaopen_bit32},
35                                            {LUA_MATHLIBNAME, luaopen_math},
36                                            {nullptr, nullptr}};
37 
38 // Implementation of a lua_Writer that appends the data to a string.
LuaStringWriter(lua_State * state,const void * data,size_t size,void * result)39 int LuaStringWriter(lua_State *state, const void *data, size_t size,
40                     void *result) {
41   std::string *const result_string = static_cast<std::string *>(result);
42   result_string->insert(result_string->size(), static_cast<const char *>(data),
43                         size);
44   return LUA_OK;
45 }
46 
47 }  // namespace
48 
LuaEnvironment()49 LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
50 
~LuaEnvironment()51 LuaEnvironment::~LuaEnvironment() {
52   if (state_ != nullptr) {
53     lua_close(state_);
54   }
55 }
56 
NextCallback(lua_State * state)57 int LuaEnvironment::Iterator::NextCallback(lua_State *state) {
58   return FromUpValue<Iterator *>(kIteratorArgId, state)->Next(state);
59 }
60 
LengthCallback(lua_State * state)61 int LuaEnvironment::Iterator::LengthCallback(lua_State *state) {
62   return FromUpValue<Iterator *>(kIteratorArgId, state)->Length(state);
63 }
64 
ItemCallback(lua_State * state)65 int LuaEnvironment::Iterator::ItemCallback(lua_State *state) {
66   return FromUpValue<Iterator *>(kIteratorArgId, state)->Item(state);
67 }
68 
IteritemsCallback(lua_State * state)69 int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) {
70   return FromUpValue<Iterator *>(kIteratorArgId, state)->Iteritems(state);
71 }
72 
PushFlatbuffer(const char * name,const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table,lua_State * state)73 void LuaEnvironment::PushFlatbuffer(const char *name,
74                                     const reflection::Schema *schema,
75                                     const reflection::Object *type,
76                                     const flatbuffers::Table *table,
77                                     lua_State *state) {
78   lua_newtable(state);
79   luaL_newmetatable(state, name);
80   lua_pushlightuserdata(state, AsUserData(schema));
81   lua_pushlightuserdata(state, AsUserData(type));
82   lua_pushlightuserdata(state, AsUserData(table));
83   lua_pushcclosure(state, &GetFieldCallback, 3);
84   lua_setfield(state, -2, kIndexKey);
85   lua_setmetatable(state, -2);
86 }
87 
GetFieldCallback(lua_State * state)88 int LuaEnvironment::GetFieldCallback(lua_State *state) {
89   // Fetch the arguments.
90   const reflection::Schema *schema =
91       FromUpValue<reflection::Schema *>(kSchemaArgId, state);
92   const reflection::Object *type =
93       FromUpValue<reflection::Object *>(kTypeArgId, state);
94   const flatbuffers::Table *table =
95       FromUpValue<flatbuffers::Table *>(kTableArgId, state);
96   return GetField(schema, type, table, state);
97 }
98 
GetField(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table,lua_State * state)99 int LuaEnvironment::GetField(const reflection::Schema *schema,
100                              const reflection::Object *type,
101                              const flatbuffers::Table *table,
102                              lua_State *state) {
103   const char *field_name = lua_tostring(state, -1);
104   const reflection::Field *field = type->fields()->LookupByKey(field_name);
105   if (field == nullptr) {
106     lua_error(state);
107     return 0;
108   }
109   // Provide primitive fields directly.
110   const reflection::BaseType field_type = field->type()->base_type();
111   switch (field_type) {
112     case reflection::Bool:
113       lua_pushboolean(state, table->GetField<uint8_t>(
114                                  field->offset(), field->default_integer()));
115       break;
116     case reflection::Int:
117       lua_pushinteger(state, table->GetField<int32>(field->offset(),
118                                                     field->default_integer()));
119       break;
120     case reflection::Long:
121       lua_pushinteger(state, table->GetField<int64>(field->offset(),
122                                                     field->default_integer()));
123       break;
124     case reflection::Float:
125       lua_pushnumber(state, table->GetField<float>(field->offset(),
126                                                    field->default_real()));
127       break;
128     case reflection::Double:
129       lua_pushnumber(state, table->GetField<double>(field->offset(),
130                                                     field->default_real()));
131       break;
132     case reflection::String: {
133       const flatbuffers::String *string_value =
134           table->GetPointer<const flatbuffers::String *>(field->offset());
135       if (string_value != nullptr) {
136         lua_pushlstring(state, string_value->data(), string_value->Length());
137       } else {
138         lua_pushlstring(state, "", 0);
139       }
140       break;
141     }
142     case reflection::Obj: {
143       const flatbuffers::Table *field_table =
144           table->GetPointer<const flatbuffers::Table *>(field->offset());
145       if (field_table == nullptr) {
146         TC3_LOG(ERROR) << "Field was not set in entity data.";
147         lua_error(state);
148         return 0;
149       }
150       const reflection::Object *field_type =
151           schema->objects()->Get(field->type()->index());
152       PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table,
153                      state);
154       break;
155     }
156     default:
157       TC3_LOG(ERROR) << "Unsupported type: " << field_type;
158       lua_error(state);
159       return 0;
160   }
161   return 1;
162 }
163 
ReadFlatbuffer(ReflectiveFlatbuffer * buffer)164 int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) {
165   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
166     TC3_LOG(ERROR) << "Expected actions table, got: "
167                    << lua_type(state_, /*idx=*/-1);
168     lua_error(state_);
169     return LUA_ERRRUN;
170   }
171 
172   lua_pushnil(state_);
173   while (lua_next(state_, /*idx=*/-2)) {
174     const StringPiece key = ReadString(/*index=*/-2);
175     const reflection::Field *field = buffer->GetFieldOrNull(key);
176     if (field == nullptr) {
177       TC3_LOG(ERROR) << "Unknown field: " << key.ToString();
178       lua_error(state_);
179       return LUA_ERRRUN;
180     }
181     switch (field->type()->base_type()) {
182       case reflection::Obj:
183         return ReadFlatbuffer(buffer->Mutable(field));
184       case reflection::Bool:
185         buffer->Set(field,
186                     static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
187         break;
188       case reflection::Int:
189         buffer->Set(field, static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
190         break;
191       case reflection::Long:
192         buffer->Set(field,
193                     static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
194         break;
195       case reflection::Float:
196         buffer->Set(field,
197                     static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
198         break;
199       case reflection::Double:
200         buffer->Set(field,
201                     static_cast<double>(lua_tonumber(state_, /*idx=*/-1)));
202         break;
203       case reflection::String: {
204         buffer->Set(field, ReadString(/*index=*/-1));
205         break;
206       }
207       default:
208         TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
209         lua_error(state_);
210         return LUA_ERRRUN;
211     }
212     lua_pop(state_, 1);
213   }
214   // lua_pop(state_, /*n=*/1);
215   return LUA_OK;
216 }
217 
LoadDefaultLibraries()218 void LuaEnvironment::LoadDefaultLibraries() {
219   for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
220     luaL_requiref(state_, lib->name, lib->func, 1);
221     lua_pop(state_, 1); /* remove lib */
222   }
223 }
224 
PushValue(const Variant & value)225 void LuaEnvironment::PushValue(const Variant &value) {
226   if (value.HasInt()) {
227     lua_pushnumber(state_, value.IntValue());
228   } else if (value.HasInt64()) {
229     lua_pushnumber(state_, value.Int64Value());
230   } else if (value.HasBool()) {
231     lua_pushboolean(state_, value.BoolValue());
232   } else if (value.HasFloat()) {
233     lua_pushnumber(state_, value.FloatValue());
234   } else if (value.HasDouble()) {
235     lua_pushnumber(state_, value.DoubleValue());
236   } else if (value.HasString()) {
237     lua_pushlstring(state_, value.StringValue().data(),
238                     value.StringValue().size());
239   } else {
240     TC3_LOG(FATAL) << "Unknown value type.";
241   }
242 }
243 
ReadString(const int index) const244 StringPiece LuaEnvironment::ReadString(const int index) const {
245   size_t length = 0;
246   const char *data = lua_tolstring(state_, index, &length);
247   return StringPiece(data, length);
248 }
249 
PushString(const StringPiece str)250 void LuaEnvironment::PushString(const StringPiece str) {
251   lua_pushlstring(state_, str.data(), str.size());
252 }
253 
PushFlatbuffer(const reflection::Schema * schema,const flatbuffers::Table * table)254 void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema,
255                                     const flatbuffers::Table *table) {
256   PushFlatbuffer(schema->root_table()->name()->c_str(), schema,
257                  schema->root_table(), table, state_);
258 }
259 
RunProtected(const std::function<int ()> & func,const int num_args,const int num_results)260 int LuaEnvironment::RunProtected(const std::function<int()> &func,
261                                  const int num_args, const int num_results) {
262   struct ProtectedCall {
263     std::function<int()> func;
264 
265     static int run(lua_State *state) {
266       // Read the pointer to the ProtectedCall struct.
267       ProtectedCall *p = static_cast<ProtectedCall *>(
268           lua_touserdata(state, lua_upvalueindex(1)));
269       return p->func();
270     }
271   };
272   ProtectedCall protected_call = {func};
273   lua_pushlightuserdata(state_, &protected_call);
274   lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1);
275   // Put the closure before the arguments on the stack.
276   if (num_args > 0) {
277     lua_insert(state_, -(1 + num_args));
278   }
279   return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
280 }
281 
Compile(StringPiece snippet,std::string * bytecode)282 bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) {
283   if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
284                       /*name=*/nullptr) != LUA_OK) {
285     TC3_LOG(ERROR) << "Could not compile lua snippet: "
286                    << ReadString(/*index=*/-1).ToString();
287     lua_pop(state_, 1);
288     return false;
289   }
290   if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
291     TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
292     lua_pop(state_, 1);
293     return false;
294   }
295   lua_pop(state_, 1);
296   return true;
297 }
298 
Compile(StringPiece snippet,std::string * bytecode)299 bool Compile(StringPiece snippet, std::string *bytecode) {
300   return LuaEnvironment().Compile(snippet, bytecode);
301 }
302 
303 }  // namespace libtextclassifier3
304