/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/actions-suggestions.h" #include #include "actions/lua-actions.h" #include "actions/types.h" #include "actions/zlib-utils.h" #include "utils/base/logging.h" #include "utils/flatbuffers.h" #include "utils/lua-utils.h" #include "utils/regex-match.h" #include "utils/strings/split.h" #include "utils/strings/stringpiece.h" #include "utils/utf8/unicodetext.h" #include "utils/zlib/zlib_regex.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { const std::string& ActionsSuggestions::kViewCalendarType = *[]() { return new std::string("view_calendar"); }(); const std::string& ActionsSuggestions::kViewMapType = *[]() { return new std::string("view_map"); }(); const std::string& ActionsSuggestions::kTrackFlightType = *[]() { return new std::string("track_flight"); }(); const std::string& ActionsSuggestions::kOpenUrlType = *[]() { return new std::string("open_url"); }(); const std::string& ActionsSuggestions::kSendSmsType = *[]() { return new std::string("send_sms"); }(); const std::string& ActionsSuggestions::kCallPhoneType = *[]() { return new std::string("call_phone"); }(); const std::string& ActionsSuggestions::kSendEmailType = *[]() { return new std::string("send_email"); }(); const std::string& ActionsSuggestions::kShareLocation = *[]() { return new std::string("share_location"); }(); namespace { const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) { flatbuffers::Verifier verifier(addr, size); if (VerifyActionsModelBuffer(verifier)) { return GetActionsModel(addr); } else { return nullptr; } } template T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset, const T default_value) { if (values == nullptr) { return default_value; } return values->GetField(field_offset, default_value); } // Returns number of (tail) messages of a conversation to consider. int NumMessagesToConsider(const Conversation& conversation, const int max_conversation_history_length) { return ((max_conversation_history_length < 0 || conversation.messages.size() < max_conversation_history_length) ? conversation.messages.size() : max_conversation_history_length); } } // namespace std::unique_ptr ActionsSuggestions::FromUnownedBuffer( const uint8_t* buffer, const int size, const UniLib* unilib, const std::string& triggering_preconditions_overlay) { auto actions = std::unique_ptr(new ActionsSuggestions()); const ActionsModel* model = LoadAndVerifyModel(buffer, size); if (model == nullptr) { return nullptr; } actions->model_ = model; actions->SetOrCreateUnilib(unilib); actions->triggering_preconditions_overlay_buffer_ = triggering_preconditions_overlay; if (!actions->ValidateAndInitialize()) { return nullptr; } return actions; } std::unique_ptr ActionsSuggestions::FromScopedMmap( std::unique_ptr mmap, const UniLib* unilib, const std::string& triggering_preconditions_overlay) { if (!mmap->handle().ok()) { TC3_VLOG(1) << "Mmap failed."; return nullptr; } const ActionsModel* model = LoadAndVerifyModel( reinterpret_cast(mmap->handle().start()), mmap->handle().num_bytes()); if (!model) { TC3_LOG(ERROR) << "Model verification failed."; return nullptr; } auto actions = std::unique_ptr(new ActionsSuggestions()); actions->model_ = model; actions->mmap_ = std::move(mmap); actions->SetOrCreateUnilib(unilib); actions->triggering_preconditions_overlay_buffer_ = triggering_preconditions_overlay; if (!actions->ValidateAndInitialize()) { return nullptr; } return actions; } std::unique_ptr ActionsSuggestions::FromScopedMmap( std::unique_ptr mmap, std::unique_ptr unilib, const std::string& triggering_preconditions_overlay) { if (!mmap->handle().ok()) { TC3_VLOG(1) << "Mmap failed."; return nullptr; } const ActionsModel* model = LoadAndVerifyModel( reinterpret_cast(mmap->handle().start()), mmap->handle().num_bytes()); if (!model) { TC3_LOG(ERROR) << "Model verification failed."; return nullptr; } auto actions = std::unique_ptr(new ActionsSuggestions()); actions->model_ = model; actions->mmap_ = std::move(mmap); actions->owned_unilib_ = std::move(unilib); actions->unilib_ = actions->owned_unilib_.get(); actions->triggering_preconditions_overlay_buffer_ = triggering_preconditions_overlay; if (!actions->ValidateAndInitialize()) { return nullptr; } return actions; } std::unique_ptr ActionsSuggestions::FromFileDescriptor( const int fd, const int offset, const int size, const UniLib* unilib, const std::string& triggering_preconditions_overlay) { std::unique_ptr mmap; if (offset >= 0 && size >= 0) { mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size)); } else { mmap.reset(new libtextclassifier3::ScopedMmap(fd)); } return FromScopedMmap(std::move(mmap), unilib, triggering_preconditions_overlay); } std::unique_ptr ActionsSuggestions::FromFileDescriptor( const int fd, const int offset, const int size, std::unique_ptr unilib, const std::string& triggering_preconditions_overlay) { std::unique_ptr mmap; if (offset >= 0 && size >= 0) { mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size)); } else { mmap.reset(new libtextclassifier3::ScopedMmap(fd)); } return FromScopedMmap(std::move(mmap), std::move(unilib), triggering_preconditions_overlay); } std::unique_ptr ActionsSuggestions::FromFileDescriptor( const int fd, const UniLib* unilib, const std::string& triggering_preconditions_overlay) { std::unique_ptr mmap( new libtextclassifier3::ScopedMmap(fd)); return FromScopedMmap(std::move(mmap), unilib, triggering_preconditions_overlay); } std::unique_ptr ActionsSuggestions::FromFileDescriptor( const int fd, std::unique_ptr unilib, const std::string& triggering_preconditions_overlay) { std::unique_ptr mmap( new libtextclassifier3::ScopedMmap(fd)); return FromScopedMmap(std::move(mmap), std::move(unilib), triggering_preconditions_overlay); } std::unique_ptr ActionsSuggestions::FromPath( const std::string& path, const UniLib* unilib, const std::string& triggering_preconditions_overlay) { std::unique_ptr mmap( new libtextclassifier3::ScopedMmap(path)); return FromScopedMmap(std::move(mmap), unilib, triggering_preconditions_overlay); } std::unique_ptr ActionsSuggestions::FromPath( const std::string& path, std::unique_ptr unilib, const std::string& triggering_preconditions_overlay) { std::unique_ptr mmap( new libtextclassifier3::ScopedMmap(path)); return FromScopedMmap(std::move(mmap), std::move(unilib), triggering_preconditions_overlay); } void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) { if (unilib != nullptr) { unilib_ = unilib; } else { owned_unilib_.reset(new UniLib); unilib_ = owned_unilib_.get(); } } bool ActionsSuggestions::ValidateAndInitialize() { if (model_ == nullptr) { TC3_LOG(ERROR) << "No model specified."; return false; } if (model_->smart_reply_action_type() == nullptr) { TC3_LOG(ERROR) << "No smart reply action type specified."; return false; } if (!InitializeTriggeringPreconditions()) { TC3_LOG(ERROR) << "Could not initialize preconditions."; return false; } if (model_->locales() && !ParseLocales(model_->locales()->c_str(), &locales_)) { TC3_LOG(ERROR) << "Could not parse model supported locales."; return false; } if (model_->tflite_model_spec() != nullptr) { model_executor_ = TfLiteModelExecutor::FromBuffer( model_->tflite_model_spec()->tflite_model()); if (!model_executor_) { TC3_LOG(ERROR) << "Could not initialize model executor."; return false; } } if (model_->annotation_actions_spec() != nullptr && model_->annotation_actions_spec()->annotation_mapping() != nullptr) { for (const AnnotationActionsSpec_::AnnotationMapping* mapping : *model_->annotation_actions_spec()->annotation_mapping()) { annotation_entity_types_.insert(mapping->annotation_collection()->str()); } } std::unique_ptr decompressor = ZlibDecompressor::Instance(); if (!InitializeRules(decompressor.get())) { TC3_LOG(ERROR) << "Could not initialize rules."; return false; } if (model_->actions_entity_data_schema() != nullptr) { entity_data_schema_ = LoadAndVerifyFlatbuffer( model_->actions_entity_data_schema()->Data(), model_->actions_entity_data_schema()->size()); if (entity_data_schema_ == nullptr) { TC3_LOG(ERROR) << "Could not load entity data schema data."; return false; } entity_data_builder_.reset( new ReflectiveFlatbufferBuilder(entity_data_schema_)); } else { entity_data_schema_ = nullptr; } std::string actions_script; if (GetUncompressedString(model_->lua_actions_script(), model_->compressed_lua_actions_script(), decompressor.get(), &actions_script) && !actions_script.empty()) { if (!Compile(actions_script, &lua_bytecode_)) { TC3_LOG(ERROR) << "Could not precompile lua actions snippet."; return false; } } if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker( model_->ranking_options(), decompressor.get(), model_->smart_reply_action_type()->str()))) { TC3_LOG(ERROR) << "Could not create an action suggestions ranker."; return false; } // Create feature processor if specified. const ActionsTokenFeatureProcessorOptions* options = model_->feature_processor_options(); if (options != nullptr) { if (options->tokenizer_options() == nullptr) { TC3_LOG(ERROR) << "No tokenizer options specified."; return false; } feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_)); embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer( options->embedding_model(), options->embedding_size(), options->embedding_quantization_bits()); if (embedding_executor_ == nullptr) { TC3_LOG(ERROR) << "Could not initialize embedding executor."; return false; } // Cache embedding of padding, start and end token. if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) || !EmbedTokenId(options->start_token_id(), &embedded_start_token_) || !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) { TC3_LOG(ERROR) << "Could not precompute token embeddings."; return false; } token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize(); } // Create low confidence model if specified. if (model_->low_confidence_ngram_model() != nullptr) { ngram_model_ = NGramModel::Create(model_->low_confidence_ngram_model(), feature_processor_ == nullptr ? nullptr : feature_processor_->tokenizer(), unilib_); if (ngram_model_ == nullptr) { TC3_LOG(ERROR) << "Could not create ngram linear regression model."; return false; } } return true; } bool ActionsSuggestions::InitializeTriggeringPreconditions() { triggering_preconditions_overlay_ = LoadAndVerifyFlatbuffer( triggering_preconditions_overlay_buffer_); if (triggering_preconditions_overlay_ == nullptr && !triggering_preconditions_overlay_buffer_.empty()) { TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites."; return false; } const flatbuffers::Table* overlay = reinterpret_cast( triggering_preconditions_overlay_); const TriggeringPreconditions* defaults = model_->preconditions(); if (defaults == nullptr) { TC3_LOG(ERROR) << "No triggering conditions specified."; return false; } preconditions_.min_smart_reply_triggering_score = ValueOrDefault( overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE, defaults->min_smart_reply_triggering_score()); preconditions_.max_sensitive_topic_score = ValueOrDefault( overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE, defaults->max_sensitive_topic_score()); preconditions_.suppress_on_sensitive_topic = ValueOrDefault( overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC, defaults->suppress_on_sensitive_topic()); preconditions_.min_input_length = ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH, defaults->min_input_length()); preconditions_.max_input_length = ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH, defaults->max_input_length()); preconditions_.min_locale_match_fraction = ValueOrDefault( overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION, defaults->min_locale_match_fraction()); preconditions_.handle_missing_locale_as_supported = ValueOrDefault( overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED, defaults->handle_missing_locale_as_supported()); preconditions_.handle_unknown_locale_as_supported = ValueOrDefault( overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED, defaults->handle_unknown_locale_as_supported()); preconditions_.suppress_on_low_confidence_input = ValueOrDefault( overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT, defaults->suppress_on_low_confidence_input()); preconditions_.diversification_distance_threshold = ValueOrDefault( overlay, TriggeringPreconditions::VT_DIVERSIFICATION_DISTANCE_THRESHOLD, defaults->diversification_distance_threshold()); preconditions_.confidence_threshold = ValueOrDefault(overlay, TriggeringPreconditions::VT_CONFIDENCE_THRESHOLD, defaults->confidence_threshold()); preconditions_.empirical_probability_factor = ValueOrDefault( overlay, TriggeringPreconditions::VT_EMPIRICAL_PROBABILITY_FACTOR, defaults->empirical_probability_factor()); preconditions_.min_reply_score_threshold = ValueOrDefault( overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD, defaults->min_reply_score_threshold()); return true; } bool ActionsSuggestions::EmbedTokenId(const int32 token_id, std::vector* embedding) const { return feature_processor_->AppendFeatures( {token_id}, /*dense_features=*/{}, embedding_executor_.get(), embedding); } bool ActionsSuggestions::InitializeRules(ZlibDecompressor* decompressor) { if (model_->rules() != nullptr) { if (!InitializeRules(decompressor, model_->rules(), &rules_)) { TC3_LOG(ERROR) << "Could not initialize action rules."; return false; } } if (model_->low_confidence_rules() != nullptr) { if (!InitializeRules(decompressor, model_->low_confidence_rules(), &low_confidence_rules_)) { TC3_LOG(ERROR) << "Could not initialize low confidence rules."; return false; } } // Extend by rules provided by the overwrite. // NOTE: The rules from the original models are *not* cleared. if (triggering_preconditions_overlay_ != nullptr && triggering_preconditions_overlay_->low_confidence_rules() != nullptr) { // These rules are optionally compressed, but separately. std::unique_ptr overwrite_decompressor = ZlibDecompressor::Instance(); if (overwrite_decompressor == nullptr) { TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules."; return false; } if (!InitializeRules( overwrite_decompressor.get(), triggering_preconditions_overlay_->low_confidence_rules(), &low_confidence_rules_)) { TC3_LOG(ERROR) << "Could not initialize low confidence rules from overwrite."; return false; } } return true; } bool ActionsSuggestions::InitializeRules( ZlibDecompressor* decompressor, const RulesModel* rules, std::vector* compiled_rules) const { for (const RulesModel_::Rule* rule : *rules->rule()) { std::unique_ptr compiled_pattern = UncompressMakeRegexPattern( *unilib_, rule->pattern(), rule->compressed_pattern(), rules->lazy_regex_compilation(), decompressor); if (compiled_pattern == nullptr) { TC3_LOG(ERROR) << "Failed to load rule pattern."; return false; } // Check whether there is a check on the output. std::unique_ptr compiled_output_pattern; if (rule->output_pattern() != nullptr || rule->compressed_output_pattern() != nullptr) { compiled_output_pattern = UncompressMakeRegexPattern( *unilib_, rule->output_pattern(), rule->compressed_output_pattern(), rules->lazy_regex_compilation(), decompressor); if (compiled_output_pattern == nullptr) { TC3_LOG(ERROR) << "Failed to load rule output pattern."; return false; } } compiled_rules->emplace_back(rule, std::move(compiled_pattern), std::move(compiled_output_pattern)); } return true; } bool ActionsSuggestions::IsLowConfidenceInput( const Conversation& conversation, const int num_messages, std::vector* post_check_rules) const { for (int i = 1; i <= num_messages; i++) { const std::string& message = conversation.messages[conversation.messages.size() - i].text; const UnicodeText message_unicode( UTF8ToUnicodeText(message, /*do_copy=*/false)); // Run ngram linear regression model. if (ngram_model_ != nullptr) { if (ngram_model_->Eval(message_unicode)) { return true; } } // Run the regex based rules. for (int low_confidence_rule = 0; low_confidence_rule < low_confidence_rules_.size(); low_confidence_rule++) { const CompiledRule& rule = low_confidence_rules_[low_confidence_rule]; const std::unique_ptr matcher = rule.pattern->Matcher(message_unicode); int status = UniLib::RegexMatcher::kNoError; if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { // Rule only applies to input-output pairs, so defer the check. if (rule.output_pattern != nullptr) { post_check_rules->push_back(low_confidence_rule); continue; } return true; } } } return false; } bool ActionsSuggestions::FilterConfidenceOutput( const std::vector& post_check_rules, std::vector* actions) const { if (post_check_rules.empty() || actions->empty()) { return true; } std::vector filtered_text_replies; for (const ActionSuggestion& action : *actions) { if (action.response_text.empty()) { filtered_text_replies.push_back(action); continue; } bool passes_post_check = true; const UnicodeText text_reply_unicode( UTF8ToUnicodeText(action.response_text, /*do_copy=*/false)); for (const int rule_id : post_check_rules) { const std::unique_ptr matcher = low_confidence_rules_[rule_id].output_pattern->Matcher( text_reply_unicode); if (matcher == nullptr) { TC3_LOG(ERROR) << "Could not create matcher for post check rule."; return false; } int status = UniLib::RegexMatcher::kNoError; if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) { passes_post_check = false; break; } } if (passes_post_check) { filtered_text_replies.push_back(action); } } *actions = std::move(filtered_text_replies); return true; } ActionSuggestion ActionsSuggestions::SuggestionFromSpec( const ActionSuggestionSpec* action, const std::string& default_type, const std::string& default_response_text, const std::string& default_serialized_entity_data, const float default_score, const float default_priority_score) const { ActionSuggestion suggestion; suggestion.score = action != nullptr ? action->score() : default_score; suggestion.priority_score = action != nullptr ? action->priority_score() : default_priority_score; suggestion.type = action != nullptr && action->type() != nullptr ? action->type()->str() : default_type; suggestion.response_text = action != nullptr && action->response_text() != nullptr ? action->response_text()->str() : default_response_text; suggestion.serialized_entity_data = action != nullptr && action->serialized_entity_data() != nullptr ? action->serialized_entity_data()->str() : default_serialized_entity_data; return suggestion; } std::vector> ActionsSuggestions::Tokenize( const std::vector& context) const { std::vector> tokens; tokens.reserve(context.size()); for (const std::string& message : context) { tokens.push_back(feature_processor_->tokenizer()->Tokenize(message)); } return tokens; } bool ActionsSuggestions::EmbedTokensPerMessage( const std::vector>& tokens, std::vector* embeddings, int* max_num_tokens_per_message) const { const int num_messages = tokens.size(); *max_num_tokens_per_message = 0; for (int i = 0; i < num_messages; i++) { const int num_message_tokens = tokens[i].size(); if (num_message_tokens > *max_num_tokens_per_message) { *max_num_tokens_per_message = num_message_tokens; } } if (model_->feature_processor_options()->min_num_tokens_per_message() > *max_num_tokens_per_message) { *max_num_tokens_per_message = model_->feature_processor_options()->min_num_tokens_per_message(); } if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 && *max_num_tokens_per_message > model_->feature_processor_options()->max_num_tokens_per_message()) { *max_num_tokens_per_message = model_->feature_processor_options()->max_num_tokens_per_message(); } // Embed all tokens and add paddings to pad tokens of each message to the // maximum number of tokens in a message of the conversation. // If a number of tokens is specified in the model config, tokens at the // beginning of a message are dropped if they don't fit in the limit. for (int i = 0; i < num_messages; i++) { const int start = std::max(tokens[i].size() - *max_num_tokens_per_message, 0); for (int pos = start; pos < tokens[i].size(); pos++) { if (!feature_processor_->AppendTokenFeatures( tokens[i][pos], embedding_executor_.get(), embeddings)) { TC3_LOG(ERROR) << "Could not run token feature extractor."; return false; } } // Add padding. for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) { embeddings->insert(embeddings->end(), embedded_padding_token_.begin(), embedded_padding_token_.end()); } } return true; } bool ActionsSuggestions::EmbedAndFlattenTokens( const std::vector> tokens, std::vector* embeddings, int* total_token_count) const { const int num_messages = tokens.size(); int start_message = 0; int message_token_offset = 0; // If a maximum model input length is specified, we need to check how // much we need to trim at the start. const int max_num_total_tokens = model_->feature_processor_options()->max_num_total_tokens(); if (max_num_total_tokens > 0) { int total_tokens = 0; start_message = num_messages - 1; for (; start_message >= 0; start_message--) { // Tokens of the message + start and end token. const int num_message_tokens = tokens[start_message].size() + 2; total_tokens += num_message_tokens; // Check whether we exhausted the budget. if (total_tokens >= max_num_total_tokens) { message_token_offset = total_tokens - max_num_total_tokens; break; } } } // Add embeddings. *total_token_count = 0; for (int i = start_message; i < num_messages; i++) { if (message_token_offset == 0) { ++(*total_token_count); // Add `start message` token. embeddings->insert(embeddings->end(), embedded_start_token_.begin(), embedded_start_token_.end()); } for (int pos = std::max(0, message_token_offset - 1); pos < tokens[i].size(); pos++) { ++(*total_token_count); if (!feature_processor_->AppendTokenFeatures( tokens[i][pos], embedding_executor_.get(), embeddings)) { TC3_LOG(ERROR) << "Could not run token feature extractor."; return false; } } // Add `end message` token. ++(*total_token_count); embeddings->insert(embeddings->end(), embedded_end_token_.begin(), embedded_end_token_.end()); // Reset for the subsequent messages. message_token_offset = 0; } // Add optional padding. const int min_num_total_tokens = model_->feature_processor_options()->min_num_total_tokens(); for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) { embeddings->insert(embeddings->end(), embedded_padding_token_.begin(), embedded_padding_token_.end()); } return true; } bool ActionsSuggestions::AllocateInput(const int conversation_length, const int max_tokens, const int total_token_count, tflite::Interpreter* interpreter) const { if (model_->tflite_model_spec()->resize_inputs()) { if (model_->tflite_model_spec()->input_context() >= 0) { interpreter->ResizeInputTensor( interpreter->inputs()[model_->tflite_model_spec()->input_context()], {1, conversation_length}); } if (model_->tflite_model_spec()->input_user_id() >= 0) { interpreter->ResizeInputTensor( interpreter->inputs()[model_->tflite_model_spec()->input_user_id()], {1, conversation_length}); } if (model_->tflite_model_spec()->input_time_diffs() >= 0) { interpreter->ResizeInputTensor( interpreter ->inputs()[model_->tflite_model_spec()->input_time_diffs()], {1, conversation_length}); } if (model_->tflite_model_spec()->input_num_tokens() >= 0) { interpreter->ResizeInputTensor( interpreter ->inputs()[model_->tflite_model_spec()->input_num_tokens()], {conversation_length, 1}); } if (model_->tflite_model_spec()->input_token_embeddings() >= 0) { interpreter->ResizeInputTensor( interpreter ->inputs()[model_->tflite_model_spec()->input_token_embeddings()], {conversation_length, max_tokens, token_embedding_size_}); } if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { interpreter->ResizeInputTensor( interpreter->inputs()[model_->tflite_model_spec() ->input_flattened_token_embeddings()], {1, total_token_count}); } } return interpreter->AllocateTensors() == kTfLiteOk; } bool ActionsSuggestions::SetupModelInput( const std::vector& context, const std::vector& user_ids, const std::vector& time_diffs, const int num_suggestions, const float confidence_threshold, const float diversification_distance, const float empirical_probability_factor, tflite::Interpreter* interpreter) const { // Compute token embeddings. std::vector> tokens; std::vector token_embeddings; std::vector flattened_token_embeddings; int max_tokens = 0; int total_token_count = 0; if (model_->tflite_model_spec()->input_num_tokens() >= 0 || model_->tflite_model_spec()->input_token_embeddings() >= 0 || model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { if (feature_processor_ == nullptr) { TC3_LOG(ERROR) << "No feature processor specified."; return false; } // Tokenize the messages in the conversation. tokens = Tokenize(context); if (model_->tflite_model_spec()->input_token_embeddings() >= 0) { if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) { TC3_LOG(ERROR) << "Could not extract token features."; return false; } } if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings, &total_token_count)) { TC3_LOG(ERROR) << "Could not extract token features."; return false; } } } if (!AllocateInput(context.size(), max_tokens, total_token_count, interpreter)) { TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed."; return false; } if (model_->tflite_model_spec()->input_context() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_context(), context, interpreter); } if (model_->tflite_model_spec()->input_context_length() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_context_length(), context.size(), interpreter); } if (model_->tflite_model_spec()->input_user_id() >= 0) { model_executor_->SetInput(model_->tflite_model_spec()->input_user_id(), user_ids, interpreter); } if (model_->tflite_model_spec()->input_num_suggestions() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_num_suggestions(), num_suggestions, interpreter); } if (model_->tflite_model_spec()->input_time_diffs() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_time_diffs(), time_diffs, interpreter); } if (model_->tflite_model_spec()->input_diversification_distance() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_diversification_distance(), diversification_distance, interpreter); } if (model_->tflite_model_spec()->input_confidence_threshold() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_confidence_threshold(), confidence_threshold, interpreter); } if (model_->tflite_model_spec()->input_empirical_probability_factor() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_empirical_probability_factor(), confidence_threshold, interpreter); } if (model_->tflite_model_spec()->input_num_tokens() >= 0) { std::vector num_tokens_per_message(tokens.size()); for (int i = 0; i < tokens.size(); i++) { num_tokens_per_message[i] = tokens[i].size(); } model_executor_->SetInput( model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message, interpreter); } if (model_->tflite_model_spec()->input_token_embeddings() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_token_embeddings(), token_embeddings, interpreter); } if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { model_executor_->SetInput( model_->tflite_model_spec()->input_flattened_token_embeddings(), flattened_token_embeddings, interpreter); } return true; } bool ActionsSuggestions::ReadModelOutput( tflite::Interpreter* interpreter, const ActionSuggestionOptions& options, ActionsSuggestionsResponse* response) const { // Read sensitivity and triggering score predictions. if (model_->tflite_model_spec()->output_triggering_score() >= 0) { const TensorView& triggering_score = model_executor_->OutputView( model_->tflite_model_spec()->output_triggering_score(), interpreter); if (!triggering_score.is_valid() || triggering_score.size() == 0) { TC3_LOG(ERROR) << "Could not compute triggering score."; return false; } response->triggering_score = triggering_score.data()[0]; response->output_filtered_min_triggering_score = (response->triggering_score < preconditions_.min_smart_reply_triggering_score); } if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) { const TensorView& sensitive_topic_score = model_executor_->OutputView( model_->tflite_model_spec()->output_sensitive_topic_score(), interpreter); if (!sensitive_topic_score.is_valid() || sensitive_topic_score.dim(0) != 1) { TC3_LOG(ERROR) << "Could not compute sensitive topic score."; return false; } response->sensitivity_score = sensitive_topic_score.data()[0]; response->output_filtered_sensitivity = (response->sensitivity_score > preconditions_.max_sensitive_topic_score); } // Suppress model outputs. if (response->output_filtered_sensitivity) { return true; } // Read smart reply predictions. std::vector text_replies; if (!response->output_filtered_min_triggering_score && model_->tflite_model_spec()->output_replies() >= 0) { const std::vector replies = model_executor_->Output( model_->tflite_model_spec()->output_replies(), interpreter); TensorView scores = model_executor_->OutputView( model_->tflite_model_spec()->output_replies_scores(), interpreter); for (int i = 0; i < replies.size(); i++) { if (replies[i].len == 0) continue; const float score = scores.data()[i]; if (score < preconditions_.min_reply_score_threshold) { continue; } response->actions.push_back({std::string(replies[i].str, replies[i].len), model_->smart_reply_action_type()->str(), score}); } } // Read actions suggestions. if (model_->tflite_model_spec()->output_actions_scores() >= 0) { const TensorView actions_scores = model_executor_->OutputView( model_->tflite_model_spec()->output_actions_scores(), interpreter); for (int i = 0; i < model_->action_type()->Length(); i++) { const ActionTypeOptions* action_type = model_->action_type()->Get(i); // Skip disabled action classes, such as the default other category. if (!action_type->enabled()) { continue; } const float score = actions_scores.data()[i]; if (score < action_type->min_triggering_score()) { continue; } ActionSuggestion suggestion = SuggestionFromSpec(action_type->action(), /*default_type=*/action_type->name()->str()); suggestion.score = score; response->actions.push_back(suggestion); } } return true; } bool ActionsSuggestions::SuggestActionsFromModel( const Conversation& conversation, const int num_messages, const ActionSuggestionOptions& options, ActionsSuggestionsResponse* response, std::unique_ptr* interpreter) const { TC3_CHECK_LE(num_messages, conversation.messages.size()); if (!model_executor_) { return true; } *interpreter = model_executor_->CreateInterpreter(); if (!*interpreter) { TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the " "actions suggestions model."; return false; } std::vector context; std::vector user_ids; std::vector time_diffs; context.reserve(num_messages); user_ids.reserve(num_messages); time_diffs.reserve(num_messages); // Gather last `num_messages` messages from the conversation. int64 last_message_reference_time_ms_utc = 0; const float second_in_ms = 1000; for (int i = conversation.messages.size() - num_messages; i < conversation.messages.size(); i++) { const ConversationMessage& message = conversation.messages[i]; context.push_back(message.text); user_ids.push_back(message.user_id); float time_diff_secs = 0; if (message.reference_time_ms_utc != 0 && last_message_reference_time_ms_utc != 0) { time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc - last_message_reference_time_ms_utc) / second_in_ms); } if (message.reference_time_ms_utc != 0) { last_message_reference_time_ms_utc = message.reference_time_ms_utc; } time_diffs.push_back(time_diff_secs); } if (!SetupModelInput(context, user_ids, time_diffs, /*num_suggestions=*/model_->num_smart_replies(), preconditions_.confidence_threshold, preconditions_.diversification_distance_threshold, preconditions_.empirical_probability_factor, interpreter->get())) { TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model."; return false; } if ((*interpreter)->Invoke() != kTfLiteOk) { TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter."; return false; } return ReadModelOutput(interpreter->get(), options, response); } AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage( const ConversationMessage& message) const { AnnotationOptions options; options.detected_text_language_tags = message.detected_text_language_tags; options.reference_time_ms_utc = message.reference_time_ms_utc; options.reference_timezone = message.reference_timezone; options.annotation_usecase = model_->annotation_actions_spec()->annotation_usecase(); options.is_serialized_entity_data_enabled = model_->annotation_actions_spec()->is_serialized_entity_data_enabled(); options.entity_types = annotation_entity_types_; return options; } void ActionsSuggestions::SuggestActionsFromAnnotations( const Conversation& conversation, const ActionSuggestionOptions& options, const Annotator* annotator, std::vector* actions) const { if (model_->annotation_actions_spec() == nullptr || model_->annotation_actions_spec()->annotation_mapping() == nullptr || model_->annotation_actions_spec()->annotation_mapping()->size() == 0) { return; } // Create actions based on the annotations. const int max_from_any_person = model_->annotation_actions_spec()->max_history_from_any_person(); const int max_from_last_person = model_->annotation_actions_spec()->max_history_from_last_person(); const int last_person = conversation.messages.back().user_id; int num_messages_last_person = 0; int num_messages_any_person = 0; bool all_from_last_person = true; for (int message_index = conversation.messages.size() - 1; message_index >= 0; message_index--) { const ConversationMessage& message = conversation.messages[message_index]; std::vector annotations = message.annotations; // Update how many messages we have processed from the last person in the // conversation and from any person in the conversation. num_messages_any_person++; if (all_from_last_person && message.user_id == last_person) { num_messages_last_person++; } else { all_from_last_person = false; } if (num_messages_any_person > max_from_any_person && (!all_from_last_person || num_messages_last_person > max_from_last_person)) { break; } if (message.user_id == kLocalUserId) { if (model_->annotation_actions_spec()->only_until_last_sent()) { break; } if (!model_->annotation_actions_spec()->include_local_user_messages()) { continue; } } if (annotations.empty() && annotator != nullptr) { annotations = annotator->Annotate(message.text, AnnotationOptionsForMessage(message)); } std::vector action_annotations; action_annotations.reserve(annotations.size()); for (const AnnotatedSpan& annotation : annotations) { if (annotation.classification.empty()) { continue; } const ClassificationResult& classification_result = annotation.classification[0]; ActionSuggestionAnnotation action_annotation; action_annotation.span = { message_index, annotation.span, UTF8ToUnicodeText(message.text, /*do_copy=*/false) .UTF8Substring(annotation.span.first, annotation.span.second)}; action_annotation.entity = classification_result; action_annotation.name = classification_result.collection; action_annotations.push_back(action_annotation); } if (model_->annotation_actions_spec()->deduplicate_annotations()) { // Create actions only for deduplicated annotations. for (const int annotation_id : DeduplicateAnnotations(action_annotations)) { SuggestActionsFromAnnotation( message_index, action_annotations[annotation_id], actions); } } else { // Create actions for all annotations. for (const ActionSuggestionAnnotation& annotation : action_annotations) { SuggestActionsFromAnnotation(message_index, annotation, actions); } } } } void ActionsSuggestions::SuggestActionsFromAnnotation( const int message_index, const ActionSuggestionAnnotation& annotation, std::vector* actions) const { for (const AnnotationActionsSpec_::AnnotationMapping* mapping : *model_->annotation_actions_spec()->annotation_mapping()) { if (annotation.entity.collection == mapping->annotation_collection()->str()) { if (annotation.entity.score < mapping->min_annotation_score()) { continue; } ActionSuggestion suggestion = SuggestionFromSpec(mapping->action()); if (mapping->use_annotation_score()) { suggestion.score = annotation.entity.score; } // Set annotation text as (additional) entity data field. if (mapping->entity_field() != nullptr) { std::unique_ptr entity_data = entity_data_builder_->NewRoot(); TC3_CHECK(entity_data != nullptr); // Merge existing static entity data. if (!suggestion.serialized_entity_data.empty()) { entity_data->MergeFromSerializedFlatbuffer( StringPiece(suggestion.serialized_entity_data.c_str(), suggestion.serialized_entity_data.size())); } entity_data->ParseAndSet(mapping->entity_field(), annotation.span.text); suggestion.serialized_entity_data = entity_data->Serialize(); } suggestion.annotations = {annotation}; actions->push_back(suggestion); } } } std::vector ActionsSuggestions::DeduplicateAnnotations( const std::vector& annotations) const { std::map, int> deduplicated_annotations; for (int i = 0; i < annotations.size(); i++) { const std::pair key = {annotations[i].name, annotations[i].span.text}; auto entry = deduplicated_annotations.find(key); if (entry != deduplicated_annotations.end()) { // Kepp the annotation with the higher score. if (annotations[entry->second].entity.score < annotations[i].entity.score) { entry->second = i; } continue; } deduplicated_annotations.insert(entry, {key, i}); } std::vector result; result.reserve(deduplicated_annotations.size()); for (const auto& key_and_annotation : deduplicated_annotations) { result.push_back(key_and_annotation.second); } return result; } bool ActionsSuggestions::FillAnnotationFromMatchGroup( const UniLib::RegexMatcher* matcher, const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group, const int message_index, ActionSuggestionAnnotation* annotation) const { if (group->annotation_name() != nullptr || group->annotation_type() != nullptr) { int status = UniLib::RegexMatcher::kNoError; const CodepointSpan span = {matcher->Start(group->group_id(), &status), matcher->End(group->group_id(), &status)}; std::string text = matcher->Group(group->group_id(), &status).ToUTF8String(); if (status != UniLib::RegexMatcher::kNoError) { TC3_LOG(ERROR) << "Could not extract span from rule capturing group."; return false; } // The capturing group was not part of the match. if (span.first == kInvalidIndex || span.second == kInvalidIndex) { return false; } annotation->span.span = span; annotation->span.message_index = message_index; annotation->span.text = text; if (group->annotation_name() != nullptr) { annotation->name = group->annotation_name()->str(); } if (group->annotation_type() != nullptr) { annotation->entity.collection = group->annotation_type()->str(); } } return true; } bool ActionsSuggestions::SuggestActionsFromRules( const Conversation& conversation, std::vector* actions) const { // Create actions based on rules checking the last message. const int message_index = conversation.messages.size() - 1; const std::string& message = conversation.messages.back().text; const UnicodeText message_unicode( UTF8ToUnicodeText(message, /*do_copy=*/false)); for (const CompiledRule& rule : rules_) { const std::unique_ptr matcher = rule.pattern->Matcher(message_unicode); int status = UniLib::RegexMatcher::kNoError; while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { for (const RulesModel_::Rule_::RuleActionSpec* rule_action : *rule.rule->actions()) { const ActionSuggestionSpec* action = rule_action->action(); std::vector annotations; bool sets_entity_data = false; std::unique_ptr entity_data = entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot() : nullptr; // Set static entity data. if (action != nullptr && action->serialized_entity_data() != nullptr) { TC3_CHECK(entity_data != nullptr); sets_entity_data = true; entity_data->MergeFromSerializedFlatbuffer( StringPiece(action->serialized_entity_data()->c_str(), action->serialized_entity_data()->size())); } // Add entity data from rule capturing groups. if (rule_action->capturing_group() != nullptr) { for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group : *rule_action->capturing_group()) { if (group->entity_field() != nullptr) { TC3_CHECK(entity_data != nullptr); sets_entity_data = true; if (!SetFieldFromCapturingGroup( group->group_id(), group->entity_field(), matcher.get(), entity_data.get())) { TC3_LOG(ERROR) << "Could not set entity data from rule capturing group."; return false; } } // Create a text annotation for the group span. ActionSuggestionAnnotation annotation; if (FillAnnotationFromMatchGroup(matcher.get(), group, message_index, &annotation)) { annotations.push_back(annotation); } // Create text reply. if (group->text_reply() != nullptr) { int status = UniLib::RegexMatcher::kNoError; const std::string group_text = matcher->Group(group->group_id(), &status).ToUTF8String(); if (status != UniLib::RegexMatcher::kNoError) { TC3_LOG(ERROR) << "Could get text from capturing group."; return false; } if (group_text.empty()) { // The group was not part of the match, ignore and continue. continue; } actions->push_back(SuggestionFromSpec( group->text_reply(), /*default_type=*/model_->smart_reply_action_type()->str(), /*default_response_text=*/group_text)); } } } if (action != nullptr) { ActionSuggestion suggestion = SuggestionFromSpec(action); suggestion.annotations = annotations; if (sets_entity_data) { suggestion.serialized_entity_data = entity_data->Serialize(); } actions->push_back(suggestion); } } } } return true; } bool ActionsSuggestions::SuggestActionsFromLua( const Conversation& conversation, const TfLiteModelExecutor* model_executor, const tflite::Interpreter* interpreter, const reflection::Schema* annotation_entity_data_schema, std::vector* actions) const { if (lua_bytecode_.empty()) { return true; } auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions( lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(), interpreter, entity_data_schema_, annotation_entity_data_schema); if (lua_actions == nullptr) { TC3_LOG(ERROR) << "Could not create lua actions."; return false; } return lua_actions->SuggestActions(actions); } bool ActionsSuggestions::GatherActionsSuggestions( const Conversation& conversation, const Annotator* annotator, const ActionSuggestionOptions& options, ActionsSuggestionsResponse* response) const { if (conversation.messages.empty()) { return true; } const int num_messages = NumMessagesToConsider( conversation, model_->max_conversation_history_length()); if (num_messages <= 0) { TC3_LOG(INFO) << "No messages provided for actions suggestions."; return false; } SuggestActionsFromAnnotations(conversation, options, annotator, &response->actions); int input_text_length = 0; int num_matching_locales = 0; for (int i = conversation.messages.size() - num_messages; i < conversation.messages.size(); i++) { input_text_length += conversation.messages[i].text.length(); std::vector message_languages; if (!ParseLocales(conversation.messages[i].detected_text_language_tags, &message_languages)) { continue; } if (Locale::IsAnyLocaleSupported( message_languages, locales_, preconditions_.handle_unknown_locale_as_supported)) { ++num_matching_locales; } } // Bail out if we are provided with too few or too much input. if (input_text_length < preconditions_.min_input_length || (preconditions_.max_input_length >= 0 && input_text_length > preconditions_.max_input_length)) { TC3_LOG(INFO) << "Too much or not enough input for inference."; return response; } // Bail out if the text does not look like it can be handled by the model. const float matching_fraction = static_cast(num_matching_locales) / num_messages; if (matching_fraction < preconditions_.min_locale_match_fraction) { TC3_LOG(INFO) << "Not enough locale matches."; response->output_filtered_locale_mismatch = true; return true; } std::vector post_check_rules; if (preconditions_.suppress_on_low_confidence_input && IsLowConfidenceInput(conversation, num_messages, &post_check_rules)) { response->output_filtered_low_confidence = true; return true; } std::unique_ptr interpreter; if (!SuggestActionsFromModel(conversation, num_messages, options, response, &interpreter)) { TC3_LOG(ERROR) << "Could not run model."; return false; } // Suppress all predictions if the conversation was deemed sensitive. if (preconditions_.suppress_on_sensitive_topic && response->output_filtered_sensitivity) { return true; } if (!SuggestActionsFromLua( conversation, model_executor_.get(), interpreter.get(), annotator != nullptr ? annotator->entity_data_schema() : nullptr, &response->actions)) { TC3_LOG(ERROR) << "Could not suggest actions from script."; return false; } if (!SuggestActionsFromRules(conversation, &response->actions)) { TC3_LOG(ERROR) << "Could not suggest actions from rules."; return false; } if (preconditions_.suppress_on_low_confidence_input && !FilterConfidenceOutput(post_check_rules, &response->actions)) { TC3_LOG(ERROR) << "Could not post-check actions."; return false; } return true; } ActionsSuggestionsResponse ActionsSuggestions::SuggestActions( const Conversation& conversation, const Annotator* annotator, const ActionSuggestionOptions& options) const { ActionsSuggestionsResponse response; if (!GatherActionsSuggestions(conversation, annotator, options, &response)) { TC3_LOG(ERROR) << "Could not gather actions suggestions."; response.actions.clear(); } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_, annotator != nullptr ? annotator->entity_data_schema() : nullptr)) { TC3_LOG(ERROR) << "Could not rank actions."; response.actions.clear(); } return response; } ActionsSuggestionsResponse ActionsSuggestions::SuggestActions( const Conversation& conversation, const ActionSuggestionOptions& options) const { return SuggestActions(conversation, /*annotator=*/nullptr, options); } const ActionsModel* ActionsSuggestions::model() const { return model_; } const reflection::Schema* ActionsSuggestions::entity_data_schema() const { return entity_data_schema_; } const ActionsModel* ViewActionsModel(const void* buffer, int size) { if (buffer == nullptr) { return nullptr; } return LoadAndVerifyModel(reinterpret_cast(buffer), size); } } // namespace libtextclassifier3