• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
18 #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
19 
20 #include <vector>
21 
22 #include "actions/types.h"
23 #include "annotator/types.h"
24 #include "utils/flatbuffers/mutable.h"
25 #include "utils/strings/stringpiece.h"
26 #include "utils/variant.h"
27 #include "flatbuffers/reflection_generated.h"
28 
29 #ifdef __cplusplus
30 extern "C" {
31 #endif
32 #include "lauxlib.h"
33 #include "lua.h"
34 #include "lualib.h"
35 #ifdef __cplusplus
36 }
37 #endif
38 
39 namespace libtextclassifier3 {
40 
41 static constexpr const char kLengthKey[] = "__len";
42 static constexpr const char kPairsKey[] = "__pairs";
43 static constexpr const char kIndexKey[] = "__index";
44 static constexpr const char kGcKey[] = "__gc";
45 static constexpr const char kNextKey[] = "__next";
46 
47 static constexpr const int kIndexStackTop = -1;
48 
49 // Casts to the lua user data type.
50 template <typename T>
AsUserData(const T * value)51 void* AsUserData(const T* value) {
52   return static_cast<void*>(const_cast<T*>(value));
53 }
54 template <typename T>
AsUserData(const T value)55 void* AsUserData(const T value) {
56   return reinterpret_cast<void*>(value);
57 }
58 
59 // Retrieves up-values.
60 template <typename T>
FromUpValue(const int index,lua_State * state)61 T FromUpValue(const int index, lua_State* state) {
62   return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
63 }
64 
65 class LuaEnvironment {
66  public:
67   virtual ~LuaEnvironment();
68   explicit LuaEnvironment();
69 
70   // Compile a lua snippet into binary bytecode.
71   // NOTE: The compiled bytecode might not be compatible across Lua versions
72   // and platforms.
73   bool Compile(StringPiece snippet, std::string* bytecode) const;
74 
75   // Loads default libraries.
76   void LoadDefaultLibraries();
77 
78   // Provides a callback to Lua.
79   template <typename T>
PushFunction(int (T::* handler)())80   void PushFunction(int (T::*handler)()) {
81     PushFunction(std::bind(handler, static_cast<T*>(this)));
82   }
83 
84   template <typename F>
PushFunction(const F & func)85   void PushFunction(const F& func) const {
86     // Copy closure to the lua stack.
87     new (lua_newuserdata(state_, sizeof(func))) F(func);
88 
89     // Register garbage collection callback.
90     lua_newtable(state_);
91     lua_pushcfunction(state_, &ReleaseFunction<F>);
92     lua_setfield(state_, -2, kGcKey);
93     lua_setmetatable(state_, -2);
94 
95     // Push dispatch.
96     lua_pushcclosure(state_, &CallFunction<F>, 1);
97   }
98 
99   // Sets up a named table that calls back whenever a member is accessed.
100   // This allows to lazily provide required information to the script.
101   template <typename T>
PushLazyObject(int (T::* handler)())102   void PushLazyObject(int (T::*handler)()) {
103     PushLazyObject(std::bind(handler, static_cast<T*>(this)));
104   }
105 
106   template <typename F>
PushLazyObject(const F & func)107   void PushLazyObject(const F& func) const {
108     lua_newtable(state_);
109     lua_newtable(state_);
110     PushFunction(func);
111     lua_setfield(state_, -2, kIndexKey);
112     lua_setmetatable(state_, -2);
113   }
114 
Push(const int64 value)115   void Push(const int64 value) const { lua_pushinteger(state_, value); }
Push(const uint64 value)116   void Push(const uint64 value) const { lua_pushinteger(state_, value); }
Push(const int32 value)117   void Push(const int32 value) const { lua_pushinteger(state_, value); }
Push(const uint32 value)118   void Push(const uint32 value) const { lua_pushinteger(state_, value); }
Push(const int16 value)119   void Push(const int16 value) const { lua_pushinteger(state_, value); }
Push(const uint16 value)120   void Push(const uint16 value) const { lua_pushinteger(state_, value); }
Push(const int8 value)121   void Push(const int8 value) const { lua_pushinteger(state_, value); }
Push(const uint8 value)122   void Push(const uint8 value) const { lua_pushinteger(state_, value); }
Push(const float value)123   void Push(const float value) const { lua_pushnumber(state_, value); }
Push(const double value)124   void Push(const double value) const { lua_pushnumber(state_, value); }
Push(const bool value)125   void Push(const bool value) const { lua_pushboolean(state_, value); }
Push(const StringPiece value)126   void Push(const StringPiece value) const { PushString(value); }
Push(const flatbuffers::String * value)127   void Push(const flatbuffers::String* value) const {
128     if (value == nullptr) {
129       PushString("");
130     } else {
131       PushString(StringPiece(value->c_str(), value->size()));
132     }
133   }
134 
135   template <typename T>
136   T Read(const int index = -1) const;
137 
138   template <>
139   int64 Read<int64>(const int index) const {
140     return static_cast<int64>(lua_tointeger(state_, /*idx=*/index));
141   }
142 
143   template <>
144   uint64 Read<uint64>(const int index) const {
145     return static_cast<uint64>(lua_tointeger(state_, /*idx=*/index));
146   }
147 
148   template <>
149   int32 Read<int32>(const int index) const {
150     return static_cast<int32>(lua_tointeger(state_, /*idx=*/index));
151   }
152 
153   template <>
154   uint32 Read<uint32>(const int index) const {
155     return static_cast<uint32>(lua_tointeger(state_, /*idx=*/index));
156   }
157 
158   template <>
159   int16 Read<int16>(const int index) const {
160     return static_cast<int16>(lua_tointeger(state_, /*idx=*/index));
161   }
162 
163   template <>
164   uint16 Read<uint16>(const int index) const {
165     return static_cast<uint16>(lua_tointeger(state_, /*idx=*/index));
166   }
167 
168   template <>
169   int8 Read<int8>(const int index) const {
170     return static_cast<int8>(lua_tointeger(state_, /*idx=*/index));
171   }
172 
173   template <>
174   uint8 Read<uint8>(const int index) const {
175     return static_cast<uint8>(lua_tointeger(state_, /*idx=*/index));
176   }
177 
178   template <>
179   float Read<float>(const int index) const {
180     return static_cast<float>(lua_tonumber(state_, /*idx=*/index));
181   }
182 
183   template <>
184   double Read<double>(const int index) const {
185     return static_cast<double>(lua_tonumber(state_, /*idx=*/index));
186   }
187 
188   template <>
189   bool Read<bool>(const int index) const {
190     return lua_toboolean(state_, /*idx=*/index);
191   }
192 
193   template <>
194   StringPiece Read<StringPiece>(const int index) const {
195     return ReadString(index);
196   }
197 
198   template <>
199   std::string Read<std::string>(const int index) const {
200     return ReadString(index).ToString();
201   }
202 
203   // Reads a string from the stack.
204   StringPiece ReadString(int index) const;
205 
206   // Pushes a string to the stack.
207   void PushString(const StringPiece str) const;
208 
209   // Pushes a flatbuffer to the stack.
PushFlatbuffer(const reflection::Schema * schema,const flatbuffers::Table * table)210   void PushFlatbuffer(const reflection::Schema* schema,
211                       const flatbuffers::Table* table) const {
212     PushFlatbuffer(schema, schema->root_table(), table);
213   }
214 
215   // Reads a flatbuffer from the stack.
216   int ReadFlatbuffer(int index, MutableFlatbuffer* buffer) const;
217 
218   // Pushes an iterator.
219   template <typename ItemCallback, typename KeyCallback>
PushIterator(const int length,const ItemCallback & item_callback,const KeyCallback & key_callback)220   void PushIterator(const int length, const ItemCallback& item_callback,
221                     const KeyCallback& key_callback) const {
222     lua_newtable(state_);
223     CreateIteratorMetatable(length, item_callback);
224     PushFunction([this, length, item_callback, key_callback]() {
225       return Iterator::Dispatch(this, length, item_callback, key_callback);
226     });
227     lua_setfield(state_, -2, kIndexKey);
228     lua_setmetatable(state_, -2);
229   }
230 
231   template <typename ItemCallback>
PushIterator(const int length,const ItemCallback & item_callback)232   void PushIterator(const int length, const ItemCallback& item_callback) const {
233     lua_newtable(state_);
234     CreateIteratorMetatable(length, item_callback);
235     PushFunction([this, length, item_callback]() {
236       return Iterator::Dispatch(this, length, item_callback);
237     });
238     lua_setfield(state_, -2, kIndexKey);
239     lua_setmetatable(state_, -2);
240   }
241 
242   template <typename ItemCallback>
CreateIteratorMetatable(const int length,const ItemCallback & item_callback)243   void CreateIteratorMetatable(const int length,
244                                const ItemCallback& item_callback) const {
245     lua_newtable(state_);
246     PushFunction([this, length]() { return Iterator::Length(this, length); });
247     lua_setfield(state_, -2, kLengthKey);
248     PushFunction([this, length, item_callback]() {
249       return Iterator::IterItems(this, length, item_callback);
250     });
251     lua_setfield(state_, -2, kPairsKey);
252     PushFunction([this, length, item_callback]() {
253       return Iterator::Next(this, length, item_callback);
254     });
255     lua_setfield(state_, -2, kNextKey);
256   }
257 
258   template <typename T>
PushVectorIterator(const std::vector<T> * items)259   void PushVectorIterator(const std::vector<T>* items) const {
260     PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
261       this->Push(items->at(pos));
262       return 1;
263     });
264   }
265 
266   template <typename T>
PushVector(const std::vector<T> & items)267   void PushVector(const std::vector<T>& items) const {
268     lua_newtable(state_);
269     for (int i = 0; i < items.size(); i++) {
270       // Key: index, 1-based.
271       Push(i + 1);
272 
273       // Value.
274       Push(items[i]);
275       lua_settable(state_, /*idx=*/-3);
276     }
277   }
278 
PushEmptyVector()279   void PushEmptyVector() const { lua_newtable(state_); }
280 
281   template <typename T>
282   std::vector<T> ReadVector(const int index = -1) const {
283     std::vector<T> result;
284     if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
285       TC3_LOG(ERROR) << "Expected a table, got: "
286                      << lua_type(state_, /*idx=*/kIndexStackTop);
287       lua_pop(state_, 1);
288       return {};
289     }
290     lua_pushnil(state_);
291     while (Next(index - 1)) {
292       result.push_back(Read<T>(/*index=*/kIndexStackTop));
293       lua_pop(state_, 1);
294     }
295     return result;
296   }
297 
298   // Runs a closure in protected mode.
299   // `func`: closure to run in protected mode.
300   // `num_lua_args`: number of arguments from the lua stack to process.
301   // `num_results`: number of result values pushed on the stack.
302   template <typename F>
303   int RunProtected(const F& func, const int num_args = 0,
304                    const int num_results = 0) const {
305     PushFunction(func);
306     // Put the closure before the arguments on the stack.
307     if (num_args > 0) {
308       lua_insert(state_, -(1 + num_args));
309     }
310     return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
311   }
312 
313   // Auxiliary methods to handle model results.
314   // Provides an annotation to lua.
315   void PushAnnotation(const ClassificationResult& classification,
316                       const reflection::Schema* entity_data_schema) const;
317   void PushAnnotation(const ClassificationResult& classification,
318                       StringPiece text,
319                       const reflection::Schema* entity_data_schema) const;
320   void PushAnnotation(const ActionSuggestionAnnotation& annotation,
321                       const reflection::Schema* entity_data_schema) const;
322 
323   template <typename Annotation>
PushAnnotations(const std::vector<Annotation> * annotations,const reflection::Schema * entity_data_schema)324   void PushAnnotations(const std::vector<Annotation>* annotations,
325                        const reflection::Schema* entity_data_schema) const {
326     PushIterator(
327         annotations ? annotations->size() : 0,
328         [this, annotations, entity_data_schema](const int64 index) {
329           PushAnnotation(annotations->at(index), entity_data_schema);
330           return 1;
331         },
332         [this, annotations, entity_data_schema](StringPiece name) {
333           if (const Annotation* annotation =
334                   GetAnnotationByName(*annotations, name)) {
335             PushAnnotation(*annotation, entity_data_schema);
336             return 1;
337           } else {
338             return 0;
339           }
340         });
341   }
342 
343   // Pushes a span to the lua stack.
344   void PushAnnotatedSpan(const AnnotatedSpan& annotated_span,
345                          const reflection::Schema* entity_data_schema) const;
346   void PushAnnotatedSpans(const std::vector<AnnotatedSpan>* annotated_spans,
347                           const reflection::Schema* entity_data_schema) const;
348 
349   // Reads a message text span from lua.
350   MessageTextSpan ReadSpan() const;
351 
352   ActionSuggestionAnnotation ReadAnnotation(
353       const reflection::Schema* entity_data_schema) const;
354   int ReadAnnotations(
355       const reflection::Schema* entity_data_schema,
356       std::vector<ActionSuggestionAnnotation>* annotations) const;
357   ClassificationResult ReadClassificationResult(
358       const reflection::Schema* entity_data_schema) const;
359 
360   // Provides an action to lua.
361   void PushAction(
362       const ActionSuggestion& action,
363       const reflection::Schema* actions_entity_data_schema,
364       const reflection::Schema* annotations_entity_data_schema) const;
365 
366   void PushActions(
367       const std::vector<ActionSuggestion>* actions,
368       const reflection::Schema* actions_entity_data_schema,
369       const reflection::Schema* annotations_entity_data_schema) const;
370 
371   ActionSuggestion ReadAction(
372       const reflection::Schema* actions_entity_data_schema,
373       const reflection::Schema* annotations_entity_data_schema) const;
374 
375   int ReadActions(const reflection::Schema* actions_entity_data_schema,
376                   const reflection::Schema* annotations_entity_data_schema,
377                   std::vector<ActionSuggestion>* actions) const;
378 
379   // Conversation message iterator.
380   void PushConversation(
381       const std::vector<ConversationMessage>* conversation,
382       const reflection::Schema* annotations_entity_data_schema) const;
383 
state()384   lua_State* state() const { return state_; }
385 
386  protected:
387   // Wrapper for handling iteration over containers.
388   class Iterator {
389    public:
390     // Starts a new key-value pair iterator.
391     template <typename ItemCallback>
IterItems(const LuaEnvironment * env,const int length,const ItemCallback & callback)392     static int IterItems(const LuaEnvironment* env, const int length,
393                          const ItemCallback& callback) {
394       env->PushFunction([env, callback, length, pos = 0]() mutable {
395         if (pos >= length) {
396           lua_pushnil(env->state());
397           return 1;
398         }
399 
400         // Push key.
401         lua_pushinteger(env->state(), pos + 1);
402 
403         // Push item.
404         return 1 + callback(pos++);
405       });
406       return 1;  // Num. results.
407     }
408 
409     // Gets the next element.
410     template <typename ItemCallback>
Next(const LuaEnvironment * env,const int length,const ItemCallback & item_callback)411     static int Next(const LuaEnvironment* env, const int length,
412                     const ItemCallback& item_callback) {
413       int64 pos = lua_isnil(env->state(), /*idx=*/kIndexStackTop)
414                       ? 0
415                       : env->Read<int64>(/*index=*/kIndexStackTop);
416       if (pos < length) {
417         // Push next key.
418         lua_pushinteger(env->state(), pos + 1);
419 
420         // Push item.
421         return 1 + item_callback(pos);
422       } else {
423         lua_pushnil(env->state());
424         return 1;
425       }
426     }
427 
428     // Returns the length of the container the iterator processes.
Length(const LuaEnvironment * env,const int length)429     static int Length(const LuaEnvironment* env, const int length) {
430       lua_pushinteger(env->state(), length);
431       return 1;  // Num. results.
432     }
433 
434     // Handles item queries to the iterator.
435     // Elements of the container can either be queried by name or index.
436     // Dispatch will check how an element is accessed and
437     // calls `key_callback` for access by name and `item_callback` for access by
438     // index.
439     template <typename ItemCallback, typename KeyCallback>
Dispatch(const LuaEnvironment * env,const int length,const ItemCallback & item_callback,const KeyCallback & key_callback)440     static int Dispatch(const LuaEnvironment* env, const int length,
441                         const ItemCallback& item_callback,
442                         const KeyCallback& key_callback) {
443       switch (lua_type(env->state(), kIndexStackTop)) {
444         case LUA_TNUMBER: {
445           // Lua is one based, so adjust the index here.
446           const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
447           if (index < 0 || index >= length) {
448             TC3_LOG(ERROR) << "Invalid index: " << index;
449             lua_error(env->state());
450             return 0;
451           }
452           return item_callback(index);
453         }
454         case LUA_TSTRING: {
455           return key_callback(env->ReadString(kIndexStackTop));
456         }
457         default:
458           TC3_LOG(ERROR) << "Unexpected access type: "
459                          << lua_type(env->state(), kIndexStackTop);
460           lua_error(env->state());
461           return 0;
462       }
463     }
464 
465     template <typename ItemCallback>
Dispatch(const LuaEnvironment * env,const int length,const ItemCallback & item_callback)466     static int Dispatch(const LuaEnvironment* env, const int length,
467                         const ItemCallback& item_callback) {
468       switch (lua_type(env->state(), kIndexStackTop)) {
469         case LUA_TNUMBER: {
470           // Lua is one based, so adjust the index here.
471           const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
472           if (index < 0 || index >= length) {
473             TC3_LOG(ERROR) << "Invalid index: " << index;
474             lua_error(env->state());
475             return 0;
476           }
477           return item_callback(index);
478         }
479         default:
480           TC3_LOG(ERROR) << "Unexpected access type: "
481                          << lua_type(env->state(), kIndexStackTop);
482           lua_error(env->state());
483           return 0;
484       }
485     }
486   };
487 
488   // Calls the deconstructor from a previously pushed function.
489   template <typename T>
ReleaseFunction(lua_State * state)490   static int ReleaseFunction(lua_State* state) {
491     static_cast<T*>(lua_touserdata(state, 1))->~T();
492     return 0;
493   }
494 
495   template <typename T>
CallFunction(lua_State * state)496   static int CallFunction(lua_State* state) {
497     return (*static_cast<T*>(lua_touserdata(state, lua_upvalueindex(1))))();
498   }
499 
500   // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
501   void PushFlatbuffer(const reflection::Schema* schema,
502                       const reflection::Object* type,
503                       const flatbuffers::Table* table) const;
504   int GetField(const reflection::Schema* schema, const reflection::Object* type,
505                const flatbuffers::Table* table) const;
506 
507   // Reads a repeated field from lua.
508   template <typename T>
ReadRepeatedField(const int index,RepeatedField * result)509   void ReadRepeatedField(const int index, RepeatedField* result) const {
510     for (const T& element : ReadVector<T>(index)) {
511       result->Add(element);
512     }
513   }
514 
515   template <>
516   void ReadRepeatedField<MutableFlatbuffer>(const int index,
517                                             RepeatedField* result) const {
518     lua_pushnil(state_);
519     while (Next(index - 1)) {
520       ReadFlatbuffer(index, result->Add());
521       lua_pop(state_, 1);
522     }
523   }
524 
525   // Pushes a repeated field to the lua stack.
526   template <typename T>
PushRepeatedField(const flatbuffers::Vector<T> * items)527   void PushRepeatedField(const flatbuffers::Vector<T>* items) const {
528     PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
529       Push(items->Get(pos));
530       return 1;  // Num. results.
531     });
532   }
533 
PushRepeatedFlatbufferField(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>> * items)534   void PushRepeatedFlatbufferField(
535       const reflection::Schema* schema, const reflection::Object* type,
536       const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>* items)
537       const {
538     PushIterator(items ? items->size() : 0,
539                  [this, schema, type, items](const int64 pos) {
540                    PushFlatbuffer(schema, type, items->Get(pos));
541                    return 1;  // Num. results.
542                  });
543   }
544 
545   // Overloads Lua next function to use __next key on the metatable.
546   // This allows us to treat lua objects and lazy objects provided by our
547   // callbacks uniformly.
Next(int index)548   int Next(int index) const {
549     // Check whether the (meta)table of this object has an associated "__next"
550     // entry. This means, we registered our own callback. So we explicitly call
551     // that.
552     if (luaL_getmetafield(state_, index, kNextKey)) {
553       // Callback is now on top of the stack, so adjust relative indices by 1.
554       if (index < 0) {
555         index--;
556       }
557 
558       // Copy the reference to the table.
559       lua_pushvalue(state_, index);
560 
561       // Move the key to top to have it as second argument for the callback.
562       // Copy the key to the top.
563       lua_pushvalue(state_, -3);
564 
565       // Remove the copy of the key.
566       lua_remove(state_, -4);
567 
568       // Call the callback with (key and table as arguments).
569       lua_pcall(state_, /*nargs=*/2 /* table, key */,
570                 /*nresults=*/2 /* key, item */, 0);
571 
572       // Next returned nil, it's the end.
573       if (lua_isnil(state_, kIndexStackTop)) {
574         // Remove nil value.
575         // Results will be padded to `nresults` specified above, so we need
576         // to remove two elements here.
577         lua_pop(state_, 2);
578         return 0;
579       }
580 
581       return 2;  // Num. results.
582     } else if (lua_istable(state_, index)) {
583       return lua_next(state_, index);
584     }
585 
586     // Remove the key.
587     lua_pop(state_, 1);
588     return 0;
589   }
590 
GetAnnotationByName(const std::vector<ClassificationResult> & annotations,StringPiece name)591   static const ClassificationResult* GetAnnotationByName(
592       const std::vector<ClassificationResult>& annotations, StringPiece name) {
593     // Lookup annotation by collection.
594     for (const ClassificationResult& annotation : annotations) {
595       if (name.Equals(annotation.collection)) {
596         return &annotation;
597       }
598     }
599     TC3_LOG(ERROR) << "No annotation with collection: " << name << " found.";
600     return nullptr;
601   }
602 
GetAnnotationByName(const std::vector<ActionSuggestionAnnotation> & annotations,StringPiece name)603   static const ActionSuggestionAnnotation* GetAnnotationByName(
604       const std::vector<ActionSuggestionAnnotation>& annotations,
605       StringPiece name) {
606     // Lookup annotation by name.
607     for (const ActionSuggestionAnnotation& annotation : annotations) {
608       if (name.Equals(annotation.name)) {
609         return &annotation;
610       }
611     }
612     TC3_LOG(ERROR) << "No annotation with name: " << name << " found.";
613     return nullptr;
614   }
615 
616   lua_State* state_;
617 };  // namespace libtextclassifier3
618 
619 bool Compile(StringPiece snippet, std::string* bytecode);
620 
621 }  // namespace libtextclassifier3
622 
623 #endif  // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
624