• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
19 
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "utils/base/integral_types.h"
26 #include "utils/grammar/rules_generated.h"
27 #include "utils/grammar/types.h"
28 #include "utils/grammar/utils/locale-shard-map.h"
29 
30 namespace libtextclassifier3::grammar {
31 
32 // Pre-defined nonterminal classes that the lexer can handle.
33 constexpr const char* kStartNonterm = "<^>";
34 constexpr const char* kEndNonterm = "<$>";
35 constexpr const char* kWordBreakNonterm = "<\b>";
36 constexpr const char* kTokenNonterm = "<token>";
37 constexpr const char* kUppercaseTokenNonterm = "<uppercase_token>";
38 constexpr const char* kDigitsNonterm = "<digits>";
39 constexpr const char* kNDigitsNonterm = "<%d_digits>";
40 constexpr const int kMaxNDigitsNontermLength = 20;
41 
42 // Low-level intermediate rules representation.
43 // In this representation, nonterminals are specified simply as integers
44 // (Nonterms), rather than strings which is more efficient.
45 // Rule set optimizations are done on this representation.
46 //
47 // Rules are represented in (mostly) Chomsky Normal Form, where all rules are
48 // of the following form, either:
49 //   * <nonterm> ::= term
50 //   * <nonterm> ::= <nonterm>
51 //   * <nonterm> ::= <nonterm> <nonterm>
52 class Ir {
53  public:
54   // A rule callback as a callback id and parameter pair.
55   struct Callback {
56     bool operator==(const Callback& other) const {
57       return std::tie(id, param) == std::tie(other.id, other.param);
58     }
59 
60     CallbackId id = kNoCallback;
61     int64 param = 0;
62   };
63 
64   // Constraints for triggering a rule.
65   struct Preconditions {
66     bool operator==(const Preconditions& other) const {
67       return max_whitespace_gap == other.max_whitespace_gap;
68     }
69 
70     // The maximum allowed whitespace between parts of the rule.
71     // The default of -1 allows for unbounded whitespace.
72     int8 max_whitespace_gap = -1;
73   };
74 
75   struct Lhs {
76     bool operator==(const Lhs& other) const {
77       return std::tie(nonterminal, callback, preconditions) ==
78              std::tie(other.nonterminal, other.callback, other.preconditions);
79     }
80 
81     Nonterm nonterminal = kUnassignedNonterm;
82     Callback callback;
83     Preconditions preconditions;
84   };
85   using LhsSet = std::vector<Lhs>;
86 
87   // A rules shard.
88   struct RulesShard {
89     // Terminal rules.
90     std::unordered_map<std::string, LhsSet> terminal_rules;
91     std::unordered_map<std::string, LhsSet> lowercase_terminal_rules;
92 
93     // Unary rules.
94     std::unordered_map<Nonterm, LhsSet> unary_rules;
95 
96     // Binary rules.
97     std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules;
98   };
99 
Ir(const LocaleShardMap & locale_shard_map)100   explicit Ir(const LocaleShardMap& locale_shard_map)
101       : num_nonterminals_(0),
102         locale_shard_map_(locale_shard_map),
103         shards_(locale_shard_map_.GetNumberOfShards()) {}
104 
105   // Adds a new non-terminal.
106   Nonterm AddNonterminal(const std::string& name = "") {
107     const Nonterm nonterminal = ++num_nonterminals_;
108     if (!name.empty()) {
109       // Record debug information.
110       SetNonterminal(name, nonterminal);
111     }
112     return nonterminal;
113   }
114 
115   // Sets the name of a nonterminal.
SetNonterminal(const std::string & name,const Nonterm nonterminal)116   void SetNonterminal(const std::string& name, const Nonterm nonterminal) {
117     nonterminal_names_[nonterminal] = name;
118     nonterminal_ids_[name] = nonterminal;
119   }
120 
121   // Defines a nonterminal if not yet defined.
DefineNonterminal(Nonterm nonterminal)122   Nonterm DefineNonterminal(Nonterm nonterminal) {
123     return (nonterminal != kUnassignedNonterm) ? nonterminal : AddNonterminal();
124   }
125 
126   // Defines a new non-terminal that cannot be shared internally.
127   Nonterm AddUnshareableNonterminal(const std::string& name = "") {
128     const Nonterm nonterminal = AddNonterminal(name);
129     nonshareable_.insert(nonterminal);
130     return nonterminal;
131   }
132 
133   // Gets the non-terminal for a given name, if it was previously defined.
GetNonterminalForName(const std::string & name)134   Nonterm GetNonterminalForName(const std::string& name) const {
135     const auto it = nonterminal_ids_.find(name);
136     if (it == nonterminal_ids_.end()) {
137       return kUnassignedNonterm;
138     }
139     return it->second;
140   }
141 
142   // Adds a terminal rule <lhs> ::= terminal.
143   Nonterm Add(const Lhs& lhs, const std::string& terminal,
144               bool case_sensitive = false, int shard = 0);
145   Nonterm Add(const Nonterm lhs, const std::string& terminal,
146               bool case_sensitive = false, int shard = 0) {
147     return Add(Lhs{lhs}, terminal, case_sensitive, shard);
148   }
149 
150   // Adds a unary rule <lhs> ::= <rhs>.
151   Nonterm Add(const Lhs& lhs, Nonterm rhs, int shard = 0) {
152     return AddRule(lhs, rhs, &shards_[shard].unary_rules);
153   }
154   Nonterm Add(Nonterm lhs, Nonterm rhs, int shard = 0) {
155     return Add(Lhs{lhs}, rhs, shard);
156   }
157 
158   // Adds a binary rule <lhs> ::= <rhs_1> <rhs_2>.
159   Nonterm Add(const Lhs& lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) {
160     return AddRule(lhs, {rhs_1, rhs_2}, &shards_[shard].binary_rules);
161   }
162   Nonterm Add(Nonterm lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) {
163     return Add(Lhs{lhs}, rhs_1, rhs_2, shard);
164   }
165 
166   // Adds a rule <lhs> ::= <rhs_1> <rhs_2> ... <rhs_k>
167   //
168   // If k > 2, we internally create a series of Nonterms representing prefixes
169   // of the full rhs.
170   //     <temp_1> ::= <RHS_1> <RHS_2>
171   //     <temp_2> ::= <temp_1> <RHS_3>
172   //     ...
173   //     <LHS> ::= <temp_(k-1)> <RHS_k>
174   Nonterm Add(const Lhs& lhs, const std::vector<Nonterm>& rhs, int shard = 0);
175   Nonterm Add(Nonterm lhs, const std::vector<Nonterm>& rhs, int shard = 0) {
176     return Add(Lhs{lhs}, rhs, shard);
177   }
178 
179   // Adds a regex rule <lhs> ::= <regex_pattern>.
180   Nonterm AddRegex(Nonterm lhs, const std::string& regex_pattern);
181 
182   // Adds a definition for a nonterminal provided by a text annotation.
183   void AddAnnotation(Nonterm lhs, const std::string& annotation);
184 
185   // Serializes a rule set in the intermediate representation into the
186   // memory mappable inference format.
187   void Serialize(bool include_debug_information, RulesSetT* output) const;
188 
189   std::string SerializeAsFlatbuffer(
190       bool include_debug_information = false) const;
191 
shards()192   const std::vector<RulesShard>& shards() const { return shards_; }
regex_rules()193   const std::vector<std::pair<std::string, Nonterm>>& regex_rules() const {
194     return regex_rules_;
195   }
annotations()196   const std::vector<std::pair<std::string, Nonterm>>& annotations() const {
197     return annotations_;
198   }
199 
200  private:
201   template <typename R, typename H>
AddRule(const Lhs & lhs,const R & rhs,std::unordered_map<R,LhsSet,H> * rules)202   Nonterm AddRule(const Lhs& lhs, const R& rhs,
203                   std::unordered_map<R, LhsSet, H>* rules) {
204     const auto it = rules->find(rhs);
205 
206     // Rhs was not yet used.
207     if (it == rules->end()) {
208       const Nonterm nonterminal = DefineNonterminal(lhs.nonterminal);
209       rules->insert(it,
210                     {rhs, {Lhs{nonterminal, lhs.callback, lhs.preconditions}}});
211       return nonterminal;
212     }
213 
214     return AddToSet(lhs, &it->second);
215   }
216 
217   // Adds a new callback to an lhs set, potentially sharing nonterminal ids and
218   // existing callbacks.
219   Nonterm AddToSet(const Lhs& lhs, LhsSet* lhs_set);
220 
221   // Serializes the sharded terminal rules.
222   void SerializeTerminalRules(
223       RulesSetT* rules_set,
224       std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const;
225 
226   // The defined non-terminals.
227   Nonterm num_nonterminals_;
228   std::unordered_set<Nonterm> nonshareable_;
229 
230   // Locale information for Rules
231   const LocaleShardMap& locale_shard_map_;
232   // The sharded rules.
233   std::vector<RulesShard> shards_;
234 
235   // The regex rules.
236   std::vector<std::pair<std::string, Nonterm>> regex_rules_;
237 
238   // Mapping from annotation name to nonterminal.
239   std::vector<std::pair<std::string, Nonterm>> annotations_;
240 
241   // Debug information.
242   std::unordered_map<Nonterm, std::string> nonterminal_names_;
243   std::unordered_map<std::string, Nonterm> nonterminal_ids_;
244 };
245 
246 }  // namespace libtextclassifier3::grammar
247 
248 #endif  // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
249