• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/number/number.h"
18 
19 #include <climits>
20 #include <cstdlib>
21 #include <string>
22 
23 #include "annotator/collections.h"
24 #include "annotator/model_generated.h"
25 #include "annotator/types.h"
26 #include "utils/base/logging.h"
27 #include "utils/strings/split.h"
28 #include "utils/utf8/unicodetext.h"
29 
30 namespace libtextclassifier3 {
31 
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const32 bool NumberAnnotator::ClassifyText(
33     const UnicodeText& context, CodepointSpan selection_indices,
34     AnnotationUsecase annotation_usecase,
35     ClassificationResult* classification_result) const {
36   TC3_CHECK(classification_result != nullptr);
37 
38   const UnicodeText substring_selected = UnicodeText::Substring(
39       context, selection_indices.first, selection_indices.second);
40 
41   std::vector<AnnotatedSpan> results;
42   if (!FindAll(substring_selected, annotation_usecase, ModeFlag_CLASSIFICATION,
43                &results)) {
44     return false;
45   }
46 
47   for (const AnnotatedSpan& result : results) {
48     if (result.classification.empty()) {
49       continue;
50     }
51 
52     // We make sure that the result span is equal to the stripped selection span
53     // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
54     // anyway only find valid numbers and percentages and a given selection with
55     // more than two tokens won't pass this check.
56     if (result.span.first + selection_indices.first ==
57             selection_indices.first &&
58         result.span.second + selection_indices.first ==
59             selection_indices.second) {
60       *classification_result = result.classification[0];
61       return true;
62     }
63   }
64   return false;
65 }
66 
IsCJTterm(UnicodeText::const_iterator token_begin_it,const int token_length) const67 bool NumberAnnotator::IsCJTterm(UnicodeText::const_iterator token_begin_it,
68                                 const int token_length) const {
69   auto token_end_it = token_begin_it;
70   std::advance(token_end_it, token_length);
71   for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
72     if (!unilib_->IsCJTletter(*char_it)) {
73       return false;
74     }
75   }
76   return true;
77 }
78 
TokensAreValidStart(const std::vector<Token> & tokens,const int start_index) const79 bool NumberAnnotator::TokensAreValidStart(const std::vector<Token>& tokens,
80                                           const int start_index) const {
81   if (start_index < 0 || tokens[start_index].is_whitespace) {
82     return true;
83   }
84   return false;
85 }
86 
TokensAreValidNumberPrefix(const std::vector<Token> & tokens,const int prefix_end_index) const87 bool NumberAnnotator::TokensAreValidNumberPrefix(
88     const std::vector<Token>& tokens, const int prefix_end_index) const {
89   if (TokensAreValidStart(tokens, prefix_end_index)) {
90     return true;
91   }
92 
93   auto prefix_begin_it =
94       UTF8ToUnicodeText(tokens[prefix_end_index].value, /*do_copy=*/false)
95           .begin();
96   const int token_length =
97       tokens[prefix_end_index].end - tokens[prefix_end_index].start;
98   if (token_length == 1 && unilib_->IsOpeningBracket(*prefix_begin_it) &&
99       TokensAreValidStart(tokens, prefix_end_index - 1)) {
100     return true;
101   }
102   if (token_length == 1 && unilib_->IsNumberSign(*prefix_begin_it) &&
103       TokensAreValidStart(tokens, prefix_end_index - 1)) {
104     return true;
105   }
106   if (token_length == 1 && unilib_->IsSlash(*prefix_begin_it) &&
107       prefix_end_index >= 1 &&
108       TokensAreValidStart(tokens, prefix_end_index - 2)) {
109     int64 int_val;
110     double double_val;
111     return TryParseNumber(UTF8ToUnicodeText(tokens[prefix_end_index - 1].value,
112                                             /*do_copy=*/false),
113                           false, &int_val, &double_val);
114   }
115   if (IsCJTterm(prefix_begin_it, token_length)) {
116     return true;
117   }
118 
119   return false;
120 }
121 
TokensAreValidEnding(const std::vector<Token> & tokens,const int ending_index) const122 bool NumberAnnotator::TokensAreValidEnding(const std::vector<Token>& tokens,
123                                            const int ending_index) const {
124   if (ending_index >= tokens.size() || tokens[ending_index].is_whitespace) {
125     return true;
126   }
127 
128   auto ending_begin_it =
129       UTF8ToUnicodeText(tokens[ending_index].value, /*do_copy=*/false).begin();
130   if (ending_index == tokens.size() - 1 &&
131       tokens[ending_index].end - tokens[ending_index].start == 1 &&
132       unilib_->IsPunctuation(*ending_begin_it)) {
133     return true;
134   }
135   if (ending_index < tokens.size() - 1 &&
136       tokens[ending_index].end - tokens[ending_index].start == 1 &&
137       unilib_->IsPunctuation(*ending_begin_it) &&
138       tokens[ending_index + 1].is_whitespace) {
139     return true;
140   }
141 
142   return false;
143 }
144 
TokensAreValidNumberSuffix(const std::vector<Token> & tokens,const int suffix_start_index) const145 bool NumberAnnotator::TokensAreValidNumberSuffix(
146     const std::vector<Token>& tokens, const int suffix_start_index) const {
147   if (TokensAreValidEnding(tokens, suffix_start_index)) {
148     return true;
149   }
150 
151   auto suffix_begin_it =
152       UTF8ToUnicodeText(tokens[suffix_start_index].value, /*do_copy=*/false)
153           .begin();
154 
155   if (percent_suffixes_.find(tokens[suffix_start_index].value) !=
156           percent_suffixes_.end() &&
157       TokensAreValidEnding(tokens, suffix_start_index + 1)) {
158     return true;
159   }
160 
161   const int token_length =
162       tokens[suffix_start_index].end - tokens[suffix_start_index].start;
163   if (token_length == 1 && unilib_->IsSlash(*suffix_begin_it) &&
164       suffix_start_index <= tokens.size() - 2 &&
165       TokensAreValidEnding(tokens, suffix_start_index + 2)) {
166     int64 int_val;
167     double double_val;
168     return TryParseNumber(
169         UTF8ToUnicodeText(tokens[suffix_start_index + 1].value,
170                           /*do_copy=*/false),
171         false, &int_val, &double_val);
172   }
173   if (IsCJTterm(suffix_begin_it, token_length)) {
174     return true;
175   }
176 
177   return false;
178 }
179 
FindPercentSuffixEndCodepoint(const std::vector<Token> & tokens,const int suffix_token_start_index) const180 int NumberAnnotator::FindPercentSuffixEndCodepoint(
181     const std::vector<Token>& tokens,
182     const int suffix_token_start_index) const {
183   if (suffix_token_start_index >= tokens.size()) {
184     return -1;
185   }
186 
187   if (percent_suffixes_.find(tokens[suffix_token_start_index].value) !=
188           percent_suffixes_.end() &&
189       TokensAreValidEnding(tokens, suffix_token_start_index + 1)) {
190     return tokens[suffix_token_start_index].end;
191   }
192   if (tokens[suffix_token_start_index].is_whitespace) {
193     return FindPercentSuffixEndCodepoint(tokens, suffix_token_start_index + 1);
194   }
195 
196   return -1;
197 }
198 
TryParseNumber(const UnicodeText & token_text,const bool is_negative,int64 * parsed_int_value,double * parsed_double_value) const199 bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
200                                      const bool is_negative,
201                                      int64* parsed_int_value,
202                                      double* parsed_double_value) const {
203   if (token_text.ToUTF8String().size() >= max_number_of_digits_) {
204     return false;
205   }
206   const bool is_double = unilib_->ParseDouble(token_text, parsed_double_value);
207   if (!is_double) {
208     return false;
209   }
210   *parsed_int_value = std::trunc(*parsed_double_value);
211   if (is_negative) {
212     *parsed_int_value *= -1;
213     *parsed_double_value *= -1;
214   }
215 
216   return true;
217 }
218 
FindAll(const UnicodeText & context,AnnotationUsecase annotation_usecase,ModeFlag mode,std::vector<AnnotatedSpan> * result) const219 bool NumberAnnotator::FindAll(const UnicodeText& context,
220                               AnnotationUsecase annotation_usecase,
221                               ModeFlag mode,
222                               std::vector<AnnotatedSpan>* result) const {
223   if (!options_->enabled() || !(options_->enabled_modes() & mode)) {
224     return true;
225   }
226 
227   const std::vector<Token> tokens = tokenizer_.Tokenize(context);
228   for (int i = 0; i < tokens.size(); ++i) {
229     const Token token = tokens[i];
230     if (tokens[i].value.empty() ||
231         !unilib_->IsDigit(
232             *UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false).begin())) {
233       continue;
234     }
235 
236     const UnicodeText token_text =
237         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
238     int64 parsed_int_value;
239     double parsed_double_value;
240     bool is_negative =
241         (i > 0) &&
242         unilib_->IsMinus(
243             *UTF8ToUnicodeText(tokens[i - 1].value, /*do_copy=*/false).begin());
244     if (!TryParseNumber(token_text, is_negative, &parsed_int_value,
245                         &parsed_double_value)) {
246       continue;
247     }
248     if (!TokensAreValidNumberPrefix(tokens, is_negative ? i - 2 : i - 1) ||
249         !TokensAreValidNumberSuffix(tokens, i + 1)) {
250       continue;
251     }
252 
253     const bool has_decimal = !(parsed_int_value == parsed_double_value);
254     const int new_start_codepoint = is_negative ? token.start - 1 : token.start;
255 
256     if (((1 << annotation_usecase) & options_->enabled_annotation_usecases()) !=
257         0) {
258       result->push_back(CreateAnnotatedSpan(
259           new_start_codepoint, token.end, parsed_int_value, parsed_double_value,
260           Collections::Number(), options_->score(),
261           /*priority_score=*/
262           has_decimal ? options_->float_number_priority_score()
263                       : options_->priority_score()));
264     }
265 
266     const int percent_end_codepoint =
267         FindPercentSuffixEndCodepoint(tokens, i + 1);
268     if (percent_end_codepoint != -1 &&
269         ((1 << annotation_usecase) &
270          options_->percentage_annotation_usecases()) != 0) {
271       result->push_back(CreateAnnotatedSpan(
272           new_start_codepoint, percent_end_codepoint, parsed_int_value,
273           parsed_double_value, Collections::Percentage(), options_->score(),
274           options_->percentage_priority_score()));
275     }
276   }
277 
278   return true;
279 }
280 
CreateAnnotatedSpan(const int start,const int end,const int int_value,const double double_value,const std::string collection,const float score,const float priority_score) const281 AnnotatedSpan NumberAnnotator::CreateAnnotatedSpan(
282     const int start, const int end, const int int_value,
283     const double double_value, const std::string collection, const float score,
284     const float priority_score) const {
285   ClassificationResult classification{collection, score};
286   classification.numeric_value = int_value;
287   classification.numeric_double_value = double_value;
288   classification.priority_score = priority_score;
289 
290   AnnotatedSpan annotated_span;
291   annotated_span.span = {start, end};
292   annotated_span.classification.push_back(classification);
293   return annotated_span;
294 }
295 
296 std::unordered_set<std::string>
FromFlatbufferStringToUnordredSet(const flatbuffers::String * flatbuffer_percent_strings)297 NumberAnnotator::FromFlatbufferStringToUnordredSet(
298     const flatbuffers::String* flatbuffer_percent_strings) {
299   std::unordered_set<std::string> strings_set;
300   if (flatbuffer_percent_strings == nullptr) {
301     return strings_set;
302   }
303 
304   const std::string percent_strings = flatbuffer_percent_strings->str();
305   for (StringPiece suffix : strings::Split(percent_strings, '\0')) {
306     std::string percent_suffix = suffix.ToString();
307     percent_suffix.erase(
308         std::remove_if(percent_suffix.begin(), percent_suffix.end(),
309                        [](unsigned char x) { return std::isspace(x); }),
310         percent_suffix.end());
311     strings_set.insert(percent_suffix);
312   }
313 
314   return strings_set;
315 }
316 
317 }  // namespace libtextclassifier3
318