1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/datetime/grammar-parser.h"
18
19 #include <set>
20 #include <unordered_set>
21
22 #include "annotator/datetime/datetime-grounder.h"
23 #include "annotator/model_generated.h"
24 #include "annotator/types.h"
25 #include "utils/grammar/analyzer.h"
26 #include "utils/grammar/evaluated-derivation.h"
27 #include "utils/grammar/parsing/derivation.h"
28
29 using ::libtextclassifier3::grammar::EvaluatedDerivation;
30 using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
31
32 namespace libtextclassifier3 {
33
GrammarDatetimeParser(const grammar::Analyzer & analyzer,const DatetimeGrounder & datetime_grounder,const float target_classification_score,const float priority_score,ModeFlag enabled_modes)34 GrammarDatetimeParser::GrammarDatetimeParser(
35 const grammar::Analyzer& analyzer,
36 const DatetimeGrounder& datetime_grounder,
37 const float target_classification_score, const float priority_score,
38 ModeFlag enabled_modes)
39 : analyzer_(analyzer),
40 datetime_grounder_(datetime_grounder),
41 target_classification_score_(target_classification_score),
42 priority_score_(priority_score),
43 enabled_modes_(enabled_modes) {}
44
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const45 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
46 const std::string& input, const int64 reference_time_ms_utc,
47 const std::string& reference_timezone, const LocaleList& locale_list,
48 ModeFlag mode, AnnotationUsecase annotation_usecase,
49 bool anchor_start_end) const {
50 return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
51 reference_time_ms_utc, reference_timezone, locale_list, mode,
52 annotation_usecase, anchor_start_end);
53 }
54
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const55 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
56 const UnicodeText& input, const int64 reference_time_ms_utc,
57 const std::string& reference_timezone, const LocaleList& locale_list,
58 ModeFlag mode, AnnotationUsecase annotation_usecase,
59 bool anchor_start_end) const {
60 if (!(enabled_modes_ & mode)) {
61 return std::vector<DatetimeParseResultSpan>();
62 }
63
64 std::vector<DatetimeParseResultSpan> results;
65 UnsafeArena arena(/*block_size=*/16 << 10);
66 std::vector<Locale> locales = locale_list.GetLocales();
67 // If the locale list is empty then datetime regex expression will still
68 // execute but in grammar based parser the rules are associated with local
69 // and engine will not run if the locale list is empty. In an unlikely
70 // scenario when locale is not mentioned fallback to en-*.
71 if (locales.empty()) {
72 locales.emplace_back(Locale::FromBCP47("en"));
73 }
74 TC3_ASSIGN_OR_RETURN(
75 const std::vector<EvaluatedDerivation> evaluated_derivations,
76 analyzer_.Parse(input, locales, &arena,
77 /*deduplicate_derivations=*/false));
78
79 std::vector<EvaluatedDerivation> valid_evaluated_derivations;
80 for (const EvaluatedDerivation& evaluated_derivation :
81 evaluated_derivations) {
82 if (evaluated_derivation.value) {
83 if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
84 const UngroundedDatetime* ungrounded_datetime =
85 evaluated_derivation.value->Table<UngroundedDatetime>();
86 if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
87 valid_evaluated_derivations.emplace_back(evaluated_derivation);
88 }
89 }
90 }
91 }
92 valid_evaluated_derivations =
93 grammar::DeduplicateDerivations(valid_evaluated_derivations);
94 for (const EvaluatedDerivation& evaluated_derivation :
95 valid_evaluated_derivations) {
96 if (evaluated_derivation.value) {
97 if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
98 const UngroundedDatetime* ungrounded_datetime =
99 evaluated_derivation.value->Table<UngroundedDatetime>();
100 if ((ungrounded_datetime->annotation_usecases() &
101 (1 << annotation_usecase)) == 0) {
102 continue;
103 }
104 const StatusOr<std::vector<DatetimeParseResult>>&
105 datetime_parse_results = datetime_grounder_.Ground(
106 reference_time_ms_utc, reference_timezone,
107 locale_list.GetReferenceLocale(), ungrounded_datetime);
108 TC3_ASSIGN_OR_RETURN(
109 const std::vector<DatetimeParseResult>& parse_datetime,
110 datetime_parse_results);
111 DatetimeParseResultSpan datetime_parse_result_span;
112 datetime_parse_result_span.target_classification_score =
113 target_classification_score_;
114 datetime_parse_result_span.priority_score = priority_score_;
115 datetime_parse_result_span.data.reserve(parse_datetime.size());
116 datetime_parse_result_span.data.insert(
117 datetime_parse_result_span.data.end(), parse_datetime.begin(),
118 parse_datetime.end());
119 datetime_parse_result_span.span =
120 evaluated_derivation.parse_tree->codepoint_span;
121
122 results.emplace_back(datetime_parse_result_span);
123 }
124 }
125 }
126 return results;
127 }
128 } // namespace libtextclassifier3
129