1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_ 19 20 #include <string> 21 #include <unordered_map> 22 #include <unordered_set> 23 #include <vector> 24 25 #include "utils/base/integral_types.h" 26 #include "utils/grammar/rules_generated.h" 27 #include "utils/grammar/types.h" 28 #include "utils/grammar/utils/locale-shard-map.h" 29 30 namespace libtextclassifier3::grammar { 31 32 // Pre-defined nonterminal classes that the lexer can handle. 33 constexpr const char* kStartNonterm = "<^>"; 34 constexpr const char* kEndNonterm = "<$>"; 35 constexpr const char* kWordBreakNonterm = "<\b>"; 36 constexpr const char* kTokenNonterm = "<token>"; 37 constexpr const char* kUppercaseTokenNonterm = "<uppercase_token>"; 38 constexpr const char* kDigitsNonterm = "<digits>"; 39 constexpr const char* kNDigitsNonterm = "<%d_digits>"; 40 constexpr const int kMaxNDigitsNontermLength = 20; 41 42 // Low-level intermediate rules representation. 43 // In this representation, nonterminals are specified simply as integers 44 // (Nonterms), rather than strings which is more efficient. 45 // Rule set optimizations are done on this representation. 46 // 47 // Rules are represented in (mostly) Chomsky Normal Form, where all rules are 48 // of the following form, either: 49 // * <nonterm> ::= term 50 // * <nonterm> ::= <nonterm> 51 // * <nonterm> ::= <nonterm> <nonterm> 52 class Ir { 53 public: 54 // A rule callback as a callback id and parameter pair. 55 struct Callback { 56 bool operator==(const Callback& other) const { 57 return std::tie(id, param) == std::tie(other.id, other.param); 58 } 59 60 CallbackId id = kNoCallback; 61 int64 param = 0; 62 }; 63 64 // Constraints for triggering a rule. 65 struct Preconditions { 66 bool operator==(const Preconditions& other) const { 67 return max_whitespace_gap == other.max_whitespace_gap; 68 } 69 70 // The maximum allowed whitespace between parts of the rule. 71 // The default of -1 allows for unbounded whitespace. 72 int8 max_whitespace_gap = -1; 73 }; 74 75 struct Lhs { 76 bool operator==(const Lhs& other) const { 77 return std::tie(nonterminal, callback, preconditions) == 78 std::tie(other.nonterminal, other.callback, other.preconditions); 79 } 80 81 Nonterm nonterminal = kUnassignedNonterm; 82 Callback callback; 83 Preconditions preconditions; 84 }; 85 using LhsSet = std::vector<Lhs>; 86 87 // A rules shard. 88 struct RulesShard { 89 // Terminal rules. 90 std::unordered_map<std::string, LhsSet> terminal_rules; 91 std::unordered_map<std::string, LhsSet> lowercase_terminal_rules; 92 93 // Unary rules. 94 std::unordered_map<Nonterm, LhsSet> unary_rules; 95 96 // Binary rules. 97 std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules; 98 }; 99 Ir(const LocaleShardMap & locale_shard_map)100 explicit Ir(const LocaleShardMap& locale_shard_map) 101 : num_nonterminals_(0), 102 locale_shard_map_(locale_shard_map), 103 shards_(locale_shard_map_.GetNumberOfShards()) {} 104 105 // Adds a new non-terminal. 106 Nonterm AddNonterminal(const std::string& name = "") { 107 const Nonterm nonterminal = ++num_nonterminals_; 108 if (!name.empty()) { 109 // Record debug information. 110 SetNonterminal(name, nonterminal); 111 } 112 return nonterminal; 113 } 114 115 // Sets the name of a nonterminal. SetNonterminal(const std::string & name,const Nonterm nonterminal)116 void SetNonterminal(const std::string& name, const Nonterm nonterminal) { 117 nonterminal_names_[nonterminal] = name; 118 nonterminal_ids_[name] = nonterminal; 119 } 120 121 // Defines a nonterminal if not yet defined. DefineNonterminal(Nonterm nonterminal)122 Nonterm DefineNonterminal(Nonterm nonterminal) { 123 return (nonterminal != kUnassignedNonterm) ? nonterminal : AddNonterminal(); 124 } 125 126 // Defines a new non-terminal that cannot be shared internally. 127 Nonterm AddUnshareableNonterminal(const std::string& name = "") { 128 const Nonterm nonterminal = AddNonterminal(name); 129 nonshareable_.insert(nonterminal); 130 return nonterminal; 131 } 132 133 // Gets the non-terminal for a given name, if it was previously defined. GetNonterminalForName(const std::string & name)134 Nonterm GetNonterminalForName(const std::string& name) const { 135 const auto it = nonterminal_ids_.find(name); 136 if (it == nonterminal_ids_.end()) { 137 return kUnassignedNonterm; 138 } 139 return it->second; 140 } 141 142 // Adds a terminal rule <lhs> ::= terminal. 143 Nonterm Add(const Lhs& lhs, const std::string& terminal, 144 bool case_sensitive = false, int shard = 0); 145 Nonterm Add(const Nonterm lhs, const std::string& terminal, 146 bool case_sensitive = false, int shard = 0) { 147 return Add(Lhs{lhs}, terminal, case_sensitive, shard); 148 } 149 150 // Adds a unary rule <lhs> ::= <rhs>. 151 Nonterm Add(const Lhs& lhs, Nonterm rhs, int shard = 0) { 152 return AddRule(lhs, rhs, &shards_[shard].unary_rules); 153 } 154 Nonterm Add(Nonterm lhs, Nonterm rhs, int shard = 0) { 155 return Add(Lhs{lhs}, rhs, shard); 156 } 157 158 // Adds a binary rule <lhs> ::= <rhs_1> <rhs_2>. 159 Nonterm Add(const Lhs& lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) { 160 return AddRule(lhs, {rhs_1, rhs_2}, &shards_[shard].binary_rules); 161 } 162 Nonterm Add(Nonterm lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) { 163 return Add(Lhs{lhs}, rhs_1, rhs_2, shard); 164 } 165 166 // Adds a rule <lhs> ::= <rhs_1> <rhs_2> ... <rhs_k> 167 // 168 // If k > 2, we internally create a series of Nonterms representing prefixes 169 // of the full rhs. 170 // <temp_1> ::= <RHS_1> <RHS_2> 171 // <temp_2> ::= <temp_1> <RHS_3> 172 // ... 173 // <LHS> ::= <temp_(k-1)> <RHS_k> 174 Nonterm Add(const Lhs& lhs, const std::vector<Nonterm>& rhs, int shard = 0); 175 Nonterm Add(Nonterm lhs, const std::vector<Nonterm>& rhs, int shard = 0) { 176 return Add(Lhs{lhs}, rhs, shard); 177 } 178 179 // Adds a regex rule <lhs> ::= <regex_pattern>. 180 Nonterm AddRegex(Nonterm lhs, const std::string& regex_pattern); 181 182 // Adds a definition for a nonterminal provided by a text annotation. 183 void AddAnnotation(Nonterm lhs, const std::string& annotation); 184 185 // Serializes a rule set in the intermediate representation into the 186 // memory mappable inference format. 187 void Serialize(bool include_debug_information, RulesSetT* output) const; 188 189 std::string SerializeAsFlatbuffer( 190 bool include_debug_information = false) const; 191 shards()192 const std::vector<RulesShard>& shards() const { return shards_; } regex_rules()193 const std::vector<std::pair<std::string, Nonterm>>& regex_rules() const { 194 return regex_rules_; 195 } annotations()196 const std::vector<std::pair<std::string, Nonterm>>& annotations() const { 197 return annotations_; 198 } 199 200 private: 201 template <typename R, typename H> AddRule(const Lhs & lhs,const R & rhs,std::unordered_map<R,LhsSet,H> * rules)202 Nonterm AddRule(const Lhs& lhs, const R& rhs, 203 std::unordered_map<R, LhsSet, H>* rules) { 204 const auto it = rules->find(rhs); 205 206 // Rhs was not yet used. 207 if (it == rules->end()) { 208 const Nonterm nonterminal = DefineNonterminal(lhs.nonterminal); 209 rules->insert(it, 210 {rhs, {Lhs{nonterminal, lhs.callback, lhs.preconditions}}}); 211 return nonterminal; 212 } 213 214 return AddToSet(lhs, &it->second); 215 } 216 217 // Adds a new callback to an lhs set, potentially sharing nonterminal ids and 218 // existing callbacks. 219 Nonterm AddToSet(const Lhs& lhs, LhsSet* lhs_set); 220 221 // Serializes the sharded terminal rules. 222 void SerializeTerminalRules( 223 RulesSetT* rules_set, 224 std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const; 225 226 // The defined non-terminals. 227 Nonterm num_nonterminals_; 228 std::unordered_set<Nonterm> nonshareable_; 229 230 // Locale information for Rules 231 const LocaleShardMap& locale_shard_map_; 232 // The sharded rules. 233 std::vector<RulesShard> shards_; 234 235 // The regex rules. 236 std::vector<std::pair<std::string, Nonterm>> regex_rules_; 237 238 // Mapping from annotation name to nonterminal. 239 std::vector<std::pair<std::string, Nonterm>> annotations_; 240 241 // Debug information. 242 std::unordered_map<Nonterm, std::string> nonterminal_names_; 243 std::unordered_map<std::string, Nonterm> nonterminal_ids_; 244 }; 245 246 } // namespace libtextclassifier3::grammar 247 248 #endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_ 249