1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/grammar-actions.h"
18
19 #include <iostream>
20 #include <memory>
21
22 #include "actions/actions_model_generated.h"
23 #include "actions/test-utils.h"
24 #include "actions/types.h"
25 #include "utils/flatbuffers/flatbuffers.h"
26 #include "utils/flatbuffers/mutable.h"
27 #include "utils/grammar/rules_generated.h"
28 #include "utils/grammar/types.h"
29 #include "utils/grammar/utils/rules.h"
30 #include "utils/jvm-test-utils.h"
31 #include "gmock/gmock.h"
32 #include "gtest/gtest.h"
33
34 namespace libtextclassifier3 {
35 namespace {
36
37 using ::testing::ElementsAre;
38 using ::testing::IsEmpty;
39
40 using ::libtextclassifier3::grammar::LocaleShardMap;
41
42 class TestGrammarActions : public GrammarActions {
43 public:
TestGrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const MutableFlatbufferBuilder * entity_data_builder=nullptr)44 explicit TestGrammarActions(
45 const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
46 const MutableFlatbufferBuilder* entity_data_builder = nullptr)
47 : GrammarActions(unilib, grammar_rules, entity_data_builder,
48
49 /*smart_reply_action_type=*/"text_reply") {}
50 };
51
52 class GrammarActionsTest : public testing::Test {
53 protected:
54 struct AnnotationSpec {
55 int group_id = 0;
56 std::string annotation_name = "";
57 bool use_annotation_match = false;
58 };
59
GrammarActionsTest()60 GrammarActionsTest()
61 : unilib_(CreateUniLibForTesting()),
62 serialized_entity_data_schema_(TestEntityDataSchema()),
63 entity_data_builder_(new MutableFlatbufferBuilder(
64 flatbuffers::GetRoot<reflection::Schema>(
65 serialized_entity_data_schema_.data()))) {}
66
SetTokenizerOptions(RulesModel_::GrammarRulesT * action_grammar_rules) const67 void SetTokenizerOptions(
68 RulesModel_::GrammarRulesT* action_grammar_rules) const {
69 action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
70 action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
71 action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
72 false;
73 }
74
AddActionSpec(const std::string & type,const std::string & response_text,const std::vector<AnnotationSpec> & annotations,RulesModel_::GrammarRulesT * action_grammar_rules) const75 int AddActionSpec(const std::string& type, const std::string& response_text,
76 const std::vector<AnnotationSpec>& annotations,
77 RulesModel_::GrammarRulesT* action_grammar_rules) const {
78 const int action_id = action_grammar_rules->actions.size();
79 action_grammar_rules->actions.emplace_back(
80 new RulesModel_::RuleActionSpecT);
81 RulesModel_::RuleActionSpecT* actions_spec =
82 action_grammar_rules->actions.back().get();
83 actions_spec->action.reset(new ActionSuggestionSpecT);
84 actions_spec->action->response_text = response_text;
85 actions_spec->action->priority_score = 1.0;
86 actions_spec->action->score = 1.0;
87 actions_spec->action->type = type;
88 // Create annotations for specified capturing groups.
89 for (const AnnotationSpec& annotation : annotations) {
90 actions_spec->capturing_group.emplace_back(
91 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
92 actions_spec->capturing_group.back()->group_id = annotation.group_id;
93 actions_spec->capturing_group.back()->annotation_name =
94 annotation.annotation_name;
95 actions_spec->capturing_group.back()->annotation_type =
96 annotation.annotation_name;
97 actions_spec->capturing_group.back()->use_annotation_match =
98 annotation.use_annotation_match;
99 }
100
101 return action_id;
102 }
103
AddSmartReplySpec(const std::string & response_text,RulesModel_::GrammarRulesT * action_grammar_rules) const104 int AddSmartReplySpec(
105 const std::string& response_text,
106 RulesModel_::GrammarRulesT* action_grammar_rules) const {
107 return AddActionSpec("text_reply", response_text, {}, action_grammar_rules);
108 }
109
AddCapturingMatchSmartReplySpec(const int match_id,RulesModel_::GrammarRulesT * action_grammar_rules) const110 int AddCapturingMatchSmartReplySpec(
111 const int match_id,
112 RulesModel_::GrammarRulesT* action_grammar_rules) const {
113 const int action_id = action_grammar_rules->actions.size();
114 action_grammar_rules->actions.emplace_back(
115 new RulesModel_::RuleActionSpecT);
116 RulesModel_::RuleActionSpecT* actions_spec =
117 action_grammar_rules->actions.back().get();
118 actions_spec->capturing_group.emplace_back(
119 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
120 actions_spec->capturing_group.back()->group_id = match_id;
121 actions_spec->capturing_group.back()->text_reply.reset(
122 new ActionSuggestionSpecT);
123 actions_spec->capturing_group.back()->text_reply->priority_score = 1.0;
124 actions_spec->capturing_group.back()->text_reply->score = 1.0;
125 return action_id;
126 }
127
AddRuleMatch(const std::vector<int> & action_ids,RulesModel_::GrammarRulesT * action_grammar_rules) const128 int AddRuleMatch(const std::vector<int>& action_ids,
129 RulesModel_::GrammarRulesT* action_grammar_rules) const {
130 const int rule_match_id = action_grammar_rules->rule_match.size();
131 action_grammar_rules->rule_match.emplace_back(
132 new RulesModel_::GrammarRules_::RuleMatchT);
133 action_grammar_rules->rule_match.back()->action_id.insert(
134 action_grammar_rules->rule_match.back()->action_id.end(),
135 action_ids.begin(), action_ids.end());
136 return rule_match_id;
137 }
138
139 std::unique_ptr<UniLib> unilib_;
140 const std::string serialized_entity_data_schema_;
141 std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
142 };
143
TEST_F(GrammarActionsTest,ProducesSmartReplies)144 TEST_F(GrammarActionsTest, ProducesSmartReplies) {
145 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
146 grammar::Rules rules(locale_shard_map);
147
148 // Create test rules.
149 // Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
150 RulesModel_::GrammarRulesT action_grammar_rules;
151 SetTokenizerOptions(&action_grammar_rules);
152 action_grammar_rules.rules.reset(new grammar::RulesSetT);
153 rules.Add(
154 "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
155 /*callback=*/
156 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
157 /*callback_param=*/
158 AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
159 AddSmartReplySpec("Yes?", &action_grammar_rules)},
160 &action_grammar_rules));
161 rules.Finalize().Serialize(/*include_debug_information=*/false,
162 action_grammar_rules.rules.get());
163 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
164 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
165 TestGrammarActions grammar_actions(unilib_.get(), model.get());
166
167 std::vector<ActionSuggestion> result;
168 EXPECT_TRUE(grammar_actions.SuggestActions(
169 {/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result));
170
171 EXPECT_THAT(result,
172 ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?")));
173 }
174
TEST_F(GrammarActionsTest,ProducesSmartRepliesFromCapturingMatches)175 TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
176 // Create test rules.
177 // Rule: ^Text <reply> to <command>
178 RulesModel_::GrammarRulesT action_grammar_rules;
179 SetTokenizerOptions(&action_grammar_rules);
180 action_grammar_rules.rules.reset(new grammar::RulesSetT);
181 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
182 grammar::Rules rules(locale_shard_map);
183
184 rules.Add(
185 "<scripted_reply>",
186 {"<^>", "text", "<captured_reply>", "to", "<command>"},
187 /*callback=*/
188 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
189 /*callback_param=*/
190 AddRuleMatch({AddCapturingMatchSmartReplySpec(
191 /*match_id=*/0, &action_grammar_rules)},
192 &action_grammar_rules));
193
194 // <command> ::= unsubscribe | cancel | confirm | receive
195 rules.Add("<command>", {"unsubscribe"});
196 rules.Add("<command>", {"cancel"});
197 rules.Add("<command>", {"confirm"});
198 rules.Add("<command>", {"receive"});
199
200 // <reply> ::= help | stop | cancel | yes
201 rules.Add("<reply>", {"help"});
202 rules.Add("<reply>", {"stop"});
203 rules.Add("<reply>", {"cancel"});
204 rules.Add("<reply>", {"yes"});
205 rules.AddValueMapping("<captured_reply>", {"<reply>"},
206 /*value=*/0);
207
208 rules.Finalize().Serialize(/*include_debug_information=*/false,
209 action_grammar_rules.rules.get());
210 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
211 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
212 TestGrammarActions grammar_actions(unilib_.get(), model.get());
213
214 {
215 std::vector<ActionSuggestion> result;
216 EXPECT_TRUE(grammar_actions.SuggestActions(
217 {/*messages=*/{{/*user_id=*/0,
218 /*text=*/"Text YES to confirm your subscription"}}},
219 &result));
220 EXPECT_THAT(result, ElementsAre(IsSmartReply("YES")));
221 }
222
223 {
224 std::vector<ActionSuggestion> result;
225 EXPECT_TRUE(grammar_actions.SuggestActions(
226 {/*messages=*/{{/*user_id=*/0,
227 /*text=*/"text Stop to cancel your order"}}},
228 &result));
229 EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop")));
230 }
231 }
232
TEST_F(GrammarActionsTest,ProducesAnnotationsForActions)233 TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
234 // Create test rules.
235 // Rule: please dial <phone>
236 RulesModel_::GrammarRulesT action_grammar_rules;
237 SetTokenizerOptions(&action_grammar_rules);
238 action_grammar_rules.rules.reset(new grammar::RulesSetT);
239 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
240 grammar::Rules rules(locale_shard_map);
241
242 rules.Add(
243 "<call_phone>", {"please", "dial", "<phone>"},
244 /*callback=*/
245 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
246 /*callback_param=*/
247 AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
248 /*annotations=*/{{0 /*value*/, "phone"}},
249 &action_grammar_rules)},
250 &action_grammar_rules));
251 // phone ::= +00 00 000 00 00
252 rules.AddValueMapping("<phone>",
253 {"+", "<2_digits>", "<2_digits>", "<3_digits>",
254 "<2_digits>", "<2_digits>"},
255 /*value=*/0);
256 rules.Finalize().Serialize(/*include_debug_information=*/false,
257 action_grammar_rules.rules.get());
258 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
259 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
260 TestGrammarActions grammar_actions(unilib_.get(), model.get());
261
262 std::vector<ActionSuggestion> result;
263 EXPECT_TRUE(grammar_actions.SuggestActions(
264 {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
265 &result));
266
267 EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
268 EXPECT_THAT(result.front().annotations,
269 ElementsAre(IsActionSuggestionAnnotation(
270 "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
271 }
272
TEST_F(GrammarActionsTest,HandlesLocales)273 TEST_F(GrammarActionsTest, HandlesLocales) {
274 // Create test rules.
275 // Rule: ^knock knock.?$ -> "Who's there?"
276 RulesModel_::GrammarRulesT action_grammar_rules;
277 SetTokenizerOptions(&action_grammar_rules);
278 action_grammar_rules.rules.reset(new grammar::RulesSetT);
279 LocaleShardMap locale_shard_map =
280 LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"});
281 grammar::Rules rules(locale_shard_map);
282 rules.Add(
283 "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
284 /*callback=*/
285 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
286 /*callback_param=*/
287 AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
288 &action_grammar_rules));
289 rules.Add(
290 "<toc>", {"<knock>"},
291 /*callback=*/
292 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
293 /*callback_param=*/
294 AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
295 &action_grammar_rules),
296 /*max_whitespace_gap=*/-1,
297 /*case_sensitive=*/false,
298 /*shard=*/1);
299 rules.Finalize().Serialize(/*include_debug_information=*/false,
300 action_grammar_rules.rules.get());
301 // Set locales for rules.
302 action_grammar_rules.rules->rules.back()->locale.emplace_back(
303 new LanguageTagT);
304 action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
305
306 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
307 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
308 TestGrammarActions grammar_actions(unilib_.get(), model.get());
309
310 // Check default.
311 {
312 std::vector<ActionSuggestion> result;
313 EXPECT_TRUE(grammar_actions.SuggestActions(
314 {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
315 /*reference_time_ms_utc=*/0,
316 /*reference_timezone=*/"UTC", /*annotations=*/{},
317 /*detected_text_language_tags=*/"en"}}},
318 &result));
319
320 EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?")));
321 }
322
323 // Check fr.
324 {
325 std::vector<ActionSuggestion> result;
326 EXPECT_TRUE(grammar_actions.SuggestActions(
327 {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
328 /*reference_time_ms_utc=*/0,
329 /*reference_timezone=*/"UTC", /*annotations=*/{},
330 /*detected_text_language_tags=*/"fr-CH"}}},
331 &result));
332
333 EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"),
334 IsSmartReply("Qui est là?")));
335 }
336 }
337
TEST_F(GrammarActionsTest,HandlesAssertions)338 TEST_F(GrammarActionsTest, HandlesAssertions) {
339 // Create test rules.
340 // Rule: <flight> -> Track flight.
341 RulesModel_::GrammarRulesT action_grammar_rules;
342 SetTokenizerOptions(&action_grammar_rules);
343 action_grammar_rules.rules.reset(new grammar::RulesSetT);
344 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
345 grammar::Rules rules(locale_shard_map);
346 rules.Add("<carrier>", {"lx"});
347 rules.Add("<carrier>", {"aa"});
348 rules.Add("<flight_code>", {"<2_digits>"});
349 rules.Add("<flight_code>", {"<3_digits>"});
350 rules.Add("<flight_code>", {"<4_digits>"});
351
352 // Capture flight code.
353 rules.AddValueMapping("<flight>", {"<carrier>", "<flight_code>"},
354 /*value=*/0);
355
356 // Flight: carrier + flight code and check right context.
357 rules.Add(
358 "<track_flight>", {"<flight>", "<context_assertion>?"},
359 /*callback=*/
360 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
361 /*callback_param=*/
362 AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
363 /*annotations=*/{{0 /*value*/, "flight"}},
364 &action_grammar_rules)},
365 &action_grammar_rules));
366
367 // Exclude matches like: LX 38.00 etc.
368 rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
369 /*negative=*/true);
370
371 rules.Finalize().Serialize(/*include_debug_information=*/false,
372 action_grammar_rules.rules.get());
373
374 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
375 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
376 TestGrammarActions grammar_actions(unilib_.get(), model.get());
377
378 std::vector<ActionSuggestion> result;
379 EXPECT_TRUE(grammar_actions.SuggestActions(
380 {/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}},
381 &result));
382
383 EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"),
384 IsActionOfType("track_flight")));
385 EXPECT_THAT(result[0].annotations,
386 ElementsAre(IsActionSuggestionAnnotation("flight", "LX38",
387 CodepointSpan{0, 4})));
388 EXPECT_THAT(result[1].annotations,
389 ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44",
390 CodepointSpan{5, 10})));
391 }
392
TEST_F(GrammarActionsTest,SetsFixedEntityData)393 TEST_F(GrammarActionsTest, SetsFixedEntityData) {
394 // Create test rules.
395 // Rule: ^hello there$
396 RulesModel_::GrammarRulesT action_grammar_rules;
397 SetTokenizerOptions(&action_grammar_rules);
398 action_grammar_rules.rules.reset(new grammar::RulesSetT);
399 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
400 grammar::Rules rules(locale_shard_map);
401
402 // Create smart reply and static entity data.
403 const int spec_id =
404 AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
405 std::unique_ptr<MutableFlatbuffer> entity_data =
406 entity_data_builder_->NewRoot();
407 entity_data->Set("person", "Kenobi");
408 action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
409 entity_data->Serialize();
410 action_grammar_rules.actions[spec_id]->action->entity_data.reset(
411 new ActionsEntityDataT);
412 action_grammar_rules.actions[spec_id]->action->entity_data->text =
413 "I have the high ground.";
414
415 rules.Add(
416 "<greeting>", {"<^>", "hello", "there", "<$>"},
417 /*callback=*/
418 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
419 /*callback_param=*/
420 AddRuleMatch({spec_id}, &action_grammar_rules));
421 rules.Finalize().Serialize(/*include_debug_information=*/false,
422 action_grammar_rules.rules.get());
423 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
424 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
425 TestGrammarActions grammar_actions(unilib_.get(), model.get(),
426 entity_data_builder_.get());
427
428 std::vector<ActionSuggestion> result;
429 EXPECT_TRUE(grammar_actions.SuggestActions(
430 {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
431
432 // Check the produces smart replies.
433 EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
434
435 // Check entity data.
436 const flatbuffers::Table* entity =
437 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
438 result[0].serialized_entity_data.data()));
439 EXPECT_THAT(
440 entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
441 "I have the high ground.");
442 EXPECT_THAT(
443 entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
444 "Kenobi");
445 }
446
TEST_F(GrammarActionsTest,SetsEntityDataFromCapturingMatches)447 TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
448 // Create test rules.
449 // Rule: ^hello there$
450 RulesModel_::GrammarRulesT action_grammar_rules;
451 SetTokenizerOptions(&action_grammar_rules);
452 action_grammar_rules.rules.reset(new grammar::RulesSetT);
453 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
454 grammar::Rules rules(locale_shard_map);
455
456 // Create smart reply and static entity data.
457 const int spec_id =
458 AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
459 std::unique_ptr<MutableFlatbuffer> entity_data =
460 entity_data_builder_->NewRoot();
461 entity_data->Set("person", "Kenobi");
462 action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
463 entity_data->Serialize();
464
465 // Specify results for capturing matches.
466 const int greeting_match_id = 0;
467 const int location_match_id = 1;
468 {
469 action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
470 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
471 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
472 action_grammar_rules.actions[spec_id]->capturing_group.back().get();
473 group->group_id = greeting_match_id;
474 group->entity_field.reset(new FlatbufferFieldPathT);
475 group->entity_field->field.emplace_back(new FlatbufferFieldT);
476 group->entity_field->field.back()->field_name = "greeting";
477 }
478 {
479 action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
480 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
481 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
482 action_grammar_rules.actions[spec_id]->capturing_group.back().get();
483 group->group_id = location_match_id;
484 group->entity_field.reset(new FlatbufferFieldPathT);
485 group->entity_field->field.emplace_back(new FlatbufferFieldT);
486 group->entity_field->field.back()->field_name = "location";
487 }
488
489 rules.Add("<location>", {"there"});
490 rules.Add("<location>", {"here"});
491 rules.AddValueMapping("<captured_location>", {"<location>"},
492 /*value=*/location_match_id);
493 rules.AddValueMapping("<greeting>", {"hello", "<captured_location>"},
494 /*value=*/greeting_match_id);
495 rules.Add(
496 "<test>", {"<^>", "<greeting>", "<$>"},
497 /*callback=*/
498 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
499 /*callback_param=*/
500 AddRuleMatch({spec_id}, &action_grammar_rules));
501 rules.Finalize().Serialize(/*include_debug_information=*/false,
502 action_grammar_rules.rules.get());
503 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
504 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
505 TestGrammarActions grammar_actions(unilib_.get(), model.get(),
506 entity_data_builder_.get());
507
508 std::vector<ActionSuggestion> result;
509 EXPECT_TRUE(grammar_actions.SuggestActions(
510 {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
511
512 // Check the produces smart replies.
513 EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
514
515 // Check entity data.
516 const flatbuffers::Table* entity =
517 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
518 result[0].serialized_entity_data.data()));
519 EXPECT_THAT(
520 entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
521 "Hello there");
522 EXPECT_THAT(
523 entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
524 "there");
525 EXPECT_THAT(
526 entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
527 "Kenobi");
528 }
529
TEST_F(GrammarActionsTest,SetsFixedEntityDataFromCapturingGroups)530 TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) {
531 // Create test rules.
532 // Rule: ^hello there$
533 RulesModel_::GrammarRulesT action_grammar_rules;
534 SetTokenizerOptions(&action_grammar_rules);
535 action_grammar_rules.rules.reset(new grammar::RulesSetT);
536 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
537 grammar::Rules rules(locale_shard_map);
538
539 // Create smart reply.
540 const int spec_id =
541 AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
542 action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
543 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
544 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
545 action_grammar_rules.actions[spec_id]->capturing_group.back().get();
546 group->group_id = 0;
547 group->entity_data.reset(new ActionsEntityDataT);
548 group->entity_data->text = "You are a bold one.";
549
550 rules.AddValueMapping("<greeting>", {"<^>", "hello", "there", "<$>"},
551 /*value=*/0);
552 rules.Add(
553 "<test>", {"<greeting>"},
554 /*callback=*/
555 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
556 /*callback_param=*/
557 AddRuleMatch({spec_id}, &action_grammar_rules));
558 rules.Finalize().Serialize(/*include_debug_information=*/false,
559 action_grammar_rules.rules.get());
560 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
561 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
562 TestGrammarActions grammar_actions(unilib_.get(), model.get(),
563 entity_data_builder_.get());
564
565 std::vector<ActionSuggestion> result;
566 EXPECT_TRUE(grammar_actions.SuggestActions(
567 {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
568
569 // Check the produces smart replies.
570 EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
571
572 // Check entity data.
573 const flatbuffers::Table* entity =
574 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
575 result[0].serialized_entity_data.data()));
576 EXPECT_THAT(
577 entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
578 "You are a bold one.");
579 }
580
TEST_F(GrammarActionsTest,ProducesActionsWithAnnotations)581 TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) {
582 // Create test rules.
583 // Rule: please dial <phone>
584 RulesModel_::GrammarRulesT action_grammar_rules;
585 SetTokenizerOptions(&action_grammar_rules);
586 action_grammar_rules.rules.reset(new grammar::RulesSetT);
587 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
588 grammar::Rules rules(locale_shard_map);
589 rules.Add(
590 "<call_phone>", {"please", "dial", "<phone>"},
591 /*callback=*/
592 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
593 /*callback_param=*/
594 AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
595 /*annotations=*/
596 {{0 /*value*/, "phone",
597 /*use_annotation_match=*/true}},
598 &action_grammar_rules)},
599 &action_grammar_rules));
600 rules.AddValueMapping("<phone>", {"<phone_annotation>"},
601 /*value=*/0);
602
603 grammar::Ir ir = rules.Finalize(
604 /*predefined_nonterminals=*/{"<phone_annotation>"});
605 ir.Serialize(/*include_debug_information=*/false,
606 action_grammar_rules.rules.get());
607
608 // Map "phone" annotation to "<phone_annotation>" nonterminal.
609 action_grammar_rules.rules->nonterminals->annotation_nt.emplace_back(
610 new grammar::RulesSet_::Nonterminals_::AnnotationNtEntryT);
611 action_grammar_rules.rules->nonterminals->annotation_nt.back()->key = "phone";
612 action_grammar_rules.rules->nonterminals->annotation_nt.back()->value =
613 ir.GetNonterminalForName("<phone_annotation>");
614
615 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
616 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
617 TestGrammarActions grammar_actions(unilib_.get(), model.get());
618
619 std::vector<ActionSuggestion> result;
620
621 // Sanity check that no result are produced when no annotations are provided.
622 EXPECT_TRUE(grammar_actions.SuggestActions(
623 {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
624 &result));
625 EXPECT_THAT(result, IsEmpty());
626
627 EXPECT_TRUE(grammar_actions.SuggestActions(
628 {/*messages=*/{
629 {/*user_id=*/0,
630 /*text=*/"Please dial +41 79 123 45 67",
631 /*reference_time_ms_utc=*/0,
632 /*reference_timezone=*/"UTC",
633 /*annotations=*/
634 {{CodepointSpan{12, 28}, {ClassificationResult{"phone", 1.0}}}}}}},
635 &result));
636 EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
637 EXPECT_THAT(result.front().annotations,
638 ElementsAre(IsActionSuggestionAnnotation(
639 "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
640 }
641
TEST_F(GrammarActionsTest,HandlesExclusions)642 TEST_F(GrammarActionsTest, HandlesExclusions) {
643 // Create test rules.
644 RulesModel_::GrammarRulesT action_grammar_rules;
645 SetTokenizerOptions(&action_grammar_rules);
646 action_grammar_rules.rules.reset(new grammar::RulesSetT);
647
648 LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
649 grammar::Rules rules(locale_shard_map);
650 rules.Add("<excluded>", {"be", "safe"});
651 rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
652 /*excluded_nonterminal=*/"<excluded>");
653
654 rules.Add(
655 "<set_reminder>",
656 {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
657 /*callback=*/
658 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
659 /*callback_param=*/
660 AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
661 /*annotations=*/
662 {}, &action_grammar_rules)},
663 &action_grammar_rules));
664
665 rules.Finalize().Serialize(/*include_debug_information=*/false,
666 action_grammar_rules.rules.get());
667 OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
668 PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
669 TestGrammarActions grammar_actions(unilib_.get(), model.get(),
670 entity_data_builder_.get());
671
672 {
673 std::vector<ActionSuggestion> result;
674 EXPECT_TRUE(grammar_actions.SuggestActions(
675 {/*messages=*/{
676 {/*user_id=*/0, /*text=*/"do not forget to bring milk"}}},
677 &result));
678 EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
679 }
680
681 {
682 std::vector<ActionSuggestion> result;
683 EXPECT_TRUE(grammar_actions.SuggestActions(
684 {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be there!"}}},
685 &result));
686 EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
687 }
688
689 {
690 std::vector<ActionSuggestion> result;
691 EXPECT_TRUE(grammar_actions.SuggestActions(
692 {/*messages=*/{
693 {/*user_id=*/0, /*text=*/"do not forget to buy safe or vault!"}}},
694 &result));
695 EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
696 }
697
698 {
699 std::vector<ActionSuggestion> result;
700 EXPECT_TRUE(grammar_actions.SuggestActions(
701 {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be safe!"}}},
702 &result));
703 EXPECT_THAT(result, IsEmpty());
704 }
705 }
706
707 } // namespace
708 } // namespace libtextclassifier3
709