/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/grammar-actions.h" #include "actions/feature-processor.h" #include "actions/utils.h" #include "annotator/types.h" #include "utils/base/arena.h" #include "utils/base/statusor.h" #include "utils/utf8/unicodetext.h" namespace libtextclassifier3 { GrammarActions::GrammarActions( const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules, const MutableFlatbufferBuilder* entity_data_builder, const std::string& smart_reply_action_type) : unilib_(*unilib), grammar_rules_(grammar_rules), tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)), entity_data_builder_(entity_data_builder), analyzer_(unilib, grammar_rules->rules(), tokenizer_.get()), smart_reply_action_type_(smart_reply_action_type) {} bool GrammarActions::InstantiateActionsFromMatch( const grammar::TextContext& text_context, const int message_index, const grammar::Derivation& derivation, std::vector* result) const { const RulesModel_::GrammarRules_::RuleMatch* rule_match = grammar_rules_->rule_match()->Get(derivation.rule_id); if (rule_match == nullptr || rule_match->action_id() == nullptr) { TC3_LOG(ERROR) << "No rule action defined."; return false; } // Gather active capturing matches. std::unordered_map capturing_matches; for (const grammar::MappingNode* mapping_node : grammar::SelectAllOfType( derivation.parse_tree, grammar::ParseTree::Type::kMapping)) { capturing_matches[mapping_node->id] = mapping_node; } // Instantiate actions from the rule match. for (const uint16 action_id : *rule_match->action_id()) { const RulesModel_::RuleActionSpec* action_spec = grammar_rules_->actions()->Get(action_id); std::vector annotations; std::unique_ptr entity_data = entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot() : nullptr; // Set information from capturing matches. if (action_spec->capturing_group() != nullptr) { for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group : *action_spec->capturing_group()) { auto it = capturing_matches.find(group->group_id()); if (it == capturing_matches.end()) { // Capturing match is not active, skip. continue; } const grammar::ParseTree* capturing_match = it->second; const UnicodeText match_text = text_context.Span(capturing_match->codepoint_span); UnicodeText normalized_match_text = NormalizeMatchText(unilib_, group, match_text); if (!MergeEntityDataFromCapturingMatch( group, normalized_match_text.ToUTF8String(), entity_data.get())) { TC3_LOG(ERROR) << "Could not merge entity data from a capturing match."; return false; } // Add smart reply suggestions. SuggestTextRepliesFromCapturingMatch(entity_data_builder_, group, normalized_match_text, smart_reply_action_type_, result); // Add annotation. ActionSuggestionAnnotation annotation; if (FillAnnotationFromCapturingMatch( /*span=*/capturing_match->codepoint_span, group, /*message_index=*/message_index, match_text.ToUTF8String(), &annotation)) { if (group->use_annotation_match()) { std::vector annotations = grammar::SelectAllOfType( capturing_match, grammar::ParseTree::Type::kAnnotation); if (annotations.size() != 1) { TC3_LOG(ERROR) << "Could not get annotation for match."; return false; } annotation.entity = *annotations.front()->annotation; } annotations.push_back(std::move(annotation)); } } } if (action_spec->action() != nullptr) { ActionSuggestion suggestion; suggestion.annotations = annotations; FillSuggestionFromSpec(action_spec->action(), entity_data.get(), &suggestion); result->push_back(std::move(suggestion)); } } return true; } bool GrammarActions::SuggestActions( const Conversation& conversation, std::vector* result) const { if (grammar_rules_->rules()->rules() == nullptr || conversation.messages.back().text.empty()) { // Nothing to do. return true; } std::vector locales; if (!ParseLocales(conversation.messages.back().detected_text_language_tags, &locales)) { TC3_LOG(ERROR) << "Could not parse locales of input text."; return false; } const int message_index = conversation.messages.size() - 1; grammar::TextContext text = analyzer_.BuildTextContextForInput( UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false), locales); text.annotations = conversation.messages.back().annotations; UnsafeArena arena(/*block_size=*/16 << 10); StatusOr> evaluated_derivations = analyzer_.Parse(text, &arena); // TODO(b/171294882): Return the status here and below. if (!evaluated_derivations.ok()) { TC3_LOG(ERROR) << "Could not run grammar analyzer: " << evaluated_derivations.status().error_message(); return false; } for (const grammar::EvaluatedDerivation& evaluated_derivation : evaluated_derivations.ValueOrDie()) { if (!InstantiateActionsFromMatch(text, message_index, evaluated_derivation, result)) { TC3_LOG(ERROR) << "Could not instantiate actions from a grammar match."; return false; } } return true; } } // namespace libtextclassifier3