1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/duration/duration.h"
18
19 #include <climits>
20 #include <cstdlib>
21
22 #include "annotator/collections.h"
23 #include "annotator/model_generated.h"
24 #include "annotator/types.h"
25 #include "utils/base/logging.h"
26 #include "utils/base/macros.h"
27 #include "utils/strings/numbers.h"
28 #include "utils/utf8/unicodetext.h"
29
30 namespace libtextclassifier3 {
31
32 using DurationUnit = internal::DurationUnit;
33
34 namespace internal {
35
36 namespace {
ToLowerString(const std::string & str,const UniLib * unilib)37 std::string ToLowerString(const std::string& str, const UniLib* unilib) {
38 return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false))
39 .ToUTF8String();
40 }
41
FillDurationUnitMap(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * expressions,DurationUnit duration_unit,std::unordered_map<std::string,DurationUnit> * target_map,const UniLib * unilib)42 void FillDurationUnitMap(
43 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
44 expressions,
45 DurationUnit duration_unit,
46 std::unordered_map<std::string, DurationUnit>* target_map,
47 const UniLib* unilib) {
48 if (expressions == nullptr) {
49 return;
50 }
51
52 for (const flatbuffers::String* expression_string : *expressions) {
53 (*target_map)[ToLowerString(expression_string->c_str(), unilib)] =
54 duration_unit;
55 }
56 }
57 } // namespace
58
BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions * options,const UniLib * unilib)59 std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
60 const DurationAnnotatorOptions* options, const UniLib* unilib) {
61 std::unordered_map<std::string, DurationUnit> mapping;
62 FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping,
63 unilib);
64 FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping,
65 unilib);
66 FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping,
67 unilib);
68 FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
69 &mapping, unilib);
70 FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
71 &mapping, unilib);
72 return mapping;
73 }
74
BuildStringSet(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * strings,const UniLib * unilib)75 std::unordered_set<std::string> BuildStringSet(
76 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
77 strings,
78 const UniLib* unilib) {
79 std::unordered_set<std::string> result;
80 if (strings == nullptr) {
81 return result;
82 }
83
84 for (const flatbuffers::String* string_value : *strings) {
85 result.insert(ToLowerString(string_value->c_str(), unilib));
86 }
87
88 return result;
89 }
90
BuildInt32Set(const flatbuffers::Vector<int32> * ints)91 std::unordered_set<int32> BuildInt32Set(
92 const flatbuffers::Vector<int32>* ints) {
93 std::unordered_set<int32> result;
94 if (ints == nullptr) {
95 return result;
96 }
97
98 for (const int32 int_value : *ints) {
99 result.insert(int_value);
100 }
101
102 return result;
103 }
104
105 // Get the dangling quantity unit e.g. for 2 hours 10, 10 would have the unit
106 // "minute".
GetDanglingQuantityUnit(const DurationUnit main_unit)107 DurationUnit GetDanglingQuantityUnit(const DurationUnit main_unit) {
108 switch (main_unit) {
109 case DurationUnit::HOUR:
110 return DurationUnit::MINUTE;
111 case DurationUnit::MINUTE:
112 return DurationUnit::SECOND;
113 case DurationUnit::UNKNOWN:
114 TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
115 TC3_FALLTHROUGH_INTENDED;
116 case DurationUnit::WEEK:
117 case DurationUnit::DAY:
118 case DurationUnit::SECOND:
119 // We only support dangling units for hours and minutes.
120 return DurationUnit::UNKNOWN;
121 }
122 }
123 } // namespace internal
124
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const125 bool DurationAnnotator::ClassifyText(
126 const UnicodeText& context, CodepointSpan selection_indices,
127 AnnotationUsecase annotation_usecase,
128 ClassificationResult* classification_result) const {
129 if (!options_->enabled() ||
130 ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
131 0 ||
132 !(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
133 return false;
134 }
135
136 const UnicodeText selection =
137 UnicodeText::Substring(context, selection_indices.first,
138 selection_indices.second, /*do_copy=*/false);
139 const std::vector<Token> tokens = feature_processor_->Tokenize(selection);
140
141 AnnotatedSpan annotated_span;
142 if (tokens.empty() ||
143 FindDurationStartingAt(context, tokens, 0, &annotated_span) !=
144 tokens.size()) {
145 return false;
146 }
147
148 TC3_DCHECK(!annotated_span.classification.empty());
149
150 *classification_result = annotated_span.classification[0];
151 return true;
152 }
153
FindAll(const UnicodeText & context,const std::vector<Token> & tokens,AnnotationUsecase annotation_usecase,ModeFlag mode,std::vector<AnnotatedSpan> * results) const154 bool DurationAnnotator::FindAll(const UnicodeText& context,
155 const std::vector<Token>& tokens,
156 AnnotationUsecase annotation_usecase,
157 ModeFlag mode,
158 std::vector<AnnotatedSpan>* results) const {
159 if (!options_->enabled() ||
160 ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
161 0 ||
162 !(options_->enabled_modes() & mode)) {
163 return true;
164 }
165
166 for (int i = 0; i < tokens.size();) {
167 AnnotatedSpan span;
168 const int next_i = FindDurationStartingAt(context, tokens, i, &span);
169 if (next_i != i) {
170 results->push_back(span);
171 i = next_i;
172 } else {
173 i++;
174 }
175 }
176 return true;
177 }
178
FindDurationStartingAt(const UnicodeText & context,const std::vector<Token> & tokens,int start_token_index,AnnotatedSpan * result) const179 int DurationAnnotator::FindDurationStartingAt(const UnicodeText& context,
180 const std::vector<Token>& tokens,
181 int start_token_index,
182 AnnotatedSpan* result) const {
183 CodepointIndex start_index = kInvalidIndex;
184 CodepointIndex end_index = kInvalidIndex;
185
186 bool has_quantity = false;
187 ParsedDurationAtom parsed_duration;
188
189 std::vector<ParsedDurationAtom> parsed_duration_atoms;
190
191 // This is the core algorithm for finding the duration expressions. It
192 // basically iterates over tokens and changes the state variables above as it
193 // goes.
194 int token_index;
195 int quantity_end_index;
196 for (token_index = start_token_index; token_index < tokens.size();
197 token_index++) {
198 const Token& token = tokens[token_index];
199
200 if (ParseQuantityToken(token, &parsed_duration)) {
201 has_quantity = true;
202 if (start_index == kInvalidIndex) {
203 start_index = token.start;
204 }
205 quantity_end_index = token.end;
206 } else if (((!options_->require_quantity() || has_quantity) &&
207 ParseDurationUnitToken(token, &parsed_duration.unit)) ||
208 ParseQuantityDurationUnitToken(token, &parsed_duration)) {
209 if (start_index == kInvalidIndex) {
210 start_index = token.start;
211 }
212 end_index = token.end;
213 parsed_duration_atoms.push_back(parsed_duration);
214 has_quantity = false;
215 parsed_duration = ParsedDurationAtom();
216 } else if (ParseFillerToken(token)) {
217 } else {
218 break;
219 }
220 }
221
222 if (parsed_duration_atoms.empty()) {
223 return start_token_index;
224 }
225
226 const bool parse_ended_without_unit_for_last_mentioned_quantity =
227 has_quantity;
228
229 if (parse_ended_without_unit_for_last_mentioned_quantity) {
230 const DurationUnit main_unit = parsed_duration_atoms.rbegin()->unit;
231 if (parsed_duration.plus_half) {
232 // Process "and half" suffix.
233 end_index = quantity_end_index;
234 ParsedDurationAtom atom = ParsedDurationAtom::Half();
235 atom.unit = main_unit;
236 parsed_duration_atoms.push_back(atom);
237 } else if (options_->enable_dangling_quantity_interpretation()) {
238 // Process dangling quantity.
239 ParsedDurationAtom atom;
240 atom.value = parsed_duration.value;
241 atom.unit = GetDanglingQuantityUnit(main_unit);
242 if (atom.unit != DurationUnit::UNKNOWN) {
243 end_index = quantity_end_index;
244 parsed_duration_atoms.push_back(atom);
245 }
246 }
247 }
248
249 ClassificationResult classification{Collections::Duration(),
250 options_->score()};
251 classification.priority_score = options_->priority_score();
252 classification.duration_ms =
253 ParsedDurationAtomsToMillis(parsed_duration_atoms);
254
255 result->span = feature_processor_->StripBoundaryCodepoints(
256 context, {start_index, end_index});
257 result->classification.push_back(classification);
258 result->source = AnnotatedSpan::Source::DURATION;
259
260 return token_index;
261 }
262
ParsedDurationAtomsToMillis(const std::vector<ParsedDurationAtom> & atoms) const263 int64 DurationAnnotator::ParsedDurationAtomsToMillis(
264 const std::vector<ParsedDurationAtom>& atoms) const {
265 int64 result = 0;
266 for (auto atom : atoms) {
267 int multiplier;
268 switch (atom.unit) {
269 case DurationUnit::WEEK:
270 multiplier = 7 * 24 * 60 * 60 * 1000;
271 break;
272 case DurationUnit::DAY:
273 multiplier = 24 * 60 * 60 * 1000;
274 break;
275 case DurationUnit::HOUR:
276 multiplier = 60 * 60 * 1000;
277 break;
278 case DurationUnit::MINUTE:
279 multiplier = 60 * 1000;
280 break;
281 case DurationUnit::SECOND:
282 multiplier = 1000;
283 break;
284 case DurationUnit::UNKNOWN:
285 TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
286 return -1;
287 break;
288 }
289
290 double value = atom.value;
291 // This condition handles expressions like "an hour", where the quantity is
292 // not specified. In this case we assume quantity 1. Except for cases like
293 // "half hour".
294 if (value == 0 && !atom.plus_half) {
295 value = 1;
296 }
297 result += value * multiplier;
298 result += atom.plus_half * multiplier / 2;
299 }
300 return result;
301 }
302
ParseQuantityToken(const Token & token,ParsedDurationAtom * value) const303 bool DurationAnnotator::ParseQuantityToken(const Token& token,
304 ParsedDurationAtom* value) const {
305 if (token.value.empty()) {
306 return false;
307 }
308
309 std::string token_value_buffer;
310 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
311 token.value, &token_value_buffer);
312 const std::string& lowercase_token_value =
313 internal::ToLowerString(token_value, unilib_);
314
315 if (half_expressions_.find(lowercase_token_value) !=
316 half_expressions_.end()) {
317 value->plus_half = true;
318 return true;
319 }
320
321 double parsed_value;
322 if (ParseDouble(lowercase_token_value.c_str(), &parsed_value)) {
323 value->value = parsed_value;
324 return true;
325 }
326
327 return false;
328 }
329
ParseDurationUnitToken(const Token & token,DurationUnit * duration_unit) const330 bool DurationAnnotator::ParseDurationUnitToken(
331 const Token& token, DurationUnit* duration_unit) const {
332 std::string token_value_buffer;
333 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
334 token.value, &token_value_buffer);
335 const std::string& lowercase_token_value =
336 internal::ToLowerString(token_value, unilib_);
337
338 const auto it = token_value_to_duration_unit_.find(lowercase_token_value);
339 if (it == token_value_to_duration_unit_.end()) {
340 return false;
341 }
342
343 *duration_unit = it->second;
344 return true;
345 }
346
ParseQuantityDurationUnitToken(const Token & token,ParsedDurationAtom * value) const347 bool DurationAnnotator::ParseQuantityDurationUnitToken(
348 const Token& token, ParsedDurationAtom* value) const {
349 if (token.value.empty()) {
350 return false;
351 }
352
353 Token sub_token;
354 bool has_quantity = false;
355 for (const char c : token.value) {
356 if (sub_token_separator_codepoints_.find(c) !=
357 sub_token_separator_codepoints_.end()) {
358 if (has_quantity || !ParseQuantityToken(sub_token, value)) {
359 return false;
360 }
361 has_quantity = true;
362
363 sub_token = Token();
364 } else {
365 sub_token.value += c;
366 }
367 }
368
369 return (!options_->require_quantity() || has_quantity) &&
370 ParseDurationUnitToken(sub_token, &(value->unit));
371 }
372
ParseFillerToken(const Token & token) const373 bool DurationAnnotator::ParseFillerToken(const Token& token) const {
374 std::string token_value_buffer;
375 const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
376 token.value, &token_value_buffer);
377 const std::string& lowercase_token_value =
378 internal::ToLowerString(token_value, unilib_);
379
380 if (filler_expressions_.find(lowercase_token_value) ==
381 filler_expressions_.end()) {
382 return false;
383 }
384
385 return true;
386 }
387
388 } // namespace libtextclassifier3
389