1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <fstream>
20 #include <iterator>
21 #include <memory>
22
23 #include "actions/actions_model_generated.h"
24 #include "actions/test_utils.h"
25 #include "actions/zlib-utils.h"
26 #include "annotator/collections.h"
27 #include "annotator/types.h"
28 #include "utils/flatbuffers.h"
29 #include "utils/flatbuffers_generated.h"
30 #include "utils/hash/farmhash.h"
31 #include "gmock/gmock.h"
32 #include "gtest/gtest.h"
33 #include "flatbuffers/flatbuffers.h"
34 #include "flatbuffers/reflection.h"
35
36 namespace libtextclassifier3 {
37 namespace {
38 using testing::_;
39
40 constexpr char kModelFileName[] = "actions_suggestions_test.model";
41 constexpr char kHashGramModelFileName[] =
42 "actions_suggestions_test.hashgram.model";
43
ReadFile(const std::string & file_name)44 std::string ReadFile(const std::string& file_name) {
45 std::ifstream file_stream(file_name);
46 return std::string(std::istreambuf_iterator<char>(file_stream), {});
47 }
48
GetModelPath()49 std::string GetModelPath() {
50 return "";
51 }
52
53 class ActionsSuggestionsTest : public testing::Test {
54 protected:
ActionsSuggestionsTest()55 ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
LoadTestModel()56 std::unique_ptr<ActionsSuggestions> LoadTestModel() {
57 return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
58 &unilib_);
59 }
LoadHashGramTestModel()60 std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
61 return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
62 &unilib_);
63 }
64 UniLib unilib_;
65 };
66
TEST_F(ActionsSuggestionsTest,InstantiateActionSuggestions)67 TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
68 EXPECT_THAT(LoadTestModel(), testing::NotNull());
69 }
70
TEST_F(ActionsSuggestionsTest,SuggestActions)71 TEST_F(ActionsSuggestionsTest, SuggestActions) {
72 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
73 const ActionsSuggestionsResponse& response =
74 actions_suggestions->SuggestActions(
75 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
76 /*reference_timezone=*/"Europe/Zurich",
77 /*annotations=*/{}, /*locales=*/"en"}}});
78 EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
79 }
80
TEST_F(ActionsSuggestionsTest,SuggestNoActionsForUnknownLocale)81 TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) {
82 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
83 const ActionsSuggestionsResponse& response =
84 actions_suggestions->SuggestActions(
85 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
86 /*reference_timezone=*/"Europe/Zurich",
87 /*annotations=*/{}, /*locales=*/"zz"}}});
88 EXPECT_THAT(response.actions, testing::IsEmpty());
89 }
90
TEST_F(ActionsSuggestionsTest,SuggestActionsFromAnnotations)91 TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
92 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
93 AnnotatedSpan annotation;
94 annotation.span = {11, 15};
95 annotation.classification = {ClassificationResult("address", 1.0)};
96 const ActionsSuggestionsResponse& response =
97 actions_suggestions->SuggestActions(
98 {{{/*user_id=*/1, "are you at home?",
99 /*reference_time_ms_utc=*/0,
100 /*reference_timezone=*/"Europe/Zurich",
101 /*annotations=*/{annotation},
102 /*locales=*/"en"}}});
103 ASSERT_GE(response.actions.size(), 1);
104 EXPECT_EQ(response.actions.front().type, "view_map");
105 EXPECT_EQ(response.actions.front().score, 1.0);
106 }
107
TEST_F(ActionsSuggestionsTest,SuggestActionsFromAnnotationsWithEntityData)108 TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) {
109 const std::string actions_model_string =
110 ReadFile(GetModelPath() + kModelFileName);
111 std::unique_ptr<ActionsModelT> actions_model =
112 UnPackActionsModel(actions_model_string.c_str());
113 SetTestEntityDataSchema(actions_model.get());
114
115 // Set custom actions from annotations config.
116 actions_model->annotation_actions_spec->annotation_mapping.clear();
117 actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
118 new AnnotationActionsSpec_::AnnotationMappingT);
119 AnnotationActionsSpec_::AnnotationMappingT* mapping =
120 actions_model->annotation_actions_spec->annotation_mapping.back().get();
121 mapping->annotation_collection = "address";
122 mapping->action.reset(new ActionSuggestionSpecT);
123 mapping->action->type = "save_location";
124 mapping->action->score = 1.0;
125 mapping->action->priority_score = 2.0;
126 mapping->entity_field.reset(new FlatbufferFieldPathT);
127 mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
128 mapping->entity_field->field.back()->field_name = "location";
129
130 flatbuffers::FlatBufferBuilder builder;
131 FinishActionsModelBuffer(builder,
132 ActionsModel::Pack(builder, actions_model.get()));
133 std::unique_ptr<ActionsSuggestions> actions_suggestions =
134 ActionsSuggestions::FromUnownedBuffer(
135 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
136 builder.GetSize(), &unilib_);
137
138 AnnotatedSpan annotation;
139 annotation.span = {11, 15};
140 annotation.classification = {ClassificationResult("address", 1.0)};
141 const ActionsSuggestionsResponse& response =
142 actions_suggestions->SuggestActions(
143 {{{/*user_id=*/1, "are you at home?",
144 /*reference_time_ms_utc=*/0,
145 /*reference_timezone=*/"Europe/Zurich",
146 /*annotations=*/{annotation},
147 /*locales=*/"en"}}});
148 ASSERT_GE(response.actions.size(), 1);
149 EXPECT_EQ(response.actions.front().type, "save_location");
150 EXPECT_EQ(response.actions.front().score, 1.0);
151
152 // Check that the `location` entity field holds the text from the address
153 // annotation.
154 const flatbuffers::Table* entity =
155 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
156 response.actions.front().serialized_entity_data.data()));
157 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
158 "home");
159 }
160
TEST_F(ActionsSuggestionsTest,SuggestActionsFromDuplicatedAnnotations)161 TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) {
162 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
163 AnnotatedSpan flight_annotation;
164 flight_annotation.span = {11, 15};
165 flight_annotation.classification = {ClassificationResult("flight", 2.5)};
166 AnnotatedSpan flight_annotation2;
167 flight_annotation2.span = {35, 39};
168 flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
169 AnnotatedSpan email_annotation;
170 email_annotation.span = {55, 68};
171 email_annotation.classification = {ClassificationResult("email", 2.0)};
172
173 const ActionsSuggestionsResponse& response =
174 actions_suggestions->SuggestActions(
175 {{{/*user_id=*/1,
176 "call me at LX38 or send message to LX38 or test@test.com.",
177 /*reference_time_ms_utc=*/0,
178 /*reference_timezone=*/"Europe/Zurich",
179 /*annotations=*/
180 {flight_annotation, flight_annotation2, email_annotation},
181 /*locales=*/"en"}}});
182
183 ASSERT_GE(response.actions.size(), 2);
184 EXPECT_EQ(response.actions[0].type, "track_flight");
185 EXPECT_EQ(response.actions[0].score, 3.0);
186 EXPECT_EQ(response.actions[1].type, "send_email");
187 EXPECT_EQ(response.actions[1].score, 2.0);
188 }
189
TEST_F(ActionsSuggestionsTest,SuggestActionsAnnotationsNoDeduplication)190 TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) {
191 const std::string actions_model_string =
192 ReadFile(GetModelPath() + kModelFileName);
193 std::unique_ptr<ActionsModelT> actions_model =
194 UnPackActionsModel(actions_model_string.c_str());
195 // Disable deduplication.
196 actions_model->annotation_actions_spec->deduplicate_annotations = false;
197 flatbuffers::FlatBufferBuilder builder;
198 FinishActionsModelBuffer(builder,
199 ActionsModel::Pack(builder, actions_model.get()));
200 std::unique_ptr<ActionsSuggestions> actions_suggestions =
201 ActionsSuggestions::FromUnownedBuffer(
202 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
203 builder.GetSize(), &unilib_);
204 AnnotatedSpan flight_annotation;
205 flight_annotation.span = {11, 15};
206 flight_annotation.classification = {ClassificationResult("flight", 2.5)};
207 AnnotatedSpan flight_annotation2;
208 flight_annotation2.span = {35, 39};
209 flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
210 AnnotatedSpan email_annotation;
211 email_annotation.span = {55, 68};
212 email_annotation.classification = {ClassificationResult("email", 2.0)};
213
214 const ActionsSuggestionsResponse& response =
215 actions_suggestions->SuggestActions(
216 {{{/*user_id=*/1,
217 "call me at LX38 or send message to LX38 or test@test.com.",
218 /*reference_time_ms_utc=*/0,
219 /*reference_timezone=*/"Europe/Zurich",
220 /*annotations=*/
221 {flight_annotation, flight_annotation2, email_annotation},
222 /*locales=*/"en"}}});
223
224 ASSERT_GE(response.actions.size(), 3);
225 EXPECT_EQ(response.actions[0].type, "track_flight");
226 EXPECT_EQ(response.actions[0].score, 3.0);
227 EXPECT_EQ(response.actions[1].type, "track_flight");
228 EXPECT_EQ(response.actions[1].score, 2.5);
229 EXPECT_EQ(response.actions[2].type, "send_email");
230 EXPECT_EQ(response.actions[2].score, 2.0);
231 }
232
TestSuggestActionsFromAnnotations(const std::function<void (ActionsModelT *)> & set_config_fn,const UniLib * unilib=nullptr)233 ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
234 const std::function<void(ActionsModelT*)>& set_config_fn,
235 const UniLib* unilib = nullptr) {
236 const std::string actions_model_string =
237 ReadFile(GetModelPath() + kModelFileName);
238 std::unique_ptr<ActionsModelT> actions_model =
239 UnPackActionsModel(actions_model_string.c_str());
240
241 // Set custom config.
242 set_config_fn(actions_model.get());
243
244 // Disable smart reply for easier testing.
245 actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
246
247 flatbuffers::FlatBufferBuilder builder;
248 FinishActionsModelBuffer(builder,
249 ActionsModel::Pack(builder, actions_model.get()));
250 std::unique_ptr<ActionsSuggestions> actions_suggestions =
251 ActionsSuggestions::FromUnownedBuffer(
252 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
253 builder.GetSize(), unilib);
254
255 AnnotatedSpan flight_annotation;
256 flight_annotation.span = {15, 19};
257 flight_annotation.classification = {ClassificationResult("flight", 2.0)};
258 AnnotatedSpan email_annotation;
259 email_annotation.span = {0, 16};
260 email_annotation.classification = {ClassificationResult("email", 1.0)};
261
262 return actions_suggestions->SuggestActions(
263 {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
264 "hehe@android.com",
265 /*reference_time_ms_utc=*/0,
266 /*reference_timezone=*/"Europe/Zurich",
267 /*annotations=*/
268 {email_annotation},
269 /*locales=*/"en"},
270 {/*user_id=*/2,
271 "yoyo@android.com",
272 /*reference_time_ms_utc=*/0,
273 /*reference_timezone=*/"Europe/Zurich",
274 /*annotations=*/
275 {email_annotation},
276 /*locales=*/"en"},
277 {/*user_id=*/1,
278 "test@android.com",
279 /*reference_time_ms_utc=*/0,
280 /*reference_timezone=*/"Europe/Zurich",
281 /*annotations=*/
282 {email_annotation},
283 /*locales=*/"en"},
284 {/*user_id=*/1,
285 "I am on flight LX38.",
286 /*reference_time_ms_utc=*/0,
287 /*reference_timezone=*/"Europe/Zurich",
288 /*annotations=*/
289 {flight_annotation},
290 /*locales=*/"en"}}});
291 }
292
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsOnlyLastMessage)293 TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) {
294 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
295 [](ActionsModelT* actions_model) {
296 actions_model->annotation_actions_spec->include_local_user_messages =
297 false;
298 actions_model->annotation_actions_spec->only_until_last_sent = true;
299 actions_model->annotation_actions_spec->max_history_from_any_person = 1;
300 actions_model->annotation_actions_spec->max_history_from_last_person =
301 1;
302 },
303 &unilib_);
304 EXPECT_EQ(response.actions.size(), 1);
305 EXPECT_EQ(response.actions[0].type, "track_flight");
306 }
307
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsOnlyLastPerson)308 TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) {
309 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
310 [](ActionsModelT* actions_model) {
311 actions_model->annotation_actions_spec->include_local_user_messages =
312 false;
313 actions_model->annotation_actions_spec->only_until_last_sent = true;
314 actions_model->annotation_actions_spec->max_history_from_any_person = 1;
315 actions_model->annotation_actions_spec->max_history_from_last_person =
316 3;
317 },
318 &unilib_);
319 EXPECT_EQ(response.actions.size(), 2);
320 EXPECT_EQ(response.actions[0].type, "track_flight");
321 EXPECT_EQ(response.actions[1].type, "send_email");
322 }
323
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAny)324 TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) {
325 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
326 [](ActionsModelT* actions_model) {
327 actions_model->annotation_actions_spec->include_local_user_messages =
328 false;
329 actions_model->annotation_actions_spec->only_until_last_sent = true;
330 actions_model->annotation_actions_spec->max_history_from_any_person = 2;
331 actions_model->annotation_actions_spec->max_history_from_last_person =
332 1;
333 },
334 &unilib_);
335 EXPECT_EQ(response.actions.size(), 2);
336 EXPECT_EQ(response.actions[0].type, "track_flight");
337 EXPECT_EQ(response.actions[1].type, "send_email");
338 }
339
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAnyManyMessages)340 TEST_F(ActionsSuggestionsTest,
341 SuggestActionsWithAnnotationsFromAnyManyMessages) {
342 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
343 [](ActionsModelT* actions_model) {
344 actions_model->annotation_actions_spec->include_local_user_messages =
345 false;
346 actions_model->annotation_actions_spec->only_until_last_sent = true;
347 actions_model->annotation_actions_spec->max_history_from_any_person = 3;
348 actions_model->annotation_actions_spec->max_history_from_last_person =
349 1;
350 },
351 &unilib_);
352 EXPECT_EQ(response.actions.size(), 3);
353 EXPECT_EQ(response.actions[0].type, "track_flight");
354 EXPECT_EQ(response.actions[1].type, "send_email");
355 EXPECT_EQ(response.actions[2].type, "send_email");
356 }
357
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser)358 TEST_F(ActionsSuggestionsTest,
359 SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
360 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
361 [](ActionsModelT* actions_model) {
362 actions_model->annotation_actions_spec->include_local_user_messages =
363 false;
364 actions_model->annotation_actions_spec->only_until_last_sent = true;
365 actions_model->annotation_actions_spec->max_history_from_any_person = 5;
366 actions_model->annotation_actions_spec->max_history_from_last_person =
367 1;
368 },
369 &unilib_);
370 EXPECT_EQ(response.actions.size(), 3);
371 EXPECT_EQ(response.actions[0].type, "track_flight");
372 EXPECT_EQ(response.actions[1].type, "send_email");
373 EXPECT_EQ(response.actions[2].type, "send_email");
374 }
375
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser)376 TEST_F(ActionsSuggestionsTest,
377 SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
378 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
379 [](ActionsModelT* actions_model) {
380 actions_model->annotation_actions_spec->include_local_user_messages =
381 true;
382 actions_model->annotation_actions_spec->only_until_last_sent = false;
383 actions_model->annotation_actions_spec->max_history_from_any_person = 5;
384 actions_model->annotation_actions_spec->max_history_from_last_person =
385 1;
386 },
387 &unilib_);
388 EXPECT_EQ(response.actions.size(), 4);
389 EXPECT_EQ(response.actions[0].type, "track_flight");
390 EXPECT_EQ(response.actions[1].type, "send_email");
391 EXPECT_EQ(response.actions[2].type, "send_email");
392 EXPECT_EQ(response.actions[3].type, "send_email");
393 }
394
TestSuggestActionsWithThreshold(const std::function<void (ActionsModelT *)> & set_value_fn,const UniLib * unilib=nullptr,const int expected_size=0,const std::string & preconditions_overwrite="")395 void TestSuggestActionsWithThreshold(
396 const std::function<void(ActionsModelT*)>& set_value_fn,
397 const UniLib* unilib = nullptr, const int expected_size = 0,
398 const std::string& preconditions_overwrite = "") {
399 const std::string actions_model_string =
400 ReadFile(GetModelPath() + kModelFileName);
401 std::unique_ptr<ActionsModelT> actions_model =
402 UnPackActionsModel(actions_model_string.c_str());
403 set_value_fn(actions_model.get());
404 flatbuffers::FlatBufferBuilder builder;
405 FinishActionsModelBuffer(builder,
406 ActionsModel::Pack(builder, actions_model.get()));
407 std::unique_ptr<ActionsSuggestions> actions_suggestions =
408 ActionsSuggestions::FromUnownedBuffer(
409 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
410 builder.GetSize(), unilib, preconditions_overwrite);
411 ASSERT_TRUE(actions_suggestions);
412 const ActionsSuggestionsResponse& response =
413 actions_suggestions->SuggestActions(
414 {{{/*user_id=*/1, "I have the low-ground. Where are you?",
415 /*reference_time_ms_utc=*/0,
416 /*reference_timezone=*/"Europe/Zurich",
417 /*annotations=*/{}, /*locales=*/"en"}}});
418 EXPECT_LE(response.actions.size(), expected_size);
419 }
420
TEST_F(ActionsSuggestionsTest,SuggestActionsWithTriggeringScore)421 TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
422 TestSuggestActionsWithThreshold(
423 [](ActionsModelT* actions_model) {
424 actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
425 },
426 &unilib_,
427 /*expected_size=*/1 /*no smart reply, only actions*/
428 );
429 }
430
TEST_F(ActionsSuggestionsTest,SuggestActionsWithMinReplyScore)431 TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) {
432 TestSuggestActionsWithThreshold(
433 [](ActionsModelT* actions_model) {
434 actions_model->preconditions->min_reply_score_threshold = 1.0;
435 },
436 &unilib_,
437 /*expected_size=*/1 /*no smart reply, only actions*/
438 );
439 }
440
TEST_F(ActionsSuggestionsTest,SuggestActionsWithSensitiveTopicScore)441 TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
442 TestSuggestActionsWithThreshold(
443 [](ActionsModelT* actions_model) {
444 actions_model->preconditions->max_sensitive_topic_score = 0.0;
445 },
446 &unilib_,
447 /*expected_size=*/4 /* no sensitive prediction in test model*/);
448 }
449
TEST_F(ActionsSuggestionsTest,SuggestActionsWithMaxInputLength)450 TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) {
451 TestSuggestActionsWithThreshold(
452 [](ActionsModelT* actions_model) {
453 actions_model->preconditions->max_input_length = 0;
454 },
455 &unilib_);
456 }
457
TEST_F(ActionsSuggestionsTest,SuggestActionsWithMinInputLength)458 TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) {
459 TestSuggestActionsWithThreshold(
460 [](ActionsModelT* actions_model) {
461 actions_model->preconditions->min_input_length = 100;
462 },
463 &unilib_);
464 }
465
TEST_F(ActionsSuggestionsTest,SuggestActionsWithPreconditionsOverwrite)466 TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) {
467 TriggeringPreconditionsT preconditions_overwrite;
468 preconditions_overwrite.max_input_length = 0;
469 flatbuffers::FlatBufferBuilder builder;
470 builder.Finish(
471 TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
472 TestSuggestActionsWithThreshold(
473 // Keep model untouched.
474 [](ActionsModelT* actions_model) {}, &unilib_,
475 /*expected_size=*/0,
476 std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
477 builder.GetSize()));
478 }
479
480 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,SuggestActionsLowConfidence)481 TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) {
482 TestSuggestActionsWithThreshold(
483 [](ActionsModelT* actions_model) {
484 actions_model->preconditions->suppress_on_low_confidence_input = true;
485 actions_model->low_confidence_rules.reset(new RulesModelT);
486 actions_model->low_confidence_rules->rule.emplace_back(
487 new RulesModel_::RuleT);
488 actions_model->low_confidence_rules->rule.back()->pattern =
489 "low-ground";
490 },
491 &unilib_);
492 }
493
TEST_F(ActionsSuggestionsTest,SuggestActionsLowConfidenceInputOutput)494 TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) {
495 const std::string actions_model_string =
496 ReadFile(GetModelPath() + kModelFileName);
497 std::unique_ptr<ActionsModelT> actions_model =
498 UnPackActionsModel(actions_model_string.c_str());
499 // Add custom triggering rule.
500 actions_model->rules.reset(new RulesModelT());
501 actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
502 RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
503 rule->pattern = "^(?i:hello\\s(there))$";
504 {
505 std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
506 new RulesModel_::Rule_::RuleActionSpecT);
507 rule_action->action.reset(new ActionSuggestionSpecT);
508 rule_action->action->type = "text_reply";
509 rule_action->action->response_text = "General Desaster!";
510 rule_action->action->score = 1.0f;
511 rule_action->action->priority_score = 1.0f;
512 rule->actions.push_back(std::move(rule_action));
513 }
514 {
515 std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
516 new RulesModel_::Rule_::RuleActionSpecT);
517 rule_action->action.reset(new ActionSuggestionSpecT);
518 rule_action->action->type = "text_reply";
519 rule_action->action->response_text = "General Kenobi!";
520 rule_action->action->score = 1.0f;
521 rule_action->action->priority_score = 1.0f;
522 rule->actions.push_back(std::move(rule_action));
523 }
524
525 // Add input-output low confidence rule.
526 actions_model->preconditions->suppress_on_low_confidence_input = true;
527 actions_model->low_confidence_rules.reset(new RulesModelT);
528 actions_model->low_confidence_rules->rule.emplace_back(
529 new RulesModel_::RuleT);
530 actions_model->low_confidence_rules->rule.back()->pattern = "hello";
531 actions_model->low_confidence_rules->rule.back()->output_pattern =
532 "(?i:desaster)";
533
534 flatbuffers::FlatBufferBuilder builder;
535 FinishActionsModelBuffer(builder,
536 ActionsModel::Pack(builder, actions_model.get()));
537 std::unique_ptr<ActionsSuggestions> actions_suggestions =
538 ActionsSuggestions::FromUnownedBuffer(
539 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
540 builder.GetSize(), &unilib_);
541 ASSERT_TRUE(actions_suggestions);
542 const ActionsSuggestionsResponse& response =
543 actions_suggestions->SuggestActions(
544 {{{/*user_id=*/1, "hello there",
545 /*reference_time_ms_utc=*/0,
546 /*reference_timezone=*/"Europe/Zurich",
547 /*annotations=*/{}, /*locales=*/"en"}}});
548 ASSERT_GE(response.actions.size(), 1);
549 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
550 }
551
TEST_F(ActionsSuggestionsTest,SuggestActionsLowConfidenceInputOutputOverwrite)552 TEST_F(ActionsSuggestionsTest,
553 SuggestActionsLowConfidenceInputOutputOverwrite) {
554 const std::string actions_model_string =
555 ReadFile(GetModelPath() + kModelFileName);
556 std::unique_ptr<ActionsModelT> actions_model =
557 UnPackActionsModel(actions_model_string.c_str());
558 actions_model->low_confidence_rules.reset();
559
560 // Add custom triggering rule.
561 actions_model->rules.reset(new RulesModelT());
562 actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
563 RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
564 rule->pattern = "^(?i:hello\\s(there))$";
565 {
566 std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
567 new RulesModel_::Rule_::RuleActionSpecT);
568 rule_action->action.reset(new ActionSuggestionSpecT);
569 rule_action->action->type = "text_reply";
570 rule_action->action->response_text = "General Desaster!";
571 rule_action->action->score = 1.0f;
572 rule_action->action->priority_score = 1.0f;
573 rule->actions.push_back(std::move(rule_action));
574 }
575 {
576 std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
577 new RulesModel_::Rule_::RuleActionSpecT);
578 rule_action->action.reset(new ActionSuggestionSpecT);
579 rule_action->action->type = "text_reply";
580 rule_action->action->response_text = "General Kenobi!";
581 rule_action->action->score = 1.0f;
582 rule_action->action->priority_score = 1.0f;
583 rule->actions.push_back(std::move(rule_action));
584 }
585
586 // Add custom triggering rule via overwrite.
587 actions_model->preconditions->low_confidence_rules.reset();
588 TriggeringPreconditionsT preconditions;
589 preconditions.suppress_on_low_confidence_input = true;
590 preconditions.low_confidence_rules.reset(new RulesModelT);
591 preconditions.low_confidence_rules->rule.emplace_back(new RulesModel_::RuleT);
592 preconditions.low_confidence_rules->rule.back()->pattern = "hello";
593 preconditions.low_confidence_rules->rule.back()->output_pattern =
594 "(?i:desaster)";
595 flatbuffers::FlatBufferBuilder preconditions_builder;
596 preconditions_builder.Finish(
597 TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
598 std::string serialize_preconditions = std::string(
599 reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
600 preconditions_builder.GetSize());
601
602 flatbuffers::FlatBufferBuilder builder;
603 FinishActionsModelBuffer(builder,
604 ActionsModel::Pack(builder, actions_model.get()));
605 std::unique_ptr<ActionsSuggestions> actions_suggestions =
606 ActionsSuggestions::FromUnownedBuffer(
607 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
608 builder.GetSize(), &unilib_, serialize_preconditions);
609
610 ASSERT_TRUE(actions_suggestions);
611 const ActionsSuggestionsResponse& response =
612 actions_suggestions->SuggestActions(
613 {{{/*user_id=*/1, "hello there",
614 /*reference_time_ms_utc=*/0,
615 /*reference_timezone=*/"Europe/Zurich",
616 /*annotations=*/{}, /*locales=*/"en"}}});
617 ASSERT_GE(response.actions.size(), 1);
618 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
619 }
620 #endif
621
TEST_F(ActionsSuggestionsTest,SuppressActionsFromAnnotationsOnSensitiveTopic)622 TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
623 const std::string actions_model_string =
624 ReadFile(GetModelPath() + kModelFileName);
625 std::unique_ptr<ActionsModelT> actions_model =
626 UnPackActionsModel(actions_model_string.c_str());
627
628 // Don't test if no sensitivity score is produced
629 if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
630 return;
631 }
632
633 actions_model->preconditions->max_sensitive_topic_score = 0.0;
634 actions_model->preconditions->suppress_on_sensitive_topic = true;
635 flatbuffers::FlatBufferBuilder builder;
636 FinishActionsModelBuffer(builder,
637 ActionsModel::Pack(builder, actions_model.get()));
638 std::unique_ptr<ActionsSuggestions> actions_suggestions =
639 ActionsSuggestions::FromUnownedBuffer(
640 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
641 builder.GetSize(), &unilib_);
642 AnnotatedSpan annotation;
643 annotation.span = {11, 15};
644 annotation.classification = {
645 ClassificationResult(Collections::Address(), 1.0)};
646 const ActionsSuggestionsResponse& response =
647 actions_suggestions->SuggestActions(
648 {{{/*user_id=*/1, "are you at home?",
649 /*reference_time_ms_utc=*/0,
650 /*reference_timezone=*/"Europe/Zurich",
651 /*annotations=*/{annotation},
652 /*locales=*/"en"}}});
653 EXPECT_THAT(response.actions, testing::IsEmpty());
654 }
655
TEST_F(ActionsSuggestionsTest,SuggestActionsWithLongerConversation)656 TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
657 const std::string actions_model_string =
658 ReadFile(GetModelPath() + kModelFileName);
659 std::unique_ptr<ActionsModelT> actions_model =
660 UnPackActionsModel(actions_model_string.c_str());
661
662 // Allow a larger conversation context.
663 actions_model->max_conversation_history_length = 10;
664
665 flatbuffers::FlatBufferBuilder builder;
666 FinishActionsModelBuffer(builder,
667 ActionsModel::Pack(builder, actions_model.get()));
668 std::unique_ptr<ActionsSuggestions> actions_suggestions =
669 ActionsSuggestions::FromUnownedBuffer(
670 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
671 builder.GetSize(), &unilib_);
672 AnnotatedSpan annotation;
673 annotation.span = {11, 15};
674 annotation.classification = {
675 ClassificationResult(Collections::Address(), 1.0)};
676 const ActionsSuggestionsResponse& response =
677 actions_suggestions->SuggestActions(
678 {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
679 /*reference_time_ms_utc=*/10000,
680 /*reference_timezone=*/"Europe/Zurich",
681 /*annotations=*/{}, /*locales=*/"en"},
682 {/*user_id=*/1, "good! are you at home?",
683 /*reference_time_ms_utc=*/15000,
684 /*reference_timezone=*/"Europe/Zurich",
685 /*annotations=*/{annotation},
686 /*locales=*/"en"}}});
687 ASSERT_GE(response.actions.size(), 1);
688 EXPECT_EQ(response.actions[0].type, "view_map");
689 EXPECT_EQ(response.actions[0].score, 1.0);
690 }
691
TEST_F(ActionsSuggestionsTest,CreateActionsFromClassificationResult)692 TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
693 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
694 AnnotatedSpan annotation;
695 annotation.span = {8, 12};
696 annotation.classification = {
697 ClassificationResult(Collections::Flight(), 1.0)};
698
699 const ActionsSuggestionsResponse& response =
700 actions_suggestions->SuggestActions(
701 {{{/*user_id=*/1, "I'm on LX38?",
702 /*reference_time_ms_utc=*/0,
703 /*reference_timezone=*/"Europe/Zurich",
704 /*annotations=*/{annotation},
705 /*locales=*/"en"}}});
706
707 ASSERT_GE(response.actions.size(), 2);
708 EXPECT_EQ(response.actions[0].type, "track_flight");
709 EXPECT_EQ(response.actions[0].score, 1.0);
710 EXPECT_EQ(response.actions[0].annotations.size(), 1);
711 EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
712 EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
713 }
714
715 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,CreateActionsFromRules)716 TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
717 const std::string actions_model_string =
718 ReadFile(GetModelPath() + kModelFileName);
719 std::unique_ptr<ActionsModelT> actions_model =
720 UnPackActionsModel(actions_model_string.c_str());
721 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
722
723 actions_model->rules.reset(new RulesModelT());
724 actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
725 RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
726 rule->pattern = "^(?i:hello\\s(there))$";
727 rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
728 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
729 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
730 action->type = "text_reply";
731 action->response_text = "General Kenobi!";
732 action->score = 1.0f;
733 action->priority_score = 1.0f;
734
735 // Set capturing groups for entity data.
736 rule->actions.back()->capturing_group.emplace_back(
737 new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
738 RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
739 rule->actions.back()->capturing_group.back().get();
740 greeting_group->group_id = 0;
741 greeting_group->entity_field.reset(new FlatbufferFieldPathT);
742 greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
743 greeting_group->entity_field->field.back()->field_name = "greeting";
744 rule->actions.back()->capturing_group.emplace_back(
745 new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
746 RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group =
747 rule->actions.back()->capturing_group.back().get();
748 location_group->group_id = 1;
749 location_group->entity_field.reset(new FlatbufferFieldPathT);
750 location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
751 location_group->entity_field->field.back()->field_name = "location";
752
753 // Set test entity data schema.
754 SetTestEntityDataSchema(actions_model.get());
755
756 // Use meta data to generate custom serialized entity data.
757 ReflectiveFlatbufferBuilder entity_data_builder(
758 flatbuffers::GetRoot<reflection::Schema>(
759 actions_model->actions_entity_data_schema.data()));
760 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
761 entity_data_builder.NewRoot();
762 entity_data->Set("person", "Kenobi");
763 action->serialized_entity_data = entity_data->Serialize();
764
765 flatbuffers::FlatBufferBuilder builder;
766 FinishActionsModelBuffer(builder,
767 ActionsModel::Pack(builder, actions_model.get()));
768 std::unique_ptr<ActionsSuggestions> actions_suggestions =
769 ActionsSuggestions::FromUnownedBuffer(
770 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
771 builder.GetSize(), &unilib_);
772
773 const ActionsSuggestionsResponse& response =
774 actions_suggestions->SuggestActions(
775 {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
776 /*reference_timezone=*/"Europe/Zurich",
777 /*annotations=*/{}, /*locales=*/"en"}}});
778 EXPECT_GE(response.actions.size(), 1);
779 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
780
781 // Check entity data.
782 const flatbuffers::Table* entity =
783 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
784 response.actions[0].serialized_entity_data.data()));
785 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
786 "hello there");
787 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
788 "there");
789 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
790 "Kenobi");
791 }
792
TEST_F(ActionsSuggestionsTest,CreatesTextRepliesFromRules)793 TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
794 const std::string actions_model_string =
795 ReadFile(GetModelPath() + kModelFileName);
796 std::unique_ptr<ActionsModelT> actions_model =
797 UnPackActionsModel(actions_model_string.c_str());
798 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
799
800 actions_model->rules.reset(new RulesModelT());
801 actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
802 RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
803 rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
804 rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
805
806 // Set capturing groups for entity data.
807 rule->actions.back()->capturing_group.emplace_back(
808 new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
809 RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
810 rule->actions.back()->capturing_group.back().get();
811 code_group->group_id = 1;
812 code_group->text_reply.reset(new ActionSuggestionSpecT);
813 code_group->text_reply->score = 1.0f;
814 code_group->text_reply->priority_score = 1.0f;
815
816 flatbuffers::FlatBufferBuilder builder;
817 FinishActionsModelBuffer(builder,
818 ActionsModel::Pack(builder, actions_model.get()));
819 std::unique_ptr<ActionsSuggestions> actions_suggestions =
820 ActionsSuggestions::FromUnownedBuffer(
821 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
822 builder.GetSize(), &unilib_);
823
824 const ActionsSuggestionsResponse& response =
825 actions_suggestions->SuggestActions(
826 {{{/*user_id=*/1,
827 "visit test.com or reply STOP to cancel your subscription",
828 /*reference_time_ms_utc=*/0,
829 /*reference_timezone=*/"Europe/Zurich",
830 /*annotations=*/{}, /*locales=*/"en"}}});
831 EXPECT_GE(response.actions.size(), 1);
832 EXPECT_EQ(response.actions[0].response_text, "STOP");
833 }
834
TEST_F(ActionsSuggestionsTest,DeduplicateActions)835 TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
836 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
837 ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
838 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
839 /*reference_timezone=*/"Europe/Zurich",
840 /*annotations=*/{}, /*locales=*/"en"}}});
841
842 // Check that the location sharing model triggered.
843 bool has_location_sharing_action = false;
844 for (const ActionSuggestion action : response.actions) {
845 if (action.type == ActionsSuggestions::kShareLocation) {
846 has_location_sharing_action = true;
847 break;
848 }
849 }
850 EXPECT_TRUE(has_location_sharing_action);
851 const int num_actions = response.actions.size();
852
853 // Add custom rule for location sharing.
854 const std::string actions_model_string =
855 ReadFile(GetModelPath() + kModelFileName);
856 std::unique_ptr<ActionsModelT> actions_model =
857 UnPackActionsModel(actions_model_string.c_str());
858 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
859
860 actions_model->rules.reset(new RulesModelT());
861 actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
862 actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$";
863 actions_model->rules->rule.back()->actions.emplace_back(
864 new RulesModel_::Rule_::RuleActionSpecT);
865 actions_model->rules->rule.back()->actions.back()->action.reset(
866 new ActionSuggestionSpecT);
867 ActionSuggestionSpecT* action =
868 actions_model->rules->rule.back()->actions.back()->action.get();
869 action->score = 1.0f;
870 action->type = ActionsSuggestions::kShareLocation;
871
872 flatbuffers::FlatBufferBuilder builder;
873 FinishActionsModelBuffer(builder,
874 ActionsModel::Pack(builder, actions_model.get()));
875 actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
876 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
877 builder.GetSize(), &unilib_);
878
879 response = actions_suggestions->SuggestActions(
880 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
881 /*reference_timezone=*/"Europe/Zurich",
882 /*annotations=*/{}, /*locales=*/"en"}}});
883 EXPECT_EQ(response.actions.size(), num_actions);
884 }
885
TEST_F(ActionsSuggestionsTest,DeduplicateConflictingActions)886 TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
887 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
888 AnnotatedSpan annotation;
889 annotation.span = {7, 11};
890 annotation.classification = {
891 ClassificationResult(Collections::Flight(), 1.0)};
892 ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
893 {{{/*user_id=*/1, "I'm on LX38",
894 /*reference_time_ms_utc=*/0,
895 /*reference_timezone=*/"Europe/Zurich",
896 /*annotations=*/{annotation},
897 /*locales=*/"en"}}});
898
899 // Check that the phone actions are present.
900 EXPECT_GE(response.actions.size(), 1);
901 EXPECT_EQ(response.actions[0].type, "track_flight");
902
903 // Add custom rule.
904 const std::string actions_model_string =
905 ReadFile(GetModelPath() + kModelFileName);
906 std::unique_ptr<ActionsModelT> actions_model =
907 UnPackActionsModel(actions_model_string.c_str());
908 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
909
910 actions_model->rules.reset(new RulesModelT());
911 actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
912 RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
913 rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
914 rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
915 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
916 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
917 action->score = 1.0f;
918 action->priority_score = 2.0f;
919 action->type = "test_code";
920 rule->actions.back()->capturing_group.emplace_back(
921 new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
922 RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
923 rule->actions.back()->capturing_group.back().get();
924 code_group->group_id = 1;
925 code_group->annotation_name = "code";
926 code_group->annotation_type = "code";
927
928 flatbuffers::FlatBufferBuilder builder;
929 FinishActionsModelBuffer(builder,
930 ActionsModel::Pack(builder, actions_model.get()));
931 actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
932 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
933 builder.GetSize(), &unilib_);
934
935 response = actions_suggestions->SuggestActions(
936 {{{/*user_id=*/1, "I'm on LX38",
937 /*reference_time_ms_utc=*/0,
938 /*reference_timezone=*/"Europe/Zurich",
939 /*annotations=*/{annotation},
940 /*locales=*/"en"}}});
941 EXPECT_GE(response.actions.size(), 1);
942 EXPECT_EQ(response.actions[0].type, "test_code");
943 }
944 #endif
945
TEST_F(ActionsSuggestionsTest,SuggestActionsRanking)946 TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) {
947 std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
948 std::vector<AnnotatedSpan> annotations(2);
949 annotations[0].span = {11, 15};
950 annotations[0].classification = {ClassificationResult("address", 1.0)};
951 annotations[1].span = {19, 23};
952 annotations[1].classification = {ClassificationResult("address", 2.0)};
953 const ActionsSuggestionsResponse& response =
954 actions_suggestions->SuggestActions(
955 {{{/*user_id=*/1, "are you at home or work?",
956 /*reference_time_ms_utc=*/0,
957 /*reference_timezone=*/"Europe/Zurich",
958 /*annotations=*/annotations,
959 /*locales=*/"en"}}});
960 EXPECT_GE(response.actions.size(), 2);
961 EXPECT_EQ(response.actions[0].type, "view_map");
962 EXPECT_EQ(response.actions[0].score, 2.0);
963 EXPECT_EQ(response.actions[1].type, "view_map");
964 EXPECT_EQ(response.actions[1].score, 1.0);
965 }
966
TEST_F(ActionsSuggestionsTest,VisitActionsModel)967 TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
968 EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
969 [](const ActionsModel* model) {
970 if (model == nullptr) {
971 return false;
972 }
973 return true;
974 }));
975 EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
976 [](const ActionsModel* model) {
977 if (model == nullptr) {
978 return false;
979 }
980 return true;
981 }));
982 }
983
TEST_F(ActionsSuggestionsTest,SuggestActionsWithHashGramModel)984 TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) {
985 std::unique_ptr<ActionsSuggestions> actions_suggestions =
986 LoadHashGramTestModel();
987 ASSERT_TRUE(actions_suggestions != nullptr);
988 {
989 const ActionsSuggestionsResponse& response =
990 actions_suggestions->SuggestActions(
991 {{{/*user_id=*/1, "hello",
992 /*reference_time_ms_utc=*/0,
993 /*reference_timezone=*/"Europe/Zurich",
994 /*annotations=*/{},
995 /*locales=*/"en"}}});
996 EXPECT_THAT(response.actions, testing::IsEmpty());
997 }
998 {
999 const ActionsSuggestionsResponse& response =
1000 actions_suggestions->SuggestActions(
1001 {{{/*user_id=*/1, "where are you",
1002 /*reference_time_ms_utc=*/0,
1003 /*reference_timezone=*/"Europe/Zurich",
1004 /*annotations=*/{},
1005 /*locales=*/"en"}}});
1006 EXPECT_THAT(
1007 response.actions,
1008 ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
1009 }
1010 {
1011 const ActionsSuggestionsResponse& response =
1012 actions_suggestions->SuggestActions(
1013 {{{/*user_id=*/1, "do you know johns number",
1014 /*reference_time_ms_utc=*/0,
1015 /*reference_timezone=*/"Europe/Zurich",
1016 /*annotations=*/{},
1017 /*locales=*/"en"}}});
1018 EXPECT_THAT(
1019 response.actions,
1020 ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
1021 }
1022 }
1023
1024 // Test class to expose token embedding methods for testing.
1025 class TestingMessageEmbedder : private ActionsSuggestions {
1026 public:
1027 explicit TestingMessageEmbedder(const ActionsModel* model);
1028
1029 using ActionsSuggestions::EmbedAndFlattenTokens;
1030 using ActionsSuggestions::EmbedTokensPerMessage;
1031
1032 protected:
1033 // EmbeddingExecutor that always returns features based on
1034 // the id of the sparse features.
1035 class FakeEmbeddingExecutor : public EmbeddingExecutor {
1036 public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const1037 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
1038 const int dest_size) const override {
1039 TC3_CHECK_GE(dest_size, 1);
1040 EXPECT_EQ(sparse_features.size(), 1);
1041 dest[0] = sparse_features.data()[0];
1042 return true;
1043 }
1044 };
1045 };
1046
TestingMessageEmbedder(const ActionsModel * model)1047 TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) {
1048 model_ = model;
1049 const ActionsTokenFeatureProcessorOptions* options =
1050 model->feature_processor_options();
1051 feature_processor_.reset(
1052 new ActionsFeatureProcessor(options, /*unilib=*/nullptr));
1053 embedding_executor_.reset(new FakeEmbeddingExecutor());
1054 EXPECT_TRUE(
1055 EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
1056 EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
1057 EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
1058 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
1059 EXPECT_EQ(token_embedding_size_, 1);
1060 }
1061
1062 class EmbeddingTest : public testing::Test {
1063 protected:
EmbeddingTest()1064 EmbeddingTest() {
1065 model_.feature_processor_options.reset(
1066 new ActionsTokenFeatureProcessorOptionsT);
1067 options_ = model_.feature_processor_options.get();
1068 options_->chargram_orders = {1};
1069 options_->num_buckets = 1000;
1070 options_->embedding_size = 1;
1071 options_->start_token_id = 0;
1072 options_->end_token_id = 1;
1073 options_->padding_token_id = 2;
1074 options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1075 }
1076
CreateTestingMessageEmbedder()1077 TestingMessageEmbedder CreateTestingMessageEmbedder() {
1078 flatbuffers::FlatBufferBuilder builder;
1079 FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
1080 buffer_ = builder.ReleaseBufferPointer();
1081 return TestingMessageEmbedder(
1082 flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
1083 }
1084
1085 flatbuffers::DetachedBuffer buffer_;
1086 ActionsModelT model_;
1087 ActionsTokenFeatureProcessorOptionsT* options_;
1088 };
1089
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithNoBounds)1090 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
1091 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1092 std::vector<std::vector<Token>> tokens = {
1093 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1094 std::vector<float> embeddings;
1095 int max_num_tokens_per_message = 0;
1096
1097 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1098 &max_num_tokens_per_message));
1099
1100 EXPECT_EQ(max_num_tokens_per_message, 3);
1101 EXPECT_EQ(embeddings.size(), 3);
1102 EXPECT_THAT(embeddings[0],
1103 testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1104 options_->num_buckets));
1105 EXPECT_THAT(embeddings[1],
1106 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1107 options_->num_buckets));
1108 EXPECT_THAT(embeddings[2],
1109 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1110 options_->num_buckets));
1111 }
1112
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithPadding)1113 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
1114 options_->min_num_tokens_per_message = 5;
1115 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1116 std::vector<std::vector<Token>> tokens = {
1117 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1118 std::vector<float> embeddings;
1119 int max_num_tokens_per_message = 0;
1120
1121 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1122 &max_num_tokens_per_message));
1123
1124 EXPECT_EQ(max_num_tokens_per_message, 5);
1125 EXPECT_EQ(embeddings.size(), 5);
1126 EXPECT_THAT(embeddings[0],
1127 testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1128 options_->num_buckets));
1129 EXPECT_THAT(embeddings[1],
1130 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1131 options_->num_buckets));
1132 EXPECT_THAT(embeddings[2],
1133 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1134 options_->num_buckets));
1135 EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id));
1136 EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id));
1137 }
1138
TEST_F(EmbeddingTest,EmbedsTokensPerMessageDropsAtBeginning)1139 TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
1140 options_->max_num_tokens_per_message = 2;
1141 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1142 std::vector<std::vector<Token>> tokens = {
1143 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1144 std::vector<float> embeddings;
1145 int max_num_tokens_per_message = 0;
1146
1147 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1148 &max_num_tokens_per_message));
1149
1150 EXPECT_EQ(max_num_tokens_per_message, 2);
1151 EXPECT_EQ(embeddings.size(), 2);
1152 EXPECT_THAT(embeddings[0],
1153 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1154 options_->num_buckets));
1155 EXPECT_THAT(embeddings[1],
1156 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1157 options_->num_buckets));
1158 }
1159
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithMultipleMessagesNoBounds)1160 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
1161 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1162 std::vector<std::vector<Token>> tokens = {
1163 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1164 {Token("d", 0, 1), Token("e", 2, 3)}};
1165 std::vector<float> embeddings;
1166 int max_num_tokens_per_message = 0;
1167
1168 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1169 &max_num_tokens_per_message));
1170
1171 EXPECT_EQ(max_num_tokens_per_message, 3);
1172 EXPECT_THAT(embeddings[0],
1173 testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1174 options_->num_buckets));
1175 EXPECT_THAT(embeddings[1],
1176 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1177 options_->num_buckets));
1178 EXPECT_THAT(embeddings[2],
1179 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1180 options_->num_buckets));
1181 EXPECT_THAT(embeddings[3],
1182 testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1183 options_->num_buckets));
1184 EXPECT_THAT(embeddings[4],
1185 testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1186 options_->num_buckets));
1187 EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
1188 }
1189
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithNoBounds)1190 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
1191 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1192 std::vector<std::vector<Token>> tokens = {
1193 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1194 std::vector<float> embeddings;
1195 int total_token_count = 0;
1196
1197 EXPECT_TRUE(
1198 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1199
1200 EXPECT_EQ(total_token_count, 5);
1201 EXPECT_EQ(embeddings.size(), 5);
1202 EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
1203 EXPECT_THAT(embeddings[1],
1204 testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1205 options_->num_buckets));
1206 EXPECT_THAT(embeddings[2],
1207 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1208 options_->num_buckets));
1209 EXPECT_THAT(embeddings[3],
1210 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1211 options_->num_buckets));
1212 EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
1213 }
1214
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithPadding)1215 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
1216 options_->min_num_total_tokens = 7;
1217 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1218 std::vector<std::vector<Token>> tokens = {
1219 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1220 std::vector<float> embeddings;
1221 int total_token_count = 0;
1222
1223 EXPECT_TRUE(
1224 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1225
1226 EXPECT_EQ(total_token_count, 7);
1227 EXPECT_EQ(embeddings.size(), 7);
1228 EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
1229 EXPECT_THAT(embeddings[1],
1230 testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1231 options_->num_buckets));
1232 EXPECT_THAT(embeddings[2],
1233 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1234 options_->num_buckets));
1235 EXPECT_THAT(embeddings[3],
1236 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1237 options_->num_buckets));
1238 EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
1239 EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
1240 EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id));
1241 }
1242
TEST_F(EmbeddingTest,EmbedsFlattenedTokensDropsAtBeginning)1243 TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
1244 options_->max_num_total_tokens = 3;
1245 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1246 std::vector<std::vector<Token>> tokens = {
1247 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1248 std::vector<float> embeddings;
1249 int total_token_count = 0;
1250
1251 EXPECT_TRUE(
1252 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1253
1254 EXPECT_EQ(total_token_count, 3);
1255 EXPECT_EQ(embeddings.size(), 3);
1256 EXPECT_THAT(embeddings[0],
1257 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1258 options_->num_buckets));
1259 EXPECT_THAT(embeddings[1],
1260 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1261 options_->num_buckets));
1262 EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id));
1263 }
1264
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesNoBounds)1265 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
1266 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1267 std::vector<std::vector<Token>> tokens = {
1268 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1269 {Token("d", 0, 1), Token("e", 2, 3)}};
1270 std::vector<float> embeddings;
1271 int total_token_count = 0;
1272
1273 EXPECT_TRUE(
1274 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1275
1276 EXPECT_EQ(total_token_count, 9);
1277 EXPECT_EQ(embeddings.size(), 9);
1278 EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
1279 EXPECT_THAT(embeddings[1],
1280 testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1281 options_->num_buckets));
1282 EXPECT_THAT(embeddings[2],
1283 testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1284 options_->num_buckets));
1285 EXPECT_THAT(embeddings[3],
1286 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1287 options_->num_buckets));
1288 EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
1289 EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id));
1290 EXPECT_THAT(embeddings[6],
1291 testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1292 options_->num_buckets));
1293 EXPECT_THAT(embeddings[7],
1294 testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1295 options_->num_buckets));
1296 EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id));
1297 }
1298
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning)1299 TEST_F(EmbeddingTest,
1300 EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
1301 options_->max_num_total_tokens = 7;
1302 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1303 std::vector<std::vector<Token>> tokens = {
1304 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1305 {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
1306 std::vector<float> embeddings;
1307 int total_token_count = 0;
1308
1309 EXPECT_TRUE(
1310 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1311
1312 EXPECT_EQ(total_token_count, 7);
1313 EXPECT_EQ(embeddings.size(), 7);
1314 EXPECT_THAT(embeddings[0],
1315 testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1316 options_->num_buckets));
1317 EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id));
1318 EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id));
1319 EXPECT_THAT(embeddings[3],
1320 testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1321 options_->num_buckets));
1322 EXPECT_THAT(embeddings[4],
1323 testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1324 options_->num_buckets));
1325 EXPECT_THAT(embeddings[5],
1326 testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) %
1327 options_->num_buckets));
1328 EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id));
1329 }
1330
1331 } // namespace
1332 } // namespace libtextclassifier3
1333