1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/grammar-actions.h"
18
19 #include "actions/feature-processor.h"
20 #include "actions/utils.h"
21 #include "annotator/types.h"
22 #include "utils/base/arena.h"
23 #include "utils/base/statusor.h"
24 #include "utils/utf8/unicodetext.h"
25
26 namespace libtextclassifier3 {
27
GrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const MutableFlatbufferBuilder * entity_data_builder,const std::string & smart_reply_action_type)28 GrammarActions::GrammarActions(
29 const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
30 const MutableFlatbufferBuilder* entity_data_builder,
31 const std::string& smart_reply_action_type)
32 : unilib_(*unilib),
33 grammar_rules_(grammar_rules),
34 tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
35 entity_data_builder_(entity_data_builder),
36 analyzer_(unilib, grammar_rules->rules(), tokenizer_.get()),
37 smart_reply_action_type_(smart_reply_action_type) {}
38
InstantiateActionsFromMatch(const grammar::TextContext & text_context,const int message_index,const grammar::Derivation & derivation,std::vector<ActionSuggestion> * result) const39 bool GrammarActions::InstantiateActionsFromMatch(
40 const grammar::TextContext& text_context, const int message_index,
41 const grammar::Derivation& derivation,
42 std::vector<ActionSuggestion>* result) const {
43 const RulesModel_::GrammarRules_::RuleMatch* rule_match =
44 grammar_rules_->rule_match()->Get(derivation.rule_id);
45 if (rule_match == nullptr || rule_match->action_id() == nullptr) {
46 TC3_LOG(ERROR) << "No rule action defined.";
47 return false;
48 }
49
50 // Gather active capturing matches.
51 std::unordered_map<uint16, const grammar::ParseTree*> capturing_matches;
52 for (const grammar::MappingNode* mapping_node :
53 grammar::SelectAllOfType<grammar::MappingNode>(
54 derivation.parse_tree, grammar::ParseTree::Type::kMapping)) {
55 capturing_matches[mapping_node->id] = mapping_node;
56 }
57
58 // Instantiate actions from the rule match.
59 for (const uint16 action_id : *rule_match->action_id()) {
60 const RulesModel_::RuleActionSpec* action_spec =
61 grammar_rules_->actions()->Get(action_id);
62 std::vector<ActionSuggestionAnnotation> annotations;
63
64 std::unique_ptr<MutableFlatbuffer> entity_data =
65 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
66 : nullptr;
67
68 // Set information from capturing matches.
69 if (action_spec->capturing_group() != nullptr) {
70 for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
71 *action_spec->capturing_group()) {
72 auto it = capturing_matches.find(group->group_id());
73 if (it == capturing_matches.end()) {
74 // Capturing match is not active, skip.
75 continue;
76 }
77
78 const grammar::ParseTree* capturing_match = it->second;
79 const UnicodeText match_text =
80 text_context.Span(capturing_match->codepoint_span);
81 UnicodeText normalized_match_text =
82 NormalizeMatchText(unilib_, group, match_text);
83
84 if (!MergeEntityDataFromCapturingMatch(
85 group, normalized_match_text.ToUTF8String(),
86 entity_data.get())) {
87 TC3_LOG(ERROR)
88 << "Could not merge entity data from a capturing match.";
89 return false;
90 }
91
92 // Add smart reply suggestions.
93 SuggestTextRepliesFromCapturingMatch(entity_data_builder_, group,
94 normalized_match_text,
95 smart_reply_action_type_, result);
96
97 // Add annotation.
98 ActionSuggestionAnnotation annotation;
99 if (FillAnnotationFromCapturingMatch(
100 /*span=*/capturing_match->codepoint_span, group,
101 /*message_index=*/message_index, match_text.ToUTF8String(),
102 &annotation)) {
103 if (group->use_annotation_match()) {
104 std::vector<const grammar::AnnotationNode*> annotations =
105 grammar::SelectAllOfType<grammar::AnnotationNode>(
106 capturing_match, grammar::ParseTree::Type::kAnnotation);
107 if (annotations.size() != 1) {
108 TC3_LOG(ERROR) << "Could not get annotation for match.";
109 return false;
110 }
111 annotation.entity = *annotations.front()->annotation;
112 }
113 annotations.push_back(std::move(annotation));
114 }
115 }
116 }
117
118 if (action_spec->action() != nullptr) {
119 ActionSuggestion suggestion;
120 suggestion.annotations = annotations;
121 FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
122 &suggestion);
123 result->push_back(std::move(suggestion));
124 }
125 }
126 return true;
127 }
SuggestActions(const Conversation & conversation,std::vector<ActionSuggestion> * result) const128 bool GrammarActions::SuggestActions(
129 const Conversation& conversation,
130 std::vector<ActionSuggestion>* result) const {
131 if (grammar_rules_->rules()->rules() == nullptr ||
132 conversation.messages.back().text.empty()) {
133 // Nothing to do.
134 return true;
135 }
136
137 std::vector<Locale> locales;
138 if (!ParseLocales(conversation.messages.back().detected_text_language_tags,
139 &locales)) {
140 TC3_LOG(ERROR) << "Could not parse locales of input text.";
141 return false;
142 }
143
144 const int message_index = conversation.messages.size() - 1;
145 grammar::TextContext text = analyzer_.BuildTextContextForInput(
146 UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false),
147 locales);
148 text.annotations = conversation.messages.back().annotations;
149
150 UnsafeArena arena(/*block_size=*/16 << 10);
151 StatusOr<std::vector<grammar::EvaluatedDerivation>> evaluated_derivations =
152 analyzer_.Parse(text, &arena);
153 // TODO(b/171294882): Return the status here and below.
154 if (!evaluated_derivations.ok()) {
155 TC3_LOG(ERROR) << "Could not run grammar analyzer: "
156 << evaluated_derivations.status().error_message();
157 return false;
158 }
159
160 for (const grammar::EvaluatedDerivation& evaluated_derivation :
161 evaluated_derivations.ValueOrDie()) {
162 if (!InstantiateActionsFromMatch(text, message_index, evaluated_derivation,
163 result)) {
164 TC3_LOG(ERROR) << "Could not instantiate actions from a grammar match.";
165 return false;
166 }
167 }
168
169 return true;
170 }
171
172 } // namespace libtextclassifier3
173