1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/testing/annotator.h"
18
19 #include "utils/flatbuffers/mutable.h"
20 #include "flatbuffers/reflection.h"
21
22 namespace libtextclassifier3 {
23
FirstResult(const std::vector<ClassificationResult> & results)24 std::string FirstResult(const std::vector<ClassificationResult>& results) {
25 if (results.empty()) {
26 return "<INVALID RESULTS>";
27 }
28 return results[0].collection;
29 }
30
ReadFile(const std::string & file_name)31 std::string ReadFile(const std::string& file_name) {
32 std::ifstream file_stream(file_name);
33 return std::string(std::istreambuf_iterator<char>(file_stream), {});
34 }
35
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score,const float priority_score)36 std::unique_ptr<RegexModel_::PatternT> MakePattern(
37 const std::string& collection_name, const std::string& pattern,
38 const bool enabled_for_classification, const bool enabled_for_selection,
39 const bool enabled_for_annotation, const float score,
40 const float priority_score) {
41 std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
42 result->collection_name = collection_name;
43 result->pattern = pattern;
44 // We cannot directly operate with |= on the flag, so use an int here.
45 int enabled_modes = ModeFlag_NONE;
46 if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
47 if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
48 if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
49 result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
50 result->target_classification_score = score;
51 result->priority_score = priority_score;
52 return result;
53 }
54
55 // Shortcut function that doesn't need to specify the priority score.
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score)56 std::unique_ptr<RegexModel_::PatternT> MakePattern(
57 const std::string& collection_name, const std::string& pattern,
58 const bool enabled_for_classification, const bool enabled_for_selection,
59 const bool enabled_for_annotation, const float score) {
60 return MakePattern(collection_name, pattern, enabled_for_classification,
61 enabled_for_selection, enabled_for_annotation,
62 /*score=*/score,
63 /*priority_score=*/score);
64 }
65
AddTestRegexModel(ModelT * unpacked_model)66 void AddTestRegexModel(ModelT* unpacked_model) {
67 // Add test regex models.
68 unpacked_model->regex_model->patterns.push_back(MakePattern(
69 "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
70 /*enabled_for_classification=*/true,
71 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
72
73 // Use meta data to generate custom serialized entity data.
74 MutableFlatbufferBuilder entity_data_builder(
75 flatbuffers::GetRoot<reflection::Schema>(
76 unpacked_model->entity_data_schema.data()));
77 RegexModel_::PatternT* pattern =
78 unpacked_model->regex_model->patterns.back().get();
79
80 {
81 std::unique_ptr<MutableFlatbuffer> entity_data =
82 entity_data_builder.NewRoot();
83 entity_data->Set("is_alive", true);
84 pattern->serialized_entity_data = entity_data->Serialize();
85 }
86 pattern->capturing_group.emplace_back(new CapturingGroupT);
87 pattern->capturing_group.emplace_back(new CapturingGroupT);
88 pattern->capturing_group.emplace_back(new CapturingGroupT);
89 pattern->capturing_group.emplace_back(new CapturingGroupT);
90 // Group 0 is the full match, capturing groups starting at 1.
91 pattern->capturing_group[1]->entity_field_path.reset(
92 new FlatbufferFieldPathT);
93 pattern->capturing_group[1]->entity_field_path->field.emplace_back(
94 new FlatbufferFieldT);
95 pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
96 "first_name";
97 pattern->capturing_group[2]->entity_field_path.reset(
98 new FlatbufferFieldPathT);
99 pattern->capturing_group[2]->entity_field_path->field.emplace_back(
100 new FlatbufferFieldT);
101 pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
102 "last_name";
103 // Set `former_us_president` field if we match Obama.
104 {
105 std::unique_ptr<MutableFlatbuffer> entity_data =
106 entity_data_builder.NewRoot();
107 entity_data->Set("former_us_president", true);
108 pattern->capturing_group[2]->serialized_entity_data =
109 entity_data->Serialize();
110 }
111 pattern->capturing_group[3]->entity_field_path.reset(
112 new FlatbufferFieldPathT);
113 pattern->capturing_group[3]->entity_field_path->field.emplace_back(
114 new FlatbufferFieldT);
115 pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
116 "age";
117 }
118
CreateEmptyModel(const std::function<void (ModelT * model)> model_update_fn)119 std::string CreateEmptyModel(
120 const std::function<void(ModelT* model)> model_update_fn) {
121 ModelT model;
122 model_update_fn(&model);
123
124 flatbuffers::FlatBufferBuilder builder;
125 FinishModelBuffer(builder, Model::Pack(builder, &model));
126 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
127 builder.GetSize());
128 }
129
130 // Create fake entity data schema meta data.
AddTestEntitySchemaData(ModelT * unpacked_model)131 void AddTestEntitySchemaData(ModelT* unpacked_model) {
132 // Cannot use object oriented API here as that is not available for the
133 // reflection schema.
134 flatbuffers::FlatBufferBuilder schema_builder;
135 std::vector<flatbuffers::Offset<reflection::Field>> fields = {
136 reflection::CreateField(
137 schema_builder,
138 /*name=*/schema_builder.CreateString("first_name"),
139 /*type=*/
140 reflection::CreateType(schema_builder,
141 /*base_type=*/reflection::String),
142 /*id=*/0,
143 /*offset=*/4),
144 reflection::CreateField(
145 schema_builder,
146 /*name=*/schema_builder.CreateString("is_alive"),
147 /*type=*/
148 reflection::CreateType(schema_builder,
149 /*base_type=*/reflection::Bool),
150 /*id=*/1,
151 /*offset=*/6),
152 reflection::CreateField(
153 schema_builder,
154 /*name=*/schema_builder.CreateString("last_name"),
155 /*type=*/
156 reflection::CreateType(schema_builder,
157 /*base_type=*/reflection::String),
158 /*id=*/2,
159 /*offset=*/8),
160 reflection::CreateField(
161 schema_builder,
162 /*name=*/schema_builder.CreateString("age"),
163 /*type=*/
164 reflection::CreateType(schema_builder,
165 /*base_type=*/reflection::Int),
166 /*id=*/3,
167 /*offset=*/10),
168 reflection::CreateField(
169 schema_builder,
170 /*name=*/schema_builder.CreateString("former_us_president"),
171 /*type=*/
172 reflection::CreateType(schema_builder,
173 /*base_type=*/reflection::Bool),
174 /*id=*/4,
175 /*offset=*/12)};
176 std::vector<flatbuffers::Offset<reflection::Enum>> enums;
177 std::vector<flatbuffers::Offset<reflection::Object>> objects = {
178 reflection::CreateObject(
179 schema_builder,
180 /*name=*/schema_builder.CreateString("EntityData"),
181 /*fields=*/
182 schema_builder.CreateVectorOfSortedTables(&fields))};
183 schema_builder.Finish(reflection::CreateSchema(
184 schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
185 schema_builder.CreateVectorOfSortedTables(&enums),
186 /*(unused) file_ident=*/0,
187 /*(unused) file_ext=*/0,
188 /*root_table*/ objects[0]));
189
190 unpacked_model->entity_data_schema.assign(
191 schema_builder.GetBufferPointer(),
192 schema_builder.GetBufferPointer() + schema_builder.GetSize());
193 }
194
MakeAnnotatedSpan(CodepointSpan span,const std::string & collection,const float score,AnnotatedSpan::Source source)195 AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
196 const std::string& collection,
197 const float score,
198 AnnotatedSpan::Source source) {
199 AnnotatedSpan result;
200 result.span = span;
201 result.classification.push_back({collection, score});
202 result.source = source;
203 return result;
204 }
205
206 } // namespace libtextclassifier3
207