1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
18 #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
19
20 #include <functional>
21 #include <vector>
22
23 #include "utils/flatbuffers.h"
24 #include "utils/strings/stringpiece.h"
25 #include "utils/variant.h"
26 #include "flatbuffers/reflection_generated.h"
27
28 #ifdef __cplusplus
29 extern "C" {
30 #endif
31 #include "lauxlib.h"
32 #include "lua.h"
33 #include "lualib.h"
34 #ifdef __cplusplus
35 }
36 #endif
37
38 namespace libtextclassifier3 {
39
40 static constexpr const char *kLengthKey = "__len";
41 static constexpr const char *kPairsKey = "__pairs";
42 static constexpr const char *kIndexKey = "__index";
43
44 // Casts to the lua user data type.
45 template <typename T>
AsUserData(const T * value)46 void *AsUserData(const T *value) {
47 return static_cast<void *>(const_cast<T *>(value));
48 }
49 template <typename T>
AsUserData(const T value)50 void *AsUserData(const T value) {
51 return reinterpret_cast<void *>(value);
52 }
53
54 // Retrieves up-values.
55 template <typename T>
FromUpValue(const int index,lua_State * state)56 T FromUpValue(const int index, lua_State *state) {
57 return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
58 }
59
60 class LuaEnvironment {
61 public:
62 // Wrapper for handling an iterator.
63 class Iterator {
64 public:
~Iterator()65 virtual ~Iterator() {}
66 static int NextCallback(lua_State *state);
67 static int LengthCallback(lua_State *state);
68 static int ItemCallback(lua_State *state);
69 static int IteritemsCallback(lua_State *state);
70
71 // Called when the next element of an iterator is fetched.
72 virtual int Next(lua_State *state) const = 0;
73
74 // Called when the length of the iterator is queried.
75 virtual int Length(lua_State *state) const = 0;
76
77 // Called when an item is queried.
78 virtual int Item(lua_State *state) const = 0;
79
80 // Called when a new iterator is started.
81 virtual int Iteritems(lua_State *state) const = 0;
82
83 protected:
84 static constexpr int kIteratorArgId = 1;
85 };
86
87 template <typename T>
88 class ItemIterator : public Iterator {
89 public:
NewIterator(StringPiece name,const T * items,lua_State * state)90 void NewIterator(StringPiece name, const T *items, lua_State *state) const {
91 lua_newtable(state);
92 luaL_newmetatable(state, name.data());
93 lua_pushlightuserdata(state, AsUserData(this));
94 lua_pushlightuserdata(state, AsUserData(items));
95 lua_pushcclosure(state, &Iterator::ItemCallback, 2);
96 lua_setfield(state, -2, kIndexKey);
97 lua_pushlightuserdata(state, AsUserData(this));
98 lua_pushlightuserdata(state, AsUserData(items));
99 lua_pushcclosure(state, &Iterator::LengthCallback, 2);
100 lua_setfield(state, -2, kLengthKey);
101 lua_pushlightuserdata(state, AsUserData(this));
102 lua_pushlightuserdata(state, AsUserData(items));
103 lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
104 lua_setfield(state, -2, kPairsKey);
105 lua_setmetatable(state, -2);
106 }
107
Iteritems(lua_State * state)108 int Iteritems(lua_State *state) const override {
109 lua_pushlightuserdata(state, AsUserData(this));
110 lua_pushlightuserdata(
111 state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
112 lua_pushnumber(state, 0);
113 lua_pushcclosure(state, &Iterator::NextCallback, 3);
114 return /*num results=*/1;
115 }
116
Length(lua_State * state)117 int Length(lua_State *state) const override {
118 lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
119 return /*num results=*/1;
120 }
121
Next(lua_State * state)122 int Next(lua_State *state) const override {
123 return Next(FromUpValue<T *>(kItemsArgId, state),
124 lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
125 state);
126 }
127
Next(const T * items,const int64 pos,lua_State * state)128 int Next(const T *items, const int64 pos, lua_State *state) const {
129 if (pos >= items->size()) {
130 return 0;
131 }
132
133 // Update iterator value.
134 lua_pushnumber(state, pos + 1);
135 lua_replace(state, lua_upvalueindex(3));
136
137 // Push key.
138 lua_pushinteger(state, pos + 1);
139
140 // Push item.
141 return 1 + Item(items, pos, state);
142 }
143
Item(lua_State * state)144 int Item(lua_State *state) const override {
145 const T *items = FromUpValue<T *>(kItemsArgId, state);
146 switch (lua_type(state, -1)) {
147 case LUA_TNUMBER: {
148 // Lua is one based, so adjust the index here.
149 const int64 index =
150 static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
151 if (index < 0 || index >= items->size()) {
152 TC3_LOG(ERROR) << "Invalid index: " << index;
153 lua_error(state);
154 return 0;
155 }
156 return Item(items, index, state);
157 }
158 case LUA_TSTRING: {
159 size_t key_length = 0;
160 const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
161 return Item(items, StringPiece(key, key_length), state);
162 }
163 default:
164 TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
165 lua_error(state);
166 return 0;
167 }
168 }
169
170 virtual int Item(const T *items, const int64 pos,
171 lua_State *state) const = 0;
172
Item(const T * items,StringPiece key,lua_State * state)173 virtual int Item(const T *items, StringPiece key, lua_State *state) const {
174 TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
175 lua_error(state);
176 return 0;
177 }
178
179 protected:
180 static constexpr int kItemsArgId = 2;
181 static constexpr int kIterValueArgId = 3;
182 };
183
184 virtual ~LuaEnvironment();
185 LuaEnvironment();
186
187 // Compile a lua snippet into binary bytecode.
188 // NOTE: The compiled bytecode might not be compatible across Lua versions
189 // and platforms.
190 bool Compile(StringPiece snippet, std::string *bytecode);
191
192 typedef int (*CallbackHandler)(lua_State *);
193
194 // Loads default libraries.
195 void LoadDefaultLibraries();
196
197 // Provides a callback to Lua.
198 template <typename T, int (T::*handler)()>
Bind()199 void Bind() {
200 lua_pushlightuserdata(state_, static_cast<void *>(this));
201 lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
202 }
203
204 // Setup a named table that callsback whenever a member is accessed.
205 // This allows to lazily provide required information to the script.
206 template <typename T, int (T::*handler)()>
BindTable(const char * name)207 void BindTable(const char *name) {
208 lua_newtable(state_);
209 luaL_newmetatable(state_, name);
210 lua_pushlightuserdata(state_, static_cast<void *>(this));
211 lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
212 lua_setfield(state_, -2, kIndexKey);
213 lua_setmetatable(state_, -2);
214 }
215
216 void PushValue(const Variant &value);
217
218 // Reads a string from the stack.
219 StringPiece ReadString(const int index) const;
220
221 // Pushes a string to the stack.
222 void PushString(const StringPiece str);
223
224 // Pushes a flatbuffer to the stack.
225 void PushFlatbuffer(const reflection::Schema *schema,
226 const flatbuffers::Table *table);
227
228 // Reads a flatbuffer from the stack.
229 int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);
230
231 // Runs a closure in protected mode.
232 // `func`: closure to run in protected mode.
233 // `num_lua_args`: number of arguments from the lua stack to process.
234 // `num_results`: number of result values pushed on the stack.
235 int RunProtected(const std::function<int()> &func, const int num_args = 0,
236 const int num_results = 0);
237
state()238 lua_State *state() const { return state_; }
239
240 protected:
241 lua_State *state_;
242
243 private:
244 // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
245 static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
246 const reflection::Object *type,
247 const flatbuffers::Table *table, lua_State *state);
248 static int GetFieldCallback(lua_State *state);
249 static int GetField(const reflection::Schema *schema,
250 const reflection::Object *type,
251 const flatbuffers::Table *table, lua_State *state);
252
253 template <typename T, int (T::*handler)()>
Dispatch(lua_State * state)254 static int Dispatch(lua_State *state) {
255 T *env = FromUpValue<T *>(1, state);
256 return ((*env).*handler)();
257 }
258 };
259
260 bool Compile(StringPiece snippet, std::string *bytecode);
261
262 } // namespace libtextclassifier3
263
264 #endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
265