1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/number/number.h"
18
19 #include <climits>
20 #include <cstdlib>
21 #include <string>
22
23 #include "annotator/collections.h"
24 #include "annotator/types.h"
25 #include "utils/base/logging.h"
26 #include "utils/strings/split.h"
27 #include "utils/utf8/unicodetext.h"
28
29 namespace libtextclassifier3 {
30
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const31 bool NumberAnnotator::ClassifyText(
32 const UnicodeText& context, CodepointSpan selection_indices,
33 AnnotationUsecase annotation_usecase,
34 ClassificationResult* classification_result) const {
35 TC3_CHECK(classification_result != nullptr);
36
37 const UnicodeText substring_selected = UnicodeText::Substring(
38 context, selection_indices.first, selection_indices.second);
39
40 std::vector<AnnotatedSpan> results;
41 if (!FindAll(substring_selected, annotation_usecase, &results)) {
42 return false;
43 }
44
45 for (const AnnotatedSpan& result : results) {
46 if (result.classification.empty()) {
47 continue;
48 }
49
50 // We make sure that the result span is equal to the stripped selection span
51 // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
52 // anyway only find valid numbers and percentages and a given selection with
53 // more than two tokens won't pass this check.
54 if (result.span.first + selection_indices.first ==
55 selection_indices.first &&
56 result.span.second + selection_indices.first ==
57 selection_indices.second) {
58 *classification_result = result.classification[0];
59 return true;
60 }
61 }
62 return false;
63 }
64
IsCJTterm(UnicodeText::const_iterator token_begin_it,const int token_length) const65 bool NumberAnnotator::IsCJTterm(UnicodeText::const_iterator token_begin_it,
66 const int token_length) const {
67 auto token_end_it = token_begin_it;
68 std::advance(token_end_it, token_length);
69 for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
70 if (!unilib_->IsCJTletter(*char_it)) {
71 return false;
72 }
73 }
74 return true;
75 }
76
TokensAreValidStart(const std::vector<Token> & tokens,const int start_index) const77 bool NumberAnnotator::TokensAreValidStart(const std::vector<Token>& tokens,
78 const int start_index) const {
79 if (start_index < 0 || tokens[start_index].is_whitespace) {
80 return true;
81 }
82 return false;
83 }
84
TokensAreValidNumberPrefix(const std::vector<Token> & tokens,const int prefix_end_index) const85 bool NumberAnnotator::TokensAreValidNumberPrefix(
86 const std::vector<Token>& tokens, const int prefix_end_index) const {
87 if (TokensAreValidStart(tokens, prefix_end_index)) {
88 return true;
89 }
90
91 auto prefix_begin_it =
92 UTF8ToUnicodeText(tokens[prefix_end_index].value, /*do_copy=*/false)
93 .begin();
94 const int token_length =
95 tokens[prefix_end_index].end - tokens[prefix_end_index].start;
96 if (token_length == 1 && unilib_->IsOpeningBracket(*prefix_begin_it) &&
97 TokensAreValidStart(tokens, prefix_end_index - 1)) {
98 return true;
99 }
100 if (token_length == 1 && unilib_->IsNumberSign(*prefix_begin_it) &&
101 TokensAreValidStart(tokens, prefix_end_index - 1)) {
102 return true;
103 }
104 if (token_length == 1 && unilib_->IsSlash(*prefix_begin_it) &&
105 prefix_end_index >= 1 &&
106 TokensAreValidStart(tokens, prefix_end_index - 2)) {
107 int64 int_val;
108 double double_val;
109 return TryParseNumber(UTF8ToUnicodeText(tokens[prefix_end_index - 1].value,
110 /*do_copy=*/false),
111 false, &int_val, &double_val);
112 }
113 if (IsCJTterm(prefix_begin_it, token_length)) {
114 return true;
115 }
116
117 return false;
118 }
119
TokensAreValidEnding(const std::vector<Token> & tokens,const int ending_index) const120 bool NumberAnnotator::TokensAreValidEnding(const std::vector<Token>& tokens,
121 const int ending_index) const {
122 if (ending_index >= tokens.size() || tokens[ending_index].is_whitespace) {
123 return true;
124 }
125
126 auto ending_begin_it =
127 UTF8ToUnicodeText(tokens[ending_index].value, /*do_copy=*/false).begin();
128 if (ending_index == tokens.size() - 1 &&
129 tokens[ending_index].end - tokens[ending_index].start == 1 &&
130 unilib_->IsPunctuation(*ending_begin_it)) {
131 return true;
132 }
133 if (ending_index < tokens.size() - 1 &&
134 tokens[ending_index].end - tokens[ending_index].start == 1 &&
135 unilib_->IsPunctuation(*ending_begin_it) &&
136 tokens[ending_index + 1].is_whitespace) {
137 return true;
138 }
139
140 return false;
141 }
142
TokensAreValidNumberSuffix(const std::vector<Token> & tokens,const int suffix_start_index) const143 bool NumberAnnotator::TokensAreValidNumberSuffix(
144 const std::vector<Token>& tokens, const int suffix_start_index) const {
145 if (TokensAreValidEnding(tokens, suffix_start_index)) {
146 return true;
147 }
148
149 auto suffix_begin_it =
150 UTF8ToUnicodeText(tokens[suffix_start_index].value, /*do_copy=*/false)
151 .begin();
152
153 if (percent_suffixes_.find(tokens[suffix_start_index].value) !=
154 percent_suffixes_.end() &&
155 TokensAreValidEnding(tokens, suffix_start_index + 1)) {
156 return true;
157 }
158
159 const int token_length =
160 tokens[suffix_start_index].end - tokens[suffix_start_index].start;
161 if (token_length == 1 && unilib_->IsSlash(*suffix_begin_it) &&
162 suffix_start_index <= tokens.size() - 2 &&
163 TokensAreValidEnding(tokens, suffix_start_index + 2)) {
164 int64 int_val;
165 double double_val;
166 return TryParseNumber(
167 UTF8ToUnicodeText(tokens[suffix_start_index + 1].value,
168 /*do_copy=*/false),
169 false, &int_val, &double_val);
170 }
171 if (IsCJTterm(suffix_begin_it, token_length)) {
172 return true;
173 }
174
175 return false;
176 }
177
FindPercentSuffixEndCodepoint(const std::vector<Token> & tokens,const int suffix_token_start_index) const178 int NumberAnnotator::FindPercentSuffixEndCodepoint(
179 const std::vector<Token>& tokens,
180 const int suffix_token_start_index) const {
181 if (suffix_token_start_index >= tokens.size()) {
182 return -1;
183 }
184
185 if (percent_suffixes_.find(tokens[suffix_token_start_index].value) !=
186 percent_suffixes_.end() &&
187 TokensAreValidEnding(tokens, suffix_token_start_index + 1)) {
188 return tokens[suffix_token_start_index].end;
189 }
190 if (tokens[suffix_token_start_index].is_whitespace) {
191 return FindPercentSuffixEndCodepoint(tokens, suffix_token_start_index + 1);
192 }
193
194 return -1;
195 }
196
TryParseNumber(const UnicodeText & token_text,const bool is_negative,int64 * parsed_int_value,double * parsed_double_value) const197 bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
198 const bool is_negative,
199 int64* parsed_int_value,
200 double* parsed_double_value) const {
201 if (token_text.ToUTF8String().size() >= max_number_of_digits_) {
202 return false;
203 }
204 const bool is_double = unilib_->ParseDouble(token_text, parsed_double_value);
205 if (!is_double) {
206 return false;
207 }
208 *parsed_int_value = std::trunc(*parsed_double_value);
209 if (is_negative) {
210 *parsed_int_value *= -1;
211 *parsed_double_value *= -1;
212 }
213
214 return true;
215 }
216
FindAll(const UnicodeText & context,AnnotationUsecase annotation_usecase,std::vector<AnnotatedSpan> * result) const217 bool NumberAnnotator::FindAll(const UnicodeText& context,
218 AnnotationUsecase annotation_usecase,
219 std::vector<AnnotatedSpan>* result) const {
220 if (!options_->enabled()) {
221 return true;
222 }
223
224 const std::vector<Token> tokens = tokenizer_.Tokenize(context);
225 for (int i = 0; i < tokens.size(); ++i) {
226 const Token token = tokens[i];
227 if (tokens[i].value.empty() ||
228 !unilib_->IsDigit(
229 *UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false).begin())) {
230 continue;
231 }
232
233 const UnicodeText token_text =
234 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
235 int64 parsed_int_value;
236 double parsed_double_value;
237 bool is_negative =
238 (i > 0) &&
239 unilib_->IsMinus(
240 *UTF8ToUnicodeText(tokens[i - 1].value, /*do_copy=*/false).begin());
241 if (!TryParseNumber(token_text, is_negative, &parsed_int_value,
242 &parsed_double_value)) {
243 continue;
244 }
245 if (!TokensAreValidNumberPrefix(tokens, is_negative ? i - 2 : i - 1) ||
246 !TokensAreValidNumberSuffix(tokens, i + 1)) {
247 continue;
248 }
249
250 const bool has_decimal = !(parsed_int_value == parsed_double_value);
251 const int new_start_codepoint = is_negative ? token.start - 1 : token.start;
252
253 if (((1 << annotation_usecase) & options_->enabled_annotation_usecases()) !=
254 0) {
255 result->push_back(CreateAnnotatedSpan(
256 new_start_codepoint, token.end, parsed_int_value, parsed_double_value,
257 Collections::Number(), options_->score(),
258 /*priority_score=*/
259 has_decimal ? options_->float_number_priority_score()
260 : options_->priority_score()));
261 }
262
263 const int percent_end_codepoint =
264 FindPercentSuffixEndCodepoint(tokens, i + 1);
265 if (percent_end_codepoint != -1 &&
266 ((1 << annotation_usecase) &
267 options_->percentage_annotation_usecases()) != 0) {
268 result->push_back(CreateAnnotatedSpan(
269 new_start_codepoint, percent_end_codepoint, parsed_int_value,
270 parsed_double_value, Collections::Percentage(), options_->score(),
271 options_->percentage_priority_score()));
272 }
273 }
274
275 return true;
276 }
277
CreateAnnotatedSpan(const int start,const int end,const int int_value,const double double_value,const std::string collection,const float score,const float priority_score) const278 AnnotatedSpan NumberAnnotator::CreateAnnotatedSpan(
279 const int start, const int end, const int int_value,
280 const double double_value, const std::string collection, const float score,
281 const float priority_score) const {
282 ClassificationResult classification{collection, score};
283 classification.numeric_value = int_value;
284 classification.numeric_double_value = double_value;
285 classification.priority_score = priority_score;
286
287 AnnotatedSpan annotated_span;
288 annotated_span.span = {start, end};
289 annotated_span.classification.push_back(classification);
290 return annotated_span;
291 }
292
293 std::unordered_set<std::string>
FromFlatbufferStringToUnordredSet(const flatbuffers::String * flatbuffer_percent_strings)294 NumberAnnotator::FromFlatbufferStringToUnordredSet(
295 const flatbuffers::String* flatbuffer_percent_strings) {
296 std::unordered_set<std::string> strings_set;
297 if (flatbuffer_percent_strings == nullptr) {
298 return strings_set;
299 }
300
301 const std::string percent_strings = flatbuffer_percent_strings->str();
302 for (StringPiece suffix : strings::Split(percent_strings, '\0')) {
303 std::string percent_suffix = suffix.ToString();
304 percent_suffix.erase(
305 std::remove_if(percent_suffix.begin(), percent_suffix.end(),
306 [](unsigned char x) { return std::isspace(x); }),
307 percent_suffix.end());
308 strings_set.insert(percent_suffix);
309 }
310
311 return strings_set;
312 }
313
314 } // namespace libtextclassifier3
315