1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/datetime/regex-parser.h"
18
19 #include <iterator>
20 #include <set>
21 #include <unordered_set>
22
23 #include "annotator/datetime/extractor.h"
24 #include "annotator/datetime/utils.h"
25 #include "utils/base/statusor.h"
26 #include "utils/calendar/calendar.h"
27 #include "utils/i18n/locale.h"
28 #include "utils/strings/split.h"
29 #include "utils/zlib/zlib_regex.h"
30
31 namespace libtextclassifier3 {
Instance(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)32 std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance(
33 const DatetimeModel* model, const UniLib* unilib,
34 const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
35 std::unique_ptr<RegexDatetimeParser> result(
36 new RegexDatetimeParser(model, unilib, calendarlib, decompressor));
37 if (!result->initialized_) {
38 result.reset();
39 }
40 return result;
41 }
42
RegexDatetimeParser(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)43 RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model,
44 const UniLib* unilib,
45 const CalendarLib* calendarlib,
46 ZlibDecompressor* decompressor)
47 : unilib_(*unilib), calendarlib_(*calendarlib) {
48 initialized_ = false;
49
50 if (model == nullptr) {
51 return;
52 }
53
54 if (model->patterns() != nullptr) {
55 for (const DatetimeModelPattern* pattern : *model->patterns()) {
56 if (pattern->regexes()) {
57 for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
58 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
59 UncompressMakeRegexPattern(
60 unilib_, regex->pattern(), regex->compressed_pattern(),
61 model->lazy_regex_compilation(), decompressor);
62 if (!regex_pattern) {
63 TC3_LOG(ERROR) << "Couldn't create rule pattern.";
64 return;
65 }
66 rules_.push_back({std::move(regex_pattern), regex, pattern});
67 if (pattern->locales()) {
68 for (int locale : *pattern->locales()) {
69 locale_to_rules_[locale].push_back(rules_.size() - 1);
70 }
71 }
72 }
73 }
74 }
75 }
76
77 if (model->extractors() != nullptr) {
78 for (const DatetimeModelExtractor* extractor : *model->extractors()) {
79 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
80 UncompressMakeRegexPattern(
81 unilib_, extractor->pattern(), extractor->compressed_pattern(),
82 model->lazy_regex_compilation(), decompressor);
83 if (!regex_pattern) {
84 TC3_LOG(ERROR) << "Couldn't create extractor pattern";
85 return;
86 }
87 extractor_rules_.push_back(std::move(regex_pattern));
88
89 if (extractor->locales()) {
90 for (int locale : *extractor->locales()) {
91 type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
92 extractor_rules_.size() - 1;
93 }
94 }
95 }
96 }
97
98 if (model->locales() != nullptr) {
99 for (int i = 0; i < model->locales()->size(); ++i) {
100 locale_string_to_id_[model->locales()->Get(i)->str()] = i;
101 }
102 }
103
104 if (model->default_locales() != nullptr) {
105 for (const int locale : *model->default_locales()) {
106 default_locale_ids_.push_back(locale);
107 }
108 }
109
110 use_extractors_for_locating_ = model->use_extractors_for_locating();
111 generate_alternative_interpretations_when_ambiguous_ =
112 model->generate_alternative_interpretations_when_ambiguous();
113 prefer_future_for_unspecified_date_ =
114 model->prefer_future_for_unspecified_date();
115
116 initialized_ = true;
117 }
118
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const119 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
120 const std::string& input, const int64 reference_time_ms_utc,
121 const std::string& reference_timezone, const LocaleList& locale_list,
122 ModeFlag mode, AnnotationUsecase annotation_usecase,
123 bool anchor_start_end) const {
124 return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
125 reference_time_ms_utc, reference_timezone, locale_list, mode,
126 annotation_usecase, anchor_start_end);
127 }
128
129 StatusOr<std::vector<DatetimeParseResultSpan>>
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules) const130 RegexDatetimeParser::FindSpansUsingLocales(
131 const std::vector<int>& locale_ids, const UnicodeText& input,
132 const int64 reference_time_ms_utc, const std::string& reference_timezone,
133 ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
134 const std::string& reference_locale,
135 std::unordered_set<int>* executed_rules) const {
136 std::vector<DatetimeParseResultSpan> found_spans;
137 for (const int locale_id : locale_ids) {
138 auto rules_it = locale_to_rules_.find(locale_id);
139 if (rules_it == locale_to_rules_.end()) {
140 continue;
141 }
142
143 for (const int rule_id : rules_it->second) {
144 // Skip rules that were already executed in previous locales.
145 if (executed_rules->find(rule_id) != executed_rules->end()) {
146 continue;
147 }
148
149 if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
150 (1 << annotation_usecase)) == 0) {
151 continue;
152 }
153
154 if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
155 continue;
156 }
157
158 executed_rules->insert(rule_id);
159 TC3_ASSIGN_OR_RETURN(
160 const std::vector<DatetimeParseResultSpan>& found_spans_per_rule,
161 ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
162 reference_timezone, reference_locale, locale_id,
163 anchor_start_end));
164 found_spans.insert(std::end(found_spans),
165 std::begin(found_spans_per_rule),
166 std::end(found_spans_per_rule));
167 }
168 }
169 return found_spans;
170 }
171
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const172 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
173 const UnicodeText& input, const int64 reference_time_ms_utc,
174 const std::string& reference_timezone, const LocaleList& locale_list,
175 ModeFlag mode, AnnotationUsecase annotation_usecase,
176 bool anchor_start_end) const {
177 std::unordered_set<int> executed_rules;
178 const std::vector<int> requested_locales =
179 ParseAndExpandLocales(locale_list.GetLocaleTags());
180 TC3_ASSIGN_OR_RETURN(
181 const std::vector<DatetimeParseResultSpan>& found_spans,
182 FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
183 reference_timezone, mode, annotation_usecase,
184 anchor_start_end, locale_list.GetReferenceLocale(),
185 &executed_rules));
186 std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
187 indexed_found_spans.reserve(found_spans.size());
188 for (int i = 0; i < found_spans.size(); i++) {
189 indexed_found_spans.push_back({found_spans[i], i});
190 }
191
192 // Resolve conflicts by always picking the longer span and breaking ties by
193 // selecting the earlier entry in the list for a given locale.
194 std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
195 [](const std::pair<DatetimeParseResultSpan, int>& a,
196 const std::pair<DatetimeParseResultSpan, int>& b) {
197 if ((a.first.span.second - a.first.span.first) !=
198 (b.first.span.second - b.first.span.first)) {
199 return (a.first.span.second - a.first.span.first) >
200 (b.first.span.second - b.first.span.first);
201 } else {
202 return a.second < b.second;
203 }
204 });
205
206 std::vector<DatetimeParseResultSpan> results;
207 std::vector<DatetimeParseResultSpan> resolved_found_spans;
208 resolved_found_spans.reserve(indexed_found_spans.size());
209 for (auto& span_index_pair : indexed_found_spans) {
210 resolved_found_spans.push_back(span_index_pair.first);
211 }
212
213 std::set<int, std::function<bool(int, int)>> chosen_indices_set(
214 [&resolved_found_spans](int a, int b) {
215 return resolved_found_spans[a].span.first <
216 resolved_found_spans[b].span.first;
217 });
218 for (int i = 0; i < resolved_found_spans.size(); ++i) {
219 if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) {
220 chosen_indices_set.insert(i);
221 results.push_back(resolved_found_spans[i]);
222 }
223 }
224 return results;
225 }
226
227 StatusOr<std::vector<DatetimeParseResultSpan>>
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id) const228 RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule,
229 const UniLib::RegexMatcher& matcher,
230 int64 reference_time_ms_utc,
231 const std::string& reference_timezone,
232 const std::string& reference_locale,
233 int locale_id) const {
234 std::vector<DatetimeParseResultSpan> results;
235 int status = UniLib::RegexMatcher::kNoError;
236 const int start = matcher.Start(&status);
237 if (status != UniLib::RegexMatcher::kNoError) {
238 return Status(StatusCode::INTERNAL,
239 "Failed to gets the start offset of the last match.");
240 }
241
242 const int end = matcher.End(&status);
243 if (status != UniLib::RegexMatcher::kNoError) {
244 return Status(StatusCode::INTERNAL,
245 "Failed to gets the end offset of the last match.");
246 }
247
248 DatetimeParseResultSpan parse_result;
249 std::vector<DatetimeParseResult> alternatives;
250 if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
251 reference_locale, locale_id, &alternatives,
252 &parse_result.span)) {
253 return Status(StatusCode::INTERNAL, "Failed to extract Datetime.");
254 }
255
256 if (!use_extractors_for_locating_) {
257 parse_result.span = {start, end};
258 }
259
260 if (parse_result.span.first != kInvalidIndex &&
261 parse_result.span.second != kInvalidIndex) {
262 parse_result.target_classification_score =
263 rule.pattern->target_classification_score();
264 parse_result.priority_score = rule.pattern->priority_score();
265
266 for (DatetimeParseResult& alternative : alternatives) {
267 parse_result.data.push_back(alternative);
268 }
269 }
270 results.push_back(parse_result);
271 return results;
272 }
273
274 StatusOr<std::vector<DatetimeParseResultSpan>>
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end) const275 RegexDatetimeParser::ParseWithRule(const CompiledRule& rule,
276 const UnicodeText& input,
277 const int64 reference_time_ms_utc,
278 const std::string& reference_timezone,
279 const std::string& reference_locale,
280 const int locale_id,
281 bool anchor_start_end) const {
282 std::vector<DatetimeParseResultSpan> results;
283 std::unique_ptr<UniLib::RegexMatcher> matcher =
284 rule.compiled_regex->Matcher(input);
285 int status = UniLib::RegexMatcher::kNoError;
286 if (anchor_start_end) {
287 if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
288 return HandleParseMatch(rule, *matcher, reference_time_ms_utc,
289 reference_timezone, reference_locale, locale_id);
290 }
291 } else {
292 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
293 TC3_ASSIGN_OR_RETURN(
294 const std::vector<DatetimeParseResultSpan>& pattern_occurrence,
295 HandleParseMatch(rule, *matcher, reference_time_ms_utc,
296 reference_timezone, reference_locale, locale_id));
297 results.insert(std::end(results), std::begin(pattern_occurrence),
298 std::end(pattern_occurrence));
299 }
300 }
301 return results;
302 }
303
ParseAndExpandLocales(const std::vector<StringPiece> & locales) const304 std::vector<int> RegexDatetimeParser::ParseAndExpandLocales(
305 const std::vector<StringPiece>& locales) const {
306 std::vector<int> result;
307 for (const StringPiece& locale_str : locales) {
308 auto locale_it = locale_string_to_id_.find(locale_str.ToString());
309 if (locale_it != locale_string_to_id_.end()) {
310 result.push_back(locale_it->second);
311 }
312
313 const Locale locale = Locale::FromBCP47(locale_str.ToString());
314 if (!locale.IsValid()) {
315 continue;
316 }
317
318 const std::string language = locale.Language();
319 const std::string script = locale.Script();
320 const std::string region = locale.Region();
321
322 // First, try adding *-region locale.
323 if (!region.empty()) {
324 locale_it = locale_string_to_id_.find("*-" + region);
325 if (locale_it != locale_string_to_id_.end()) {
326 result.push_back(locale_it->second);
327 }
328 }
329 // Second, try adding language-script-* locale.
330 if (!script.empty()) {
331 locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
332 if (locale_it != locale_string_to_id_.end()) {
333 result.push_back(locale_it->second);
334 }
335 }
336 // Third, try adding language-* locale.
337 if (!language.empty()) {
338 locale_it = locale_string_to_id_.find(language + "-*");
339 if (locale_it != locale_string_to_id_.end()) {
340 result.push_back(locale_it->second);
341 }
342 }
343 }
344
345 // Add the default locales if they haven't been added already.
346 const std::unordered_set<int> result_set(result.begin(), result.end());
347 for (const int default_locale_id : default_locale_ids_) {
348 if (result_set.find(default_locale_id) == result_set.end()) {
349 result.push_back(default_locale_id);
350 }
351 }
352
353 return result;
354 }
355
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResult> * results,CodepointSpan * result_span) const356 bool RegexDatetimeParser::ExtractDatetime(
357 const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
358 const int64 reference_time_ms_utc, const std::string& reference_timezone,
359 const std::string& reference_locale, int locale_id,
360 std::vector<DatetimeParseResult>* results,
361 CodepointSpan* result_span) const {
362 DatetimeParsedData parse;
363 DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
364 extractor_rules_,
365 type_and_locale_to_extractor_rule_);
366 if (!extractor.Extract(&parse, result_span)) {
367 return false;
368 }
369 std::vector<DatetimeParsedData> interpretations;
370 if (generate_alternative_interpretations_when_ambiguous_) {
371 FillInterpretations(parse, calendarlib_.GetGranularity(parse),
372 &interpretations);
373 } else {
374 interpretations.push_back(parse);
375 }
376
377 results->reserve(results->size() + interpretations.size());
378 for (const DatetimeParsedData& interpretation : interpretations) {
379 std::vector<DatetimeComponent> date_components;
380 interpretation.GetDatetimeComponents(&date_components);
381 DatetimeParseResult result;
382 // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM”
383 // which is encoded in the pair of DatetimeParseResult; both
384 // corresponding to the same date, but one corresponding to
385 // “AM” and the other one corresponding to “PM”.
386 // Remove multiple DatetimeParseResult per datetime span,
387 // once the ambiguities/DatetimeComponents are added in the
388 // response. For Details see b/130355975
389 if (!calendarlib_.InterpretParseData(
390 interpretation, reference_time_ms_utc, reference_timezone,
391 reference_locale, prefer_future_for_unspecified_date_,
392 &(result.time_ms_utc), &(result.granularity))) {
393 return false;
394 }
395
396 // Sort the date time units by component type.
397 std::sort(date_components.begin(), date_components.end(),
398 [](DatetimeComponent a, DatetimeComponent b) {
399 return a.component_type > b.component_type;
400 });
401 result.datetime_components.swap(date_components);
402 results->push_back(result);
403 }
404 return true;
405 }
406
407 } // namespace libtextclassifier3
408