1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/utils/rules.h"
18
19 #include <set>
20
21 #include "utils/grammar/utils/ir.h"
22 #include "utils/strings/append.h"
23 #include "utils/strings/stringpiece.h"
24
25 namespace libtextclassifier3::grammar {
26 namespace {
27
28 // Returns whether a nonterminal is a pre-defined one.
IsPredefinedNonterminal(const std::string & nonterminal_name)29 bool IsPredefinedNonterminal(const std::string& nonterminal_name) {
30 if (nonterminal_name == kStartNonterm || nonterminal_name == kEndNonterm ||
31 nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm ||
32 nonterminal_name == kWordBreakNonterm) {
33 return true;
34 }
35 for (int digits = 1; digits <= kMaxNDigitsNontermLength; digits++) {
36 if (nonterminal_name == strings::StringPrintf(kNDigitsNonterm, digits)) {
37 return true;
38 }
39 }
40 return false;
41 }
42
43 // Gets an assigned Nonterm for a nonterminal or kUnassignedNonterm if not yet
44 // assigned.
GetAssignedIdForNonterminal(const int nonterminal,const std::unordered_map<int,Nonterm> & assignment)45 Nonterm GetAssignedIdForNonterminal(
46 const int nonterminal, const std::unordered_map<int, Nonterm>& assignment) {
47 const auto it = assignment.find(nonterminal);
48 if (it == assignment.end()) {
49 return kUnassignedNonterm;
50 }
51 return it->second;
52 }
53
54 // Checks whether all the nonterminals in the rhs of a rule have already been
55 // assigned Nonterm values.
IsRhsAssigned(const Rules::Rule & rule,const std::unordered_map<int,Nonterm> & nonterminals)56 bool IsRhsAssigned(const Rules::Rule& rule,
57 const std::unordered_map<int, Nonterm>& nonterminals) {
58 for (const Rules::RhsElement& element : rule.rhs) {
59 // Terminals are always considered assigned, check only for non-terminals.
60 if (element.is_terminal) {
61 continue;
62 }
63 if (GetAssignedIdForNonterminal(element.nonterminal, nonterminals) ==
64 kUnassignedNonterm) {
65 return false;
66 }
67 }
68
69 // Check that all parts of an exclusion are defined.
70 if (rule.callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
71 if (GetAssignedIdForNonterminal(rule.callback_param, nonterminals) ==
72 kUnassignedNonterm) {
73 return false;
74 }
75 }
76
77 return true;
78 }
79
80 // Lowers a single high-level rule down into the intermediate representation.
LowerRule(const int lhs_index,const Rules::Rule & rule,std::unordered_map<int,Nonterm> * nonterminals,Ir * ir)81 void LowerRule(const int lhs_index, const Rules::Rule& rule,
82 std::unordered_map<int, Nonterm>* nonterminals, Ir* ir) {
83 const CallbackId callback = rule.callback;
84 int64 callback_param = rule.callback_param;
85
86 // Resolve id of excluded nonterminal in exclusion rules.
87 if (callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
88 callback_param = GetAssignedIdForNonterminal(callback_param, *nonterminals);
89 TC3_CHECK_NE(callback_param, kUnassignedNonterm);
90 }
91
92 // Special case for terminal rules.
93 if (rule.rhs.size() == 1 && rule.rhs.front().is_terminal) {
94 (*nonterminals)[lhs_index] =
95 ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
96 /*callback=*/{callback, callback_param},
97 /*preconditions=*/{rule.max_whitespace_gap}},
98 rule.rhs.front().terminal, rule.case_sensitive, rule.shard);
99 return;
100 }
101
102 // Nonterminal rules.
103 std::vector<Nonterm> rhs_nonterms;
104 for (const Rules::RhsElement& element : rule.rhs) {
105 if (element.is_terminal) {
106 rhs_nonterms.push_back(ir->Add(Ir::Lhs{kUnassignedNonterm},
107 element.terminal, rule.case_sensitive,
108 rule.shard));
109 } else {
110 Nonterm nonterminal_id =
111 GetAssignedIdForNonterminal(element.nonterminal, *nonterminals);
112 TC3_CHECK_NE(nonterminal_id, kUnassignedNonterm);
113 rhs_nonterms.push_back(nonterminal_id);
114 }
115 }
116 (*nonterminals)[lhs_index] =
117 ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
118 /*callback=*/{callback, callback_param},
119 /*preconditions=*/{rule.max_whitespace_gap}},
120 rhs_nonterms, rule.shard);
121 }
122 // Check whether this component is a non-terminal.
IsNonterminal(StringPiece rhs_component)123 bool IsNonterminal(StringPiece rhs_component) {
124 return rhs_component[0] == '<' &&
125 rhs_component[rhs_component.size() - 1] == '>';
126 }
127
128 // Sanity check for common typos -- '<' or '>' in a terminal.
ValidateTerminal(StringPiece rhs_component)129 void ValidateTerminal(StringPiece rhs_component) {
130 TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
131 << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
132 TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
133 << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
134 TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
135 << "Rhs terminal `" << rhs_component << "` contains a question mark.";
136 }
137
138 } // namespace
139
AddNonterminal(const std::string & nonterminal_name)140 int Rules::AddNonterminal(const std::string& nonterminal_name) {
141 std::string key = nonterminal_name;
142 auto alias_it = nonterminal_alias_.find(key);
143 if (alias_it != nonterminal_alias_.end()) {
144 key = alias_it->second;
145 }
146 auto it = nonterminal_names_.find(key);
147 if (it != nonterminal_names_.end()) {
148 return it->second;
149 }
150 const int index = nonterminals_.size();
151 nonterminals_.push_back(NontermInfo{key});
152 nonterminal_names_.insert(it, {key, index});
153 return index;
154 }
155
AddNewNonterminal()156 int Rules::AddNewNonterminal() {
157 const int index = nonterminals_.size();
158 nonterminals_.push_back(NontermInfo{});
159 return index;
160 }
161
AddAlias(const std::string & nonterminal_name,const std::string & alias)162 void Rules::AddAlias(const std::string& nonterminal_name,
163 const std::string& alias) {
164 #ifndef TC3_USE_CXX14
165 TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name)
166 .first->second,
167 nonterminal_name)
168 << "Cannot redefine alias: " << alias;
169 #else
170 nonterminal_alias_[alias] = nonterminal_name;
171 TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name)
172 << "Cannot redefine alias: " << alias;
173 #endif
174 }
175
176 // Defines a nonterminal for an externally provided annotation.
AddAnnotation(const std::string & annotation_name)177 int Rules::AddAnnotation(const std::string& annotation_name) {
178 auto [it, inserted] =
179 annotation_nonterminals_.insert({annotation_name, nonterminals_.size()});
180 if (inserted) {
181 nonterminals_.push_back(NontermInfo{});
182 }
183 return it->second;
184 }
185
BindAnnotation(const std::string & nonterminal_name,const std::string & annotation_name)186 void Rules::BindAnnotation(const std::string& nonterminal_name,
187 const std::string& annotation_name) {
188 auto [_, inserted] = annotation_nonterminals_.insert(
189 {annotation_name, AddNonterminal(nonterminal_name)});
190 TC3_CHECK(inserted);
191 }
192
IsNonterminalOfName(const RhsElement & element,const std::string & nonterminal) const193 bool Rules::IsNonterminalOfName(const RhsElement& element,
194 const std::string& nonterminal) const {
195 if (element.is_terminal) {
196 return false;
197 }
198 return (nonterminals_[element.nonterminal].name == nonterminal);
199 }
200
201 // Note: For k optional components this creates 2^k rules, but it would be
202 // possible to be smarter about this and only use 2k rules instead.
203 // However that might be slower as it requires an extra rule firing at match
204 // time for every omitted optional element.
ExpandOptionals(const int lhs,const std::vector<RhsElement> & rhs,const CallbackId callback,const int64 callback_param,const int8 max_whitespace_gap,const bool case_sensitive,const int shard,std::vector<int>::const_iterator optional_element_indices,std::vector<int>::const_iterator optional_element_indices_end,std::vector<bool> * omit_these)205 void Rules::ExpandOptionals(
206 const int lhs, const std::vector<RhsElement>& rhs,
207 const CallbackId callback, const int64 callback_param,
208 const int8 max_whitespace_gap, const bool case_sensitive, const int shard,
209 std::vector<int>::const_iterator optional_element_indices,
210 std::vector<int>::const_iterator optional_element_indices_end,
211 std::vector<bool>* omit_these) {
212 if (optional_element_indices == optional_element_indices_end) {
213 // Nothing is optional, so just generate a rule.
214 Rule r;
215 for (uint32 i = 0; i < rhs.size(); i++) {
216 if (!omit_these->at(i)) {
217 r.rhs.push_back(rhs[i]);
218 }
219 }
220 r.callback = callback;
221 r.callback_param = callback_param;
222 r.max_whitespace_gap = max_whitespace_gap;
223 r.case_sensitive = case_sensitive;
224 r.shard = shard;
225 nonterminals_[lhs].rules.push_back(rules_.size());
226 rules_.push_back(r);
227 return;
228 }
229
230 const int next_optional_part = *optional_element_indices;
231 ++optional_element_indices;
232
233 // Recursive call 1: The optional part is omitted.
234 (*omit_these)[next_optional_part] = true;
235 ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
236 case_sensitive, shard, optional_element_indices,
237 optional_element_indices_end, omit_these);
238
239 // Recursive call 2: The optional part is required.
240 (*omit_these)[next_optional_part] = false;
241 ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
242 case_sensitive, shard, optional_element_indices,
243 optional_element_indices_end, omit_these);
244 }
245
ResolveAnchors(const std::vector<RhsElement> & rhs) const246 std::vector<Rules::RhsElement> Rules::ResolveAnchors(
247 const std::vector<RhsElement>& rhs) const {
248 if (rhs.size() <= 2) {
249 return rhs;
250 }
251 auto begin = rhs.begin();
252 auto end = rhs.end();
253 if (IsNonterminalOfName(rhs.front(), kStartNonterm) &&
254 IsNonterminalOfName(rhs[1], kFiller)) {
255 // Skip start anchor and filler.
256 begin += 2;
257 }
258 if (IsNonterminalOfName(rhs.back(), kEndNonterm) &&
259 IsNonterminalOfName(rhs[rhs.size() - 2], kFiller)) {
260 // Skip filler and end anchor.
261 end -= 2;
262 }
263 return std::vector<Rules::RhsElement>(begin, end);
264 }
265
ResolveFillers(const std::vector<RhsElement> & rhs,int shard)266 std::vector<Rules::RhsElement> Rules::ResolveFillers(
267 const std::vector<RhsElement>& rhs, int shard) {
268 std::vector<RhsElement> result;
269 for (int i = 0; i < rhs.size();) {
270 if (i == rhs.size() - 1 || IsNonterminalOfName(rhs[i], kFiller) ||
271 rhs[i].is_optional || !IsNonterminalOfName(rhs[i + 1], kFiller)) {
272 result.push_back(rhs[i]);
273 i++;
274 continue;
275 }
276
277 // We have the case:
278 // <a> <filler>
279 // rewrite as:
280 // <a_with_tokens> ::= <a>
281 // <a_with_tokens> ::= <a_with_tokens> <token>
282 const int with_tokens_nonterminal = AddNewNonterminal();
283 const RhsElement token(AddNonterminal(kTokenNonterm),
284 /*is_optional=*/false);
285 if (rhs[i + 1].is_optional) {
286 // <a_with_tokens> ::= <a>
287 Add(with_tokens_nonterminal, {rhs[i]},
288 /*callback=*/kNoCallback,
289 /*callback_param=*/0,
290 /*max_whitespace_gap=*/-1,
291 /*case_sensitive=*/false, shard);
292 } else {
293 // <a_with_tokens> ::= <a> <token>
294 Add(with_tokens_nonterminal, {rhs[i], token},
295 /*callback=*/kNoCallback,
296 /*callback_param=*/0,
297 /*max_whitespace_gap=*/-1,
298 /*case_sensitive=*/false, shard);
299 }
300 // <a_with_tokens> ::= <a_with_tokens> <token>
301 const RhsElement with_tokens(with_tokens_nonterminal,
302 /*is_optional=*/false);
303 Add(with_tokens_nonterminal, {with_tokens, token},
304 /*callback=*/kNoCallback,
305 /*callback_param=*/0,
306 /*max_whitespace_gap=*/-1,
307 /*case_sensitive=*/false, shard);
308 result.push_back(with_tokens);
309 i += 2;
310 }
311 return result;
312 }
313
OptimizeRhs(const std::vector<RhsElement> & rhs,int shard)314 std::vector<Rules::RhsElement> Rules::OptimizeRhs(
315 const std::vector<RhsElement>& rhs, int shard) {
316 return ResolveFillers(ResolveAnchors(rhs), shard);
317 }
318
Add(const int lhs,const std::vector<RhsElement> & rhs,const CallbackId callback,const int64 callback_param,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)319 void Rules::Add(const int lhs, const std::vector<RhsElement>& rhs,
320 const CallbackId callback, const int64 callback_param,
321 const int8 max_whitespace_gap, const bool case_sensitive,
322 const int shard) {
323 // Resolve anchors and fillers.
324 const std::vector optimized_rhs = OptimizeRhs(rhs);
325
326 std::vector<int> optional_element_indices;
327 TC3_CHECK_LT(optional_element_indices.size(), optimized_rhs.size())
328 << "Rhs must contain at least one non-optional element.";
329 for (int i = 0; i < optimized_rhs.size(); i++) {
330 if (optimized_rhs[i].is_optional) {
331 optional_element_indices.push_back(i);
332 }
333 }
334 std::vector<bool> omit_these(optimized_rhs.size(), false);
335 ExpandOptionals(lhs, optimized_rhs, callback, callback_param,
336 max_whitespace_gap, case_sensitive, shard,
337 optional_element_indices.begin(),
338 optional_element_indices.end(), &omit_these);
339 }
340
Add(const std::string & lhs,const std::vector<std::string> & rhs,const CallbackId callback,const int64 callback_param,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)341 void Rules::Add(const std::string& lhs, const std::vector<std::string>& rhs,
342 const CallbackId callback, const int64 callback_param,
343 const int8 max_whitespace_gap, const bool case_sensitive,
344 const int shard) {
345 TC3_CHECK(!rhs.empty()) << "Rhs cannot be empty (Lhs=" << lhs << ")";
346 TC3_CHECK(!IsPredefinedNonterminal(lhs));
347 std::vector<RhsElement> rhs_elements;
348 rhs_elements.reserve(rhs.size());
349 for (StringPiece rhs_component : rhs) {
350 // Check whether this component is optional.
351 bool is_optional = false;
352 if (rhs_component[rhs_component.size() - 1] == '?') {
353 rhs_component.RemoveSuffix(1);
354 is_optional = true;
355 }
356 // Check whether this component is a non-terminal.
357 if (IsNonterminal(rhs_component)) {
358 rhs_elements.push_back(
359 RhsElement(AddNonterminal(rhs_component.ToString()), is_optional));
360 } else {
361 // A terminal.
362 // Sanity check for common typos -- '<' or '>' in a terminal.
363 ValidateTerminal(rhs_component);
364 rhs_elements.push_back(RhsElement(rhs_component.ToString(), is_optional));
365 }
366 }
367 Add(AddNonterminal(lhs), rhs_elements, callback, callback_param,
368 max_whitespace_gap, case_sensitive, shard);
369 }
370
AddWithExclusion(const std::string & lhs,const std::vector<std::string> & rhs,const std::string & excluded_nonterminal,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)371 void Rules::AddWithExclusion(const std::string& lhs,
372 const std::vector<std::string>& rhs,
373 const std::string& excluded_nonterminal,
374 const int8 max_whitespace_gap,
375 const bool case_sensitive, const int shard) {
376 Add(lhs, rhs,
377 /*callback=*/static_cast<CallbackId>(DefaultCallback::kExclusion),
378 /*callback_param=*/AddNonterminal(excluded_nonterminal),
379 max_whitespace_gap, case_sensitive, shard);
380 }
381
AddAssertion(const std::string & lhs,const std::vector<std::string> & rhs,const bool negative,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)382 void Rules::AddAssertion(const std::string& lhs,
383 const std::vector<std::string>& rhs,
384 const bool negative, const int8 max_whitespace_gap,
385 const bool case_sensitive, const int shard) {
386 Add(lhs, rhs,
387 /*callback=*/static_cast<CallbackId>(DefaultCallback::kAssertion),
388 /*callback_param=*/negative, max_whitespace_gap, case_sensitive, shard);
389 }
390
AddValueMapping(const std::string & lhs,const std::vector<std::string> & rhs,const int64 value,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)391 void Rules::AddValueMapping(const std::string& lhs,
392 const std::vector<std::string>& rhs,
393 const int64 value, const int8 max_whitespace_gap,
394 const bool case_sensitive, const int shard) {
395 Add(lhs, rhs,
396 /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
397 /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
398 }
399
AddValueMapping(const int lhs,const std::vector<RhsElement> & rhs,int64 value,const int8 max_whitespace_gap,const bool case_sensitive,const int shard)400 void Rules::AddValueMapping(const int lhs, const std::vector<RhsElement>& rhs,
401 int64 value, const int8 max_whitespace_gap,
402 const bool case_sensitive, const int shard) {
403 Add(lhs, rhs,
404 /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
405 /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
406 }
407
AddRegex(const std::string & lhs,const std::string & regex_pattern)408 void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
409 AddRegex(AddNonterminal(lhs), regex_pattern);
410 }
411
AddRegex(int lhs,const std::string & regex_pattern)412 void Rules::AddRegex(int lhs, const std::string& regex_pattern) {
413 nonterminals_[lhs].regex_rules.push_back(regex_rules_.size());
414 regex_rules_.push_back(regex_pattern);
415 }
416
UsesFillers() const417 bool Rules::UsesFillers() const {
418 for (const Rule& rule : rules_) {
419 for (const RhsElement& rhs_element : rule.rhs) {
420 if (IsNonterminalOfName(rhs_element, kFiller)) {
421 return true;
422 }
423 }
424 }
425 return false;
426 }
427
Finalize(const std::set<std::string> & predefined_nonterminals) const428 Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
429 Ir rules(locale_shard_map_);
430 std::unordered_map<int, Nonterm> nonterminal_ids;
431
432 // Pending rules to process.
433 std::set<std::pair<int, int>> scheduled_rules;
434
435 // Define all used predefined nonterminals.
436 for (const auto& it : nonterminal_names_) {
437 if (IsPredefinedNonterminal(it.first) ||
438 predefined_nonterminals.find(it.first) !=
439 predefined_nonterminals.end()) {
440 nonterminal_ids[it.second] = rules.AddUnshareableNonterminal(it.first);
441 }
442 }
443
444 // Assign (unmergeable) Nonterm values to any nonterminals that have
445 // multiple rules.
446 for (int i = 0; i < nonterminals_.size(); i++) {
447 const NontermInfo& nonterminal = nonterminals_[i];
448
449 // Skip predefined nonterminals, they have already been assigned.
450 if (rules.GetNonterminalForName(nonterminal.name) != kUnassignedNonterm) {
451 continue;
452 }
453
454 bool unmergeable =
455 (nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
456 !nonterminal.regex_rules.empty());
457 for (const int rule_index : nonterminal.rules) {
458 // Schedule rule.
459 scheduled_rules.insert({i, rule_index});
460 }
461
462 if (unmergeable) {
463 // Define unique nonterminal id.
464 nonterminal_ids[i] = rules.AddUnshareableNonterminal(nonterminal.name);
465 } else {
466 nonterminal_ids[i] = rules.AddNonterminal(nonterminal.name);
467 }
468
469 // Define regex rules.
470 for (const int regex_rule : nonterminal.regex_rules) {
471 rules.AddRegex(nonterminal_ids[i], regex_rules_[regex_rule]);
472 }
473 }
474
475 // Define annotations.
476 for (const auto& [annotation, nonterminal] : annotation_nonterminals_) {
477 rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
478 }
479
480 // Check whether fillers are still referenced (if they couldn't get optimized
481 // away).
482 if (UsesFillers()) {
483 TC3_LOG(WARNING) << "Rules use fillers that couldn't be optimized, grammar "
484 "matching performance might be impacted.";
485
486 // Add a definition for the filler:
487 // <filler> = <token>
488 // <filler> = <token> <filler>
489 const Nonterm filler = rules.GetNonterminalForName(kFiller);
490 const Nonterm token =
491 rules.DefineNonterminal(rules.GetNonterminalForName(kTokenNonterm));
492 rules.Add(filler, token);
493 rules.Add(filler, std::vector<Nonterm>{token, filler});
494 }
495
496 // Now, keep adding eligible rules (rules whose rhs is completely assigned)
497 // until we can't make any more progress.
498 // Note: The following code is quadratic in the worst case.
499 // This seems fine as this will only run as part of the compilation of the
500 // grammar rules during model assembly.
501 bool changed = true;
502 while (changed) {
503 changed = false;
504 for (auto nt_and_rule = scheduled_rules.begin();
505 nt_and_rule != scheduled_rules.end();) {
506 const Rule& rule = rules_[nt_and_rule->second];
507 if (IsRhsAssigned(rule, nonterminal_ids)) {
508 // Compile the rule.
509 LowerRule(/*lhs_index=*/nt_and_rule->first, rule, &nonterminal_ids,
510 &rules);
511 scheduled_rules.erase(
512 nt_and_rule++); // Iterator is advanced before erase.
513 changed = true;
514 break;
515 } else {
516 nt_and_rule++;
517 }
518 }
519 }
520 TC3_CHECK(scheduled_rules.empty());
521 return rules;
522 }
523
524 } // namespace libtextclassifier3::grammar
525