1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/datetime/parser.h"
18
19 #include <set>
20 #include <unordered_set>
21
22 #include "annotator/datetime/extractor.h"
23 #include "annotator/datetime/utils.h"
24 #include "utils/calendar/calendar.h"
25 #include "utils/i18n/locale.h"
26 #include "utils/strings/split.h"
27 #include "utils/zlib/zlib_regex.h"
28
29 namespace libtextclassifier3 {
Instance(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)30 std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
31 const DatetimeModel* model, const UniLib* unilib,
32 const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
33 std::unique_ptr<DatetimeParser> result(
34 new DatetimeParser(model, unilib, calendarlib, decompressor));
35 if (!result->initialized_) {
36 result.reset();
37 }
38 return result;
39 }
40
DatetimeParser(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)41 DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
42 const CalendarLib* calendarlib,
43 ZlibDecompressor* decompressor)
44 : unilib_(*unilib), calendarlib_(*calendarlib) {
45 initialized_ = false;
46
47 if (model == nullptr) {
48 return;
49 }
50
51 if (model->patterns() != nullptr) {
52 for (const DatetimeModelPattern* pattern : *model->patterns()) {
53 if (pattern->regexes()) {
54 for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
55 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
56 UncompressMakeRegexPattern(
57 unilib_, regex->pattern(), regex->compressed_pattern(),
58 model->lazy_regex_compilation(), decompressor);
59 if (!regex_pattern) {
60 TC3_LOG(ERROR) << "Couldn't create rule pattern.";
61 return;
62 }
63 rules_.push_back({std::move(regex_pattern), regex, pattern});
64 if (pattern->locales()) {
65 for (int locale : *pattern->locales()) {
66 locale_to_rules_[locale].push_back(rules_.size() - 1);
67 }
68 }
69 }
70 }
71 }
72 }
73
74 if (model->extractors() != nullptr) {
75 for (const DatetimeModelExtractor* extractor : *model->extractors()) {
76 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
77 UncompressMakeRegexPattern(
78 unilib_, extractor->pattern(), extractor->compressed_pattern(),
79 model->lazy_regex_compilation(), decompressor);
80 if (!regex_pattern) {
81 TC3_LOG(ERROR) << "Couldn't create extractor pattern";
82 return;
83 }
84 extractor_rules_.push_back(std::move(regex_pattern));
85
86 if (extractor->locales()) {
87 for (int locale : *extractor->locales()) {
88 type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
89 extractor_rules_.size() - 1;
90 }
91 }
92 }
93 }
94
95 if (model->locales() != nullptr) {
96 for (int i = 0; i < model->locales()->size(); ++i) {
97 locale_string_to_id_[model->locales()->Get(i)->str()] = i;
98 }
99 }
100
101 if (model->default_locales() != nullptr) {
102 for (const int locale : *model->default_locales()) {
103 default_locale_ids_.push_back(locale);
104 }
105 }
106
107 use_extractors_for_locating_ = model->use_extractors_for_locating();
108 generate_alternative_interpretations_when_ambiguous_ =
109 model->generate_alternative_interpretations_when_ambiguous();
110 prefer_future_for_unspecified_date_ =
111 model->prefer_future_for_unspecified_date();
112
113 initialized_ = true;
114 }
115
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const116 bool DatetimeParser::Parse(
117 const std::string& input, const int64 reference_time_ms_utc,
118 const std::string& reference_timezone, const std::string& locales,
119 ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
120 std::vector<DatetimeParseResultSpan>* results) const {
121 return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
122 reference_time_ms_utc, reference_timezone, locales, mode,
123 annotation_usecase, anchor_start_end, results);
124 }
125
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules,std::vector<DatetimeParseResultSpan> * found_spans) const126 bool DatetimeParser::FindSpansUsingLocales(
127 const std::vector<int>& locale_ids, const UnicodeText& input,
128 const int64 reference_time_ms_utc, const std::string& reference_timezone,
129 ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
130 const std::string& reference_locale,
131 std::unordered_set<int>* executed_rules,
132 std::vector<DatetimeParseResultSpan>* found_spans) const {
133 for (const int locale_id : locale_ids) {
134 auto rules_it = locale_to_rules_.find(locale_id);
135 if (rules_it == locale_to_rules_.end()) {
136 continue;
137 }
138
139 for (const int rule_id : rules_it->second) {
140 // Skip rules that were already executed in previous locales.
141 if (executed_rules->find(rule_id) != executed_rules->end()) {
142 continue;
143 }
144
145 if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
146 (1 << annotation_usecase)) == 0) {
147 continue;
148 }
149
150 if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
151 continue;
152 }
153
154 executed_rules->insert(rule_id);
155
156 if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
157 reference_timezone, reference_locale, locale_id,
158 anchor_start_end, found_spans)) {
159 return false;
160 }
161 }
162 }
163 return true;
164 }
165
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * results) const166 bool DatetimeParser::Parse(
167 const UnicodeText& input, const int64 reference_time_ms_utc,
168 const std::string& reference_timezone, const std::string& locales,
169 ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
170 std::vector<DatetimeParseResultSpan>* results) const {
171 std::vector<DatetimeParseResultSpan> found_spans;
172 std::unordered_set<int> executed_rules;
173 std::string reference_locale;
174 const std::vector<int> requested_locales =
175 ParseAndExpandLocales(locales, &reference_locale);
176 if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
177 reference_timezone, mode, annotation_usecase,
178 anchor_start_end, reference_locale,
179 &executed_rules, &found_spans)) {
180 return false;
181 }
182
183 std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
184 indexed_found_spans.reserve(found_spans.size());
185 for (int i = 0; i < found_spans.size(); i++) {
186 indexed_found_spans.push_back({found_spans[i], i});
187 }
188
189 // Resolve conflicts by always picking the longer span and breaking ties by
190 // selecting the earlier entry in the list for a given locale.
191 std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
192 [](const std::pair<DatetimeParseResultSpan, int>& a,
193 const std::pair<DatetimeParseResultSpan, int>& b) {
194 if ((a.first.span.second - a.first.span.first) !=
195 (b.first.span.second - b.first.span.first)) {
196 return (a.first.span.second - a.first.span.first) >
197 (b.first.span.second - b.first.span.first);
198 } else {
199 return a.second < b.second;
200 }
201 });
202
203 found_spans.clear();
204 for (auto& span_index_pair : indexed_found_spans) {
205 found_spans.push_back(span_index_pair.first);
206 }
207
208 std::set<int, std::function<bool(int, int)>> chosen_indices_set(
209 [&found_spans](int a, int b) {
210 return found_spans[a].span.first < found_spans[b].span.first;
211 });
212 for (int i = 0; i < found_spans.size(); ++i) {
213 if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
214 chosen_indices_set.insert(i);
215 results->push_back(found_spans[i]);
216 }
217 }
218
219 return true;
220 }
221
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResultSpan> * result) const222 bool DatetimeParser::HandleParseMatch(
223 const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
224 int64 reference_time_ms_utc, const std::string& reference_timezone,
225 const std::string& reference_locale, int locale_id,
226 std::vector<DatetimeParseResultSpan>* result) const {
227 int status = UniLib::RegexMatcher::kNoError;
228 const int start = matcher.Start(&status);
229 if (status != UniLib::RegexMatcher::kNoError) {
230 return false;
231 }
232
233 const int end = matcher.End(&status);
234 if (status != UniLib::RegexMatcher::kNoError) {
235 return false;
236 }
237
238 DatetimeParseResultSpan parse_result;
239 std::vector<DatetimeParseResult> alternatives;
240 if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
241 reference_locale, locale_id, &alternatives,
242 &parse_result.span)) {
243 return false;
244 }
245
246 if (!use_extractors_for_locating_) {
247 parse_result.span = {start, end};
248 }
249
250 if (parse_result.span.first != kInvalidIndex &&
251 parse_result.span.second != kInvalidIndex) {
252 parse_result.target_classification_score =
253 rule.pattern->target_classification_score();
254 parse_result.priority_score = rule.pattern->priority_score();
255
256 for (DatetimeParseResult& alternative : alternatives) {
257 parse_result.data.push_back(alternative);
258 }
259 }
260 result->push_back(parse_result);
261 return true;
262 }
263
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end,std::vector<DatetimeParseResultSpan> * result) const264 bool DatetimeParser::ParseWithRule(
265 const CompiledRule& rule, const UnicodeText& input,
266 const int64 reference_time_ms_utc, const std::string& reference_timezone,
267 const std::string& reference_locale, const int locale_id,
268 bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
269 std::unique_ptr<UniLib::RegexMatcher> matcher =
270 rule.compiled_regex->Matcher(input);
271 int status = UniLib::RegexMatcher::kNoError;
272 if (anchor_start_end) {
273 if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
274 if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
275 reference_timezone, reference_locale, locale_id,
276 result)) {
277 return false;
278 }
279 }
280 } else {
281 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
282 if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
283 reference_timezone, reference_locale, locale_id,
284 result)) {
285 return false;
286 }
287 }
288 }
289 return true;
290 }
291
ParseAndExpandLocales(const std::string & locales,std::string * reference_locale) const292 std::vector<int> DatetimeParser::ParseAndExpandLocales(
293 const std::string& locales, std::string* reference_locale) const {
294 std::vector<StringPiece> split_locales = strings::Split(locales, ',');
295 if (!split_locales.empty()) {
296 *reference_locale = split_locales[0].ToString();
297 } else {
298 *reference_locale = "";
299 }
300
301 std::vector<int> result;
302 for (const StringPiece& locale_str : split_locales) {
303 auto locale_it = locale_string_to_id_.find(locale_str.ToString());
304 if (locale_it != locale_string_to_id_.end()) {
305 result.push_back(locale_it->second);
306 }
307
308 const Locale locale = Locale::FromBCP47(locale_str.ToString());
309 if (!locale.IsValid()) {
310 continue;
311 }
312
313 const std::string language = locale.Language();
314 const std::string script = locale.Script();
315 const std::string region = locale.Region();
316
317 // First, try adding *-region locale.
318 if (!region.empty()) {
319 locale_it = locale_string_to_id_.find("*-" + region);
320 if (locale_it != locale_string_to_id_.end()) {
321 result.push_back(locale_it->second);
322 }
323 }
324 // Second, try adding language-script-* locale.
325 if (!script.empty()) {
326 locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
327 if (locale_it != locale_string_to_id_.end()) {
328 result.push_back(locale_it->second);
329 }
330 }
331 // Third, try adding language-* locale.
332 if (!language.empty()) {
333 locale_it = locale_string_to_id_.find(language + "-*");
334 if (locale_it != locale_string_to_id_.end()) {
335 result.push_back(locale_it->second);
336 }
337 }
338 }
339
340 // Add the default locales if they haven't been added already.
341 const std::unordered_set<int> result_set(result.begin(), result.end());
342 for (const int default_locale_id : default_locale_ids_) {
343 if (result_set.find(default_locale_id) == result_set.end()) {
344 result.push_back(default_locale_id);
345 }
346 }
347
348 return result;
349 }
350
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResult> * results,CodepointSpan * result_span) const351 bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
352 const UniLib::RegexMatcher& matcher,
353 const int64 reference_time_ms_utc,
354 const std::string& reference_timezone,
355 const std::string& reference_locale,
356 int locale_id,
357 std::vector<DatetimeParseResult>* results,
358 CodepointSpan* result_span) const {
359 DatetimeParsedData parse;
360 DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
361 extractor_rules_,
362 type_and_locale_to_extractor_rule_);
363 if (!extractor.Extract(&parse, result_span)) {
364 return false;
365 }
366 std::vector<DatetimeParsedData> interpretations;
367 if (generate_alternative_interpretations_when_ambiguous_) {
368 FillInterpretations(parse, calendarlib_.GetGranularity(parse),
369 &interpretations);
370 } else {
371 interpretations.push_back(parse);
372 }
373
374 results->reserve(results->size() + interpretations.size());
375 for (const DatetimeParsedData& interpretation : interpretations) {
376 std::vector<DatetimeComponent> date_components;
377 interpretation.GetDatetimeComponents(&date_components);
378 DatetimeParseResult result;
379 // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM”
380 // which is encoded in the pair of DatetimeParseResult; both
381 // corresponding to the same date, but one corresponding to
382 // “AM” and the other one corresponding to “PM”.
383 // Remove multiple DatetimeParseResult per datetime span,
384 // once the ambiguities/DatetimeComponents are added in the
385 // response. For Details see b/130355975
386 if (!calendarlib_.InterpretParseData(
387 interpretation, reference_time_ms_utc, reference_timezone,
388 reference_locale, prefer_future_for_unspecified_date_,
389 &(result.time_ms_utc), &(result.granularity))) {
390 return false;
391 }
392
393 // Sort the date time units by component type.
394 std::sort(date_components.begin(), date_components.end(),
395 [](DatetimeComponent a, DatetimeComponent b) {
396 return a.component_type > b.component_type;
397 });
398 result.datetime_components.swap(date_components);
399 results->push_back(result);
400 }
401 return true;
402 }
403
404 } // namespace libtextclassifier3
405