1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
19
20 #include <vector>
21
22 #include "utils/grammar/parsing/parse-tree.h"
23
24 namespace libtextclassifier3::grammar {
25
26 // A parse tree for a root rule.
27 struct Derivation {
28 const ParseTree* parse_tree;
29 int64 rule_id;
30
31 // Checks that all assertions are fulfilled.
32 bool IsValid() const;
GetRuleIdDerivation33 int64 GetRuleId() const { return rule_id; }
GetParseTreeDerivation34 const ParseTree* GetParseTree() const { return parse_tree; }
35 };
36
37 // Deduplicates rule derivations by containing overlap.
38 // The grammar system can output multiple candidates for optional parts.
39 // For example if a rule has an optional suffix, we
40 // will get two rule derivations when the suffix is present: one with and one
41 // without the suffix. We therefore deduplicate by containing overlap, viz. from
42 // two candidates we keep the longer one if it completely contains the shorter.
43 // This factory function works with any type T that extends Derivation.
44 template <typename T, typename std::enable_if<std::is_base_of<
45 Derivation, T>::value>::type* = nullptr>
46 // std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations);
DeduplicateDerivations(const std::vector<T> & derivations)47 std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations) {
48 std::vector<T> sorted_candidates = derivations;
49
50 std::stable_sort(sorted_candidates.begin(), sorted_candidates.end(),
51 [](const T& a, const T& b) {
52 // Sort by id.
53 if (a.GetRuleId() != b.GetRuleId()) {
54 return a.GetRuleId() < b.GetRuleId();
55 }
56
57 // Sort by increasing start.
58 if (a.GetParseTree()->codepoint_span.first !=
59 b.GetParseTree()->codepoint_span.first) {
60 return a.GetParseTree()->codepoint_span.first <
61 b.GetParseTree()->codepoint_span.first;
62 }
63
64 // Sort by decreasing end.
65 return a.GetParseTree()->codepoint_span.second >
66 b.GetParseTree()->codepoint_span.second;
67 });
68
69 // Deduplicate by overlap.
70 std::vector<T> result;
71 for (int i = 0; i < sorted_candidates.size(); i++) {
72 const T& candidate = sorted_candidates[i];
73 bool eliminated = false;
74
75 // Due to the sorting above, the candidate can only be completely
76 // intersected by a match before it in the sorted order.
77 for (int j = i - 1; j >= 0; j--) {
78 if (sorted_candidates[j].rule_id != candidate.rule_id) {
79 break;
80 }
81 if (sorted_candidates[j].parse_tree->codepoint_span.first <=
82 candidate.parse_tree->codepoint_span.first &&
83 sorted_candidates[j].parse_tree->codepoint_span.second >=
84 candidate.parse_tree->codepoint_span.second) {
85 eliminated = true;
86 break;
87 }
88 }
89 if (!eliminated) {
90 result.push_back(candidate);
91 }
92 }
93 return result;
94 }
95
96 // Deduplicates and validates rule derivations.
97 std::vector<Derivation> ValidDeduplicatedDerivations(
98 const std::vector<Derivation>& derivations);
99
100 } // namespace libtextclassifier3::grammar
101
102 #endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
103