1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "parent_def.h"
18
19 #include "fields/all_fields.h"
20 #include "util.h"
21
ParentDef(std::string name,FieldList fields)22 ParentDef::ParentDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
ParentDef(std::string name,FieldList fields,ParentDef * parent)23 ParentDef::ParentDef(std::string name, FieldList fields, ParentDef* parent)
24 : TypeDef(name), fields_(fields), parent_(parent) {}
25
AddParentConstraint(std::string field_name,std::variant<int64_t,std::string> value)26 void ParentDef::AddParentConstraint(std::string field_name, std::variant<int64_t, std::string> value) {
27 // NOTE: This could end up being very slow if there are a lot of constraints.
28 const auto& parent_params = parent_->GetParamList();
29 const auto& constrained_field = parent_params.GetField(field_name);
30 if (constrained_field == nullptr) {
31 ERROR() << "Attempting to constrain field " << field_name << " in parent " << parent_->name_
32 << ", but no such field exists.";
33 }
34
35 if (constrained_field->GetFieldType() == ScalarField::kFieldType) {
36 if (!std::holds_alternative<int64_t>(value)) {
37 ERROR(constrained_field) << "Attempting to constrain a scalar field to an enum value in " << parent_->name_;
38 }
39 } else if (constrained_field->GetFieldType() == EnumField::kFieldType) {
40 if (!std::holds_alternative<std::string>(value)) {
41 ERROR(constrained_field) << "Attempting to constrain an enum field to a scalar value in " << parent_->name_;
42 }
43 const auto& enum_def = static_cast<EnumField*>(constrained_field)->GetEnumDef();
44 if (!enum_def.HasEntry(std::get<std::string>(value))) {
45 ERROR(constrained_field) << "No matching enumeration \"" << std::get<std::string>(value)
46 << "\" for constraint on enum in parent " << parent_->name_ << ".";
47 }
48
49 // For enums, we have to qualify the value using the enum type name.
50 value = enum_def.GetTypeName() + "::" + std::get<std::string>(value);
51 } else {
52 ERROR(constrained_field) << "Field in parent " << parent_->name_ << " is not viable for constraining.";
53 }
54
55 parent_constraints_.insert(std::pair(field_name, value));
56 }
57
AddTestCase(std::string packet_bytes)58 void ParentDef::AddTestCase(std::string packet_bytes) {
59 test_cases_.insert(std::move(packet_bytes));
60 }
61
62 // Assign all size fields to their corresponding variable length fields.
63 // Will crash if
64 // - there aren't any fields that don't match up to a field.
65 // - the size field points to a fixed size field.
66 // - if the size field comes after the variable length field.
AssignSizeFields()67 void ParentDef::AssignSizeFields() {
68 for (const auto& field : fields_) {
69 DEBUG() << "field name: " << field->GetName();
70
71 if (field->GetFieldType() != SizeField::kFieldType && field->GetFieldType() != CountField::kFieldType) {
72 continue;
73 }
74
75 const SizeField* size_field = static_cast<SizeField*>(field);
76 // Check to see if a corresponding field can be found.
77 const auto& var_len_field = fields_.GetField(size_field->GetSizedFieldName());
78 if (var_len_field == nullptr) {
79 ERROR(field) << "Could not find corresponding field for size/count field.";
80 }
81
82 // Do the ordering check to ensure the size field comes before the
83 // variable length field.
84 for (auto it = fields_.begin(); *it != size_field; it++) {
85 DEBUG() << "field name: " << (*it)->GetName();
86 if (*it == var_len_field) {
87 ERROR(var_len_field, size_field) << "Size/count field must come before the variable length field it describes.";
88 }
89 }
90
91 if (var_len_field->GetFieldType() == PayloadField::kFieldType) {
92 const auto& payload_field = static_cast<PayloadField*>(var_len_field);
93 payload_field->SetSizeField(size_field);
94 continue;
95 }
96
97 if (var_len_field->GetFieldType() == BodyField::kFieldType) {
98 const auto& body_field = static_cast<BodyField*>(var_len_field);
99 body_field->SetSizeField(size_field);
100 continue;
101 }
102
103 if (var_len_field->GetFieldType() == VectorField::kFieldType) {
104 const auto& vector_field = static_cast<VectorField*>(var_len_field);
105 vector_field->SetSizeField(size_field);
106 continue;
107 }
108
109 // If we've reached this point then the field wasn't a variable length field.
110 // Check to see if the field is a variable length field
111 ERROR(field, size_field) << "Can not use size/count in reference to a fixed size field.\n";
112 }
113 }
114
SetEndianness(bool is_little_endian)115 void ParentDef::SetEndianness(bool is_little_endian) {
116 is_little_endian_ = is_little_endian;
117 }
118
119 // Get the size. You scan specify without_payload in order to exclude payload fields as children will be overriding it.
GetSize(bool without_payload) const120 Size ParentDef::GetSize(bool without_payload) const {
121 auto size = Size(0);
122
123 for (const auto& field : fields_) {
124 if (without_payload &&
125 (field->GetFieldType() == PayloadField::kFieldType || field->GetFieldType() == BodyField::kFieldType)) {
126 continue;
127 }
128
129 // The offset to the field must be passed in as an argument for dynamically sized custom fields.
130 if (field->GetFieldType() == CustomField::kFieldType && field->GetSize().has_dynamic()) {
131 std::stringstream custom_field_size;
132
133 // Custom fields are special as their size field takes an argument.
134 custom_field_size << field->GetSize().dynamic_string() << "(begin()";
135
136 // Check if we can determine offset from begin(), otherwise error because by this point,
137 // the size of the custom field is unknown and can't be subtracted from end() to get the
138 // offset.
139 auto offset = GetOffsetForField(field->GetName(), false);
140 if (offset.empty()) {
141 ERROR(field) << "Custom Field offset can not be determined from begin().";
142 }
143
144 if (offset.bits() % 8 != 0) {
145 ERROR(field) << "Custom fields must be byte aligned.";
146 }
147 if (offset.has_bits()) custom_field_size << " + " << offset.bits() / 8;
148 if (offset.has_dynamic()) custom_field_size << " + " << offset.dynamic_string();
149 custom_field_size << ")";
150
151 size += custom_field_size.str();
152 continue;
153 }
154
155 size += field->GetSize();
156 }
157
158 if (parent_ != nullptr) {
159 size += parent_->GetSize(true);
160 }
161
162 return size;
163 }
164
165 // Get the offset until the field is reached, if there is no field
166 // returns an empty Size. from_end requests the offset to the field
167 // starting from the end() iterator. If there is a field with an unknown
168 // size along the traversal, then an empty size is returned.
GetOffsetForField(std::string field_name,bool from_end) const169 Size ParentDef::GetOffsetForField(std::string field_name, bool from_end) const {
170 // Check first if the field exists.
171 if (fields_.GetField(field_name) == nullptr) {
172 ERROR() << "Can't find a field offset for nonexistent field named: " << field_name << " in " << name_;
173 }
174
175 PacketField* padded_field = nullptr;
176 {
177 PacketField* last_field = nullptr;
178 for (const auto field : fields_) {
179 if (field->GetFieldType() == PaddingField::kFieldType) {
180 padded_field = last_field;
181 }
182 last_field = field;
183 }
184 }
185
186 // We have to use a generic lambda to conditionally change iteration direction
187 // due to iterator and reverse_iterator being different types.
188 auto size_lambda = [&](auto from, auto to) -> Size {
189 auto size = Size(0);
190 for (auto it = from; it != to; it++) {
191 // We've reached the field, end the loop.
192 if ((*it)->GetName() == field_name) break;
193 const auto& field = *it;
194 // If there is a field with an unknown size before the field, return an empty Size.
195 if (field->GetSize().empty() && padded_field != field) {
196 return Size();
197 }
198 if (field != padded_field) {
199 if (!from_end || field->GetFieldType() != PaddingField::kFieldType) {
200 size += field->GetSize();
201 }
202 }
203 }
204 return size;
205 };
206
207 // Change iteration direction based on from_end.
208 auto size = Size();
209 if (from_end)
210 size = size_lambda(fields_.rbegin(), fields_.rend());
211 else
212 size = size_lambda(fields_.begin(), fields_.end());
213 if (size.empty()) return size;
214
215 // We need the offset until a payload or body field.
216 if (parent_ != nullptr) {
217 if (parent_->fields_.HasPayload()) {
218 auto parent_payload_offset = parent_->GetOffsetForField("payload", from_end);
219 if (parent_payload_offset.empty()) {
220 ERROR() << "Empty offset for payload in " << parent_->name_ << " finding the offset for field: " << field_name;
221 }
222 size += parent_payload_offset;
223 } else {
224 auto parent_body_offset = parent_->GetOffsetForField("body", from_end);
225 if (parent_body_offset.empty()) {
226 ERROR() << "Empty offset for body in " << parent_->name_ << " finding the offset for field: " << field_name;
227 }
228 size += parent_body_offset;
229 }
230 }
231
232 return size;
233 }
234
GetParamList() const235 FieldList ParentDef::GetParamList() const {
236 FieldList params;
237
238 std::set<std::string> param_types = {
239 ScalarField::kFieldType,
240 EnumField::kFieldType,
241 ArrayField::kFieldType,
242 VectorField::kFieldType,
243 CustomField::kFieldType,
244 StructField::kFieldType,
245 VariableLengthStructField::kFieldType,
246 PayloadField::kFieldType,
247 };
248
249 if (parent_ != nullptr) {
250 auto parent_params = parent_->GetParamList().GetFieldsWithTypes(param_types);
251
252 // Do not include constrained fields in the params
253 for (const auto& field : parent_params) {
254 if (parent_constraints_.find(field->GetName()) == parent_constraints_.end()) {
255 params.AppendField(field);
256 }
257 }
258 }
259 // Add our parameters.
260 return params.Merge(fields_.GetFieldsWithTypes(param_types));
261 }
262
GenMembers(std::ostream & s) const263 void ParentDef::GenMembers(std::ostream& s) const {
264 // Add the parameter list.
265 for (const auto& field : fields_) {
266 if (field->GenBuilderMember(s)) {
267 s << "_{};";
268 }
269 }
270 }
271
GenSize(std::ostream & s) const272 void ParentDef::GenSize(std::ostream& s) const {
273 auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
274 auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
275
276 Size padded_size;
277 const PacketField* padded_field = nullptr;
278 const PacketField* last_field = nullptr;
279 for (const auto& field : fields_) {
280 if (field->GetFieldType() == PaddingField::kFieldType) {
281 if (!padded_size.empty()) {
282 ERROR() << "Only one padding field is allowed. Second field: " << field->GetName();
283 }
284 padded_field = last_field;
285 padded_size = field->GetSize();
286 }
287 last_field = field;
288 }
289
290 s << "protected:";
291 s << "size_t BitsOfHeader() const {";
292 s << "return 0";
293
294 if (parent_ != nullptr) {
295 if (parent_->GetDefinitionType() == Type::PACKET) {
296 s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
297 } else {
298 s << " + " << parent_->name_ << "::BitsOfHeader() ";
299 }
300 }
301
302 for (const auto& field : header_fields) {
303 if (field == padded_field) {
304 s << " + " << padded_size;
305 } else {
306 s << " + " << field->GetBuilderSize();
307 }
308 }
309 s << ";";
310
311 s << "}\n\n";
312
313 s << "size_t BitsOfFooter() const {";
314 s << "return 0";
315 for (const auto& field : footer_fields) {
316 if (field == padded_field) {
317 s << " + " << padded_size;
318 } else {
319 s << " + " << field->GetBuilderSize();
320 }
321 }
322
323 if (parent_ != nullptr) {
324 if (parent_->GetDefinitionType() == Type::PACKET) {
325 s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
326 } else {
327 s << " + " << parent_->name_ << "::BitsOfFooter() ";
328 }
329 }
330 s << ";";
331 s << "}\n\n";
332
333 if (fields_.HasPayload()) {
334 s << "size_t GetPayloadSize() const {";
335 s << "if (payload_ != nullptr) {return payload_->size();}";
336 s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
337 s << ";}\n\n";
338 }
339
340 s << "public:";
341 s << "virtual size_t size() const override {";
342 s << "return (BitsOfHeader() / 8)";
343 if (fields_.HasPayload()) {
344 s << "+ payload_->size()";
345 }
346 if (fields_.HasBody()) {
347 for (const auto& field : header_fields) {
348 if (field->GetFieldType() == SizeField::kFieldType) {
349 const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
350 if (field_name == "body") {
351 s << "+ body_size_extracted_";
352 }
353 }
354 }
355 }
356 s << " + (BitsOfFooter() / 8);";
357 s << "}\n";
358 }
359
GenSerialize(std::ostream & s) const360 void ParentDef::GenSerialize(std::ostream& s) const {
361 auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
362 auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
363
364 s << "protected:";
365 s << "void SerializeHeader(BitInserter&";
366 if (parent_ != nullptr || header_fields.size() != 0) {
367 s << " i ";
368 }
369 s << ") const {";
370
371 if (parent_ != nullptr) {
372 if (parent_->GetDefinitionType() == Type::PACKET) {
373 s << parent_->name_ << "Builder::SerializeHeader(i);";
374 } else {
375 s << parent_->name_ << "::SerializeHeader(i);";
376 }
377 }
378
379 const PacketField* padded_field = nullptr;
380 {
381 PacketField* last_field = nullptr;
382 for (const auto field : header_fields) {
383 if (field->GetFieldType() == PaddingField::kFieldType) {
384 padded_field = last_field;
385 }
386 last_field = field;
387 }
388 }
389
390 for (const auto& field : header_fields) {
391 if (field->GetFieldType() == SizeField::kFieldType) {
392 const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
393 const auto& sized_field = fields_.GetField(field_name);
394 if (sized_field == nullptr) {
395 ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
396 }
397 if (sized_field->GetFieldType() == PayloadField::kFieldType) {
398 s << "size_t payload_bytes = GetPayloadSize();";
399 std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
400 if (modifier != "") {
401 s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
402 s << "payload_bytes = payload_bytes + (" << modifier << ") / 8;";
403 }
404 s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
405 s << "insert(static_cast<" << field->GetDataType() << ">(payload_bytes), i," << field->GetSize().bits() << ");";
406 } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
407 s << field->GetName() << "_extracted_ = 0;";
408 s << "size_t local_size = " << name_ << "::size();";
409
410 s << "ASSERT((size() - local_size) < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
411 s << "insert(static_cast<" << field->GetDataType() << ">(size() - local_size), i," << field->GetSize().bits()
412 << ");";
413 } else {
414 if (sized_field->GetFieldType() != VectorField::kFieldType) {
415 ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
416 }
417 const auto& vector_name = field_name + "_";
418 const VectorField* vector = (VectorField*)sized_field;
419 s << "size_t " << vector_name + "bytes = 0;";
420 if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
421 s << "for (auto elem : " << vector_name << ") {";
422 s << vector_name + "bytes += elem.size(); }";
423 } else {
424 s << vector_name + "bytes = ";
425 s << vector_name << ".size() * ((" << vector->element_size_ << ") / 8);";
426 }
427 std::string modifier = vector->GetSizeModifier();
428 if (modifier != "") {
429 s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
430 s << vector_name << "bytes = ";
431 s << vector_name << "bytes + (" << modifier << ") / 8;";
432 }
433 s << "ASSERT(" << vector_name + "bytes < (1 << " << field->GetSize().bits() << "));";
434 s << "insert(" << vector_name << "bytes, i, ";
435 s << field->GetSize().bits() << ");";
436 }
437 } else if (field->GetFieldType() == ChecksumStartField::kFieldType) {
438 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
439 const auto& started_field = fields_.GetField(field_name);
440 if (started_field == nullptr) {
441 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
442 << ")";
443 }
444 s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetDataType() << ">();";
445 s << "shared_checksum_ptr->Initialize();";
446 s << "i.RegisterObserver(packet::ByteObserver(";
447 s << "[shared_checksum_ptr](uint8_t byte){ shared_checksum_ptr->AddByte(byte);},";
448 s << "[shared_checksum_ptr](){ return static_cast<uint64_t>(shared_checksum_ptr->GetChecksum());}));";
449 } else if (field->GetFieldType() == PaddingField::kFieldType) {
450 s << "ASSERT(unpadded_size <= " << field->GetSize().bytes() << ");";
451 s << "size_t padding_bytes = ";
452 s << field->GetSize().bytes() << " - unpadded_size;";
453 s << "for (size_t padding = 0; padding < padding_bytes; padding++) {i.insert_byte(0);}";
454 } else if (field->GetFieldType() == CountField::kFieldType) {
455 const auto& vector_name = ((SizeField*)field)->GetSizedFieldName() + "_";
456 s << "insert(" << vector_name << ".size(), i, " << field->GetSize().bits() << ");";
457 } else {
458 if (field == padded_field) {
459 s << "size_t unpadded_size = (" << field->GetBuilderSize() << ") / 8;";
460 }
461 field->GenInserter(s);
462 }
463 }
464 s << "}\n\n";
465
466 s << "void SerializeFooter(BitInserter&";
467 if (parent_ != nullptr || footer_fields.size() != 0) {
468 s << " i ";
469 }
470 s << ") const {";
471
472 for (const auto& field : footer_fields) {
473 field->GenInserter(s);
474 }
475 if (parent_ != nullptr) {
476 if (parent_->GetDefinitionType() == Type::PACKET) {
477 s << parent_->name_ << "Builder::SerializeFooter(i);";
478 } else {
479 s << parent_->name_ << "::SerializeFooter(i);";
480 }
481 }
482 s << "}\n\n";
483
484 s << "public:";
485 s << "virtual void Serialize(BitInserter& i) const override {";
486 s << "SerializeHeader(i);";
487 if (fields_.HasPayload()) {
488 s << "payload_->Serialize(i);";
489 }
490 s << "SerializeFooter(i);";
491
492 s << "}\n";
493 }
494
GenInstanceOf(std::ostream & s) const495 void ParentDef::GenInstanceOf(std::ostream& s) const {
496 if (parent_ != nullptr && parent_constraints_.size() > 0) {
497 s << "static bool IsInstance(const " << parent_->name_ << "& parent) {";
498 // Get the list of parent params.
499 FieldList parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
500 PayloadField::kFieldType,
501 BodyField::kFieldType,
502 });
503
504 // Check if constrained parent fields are set to their correct values.
505 for (const auto& field : parent_params) {
506 const auto& constraint = parent_constraints_.find(field->GetName());
507 if (constraint != parent_constraints_.end()) {
508 s << "if (parent." << field->GetName() << "_ != ";
509 if (field->GetFieldType() == ScalarField::kFieldType) {
510 s << std::get<int64_t>(constraint->second) << ")";
511 s << "{ return false;}";
512 } else if (field->GetFieldType() == EnumField::kFieldType) {
513 s << std::get<std::string>(constraint->second) << ")";
514 s << "{ return false;}";
515 } else {
516 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
517 }
518 }
519 }
520 s << "return true;}";
521 }
522 }
523
GetRootDef() const524 const ParentDef* ParentDef::GetRootDef() const {
525 if (parent_ == nullptr) {
526 return this;
527 }
528
529 return parent_->GetRootDef();
530 }
531
GetAncestors() const532 std::vector<const ParentDef*> ParentDef::GetAncestors() const {
533 std::vector<const ParentDef*> res;
534 auto parent = parent_;
535 while (parent != nullptr) {
536 res.push_back(parent);
537 parent = parent->parent_;
538 }
539 std::reverse(res.begin(), res.end());
540 return res;
541 }
542
GetAllConstraints() const543 std::map<std::string, std::variant<int64_t, std::string>> ParentDef::GetAllConstraints() const {
544 std::map<std::string, std::variant<int64_t, std::string>> res;
545 res.insert(parent_constraints_.begin(), parent_constraints_.end());
546 for (auto parent : GetAncestors()) {
547 res.insert(parent->parent_constraints_.begin(), parent->parent_constraints_.end());
548 }
549 return res;
550 }
551
HasAncestorNamed(std::string name) const552 bool ParentDef::HasAncestorNamed(std::string name) const {
553 auto parent = parent_;
554 while (parent != nullptr) {
555 if (parent->name_ == name) {
556 return true;
557 }
558 parent = parent->parent_;
559 }
560 return false;
561 }
562
FindConstraintField() const563 std::string ParentDef::FindConstraintField() const {
564 std::string res;
565 for (const auto& child : children_) {
566 if (!child->parent_constraints_.empty()) {
567 return child->parent_constraints_.begin()->first;
568 }
569 res = child->FindConstraintField();
570 }
571 return res;
572 }
573
574 std::map<const ParentDef*, const std::variant<int64_t, std::string>>
FindDescendantsWithConstraint(std::string constraint_name) const575 ParentDef::FindDescendantsWithConstraint(
576 std::string constraint_name) const {
577 std::map<const ParentDef*, const std::variant<int64_t, std::string>> res;
578
579 for (auto const& child : children_) {
580 auto constraint = child->parent_constraints_.find(constraint_name);
581 if (constraint != child->parent_constraints_.end()) {
582 res.insert(std::pair(child, constraint->second));
583 }
584 auto m = child->FindDescendantsWithConstraint(constraint_name);
585 res.insert(m.begin(), m.end());
586 }
587 return res;
588 }
589
FindPathToDescendant(std::string descendant) const590 std::vector<const ParentDef*> ParentDef::FindPathToDescendant(std::string descendant) const {
591 std::vector<const ParentDef*> res;
592
593 for (auto const& child : children_) {
594 auto v = child->FindPathToDescendant(descendant);
595 if (v.size() > 0) {
596 res.insert(res.begin(), v.begin(), v.end());
597 res.push_back(child);
598 }
599 if (child->name_ == descendant) {
600 res.push_back(child);
601 return res;
602 }
603 }
604 return res;
605 }
606
HasChildEnums() const607 bool ParentDef::HasChildEnums() const {
608 return !children_.empty() || fields_.HasPayload();
609 }
610
GenRustConformanceCheck(std::ostream & s) const611 void ParentDef::GenRustConformanceCheck(std::ostream& s) const {
612 auto fields = fields_.GetFieldsWithTypes({
613 FixedScalarField::kFieldType,
614 });
615
616 for (auto const& field : fields) {
617 auto start_offset = GetOffsetForField(field->GetName(), false);
618 auto end_offset = GetOffsetForField(field->GetName(), true);
619
620 auto f = (FixedScalarField*)field;
621 f->GenRustGetter(s, start_offset, end_offset);
622 s << "if " << f->GetName() << " != ";
623 f->GenValue(s);
624 s << " { return false; } ";
625 }
626 }
627
GenRustWriteToFields(std::ostream & s) const628 void ParentDef::GenRustWriteToFields(std::ostream& s) const {
629 auto fields = fields_.GetFieldsWithoutTypes({
630 BodyField::kFieldType,
631 PaddingField::kFieldType,
632 ReservedField::kFieldType,
633 });
634
635 for (auto const& field : fields) {
636 auto start_field_offset = GetOffsetForField(field->GetName(), false);
637 auto end_field_offset = GetOffsetForField(field->GetName(), true);
638
639 if (start_field_offset.empty() && end_field_offset.empty()) {
640 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
641 << "no method exists to determine field location from begin() or end().\n";
642 }
643
644 if (field->GetFieldType() == SizeField::kFieldType) {
645 const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
646 const auto& sized_field = fields_.GetField(field_name);
647 if (sized_field == nullptr) {
648 ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
649 }
650 if (sized_field->GetFieldType() == PayloadField::kFieldType) {
651 std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
652 if (modifier != "") {
653 ERROR(field) << __func__ << ": size modifiers not implemented yet for " << field_name;
654 }
655
656 s << "let " << field->GetName() << " = " << field->GetRustDataType()
657 << "::try_from(self.child.get_total_size()).expect(\"payload size did not fit\");";
658 } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
659 s << "let " << field->GetName() << " = " << field->GetRustDataType()
660 << "::try_from(self.get_total_size() - self.get_size()).expect(\"payload size did not fit\");";
661 } else if (sized_field->GetFieldType() == VectorField::kFieldType) {
662 const auto& vector_name = field_name + "_bytes";
663 const VectorField* vector = (VectorField*)sized_field;
664 if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
665 s << "let " << vector_name + " = self." << field_name
666 << ".iter().fold(0, |acc, x| acc + x.get_total_size());";
667 } else {
668 s << "let " << vector_name + " = self." << field_name << ".len() * ((" << vector->element_size_ << ") / 8);";
669 }
670 std::string modifier = vector->GetSizeModifier();
671 if (modifier != "") {
672 s << "let " << vector_name << " = " << vector_name << " + (" << modifier.substr(1) << ") / 8;";
673 }
674
675 s << "let " << field->GetName() << " = " << field->GetRustDataType() << "::try_from(" << vector_name
676 << ").expect(\"payload size did not fit\");";
677 } else {
678 ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
679 }
680 }
681
682 field->GenRustWriter(s, start_field_offset, end_field_offset);
683 }
684 }
685
GenSizeRetVal(std::ostream & s) const686 void ParentDef::GenSizeRetVal(std::ostream& s) const {
687 int size = 0;
688 auto fields = fields_.GetFieldsWithoutTypes({
689 BodyField::kFieldType,
690 });
691 const PacketField* padded_field = nullptr;
692 auto padding_fields = fields_.GetFieldsWithTypes({
693 PaddingField::kFieldType,
694 });
695 if (padding_fields.size()) {
696 PacketField* last_field = nullptr;
697 for (const auto field : fields) {
698 if (field->GetFieldType() == PaddingField::kFieldType) {
699 padded_field = last_field;
700 }
701 last_field = field;
702 }
703 }
704
705 s << "let ret = 0;";
706 for (const auto field : fields) {
707 bool is_vector = field->GetFieldType() == VectorField::kFieldType;
708 if (field != padded_field) { // Skip the size of padded fields
709 if (is_vector) {
710 if (size > 0) {
711 if (size % 8 != 0) {
712 ERROR() << "size is not a multiple of 8!\n";
713 }
714 s << "let ret = ret + " << size / 8 << ";";
715 size = 0;
716 }
717
718 const VectorField* vector = (VectorField*)field;
719 if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
720 s << "let ret = ret + self." << vector->GetName() << ".iter().fold(0, |acc, x| acc + x.get_total_size());";
721 } else {
722 s << "let ret = ret + (self." << vector->GetName() << ".len() * ((" << vector->element_size_ << ") / 8));";
723 }
724 } else {
725 size += field->GetSize().bits();
726 }
727 } else {
728 s << "/* Skipping " << field->GetName() << " since it is padded */";
729 }
730 }
731 if (size > 0) {
732 if (size % 8 != 0) {
733 ERROR() << "size is not a multiple of 8!\n";
734 }
735 s << "let ret = ret + " << size / 8 << ";";
736 }
737
738 s << "ret";
739 }
740