• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-ranker.h"
18 #include "utils/base/logging.h"
19 #include "utils/lua-utils.h"
20 
21 #ifdef __cplusplus
22 extern "C" {
23 #endif
24 #include "lauxlib.h"
25 #include "lualib.h"
26 #ifdef __cplusplus
27 }
28 #endif
29 
30 namespace libtextclassifier3 {
31 
32 std::unique_ptr<ActionsSuggestionsLuaRanker>
Create(const Conversation & conversation,const std::string & ranker_code,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,ActionsSuggestionsResponse * response)33 ActionsSuggestionsLuaRanker::Create(
34     const Conversation& conversation, const std::string& ranker_code,
35     const reflection::Schema* entity_data_schema,
36     const reflection::Schema* annotations_entity_data_schema,
37     ActionsSuggestionsResponse* response) {
38   auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
39       new ActionsSuggestionsLuaRanker(
40           conversation, ranker_code, entity_data_schema,
41           annotations_entity_data_schema, response));
42   if (!ranker->Initialize()) {
43     TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
44     return nullptr;
45   }
46   return ranker;
47 }
48 
Initialize()49 bool ActionsSuggestionsLuaRanker::Initialize() {
50   return RunProtected([this] {
51            LoadDefaultLibraries();
52 
53            // Expose generated actions.
54            actions_iterator_.NewIterator("actions", &response_->actions,
55                                          state_);
56            lua_setglobal(state_, "actions");
57 
58            // Expose conversation message stream.
59            conversation_iterator_.NewIterator("messages",
60                                               &conversation_.messages, state_);
61            lua_setglobal(state_, "messages");
62            return LUA_OK;
63          }) == LUA_OK;
64 }
65 
ReadActionsRanking()66 int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
67   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
68     TC3_LOG(ERROR) << "Expected actions table, got: "
69                    << lua_type(state_, /*idx=*/-1);
70     lua_pop(state_, 1);
71     lua_error(state_);
72     return LUA_ERRRUN;
73   }
74   std::vector<ActionSuggestion> ranked_actions;
75   lua_pushnil(state_);
76   while (lua_next(state_, /*idx=*/-2)) {
77     const int action_id =
78         static_cast<int>(lua_tointeger(state_, /*idx=*/-1)) - 1;
79     lua_pop(state_, 1);
80     if (action_id < 0 || action_id >= response_->actions.size()) {
81       TC3_LOG(ERROR) << "Invalid action index: " << action_id;
82       lua_error(state_);
83       return LUA_ERRRUN;
84     }
85     ranked_actions.push_back(response_->actions[action_id]);
86   }
87   lua_pop(state_, 1);
88   response_->actions = ranked_actions;
89   return LUA_OK;
90 }
91 
RankActions()92 bool ActionsSuggestionsLuaRanker::RankActions() {
93   if (response_->actions.empty()) {
94     // Nothing to do.
95     return true;
96   }
97 
98   if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
99                       /*name=*/nullptr) != LUA_OK) {
100     TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
101     return false;
102   }
103 
104   if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
105     TC3_LOG(ERROR) << "Could not run ranking snippet.";
106     return false;
107   }
108 
109   if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
110       LUA_OK) {
111     TC3_LOG(ERROR) << "Could not read lua result.";
112     return false;
113   }
114   return true;
115 }
116 
117 }  // namespace libtextclassifier3
118