• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/grammar-actions.h"
18 
19 #include <iostream>
20 #include <memory>
21 
22 #include "actions/actions_model_generated.h"
23 #include "actions/test-utils.h"
24 #include "actions/types.h"
25 #include "utils/flatbuffers/flatbuffers.h"
26 #include "utils/flatbuffers/mutable.h"
27 #include "utils/grammar/rules_generated.h"
28 #include "utils/grammar/types.h"
29 #include "utils/grammar/utils/rules.h"
30 #include "utils/jvm-test-utils.h"
31 #include "gmock/gmock.h"
32 #include "gtest/gtest.h"
33 
34 namespace libtextclassifier3 {
35 namespace {
36 
37 using ::testing::ElementsAre;
38 using ::testing::IsEmpty;
39 
40 using ::libtextclassifier3::grammar::LocaleShardMap;
41 
42 class TestGrammarActions : public GrammarActions {
43  public:
TestGrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const MutableFlatbufferBuilder * entity_data_builder=nullptr)44   explicit TestGrammarActions(
45       const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
46       const MutableFlatbufferBuilder* entity_data_builder = nullptr)
47       : GrammarActions(unilib, grammar_rules, entity_data_builder,
48 
49                        /*smart_reply_action_type=*/"text_reply") {}
50 };
51 
52 class GrammarActionsTest : public testing::Test {
53  protected:
54   struct AnnotationSpec {
55     int group_id = 0;
56     std::string annotation_name = "";
57     bool use_annotation_match = false;
58   };
59 
GrammarActionsTest()60   GrammarActionsTest()
61       : unilib_(CreateUniLibForTesting()),
62         serialized_entity_data_schema_(TestEntityDataSchema()),
63         entity_data_builder_(new MutableFlatbufferBuilder(
64             flatbuffers::GetRoot<reflection::Schema>(
65                 serialized_entity_data_schema_.data()))) {}
66 
SetTokenizerOptions(RulesModel_::GrammarRulesT * action_grammar_rules) const67   void SetTokenizerOptions(
68       RulesModel_::GrammarRulesT* action_grammar_rules) const {
69     action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
70     action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
71     action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
72         false;
73   }
74 
AddActionSpec(const std::string & type,const std::string & response_text,const std::vector<AnnotationSpec> & annotations,RulesModel_::GrammarRulesT * action_grammar_rules) const75   int AddActionSpec(const std::string& type, const std::string& response_text,
76                     const std::vector<AnnotationSpec>& annotations,
77                     RulesModel_::GrammarRulesT* action_grammar_rules) const {
78     const int action_id = action_grammar_rules->actions.size();
79     action_grammar_rules->actions.emplace_back(
80         new RulesModel_::RuleActionSpecT);
81     RulesModel_::RuleActionSpecT* actions_spec =
82         action_grammar_rules->actions.back().get();
83     actions_spec->action.reset(new ActionSuggestionSpecT);
84     actions_spec->action->response_text = response_text;
85     actions_spec->action->priority_score = 1.0;
86     actions_spec->action->score = 1.0;
87     actions_spec->action->type = type;
88     // Create annotations for specified capturing groups.
89     for (const AnnotationSpec& annotation : annotations) {
90       actions_spec->capturing_group.emplace_back(
91           new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
92       actions_spec->capturing_group.back()->group_id = annotation.group_id;
93       actions_spec->capturing_group.back()->annotation_name =
94           annotation.annotation_name;
95       actions_spec->capturing_group.back()->annotation_type =
96           annotation.annotation_name;
97       actions_spec->capturing_group.back()->use_annotation_match =
98           annotation.use_annotation_match;
99     }
100 
101     return action_id;
102   }
103 
AddSmartReplySpec(const std::string & response_text,RulesModel_::GrammarRulesT * action_grammar_rules) const104   int AddSmartReplySpec(
105       const std::string& response_text,
106       RulesModel_::GrammarRulesT* action_grammar_rules) const {
107     return AddActionSpec("text_reply", response_text, {}, action_grammar_rules);
108   }
109 
AddCapturingMatchSmartReplySpec(const int match_id,RulesModel_::GrammarRulesT * action_grammar_rules) const110   int AddCapturingMatchSmartReplySpec(
111       const int match_id,
112       RulesModel_::GrammarRulesT* action_grammar_rules) const {
113     const int action_id = action_grammar_rules->actions.size();
114     action_grammar_rules->actions.emplace_back(
115         new RulesModel_::RuleActionSpecT);
116     RulesModel_::RuleActionSpecT* actions_spec =
117         action_grammar_rules->actions.back().get();
118     actions_spec->capturing_group.emplace_back(
119         new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
120     actions_spec->capturing_group.back()->group_id = match_id;
121     actions_spec->capturing_group.back()->text_reply.reset(
122         new ActionSuggestionSpecT);
123     actions_spec->capturing_group.back()->text_reply->priority_score = 1.0;
124     actions_spec->capturing_group.back()->text_reply->score = 1.0;
125     return action_id;
126   }
127 
AddRuleMatch(const std::vector<int> & action_ids,RulesModel_::GrammarRulesT * action_grammar_rules) const128   int AddRuleMatch(const std::vector<int>& action_ids,
129                    RulesModel_::GrammarRulesT* action_grammar_rules) const {
130     const int rule_match_id = action_grammar_rules->rule_match.size();
131     action_grammar_rules->rule_match.emplace_back(
132         new RulesModel_::GrammarRules_::RuleMatchT);
133     action_grammar_rules->rule_match.back()->action_id.insert(
134         action_grammar_rules->rule_match.back()->action_id.end(),
135         action_ids.begin(), action_ids.end());
136     return rule_match_id;
137   }
138 
139   std::unique_ptr<UniLib> unilib_;
140   const std::string serialized_entity_data_schema_;
141   std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
142 };
143 
TEST_F(GrammarActionsTest,ProducesSmartReplies)144 TEST_F(GrammarActionsTest, ProducesSmartReplies) {
145   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
146   grammar::Rules rules(locale_shard_map);
147 
148   // Create test rules.
149   // Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
150   RulesModel_::GrammarRulesT action_grammar_rules;
151   SetTokenizerOptions(&action_grammar_rules);
152   action_grammar_rules.rules.reset(new grammar::RulesSetT);
153   rules.Add(
154       "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
155       /*callback=*/
156       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
157       /*callback_param=*/
158       AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
159                     AddSmartReplySpec("Yes?", &action_grammar_rules)},
160                    &action_grammar_rules));
161   rules.Finalize().Serialize(/*include_debug_information=*/false,
162                              action_grammar_rules.rules.get());
163   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
164       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
165   TestGrammarActions grammar_actions(unilib_.get(), model.get());
166 
167   std::vector<ActionSuggestion> result;
168   EXPECT_TRUE(grammar_actions.SuggestActions(
169       {/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result));
170 
171   EXPECT_THAT(result,
172               ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?")));
173 }
174 
TEST_F(GrammarActionsTest,ProducesSmartRepliesFromCapturingMatches)175 TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
176   // Create test rules.
177   // Rule: ^Text <reply> to <command>
178   RulesModel_::GrammarRulesT action_grammar_rules;
179   SetTokenizerOptions(&action_grammar_rules);
180   action_grammar_rules.rules.reset(new grammar::RulesSetT);
181   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
182   grammar::Rules rules(locale_shard_map);
183 
184   rules.Add(
185       "<scripted_reply>",
186       {"<^>", "text", "<captured_reply>", "to", "<command>"},
187       /*callback=*/
188       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
189       /*callback_param=*/
190       AddRuleMatch({AddCapturingMatchSmartReplySpec(
191                        /*match_id=*/0, &action_grammar_rules)},
192                    &action_grammar_rules));
193 
194   // <command> ::= unsubscribe | cancel | confirm | receive
195   rules.Add("<command>", {"unsubscribe"});
196   rules.Add("<command>", {"cancel"});
197   rules.Add("<command>", {"confirm"});
198   rules.Add("<command>", {"receive"});
199 
200   // <reply> ::= help | stop | cancel | yes
201   rules.Add("<reply>", {"help"});
202   rules.Add("<reply>", {"stop"});
203   rules.Add("<reply>", {"cancel"});
204   rules.Add("<reply>", {"yes"});
205   rules.AddValueMapping("<captured_reply>", {"<reply>"},
206                         /*value=*/0);
207 
208   rules.Finalize().Serialize(/*include_debug_information=*/false,
209                              action_grammar_rules.rules.get());
210   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
211       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
212   TestGrammarActions grammar_actions(unilib_.get(), model.get());
213 
214   {
215     std::vector<ActionSuggestion> result;
216     EXPECT_TRUE(grammar_actions.SuggestActions(
217         {/*messages=*/{{/*user_id=*/0,
218                         /*text=*/"Text YES to confirm your subscription"}}},
219         &result));
220     EXPECT_THAT(result, ElementsAre(IsSmartReply("YES")));
221   }
222 
223   {
224     std::vector<ActionSuggestion> result;
225     EXPECT_TRUE(grammar_actions.SuggestActions(
226         {/*messages=*/{{/*user_id=*/0,
227                         /*text=*/"text Stop to cancel your order"}}},
228         &result));
229     EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop")));
230   }
231 }
232 
TEST_F(GrammarActionsTest,ProducesAnnotationsForActions)233 TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
234   // Create test rules.
235   // Rule: please dial <phone>
236   RulesModel_::GrammarRulesT action_grammar_rules;
237   SetTokenizerOptions(&action_grammar_rules);
238   action_grammar_rules.rules.reset(new grammar::RulesSetT);
239   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
240   grammar::Rules rules(locale_shard_map);
241 
242   rules.Add(
243       "<call_phone>", {"please", "dial", "<phone>"},
244       /*callback=*/
245       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
246       /*callback_param=*/
247       AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
248                                   /*annotations=*/{{0 /*value*/, "phone"}},
249                                   &action_grammar_rules)},
250                    &action_grammar_rules));
251   // phone ::= +00 00 000 00 00
252   rules.AddValueMapping("<phone>",
253                         {"+", "<2_digits>", "<2_digits>", "<3_digits>",
254                          "<2_digits>", "<2_digits>"},
255                         /*value=*/0);
256   rules.Finalize().Serialize(/*include_debug_information=*/false,
257                              action_grammar_rules.rules.get());
258   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
259       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
260   TestGrammarActions grammar_actions(unilib_.get(), model.get());
261 
262   std::vector<ActionSuggestion> result;
263   EXPECT_TRUE(grammar_actions.SuggestActions(
264       {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
265       &result));
266 
267   EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
268   EXPECT_THAT(result.front().annotations,
269               ElementsAre(IsActionSuggestionAnnotation(
270                   "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
271 }
272 
TEST_F(GrammarActionsTest,HandlesLocales)273 TEST_F(GrammarActionsTest, HandlesLocales) {
274   // Create test rules.
275   // Rule: ^knock knock.?$ -> "Who's there?"
276   RulesModel_::GrammarRulesT action_grammar_rules;
277   SetTokenizerOptions(&action_grammar_rules);
278   action_grammar_rules.rules.reset(new grammar::RulesSetT);
279   LocaleShardMap locale_shard_map =
280       LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"});
281   grammar::Rules rules(locale_shard_map);
282   rules.Add(
283       "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
284       /*callback=*/
285       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
286       /*callback_param=*/
287       AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
288                    &action_grammar_rules));
289   rules.Add(
290       "<toc>", {"<knock>"},
291       /*callback=*/
292       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
293       /*callback_param=*/
294       AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
295                    &action_grammar_rules),
296       /*max_whitespace_gap=*/-1,
297       /*case_sensitive=*/false,
298       /*shard=*/1);
299   rules.Finalize().Serialize(/*include_debug_information=*/false,
300                              action_grammar_rules.rules.get());
301   // Set locales for rules.
302   action_grammar_rules.rules->rules.back()->locale.emplace_back(
303       new LanguageTagT);
304   action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
305 
306   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
307       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
308   TestGrammarActions grammar_actions(unilib_.get(), model.get());
309 
310   // Check default.
311   {
312     std::vector<ActionSuggestion> result;
313     EXPECT_TRUE(grammar_actions.SuggestActions(
314         {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
315                         /*reference_time_ms_utc=*/0,
316                         /*reference_timezone=*/"UTC", /*annotations=*/{},
317                         /*detected_text_language_tags=*/"en"}}},
318         &result));
319 
320     EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?")));
321   }
322 
323   // Check fr.
324   {
325     std::vector<ActionSuggestion> result;
326     EXPECT_TRUE(grammar_actions.SuggestActions(
327         {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
328                         /*reference_time_ms_utc=*/0,
329                         /*reference_timezone=*/"UTC", /*annotations=*/{},
330                         /*detected_text_language_tags=*/"fr-CH"}}},
331         &result));
332 
333     EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"),
334                                     IsSmartReply("Qui est là?")));
335   }
336 }
337 
TEST_F(GrammarActionsTest,HandlesAssertions)338 TEST_F(GrammarActionsTest, HandlesAssertions) {
339   // Create test rules.
340   // Rule: <flight> -> Track flight.
341   RulesModel_::GrammarRulesT action_grammar_rules;
342   SetTokenizerOptions(&action_grammar_rules);
343   action_grammar_rules.rules.reset(new grammar::RulesSetT);
344   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
345   grammar::Rules rules(locale_shard_map);
346   rules.Add("<carrier>", {"lx"});
347   rules.Add("<carrier>", {"aa"});
348   rules.Add("<flight_code>", {"<2_digits>"});
349   rules.Add("<flight_code>", {"<3_digits>"});
350   rules.Add("<flight_code>", {"<4_digits>"});
351 
352   // Capture flight code.
353   rules.AddValueMapping("<flight>", {"<carrier>", "<flight_code>"},
354                         /*value=*/0);
355 
356   // Flight: carrier + flight code and check right context.
357   rules.Add(
358       "<track_flight>", {"<flight>", "<context_assertion>?"},
359       /*callback=*/
360       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
361       /*callback_param=*/
362       AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
363                                   /*annotations=*/{{0 /*value*/, "flight"}},
364                                   &action_grammar_rules)},
365                    &action_grammar_rules));
366 
367   // Exclude matches like: LX 38.00 etc.
368   rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
369                      /*negative=*/true);
370 
371   rules.Finalize().Serialize(/*include_debug_information=*/false,
372                              action_grammar_rules.rules.get());
373 
374   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
375       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
376   TestGrammarActions grammar_actions(unilib_.get(), model.get());
377 
378   std::vector<ActionSuggestion> result;
379   EXPECT_TRUE(grammar_actions.SuggestActions(
380       {/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}},
381       &result));
382 
383   EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"),
384                                   IsActionOfType("track_flight")));
385   EXPECT_THAT(result[0].annotations,
386               ElementsAre(IsActionSuggestionAnnotation("flight", "LX38",
387                                                        CodepointSpan{0, 4})));
388   EXPECT_THAT(result[1].annotations,
389               ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44",
390                                                        CodepointSpan{5, 10})));
391 }
392 
TEST_F(GrammarActionsTest,SetsFixedEntityData)393 TEST_F(GrammarActionsTest, SetsFixedEntityData) {
394   // Create test rules.
395   // Rule: ^hello there$
396   RulesModel_::GrammarRulesT action_grammar_rules;
397   SetTokenizerOptions(&action_grammar_rules);
398   action_grammar_rules.rules.reset(new grammar::RulesSetT);
399   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
400   grammar::Rules rules(locale_shard_map);
401 
402   // Create smart reply and static entity data.
403   const int spec_id =
404       AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
405   std::unique_ptr<MutableFlatbuffer> entity_data =
406       entity_data_builder_->NewRoot();
407   entity_data->Set("person", "Kenobi");
408   action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
409       entity_data->Serialize();
410   action_grammar_rules.actions[spec_id]->action->entity_data.reset(
411       new ActionsEntityDataT);
412   action_grammar_rules.actions[spec_id]->action->entity_data->text =
413       "I have the high ground.";
414 
415   rules.Add(
416       "<greeting>", {"<^>", "hello", "there", "<$>"},
417       /*callback=*/
418       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
419       /*callback_param=*/
420       AddRuleMatch({spec_id}, &action_grammar_rules));
421   rules.Finalize().Serialize(/*include_debug_information=*/false,
422                              action_grammar_rules.rules.get());
423   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
424       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
425   TestGrammarActions grammar_actions(unilib_.get(), model.get(),
426                                      entity_data_builder_.get());
427 
428   std::vector<ActionSuggestion> result;
429   EXPECT_TRUE(grammar_actions.SuggestActions(
430       {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
431 
432   // Check the produces smart replies.
433   EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
434 
435   // Check entity data.
436   const flatbuffers::Table* entity =
437       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
438           result[0].serialized_entity_data.data()));
439   EXPECT_THAT(
440       entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
441       "I have the high ground.");
442   EXPECT_THAT(
443       entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
444       "Kenobi");
445 }
446 
TEST_F(GrammarActionsTest,SetsEntityDataFromCapturingMatches)447 TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
448   // Create test rules.
449   // Rule: ^hello there$
450   RulesModel_::GrammarRulesT action_grammar_rules;
451   SetTokenizerOptions(&action_grammar_rules);
452   action_grammar_rules.rules.reset(new grammar::RulesSetT);
453   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
454   grammar::Rules rules(locale_shard_map);
455 
456   // Create smart reply and static entity data.
457   const int spec_id =
458       AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
459   std::unique_ptr<MutableFlatbuffer> entity_data =
460       entity_data_builder_->NewRoot();
461   entity_data->Set("person", "Kenobi");
462   action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
463       entity_data->Serialize();
464 
465   // Specify results for capturing matches.
466   const int greeting_match_id = 0;
467   const int location_match_id = 1;
468   {
469     action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
470         new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
471     RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
472         action_grammar_rules.actions[spec_id]->capturing_group.back().get();
473     group->group_id = greeting_match_id;
474     group->entity_field.reset(new FlatbufferFieldPathT);
475     group->entity_field->field.emplace_back(new FlatbufferFieldT);
476     group->entity_field->field.back()->field_name = "greeting";
477   }
478   {
479     action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
480         new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
481     RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
482         action_grammar_rules.actions[spec_id]->capturing_group.back().get();
483     group->group_id = location_match_id;
484     group->entity_field.reset(new FlatbufferFieldPathT);
485     group->entity_field->field.emplace_back(new FlatbufferFieldT);
486     group->entity_field->field.back()->field_name = "location";
487   }
488 
489   rules.Add("<location>", {"there"});
490   rules.Add("<location>", {"here"});
491   rules.AddValueMapping("<captured_location>", {"<location>"},
492                         /*value=*/location_match_id);
493   rules.AddValueMapping("<greeting>", {"hello", "<captured_location>"},
494                         /*value=*/greeting_match_id);
495   rules.Add(
496       "<test>", {"<^>", "<greeting>", "<$>"},
497       /*callback=*/
498       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
499       /*callback_param=*/
500       AddRuleMatch({spec_id}, &action_grammar_rules));
501   rules.Finalize().Serialize(/*include_debug_information=*/false,
502                              action_grammar_rules.rules.get());
503   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
504       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
505   TestGrammarActions grammar_actions(unilib_.get(), model.get(),
506                                      entity_data_builder_.get());
507 
508   std::vector<ActionSuggestion> result;
509   EXPECT_TRUE(grammar_actions.SuggestActions(
510       {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
511 
512   // Check the produces smart replies.
513   EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
514 
515   // Check entity data.
516   const flatbuffers::Table* entity =
517       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
518           result[0].serialized_entity_data.data()));
519   EXPECT_THAT(
520       entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
521       "Hello there");
522   EXPECT_THAT(
523       entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
524       "there");
525   EXPECT_THAT(
526       entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
527       "Kenobi");
528 }
529 
TEST_F(GrammarActionsTest,SetsFixedEntityDataFromCapturingGroups)530 TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) {
531   // Create test rules.
532   // Rule: ^hello there$
533   RulesModel_::GrammarRulesT action_grammar_rules;
534   SetTokenizerOptions(&action_grammar_rules);
535   action_grammar_rules.rules.reset(new grammar::RulesSetT);
536   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
537   grammar::Rules rules(locale_shard_map);
538 
539   // Create smart reply.
540   const int spec_id =
541       AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
542   action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
543       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
544   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
545       action_grammar_rules.actions[spec_id]->capturing_group.back().get();
546   group->group_id = 0;
547   group->entity_data.reset(new ActionsEntityDataT);
548   group->entity_data->text = "You are a bold one.";
549 
550   rules.AddValueMapping("<greeting>", {"<^>", "hello", "there", "<$>"},
551                         /*value=*/0);
552   rules.Add(
553       "<test>", {"<greeting>"},
554       /*callback=*/
555       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
556       /*callback_param=*/
557       AddRuleMatch({spec_id}, &action_grammar_rules));
558   rules.Finalize().Serialize(/*include_debug_information=*/false,
559                              action_grammar_rules.rules.get());
560   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
561       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
562   TestGrammarActions grammar_actions(unilib_.get(), model.get(),
563                                      entity_data_builder_.get());
564 
565   std::vector<ActionSuggestion> result;
566   EXPECT_TRUE(grammar_actions.SuggestActions(
567       {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
568 
569   // Check the produces smart replies.
570   EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
571 
572   // Check entity data.
573   const flatbuffers::Table* entity =
574       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
575           result[0].serialized_entity_data.data()));
576   EXPECT_THAT(
577       entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
578       "You are a bold one.");
579 }
580 
TEST_F(GrammarActionsTest,ProducesActionsWithAnnotations)581 TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) {
582   // Create test rules.
583   // Rule: please dial <phone>
584   RulesModel_::GrammarRulesT action_grammar_rules;
585   SetTokenizerOptions(&action_grammar_rules);
586   action_grammar_rules.rules.reset(new grammar::RulesSetT);
587   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
588   grammar::Rules rules(locale_shard_map);
589   rules.Add(
590       "<call_phone>", {"please", "dial", "<phone>"},
591       /*callback=*/
592       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
593       /*callback_param=*/
594       AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
595                                   /*annotations=*/
596                                   {{0 /*value*/, "phone",
597                                     /*use_annotation_match=*/true}},
598                                   &action_grammar_rules)},
599                    &action_grammar_rules));
600   rules.AddValueMapping("<phone>", {"<phone_annotation>"},
601                         /*value=*/0);
602 
603   grammar::Ir ir = rules.Finalize(
604       /*predefined_nonterminals=*/{"<phone_annotation>"});
605   ir.Serialize(/*include_debug_information=*/false,
606                action_grammar_rules.rules.get());
607 
608   // Map "phone" annotation to "<phone_annotation>" nonterminal.
609   action_grammar_rules.rules->nonterminals->annotation_nt.emplace_back(
610       new grammar::RulesSet_::Nonterminals_::AnnotationNtEntryT);
611   action_grammar_rules.rules->nonterminals->annotation_nt.back()->key = "phone";
612   action_grammar_rules.rules->nonterminals->annotation_nt.back()->value =
613       ir.GetNonterminalForName("<phone_annotation>");
614 
615   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
616       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
617   TestGrammarActions grammar_actions(unilib_.get(), model.get());
618 
619   std::vector<ActionSuggestion> result;
620 
621   // Sanity check that no result are produced when no annotations are provided.
622   EXPECT_TRUE(grammar_actions.SuggestActions(
623       {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
624       &result));
625   EXPECT_THAT(result, IsEmpty());
626 
627   EXPECT_TRUE(grammar_actions.SuggestActions(
628       {/*messages=*/{
629           {/*user_id=*/0,
630            /*text=*/"Please dial +41 79 123 45 67",
631            /*reference_time_ms_utc=*/0,
632            /*reference_timezone=*/"UTC",
633            /*annotations=*/
634            {{CodepointSpan{12, 28}, {ClassificationResult{"phone", 1.0}}}}}}},
635       &result));
636   EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
637   EXPECT_THAT(result.front().annotations,
638               ElementsAre(IsActionSuggestionAnnotation(
639                   "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
640 }
641 
TEST_F(GrammarActionsTest,HandlesExclusions)642 TEST_F(GrammarActionsTest, HandlesExclusions) {
643   // Create test rules.
644   RulesModel_::GrammarRulesT action_grammar_rules;
645   SetTokenizerOptions(&action_grammar_rules);
646   action_grammar_rules.rules.reset(new grammar::RulesSetT);
647 
648   LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
649   grammar::Rules rules(locale_shard_map);
650   rules.Add("<excluded>", {"be", "safe"});
651   rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
652                          /*excluded_nonterminal=*/"<excluded>");
653 
654   rules.Add(
655       "<set_reminder>",
656       {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
657       /*callback=*/
658       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
659       /*callback_param=*/
660       AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
661                                   /*annotations=*/
662                                   {}, &action_grammar_rules)},
663                    &action_grammar_rules));
664 
665   rules.Finalize().Serialize(/*include_debug_information=*/false,
666                              action_grammar_rules.rules.get());
667   OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
668       PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
669   TestGrammarActions grammar_actions(unilib_.get(), model.get(),
670                                      entity_data_builder_.get());
671 
672   {
673     std::vector<ActionSuggestion> result;
674     EXPECT_TRUE(grammar_actions.SuggestActions(
675         {/*messages=*/{
676             {/*user_id=*/0, /*text=*/"do not forget to bring milk"}}},
677         &result));
678     EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
679   }
680 
681   {
682     std::vector<ActionSuggestion> result;
683     EXPECT_TRUE(grammar_actions.SuggestActions(
684         {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be there!"}}},
685         &result));
686     EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
687   }
688 
689   {
690     std::vector<ActionSuggestion> result;
691     EXPECT_TRUE(grammar_actions.SuggestActions(
692         {/*messages=*/{
693             {/*user_id=*/0, /*text=*/"do not forget to buy safe or vault!"}}},
694         &result));
695     EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
696   }
697 
698   {
699     std::vector<ActionSuggestion> result;
700     EXPECT_TRUE(grammar_actions.SuggestActions(
701         {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be safe!"}}},
702         &result));
703     EXPECT_THAT(result, IsEmpty());
704   }
705 }
706 
707 }  // namespace
708 }  // namespace libtextclassifier3
709