1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/number/number.h"
18
19 #include <climits>
20 #include <cstdlib>
21 #include <string>
22
23 #include "annotator/collections.h"
24 #include "annotator/model_generated.h"
25 #include "annotator/types.h"
26 #include "utils/base/logging.h"
27 #include "utils/strings/split.h"
28 #include "utils/utf8/unicodetext.h"
29
30 namespace libtextclassifier3 {
31
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const32 bool NumberAnnotator::ClassifyText(
33 const UnicodeText& context, CodepointSpan selection_indices,
34 AnnotationUsecase annotation_usecase,
35 ClassificationResult* classification_result) const {
36 TC3_CHECK(classification_result != nullptr);
37
38 const UnicodeText substring_selected = UnicodeText::Substring(
39 context, selection_indices.first, selection_indices.second);
40
41 std::vector<AnnotatedSpan> results;
42 if (!FindAll(substring_selected, annotation_usecase, ModeFlag_CLASSIFICATION,
43 &results)) {
44 return false;
45 }
46
47 for (const AnnotatedSpan& result : results) {
48 if (result.classification.empty()) {
49 continue;
50 }
51
52 // We make sure that the result span is equal to the stripped selection span
53 // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
54 // anyway only find valid numbers and percentages and a given selection with
55 // more than two tokens won't pass this check.
56 if (result.span.first + selection_indices.first ==
57 selection_indices.first &&
58 result.span.second + selection_indices.first ==
59 selection_indices.second) {
60 *classification_result = result.classification[0];
61 return true;
62 }
63 }
64 return false;
65 }
66
IsCJTterm(UnicodeText::const_iterator token_begin_it,const int token_length) const67 bool NumberAnnotator::IsCJTterm(UnicodeText::const_iterator token_begin_it,
68 const int token_length) const {
69 auto token_end_it = token_begin_it;
70 std::advance(token_end_it, token_length);
71 for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
72 if (!unilib_->IsCJTletter(*char_it)) {
73 return false;
74 }
75 }
76 return true;
77 }
78
TokensAreValidStart(const std::vector<Token> & tokens,const int start_index) const79 bool NumberAnnotator::TokensAreValidStart(const std::vector<Token>& tokens,
80 const int start_index) const {
81 if (start_index < 0 || tokens[start_index].is_whitespace) {
82 return true;
83 }
84 return false;
85 }
86
TokensAreValidNumberPrefix(const std::vector<Token> & tokens,const int prefix_end_index) const87 bool NumberAnnotator::TokensAreValidNumberPrefix(
88 const std::vector<Token>& tokens, const int prefix_end_index) const {
89 if (TokensAreValidStart(tokens, prefix_end_index)) {
90 return true;
91 }
92
93 auto prefix_begin_it =
94 UTF8ToUnicodeText(tokens[prefix_end_index].value, /*do_copy=*/false)
95 .begin();
96 const int token_length =
97 tokens[prefix_end_index].end - tokens[prefix_end_index].start;
98 if (token_length == 1 && unilib_->IsOpeningBracket(*prefix_begin_it) &&
99 TokensAreValidStart(tokens, prefix_end_index - 1)) {
100 return true;
101 }
102 if (token_length == 1 && unilib_->IsNumberSign(*prefix_begin_it) &&
103 TokensAreValidStart(tokens, prefix_end_index - 1)) {
104 return true;
105 }
106 if (token_length == 1 && unilib_->IsSlash(*prefix_begin_it) &&
107 prefix_end_index >= 1 &&
108 TokensAreValidStart(tokens, prefix_end_index - 2)) {
109 int64 int_val;
110 double double_val;
111 return TryParseNumber(UTF8ToUnicodeText(tokens[prefix_end_index - 1].value,
112 /*do_copy=*/false),
113 false, &int_val, &double_val);
114 }
115 if (IsCJTterm(prefix_begin_it, token_length)) {
116 return true;
117 }
118
119 return false;
120 }
121
TokensAreValidEnding(const std::vector<Token> & tokens,const int ending_index) const122 bool NumberAnnotator::TokensAreValidEnding(const std::vector<Token>& tokens,
123 const int ending_index) const {
124 if (ending_index >= tokens.size() || tokens[ending_index].is_whitespace) {
125 return true;
126 }
127
128 auto ending_begin_it =
129 UTF8ToUnicodeText(tokens[ending_index].value, /*do_copy=*/false).begin();
130 if (ending_index == tokens.size() - 1 &&
131 tokens[ending_index].end - tokens[ending_index].start == 1 &&
132 unilib_->IsPunctuation(*ending_begin_it)) {
133 return true;
134 }
135 if (ending_index < tokens.size() - 1 &&
136 tokens[ending_index].end - tokens[ending_index].start == 1 &&
137 unilib_->IsPunctuation(*ending_begin_it) &&
138 tokens[ending_index + 1].is_whitespace) {
139 return true;
140 }
141
142 return false;
143 }
144
TokensAreValidNumberSuffix(const std::vector<Token> & tokens,const int suffix_start_index) const145 bool NumberAnnotator::TokensAreValidNumberSuffix(
146 const std::vector<Token>& tokens, const int suffix_start_index) const {
147 if (TokensAreValidEnding(tokens, suffix_start_index)) {
148 return true;
149 }
150
151 auto suffix_begin_it =
152 UTF8ToUnicodeText(tokens[suffix_start_index].value, /*do_copy=*/false)
153 .begin();
154
155 if (percent_suffixes_.find(tokens[suffix_start_index].value) !=
156 percent_suffixes_.end() &&
157 TokensAreValidEnding(tokens, suffix_start_index + 1)) {
158 return true;
159 }
160
161 const int token_length =
162 tokens[suffix_start_index].end - tokens[suffix_start_index].start;
163 if (token_length == 1 && unilib_->IsSlash(*suffix_begin_it) &&
164 suffix_start_index <= tokens.size() - 2 &&
165 TokensAreValidEnding(tokens, suffix_start_index + 2)) {
166 int64 int_val;
167 double double_val;
168 return TryParseNumber(
169 UTF8ToUnicodeText(tokens[suffix_start_index + 1].value,
170 /*do_copy=*/false),
171 false, &int_val, &double_val);
172 }
173 if (IsCJTterm(suffix_begin_it, token_length)) {
174 return true;
175 }
176
177 return false;
178 }
179
FindPercentSuffixEndCodepoint(const std::vector<Token> & tokens,const int suffix_token_start_index) const180 int NumberAnnotator::FindPercentSuffixEndCodepoint(
181 const std::vector<Token>& tokens,
182 const int suffix_token_start_index) const {
183 if (suffix_token_start_index >= tokens.size()) {
184 return -1;
185 }
186
187 if (percent_suffixes_.find(tokens[suffix_token_start_index].value) !=
188 percent_suffixes_.end() &&
189 TokensAreValidEnding(tokens, suffix_token_start_index + 1)) {
190 return tokens[suffix_token_start_index].end;
191 }
192 if (tokens[suffix_token_start_index].is_whitespace) {
193 return FindPercentSuffixEndCodepoint(tokens, suffix_token_start_index + 1);
194 }
195
196 return -1;
197 }
198
TryParseNumber(const UnicodeText & token_text,const bool is_negative,int64 * parsed_int_value,double * parsed_double_value) const199 bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
200 const bool is_negative,
201 int64* parsed_int_value,
202 double* parsed_double_value) const {
203 if (token_text.ToUTF8String().size() >= max_number_of_digits_) {
204 return false;
205 }
206 const bool is_double = unilib_->ParseDouble(token_text, parsed_double_value);
207 if (!is_double) {
208 return false;
209 }
210 *parsed_int_value = std::trunc(*parsed_double_value);
211 if (is_negative) {
212 *parsed_int_value *= -1;
213 *parsed_double_value *= -1;
214 }
215
216 return true;
217 }
218
FindAll(const UnicodeText & context,AnnotationUsecase annotation_usecase,ModeFlag mode,std::vector<AnnotatedSpan> * result) const219 bool NumberAnnotator::FindAll(const UnicodeText& context,
220 AnnotationUsecase annotation_usecase,
221 ModeFlag mode,
222 std::vector<AnnotatedSpan>* result) const {
223 if (!options_->enabled() || !(options_->enabled_modes() & mode)) {
224 return true;
225 }
226
227 const std::vector<Token> tokens = tokenizer_.Tokenize(context);
228 for (int i = 0; i < tokens.size(); ++i) {
229 const Token token = tokens[i];
230 if (tokens[i].value.empty() ||
231 !unilib_->IsDigit(
232 *UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false).begin())) {
233 continue;
234 }
235
236 const UnicodeText token_text =
237 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
238 int64 parsed_int_value;
239 double parsed_double_value;
240 bool is_negative =
241 (i > 0) &&
242 unilib_->IsMinus(
243 *UTF8ToUnicodeText(tokens[i - 1].value, /*do_copy=*/false).begin());
244 if (!TryParseNumber(token_text, is_negative, &parsed_int_value,
245 &parsed_double_value)) {
246 continue;
247 }
248 if (!TokensAreValidNumberPrefix(tokens, is_negative ? i - 2 : i - 1) ||
249 !TokensAreValidNumberSuffix(tokens, i + 1)) {
250 continue;
251 }
252
253 const bool has_decimal = !(parsed_int_value == parsed_double_value);
254 const int new_start_codepoint = is_negative ? token.start - 1 : token.start;
255
256 if (((1 << annotation_usecase) & options_->enabled_annotation_usecases()) !=
257 0) {
258 result->push_back(CreateAnnotatedSpan(
259 new_start_codepoint, token.end, parsed_int_value, parsed_double_value,
260 Collections::Number(), options_->score(),
261 /*priority_score=*/
262 has_decimal ? options_->float_number_priority_score()
263 : options_->priority_score()));
264 }
265
266 const int percent_end_codepoint =
267 FindPercentSuffixEndCodepoint(tokens, i + 1);
268 if (percent_end_codepoint != -1 &&
269 ((1 << annotation_usecase) &
270 options_->percentage_annotation_usecases()) != 0) {
271 result->push_back(CreateAnnotatedSpan(
272 new_start_codepoint, percent_end_codepoint, parsed_int_value,
273 parsed_double_value, Collections::Percentage(), options_->score(),
274 options_->percentage_priority_score()));
275 }
276 }
277
278 return true;
279 }
280
CreateAnnotatedSpan(const int start,const int end,const int int_value,const double double_value,const std::string collection,const float score,const float priority_score) const281 AnnotatedSpan NumberAnnotator::CreateAnnotatedSpan(
282 const int start, const int end, const int int_value,
283 const double double_value, const std::string collection, const float score,
284 const float priority_score) const {
285 ClassificationResult classification{collection, score};
286 classification.numeric_value = int_value;
287 classification.numeric_double_value = double_value;
288 classification.priority_score = priority_score;
289
290 AnnotatedSpan annotated_span;
291 annotated_span.span = {start, end};
292 annotated_span.classification.push_back(classification);
293 return annotated_span;
294 }
295
296 std::unordered_set<std::string>
FromFlatbufferStringToUnordredSet(const flatbuffers::String * flatbuffer_percent_strings)297 NumberAnnotator::FromFlatbufferStringToUnordredSet(
298 const flatbuffers::String* flatbuffer_percent_strings) {
299 std::unordered_set<std::string> strings_set;
300 if (flatbuffer_percent_strings == nullptr) {
301 return strings_set;
302 }
303
304 const std::string percent_strings = flatbuffer_percent_strings->str();
305 for (StringPiece suffix : strings::Split(percent_strings, '\0')) {
306 std::string percent_suffix = suffix.ToString();
307 percent_suffix.erase(
308 std::remove_if(percent_suffix.begin(), percent_suffix.end(),
309 [](unsigned char x) { return std::isspace(x); }),
310 percent_suffix.end());
311 strings_set.insert(percent_suffix);
312 }
313
314 return strings_set;
315 }
316
317 } // namespace libtextclassifier3
318