1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "packet_def.h"
18
19 #include <iomanip>
20 #include <list>
21 #include <set>
22
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
29
GetNewField(const std::string &,ParseLocation) const30 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
31 return nullptr; // Packets can't be fields
32 }
33
GenParserDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const34 void PacketDef::GenParserDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
35 s << "class " << name_ << "View";
36 if (parent_ != nullptr) {
37 s << " : public " << parent_->name_ << "View {";
38 } else {
39 s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
40 }
41 s << " public:";
42
43 // Specialize function
44 if (parent_ != nullptr) {
45 s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
46 s << "{ return " << name_ << "View(std::move(parent)); }";
47 } else {
48 s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
49 s << "{ return " << name_ << "View(std::move(packet)); }";
50 }
51
52 if (generate_fuzzing || generate_tests) {
53 GenTestingParserFromBytes(s);
54 }
55
56 std::set<std::string> fixed_types = {
57 FixedScalarField::kFieldType,
58 FixedEnumField::kFieldType,
59 };
60
61 // Print all of the public fields which are all the fields minus the fixed fields.
62 const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
63 bool has_fixed_fields = public_fields.size() != fields_.size();
64 for (const auto& field : public_fields) {
65 GenParserFieldGetter(s, field);
66 s << "\n";
67 }
68 GenValidator(s);
69 s << "\n";
70
71 s << " public:";
72 GenParserToString(s);
73 s << "\n";
74
75 s << " protected:\n";
76 // Constructor from a View
77 if (parent_ != nullptr) {
78 s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
79 s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
80 } else {
81 s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
82 s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
83 }
84
85 // Print the private fields which are the fixed fields.
86 if (has_fixed_fields) {
87 const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
88 s << " private:\n";
89 for (const auto& field : private_fields) {
90 GenParserFieldGetter(s, field);
91 s << "\n";
92 }
93 }
94 s << "};\n";
95 }
96
GenTestingParserFromBytes(std::ostream & s) const97 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
98 s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
99
100 s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
101 s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
102 s << "return " << name_ << "View::Create(";
103 auto ancestor_ptr = parent_;
104 size_t parent_parens = 0;
105 while (ancestor_ptr != nullptr) {
106 s << ancestor_ptr->name_ << "View::Create(";
107 parent_parens++;
108 ancestor_ptr = ancestor_ptr->parent_;
109 }
110 s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
111 for (size_t i = 0; i < parent_parens; i++) {
112 s << ")";
113 }
114 s << ");";
115 s << "}";
116
117 s << "\n#endif\n";
118 }
119
GenParserDefinitionPybind11(std::ostream & s) const120 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
121 s << "py::class_<" << name_ << "View";
122 if (parent_ != nullptr) {
123 s << ", " << parent_->name_ << "View";
124 } else {
125 s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
126 }
127 s << ">(m, \"" << name_ << "View\")";
128 if (parent_ != nullptr) {
129 s << ".def(py::init([](" << parent_->name_ << "View parent) {";
130 } else {
131 s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
132 }
133 s << "auto view =" << name_ << "View::Create(std::move(parent));";
134 s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
135 s << "return view; }))";
136
137 s << ".def(py::init(&" << name_ << "View::Create))";
138 std::set<std::string> protected_field_types = {
139 FixedScalarField::kFieldType,
140 FixedEnumField::kFieldType,
141 SizeField::kFieldType,
142 CountField::kFieldType,
143 };
144 const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
145 for (const auto& field : public_fields) {
146 auto getter_func_name = field->GetGetterFunctionName();
147 if (getter_func_name.empty()) {
148 continue;
149 }
150 s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
151 }
152 s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
153 s << ";\n";
154 }
155
GenParserFieldGetter(std::ostream & s,const PacketField * field) const156 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
157 // Start field offset
158 auto start_field_offset = GetOffsetForField(field->GetName(), false);
159 auto end_field_offset = GetOffsetForField(field->GetName(), true);
160
161 if (start_field_offset.empty() && end_field_offset.empty()) {
162 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
163 << "no method exists to determine field location from begin() or end().\n";
164 }
165
166 field->GenGetter(s, start_field_offset, end_field_offset);
167 }
168
GetDefinitionType() const169 TypeDef::Type PacketDef::GetDefinitionType() const {
170 return TypeDef::Type::PACKET;
171 }
172
GenValidator(std::ostream & s) const173 void PacketDef::GenValidator(std::ostream& s) const {
174 // Get the static offset for all of our fields.
175 int bits_size = 0;
176 for (const auto& field : fields_) {
177 if (field->GetFieldType() != PaddingField::kFieldType) {
178 bits_size += field->GetSize().bits();
179 }
180 }
181
182 // Generate the public validator IsValid().
183 // The method only needs to be generated for the top most class.
184 if (parent_ == nullptr) {
185 s << "bool IsValid() {" << std::endl;
186 s << " if (was_validated_) {" << std::endl;
187 s << " return true;" << std::endl;
188 s << " } else {" << std::endl;
189 s << " was_validated_ = true;" << std::endl;
190 s << " return (was_validated_ = Validate());" << std::endl;
191 s << " }" << std::endl;
192 s << "}" << std::endl;
193 }
194
195 // Generate the private validator Validate().
196 // The method is overridden by all child classes.
197 s << "protected:" << std::endl;
198 if (parent_ == nullptr) {
199 s << "virtual bool Validate() const {" << std::endl;
200 } else {
201 s << "bool Validate() const override {" << std::endl;
202 s << " if (!" << parent_->name_ << "View::Validate()) {" << std::endl;
203 s << " return false;" << std::endl;
204 s << " }" << std::endl;
205 }
206
207 // Offset by the parents known size. We know that any dynamic fields can
208 // already be called since the parent must have already been validated by
209 // this point.
210 auto parent_size = Size(0);
211 if (parent_ != nullptr) {
212 parent_size = parent_->GetSize(true);
213 }
214
215 s << "auto it = begin() + (" << parent_size << ") / 8;";
216
217 // Check if you can extract the static fields.
218 // At this point you know you can use the size getters without crashing
219 // as long as they follow the instruction that size fields cant come before
220 // their corrisponding variable length field.
221 s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
222 s << "if (it > end()) return false;";
223
224 // For any variable length fields, use their size check.
225 for (const auto& field : fields_) {
226 if (field->GetFieldType() == ChecksumStartField::kFieldType) {
227 auto offset = GetOffsetForField(field->GetName(), false);
228 if (!offset.empty()) {
229 s << "size_t sum_index = (" << offset << ") / 8;";
230 } else {
231 offset = GetOffsetForField(field->GetName(), true);
232 if (offset.empty()) {
233 ERROR(field) << "Checksum Start Field offset can not be determined.";
234 }
235 s << "size_t sum_index = size() - (" << offset << ") / 8;";
236 }
237
238 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
239 const auto& started_field = fields_.GetField(field_name);
240 if (started_field == nullptr) {
241 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
242 << ")";
243 }
244 auto end_offset = GetOffsetForField(started_field->GetName(), false);
245 if (!end_offset.empty()) {
246 s << "size_t end_sum_index = (" << end_offset << ") / 8;";
247 } else {
248 end_offset = GetOffsetForField(started_field->GetName(), true);
249 if (end_offset.empty()) {
250 ERROR(started_field) << "Checksum Field end_offset can not be determined.";
251 }
252 s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
253 }
254 s << "if (end_sum_index >= size()) { return false; }";
255 if (is_little_endian_) {
256 s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
257 } else {
258 s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
259 }
260 s << started_field->GetDataType() << " checksum;";
261 s << "checksum.Initialize();";
262 s << "for (uint8_t byte : checksum_view) { ";
263 s << "checksum.AddByte(byte);}";
264 s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
265 << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
266
267 continue;
268 }
269
270 auto field_size = field->GetSize();
271 // Fixed size fields have already been handled.
272 if (!field_size.has_dynamic()) {
273 continue;
274 }
275
276 // Custom fields with dynamic size must have the offset for the field passed in as well
277 // as the end iterator so that they may ensure that they don't try to read past the end.
278 // Custom fields with fixed sizes will be handled in the static offset checking.
279 if (field->GetFieldType() == CustomField::kFieldType) {
280 // Check if we can determine offset from begin(), otherwise error because by this point,
281 // the size of the custom field is unknown and can't be subtracted from end() to get the
282 // offset.
283 auto offset = GetOffsetForField(field->GetName(), false);
284 if (offset.empty()) {
285 ERROR(field) << "Custom Field offset can not be determined from begin().";
286 }
287
288 if (offset.bits() % 8 != 0) {
289 ERROR(field) << "Custom fields must be byte aligned.";
290 }
291
292 // Custom fields are special as their size field takes an argument.
293 const auto& custom_size_var = field->GetName() + "_size";
294 s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
295 s << "(begin() + (" << offset << ") / 8);";
296
297 s << "if (!" << custom_size_var << ".has_value()) { return false; }";
298 s << "it += *" << custom_size_var << ";";
299 s << "if (it > end()) return false;";
300 continue;
301 } else {
302 s << "it += (" << field_size.dynamic_string() << ") / 8;";
303 s << "if (it > end()) return false;";
304 }
305 }
306
307 // Validate constraints after validating the size
308 if (parent_constraints_.size() > 0 && parent_ == nullptr) {
309 ERROR() << "Can't have a constraint on a NULL parent";
310 }
311
312 for (const auto& constraint : parent_constraints_) {
313 s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
314 const auto& field = parent_->GetParamList().GetField(constraint.first);
315 if (field->GetFieldType() == ScalarField::kFieldType) {
316 s << std::get<int64_t>(constraint.second);
317 } else {
318 s << std::get<std::string>(constraint.second);
319 }
320 s << ") return false;";
321 }
322
323 // Validate the packets fields last
324 for (const auto& field : fields_) {
325 field->GenValidator(s);
326 s << "\n";
327 }
328
329 s << "return true;";
330 s << "}\n";
331 if (parent_ == nullptr) {
332 s << "bool was_validated_{false};\n";
333 }
334 }
335
GenParserToString(std::ostream & s) const336 void PacketDef::GenParserToString(std::ostream& s) const {
337 s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
338 s << "std::stringstream ss;";
339 s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
340
341 if (fields_.size() > 0) {
342 s << "ss << \"\" ";
343 bool firstfield = true;
344 for (const auto& field : fields_) {
345 if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
346 field->GetFieldType() == ChecksumStartField::kFieldType)
347 continue;
348
349 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
350
351 field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
352
353 if (firstfield) {
354 firstfield = false;
355 }
356 }
357 s << ";";
358 }
359
360 s << "ss << \" }\";";
361 s << "return ss.str();";
362 s << "}\n";
363 }
364
GenBuilderDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const365 void PacketDef::GenBuilderDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
366 s << "class " << name_ << "Builder";
367 if (parent_ != nullptr) {
368 s << " : public " << parent_->name_ << "Builder";
369 } else {
370 if (is_little_endian_) {
371 s << " : public PacketBuilder<kLittleEndian>";
372 } else {
373 s << " : public PacketBuilder<!kLittleEndian>";
374 }
375 }
376 s << " {";
377 s << " public:";
378 s << " virtual ~" << name_ << "Builder() = default;";
379
380 if (!fields_.HasBody()) {
381 GenBuilderCreate(s);
382 s << "\n";
383
384 if (generate_fuzzing || generate_tests) {
385 GenTestingFromView(s);
386 s << "\n";
387 }
388 }
389
390 GenSerialize(s);
391 s << "\n";
392
393 GenSize(s);
394 s << "\n";
395
396 s << " protected:\n";
397 GenBuilderConstructor(s);
398 s << "\n";
399
400 GenBuilderParameterChecker(s);
401 s << "\n";
402
403 GenMembers(s);
404 s << "};\n";
405
406 if (generate_tests) {
407 GenTestDefine(s);
408 s << "\n";
409 }
410
411 if (generate_fuzzing || generate_tests) {
412 GenReflectTestDefine(s);
413 s << "\n";
414 }
415
416 if (generate_fuzzing) {
417 GenFuzzTestDefine(s);
418 s << "\n";
419 }
420 }
421
GenTestingFromView(std::ostream & s) const422 void PacketDef::GenTestingFromView(std::ostream& s) const {
423 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
424
425 s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
426 s << "return " << name_ << "Builder::Create(";
427 FieldList params = GetParamList().GetFieldsWithoutTypes({
428 BodyField::kFieldType,
429 });
430 for (std::size_t i = 0; i < params.size(); i++) {
431 params[i]->GenBuilderParameterFromView(s);
432 if (i != params.size() - 1) {
433 s << ", ";
434 }
435 }
436 s << ");";
437 s << "}";
438
439 s << "\n#endif\n";
440 }
441
GenBuilderDefinitionPybind11(std::ostream & s) const442 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
443 s << "py::class_<" << name_ << "Builder";
444 if (parent_ != nullptr) {
445 s << ", " << parent_->name_ << "Builder";
446 } else {
447 if (is_little_endian_) {
448 s << ", PacketBuilder<kLittleEndian>";
449 } else {
450 s << ", PacketBuilder<!kLittleEndian>";
451 }
452 }
453 s << ", std::shared_ptr<" << name_ << "Builder>";
454 s << ">(m, \"" << name_ << "Builder\")";
455 if (!fields_.HasBody()) {
456 GenBuilderCreatePybind11(s);
457 }
458 s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
459 s << "std::vector<uint8_t> bytes;";
460 s << "BitInserter bi(bytes);";
461 s << "builder.Serialize(bi);";
462 s << "return bytes;})";
463 s << ";\n";
464 }
465
GenTestDefine(std::ostream & s) const466 void PacketDef::GenTestDefine(std::ostream& s) const {
467 s << "#ifdef PACKET_TESTING\n";
468 s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
469 s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
470 s << "public: ";
471 s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
472 s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
473 s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
474 s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
475 s << "ASSERT_TRUE(view.IsValid());";
476 s << "auto packet = " << name_ << "Builder::FromView(view);";
477 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
478 s << "packet_bytes->reserve(packet->size());";
479 s << "BitInserter it(*packet_bytes);";
480 s << "packet->Serialize(it);";
481 s << "ASSERT_EQ(*packet_bytes, captured_packet);";
482 s << "}";
483 s << "};";
484 s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
485 s << "CompareBytes(GetParam());";
486 s << "}";
487 s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
488 s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
489 int i = 0;
490 for (const auto& bytes : test_cases_) {
491 s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
492 s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
493 s << name_ << "_test_bytes_" << i << ",";
494 s << name_ << "_test_bytes_" << i << " + sizeof(";
495 s << name_ << "_test_bytes_" << i << ") - 1);";
496 i++;
497 }
498 if (!test_cases_.empty()) {
499 i = 0;
500 s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
501 for (auto bytes : test_cases_) {
502 if (i > 0) {
503 s << ",";
504 }
505 s << name_ << "_test_vec_" << i++;
506 }
507 s << ");";
508 }
509 s << "\n#endif";
510 }
511
GenReflectTestDefine(std::ostream & s) const512 void PacketDef::GenReflectTestDefine(std::ostream& s) const {
513 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
514 s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
515 s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
516 s << "auto vec = std::vector<uint8_t>(data, data + size);";
517 s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
518 s << "if (!view.IsValid()) { return; }";
519 s << "auto packet = " << name_ << "Builder::FromView(view);";
520 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
521 s << "packet_bytes->reserve(packet->size());";
522 s << "BitInserter it(*packet_bytes);";
523 s << "packet->Serialize(it);";
524 s << "}";
525 s << "\n#endif\n";
526 }
527
GenFuzzTestDefine(std::ostream & s) const528 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
529 s << "#ifdef PACKET_FUZZ_TESTING\n";
530 s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
531 s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
532 s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
533 s << "public: ";
534 s << "explicit " << name_
535 << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
536 s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
537 s << "}}; ";
538 s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
539 s << "\n#endif";
540 }
541
GetParametersToValidate() const542 FieldList PacketDef::GetParametersToValidate() const {
543 FieldList params_to_validate;
544 for (const auto& field : GetParamList()) {
545 if (field->HasParameterValidator()) {
546 params_to_validate.AppendField(field);
547 }
548 }
549 return params_to_validate;
550 }
551
GenBuilderCreate(std::ostream & s) const552 void PacketDef::GenBuilderCreate(std::ostream& s) const {
553 s << "static std::unique_ptr<" << name_ << "Builder> Create(";
554
555 auto params = GetParamList();
556 for (std::size_t i = 0; i < params.size(); i++) {
557 params[i]->GenBuilderParameter(s);
558 if (i != params.size() - 1) {
559 s << ", ";
560 }
561 }
562 s << ") {";
563
564 // Call the constructor
565 s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
566
567 params = params.GetFieldsWithoutTypes({
568 PayloadField::kFieldType,
569 BodyField::kFieldType,
570 });
571 // Add the parameters.
572 for (std::size_t i = 0; i < params.size(); i++) {
573 if (params[i]->BuilderParameterMustBeMoved()) {
574 s << "std::move(" << params[i]->GetName() << ")";
575 } else {
576 s << params[i]->GetName();
577 }
578 if (i != params.size() - 1) {
579 s << ", ";
580 }
581 }
582
583 s << "));";
584 if (fields_.HasPayload()) {
585 s << "builder->payload_ = std::move(payload);";
586 }
587 s << "return builder;";
588 s << "}\n";
589 }
590
GenBuilderCreatePybind11(std::ostream & s) const591 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
592 s << ".def(py::init([](";
593 auto params = GetParamList();
594 std::vector<std::string> constructor_args;
595 int i = 1;
596 for (const auto& param : params) {
597 i++;
598 std::stringstream ss;
599 auto param_type = param->GetBuilderParameterType();
600 if (param_type.empty()) {
601 continue;
602 }
603 // Use shared_ptr instead of unique_ptr for the Python interface
604 if (param->BuilderParameterMustBeMoved()) {
605 param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
606 }
607 ss << param_type << " " << param->GetName();
608 constructor_args.push_back(ss.str());
609 }
610 s << util::StringJoin(",", constructor_args) << "){";
611
612 // Deal with move only args
613 for (const auto& param : params) {
614 std::stringstream ss;
615 auto param_type = param->GetBuilderParameterType();
616 if (param_type.empty()) {
617 continue;
618 }
619 if (!param->BuilderParameterMustBeMoved()) {
620 continue;
621 }
622 auto move_only_param_name = param->GetName() + "_move_only";
623 s << param_type << " " << move_only_param_name << ";";
624 if (param->IsContainerField()) {
625 // Assume single layer container and copy it
626 auto struct_type = param->GetElementField()->GetDataType();
627 struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
628 struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
629 s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
630 // Serialize each struct
631 s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
632 s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
633 s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
634 s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
635 // Parse it again
636 s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
637 s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
638 s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
639 // Push it into a new container
640 if (param->GetFieldType() == VectorField::kFieldType) {
641 s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
642 } else if (param->GetFieldType() == ArrayField::kFieldType) {
643 s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
644 } else {
645 ERROR() << param << " is not supported by Pybind11";
646 }
647 s << "}";
648 } else {
649 // Serialize the parameter and pass the bytes in a RawBuilder
650 s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
651 s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
652 s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
653 s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
654 s << move_only_param_name << " = ";
655 s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
656 }
657 }
658 s << "return " << name_ << "Builder::Create(";
659 std::vector<std::string> builder_vars;
660 for (const auto& param : params) {
661 std::stringstream ss;
662 auto param_type = param->GetBuilderParameterType();
663 if (param_type.empty()) {
664 continue;
665 }
666 auto param_name = param->GetName();
667 if (param->BuilderParameterMustBeMoved()) {
668 ss << "std::move(" << param_name << "_move_only)";
669 } else {
670 ss << param_name;
671 }
672 builder_vars.push_back(ss.str());
673 }
674 s << util::StringJoin(",", builder_vars) << ");}";
675 s << "))";
676 }
677
GenBuilderParameterChecker(std::ostream & s) const678 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
679 FieldList params_to_validate = GetParametersToValidate();
680
681 // Skip writing this function if there is nothing to validate.
682 if (params_to_validate.size() == 0) {
683 return;
684 }
685
686 // Generate function arguments.
687 s << "void CheckParameterValues(";
688 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
689 params_to_validate[i]->GenBuilderParameter(s);
690 if (i != params_to_validate.size() - 1) {
691 s << ", ";
692 }
693 }
694 s << ") {";
695
696 // Check the parameters.
697 for (const auto& field : params_to_validate) {
698 field->GenParameterValidator(s);
699 }
700 s << "}\n";
701 }
702
GenBuilderConstructor(std::ostream & s) const703 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
704 s << "explicit " << name_ << "Builder(";
705
706 // Generate the constructor parameters.
707 auto params = GetParamList().GetFieldsWithoutTypes({
708 PayloadField::kFieldType,
709 BodyField::kFieldType,
710 });
711 for (std::size_t i = 0; i < params.size(); i++) {
712 params[i]->GenBuilderParameter(s);
713 if (i != params.size() - 1) {
714 s << ", ";
715 }
716 }
717 if (params.size() > 0 || parent_constraints_.size() > 0) {
718 s << ") :";
719 } else {
720 s << ")";
721 }
722
723 // Get the list of parent params to call the parent constructor with.
724 FieldList parent_params;
725 if (parent_ != nullptr) {
726 // Pass parameters to the parent constructor
727 s << parent_->name_ << "Builder(";
728 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
729 PayloadField::kFieldType,
730 BodyField::kFieldType,
731 });
732
733 // Go through all the fields and replace constrained fields with fixed values
734 // when calling the parent constructor.
735 for (std::size_t i = 0; i < parent_params.size(); i++) {
736 const auto& field = parent_params[i];
737 const auto& constraint = parent_constraints_.find(field->GetName());
738 if (constraint != parent_constraints_.end()) {
739 if (field->GetFieldType() == ScalarField::kFieldType) {
740 s << std::get<int64_t>(constraint->second);
741 } else if (field->GetFieldType() == EnumField::kFieldType) {
742 s << std::get<std::string>(constraint->second);
743 } else {
744 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
745 }
746
747 s << "/* " << field->GetName() << "_ */";
748 } else {
749 s << field->GetName();
750 }
751
752 if (i != parent_params.size() - 1) {
753 s << ", ";
754 }
755 }
756 s << ") ";
757 }
758
759 // Build a list of parameters that excludes all parent parameters.
760 FieldList saved_params;
761 for (const auto& field : params) {
762 if (parent_params.GetField(field->GetName()) == nullptr) {
763 saved_params.AppendField(field);
764 }
765 }
766 if (parent_ != nullptr && saved_params.size() > 0) {
767 s << ",";
768 }
769 for (std::size_t i = 0; i < saved_params.size(); i++) {
770 const auto& saved_param_name = saved_params[i]->GetName();
771 if (saved_params[i]->BuilderParameterMustBeMoved()) {
772 s << saved_param_name << "_(std::move(" << saved_param_name << "))";
773 } else {
774 s << saved_param_name << "_(" << saved_param_name << ")";
775 }
776 if (i != saved_params.size() - 1) {
777 s << ",";
778 }
779 }
780 s << " {";
781
782 FieldList params_to_validate = GetParametersToValidate();
783
784 if (params_to_validate.size() > 0) {
785 s << "CheckParameterValues(";
786 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
787 s << params_to_validate[i]->GetName() << "_";
788 if (i != params_to_validate.size() - 1) {
789 s << ", ";
790 }
791 }
792 s << ");";
793 }
794
795 s << "}\n";
796 }
797
GenRustChildEnums(std::ostream & s) const798 void PacketDef::GenRustChildEnums(std::ostream& s) const {
799 if (HasChildEnums()) {
800 bool payload = fields_.HasPayload();
801 s << "#[derive(Debug)] ";
802 s << "enum " << name_ << "DataChild {";
803 for (const auto& child : children_) {
804 s << child->name_ << "(Arc<" << child->name_ << "Data>),";
805 }
806 if (payload) {
807 s << "Payload(Bytes),";
808 }
809 s << "None,";
810 s << "}\n";
811
812 s << "impl " << name_ << "DataChild {";
813 s << "fn get_total_size(&self) -> usize {";
814 s << "match self {";
815 for (const auto& child : children_) {
816 s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),";
817 }
818 if (payload) {
819 s << name_ << "DataChild::Payload(p) => p.len(),";
820 }
821 s << name_ << "DataChild::None => 0,";
822 s << "}\n";
823 s << "}\n";
824 s << "}\n";
825
826 s << "#[derive(Debug)] ";
827 s << "pub enum " << name_ << "Child {";
828 for (const auto& child : children_) {
829 s << child->name_ << "(" << child->name_ << "Packet),";
830 }
831 if (payload) {
832 s << "Payload(Bytes),";
833 }
834 s << "None,";
835 s << "}\n";
836 }
837 }
838
GenRustStructDeclarations(std::ostream & s) const839 void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
840 s << "#[derive(Debug)] ";
841 s << "struct " << name_ << "Data {";
842
843 // Generate struct fields
844 GenRustStructFieldNameAndType(s);
845 if (HasChildEnums()) {
846 s << "child: " << name_ << "DataChild,";
847 }
848 s << "}\n";
849
850 // Generate accessor struct
851 s << "#[derive(Debug, Clone)] ";
852 s << "pub struct " << name_ << "Packet {";
853 auto lineage = GetAncestors();
854 lineage.push_back(this);
855 for (auto it = lineage.begin(); it != lineage.end(); it++) {
856 auto def = *it;
857 s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,";
858 }
859 s << "}\n";
860
861 // Generate builder struct
862 s << "#[derive(Debug)] ";
863 s << "pub struct " << name_ << "Builder {";
864 auto params = GetParamList().GetFieldsWithoutTypes({
865 PayloadField::kFieldType,
866 BodyField::kFieldType,
867 });
868 for (auto param : params) {
869 s << "pub ";
870 param->GenRustNameAndType(s);
871 s << ", ";
872 }
873 if (fields_.HasPayload()) {
874 s << "pub payload: Option<Bytes>,";
875 }
876 s << "}\n";
877 }
878
GenRustStructFieldNameAndType(std::ostream & s) const879 bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
880 auto fields = fields_.GetFieldsWithoutTypes({
881 BodyField::kFieldType,
882 CountField::kFieldType,
883 PaddingField::kFieldType,
884 ReservedField::kFieldType,
885 SizeField::kFieldType,
886 PayloadField::kFieldType,
887 FixedScalarField::kFieldType,
888 });
889 if (fields.size() == 0) {
890 return false;
891 }
892 for (const auto& field : fields) {
893 field->GenRustNameAndType(s);
894 s << ", ";
895 }
896 return true;
897 }
898
GenRustStructFieldNames(std::ostream & s) const899 void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
900 auto fields = fields_.GetFieldsWithoutTypes({
901 BodyField::kFieldType,
902 CountField::kFieldType,
903 PaddingField::kFieldType,
904 ReservedField::kFieldType,
905 SizeField::kFieldType,
906 PayloadField::kFieldType,
907 FixedScalarField::kFieldType,
908 });
909 for (const auto field : fields) {
910 s << field->GetName();
911 s << ", ";
912 }
913 }
914
GenRustStructImpls(std::ostream & s) const915 void PacketDef::GenRustStructImpls(std::ostream& s) const {
916 auto packet_dep = PacketDependency(GetRootDef());
917
918 s << "impl " << name_ << "Data {";
919 // conforms function
920 s << "fn conforms(bytes: &[u8]) -> bool {";
921 GenRustConformanceCheck(s);
922
923 auto fields = fields_.GetFieldsWithTypes({
924 StructField::kFieldType,
925 });
926
927 for (auto const& field : fields) {
928 auto start_offset = GetOffsetForField(field->GetName(), false);
929 auto end_offset = GetOffsetForField(field->GetName(), true);
930
931 s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes();
932 s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }";
933 }
934
935 s << " true";
936 s << "}";
937
938 auto parse_params = packet_dep.GetDependencies(name_);
939 s << "fn parse(bytes: &[u8]";
940 for (auto field_name : parse_params) {
941 auto constraint_field = GetParamList().GetField(field_name);
942 auto constraint_type = constraint_field->GetRustDataType();
943 s << ", " << field_name << ": " << constraint_type;
944 }
945 s << ") -> Result<Self> {";
946
947 fields = fields_.GetFieldsWithoutTypes({
948 BodyField::kFieldType,
949 });
950
951 for (auto const& field : fields) {
952 auto start_field_offset = GetOffsetForField(field->GetName(), false);
953 auto end_field_offset = GetOffsetForField(field->GetName(), true);
954
955 if (start_field_offset.empty() && end_field_offset.empty()) {
956 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
957 << "no method exists to determine field location from begin() or end().\n";
958 }
959
960 field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
961 field->GenRustGetter(s, start_field_offset, end_field_offset, name_);
962 }
963
964 auto payload_field = fields_.GetFieldsWithTypes({
965 PayloadField::kFieldType,
966 });
967
968 Size payload_offset;
969
970 if (payload_field.HasPayload()) {
971 payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
972 }
973
974 if (children_.size() > 1) {
975 auto match_on_variables = packet_dep.GetChildrenDependencies(name_);
976 // If match_on_variables is empty, this means there are multiple abstract packets which will
977 // specialize to a child down the packet tree.
978 // In this case match variables will be the union of parent fields and parse params of children.
979 if (match_on_variables.empty()) {
980 for (auto& field : fields_) {
981 if (std::any_of(children_.begin(), children_.end(), [&](auto child) {
982 auto pass_me = packet_dep.GetDependencies(child->name_);
983 return std::find(pass_me.begin(), pass_me.end(), field->GetName()) != pass_me.end();
984 })) {
985 match_on_variables.push_back(field->GetName());
986 }
987 }
988 }
989
990 s << "let child = match (";
991
992 for (auto var : match_on_variables) {
993 if (var == match_on_variables[match_on_variables.size() - 1]) {
994 s << var;
995 } else {
996 s << var << ", ";
997 }
998 }
999 s << ") {";
1000
1001 auto get_match_val = [&](
1002 std::string& match_var,
1003 std::variant<int64_t,
1004 std::string> constraint) -> std::string {
1005 auto constraint_field = GetParamList().GetField(match_var);
1006 auto constraint_type = constraint_field->GetFieldType();
1007
1008 if (constraint_type == EnumField::kFieldType) {
1009 auto type = std::get<std::string>(constraint);
1010 auto variant_name = type.substr(type.find("::") + 2, type.length());
1011 auto enum_type = type.substr(0, type.find("::"));
1012 return enum_type + "::" + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
1013 }
1014 if (constraint_type == ScalarField::kFieldType) {
1015 return std::to_string(std::get<int64_t>(constraint));
1016 }
1017 return "_";
1018 };
1019
1020 for (auto& child : children_) {
1021 s << "(";
1022 for (auto var : match_on_variables) {
1023 std::string match_val = "_";
1024
1025 if (child->parent_constraints_.find(var) != child->parent_constraints_.end()) {
1026 match_val = get_match_val(var, child->parent_constraints_[var]);
1027 } else {
1028 auto dcs = child->FindDescendantsWithConstraint(var);
1029 std::vector<std::string> all_match_vals;
1030 for (auto& desc : dcs) {
1031 all_match_vals.push_back(get_match_val(var, desc.second));
1032 }
1033 match_val = "";
1034 for (std::size_t i = 0; i < all_match_vals.size(); ++i) {
1035 match_val += all_match_vals[i];
1036 if (i != all_match_vals.size() - 1) {
1037 match_val += " | ";
1038 }
1039 }
1040 match_val = (match_val == "") ? "_" : match_val;
1041 }
1042
1043 if (var == match_on_variables[match_on_variables.size() - 1]) {
1044 s << match_val << ")";
1045 } else {
1046 s << match_val << ", ";
1047 }
1048 }
1049 s << " if " << child->name_ << "Data::conforms(&bytes[..])";
1050 s << " => {";
1051 s << name_ << "DataChild::";
1052 s << child->name_ << "(Arc::new(";
1053
1054 auto child_parse_params = packet_dep.GetDependencies(child->name_);
1055 if (child_parse_params.size() == 0) {
1056 s << child->name_ << "Data::parse(&bytes[..]";
1057 } else {
1058 s << child->name_ << "Data::parse(&bytes[..], ";
1059 }
1060
1061 for (auto var : child_parse_params) {
1062 if (var == child_parse_params[child_parse_params.size() - 1]) {
1063 s << var;
1064 } else {
1065 s << var << ", ";
1066 }
1067 }
1068 s << ")?))";
1069 s << "}\n";
1070 }
1071
1072 s << "(";
1073 for (int i = 1; i <= match_on_variables.size(); i++) {
1074 if (i == match_on_variables.size()) {
1075 s << "_";
1076 } else {
1077 s << "_, ";
1078 }
1079 }
1080 s << ")";
1081 s << " => return Err(Error::InvalidPacketError),";
1082 s << "};\n";
1083 } else if (children_.size() == 1) {
1084 auto child = children_.at(0);
1085 auto params = packet_dep.GetDependencies(child->name_);
1086 s << "let child = match " << child->name_ << "Data::parse(&bytes[..]";
1087 for (auto field_name : params) {
1088 s << ", " << field_name;
1089 }
1090 s << ") {";
1091 s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
1092 s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
1093 s << " },";
1094 s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,";
1095 s << " _ => return Err(Error::InvalidPacketError),";
1096 s << "};";
1097 } else if (fields_.HasPayload()) {
1098 s << "let child = if payload.len() > 0 {";
1099 s << name_ << "DataChild::Payload(Bytes::from(payload))";
1100 s << "} else {";
1101 s << name_ << "DataChild::None";
1102 s << "};";
1103 }
1104
1105 s << "Ok(Self {";
1106 fields = fields_.GetFieldsWithoutTypes({
1107 BodyField::kFieldType,
1108 CountField::kFieldType,
1109 PaddingField::kFieldType,
1110 ReservedField::kFieldType,
1111 SizeField::kFieldType,
1112 PayloadField::kFieldType,
1113 FixedScalarField::kFieldType,
1114 });
1115
1116 if (fields.size() > 0) {
1117 for (const auto& field : fields) {
1118 auto field_type = field->GetFieldType();
1119 s << field->GetName();
1120 s << ", ";
1121 }
1122 }
1123
1124 if (HasChildEnums()) {
1125 s << "child,";
1126 }
1127 s << "})\n";
1128 s << "}\n";
1129
1130 // write_to function
1131 s << "fn write_to(&self, buffer: &mut BytesMut) {";
1132 GenRustWriteToFields(s);
1133
1134 if (HasChildEnums()) {
1135 s << "match &self.child {";
1136 for (const auto& child : children_) {
1137 s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
1138 }
1139 if (fields_.HasPayload()) {
1140 auto offset = GetOffsetForField("payload");
1141 s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),";
1142 }
1143 s << name_ << "DataChild::None => {}";
1144 s << "}";
1145 }
1146
1147 s << "}\n";
1148
1149 s << "fn get_total_size(&self) -> usize {";
1150 if (HasChildEnums()) {
1151 s << "self.get_size() + self.child.get_total_size()";
1152 } else {
1153 s << "self.get_size()";
1154 }
1155 s << "}\n";
1156
1157 s << "fn get_size(&self) -> usize {";
1158 GenSizeRetVal(s);
1159 s << "}\n";
1160 s << "}\n";
1161 }
1162
GenRustAccessStructImpls(std::ostream & s) const1163 void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
1164 if (complement_ != nullptr) {
1165 auto complement_root = complement_->GetRootDef();
1166 auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1167 s << "impl CommandExpectations for " << name_ << "Packet {";
1168 s << " type ResponseType = " << complement_->name_ << "Packet;";
1169 s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1170 s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1171 << ".unwrap()";
1172 s << " }";
1173 s << "}";
1174 }
1175
1176 s << "impl Packet for " << name_ << "Packet {";
1177 auto root = GetRootDef();
1178 auto root_accessor = util::CamelCaseToUnderScore(root->name_);
1179
1180 s << "fn to_bytes(self) -> Bytes {";
1181 s << " let mut buffer = BytesMut::new();";
1182 s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);";
1183 s << " self." << root_accessor << ".write_to(&mut buffer);";
1184 s << " buffer.freeze()";
1185 s << "}\n";
1186
1187 s << "fn to_vec(self) -> Vec<u8> { self.to_bytes().to_vec() }\n";
1188 s << "}";
1189
1190 s << "impl From<" << name_ << "Packet"
1191 << "> for Bytes {\n";
1192 s << "fn from(packet: " << name_ << "Packet"
1193 << ") -> Self {\n";
1194 s << "packet.to_bytes()\n";
1195 s << "}\n";
1196 s << "}\n";
1197
1198 s << "impl From<" << name_ << "Packet"
1199 << "> for Vec<u8> {\n";
1200 s << "fn from(packet: " << name_ << "Packet"
1201 << ") -> Self {\n";
1202 s << "packet.to_vec()\n";
1203 s << "}\n";
1204 s << "}\n";
1205
1206 if (root != this) {
1207 s << "impl TryFrom<" << root->name_ << "Packet"
1208 << "> for " << name_ << "Packet {\n";
1209 s << "type Error = TryFromError;\n";
1210 s << "fn try_from(value: " << root->name_ << "Packet)"
1211 << " -> std::result::Result<Self, Self::Error> {\n";
1212 s << "Self::new(value." << root_accessor << ").map_err(TryFromError)\n", s << "}\n";
1213 s << "}\n";
1214 }
1215
1216 s << "impl " << name_ << "Packet {";
1217 if (parent_ == nullptr) {
1218 s << "pub fn parse(bytes: &[u8]) -> Result<Self> { ";
1219 s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)).unwrap())";
1220 s << "}";
1221 }
1222
1223 if (HasChildEnums()) {
1224 s << " pub fn specialize(&self) -> " << name_ << "Child {";
1225 s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
1226 for (const auto& child : children_) {
1227 s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
1228 << child->name_ << "Packet::new(self." << root_accessor << ".clone()).unwrap()),";
1229 }
1230 if (fields_.HasPayload()) {
1231 s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),";
1232 }
1233 s << name_ << "DataChild::None => " << name_ << "Child::None,";
1234 s << "}}";
1235 }
1236 auto lineage = GetAncestors();
1237 lineage.push_back(this);
1238 const ParentDef* prev = nullptr;
1239
1240 s << " fn new(root: Arc<" << root->name_ << "Data>) -> std::result::Result<Self, &'static str> {";
1241 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1242 auto def = *it;
1243 auto accessor_name = util::CamelCaseToUnderScore(def->name_);
1244 if (prev == nullptr) {
1245 s << "let " << accessor_name << " = root;";
1246 } else {
1247 s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {";
1248 s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),";
1249 s << "_ => return Err(\"inconsistent state - child was not " << def->name_ << "\"),";
1250 s << "};";
1251 }
1252 prev = def;
1253 }
1254 s << "Ok(Self {";
1255 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1256 auto def = *it;
1257 s << util::CamelCaseToUnderScore(def->name_) << ",";
1258 }
1259 s << "})}";
1260
1261 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1262 auto def = *it;
1263 auto fields = def->fields_.GetFieldsWithoutTypes({
1264 BodyField::kFieldType,
1265 CountField::kFieldType,
1266 PaddingField::kFieldType,
1267 ReservedField::kFieldType,
1268 SizeField::kFieldType,
1269 PayloadField::kFieldType,
1270 FixedScalarField::kFieldType,
1271 });
1272
1273 for (auto const& field : fields) {
1274 if (field->GetterIsByRef()) {
1275 s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{";
1276 s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1277 s << "}\n";
1278 } else {
1279 s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{";
1280 s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1281 s << "}\n";
1282 }
1283 }
1284 }
1285
1286 s << "}\n";
1287
1288 lineage = GetAncestors();
1289 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1290 auto def = *it;
1291 s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {";
1292 s << " fn into(self) -> " << def->name_ << "Packet {";
1293 s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")"
1294 << ".unwrap()";
1295 s << " }";
1296 s << "}\n";
1297 }
1298 }
1299
GenRustBuilderStructImpls(std::ostream & s) const1300 void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
1301 if (complement_ != nullptr) {
1302 auto complement_root = complement_->GetRootDef();
1303 auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1304 s << "impl CommandExpectations for " << name_ << "Builder {";
1305 s << " type ResponseType = " << complement_->name_ << "Packet;";
1306 s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1307 s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1308 << ".unwrap()";
1309 s << " }";
1310 s << "}";
1311 }
1312
1313 s << "impl " << name_ << "Builder {";
1314 s << "pub fn build(self) -> " << name_ << "Packet {";
1315 auto lineage = GetAncestors();
1316 lineage.push_back(this);
1317 std::reverse(lineage.begin(), lineage.end());
1318
1319 auto all_constraints = GetAllConstraints();
1320
1321 const ParentDef* prev = nullptr;
1322 for (auto ancestor : lineage) {
1323 auto fields = ancestor->fields_.GetFieldsWithoutTypes({
1324 BodyField::kFieldType,
1325 CountField::kFieldType,
1326 PaddingField::kFieldType,
1327 ReservedField::kFieldType,
1328 SizeField::kFieldType,
1329 PayloadField::kFieldType,
1330 FixedScalarField::kFieldType,
1331 });
1332
1333 auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_);
1334 s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {";
1335 for (auto field : fields) {
1336 auto constraint = all_constraints.find(field->GetName());
1337 s << field->GetName() << ": ";
1338 if (constraint != all_constraints.end()) {
1339 if (field->GetFieldType() == ScalarField::kFieldType) {
1340 s << std::get<int64_t>(constraint->second);
1341 } else if (field->GetFieldType() == EnumField::kFieldType) {
1342 auto value = std::get<std::string>(constraint->second);
1343 auto constant = value.substr(value.find("::") + 2, std::string::npos);
1344 s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant);
1345 ;
1346 } else {
1347 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
1348 }
1349 } else {
1350 s << "self." << field->GetName();
1351 }
1352 s << ", ";
1353 }
1354 if (ancestor->HasChildEnums()) {
1355 if (prev == nullptr) {
1356 if (ancestor->fields_.HasPayload()) {
1357 s << "child: match self.payload { ";
1358 s << "None => " << name_ << "DataChild::None,";
1359 s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),";
1360 s << "},";
1361 } else {
1362 s << "child: " << name_ << "DataChild::None,";
1363 }
1364 } else {
1365 s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "("
1366 << util::CamelCaseToUnderScore(prev->name_) << "),";
1367 }
1368 }
1369 s << "});";
1370 prev = ancestor;
1371 }
1372
1373 s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ").unwrap()";
1374 s << "}\n";
1375
1376 s << "}\n";
1377 for (const auto ancestor : GetAncestors()) {
1378 s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {";
1379 s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }";
1380 s << "}\n";
1381 }
1382 }
1383
GenRustBuilderTest(std::ostream & s) const1384 void PacketDef::GenRustBuilderTest(std::ostream& s) const {
1385 auto lineage = GetAncestors();
1386 lineage.push_back(this);
1387 if (!lineage.empty() && !test_cases_.empty()) {
1388 s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { ";
1389 s << "($($name:ident: $byte_string:expr,)*) => {";
1390 s << "$(";
1391 s << "\n#[test]\n";
1392 s << "pub fn $name() { ";
1393 s << "let raw_bytes = $byte_string;";
1394 for (size_t i = 0; i < lineage.size(); i++) {
1395 s << "/* (" << i << ") */\n";
1396 if (i == 0) {
1397 s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {";
1398 s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1399 s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1400 } else if (i != lineage.size() - 1) {
1401 s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(";
1402 s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1403 s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1404 } else {
1405 s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {";
1406 s << "let rebuilder = " << lineage[i]->name_ << "Builder {";
1407 FieldList params = GetParamList();
1408 if (params.HasBody()) {
1409 ERROR() << "Packets with body fields can't be auto-tested. Test a child.";
1410 }
1411 for (const auto param : params) {
1412 s << param->GetName() << " : packet.";
1413 if (param->GetFieldType() == VectorField::kFieldType) {
1414 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1415 } else if (param->GetFieldType() == ArrayField::kFieldType) {
1416 const auto array_param = static_cast<const ArrayField*>(param);
1417 const auto element_field = array_param->GetElementField();
1418 if (element_field->GetFieldType() == StructField::kFieldType) {
1419 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1420 } else {
1421 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1422 }
1423 } else if (param->GetFieldType() == StructField::kFieldType) {
1424 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1425 } else {
1426 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),";
1427 }
1428 }
1429 s << "};";
1430 s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();";
1431 s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();";
1432 s << "assert_eq!(rebuilder_bytes, raw_bytes);";
1433 s << "}";
1434 }
1435 }
1436 for (size_t i = 1; i < lineage.size(); i++) {
1437 s << "_ => {";
1438 s << "panic!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_);
1439 s << "\n {:#02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); ";
1440 s << "}}}";
1441 }
1442
1443 s << ",";
1444 s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),";
1445 s << "}";
1446 s << "}";
1447 s << ")*";
1448 s << "}";
1449 s << "}";
1450
1451 s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { ";
1452 int number = 0;
1453 for (const auto& test_case : test_cases_) {
1454 s << util::CamelCaseToUnderScore(name_) << "_builder_test_";
1455 s << std::setfill('0') << std::setw(2) << number++ << ": ";
1456 s << "b\"" << test_case << "\",";
1457 }
1458 s << "}";
1459 s << "\n";
1460 }
1461 }
1462
GenRustDef(std::ostream & s) const1463 void PacketDef::GenRustDef(std::ostream& s) const {
1464 GenRustChildEnums(s);
1465 GenRustStructDeclarations(s);
1466 GenRustStructImpls(s);
1467 GenRustAccessStructImpls(s);
1468 GenRustBuilderStructImpls(s);
1469 GenRustBuilderTest(s);
1470 }
1471