1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "datetime/parser.h"
18
19 #include <set>
20 #include <unordered_set>
21
22 #include "datetime/extractor.h"
23 #include "util/calendar/calendar.h"
24 #include "util/i18n/locale.h"
25 #include "util/strings/split.h"
26
27 namespace libtextclassifier2 {
Instance(const DatetimeModel * model,const UniLib & unilib,ZlibDecompressor * decompressor)28 std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
29 const DatetimeModel* model, const UniLib& unilib,
30 ZlibDecompressor* decompressor) {
31 std::unique_ptr<DatetimeParser> result(
32 new DatetimeParser(model, unilib, decompressor));
33 if (!result->initialized_) {
34 result.reset();
35 }
36 return result;
37 }
38
DatetimeParser(const DatetimeModel * model,const UniLib & unilib,ZlibDecompressor * decompressor)39 DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
40 ZlibDecompressor* decompressor)
41 : unilib_(unilib) {
42 initialized_ = false;
43
44 if (model == nullptr) {
45 return;
46 }
47
48 if (model->patterns() != nullptr) {
49 for (const DatetimeModelPattern* pattern : *model->patterns()) {
50 if (pattern->regexes()) {
51 for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
52 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
53 UncompressMakeRegexPattern(unilib, regex->pattern(),
54 regex->compressed_pattern(),
55 decompressor);
56 if (!regex_pattern) {
57 TC_LOG(ERROR) << "Couldn't create rule pattern.";
58 return;
59 }
60 rules_.push_back({std::move(regex_pattern), regex, pattern});
61 if (pattern->locales()) {
62 for (int locale : *pattern->locales()) {
63 locale_to_rules_[locale].push_back(rules_.size() - 1);
64 }
65 }
66 }
67 }
68 }
69 }
70
71 if (model->extractors() != nullptr) {
72 for (const DatetimeModelExtractor* extractor : *model->extractors()) {
73 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
74 UncompressMakeRegexPattern(unilib, extractor->pattern(),
75 extractor->compressed_pattern(),
76 decompressor);
77 if (!regex_pattern) {
78 TC_LOG(ERROR) << "Couldn't create extractor pattern";
79 return;
80 }
81 extractor_rules_.push_back(std::move(regex_pattern));
82
83 if (extractor->locales()) {
84 for (int locale : *extractor->locales()) {
85 type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
86 extractor_rules_.size() - 1;
87 }
88 }
89 }
90 }
91
92 if (model->locales() != nullptr) {
93 for (int i = 0; i < model->locales()->Length(); ++i) {
94 locale_string_to_id_[model->locales()->Get(i)->str()] = i;
95 }
96 }
97
98 if (model->default_locales() != nullptr) {
99 for (const int locale : *model->default_locales()) {
100 default_locale_ids_.push_back(locale);
101 }
102 }
103
104 use_extractors_for_locating_ = model->use_extractors_for_locating();
105
106 initialized_ = true;
107 }
108
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const109 bool DatetimeParser::Parse(
110 const std::string& input, const int64 reference_time_ms_utc,
111 const std::string& reference_timezone, const std::string& locales,
112 ModeFlag mode, bool anchor_start_end,
113 std::vector<DatetimeParseResultSpan>* results) const {
114 return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
115 reference_time_ms_utc, reference_timezone, locales, mode,
116 anchor_start_end, results);
117 }
118
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules,std::vector<DatetimeParseResultSpan> * found_spans) const119 bool DatetimeParser::FindSpansUsingLocales(
120 const std::vector<int>& locale_ids, const UnicodeText& input,
121 const int64 reference_time_ms_utc, const std::string& reference_timezone,
122 ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
123 std::unordered_set<int>* executed_rules,
124 std::vector<DatetimeParseResultSpan>* found_spans) const {
125 for (const int locale_id : locale_ids) {
126 auto rules_it = locale_to_rules_.find(locale_id);
127 if (rules_it == locale_to_rules_.end()) {
128 continue;
129 }
130
131 for (const int rule_id : rules_it->second) {
132 // Skip rules that were already executed in previous locales.
133 if (executed_rules->find(rule_id) != executed_rules->end()) {
134 continue;
135 }
136
137 if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
138 continue;
139 }
140
141 executed_rules->insert(rule_id);
142
143 if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
144 reference_timezone, reference_locale, locale_id,
145 anchor_start_end, found_spans)) {
146 return false;
147 }
148 }
149 }
150 return true;
151 }
152
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const153 bool DatetimeParser::Parse(
154 const UnicodeText& input, const int64 reference_time_ms_utc,
155 const std::string& reference_timezone, const std::string& locales,
156 ModeFlag mode, bool anchor_start_end,
157 std::vector<DatetimeParseResultSpan>* results) const {
158 std::vector<DatetimeParseResultSpan> found_spans;
159 std::unordered_set<int> executed_rules;
160 std::string reference_locale;
161 const std::vector<int> requested_locales =
162 ParseAndExpandLocales(locales, &reference_locale);
163 if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
164 reference_timezone, mode, anchor_start_end,
165 reference_locale, &executed_rules, &found_spans)) {
166 return false;
167 }
168
169 std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
170 int counter = 0;
171 for (const auto& found_span : found_spans) {
172 indexed_found_spans.push_back({found_span, counter});
173 counter++;
174 }
175
176 // Resolve conflicts by always picking the longer span and breaking ties by
177 // selecting the earlier entry in the list for a given locale.
178 std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
179 [](const std::pair<DatetimeParseResultSpan, int>& a,
180 const std::pair<DatetimeParseResultSpan, int>& b) {
181 if ((a.first.span.second - a.first.span.first) !=
182 (b.first.span.second - b.first.span.first)) {
183 return (a.first.span.second - a.first.span.first) >
184 (b.first.span.second - b.first.span.first);
185 } else {
186 return a.second < b.second;
187 }
188 });
189
190 found_spans.clear();
191 for (auto& span_index_pair : indexed_found_spans) {
192 found_spans.push_back(span_index_pair.first);
193 }
194
195 std::set<int, std::function<bool(int, int)>> chosen_indices_set(
196 [&found_spans](int a, int b) {
197 return found_spans[a].span.first < found_spans[b].span.first;
198 });
199 for (int i = 0; i < found_spans.size(); ++i) {
200 if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
201 chosen_indices_set.insert(i);
202 results->push_back(found_spans[i]);
203 }
204 }
205
206 return true;
207 }
208
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResultSpan> * result) const209 bool DatetimeParser::HandleParseMatch(
210 const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
211 int64 reference_time_ms_utc, const std::string& reference_timezone,
212 const std::string& reference_locale, int locale_id,
213 std::vector<DatetimeParseResultSpan>* result) const {
214 int status = UniLib::RegexMatcher::kNoError;
215 const int start = matcher.Start(&status);
216 if (status != UniLib::RegexMatcher::kNoError) {
217 return false;
218 }
219
220 const int end = matcher.End(&status);
221 if (status != UniLib::RegexMatcher::kNoError) {
222 return false;
223 }
224
225 DatetimeParseResultSpan parse_result;
226 if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
227 reference_locale, locale_id, &(parse_result.data),
228 &parse_result.span)) {
229 return false;
230 }
231 if (!use_extractors_for_locating_) {
232 parse_result.span = {start, end};
233 }
234 if (parse_result.span.first != kInvalidIndex &&
235 parse_result.span.second != kInvalidIndex) {
236 parse_result.target_classification_score =
237 rule.pattern->target_classification_score();
238 parse_result.priority_score = rule.pattern->priority_score();
239 result->push_back(parse_result);
240 }
241 return true;
242 }
243
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * result) const244 bool DatetimeParser::ParseWithRule(
245 const CompiledRule& rule, const UnicodeText& input,
246 const int64 reference_time_ms_utc, const std::string& reference_timezone,
247 const std::string& reference_locale, const int locale_id,
248 bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
249 std::unique_ptr<UniLib::RegexMatcher> matcher =
250 rule.compiled_regex->Matcher(input);
251 int status = UniLib::RegexMatcher::kNoError;
252 if (anchor_start_end) {
253 if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
254 if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
255 reference_timezone, reference_locale, locale_id,
256 result)) {
257 return false;
258 }
259 }
260 } else {
261 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
262 if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
263 reference_timezone, reference_locale, locale_id,
264 result)) {
265 return false;
266 }
267 }
268 }
269 return true;
270 }
271
ParseAndExpandLocales(const std::string & locales,std::string * reference_locale) const272 std::vector<int> DatetimeParser::ParseAndExpandLocales(
273 const std::string& locales, std::string* reference_locale) const {
274 std::vector<StringPiece> split_locales = strings::Split(locales, ',');
275 if (!split_locales.empty()) {
276 *reference_locale = split_locales[0].ToString();
277 } else {
278 *reference_locale = "";
279 }
280
281 std::vector<int> result;
282 for (const StringPiece& locale_str : split_locales) {
283 auto locale_it = locale_string_to_id_.find(locale_str.ToString());
284 if (locale_it != locale_string_to_id_.end()) {
285 result.push_back(locale_it->second);
286 }
287
288 const Locale locale = Locale::FromBCP47(locale_str.ToString());
289 if (!locale.IsValid()) {
290 continue;
291 }
292
293 const std::string language = locale.Language();
294 const std::string script = locale.Script();
295 const std::string region = locale.Region();
296
297 // First, try adding *-region locale.
298 if (!region.empty()) {
299 locale_it = locale_string_to_id_.find("*-" + region);
300 if (locale_it != locale_string_to_id_.end()) {
301 result.push_back(locale_it->second);
302 }
303 }
304 // Second, try adding language-script-* locale.
305 if (!script.empty()) {
306 locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
307 if (locale_it != locale_string_to_id_.end()) {
308 result.push_back(locale_it->second);
309 }
310 }
311 // Third, try adding language-* locale.
312 if (!language.empty()) {
313 locale_it = locale_string_to_id_.find(language + "-*");
314 if (locale_it != locale_string_to_id_.end()) {
315 result.push_back(locale_it->second);
316 }
317 }
318 }
319
320 // Add the default locales if they haven't been added already.
321 const std::unordered_set<int> result_set(result.begin(), result.end());
322 for (const int default_locale_id : default_locale_ids_) {
323 if (result_set.find(default_locale_id) == result_set.end()) {
324 result.push_back(default_locale_id);
325 }
326 }
327
328 return result;
329 }
330
331 namespace {
332
GetGranularity(const DateParseData & data)333 DatetimeGranularity GetGranularity(const DateParseData& data) {
334 DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
335 if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
336 (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
337 (data.relation_type == DateParseData::RelationType::YEAR))) {
338 granularity = DatetimeGranularity::GRANULARITY_YEAR;
339 }
340 if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
341 (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
342 (data.relation_type == DateParseData::RelationType::MONTH))) {
343 granularity = DatetimeGranularity::GRANULARITY_MONTH;
344 }
345 if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
346 (data.relation_type == DateParseData::RelationType::WEEK)) {
347 granularity = DatetimeGranularity::GRANULARITY_WEEK;
348 }
349 if (data.field_set_mask & DateParseData::DAY_FIELD ||
350 (data.field_set_mask & DateParseData::RELATION_FIELD &&
351 (data.relation == DateParseData::Relation::NOW ||
352 data.relation == DateParseData::Relation::TOMORROW ||
353 data.relation == DateParseData::Relation::YESTERDAY)) ||
354 (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
355 (data.relation_type == DateParseData::RelationType::MONDAY ||
356 data.relation_type == DateParseData::RelationType::TUESDAY ||
357 data.relation_type == DateParseData::RelationType::WEDNESDAY ||
358 data.relation_type == DateParseData::RelationType::THURSDAY ||
359 data.relation_type == DateParseData::RelationType::FRIDAY ||
360 data.relation_type == DateParseData::RelationType::SATURDAY ||
361 data.relation_type == DateParseData::RelationType::SUNDAY ||
362 data.relation_type == DateParseData::RelationType::DAY))) {
363 granularity = DatetimeGranularity::GRANULARITY_DAY;
364 }
365 if (data.field_set_mask & DateParseData::HOUR_FIELD) {
366 granularity = DatetimeGranularity::GRANULARITY_HOUR;
367 }
368 if (data.field_set_mask & DateParseData::MINUTE_FIELD) {
369 granularity = DatetimeGranularity::GRANULARITY_MINUTE;
370 }
371 if (data.field_set_mask & DateParseData::SECOND_FIELD) {
372 granularity = DatetimeGranularity::GRANULARITY_SECOND;
373 }
374 return granularity;
375 }
376
377 } // namespace
378
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,DatetimeParseResult * result,CodepointSpan * result_span) const379 bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
380 const UniLib::RegexMatcher& matcher,
381 const int64 reference_time_ms_utc,
382 const std::string& reference_timezone,
383 const std::string& reference_locale,
384 int locale_id, DatetimeParseResult* result,
385 CodepointSpan* result_span) const {
386 DateParseData parse;
387 DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
388 extractor_rules_,
389 type_and_locale_to_extractor_rule_);
390 if (!extractor.Extract(&parse, result_span)) {
391 return false;
392 }
393
394 result->granularity = GetGranularity(parse);
395
396 if (!calendar_lib_.InterpretParseData(
397 parse, reference_time_ms_utc, reference_timezone, reference_locale,
398 result->granularity, &(result->time_ms_utc))) {
399 return false;
400 }
401
402 return true;
403 }
404
405 } // namespace libtextclassifier2
406