1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/flatbuffers/reflection.h"
18
19 namespace libtextclassifier3 {
20
GetFieldOrNull(const reflection::Object * type,const StringPiece field_name)21 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
22 const StringPiece field_name) {
23 TC3_CHECK(type != nullptr && type->fields() != nullptr);
24 return type->fields()->LookupByKey(field_name.data());
25 }
26
GetFieldOrNull(const reflection::Object * type,const int field_offset)27 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
28 const int field_offset) {
29 if (type->fields() == nullptr) {
30 return nullptr;
31 }
32 for (const reflection::Field* field : *type->fields()) {
33 if (field->offset() == field_offset) {
34 return field;
35 }
36 }
37 return nullptr;
38 }
39
GetFieldOrNull(const reflection::Object * type,const StringPiece field_name,const int field_offset)40 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
41 const StringPiece field_name,
42 const int field_offset) {
43 // Lookup by name might be faster as the fields are sorted by name in the
44 // schema data, so try that first.
45 if (!field_name.empty()) {
46 return GetFieldOrNull(type, field_name.data());
47 }
48 return GetFieldOrNull(type, field_offset);
49 }
50
GetFieldOrNull(const reflection::Object * type,const FlatbufferField * field)51 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
52 const FlatbufferField* field) {
53 TC3_CHECK(type != nullptr && field != nullptr);
54 if (field->field_name() == nullptr) {
55 return GetFieldOrNull(type, field->field_offset());
56 }
57 return GetFieldOrNull(
58 type,
59 StringPiece(field->field_name()->data(), field->field_name()->size()),
60 field->field_offset());
61 }
62
GetFieldOrNull(const reflection::Object * type,const FlatbufferFieldT * field)63 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
64 const FlatbufferFieldT* field) {
65 TC3_CHECK(type != nullptr && field != nullptr);
66 return GetFieldOrNull(type, field->field_name, field->field_offset);
67 }
68
TypeForName(const reflection::Schema * schema,const StringPiece type_name)69 const reflection::Object* TypeForName(const reflection::Schema* schema,
70 const StringPiece type_name) {
71 for (const reflection::Object* object : *schema->objects()) {
72 if (type_name.Equals(object->name()->str())) {
73 return object;
74 }
75 }
76 return nullptr;
77 }
78
TypeIdForObject(const reflection::Schema * schema,const reflection::Object * type)79 Optional<int> TypeIdForObject(const reflection::Schema* schema,
80 const reflection::Object* type) {
81 for (int i = 0; i < schema->objects()->size(); i++) {
82 if (schema->objects()->Get(i) == type) {
83 return Optional<int>(i);
84 }
85 }
86 return Optional<int>();
87 }
88
TypeIdForName(const reflection::Schema * schema,const StringPiece type_name)89 Optional<int> TypeIdForName(const reflection::Schema* schema,
90 const StringPiece type_name) {
91 for (int i = 0; i < schema->objects()->size(); i++) {
92 if (type_name.Equals(schema->objects()->Get(i)->name()->str())) {
93 return Optional<int>(i);
94 }
95 }
96 return Optional<int>();
97 }
98
SwapFieldNamesForOffsetsInPath(const reflection::Schema * schema,FlatbufferFieldPathT * path)99 bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
100 FlatbufferFieldPathT* path) {
101 if (schema == nullptr || !schema->root_table()) {
102 TC3_LOG(ERROR) << "Empty schema provided.";
103 return false;
104 }
105
106 reflection::Object const* type = schema->root_table();
107 for (int i = 0; i < path->field.size(); i++) {
108 const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
109 if (field == nullptr) {
110 TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
111 return false;
112 }
113 path->field[i]->field_name.clear();
114 path->field[i]->field_offset = field->offset();
115
116 // Descend.
117 if (i < path->field.size() - 1) {
118 if (field->type()->base_type() != reflection::Obj) {
119 TC3_LOG(ERROR) << "Field: " << field->name()->str()
120 << " is not of type `Object`.";
121 return false;
122 }
123 type = schema->objects()->Get(field->type()->index());
124 }
125 }
126 return true;
127 }
128
ParseEnumValue(const reflection::Schema * schema,const reflection::Type * type,StringPiece value)129 Variant ParseEnumValue(const reflection::Schema* schema,
130 const reflection::Type* type, StringPiece value) {
131 TC3_DCHECK(IsEnum(type));
132 TC3_CHECK_NE(schema->enums(), nullptr);
133 const auto* enum_values = schema->enums()->Get(type->index())->values();
134 if (enum_values == nullptr) {
135 TC3_LOG(ERROR) << "Enum has no specified values.";
136 return Variant();
137 }
138 for (const reflection::EnumVal* enum_value : *enum_values) {
139 if (value.Equals(StringPiece(enum_value->name()->c_str(),
140 enum_value->name()->size()))) {
141 const int64 value = enum_value->value();
142 switch (type->base_type()) {
143 case reflection::BaseType::Byte:
144 return Variant(static_cast<int8>(value));
145 case reflection::BaseType::UByte:
146 return Variant(static_cast<uint8>(value));
147 case reflection::BaseType::Short:
148 return Variant(static_cast<int16>(value));
149 case reflection::BaseType::UShort:
150 return Variant(static_cast<uint16>(value));
151 case reflection::BaseType::Int:
152 return Variant(static_cast<int32>(value));
153 case reflection::BaseType::UInt:
154 return Variant(static_cast<uint32>(value));
155 case reflection::BaseType::Long:
156 return Variant(value);
157 case reflection::BaseType::ULong:
158 return Variant(static_cast<uint64>(value));
159 default:
160 break;
161 }
162 }
163 }
164 return Variant();
165 }
166
167 } // namespace libtextclassifier3
168