• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "packet_def.h"
18 
19 #include <iomanip>
20 #include <list>
21 #include <set>
22 
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26 
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
29 
GetNewField(const std::string &,ParseLocation) const30 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
31   return nullptr;  // Packets can't be fields
32 }
33 
GenParserDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const34 void PacketDef::GenParserDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
35   s << "class " << name_ << "View";
36   if (parent_ != nullptr) {
37     s << " : public " << parent_->name_ << "View {";
38   } else {
39     s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
40   }
41   s << " public:";
42 
43   // Specialize function
44   if (parent_ != nullptr) {
45     s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
46     s << "{ return " << name_ << "View(std::move(parent)); }";
47   } else {
48     s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
49     s << "{ return " << name_ << "View(std::move(packet)); }";
50   }
51 
52   if (generate_fuzzing || generate_tests) {
53     GenTestingParserFromBytes(s);
54   }
55 
56   std::set<std::string> fixed_types = {
57       FixedScalarField::kFieldType,
58       FixedEnumField::kFieldType,
59   };
60 
61   // Print all of the public fields which are all the fields minus the fixed fields.
62   const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
63   bool has_fixed_fields = public_fields.size() != fields_.size();
64   for (const auto& field : public_fields) {
65     GenParserFieldGetter(s, field);
66     s << "\n";
67   }
68   GenValidator(s);
69   s << "\n";
70 
71   s << " public:";
72   GenParserToString(s);
73   s << "\n";
74 
75   s << " protected:\n";
76   // Constructor from a View
77   if (parent_ != nullptr) {
78     s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
79     s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
80   } else {
81     s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
82     s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
83   }
84 
85   // Print the private fields which are the fixed fields.
86   if (has_fixed_fields) {
87     const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
88     s << " private:\n";
89     for (const auto& field : private_fields) {
90       GenParserFieldGetter(s, field);
91       s << "\n";
92     }
93   }
94   s << "};\n";
95 }
96 
GenTestingParserFromBytes(std::ostream & s) const97 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
98   s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
99 
100   s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
101   s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
102   s << "return " << name_ << "View::Create(";
103   auto ancestor_ptr = parent_;
104   size_t parent_parens = 0;
105   while (ancestor_ptr != nullptr) {
106     s << ancestor_ptr->name_ << "View::Create(";
107     parent_parens++;
108     ancestor_ptr = ancestor_ptr->parent_;
109   }
110   s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
111   for (size_t i = 0; i < parent_parens; i++) {
112     s << ")";
113   }
114   s << ");";
115   s << "}";
116 
117   s << "\n#endif\n";
118 }
119 
GenParserDefinitionPybind11(std::ostream & s) const120 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
121   s << "py::class_<" << name_ << "View";
122   if (parent_ != nullptr) {
123     s << ", " << parent_->name_ << "View";
124   } else {
125     s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
126   }
127   s << ">(m, \"" << name_ << "View\")";
128   if (parent_ != nullptr) {
129     s << ".def(py::init([](" << parent_->name_ << "View parent) {";
130   } else {
131     s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
132   }
133   s << "auto view =" << name_ << "View::Create(std::move(parent));";
134   s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
135   s << "return view; }))";
136 
137   s << ".def(py::init(&" << name_ << "View::Create))";
138   std::set<std::string> protected_field_types = {
139       FixedScalarField::kFieldType,
140       FixedEnumField::kFieldType,
141       SizeField::kFieldType,
142       CountField::kFieldType,
143   };
144   const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
145   for (const auto& field : public_fields) {
146     auto getter_func_name = field->GetGetterFunctionName();
147     if (getter_func_name.empty()) {
148       continue;
149     }
150     s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
151   }
152   s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
153   s << ";\n";
154 }
155 
GenParserFieldGetter(std::ostream & s,const PacketField * field) const156 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
157   // Start field offset
158   auto start_field_offset = GetOffsetForField(field->GetName(), false);
159   auto end_field_offset = GetOffsetForField(field->GetName(), true);
160 
161   if (start_field_offset.empty() && end_field_offset.empty()) {
162     ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
163                  << "no method exists to determine field location from begin() or end().\n";
164   }
165 
166   field->GenGetter(s, start_field_offset, end_field_offset);
167 }
168 
GetDefinitionType() const169 TypeDef::Type PacketDef::GetDefinitionType() const {
170   return TypeDef::Type::PACKET;
171 }
172 
GenValidator(std::ostream & s) const173 void PacketDef::GenValidator(std::ostream& s) const {
174   // Get the static offset for all of our fields.
175   int bits_size = 0;
176   for (const auto& field : fields_) {
177     if (field->GetFieldType() != PaddingField::kFieldType) {
178       bits_size += field->GetSize().bits();
179     }
180   }
181 
182   // Generate the public validator IsValid().
183   // The method only needs to be generated for the top most class.
184   if (parent_ == nullptr) {
185     s << "bool IsValid() {" << std::endl;
186     s << "  if (was_validated_) {" << std::endl;
187     s << "    return true;" << std::endl;
188     s << "  } else {" << std::endl;
189     s << "    was_validated_ = true;" << std::endl;
190     s << "    return (was_validated_ = Validate());" << std::endl;
191     s << "  }" << std::endl;
192     s << "}" << std::endl;
193   }
194 
195   // Generate the private validator Validate().
196   // The method is overridden by all child classes.
197   s << "protected:" << std::endl;
198   if (parent_ == nullptr) {
199     s << "virtual bool Validate() const {" << std::endl;
200   } else {
201     s << "bool Validate() const override {" << std::endl;
202     s << "  if (!" << parent_->name_ << "View::Validate()) {" << std::endl;
203     s << "    return false;" << std::endl;
204     s << "  }" << std::endl;
205   }
206 
207   // Offset by the parents known size. We know that any dynamic fields can
208   // already be called since the parent must have already been validated by
209   // this point.
210   auto parent_size = Size(0);
211   if (parent_ != nullptr) {
212     parent_size = parent_->GetSize(true);
213   }
214 
215   s << "auto it = begin() + (" << parent_size << ") / 8;";
216 
217   // Check if you can extract the static fields.
218   // At this point you know you can use the size getters without crashing
219   // as long as they follow the instruction that size fields cant come before
220   // their corrisponding variable length field.
221   s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
222   s << "if (it > end()) return false;";
223 
224   // For any variable length fields, use their size check.
225   for (const auto& field : fields_) {
226     if (field->GetFieldType() == ChecksumStartField::kFieldType) {
227       auto offset = GetOffsetForField(field->GetName(), false);
228       if (!offset.empty()) {
229         s << "size_t sum_index = (" << offset << ") / 8;";
230       } else {
231         offset = GetOffsetForField(field->GetName(), true);
232         if (offset.empty()) {
233           ERROR(field) << "Checksum Start Field offset can not be determined.";
234         }
235         s << "size_t sum_index = size() - (" << offset << ") / 8;";
236       }
237 
238       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
239       const auto& started_field = fields_.GetField(field_name);
240       if (started_field == nullptr) {
241         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
242                      << ")";
243       }
244       auto end_offset = GetOffsetForField(started_field->GetName(), false);
245       if (!end_offset.empty()) {
246         s << "size_t end_sum_index = (" << end_offset << ") / 8;";
247       } else {
248         end_offset = GetOffsetForField(started_field->GetName(), true);
249         if (end_offset.empty()) {
250           ERROR(started_field) << "Checksum Field end_offset can not be determined.";
251         }
252         s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
253       }
254       s << "if (end_sum_index >= size()) { return false; }";
255       if (is_little_endian_) {
256         s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
257       } else {
258         s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
259       }
260       s << started_field->GetDataType() << " checksum;";
261       s << "checksum.Initialize();";
262       s << "for (uint8_t byte : checksum_view) { ";
263       s << "checksum.AddByte(byte);}";
264       s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
265         << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
266 
267       continue;
268     }
269 
270     auto field_size = field->GetSize();
271     // Fixed size fields have already been handled.
272     if (!field_size.has_dynamic()) {
273       continue;
274     }
275 
276     // Custom fields with dynamic size must have the offset for the field passed in as well
277     // as the end iterator so that they may ensure that they don't try to read past the end.
278     // Custom fields with fixed sizes will be handled in the static offset checking.
279     if (field->GetFieldType() == CustomField::kFieldType) {
280       // Check if we can determine offset from begin(), otherwise error because by this point,
281       // the size of the custom field is unknown and can't be subtracted from end() to get the
282       // offset.
283       auto offset = GetOffsetForField(field->GetName(), false);
284       if (offset.empty()) {
285         ERROR(field) << "Custom Field offset can not be determined from begin().";
286       }
287 
288       if (offset.bits() % 8 != 0) {
289         ERROR(field) << "Custom fields must be byte aligned.";
290       }
291 
292       // Custom fields are special as their size field takes an argument.
293       const auto& custom_size_var = field->GetName() + "_size";
294       s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
295       s << "(begin() + (" << offset << ") / 8);";
296 
297       s << "if (!" << custom_size_var << ".has_value()) { return false; }";
298       s << "it += *" << custom_size_var << ";";
299       s << "if (it > end()) return false;";
300       continue;
301     } else {
302       s << "it += (" << field_size.dynamic_string() << ") / 8;";
303       s << "if (it > end()) return false;";
304     }
305   }
306 
307   // Validate constraints after validating the size
308   if (parent_constraints_.size() > 0 && parent_ == nullptr) {
309     ERROR() << "Can't have a constraint on a NULL parent";
310   }
311 
312   for (const auto& constraint : parent_constraints_) {
313     s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
314     const auto& field = parent_->GetParamList().GetField(constraint.first);
315     if (field->GetFieldType() == ScalarField::kFieldType) {
316       s << std::get<int64_t>(constraint.second);
317     } else {
318       s << std::get<std::string>(constraint.second);
319     }
320     s << ") return false;";
321   }
322 
323   // Validate the packets fields last
324   for (const auto& field : fields_) {
325     field->GenValidator(s);
326     s << "\n";
327   }
328 
329   s << "return true;";
330   s << "}\n";
331   if (parent_ == nullptr) {
332     s << "bool was_validated_{false};\n";
333   }
334 }
335 
GenParserToString(std::ostream & s) const336 void PacketDef::GenParserToString(std::ostream& s) const {
337   s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
338   s << "std::stringstream ss;";
339   s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
340 
341   if (fields_.size() > 0) {
342     s << "ss << \"\" ";
343     bool firstfield = true;
344     for (const auto& field : fields_) {
345       if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
346           field->GetFieldType() == ChecksumStartField::kFieldType)
347         continue;
348 
349       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
350 
351       field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
352 
353       if (firstfield) {
354         firstfield = false;
355       }
356     }
357     s << ";";
358   }
359 
360   s << "ss << \" }\";";
361   s << "return ss.str();";
362   s << "}\n";
363 }
364 
GenBuilderDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const365 void PacketDef::GenBuilderDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
366   s << "class " << name_ << "Builder";
367   if (parent_ != nullptr) {
368     s << " : public " << parent_->name_ << "Builder";
369   } else {
370     if (is_little_endian_) {
371       s << " : public PacketBuilder<kLittleEndian>";
372     } else {
373       s << " : public PacketBuilder<!kLittleEndian>";
374     }
375   }
376   s << " {";
377   s << " public:";
378   s << "  virtual ~" << name_ << "Builder() = default;";
379 
380   if (!fields_.HasBody()) {
381     GenBuilderCreate(s);
382     s << "\n";
383 
384     if (generate_fuzzing || generate_tests) {
385       GenTestingFromView(s);
386       s << "\n";
387     }
388   }
389 
390   GenSerialize(s);
391   s << "\n";
392 
393   GenSize(s);
394   s << "\n";
395 
396   s << " protected:\n";
397   GenBuilderConstructor(s);
398   s << "\n";
399 
400   GenBuilderParameterChecker(s);
401   s << "\n";
402 
403   GenMembers(s);
404   s << "};\n";
405 
406   if (generate_tests) {
407     GenTestDefine(s);
408     s << "\n";
409   }
410 
411   if (generate_fuzzing || generate_tests) {
412     GenReflectTestDefine(s);
413     s << "\n";
414   }
415 
416   if (generate_fuzzing) {
417     GenFuzzTestDefine(s);
418     s << "\n";
419   }
420 }
421 
GenTestingFromView(std::ostream & s) const422 void PacketDef::GenTestingFromView(std::ostream& s) const {
423   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
424 
425   s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
426   s << "return " << name_ << "Builder::Create(";
427   FieldList params = GetParamList().GetFieldsWithoutTypes({
428       BodyField::kFieldType,
429   });
430   for (std::size_t i = 0; i < params.size(); i++) {
431     params[i]->GenBuilderParameterFromView(s);
432     if (i != params.size() - 1) {
433       s << ", ";
434     }
435   }
436   s << ");";
437   s << "}";
438 
439   s << "\n#endif\n";
440 }
441 
GenBuilderDefinitionPybind11(std::ostream & s) const442 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
443   s << "py::class_<" << name_ << "Builder";
444   if (parent_ != nullptr) {
445     s << ", " << parent_->name_ << "Builder";
446   } else {
447     if (is_little_endian_) {
448       s << ", PacketBuilder<kLittleEndian>";
449     } else {
450       s << ", PacketBuilder<!kLittleEndian>";
451     }
452   }
453   s << ", std::shared_ptr<" << name_ << "Builder>";
454   s << ">(m, \"" << name_ << "Builder\")";
455   if (!fields_.HasBody()) {
456     GenBuilderCreatePybind11(s);
457   }
458   s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
459   s << "std::vector<uint8_t> bytes;";
460   s << "BitInserter bi(bytes);";
461   s << "builder.Serialize(bi);";
462   s << "return bytes;})";
463   s << ";\n";
464 }
465 
GenTestDefine(std::ostream & s) const466 void PacketDef::GenTestDefine(std::ostream& s) const {
467   s << "#ifdef PACKET_TESTING\n";
468   s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
469   s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
470   s << "public: ";
471   s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
472   s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
473   s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
474   s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
475   s << "ASSERT_TRUE(view.IsValid());";
476   s << "auto packet = " << name_ << "Builder::FromView(view);";
477   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
478   s << "packet_bytes->reserve(packet->size());";
479   s << "BitInserter it(*packet_bytes);";
480   s << "packet->Serialize(it);";
481   s << "ASSERT_EQ(*packet_bytes, captured_packet);";
482   s << "}";
483   s << "};";
484   s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
485   s << "CompareBytes(GetParam());";
486   s << "}";
487   s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
488   s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
489   int i = 0;
490   for (const auto& bytes : test_cases_) {
491     s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
492     s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
493     s << name_ << "_test_bytes_" << i << ",";
494     s << name_ << "_test_bytes_" << i << " + sizeof(";
495     s << name_ << "_test_bytes_" << i << ") - 1);";
496     i++;
497   }
498   if (!test_cases_.empty()) {
499     i = 0;
500     s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
501     for (auto bytes : test_cases_) {
502       if (i > 0) {
503         s << ",";
504       }
505       s << name_ << "_test_vec_" << i++;
506     }
507     s << ");";
508   }
509   s << "\n#endif";
510 }
511 
GenReflectTestDefine(std::ostream & s) const512 void PacketDef::GenReflectTestDefine(std::ostream& s) const {
513   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
514   s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
515   s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
516   s << "auto vec = std::vector<uint8_t>(data, data + size);";
517   s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
518   s << "if (!view.IsValid()) { return; }";
519   s << "auto packet = " << name_ << "Builder::FromView(view);";
520   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
521   s << "packet_bytes->reserve(packet->size());";
522   s << "BitInserter it(*packet_bytes);";
523   s << "packet->Serialize(it);";
524   s << "}";
525   s << "\n#endif\n";
526 }
527 
GenFuzzTestDefine(std::ostream & s) const528 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
529   s << "#ifdef PACKET_FUZZ_TESTING\n";
530   s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
531   s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
532   s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
533   s << "public: ";
534   s << "explicit " << name_
535     << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
536   s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
537   s << "}}; ";
538   s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
539   s << "\n#endif";
540 }
541 
GetParametersToValidate() const542 FieldList PacketDef::GetParametersToValidate() const {
543   FieldList params_to_validate;
544   for (const auto& field : GetParamList()) {
545     if (field->HasParameterValidator()) {
546       params_to_validate.AppendField(field);
547     }
548   }
549   return params_to_validate;
550 }
551 
GenBuilderCreate(std::ostream & s) const552 void PacketDef::GenBuilderCreate(std::ostream& s) const {
553   s << "static std::unique_ptr<" << name_ << "Builder> Create(";
554 
555   auto params = GetParamList();
556   for (std::size_t i = 0; i < params.size(); i++) {
557     params[i]->GenBuilderParameter(s);
558     if (i != params.size() - 1) {
559       s << ", ";
560     }
561   }
562   s << ") {";
563 
564   // Call the constructor
565   s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
566 
567   params = params.GetFieldsWithoutTypes({
568       PayloadField::kFieldType,
569       BodyField::kFieldType,
570   });
571   // Add the parameters.
572   for (std::size_t i = 0; i < params.size(); i++) {
573     if (params[i]->BuilderParameterMustBeMoved()) {
574       s << "std::move(" << params[i]->GetName() << ")";
575     } else {
576       s << params[i]->GetName();
577     }
578     if (i != params.size() - 1) {
579       s << ", ";
580     }
581   }
582 
583   s << "));";
584   if (fields_.HasPayload()) {
585     s << "builder->payload_ = std::move(payload);";
586   }
587   s << "return builder;";
588   s << "}\n";
589 }
590 
GenBuilderCreatePybind11(std::ostream & s) const591 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
592   s << ".def(py::init([](";
593   auto params = GetParamList();
594   std::vector<std::string> constructor_args;
595   int i = 1;
596   for (const auto& param : params) {
597     i++;
598     std::stringstream ss;
599     auto param_type = param->GetBuilderParameterType();
600     if (param_type.empty()) {
601       continue;
602     }
603     // Use shared_ptr instead of unique_ptr for the Python interface
604     if (param->BuilderParameterMustBeMoved()) {
605       param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
606     }
607     ss << param_type << " " << param->GetName();
608     constructor_args.push_back(ss.str());
609   }
610   s << util::StringJoin(",", constructor_args) << "){";
611 
612   // Deal with move only args
613   for (const auto& param : params) {
614     std::stringstream ss;
615     auto param_type = param->GetBuilderParameterType();
616     if (param_type.empty()) {
617       continue;
618     }
619     if (!param->BuilderParameterMustBeMoved()) {
620       continue;
621     }
622     auto move_only_param_name = param->GetName() + "_move_only";
623     s << param_type << " " << move_only_param_name << ";";
624     if (param->IsContainerField()) {
625       // Assume single layer container and copy it
626       auto struct_type = param->GetElementField()->GetDataType();
627       struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
628       struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
629       s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
630       // Serialize each struct
631       s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
632       s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
633       s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
634       s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
635       // Parse it again
636       s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
637       s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
638       s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
639       // Push it into a new container
640       if (param->GetFieldType() == VectorField::kFieldType) {
641         s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
642       } else if (param->GetFieldType() == ArrayField::kFieldType) {
643         s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
644       } else {
645         ERROR() << param << " is not supported by Pybind11";
646       }
647       s << "}";
648     } else {
649       // Serialize the parameter and pass the bytes in a RawBuilder
650       s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
651       s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
652       s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
653       s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
654       s << move_only_param_name << " = ";
655       s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
656     }
657   }
658   s << "return " << name_ << "Builder::Create(";
659   std::vector<std::string> builder_vars;
660   for (const auto& param : params) {
661     std::stringstream ss;
662     auto param_type = param->GetBuilderParameterType();
663     if (param_type.empty()) {
664       continue;
665     }
666     auto param_name = param->GetName();
667     if (param->BuilderParameterMustBeMoved()) {
668       ss << "std::move(" << param_name << "_move_only)";
669     } else {
670       ss << param_name;
671     }
672     builder_vars.push_back(ss.str());
673   }
674   s << util::StringJoin(",", builder_vars) << ");}";
675   s << "))";
676 }
677 
GenBuilderParameterChecker(std::ostream & s) const678 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
679   FieldList params_to_validate = GetParametersToValidate();
680 
681   // Skip writing this function if there is nothing to validate.
682   if (params_to_validate.size() == 0) {
683     return;
684   }
685 
686   // Generate function arguments.
687   s << "void CheckParameterValues(";
688   for (std::size_t i = 0; i < params_to_validate.size(); i++) {
689     params_to_validate[i]->GenBuilderParameter(s);
690     if (i != params_to_validate.size() - 1) {
691       s << ", ";
692     }
693   }
694   s << ") {";
695 
696   // Check the parameters.
697   for (const auto& field : params_to_validate) {
698     field->GenParameterValidator(s);
699   }
700   s << "}\n";
701 }
702 
GenBuilderConstructor(std::ostream & s) const703 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
704   s << "explicit " << name_ << "Builder(";
705 
706   // Generate the constructor parameters.
707   auto params = GetParamList().GetFieldsWithoutTypes({
708       PayloadField::kFieldType,
709       BodyField::kFieldType,
710   });
711   for (std::size_t i = 0; i < params.size(); i++) {
712     params[i]->GenBuilderParameter(s);
713     if (i != params.size() - 1) {
714       s << ", ";
715     }
716   }
717   if (params.size() > 0 || parent_constraints_.size() > 0) {
718     s << ") :";
719   } else {
720     s << ")";
721   }
722 
723   // Get the list of parent params to call the parent constructor with.
724   FieldList parent_params;
725   if (parent_ != nullptr) {
726     // Pass parameters to the parent constructor
727     s << parent_->name_ << "Builder(";
728     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
729         PayloadField::kFieldType,
730         BodyField::kFieldType,
731     });
732 
733     // Go through all the fields and replace constrained fields with fixed values
734     // when calling the parent constructor.
735     for (std::size_t i = 0; i < parent_params.size(); i++) {
736       const auto& field = parent_params[i];
737       const auto& constraint = parent_constraints_.find(field->GetName());
738       if (constraint != parent_constraints_.end()) {
739         if (field->GetFieldType() == ScalarField::kFieldType) {
740           s << std::get<int64_t>(constraint->second);
741         } else if (field->GetFieldType() == EnumField::kFieldType) {
742           s << std::get<std::string>(constraint->second);
743         } else {
744           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
745         }
746 
747         s << "/* " << field->GetName() << "_ */";
748       } else {
749         s << field->GetName();
750       }
751 
752       if (i != parent_params.size() - 1) {
753         s << ", ";
754       }
755     }
756     s << ") ";
757   }
758 
759   // Build a list of parameters that excludes all parent parameters.
760   FieldList saved_params;
761   for (const auto& field : params) {
762     if (parent_params.GetField(field->GetName()) == nullptr) {
763       saved_params.AppendField(field);
764     }
765   }
766   if (parent_ != nullptr && saved_params.size() > 0) {
767     s << ",";
768   }
769   for (std::size_t i = 0; i < saved_params.size(); i++) {
770     const auto& saved_param_name = saved_params[i]->GetName();
771     if (saved_params[i]->BuilderParameterMustBeMoved()) {
772       s << saved_param_name << "_(std::move(" << saved_param_name << "))";
773     } else {
774       s << saved_param_name << "_(" << saved_param_name << ")";
775     }
776     if (i != saved_params.size() - 1) {
777       s << ",";
778     }
779   }
780   s << " {";
781 
782   FieldList params_to_validate = GetParametersToValidate();
783 
784   if (params_to_validate.size() > 0) {
785     s << "CheckParameterValues(";
786     for (std::size_t i = 0; i < params_to_validate.size(); i++) {
787       s << params_to_validate[i]->GetName() << "_";
788       if (i != params_to_validate.size() - 1) {
789         s << ", ";
790       }
791     }
792     s << ");";
793   }
794 
795   s << "}\n";
796 }
797 
GenRustChildEnums(std::ostream & s) const798 void PacketDef::GenRustChildEnums(std::ostream& s) const {
799   if (HasChildEnums()) {
800     bool payload = fields_.HasPayload();
801     s << "#[derive(Debug)] ";
802     s << "enum " << name_ << "DataChild {";
803     for (const auto& child : children_) {
804       s << child->name_ << "(Arc<" << child->name_ << "Data>),";
805     }
806     if (payload) {
807       s << "Payload(Bytes),";
808     }
809     s << "None,";
810     s << "}\n";
811 
812     s << "impl " << name_ << "DataChild {";
813     s << "fn get_total_size(&self) -> usize {";
814     s << "match self {";
815     for (const auto& child : children_) {
816       s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),";
817     }
818     if (payload) {
819       s << name_ << "DataChild::Payload(p) => p.len(),";
820     }
821     s << name_ << "DataChild::None => 0,";
822     s << "}\n";
823     s << "}\n";
824     s << "}\n";
825 
826     s << "#[derive(Debug)] ";
827     s << "pub enum " << name_ << "Child {";
828     for (const auto& child : children_) {
829       s << child->name_ << "(" << child->name_ << "Packet),";
830     }
831     if (payload) {
832       s << "Payload(Bytes),";
833     }
834     s << "None,";
835     s << "}\n";
836   }
837 }
838 
GenRustStructDeclarations(std::ostream & s) const839 void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
840   s << "#[derive(Debug)] ";
841   s << "struct " << name_ << "Data {";
842 
843   // Generate struct fields
844   GenRustStructFieldNameAndType(s);
845   if (HasChildEnums()) {
846     s << "child: " << name_ << "DataChild,";
847   }
848   s << "}\n";
849 
850   // Generate accessor struct
851   s << "#[derive(Debug, Clone)] ";
852   s << "pub struct " << name_ << "Packet {";
853   auto lineage = GetAncestors();
854   lineage.push_back(this);
855   for (auto it = lineage.begin(); it != lineage.end(); it++) {
856     auto def = *it;
857     s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,";
858   }
859   s << "}\n";
860 
861   // Generate builder struct
862   s << "#[derive(Debug)] ";
863   s << "pub struct " << name_ << "Builder {";
864   auto params = GetParamList().GetFieldsWithoutTypes({
865       PayloadField::kFieldType,
866       BodyField::kFieldType,
867   });
868   for (auto param : params) {
869     s << "pub ";
870     param->GenRustNameAndType(s);
871     s << ", ";
872   }
873   if (fields_.HasPayload()) {
874     s << "pub payload: Option<Bytes>,";
875   }
876   s << "}\n";
877 }
878 
GenRustStructFieldNameAndType(std::ostream & s) const879 bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
880   auto fields = fields_.GetFieldsWithoutTypes({
881       BodyField::kFieldType,
882       CountField::kFieldType,
883       PaddingField::kFieldType,
884       ReservedField::kFieldType,
885       SizeField::kFieldType,
886       PayloadField::kFieldType,
887       FixedScalarField::kFieldType,
888   });
889   if (fields.size() == 0) {
890     return false;
891   }
892   for (const auto& field : fields) {
893     field->GenRustNameAndType(s);
894     s << ", ";
895   }
896   return true;
897 }
898 
GenRustStructFieldNames(std::ostream & s) const899 void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
900   auto fields = fields_.GetFieldsWithoutTypes({
901       BodyField::kFieldType,
902       CountField::kFieldType,
903       PaddingField::kFieldType,
904       ReservedField::kFieldType,
905       SizeField::kFieldType,
906       PayloadField::kFieldType,
907       FixedScalarField::kFieldType,
908   });
909   for (const auto field : fields) {
910     s << field->GetName();
911     s << ", ";
912   }
913 }
914 
GenRustStructImpls(std::ostream & s) const915 void PacketDef::GenRustStructImpls(std::ostream& s) const {
916   auto packet_dep = PacketDependency(GetRootDef());
917 
918   s << "impl " << name_ << "Data {";
919   // conforms function
920   s << "fn conforms(bytes: &[u8]) -> bool {";
921   GenRustConformanceCheck(s);
922 
923   auto fields = fields_.GetFieldsWithTypes({
924       StructField::kFieldType,
925   });
926 
927   for (auto const& field : fields) {
928     auto start_offset = GetOffsetForField(field->GetName(), false);
929     auto end_offset = GetOffsetForField(field->GetName(), true);
930 
931     s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes();
932     s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }";
933   }
934 
935   s << " true";
936   s << "}";
937 
938   auto parse_params = packet_dep.GetDependencies(name_);
939   s << "fn parse(bytes: &[u8]";
940   for (auto field_name : parse_params) {
941     auto constraint_field = GetParamList().GetField(field_name);
942     auto constraint_type = constraint_field->GetRustDataType();
943     s << ", " << field_name << ": " << constraint_type;
944   }
945   s << ") -> Result<Self> {";
946 
947   fields = fields_.GetFieldsWithoutTypes({
948       BodyField::kFieldType,
949   });
950 
951   for (auto const& field : fields) {
952     auto start_field_offset = GetOffsetForField(field->GetName(), false);
953     auto end_field_offset = GetOffsetForField(field->GetName(), true);
954 
955     if (start_field_offset.empty() && end_field_offset.empty()) {
956       ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
957                    << "no method exists to determine field location from begin() or end().\n";
958     }
959 
960     field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
961     field->GenRustGetter(s, start_field_offset, end_field_offset, name_);
962   }
963 
964   auto payload_field = fields_.GetFieldsWithTypes({
965     PayloadField::kFieldType,
966   });
967 
968   Size payload_offset;
969 
970   if (payload_field.HasPayload()) {
971     payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
972   }
973 
974   if (children_.size() > 1) {
975     auto match_on_variables = packet_dep.GetChildrenDependencies(name_);
976     // If match_on_variables is empty, this means there are multiple abstract packets which will
977     // specialize to a child down the packet tree.
978     // In this case match variables will be the union of parent fields and parse params of children.
979     if (match_on_variables.empty()) {
980       for (auto& field : fields_) {
981         if (std::any_of(children_.begin(), children_.end(), [&](auto child) {
982               auto pass_me = packet_dep.GetDependencies(child->name_);
983               return std::find(pass_me.begin(), pass_me.end(), field->GetName()) != pass_me.end();
984             })) {
985           match_on_variables.push_back(field->GetName());
986         }
987       }
988     }
989 
990     s << "let child = match (";
991 
992     for (auto var : match_on_variables) {
993       if (var == match_on_variables[match_on_variables.size() - 1]) {
994         s << var;
995       } else {
996         s << var << ", ";
997       }
998     }
999     s << ") {";
1000 
1001     auto get_match_val = [&](
1002         std::string& match_var,
1003         std::variant<int64_t,
1004         std::string> constraint) -> std::string {
1005       auto constraint_field = GetParamList().GetField(match_var);
1006       auto constraint_type = constraint_field->GetFieldType();
1007 
1008       if (constraint_type == EnumField::kFieldType) {
1009         auto type = std::get<std::string>(constraint);
1010         auto variant_name = type.substr(type.find("::") + 2, type.length());
1011         auto enum_type = type.substr(0, type.find("::"));
1012         return enum_type + "::" + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
1013       }
1014       if (constraint_type == ScalarField::kFieldType) {
1015         return std::to_string(std::get<int64_t>(constraint));
1016       }
1017       return "_";
1018     };
1019 
1020     for (auto& child : children_) {
1021       s << "(";
1022       for (auto var : match_on_variables) {
1023         std::string match_val = "_";
1024 
1025         if (child->parent_constraints_.find(var) != child->parent_constraints_.end()) {
1026           match_val = get_match_val(var, child->parent_constraints_[var]);
1027         } else {
1028           auto dcs = child->FindDescendantsWithConstraint(var);
1029           std::vector<std::string> all_match_vals;
1030           for (auto& desc : dcs) {
1031             all_match_vals.push_back(get_match_val(var, desc.second));
1032           }
1033           match_val = "";
1034           for (std::size_t i = 0; i < all_match_vals.size(); ++i) {
1035             match_val += all_match_vals[i];
1036             if (i != all_match_vals.size() - 1) {
1037               match_val += " | ";
1038             }
1039           }
1040           match_val = (match_val == "") ? "_" : match_val;
1041         }
1042 
1043         if (var == match_on_variables[match_on_variables.size() - 1]) {
1044           s << match_val << ")";
1045         } else {
1046           s << match_val << ", ";
1047         }
1048       }
1049       s << " if " << child->name_ << "Data::conforms(&bytes[..])";
1050       s << " => {";
1051       s << name_ << "DataChild::";
1052       s << child->name_ << "(Arc::new(";
1053 
1054       auto child_parse_params = packet_dep.GetDependencies(child->name_);
1055       if (child_parse_params.size() == 0) {
1056         s << child->name_ << "Data::parse(&bytes[..]";
1057       } else {
1058         s << child->name_ << "Data::parse(&bytes[..], ";
1059       }
1060 
1061       for (auto var : child_parse_params) {
1062         if (var == child_parse_params[child_parse_params.size() - 1]) {
1063           s << var;
1064         } else {
1065           s << var << ", ";
1066         }
1067       }
1068       s << ")?))";
1069       s << "}\n";
1070     }
1071 
1072     s << "(";
1073     for (int i = 1; i <= match_on_variables.size(); i++) {
1074       if (i == match_on_variables.size()) {
1075         s << "_";
1076       } else {
1077         s << "_, ";
1078       }
1079     }
1080     s << ")";
1081     s << " => return Err(Error::InvalidPacketError),";
1082     s << "};\n";
1083   } else if (children_.size() == 1) {
1084     auto child = children_.at(0);
1085     auto params = packet_dep.GetDependencies(child->name_);
1086     s << "let child = match " << child->name_ << "Data::parse(&bytes[..]";
1087     for (auto field_name : params) {
1088       s << ", " << field_name;
1089     }
1090     s << ") {";
1091     s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
1092     s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
1093     s << " },";
1094     s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,";
1095     s << " _ => return Err(Error::InvalidPacketError),";
1096     s << "};";
1097   } else if (fields_.HasPayload()) {
1098     s << "let child = if payload.len() > 0 {";
1099     s << name_ << "DataChild::Payload(Bytes::from(payload))";
1100     s << "} else {";
1101     s << name_ << "DataChild::None";
1102     s << "};";
1103   }
1104 
1105   s << "Ok(Self {";
1106   fields = fields_.GetFieldsWithoutTypes({
1107       BodyField::kFieldType,
1108       CountField::kFieldType,
1109       PaddingField::kFieldType,
1110       ReservedField::kFieldType,
1111       SizeField::kFieldType,
1112       PayloadField::kFieldType,
1113       FixedScalarField::kFieldType,
1114   });
1115 
1116   if (fields.size() > 0) {
1117     for (const auto& field : fields) {
1118       auto field_type = field->GetFieldType();
1119       s << field->GetName();
1120       s << ", ";
1121     }
1122   }
1123 
1124   if (HasChildEnums()) {
1125     s << "child,";
1126   }
1127   s << "})\n";
1128   s << "}\n";
1129 
1130   // write_to function
1131   s << "fn write_to(&self, buffer: &mut BytesMut) {";
1132   GenRustWriteToFields(s);
1133 
1134   if (HasChildEnums()) {
1135     s << "match &self.child {";
1136     for (const auto& child : children_) {
1137       s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
1138     }
1139     if (fields_.HasPayload()) {
1140       auto offset = GetOffsetForField("payload");
1141       s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),";
1142     }
1143     s << name_ << "DataChild::None => {}";
1144     s << "}";
1145   }
1146 
1147   s << "}\n";
1148 
1149   s << "fn get_total_size(&self) -> usize {";
1150   if (HasChildEnums()) {
1151     s << "self.get_size() + self.child.get_total_size()";
1152   } else {
1153     s << "self.get_size()";
1154   }
1155   s << "}\n";
1156 
1157   s << "fn get_size(&self) -> usize {";
1158   GenSizeRetVal(s);
1159   s << "}\n";
1160   s << "}\n";
1161 }
1162 
GenRustAccessStructImpls(std::ostream & s) const1163 void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
1164   if (complement_ != nullptr) {
1165     auto complement_root = complement_->GetRootDef();
1166     auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1167     s << "impl CommandExpectations for " << name_ << "Packet {";
1168     s << " type ResponseType = " << complement_->name_ << "Packet;";
1169     s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1170     s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1171       << ".unwrap()";
1172     s << " }";
1173     s << "}";
1174   }
1175 
1176   s << "impl Packet for " << name_ << "Packet {";
1177   auto root = GetRootDef();
1178   auto root_accessor = util::CamelCaseToUnderScore(root->name_);
1179 
1180   s << "fn to_bytes(self) -> Bytes {";
1181   s << " let mut buffer = BytesMut::new();";
1182   s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);";
1183   s << " self." << root_accessor << ".write_to(&mut buffer);";
1184   s << " buffer.freeze()";
1185   s << "}\n";
1186 
1187   s << "fn to_vec(self) -> Vec<u8> { self.to_bytes().to_vec() }\n";
1188   s << "}";
1189 
1190   s << "impl From<" << name_ << "Packet"
1191     << "> for Bytes {\n";
1192   s << "fn from(packet: " << name_ << "Packet"
1193     << ") -> Self {\n";
1194   s << "packet.to_bytes()\n";
1195   s << "}\n";
1196   s << "}\n";
1197 
1198   s << "impl From<" << name_ << "Packet"
1199     << "> for Vec<u8> {\n";
1200   s << "fn from(packet: " << name_ << "Packet"
1201     << ") -> Self {\n";
1202   s << "packet.to_vec()\n";
1203   s << "}\n";
1204   s << "}\n";
1205 
1206   if (root != this) {
1207     s << "impl TryFrom<" << root->name_ << "Packet"
1208       << "> for " << name_ << "Packet {\n";
1209     s << "type Error = TryFromError;\n";
1210     s << "fn try_from(value: " << root->name_ << "Packet)"
1211       << " -> std::result::Result<Self, Self::Error> {\n";
1212     s << "Self::new(value." << root_accessor << ").map_err(TryFromError)\n", s << "}\n";
1213     s << "}\n";
1214   }
1215 
1216   s << "impl " << name_ << "Packet {";
1217   if (parent_ == nullptr) {
1218     s << "pub fn parse(bytes: &[u8]) -> Result<Self> { ";
1219     s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)).unwrap())";
1220     s << "}";
1221   }
1222 
1223   if (HasChildEnums()) {
1224     s << " pub fn specialize(&self) -> " << name_ << "Child {";
1225     s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
1226     for (const auto& child : children_) {
1227       s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
1228         << child->name_ << "Packet::new(self." << root_accessor << ".clone()).unwrap()),";
1229     }
1230     if (fields_.HasPayload()) {
1231       s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),";
1232     }
1233     s << name_ << "DataChild::None => " << name_ << "Child::None,";
1234     s << "}}";
1235   }
1236   auto lineage = GetAncestors();
1237   lineage.push_back(this);
1238   const ParentDef* prev = nullptr;
1239 
1240   s << " fn new(root: Arc<" << root->name_ << "Data>) -> std::result::Result<Self, &'static str> {";
1241   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1242     auto def = *it;
1243     auto accessor_name = util::CamelCaseToUnderScore(def->name_);
1244     if (prev == nullptr) {
1245       s << "let " << accessor_name << " = root;";
1246     } else {
1247       s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {";
1248       s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),";
1249       s << "_ => return Err(\"inconsistent state - child was not " << def->name_ << "\"),";
1250       s << "};";
1251     }
1252     prev = def;
1253   }
1254   s << "Ok(Self {";
1255   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1256     auto def = *it;
1257     s << util::CamelCaseToUnderScore(def->name_) << ",";
1258   }
1259   s << "})}";
1260 
1261   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1262     auto def = *it;
1263     auto fields = def->fields_.GetFieldsWithoutTypes({
1264         BodyField::kFieldType,
1265         CountField::kFieldType,
1266         PaddingField::kFieldType,
1267         ReservedField::kFieldType,
1268         SizeField::kFieldType,
1269         PayloadField::kFieldType,
1270         FixedScalarField::kFieldType,
1271     });
1272 
1273     for (auto const& field : fields) {
1274       if (field->GetterIsByRef()) {
1275         s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{";
1276         s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1277         s << "}\n";
1278       } else {
1279         s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{";
1280         s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1281         s << "}\n";
1282       }
1283     }
1284   }
1285 
1286   s << "}\n";
1287 
1288   lineage = GetAncestors();
1289   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1290     auto def = *it;
1291     s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {";
1292     s << " fn into(self) -> " << def->name_ << "Packet {";
1293     s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")"
1294       << ".unwrap()";
1295     s << " }";
1296     s << "}\n";
1297   }
1298 }
1299 
GenRustBuilderStructImpls(std::ostream & s) const1300 void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
1301   if (complement_ != nullptr) {
1302     auto complement_root = complement_->GetRootDef();
1303     auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1304     s << "impl CommandExpectations for " << name_ << "Builder {";
1305     s << " type ResponseType = " << complement_->name_ << "Packet;";
1306     s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1307     s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1308       << ".unwrap()";
1309     s << " }";
1310     s << "}";
1311   }
1312 
1313   s << "impl " << name_ << "Builder {";
1314   s << "pub fn build(self) -> " << name_ << "Packet {";
1315   auto lineage = GetAncestors();
1316   lineage.push_back(this);
1317   std::reverse(lineage.begin(), lineage.end());
1318 
1319   auto all_constraints = GetAllConstraints();
1320 
1321   const ParentDef* prev = nullptr;
1322   for (auto ancestor : lineage) {
1323     auto fields = ancestor->fields_.GetFieldsWithoutTypes({
1324         BodyField::kFieldType,
1325         CountField::kFieldType,
1326         PaddingField::kFieldType,
1327         ReservedField::kFieldType,
1328         SizeField::kFieldType,
1329         PayloadField::kFieldType,
1330         FixedScalarField::kFieldType,
1331     });
1332 
1333     auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_);
1334     s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {";
1335     for (auto field : fields) {
1336       auto constraint = all_constraints.find(field->GetName());
1337       s << field->GetName() << ": ";
1338       if (constraint != all_constraints.end()) {
1339         if (field->GetFieldType() == ScalarField::kFieldType) {
1340           s << std::get<int64_t>(constraint->second);
1341         } else if (field->GetFieldType() == EnumField::kFieldType) {
1342           auto value = std::get<std::string>(constraint->second);
1343           auto constant = value.substr(value.find("::") + 2, std::string::npos);
1344           s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant);
1345           ;
1346         } else {
1347           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
1348         }
1349       } else {
1350         s << "self." << field->GetName();
1351       }
1352       s << ", ";
1353     }
1354     if (ancestor->HasChildEnums()) {
1355       if (prev == nullptr) {
1356         if (ancestor->fields_.HasPayload()) {
1357           s << "child: match self.payload { ";
1358           s << "None => " << name_ << "DataChild::None,";
1359           s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),";
1360           s << "},";
1361         } else {
1362           s << "child: " << name_ << "DataChild::None,";
1363         }
1364       } else {
1365         s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "("
1366           << util::CamelCaseToUnderScore(prev->name_) << "),";
1367       }
1368     }
1369     s << "});";
1370     prev = ancestor;
1371   }
1372 
1373   s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ").unwrap()";
1374   s << "}\n";
1375 
1376   s << "}\n";
1377   for (const auto ancestor : GetAncestors()) {
1378     s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {";
1379     s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }";
1380     s << "}\n";
1381   }
1382 }
1383 
GenRustBuilderTest(std::ostream & s) const1384 void PacketDef::GenRustBuilderTest(std::ostream& s) const {
1385   auto lineage = GetAncestors();
1386   lineage.push_back(this);
1387   if (!lineage.empty() && !test_cases_.empty()) {
1388     s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { ";
1389     s << "($($name:ident: $byte_string:expr,)*) => {";
1390     s << "$(";
1391     s << "\n#[test]\n";
1392     s << "pub fn $name() { ";
1393     s << "let raw_bytes = $byte_string;";
1394     for (size_t i = 0; i < lineage.size(); i++) {
1395       s << "/* (" << i << ") */\n";
1396       if (i == 0) {
1397         s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {";
1398         s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1399         s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1400       } else if (i != lineage.size() - 1) {
1401         s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(";
1402         s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1403         s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1404       } else {
1405         s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {";
1406         s << "let rebuilder = " << lineage[i]->name_ << "Builder {";
1407         FieldList params = GetParamList();
1408         if (params.HasBody()) {
1409           ERROR() << "Packets with body fields can't be auto-tested.  Test a child.";
1410         }
1411         for (const auto param : params) {
1412           s << param->GetName() << " : packet.";
1413           if (param->GetFieldType() == VectorField::kFieldType) {
1414             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1415           } else if (param->GetFieldType() == ArrayField::kFieldType) {
1416             const auto array_param = static_cast<const ArrayField*>(param);
1417             const auto element_field = array_param->GetElementField();
1418             if (element_field->GetFieldType() == StructField::kFieldType) {
1419               s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1420             } else {
1421               s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1422             }
1423           } else if (param->GetFieldType() == StructField::kFieldType) {
1424             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1425           } else {
1426             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),";
1427           }
1428         }
1429         s << "};";
1430         s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();";
1431         s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();";
1432         s << "assert_eq!(rebuilder_bytes, raw_bytes);";
1433         s << "}";
1434       }
1435     }
1436     for (size_t i = 1; i < lineage.size(); i++) {
1437       s << "_ => {";
1438       s << "panic!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_);
1439       s << "\n {:#02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); ";
1440       s << "}}}";
1441     }
1442 
1443     s << ",";
1444     s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),";
1445     s << "}";
1446     s << "}";
1447     s << ")*";
1448     s << "}";
1449     s << "}";
1450 
1451     s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { ";
1452     int number = 0;
1453     for (const auto& test_case : test_cases_) {
1454       s << util::CamelCaseToUnderScore(name_) << "_builder_test_";
1455       s << std::setfill('0') << std::setw(2) << number++ << ": ";
1456       s << "b\"" << test_case << "\",";
1457     }
1458     s << "}";
1459     s << "\n";
1460   }
1461 }
1462 
GenRustDef(std::ostream & s) const1463 void PacketDef::GenRustDef(std::ostream& s) const {
1464   GenRustChildEnums(s);
1465   GenRustStructDeclarations(s);
1466   GenRustStructImpls(s);
1467   GenRustAccessStructImpls(s);
1468   GenRustBuilderStructImpls(s);
1469   GenRustBuilderTest(s);
1470 }
1471