1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/number/number.h"
18
19 #include <climits>
20 #include <cstdlib>
21
22 #include "annotator/collections.h"
23 #include "utils/base/logging.h"
24
25 namespace libtextclassifier3 {
26
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const27 bool NumberAnnotator::ClassifyText(
28 const UnicodeText& context, CodepointSpan selection_indices,
29 AnnotationUsecase annotation_usecase,
30 ClassificationResult* classification_result) const {
31 int64 parsed_value;
32 int num_prefix_codepoints;
33 int num_suffix_codepoints;
34 if (ParseNumber(UnicodeText::Substring(context, selection_indices.first,
35 selection_indices.second),
36 &parsed_value, &num_prefix_codepoints,
37 &num_suffix_codepoints)) {
38 ClassificationResult classification{Collections::Number(), 1.0};
39 TC3_CHECK(classification_result != nullptr);
40 classification_result->collection = Collections::Number();
41 classification_result->score = options_->score();
42 classification_result->priority_score = options_->priority_score();
43 classification_result->numeric_value = parsed_value;
44 return true;
45 }
46 return false;
47 }
48
FindAll(const UnicodeText & context,AnnotationUsecase annotation_usecase,std::vector<AnnotatedSpan> * result) const49 bool NumberAnnotator::FindAll(const UnicodeText& context,
50 AnnotationUsecase annotation_usecase,
51 std::vector<AnnotatedSpan>* result) const {
52 if (!options_->enabled() || ((1 << annotation_usecase) &
53 options_->enabled_annotation_usecases()) == 0) {
54 return true;
55 }
56
57 const std::vector<Token> tokens = feature_processor_->Tokenize(context);
58 for (const Token& token : tokens) {
59 const UnicodeText token_text =
60 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
61 int64 parsed_value;
62 int num_prefix_codepoints;
63 int num_suffix_codepoints;
64 if (ParseNumber(token_text, &parsed_value, &num_prefix_codepoints,
65 &num_suffix_codepoints)) {
66 ClassificationResult classification{Collections::Number(),
67 options_->score()};
68 classification.numeric_value = parsed_value;
69 classification.priority_score = options_->priority_score();
70
71 AnnotatedSpan annotated_span;
72 annotated_span.span = {token.start + num_prefix_codepoints,
73 token.end - num_suffix_codepoints};
74 annotated_span.classification.push_back(classification);
75
76 result->push_back(annotated_span);
77 }
78 }
79
80 return true;
81 }
82
FlatbuffersVectorToSet(const flatbuffers::Vector<int32_t> * codepoints)83 std::unordered_set<int> NumberAnnotator::FlatbuffersVectorToSet(
84 const flatbuffers::Vector<int32_t>* codepoints) {
85 if (codepoints == nullptr) {
86 return std::unordered_set<int>{};
87 }
88
89 std::unordered_set<int> result;
90 for (const int codepoint : *codepoints) {
91 result.insert(codepoint);
92 }
93 return result;
94 }
95
96 namespace {
ConsumeAndParseNumber(const UnicodeText::const_iterator & it_begin,const UnicodeText::const_iterator & it_end,int64 * result)97 UnicodeText::const_iterator ConsumeAndParseNumber(
98 const UnicodeText::const_iterator& it_begin,
99 const UnicodeText::const_iterator& it_end, int64* result) {
100 *result = 0;
101
102 // See if there's a sign in the beginning of the number.
103 int sign = 1;
104 auto it = it_begin;
105 if (it != it_end) {
106 if (*it == '-') {
107 ++it;
108 sign = -1;
109 } else if (*it == '+') {
110 ++it;
111 sign = 1;
112 }
113 }
114
115 while (it != it_end) {
116 if (*it >= '0' && *it <= '9') {
117 // When overflow is imminent we'll fail to parse the number.
118 if (*result > INT64_MAX / 10) {
119 return it_begin;
120 }
121 *result *= 10;
122 *result += *it - '0';
123 } else {
124 *result *= sign;
125 return it;
126 }
127
128 ++it;
129 }
130
131 *result *= sign;
132 return it_end;
133 }
134 } // namespace
135
ParseNumber(const UnicodeText & text,int64 * result,int * num_prefix_codepoints,int * num_suffix_codepoints) const136 bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* result,
137 int* num_prefix_codepoints,
138 int* num_suffix_codepoints) const {
139 TC3_CHECK(result != nullptr && num_prefix_codepoints != nullptr &&
140 num_suffix_codepoints != nullptr);
141 auto it = text.begin();
142 auto it_end = text.end();
143
144 // Strip boundary codepoints from both ends.
145 const CodepointSpan original_span{0, text.size_codepoints()};
146 const CodepointSpan stripped_span =
147 feature_processor_->StripBoundaryCodepoints(text, original_span);
148 const int num_stripped_end = (original_span.second - stripped_span.second);
149 std::advance(it, stripped_span.first);
150 std::advance(it_end, -num_stripped_end);
151
152 // Consume prefix codepoints.
153 *num_prefix_codepoints = stripped_span.first;
154 while (it != text.end()) {
155 if (allowed_prefix_codepoints_.find(*it) ==
156 allowed_prefix_codepoints_.end()) {
157 break;
158 }
159
160 ++it;
161 ++(*num_prefix_codepoints);
162 }
163
164 auto it_start = it;
165 it = ConsumeAndParseNumber(it, text.end(), result);
166 if (it == it_start) {
167 return false;
168 }
169
170 // Consume suffix codepoints.
171 bool valid_suffix = true;
172 *num_suffix_codepoints = 0;
173 while (it != it_end) {
174 if (allowed_suffix_codepoints_.find(*it) ==
175 allowed_suffix_codepoints_.end()) {
176 valid_suffix = false;
177 break;
178 }
179
180 ++it;
181 ++(*num_suffix_codepoints);
182 }
183 *num_suffix_codepoints += num_stripped_end;
184 return valid_suffix;
185 }
186
187 } // namespace libtextclassifier3
188