1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/grammar-actions.h"
18
19 #include <algorithm>
20 #include <unordered_map>
21
22 #include "actions/feature-processor.h"
23 #include "actions/utils.h"
24 #include "annotator/types.h"
25 #include "utils/grammar/callback-delegate.h"
26 #include "utils/grammar/match.h"
27 #include "utils/grammar/matcher.h"
28 #include "utils/grammar/rules-utils.h"
29 #include "utils/i18n/language-tag_generated.h"
30 #include "utils/utf8/unicodetext.h"
31
32 namespace libtextclassifier3 {
33 namespace {
34
35 class GrammarActionsCallbackDelegate : public grammar::CallbackDelegate {
36 public:
GrammarActionsCallbackDelegate(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules)37 GrammarActionsCallbackDelegate(const UniLib* unilib,
38 const RulesModel_::GrammarRules* grammar_rules)
39 : unilib_(*unilib), grammar_rules_(grammar_rules) {}
40
41 // Handle a grammar rule match in the actions grammar.
MatchFound(const grammar::Match * match,grammar::CallbackId type,int64 value,grammar::Matcher * matcher)42 void MatchFound(const grammar::Match* match, grammar::CallbackId type,
43 int64 value, grammar::Matcher* matcher) override {
44 switch (static_cast<GrammarActions::Callback>(type)) {
45 case GrammarActions::Callback::kActionRuleMatch: {
46 HandleRuleMatch(match, /*rule_id=*/value);
47 return;
48 }
49 default:
50 grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
51 }
52 }
53
54 // Deduplicate, verify and populate actions from grammar matches.
GetActions(const Conversation & conversation,const std::string & smart_reply_action_type,const ReflectiveFlatbufferBuilder * entity_data_builder,std::vector<ActionSuggestion> * action_suggestions) const55 bool GetActions(const Conversation& conversation,
56 const std::string& smart_reply_action_type,
57 const ReflectiveFlatbufferBuilder* entity_data_builder,
58 std::vector<ActionSuggestion>* action_suggestions) const {
59 std::vector<UnicodeText::const_iterator> codepoint_offsets;
60 const UnicodeText message_unicode =
61 UTF8ToUnicodeText(conversation.messages.back().text,
62 /*do_copy=*/false);
63 for (auto it = message_unicode.begin(); it != message_unicode.end(); it++) {
64 codepoint_offsets.push_back(it);
65 }
66 codepoint_offsets.push_back(message_unicode.end());
67 for (const grammar::Derivation& candidate :
68 grammar::DeduplicateDerivations(candidates_)) {
69 // Check that assertions are fulfilled.
70 if (!VerifyAssertions(candidate.match)) {
71 continue;
72 }
73 if (!InstantiateActionsFromMatch(
74 codepoint_offsets,
75 /*message_index=*/conversation.messages.size() - 1,
76 smart_reply_action_type, candidate, entity_data_builder,
77 action_suggestions)) {
78 return false;
79 }
80 }
81 return true;
82 }
83
84 private:
85 // Handles action rule matches.
HandleRuleMatch(const grammar::Match * match,const int64 rule_id)86 void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
87 candidates_.push_back(grammar::Derivation{match, rule_id});
88 }
89
90 // Instantiates action suggestions from verified and deduplicated rule matches
91 // and appends them to the result.
92 // Expects the message as codepoints for text extraction from capturing
93 // matches as well as the index of the message, for correct span production.
InstantiateActionsFromMatch(const std::vector<UnicodeText::const_iterator> & message_codepoint_offsets,int message_index,const std::string & smart_reply_action_type,const grammar::Derivation & candidate,const ReflectiveFlatbufferBuilder * entity_data_builder,std::vector<ActionSuggestion> * result) const94 bool InstantiateActionsFromMatch(
95 const std::vector<UnicodeText::const_iterator>& message_codepoint_offsets,
96 int message_index, const std::string& smart_reply_action_type,
97 const grammar::Derivation& candidate,
98 const ReflectiveFlatbufferBuilder* entity_data_builder,
99 std::vector<ActionSuggestion>* result) const {
100 const RulesModel_::GrammarRules_::RuleMatch* rule_match =
101 grammar_rules_->rule_match()->Get(candidate.rule_id);
102 if (rule_match == nullptr || rule_match->action_id() == nullptr) {
103 TC3_LOG(ERROR) << "No rule action defined.";
104 return false;
105 }
106
107 // Gather active capturing matches.
108 std::unordered_map<uint16, const grammar::Match*> capturing_matches;
109 for (const grammar::MappingMatch* match :
110 grammar::SelectAllOfType<grammar::MappingMatch>(
111 candidate.match, grammar::Match::kMappingMatch)) {
112 capturing_matches[match->id] = match;
113 }
114
115 // Instantiate actions from the rule match.
116 for (const uint16 action_id : *rule_match->action_id()) {
117 const RulesModel_::RuleActionSpec* action_spec =
118 grammar_rules_->actions()->Get(action_id);
119 std::vector<ActionSuggestionAnnotation> annotations;
120
121 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
122 entity_data_builder != nullptr ? entity_data_builder->NewRoot()
123 : nullptr;
124
125 // Set information from capturing matches.
126 if (action_spec->capturing_group() != nullptr) {
127 for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
128 *action_spec->capturing_group()) {
129 auto it = capturing_matches.find(group->group_id());
130 if (it == capturing_matches.end()) {
131 // Capturing match is not active, skip.
132 continue;
133 }
134
135 const grammar::Match* capturing_match = it->second;
136 StringPiece match_text = StringPiece(
137 message_codepoint_offsets[capturing_match->codepoint_span.first]
138 .utf8_data(),
139 message_codepoint_offsets[capturing_match->codepoint_span.second]
140 .utf8_data() -
141 message_codepoint_offsets[capturing_match->codepoint_span
142 .first]
143 .utf8_data());
144 UnicodeText normalized_match_text =
145 NormalizeMatchText(unilib_, group, match_text);
146
147 if (!MergeEntityDataFromCapturingMatch(
148 group, normalized_match_text.ToUTF8String(),
149 entity_data.get())) {
150 TC3_LOG(ERROR)
151 << "Could not merge entity data from a capturing match.";
152 return false;
153 }
154
155 // Add smart reply suggestions.
156 SuggestTextRepliesFromCapturingMatch(entity_data_builder, group,
157 normalized_match_text,
158 smart_reply_action_type, result);
159
160 // Add annotation.
161 ActionSuggestionAnnotation annotation;
162 if (FillAnnotationFromCapturingMatch(
163 /*span=*/capturing_match->codepoint_span, group,
164 /*message_index=*/message_index, match_text, &annotation)) {
165 if (group->use_annotation_match()) {
166 const grammar::AnnotationMatch* annotation_match =
167 grammar::SelectFirstOfType<grammar::AnnotationMatch>(
168 capturing_match, grammar::Match::kAnnotationMatch);
169 if (!annotation_match) {
170 TC3_LOG(ERROR) << "Could not get annotation for match.";
171 return false;
172 }
173 annotation.entity = *annotation_match->annotation;
174 }
175 annotations.push_back(std::move(annotation));
176 }
177 }
178 }
179
180 if (action_spec->action() != nullptr) {
181 ActionSuggestion suggestion;
182 suggestion.annotations = annotations;
183 FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
184 &suggestion);
185 result->push_back(std::move(suggestion));
186 }
187 }
188 return true;
189 }
190
191 const UniLib& unilib_;
192 const RulesModel_::GrammarRules* grammar_rules_;
193
194 // All action rule match candidates.
195 // Grammar rule matches are recorded, deduplicated, verified and then
196 // instantiated.
197 std::vector<grammar::Derivation> candidates_;
198 };
199 } // namespace
200
GrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const ReflectiveFlatbufferBuilder * entity_data_builder,const std::string & smart_reply_action_type)201 GrammarActions::GrammarActions(
202 const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
203 const ReflectiveFlatbufferBuilder* entity_data_builder,
204 const std::string& smart_reply_action_type)
205 : unilib_(*unilib),
206 grammar_rules_(grammar_rules),
207 tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
208 lexer_(unilib, grammar_rules->rules()),
209 entity_data_builder_(entity_data_builder),
210 smart_reply_action_type_(smart_reply_action_type),
211 rules_locales_(ParseRulesLocales(grammar_rules->rules())) {}
212
SuggestActions(const Conversation & conversation,std::vector<ActionSuggestion> * result) const213 bool GrammarActions::SuggestActions(
214 const Conversation& conversation,
215 std::vector<ActionSuggestion>* result) const {
216 if (grammar_rules_->rules()->rules() == nullptr) {
217 // Nothing to do.
218 return true;
219 }
220
221 std::vector<Locale> locales;
222 if (!ParseLocales(conversation.messages.back().detected_text_language_tags,
223 &locales)) {
224 TC3_LOG(ERROR) << "Could not parse locales of input text.";
225 return false;
226 }
227
228 // Select locale matching rules.
229 std::vector<const grammar::RulesSet_::Rules*> locale_rules =
230 SelectLocaleMatchingShards(grammar_rules_->rules(), rules_locales_,
231 locales);
232 if (locale_rules.empty()) {
233 // Nothing to do.
234 return true;
235 }
236
237 GrammarActionsCallbackDelegate callback_handler(&unilib_, grammar_rules_);
238 grammar::Matcher matcher(&unilib_, grammar_rules_->rules(), locale_rules,
239 &callback_handler);
240
241 const UnicodeText text =
242 UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false);
243
244 // Run grammar on last message.
245 lexer_.Process(text, tokenizer_->Tokenize(text),
246 /*annotations=*/&conversation.messages.back().annotations,
247 &matcher);
248
249 // Populate results.
250 return callback_handler.GetActions(conversation, smart_reply_action_type_,
251 entity_data_builder_, result);
252 }
253
254 } // namespace libtextclassifier3
255