/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/actions-suggestions.h" #include #include #include #include "actions/actions_model_generated.h" #include "actions/test_utils.h" #include "actions/zlib-utils.h" #include "annotator/collections.h" #include "annotator/types.h" #include "utils/flatbuffers.h" #include "utils/flatbuffers_generated.h" #include "utils/hash/farmhash.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "flatbuffers/flatbuffers.h" #include "flatbuffers/reflection.h" namespace libtextclassifier3 { namespace { using testing::_; constexpr char kModelFileName[] = "actions_suggestions_test.model"; constexpr char kHashGramModelFileName[] = "actions_suggestions_test.hashgram.model"; std::string ReadFile(const std::string& file_name) { std::ifstream file_stream(file_name); return std::string(std::istreambuf_iterator(file_stream), {}); } std::string GetModelPath() { return ""; } class ActionsSuggestionsTest : public testing::Test { protected: ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} std::unique_ptr LoadTestModel() { return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName, &unilib_); } std::unique_ptr LoadHashGramTestModel() { return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName, &unilib_); } UniLib unilib_; }; TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) { EXPECT_THAT(LoadTestModel(), testing::NotNull()); } TEST_F(ActionsSuggestionsTest, SuggestActions) { std::unique_ptr actions_suggestions = LoadTestModel(); const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/); } TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) { std::unique_ptr actions_suggestions = LoadTestModel(); const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"zz"}}}); EXPECT_THAT(response.actions, testing::IsEmpty()); } TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) { std::unique_ptr actions_suggestions = LoadTestModel(); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = {ClassificationResult("address", 1.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions.front().type, "view_map"); EXPECT_EQ(response.actions.front().score, 1.0); } TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); SetTestEntityDataSchema(actions_model.get()); // Set custom actions from annotations config. actions_model->annotation_actions_spec->annotation_mapping.clear(); actions_model->annotation_actions_spec->annotation_mapping.emplace_back( new AnnotationActionsSpec_::AnnotationMappingT); AnnotationActionsSpec_::AnnotationMappingT* mapping = actions_model->annotation_actions_spec->annotation_mapping.back().get(); mapping->annotation_collection = "address"; mapping->action.reset(new ActionSuggestionSpecT); mapping->action->type = "save_location"; mapping->action->score = 1.0; mapping->action->priority_score = 2.0; mapping->entity_field.reset(new FlatbufferFieldPathT); mapping->entity_field->field.emplace_back(new FlatbufferFieldT); mapping->entity_field->field.back()->field_name = "location"; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = {ClassificationResult("address", 1.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions.front().type, "save_location"); EXPECT_EQ(response.actions.front().score, 1.0); // Check that the `location` entity field holds the text from the address // annotation. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( response.actions.front().serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/6)->str(), "home"); } TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) { std::unique_ptr actions_suggestions = LoadTestModel(); AnnotatedSpan flight_annotation; flight_annotation.span = {11, 15}; flight_annotation.classification = {ClassificationResult("flight", 2.5)}; AnnotatedSpan flight_annotation2; flight_annotation2.span = {35, 39}; flight_annotation2.classification = {ClassificationResult("flight", 3.0)}; AnnotatedSpan email_annotation; email_annotation.span = {55, 68}; email_annotation.classification = {ClassificationResult("email", 2.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "call me at LX38 or send message to LX38 or test@test.com.", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {flight_annotation, flight_annotation2, email_annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 2); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[0].score, 3.0); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[1].score, 2.0); } TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Disable deduplication. actions_model->annotation_actions_spec->deduplicate_annotations = false; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); AnnotatedSpan flight_annotation; flight_annotation.span = {11, 15}; flight_annotation.classification = {ClassificationResult("flight", 2.5)}; AnnotatedSpan flight_annotation2; flight_annotation2.span = {35, 39}; flight_annotation2.classification = {ClassificationResult("flight", 3.0)}; AnnotatedSpan email_annotation; email_annotation.span = {55, 68}; email_annotation.classification = {ClassificationResult("email", 2.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "call me at LX38 or send message to LX38 or test@test.com.", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {flight_annotation, flight_annotation2, email_annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 3); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[0].score, 3.0); EXPECT_EQ(response.actions[1].type, "track_flight"); EXPECT_EQ(response.actions[1].score, 2.5); EXPECT_EQ(response.actions[2].type, "send_email"); EXPECT_EQ(response.actions[2].score, 2.0); } ActionsSuggestionsResponse TestSuggestActionsFromAnnotations( const std::function& set_config_fn, const UniLib* unilib = nullptr) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Set custom config. set_config_fn(actions_model.get()); // Disable smart reply for easier testing. actions_model->preconditions->min_smart_reply_triggering_score = 1.0; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib); AnnotatedSpan flight_annotation; flight_annotation.span = {15, 19}; flight_annotation.classification = {ClassificationResult("flight", 2.0)}; AnnotatedSpan email_annotation; email_annotation.span = {0, 16}; email_annotation.classification = {ClassificationResult("email", 1.0)}; return actions_suggestions->SuggestActions( {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hehe@android.com", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {email_annotation}, /*locales=*/"en"}, {/*user_id=*/2, "yoyo@android.com", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {email_annotation}, /*locales=*/"en"}, {/*user_id=*/1, "test@android.com", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {email_annotation}, /*locales=*/"en"}, {/*user_id=*/1, "I am on flight LX38.", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/ {flight_annotation}, /*locales=*/"en"}}}); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 1; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, &unilib_); EXPECT_EQ(response.actions.size(), 1); EXPECT_EQ(response.actions[0].type, "track_flight"); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 1; actions_model->annotation_actions_spec->max_history_from_last_person = 3; }, &unilib_); EXPECT_EQ(response.actions.size(), 2); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 2; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, &unilib_); EXPECT_EQ(response.actions.size(), 2); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAnyManyMessages) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 3; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, &unilib_); EXPECT_EQ(response.actions.size(), 3); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[2].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = false; actions_model->annotation_actions_spec->only_until_last_sent = true; actions_model->annotation_actions_spec->max_history_from_any_person = 5; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, &unilib_); EXPECT_EQ(response.actions.size(), 3); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[2].type, "send_email"); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) { const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( [](ActionsModelT* actions_model) { actions_model->annotation_actions_spec->include_local_user_messages = true; actions_model->annotation_actions_spec->only_until_last_sent = false; actions_model->annotation_actions_spec->max_history_from_any_person = 5; actions_model->annotation_actions_spec->max_history_from_last_person = 1; }, &unilib_); EXPECT_EQ(response.actions.size(), 4); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[1].type, "send_email"); EXPECT_EQ(response.actions[2].type, "send_email"); EXPECT_EQ(response.actions[3].type, "send_email"); } void TestSuggestActionsWithThreshold( const std::function& set_value_fn, const UniLib* unilib = nullptr, const int expected_size = 0, const std::string& preconditions_overwrite = "") { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); set_value_fn(actions_model.get()); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), unilib, preconditions_overwrite); ASSERT_TRUE(actions_suggestions); const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "I have the low-ground. Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_LE(response.actions.size(), expected_size); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->min_smart_reply_triggering_score = 1.0; }, &unilib_, /*expected_size=*/1 /*no smart reply, only actions*/ ); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->min_reply_score_threshold = 1.0; }, &unilib_, /*expected_size=*/1 /*no smart reply, only actions*/ ); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->max_sensitive_topic_score = 0.0; }, &unilib_, /*expected_size=*/4 /* no sensitive prediction in test model*/); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->max_input_length = 0; }, &unilib_); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->min_input_length = 100; }, &unilib_); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) { TriggeringPreconditionsT preconditions_overwrite; preconditions_overwrite.max_input_length = 0; flatbuffers::FlatBufferBuilder builder; builder.Finish( TriggeringPreconditions::Pack(builder, &preconditions_overwrite)); TestSuggestActionsWithThreshold( // Keep model untouched. [](ActionsModelT* actions_model) {}, &unilib_, /*expected_size=*/0, std::string(reinterpret_cast(builder.GetBufferPointer()), builder.GetSize())); } #ifdef TC3_UNILIB_ICU TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) { TestSuggestActionsWithThreshold( [](ActionsModelT* actions_model) { actions_model->preconditions->suppress_on_low_confidence_input = true; actions_model->low_confidence_rules.reset(new RulesModelT); actions_model->low_confidence_rules->rule.emplace_back( new RulesModel_::RuleT); actions_model->low_confidence_rules->rule.back()->pattern = "low-ground"; }, &unilib_); } TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Add custom triggering rule. actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); rule->pattern = "^(?i:hello\\s(there))$"; { std::unique_ptr rule_action( new RulesModel_::Rule_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Desaster!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } { std::unique_ptr rule_action( new RulesModel_::Rule_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Kenobi!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } // Add input-output low confidence rule. actions_model->preconditions->suppress_on_low_confidence_input = true; actions_model->low_confidence_rules.reset(new RulesModelT); actions_model->low_confidence_rules->rule.emplace_back( new RulesModel_::RuleT); actions_model->low_confidence_rules->rule.back()->pattern = "hello"; actions_model->low_confidence_rules->rule.back()->output_pattern = "(?i:desaster)"; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); ASSERT_TRUE(actions_suggestions); const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); } TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutputOverwrite) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); actions_model->low_confidence_rules.reset(); // Add custom triggering rule. actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); rule->pattern = "^(?i:hello\\s(there))$"; { std::unique_ptr rule_action( new RulesModel_::Rule_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Desaster!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } { std::unique_ptr rule_action( new RulesModel_::Rule_::RuleActionSpecT); rule_action->action.reset(new ActionSuggestionSpecT); rule_action->action->type = "text_reply"; rule_action->action->response_text = "General Kenobi!"; rule_action->action->score = 1.0f; rule_action->action->priority_score = 1.0f; rule->actions.push_back(std::move(rule_action)); } // Add custom triggering rule via overwrite. actions_model->preconditions->low_confidence_rules.reset(); TriggeringPreconditionsT preconditions; preconditions.suppress_on_low_confidence_input = true; preconditions.low_confidence_rules.reset(new RulesModelT); preconditions.low_confidence_rules->rule.emplace_back(new RulesModel_::RuleT); preconditions.low_confidence_rules->rule.back()->pattern = "hello"; preconditions.low_confidence_rules->rule.back()->output_pattern = "(?i:desaster)"; flatbuffers::FlatBufferBuilder preconditions_builder; preconditions_builder.Finish( TriggeringPreconditions::Pack(preconditions_builder, &preconditions)); std::string serialize_preconditions = std::string( reinterpret_cast(preconditions_builder.GetBufferPointer()), preconditions_builder.GetSize()); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_, serialize_preconditions); ASSERT_TRUE(actions_suggestions); const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); } #endif TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Don't test if no sensitivity score is produced if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) { return; } actions_model->preconditions->max_sensitive_topic_score = 0.0; actions_model->preconditions->suppress_on_sensitive_topic = true; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = { ClassificationResult(Collections::Address(), 1.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); EXPECT_THAT(response.actions, testing::IsEmpty()); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); // Allow a larger conversation context. actions_model->max_conversation_history_length = 10; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); AnnotatedSpan annotation; annotation.span = {11, 15}; annotation.classification = { ClassificationResult(Collections::Address(), 1.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?", /*reference_time_ms_utc=*/10000, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}, {/*user_id=*/1, "good! are you at home?", /*reference_time_ms_utc=*/15000, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].type, "view_map"); EXPECT_EQ(response.actions[0].score, 1.0); } TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) { std::unique_ptr actions_suggestions = LoadTestModel(); AnnotatedSpan annotation; annotation.span = {8, 12}; annotation.classification = { ClassificationResult(Collections::Flight(), 1.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "I'm on LX38?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); ASSERT_GE(response.actions.size(), 2); EXPECT_EQ(response.actions[0].type, "track_flight"); EXPECT_EQ(response.actions[0].score, 1.0); EXPECT_EQ(response.actions[0].annotations.size(), 1); EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0); EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span); } #ifdef TC3_UNILIB_ICU TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); rule->pattern = "^(?i:hello\\s(there))$"; rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT); rule->actions.back()->action.reset(new ActionSuggestionSpecT); ActionSuggestionSpecT* action = rule->actions.back()->action.get(); action->type = "text_reply"; action->response_text = "General Kenobi!"; action->score = 1.0f; action->priority_score = 1.0f; // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group = rule->actions.back()->capturing_group.back().get(); greeting_group->group_id = 0; greeting_group->entity_field.reset(new FlatbufferFieldPathT); greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT); greeting_group->entity_field->field.back()->field_name = "greeting"; rule->actions.back()->capturing_group.emplace_back( new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group = rule->actions.back()->capturing_group.back().get(); location_group->group_id = 1; location_group->entity_field.reset(new FlatbufferFieldPathT); location_group->entity_field->field.emplace_back(new FlatbufferFieldT); location_group->entity_field->field.back()->field_name = "location"; // Set test entity data schema. SetTestEntityDataSchema(actions_model.get()); // Use meta data to generate custom serialized entity data. ReflectiveFlatbufferBuilder entity_data_builder( flatbuffers::GetRoot( actions_model->actions_entity_data_schema.data())); std::unique_ptr entity_data = entity_data_builder.NewRoot(); entity_data->Set("person", "Kenobi"); action->serialized_entity_data = entity_data->Serialize(); flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); // Check entity data. const flatbuffers::Table* entity = flatbuffers::GetAnyRoot(reinterpret_cast( response.actions[0].serialized_entity_data.data())); EXPECT_EQ(entity->GetPointer(/*field=*/4)->str(), "hello there"); EXPECT_EQ(entity->GetPointer(/*field=*/6)->str(), "there"); EXPECT_EQ(entity->GetPointer(/*field=*/8)->str(), "Kenobi"); } TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) { const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )"; rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT); // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = rule->actions.back()->capturing_group.back().get(); code_group->group_id = 1; code_group->text_reply.reset(new ActionSuggestionSpecT); code_group->text_reply->score = 1.0f; code_group->text_reply->priority_score = 1.0f; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); std::unique_ptr actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "visit test.com or reply STOP to cancel your subscription", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].response_text, "STOP"); } TEST_F(ActionsSuggestionsTest, DeduplicateActions) { std::unique_ptr actions_suggestions = LoadTestModel(); ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); // Check that the location sharing model triggered. bool has_location_sharing_action = false; for (const ActionSuggestion action : response.actions) { if (action.type == ActionsSuggestions::kShareLocation) { has_location_sharing_action = true; break; } } EXPECT_TRUE(has_location_sharing_action); const int num_actions = response.actions.size(); // Add custom rule for location sharing. const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$"; actions_model->rules->rule.back()->actions.emplace_back( new RulesModel_::Rule_::RuleActionSpecT); actions_model->rules->rule.back()->actions.back()->action.reset( new ActionSuggestionSpecT); ActionSuggestionSpecT* action = actions_model->rules->rule.back()->actions.back()->action.get(); action->score = 1.0f; action->type = ActionsSuggestions::kShareLocation; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_EQ(response.actions.size(), num_actions); } TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) { std::unique_ptr actions_suggestions = LoadTestModel(); AnnotatedSpan annotation; annotation.span = {7, 11}; annotation.classification = { ClassificationResult(Collections::Flight(), 1.0)}; ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "I'm on LX38", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); // Check that the phone actions are present. EXPECT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].type, "track_flight"); // Add custom rule. const std::string actions_model_string = ReadFile(GetModelPath() + kModelFileName); std::unique_ptr actions_model = UnPackActionsModel(actions_model_string.c_str()); ASSERT_TRUE(DecompressActionsModel(actions_model.get())); actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); rule->pattern = "^(?i:I'm on ([a-z0-9]+))$"; rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT); rule->actions.back()->action.reset(new ActionSuggestionSpecT); ActionSuggestionSpecT* action = rule->actions.back()->action.get(); action->score = 1.0f; action->priority_score = 2.0f; action->type = "test_code"; rule->actions.back()->capturing_group.emplace_back( new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = rule->actions.back()->capturing_group.back().get(); code_group->group_id = 1; code_group->annotation_name = "code"; code_group->annotation_type = "code"; flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, actions_model.get())); actions_suggestions = ActionsSuggestions::FromUnownedBuffer( reinterpret_cast(builder.GetBufferPointer()), builder.GetSize(), &unilib_); response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "I'm on LX38", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{annotation}, /*locales=*/"en"}}}); EXPECT_GE(response.actions.size(), 1); EXPECT_EQ(response.actions[0].type, "test_code"); } #endif TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) { std::unique_ptr actions_suggestions = LoadTestModel(); std::vector annotations(2); annotations[0].span = {11, 15}; annotations[0].classification = {ClassificationResult("address", 1.0)}; annotations[1].span = {19, 23}; annotations[1].classification = {ClassificationResult("address", 2.0)}; const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "are you at home or work?", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/annotations, /*locales=*/"en"}}}); EXPECT_GE(response.actions.size(), 2); EXPECT_EQ(response.actions[0].type, "view_map"); EXPECT_EQ(response.actions[0].score, 2.0); EXPECT_EQ(response.actions[1].type, "view_map"); EXPECT_EQ(response.actions[1].score, 1.0); } TEST_F(ActionsSuggestionsTest, VisitActionsModel) { EXPECT_TRUE(VisitActionsModel(GetModelPath() + kModelFileName, [](const ActionsModel* model) { if (model == nullptr) { return false; } return true; })); EXPECT_FALSE(VisitActionsModel(GetModelPath() + "non_existing_model.fb", [](const ActionsModel* model) { if (model == nullptr) { return false; } return true; })); } TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) { std::unique_ptr actions_suggestions = LoadHashGramTestModel(); ASSERT_TRUE(actions_suggestions != nullptr); { const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_THAT(response.actions, testing::IsEmpty()); } { const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "where are you", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_THAT( response.actions, ElementsAre(testing::Field(&ActionSuggestion::type, "share_location"))); } { const ActionsSuggestionsResponse& response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "do you know johns number", /*reference_time_ms_utc=*/0, /*reference_timezone=*/"Europe/Zurich", /*annotations=*/{}, /*locales=*/"en"}}}); EXPECT_THAT( response.actions, ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact"))); } } // Test class to expose token embedding methods for testing. class TestingMessageEmbedder : private ActionsSuggestions { public: explicit TestingMessageEmbedder(const ActionsModel* model); using ActionsSuggestions::EmbedAndFlattenTokens; using ActionsSuggestions::EmbedTokensPerMessage; protected: // EmbeddingExecutor that always returns features based on // the id of the sparse features. class FakeEmbeddingExecutor : public EmbeddingExecutor { public: bool AddEmbedding(const TensorView& sparse_features, float* dest, const int dest_size) const override { TC3_CHECK_GE(dest_size, 1); EXPECT_EQ(sparse_features.size(), 1); dest[0] = sparse_features.data()[0]; return true; } }; }; TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) { model_ = model; const ActionsTokenFeatureProcessorOptions* options = model->feature_processor_options(); feature_processor_.reset( new ActionsFeatureProcessor(options, /*unilib=*/nullptr)); embedding_executor_.reset(new FakeEmbeddingExecutor()); EXPECT_TRUE( EmbedTokenId(options->padding_token_id(), &embedded_padding_token_)); EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_)); EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_)); token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize(); EXPECT_EQ(token_embedding_size_, 1); } class EmbeddingTest : public testing::Test { protected: EmbeddingTest() { model_.feature_processor_options.reset( new ActionsTokenFeatureProcessorOptionsT); options_ = model_.feature_processor_options.get(); options_->chargram_orders = {1}; options_->num_buckets = 1000; options_->embedding_size = 1; options_->start_token_id = 0; options_->end_token_id = 1; options_->padding_token_id = 2; options_->tokenizer_options.reset(new ActionsTokenizerOptionsT); } TestingMessageEmbedder CreateTestingMessageEmbedder() { flatbuffers::FlatBufferBuilder builder; FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_)); buffer_ = builder.ReleaseBufferPointer(); return TestingMessageEmbedder( flatbuffers::GetRoot(buffer_.data())); } flatbuffers::DetachedBuffer buffer_; ActionsModelT model_; ActionsTokenFeatureProcessorOptionsT* options_; }; TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) { const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; std::vector embeddings; int max_num_tokens_per_message = 0; EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, &max_num_tokens_per_message)); EXPECT_EQ(max_num_tokens_per_message, 3); EXPECT_EQ(embeddings.size(), 3); EXPECT_THAT(embeddings[0], testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[2], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); } TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) { options_->min_num_tokens_per_message = 5; const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; std::vector embeddings; int max_num_tokens_per_message = 0; EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, &max_num_tokens_per_message)); EXPECT_EQ(max_num_tokens_per_message, 5); EXPECT_EQ(embeddings.size(), 5); EXPECT_THAT(embeddings[0], testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[2], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id)); EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id)); } TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) { options_->max_num_tokens_per_message = 2; const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; std::vector embeddings; int max_num_tokens_per_message = 0; EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, &max_num_tokens_per_message)); EXPECT_EQ(max_num_tokens_per_message, 2); EXPECT_EQ(embeddings.size(), 2); EXPECT_THAT(embeddings[0], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); } TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) { const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}, {Token("d", 0, 1), Token("e", 2, 3)}}; std::vector embeddings; int max_num_tokens_per_message = 0; EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, &max_num_tokens_per_message)); EXPECT_EQ(max_num_tokens_per_message, 3); EXPECT_THAT(embeddings[0], testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[2], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[3], testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[4], testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id)); } TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) { const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; std::vector embeddings; int total_token_count = 0; EXPECT_TRUE( embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); EXPECT_EQ(total_token_count, 5); EXPECT_EQ(embeddings.size(), 5); EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[2], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[3], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id)); } TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) { options_->min_num_total_tokens = 7; const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; std::vector embeddings; int total_token_count = 0; EXPECT_TRUE( embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); EXPECT_EQ(total_token_count, 7); EXPECT_EQ(embeddings.size(), 7); EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[2], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[3], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id)); EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id)); EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id)); } TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) { options_->max_num_total_tokens = 3; const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; std::vector embeddings; int total_token_count = 0; EXPECT_TRUE( embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); EXPECT_EQ(total_token_count, 3); EXPECT_EQ(embeddings.size(), 3); EXPECT_THAT(embeddings[0], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id)); } TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) { const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}, {Token("d", 0, 1), Token("e", 2, 3)}}; std::vector embeddings; int total_token_count = 0; EXPECT_TRUE( embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); EXPECT_EQ(total_token_count, 9); EXPECT_EQ(embeddings.size(), 9); EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id)); EXPECT_THAT(embeddings[1], testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[2], testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[3], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id)); EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id)); EXPECT_THAT(embeddings[6], testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[7], testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id)); } TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) { options_->max_num_total_tokens = 7; const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); std::vector> tokens = { {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}, {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}}; std::vector embeddings; int total_token_count = 0; EXPECT_TRUE( embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); EXPECT_EQ(total_token_count, 7); EXPECT_EQ(embeddings.size(), 7); EXPECT_THAT(embeddings[0], testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id)); EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id)); EXPECT_THAT(embeddings[3], testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[4], testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[5], testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) % options_->num_buckets)); EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id)); } } // namespace } // namespace libtextclassifier3