• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/parsing/parser.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "utils/grammar/parsing/derivation.h"
23 #include "utils/grammar/rules_generated.h"
24 #include "utils/grammar/testing/utils.h"
25 #include "utils/grammar/types.h"
26 #include "utils/grammar/utils/ir.h"
27 #include "utils/grammar/utils/rules.h"
28 #include "utils/i18n/locale.h"
29 #include "utils/tokenizer.h"
30 #include "utils/utf8/unicodetext.h"
31 #include "utils/utf8/unilib.h"
32 #include "gmock/gmock.h"
33 #include "gtest/gtest.h"
34 
35 namespace libtextclassifier3::grammar {
36 namespace {
37 
38 using ::testing::ElementsAre;
39 using ::testing::IsEmpty;
40 
41 class ParserTest : public GrammarTest {};
42 
TEST_F(ParserTest,ParsesSimpleRules)43 TEST_F(ParserTest, ParsesSimpleRules) {
44   grammar::LocaleShardMap locale_shard_map =
45       grammar::LocaleShardMap::CreateLocaleShardMap({""});
46   Rules rules(locale_shard_map);
47   rules.Add("<day>", {"<2_digits>"});
48   rules.Add("<month>", {"<2_digits>"});
49   rules.Add("<year>", {"<4_digits>"});
50   constexpr int kDate = 0;
51   rules.Add("<date>", {"<year>", "/", "<month>", "/", "<day>"},
52             static_cast<CallbackId>(DefaultCallback::kRootRule), kDate);
53   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
54   Parser parser(unilib_.get(),
55                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
56 
57   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
58                   TextContextForText("Event: 2020/05/08"), &arena_)),
59               ElementsAre(IsDerivation(kDate, 7, 17)));
60 }
61 
TEST_F(ParserTest,HandlesEmptyInput)62 TEST_F(ParserTest, HandlesEmptyInput) {
63   grammar::LocaleShardMap locale_shard_map =
64       grammar::LocaleShardMap::CreateLocaleShardMap({""});
65   Rules rules(locale_shard_map);
66   constexpr int kTest = 0;
67   rules.Add("<test>", {"test"},
68             static_cast<CallbackId>(DefaultCallback::kRootRule), kTest);
69   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
70   Parser parser(unilib_.get(),
71                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
72 
73   EXPECT_THAT(ValidDeduplicatedDerivations(
74                   parser.Parse(TextContextForText("Event: test"), &arena_)),
75               ElementsAre(IsDerivation(kTest, 7, 11)));
76 
77   // Check that we bail out in case of empty input.
78   EXPECT_THAT(ValidDeduplicatedDerivations(
79                   parser.Parse(TextContextForText(""), &arena_)),
80               IsEmpty());
81   EXPECT_THAT(ValidDeduplicatedDerivations(
82                   parser.Parse(TextContextForText("    "), &arena_)),
83               IsEmpty());
84 }
85 
TEST_F(ParserTest,HandlesUppercaseTokens)86 TEST_F(ParserTest, HandlesUppercaseTokens) {
87   grammar::LocaleShardMap locale_shard_map =
88       grammar::LocaleShardMap::CreateLocaleShardMap({""});
89   Rules rules(locale_shard_map);
90   constexpr int kScriptedReply = 0;
91   rules.Add("<test>", {"please?", "reply", "<uppercase_token>"},
92             static_cast<CallbackId>(DefaultCallback::kRootRule),
93             kScriptedReply);
94   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
95   Parser parser(unilib_.get(),
96                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
97 
98   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
99                   TextContextForText("Reply STOP to cancel."), &arena_)),
100               ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
101 
102   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
103                   TextContextForText("Reply stop to cancel."), &arena_)),
104               IsEmpty());
105 }
106 
TEST_F(ParserTest,HandlesAnchors)107 TEST_F(ParserTest, HandlesAnchors) {
108   grammar::LocaleShardMap locale_shard_map =
109       grammar::LocaleShardMap::CreateLocaleShardMap({""});
110   Rules rules(locale_shard_map);
111   constexpr int kScriptedReply = 0;
112   rules.Add("<test>", {"<^>", "reply", "<uppercase_token>", "<$>"},
113             static_cast<CallbackId>(DefaultCallback::kRootRule),
114             kScriptedReply);
115   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
116   Parser parser(unilib_.get(),
117                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
118 
119   EXPECT_THAT(ValidDeduplicatedDerivations(
120                   parser.Parse(TextContextForText("Reply STOP"), &arena_)),
121               ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
122 
123   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
124                   TextContextForText("Please reply STOP to cancel."), &arena_)),
125               IsEmpty());
126 }
127 
TEST_F(ParserTest,HandlesWordBreaks)128 TEST_F(ParserTest, HandlesWordBreaks) {
129   grammar::LocaleShardMap locale_shard_map =
130       grammar::LocaleShardMap::CreateLocaleShardMap({""});
131   Rules rules(locale_shard_map);
132   rules.Add("<carrier>", {"lx"});
133   rules.Add("<carrier>", {"aa"});
134   constexpr int kFlight = 0;
135   rules.Add("<flight>", {"<carrier>", "<digits>", "<\b>"},
136             static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
137   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
138   Parser parser(unilib_.get(),
139                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
140 
141   // Make sure the grammar recognizes "LX 38".
142   EXPECT_THAT(
143       ValidDeduplicatedDerivations(parser.Parse(
144           TextContextForText("My flight is: LX 38. Arriving later"), &arena_)),
145       ElementsAre(IsDerivation(kFlight, 14, 19)));
146 
147   // Make sure the grammar doesn't trigger on "LX 38.00".
148   EXPECT_THAT(ValidDeduplicatedDerivations(
149                   parser.Parse(TextContextForText("LX 38.00"), &arena_)),
150               IsEmpty());
151 }
152 
TEST_F(ParserTest,HandlesAnnotations)153 TEST_F(ParserTest, HandlesAnnotations) {
154   grammar::LocaleShardMap locale_shard_map =
155       grammar::LocaleShardMap::CreateLocaleShardMap({""});
156   Rules rules(locale_shard_map);
157   constexpr int kCallPhone = 0;
158   rules.Add("<flight>", {"dial", "<phone>"},
159             static_cast<CallbackId>(DefaultCallback::kRootRule), kCallPhone);
160   rules.BindAnnotation("<phone>", "phone");
161   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
162   Parser parser(unilib_.get(),
163                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
164 
165   TextContext context = TextContextForText("Please dial 911");
166 
167   // Sanity check that we don't trigger if we don't feed the correct
168   // annotations.
169   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
170               IsEmpty());
171 
172   // Create a phone annotion.
173   AnnotatedSpan phone_span;
174   phone_span.span = CodepointSpan{12, 15};
175   phone_span.classification.emplace_back("phone", 1.0);
176   context.annotations.push_back(phone_span);
177   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
178               ElementsAre(IsDerivation(kCallPhone, 7, 15)));
179 }
180 
TEST_F(ParserTest,HandlesRegexAnnotators)181 TEST_F(ParserTest, HandlesRegexAnnotators) {
182   grammar::LocaleShardMap locale_shard_map =
183       grammar::LocaleShardMap::CreateLocaleShardMap({""});
184   Rules rules(locale_shard_map);
185   rules.AddRegex("<code>",
186                  "(\"([A-Za-z]+)\"|\\b\"?(?:[A-Z]+[0-9]*|[0-9])\"?\\b)");
187   constexpr int kScriptedReply = 0;
188   rules.Add("<test>", {"please?", "reply", "<code>"},
189             static_cast<CallbackId>(DefaultCallback::kRootRule),
190             kScriptedReply);
191   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
192   Parser parser(unilib_.get(),
193                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
194 
195   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
196                   TextContextForText("Reply STOP to cancel."), &arena_)),
197               ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
198 
199   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
200                   TextContextForText("Reply Stop to cancel."), &arena_)),
201               IsEmpty());
202 }
203 
TEST_F(ParserTest,HandlesExclusions)204 TEST_F(ParserTest, HandlesExclusions) {
205   grammar::LocaleShardMap locale_shard_map =
206       grammar::LocaleShardMap::CreateLocaleShardMap({""});
207   Rules rules(locale_shard_map);
208   rules.Add("<excluded>", {"be", "safe"});
209   rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
210                          /*excluded_nonterminal=*/"<excluded>");
211   constexpr int kSetReminder = 0;
212   rules.Add("<set_reminder>",
213             {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
214             static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
215   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
216   Parser parser(unilib_.get(),
217                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
218 
219   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
220                   TextContextForText("do not forget to be there"), &arena_)),
221               ElementsAre(IsDerivation(kSetReminder, 0, 25)));
222 
223   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
224                   TextContextForText("do not forget to be safe"), &arena_)),
225               IsEmpty());
226 }
227 
TEST_F(ParserTest,HandlesFillers)228 TEST_F(ParserTest, HandlesFillers) {
229   grammar::LocaleShardMap locale_shard_map =
230       grammar::LocaleShardMap::CreateLocaleShardMap({""});
231   Rules rules(locale_shard_map);
232   constexpr int kSetReminder = 0;
233   rules.Add("<set_reminder>", {"do", "not", "forget", "to", "<filler>"},
234             static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
235   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
236   Parser parser(unilib_.get(),
237                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
238 
239   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
240                   TextContextForText("do not forget to be there"), &arena_)),
241               ElementsAre(IsDerivation(kSetReminder, 0, 25)));
242 }
243 
TEST_F(ParserTest,HandlesAssertions)244 TEST_F(ParserTest, HandlesAssertions) {
245   grammar::LocaleShardMap locale_shard_map =
246       grammar::LocaleShardMap::CreateLocaleShardMap({""});
247   Rules rules(locale_shard_map);
248   rules.Add("<carrier>", {"lx"});
249   rules.Add("<carrier>", {"aa"});
250   rules.Add("<flight_code>", {"<2_digits>"});
251   rules.Add("<flight_code>", {"<3_digits>"});
252   rules.Add("<flight_code>", {"<4_digits>"});
253   // Flight: carrier + flight code and check right context.
254   constexpr int kFlight = 0;
255   rules.Add("<track_flight>",
256             {"<carrier>", "<flight_code>", "<context_assertion>?"},
257             static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
258   // Exclude matches like: LX 38.00 etc.
259   rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
260                      /*negative=*/true);
261   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
262   Parser parser(unilib_.get(),
263                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
264 
265   EXPECT_THAT(
266       ValidDeduplicatedDerivations(
267           parser.Parse(TextContextForText("LX38 aa 44 LX 38.38"), &arena_)),
268       ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
269 }
270 
TEST_F(ParserTest,HandlesWhitespaceGapLimit)271 TEST_F(ParserTest, HandlesWhitespaceGapLimit) {
272   grammar::LocaleShardMap locale_shard_map =
273       grammar::LocaleShardMap::CreateLocaleShardMap({""});
274   Rules rules(locale_shard_map);
275   rules.Add("<carrier>", {"lx"});
276   rules.Add("<carrier>", {"aa"});
277   rules.Add("<flight_code>", {"<2_digits>"});
278   rules.Add("<flight_code>", {"<3_digits>"});
279   rules.Add("<flight_code>", {"<4_digits>"});
280   // Flight: carrier + flight code and check right context.
281   constexpr int kFlight = 0;
282   rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
283             static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight,
284             /*max_whitespace_gap=*/0);
285   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
286   Parser parser(unilib_.get(),
287                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
288 
289   EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
290                   TextContextForText("LX38 aa 44 LX 38"), &arena_)),
291               ElementsAre(IsDerivation(kFlight, 0, 4)));
292 }
293 
TEST_F(ParserTest,HandlesCaseSensitiveMatching)294 TEST_F(ParserTest, HandlesCaseSensitiveMatching) {
295   grammar::LocaleShardMap locale_shard_map =
296       grammar::LocaleShardMap::CreateLocaleShardMap({""});
297   Rules rules(locale_shard_map);
298   rules.Add("<carrier>", {"Lx"}, /*callback=*/kNoCallback, /*callback_param=*/0,
299             /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
300   rules.Add("<carrier>", {"AA"}, /*callback=*/kNoCallback, /*callback_param=*/0,
301             /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
302   rules.Add("<flight_code>", {"<2_digits>"});
303   rules.Add("<flight_code>", {"<3_digits>"});
304   rules.Add("<flight_code>", {"<4_digits>"});
305   // Flight: carrier + flight code and check right context.
306   constexpr int kFlight = 0;
307   rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
308             static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
309   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
310   Parser parser(unilib_.get(),
311                 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
312 
313   EXPECT_THAT(
314       ValidDeduplicatedDerivations(
315           parser.Parse(TextContextForText("Lx38 AA 44 LX 38"), &arena_)),
316       ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
317 }
318 
319 }  // namespace
320 }  // namespace libtextclassifier3::grammar
321