1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/duration/duration.h"
18
19 #include <climits>
20 #include <cstdlib>
21
22 #include "annotator/collections.h"
23 #include "annotator/types.h"
24 #include "utils/base/logging.h"
25 #include "utils/strings/numbers.h"
26 #include "utils/utf8/unicodetext.h"
27
28 namespace libtextclassifier3 {
29
30 using DurationUnit = internal::DurationUnit;
31
32 namespace internal {
33
34 namespace {
ToLowerString(const std::string & str,const UniLib * unilib)35 std::string ToLowerString(const std::string& str, const UniLib* unilib) {
36 return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false))
37 .ToUTF8String();
38 }
39
FillDurationUnitMap(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * expressions,DurationUnit duration_unit,std::unordered_map<std::string,DurationUnit> * target_map,const UniLib * unilib)40 void FillDurationUnitMap(
41 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
42 expressions,
43 DurationUnit duration_unit,
44 std::unordered_map<std::string, DurationUnit>* target_map,
45 const UniLib* unilib) {
46 if (expressions == nullptr) {
47 return;
48 }
49
50 for (const flatbuffers::String* expression_string : *expressions) {
51 (*target_map)[ToLowerString(expression_string->c_str(), unilib)] =
52 duration_unit;
53 }
54 }
55 } // namespace
56
BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions * options,const UniLib * unilib)57 std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
58 const DurationAnnotatorOptions* options, const UniLib* unilib) {
59 std::unordered_map<std::string, DurationUnit> mapping;
60 FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping,
61 unilib);
62 FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping,
63 unilib);
64 FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping,
65 unilib);
66 FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
67 &mapping, unilib);
68 FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
69 &mapping, unilib);
70 return mapping;
71 }
72
BuildStringSet(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * strings,const UniLib * unilib)73 std::unordered_set<std::string> BuildStringSet(
74 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
75 strings,
76 const UniLib* unilib) {
77 std::unordered_set<std::string> result;
78 if (strings == nullptr) {
79 return result;
80 }
81
82 for (const flatbuffers::String* string_value : *strings) {
83 result.insert(ToLowerString(string_value->c_str(), unilib));
84 }
85
86 return result;
87 }
88
BuildInt32Set(const flatbuffers::Vector<int32> * ints)89 std::unordered_set<int32> BuildInt32Set(
90 const flatbuffers::Vector<int32>* ints) {
91 std::unordered_set<int32> result;
92 if (ints == nullptr) {
93 return result;
94 }
95
96 for (const int32 int_value : *ints) {
97 result.insert(int_value);
98 }
99
100 return result;
101 }
102
103 } // namespace internal
104
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const105 bool DurationAnnotator::ClassifyText(
106 const UnicodeText& context, CodepointSpan selection_indices,
107 AnnotationUsecase annotation_usecase,
108 ClassificationResult* classification_result) const {
109 if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
110 (1 << annotation_usecase))) == 0) {
111 return false;
112 }
113
114 const UnicodeText selection =
115 UnicodeText::Substring(context, selection_indices.first,
116 selection_indices.second, /*do_copy=*/false);
117 const std::vector<Token> tokens = feature_processor_->Tokenize(selection);
118
119 AnnotatedSpan annotated_span;
120 if (tokens.empty() ||
121 FindDurationStartingAt(context, tokens, 0, &annotated_span) !=
122 tokens.size()) {
123 return false;
124 }
125
126 TC3_DCHECK(!annotated_span.classification.empty());
127
128 *classification_result = annotated_span.classification[0];
129 return true;
130 }
131
FindAll(const UnicodeText & context,const std::vector<Token> & tokens,AnnotationUsecase annotation_usecase,std::vector<AnnotatedSpan> * results) const132 bool DurationAnnotator::FindAll(const UnicodeText& context,
133 const std::vector<Token>& tokens,
134 AnnotationUsecase annotation_usecase,
135 std::vector<AnnotatedSpan>* results) const {
136 if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
137 (1 << annotation_usecase))) == 0) {
138 return true;
139 }
140
141 for (int i = 0; i < tokens.size();) {
142 AnnotatedSpan span;
143 const int next_i = FindDurationStartingAt(context, tokens, i, &span);
144 if (next_i != i) {
145 results->push_back(span);
146 i = next_i;
147 } else {
148 i++;
149 }
150 }
151 return true;
152 }
153
FindDurationStartingAt(const UnicodeText & context,const std::vector<Token> & tokens,int start_token_index,AnnotatedSpan * result) const154 int DurationAnnotator::FindDurationStartingAt(const UnicodeText& context,
155 const std::vector<Token>& tokens,
156 int start_token_index,
157 AnnotatedSpan* result) const {
158 CodepointIndex start_index = kInvalidIndex;
159 CodepointIndex end_index = kInvalidIndex;
160
161 bool has_quantity = false;
162 ParsedDurationAtom parsed_duration;
163
164 std::vector<ParsedDurationAtom> parsed_duration_atoms;
165
166 // This is the core algorithm for finding the duration expressions. It
167 // basically iterates over tokens and changes the state variables above as it
168 // goes.
169 int token_index;
170 int quantity_end_index;
171 for (token_index = start_token_index; token_index < tokens.size();
172 token_index++) {
173 const Token& token = tokens[token_index];
174
175 if (ParseQuantityToken(token, &parsed_duration)) {
176 has_quantity = true;
177 if (start_index == kInvalidIndex) {
178 start_index = token.start;
179 }
180 quantity_end_index = token.end;
181 } else if (((!options_->require_quantity() || has_quantity) &&
182 ParseDurationUnitToken(token, &parsed_duration.unit)) ||
183 ParseQuantityDurationUnitToken(token, &parsed_duration)) {
184 if (start_index == kInvalidIndex) {
185 start_index = token.start;
186 }
187 end_index = token.end;
188 parsed_duration_atoms.push_back(parsed_duration);
189 has_quantity = false;
190 parsed_duration = ParsedDurationAtom();
191 } else if (ParseFillerToken(token)) {
192 } else {
193 break;
194 }
195 }
196
197 if (parsed_duration_atoms.empty()) {
198 return start_token_index;
199 }
200
201 const bool parse_ended_without_unit_for_last_mentioned_quantity =
202 has_quantity;
203
204 ClassificationResult classification{Collections::Duration(),
205 options_->score()};
206 classification.priority_score = options_->priority_score();
207 classification.duration_ms =
208 ParsedDurationAtomsToMillis(parsed_duration_atoms);
209
210 // Process suffix expressions like "and half" that don't have the
211 // duration_unit explicitly mentioned.
212 if (parse_ended_without_unit_for_last_mentioned_quantity) {
213 if (parsed_duration.plus_half) {
214 end_index = quantity_end_index;
215 ParsedDurationAtom atom = ParsedDurationAtom::Half();
216 atom.unit = parsed_duration_atoms.rbegin()->unit;
217 classification.duration_ms += ParsedDurationAtomsToMillis({atom});
218 } else if (options_->enable_dangling_quantity_interpretation()) {
219 end_index = quantity_end_index;
220 // TODO(b/144752747) Add dangling quantity to duration_ms.
221 }
222 }
223
224 result->span = feature_processor_->StripBoundaryCodepoints(
225 context, {start_index, end_index});
226 result->classification.push_back(classification);
227 result->source = AnnotatedSpan::Source::DURATION;
228
229 return token_index;
230 }
231
ParsedDurationAtomsToMillis(const std::vector<ParsedDurationAtom> & atoms) const232 int64 DurationAnnotator::ParsedDurationAtomsToMillis(
233 const std::vector<ParsedDurationAtom>& atoms) const {
234 int64 result = 0;
235 for (auto atom : atoms) {
236 int multiplier;
237 switch (atom.unit) {
238 case DurationUnit::WEEK:
239 multiplier = 7 * 24 * 60 * 60 * 1000;
240 break;
241 case DurationUnit::DAY:
242 multiplier = 24 * 60 * 60 * 1000;
243 break;
244 case DurationUnit::HOUR:
245 multiplier = 60 * 60 * 1000;
246 break;
247 case DurationUnit::MINUTE:
248 multiplier = 60 * 1000;
249 break;
250 case DurationUnit::SECOND:
251 multiplier = 1000;
252 break;
253 case DurationUnit::UNKNOWN:
254 TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
255 return -1;
256 break;
257 }
258
259 int64 value = atom.value;
260 // This condition handles expressions like "an hour", where the quantity is
261 // not specified. In this case we assume quantity 1. Except for cases like
262 // "half hour".
263 if (value == 0 && !atom.plus_half) {
264 value = 1;
265 }
266 result += value * multiplier;
267 result += atom.plus_half * multiplier / 2;
268 }
269 return result;
270 }
271
ParseQuantityToken(const Token & token,ParsedDurationAtom * value) const272 bool DurationAnnotator::ParseQuantityToken(const Token& token,
273 ParsedDurationAtom* value) const {
274 if (token.value.empty()) {
275 return false;
276 }
277
278 std::string token_value_buffer;
279 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
280 token.value, &token_value_buffer);
281 const std::string& lowercase_token_value =
282 internal::ToLowerString(token_value, unilib_);
283
284 if (half_expressions_.find(lowercase_token_value) !=
285 half_expressions_.end()) {
286 value->plus_half = true;
287 return true;
288 }
289
290 int32 parsed_value;
291 if (ParseInt32(lowercase_token_value.c_str(), &parsed_value)) {
292 value->value = parsed_value;
293 return true;
294 }
295
296 return false;
297 }
298
ParseDurationUnitToken(const Token & token,DurationUnit * duration_unit) const299 bool DurationAnnotator::ParseDurationUnitToken(
300 const Token& token, DurationUnit* duration_unit) const {
301 std::string token_value_buffer;
302 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
303 token.value, &token_value_buffer);
304 const std::string& lowercase_token_value =
305 internal::ToLowerString(token_value, unilib_);
306
307 const auto it = token_value_to_duration_unit_.find(lowercase_token_value);
308 if (it == token_value_to_duration_unit_.end()) {
309 return false;
310 }
311
312 *duration_unit = it->second;
313 return true;
314 }
315
ParseQuantityDurationUnitToken(const Token & token,ParsedDurationAtom * value) const316 bool DurationAnnotator::ParseQuantityDurationUnitToken(
317 const Token& token, ParsedDurationAtom* value) const {
318 if (token.value.empty()) {
319 return false;
320 }
321
322 Token sub_token;
323 bool has_quantity = false;
324 for (const char c : token.value) {
325 if (sub_token_separator_codepoints_.find(c) !=
326 sub_token_separator_codepoints_.end()) {
327 if (has_quantity || !ParseQuantityToken(sub_token, value)) {
328 return false;
329 }
330 has_quantity = true;
331
332 sub_token = Token();
333 } else {
334 sub_token.value += c;
335 }
336 }
337
338 return (!options_->require_quantity() || has_quantity) &&
339 ParseDurationUnitToken(sub_token, &(value->unit));
340 }
341
ParseFillerToken(const Token & token) const342 bool DurationAnnotator::ParseFillerToken(const Token& token) const {
343 std::string token_value_buffer;
344 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
345 token.value, &token_value_buffer);
346 const std::string& lowercase_token_value =
347 internal::ToLowerString(token_value, unilib_);
348
349 if (filler_expressions_.find(lowercase_token_value) ==
350 filler_expressions_.end()) {
351 return false;
352 }
353
354 return true;
355 }
356
357 } // namespace libtextclassifier3
358