1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/regex-actions.h"
18
19 #include "actions/utils.h"
20 #include "utils/base/logging.h"
21 #include "utils/regex-match.h"
22 #include "utils/utf8/unicodetext.h"
23 #include "utils/zlib/zlib_regex.h"
24
25 namespace libtextclassifier3 {
26 namespace {
27
28 // Creates an annotation from a regex capturing group.
FillAnnotationFromMatchGroup(const UniLib::RegexMatcher * matcher,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const std::string & group_match_text,const int message_index,ActionSuggestionAnnotation * annotation)29 bool FillAnnotationFromMatchGroup(
30 const UniLib::RegexMatcher* matcher,
31 const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
32 const std::string& group_match_text, const int message_index,
33 ActionSuggestionAnnotation* annotation) {
34 if (group->annotation_name() != nullptr ||
35 group->annotation_type() != nullptr) {
36 int status = UniLib::RegexMatcher::kNoError;
37 const CodepointSpan span = {matcher->Start(group->group_id(), &status),
38 matcher->End(group->group_id(), &status)};
39 if (status != UniLib::RegexMatcher::kNoError) {
40 TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
41 return false;
42 }
43 return FillAnnotationFromCapturingMatch(span, group, message_index,
44 group_match_text, annotation);
45 }
46 return true;
47 }
48
49 } // namespace
50
InitializeRules(const RulesModel * rules,const RulesModel * low_confidence_rules,const TriggeringPreconditions * triggering_preconditions_overlay,ZlibDecompressor * decompressor)51 bool RegexActions::InitializeRules(
52 const RulesModel* rules, const RulesModel* low_confidence_rules,
53 const TriggeringPreconditions* triggering_preconditions_overlay,
54 ZlibDecompressor* decompressor) {
55 if (rules != nullptr) {
56 if (!InitializeRulesModel(rules, decompressor, &rules_)) {
57 TC3_LOG(ERROR) << "Could not initialize action rules.";
58 return false;
59 }
60 }
61
62 if (low_confidence_rules != nullptr) {
63 if (!InitializeRulesModel(low_confidence_rules, decompressor,
64 &low_confidence_rules_)) {
65 TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
66 return false;
67 }
68 }
69
70 // Extend by rules provided by the overwrite.
71 // NOTE: The rules from the original models are *not* cleared.
72 if (triggering_preconditions_overlay != nullptr &&
73 triggering_preconditions_overlay->low_confidence_rules() != nullptr) {
74 // These rules are optionally compressed, but separately.
75 std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
76 ZlibDecompressor::Instance();
77 if (overwrite_decompressor == nullptr) {
78 TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
79 return false;
80 }
81 if (!InitializeRulesModel(
82 triggering_preconditions_overlay->low_confidence_rules(),
83 overwrite_decompressor.get(), &low_confidence_rules_)) {
84 TC3_LOG(ERROR)
85 << "Could not initialize low confidence rules from overwrite.";
86 return false;
87 }
88 }
89
90 return true;
91 }
92
InitializeRulesModel(const RulesModel * rules,ZlibDecompressor * decompressor,std::vector<CompiledRule> * compiled_rules) const93 bool RegexActions::InitializeRulesModel(
94 const RulesModel* rules, ZlibDecompressor* decompressor,
95 std::vector<CompiledRule>* compiled_rules) const {
96 if (rules->regex_rule() == nullptr) {
97 return true;
98 }
99 for (const RulesModel_::RegexRule* rule : *rules->regex_rule()) {
100 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
101 UncompressMakeRegexPattern(
102 unilib_, rule->pattern(), rule->compressed_pattern(),
103 rules->lazy_regex_compilation(), decompressor);
104 if (compiled_pattern == nullptr) {
105 TC3_LOG(ERROR) << "Failed to load rule pattern.";
106 return false;
107 }
108
109 // Check whether there is a check on the output.
110 std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
111 if (rule->output_pattern() != nullptr ||
112 rule->compressed_output_pattern() != nullptr) {
113 compiled_output_pattern = UncompressMakeRegexPattern(
114 unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
115 rules->lazy_regex_compilation(), decompressor);
116 if (compiled_output_pattern == nullptr) {
117 TC3_LOG(ERROR) << "Failed to load rule output pattern.";
118 return false;
119 }
120 }
121
122 compiled_rules->emplace_back(rule, std::move(compiled_pattern),
123 std::move(compiled_output_pattern));
124 }
125
126 return true;
127 }
128
IsLowConfidenceInput(const Conversation & conversation,const int num_messages,std::vector<const UniLib::RegexPattern * > * post_check_rules) const129 bool RegexActions::IsLowConfidenceInput(
130 const Conversation& conversation, const int num_messages,
131 std::vector<const UniLib::RegexPattern*>* post_check_rules) const {
132 for (int i = 1; i <= num_messages; i++) {
133 const std::string& message =
134 conversation.messages[conversation.messages.size() - i].text;
135 const UnicodeText message_unicode(
136 UTF8ToUnicodeText(message, /*do_copy=*/false));
137 for (int low_confidence_rule = 0;
138 low_confidence_rule < low_confidence_rules_.size();
139 low_confidence_rule++) {
140 const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
141 const std::unique_ptr<UniLib::RegexMatcher> matcher =
142 rule.pattern->Matcher(message_unicode);
143 int status = UniLib::RegexMatcher::kNoError;
144 if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
145 // Rule only applies to input-output pairs, so defer the check.
146 if (rule.output_pattern != nullptr) {
147 post_check_rules->push_back(rule.output_pattern.get());
148 continue;
149 }
150 return true;
151 }
152 }
153 }
154 return false;
155 }
156
FilterConfidenceOutput(const std::vector<const UniLib::RegexPattern * > & post_check_rules,std::vector<ActionSuggestion> * actions) const157 bool RegexActions::FilterConfidenceOutput(
158 const std::vector<const UniLib::RegexPattern*>& post_check_rules,
159 std::vector<ActionSuggestion>* actions) const {
160 if (post_check_rules.empty() || actions->empty()) {
161 return true;
162 }
163 std::vector<ActionSuggestion> filtered_text_replies;
164 for (const ActionSuggestion& action : *actions) {
165 if (action.response_text.empty()) {
166 filtered_text_replies.push_back(action);
167 continue;
168 }
169 bool passes_post_check = true;
170 const UnicodeText text_reply_unicode(
171 UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
172 for (const UniLib::RegexPattern* post_check_rule : post_check_rules) {
173 const std::unique_ptr<UniLib::RegexMatcher> matcher =
174 post_check_rule->Matcher(text_reply_unicode);
175 if (matcher == nullptr) {
176 TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
177 return false;
178 }
179 int status = UniLib::RegexMatcher::kNoError;
180 if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
181 passes_post_check = false;
182 break;
183 }
184 }
185 if (passes_post_check) {
186 filtered_text_replies.push_back(action);
187 }
188 }
189 *actions = std::move(filtered_text_replies);
190 return true;
191 }
192
SuggestActions(const Conversation & conversation,const MutableFlatbufferBuilder * entity_data_builder,std::vector<ActionSuggestion> * actions) const193 bool RegexActions::SuggestActions(
194 const Conversation& conversation,
195 const MutableFlatbufferBuilder* entity_data_builder,
196 std::vector<ActionSuggestion>* actions) const {
197 // Create actions based on rules checking the last message.
198 const int message_index = conversation.messages.size() - 1;
199 const std::string& message = conversation.messages.back().text;
200 const UnicodeText message_unicode(
201 UTF8ToUnicodeText(message, /*do_copy=*/false));
202 for (const CompiledRule& rule : rules_) {
203 const std::unique_ptr<UniLib::RegexMatcher> matcher =
204 rule.pattern->Matcher(message_unicode);
205 int status = UniLib::RegexMatcher::kNoError;
206 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
207 for (const RulesModel_::RuleActionSpec* rule_action :
208 *rule.rule->actions()) {
209 const ActionSuggestionSpec* action = rule_action->action();
210 std::vector<ActionSuggestionAnnotation> annotations;
211
212 std::unique_ptr<MutableFlatbuffer> entity_data =
213 entity_data_builder != nullptr ? entity_data_builder->NewRoot()
214 : nullptr;
215
216 // Add entity data from rule capturing groups.
217 if (rule_action->capturing_group() != nullptr) {
218 for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
219 *rule_action->capturing_group()) {
220 Optional<std::string> group_match_text =
221 GetCapturingGroupText(matcher.get(), group->group_id());
222 if (!group_match_text.has_value()) {
223 // The group was not part of the match, ignore and continue.
224 continue;
225 }
226
227 UnicodeText normalized_group_match_text =
228 NormalizeMatchText(unilib_, group, group_match_text.value());
229
230 if (!MergeEntityDataFromCapturingMatch(
231 group, normalized_group_match_text.ToUTF8String(),
232 entity_data.get())) {
233 TC3_LOG(ERROR)
234 << "Could not merge entity data from a capturing match.";
235 return false;
236 }
237
238 // Create a text annotation for the group span.
239 ActionSuggestionAnnotation annotation;
240 if (FillAnnotationFromMatchGroup(matcher.get(), group,
241 group_match_text.value(),
242 message_index, &annotation)) {
243 annotations.push_back(annotation);
244 }
245
246 // Create text reply.
247 SuggestTextRepliesFromCapturingMatch(
248 entity_data_builder, group, normalized_group_match_text,
249 smart_reply_action_type_, actions);
250 }
251 }
252
253 if (action != nullptr) {
254 ActionSuggestion suggestion;
255 suggestion.annotations = annotations;
256 FillSuggestionFromSpec(action, entity_data.get(), &suggestion);
257 actions->push_back(suggestion);
258 }
259 }
260 }
261 }
262 return true;
263 }
264
265 } // namespace libtextclassifier3
266