• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/duration/duration.h"
18 
19 #include <climits>
20 #include <cstdlib>
21 
22 #include "annotator/collections.h"
23 #include "annotator/model_generated.h"
24 #include "annotator/types.h"
25 #include "utils/base/logging.h"
26 #include "utils/base/macros.h"
27 #include "utils/strings/numbers.h"
28 #include "utils/utf8/unicodetext.h"
29 
30 namespace libtextclassifier3 {
31 
32 using DurationUnit = internal::DurationUnit;
33 
34 namespace internal {
35 
36 namespace {
ToLowerString(const std::string & str,const UniLib * unilib)37 std::string ToLowerString(const std::string& str, const UniLib* unilib) {
38   return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false))
39       .ToUTF8String();
40 }
41 
FillDurationUnitMap(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * expressions,DurationUnit duration_unit,std::unordered_map<std::string,DurationUnit> * target_map,const UniLib * unilib)42 void FillDurationUnitMap(
43     const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
44         expressions,
45     DurationUnit duration_unit,
46     std::unordered_map<std::string, DurationUnit>* target_map,
47     const UniLib* unilib) {
48   if (expressions == nullptr) {
49     return;
50   }
51 
52   for (const flatbuffers::String* expression_string : *expressions) {
53     (*target_map)[ToLowerString(expression_string->c_str(), unilib)] =
54         duration_unit;
55   }
56 }
57 }  // namespace
58 
BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions * options,const UniLib * unilib)59 std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
60     const DurationAnnotatorOptions* options, const UniLib* unilib) {
61   std::unordered_map<std::string, DurationUnit> mapping;
62   FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping,
63                       unilib);
64   FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping,
65                       unilib);
66   FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping,
67                       unilib);
68   FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
69                       &mapping, unilib);
70   FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
71                       &mapping, unilib);
72   return mapping;
73 }
74 
BuildStringSet(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * strings,const UniLib * unilib)75 std::unordered_set<std::string> BuildStringSet(
76     const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
77         strings,
78     const UniLib* unilib) {
79   std::unordered_set<std::string> result;
80   if (strings == nullptr) {
81     return result;
82   }
83 
84   for (const flatbuffers::String* string_value : *strings) {
85     result.insert(ToLowerString(string_value->c_str(), unilib));
86   }
87 
88   return result;
89 }
90 
BuildInt32Set(const flatbuffers::Vector<int32> * ints)91 std::unordered_set<int32> BuildInt32Set(
92     const flatbuffers::Vector<int32>* ints) {
93   std::unordered_set<int32> result;
94   if (ints == nullptr) {
95     return result;
96   }
97 
98   for (const int32 int_value : *ints) {
99     result.insert(int_value);
100   }
101 
102   return result;
103 }
104 
105 // Get the dangling quantity unit e.g. for 2 hours 10, 10 would have the unit
106 // "minute".
GetDanglingQuantityUnit(const DurationUnit main_unit)107 DurationUnit GetDanglingQuantityUnit(const DurationUnit main_unit) {
108   switch (main_unit) {
109     case DurationUnit::HOUR:
110       return DurationUnit::MINUTE;
111     case DurationUnit::MINUTE:
112       return DurationUnit::SECOND;
113     case DurationUnit::UNKNOWN:
114       TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
115       TC3_FALLTHROUGH_INTENDED;
116     case DurationUnit::WEEK:
117     case DurationUnit::DAY:
118     case DurationUnit::SECOND:
119       // We only support dangling units for hours and minutes.
120       return DurationUnit::UNKNOWN;
121   }
122 }
123 }  // namespace internal
124 
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,AnnotationUsecase annotation_usecase,ClassificationResult * classification_result) const125 bool DurationAnnotator::ClassifyText(
126     const UnicodeText& context, CodepointSpan selection_indices,
127     AnnotationUsecase annotation_usecase,
128     ClassificationResult* classification_result) const {
129   if (!options_->enabled() ||
130       ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
131           0 ||
132       !(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
133     return false;
134   }
135 
136   const UnicodeText selection =
137       UnicodeText::Substring(context, selection_indices.first,
138                              selection_indices.second, /*do_copy=*/false);
139   const std::vector<Token> tokens = feature_processor_->Tokenize(selection);
140 
141   AnnotatedSpan annotated_span;
142   if (tokens.empty() ||
143       FindDurationStartingAt(context, tokens, 0, &annotated_span) !=
144           tokens.size()) {
145     return false;
146   }
147 
148   TC3_DCHECK(!annotated_span.classification.empty());
149 
150   *classification_result = annotated_span.classification[0];
151   return true;
152 }
153 
FindAll(const UnicodeText & context,const std::vector<Token> & tokens,AnnotationUsecase annotation_usecase,ModeFlag mode,std::vector<AnnotatedSpan> * results) const154 bool DurationAnnotator::FindAll(const UnicodeText& context,
155                                 const std::vector<Token>& tokens,
156                                 AnnotationUsecase annotation_usecase,
157                                 ModeFlag mode,
158                                 std::vector<AnnotatedSpan>* results) const {
159   if (!options_->enabled() ||
160       ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
161           0 ||
162       !(options_->enabled_modes() & mode)) {
163     return true;
164   }
165 
166   for (int i = 0; i < tokens.size();) {
167     AnnotatedSpan span;
168     const int next_i = FindDurationStartingAt(context, tokens, i, &span);
169     if (next_i != i) {
170       results->push_back(span);
171       i = next_i;
172     } else {
173       i++;
174     }
175   }
176   return true;
177 }
178 
FindDurationStartingAt(const UnicodeText & context,const std::vector<Token> & tokens,int start_token_index,AnnotatedSpan * result) const179 int DurationAnnotator::FindDurationStartingAt(const UnicodeText& context,
180                                               const std::vector<Token>& tokens,
181                                               int start_token_index,
182                                               AnnotatedSpan* result) const {
183   CodepointIndex start_index = kInvalidIndex;
184   CodepointIndex end_index = kInvalidIndex;
185 
186   bool has_quantity = false;
187   ParsedDurationAtom parsed_duration;
188 
189   std::vector<ParsedDurationAtom> parsed_duration_atoms;
190 
191   // This is the core algorithm for finding the duration expressions. It
192   // basically iterates over tokens and changes the state variables above as it
193   // goes.
194   int token_index;
195   int quantity_end_index;
196   for (token_index = start_token_index; token_index < tokens.size();
197        token_index++) {
198     const Token& token = tokens[token_index];
199 
200     if (ParseQuantityToken(token, &parsed_duration)) {
201       has_quantity = true;
202       if (start_index == kInvalidIndex) {
203         start_index = token.start;
204       }
205       quantity_end_index = token.end;
206     } else if (((!options_->require_quantity() || has_quantity) &&
207                 ParseDurationUnitToken(token, &parsed_duration.unit)) ||
208                ParseQuantityDurationUnitToken(token, &parsed_duration)) {
209       if (start_index == kInvalidIndex) {
210         start_index = token.start;
211       }
212       end_index = token.end;
213       parsed_duration_atoms.push_back(parsed_duration);
214       has_quantity = false;
215       parsed_duration = ParsedDurationAtom();
216     } else if (ParseFillerToken(token)) {
217     } else {
218       break;
219     }
220   }
221 
222   if (parsed_duration_atoms.empty()) {
223     return start_token_index;
224   }
225 
226   const bool parse_ended_without_unit_for_last_mentioned_quantity =
227       has_quantity;
228 
229   if (parse_ended_without_unit_for_last_mentioned_quantity) {
230     const DurationUnit main_unit = parsed_duration_atoms.rbegin()->unit;
231     if (parsed_duration.plus_half) {
232       // Process "and half" suffix.
233       end_index = quantity_end_index;
234       ParsedDurationAtom atom = ParsedDurationAtom::Half();
235       atom.unit = main_unit;
236       parsed_duration_atoms.push_back(atom);
237     } else if (options_->enable_dangling_quantity_interpretation()) {
238       // Process dangling quantity.
239       ParsedDurationAtom atom;
240       atom.value = parsed_duration.value;
241       atom.unit = GetDanglingQuantityUnit(main_unit);
242       if (atom.unit != DurationUnit::UNKNOWN) {
243         end_index = quantity_end_index;
244         parsed_duration_atoms.push_back(atom);
245       }
246     }
247   }
248 
249   ClassificationResult classification{Collections::Duration(),
250                                       options_->score()};
251   classification.priority_score = options_->priority_score();
252   classification.duration_ms =
253       ParsedDurationAtomsToMillis(parsed_duration_atoms);
254 
255   result->span = feature_processor_->StripBoundaryCodepoints(
256       context, {start_index, end_index});
257   result->classification.push_back(classification);
258   result->source = AnnotatedSpan::Source::DURATION;
259 
260   return token_index;
261 }
262 
ParsedDurationAtomsToMillis(const std::vector<ParsedDurationAtom> & atoms) const263 int64 DurationAnnotator::ParsedDurationAtomsToMillis(
264     const std::vector<ParsedDurationAtom>& atoms) const {
265   int64 result = 0;
266   for (auto atom : atoms) {
267     int multiplier;
268     switch (atom.unit) {
269       case DurationUnit::WEEK:
270         multiplier = 7 * 24 * 60 * 60 * 1000;
271         break;
272       case DurationUnit::DAY:
273         multiplier = 24 * 60 * 60 * 1000;
274         break;
275       case DurationUnit::HOUR:
276         multiplier = 60 * 60 * 1000;
277         break;
278       case DurationUnit::MINUTE:
279         multiplier = 60 * 1000;
280         break;
281       case DurationUnit::SECOND:
282         multiplier = 1000;
283         break;
284       case DurationUnit::UNKNOWN:
285         TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
286         return -1;
287         break;
288     }
289 
290     double value = atom.value;
291     // This condition handles expressions like "an hour", where the quantity is
292     // not specified. In this case we assume quantity 1. Except for cases like
293     // "half hour".
294     if (value == 0 && !atom.plus_half) {
295       value = 1;
296     }
297     result += value * multiplier;
298     result += atom.plus_half * multiplier / 2;
299   }
300   return result;
301 }
302 
ParseQuantityToken(const Token & token,ParsedDurationAtom * value) const303 bool DurationAnnotator::ParseQuantityToken(const Token& token,
304                                            ParsedDurationAtom* value) const {
305   if (token.value.empty()) {
306     return false;
307   }
308 
309   std::string token_value_buffer;
310   const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
311       token.value, &token_value_buffer);
312   const std::string& lowercase_token_value =
313       internal::ToLowerString(token_value, unilib_);
314 
315   if (half_expressions_.find(lowercase_token_value) !=
316       half_expressions_.end()) {
317     value->plus_half = true;
318     return true;
319   }
320 
321   double parsed_value;
322   if (ParseDouble(lowercase_token_value.c_str(), &parsed_value)) {
323     value->value = parsed_value;
324     return true;
325   }
326 
327   return false;
328 }
329 
ParseDurationUnitToken(const Token & token,DurationUnit * duration_unit) const330 bool DurationAnnotator::ParseDurationUnitToken(
331     const Token& token, DurationUnit* duration_unit) const {
332   std::string token_value_buffer;
333   const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
334       token.value, &token_value_buffer);
335   const std::string& lowercase_token_value =
336       internal::ToLowerString(token_value, unilib_);
337 
338   const auto it = token_value_to_duration_unit_.find(lowercase_token_value);
339   if (it == token_value_to_duration_unit_.end()) {
340     return false;
341   }
342 
343   *duration_unit = it->second;
344   return true;
345 }
346 
ParseQuantityDurationUnitToken(const Token & token,ParsedDurationAtom * value) const347 bool DurationAnnotator::ParseQuantityDurationUnitToken(
348     const Token& token, ParsedDurationAtom* value) const {
349   if (token.value.empty()) {
350     return false;
351   }
352 
353   Token sub_token;
354   bool has_quantity = false;
355   for (const char c : token.value) {
356     if (sub_token_separator_codepoints_.find(c) !=
357         sub_token_separator_codepoints_.end()) {
358       if (has_quantity || !ParseQuantityToken(sub_token, value)) {
359         return false;
360       }
361       has_quantity = true;
362 
363       sub_token = Token();
364     } else {
365       sub_token.value += c;
366     }
367   }
368 
369   return (!options_->require_quantity() || has_quantity) &&
370          ParseDurationUnitToken(sub_token, &(value->unit));
371 }
372 
ParseFillerToken(const Token & token) const373 bool DurationAnnotator::ParseFillerToken(const Token& token) const {
374   std::string token_value_buffer;
375   const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
376       token.value, &token_value_buffer);
377   const std::string& lowercase_token_value =
378       internal::ToLowerString(token_value, unilib_);
379 
380   if (filler_expressions_.find(lowercase_token_value) ==
381       filler_expressions_.end()) {
382     return false;
383   }
384 
385   return true;
386 }
387 
388 }  // namespace libtextclassifier3
389