/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/testing/annotator.h" #include "utils/flatbuffers/mutable.h" #include "flatbuffers/reflection.h" namespace libtextclassifier3 { std::string FirstResult(const std::vector& results) { if (results.empty()) { return ""; } return results[0].collection; } std::string ReadFile(const std::string& file_name) { std::ifstream file_stream(file_name); return std::string(std::istreambuf_iterator(file_stream), {}); } std::unique_ptr MakePattern( const std::string& collection_name, const std::string& pattern, const bool enabled_for_classification, const bool enabled_for_selection, const bool enabled_for_annotation, const float score, const float priority_score) { std::unique_ptr result(new RegexModel_::PatternT); result->collection_name = collection_name; result->pattern = pattern; // We cannot directly operate with |= on the flag, so use an int here. int enabled_modes = ModeFlag_NONE; if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION; if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION; if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION; result->enabled_modes = static_cast(enabled_modes); result->target_classification_score = score; result->priority_score = priority_score; return result; } // Shortcut function that doesn't need to specify the priority score. std::unique_ptr MakePattern( const std::string& collection_name, const std::string& pattern, const bool enabled_for_classification, const bool enabled_for_selection, const bool enabled_for_annotation, const float score) { return MakePattern(collection_name, pattern, enabled_for_classification, enabled_for_selection, enabled_for_annotation, /*score=*/score, /*priority_score=*/score); } void AddTestRegexModel(ModelT* unpacked_model) { // Add test regex models. unpacked_model->regex_model->patterns.push_back(MakePattern( "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old", /*enabled_for_classification=*/true, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0)); // Use meta data to generate custom serialized entity data. MutableFlatbufferBuilder entity_data_builder( flatbuffers::GetRoot( unpacked_model->entity_data_schema.data())); RegexModel_::PatternT* pattern = unpacked_model->regex_model->patterns.back().get(); { std::unique_ptr entity_data = entity_data_builder.NewRoot(); entity_data->Set("is_alive", true); pattern->serialized_entity_data = entity_data->Serialize(); } pattern->capturing_group.emplace_back(new CapturingGroupT); pattern->capturing_group.emplace_back(new CapturingGroupT); pattern->capturing_group.emplace_back(new CapturingGroupT); pattern->capturing_group.emplace_back(new CapturingGroupT); // Group 0 is the full match, capturing groups starting at 1. pattern->capturing_group[1]->entity_field_path.reset( new FlatbufferFieldPathT); pattern->capturing_group[1]->entity_field_path->field.emplace_back( new FlatbufferFieldT); pattern->capturing_group[1]->entity_field_path->field.back()->field_name = "first_name"; pattern->capturing_group[2]->entity_field_path.reset( new FlatbufferFieldPathT); pattern->capturing_group[2]->entity_field_path->field.emplace_back( new FlatbufferFieldT); pattern->capturing_group[2]->entity_field_path->field.back()->field_name = "last_name"; // Set `former_us_president` field if we match Obama. { std::unique_ptr entity_data = entity_data_builder.NewRoot(); entity_data->Set("former_us_president", true); pattern->capturing_group[2]->serialized_entity_data = entity_data->Serialize(); } pattern->capturing_group[3]->entity_field_path.reset( new FlatbufferFieldPathT); pattern->capturing_group[3]->entity_field_path->field.emplace_back( new FlatbufferFieldT); pattern->capturing_group[3]->entity_field_path->field.back()->field_name = "age"; } std::string CreateEmptyModel( const std::function model_update_fn) { ModelT model; model_update_fn(&model); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, &model)); return std::string(reinterpret_cast(builder.GetBufferPointer()), builder.GetSize()); } // Create fake entity data schema meta data. void AddTestEntitySchemaData(ModelT* unpacked_model) { // Cannot use object oriented API here as that is not available for the // reflection schema. flatbuffers::FlatBufferBuilder schema_builder; std::vector> fields = { reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("first_name"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::String), /*id=*/0, /*offset=*/4), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("is_alive"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::Bool), /*id=*/1, /*offset=*/6), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("last_name"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::String), /*id=*/2, /*offset=*/8), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("age"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::Int), /*id=*/3, /*offset=*/10), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("former_us_president"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::Bool), /*id=*/4, /*offset=*/12)}; std::vector> enums; std::vector> objects = { reflection::CreateObject( schema_builder, /*name=*/schema_builder.CreateString("EntityData"), /*fields=*/ schema_builder.CreateVectorOfSortedTables(&fields))}; schema_builder.Finish(reflection::CreateSchema( schema_builder, schema_builder.CreateVectorOfSortedTables(&objects), schema_builder.CreateVectorOfSortedTables(&enums), /*(unused) file_ident=*/0, /*(unused) file_ext=*/0, /*root_table*/ objects[0])); unpacked_model->entity_data_schema.assign( schema_builder.GetBufferPointer(), schema_builder.GetBufferPointer() + schema_builder.GetSize()); } AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span, const std::string& collection, const float score, AnnotatedSpan::Source source) { AnnotatedSpan result; result.span = span; result.classification.push_back({collection, score}); result.source = source; return result; } } // namespace libtextclassifier3