1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "struct_def.h"
18
19 #include "fields/all_fields.h"
20 #include "util.h"
21
StructDef(std::string name,FieldList fields)22 StructDef::StructDef(std::string name, FieldList fields) : StructDef(name, fields, nullptr) {}
StructDef(std::string name,FieldList fields,StructDef * parent)23 StructDef::StructDef(std::string name, FieldList fields, StructDef* parent)
24 : ParentDef(name, fields, parent), total_size_(GetSize(true)) {}
25
GetNewField(const std::string & name,ParseLocation loc) const26 PacketField* StructDef::GetNewField(const std::string& name, ParseLocation loc) const {
27 if (fields_.HasBody()) {
28 return new VariableLengthStructField(name, name_, loc);
29 } else {
30 return new StructField(name, name_, total_size_, loc);
31 }
32 }
33
GetDefinitionType() const34 TypeDef::Type StructDef::GetDefinitionType() const {
35 return TypeDef::Type::STRUCT;
36 }
37
GenSpecialize(std::ostream & s) const38 void StructDef::GenSpecialize(std::ostream& s) const {
39 if (parent_ == nullptr) {
40 return;
41 }
42 s << "static " << name_ << "* Specialize(" << parent_->name_ << "* parent) {";
43 s << "ASSERT(" << name_ << "::IsInstance(*parent));";
44 s << "return static_cast<" << name_ << "*>(parent);";
45 s << "}";
46 }
47
GenToString(std::ostream & s) const48 void StructDef::GenToString(std::ostream& s) const {
49 s << "std::string ToString() {";
50 s << "std::stringstream ss;";
51 s << "ss << std::hex << std::showbase << \"" << name_ << " { \";";
52
53 if (fields_.size() > 0) {
54 s << "ss";
55 bool firstfield = true;
56 for (const auto& field : fields_) {
57 if (field->GetFieldType() == ReservedField::kFieldType ||
58 field->GetFieldType() == ChecksumStartField::kFieldType ||
59 field->GetFieldType() == FixedScalarField::kFieldType || field->GetFieldType() == CountField::kFieldType ||
60 field->GetFieldType() == SizeField::kFieldType)
61 continue;
62
63 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
64
65 field->GenStringRepresentation(s, field->GetName() + "_");
66
67 if (firstfield) {
68 firstfield = false;
69 }
70 }
71 s << ";";
72 }
73
74 s << "ss << \" }\";";
75 s << "return ss.str();";
76 s << "}\n";
77 }
78
GenParse(std::ostream & s) const79 void StructDef::GenParse(std::ostream& s) const {
80 std::string iterator = (is_little_endian_ ? "Iterator<kLittleEndian>" : "Iterator<!kLittleEndian>");
81
82 if (fields_.HasBody()) {
83 s << "static std::optional<" << iterator << ">";
84 } else {
85 s << "static " << iterator;
86 }
87
88 s << " Parse(" << name_ << "* to_fill, " << iterator << " struct_begin_it ";
89
90 if (parent_ != nullptr) {
91 s << ", bool fill_parent = true) {";
92 } else {
93 s << ") {";
94 }
95 s << "auto to_bound = struct_begin_it;";
96
97 if (parent_ != nullptr) {
98 s << "if (fill_parent) {";
99 std::string parent_param = (parent_->parent_ == nullptr ? "" : ", true");
100 if (parent_->fields_.HasBody()) {
101 s << "auto parent_optional_it = " << parent_->name_ << "::Parse(to_fill, to_bound" << parent_param << ");";
102 if (fields_.HasBody()) {
103 s << "if (!parent_optional_it) { return {}; }";
104 } else {
105 s << "ASSERT(parent_optional_it);";
106 }
107 } else {
108 s << parent_->name_ << "::Parse(to_fill, to_bound" << parent_param << ");";
109 }
110 s << "}";
111 }
112
113 if (!fields_.HasBody()) {
114 s << "size_t end_index = struct_begin_it.NumBytesRemaining();";
115 s << "if (end_index < " << GetSize().bytes() << ")";
116 s << "{ return struct_begin_it.Subrange(0,0);}";
117 }
118
119 Size total_bits{0};
120 for (const auto& field : fields_) {
121 if (field->GetFieldType() != ReservedField::kFieldType && field->GetFieldType() != BodyField::kFieldType &&
122 field->GetFieldType() != FixedScalarField::kFieldType &&
123 field->GetFieldType() != ChecksumStartField::kFieldType && field->GetFieldType() != ChecksumField::kFieldType &&
124 field->GetFieldType() != CountField::kFieldType) {
125 total_bits += field->GetSize().bits();
126 }
127 }
128 s << "{";
129 s << "if (to_bound.NumBytesRemaining() < " << total_bits.bytes() << ")";
130 if (!fields_.HasBody()) {
131 s << "{ return to_bound.Subrange(to_bound.NumBytesRemaining(),0);}";
132 } else {
133 s << "{ return {};}";
134 }
135 s << "}";
136 for (const auto& field : fields_) {
137 if (field->GetFieldType() != ReservedField::kFieldType && field->GetFieldType() != BodyField::kFieldType &&
138 field->GetFieldType() != FixedScalarField::kFieldType && field->GetFieldType() != SizeField::kFieldType &&
139 field->GetFieldType() != ChecksumStartField::kFieldType && field->GetFieldType() != ChecksumField::kFieldType &&
140 field->GetFieldType() != CountField::kFieldType) {
141 s << "{";
142 int num_leading_bits =
143 field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(), field->GetStructSize());
144 s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_;";
145 field->GenExtractor(s, num_leading_bits, true);
146 s << "}";
147 }
148 if (field->GetFieldType() == CountField::kFieldType || field->GetFieldType() == SizeField::kFieldType) {
149 s << "{";
150 int num_leading_bits =
151 field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(), field->GetStructSize());
152 s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_extracted_;";
153 field->GenExtractor(s, num_leading_bits, true);
154 s << "}";
155 }
156 }
157 s << "return struct_begin_it + to_fill->size();";
158 s << "}";
159 }
160
GenParseFunctionPrototype(std::ostream & s) const161 void StructDef::GenParseFunctionPrototype(std::ostream& s) const {
162 s << "std::unique_ptr<" << name_ << "> Parse" << name_ << "(";
163 if (is_little_endian_) {
164 s << "Iterator<kLittleEndian>";
165 } else {
166 s << "Iterator<!kLittleEndian>";
167 }
168 s << "it);";
169 }
170
GenDefinition(std::ostream & s) const171 void StructDef::GenDefinition(std::ostream& s) const {
172 s << "class " << name_;
173 if (parent_ != nullptr) {
174 s << " : public " << parent_->name_;
175 } else {
176 if (is_little_endian_) {
177 s << " : public PacketStruct<kLittleEndian>";
178 } else {
179 s << " : public PacketStruct<!kLittleEndian>";
180 }
181 }
182 s << " {";
183 s << " public:";
184
185 GenConstructor(s);
186
187 s << " public:\n";
188 s << " virtual ~" << name_ << "() = default;\n";
189
190 GenSerialize(s);
191 s << "\n";
192
193 GenParse(s);
194 s << "\n";
195
196 GenSize(s);
197 s << "\n";
198
199 GenInstanceOf(s);
200 s << "\n";
201
202 GenSpecialize(s);
203 s << "\n";
204
205 GenToString(s);
206 s << "\n";
207
208 GenMembers(s);
209 for (const auto& field : fields_) {
210 if (field->GetFieldType() == CountField::kFieldType || field->GetFieldType() == SizeField::kFieldType) {
211 s << "\n private:\n";
212 s << " mutable " << field->GetDataType() << " " << field->GetName() << "_extracted_{0};";
213 }
214 }
215 s << "};\n";
216
217 if (fields_.HasBody()) {
218 GenParseFunctionPrototype(s);
219 }
220 s << "\n";
221 }
222
GenDefinitionPybind11(std::ostream & s) const223 void StructDef::GenDefinitionPybind11(std::ostream& s) const {
224 s << "py::class_<" << name_;
225 if (parent_ != nullptr) {
226 s << ", " << parent_->name_;
227 } else {
228 if (is_little_endian_) {
229 s << ", PacketStruct<kLittleEndian>";
230 } else {
231 s << ", PacketStruct<!kLittleEndian>";
232 }
233 }
234 s << ", std::shared_ptr<" << name_ << ">";
235 s << ">(m, \"" << name_ << "\")";
236 s << ".def(py::init<>())";
237 s << ".def(\"Serialize\", [](" << GetTypeName() << "& obj){";
238 s << "std::vector<uint8_t> bytes;";
239 s << "BitInserter bi(bytes);";
240 s << "obj.Serialize(bi);";
241 s << "return bytes;})";
242 s << ".def(\"Parse\", &" << name_ << "::Parse)";
243 s << ".def(\"size\", &" << name_ << "::size)";
244 for (const auto& field : fields_) {
245 if (field->GetBuilderParameterType().empty()) {
246 continue;
247 }
248 s << ".def_readwrite(\"" << field->GetName() << "\", &" << name_ << "::" << field->GetName() << "_)";
249 }
250 s << ";\n";
251 }
252
GenConstructor(std::ostream & s) const253 void StructDef::GenConstructor(std::ostream& s) const {
254 if (parent_ != nullptr) {
255 s << name_ << "(const " << parent_->name_ << "& parent) : " << parent_->name_ << "(parent) {}";
256 s << name_ << "() : " << parent_->name_ << "() {";
257 } else {
258 s << name_ << "() {";
259 }
260
261 // Get the list of parent params.
262 FieldList parent_params;
263 if (parent_ != nullptr) {
264 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
265 PayloadField::kFieldType,
266 BodyField::kFieldType,
267 });
268
269 // Set constrained parent fields to their correct values.
270 for (const auto& field : parent_params) {
271 const auto& constraint = parent_constraints_.find(field->GetName());
272 if (constraint != parent_constraints_.end()) {
273 s << parent_->name_ << "::" << field->GetName() << "_ = ";
274 if (field->GetFieldType() == ScalarField::kFieldType) {
275 s << std::get<int64_t>(constraint->second) << ";";
276 } else if (field->GetFieldType() == EnumField::kFieldType) {
277 s << std::get<std::string>(constraint->second) << ";";
278 } else {
279 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
280 }
281 }
282 }
283 }
284
285 s << "}\n";
286 }
287
GetStructOffsetForField(std::string field_name) const288 Size StructDef::GetStructOffsetForField(std::string field_name) const {
289 auto size = Size(0);
290 for (auto it = fields_.begin(); it != fields_.end(); it++) {
291 // We've reached the field, end the loop.
292 if ((*it)->GetName() == field_name) break;
293 const auto& field = *it;
294 // When we need to parse this field, all previous fields should already be parsed.
295 if (field->GetStructSize().empty()) {
296 ERROR() << "Empty size for field " << (*it)->GetName() << " finding the offset for field: " << field_name;
297 }
298 size += field->GetStructSize();
299 }
300
301 // We need the offset until a body field.
302 if (parent_ != nullptr) {
303 auto parent_body_offset = static_cast<StructDef*>(parent_)->GetStructOffsetForField("body");
304 if (parent_body_offset.empty()) {
305 ERROR() << "Empty offset for body in " << parent_->name_ << " finding the offset for field: " << field_name;
306 }
307 size += parent_body_offset;
308 }
309
310 return size;
311 }
312
GenRustFieldNameAndType(std::ostream & s,bool include_fixed) const313 void StructDef::GenRustFieldNameAndType(std::ostream& s, bool include_fixed) const {
314 auto fields = fields_.GetFieldsWithoutTypes({
315 BodyField::kFieldType,
316 CountField::kFieldType,
317 PaddingField::kFieldType,
318 ReservedField::kFieldType,
319 SizeField::kFieldType,
320 });
321 for (const auto& field : fields) {
322 if (!include_fixed && field->GetFieldType() == FixedScalarField::kFieldType) {
323 continue;
324 }
325 field->GenRustNameAndType(s);
326 s << ", ";
327 }
328 }
329
GenRustFieldNames(std::ostream & s) const330 void StructDef::GenRustFieldNames(std::ostream& s) const {
331 auto fields = fields_.GetFieldsWithoutTypes({
332 BodyField::kFieldType,
333 CountField::kFieldType,
334 PaddingField::kFieldType,
335 ReservedField::kFieldType,
336 SizeField::kFieldType,
337 });
338 for (const auto& field : fields) {
339 s << field->GetName();
340 s << ", ";
341 }
342 }
343
GenRustDeclarations(std::ostream & s) const344 void StructDef::GenRustDeclarations(std::ostream& s) const {
345 s << "#[derive(Debug, Clone)] ";
346 s << "pub struct " << name_ << "{";
347
348 // Generate struct fields
349 auto fields = fields_.GetFieldsWithoutTypes({
350 BodyField::kFieldType,
351 CountField::kFieldType,
352 PaddingField::kFieldType,
353 ReservedField::kFieldType,
354 SizeField::kFieldType,
355 });
356 for (const auto& field : fields) {
357 s << "pub ";
358 field->GenRustNameAndType(s);
359 s << ", ";
360 }
361 s << "}\n";
362 }
363
GenRustImpls(std::ostream & s) const364 void StructDef::GenRustImpls(std::ostream& s) const {
365 s << "impl " << name_ << "{";
366
367 s << "fn conforms(bytes: &[u8]) -> bool {";
368 GenRustConformanceCheck(s);
369 s << " true";
370 s << "}";
371
372 s << "pub fn parse(bytes: &[u8]) -> Result<Self> {";
373 auto fields = fields_.GetFieldsWithoutTypes({
374 BodyField::kFieldType,
375 });
376
377 for (const auto& field : fields) {
378 auto start_field_offset = GetOffsetForField(field->GetName(), false);
379 auto end_field_offset = GetOffsetForField(field->GetName(), true);
380
381 if (start_field_offset.empty() && end_field_offset.empty()) {
382 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
383 << "no method exists to determine field location from begin() or end().\n";
384 }
385
386 field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
387 field->GenRustGetter(s, start_field_offset, end_field_offset, name_);
388 }
389
390 fields = fields_.GetFieldsWithoutTypes({
391 BodyField::kFieldType,
392 CountField::kFieldType,
393 PaddingField::kFieldType,
394 ReservedField::kFieldType,
395 SizeField::kFieldType,
396 });
397
398 s << "Ok(Self {";
399 for (const auto& field : fields) {
400 if (field->GetFieldType() == FixedScalarField::kFieldType) {
401 s << field->GetName() << ": ";
402 static_cast<FixedScalarField*>(field)->GenValue(s);
403 } else {
404 s << field->GetName();
405 }
406 s << ", ";
407 }
408 s << "})}\n";
409
410 // write_to function
411 s << "fn write_to(&self, buffer: &mut [u8]) {";
412 GenRustWriteToFields(s);
413 s << "}\n";
414
415 s << "fn get_total_size(&self) -> usize {";
416 GenSizeRetVal(s);
417 s << "}";
418 s << "}\n";
419 }
420
GenRustDef(std::ostream & s) const421 void StructDef::GenRustDef(std::ostream& s) const {
422 GenRustDeclarations(s);
423 GenRustImpls(s);
424 }
425