1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/duration/duration.h"
18
19 #include <climits>
20 #include <cstdlib>
21
22 #include "annotator/collections.h"
23 #include "annotator/types.h"
24 #include "utils/base/logging.h"
25 #include "utils/base/macros.h"
26 #include "utils/strings/numbers.h"
27 #include "utils/utf8/unicodetext.h"
28
29 namespace libtextclassifier3 {
30
31 using DurationUnit = internal::DurationUnit;
32
33 namespace internal {
34
35 namespace {
ToLowerString(const std::string & str,const UniLib * unilib)36 std::string ToLowerString(const std::string& str, const UniLib* unilib) {
37 return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false))
38 .ToUTF8String();
39 }
40
FillDurationUnitMap(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * expressions,DurationUnit duration_unit,std::unordered_map<std::string,DurationUnit> * target_map,const UniLib * unilib)41 void FillDurationUnitMap(
42 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
43 expressions,
44 DurationUnit duration_unit,
45 std::unordered_map<std::string, DurationUnit>* target_map,
46 const UniLib* unilib) {
47 if (expressions == nullptr) {
48 return;
49 }
50
51 for (const flatbuffers::String* expression_string : *expressions) {
52 (*target_map)[ToLowerString(expression_string->c_str(), unilib)] =
53 duration_unit;
54 }
55 }
56 } // namespace
57
BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions * options,const UniLib * unilib)58 std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
59 const DurationAnnotatorOptions* options, const UniLib* unilib) {
60 std::unordered_map<std::string, DurationUnit> mapping;
61 FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping,
62 unilib);
63 FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping,
64 unilib);
65 FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping,
66 unilib);
67 FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
68 &mapping, unilib);
69 FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
70 &mapping, unilib);
71 return mapping;
72 }
73
BuildStringSet(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * strings,const UniLib * unilib)74 std::unordered_set<std::string> BuildStringSet(
75 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
76 strings,
77 const UniLib* unilib) {
78 std::unordered_set<std::string> result;
79 if (strings == nullptr) {
80 return result;
81 }
82
83 for (const flatbuffers::String* string_value : *strings) {
84 result.insert(ToLowerString(string_value->c_str(), unilib));
85 }
86
87 return result;
88 }
89
BuildInt32Set(const flatbuffers::Vector<int32> * ints)90 std::unordered_set<int32> BuildInt32Set(
91 const flatbuffers::Vector<int32>* ints) {
92 std::unordered_set<int32> result;
93 if (ints == nullptr) {
94 return result;
95 }
96
97 for (const int32 int_value : *ints) {
98 result.insert(int_value);
99 }
100
101 return result;
102 }
103
104 // Get the dangling quantity unit e.g. for 2 hours 10, 10 would have the unit
105 // "minute".
GetDanglingQuantityUnit(const DurationUnit main_unit)106 DurationUnit GetDanglingQuantityUnit(const DurationUnit main_unit) {
107 switch (main_unit) {
108 case DurationUnit::HOUR:
109 return DurationUnit::MINUTE;
110 case DurationUnit::MINUTE:
111 return DurationUnit::SECOND;
112 case DurationUnit::UNKNOWN:
113 TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
114 TC3_FALLTHROUGH_INTENDED;
115 case DurationUnit::WEEK:
116 case DurationUnit::DAY:
117 case DurationUnit::SECOND:
118 // We only support dangling units for hours and minutes.
119 return DurationUnit::UNKNOWN;
120 }
121 }
122 } // namespace internal
123
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const124 bool DurationAnnotator::ClassifyText(
125 const UnicodeText& context, CodepointSpan selection_indices,
126 AnnotationUsecase annotation_usecase,
127 ClassificationResult* classification_result) const {
128 if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
129 (1 << annotation_usecase))) == 0) {
130 return false;
131 }
132
133 const UnicodeText selection =
134 UnicodeText::Substring(context, selection_indices.first,
135 selection_indices.second, /*do_copy=*/false);
136 const std::vector<Token> tokens = feature_processor_->Tokenize(selection);
137
138 AnnotatedSpan annotated_span;
139 if (tokens.empty() ||
140 FindDurationStartingAt(context, tokens, 0, &annotated_span) !=
141 tokens.size()) {
142 return false;
143 }
144
145 TC3_DCHECK(!annotated_span.classification.empty());
146
147 *classification_result = annotated_span.classification[0];
148 return true;
149 }
150
FindAll(const UnicodeText & context,const std::vector<Token> & tokens,AnnotationUsecase annotation_usecase,std::vector<AnnotatedSpan> * results) const151 bool DurationAnnotator::FindAll(const UnicodeText& context,
152 const std::vector<Token>& tokens,
153 AnnotationUsecase annotation_usecase,
154 std::vector<AnnotatedSpan>* results) const {
155 if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
156 (1 << annotation_usecase))) == 0) {
157 return true;
158 }
159
160 for (int i = 0; i < tokens.size();) {
161 AnnotatedSpan span;
162 const int next_i = FindDurationStartingAt(context, tokens, i, &span);
163 if (next_i != i) {
164 results->push_back(span);
165 i = next_i;
166 } else {
167 i++;
168 }
169 }
170 return true;
171 }
172
FindDurationStartingAt(const UnicodeText & context,const std::vector<Token> & tokens,int start_token_index,AnnotatedSpan * result) const173 int DurationAnnotator::FindDurationStartingAt(const UnicodeText& context,
174 const std::vector<Token>& tokens,
175 int start_token_index,
176 AnnotatedSpan* result) const {
177 CodepointIndex start_index = kInvalidIndex;
178 CodepointIndex end_index = kInvalidIndex;
179
180 bool has_quantity = false;
181 ParsedDurationAtom parsed_duration;
182
183 std::vector<ParsedDurationAtom> parsed_duration_atoms;
184
185 // This is the core algorithm for finding the duration expressions. It
186 // basically iterates over tokens and changes the state variables above as it
187 // goes.
188 int token_index;
189 int quantity_end_index;
190 for (token_index = start_token_index; token_index < tokens.size();
191 token_index++) {
192 const Token& token = tokens[token_index];
193
194 if (ParseQuantityToken(token, &parsed_duration)) {
195 has_quantity = true;
196 if (start_index == kInvalidIndex) {
197 start_index = token.start;
198 }
199 quantity_end_index = token.end;
200 } else if (((!options_->require_quantity() || has_quantity) &&
201 ParseDurationUnitToken(token, &parsed_duration.unit)) ||
202 ParseQuantityDurationUnitToken(token, &parsed_duration)) {
203 if (start_index == kInvalidIndex) {
204 start_index = token.start;
205 }
206 end_index = token.end;
207 parsed_duration_atoms.push_back(parsed_duration);
208 has_quantity = false;
209 parsed_duration = ParsedDurationAtom();
210 } else if (ParseFillerToken(token)) {
211 } else {
212 break;
213 }
214 }
215
216 if (parsed_duration_atoms.empty()) {
217 return start_token_index;
218 }
219
220 const bool parse_ended_without_unit_for_last_mentioned_quantity =
221 has_quantity;
222
223 if (parse_ended_without_unit_for_last_mentioned_quantity) {
224 const DurationUnit main_unit = parsed_duration_atoms.rbegin()->unit;
225 if (parsed_duration.plus_half) {
226 // Process "and half" suffix.
227 end_index = quantity_end_index;
228 ParsedDurationAtom atom = ParsedDurationAtom::Half();
229 atom.unit = main_unit;
230 parsed_duration_atoms.push_back(atom);
231 } else if (options_->enable_dangling_quantity_interpretation()) {
232 // Process dangling quantity.
233 ParsedDurationAtom atom;
234 atom.value = parsed_duration.value;
235 atom.unit = GetDanglingQuantityUnit(main_unit);
236 if (atom.unit != DurationUnit::UNKNOWN) {
237 end_index = quantity_end_index;
238 parsed_duration_atoms.push_back(atom);
239 }
240 }
241 }
242
243 ClassificationResult classification{Collections::Duration(),
244 options_->score()};
245 classification.priority_score = options_->priority_score();
246 classification.duration_ms =
247 ParsedDurationAtomsToMillis(parsed_duration_atoms);
248
249 result->span = feature_processor_->StripBoundaryCodepoints(
250 context, {start_index, end_index});
251 result->classification.push_back(classification);
252 result->source = AnnotatedSpan::Source::DURATION;
253
254 return token_index;
255 }
256
ParsedDurationAtomsToMillis(const std::vector<ParsedDurationAtom> & atoms) const257 int64 DurationAnnotator::ParsedDurationAtomsToMillis(
258 const std::vector<ParsedDurationAtom>& atoms) const {
259 int64 result = 0;
260 for (auto atom : atoms) {
261 int multiplier;
262 switch (atom.unit) {
263 case DurationUnit::WEEK:
264 multiplier = 7 * 24 * 60 * 60 * 1000;
265 break;
266 case DurationUnit::DAY:
267 multiplier = 24 * 60 * 60 * 1000;
268 break;
269 case DurationUnit::HOUR:
270 multiplier = 60 * 60 * 1000;
271 break;
272 case DurationUnit::MINUTE:
273 multiplier = 60 * 1000;
274 break;
275 case DurationUnit::SECOND:
276 multiplier = 1000;
277 break;
278 case DurationUnit::UNKNOWN:
279 TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
280 return -1;
281 break;
282 }
283
284 double value = atom.value;
285 // This condition handles expressions like "an hour", where the quantity is
286 // not specified. In this case we assume quantity 1. Except for cases like
287 // "half hour".
288 if (value == 0 && !atom.plus_half) {
289 value = 1;
290 }
291 result += value * multiplier;
292 result += atom.plus_half * multiplier / 2;
293 }
294 return result;
295 }
296
ParseQuantityToken(const Token & token,ParsedDurationAtom * value) const297 bool DurationAnnotator::ParseQuantityToken(const Token& token,
298 ParsedDurationAtom* value) const {
299 if (token.value.empty()) {
300 return false;
301 }
302
303 std::string token_value_buffer;
304 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
305 token.value, &token_value_buffer);
306 const std::string& lowercase_token_value =
307 internal::ToLowerString(token_value, unilib_);
308
309 if (half_expressions_.find(lowercase_token_value) !=
310 half_expressions_.end()) {
311 value->plus_half = true;
312 return true;
313 }
314
315 double parsed_value;
316 if (ParseDouble(lowercase_token_value.c_str(), &parsed_value)) {
317 value->value = parsed_value;
318 return true;
319 }
320
321 return false;
322 }
323
ParseDurationUnitToken(const Token & token,DurationUnit * duration_unit) const324 bool DurationAnnotator::ParseDurationUnitToken(
325 const Token& token, DurationUnit* duration_unit) const {
326 std::string token_value_buffer;
327 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
328 token.value, &token_value_buffer);
329 const std::string& lowercase_token_value =
330 internal::ToLowerString(token_value, unilib_);
331
332 const auto it = token_value_to_duration_unit_.find(lowercase_token_value);
333 if (it == token_value_to_duration_unit_.end()) {
334 return false;
335 }
336
337 *duration_unit = it->second;
338 return true;
339 }
340
ParseQuantityDurationUnitToken(const Token & token,ParsedDurationAtom * value) const341 bool DurationAnnotator::ParseQuantityDurationUnitToken(
342 const Token& token, ParsedDurationAtom* value) const {
343 if (token.value.empty()) {
344 return false;
345 }
346
347 Token sub_token;
348 bool has_quantity = false;
349 for (const char c : token.value) {
350 if (sub_token_separator_codepoints_.find(c) !=
351 sub_token_separator_codepoints_.end()) {
352 if (has_quantity || !ParseQuantityToken(sub_token, value)) {
353 return false;
354 }
355 has_quantity = true;
356
357 sub_token = Token();
358 } else {
359 sub_token.value += c;
360 }
361 }
362
363 return (!options_->require_quantity() || has_quantity) &&
364 ParseDurationUnitToken(sub_token, &(value->unit));
365 }
366
ParseFillerToken(const Token & token) const367 bool DurationAnnotator::ParseFillerToken(const Token& token) const {
368 std::string token_value_buffer;
369 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
370 token.value, &token_value_buffer);
371 const std::string& lowercase_token_value =
372 internal::ToLowerString(token_value, unilib_);
373
374 if (filler_expressions_.find(lowercase_token_value) ==
375 filler_expressions_.end()) {
376 return false;
377 }
378
379 return true;
380 }
381
382 } // namespace libtextclassifier3
383