1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/utils/ir.h"
18
19 #include "utils/i18n/locale.h"
20 #include "utils/strings/append.h"
21 #include "utils/strings/stringpiece.h"
22 #include "utils/zlib/zlib.h"
23
24 namespace libtextclassifier3::grammar {
25 namespace {
26
27 constexpr size_t kMaxHashTableSize = 100;
28
29 template <typename T>
SortForBinarySearchLookup(T * entries)30 void SortForBinarySearchLookup(T* entries) {
31 std::sort(entries->begin(), entries->end(),
32 [](const auto& a, const auto& b) { return a->key < b->key; });
33 }
34
35 template <typename T>
SortStructsForBinarySearchLookup(T * entries)36 void SortStructsForBinarySearchLookup(T* entries) {
37 std::sort(entries->begin(), entries->end(),
38 [](const auto& a, const auto& b) { return a.key() < b.key(); });
39 }
40
IsSameLhs(const Ir::Lhs & lhs,const RulesSet_::Lhs & other)41 bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
42 return (lhs.nonterminal == other.nonterminal() &&
43 lhs.callback.id == other.callback_id() &&
44 lhs.callback.param == other.callback_param() &&
45 lhs.preconditions.max_whitespace_gap == other.max_whitespace_gap());
46 }
47
IsSameLhsEntry(const Ir::Lhs & lhs,const int32 lhs_entry,const std::vector<RulesSet_::Lhs> & candidates)48 bool IsSameLhsEntry(const Ir::Lhs& lhs, const int32 lhs_entry,
49 const std::vector<RulesSet_::Lhs>& candidates) {
50 // Simple case: direct encoding of the nonterminal.
51 if (lhs_entry > 0) {
52 return (lhs.nonterminal == lhs_entry && lhs.callback.id == kNoCallback &&
53 lhs.preconditions.max_whitespace_gap == -1);
54 }
55
56 // Entry is index into callback lookup.
57 return IsSameLhs(lhs, candidates[-lhs_entry]);
58 }
59
IsSameLhsSet(const Ir::LhsSet & lhs_set,const RulesSet_::LhsSetT & candidate,const std::vector<RulesSet_::Lhs> & candidates)60 bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
61 const RulesSet_::LhsSetT& candidate,
62 const std::vector<RulesSet_::Lhs>& candidates) {
63 if (lhs_set.size() != candidate.lhs.size()) {
64 return false;
65 }
66
67 for (int i = 0; i < lhs_set.size(); i++) {
68 // Check that entries are the same.
69 if (!IsSameLhsEntry(lhs_set[i], candidate.lhs[i], candidates)) {
70 return false;
71 }
72 }
73
74 return true;
75 }
76
SortedLhsSet(const Ir::LhsSet & lhs_set)77 Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
78 Ir::LhsSet sorted_lhs = lhs_set;
79 std::sort(sorted_lhs.begin(), sorted_lhs.end(),
80 [](const Ir::Lhs& a, const Ir::Lhs& b) {
81 return std::tie(a.nonterminal, a.callback.id, a.callback.param,
82 a.preconditions.max_whitespace_gap) <
83 std::tie(b.nonterminal, b.callback.id, b.callback.param,
84 b.preconditions.max_whitespace_gap);
85 });
86 return lhs_set;
87 }
88
89 // Adds a new lhs match set to the output.
90 // Reuses the same set, if it was previously observed.
AddLhsSet(const Ir::LhsSet & lhs_set,RulesSetT * rules_set)91 int AddLhsSet(const Ir::LhsSet& lhs_set, RulesSetT* rules_set) {
92 Ir::LhsSet sorted_lhs = SortedLhsSet(lhs_set);
93 // Check whether we can reuse an entry.
94 const int output_size = rules_set->lhs_set.size();
95 for (int i = 0; i < output_size; i++) {
96 if (IsSameLhsSet(lhs_set, *rules_set->lhs_set[i], rules_set->lhs)) {
97 return i;
98 }
99 }
100
101 // Add new entry.
102 rules_set->lhs_set.emplace_back(std::make_unique<RulesSet_::LhsSetT>());
103 RulesSet_::LhsSetT* serialized_lhs_set = rules_set->lhs_set.back().get();
104 for (const Ir::Lhs& lhs : lhs_set) {
105 // Simple case: No callback and no special requirements, we directly encode
106 // the nonterminal.
107 if (lhs.callback.id == kNoCallback &&
108 lhs.preconditions.max_whitespace_gap < 0) {
109 serialized_lhs_set->lhs.push_back(lhs.nonterminal);
110 } else {
111 // Check whether we can reuse a callback entry.
112 const int lhs_size = rules_set->lhs.size();
113 bool found_entry = false;
114 for (int i = 0; i < lhs_size; i++) {
115 if (IsSameLhs(lhs, rules_set->lhs[i])) {
116 found_entry = true;
117 serialized_lhs_set->lhs.push_back(-i);
118 break;
119 }
120 }
121
122 // We could reuse an existing entry.
123 if (found_entry) {
124 continue;
125 }
126
127 // Add a new one.
128 rules_set->lhs.push_back(
129 RulesSet_::Lhs(lhs.nonterminal, lhs.callback.id, lhs.callback.param,
130 lhs.preconditions.max_whitespace_gap));
131 serialized_lhs_set->lhs.push_back(-lhs_size);
132 }
133 }
134 return output_size;
135 }
136
137 // Serializes a unary rules table.
SerializeUnaryRulesShard(const std::unordered_map<Nonterm,Ir::LhsSet> & unary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)138 void SerializeUnaryRulesShard(
139 const std::unordered_map<Nonterm, Ir::LhsSet>& unary_rules,
140 RulesSetT* rules_set, RulesSet_::RulesT* rules) {
141 for (const auto& it : unary_rules) {
142 rules->unary_rules.push_back(RulesSet_::Rules_::UnaryRulesEntry(
143 it.first, AddLhsSet(it.second, rules_set)));
144 }
145 SortStructsForBinarySearchLookup(&rules->unary_rules);
146 }
147
148 // // Serializes a binary rules table.
SerializeBinaryRulesShard(const std::unordered_map<TwoNonterms,Ir::LhsSet,BinaryRuleHasher> & binary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)149 void SerializeBinaryRulesShard(
150 const std::unordered_map<TwoNonterms, Ir::LhsSet, BinaryRuleHasher>&
151 binary_rules,
152 RulesSetT* rules_set, RulesSet_::RulesT* rules) {
153 const size_t num_buckets = std::min(binary_rules.size(), kMaxHashTableSize);
154 for (int i = 0; i < num_buckets; i++) {
155 rules->binary_rules.emplace_back(
156 new RulesSet_::Rules_::BinaryRuleTableBucketT());
157 }
158
159 // Serialize the table.
160 BinaryRuleHasher hash;
161 for (const auto& it : binary_rules) {
162 const TwoNonterms key = it.first;
163 uint32 bucket_index = hash(key) % num_buckets;
164
165 // Add entry to bucket chain list.
166 rules->binary_rules[bucket_index]->rules.push_back(
167 RulesSet_::Rules_::BinaryRule(key.first, key.second,
168 AddLhsSet(it.second, rules_set)));
169 }
170 }
171
172 } // namespace
173
AddToSet(const Lhs & lhs,LhsSet * lhs_set)174 Nonterm Ir::AddToSet(const Lhs& lhs, LhsSet* lhs_set) {
175 const int lhs_set_size = lhs_set->size();
176 Nonterm shareable_nonterm = lhs.nonterminal;
177 for (int i = 0; i < lhs_set_size; i++) {
178 Lhs* candidate = &lhs_set->at(i);
179
180 // Exact match, just reuse rule.
181 if (lhs == *candidate) {
182 return candidate->nonterminal;
183 }
184
185 // Cannot reuse unshareable ids.
186 if (nonshareable_.find(candidate->nonterminal) != nonshareable_.end() ||
187 nonshareable_.find(lhs.nonterminal) != nonshareable_.end()) {
188 continue;
189 }
190
191 // Cannot reuse id if the preconditions are different.
192 if (!(lhs.preconditions == candidate->preconditions)) {
193 continue;
194 }
195
196 // If the nonterminal is already defined, it must match for sharing.
197 if (lhs.nonterminal != kUnassignedNonterm &&
198 lhs.nonterminal != candidate->nonterminal) {
199 continue;
200 }
201
202 // Check whether the callbacks match.
203 if (lhs.callback == candidate->callback) {
204 return candidate->nonterminal;
205 }
206
207 // We can reuse if one of the output callbacks is not used.
208 if (lhs.callback.id == kNoCallback) {
209 return candidate->nonterminal;
210 } else if (candidate->callback.id == kNoCallback) {
211 // Old entry has no output callback, which is redundant now.
212 candidate->callback = lhs.callback;
213 return candidate->nonterminal;
214 }
215
216 // We can share the nonterminal, but we need to
217 // add a new output callback. Defer this as we might find a shareable
218 // nonterminal first.
219 shareable_nonterm = candidate->nonterminal;
220 }
221
222 // We didn't find a redundant entry, so create a new one.
223 shareable_nonterm = DefineNonterminal(shareable_nonterm);
224 lhs_set->push_back(Lhs{shareable_nonterm, lhs.callback, lhs.preconditions});
225 return shareable_nonterm;
226 }
227
Add(const Lhs & lhs,const std::string & terminal,const bool case_sensitive,const int shard)228 Nonterm Ir::Add(const Lhs& lhs, const std::string& terminal,
229 const bool case_sensitive, const int shard) {
230 TC3_CHECK_LT(shard, shards_.size());
231 if (case_sensitive) {
232 return AddRule(lhs, terminal, &shards_[shard].terminal_rules);
233 } else {
234 return AddRule(lhs, terminal, &shards_[shard].lowercase_terminal_rules);
235 }
236 }
237
Add(const Lhs & lhs,const std::vector<Nonterm> & rhs,const int shard)238 Nonterm Ir::Add(const Lhs& lhs, const std::vector<Nonterm>& rhs,
239 const int shard) {
240 // Add a new unary rule.
241 if (rhs.size() == 1) {
242 return Add(lhs, rhs.front(), shard);
243 }
244
245 // Add a chain of (rhs.size() - 1) binary rules.
246 Nonterm prev = rhs.front();
247 for (int i = 1; i < rhs.size() - 1; i++) {
248 prev = Add(kUnassignedNonterm, prev, rhs[i], shard);
249 }
250 return Add(lhs, prev, rhs.back(), shard);
251 }
252
AddRegex(Nonterm lhs,const std::string & regex_pattern)253 Nonterm Ir::AddRegex(Nonterm lhs, const std::string& regex_pattern) {
254 lhs = DefineNonterminal(lhs);
255 regex_rules_.emplace_back(regex_pattern, lhs);
256 return lhs;
257 }
258
AddAnnotation(const Nonterm lhs,const std::string & annotation)259 void Ir::AddAnnotation(const Nonterm lhs, const std::string& annotation) {
260 annotations_.emplace_back(annotation, lhs);
261 }
262
263 // Serializes the terminal rules table.
SerializeTerminalRules(RulesSetT * rules_set,std::vector<std::unique_ptr<RulesSet_::RulesT>> * rules_shards) const264 void Ir::SerializeTerminalRules(
265 RulesSetT* rules_set,
266 std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const {
267 // Use common pool for all terminals.
268 struct TerminalEntry {
269 std::string terminal;
270 int set_index;
271 int index;
272 Ir::LhsSet lhs_set;
273 };
274 std::vector<TerminalEntry> terminal_rules;
275
276 // Merge all terminals into a common pool.
277 // We want to use one common pool, but still need to track which set they
278 // belong to.
279 std::vector<const std::unordered_map<std::string, Ir::LhsSet>*>
280 terminal_rules_sets;
281 std::vector<RulesSet_::Rules_::TerminalRulesMapT*> rules_maps;
282 terminal_rules_sets.reserve(2 * shards_.size());
283 rules_maps.reserve(terminal_rules_sets.size());
284 for (int i = 0; i < shards_.size(); i++) {
285 terminal_rules_sets.push_back(&shards_[i].terminal_rules);
286 terminal_rules_sets.push_back(&shards_[i].lowercase_terminal_rules);
287 rules_shards->at(i)->terminal_rules.reset(
288 new RulesSet_::Rules_::TerminalRulesMapT());
289 rules_shards->at(i)->lowercase_terminal_rules.reset(
290 new RulesSet_::Rules_::TerminalRulesMapT());
291 rules_maps.push_back(rules_shards->at(i)->terminal_rules.get());
292 rules_maps.push_back(rules_shards->at(i)->lowercase_terminal_rules.get());
293 }
294 for (int i = 0; i < terminal_rules_sets.size(); i++) {
295 for (const auto& it : *terminal_rules_sets[i]) {
296 terminal_rules.push_back(
297 TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
298 }
299 }
300 std::sort(terminal_rules.begin(), terminal_rules.end(),
301 [](const TerminalEntry& a, const TerminalEntry& b) {
302 return a.terminal < b.terminal;
303 });
304
305 // Index the entries in sorted order.
306 std::vector<int> index(terminal_rules_sets.size(), 0);
307 for (int i = 0; i < terminal_rules.size(); i++) {
308 terminal_rules[i].index = index[terminal_rules[i].set_index]++;
309 }
310
311 // We store the terminal strings sorted into a buffer and keep offsets into
312 // that buffer. In this way, we don't need extra space for terminals that are
313 // suffixes of others.
314
315 // Find terminals that are a suffix of others, O(n^2) algorithm.
316 constexpr int kInvalidIndex = -1;
317 std::vector<int> suffix(terminal_rules.size(), kInvalidIndex);
318 for (int i = 0; i < terminal_rules.size(); i++) {
319 const StringPiece terminal(terminal_rules[i].terminal);
320
321 // Check whether the ith terminal is a suffix of another.
322 for (int j = 0; j < terminal_rules.size(); j++) {
323 if (i == j) {
324 continue;
325 }
326 if (StringPiece(terminal_rules[j].terminal).EndsWith(terminal)) {
327 // If both terminals are the same keep the first.
328 // This avoids cyclic dependencies.
329 // This can happen if multiple shards use same terminals, such as
330 // punctuation.
331 if (terminal_rules[j].terminal.size() == terminal.size() && j < i) {
332 continue;
333 }
334 suffix[i] = j;
335 break;
336 }
337 }
338 }
339
340 rules_set->terminals = "";
341
342 for (int i = 0; i < terminal_rules_sets.size(); i++) {
343 rules_maps[i]->terminal_offsets.resize(terminal_rules_sets[i]->size());
344 rules_maps[i]->max_terminal_length = 0;
345 rules_maps[i]->min_terminal_length = std::numeric_limits<int>::max();
346 }
347
348 for (int i = 0; i < terminal_rules.size(); i++) {
349 const TerminalEntry& entry = terminal_rules[i];
350
351 // Update bounds.
352 rules_maps[entry.set_index]->min_terminal_length =
353 std::min(rules_maps[entry.set_index]->min_terminal_length,
354 static_cast<int>(entry.terminal.size()));
355 rules_maps[entry.set_index]->max_terminal_length =
356 std::max(rules_maps[entry.set_index]->max_terminal_length,
357 static_cast<int>(entry.terminal.size()));
358
359 // Only include terminals that are not suffixes of others.
360 if (suffix[i] != kInvalidIndex) {
361 continue;
362 }
363
364 rules_maps[entry.set_index]->terminal_offsets[entry.index] =
365 rules_set->terminals.length();
366 rules_set->terminals += entry.terminal + '\0';
367 }
368
369 // Store just an offset into the existing terminal data for the terminals
370 // that are suffixes of others.
371 for (int i = 0; i < terminal_rules.size(); i++) {
372 int canonical_index = i;
373 if (suffix[canonical_index] == kInvalidIndex) {
374 continue;
375 }
376
377 // Find the overlapping string that was included in the data.
378 while (suffix[canonical_index] != kInvalidIndex) {
379 canonical_index = suffix[canonical_index];
380 }
381
382 const TerminalEntry& entry = terminal_rules[i];
383 const TerminalEntry& canonical_entry = terminal_rules[canonical_index];
384
385 // The offset is the offset of the overlapping string and the offset within
386 // that string.
387 rules_maps[entry.set_index]->terminal_offsets[entry.index] =
388 rules_maps[canonical_entry.set_index]
389 ->terminal_offsets[canonical_entry.index] +
390 (canonical_entry.terminal.length() - entry.terminal.length());
391 }
392
393 for (const TerminalEntry& entry : terminal_rules) {
394 rules_maps[entry.set_index]->lhs_set_index.push_back(
395 AddLhsSet(entry.lhs_set, rules_set));
396 }
397 }
398
Serialize(const bool include_debug_information,RulesSetT * output) const399 void Ir::Serialize(const bool include_debug_information,
400 RulesSetT* output) const {
401 // Add information about predefined nonterminal classes.
402 output->nonterminals.reset(new RulesSet_::NonterminalsT);
403 output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
404 output->nonterminals->end_nt = GetNonterminalForName(kEndNonterm);
405 output->nonterminals->wordbreak_nt = GetNonterminalForName(kWordBreakNonterm);
406 output->nonterminals->token_nt = GetNonterminalForName(kTokenNonterm);
407 output->nonterminals->uppercase_token_nt =
408 GetNonterminalForName(kUppercaseTokenNonterm);
409 output->nonterminals->digits_nt = GetNonterminalForName(kDigitsNonterm);
410 for (int i = 1; i <= kMaxNDigitsNontermLength; i++) {
411 if (const Nonterm n_digits_nt =
412 GetNonterminalForName(strings::StringPrintf(kNDigitsNonterm, i))) {
413 output->nonterminals->n_digits_nt.resize(i, kUnassignedNonterm);
414 output->nonterminals->n_digits_nt[i - 1] = n_digits_nt;
415 }
416 }
417 for (const auto& [annotation, annotation_nt] : annotations_) {
418 output->nonterminals->annotation_nt.emplace_back(
419 new RulesSet_::Nonterminals_::AnnotationNtEntryT);
420 output->nonterminals->annotation_nt.back()->key = annotation;
421 output->nonterminals->annotation_nt.back()->value = annotation_nt;
422 }
423 SortForBinarySearchLookup(&output->nonterminals->annotation_nt);
424
425 if (include_debug_information) {
426 output->debug_information.reset(new RulesSet_::DebugInformationT);
427 // Keep original non-terminal names.
428 for (const auto& it : nonterminal_names_) {
429 output->debug_information->nonterminal_names.emplace_back(
430 new RulesSet_::DebugInformation_::NonterminalNamesEntryT);
431 output->debug_information->nonterminal_names.back()->key = it.first;
432 output->debug_information->nonterminal_names.back()->value = it.second;
433 }
434 SortForBinarySearchLookup(&output->debug_information->nonterminal_names);
435 }
436
437 // Add regex rules.
438 std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
439 for (auto [pattern, lhs] : regex_rules_) {
440 output->regex_annotator.emplace_back(new RulesSet_::RegexAnnotatorT);
441 output->regex_annotator.back()->compressed_pattern.reset(
442 new CompressedBufferT);
443 compressor->Compress(
444 pattern, output->regex_annotator.back()->compressed_pattern.get());
445 output->regex_annotator.back()->nonterminal = lhs;
446 }
447
448 // Serialize the unary and binary rules.
449 for (int i = 0; i < shards_.size(); i++) {
450 output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
451 RulesSet_::RulesT* rules = output->rules.back().get();
452 for (const Locale& shard_locale : locale_shard_map_.GetLocales(i)) {
453 if (shard_locale.IsValid()) {
454 // Check if the language is set to all i.e. '*' which is a special, to
455 // make it consistent with device side parser here instead of filling
456 // the all locale leave the language tag list empty
457 rules->locale.emplace_back(
458 std::make_unique<libtextclassifier3::LanguageTagT>());
459 libtextclassifier3::LanguageTagT* language_tag =
460 rules->locale.back().get();
461 language_tag->language = shard_locale.Language();
462 language_tag->region = shard_locale.Region();
463 language_tag->script = shard_locale.Script();
464 }
465 }
466
467 // Serialize the unary rules.
468 SerializeUnaryRulesShard(shards_[i].unary_rules, output, rules);
469 // Serialize the binary rules.
470 SerializeBinaryRulesShard(shards_[i].binary_rules, output, rules);
471 }
472 // Serialize the terminal rules.
473 // We keep the rules separate by shard but merge the actual terminals into
474 // one shared string pool to most effectively exploit reuse.
475 SerializeTerminalRules(output, &output->rules);
476 }
477
SerializeAsFlatbuffer(const bool include_debug_information) const478 std::string Ir::SerializeAsFlatbuffer(
479 const bool include_debug_information) const {
480 RulesSetT output;
481 Serialize(include_debug_information, &output);
482 flatbuffers::FlatBufferBuilder builder;
483 builder.Finish(RulesSet::Pack(builder, &output));
484 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
485 builder.GetSize());
486 }
487
488 } // namespace libtextclassifier3::grammar
489