• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/number/number.h"
18 
19 #include <climits>
20 #include <cstdlib>
21 #include <string>
22 
23 #include "annotator/collections.h"
24 #include "annotator/types.h"
25 #include "utils/base/logging.h"
26 #include "utils/strings/split.h"
27 #include "utils/utf8/unicodetext.h"
28 
29 namespace libtextclassifier3 {
30 
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const31 bool NumberAnnotator::ClassifyText(
32     const UnicodeText& context, CodepointSpan selection_indices,
33     AnnotationUsecase annotation_usecase,
34     ClassificationResult* classification_result) const {
35   TC3_CHECK(classification_result != nullptr);
36 
37   const UnicodeText substring_selected = UnicodeText::Substring(
38       context, selection_indices.first, selection_indices.second);
39 
40   std::vector<AnnotatedSpan> results;
41   if (!FindAll(substring_selected, annotation_usecase, &results)) {
42     return false;
43   }
44 
45   for (const AnnotatedSpan& result : results) {
46     if (result.classification.empty()) {
47       continue;
48     }
49 
50     // We make sure that the result span is equal to the stripped selection span
51     // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
52     // anyway only find valid numbers and percentages and a given selection with
53     // more than two tokens won't pass this check.
54     if (result.span.first + selection_indices.first ==
55             selection_indices.first &&
56         result.span.second + selection_indices.first ==
57             selection_indices.second) {
58       *classification_result = result.classification[0];
59       return true;
60     }
61   }
62   return false;
63 }
64 
IsCJTterm(UnicodeText::const_iterator token_begin_it,const int token_length) const65 bool NumberAnnotator::IsCJTterm(UnicodeText::const_iterator token_begin_it,
66                                 const int token_length) const {
67   auto token_end_it = token_begin_it;
68   std::advance(token_end_it, token_length);
69   for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
70     if (!unilib_->IsCJTletter(*char_it)) {
71       return false;
72     }
73   }
74   return true;
75 }
76 
TokensAreValidStart(const std::vector<Token> & tokens,const int start_index) const77 bool NumberAnnotator::TokensAreValidStart(const std::vector<Token>& tokens,
78                                           const int start_index) const {
79   if (start_index < 0 || tokens[start_index].is_whitespace) {
80     return true;
81   }
82   return false;
83 }
84 
TokensAreValidNumberPrefix(const std::vector<Token> & tokens,const int prefix_end_index) const85 bool NumberAnnotator::TokensAreValidNumberPrefix(
86     const std::vector<Token>& tokens, const int prefix_end_index) const {
87   if (TokensAreValidStart(tokens, prefix_end_index)) {
88     return true;
89   }
90 
91   auto prefix_begin_it =
92       UTF8ToUnicodeText(tokens[prefix_end_index].value, /*do_copy=*/false)
93           .begin();
94   const int token_length =
95       tokens[prefix_end_index].end - tokens[prefix_end_index].start;
96   if (token_length == 1 && unilib_->IsOpeningBracket(*prefix_begin_it) &&
97       TokensAreValidStart(tokens, prefix_end_index - 1)) {
98     return true;
99   }
100   if (token_length == 1 && unilib_->IsNumberSign(*prefix_begin_it) &&
101       TokensAreValidStart(tokens, prefix_end_index - 1)) {
102     return true;
103   }
104   if (token_length == 1 && unilib_->IsSlash(*prefix_begin_it) &&
105       prefix_end_index >= 1 &&
106       TokensAreValidStart(tokens, prefix_end_index - 2)) {
107     int64 int_val;
108     double double_val;
109     return TryParseNumber(UTF8ToUnicodeText(tokens[prefix_end_index - 1].value,
110                                             /*do_copy=*/false),
111                           false, &int_val, &double_val);
112   }
113   if (IsCJTterm(prefix_begin_it, token_length)) {
114     return true;
115   }
116 
117   return false;
118 }
119 
TokensAreValidEnding(const std::vector<Token> & tokens,const int ending_index) const120 bool NumberAnnotator::TokensAreValidEnding(const std::vector<Token>& tokens,
121                                            const int ending_index) const {
122   if (ending_index >= tokens.size() || tokens[ending_index].is_whitespace) {
123     return true;
124   }
125 
126   auto ending_begin_it =
127       UTF8ToUnicodeText(tokens[ending_index].value, /*do_copy=*/false).begin();
128   if (ending_index == tokens.size() - 1 &&
129       tokens[ending_index].end - tokens[ending_index].start == 1 &&
130       unilib_->IsPunctuation(*ending_begin_it)) {
131     return true;
132   }
133   if (ending_index < tokens.size() - 1 &&
134       tokens[ending_index].end - tokens[ending_index].start == 1 &&
135       unilib_->IsPunctuation(*ending_begin_it) &&
136       tokens[ending_index + 1].is_whitespace) {
137     return true;
138   }
139 
140   return false;
141 }
142 
TokensAreValidNumberSuffix(const std::vector<Token> & tokens,const int suffix_start_index) const143 bool NumberAnnotator::TokensAreValidNumberSuffix(
144     const std::vector<Token>& tokens, const int suffix_start_index) const {
145   if (TokensAreValidEnding(tokens, suffix_start_index)) {
146     return true;
147   }
148 
149   auto suffix_begin_it =
150       UTF8ToUnicodeText(tokens[suffix_start_index].value, /*do_copy=*/false)
151           .begin();
152 
153   if (percent_suffixes_.find(tokens[suffix_start_index].value) !=
154           percent_suffixes_.end() &&
155       TokensAreValidEnding(tokens, suffix_start_index + 1)) {
156     return true;
157   }
158 
159   const int token_length =
160       tokens[suffix_start_index].end - tokens[suffix_start_index].start;
161   if (token_length == 1 && unilib_->IsSlash(*suffix_begin_it) &&
162       suffix_start_index <= tokens.size() - 2 &&
163       TokensAreValidEnding(tokens, suffix_start_index + 2)) {
164     int64 int_val;
165     double double_val;
166     return TryParseNumber(
167         UTF8ToUnicodeText(tokens[suffix_start_index + 1].value,
168                           /*do_copy=*/false),
169         false, &int_val, &double_val);
170   }
171   if (IsCJTterm(suffix_begin_it, token_length)) {
172     return true;
173   }
174 
175   return false;
176 }
177 
FindPercentSuffixEndCodepoint(const std::vector<Token> & tokens,const int suffix_token_start_index) const178 int NumberAnnotator::FindPercentSuffixEndCodepoint(
179     const std::vector<Token>& tokens,
180     const int suffix_token_start_index) const {
181   if (suffix_token_start_index >= tokens.size()) {
182     return -1;
183   }
184 
185   if (percent_suffixes_.find(tokens[suffix_token_start_index].value) !=
186           percent_suffixes_.end() &&
187       TokensAreValidEnding(tokens, suffix_token_start_index + 1)) {
188     return tokens[suffix_token_start_index].end;
189   }
190   if (tokens[suffix_token_start_index].is_whitespace) {
191     return FindPercentSuffixEndCodepoint(tokens, suffix_token_start_index + 1);
192   }
193 
194   return -1;
195 }
196 
TryParseNumber(const UnicodeText & token_text,const bool is_negative,int64 * parsed_int_value,double * parsed_double_value) const197 bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
198                                      const bool is_negative,
199                                      int64* parsed_int_value,
200                                      double* parsed_double_value) const {
201   if (token_text.ToUTF8String().size() >= max_number_of_digits_) {
202     return false;
203   }
204   const bool is_double = unilib_->ParseDouble(token_text, parsed_double_value);
205   if (!is_double) {
206     return false;
207   }
208   *parsed_int_value = std::trunc(*parsed_double_value);
209   if (is_negative) {
210     *parsed_int_value *= -1;
211     *parsed_double_value *= -1;
212   }
213 
214   return true;
215 }
216 
FindAll(const UnicodeText & context,AnnotationUsecase annotation_usecase,std::vector<AnnotatedSpan> * result) const217 bool NumberAnnotator::FindAll(const UnicodeText& context,
218                               AnnotationUsecase annotation_usecase,
219                               std::vector<AnnotatedSpan>* result) const {
220   if (!options_->enabled()) {
221     return true;
222   }
223 
224   const std::vector<Token> tokens = tokenizer_.Tokenize(context);
225   for (int i = 0; i < tokens.size(); ++i) {
226     const Token token = tokens[i];
227     if (tokens[i].value.empty() ||
228         !unilib_->IsDigit(
229             *UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false).begin())) {
230       continue;
231     }
232 
233     const UnicodeText token_text =
234         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
235     int64 parsed_int_value;
236     double parsed_double_value;
237     bool is_negative =
238         (i > 0) &&
239         unilib_->IsMinus(
240             *UTF8ToUnicodeText(tokens[i - 1].value, /*do_copy=*/false).begin());
241     if (!TryParseNumber(token_text, is_negative, &parsed_int_value,
242                         &parsed_double_value)) {
243       continue;
244     }
245     if (!TokensAreValidNumberPrefix(tokens, is_negative ? i - 2 : i - 1) ||
246         !TokensAreValidNumberSuffix(tokens, i + 1)) {
247       continue;
248     }
249 
250     const bool has_decimal = !(parsed_int_value == parsed_double_value);
251     const int new_start_codepoint = is_negative ? token.start - 1 : token.start;
252 
253     if (((1 << annotation_usecase) & options_->enabled_annotation_usecases()) !=
254         0) {
255       result->push_back(CreateAnnotatedSpan(
256           new_start_codepoint, token.end, parsed_int_value, parsed_double_value,
257           Collections::Number(), options_->score(),
258           /*priority_score=*/
259           has_decimal ? options_->float_number_priority_score()
260                       : options_->priority_score()));
261     }
262 
263     const int percent_end_codepoint =
264         FindPercentSuffixEndCodepoint(tokens, i + 1);
265     if (percent_end_codepoint != -1 &&
266         ((1 << annotation_usecase) &
267          options_->percentage_annotation_usecases()) != 0) {
268       result->push_back(CreateAnnotatedSpan(
269           new_start_codepoint, percent_end_codepoint, parsed_int_value,
270           parsed_double_value, Collections::Percentage(), options_->score(),
271           options_->percentage_priority_score()));
272     }
273   }
274 
275   return true;
276 }
277 
CreateAnnotatedSpan(const int start,const int end,const int int_value,const double double_value,const std::string collection,const float score,const float priority_score) const278 AnnotatedSpan NumberAnnotator::CreateAnnotatedSpan(
279     const int start, const int end, const int int_value,
280     const double double_value, const std::string collection, const float score,
281     const float priority_score) const {
282   ClassificationResult classification{collection, score};
283   classification.numeric_value = int_value;
284   classification.numeric_double_value = double_value;
285   classification.priority_score = priority_score;
286 
287   AnnotatedSpan annotated_span;
288   annotated_span.span = {start, end};
289   annotated_span.classification.push_back(classification);
290   return annotated_span;
291 }
292 
293 std::unordered_set<std::string>
FromFlatbufferStringToUnordredSet(const flatbuffers::String * flatbuffer_percent_strings)294 NumberAnnotator::FromFlatbufferStringToUnordredSet(
295     const flatbuffers::String* flatbuffer_percent_strings) {
296   std::unordered_set<std::string> strings_set;
297   if (flatbuffer_percent_strings == nullptr) {
298     return strings_set;
299   }
300 
301   const std::string percent_strings = flatbuffer_percent_strings->str();
302   for (StringPiece suffix : strings::Split(percent_strings, '\0')) {
303     std::string percent_suffix = suffix.ToString();
304     percent_suffix.erase(
305         std::remove_if(percent_suffix.begin(), percent_suffix.end(),
306                        [](unsigned char x) { return std::isspace(x); }),
307         percent_suffix.end());
308     strings_set.insert(percent_suffix);
309   }
310 
311   return strings_set;
312 }
313 
314 }  // namespace libtextclassifier3
315