1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/parsing/parser.h"
18
19 #include <string>
20 #include <vector>
21
22 #include "utils/grammar/parsing/derivation.h"
23 #include "utils/grammar/rules_generated.h"
24 #include "utils/grammar/testing/utils.h"
25 #include "utils/grammar/types.h"
26 #include "utils/grammar/utils/ir.h"
27 #include "utils/grammar/utils/rules.h"
28 #include "utils/i18n/locale.h"
29 #include "utils/tokenizer.h"
30 #include "utils/utf8/unicodetext.h"
31 #include "utils/utf8/unilib.h"
32 #include "gmock/gmock.h"
33 #include "gtest/gtest.h"
34
35 namespace libtextclassifier3::grammar {
36 namespace {
37
38 using ::testing::ElementsAre;
39 using ::testing::IsEmpty;
40
41 class ParserTest : public GrammarTest {};
42
TEST_F(ParserTest,ParsesSimpleRules)43 TEST_F(ParserTest, ParsesSimpleRules) {
44 grammar::LocaleShardMap locale_shard_map =
45 grammar::LocaleShardMap::CreateLocaleShardMap({""});
46 Rules rules(locale_shard_map);
47 rules.Add("<day>", {"<2_digits>"});
48 rules.Add("<month>", {"<2_digits>"});
49 rules.Add("<year>", {"<4_digits>"});
50 constexpr int kDate = 0;
51 rules.Add("<date>", {"<year>", "/", "<month>", "/", "<day>"},
52 static_cast<CallbackId>(DefaultCallback::kRootRule), kDate);
53 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
54 Parser parser(unilib_.get(),
55 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
56
57 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
58 TextContextForText("Event: 2020/05/08"), &arena_)),
59 ElementsAre(IsDerivation(kDate, 7, 17)));
60 }
61
TEST_F(ParserTest,HandlesEmptyInput)62 TEST_F(ParserTest, HandlesEmptyInput) {
63 grammar::LocaleShardMap locale_shard_map =
64 grammar::LocaleShardMap::CreateLocaleShardMap({""});
65 Rules rules(locale_shard_map);
66 constexpr int kTest = 0;
67 rules.Add("<test>", {"test"},
68 static_cast<CallbackId>(DefaultCallback::kRootRule), kTest);
69 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
70 Parser parser(unilib_.get(),
71 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
72
73 EXPECT_THAT(ValidDeduplicatedDerivations(
74 parser.Parse(TextContextForText("Event: test"), &arena_)),
75 ElementsAre(IsDerivation(kTest, 7, 11)));
76
77 // Check that we bail out in case of empty input.
78 EXPECT_THAT(ValidDeduplicatedDerivations(
79 parser.Parse(TextContextForText(""), &arena_)),
80 IsEmpty());
81 EXPECT_THAT(ValidDeduplicatedDerivations(
82 parser.Parse(TextContextForText(" "), &arena_)),
83 IsEmpty());
84 }
85
TEST_F(ParserTest,HandlesUppercaseTokens)86 TEST_F(ParserTest, HandlesUppercaseTokens) {
87 grammar::LocaleShardMap locale_shard_map =
88 grammar::LocaleShardMap::CreateLocaleShardMap({""});
89 Rules rules(locale_shard_map);
90 constexpr int kScriptedReply = 0;
91 rules.Add("<test>", {"please?", "reply", "<uppercase_token>"},
92 static_cast<CallbackId>(DefaultCallback::kRootRule),
93 kScriptedReply);
94 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
95 Parser parser(unilib_.get(),
96 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
97
98 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
99 TextContextForText("Reply STOP to cancel."), &arena_)),
100 ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
101
102 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
103 TextContextForText("Reply stop to cancel."), &arena_)),
104 IsEmpty());
105 }
106
TEST_F(ParserTest,HandlesAnchors)107 TEST_F(ParserTest, HandlesAnchors) {
108 grammar::LocaleShardMap locale_shard_map =
109 grammar::LocaleShardMap::CreateLocaleShardMap({""});
110 Rules rules(locale_shard_map);
111 constexpr int kScriptedReply = 0;
112 rules.Add("<test>", {"<^>", "reply", "<uppercase_token>", "<$>"},
113 static_cast<CallbackId>(DefaultCallback::kRootRule),
114 kScriptedReply);
115 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
116 Parser parser(unilib_.get(),
117 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
118
119 EXPECT_THAT(ValidDeduplicatedDerivations(
120 parser.Parse(TextContextForText("Reply STOP"), &arena_)),
121 ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
122
123 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
124 TextContextForText("Please reply STOP to cancel."), &arena_)),
125 IsEmpty());
126 }
127
TEST_F(ParserTest,HandlesWordBreaks)128 TEST_F(ParserTest, HandlesWordBreaks) {
129 grammar::LocaleShardMap locale_shard_map =
130 grammar::LocaleShardMap::CreateLocaleShardMap({""});
131 Rules rules(locale_shard_map);
132 rules.Add("<carrier>", {"lx"});
133 rules.Add("<carrier>", {"aa"});
134 constexpr int kFlight = 0;
135 rules.Add("<flight>", {"<carrier>", "<digits>", "<\b>"},
136 static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
137 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
138 Parser parser(unilib_.get(),
139 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
140
141 // Make sure the grammar recognizes "LX 38".
142 EXPECT_THAT(
143 ValidDeduplicatedDerivations(parser.Parse(
144 TextContextForText("My flight is: LX 38. Arriving later"), &arena_)),
145 ElementsAre(IsDerivation(kFlight, 14, 19)));
146
147 // Make sure the grammar doesn't trigger on "LX 38.00".
148 EXPECT_THAT(ValidDeduplicatedDerivations(
149 parser.Parse(TextContextForText("LX 38.00"), &arena_)),
150 IsEmpty());
151 }
152
TEST_F(ParserTest,HandlesAnnotations)153 TEST_F(ParserTest, HandlesAnnotations) {
154 grammar::LocaleShardMap locale_shard_map =
155 grammar::LocaleShardMap::CreateLocaleShardMap({""});
156 Rules rules(locale_shard_map);
157 constexpr int kCallPhone = 0;
158 rules.Add("<flight>", {"dial", "<phone>"},
159 static_cast<CallbackId>(DefaultCallback::kRootRule), kCallPhone);
160 rules.BindAnnotation("<phone>", "phone");
161 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
162 Parser parser(unilib_.get(),
163 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
164
165 TextContext context = TextContextForText("Please dial 911");
166
167 // Sanity check that we don't trigger if we don't feed the correct
168 // annotations.
169 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
170 IsEmpty());
171
172 // Create a phone annotion.
173 AnnotatedSpan phone_span;
174 phone_span.span = CodepointSpan{12, 15};
175 phone_span.classification.emplace_back("phone", 1.0);
176 context.annotations.push_back(phone_span);
177 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
178 ElementsAre(IsDerivation(kCallPhone, 7, 15)));
179 }
180
TEST_F(ParserTest,HandlesRegexAnnotators)181 TEST_F(ParserTest, HandlesRegexAnnotators) {
182 grammar::LocaleShardMap locale_shard_map =
183 grammar::LocaleShardMap::CreateLocaleShardMap({""});
184 Rules rules(locale_shard_map);
185 rules.AddRegex("<code>",
186 "(\"([A-Za-z]+)\"|\\b\"?(?:[A-Z]+[0-9]*|[0-9])\"?\\b)");
187 constexpr int kScriptedReply = 0;
188 rules.Add("<test>", {"please?", "reply", "<code>"},
189 static_cast<CallbackId>(DefaultCallback::kRootRule),
190 kScriptedReply);
191 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
192 Parser parser(unilib_.get(),
193 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
194
195 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
196 TextContextForText("Reply STOP to cancel."), &arena_)),
197 ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
198
199 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
200 TextContextForText("Reply Stop to cancel."), &arena_)),
201 IsEmpty());
202 }
203
TEST_F(ParserTest,HandlesExclusions)204 TEST_F(ParserTest, HandlesExclusions) {
205 grammar::LocaleShardMap locale_shard_map =
206 grammar::LocaleShardMap::CreateLocaleShardMap({""});
207 Rules rules(locale_shard_map);
208 rules.Add("<excluded>", {"be", "safe"});
209 rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
210 /*excluded_nonterminal=*/"<excluded>");
211 constexpr int kSetReminder = 0;
212 rules.Add("<set_reminder>",
213 {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
214 static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
215 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
216 Parser parser(unilib_.get(),
217 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
218
219 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
220 TextContextForText("do not forget to be there"), &arena_)),
221 ElementsAre(IsDerivation(kSetReminder, 0, 25)));
222
223 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
224 TextContextForText("do not forget to be safe"), &arena_)),
225 IsEmpty());
226 }
227
TEST_F(ParserTest,HandlesFillers)228 TEST_F(ParserTest, HandlesFillers) {
229 grammar::LocaleShardMap locale_shard_map =
230 grammar::LocaleShardMap::CreateLocaleShardMap({""});
231 Rules rules(locale_shard_map);
232 constexpr int kSetReminder = 0;
233 rules.Add("<set_reminder>", {"do", "not", "forget", "to", "<filler>"},
234 static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
235 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
236 Parser parser(unilib_.get(),
237 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
238
239 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
240 TextContextForText("do not forget to be there"), &arena_)),
241 ElementsAre(IsDerivation(kSetReminder, 0, 25)));
242 }
243
TEST_F(ParserTest,HandlesAssertions)244 TEST_F(ParserTest, HandlesAssertions) {
245 grammar::LocaleShardMap locale_shard_map =
246 grammar::LocaleShardMap::CreateLocaleShardMap({""});
247 Rules rules(locale_shard_map);
248 rules.Add("<carrier>", {"lx"});
249 rules.Add("<carrier>", {"aa"});
250 rules.Add("<flight_code>", {"<2_digits>"});
251 rules.Add("<flight_code>", {"<3_digits>"});
252 rules.Add("<flight_code>", {"<4_digits>"});
253 // Flight: carrier + flight code and check right context.
254 constexpr int kFlight = 0;
255 rules.Add("<track_flight>",
256 {"<carrier>", "<flight_code>", "<context_assertion>?"},
257 static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
258 // Exclude matches like: LX 38.00 etc.
259 rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
260 /*negative=*/true);
261 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
262 Parser parser(unilib_.get(),
263 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
264
265 EXPECT_THAT(
266 ValidDeduplicatedDerivations(
267 parser.Parse(TextContextForText("LX38 aa 44 LX 38.38"), &arena_)),
268 ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
269 }
270
TEST_F(ParserTest,HandlesWhitespaceGapLimit)271 TEST_F(ParserTest, HandlesWhitespaceGapLimit) {
272 grammar::LocaleShardMap locale_shard_map =
273 grammar::LocaleShardMap::CreateLocaleShardMap({""});
274 Rules rules(locale_shard_map);
275 rules.Add("<carrier>", {"lx"});
276 rules.Add("<carrier>", {"aa"});
277 rules.Add("<flight_code>", {"<2_digits>"});
278 rules.Add("<flight_code>", {"<3_digits>"});
279 rules.Add("<flight_code>", {"<4_digits>"});
280 // Flight: carrier + flight code and check right context.
281 constexpr int kFlight = 0;
282 rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
283 static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight,
284 /*max_whitespace_gap=*/0);
285 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
286 Parser parser(unilib_.get(),
287 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
288
289 EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
290 TextContextForText("LX38 aa 44 LX 38"), &arena_)),
291 ElementsAre(IsDerivation(kFlight, 0, 4)));
292 }
293
TEST_F(ParserTest,HandlesCaseSensitiveMatching)294 TEST_F(ParserTest, HandlesCaseSensitiveMatching) {
295 grammar::LocaleShardMap locale_shard_map =
296 grammar::LocaleShardMap::CreateLocaleShardMap({""});
297 Rules rules(locale_shard_map);
298 rules.Add("<carrier>", {"Lx"}, /*callback=*/kNoCallback, /*callback_param=*/0,
299 /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
300 rules.Add("<carrier>", {"AA"}, /*callback=*/kNoCallback, /*callback_param=*/0,
301 /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
302 rules.Add("<flight_code>", {"<2_digits>"});
303 rules.Add("<flight_code>", {"<3_digits>"});
304 rules.Add("<flight_code>", {"<4_digits>"});
305 // Flight: carrier + flight code and check right context.
306 constexpr int kFlight = 0;
307 rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
308 static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
309 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
310 Parser parser(unilib_.get(),
311 flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
312
313 EXPECT_THAT(
314 ValidDeduplicatedDerivations(
315 parser.Parse(TextContextForText("Lx38 AA 44 LX 38"), &arena_)),
316 ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
317 }
318
319 } // namespace
320 } // namespace libtextclassifier3::grammar
321