1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ 18 #define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ 19 20 #include "actions/types.h" 21 #include "annotator/types.h" 22 #include "utils/flatbuffers.h" 23 #include "utils/lua-utils.h" 24 25 #ifdef __cplusplus 26 extern "C" { 27 #endif 28 #include "lauxlib.h" 29 #include "lua.h" 30 #include "lualib.h" 31 #ifdef __cplusplus 32 } 33 #endif 34 35 // Action specific shared lua utilities. 36 namespace libtextclassifier3 { 37 38 // Provides an annotation to lua. 39 void PushAnnotation(const ClassificationResult& classification, 40 const reflection::Schema* entity_data_schema, 41 LuaEnvironment* env); 42 void PushAnnotation(const ClassificationResult& classification, 43 StringPiece text, 44 const reflection::Schema* entity_data_schema, 45 LuaEnvironment* env); 46 void PushAnnotation(const ActionSuggestionAnnotation& annotation, 47 const reflection::Schema* entity_data_schema, 48 LuaEnvironment* env); 49 50 // A lua iterator to enumerate annotation. 51 template <typename Annotation> 52 class AnnotationIterator 53 : public LuaEnvironment::ItemIterator<std::vector<Annotation>> { 54 public: AnnotationIterator(const reflection::Schema * entity_data_schema,LuaEnvironment * env)55 AnnotationIterator(const reflection::Schema* entity_data_schema, 56 LuaEnvironment* env) 57 : env_(env), entity_data_schema_(entity_data_schema) {} Item(const std::vector<Annotation> * annotations,const int64 pos,lua_State * state)58 int Item(const std::vector<Annotation>* annotations, const int64 pos, 59 lua_State* state) const override { 60 PushAnnotation((*annotations)[pos], entity_data_schema_, env_); 61 return 1; 62 } 63 int Item(const std::vector<Annotation>* annotations, StringPiece key, 64 lua_State* state) const override; 65 66 private: 67 LuaEnvironment* env_; 68 const reflection::Schema* entity_data_schema_; 69 }; 70 71 template <> 72 int AnnotationIterator<ClassificationResult>::Item( 73 const std::vector<ClassificationResult>* annotations, StringPiece key, 74 lua_State* state) const; 75 76 template <> 77 int AnnotationIterator<ActionSuggestionAnnotation>::Item( 78 const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key, 79 lua_State* state) const; 80 81 void PushAnnotatedSpan( 82 const AnnotatedSpan& annotated_span, 83 const AnnotationIterator<ClassificationResult>& annotation_iterator, 84 LuaEnvironment* env); 85 86 MessageTextSpan ReadSpan(LuaEnvironment* env); 87 ActionSuggestionAnnotation ReadAnnotation( 88 const reflection::Schema* entity_data_schema, LuaEnvironment* env); 89 int ReadAnnotations(const reflection::Schema* entity_data_schema, 90 LuaEnvironment* env, 91 std::vector<ActionSuggestionAnnotation>* annotations); 92 ClassificationResult ReadClassificationResult( 93 const reflection::Schema* entity_data_schema, LuaEnvironment* env); 94 95 // A lua iterator to enumerate annotated spans. 96 class AnnotatedSpanIterator 97 : public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> { 98 public: AnnotatedSpanIterator(const AnnotationIterator<ClassificationResult> & annotation_iterator,LuaEnvironment * env)99 AnnotatedSpanIterator( 100 const AnnotationIterator<ClassificationResult>& annotation_iterator, 101 LuaEnvironment* env) 102 : env_(env), annotation_iterator_(annotation_iterator) {} AnnotatedSpanIterator(const reflection::Schema * entity_data_schema,LuaEnvironment * env)103 AnnotatedSpanIterator(const reflection::Schema* entity_data_schema, 104 LuaEnvironment* env) 105 : env_(env), annotation_iterator_(entity_data_schema, env) {} 106 Item(const std::vector<AnnotatedSpan> * spans,const int64 pos,lua_State * state)107 int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos, 108 lua_State* state) const override { 109 PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_); 110 return /*num results=*/1; 111 } 112 113 private: 114 LuaEnvironment* env_; 115 AnnotationIterator<ClassificationResult> annotation_iterator_; 116 }; 117 118 // Provides an action to lua. 119 void PushAction( 120 const ActionSuggestion& action, 121 const reflection::Schema* entity_data_schema, 122 const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator, 123 LuaEnvironment* env); 124 125 ActionSuggestion ReadAction( 126 const reflection::Schema* actions_entity_data_schema, 127 const reflection::Schema* annotations_entity_data_schema, 128 LuaEnvironment* env); 129 int ReadActions(const reflection::Schema* actions_entity_data_schema, 130 const reflection::Schema* annotations_entity_data_schema, 131 LuaEnvironment* env, std::vector<ActionSuggestion>* actions); 132 133 // A lua iterator to enumerate actions suggestions. 134 class ActionsIterator 135 : public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> { 136 public: ActionsIterator(const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,LuaEnvironment * env)137 ActionsIterator(const reflection::Schema* entity_data_schema, 138 const reflection::Schema* annotations_entity_data_schema, 139 LuaEnvironment* env) 140 : env_(env), 141 entity_data_schema_(entity_data_schema), 142 annotation_iterator_(annotations_entity_data_schema, env) {} Item(const std::vector<ActionSuggestion> * actions,const int64 pos,lua_State * state)143 int Item(const std::vector<ActionSuggestion>* actions, const int64 pos, 144 lua_State* state) const override { 145 PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_, 146 env_); 147 return /*num results=*/1; 148 } 149 150 private: 151 LuaEnvironment* env_; 152 const reflection::Schema* entity_data_schema_; 153 AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_; 154 }; 155 156 // Conversation message lua iterator. 157 class ConversationIterator 158 : public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> { 159 public: ConversationIterator(const AnnotationIterator<ClassificationResult> & annotation_iterator,LuaEnvironment * env)160 ConversationIterator( 161 const AnnotationIterator<ClassificationResult>& annotation_iterator, 162 LuaEnvironment* env) 163 : env_(env), 164 annotated_span_iterator_( 165 AnnotatedSpanIterator(annotation_iterator, env)) {} ConversationIterator(const reflection::Schema * entity_data_schema,LuaEnvironment * env)166 ConversationIterator(const reflection::Schema* entity_data_schema, 167 LuaEnvironment* env) 168 : env_(env), 169 annotated_span_iterator_( 170 AnnotatedSpanIterator(entity_data_schema, env)) {} 171 172 int Item(const std::vector<ConversationMessage>* messages, const int64 pos, 173 lua_State* state) const override; 174 175 private: 176 LuaEnvironment* env_; 177 AnnotatedSpanIterator annotated_span_iterator_; 178 }; 179 180 } // namespace libtextclassifier3 181 182 #endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ 183