1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/grammar/grammar-annotator.h"
18
19 #include "annotator/feature-processor.h"
20 #include "annotator/grammar/utils.h"
21 #include "annotator/types.h"
22 #include "utils/base/arena.h"
23 #include "utils/base/logging.h"
24 #include "utils/normalization.h"
25 #include "utils/optional.h"
26 #include "utils/utf8/unicodetext.h"
27
28 namespace libtextclassifier3 {
29 namespace {
30
31 // Retrieves all capturing nodes from a parse tree.
GetCapturingNodes(const grammar::ParseTree * parse_tree)32 std::unordered_map<uint16, const grammar::ParseTree*> GetCapturingNodes(
33 const grammar::ParseTree* parse_tree) {
34 std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes;
35 for (const grammar::MappingNode* mapping_node :
36 grammar::SelectAllOfType<grammar::MappingNode>(
37 parse_tree, grammar::ParseTree::Type::kMapping)) {
38 capturing_nodes[mapping_node->id] = mapping_node;
39 }
40 return capturing_nodes;
41 }
42
43 // Computes the selection boundaries from a parse tree.
MatchSelectionBoundaries(const grammar::ParseTree * parse_tree,const GrammarModel_::RuleClassificationResult * classification)44 CodepointSpan MatchSelectionBoundaries(
45 const grammar::ParseTree* parse_tree,
46 const GrammarModel_::RuleClassificationResult* classification) {
47 if (classification->capturing_group() == nullptr) {
48 // Use full match as selection span.
49 return parse_tree->codepoint_span;
50 }
51
52 // Set information from capturing matches.
53 CodepointSpan span{kInvalidIndex, kInvalidIndex};
54 std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
55 GetCapturingNodes(parse_tree);
56
57 // Compute span boundaries.
58 for (int i = 0; i < classification->capturing_group()->size(); i++) {
59 auto it = capturing_nodes.find(i);
60 if (it == capturing_nodes.end()) {
61 // Capturing group is not active, skip.
62 continue;
63 }
64 const CapturingGroup* group = classification->capturing_group()->Get(i);
65 if (group->extend_selection()) {
66 if (span.first == kInvalidIndex) {
67 span = it->second->codepoint_span;
68 } else {
69 span.first = std::min(span.first, it->second->codepoint_span.first);
70 span.second = std::max(span.second, it->second->codepoint_span.second);
71 }
72 }
73 }
74 return span;
75 }
76
77 } // namespace
78
GrammarAnnotator(const UniLib * unilib,const GrammarModel * model,const MutableFlatbufferBuilder * entity_data_builder)79 GrammarAnnotator::GrammarAnnotator(
80 const UniLib* unilib, const GrammarModel* model,
81 const MutableFlatbufferBuilder* entity_data_builder)
82 : unilib_(*unilib),
83 model_(model),
84 tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())),
85 entity_data_builder_(entity_data_builder),
86 analyzer_(unilib, model->rules(), &tokenizer_) {}
87
88 // Filters out results that do not overlap with a reference span.
OverlappingDerivations(const CodepointSpan & selection,const std::vector<grammar::Derivation> & derivations,const bool only_exact_overlap) const89 std::vector<grammar::Derivation> GrammarAnnotator::OverlappingDerivations(
90 const CodepointSpan& selection,
91 const std::vector<grammar::Derivation>& derivations,
92 const bool only_exact_overlap) const {
93 std::vector<grammar::Derivation> result;
94 for (const grammar::Derivation& derivation : derivations) {
95 // Discard matches that do not match the selection.
96 // Simple check.
97 if (!SpansOverlap(selection, derivation.parse_tree->codepoint_span)) {
98 continue;
99 }
100
101 // Compute exact selection boundaries (without assertions and
102 // non-capturing parts).
103 const CodepointSpan span = MatchSelectionBoundaries(
104 derivation.parse_tree,
105 model_->rule_classification_result()->Get(derivation.rule_id));
106 if (!SpansOverlap(selection, span) ||
107 (only_exact_overlap && span != selection)) {
108 continue;
109 }
110 result.push_back(derivation);
111 }
112 return result;
113 }
114
InstantiateAnnotatedSpanFromDerivation(const grammar::TextContext & input_context,const grammar::ParseTree * parse_tree,const GrammarModel_::RuleClassificationResult * interpretation,AnnotatedSpan * result) const115 bool GrammarAnnotator::InstantiateAnnotatedSpanFromDerivation(
116 const grammar::TextContext& input_context,
117 const grammar::ParseTree* parse_tree,
118 const GrammarModel_::RuleClassificationResult* interpretation,
119 AnnotatedSpan* result) const {
120 result->span = MatchSelectionBoundaries(parse_tree, interpretation);
121 ClassificationResult classification;
122 if (!InstantiateClassificationFromDerivation(
123 input_context, parse_tree, interpretation, &classification)) {
124 return false;
125 }
126 result->classification.push_back(classification);
127 return true;
128 }
129
130 // Instantiates a classification result from a rule match.
InstantiateClassificationFromDerivation(const grammar::TextContext & input_context,const grammar::ParseTree * parse_tree,const GrammarModel_::RuleClassificationResult * interpretation,ClassificationResult * classification) const131 bool GrammarAnnotator::InstantiateClassificationFromDerivation(
132 const grammar::TextContext& input_context,
133 const grammar::ParseTree* parse_tree,
134 const GrammarModel_::RuleClassificationResult* interpretation,
135 ClassificationResult* classification) const {
136 classification->collection = interpretation->collection_name()->str();
137 classification->score = interpretation->target_classification_score();
138 classification->priority_score = interpretation->priority_score();
139
140 // Assemble entity data.
141 if (entity_data_builder_ == nullptr) {
142 return true;
143 }
144 std::unique_ptr<MutableFlatbuffer> entity_data =
145 entity_data_builder_->NewRoot();
146 if (interpretation->serialized_entity_data() != nullptr) {
147 entity_data->MergeFromSerializedFlatbuffer(
148 StringPiece(interpretation->serialized_entity_data()->data(),
149 interpretation->serialized_entity_data()->size()));
150 }
151 if (interpretation->entity_data() != nullptr) {
152 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
153 interpretation->entity_data()));
154 }
155
156 // Populate entity data from the capturing matches.
157 if (interpretation->capturing_group() != nullptr) {
158 // Gather active capturing matches.
159 std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
160 GetCapturingNodes(parse_tree);
161
162 for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
163 auto it = capturing_nodes.find(i);
164 if (it == capturing_nodes.end()) {
165 // Capturing group is not active, skip.
166 continue;
167 }
168 const CapturingGroup* group = interpretation->capturing_group()->Get(i);
169
170 // Add static entity data.
171 if (group->serialized_entity_data() != nullptr) {
172 entity_data->MergeFromSerializedFlatbuffer(
173 StringPiece(interpretation->serialized_entity_data()->data(),
174 interpretation->serialized_entity_data()->size()));
175 }
176
177 // Set entity field from captured text.
178 if (group->entity_field_path() != nullptr) {
179 const grammar::ParseTree* capturing_match = it->second;
180 UnicodeText match_text =
181 input_context.Span(capturing_match->codepoint_span);
182 if (group->normalization_options() != nullptr) {
183 match_text = NormalizeText(unilib_, group->normalization_options(),
184 match_text);
185 }
186 if (!entity_data->ParseAndSet(group->entity_field_path(),
187 match_text.ToUTF8String())) {
188 TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
189 return false;
190 }
191 }
192 }
193 }
194
195 if (entity_data && entity_data->HasExplicitlySetFields()) {
196 classification->serialized_entity_data = entity_data->Serialize();
197 }
198 return true;
199 }
200
Annotate(const std::vector<Locale> & locales,const UnicodeText & text,std::vector<AnnotatedSpan> * result) const201 bool GrammarAnnotator::Annotate(const std::vector<Locale>& locales,
202 const UnicodeText& text,
203 std::vector<AnnotatedSpan>* result) const {
204 grammar::TextContext input_context =
205 analyzer_.BuildTextContextForInput(text, locales);
206
207 UnsafeArena arena(/*block_size=*/16 << 10);
208
209 for (const grammar::Derivation& derivation : ValidDeduplicatedDerivations(
210 analyzer_.parser().Parse(input_context, &arena))) {
211 const GrammarModel_::RuleClassificationResult* interpretation =
212 model_->rule_classification_result()->Get(derivation.rule_id);
213 if ((interpretation->enabled_modes() & ModeFlag_ANNOTATION) == 0) {
214 continue;
215 }
216 result->emplace_back();
217 if (!InstantiateAnnotatedSpanFromDerivation(
218 input_context, derivation.parse_tree, interpretation,
219 &result->back())) {
220 return false;
221 }
222 }
223
224 return true;
225 }
226
SuggestSelection(const std::vector<Locale> & locales,const UnicodeText & text,const CodepointSpan & selection,AnnotatedSpan * result) const227 bool GrammarAnnotator::SuggestSelection(const std::vector<Locale>& locales,
228 const UnicodeText& text,
229 const CodepointSpan& selection,
230 AnnotatedSpan* result) const {
231 if (!selection.IsValid() || selection.IsEmpty()) {
232 return false;
233 }
234
235 grammar::TextContext input_context =
236 analyzer_.BuildTextContextForInput(text, locales);
237
238 UnsafeArena arena(/*block_size=*/16 << 10);
239
240 const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
241 const grammar::ParseTree* best_match = nullptr;
242 for (const grammar::Derivation& derivation :
243 ValidDeduplicatedDerivations(OverlappingDerivations(
244 selection, analyzer_.parser().Parse(input_context, &arena),
245 /*only_exact_overlap=*/false))) {
246 const GrammarModel_::RuleClassificationResult* interpretation =
247 model_->rule_classification_result()->Get(derivation.rule_id);
248 if ((interpretation->enabled_modes() & ModeFlag_SELECTION) == 0) {
249 continue;
250 }
251 if (best_interpretation == nullptr ||
252 interpretation->priority_score() >
253 best_interpretation->priority_score()) {
254 best_interpretation = interpretation;
255 best_match = derivation.parse_tree;
256 }
257 }
258
259 if (best_interpretation == nullptr) {
260 return false;
261 }
262
263 return InstantiateAnnotatedSpanFromDerivation(input_context, best_match,
264 best_interpretation, result);
265 }
266
ClassifyText(const std::vector<Locale> & locales,const UnicodeText & text,const CodepointSpan & selection,ClassificationResult * classification_result) const267 bool GrammarAnnotator::ClassifyText(
268 const std::vector<Locale>& locales, const UnicodeText& text,
269 const CodepointSpan& selection,
270 ClassificationResult* classification_result) const {
271 if (!selection.IsValid() || selection.IsEmpty()) {
272 // Nothing to do.
273 return false;
274 }
275
276 grammar::TextContext input_context =
277 analyzer_.BuildTextContextForInput(text, locales);
278
279 if (const TokenSpan context_span = CodepointSpanToTokenSpan(
280 input_context.tokens, selection,
281 /*snap_boundaries_to_containing_tokens=*/true);
282 context_span.IsValid()) {
283 if (model_->context_left_num_tokens() != kInvalidIndex) {
284 input_context.context_span.first =
285 std::max(0, context_span.first - model_->context_left_num_tokens());
286 }
287 if (model_->context_right_num_tokens() != kInvalidIndex) {
288 input_context.context_span.second =
289 std::min(static_cast<int>(input_context.tokens.size()),
290 context_span.second + model_->context_right_num_tokens());
291 }
292 }
293
294 UnsafeArena arena(/*block_size=*/16 << 10);
295
296 const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
297 const grammar::ParseTree* best_match = nullptr;
298 for (const grammar::Derivation& derivation :
299 ValidDeduplicatedDerivations(OverlappingDerivations(
300 selection, analyzer_.parser().Parse(input_context, &arena),
301 /*only_exact_overlap=*/true))) {
302 const GrammarModel_::RuleClassificationResult* interpretation =
303 model_->rule_classification_result()->Get(derivation.rule_id);
304 if ((interpretation->enabled_modes() & ModeFlag_CLASSIFICATION) == 0) {
305 continue;
306 }
307 if (best_interpretation == nullptr ||
308 interpretation->priority_score() >
309 best_interpretation->priority_score()) {
310 best_interpretation = interpretation;
311 best_match = derivation.parse_tree;
312 }
313 }
314
315 if (best_interpretation == nullptr) {
316 return false;
317 }
318
319 return InstantiateClassificationFromDerivation(
320 input_context, best_match, best_interpretation, classification_result);
321 }
322
323 } // namespace libtextclassifier3
324