1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/lua-ranker.h"
18 #include "utils/base/logging.h"
19 #include "utils/lua-utils.h"
20
21 #ifdef __cplusplus
22 extern "C" {
23 #endif
24 #include "lauxlib.h"
25 #include "lualib.h"
26 #ifdef __cplusplus
27 }
28 #endif
29
30 namespace libtextclassifier3 {
31
32 std::unique_ptr<ActionsSuggestionsLuaRanker>
Create(const Conversation & conversation,const std::string & ranker_code,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,ActionsSuggestionsResponse * response)33 ActionsSuggestionsLuaRanker::Create(
34 const Conversation& conversation, const std::string& ranker_code,
35 const reflection::Schema* entity_data_schema,
36 const reflection::Schema* annotations_entity_data_schema,
37 ActionsSuggestionsResponse* response) {
38 auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
39 new ActionsSuggestionsLuaRanker(
40 conversation, ranker_code, entity_data_schema,
41 annotations_entity_data_schema, response));
42 if (!ranker->Initialize()) {
43 TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
44 return nullptr;
45 }
46 return ranker;
47 }
48
Initialize()49 bool ActionsSuggestionsLuaRanker::Initialize() {
50 return RunProtected([this] {
51 LoadDefaultLibraries();
52
53 // Expose generated actions.
54 actions_iterator_.NewIterator("actions", &response_->actions,
55 state_);
56 lua_setglobal(state_, "actions");
57
58 // Expose conversation message stream.
59 conversation_iterator_.NewIterator("messages",
60 &conversation_.messages, state_);
61 lua_setglobal(state_, "messages");
62 return LUA_OK;
63 }) == LUA_OK;
64 }
65
ReadActionsRanking()66 int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
67 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
68 TC3_LOG(ERROR) << "Expected actions table, got: "
69 << lua_type(state_, /*idx=*/-1);
70 lua_pop(state_, 1);
71 lua_error(state_);
72 return LUA_ERRRUN;
73 }
74 std::vector<ActionSuggestion> ranked_actions;
75 lua_pushnil(state_);
76 while (lua_next(state_, /*idx=*/-2)) {
77 const int action_id =
78 static_cast<int>(lua_tointeger(state_, /*idx=*/-1)) - 1;
79 lua_pop(state_, 1);
80 if (action_id < 0 || action_id >= response_->actions.size()) {
81 TC3_LOG(ERROR) << "Invalid action index: " << action_id;
82 lua_error(state_);
83 return LUA_ERRRUN;
84 }
85 ranked_actions.push_back(response_->actions[action_id]);
86 }
87 lua_pop(state_, 1);
88 response_->actions = ranked_actions;
89 return LUA_OK;
90 }
91
RankActions()92 bool ActionsSuggestionsLuaRanker::RankActions() {
93 if (response_->actions.empty()) {
94 // Nothing to do.
95 return true;
96 }
97
98 if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
99 /*name=*/nullptr) != LUA_OK) {
100 TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
101 return false;
102 }
103
104 if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
105 TC3_LOG(ERROR) << "Could not run ranking snippet.";
106 return false;
107 }
108
109 if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
110 LUA_OK) {
111 TC3_LOG(ERROR) << "Could not read lua result.";
112 return false;
113 }
114 return true;
115 }
116
117 } // namespace libtextclassifier3
118