• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/datetime/regex-parser.h"
18 
19 #include <iterator>
20 #include <set>
21 #include <unordered_set>
22 
23 #include "annotator/datetime/extractor.h"
24 #include "annotator/datetime/utils.h"
25 #include "utils/base/statusor.h"
26 #include "utils/calendar/calendar.h"
27 #include "utils/i18n/locale.h"
28 #include "utils/strings/split.h"
29 #include "utils/zlib/zlib_regex.h"
30 
31 namespace libtextclassifier3 {
Instance(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)32 std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance(
33     const DatetimeModel* model, const UniLib* unilib,
34     const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
35   std::unique_ptr<RegexDatetimeParser> result(
36       new RegexDatetimeParser(model, unilib, calendarlib, decompressor));
37   if (!result->initialized_) {
38     result.reset();
39   }
40   return result;
41 }
42 
RegexDatetimeParser(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)43 RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model,
44                                          const UniLib* unilib,
45                                          const CalendarLib* calendarlib,
46                                          ZlibDecompressor* decompressor)
47     : unilib_(*unilib), calendarlib_(*calendarlib) {
48   initialized_ = false;
49 
50   if (model == nullptr) {
51     return;
52   }
53 
54   if (model->patterns() != nullptr) {
55     for (const DatetimeModelPattern* pattern : *model->patterns()) {
56       if (pattern->regexes()) {
57         for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
58           std::unique_ptr<UniLib::RegexPattern> regex_pattern =
59               UncompressMakeRegexPattern(
60                   unilib_, regex->pattern(), regex->compressed_pattern(),
61                   model->lazy_regex_compilation(), decompressor);
62           if (!regex_pattern) {
63             TC3_LOG(ERROR) << "Couldn't create rule pattern.";
64             return;
65           }
66           rules_.push_back({std::move(regex_pattern), regex, pattern});
67           if (pattern->locales()) {
68             for (int locale : *pattern->locales()) {
69               locale_to_rules_[locale].push_back(rules_.size() - 1);
70             }
71           }
72         }
73       }
74     }
75   }
76 
77   if (model->extractors() != nullptr) {
78     for (const DatetimeModelExtractor* extractor : *model->extractors()) {
79       std::unique_ptr<UniLib::RegexPattern> regex_pattern =
80           UncompressMakeRegexPattern(
81               unilib_, extractor->pattern(), extractor->compressed_pattern(),
82               model->lazy_regex_compilation(), decompressor);
83       if (!regex_pattern) {
84         TC3_LOG(ERROR) << "Couldn't create extractor pattern";
85         return;
86       }
87       extractor_rules_.push_back(std::move(regex_pattern));
88 
89       if (extractor->locales()) {
90         for (int locale : *extractor->locales()) {
91           type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
92               extractor_rules_.size() - 1;
93         }
94       }
95     }
96   }
97 
98   if (model->locales() != nullptr) {
99     for (int i = 0; i < model->locales()->size(); ++i) {
100       locale_string_to_id_[model->locales()->Get(i)->str()] = i;
101     }
102   }
103 
104   if (model->default_locales() != nullptr) {
105     for (const int locale : *model->default_locales()) {
106       default_locale_ids_.push_back(locale);
107     }
108   }
109 
110   use_extractors_for_locating_ = model->use_extractors_for_locating();
111   generate_alternative_interpretations_when_ambiguous_ =
112       model->generate_alternative_interpretations_when_ambiguous();
113   prefer_future_for_unspecified_date_ =
114       model->prefer_future_for_unspecified_date();
115 
116   initialized_ = true;
117 }
118 
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const119 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
120     const std::string& input, const int64 reference_time_ms_utc,
121     const std::string& reference_timezone, const LocaleList& locale_list,
122     ModeFlag mode, AnnotationUsecase annotation_usecase,
123     bool anchor_start_end) const {
124   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
125                reference_time_ms_utc, reference_timezone, locale_list, mode,
126                annotation_usecase, anchor_start_end);
127 }
128 
129 StatusOr<std::vector<DatetimeParseResultSpan>>
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules) const130 RegexDatetimeParser::FindSpansUsingLocales(
131     const std::vector<int>& locale_ids, const UnicodeText& input,
132     const int64 reference_time_ms_utc, const std::string& reference_timezone,
133     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
134     const std::string& reference_locale,
135     std::unordered_set<int>* executed_rules) const {
136   std::vector<DatetimeParseResultSpan> found_spans;
137   for (const int locale_id : locale_ids) {
138     auto rules_it = locale_to_rules_.find(locale_id);
139     if (rules_it == locale_to_rules_.end()) {
140       continue;
141     }
142 
143     for (const int rule_id : rules_it->second) {
144       // Skip rules that were already executed in previous locales.
145       if (executed_rules->find(rule_id) != executed_rules->end()) {
146         continue;
147       }
148 
149       if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
150            (1 << annotation_usecase)) == 0) {
151         continue;
152       }
153 
154       if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
155         continue;
156       }
157 
158       executed_rules->insert(rule_id);
159       TC3_ASSIGN_OR_RETURN(
160           const std::vector<DatetimeParseResultSpan>& found_spans_per_rule,
161           ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
162                         reference_timezone, reference_locale, locale_id,
163                         anchor_start_end));
164       found_spans.insert(std::end(found_spans),
165                          std::begin(found_spans_per_rule),
166                          std::end(found_spans_per_rule));
167     }
168   }
169   return found_spans;
170 }
171 
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const172 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
173     const UnicodeText& input, const int64 reference_time_ms_utc,
174     const std::string& reference_timezone, const LocaleList& locale_list,
175     ModeFlag mode, AnnotationUsecase annotation_usecase,
176     bool anchor_start_end) const {
177   std::unordered_set<int> executed_rules;
178   const std::vector<int> requested_locales =
179       ParseAndExpandLocales(locale_list.GetLocaleTags());
180   TC3_ASSIGN_OR_RETURN(
181       const std::vector<DatetimeParseResultSpan>& found_spans,
182       FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
183                             reference_timezone, mode, annotation_usecase,
184                             anchor_start_end, locale_list.GetReferenceLocale(),
185                             &executed_rules));
186   std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
187   indexed_found_spans.reserve(found_spans.size());
188   for (int i = 0; i < found_spans.size(); i++) {
189     indexed_found_spans.push_back({found_spans[i], i});
190   }
191 
192   // Resolve conflicts by always picking the longer span and breaking ties by
193   // selecting the earlier entry in the list for a given locale.
194   std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
195             [](const std::pair<DatetimeParseResultSpan, int>& a,
196                const std::pair<DatetimeParseResultSpan, int>& b) {
197               if ((a.first.span.second - a.first.span.first) !=
198                   (b.first.span.second - b.first.span.first)) {
199                 return (a.first.span.second - a.first.span.first) >
200                        (b.first.span.second - b.first.span.first);
201               } else {
202                 return a.second < b.second;
203               }
204             });
205 
206   std::vector<DatetimeParseResultSpan> results;
207   std::vector<DatetimeParseResultSpan> resolved_found_spans;
208   resolved_found_spans.reserve(indexed_found_spans.size());
209   for (auto& span_index_pair : indexed_found_spans) {
210     resolved_found_spans.push_back(span_index_pair.first);
211   }
212 
213   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
214       [&resolved_found_spans](int a, int b) {
215         return resolved_found_spans[a].span.first <
216                resolved_found_spans[b].span.first;
217       });
218   for (int i = 0; i < resolved_found_spans.size(); ++i) {
219     if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) {
220       chosen_indices_set.insert(i);
221       results.push_back(resolved_found_spans[i]);
222     }
223   }
224   return results;
225 }
226 
227 StatusOr<std::vector<DatetimeParseResultSpan>>
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id) const228 RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule,
229                                       const UniLib::RegexMatcher& matcher,
230                                       int64 reference_time_ms_utc,
231                                       const std::string& reference_timezone,
232                                       const std::string& reference_locale,
233                                       int locale_id) const {
234   std::vector<DatetimeParseResultSpan> results;
235   int status = UniLib::RegexMatcher::kNoError;
236   const int start = matcher.Start(&status);
237   if (status != UniLib::RegexMatcher::kNoError) {
238     return Status(StatusCode::INTERNAL,
239                   "Failed to gets the start offset of the last match.");
240   }
241 
242   const int end = matcher.End(&status);
243   if (status != UniLib::RegexMatcher::kNoError) {
244     return Status(StatusCode::INTERNAL,
245                   "Failed to gets the end offset of the last match.");
246   }
247 
248   DatetimeParseResultSpan parse_result;
249   std::vector<DatetimeParseResult> alternatives;
250   if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
251                        reference_locale, locale_id, &alternatives,
252                        &parse_result.span)) {
253     return Status(StatusCode::INTERNAL, "Failed to extract Datetime.");
254   }
255 
256   if (!use_extractors_for_locating_) {
257     parse_result.span = {start, end};
258   }
259 
260   if (parse_result.span.first != kInvalidIndex &&
261       parse_result.span.second != kInvalidIndex) {
262     parse_result.target_classification_score =
263         rule.pattern->target_classification_score();
264     parse_result.priority_score = rule.pattern->priority_score();
265 
266     for (DatetimeParseResult& alternative : alternatives) {
267       parse_result.data.push_back(alternative);
268     }
269   }
270   results.push_back(parse_result);
271   return results;
272 }
273 
274 StatusOr<std::vector<DatetimeParseResultSpan>>
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end) const275 RegexDatetimeParser::ParseWithRule(const CompiledRule& rule,
276                                    const UnicodeText& input,
277                                    const int64 reference_time_ms_utc,
278                                    const std::string& reference_timezone,
279                                    const std::string& reference_locale,
280                                    const int locale_id,
281                                    bool anchor_start_end) const {
282   std::vector<DatetimeParseResultSpan> results;
283   std::unique_ptr<UniLib::RegexMatcher> matcher =
284       rule.compiled_regex->Matcher(input);
285   int status = UniLib::RegexMatcher::kNoError;
286   if (anchor_start_end) {
287     if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
288       return HandleParseMatch(rule, *matcher, reference_time_ms_utc,
289                               reference_timezone, reference_locale, locale_id);
290     }
291   } else {
292     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
293       TC3_ASSIGN_OR_RETURN(
294           const std::vector<DatetimeParseResultSpan>& pattern_occurrence,
295           HandleParseMatch(rule, *matcher, reference_time_ms_utc,
296                            reference_timezone, reference_locale, locale_id));
297       results.insert(std::end(results), std::begin(pattern_occurrence),
298                      std::end(pattern_occurrence));
299     }
300   }
301   return results;
302 }
303 
ParseAndExpandLocales(const std::vector<StringPiece> & locales) const304 std::vector<int> RegexDatetimeParser::ParseAndExpandLocales(
305     const std::vector<StringPiece>& locales) const {
306   std::vector<int> result;
307   for (const StringPiece& locale_str : locales) {
308     auto locale_it = locale_string_to_id_.find(locale_str.ToString());
309     if (locale_it != locale_string_to_id_.end()) {
310       result.push_back(locale_it->second);
311     }
312 
313     const Locale locale = Locale::FromBCP47(locale_str.ToString());
314     if (!locale.IsValid()) {
315       continue;
316     }
317 
318     const std::string language = locale.Language();
319     const std::string script = locale.Script();
320     const std::string region = locale.Region();
321 
322     // First, try adding *-region locale.
323     if (!region.empty()) {
324       locale_it = locale_string_to_id_.find("*-" + region);
325       if (locale_it != locale_string_to_id_.end()) {
326         result.push_back(locale_it->second);
327       }
328     }
329     // Second, try adding language-script-* locale.
330     if (!script.empty()) {
331       locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
332       if (locale_it != locale_string_to_id_.end()) {
333         result.push_back(locale_it->second);
334       }
335     }
336     // Third, try adding language-* locale.
337     if (!language.empty()) {
338       locale_it = locale_string_to_id_.find(language + "-*");
339       if (locale_it != locale_string_to_id_.end()) {
340         result.push_back(locale_it->second);
341       }
342     }
343   }
344 
345   // Add the default locales if they haven't been added already.
346   const std::unordered_set<int> result_set(result.begin(), result.end());
347   for (const int default_locale_id : default_locale_ids_) {
348     if (result_set.find(default_locale_id) == result_set.end()) {
349       result.push_back(default_locale_id);
350     }
351   }
352 
353   return result;
354 }
355 
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResult> * results,CodepointSpan * result_span) const356 bool RegexDatetimeParser::ExtractDatetime(
357     const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
358     const int64 reference_time_ms_utc, const std::string& reference_timezone,
359     const std::string& reference_locale, int locale_id,
360     std::vector<DatetimeParseResult>* results,
361     CodepointSpan* result_span) const {
362   DatetimeParsedData parse;
363   DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
364                               extractor_rules_,
365                               type_and_locale_to_extractor_rule_);
366   if (!extractor.Extract(&parse, result_span)) {
367     return false;
368   }
369   std::vector<DatetimeParsedData> interpretations;
370   if (generate_alternative_interpretations_when_ambiguous_) {
371     FillInterpretations(parse, calendarlib_.GetGranularity(parse),
372                         &interpretations);
373   } else {
374     interpretations.push_back(parse);
375   }
376 
377   results->reserve(results->size() + interpretations.size());
378   for (const DatetimeParsedData& interpretation : interpretations) {
379     std::vector<DatetimeComponent> date_components;
380     interpretation.GetDatetimeComponents(&date_components);
381     DatetimeParseResult result;
382     // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM
383     //               which is encoded in the pair of DatetimeParseResult; both
384     //               corresponding to the same date, but one corresponding to
385     //               “AM” and the other one corresponding to “PM”.
386     //               Remove multiple DatetimeParseResult per datetime span,
387     //               once the ambiguities/DatetimeComponents are added in the
388     //               response. For Details see b/130355975
389     if (!calendarlib_.InterpretParseData(
390             interpretation, reference_time_ms_utc, reference_timezone,
391             reference_locale, prefer_future_for_unspecified_date_,
392             &(result.time_ms_utc), &(result.granularity))) {
393       return false;
394     }
395 
396     // Sort the date time units by component type.
397     std::sort(date_components.begin(), date_components.end(),
398               [](DatetimeComponent a, DatetimeComponent b) {
399                 return a.component_type > b.component_type;
400               });
401     result.datetime_components.swap(date_components);
402     results->push_back(result);
403   }
404   return true;
405 }
406 
407 }  // namespace libtextclassifier3
408