1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/flatbuffers.h"
18
19 #include <vector>
20 #include "utils/strings/numbers.h"
21 #include "utils/variant.h"
22
23 namespace libtextclassifier3 {
24 namespace {
CreateRepeatedField(const reflection::Schema * schema,const reflection::Type * type,std::unique_ptr<ReflectiveFlatbuffer::RepeatedField> * repeated_field)25 bool CreateRepeatedField(
26 const reflection::Schema* schema, const reflection::Type* type,
27 std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) {
28 switch (type->element()) {
29 case reflection::Bool:
30 repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>);
31 return true;
32 case reflection::Int:
33 repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>);
34 return true;
35 case reflection::Long:
36 repeated_field->reset(
37 new ReflectiveFlatbuffer::TypedRepeatedField<int64>);
38 return true;
39 case reflection::Float:
40 repeated_field->reset(
41 new ReflectiveFlatbuffer::TypedRepeatedField<float>);
42 return true;
43 case reflection::Double:
44 repeated_field->reset(
45 new ReflectiveFlatbuffer::TypedRepeatedField<double>);
46 return true;
47 case reflection::String:
48 repeated_field->reset(
49 new ReflectiveFlatbuffer::TypedRepeatedField<std::string>);
50 return true;
51 case reflection::Obj:
52 repeated_field->reset(
53 new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>(
54 schema, type));
55 return true;
56 default:
57 TC3_LOG(ERROR) << "Unsupported type: " << type->element();
58 return false;
59 }
60 }
61 } // namespace
62
63 template <>
FlatbufferFileIdentifier()64 const char* FlatbufferFileIdentifier<Model>() {
65 return ModelIdentifier();
66 }
67
NewRoot() const68 std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
69 const {
70 if (!schema_->root_table()) {
71 TC3_LOG(ERROR) << "No root table specified.";
72 return nullptr;
73 }
74 return std::unique_ptr<ReflectiveFlatbuffer>(
75 new ReflectiveFlatbuffer(schema_, schema_->root_table()));
76 }
77
NewTable(StringPiece table_name) const78 std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
79 StringPiece table_name) const {
80 for (const reflection::Object* object : *schema_->objects()) {
81 if (table_name.Equals(object->name()->str())) {
82 return std::unique_ptr<ReflectiveFlatbuffer>(
83 new ReflectiveFlatbuffer(schema_, object));
84 }
85 }
86 return nullptr;
87 }
88
GetFieldOrNull(const StringPiece field_name) const89 const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
90 const StringPiece field_name) const {
91 return type_->fields()->LookupByKey(field_name.data());
92 }
93
GetFieldOrNull(const FlatbufferField * field) const94 const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
95 const FlatbufferField* field) const {
96 // Lookup by name might be faster as the fields are sorted by name in the
97 // schema data, so try that first.
98 if (field->field_name() != nullptr) {
99 return GetFieldOrNull(field->field_name()->str());
100 }
101 return GetFieldByOffsetOrNull(field->field_offset());
102 }
103
GetFieldWithParent(const FlatbufferFieldPath * field_path,ReflectiveFlatbuffer ** parent,reflection::Field const ** field)104 bool ReflectiveFlatbuffer::GetFieldWithParent(
105 const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
106 reflection::Field const** field) {
107 const auto* path = field_path->field();
108 if (path == nullptr || path->size() == 0) {
109 return false;
110 }
111
112 for (int i = 0; i < path->size(); i++) {
113 *parent = (i == 0 ? this : (*parent)->Mutable(*field));
114 if (*parent == nullptr) {
115 return false;
116 }
117 *field = (*parent)->GetFieldOrNull(path->Get(i));
118 if (*field == nullptr) {
119 return false;
120 }
121 }
122
123 return true;
124 }
125
GetFieldByOffsetOrNull(const int field_offset) const126 const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
127 const int field_offset) const {
128 if (type_->fields() == nullptr) {
129 return nullptr;
130 }
131 for (const reflection::Field* field : *type_->fields()) {
132 if (field->offset() == field_offset) {
133 return field;
134 }
135 }
136 return nullptr;
137 }
138
IsMatchingType(const reflection::Field * field,const Variant & value) const139 bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field,
140 const Variant& value) const {
141 switch (field->type()->base_type()) {
142 case reflection::Bool:
143 return value.HasBool();
144 case reflection::Int:
145 return value.HasInt();
146 case reflection::Long:
147 return value.HasInt64();
148 case reflection::Float:
149 return value.HasFloat();
150 case reflection::Double:
151 return value.HasDouble();
152 case reflection::String:
153 return value.HasString();
154 default:
155 return false;
156 }
157 }
158
ParseAndSet(const reflection::Field * field,const std::string & value)159 bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
160 const std::string& value) {
161 switch (field->type()->base_type()) {
162 case reflection::String:
163 return Set(field, value);
164 case reflection::Int: {
165 int32 int_value;
166 if (!ParseInt32(value.data(), &int_value)) {
167 TC3_LOG(ERROR) << "Could not parse '" << value << "' as int32.";
168 return false;
169 }
170 return Set(field, int_value);
171 }
172 case reflection::Long: {
173 int64 int_value;
174 if (!ParseInt64(value.data(), &int_value)) {
175 TC3_LOG(ERROR) << "Could not parse '" << value << "' as int64.";
176 return false;
177 }
178 return Set(field, int_value);
179 }
180 case reflection::Float: {
181 double double_value;
182 if (!ParseDouble(value.data(), &double_value)) {
183 TC3_LOG(ERROR) << "Could not parse '" << value << "' as float.";
184 return false;
185 }
186 return Set(field, static_cast<float>(double_value));
187 }
188 case reflection::Double: {
189 double double_value;
190 if (!ParseDouble(value.data(), &double_value)) {
191 TC3_LOG(ERROR) << "Could not parse '" << value << "' as double.";
192 return false;
193 }
194 return Set(field, double_value);
195 }
196 default:
197 TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type();
198 return false;
199 }
200 }
201
ParseAndSet(const FlatbufferFieldPath * path,const std::string & value)202 bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
203 const std::string& value) {
204 ReflectiveFlatbuffer* parent;
205 const reflection::Field* field;
206 if (!GetFieldWithParent(path, &parent, &field)) {
207 return false;
208 }
209 return parent->ParseAndSet(field, value);
210 }
211
Mutable(const StringPiece field_name)212 ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
213 const StringPiece field_name) {
214 if (const reflection::Field* field = GetFieldOrNull(field_name)) {
215 return Mutable(field);
216 }
217 TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
218 return nullptr;
219 }
220
Mutable(const reflection::Field * field)221 ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
222 const reflection::Field* field) {
223 if (field->type()->base_type() != reflection::Obj) {
224 TC3_LOG(ERROR) << "Field is not of type Object.";
225 return nullptr;
226 }
227 const auto entry = children_.find(field);
228 if (entry != children_.end()) {
229 return entry->second.get();
230 }
231 const auto it = children_.insert(
232 /*hint=*/entry,
233 std::make_pair(
234 field,
235 std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
236 schema_, schema_->objects()->Get(field->type()->index())))));
237 return it->second.get();
238 }
239
Repeated(StringPiece field_name)240 ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
241 StringPiece field_name) {
242 if (const reflection::Field* field = GetFieldOrNull(field_name)) {
243 return Repeated(field);
244 }
245 TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
246 return nullptr;
247 }
248
Repeated(const reflection::Field * field)249 ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
250 const reflection::Field* field) {
251 if (field->type()->base_type() != reflection::Vector) {
252 TC3_LOG(ERROR) << "Field is not of type Vector.";
253 return nullptr;
254 }
255
256 // If the repeated field was already set, return its instance.
257 const auto entry = repeated_fields_.find(field);
258 if (entry != repeated_fields_.end()) {
259 return entry->second.get();
260 }
261
262 // Otherwise, create a new instance and store it.
263 std::unique_ptr<RepeatedField> repeated_field;
264 if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) {
265 TC3_LOG(ERROR) << "Could not create repeated field.";
266 return nullptr;
267 }
268 const auto it = repeated_fields_.insert(
269 /*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
270 return it->second.get();
271 }
272
Serialize(flatbuffers::FlatBufferBuilder * builder) const273 flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
274 flatbuffers::FlatBufferBuilder* builder) const {
275 // Build all children before we can start with this table.
276 std::vector<
277 std::pair</* field vtable offset */ int,
278 /* field data offset in buffer */ flatbuffers::uoffset_t>>
279 offsets;
280 offsets.reserve(children_.size() + repeated_fields_.size());
281 for (const auto& it : children_) {
282 offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
283 }
284
285 // Create strings.
286 for (const auto& it : fields_) {
287 if (it.second.HasString()) {
288 offsets.push_back({it.first->offset(),
289 builder->CreateString(it.second.StringValue()).o});
290 }
291 }
292
293 // Build the repeated fields.
294 for (const auto& it : repeated_fields_) {
295 offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
296 }
297
298 // Build the table now.
299 const flatbuffers::uoffset_t table_start = builder->StartTable();
300
301 // Add scalar fields.
302 for (const auto& it : fields_) {
303 switch (it.second.GetType()) {
304 case Variant::TYPE_BOOL_VALUE:
305 builder->AddElement<uint8_t>(
306 it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()),
307 static_cast<uint8_t>(it.first->default_integer()));
308 continue;
309 case Variant::TYPE_INT_VALUE:
310 builder->AddElement<int32>(
311 it.first->offset(), it.second.IntValue(),
312 static_cast<int32>(it.first->default_integer()));
313 continue;
314 case Variant::TYPE_INT64_VALUE:
315 builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(),
316 it.first->default_integer());
317 continue;
318 case Variant::TYPE_FLOAT_VALUE:
319 builder->AddElement<float>(
320 it.first->offset(), it.second.FloatValue(),
321 static_cast<float>(it.first->default_real()));
322 continue;
323 case Variant::TYPE_DOUBLE_VALUE:
324 builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(),
325 it.first->default_real());
326 continue;
327 default:
328 continue;
329 }
330 }
331
332 // Add strings, subtables and repeated fields.
333 for (const auto& it : offsets) {
334 builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second));
335 }
336
337 return builder->EndTable(table_start);
338 }
339
Serialize() const340 std::string ReflectiveFlatbuffer::Serialize() const {
341 flatbuffers::FlatBufferBuilder builder;
342 builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
343 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
344 builder.GetSize());
345 }
346
MergeFrom(const flatbuffers::Table * from)347 bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
348 // No fields to set.
349 if (type_->fields() == nullptr) {
350 return true;
351 }
352
353 for (const reflection::Field* field : *type_->fields()) {
354 // Skip fields that are not explicitly set.
355 if (!from->CheckField(field->offset())) {
356 continue;
357 }
358 const reflection::BaseType type = field->type()->base_type();
359 switch (type) {
360 case reflection::Bool:
361 Set<bool>(field, from->GetField<uint8_t>(field->offset(),
362 field->default_integer()));
363 break;
364 case reflection::Int:
365 Set<int32>(field, from->GetField<int32>(field->offset(),
366 field->default_integer()));
367 break;
368 case reflection::Long:
369 Set<int64>(field, from->GetField<int64>(field->offset(),
370 field->default_integer()));
371 break;
372 case reflection::Float:
373 Set<float>(field, from->GetField<float>(field->offset(),
374 field->default_real()));
375 break;
376 case reflection::Double:
377 Set<double>(field, from->GetField<double>(field->offset(),
378 field->default_real()));
379 break;
380 case reflection::String:
381 Set<std::string>(
382 field, from->GetPointer<const flatbuffers::String*>(field->offset())
383 ->str());
384 break;
385 case reflection::Obj:
386 if (!Mutable(field)->MergeFrom(
387 from->GetPointer<const flatbuffers::Table* const>(
388 field->offset()))) {
389 return false;
390 }
391 break;
392 default:
393 TC3_LOG(ERROR) << "Unsupported type: " << type;
394 return false;
395 }
396 }
397 return true;
398 }
399
MergeFromSerializedFlatbuffer(StringPiece from)400 bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
401 return MergeFrom(flatbuffers::GetAnyRoot(
402 reinterpret_cast<const unsigned char*>(from.data())));
403 }
404
AsFlatMap(const std::string & key_separator,const std::string & key_prefix,std::map<std::string,Variant> * result) const405 void ReflectiveFlatbuffer::AsFlatMap(
406 const std::string& key_separator, const std::string& key_prefix,
407 std::map<std::string, Variant>* result) const {
408 // Add direct fields.
409 for (auto it : fields_) {
410 (*result)[key_prefix + it.first->name()->str()] = it.second;
411 }
412
413 // Add nested messages.
414 for (auto& it : children_) {
415 it.second->AsFlatMap(key_separator,
416 key_prefix + it.first->name()->str() + key_separator,
417 result);
418 }
419 }
420
421 } // namespace libtextclassifier3
422