1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/datetime/grammar-parser.h"
18
19 #include <set>
20 #include <unordered_set>
21
22 #include "annotator/datetime/datetime-grounder.h"
23 #include "annotator/types.h"
24 #include "utils/grammar/analyzer.h"
25 #include "utils/grammar/evaluated-derivation.h"
26 #include "utils/grammar/parsing/derivation.h"
27
28 using ::libtextclassifier3::grammar::EvaluatedDerivation;
29 using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
30
31 namespace libtextclassifier3 {
32
GrammarDatetimeParser(const grammar::Analyzer & analyzer,const DatetimeGrounder & datetime_grounder,const float target_classification_score,const float priority_score)33 GrammarDatetimeParser::GrammarDatetimeParser(
34 const grammar::Analyzer& analyzer,
35 const DatetimeGrounder& datetime_grounder,
36 const float target_classification_score, const float priority_score)
37 : analyzer_(analyzer),
38 datetime_grounder_(datetime_grounder),
39 target_classification_score_(target_classification_score),
40 priority_score_(priority_score) {}
41
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const42 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
43 const std::string& input, const int64 reference_time_ms_utc,
44 const std::string& reference_timezone, const LocaleList& locale_list,
45 ModeFlag mode, AnnotationUsecase annotation_usecase,
46 bool anchor_start_end) const {
47 return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
48 reference_time_ms_utc, reference_timezone, locale_list, mode,
49 annotation_usecase, anchor_start_end);
50 }
51
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const52 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
53 const UnicodeText& input, const int64 reference_time_ms_utc,
54 const std::string& reference_timezone, const LocaleList& locale_list,
55 ModeFlag mode, AnnotationUsecase annotation_usecase,
56 bool anchor_start_end) const {
57 std::vector<DatetimeParseResultSpan> results;
58 UnsafeArena arena(/*block_size=*/16 << 10);
59 std::vector<Locale> locales = locale_list.GetLocales();
60 // If the locale list is empty then datetime regex expression will still
61 // execute but in grammar based parser the rules are associated with local
62 // and engine will not run if the locale list is empty. In an unlikely
63 // scenario when locale is not mentioned fallback to en-*.
64 if (locales.empty()) {
65 locales.emplace_back(Locale::FromBCP47("en"));
66 }
67 TC3_ASSIGN_OR_RETURN(
68 const std::vector<EvaluatedDerivation> evaluated_derivations,
69 analyzer_.Parse(input, locales, &arena,
70 /*deduplicate_derivations=*/false));
71
72 std::vector<EvaluatedDerivation> valid_evaluated_derivations;
73 for (const EvaluatedDerivation& evaluated_derivation :
74 evaluated_derivations) {
75 if (evaluated_derivation.value) {
76 if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
77 const UngroundedDatetime* ungrounded_datetime =
78 evaluated_derivation.value->Table<UngroundedDatetime>();
79 if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
80 valid_evaluated_derivations.emplace_back(evaluated_derivation);
81 }
82 }
83 }
84 }
85 valid_evaluated_derivations =
86 grammar::DeduplicateDerivations(valid_evaluated_derivations);
87 for (const EvaluatedDerivation& evaluated_derivation :
88 valid_evaluated_derivations) {
89 if (evaluated_derivation.value) {
90 if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
91 const UngroundedDatetime* ungrounded_datetime =
92 evaluated_derivation.value->Table<UngroundedDatetime>();
93 if ((ungrounded_datetime->annotation_usecases() &
94 (1 << annotation_usecase)) == 0) {
95 continue;
96 }
97 const StatusOr<std::vector<DatetimeParseResult>>&
98 datetime_parse_results = datetime_grounder_.Ground(
99 reference_time_ms_utc, reference_timezone,
100 locale_list.GetReferenceLocale(), ungrounded_datetime);
101 TC3_ASSIGN_OR_RETURN(
102 const std::vector<DatetimeParseResult>& parse_datetime,
103 datetime_parse_results);
104 DatetimeParseResultSpan datetime_parse_result_span;
105 datetime_parse_result_span.target_classification_score =
106 target_classification_score_;
107 datetime_parse_result_span.priority_score = priority_score_;
108 datetime_parse_result_span.data.reserve(parse_datetime.size());
109 datetime_parse_result_span.data.insert(
110 datetime_parse_result_span.data.end(), parse_datetime.begin(),
111 parse_datetime.end());
112 datetime_parse_result_span.span =
113 evaluated_derivation.parse_tree->codepoint_span;
114
115 results.emplace_back(datetime_parse_result_span);
116 }
117 }
118 }
119 return results;
120 }
121 } // namespace libtextclassifier3
122