• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-actions.h"
18 
19 #include "utils/base/logging.h"
20 #include "utils/lua-utils.h"
21 
22 #ifdef __cplusplus
23 extern "C" {
24 #endif
25 #include "lauxlib.h"
26 #include "lualib.h"
27 #ifdef __cplusplus
28 }
29 #endif
30 
31 namespace libtextclassifier3 {
32 namespace {
33 
GetTensorViewForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)34 TensorView<float> GetTensorViewForOutput(
35     const TfLiteModelExecutor* model_executor,
36     const tflite::Interpreter* interpreter, int output) {
37   if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
38     return TensorView<float>::Invalid();
39   }
40   return model_executor->OutputView<float>(output, interpreter);
41 }
42 
GetStringTensorForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)43 std::vector<std::string> GetStringTensorForOutput(
44     const TfLiteModelExecutor* model_executor,
45     const tflite::Interpreter* interpreter, int output) {
46   if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
47     return {};
48   }
49   return model_executor->Output<std::string>(output, interpreter);
50 }
51 
52 }  // namespace
53 
54 std::unique_ptr<LuaActionsSuggestions>
CreateLuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)55 LuaActionsSuggestions::CreateLuaActionsSuggestions(
56     const std::string& snippet, const Conversation& conversation,
57     const TfLiteModelExecutor* model_executor,
58     const TensorflowLiteModelSpec* model_spec,
59     const tflite::Interpreter* interpreter,
60     const reflection::Schema* actions_entity_data_schema,
61     const reflection::Schema* annotations_entity_data_schema) {
62   auto lua_actions =
63       std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
64           snippet, conversation, model_executor, model_spec, interpreter,
65           actions_entity_data_schema, annotations_entity_data_schema));
66   if (!lua_actions->Initialize()) {
67     TC3_LOG(ERROR)
68         << "Could not initialize lua environment for actions suggestions.";
69     return nullptr;
70   }
71   return lua_actions;
72 }
73 
LuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)74 LuaActionsSuggestions::LuaActionsSuggestions(
75     const std::string& snippet, const Conversation& conversation,
76     const TfLiteModelExecutor* model_executor,
77     const TensorflowLiteModelSpec* model_spec,
78     const tflite::Interpreter* interpreter,
79     const reflection::Schema* actions_entity_data_schema,
80     const reflection::Schema* annotations_entity_data_schema)
81     : snippet_(snippet),
82       conversation_(conversation),
83       actions_scores_(
84           model_spec == nullptr
85               ? TensorView<float>::Invalid()
86               : GetTensorViewForOutput(model_executor, interpreter,
87                                        model_spec->output_actions_scores())),
88       smart_reply_scores_(
89           model_spec == nullptr
90               ? TensorView<float>::Invalid()
91               : GetTensorViewForOutput(model_executor, interpreter,
92                                        model_spec->output_replies_scores())),
93       sensitivity_score_(model_spec == nullptr
94                              ? TensorView<float>::Invalid()
95                              : GetTensorViewForOutput(
96                                    model_executor, interpreter,
97                                    model_spec->output_sensitive_topic_score())),
98       triggering_score_(
99           model_spec == nullptr
100               ? TensorView<float>::Invalid()
101               : GetTensorViewForOutput(model_executor, interpreter,
102                                        model_spec->output_triggering_score())),
103       smart_replies_(model_spec == nullptr ? std::vector<std::string>{}
104                                            : GetStringTensorForOutput(
105                                                  model_executor, interpreter,
106                                                  model_spec->output_replies())),
107       actions_entity_data_schema_(actions_entity_data_schema),
108       annotations_entity_data_schema_(annotations_entity_data_schema) {}
109 
Initialize()110 bool LuaActionsSuggestions::Initialize() {
111   return RunProtected([this] {
112            LoadDefaultLibraries();
113 
114            // Expose conversation message stream.
115            PushConversation(&conversation_.messages,
116                             annotations_entity_data_schema_);
117            lua_setglobal(state_, "messages");
118 
119            // Expose ML model output.
120            lua_newtable(state_);
121 
122            PushTensor(&actions_scores_);
123            lua_setfield(state_, /*idx=*/-2, "actions_scores");
124 
125            PushTensor(&smart_reply_scores_);
126            lua_setfield(state_, /*idx=*/-2, "reply_scores");
127 
128            PushTensor(&sensitivity_score_);
129            lua_setfield(state_, /*idx=*/-2, "sensitivity");
130 
131            PushTensor(&triggering_score_);
132            lua_setfield(state_, /*idx=*/-2, "triggering_score");
133 
134            PushVectorIterator(&smart_replies_);
135            lua_setfield(state_, /*idx=*/-2, "reply");
136 
137            lua_setglobal(state_, "model");
138 
139            return LUA_OK;
140          }) == LUA_OK;
141 }
142 
SuggestActions(std::vector<ActionSuggestion> * actions)143 bool LuaActionsSuggestions::SuggestActions(
144     std::vector<ActionSuggestion>* actions) {
145   if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
146                       /*name=*/nullptr) != LUA_OK) {
147     TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
148     return false;
149   }
150 
151   if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
152     TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
153     return false;
154   }
155 
156   if (RunProtected(
157           [this, actions] {
158             return ReadActions(actions_entity_data_schema_,
159                                annotations_entity_data_schema_, actions);
160           },
161           /*num_args=*/1) != LUA_OK) {
162     TC3_LOG(ERROR) << "Could not read lua result.";
163     return false;
164   }
165   return true;
166 }
167 
168 }  // namespace libtextclassifier3
169