• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "datetime/parser.h"
18 
19 #include <set>
20 #include <unordered_set>
21 
22 #include "datetime/extractor.h"
23 #include "util/calendar/calendar.h"
24 #include "util/i18n/locale.h"
25 #include "util/strings/split.h"
26 
27 namespace libtextclassifier2 {
Instance(const DatetimeModel * model,const UniLib & unilib,ZlibDecompressor * decompressor)28 std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
29     const DatetimeModel* model, const UniLib& unilib,
30     ZlibDecompressor* decompressor) {
31   std::unique_ptr<DatetimeParser> result(
32       new DatetimeParser(model, unilib, decompressor));
33   if (!result->initialized_) {
34     result.reset();
35   }
36   return result;
37 }
38 
DatetimeParser(const DatetimeModel * model,const UniLib & unilib,ZlibDecompressor * decompressor)39 DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
40                                ZlibDecompressor* decompressor)
41     : unilib_(unilib) {
42   initialized_ = false;
43 
44   if (model == nullptr) {
45     return;
46   }
47 
48   if (model->patterns() != nullptr) {
49     for (const DatetimeModelPattern* pattern : *model->patterns()) {
50       if (pattern->regexes()) {
51         for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
52           std::unique_ptr<UniLib::RegexPattern> regex_pattern =
53               UncompressMakeRegexPattern(unilib, regex->pattern(),
54                                          regex->compressed_pattern(),
55                                          decompressor);
56           if (!regex_pattern) {
57             TC_LOG(ERROR) << "Couldn't create rule pattern.";
58             return;
59           }
60           rules_.push_back({std::move(regex_pattern), regex, pattern});
61           if (pattern->locales()) {
62             for (int locale : *pattern->locales()) {
63               locale_to_rules_[locale].push_back(rules_.size() - 1);
64             }
65           }
66         }
67       }
68     }
69   }
70 
71   if (model->extractors() != nullptr) {
72     for (const DatetimeModelExtractor* extractor : *model->extractors()) {
73       std::unique_ptr<UniLib::RegexPattern> regex_pattern =
74           UncompressMakeRegexPattern(unilib, extractor->pattern(),
75                                      extractor->compressed_pattern(),
76                                      decompressor);
77       if (!regex_pattern) {
78         TC_LOG(ERROR) << "Couldn't create extractor pattern";
79         return;
80       }
81       extractor_rules_.push_back(std::move(regex_pattern));
82 
83       if (extractor->locales()) {
84         for (int locale : *extractor->locales()) {
85           type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
86               extractor_rules_.size() - 1;
87         }
88       }
89     }
90   }
91 
92   if (model->locales() != nullptr) {
93     for (int i = 0; i < model->locales()->Length(); ++i) {
94       locale_string_to_id_[model->locales()->Get(i)->str()] = i;
95     }
96   }
97 
98   if (model->default_locales() != nullptr) {
99     for (const int locale : *model->default_locales()) {
100       default_locale_ids_.push_back(locale);
101     }
102   }
103 
104   use_extractors_for_locating_ = model->use_extractors_for_locating();
105 
106   initialized_ = true;
107 }
108 
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const109 bool DatetimeParser::Parse(
110     const std::string& input, const int64 reference_time_ms_utc,
111     const std::string& reference_timezone, const std::string& locales,
112     ModeFlag mode, bool anchor_start_end,
113     std::vector<DatetimeParseResultSpan>* results) const {
114   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
115                reference_time_ms_utc, reference_timezone, locales, mode,
116                anchor_start_end, results);
117 }
118 
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules,std::vector<DatetimeParseResultSpan> * found_spans) const119 bool DatetimeParser::FindSpansUsingLocales(
120     const std::vector<int>& locale_ids, const UnicodeText& input,
121     const int64 reference_time_ms_utc, const std::string& reference_timezone,
122     ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
123     std::unordered_set<int>* executed_rules,
124     std::vector<DatetimeParseResultSpan>* found_spans) const {
125   for (const int locale_id : locale_ids) {
126     auto rules_it = locale_to_rules_.find(locale_id);
127     if (rules_it == locale_to_rules_.end()) {
128       continue;
129     }
130 
131     for (const int rule_id : rules_it->second) {
132       // Skip rules that were already executed in previous locales.
133       if (executed_rules->find(rule_id) != executed_rules->end()) {
134         continue;
135       }
136 
137       if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
138         continue;
139       }
140 
141       executed_rules->insert(rule_id);
142 
143       if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
144                          reference_timezone, reference_locale, locale_id,
145                          anchor_start_end, found_spans)) {
146         return false;
147       }
148     }
149   }
150   return true;
151 }
152 
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const153 bool DatetimeParser::Parse(
154     const UnicodeText& input, const int64 reference_time_ms_utc,
155     const std::string& reference_timezone, const std::string& locales,
156     ModeFlag mode, bool anchor_start_end,
157     std::vector<DatetimeParseResultSpan>* results) const {
158   std::vector<DatetimeParseResultSpan> found_spans;
159   std::unordered_set<int> executed_rules;
160   std::string reference_locale;
161   const std::vector<int> requested_locales =
162       ParseAndExpandLocales(locales, &reference_locale);
163   if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
164                              reference_timezone, mode, anchor_start_end,
165                              reference_locale, &executed_rules, &found_spans)) {
166     return false;
167   }
168 
169   std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
170   int counter = 0;
171   for (const auto& found_span : found_spans) {
172     indexed_found_spans.push_back({found_span, counter});
173     counter++;
174   }
175 
176   // Resolve conflicts by always picking the longer span and breaking ties by
177   // selecting the earlier entry in the list for a given locale.
178   std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
179             [](const std::pair<DatetimeParseResultSpan, int>& a,
180                const std::pair<DatetimeParseResultSpan, int>& b) {
181               if ((a.first.span.second - a.first.span.first) !=
182                   (b.first.span.second - b.first.span.first)) {
183                 return (a.first.span.second - a.first.span.first) >
184                        (b.first.span.second - b.first.span.first);
185               } else {
186                 return a.second < b.second;
187               }
188             });
189 
190   found_spans.clear();
191   for (auto& span_index_pair : indexed_found_spans) {
192     found_spans.push_back(span_index_pair.first);
193   }
194 
195   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
196       [&found_spans](int a, int b) {
197         return found_spans[a].span.first < found_spans[b].span.first;
198       });
199   for (int i = 0; i < found_spans.size(); ++i) {
200     if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
201       chosen_indices_set.insert(i);
202       results->push_back(found_spans[i]);
203     }
204   }
205 
206   return true;
207 }
208 
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResultSpan> * result) const209 bool DatetimeParser::HandleParseMatch(
210     const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
211     int64 reference_time_ms_utc, const std::string& reference_timezone,
212     const std::string& reference_locale, int locale_id,
213     std::vector<DatetimeParseResultSpan>* result) const {
214   int status = UniLib::RegexMatcher::kNoError;
215   const int start = matcher.Start(&status);
216   if (status != UniLib::RegexMatcher::kNoError) {
217     return false;
218   }
219 
220   const int end = matcher.End(&status);
221   if (status != UniLib::RegexMatcher::kNoError) {
222     return false;
223   }
224 
225   DatetimeParseResultSpan parse_result;
226   if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
227                        reference_locale, locale_id, &(parse_result.data),
228                        &parse_result.span)) {
229     return false;
230   }
231   if (!use_extractors_for_locating_) {
232     parse_result.span = {start, end};
233   }
234   if (parse_result.span.first != kInvalidIndex &&
235       parse_result.span.second != kInvalidIndex) {
236     parse_result.target_classification_score =
237         rule.pattern->target_classification_score();
238     parse_result.priority_score = rule.pattern->priority_score();
239     result->push_back(parse_result);
240   }
241   return true;
242 }
243 
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * result) const244 bool DatetimeParser::ParseWithRule(
245     const CompiledRule& rule, const UnicodeText& input,
246     const int64 reference_time_ms_utc, const std::string& reference_timezone,
247     const std::string& reference_locale, const int locale_id,
248     bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
249   std::unique_ptr<UniLib::RegexMatcher> matcher =
250       rule.compiled_regex->Matcher(input);
251   int status = UniLib::RegexMatcher::kNoError;
252   if (anchor_start_end) {
253     if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
254       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
255                             reference_timezone, reference_locale, locale_id,
256                             result)) {
257         return false;
258       }
259     }
260   } else {
261     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
262       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
263                             reference_timezone, reference_locale, locale_id,
264                             result)) {
265         return false;
266       }
267     }
268   }
269   return true;
270 }
271 
ParseAndExpandLocales(const std::string & locales,std::string * reference_locale) const272 std::vector<int> DatetimeParser::ParseAndExpandLocales(
273     const std::string& locales, std::string* reference_locale) const {
274   std::vector<StringPiece> split_locales = strings::Split(locales, ',');
275   if (!split_locales.empty()) {
276     *reference_locale = split_locales[0].ToString();
277   } else {
278     *reference_locale = "";
279   }
280 
281   std::vector<int> result;
282   for (const StringPiece& locale_str : split_locales) {
283     auto locale_it = locale_string_to_id_.find(locale_str.ToString());
284     if (locale_it != locale_string_to_id_.end()) {
285       result.push_back(locale_it->second);
286     }
287 
288     const Locale locale = Locale::FromBCP47(locale_str.ToString());
289     if (!locale.IsValid()) {
290       continue;
291     }
292 
293     const std::string language = locale.Language();
294     const std::string script = locale.Script();
295     const std::string region = locale.Region();
296 
297     // First, try adding *-region locale.
298     if (!region.empty()) {
299       locale_it = locale_string_to_id_.find("*-" + region);
300       if (locale_it != locale_string_to_id_.end()) {
301         result.push_back(locale_it->second);
302       }
303     }
304     // Second, try adding language-script-* locale.
305     if (!script.empty()) {
306       locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
307       if (locale_it != locale_string_to_id_.end()) {
308         result.push_back(locale_it->second);
309       }
310     }
311     // Third, try adding language-* locale.
312     if (!language.empty()) {
313       locale_it = locale_string_to_id_.find(language + "-*");
314       if (locale_it != locale_string_to_id_.end()) {
315         result.push_back(locale_it->second);
316       }
317     }
318   }
319 
320   // Add the default locales if they haven't been added already.
321   const std::unordered_set<int> result_set(result.begin(), result.end());
322   for (const int default_locale_id : default_locale_ids_) {
323     if (result_set.find(default_locale_id) == result_set.end()) {
324       result.push_back(default_locale_id);
325     }
326   }
327 
328   return result;
329 }
330 
331 namespace {
332 
GetGranularity(const DateParseData & data)333 DatetimeGranularity GetGranularity(const DateParseData& data) {
334   DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
335   if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
336       (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
337        (data.relation_type == DateParseData::RelationType::YEAR))) {
338     granularity = DatetimeGranularity::GRANULARITY_YEAR;
339   }
340   if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
341       (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
342        (data.relation_type == DateParseData::RelationType::MONTH))) {
343     granularity = DatetimeGranularity::GRANULARITY_MONTH;
344   }
345   if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
346       (data.relation_type == DateParseData::RelationType::WEEK)) {
347     granularity = DatetimeGranularity::GRANULARITY_WEEK;
348   }
349   if (data.field_set_mask & DateParseData::DAY_FIELD ||
350       (data.field_set_mask & DateParseData::RELATION_FIELD &&
351        (data.relation == DateParseData::Relation::NOW ||
352         data.relation == DateParseData::Relation::TOMORROW ||
353         data.relation == DateParseData::Relation::YESTERDAY)) ||
354       (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
355        (data.relation_type == DateParseData::RelationType::MONDAY ||
356         data.relation_type == DateParseData::RelationType::TUESDAY ||
357         data.relation_type == DateParseData::RelationType::WEDNESDAY ||
358         data.relation_type == DateParseData::RelationType::THURSDAY ||
359         data.relation_type == DateParseData::RelationType::FRIDAY ||
360         data.relation_type == DateParseData::RelationType::SATURDAY ||
361         data.relation_type == DateParseData::RelationType::SUNDAY ||
362         data.relation_type == DateParseData::RelationType::DAY))) {
363     granularity = DatetimeGranularity::GRANULARITY_DAY;
364   }
365   if (data.field_set_mask & DateParseData::HOUR_FIELD) {
366     granularity = DatetimeGranularity::GRANULARITY_HOUR;
367   }
368   if (data.field_set_mask & DateParseData::MINUTE_FIELD) {
369     granularity = DatetimeGranularity::GRANULARITY_MINUTE;
370   }
371   if (data.field_set_mask & DateParseData::SECOND_FIELD) {
372     granularity = DatetimeGranularity::GRANULARITY_SECOND;
373   }
374   return granularity;
375 }
376 
377 }  // namespace
378 
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,DatetimeParseResult * result,CodepointSpan * result_span) const379 bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
380                                      const UniLib::RegexMatcher& matcher,
381                                      const int64 reference_time_ms_utc,
382                                      const std::string& reference_timezone,
383                                      const std::string& reference_locale,
384                                      int locale_id, DatetimeParseResult* result,
385                                      CodepointSpan* result_span) const {
386   DateParseData parse;
387   DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
388                               extractor_rules_,
389                               type_and_locale_to_extractor_rule_);
390   if (!extractor.Extract(&parse, result_span)) {
391     return false;
392   }
393 
394   result->granularity = GetGranularity(parse);
395 
396   if (!calendar_lib_.InterpretParseData(
397           parse, reference_time_ms_utc, reference_timezone, reference_locale,
398           result->granularity, &(result->time_ms_utc))) {
399     return false;
400   }
401 
402   return true;
403 }
404 
405 }  // namespace libtextclassifier2
406