• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-ranker.h"
18 
19 #include <string>
20 
21 #include "actions/types.h"
22 #include "utils/flatbuffers/mutable.h"
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 
26 namespace libtextclassifier3 {
27 namespace {
28 
29 MATCHER_P2(IsAction, type, response_text, "") {
30   return testing::Value(arg.type, type) &&
31          testing::Value(arg.response_text, response_text);
32 }
33 
34 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
35 
TestEntitySchema()36 std::string TestEntitySchema() {
37   // Create fake entity data schema meta data.
38   // Cannot use object oriented API here as that is not available for the
39   // reflection schema.
40   flatbuffers::FlatBufferBuilder schema_builder;
41   std::vector<flatbuffers::Offset<reflection::Field>> fields = {
42       reflection::CreateField(
43           schema_builder,
44           /*name=*/schema_builder.CreateString("test"),
45           /*type=*/
46           reflection::CreateType(schema_builder,
47                                  /*base_type=*/reflection::String),
48           /*id=*/0,
49           /*offset=*/4)};
50   std::vector<flatbuffers::Offset<reflection::Enum>> enums;
51   std::vector<flatbuffers::Offset<reflection::Object>> objects = {
52       reflection::CreateObject(
53           schema_builder,
54           /*name=*/schema_builder.CreateString("EntityData"),
55           /*fields=*/
56           schema_builder.CreateVectorOfSortedTables(&fields))};
57   schema_builder.Finish(reflection::CreateSchema(
58       schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
59       schema_builder.CreateVectorOfSortedTables(&enums),
60       /*(unused) file_ident=*/0,
61       /*(unused) file_ext=*/0,
62       /*root_table*/ objects[0]));
63   return std::string(
64       reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
65       schema_builder.GetSize());
66 }
67 
TEST(LuaRankingTest,PassThrough)68 TEST(LuaRankingTest, PassThrough) {
69   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
70   ActionsSuggestionsResponse response;
71   response.actions = {
72       {/*response_text=*/"hello there", /*type=*/"text_reply",
73        /*score=*/1.0},
74       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
75       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
76   const std::string test_snippet = R"(
77     local result = {}
78     for i=1,#actions do
79       table.insert(result, i)
80     end
81     return result
82   )";
83 
84   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
85                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
86                   /*annotations_entity_data_schema=*/nullptr, &response)
87                   ->RankActions());
88   EXPECT_THAT(response.actions,
89               testing::ElementsAreArray({IsActionType("text_reply"),
90                                          IsActionType("share_location"),
91                                          IsActionType("add_to_collection")}));
92 }
93 
TEST(LuaRankingTest,Filtering)94 TEST(LuaRankingTest, Filtering) {
95   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
96   ActionsSuggestionsResponse response;
97   response.actions = {
98       {/*response_text=*/"hello there", /*type=*/"text_reply",
99        /*score=*/1.0},
100       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
101       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
102   const std::string test_snippet = R"(
103     return {}
104   )";
105 
106   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
107                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
108                   /*annotations_entity_data_schema=*/nullptr, &response)
109                   ->RankActions());
110   EXPECT_THAT(response.actions, testing::IsEmpty());
111 }
112 
TEST(LuaRankingTest,Duplication)113 TEST(LuaRankingTest, Duplication) {
114   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
115   ActionsSuggestionsResponse response;
116   response.actions = {
117       {/*response_text=*/"hello there", /*type=*/"text_reply",
118        /*score=*/1.0},
119       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
120       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
121   const std::string test_snippet = R"(
122     local result = {}
123     for i=1,#actions do
124       table.insert(result, 1)
125     end
126     return result
127   )";
128 
129   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
130                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
131                   /*annotations_entity_data_schema=*/nullptr, &response)
132                   ->RankActions());
133   EXPECT_THAT(response.actions,
134               testing::ElementsAreArray({IsActionType("text_reply"),
135                                          IsActionType("text_reply"),
136                                          IsActionType("text_reply")}));
137 }
138 
TEST(LuaRankingTest,SortByScore)139 TEST(LuaRankingTest, SortByScore) {
140   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
141   ActionsSuggestionsResponse response;
142   response.actions = {
143       {/*response_text=*/"hello there", /*type=*/"text_reply",
144        /*score=*/1.0},
145       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
146       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
147   const std::string test_snippet = R"(
148     function testScoreSorter(a, b)
149       return actions[a].score < actions[b].score
150     end
151     local result = {}
152     for i=1,#actions do
153       result[i] = i
154     end
155     table.sort(result, testScoreSorter)
156     return result
157   )";
158 
159   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
160                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
161                   /*annotations_entity_data_schema=*/nullptr, &response)
162                   ->RankActions());
163   EXPECT_THAT(response.actions,
164               testing::ElementsAreArray({IsActionType("add_to_collection"),
165                                          IsActionType("share_location"),
166                                          IsActionType("text_reply")}));
167 }
168 
TEST(LuaRankingTest,SuppressType)169 TEST(LuaRankingTest, SuppressType) {
170   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
171   ActionsSuggestionsResponse response;
172   response.actions = {
173       {/*response_text=*/"hello there", /*type=*/"text_reply",
174        /*score=*/1.0},
175       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
176       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
177   const std::string test_snippet = R"(
178     local result = {}
179     for id, action in pairs(actions) do
180       if action.type ~= "text_reply" then
181         table.insert(result, id)
182       end
183     end
184     return result
185   )";
186 
187   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
188                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
189                   /*annotations_entity_data_schema=*/nullptr, &response)
190                   ->RankActions());
191   EXPECT_THAT(response.actions,
192               testing::ElementsAreArray({IsActionType("share_location"),
193                                          IsActionType("add_to_collection")}));
194 }
195 
TEST(LuaRankingTest,HandlesConversation)196 TEST(LuaRankingTest, HandlesConversation) {
197   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
198   ActionsSuggestionsResponse response;
199   response.actions = {
200       {/*response_text=*/"hello there", /*type=*/"text_reply",
201        /*score=*/1.0},
202       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
203       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
204   const std::string test_snippet = R"(
205     local result = {}
206     if messages[1].text ~= "hello hello" then
207       return result
208     end
209     for id, action in pairs(actions) do
210       if action.type ~= "text_reply" then
211         table.insert(result, id)
212       end
213     end
214     return result
215   )";
216 
217   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
218                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
219                   /*annotations_entity_data_schema=*/nullptr, &response)
220                   ->RankActions());
221   EXPECT_THAT(response.actions,
222               testing::ElementsAreArray({IsActionType("share_location"),
223                                          IsActionType("add_to_collection")}));
224 }
225 
TEST(LuaRankingTest,HandlesEntityData)226 TEST(LuaRankingTest, HandlesEntityData) {
227   std::string serialized_schema = TestEntitySchema();
228   const reflection::Schema* entity_data_schema =
229       flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
230 
231   // Create test entity data.
232   MutableFlatbufferBuilder builder(entity_data_schema);
233   std::unique_ptr<MutableFlatbuffer> buffer = builder.NewRoot();
234   buffer->Set("test", "value_a");
235   const std::string serialized_entity_data_a = buffer->Serialize();
236   buffer->Set("test", "value_b");
237   const std::string serialized_entity_data_b = buffer->Serialize();
238 
239   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
240   ActionsSuggestionsResponse response;
241   response.actions = {
242       {/*response_text=*/"", /*type=*/"test",
243        /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
244        /*serialized_entity_data=*/serialized_entity_data_a},
245       {/*response_text=*/"", /*type=*/"test",
246        /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
247        /*serialized_entity_data=*/serialized_entity_data_b},
248       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
249       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
250   const std::string test_snippet = R"(
251     local result = {}
252     for id, action in pairs(actions) do
253       if action.type == "test" and action.test == "value_a" then
254         table.insert(result, id)
255       end
256     end
257     return result
258   )";
259 
260   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
261                   conversation, test_snippet, entity_data_schema,
262                   /*annotations_entity_data_schema=*/nullptr, &response)
263                   ->RankActions());
264   EXPECT_THAT(response.actions,
265               testing::ElementsAreArray({IsActionType("test")}));
266 }
267 
268 }  // namespace
269 }  // namespace libtextclassifier3
270