• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/regex-actions.h"
18 
19 #include "actions/utils.h"
20 #include "utils/base/logging.h"
21 #include "utils/regex-match.h"
22 #include "utils/utf8/unicodetext.h"
23 #include "utils/zlib/zlib_regex.h"
24 
25 namespace libtextclassifier3 {
26 namespace {
27 
28 // Creates an annotation from a regex capturing group.
FillAnnotationFromMatchGroup(const UniLib::RegexMatcher * matcher,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const std::string & group_match_text,const int message_index,ActionSuggestionAnnotation * annotation)29 bool FillAnnotationFromMatchGroup(
30     const UniLib::RegexMatcher* matcher,
31     const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
32     const std::string& group_match_text, const int message_index,
33     ActionSuggestionAnnotation* annotation) {
34   if (group->annotation_name() != nullptr ||
35       group->annotation_type() != nullptr) {
36     int status = UniLib::RegexMatcher::kNoError;
37     const CodepointSpan span = {matcher->Start(group->group_id(), &status),
38                                 matcher->End(group->group_id(), &status)};
39     if (status != UniLib::RegexMatcher::kNoError) {
40       TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
41       return false;
42     }
43     return FillAnnotationFromCapturingMatch(span, group, message_index,
44                                             group_match_text, annotation);
45   }
46   return true;
47 }
48 
49 }  // namespace
50 
InitializeRules(const RulesModel * rules,const RulesModel * low_confidence_rules,const TriggeringPreconditions * triggering_preconditions_overlay,ZlibDecompressor * decompressor)51 bool RegexActions::InitializeRules(
52     const RulesModel* rules, const RulesModel* low_confidence_rules,
53     const TriggeringPreconditions* triggering_preconditions_overlay,
54     ZlibDecompressor* decompressor) {
55   if (rules != nullptr) {
56     if (!InitializeRulesModel(rules, decompressor, &rules_)) {
57       TC3_LOG(ERROR) << "Could not initialize action rules.";
58       return false;
59     }
60   }
61 
62   if (low_confidence_rules != nullptr) {
63     if (!InitializeRulesModel(low_confidence_rules, decompressor,
64                               &low_confidence_rules_)) {
65       TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
66       return false;
67     }
68   }
69 
70   // Extend by rules provided by the overwrite.
71   // NOTE: The rules from the original models are *not* cleared.
72   if (triggering_preconditions_overlay != nullptr &&
73       triggering_preconditions_overlay->low_confidence_rules() != nullptr) {
74     // These rules are optionally compressed, but separately.
75     std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
76         ZlibDecompressor::Instance();
77     if (overwrite_decompressor == nullptr) {
78       TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
79       return false;
80     }
81     if (!InitializeRulesModel(
82             triggering_preconditions_overlay->low_confidence_rules(),
83             overwrite_decompressor.get(), &low_confidence_rules_)) {
84       TC3_LOG(ERROR)
85           << "Could not initialize low confidence rules from overwrite.";
86       return false;
87     }
88   }
89 
90   return true;
91 }
92 
InitializeRulesModel(const RulesModel * rules,ZlibDecompressor * decompressor,std::vector<CompiledRule> * compiled_rules) const93 bool RegexActions::InitializeRulesModel(
94     const RulesModel* rules, ZlibDecompressor* decompressor,
95     std::vector<CompiledRule>* compiled_rules) const {
96   for (const RulesModel_::RegexRule* rule : *rules->regex_rule()) {
97     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
98         UncompressMakeRegexPattern(
99             unilib_, rule->pattern(), rule->compressed_pattern(),
100             rules->lazy_regex_compilation(), decompressor);
101     if (compiled_pattern == nullptr) {
102       TC3_LOG(ERROR) << "Failed to load rule pattern.";
103       return false;
104     }
105 
106     // Check whether there is a check on the output.
107     std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
108     if (rule->output_pattern() != nullptr ||
109         rule->compressed_output_pattern() != nullptr) {
110       compiled_output_pattern = UncompressMakeRegexPattern(
111           unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
112           rules->lazy_regex_compilation(), decompressor);
113       if (compiled_output_pattern == nullptr) {
114         TC3_LOG(ERROR) << "Failed to load rule output pattern.";
115         return false;
116       }
117     }
118 
119     compiled_rules->emplace_back(rule, std::move(compiled_pattern),
120                                  std::move(compiled_output_pattern));
121   }
122 
123   return true;
124 }
125 
IsLowConfidenceInput(const Conversation & conversation,const int num_messages,std::vector<const UniLib::RegexPattern * > * post_check_rules) const126 bool RegexActions::IsLowConfidenceInput(
127     const Conversation& conversation, const int num_messages,
128     std::vector<const UniLib::RegexPattern*>* post_check_rules) const {
129   for (int i = 1; i <= num_messages; i++) {
130     const std::string& message =
131         conversation.messages[conversation.messages.size() - i].text;
132     const UnicodeText message_unicode(
133         UTF8ToUnicodeText(message, /*do_copy=*/false));
134     for (int low_confidence_rule = 0;
135          low_confidence_rule < low_confidence_rules_.size();
136          low_confidence_rule++) {
137       const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
138       const std::unique_ptr<UniLib::RegexMatcher> matcher =
139           rule.pattern->Matcher(message_unicode);
140       int status = UniLib::RegexMatcher::kNoError;
141       if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
142         // Rule only applies to input-output pairs, so defer the check.
143         if (rule.output_pattern != nullptr) {
144           post_check_rules->push_back(rule.output_pattern.get());
145           continue;
146         }
147         return true;
148       }
149     }
150   }
151   return false;
152 }
153 
FilterConfidenceOutput(const std::vector<const UniLib::RegexPattern * > & post_check_rules,std::vector<ActionSuggestion> * actions) const154 bool RegexActions::FilterConfidenceOutput(
155     const std::vector<const UniLib::RegexPattern*>& post_check_rules,
156     std::vector<ActionSuggestion>* actions) const {
157   if (post_check_rules.empty() || actions->empty()) {
158     return true;
159   }
160   std::vector<ActionSuggestion> filtered_text_replies;
161   for (const ActionSuggestion& action : *actions) {
162     if (action.response_text.empty()) {
163       filtered_text_replies.push_back(action);
164       continue;
165     }
166     bool passes_post_check = true;
167     const UnicodeText text_reply_unicode(
168         UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
169     for (const UniLib::RegexPattern* post_check_rule : post_check_rules) {
170       const std::unique_ptr<UniLib::RegexMatcher> matcher =
171           post_check_rule->Matcher(text_reply_unicode);
172       if (matcher == nullptr) {
173         TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
174         return false;
175       }
176       int status = UniLib::RegexMatcher::kNoError;
177       if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
178         passes_post_check = false;
179         break;
180       }
181     }
182     if (passes_post_check) {
183       filtered_text_replies.push_back(action);
184     }
185   }
186   *actions = std::move(filtered_text_replies);
187   return true;
188 }
189 
SuggestActions(const Conversation & conversation,const ReflectiveFlatbufferBuilder * entity_data_builder,std::vector<ActionSuggestion> * actions) const190 bool RegexActions::SuggestActions(
191     const Conversation& conversation,
192     const ReflectiveFlatbufferBuilder* entity_data_builder,
193     std::vector<ActionSuggestion>* actions) const {
194   // Create actions based on rules checking the last message.
195   const int message_index = conversation.messages.size() - 1;
196   const std::string& message = conversation.messages.back().text;
197   const UnicodeText message_unicode(
198       UTF8ToUnicodeText(message, /*do_copy=*/false));
199   for (const CompiledRule& rule : rules_) {
200     const std::unique_ptr<UniLib::RegexMatcher> matcher =
201         rule.pattern->Matcher(message_unicode);
202     int status = UniLib::RegexMatcher::kNoError;
203     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
204       for (const RulesModel_::RuleActionSpec* rule_action :
205            *rule.rule->actions()) {
206         const ActionSuggestionSpec* action = rule_action->action();
207         std::vector<ActionSuggestionAnnotation> annotations;
208 
209         std::unique_ptr<ReflectiveFlatbuffer> entity_data =
210             entity_data_builder != nullptr ? entity_data_builder->NewRoot()
211                                            : nullptr;
212 
213         // Add entity data from rule capturing groups.
214         if (rule_action->capturing_group() != nullptr) {
215           for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
216                *rule_action->capturing_group()) {
217             Optional<std::string> group_match_text =
218                 GetCapturingGroupText(matcher.get(), group->group_id());
219             if (!group_match_text.has_value()) {
220               // The group was not part of the match, ignore and continue.
221               continue;
222             }
223 
224             UnicodeText normalized_group_match_text =
225                 NormalizeMatchText(unilib_, group, group_match_text.value());
226 
227             if (!MergeEntityDataFromCapturingMatch(
228                     group, normalized_group_match_text.ToUTF8String(),
229                     entity_data.get())) {
230               TC3_LOG(ERROR)
231                   << "Could not merge entity data from a capturing match.";
232               return false;
233             }
234 
235             // Create a text annotation for the group span.
236             ActionSuggestionAnnotation annotation;
237             if (FillAnnotationFromMatchGroup(matcher.get(), group,
238                                              group_match_text.value(),
239                                              message_index, &annotation)) {
240               annotations.push_back(annotation);
241             }
242 
243             // Create text reply.
244             SuggestTextRepliesFromCapturingMatch(
245                 entity_data_builder, group, normalized_group_match_text,
246                 smart_reply_action_type_, actions);
247           }
248         }
249 
250         if (action != nullptr) {
251           ActionSuggestion suggestion;
252           suggestion.annotations = annotations;
253           FillSuggestionFromSpec(action, entity_data.get(), &suggestion);
254           actions->push_back(suggestion);
255         }
256       }
257     }
258   }
259   return true;
260 }
261 
262 }  // namespace libtextclassifier3
263