• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/datetime/grammar-parser.h"
18 
19 #include <set>
20 #include <unordered_set>
21 
22 #include "annotator/datetime/datetime-grounder.h"
23 #include "annotator/model_generated.h"
24 #include "annotator/types.h"
25 #include "utils/grammar/analyzer.h"
26 #include "utils/grammar/evaluated-derivation.h"
27 #include "utils/grammar/parsing/derivation.h"
28 
29 using ::libtextclassifier3::grammar::EvaluatedDerivation;
30 using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
31 
32 namespace libtextclassifier3 {
33 
GrammarDatetimeParser(const grammar::Analyzer & analyzer,const DatetimeGrounder & datetime_grounder,const float target_classification_score,const float priority_score,ModeFlag enabled_modes)34 GrammarDatetimeParser::GrammarDatetimeParser(
35     const grammar::Analyzer& analyzer,
36     const DatetimeGrounder& datetime_grounder,
37     const float target_classification_score, const float priority_score,
38     ModeFlag enabled_modes)
39     : analyzer_(analyzer),
40       datetime_grounder_(datetime_grounder),
41       target_classification_score_(target_classification_score),
42       priority_score_(priority_score),
43       enabled_modes_(enabled_modes) {}
44 
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const45 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
46     const std::string& input, const int64 reference_time_ms_utc,
47     const std::string& reference_timezone, const LocaleList& locale_list,
48     ModeFlag mode, AnnotationUsecase annotation_usecase,
49     bool anchor_start_end) const {
50   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
51                reference_time_ms_utc, reference_timezone, locale_list, mode,
52                annotation_usecase, anchor_start_end);
53 }
54 
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const55 StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
56     const UnicodeText& input, const int64 reference_time_ms_utc,
57     const std::string& reference_timezone, const LocaleList& locale_list,
58     ModeFlag mode, AnnotationUsecase annotation_usecase,
59     bool anchor_start_end) const {
60   if (!(enabled_modes_ & mode)) {
61     return std::vector<DatetimeParseResultSpan>();
62   }
63 
64   std::vector<DatetimeParseResultSpan> results;
65   UnsafeArena arena(/*block_size=*/16 << 10);
66   std::vector<Locale> locales = locale_list.GetLocales();
67   // If the locale list is empty then datetime regex expression will still
68   // execute but in grammar based parser the rules are associated with local
69   // and engine will not run if the locale list is empty. In an unlikely
70   // scenario when locale is not mentioned fallback to en-*.
71   if (locales.empty()) {
72     locales.emplace_back(Locale::FromBCP47("en"));
73   }
74   TC3_ASSIGN_OR_RETURN(
75       const std::vector<EvaluatedDerivation> evaluated_derivations,
76       analyzer_.Parse(input, locales, &arena,
77                       /*deduplicate_derivations=*/false));
78 
79   std::vector<EvaluatedDerivation> valid_evaluated_derivations;
80   for (const EvaluatedDerivation& evaluated_derivation :
81        evaluated_derivations) {
82     if (evaluated_derivation.value) {
83       if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
84         const UngroundedDatetime* ungrounded_datetime =
85             evaluated_derivation.value->Table<UngroundedDatetime>();
86         if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
87           valid_evaluated_derivations.emplace_back(evaluated_derivation);
88         }
89       }
90     }
91   }
92   valid_evaluated_derivations =
93       grammar::DeduplicateDerivations(valid_evaluated_derivations);
94   for (const EvaluatedDerivation& evaluated_derivation :
95        valid_evaluated_derivations) {
96     if (evaluated_derivation.value) {
97       if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
98         const UngroundedDatetime* ungrounded_datetime =
99             evaluated_derivation.value->Table<UngroundedDatetime>();
100         if ((ungrounded_datetime->annotation_usecases() &
101              (1 << annotation_usecase)) == 0) {
102           continue;
103         }
104         const StatusOr<std::vector<DatetimeParseResult>>&
105             datetime_parse_results = datetime_grounder_.Ground(
106                 reference_time_ms_utc, reference_timezone,
107                 locale_list.GetReferenceLocale(), ungrounded_datetime);
108         TC3_ASSIGN_OR_RETURN(
109             const std::vector<DatetimeParseResult>& parse_datetime,
110             datetime_parse_results);
111         DatetimeParseResultSpan datetime_parse_result_span;
112         datetime_parse_result_span.target_classification_score =
113             target_classification_score_;
114         datetime_parse_result_span.priority_score = priority_score_;
115         datetime_parse_result_span.data.reserve(parse_datetime.size());
116         datetime_parse_result_span.data.insert(
117             datetime_parse_result_span.data.end(), parse_datetime.begin(),
118             parse_datetime.end());
119         datetime_parse_result_span.span =
120             evaluated_derivation.parse_tree->codepoint_span;
121 
122         results.emplace_back(datetime_parse_result_span);
123       }
124     }
125   }
126   return results;
127 }
128 }  // namespace libtextclassifier3
129