1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/lua-ranker.h"
18
19 #include <string>
20
21 #include "actions/types.h"
22 #include "utils/flatbuffers/mutable.h"
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25
26 namespace libtextclassifier3 {
27 namespace {
28
29 MATCHER_P2(IsAction, type, response_text, "") {
30 return testing::Value(arg.type, type) &&
31 testing::Value(arg.response_text, response_text);
32 }
33
34 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
35
TestEntitySchema()36 std::string TestEntitySchema() {
37 // Create fake entity data schema meta data.
38 // Cannot use object oriented API here as that is not available for the
39 // reflection schema.
40 flatbuffers::FlatBufferBuilder schema_builder;
41 std::vector<flatbuffers::Offset<reflection::Field>> fields = {
42 reflection::CreateField(
43 schema_builder,
44 /*name=*/schema_builder.CreateString("test"),
45 /*type=*/
46 reflection::CreateType(schema_builder,
47 /*base_type=*/reflection::String),
48 /*id=*/0,
49 /*offset=*/4)};
50 std::vector<flatbuffers::Offset<reflection::Enum>> enums;
51 std::vector<flatbuffers::Offset<reflection::Object>> objects = {
52 reflection::CreateObject(
53 schema_builder,
54 /*name=*/schema_builder.CreateString("EntityData"),
55 /*fields=*/
56 schema_builder.CreateVectorOfSortedTables(&fields))};
57 schema_builder.Finish(reflection::CreateSchema(
58 schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
59 schema_builder.CreateVectorOfSortedTables(&enums),
60 /*(unused) file_ident=*/0,
61 /*(unused) file_ext=*/0,
62 /*root_table*/ objects[0]));
63 return std::string(
64 reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
65 schema_builder.GetSize());
66 }
67
TEST(LuaRankingTest,PassThrough)68 TEST(LuaRankingTest, PassThrough) {
69 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
70 ActionsSuggestionsResponse response;
71 response.actions = {
72 {/*response_text=*/"hello there", /*type=*/"text_reply",
73 /*score=*/1.0},
74 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
75 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
76 const std::string test_snippet = R"(
77 local result = {}
78 for i=1,#actions do
79 table.insert(result, i)
80 end
81 return result
82 )";
83
84 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
85 conversation, test_snippet, /*entity_data_schema=*/nullptr,
86 /*annotations_entity_data_schema=*/nullptr, &response)
87 ->RankActions());
88 EXPECT_THAT(response.actions,
89 testing::ElementsAreArray({IsActionType("text_reply"),
90 IsActionType("share_location"),
91 IsActionType("add_to_collection")}));
92 }
93
TEST(LuaRankingTest,Filtering)94 TEST(LuaRankingTest, Filtering) {
95 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
96 ActionsSuggestionsResponse response;
97 response.actions = {
98 {/*response_text=*/"hello there", /*type=*/"text_reply",
99 /*score=*/1.0},
100 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
101 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
102 const std::string test_snippet = R"(
103 return {}
104 )";
105
106 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
107 conversation, test_snippet, /*entity_data_schema=*/nullptr,
108 /*annotations_entity_data_schema=*/nullptr, &response)
109 ->RankActions());
110 EXPECT_THAT(response.actions, testing::IsEmpty());
111 }
112
TEST(LuaRankingTest,Duplication)113 TEST(LuaRankingTest, Duplication) {
114 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
115 ActionsSuggestionsResponse response;
116 response.actions = {
117 {/*response_text=*/"hello there", /*type=*/"text_reply",
118 /*score=*/1.0},
119 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
120 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
121 const std::string test_snippet = R"(
122 local result = {}
123 for i=1,#actions do
124 table.insert(result, 1)
125 end
126 return result
127 )";
128
129 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
130 conversation, test_snippet, /*entity_data_schema=*/nullptr,
131 /*annotations_entity_data_schema=*/nullptr, &response)
132 ->RankActions());
133 EXPECT_THAT(response.actions,
134 testing::ElementsAreArray({IsActionType("text_reply"),
135 IsActionType("text_reply"),
136 IsActionType("text_reply")}));
137 }
138
TEST(LuaRankingTest,SortByScore)139 TEST(LuaRankingTest, SortByScore) {
140 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
141 ActionsSuggestionsResponse response;
142 response.actions = {
143 {/*response_text=*/"hello there", /*type=*/"text_reply",
144 /*score=*/1.0},
145 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
146 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
147 const std::string test_snippet = R"(
148 function testScoreSorter(a, b)
149 return actions[a].score < actions[b].score
150 end
151 local result = {}
152 for i=1,#actions do
153 result[i] = i
154 end
155 table.sort(result, testScoreSorter)
156 return result
157 )";
158
159 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
160 conversation, test_snippet, /*entity_data_schema=*/nullptr,
161 /*annotations_entity_data_schema=*/nullptr, &response)
162 ->RankActions());
163 EXPECT_THAT(response.actions,
164 testing::ElementsAreArray({IsActionType("add_to_collection"),
165 IsActionType("share_location"),
166 IsActionType("text_reply")}));
167 }
168
TEST(LuaRankingTest,SuppressType)169 TEST(LuaRankingTest, SuppressType) {
170 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
171 ActionsSuggestionsResponse response;
172 response.actions = {
173 {/*response_text=*/"hello there", /*type=*/"text_reply",
174 /*score=*/1.0},
175 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
176 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
177 const std::string test_snippet = R"(
178 local result = {}
179 for id, action in pairs(actions) do
180 if action.type ~= "text_reply" then
181 table.insert(result, id)
182 end
183 end
184 return result
185 )";
186
187 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
188 conversation, test_snippet, /*entity_data_schema=*/nullptr,
189 /*annotations_entity_data_schema=*/nullptr, &response)
190 ->RankActions());
191 EXPECT_THAT(response.actions,
192 testing::ElementsAreArray({IsActionType("share_location"),
193 IsActionType("add_to_collection")}));
194 }
195
TEST(LuaRankingTest,HandlesConversation)196 TEST(LuaRankingTest, HandlesConversation) {
197 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
198 ActionsSuggestionsResponse response;
199 response.actions = {
200 {/*response_text=*/"hello there", /*type=*/"text_reply",
201 /*score=*/1.0},
202 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
203 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
204 const std::string test_snippet = R"(
205 local result = {}
206 if messages[1].text ~= "hello hello" then
207 return result
208 end
209 for id, action in pairs(actions) do
210 if action.type ~= "text_reply" then
211 table.insert(result, id)
212 end
213 end
214 return result
215 )";
216
217 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
218 conversation, test_snippet, /*entity_data_schema=*/nullptr,
219 /*annotations_entity_data_schema=*/nullptr, &response)
220 ->RankActions());
221 EXPECT_THAT(response.actions,
222 testing::ElementsAreArray({IsActionType("share_location"),
223 IsActionType("add_to_collection")}));
224 }
225
TEST(LuaRankingTest,HandlesEntityData)226 TEST(LuaRankingTest, HandlesEntityData) {
227 std::string serialized_schema = TestEntitySchema();
228 const reflection::Schema* entity_data_schema =
229 flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
230
231 // Create test entity data.
232 MutableFlatbufferBuilder builder(entity_data_schema);
233 std::unique_ptr<MutableFlatbuffer> buffer = builder.NewRoot();
234 buffer->Set("test", "value_a");
235 const std::string serialized_entity_data_a = buffer->Serialize();
236 buffer->Set("test", "value_b");
237 const std::string serialized_entity_data_b = buffer->Serialize();
238
239 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
240 ActionsSuggestionsResponse response;
241 response.actions = {
242 {/*response_text=*/"", /*type=*/"test",
243 /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
244 /*serialized_entity_data=*/serialized_entity_data_a},
245 {/*response_text=*/"", /*type=*/"test",
246 /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
247 /*serialized_entity_data=*/serialized_entity_data_b},
248 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
249 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
250 const std::string test_snippet = R"(
251 local result = {}
252 for id, action in pairs(actions) do
253 if action.type == "test" and action.test == "value_a" then
254 table.insert(result, id)
255 end
256 end
257 return result
258 )";
259
260 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
261 conversation, test_snippet, entity_data_schema,
262 /*annotations_entity_data_schema=*/nullptr, &response)
263 ->RankActions());
264 EXPECT_THAT(response.actions,
265 testing::ElementsAreArray({IsActionType("test")}));
266 }
267
268 } // namespace
269 } // namespace libtextclassifier3
270