• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <fstream>
20 #include <iterator>
21 #include <memory>
22 
23 #include "actions/actions_model_generated.h"
24 #include "actions/test_utils.h"
25 #include "actions/zlib-utils.h"
26 #include "annotator/collections.h"
27 #include "annotator/types.h"
28 #include "utils/flatbuffers.h"
29 #include "utils/flatbuffers_generated.h"
30 #include "utils/hash/farmhash.h"
31 #include "gmock/gmock.h"
32 #include "gtest/gtest.h"
33 #include "flatbuffers/flatbuffers.h"
34 #include "flatbuffers/reflection.h"
35 
36 namespace libtextclassifier3 {
37 namespace {
38 using testing::_;
39 
40 constexpr char kModelFileName[] = "actions_suggestions_test.model";
41 constexpr char kHashGramModelFileName[] =
42     "actions_suggestions_test.hashgram.model";
43 
ReadFile(const std::string & file_name)44 std::string ReadFile(const std::string& file_name) {
45   std::ifstream file_stream(file_name);
46   return std::string(std::istreambuf_iterator<char>(file_stream), {});
47 }
48 
GetModelPath()49 std::string GetModelPath() {
50   return "";
51 }
52 
53 class ActionsSuggestionsTest : public testing::Test {
54  protected:
ActionsSuggestionsTest()55   ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
LoadTestModel()56   std::unique_ptr<ActionsSuggestions> LoadTestModel() {
57     return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
58                                         &unilib_);
59   }
LoadHashGramTestModel()60   std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
61     return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
62                                         &unilib_);
63   }
64   UniLib unilib_;
65 };
66 
TEST_F(ActionsSuggestionsTest,InstantiateActionSuggestions)67 TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
68   EXPECT_THAT(LoadTestModel(), testing::NotNull());
69 }
70 
TEST_F(ActionsSuggestionsTest,SuggestActions)71 TEST_F(ActionsSuggestionsTest, SuggestActions) {
72   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
73   const ActionsSuggestionsResponse& response =
74       actions_suggestions->SuggestActions(
75           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
76              /*reference_timezone=*/"Europe/Zurich",
77              /*annotations=*/{}, /*locales=*/"en"}}});
78   EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
79 }
80 
TEST_F(ActionsSuggestionsTest,SuggestNoActionsForUnknownLocale)81 TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) {
82   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
83   const ActionsSuggestionsResponse& response =
84       actions_suggestions->SuggestActions(
85           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
86              /*reference_timezone=*/"Europe/Zurich",
87              /*annotations=*/{}, /*locales=*/"zz"}}});
88   EXPECT_THAT(response.actions, testing::IsEmpty());
89 }
90 
TEST_F(ActionsSuggestionsTest,SuggestActionsFromAnnotations)91 TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
92   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
93   AnnotatedSpan annotation;
94   annotation.span = {11, 15};
95   annotation.classification = {ClassificationResult("address", 1.0)};
96   const ActionsSuggestionsResponse& response =
97       actions_suggestions->SuggestActions(
98           {{{/*user_id=*/1, "are you at home?",
99              /*reference_time_ms_utc=*/0,
100              /*reference_timezone=*/"Europe/Zurich",
101              /*annotations=*/{annotation},
102              /*locales=*/"en"}}});
103   ASSERT_GE(response.actions.size(), 1);
104   EXPECT_EQ(response.actions.front().type, "view_map");
105   EXPECT_EQ(response.actions.front().score, 1.0);
106 }
107 
TEST_F(ActionsSuggestionsTest,SuggestActionsFromAnnotationsWithEntityData)108 TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) {
109   const std::string actions_model_string =
110       ReadFile(GetModelPath() + kModelFileName);
111   std::unique_ptr<ActionsModelT> actions_model =
112       UnPackActionsModel(actions_model_string.c_str());
113   SetTestEntityDataSchema(actions_model.get());
114 
115   // Set custom actions from annotations config.
116   actions_model->annotation_actions_spec->annotation_mapping.clear();
117   actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
118       new AnnotationActionsSpec_::AnnotationMappingT);
119   AnnotationActionsSpec_::AnnotationMappingT* mapping =
120       actions_model->annotation_actions_spec->annotation_mapping.back().get();
121   mapping->annotation_collection = "address";
122   mapping->action.reset(new ActionSuggestionSpecT);
123   mapping->action->type = "save_location";
124   mapping->action->score = 1.0;
125   mapping->action->priority_score = 2.0;
126   mapping->entity_field.reset(new FlatbufferFieldPathT);
127   mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
128   mapping->entity_field->field.back()->field_name = "location";
129 
130   flatbuffers::FlatBufferBuilder builder;
131   FinishActionsModelBuffer(builder,
132                            ActionsModel::Pack(builder, actions_model.get()));
133   std::unique_ptr<ActionsSuggestions> actions_suggestions =
134       ActionsSuggestions::FromUnownedBuffer(
135           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
136           builder.GetSize(), &unilib_);
137 
138   AnnotatedSpan annotation;
139   annotation.span = {11, 15};
140   annotation.classification = {ClassificationResult("address", 1.0)};
141   const ActionsSuggestionsResponse& response =
142       actions_suggestions->SuggestActions(
143           {{{/*user_id=*/1, "are you at home?",
144              /*reference_time_ms_utc=*/0,
145              /*reference_timezone=*/"Europe/Zurich",
146              /*annotations=*/{annotation},
147              /*locales=*/"en"}}});
148   ASSERT_GE(response.actions.size(), 1);
149   EXPECT_EQ(response.actions.front().type, "save_location");
150   EXPECT_EQ(response.actions.front().score, 1.0);
151 
152   // Check that the `location` entity field holds the text from the address
153   // annotation.
154   const flatbuffers::Table* entity =
155       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
156           response.actions.front().serialized_entity_data.data()));
157   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
158             "home");
159 }
160 
TEST_F(ActionsSuggestionsTest,SuggestActionsFromDuplicatedAnnotations)161 TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) {
162   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
163   AnnotatedSpan flight_annotation;
164   flight_annotation.span = {11, 15};
165   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
166   AnnotatedSpan flight_annotation2;
167   flight_annotation2.span = {35, 39};
168   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
169   AnnotatedSpan email_annotation;
170   email_annotation.span = {55, 68};
171   email_annotation.classification = {ClassificationResult("email", 2.0)};
172 
173   const ActionsSuggestionsResponse& response =
174       actions_suggestions->SuggestActions(
175           {{{/*user_id=*/1,
176              "call me at LX38 or send message to LX38 or test@test.com.",
177              /*reference_time_ms_utc=*/0,
178              /*reference_timezone=*/"Europe/Zurich",
179              /*annotations=*/
180              {flight_annotation, flight_annotation2, email_annotation},
181              /*locales=*/"en"}}});
182 
183   ASSERT_GE(response.actions.size(), 2);
184   EXPECT_EQ(response.actions[0].type, "track_flight");
185   EXPECT_EQ(response.actions[0].score, 3.0);
186   EXPECT_EQ(response.actions[1].type, "send_email");
187   EXPECT_EQ(response.actions[1].score, 2.0);
188 }
189 
TEST_F(ActionsSuggestionsTest,SuggestActionsAnnotationsNoDeduplication)190 TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) {
191   const std::string actions_model_string =
192       ReadFile(GetModelPath() + kModelFileName);
193   std::unique_ptr<ActionsModelT> actions_model =
194       UnPackActionsModel(actions_model_string.c_str());
195   // Disable deduplication.
196   actions_model->annotation_actions_spec->deduplicate_annotations = false;
197   flatbuffers::FlatBufferBuilder builder;
198   FinishActionsModelBuffer(builder,
199                            ActionsModel::Pack(builder, actions_model.get()));
200   std::unique_ptr<ActionsSuggestions> actions_suggestions =
201       ActionsSuggestions::FromUnownedBuffer(
202           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
203           builder.GetSize(), &unilib_);
204   AnnotatedSpan flight_annotation;
205   flight_annotation.span = {11, 15};
206   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
207   AnnotatedSpan flight_annotation2;
208   flight_annotation2.span = {35, 39};
209   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
210   AnnotatedSpan email_annotation;
211   email_annotation.span = {55, 68};
212   email_annotation.classification = {ClassificationResult("email", 2.0)};
213 
214   const ActionsSuggestionsResponse& response =
215       actions_suggestions->SuggestActions(
216           {{{/*user_id=*/1,
217              "call me at LX38 or send message to LX38 or test@test.com.",
218              /*reference_time_ms_utc=*/0,
219              /*reference_timezone=*/"Europe/Zurich",
220              /*annotations=*/
221              {flight_annotation, flight_annotation2, email_annotation},
222              /*locales=*/"en"}}});
223 
224   ASSERT_GE(response.actions.size(), 3);
225   EXPECT_EQ(response.actions[0].type, "track_flight");
226   EXPECT_EQ(response.actions[0].score, 3.0);
227   EXPECT_EQ(response.actions[1].type, "track_flight");
228   EXPECT_EQ(response.actions[1].score, 2.5);
229   EXPECT_EQ(response.actions[2].type, "send_email");
230   EXPECT_EQ(response.actions[2].score, 2.0);
231 }
232 
TestSuggestActionsFromAnnotations(const std::function<void (ActionsModelT *)> & set_config_fn,const UniLib * unilib=nullptr)233 ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
234     const std::function<void(ActionsModelT*)>& set_config_fn,
235     const UniLib* unilib = nullptr) {
236   const std::string actions_model_string =
237       ReadFile(GetModelPath() + kModelFileName);
238   std::unique_ptr<ActionsModelT> actions_model =
239       UnPackActionsModel(actions_model_string.c_str());
240 
241   // Set custom config.
242   set_config_fn(actions_model.get());
243 
244   // Disable smart reply for easier testing.
245   actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
246 
247   flatbuffers::FlatBufferBuilder builder;
248   FinishActionsModelBuffer(builder,
249                            ActionsModel::Pack(builder, actions_model.get()));
250   std::unique_ptr<ActionsSuggestions> actions_suggestions =
251       ActionsSuggestions::FromUnownedBuffer(
252           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
253           builder.GetSize(), unilib);
254 
255   AnnotatedSpan flight_annotation;
256   flight_annotation.span = {15, 19};
257   flight_annotation.classification = {ClassificationResult("flight", 2.0)};
258   AnnotatedSpan email_annotation;
259   email_annotation.span = {0, 16};
260   email_annotation.classification = {ClassificationResult("email", 1.0)};
261 
262   return actions_suggestions->SuggestActions(
263       {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
264          "hehe@android.com",
265          /*reference_time_ms_utc=*/0,
266          /*reference_timezone=*/"Europe/Zurich",
267          /*annotations=*/
268          {email_annotation},
269          /*locales=*/"en"},
270         {/*user_id=*/2,
271          "yoyo@android.com",
272          /*reference_time_ms_utc=*/0,
273          /*reference_timezone=*/"Europe/Zurich",
274          /*annotations=*/
275          {email_annotation},
276          /*locales=*/"en"},
277         {/*user_id=*/1,
278          "test@android.com",
279          /*reference_time_ms_utc=*/0,
280          /*reference_timezone=*/"Europe/Zurich",
281          /*annotations=*/
282          {email_annotation},
283          /*locales=*/"en"},
284         {/*user_id=*/1,
285          "I am on flight LX38.",
286          /*reference_time_ms_utc=*/0,
287          /*reference_timezone=*/"Europe/Zurich",
288          /*annotations=*/
289          {flight_annotation},
290          /*locales=*/"en"}}});
291 }
292 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsOnlyLastMessage)293 TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) {
294   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
295       [](ActionsModelT* actions_model) {
296         actions_model->annotation_actions_spec->include_local_user_messages =
297             false;
298         actions_model->annotation_actions_spec->only_until_last_sent = true;
299         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
300         actions_model->annotation_actions_spec->max_history_from_last_person =
301             1;
302       },
303       &unilib_);
304   EXPECT_EQ(response.actions.size(), 1);
305   EXPECT_EQ(response.actions[0].type, "track_flight");
306 }
307 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsOnlyLastPerson)308 TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) {
309   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
310       [](ActionsModelT* actions_model) {
311         actions_model->annotation_actions_spec->include_local_user_messages =
312             false;
313         actions_model->annotation_actions_spec->only_until_last_sent = true;
314         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
315         actions_model->annotation_actions_spec->max_history_from_last_person =
316             3;
317       },
318       &unilib_);
319   EXPECT_EQ(response.actions.size(), 2);
320   EXPECT_EQ(response.actions[0].type, "track_flight");
321   EXPECT_EQ(response.actions[1].type, "send_email");
322 }
323 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAny)324 TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) {
325   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
326       [](ActionsModelT* actions_model) {
327         actions_model->annotation_actions_spec->include_local_user_messages =
328             false;
329         actions_model->annotation_actions_spec->only_until_last_sent = true;
330         actions_model->annotation_actions_spec->max_history_from_any_person = 2;
331         actions_model->annotation_actions_spec->max_history_from_last_person =
332             1;
333       },
334       &unilib_);
335   EXPECT_EQ(response.actions.size(), 2);
336   EXPECT_EQ(response.actions[0].type, "track_flight");
337   EXPECT_EQ(response.actions[1].type, "send_email");
338 }
339 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAnyManyMessages)340 TEST_F(ActionsSuggestionsTest,
341        SuggestActionsWithAnnotationsFromAnyManyMessages) {
342   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
343       [](ActionsModelT* actions_model) {
344         actions_model->annotation_actions_spec->include_local_user_messages =
345             false;
346         actions_model->annotation_actions_spec->only_until_last_sent = true;
347         actions_model->annotation_actions_spec->max_history_from_any_person = 3;
348         actions_model->annotation_actions_spec->max_history_from_last_person =
349             1;
350       },
351       &unilib_);
352   EXPECT_EQ(response.actions.size(), 3);
353   EXPECT_EQ(response.actions[0].type, "track_flight");
354   EXPECT_EQ(response.actions[1].type, "send_email");
355   EXPECT_EQ(response.actions[2].type, "send_email");
356 }
357 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser)358 TEST_F(ActionsSuggestionsTest,
359        SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
360   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
361       [](ActionsModelT* actions_model) {
362         actions_model->annotation_actions_spec->include_local_user_messages =
363             false;
364         actions_model->annotation_actions_spec->only_until_last_sent = true;
365         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
366         actions_model->annotation_actions_spec->max_history_from_last_person =
367             1;
368       },
369       &unilib_);
370   EXPECT_EQ(response.actions.size(), 3);
371   EXPECT_EQ(response.actions[0].type, "track_flight");
372   EXPECT_EQ(response.actions[1].type, "send_email");
373   EXPECT_EQ(response.actions[2].type, "send_email");
374 }
375 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser)376 TEST_F(ActionsSuggestionsTest,
377        SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
378   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
379       [](ActionsModelT* actions_model) {
380         actions_model->annotation_actions_spec->include_local_user_messages =
381             true;
382         actions_model->annotation_actions_spec->only_until_last_sent = false;
383         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
384         actions_model->annotation_actions_spec->max_history_from_last_person =
385             1;
386       },
387       &unilib_);
388   EXPECT_EQ(response.actions.size(), 4);
389   EXPECT_EQ(response.actions[0].type, "track_flight");
390   EXPECT_EQ(response.actions[1].type, "send_email");
391   EXPECT_EQ(response.actions[2].type, "send_email");
392   EXPECT_EQ(response.actions[3].type, "send_email");
393 }
394 
TestSuggestActionsWithThreshold(const std::function<void (ActionsModelT *)> & set_value_fn,const UniLib * unilib=nullptr,const int expected_size=0,const std::string & preconditions_overwrite="")395 void TestSuggestActionsWithThreshold(
396     const std::function<void(ActionsModelT*)>& set_value_fn,
397     const UniLib* unilib = nullptr, const int expected_size = 0,
398     const std::string& preconditions_overwrite = "") {
399   const std::string actions_model_string =
400       ReadFile(GetModelPath() + kModelFileName);
401   std::unique_ptr<ActionsModelT> actions_model =
402       UnPackActionsModel(actions_model_string.c_str());
403   set_value_fn(actions_model.get());
404   flatbuffers::FlatBufferBuilder builder;
405   FinishActionsModelBuffer(builder,
406                            ActionsModel::Pack(builder, actions_model.get()));
407   std::unique_ptr<ActionsSuggestions> actions_suggestions =
408       ActionsSuggestions::FromUnownedBuffer(
409           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
410           builder.GetSize(), unilib, preconditions_overwrite);
411   ASSERT_TRUE(actions_suggestions);
412   const ActionsSuggestionsResponse& response =
413       actions_suggestions->SuggestActions(
414           {{{/*user_id=*/1, "I have the low-ground. Where are you?",
415              /*reference_time_ms_utc=*/0,
416              /*reference_timezone=*/"Europe/Zurich",
417              /*annotations=*/{}, /*locales=*/"en"}}});
418   EXPECT_LE(response.actions.size(), expected_size);
419 }
420 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithTriggeringScore)421 TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
422   TestSuggestActionsWithThreshold(
423       [](ActionsModelT* actions_model) {
424         actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
425       },
426       &unilib_,
427       /*expected_size=*/1 /*no smart reply, only actions*/
428   );
429 }
430 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithMinReplyScore)431 TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) {
432   TestSuggestActionsWithThreshold(
433       [](ActionsModelT* actions_model) {
434         actions_model->preconditions->min_reply_score_threshold = 1.0;
435       },
436       &unilib_,
437       /*expected_size=*/1 /*no smart reply, only actions*/
438   );
439 }
440 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithSensitiveTopicScore)441 TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
442   TestSuggestActionsWithThreshold(
443       [](ActionsModelT* actions_model) {
444         actions_model->preconditions->max_sensitive_topic_score = 0.0;
445       },
446       &unilib_,
447       /*expected_size=*/4 /* no sensitive prediction in test model*/);
448 }
449 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithMaxInputLength)450 TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) {
451   TestSuggestActionsWithThreshold(
452       [](ActionsModelT* actions_model) {
453         actions_model->preconditions->max_input_length = 0;
454       },
455       &unilib_);
456 }
457 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithMinInputLength)458 TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) {
459   TestSuggestActionsWithThreshold(
460       [](ActionsModelT* actions_model) {
461         actions_model->preconditions->min_input_length = 100;
462       },
463       &unilib_);
464 }
465 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithPreconditionsOverwrite)466 TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) {
467   TriggeringPreconditionsT preconditions_overwrite;
468   preconditions_overwrite.max_input_length = 0;
469   flatbuffers::FlatBufferBuilder builder;
470   builder.Finish(
471       TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
472   TestSuggestActionsWithThreshold(
473       // Keep model untouched.
474       [](ActionsModelT* actions_model) {}, &unilib_,
475       /*expected_size=*/0,
476       std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
477                   builder.GetSize()));
478 }
479 
480 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,SuggestActionsLowConfidence)481 TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) {
482   TestSuggestActionsWithThreshold(
483       [](ActionsModelT* actions_model) {
484         actions_model->preconditions->suppress_on_low_confidence_input = true;
485         actions_model->low_confidence_rules.reset(new RulesModelT);
486         actions_model->low_confidence_rules->rule.emplace_back(
487             new RulesModel_::RuleT);
488         actions_model->low_confidence_rules->rule.back()->pattern =
489             "low-ground";
490       },
491       &unilib_);
492 }
493 
TEST_F(ActionsSuggestionsTest,SuggestActionsLowConfidenceInputOutput)494 TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) {
495   const std::string actions_model_string =
496       ReadFile(GetModelPath() + kModelFileName);
497   std::unique_ptr<ActionsModelT> actions_model =
498       UnPackActionsModel(actions_model_string.c_str());
499   // Add custom triggering rule.
500   actions_model->rules.reset(new RulesModelT());
501   actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
502   RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
503   rule->pattern = "^(?i:hello\\s(there))$";
504   {
505     std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
506         new RulesModel_::Rule_::RuleActionSpecT);
507     rule_action->action.reset(new ActionSuggestionSpecT);
508     rule_action->action->type = "text_reply";
509     rule_action->action->response_text = "General Desaster!";
510     rule_action->action->score = 1.0f;
511     rule_action->action->priority_score = 1.0f;
512     rule->actions.push_back(std::move(rule_action));
513   }
514   {
515     std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
516         new RulesModel_::Rule_::RuleActionSpecT);
517     rule_action->action.reset(new ActionSuggestionSpecT);
518     rule_action->action->type = "text_reply";
519     rule_action->action->response_text = "General Kenobi!";
520     rule_action->action->score = 1.0f;
521     rule_action->action->priority_score = 1.0f;
522     rule->actions.push_back(std::move(rule_action));
523   }
524 
525   // Add input-output low confidence rule.
526   actions_model->preconditions->suppress_on_low_confidence_input = true;
527   actions_model->low_confidence_rules.reset(new RulesModelT);
528   actions_model->low_confidence_rules->rule.emplace_back(
529       new RulesModel_::RuleT);
530   actions_model->low_confidence_rules->rule.back()->pattern = "hello";
531   actions_model->low_confidence_rules->rule.back()->output_pattern =
532       "(?i:desaster)";
533 
534   flatbuffers::FlatBufferBuilder builder;
535   FinishActionsModelBuffer(builder,
536                            ActionsModel::Pack(builder, actions_model.get()));
537   std::unique_ptr<ActionsSuggestions> actions_suggestions =
538       ActionsSuggestions::FromUnownedBuffer(
539           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
540           builder.GetSize(), &unilib_);
541   ASSERT_TRUE(actions_suggestions);
542   const ActionsSuggestionsResponse& response =
543       actions_suggestions->SuggestActions(
544           {{{/*user_id=*/1, "hello there",
545              /*reference_time_ms_utc=*/0,
546              /*reference_timezone=*/"Europe/Zurich",
547              /*annotations=*/{}, /*locales=*/"en"}}});
548   ASSERT_GE(response.actions.size(), 1);
549   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
550 }
551 
TEST_F(ActionsSuggestionsTest,SuggestActionsLowConfidenceInputOutputOverwrite)552 TEST_F(ActionsSuggestionsTest,
553        SuggestActionsLowConfidenceInputOutputOverwrite) {
554   const std::string actions_model_string =
555       ReadFile(GetModelPath() + kModelFileName);
556   std::unique_ptr<ActionsModelT> actions_model =
557       UnPackActionsModel(actions_model_string.c_str());
558   actions_model->low_confidence_rules.reset();
559 
560   // Add custom triggering rule.
561   actions_model->rules.reset(new RulesModelT());
562   actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
563   RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
564   rule->pattern = "^(?i:hello\\s(there))$";
565   {
566     std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
567         new RulesModel_::Rule_::RuleActionSpecT);
568     rule_action->action.reset(new ActionSuggestionSpecT);
569     rule_action->action->type = "text_reply";
570     rule_action->action->response_text = "General Desaster!";
571     rule_action->action->score = 1.0f;
572     rule_action->action->priority_score = 1.0f;
573     rule->actions.push_back(std::move(rule_action));
574   }
575   {
576     std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
577         new RulesModel_::Rule_::RuleActionSpecT);
578     rule_action->action.reset(new ActionSuggestionSpecT);
579     rule_action->action->type = "text_reply";
580     rule_action->action->response_text = "General Kenobi!";
581     rule_action->action->score = 1.0f;
582     rule_action->action->priority_score = 1.0f;
583     rule->actions.push_back(std::move(rule_action));
584   }
585 
586   // Add custom triggering rule via overwrite.
587   actions_model->preconditions->low_confidence_rules.reset();
588   TriggeringPreconditionsT preconditions;
589   preconditions.suppress_on_low_confidence_input = true;
590   preconditions.low_confidence_rules.reset(new RulesModelT);
591   preconditions.low_confidence_rules->rule.emplace_back(new RulesModel_::RuleT);
592   preconditions.low_confidence_rules->rule.back()->pattern = "hello";
593   preconditions.low_confidence_rules->rule.back()->output_pattern =
594       "(?i:desaster)";
595   flatbuffers::FlatBufferBuilder preconditions_builder;
596   preconditions_builder.Finish(
597       TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
598   std::string serialize_preconditions = std::string(
599       reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
600       preconditions_builder.GetSize());
601 
602   flatbuffers::FlatBufferBuilder builder;
603   FinishActionsModelBuffer(builder,
604                            ActionsModel::Pack(builder, actions_model.get()));
605   std::unique_ptr<ActionsSuggestions> actions_suggestions =
606       ActionsSuggestions::FromUnownedBuffer(
607           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
608           builder.GetSize(), &unilib_, serialize_preconditions);
609 
610   ASSERT_TRUE(actions_suggestions);
611   const ActionsSuggestionsResponse& response =
612       actions_suggestions->SuggestActions(
613           {{{/*user_id=*/1, "hello there",
614              /*reference_time_ms_utc=*/0,
615              /*reference_timezone=*/"Europe/Zurich",
616              /*annotations=*/{}, /*locales=*/"en"}}});
617   ASSERT_GE(response.actions.size(), 1);
618   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
619 }
620 #endif
621 
TEST_F(ActionsSuggestionsTest,SuppressActionsFromAnnotationsOnSensitiveTopic)622 TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
623   const std::string actions_model_string =
624       ReadFile(GetModelPath() + kModelFileName);
625   std::unique_ptr<ActionsModelT> actions_model =
626       UnPackActionsModel(actions_model_string.c_str());
627 
628   // Don't test if no sensitivity score is produced
629   if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
630     return;
631   }
632 
633   actions_model->preconditions->max_sensitive_topic_score = 0.0;
634   actions_model->preconditions->suppress_on_sensitive_topic = true;
635   flatbuffers::FlatBufferBuilder builder;
636   FinishActionsModelBuffer(builder,
637                            ActionsModel::Pack(builder, actions_model.get()));
638   std::unique_ptr<ActionsSuggestions> actions_suggestions =
639       ActionsSuggestions::FromUnownedBuffer(
640           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
641           builder.GetSize(), &unilib_);
642   AnnotatedSpan annotation;
643   annotation.span = {11, 15};
644   annotation.classification = {
645       ClassificationResult(Collections::Address(), 1.0)};
646   const ActionsSuggestionsResponse& response =
647       actions_suggestions->SuggestActions(
648           {{{/*user_id=*/1, "are you at home?",
649              /*reference_time_ms_utc=*/0,
650              /*reference_timezone=*/"Europe/Zurich",
651              /*annotations=*/{annotation},
652              /*locales=*/"en"}}});
653   EXPECT_THAT(response.actions, testing::IsEmpty());
654 }
655 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithLongerConversation)656 TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
657   const std::string actions_model_string =
658       ReadFile(GetModelPath() + kModelFileName);
659   std::unique_ptr<ActionsModelT> actions_model =
660       UnPackActionsModel(actions_model_string.c_str());
661 
662   // Allow a larger conversation context.
663   actions_model->max_conversation_history_length = 10;
664 
665   flatbuffers::FlatBufferBuilder builder;
666   FinishActionsModelBuffer(builder,
667                            ActionsModel::Pack(builder, actions_model.get()));
668   std::unique_ptr<ActionsSuggestions> actions_suggestions =
669       ActionsSuggestions::FromUnownedBuffer(
670           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
671           builder.GetSize(), &unilib_);
672   AnnotatedSpan annotation;
673   annotation.span = {11, 15};
674   annotation.classification = {
675       ClassificationResult(Collections::Address(), 1.0)};
676   const ActionsSuggestionsResponse& response =
677       actions_suggestions->SuggestActions(
678           {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
679              /*reference_time_ms_utc=*/10000,
680              /*reference_timezone=*/"Europe/Zurich",
681              /*annotations=*/{}, /*locales=*/"en"},
682             {/*user_id=*/1, "good! are you at home?",
683              /*reference_time_ms_utc=*/15000,
684              /*reference_timezone=*/"Europe/Zurich",
685              /*annotations=*/{annotation},
686              /*locales=*/"en"}}});
687   ASSERT_GE(response.actions.size(), 1);
688   EXPECT_EQ(response.actions[0].type, "view_map");
689   EXPECT_EQ(response.actions[0].score, 1.0);
690 }
691 
TEST_F(ActionsSuggestionsTest,CreateActionsFromClassificationResult)692 TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
693   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
694   AnnotatedSpan annotation;
695   annotation.span = {8, 12};
696   annotation.classification = {
697       ClassificationResult(Collections::Flight(), 1.0)};
698 
699   const ActionsSuggestionsResponse& response =
700       actions_suggestions->SuggestActions(
701           {{{/*user_id=*/1, "I'm on LX38?",
702              /*reference_time_ms_utc=*/0,
703              /*reference_timezone=*/"Europe/Zurich",
704              /*annotations=*/{annotation},
705              /*locales=*/"en"}}});
706 
707   ASSERT_GE(response.actions.size(), 2);
708   EXPECT_EQ(response.actions[0].type, "track_flight");
709   EXPECT_EQ(response.actions[0].score, 1.0);
710   EXPECT_EQ(response.actions[0].annotations.size(), 1);
711   EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
712   EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
713 }
714 
715 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,CreateActionsFromRules)716 TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
717   const std::string actions_model_string =
718       ReadFile(GetModelPath() + kModelFileName);
719   std::unique_ptr<ActionsModelT> actions_model =
720       UnPackActionsModel(actions_model_string.c_str());
721   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
722 
723   actions_model->rules.reset(new RulesModelT());
724   actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
725   RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
726   rule->pattern = "^(?i:hello\\s(there))$";
727   rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
728   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
729   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
730   action->type = "text_reply";
731   action->response_text = "General Kenobi!";
732   action->score = 1.0f;
733   action->priority_score = 1.0f;
734 
735   // Set capturing groups for entity data.
736   rule->actions.back()->capturing_group.emplace_back(
737       new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
738   RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
739       rule->actions.back()->capturing_group.back().get();
740   greeting_group->group_id = 0;
741   greeting_group->entity_field.reset(new FlatbufferFieldPathT);
742   greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
743   greeting_group->entity_field->field.back()->field_name = "greeting";
744   rule->actions.back()->capturing_group.emplace_back(
745       new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
746   RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group =
747       rule->actions.back()->capturing_group.back().get();
748   location_group->group_id = 1;
749   location_group->entity_field.reset(new FlatbufferFieldPathT);
750   location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
751   location_group->entity_field->field.back()->field_name = "location";
752 
753   // Set test entity data schema.
754   SetTestEntityDataSchema(actions_model.get());
755 
756   // Use meta data to generate custom serialized entity data.
757   ReflectiveFlatbufferBuilder entity_data_builder(
758       flatbuffers::GetRoot<reflection::Schema>(
759           actions_model->actions_entity_data_schema.data()));
760   std::unique_ptr<ReflectiveFlatbuffer> entity_data =
761       entity_data_builder.NewRoot();
762   entity_data->Set("person", "Kenobi");
763   action->serialized_entity_data = entity_data->Serialize();
764 
765   flatbuffers::FlatBufferBuilder builder;
766   FinishActionsModelBuffer(builder,
767                            ActionsModel::Pack(builder, actions_model.get()));
768   std::unique_ptr<ActionsSuggestions> actions_suggestions =
769       ActionsSuggestions::FromUnownedBuffer(
770           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
771           builder.GetSize(), &unilib_);
772 
773   const ActionsSuggestionsResponse& response =
774       actions_suggestions->SuggestActions(
775           {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
776              /*reference_timezone=*/"Europe/Zurich",
777              /*annotations=*/{}, /*locales=*/"en"}}});
778   EXPECT_GE(response.actions.size(), 1);
779   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
780 
781   // Check entity data.
782   const flatbuffers::Table* entity =
783       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
784           response.actions[0].serialized_entity_data.data()));
785   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
786             "hello there");
787   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
788             "there");
789   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
790             "Kenobi");
791 }
792 
TEST_F(ActionsSuggestionsTest,CreatesTextRepliesFromRules)793 TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
794   const std::string actions_model_string =
795       ReadFile(GetModelPath() + kModelFileName);
796   std::unique_ptr<ActionsModelT> actions_model =
797       UnPackActionsModel(actions_model_string.c_str());
798   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
799 
800   actions_model->rules.reset(new RulesModelT());
801   actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
802   RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
803   rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
804   rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
805 
806   // Set capturing groups for entity data.
807   rule->actions.back()->capturing_group.emplace_back(
808       new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
809   RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
810       rule->actions.back()->capturing_group.back().get();
811   code_group->group_id = 1;
812   code_group->text_reply.reset(new ActionSuggestionSpecT);
813   code_group->text_reply->score = 1.0f;
814   code_group->text_reply->priority_score = 1.0f;
815 
816   flatbuffers::FlatBufferBuilder builder;
817   FinishActionsModelBuffer(builder,
818                            ActionsModel::Pack(builder, actions_model.get()));
819   std::unique_ptr<ActionsSuggestions> actions_suggestions =
820       ActionsSuggestions::FromUnownedBuffer(
821           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
822           builder.GetSize(), &unilib_);
823 
824   const ActionsSuggestionsResponse& response =
825       actions_suggestions->SuggestActions(
826           {{{/*user_id=*/1,
827              "visit test.com or reply STOP to cancel your subscription",
828              /*reference_time_ms_utc=*/0,
829              /*reference_timezone=*/"Europe/Zurich",
830              /*annotations=*/{}, /*locales=*/"en"}}});
831   EXPECT_GE(response.actions.size(), 1);
832   EXPECT_EQ(response.actions[0].response_text, "STOP");
833 }
834 
TEST_F(ActionsSuggestionsTest,DeduplicateActions)835 TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
836   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
837   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
838       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
839          /*reference_timezone=*/"Europe/Zurich",
840          /*annotations=*/{}, /*locales=*/"en"}}});
841 
842   // Check that the location sharing model triggered.
843   bool has_location_sharing_action = false;
844   for (const ActionSuggestion action : response.actions) {
845     if (action.type == ActionsSuggestions::kShareLocation) {
846       has_location_sharing_action = true;
847       break;
848     }
849   }
850   EXPECT_TRUE(has_location_sharing_action);
851   const int num_actions = response.actions.size();
852 
853   // Add custom rule for location sharing.
854   const std::string actions_model_string =
855       ReadFile(GetModelPath() + kModelFileName);
856   std::unique_ptr<ActionsModelT> actions_model =
857       UnPackActionsModel(actions_model_string.c_str());
858   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
859 
860   actions_model->rules.reset(new RulesModelT());
861   actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
862   actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$";
863   actions_model->rules->rule.back()->actions.emplace_back(
864       new RulesModel_::Rule_::RuleActionSpecT);
865   actions_model->rules->rule.back()->actions.back()->action.reset(
866       new ActionSuggestionSpecT);
867   ActionSuggestionSpecT* action =
868       actions_model->rules->rule.back()->actions.back()->action.get();
869   action->score = 1.0f;
870   action->type = ActionsSuggestions::kShareLocation;
871 
872   flatbuffers::FlatBufferBuilder builder;
873   FinishActionsModelBuffer(builder,
874                            ActionsModel::Pack(builder, actions_model.get()));
875   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
876       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
877       builder.GetSize(), &unilib_);
878 
879   response = actions_suggestions->SuggestActions(
880       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
881          /*reference_timezone=*/"Europe/Zurich",
882          /*annotations=*/{}, /*locales=*/"en"}}});
883   EXPECT_EQ(response.actions.size(), num_actions);
884 }
885 
TEST_F(ActionsSuggestionsTest,DeduplicateConflictingActions)886 TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
887   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
888   AnnotatedSpan annotation;
889   annotation.span = {7, 11};
890   annotation.classification = {
891       ClassificationResult(Collections::Flight(), 1.0)};
892   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
893       {{{/*user_id=*/1, "I'm on LX38",
894          /*reference_time_ms_utc=*/0,
895          /*reference_timezone=*/"Europe/Zurich",
896          /*annotations=*/{annotation},
897          /*locales=*/"en"}}});
898 
899   // Check that the phone actions are present.
900   EXPECT_GE(response.actions.size(), 1);
901   EXPECT_EQ(response.actions[0].type, "track_flight");
902 
903   // Add custom rule.
904   const std::string actions_model_string =
905       ReadFile(GetModelPath() + kModelFileName);
906   std::unique_ptr<ActionsModelT> actions_model =
907       UnPackActionsModel(actions_model_string.c_str());
908   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
909 
910   actions_model->rules.reset(new RulesModelT());
911   actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
912   RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
913   rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
914   rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
915   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
916   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
917   action->score = 1.0f;
918   action->priority_score = 2.0f;
919   action->type = "test_code";
920   rule->actions.back()->capturing_group.emplace_back(
921       new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
922   RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
923       rule->actions.back()->capturing_group.back().get();
924   code_group->group_id = 1;
925   code_group->annotation_name = "code";
926   code_group->annotation_type = "code";
927 
928   flatbuffers::FlatBufferBuilder builder;
929   FinishActionsModelBuffer(builder,
930                            ActionsModel::Pack(builder, actions_model.get()));
931   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
932       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
933       builder.GetSize(), &unilib_);
934 
935   response = actions_suggestions->SuggestActions(
936       {{{/*user_id=*/1, "I'm on LX38",
937          /*reference_time_ms_utc=*/0,
938          /*reference_timezone=*/"Europe/Zurich",
939          /*annotations=*/{annotation},
940          /*locales=*/"en"}}});
941   EXPECT_GE(response.actions.size(), 1);
942   EXPECT_EQ(response.actions[0].type, "test_code");
943 }
944 #endif
945 
TEST_F(ActionsSuggestionsTest,SuggestActionsRanking)946 TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) {
947   std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
948   std::vector<AnnotatedSpan> annotations(2);
949   annotations[0].span = {11, 15};
950   annotations[0].classification = {ClassificationResult("address", 1.0)};
951   annotations[1].span = {19, 23};
952   annotations[1].classification = {ClassificationResult("address", 2.0)};
953   const ActionsSuggestionsResponse& response =
954       actions_suggestions->SuggestActions(
955           {{{/*user_id=*/1, "are you at home or work?",
956              /*reference_time_ms_utc=*/0,
957              /*reference_timezone=*/"Europe/Zurich",
958              /*annotations=*/annotations,
959              /*locales=*/"en"}}});
960   EXPECT_GE(response.actions.size(), 2);
961   EXPECT_EQ(response.actions[0].type, "view_map");
962   EXPECT_EQ(response.actions[0].score, 2.0);
963   EXPECT_EQ(response.actions[1].type, "view_map");
964   EXPECT_EQ(response.actions[1].score, 1.0);
965 }
966 
TEST_F(ActionsSuggestionsTest,VisitActionsModel)967 TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
968   EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
969                                       [](const ActionsModel* model) {
970                                         if (model == nullptr) {
971                                           return false;
972                                         }
973                                         return true;
974                                       }));
975   EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
976                                        [](const ActionsModel* model) {
977                                          if (model == nullptr) {
978                                            return false;
979                                          }
980                                          return true;
981                                        }));
982 }
983 
TEST_F(ActionsSuggestionsTest,SuggestActionsWithHashGramModel)984 TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) {
985   std::unique_ptr<ActionsSuggestions> actions_suggestions =
986       LoadHashGramTestModel();
987   ASSERT_TRUE(actions_suggestions != nullptr);
988   {
989     const ActionsSuggestionsResponse& response =
990         actions_suggestions->SuggestActions(
991             {{{/*user_id=*/1, "hello",
992                /*reference_time_ms_utc=*/0,
993                /*reference_timezone=*/"Europe/Zurich",
994                /*annotations=*/{},
995                /*locales=*/"en"}}});
996     EXPECT_THAT(response.actions, testing::IsEmpty());
997   }
998   {
999     const ActionsSuggestionsResponse& response =
1000         actions_suggestions->SuggestActions(
1001             {{{/*user_id=*/1, "where are you",
1002                /*reference_time_ms_utc=*/0,
1003                /*reference_timezone=*/"Europe/Zurich",
1004                /*annotations=*/{},
1005                /*locales=*/"en"}}});
1006     EXPECT_THAT(
1007         response.actions,
1008         ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
1009   }
1010   {
1011     const ActionsSuggestionsResponse& response =
1012         actions_suggestions->SuggestActions(
1013             {{{/*user_id=*/1, "do you know johns number",
1014                /*reference_time_ms_utc=*/0,
1015                /*reference_timezone=*/"Europe/Zurich",
1016                /*annotations=*/{},
1017                /*locales=*/"en"}}});
1018     EXPECT_THAT(
1019         response.actions,
1020         ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
1021   }
1022 }
1023 
1024 // Test class to expose token embedding methods for testing.
1025 class TestingMessageEmbedder : private ActionsSuggestions {
1026  public:
1027   explicit TestingMessageEmbedder(const ActionsModel* model);
1028 
1029   using ActionsSuggestions::EmbedAndFlattenTokens;
1030   using ActionsSuggestions::EmbedTokensPerMessage;
1031 
1032  protected:
1033   // EmbeddingExecutor that always returns features based on
1034   // the id of the sparse features.
1035   class FakeEmbeddingExecutor : public EmbeddingExecutor {
1036    public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const1037     bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
1038                       const int dest_size) const override {
1039       TC3_CHECK_GE(dest_size, 1);
1040       EXPECT_EQ(sparse_features.size(), 1);
1041       dest[0] = sparse_features.data()[0];
1042       return true;
1043     }
1044   };
1045 };
1046 
TestingMessageEmbedder(const ActionsModel * model)1047 TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) {
1048   model_ = model;
1049   const ActionsTokenFeatureProcessorOptions* options =
1050       model->feature_processor_options();
1051   feature_processor_.reset(
1052       new ActionsFeatureProcessor(options, /*unilib=*/nullptr));
1053   embedding_executor_.reset(new FakeEmbeddingExecutor());
1054   EXPECT_TRUE(
1055       EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
1056   EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
1057   EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
1058   token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
1059   EXPECT_EQ(token_embedding_size_, 1);
1060 }
1061 
1062 class EmbeddingTest : public testing::Test {
1063  protected:
EmbeddingTest()1064   EmbeddingTest() {
1065     model_.feature_processor_options.reset(
1066         new ActionsTokenFeatureProcessorOptionsT);
1067     options_ = model_.feature_processor_options.get();
1068     options_->chargram_orders = {1};
1069     options_->num_buckets = 1000;
1070     options_->embedding_size = 1;
1071     options_->start_token_id = 0;
1072     options_->end_token_id = 1;
1073     options_->padding_token_id = 2;
1074     options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1075   }
1076 
CreateTestingMessageEmbedder()1077   TestingMessageEmbedder CreateTestingMessageEmbedder() {
1078     flatbuffers::FlatBufferBuilder builder;
1079     FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
1080     buffer_ = builder.ReleaseBufferPointer();
1081     return TestingMessageEmbedder(
1082         flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
1083   }
1084 
1085   flatbuffers::DetachedBuffer buffer_;
1086   ActionsModelT model_;
1087   ActionsTokenFeatureProcessorOptionsT* options_;
1088 };
1089 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithNoBounds)1090 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
1091   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1092   std::vector<std::vector<Token>> tokens = {
1093       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1094   std::vector<float> embeddings;
1095   int max_num_tokens_per_message = 0;
1096 
1097   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1098                                              &max_num_tokens_per_message));
1099 
1100   EXPECT_EQ(max_num_tokens_per_message, 3);
1101   EXPECT_EQ(embeddings.size(), 3);
1102   EXPECT_THAT(embeddings[0],
1103               testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1104                                options_->num_buckets));
1105   EXPECT_THAT(embeddings[1],
1106               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1107                                options_->num_buckets));
1108   EXPECT_THAT(embeddings[2],
1109               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1110                                options_->num_buckets));
1111 }
1112 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithPadding)1113 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
1114   options_->min_num_tokens_per_message = 5;
1115   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1116   std::vector<std::vector<Token>> tokens = {
1117       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1118   std::vector<float> embeddings;
1119   int max_num_tokens_per_message = 0;
1120 
1121   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1122                                              &max_num_tokens_per_message));
1123 
1124   EXPECT_EQ(max_num_tokens_per_message, 5);
1125   EXPECT_EQ(embeddings.size(), 5);
1126   EXPECT_THAT(embeddings[0],
1127               testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1128                                options_->num_buckets));
1129   EXPECT_THAT(embeddings[1],
1130               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1131                                options_->num_buckets));
1132   EXPECT_THAT(embeddings[2],
1133               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1134                                options_->num_buckets));
1135   EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id));
1136   EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id));
1137 }
1138 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageDropsAtBeginning)1139 TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
1140   options_->max_num_tokens_per_message = 2;
1141   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1142   std::vector<std::vector<Token>> tokens = {
1143       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1144   std::vector<float> embeddings;
1145   int max_num_tokens_per_message = 0;
1146 
1147   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1148                                              &max_num_tokens_per_message));
1149 
1150   EXPECT_EQ(max_num_tokens_per_message, 2);
1151   EXPECT_EQ(embeddings.size(), 2);
1152   EXPECT_THAT(embeddings[0],
1153               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1154                                options_->num_buckets));
1155   EXPECT_THAT(embeddings[1],
1156               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1157                                options_->num_buckets));
1158 }
1159 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithMultipleMessagesNoBounds)1160 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
1161   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1162   std::vector<std::vector<Token>> tokens = {
1163       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1164       {Token("d", 0, 1), Token("e", 2, 3)}};
1165   std::vector<float> embeddings;
1166   int max_num_tokens_per_message = 0;
1167 
1168   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1169                                              &max_num_tokens_per_message));
1170 
1171   EXPECT_EQ(max_num_tokens_per_message, 3);
1172   EXPECT_THAT(embeddings[0],
1173               testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1174                                options_->num_buckets));
1175   EXPECT_THAT(embeddings[1],
1176               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1177                                options_->num_buckets));
1178   EXPECT_THAT(embeddings[2],
1179               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1180                                options_->num_buckets));
1181   EXPECT_THAT(embeddings[3],
1182               testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1183                                options_->num_buckets));
1184   EXPECT_THAT(embeddings[4],
1185               testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1186                                options_->num_buckets));
1187   EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
1188 }
1189 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithNoBounds)1190 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
1191   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1192   std::vector<std::vector<Token>> tokens = {
1193       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1194   std::vector<float> embeddings;
1195   int total_token_count = 0;
1196 
1197   EXPECT_TRUE(
1198       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1199 
1200   EXPECT_EQ(total_token_count, 5);
1201   EXPECT_EQ(embeddings.size(), 5);
1202   EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
1203   EXPECT_THAT(embeddings[1],
1204               testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1205                                options_->num_buckets));
1206   EXPECT_THAT(embeddings[2],
1207               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1208                                options_->num_buckets));
1209   EXPECT_THAT(embeddings[3],
1210               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1211                                options_->num_buckets));
1212   EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
1213 }
1214 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithPadding)1215 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
1216   options_->min_num_total_tokens = 7;
1217   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1218   std::vector<std::vector<Token>> tokens = {
1219       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1220   std::vector<float> embeddings;
1221   int total_token_count = 0;
1222 
1223   EXPECT_TRUE(
1224       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1225 
1226   EXPECT_EQ(total_token_count, 7);
1227   EXPECT_EQ(embeddings.size(), 7);
1228   EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
1229   EXPECT_THAT(embeddings[1],
1230               testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1231                                options_->num_buckets));
1232   EXPECT_THAT(embeddings[2],
1233               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1234                                options_->num_buckets));
1235   EXPECT_THAT(embeddings[3],
1236               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1237                                options_->num_buckets));
1238   EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
1239   EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
1240   EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id));
1241 }
1242 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensDropsAtBeginning)1243 TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
1244   options_->max_num_total_tokens = 3;
1245   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1246   std::vector<std::vector<Token>> tokens = {
1247       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1248   std::vector<float> embeddings;
1249   int total_token_count = 0;
1250 
1251   EXPECT_TRUE(
1252       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1253 
1254   EXPECT_EQ(total_token_count, 3);
1255   EXPECT_EQ(embeddings.size(), 3);
1256   EXPECT_THAT(embeddings[0],
1257               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1258                                options_->num_buckets));
1259   EXPECT_THAT(embeddings[1],
1260               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1261                                options_->num_buckets));
1262   EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id));
1263 }
1264 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesNoBounds)1265 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
1266   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1267   std::vector<std::vector<Token>> tokens = {
1268       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1269       {Token("d", 0, 1), Token("e", 2, 3)}};
1270   std::vector<float> embeddings;
1271   int total_token_count = 0;
1272 
1273   EXPECT_TRUE(
1274       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1275 
1276   EXPECT_EQ(total_token_count, 9);
1277   EXPECT_EQ(embeddings.size(), 9);
1278   EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
1279   EXPECT_THAT(embeddings[1],
1280               testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1281                                options_->num_buckets));
1282   EXPECT_THAT(embeddings[2],
1283               testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1284                                options_->num_buckets));
1285   EXPECT_THAT(embeddings[3],
1286               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1287                                options_->num_buckets));
1288   EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
1289   EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id));
1290   EXPECT_THAT(embeddings[6],
1291               testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1292                                options_->num_buckets));
1293   EXPECT_THAT(embeddings[7],
1294               testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1295                                options_->num_buckets));
1296   EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id));
1297 }
1298 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning)1299 TEST_F(EmbeddingTest,
1300        EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
1301   options_->max_num_total_tokens = 7;
1302   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1303   std::vector<std::vector<Token>> tokens = {
1304       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1305       {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
1306   std::vector<float> embeddings;
1307   int total_token_count = 0;
1308 
1309   EXPECT_TRUE(
1310       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1311 
1312   EXPECT_EQ(total_token_count, 7);
1313   EXPECT_EQ(embeddings.size(), 7);
1314   EXPECT_THAT(embeddings[0],
1315               testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1316                                options_->num_buckets));
1317   EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id));
1318   EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id));
1319   EXPECT_THAT(embeddings[3],
1320               testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1321                                options_->num_buckets));
1322   EXPECT_THAT(embeddings[4],
1323               testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1324                                options_->num_buckets));
1325   EXPECT_THAT(embeddings[5],
1326               testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) %
1327                                options_->num_buckets));
1328   EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id));
1329 }
1330 
1331 }  // namespace
1332 }  // namespace libtextclassifier3
1333