• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/datetime/parser.h"
18 
19 #include <set>
20 #include <unordered_set>
21 
22 #include "annotator/datetime/extractor.h"
23 #include "annotator/datetime/utils.h"
24 #include "utils/calendar/calendar.h"
25 #include "utils/i18n/locale.h"
26 #include "utils/strings/split.h"
27 #include "utils/zlib/zlib_regex.h"
28 
29 namespace libtextclassifier3 {
Instance(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)30 std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
31     const DatetimeModel* model, const UniLib* unilib,
32     const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
33   std::unique_ptr<DatetimeParser> result(
34       new DatetimeParser(model, unilib, calendarlib, decompressor));
35   if (!result->initialized_) {
36     result.reset();
37   }
38   return result;
39 }
40 
DatetimeParser(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)41 DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
42                                const CalendarLib* calendarlib,
43                                ZlibDecompressor* decompressor)
44     : unilib_(*unilib), calendarlib_(*calendarlib) {
45   initialized_ = false;
46 
47   if (model == nullptr) {
48     return;
49   }
50 
51   if (model->patterns() != nullptr) {
52     for (const DatetimeModelPattern* pattern : *model->patterns()) {
53       if (pattern->regexes()) {
54         for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
55           std::unique_ptr<UniLib::RegexPattern> regex_pattern =
56               UncompressMakeRegexPattern(
57                   unilib_, regex->pattern(), regex->compressed_pattern(),
58                   model->lazy_regex_compilation(), decompressor);
59           if (!regex_pattern) {
60             TC3_LOG(ERROR) << "Couldn't create rule pattern.";
61             return;
62           }
63           rules_.push_back({std::move(regex_pattern), regex, pattern});
64           if (pattern->locales()) {
65             for (int locale : *pattern->locales()) {
66               locale_to_rules_[locale].push_back(rules_.size() - 1);
67             }
68           }
69         }
70       }
71     }
72   }
73 
74   if (model->extractors() != nullptr) {
75     for (const DatetimeModelExtractor* extractor : *model->extractors()) {
76       std::unique_ptr<UniLib::RegexPattern> regex_pattern =
77           UncompressMakeRegexPattern(
78               unilib_, extractor->pattern(), extractor->compressed_pattern(),
79               model->lazy_regex_compilation(), decompressor);
80       if (!regex_pattern) {
81         TC3_LOG(ERROR) << "Couldn't create extractor pattern";
82         return;
83       }
84       extractor_rules_.push_back(std::move(regex_pattern));
85 
86       if (extractor->locales()) {
87         for (int locale : *extractor->locales()) {
88           type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
89               extractor_rules_.size() - 1;
90         }
91       }
92     }
93   }
94 
95   if (model->locales() != nullptr) {
96     for (int i = 0; i < model->locales()->size(); ++i) {
97       locale_string_to_id_[model->locales()->Get(i)->str()] = i;
98     }
99   }
100 
101   if (model->default_locales() != nullptr) {
102     for (const int locale : *model->default_locales()) {
103       default_locale_ids_.push_back(locale);
104     }
105   }
106 
107   use_extractors_for_locating_ = model->use_extractors_for_locating();
108   generate_alternative_interpretations_when_ambiguous_ =
109       model->generate_alternative_interpretations_when_ambiguous();
110   prefer_future_for_unspecified_date_ =
111       model->prefer_future_for_unspecified_date();
112 
113   initialized_ = true;
114 }
115 
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const116 bool DatetimeParser::Parse(
117     const std::string& input, const int64 reference_time_ms_utc,
118     const std::string& reference_timezone, const std::string& locales,
119     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
120     std::vector<DatetimeParseResultSpan>* results) const {
121   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
122                reference_time_ms_utc, reference_timezone, locales, mode,
123                annotation_usecase, anchor_start_end, results);
124 }
125 
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules,std::vector<DatetimeParseResultSpan> * found_spans) const126 bool DatetimeParser::FindSpansUsingLocales(
127     const std::vector<int>& locale_ids, const UnicodeText& input,
128     const int64 reference_time_ms_utc, const std::string& reference_timezone,
129     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
130     const std::string& reference_locale,
131     std::unordered_set<int>* executed_rules,
132     std::vector<DatetimeParseResultSpan>* found_spans) const {
133   for (const int locale_id : locale_ids) {
134     auto rules_it = locale_to_rules_.find(locale_id);
135     if (rules_it == locale_to_rules_.end()) {
136       continue;
137     }
138 
139     for (const int rule_id : rules_it->second) {
140       // Skip rules that were already executed in previous locales.
141       if (executed_rules->find(rule_id) != executed_rules->end()) {
142         continue;
143       }
144 
145       if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
146            (1 << annotation_usecase)) == 0) {
147         continue;
148       }
149 
150       if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
151         continue;
152       }
153 
154       executed_rules->insert(rule_id);
155 
156       if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
157                          reference_timezone, reference_locale, locale_id,
158                          anchor_start_end, found_spans)) {
159         return false;
160       }
161     }
162   }
163   return true;
164 }
165 
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const166 bool DatetimeParser::Parse(
167     const UnicodeText& input, const int64 reference_time_ms_utc,
168     const std::string& reference_timezone, const std::string& locales,
169     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
170     std::vector<DatetimeParseResultSpan>* results) const {
171   std::vector<DatetimeParseResultSpan> found_spans;
172   std::unordered_set<int> executed_rules;
173   std::string reference_locale;
174   const std::vector<int> requested_locales =
175       ParseAndExpandLocales(locales, &reference_locale);
176   if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
177                              reference_timezone, mode, annotation_usecase,
178                              anchor_start_end, reference_locale,
179                              &executed_rules, &found_spans)) {
180     return false;
181   }
182 
183   std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
184   indexed_found_spans.reserve(found_spans.size());
185   for (int i = 0; i < found_spans.size(); i++) {
186     indexed_found_spans.push_back({found_spans[i], i});
187   }
188 
189   // Resolve conflicts by always picking the longer span and breaking ties by
190   // selecting the earlier entry in the list for a given locale.
191   std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
192             [](const std::pair<DatetimeParseResultSpan, int>& a,
193                const std::pair<DatetimeParseResultSpan, int>& b) {
194               if ((a.first.span.second - a.first.span.first) !=
195                   (b.first.span.second - b.first.span.first)) {
196                 return (a.first.span.second - a.first.span.first) >
197                        (b.first.span.second - b.first.span.first);
198               } else {
199                 return a.second < b.second;
200               }
201             });
202 
203   found_spans.clear();
204   for (auto& span_index_pair : indexed_found_spans) {
205     found_spans.push_back(span_index_pair.first);
206   }
207 
208   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
209       [&found_spans](int a, int b) {
210         return found_spans[a].span.first < found_spans[b].span.first;
211       });
212   for (int i = 0; i < found_spans.size(); ++i) {
213     if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
214       chosen_indices_set.insert(i);
215       results->push_back(found_spans[i]);
216     }
217   }
218 
219   return true;
220 }
221 
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResultSpan> * result) const222 bool DatetimeParser::HandleParseMatch(
223     const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
224     int64 reference_time_ms_utc, const std::string& reference_timezone,
225     const std::string& reference_locale, int locale_id,
226     std::vector<DatetimeParseResultSpan>* result) const {
227   int status = UniLib::RegexMatcher::kNoError;
228   const int start = matcher.Start(&status);
229   if (status != UniLib::RegexMatcher::kNoError) {
230     return false;
231   }
232 
233   const int end = matcher.End(&status);
234   if (status != UniLib::RegexMatcher::kNoError) {
235     return false;
236   }
237 
238   DatetimeParseResultSpan parse_result;
239   std::vector<DatetimeParseResult> alternatives;
240   if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
241                        reference_locale, locale_id, &alternatives,
242                        &parse_result.span)) {
243     return false;
244   }
245 
246   if (!use_extractors_for_locating_) {
247     parse_result.span = {start, end};
248   }
249 
250   if (parse_result.span.first != kInvalidIndex &&
251       parse_result.span.second != kInvalidIndex) {
252     parse_result.target_classification_score =
253         rule.pattern->target_classification_score();
254     parse_result.priority_score = rule.pattern->priority_score();
255 
256     for (DatetimeParseResult& alternative : alternatives) {
257       parse_result.data.push_back(alternative);
258     }
259   }
260   result->push_back(parse_result);
261   return true;
262 }
263 
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * result) const264 bool DatetimeParser::ParseWithRule(
265     const CompiledRule& rule, const UnicodeText& input,
266     const int64 reference_time_ms_utc, const std::string& reference_timezone,
267     const std::string& reference_locale, const int locale_id,
268     bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
269   std::unique_ptr<UniLib::RegexMatcher> matcher =
270       rule.compiled_regex->Matcher(input);
271   int status = UniLib::RegexMatcher::kNoError;
272   if (anchor_start_end) {
273     if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
274       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
275                             reference_timezone, reference_locale, locale_id,
276                             result)) {
277         return false;
278       }
279     }
280   } else {
281     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
282       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
283                             reference_timezone, reference_locale, locale_id,
284                             result)) {
285         return false;
286       }
287     }
288   }
289   return true;
290 }
291 
ParseAndExpandLocales(const std::string & locales,std::string * reference_locale) const292 std::vector<int> DatetimeParser::ParseAndExpandLocales(
293     const std::string& locales, std::string* reference_locale) const {
294   std::vector<StringPiece> split_locales = strings::Split(locales, ',');
295   if (!split_locales.empty()) {
296     *reference_locale = split_locales[0].ToString();
297   } else {
298     *reference_locale = "";
299   }
300 
301   std::vector<int> result;
302   for (const StringPiece& locale_str : split_locales) {
303     auto locale_it = locale_string_to_id_.find(locale_str.ToString());
304     if (locale_it != locale_string_to_id_.end()) {
305       result.push_back(locale_it->second);
306     }
307 
308     const Locale locale = Locale::FromBCP47(locale_str.ToString());
309     if (!locale.IsValid()) {
310       continue;
311     }
312 
313     const std::string language = locale.Language();
314     const std::string script = locale.Script();
315     const std::string region = locale.Region();
316 
317     // First, try adding *-region locale.
318     if (!region.empty()) {
319       locale_it = locale_string_to_id_.find("*-" + region);
320       if (locale_it != locale_string_to_id_.end()) {
321         result.push_back(locale_it->second);
322       }
323     }
324     // Second, try adding language-script-* locale.
325     if (!script.empty()) {
326       locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
327       if (locale_it != locale_string_to_id_.end()) {
328         result.push_back(locale_it->second);
329       }
330     }
331     // Third, try adding language-* locale.
332     if (!language.empty()) {
333       locale_it = locale_string_to_id_.find(language + "-*");
334       if (locale_it != locale_string_to_id_.end()) {
335         result.push_back(locale_it->second);
336       }
337     }
338   }
339 
340   // Add the default locales if they haven't been added already.
341   const std::unordered_set<int> result_set(result.begin(), result.end());
342   for (const int default_locale_id : default_locale_ids_) {
343     if (result_set.find(default_locale_id) == result_set.end()) {
344       result.push_back(default_locale_id);
345     }
346   }
347 
348   return result;
349 }
350 
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResult> * results,CodepointSpan * result_span) const351 bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
352                                      const UniLib::RegexMatcher& matcher,
353                                      const int64 reference_time_ms_utc,
354                                      const std::string& reference_timezone,
355                                      const std::string& reference_locale,
356                                      int locale_id,
357                                      std::vector<DatetimeParseResult>* results,
358                                      CodepointSpan* result_span) const {
359   DatetimeParsedData parse;
360   DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
361                               extractor_rules_,
362                               type_and_locale_to_extractor_rule_);
363   if (!extractor.Extract(&parse, result_span)) {
364     return false;
365   }
366   std::vector<DatetimeParsedData> interpretations;
367   if (generate_alternative_interpretations_when_ambiguous_) {
368     FillInterpretations(parse, calendarlib_.GetGranularity(parse),
369                         &interpretations);
370   } else {
371     interpretations.push_back(parse);
372   }
373 
374   results->reserve(results->size() + interpretations.size());
375   for (const DatetimeParsedData& interpretation : interpretations) {
376     std::vector<DatetimeComponent> date_components;
377     interpretation.GetDatetimeComponents(&date_components);
378     DatetimeParseResult result;
379     // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM
380     //               which is encoded in the pair of DatetimeParseResult; both
381     //               corresponding to the same date, but one corresponding to
382     //               “AM” and the other one corresponding to “PM”.
383     //               Remove multiple DatetimeParseResult per datetime span,
384     //               once the ambiguities/DatetimeComponents are added in the
385     //               response. For Details see b/130355975
386     if (!calendarlib_.InterpretParseData(
387             interpretation, reference_time_ms_utc, reference_timezone,
388             reference_locale, prefer_future_for_unspecified_date_,
389             &(result.time_ms_utc), &(result.granularity))) {
390       return false;
391     }
392 
393     // Sort the date time units by component type.
394     std::sort(date_components.begin(), date_components.end(),
395               [](DatetimeComponent a, DatetimeComponent b) {
396                 return a.component_type > b.component_type;
397               });
398     result.datetime_components.swap(date_components);
399     results->push_back(result);
400   }
401   return true;
402 }
403 
404 }  // namespace libtextclassifier3
405