• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-actions.h"
18 #include "utils/base/logging.h"
19 #include "utils/lua-utils.h"
20 
21 #ifdef __cplusplus
22 extern "C" {
23 #endif
24 #include "lauxlib.h"
25 #include "lualib.h"
26 #ifdef __cplusplus
27 }
28 #endif
29 
30 namespace libtextclassifier3 {
31 namespace {
GetTensorViewForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)32 TensorView<float> GetTensorViewForOutput(
33     const TfLiteModelExecutor* model_executor,
34     const tflite::Interpreter* interpreter, int output) {
35   if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
36     return TensorView<float>::Invalid();
37   }
38   return model_executor->OutputView<float>(output, interpreter);
39 }
40 }  // namespace
41 
Item(const TensorView<float> * tensor,const int64 index,lua_State * state) const42 int LuaActionsSuggestions::TensorViewIterator::Item(
43     const TensorView<float>* tensor, const int64 index,
44     lua_State* state) const {
45   lua_pushnumber(state, tensor->data()[index]);
46   return 1;
47 }
48 
49 std::unique_ptr<LuaActionsSuggestions>
CreateLuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)50 LuaActionsSuggestions::CreateLuaActionsSuggestions(
51     const std::string& snippet, const Conversation& conversation,
52     const TfLiteModelExecutor* model_executor,
53     const TensorflowLiteModelSpec* model_spec,
54     const tflite::Interpreter* interpreter,
55     const reflection::Schema* actions_entity_data_schema,
56     const reflection::Schema* annotations_entity_data_schema) {
57   auto lua_actions =
58       std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
59           snippet, conversation, model_executor, model_spec, interpreter,
60           actions_entity_data_schema, annotations_entity_data_schema));
61   if (!lua_actions->Initialize()) {
62     TC3_LOG(ERROR)
63         << "Could not initialize lua environment for actions suggestions.";
64     return nullptr;
65   }
66   return lua_actions;
67 }
68 
LuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)69 LuaActionsSuggestions::LuaActionsSuggestions(
70     const std::string& snippet, const Conversation& conversation,
71     const TfLiteModelExecutor* model_executor,
72     const TensorflowLiteModelSpec* model_spec,
73     const tflite::Interpreter* interpreter,
74     const reflection::Schema* actions_entity_data_schema,
75     const reflection::Schema* annotations_entity_data_schema)
76     : snippet_(snippet),
77       conversation_(conversation),
78       conversation_iterator_(annotations_entity_data_schema, this),
79       actions_scores_(
80           model_spec == nullptr
81               ? TensorView<float>::Invalid()
82               : GetTensorViewForOutput(model_executor, interpreter,
83                                        model_spec->output_actions_scores())),
84       smart_reply_scores_(
85           model_spec == nullptr
86               ? TensorView<float>::Invalid()
87               : GetTensorViewForOutput(model_executor, interpreter,
88                                        model_spec->output_replies_scores())),
89       sensitivity_score_(model_spec == nullptr
90                              ? TensorView<float>::Invalid()
91                              : GetTensorViewForOutput(
92                                    model_executor, interpreter,
93                                    model_spec->output_sensitive_topic_score())),
94       triggering_score_(
95           model_spec == nullptr
96               ? TensorView<float>::Invalid()
97               : GetTensorViewForOutput(model_executor, interpreter,
98                                        model_spec->output_triggering_score())),
99       actions_entity_data_schema_(actions_entity_data_schema),
100       annotations_entity_data_schema_(annotations_entity_data_schema) {}
101 
Initialize()102 bool LuaActionsSuggestions::Initialize() {
103   return RunProtected([this] {
104            LoadDefaultLibraries();
105 
106            // Expose conversation message stream.
107            conversation_iterator_.NewIterator("messages",
108                                               &conversation_.messages, state_);
109            lua_setglobal(state_, "messages");
110 
111            // Expose ML model output.
112            lua_newtable(state_);
113            {
114              tensor_iterator_.NewIterator("actions_scores", &actions_scores_,
115                                           state_);
116              lua_setfield(state_, /*idx=*/-2, "actions_scores");
117            }
118            {
119              tensor_iterator_.NewIterator("reply_scores", &smart_reply_scores_,
120                                           state_);
121              lua_setfield(state_, /*idx=*/-2, "reply_scores");
122            }
123            {
124              tensor_iterator_.NewIterator("sensitivity", &sensitivity_score_,
125                                           state_);
126              lua_setfield(state_, /*idx=*/-2, "sensitivity");
127            }
128            {
129              tensor_iterator_.NewIterator("triggering_score",
130                                           &triggering_score_, state_);
131              lua_setfield(state_, /*idx=*/-2, "triggering_score");
132            }
133            lua_setglobal(state_, "model");
134 
135            return LUA_OK;
136          }) == LUA_OK;
137 }
138 
SuggestActions(std::vector<ActionSuggestion> * actions)139 bool LuaActionsSuggestions::SuggestActions(
140     std::vector<ActionSuggestion>* actions) {
141   if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
142                       /*name=*/nullptr) != LUA_OK) {
143     TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
144     return false;
145   }
146 
147   if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
148     TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
149     return false;
150   }
151 
152   if (RunProtected(
153           [this, actions] {
154             return ReadActions(actions_entity_data_schema_,
155                                annotations_entity_data_schema_, this, actions);
156           },
157           /*num_args=*/1) != LUA_OK) {
158     TC3_LOG(ERROR) << "Could not read lua result.";
159     return false;
160   }
161   return true;
162 }
163 
164 }  // namespace libtextclassifier3
165