• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/grammar-actions.h"
18 
19 #include <algorithm>
20 #include <unordered_map>
21 
22 #include "actions/feature-processor.h"
23 #include "actions/utils.h"
24 #include "annotator/types.h"
25 #include "utils/grammar/callback-delegate.h"
26 #include "utils/grammar/match.h"
27 #include "utils/grammar/matcher.h"
28 #include "utils/grammar/rules-utils.h"
29 #include "utils/i18n/language-tag_generated.h"
30 #include "utils/utf8/unicodetext.h"
31 
32 namespace libtextclassifier3 {
33 namespace {
34 
35 class GrammarActionsCallbackDelegate : public grammar::CallbackDelegate {
36  public:
GrammarActionsCallbackDelegate(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules)37   GrammarActionsCallbackDelegate(const UniLib* unilib,
38                                  const RulesModel_::GrammarRules* grammar_rules)
39       : unilib_(*unilib), grammar_rules_(grammar_rules) {}
40 
41   // Handle a grammar rule match in the actions grammar.
MatchFound(const grammar::Match * match,grammar::CallbackId type,int64 value,grammar::Matcher * matcher)42   void MatchFound(const grammar::Match* match, grammar::CallbackId type,
43                   int64 value, grammar::Matcher* matcher) override {
44     switch (static_cast<GrammarActions::Callback>(type)) {
45       case GrammarActions::Callback::kActionRuleMatch: {
46         HandleRuleMatch(match, /*rule_id=*/value);
47         return;
48       }
49       default:
50         grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
51     }
52   }
53 
54   // Deduplicate, verify and populate actions from grammar matches.
GetActions(const Conversation & conversation,const std::string & smart_reply_action_type,const ReflectiveFlatbufferBuilder * entity_data_builder,std::vector<ActionSuggestion> * action_suggestions) const55   bool GetActions(const Conversation& conversation,
56                   const std::string& smart_reply_action_type,
57                   const ReflectiveFlatbufferBuilder* entity_data_builder,
58                   std::vector<ActionSuggestion>* action_suggestions) const {
59     std::vector<UnicodeText::const_iterator> codepoint_offsets;
60     const UnicodeText message_unicode =
61         UTF8ToUnicodeText(conversation.messages.back().text,
62                           /*do_copy=*/false);
63     for (auto it = message_unicode.begin(); it != message_unicode.end(); it++) {
64       codepoint_offsets.push_back(it);
65     }
66     codepoint_offsets.push_back(message_unicode.end());
67     for (const grammar::Derivation& candidate :
68          grammar::DeduplicateDerivations(candidates_)) {
69       // Check that assertions are fulfilled.
70       if (!VerifyAssertions(candidate.match)) {
71         continue;
72       }
73       if (!InstantiateActionsFromMatch(
74               codepoint_offsets,
75               /*message_index=*/conversation.messages.size() - 1,
76               smart_reply_action_type, candidate, entity_data_builder,
77               action_suggestions)) {
78         return false;
79       }
80     }
81     return true;
82   }
83 
84  private:
85   // Handles action rule matches.
HandleRuleMatch(const grammar::Match * match,const int64 rule_id)86   void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
87     candidates_.push_back(grammar::Derivation{match, rule_id});
88   }
89 
90   // Instantiates action suggestions from verified and deduplicated rule matches
91   // and appends them to the result.
92   // Expects the message as codepoints for text extraction from capturing
93   // matches as well as the index of the message, for correct span production.
InstantiateActionsFromMatch(const std::vector<UnicodeText::const_iterator> & message_codepoint_offsets,int message_index,const std::string & smart_reply_action_type,const grammar::Derivation & candidate,const ReflectiveFlatbufferBuilder * entity_data_builder,std::vector<ActionSuggestion> * result) const94   bool InstantiateActionsFromMatch(
95       const std::vector<UnicodeText::const_iterator>& message_codepoint_offsets,
96       int message_index, const std::string& smart_reply_action_type,
97       const grammar::Derivation& candidate,
98       const ReflectiveFlatbufferBuilder* entity_data_builder,
99       std::vector<ActionSuggestion>* result) const {
100     const RulesModel_::GrammarRules_::RuleMatch* rule_match =
101         grammar_rules_->rule_match()->Get(candidate.rule_id);
102     if (rule_match == nullptr || rule_match->action_id() == nullptr) {
103       TC3_LOG(ERROR) << "No rule action defined.";
104       return false;
105     }
106 
107     // Gather active capturing matches.
108     std::unordered_map<uint16, const grammar::Match*> capturing_matches;
109     for (const grammar::MappingMatch* match :
110          grammar::SelectAllOfType<grammar::MappingMatch>(
111              candidate.match, grammar::Match::kMappingMatch)) {
112       capturing_matches[match->id] = match;
113     }
114 
115     // Instantiate actions from the rule match.
116     for (const uint16 action_id : *rule_match->action_id()) {
117       const RulesModel_::RuleActionSpec* action_spec =
118           grammar_rules_->actions()->Get(action_id);
119       std::vector<ActionSuggestionAnnotation> annotations;
120 
121       std::unique_ptr<ReflectiveFlatbuffer> entity_data =
122           entity_data_builder != nullptr ? entity_data_builder->NewRoot()
123                                          : nullptr;
124 
125       // Set information from capturing matches.
126       if (action_spec->capturing_group() != nullptr) {
127         for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
128              *action_spec->capturing_group()) {
129           auto it = capturing_matches.find(group->group_id());
130           if (it == capturing_matches.end()) {
131             // Capturing match is not active, skip.
132             continue;
133           }
134 
135           const grammar::Match* capturing_match = it->second;
136           StringPiece match_text = StringPiece(
137               message_codepoint_offsets[capturing_match->codepoint_span.first]
138                   .utf8_data(),
139               message_codepoint_offsets[capturing_match->codepoint_span.second]
140                       .utf8_data() -
141                   message_codepoint_offsets[capturing_match->codepoint_span
142                                                 .first]
143                       .utf8_data());
144           UnicodeText normalized_match_text =
145               NormalizeMatchText(unilib_, group, match_text);
146 
147           if (!MergeEntityDataFromCapturingMatch(
148                   group, normalized_match_text.ToUTF8String(),
149                   entity_data.get())) {
150             TC3_LOG(ERROR)
151                 << "Could not merge entity data from a capturing match.";
152             return false;
153           }
154 
155           // Add smart reply suggestions.
156           SuggestTextRepliesFromCapturingMatch(entity_data_builder, group,
157                                                normalized_match_text,
158                                                smart_reply_action_type, result);
159 
160           // Add annotation.
161           ActionSuggestionAnnotation annotation;
162           if (FillAnnotationFromCapturingMatch(
163                   /*span=*/capturing_match->codepoint_span, group,
164                   /*message_index=*/message_index, match_text, &annotation)) {
165             if (group->use_annotation_match()) {
166               const grammar::AnnotationMatch* annotation_match =
167                   grammar::SelectFirstOfType<grammar::AnnotationMatch>(
168                       capturing_match, grammar::Match::kAnnotationMatch);
169               if (!annotation_match) {
170                 TC3_LOG(ERROR) << "Could not get annotation for match.";
171                 return false;
172               }
173               annotation.entity = *annotation_match->annotation;
174             }
175             annotations.push_back(std::move(annotation));
176           }
177         }
178       }
179 
180       if (action_spec->action() != nullptr) {
181         ActionSuggestion suggestion;
182         suggestion.annotations = annotations;
183         FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
184                                &suggestion);
185         result->push_back(std::move(suggestion));
186       }
187     }
188     return true;
189   }
190 
191   const UniLib& unilib_;
192   const RulesModel_::GrammarRules* grammar_rules_;
193 
194   // All action rule match candidates.
195   // Grammar rule matches are recorded, deduplicated, verified and then
196   // instantiated.
197   std::vector<grammar::Derivation> candidates_;
198 };
199 }  // namespace
200 
GrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const ReflectiveFlatbufferBuilder * entity_data_builder,const std::string & smart_reply_action_type)201 GrammarActions::GrammarActions(
202     const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
203     const ReflectiveFlatbufferBuilder* entity_data_builder,
204     const std::string& smart_reply_action_type)
205     : unilib_(*unilib),
206       grammar_rules_(grammar_rules),
207       tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
208       lexer_(unilib, grammar_rules->rules()),
209       entity_data_builder_(entity_data_builder),
210       smart_reply_action_type_(smart_reply_action_type),
211       rules_locales_(ParseRulesLocales(grammar_rules->rules())) {}
212 
SuggestActions(const Conversation & conversation,std::vector<ActionSuggestion> * result) const213 bool GrammarActions::SuggestActions(
214     const Conversation& conversation,
215     std::vector<ActionSuggestion>* result) const {
216   if (grammar_rules_->rules()->rules() == nullptr) {
217     // Nothing to do.
218     return true;
219   }
220 
221   std::vector<Locale> locales;
222   if (!ParseLocales(conversation.messages.back().detected_text_language_tags,
223                     &locales)) {
224     TC3_LOG(ERROR) << "Could not parse locales of input text.";
225     return false;
226   }
227 
228   // Select locale matching rules.
229   std::vector<const grammar::RulesSet_::Rules*> locale_rules =
230       SelectLocaleMatchingShards(grammar_rules_->rules(), rules_locales_,
231                                  locales);
232   if (locale_rules.empty()) {
233     // Nothing to do.
234     return true;
235   }
236 
237   GrammarActionsCallbackDelegate callback_handler(&unilib_, grammar_rules_);
238   grammar::Matcher matcher(&unilib_, grammar_rules_->rules(), locale_rules,
239                            &callback_handler);
240 
241   const UnicodeText text =
242       UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false);
243 
244   // Run grammar on last message.
245   lexer_.Process(text, tokenizer_->Tokenize(text),
246                  /*annotations=*/&conversation.messages.back().annotations,
247                  &matcher);
248 
249   // Populate results.
250   return callback_handler.GetActions(conversation, smart_reply_action_type_,
251                                      entity_data_builder_, result);
252 }
253 
254 }  // namespace libtextclassifier3
255