1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/lua-actions.h"
18
19 #include "utils/base/logging.h"
20 #include "utils/lua-utils.h"
21
22 #ifdef __cplusplus
23 extern "C" {
24 #endif
25 #include "lauxlib.h"
26 #include "lualib.h"
27 #ifdef __cplusplus
28 }
29 #endif
30
31 namespace libtextclassifier3 {
32 namespace {
33
GetTensorViewForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)34 TensorView<float> GetTensorViewForOutput(
35 const TfLiteModelExecutor* model_executor,
36 const tflite::Interpreter* interpreter, int output) {
37 if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
38 return TensorView<float>::Invalid();
39 }
40 return model_executor->OutputView<float>(output, interpreter);
41 }
42
GetStringTensorForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)43 std::vector<std::string> GetStringTensorForOutput(
44 const TfLiteModelExecutor* model_executor,
45 const tflite::Interpreter* interpreter, int output) {
46 if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
47 return {};
48 }
49 return model_executor->Output<std::string>(output, interpreter);
50 }
51
52 } // namespace
53
54 std::unique_ptr<LuaActionsSuggestions>
CreateLuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)55 LuaActionsSuggestions::CreateLuaActionsSuggestions(
56 const std::string& snippet, const Conversation& conversation,
57 const TfLiteModelExecutor* model_executor,
58 const TensorflowLiteModelSpec* model_spec,
59 const tflite::Interpreter* interpreter,
60 const reflection::Schema* actions_entity_data_schema,
61 const reflection::Schema* annotations_entity_data_schema) {
62 auto lua_actions =
63 std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
64 snippet, conversation, model_executor, model_spec, interpreter,
65 actions_entity_data_schema, annotations_entity_data_schema));
66 if (!lua_actions->Initialize()) {
67 TC3_LOG(ERROR)
68 << "Could not initialize lua environment for actions suggestions.";
69 return nullptr;
70 }
71 return lua_actions;
72 }
73
LuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)74 LuaActionsSuggestions::LuaActionsSuggestions(
75 const std::string& snippet, const Conversation& conversation,
76 const TfLiteModelExecutor* model_executor,
77 const TensorflowLiteModelSpec* model_spec,
78 const tflite::Interpreter* interpreter,
79 const reflection::Schema* actions_entity_data_schema,
80 const reflection::Schema* annotations_entity_data_schema)
81 : snippet_(snippet),
82 conversation_(conversation),
83 actions_scores_(
84 model_spec == nullptr
85 ? TensorView<float>::Invalid()
86 : GetTensorViewForOutput(model_executor, interpreter,
87 model_spec->output_actions_scores())),
88 smart_reply_scores_(
89 model_spec == nullptr
90 ? TensorView<float>::Invalid()
91 : GetTensorViewForOutput(model_executor, interpreter,
92 model_spec->output_replies_scores())),
93 sensitivity_score_(model_spec == nullptr
94 ? TensorView<float>::Invalid()
95 : GetTensorViewForOutput(
96 model_executor, interpreter,
97 model_spec->output_sensitive_topic_score())),
98 triggering_score_(
99 model_spec == nullptr
100 ? TensorView<float>::Invalid()
101 : GetTensorViewForOutput(model_executor, interpreter,
102 model_spec->output_triggering_score())),
103 smart_replies_(model_spec == nullptr ? std::vector<std::string>{}
104 : GetStringTensorForOutput(
105 model_executor, interpreter,
106 model_spec->output_replies())),
107 actions_entity_data_schema_(actions_entity_data_schema),
108 annotations_entity_data_schema_(annotations_entity_data_schema) {}
109
Initialize()110 bool LuaActionsSuggestions::Initialize() {
111 return RunProtected([this] {
112 LoadDefaultLibraries();
113
114 // Expose conversation message stream.
115 PushConversation(&conversation_.messages,
116 annotations_entity_data_schema_);
117 lua_setglobal(state_, "messages");
118
119 // Expose ML model output.
120 lua_newtable(state_);
121
122 PushTensor(&actions_scores_);
123 lua_setfield(state_, /*idx=*/-2, "actions_scores");
124
125 PushTensor(&smart_reply_scores_);
126 lua_setfield(state_, /*idx=*/-2, "reply_scores");
127
128 PushTensor(&sensitivity_score_);
129 lua_setfield(state_, /*idx=*/-2, "sensitivity");
130
131 PushTensor(&triggering_score_);
132 lua_setfield(state_, /*idx=*/-2, "triggering_score");
133
134 PushVectorIterator(&smart_replies_);
135 lua_setfield(state_, /*idx=*/-2, "reply");
136
137 lua_setglobal(state_, "model");
138
139 return LUA_OK;
140 }) == LUA_OK;
141 }
142
SuggestActions(std::vector<ActionSuggestion> * actions)143 bool LuaActionsSuggestions::SuggestActions(
144 std::vector<ActionSuggestion>* actions) {
145 if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
146 /*name=*/nullptr) != LUA_OK) {
147 TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
148 return false;
149 }
150
151 if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
152 TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
153 return false;
154 }
155
156 if (RunProtected(
157 [this, actions] {
158 return ReadActions(actions_entity_data_schema_,
159 annotations_entity_data_schema_, actions);
160 },
161 /*num_args=*/1) != LUA_OK) {
162 TC3_LOG(ERROR) << "Could not read lua result.";
163 return false;
164 }
165 return true;
166 }
167
168 } // namespace libtextclassifier3
169