• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "packet_def.h"
18 
19 #include <iomanip>
20 #include <list>
21 #include <set>
22 
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26 
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
29 
GetNewField(const std::string &,ParseLocation) const30 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
31   return nullptr;  // Packets can't be fields
32 }
33 
GenParserDefinition(std::ostream & s) const34 void PacketDef::GenParserDefinition(std::ostream& s) const {
35   s << "class " << name_ << "View";
36   if (parent_ != nullptr) {
37     s << " : public " << parent_->name_ << "View {";
38   } else {
39     s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
40   }
41   s << " public:";
42 
43   // Specialize function
44   if (parent_ != nullptr) {
45     s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
46     s << "{ return " << name_ << "View(std::move(parent)); }";
47   } else {
48     s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
49     s << "{ return " << name_ << "View(std::move(packet)); }";
50   }
51 
52   GenTestingParserFromBytes(s);
53 
54   std::set<std::string> fixed_types = {
55       FixedScalarField::kFieldType,
56       FixedEnumField::kFieldType,
57   };
58 
59   // Print all of the public fields which are all the fields minus the fixed fields.
60   const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
61   bool has_fixed_fields = public_fields.size() != fields_.size();
62   for (const auto& field : public_fields) {
63     GenParserFieldGetter(s, field);
64     s << "\n";
65   }
66   GenValidator(s);
67   s << "\n";
68 
69   s << " public:";
70   GenParserToString(s);
71   s << "\n";
72 
73   s << " protected:\n";
74   // Constructor from a View
75   if (parent_ != nullptr) {
76     s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
77     s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
78   } else {
79     s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
80     s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
81   }
82 
83   // Print the private fields which are the fixed fields.
84   if (has_fixed_fields) {
85     const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
86     s << " private:\n";
87     for (const auto& field : private_fields) {
88       GenParserFieldGetter(s, field);
89       s << "\n";
90     }
91   }
92   s << "};\n";
93 }
94 
GenTestingParserFromBytes(std::ostream & s) const95 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
96   s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
97 
98   s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
99   s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
100   s << "return " << name_ << "View::Create(";
101   auto ancestor_ptr = parent_;
102   size_t parent_parens = 0;
103   while (ancestor_ptr != nullptr) {
104     s << ancestor_ptr->name_ << "View::Create(";
105     parent_parens++;
106     ancestor_ptr = ancestor_ptr->parent_;
107   }
108   s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
109   for (size_t i = 0; i < parent_parens; i++) {
110     s << ")";
111   }
112   s << ");";
113   s << "}";
114 
115   s << "\n#endif\n";
116 }
117 
GenParserDefinitionPybind11(std::ostream & s) const118 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
119   s << "py::class_<" << name_ << "View";
120   if (parent_ != nullptr) {
121     s << ", " << parent_->name_ << "View";
122   } else {
123     s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
124   }
125   s << ">(m, \"" << name_ << "View\")";
126   if (parent_ != nullptr) {
127     s << ".def(py::init([](" << parent_->name_ << "View parent) {";
128   } else {
129     s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
130   }
131   s << "auto view =" << name_ << "View::Create(std::move(parent));";
132   s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
133   s << "return view; }))";
134 
135   s << ".def(py::init(&" << name_ << "View::Create))";
136   std::set<std::string> protected_field_types = {
137       FixedScalarField::kFieldType,
138       FixedEnumField::kFieldType,
139       SizeField::kFieldType,
140       CountField::kFieldType,
141   };
142   const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
143   for (const auto& field : public_fields) {
144     auto getter_func_name = field->GetGetterFunctionName();
145     if (getter_func_name.empty()) {
146       continue;
147     }
148     s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
149   }
150   s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
151   s << ";\n";
152 }
153 
GenParserFieldGetter(std::ostream & s,const PacketField * field) const154 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
155   // Start field offset
156   auto start_field_offset = GetOffsetForField(field->GetName(), false);
157   auto end_field_offset = GetOffsetForField(field->GetName(), true);
158 
159   if (start_field_offset.empty() && end_field_offset.empty()) {
160     ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
161                  << "no method exists to determine field location from begin() or end().\n";
162   }
163 
164   field->GenGetter(s, start_field_offset, end_field_offset);
165 }
166 
GetDefinitionType() const167 TypeDef::Type PacketDef::GetDefinitionType() const {
168   return TypeDef::Type::PACKET;
169 }
170 
GenValidator(std::ostream & s) const171 void PacketDef::GenValidator(std::ostream& s) const {
172   // Get the static offset for all of our fields.
173   int bits_size = 0;
174   for (const auto& field : fields_) {
175     if (field->GetFieldType() != PaddingField::kFieldType) {
176       bits_size += field->GetSize().bits();
177     }
178   }
179 
180   // Write the function declaration.
181   s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {";
182   s << "if (was_validated_) { return true; } ";
183   s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }";
184   s << "}";
185 
186   s << "protected:";
187   s << "virtual bool IsValid_() const {";
188 
189   if (parent_ != nullptr) {
190     s << "if (!" << parent_->name_ << "View::IsValid_()) { return false; } ";
191   }
192 
193   // Offset by the parents known size. We know that any dynamic fields can
194   // already be called since the parent must have already been validated by
195   // this point.
196   auto parent_size = Size(0);
197   if (parent_ != nullptr) {
198     parent_size = parent_->GetSize(true);
199   }
200 
201   s << "auto it = begin() + (" << parent_size << ") / 8;";
202 
203   // Check if you can extract the static fields.
204   // At this point you know you can use the size getters without crashing
205   // as long as they follow the instruction that size fields cant come before
206   // their corrisponding variable length field.
207   s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
208   s << "if (it > end()) return false;";
209 
210   // For any variable length fields, use their size check.
211   for (const auto& field : fields_) {
212     if (field->GetFieldType() == ChecksumStartField::kFieldType) {
213       auto offset = GetOffsetForField(field->GetName(), false);
214       if (!offset.empty()) {
215         s << "size_t sum_index = (" << offset << ") / 8;";
216       } else {
217         offset = GetOffsetForField(field->GetName(), true);
218         if (offset.empty()) {
219           ERROR(field) << "Checksum Start Field offset can not be determined.";
220         }
221         s << "size_t sum_index = size() - (" << offset << ") / 8;";
222       }
223 
224       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
225       const auto& started_field = fields_.GetField(field_name);
226       if (started_field == nullptr) {
227         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
228                      << ")";
229       }
230       auto end_offset = GetOffsetForField(started_field->GetName(), false);
231       if (!end_offset.empty()) {
232         s << "size_t end_sum_index = (" << end_offset << ") / 8;";
233       } else {
234         end_offset = GetOffsetForField(started_field->GetName(), true);
235         if (end_offset.empty()) {
236           ERROR(started_field) << "Checksum Field end_offset can not be determined.";
237         }
238         s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
239       }
240       if (is_little_endian_) {
241         s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
242       } else {
243         s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
244       }
245       s << started_field->GetDataType() << " checksum;";
246       s << "checksum.Initialize();";
247       s << "for (uint8_t byte : checksum_view) { ";
248       s << "checksum.AddByte(byte);}";
249       s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
250         << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
251 
252       continue;
253     }
254 
255     auto field_size = field->GetSize();
256     // Fixed size fields have already been handled.
257     if (!field_size.has_dynamic()) {
258       continue;
259     }
260 
261     // Custom fields with dynamic size must have the offset for the field passed in as well
262     // as the end iterator so that they may ensure that they don't try to read past the end.
263     // Custom fields with fixed sizes will be handled in the static offset checking.
264     if (field->GetFieldType() == CustomField::kFieldType) {
265       // Check if we can determine offset from begin(), otherwise error because by this point,
266       // the size of the custom field is unknown and can't be subtracted from end() to get the
267       // offset.
268       auto offset = GetOffsetForField(field->GetName(), false);
269       if (offset.empty()) {
270         ERROR(field) << "Custom Field offset can not be determined from begin().";
271       }
272 
273       if (offset.bits() % 8 != 0) {
274         ERROR(field) << "Custom fields must be byte aligned.";
275       }
276 
277       // Custom fields are special as their size field takes an argument.
278       const auto& custom_size_var = field->GetName() + "_size";
279       s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
280       s << "(begin() + (" << offset << ") / 8);";
281 
282       s << "if (!" << custom_size_var << ".has_value()) { return false; }";
283       s << "it += *" << custom_size_var << ";";
284       s << "if (it > end()) return false;";
285       continue;
286     } else {
287       s << "it += (" << field_size.dynamic_string() << ") / 8;";
288       s << "if (it > end()) return false;";
289     }
290   }
291 
292   // Validate constraints after validating the size
293   if (parent_constraints_.size() > 0 && parent_ == nullptr) {
294     ERROR() << "Can't have a constraint on a NULL parent";
295   }
296 
297   for (const auto& constraint : parent_constraints_) {
298     s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
299     const auto& field = parent_->GetParamList().GetField(constraint.first);
300     if (field->GetFieldType() == ScalarField::kFieldType) {
301       s << std::get<int64_t>(constraint.second);
302     } else {
303       s << std::get<std::string>(constraint.second);
304     }
305     s << ") return false;";
306   }
307 
308   // Validate the packets fields last
309   for (const auto& field : fields_) {
310     field->GenValidator(s);
311     s << "\n";
312   }
313 
314   s << "return true;";
315   s << "}\n";
316   if (parent_ == nullptr) {
317     s << "bool was_validated_{false};\n";
318   }
319 }
320 
GenParserToString(std::ostream & s) const321 void PacketDef::GenParserToString(std::ostream& s) const {
322   s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
323   s << "std::stringstream ss;";
324   s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
325 
326   if (fields_.size() > 0) {
327     s << "ss << \"\" ";
328     bool firstfield = true;
329     for (const auto& field : fields_) {
330       if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
331           field->GetFieldType() == ChecksumStartField::kFieldType)
332         continue;
333 
334       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
335 
336       field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
337 
338       if (firstfield) {
339         firstfield = false;
340       }
341     }
342     s << ";";
343   }
344 
345   s << "ss << \" }\";";
346   s << "return ss.str();";
347   s << "}\n";
348 }
349 
GenBuilderDefinition(std::ostream & s) const350 void PacketDef::GenBuilderDefinition(std::ostream& s) const {
351   s << "class " << name_ << "Builder";
352   if (parent_ != nullptr) {
353     s << " : public " << parent_->name_ << "Builder";
354   } else {
355     if (is_little_endian_) {
356       s << " : public PacketBuilder<kLittleEndian>";
357     } else {
358       s << " : public PacketBuilder<!kLittleEndian>";
359     }
360   }
361   s << " {";
362   s << " public:";
363   s << "  virtual ~" << name_ << "Builder() = default;";
364 
365   if (!fields_.HasBody()) {
366     GenBuilderCreate(s);
367     s << "\n";
368 
369     GenTestingFromView(s);
370     s << "\n";
371   }
372 
373   GenSerialize(s);
374   s << "\n";
375 
376   GenSize(s);
377   s << "\n";
378 
379   s << " protected:\n";
380   GenBuilderConstructor(s);
381   s << "\n";
382 
383   GenBuilderParameterChecker(s);
384   s << "\n";
385 
386   GenMembers(s);
387   s << "};\n";
388 
389   GenTestDefine(s);
390   s << "\n";
391 
392   GenFuzzTestDefine(s);
393   s << "\n";
394 }
395 
GenTestingFromView(std::ostream & s) const396 void PacketDef::GenTestingFromView(std::ostream& s) const {
397   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
398 
399   s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
400   s << "return " << name_ << "Builder::Create(";
401   FieldList params = GetParamList().GetFieldsWithoutTypes({
402       BodyField::kFieldType,
403   });
404   for (std::size_t i = 0; i < params.size(); i++) {
405     params[i]->GenBuilderParameterFromView(s);
406     if (i != params.size() - 1) {
407       s << ", ";
408     }
409   }
410   s << ");";
411   s << "}";
412 
413   s << "\n#endif\n";
414 }
415 
GenBuilderDefinitionPybind11(std::ostream & s) const416 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
417   s << "py::class_<" << name_ << "Builder";
418   if (parent_ != nullptr) {
419     s << ", " << parent_->name_ << "Builder";
420   } else {
421     if (is_little_endian_) {
422       s << ", PacketBuilder<kLittleEndian>";
423     } else {
424       s << ", PacketBuilder<!kLittleEndian>";
425     }
426   }
427   s << ", std::shared_ptr<" << name_ << "Builder>";
428   s << ">(m, \"" << name_ << "Builder\")";
429   if (!fields_.HasBody()) {
430     GenBuilderCreatePybind11(s);
431   }
432   s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
433   s << "std::vector<uint8_t> bytes;";
434   s << "BitInserter bi(bytes);";
435   s << "builder.Serialize(bi);";
436   s << "return bytes;})";
437   s << ";\n";
438 }
439 
GenTestDefine(std::ostream & s) const440 void PacketDef::GenTestDefine(std::ostream& s) const {
441   s << "#ifdef PACKET_TESTING\n";
442   s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
443   s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
444   s << "public: ";
445   s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
446   s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
447   s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
448   s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
449   s << "ASSERT_TRUE(view.IsValid());";
450   s << "auto packet = " << name_ << "Builder::FromView(view);";
451   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
452   s << "packet_bytes->reserve(packet->size());";
453   s << "BitInserter it(*packet_bytes);";
454   s << "packet->Serialize(it);";
455   s << "ASSERT_EQ(*packet_bytes, captured_packet);";
456   s << "}";
457   s << "};";
458   s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
459   s << "CompareBytes(GetParam());";
460   s << "}";
461   s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
462   s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
463   int i = 0;
464   for (const auto& bytes : test_cases_) {
465     s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
466     s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
467     s << name_ << "_test_bytes_" << i << ",";
468     s << name_ << "_test_bytes_" << i << " + sizeof(";
469     s << name_ << "_test_bytes_" << i << ") - 1);";
470     i++;
471   }
472   if (!test_cases_.empty()) {
473     i = 0;
474     s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
475     for (auto bytes : test_cases_) {
476       if (i > 0) {
477         s << ",";
478       }
479       s << name_ << "_test_vec_" << i++;
480     }
481     s << ");";
482   }
483   s << "\n#endif";
484 }
485 
GenFuzzTestDefine(std::ostream & s) const486 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
487   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
488   s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
489   s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
490   s << "auto vec = std::vector<uint8_t>(data, data + size);";
491   s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
492   s << "if (!view.IsValid()) { return; }";
493   s << "auto packet = " << name_ << "Builder::FromView(view);";
494   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
495   s << "packet_bytes->reserve(packet->size());";
496   s << "BitInserter it(*packet_bytes);";
497   s << "packet->Serialize(it);";
498   s << "}";
499   s << "\n#endif\n";
500   s << "#ifdef PACKET_FUZZ_TESTING\n";
501   s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
502   s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
503   s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
504   s << "public: ";
505   s << "explicit " << name_
506     << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
507   s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
508   s << "}}; ";
509   s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
510   s << "\n#endif";
511 }
512 
GetParametersToValidate() const513 FieldList PacketDef::GetParametersToValidate() const {
514   FieldList params_to_validate;
515   for (const auto& field : GetParamList()) {
516     if (field->HasParameterValidator()) {
517       params_to_validate.AppendField(field);
518     }
519   }
520   return params_to_validate;
521 }
522 
GenBuilderCreate(std::ostream & s) const523 void PacketDef::GenBuilderCreate(std::ostream& s) const {
524   s << "static std::unique_ptr<" << name_ << "Builder> Create(";
525 
526   auto params = GetParamList();
527   for (std::size_t i = 0; i < params.size(); i++) {
528     params[i]->GenBuilderParameter(s);
529     if (i != params.size() - 1) {
530       s << ", ";
531     }
532   }
533   s << ") {";
534 
535   // Call the constructor
536   s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
537 
538   params = params.GetFieldsWithoutTypes({
539       PayloadField::kFieldType,
540       BodyField::kFieldType,
541   });
542   // Add the parameters.
543   for (std::size_t i = 0; i < params.size(); i++) {
544     if (params[i]->BuilderParameterMustBeMoved()) {
545       s << "std::move(" << params[i]->GetName() << ")";
546     } else {
547       s << params[i]->GetName();
548     }
549     if (i != params.size() - 1) {
550       s << ", ";
551     }
552   }
553 
554   s << "));";
555   if (fields_.HasPayload()) {
556     s << "builder->payload_ = std::move(payload);";
557   }
558   s << "return builder;";
559   s << "}\n";
560 }
561 
GenBuilderCreatePybind11(std::ostream & s) const562 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
563   s << ".def(py::init([](";
564   auto params = GetParamList();
565   std::vector<std::string> constructor_args;
566   int i = 1;
567   for (const auto& param : params) {
568     i++;
569     std::stringstream ss;
570     auto param_type = param->GetBuilderParameterType();
571     if (param_type.empty()) {
572       continue;
573     }
574     // Use shared_ptr instead of unique_ptr for the Python interface
575     if (param->BuilderParameterMustBeMoved()) {
576       param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
577     }
578     ss << param_type << " " << param->GetName();
579     constructor_args.push_back(ss.str());
580   }
581   s << util::StringJoin(",", constructor_args) << "){";
582 
583   // Deal with move only args
584   for (const auto& param : params) {
585     std::stringstream ss;
586     auto param_type = param->GetBuilderParameterType();
587     if (param_type.empty()) {
588       continue;
589     }
590     if (!param->BuilderParameterMustBeMoved()) {
591       continue;
592     }
593     auto move_only_param_name = param->GetName() + "_move_only";
594     s << param_type << " " << move_only_param_name << ";";
595     if (param->IsContainerField()) {
596       // Assume single layer container and copy it
597       auto struct_type = param->GetElementField()->GetDataType();
598       struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
599       struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
600       s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
601       // Serialize each struct
602       s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
603       s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
604       s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
605       s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
606       // Parse it again
607       s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
608       s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
609       s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
610       // Push it into a new container
611       if (param->GetFieldType() == VectorField::kFieldType) {
612         s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
613       } else if (param->GetFieldType() == ArrayField::kFieldType) {
614         s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
615       } else {
616         ERROR() << param << " is not supported by Pybind11";
617       }
618       s << "}";
619     } else {
620       // Serialize the parameter and pass the bytes in a RawBuilder
621       s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
622       s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
623       s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
624       s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
625       s << move_only_param_name << " = ";
626       s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
627     }
628   }
629   s << "return " << name_ << "Builder::Create(";
630   std::vector<std::string> builder_vars;
631   for (const auto& param : params) {
632     std::stringstream ss;
633     auto param_type = param->GetBuilderParameterType();
634     if (param_type.empty()) {
635       continue;
636     }
637     auto param_name = param->GetName();
638     if (param->BuilderParameterMustBeMoved()) {
639       ss << "std::move(" << param_name << "_move_only)";
640     } else {
641       ss << param_name;
642     }
643     builder_vars.push_back(ss.str());
644   }
645   s << util::StringJoin(",", builder_vars) << ");}";
646   s << "))";
647 }
648 
GenBuilderParameterChecker(std::ostream & s) const649 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
650   FieldList params_to_validate = GetParametersToValidate();
651 
652   // Skip writing this function if there is nothing to validate.
653   if (params_to_validate.size() == 0) {
654     return;
655   }
656 
657   // Generate function arguments.
658   s << "void CheckParameterValues(";
659   for (std::size_t i = 0; i < params_to_validate.size(); i++) {
660     params_to_validate[i]->GenBuilderParameter(s);
661     if (i != params_to_validate.size() - 1) {
662       s << ", ";
663     }
664   }
665   s << ") {";
666 
667   // Check the parameters.
668   for (const auto& field : params_to_validate) {
669     field->GenParameterValidator(s);
670   }
671   s << "}\n";
672 }
673 
GenBuilderConstructor(std::ostream & s) const674 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
675   s << "explicit " << name_ << "Builder(";
676 
677   // Generate the constructor parameters.
678   auto params = GetParamList().GetFieldsWithoutTypes({
679       PayloadField::kFieldType,
680       BodyField::kFieldType,
681   });
682   for (std::size_t i = 0; i < params.size(); i++) {
683     params[i]->GenBuilderParameter(s);
684     if (i != params.size() - 1) {
685       s << ", ";
686     }
687   }
688   if (params.size() > 0 || parent_constraints_.size() > 0) {
689     s << ") :";
690   } else {
691     s << ")";
692   }
693 
694   // Get the list of parent params to call the parent constructor with.
695   FieldList parent_params;
696   if (parent_ != nullptr) {
697     // Pass parameters to the parent constructor
698     s << parent_->name_ << "Builder(";
699     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
700         PayloadField::kFieldType,
701         BodyField::kFieldType,
702     });
703 
704     // Go through all the fields and replace constrained fields with fixed values
705     // when calling the parent constructor.
706     for (std::size_t i = 0; i < parent_params.size(); i++) {
707       const auto& field = parent_params[i];
708       const auto& constraint = parent_constraints_.find(field->GetName());
709       if (constraint != parent_constraints_.end()) {
710         if (field->GetFieldType() == ScalarField::kFieldType) {
711           s << std::get<int64_t>(constraint->second);
712         } else if (field->GetFieldType() == EnumField::kFieldType) {
713           s << std::get<std::string>(constraint->second);
714         } else {
715           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
716         }
717 
718         s << "/* " << field->GetName() << "_ */";
719       } else {
720         s << field->GetName();
721       }
722 
723       if (i != parent_params.size() - 1) {
724         s << ", ";
725       }
726     }
727     s << ") ";
728   }
729 
730   // Build a list of parameters that excludes all parent parameters.
731   FieldList saved_params;
732   for (const auto& field : params) {
733     if (parent_params.GetField(field->GetName()) == nullptr) {
734       saved_params.AppendField(field);
735     }
736   }
737   if (parent_ != nullptr && saved_params.size() > 0) {
738     s << ",";
739   }
740   for (std::size_t i = 0; i < saved_params.size(); i++) {
741     const auto& saved_param_name = saved_params[i]->GetName();
742     if (saved_params[i]->BuilderParameterMustBeMoved()) {
743       s << saved_param_name << "_(std::move(" << saved_param_name << "))";
744     } else {
745       s << saved_param_name << "_(" << saved_param_name << ")";
746     }
747     if (i != saved_params.size() - 1) {
748       s << ",";
749     }
750   }
751   s << " {";
752 
753   FieldList params_to_validate = GetParametersToValidate();
754 
755   if (params_to_validate.size() > 0) {
756     s << "CheckParameterValues(";
757     for (std::size_t i = 0; i < params_to_validate.size(); i++) {
758       s << params_to_validate[i]->GetName() << "_";
759       if (i != params_to_validate.size() - 1) {
760         s << ", ";
761       }
762     }
763     s << ");";
764   }
765 
766   s << "}\n";
767 }
768 
GenRustChildEnums(std::ostream & s) const769 void PacketDef::GenRustChildEnums(std::ostream& s) const {
770   if (HasChildEnums()) {
771     bool payload = fields_.HasPayload();
772     s << "#[derive(Debug)] ";
773     s << "enum " << name_ << "DataChild {";
774     for (const auto& child : children_) {
775       s << child->name_ << "(Arc<" << child->name_ << "Data>),";
776     }
777     if (payload) {
778       s << "Payload(Bytes),";
779     }
780     s << "None,";
781     s << "}\n";
782 
783     s << "impl " << name_ << "DataChild {";
784     s << "fn get_total_size(&self) -> usize {";
785     s << "match self {";
786     for (const auto& child : children_) {
787       s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),";
788     }
789     if (payload) {
790       s << name_ << "DataChild::Payload(p) => p.len(),";
791     }
792     s << name_ << "DataChild::None => 0,";
793     s << "}\n";
794     s << "}\n";
795     s << "}\n";
796 
797     s << "#[derive(Debug)] ";
798     s << "pub enum " << name_ << "Child {";
799     for (const auto& child : children_) {
800       s << child->name_ << "(" << child->name_ << "Packet),";
801     }
802     if (payload) {
803       s << "Payload(Bytes),";
804     }
805     s << "None,";
806     s << "}\n";
807   }
808 }
809 
GenRustStructDeclarations(std::ostream & s) const810 void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
811   s << "#[derive(Debug)] ";
812   s << "struct " << name_ << "Data {";
813 
814   // Generate struct fields
815   GenRustStructFieldNameAndType(s);
816   if (HasChildEnums()) {
817     s << "child: " << name_ << "DataChild,";
818   }
819   s << "}\n";
820 
821   // Generate accessor struct
822   s << "#[derive(Debug, Clone)] ";
823   s << "pub struct " << name_ << "Packet {";
824   auto lineage = GetAncestors();
825   lineage.push_back(this);
826   for (auto it = lineage.begin(); it != lineage.end(); it++) {
827     auto def = *it;
828     s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,";
829   }
830   s << "}\n";
831 
832   // Generate builder struct
833   s << "#[derive(Debug)] ";
834   s << "pub struct " << name_ << "Builder {";
835   auto params = GetParamList().GetFieldsWithoutTypes({
836       PayloadField::kFieldType,
837       BodyField::kFieldType,
838   });
839   for (auto param : params) {
840     s << "pub ";
841     param->GenRustNameAndType(s);
842     s << ", ";
843   }
844   if (fields_.HasPayload()) {
845     s << "pub payload: Option<Bytes>,";
846   }
847   s << "}\n";
848 }
849 
GenRustStructFieldNameAndType(std::ostream & s) const850 bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
851   auto fields = fields_.GetFieldsWithoutTypes({
852       BodyField::kFieldType,
853       CountField::kFieldType,
854       PaddingField::kFieldType,
855       ReservedField::kFieldType,
856       SizeField::kFieldType,
857       PayloadField::kFieldType,
858       FixedScalarField::kFieldType,
859   });
860   if (fields.size() == 0) {
861     return false;
862   }
863   for (const auto& field : fields) {
864     field->GenRustNameAndType(s);
865     s << ", ";
866   }
867   return true;
868 }
869 
GenRustStructFieldNames(std::ostream & s) const870 void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
871   auto fields = fields_.GetFieldsWithoutTypes({
872       BodyField::kFieldType,
873       CountField::kFieldType,
874       PaddingField::kFieldType,
875       ReservedField::kFieldType,
876       SizeField::kFieldType,
877       PayloadField::kFieldType,
878       FixedScalarField::kFieldType,
879   });
880   for (const auto field : fields) {
881     s << field->GetName();
882     s << ", ";
883   }
884 }
885 
GenRustStructImpls(std::ostream & s) const886 void PacketDef::GenRustStructImpls(std::ostream& s) const {
887   auto packet_dep = PacketDependency(GetRootDef());
888 
889   s << "impl " << name_ << "Data {";
890   // conforms function
891   s << "fn conforms(bytes: &[u8]) -> bool {";
892   GenRustConformanceCheck(s);
893 
894   auto fields = fields_.GetFieldsWithTypes({
895       StructField::kFieldType,
896   });
897 
898   for (auto const& field : fields) {
899     auto start_offset = GetOffsetForField(field->GetName(), false);
900     auto end_offset = GetOffsetForField(field->GetName(), true);
901 
902     s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes();
903     s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }";
904   }
905 
906   s << " true";
907   s << "}";
908 
909   auto parse_params = packet_dep.GetDependencies(name_);
910   s << "fn parse(bytes: &[u8]";
911   for (auto field_name : parse_params) {
912     auto constraint_field = GetParamList().GetField(field_name);
913     auto constraint_type = constraint_field->GetRustDataType();
914     s << ", " << field_name << ": " << constraint_type;
915   }
916   s << ") -> Result<Self> {";
917 
918   fields = fields_.GetFieldsWithoutTypes({
919       BodyField::kFieldType,
920   });
921 
922   for (auto const& field : fields) {
923     auto start_field_offset = GetOffsetForField(field->GetName(), false);
924     auto end_field_offset = GetOffsetForField(field->GetName(), true);
925 
926     if (start_field_offset.empty() && end_field_offset.empty()) {
927       ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
928                    << "no method exists to determine field location from begin() or end().\n";
929     }
930 
931     field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
932     field->GenRustGetter(s, start_field_offset, end_field_offset, name_);
933   }
934 
935   auto payload_field = fields_.GetFieldsWithTypes({
936     PayloadField::kFieldType,
937   });
938 
939   Size payload_offset;
940 
941   if (payload_field.HasPayload()) {
942     payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
943   }
944 
945   if (children_.size() > 1) {
946     auto match_on_variables = packet_dep.GetChildrenDependencies(name_);
947     // If match_on_variables is empty, this means there are multiple abstract packets which will
948     // specialize to a child down the packet tree.
949     // In this case match variables will be the union of parent fields and parse params of children.
950     if (match_on_variables.empty()) {
951       for (auto& field : fields_) {
952         if (std::any_of(children_.begin(), children_.end(), [&](auto child) {
953               auto pass_me = packet_dep.GetDependencies(child->name_);
954               return std::find(pass_me.begin(), pass_me.end(), field->GetName()) != pass_me.end();
955             })) {
956           match_on_variables.push_back(field->GetName());
957         }
958       }
959     }
960 
961     s << "let child = match (";
962 
963     for (auto var : match_on_variables) {
964       if (var == match_on_variables[match_on_variables.size() - 1]) {
965         s << var;
966       } else {
967         s << var << ", ";
968       }
969     }
970     s << ") {";
971 
972     auto get_match_val = [&](
973         std::string& match_var,
974         std::variant<int64_t,
975         std::string> constraint) -> std::string {
976       auto constraint_field = GetParamList().GetField(match_var);
977       auto constraint_type = constraint_field->GetFieldType();
978 
979       if (constraint_type == EnumField::kFieldType) {
980         auto type = std::get<std::string>(constraint);
981         auto variant_name = type.substr(type.find("::") + 2, type.length());
982         auto enum_type = type.substr(0, type.find("::"));
983         return enum_type + "::" + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
984       }
985       if (constraint_type == ScalarField::kFieldType) {
986         return std::to_string(std::get<int64_t>(constraint));
987       }
988       return "_";
989     };
990 
991     for (auto& child : children_) {
992       s << "(";
993       for (auto var : match_on_variables) {
994         std::string match_val = "_";
995 
996         if (child->parent_constraints_.find(var) != child->parent_constraints_.end()) {
997           match_val = get_match_val(var, child->parent_constraints_[var]);
998         } else {
999           auto dcs = child->FindDescendantsWithConstraint(var);
1000           std::vector<std::string> all_match_vals;
1001           for (auto& desc : dcs) {
1002             all_match_vals.push_back(get_match_val(var, desc.second));
1003           }
1004           match_val = "";
1005           for (std::size_t i = 0; i < all_match_vals.size(); ++i) {
1006             match_val += all_match_vals[i];
1007             if (i != all_match_vals.size() - 1) {
1008               match_val += " | ";
1009             }
1010           }
1011           match_val = (match_val == "") ? "_" : match_val;
1012         }
1013 
1014         if (var == match_on_variables[match_on_variables.size() - 1]) {
1015           s << match_val << ")";
1016         } else {
1017           s << match_val << ", ";
1018         }
1019       }
1020       s << " if " << child->name_ << "Data::conforms(&bytes[..])";
1021       s << " => {";
1022       s << name_ << "DataChild::";
1023       s << child->name_ << "(Arc::new(";
1024 
1025       auto child_parse_params = packet_dep.GetDependencies(child->name_);
1026       if (child_parse_params.size() == 0) {
1027         s << child->name_ << "Data::parse(&bytes[..]";
1028       } else {
1029         s << child->name_ << "Data::parse(&bytes[..], ";
1030       }
1031 
1032       for (auto var : child_parse_params) {
1033         if (var == child_parse_params[child_parse_params.size() - 1]) {
1034           s << var;
1035         } else {
1036           s << var << ", ";
1037         }
1038       }
1039       s << ")?))";
1040       s << "}\n";
1041     }
1042 
1043     s << "(";
1044     for (int i = 1; i <= match_on_variables.size(); i++) {
1045       if (i == match_on_variables.size()) {
1046         s << "_";
1047       } else {
1048         s << "_, ";
1049       }
1050     }
1051     s << ")";
1052     s << " => return Err(Error::InvalidPacketError),";
1053     s << "};\n";
1054   } else if (children_.size() == 1) {
1055     auto child = children_.at(0);
1056     auto params = packet_dep.GetDependencies(child->name_);
1057     s << "let child = match " << child->name_ << "Data::parse(&bytes[..]";
1058     for (auto field_name : params) {
1059       s << ", " << field_name;
1060     }
1061     s << ") {";
1062     s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
1063     s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
1064     s << " },";
1065     s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,";
1066     s << " _ => return Err(Error::InvalidPacketError),";
1067     s << "};";
1068   } else if (fields_.HasPayload()) {
1069     s << "let child = if payload.len() > 0 {";
1070     s << name_ << "DataChild::Payload(Bytes::from(payload))";
1071     s << "} else {";
1072     s << name_ << "DataChild::None";
1073     s << "};";
1074   }
1075 
1076   s << "Ok(Self {";
1077   fields = fields_.GetFieldsWithoutTypes({
1078       BodyField::kFieldType,
1079       CountField::kFieldType,
1080       PaddingField::kFieldType,
1081       ReservedField::kFieldType,
1082       SizeField::kFieldType,
1083       PayloadField::kFieldType,
1084       FixedScalarField::kFieldType,
1085   });
1086 
1087   if (fields.size() > 0) {
1088     for (const auto& field : fields) {
1089       auto field_type = field->GetFieldType();
1090       s << field->GetName();
1091       s << ", ";
1092     }
1093   }
1094 
1095   if (HasChildEnums()) {
1096     s << "child,";
1097   }
1098   s << "})\n";
1099   s << "}\n";
1100 
1101   // write_to function
1102   s << "fn write_to(&self, buffer: &mut BytesMut) {";
1103   GenRustWriteToFields(s);
1104 
1105   if (HasChildEnums()) {
1106     s << "match &self.child {";
1107     for (const auto& child : children_) {
1108       s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
1109     }
1110     if (fields_.HasPayload()) {
1111       auto offset = GetOffsetForField("payload");
1112       s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),";
1113     }
1114     s << name_ << "DataChild::None => {}";
1115     s << "}";
1116   }
1117 
1118   s << "}\n";
1119 
1120   s << "fn get_total_size(&self) -> usize {";
1121   if (HasChildEnums()) {
1122     s << "self.get_size() + self.child.get_total_size()";
1123   } else {
1124     s << "self.get_size()";
1125   }
1126   s << "}\n";
1127 
1128   s << "fn get_size(&self) -> usize {";
1129   GenSizeRetVal(s);
1130   s << "}\n";
1131   s << "}\n";
1132 }
1133 
GenRustAccessStructImpls(std::ostream & s) const1134 void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
1135   if (complement_ != nullptr) {
1136     auto complement_root = complement_->GetRootDef();
1137     auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1138     s << "impl CommandExpectations for " << name_ << "Packet {";
1139     s << " type ResponseType = " << complement_->name_ << "Packet;";
1140     s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1141     s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1142       << ".unwrap()";
1143     s << " }";
1144     s << "}";
1145   }
1146 
1147   s << "impl Packet for " << name_ << "Packet {";
1148   auto root = GetRootDef();
1149   auto root_accessor = util::CamelCaseToUnderScore(root->name_);
1150 
1151   s << "fn to_bytes(self) -> Bytes {";
1152   s << " let mut buffer = BytesMut::new();";
1153   s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);";
1154   s << " self." << root_accessor << ".write_to(&mut buffer);";
1155   s << " buffer.freeze()";
1156   s << "}\n";
1157 
1158   s << "fn to_vec(self) -> Vec<u8> { self.to_bytes().to_vec() }\n";
1159   s << "}";
1160 
1161   s << "impl From<" << name_ << "Packet"
1162     << "> for Bytes {\n";
1163   s << "fn from(packet: " << name_ << "Packet"
1164     << ") -> Self {\n";
1165   s << "packet.to_bytes()\n";
1166   s << "}\n";
1167   s << "}\n";
1168 
1169   s << "impl From<" << name_ << "Packet"
1170     << "> for Vec<u8> {\n";
1171   s << "fn from(packet: " << name_ << "Packet"
1172     << ") -> Self {\n";
1173   s << "packet.to_vec()\n";
1174   s << "}\n";
1175   s << "}\n";
1176 
1177   if (root != this) {
1178     s << "impl TryFrom<" << root->name_ << "Packet"
1179       << "> for " << name_ << "Packet {\n";
1180     s << "type Error = TryFromError;\n";
1181     s << "fn try_from(value: " << root->name_ << "Packet)"
1182       << " -> std::result::Result<Self, Self::Error> {\n";
1183     s << "Self::new(value." << root_accessor << ").map_err(TryFromError)\n", s << "}\n";
1184     s << "}\n";
1185   }
1186 
1187   s << "impl " << name_ << "Packet {";
1188   if (parent_ == nullptr) {
1189     s << "pub fn parse(bytes: &[u8]) -> Result<Self> { ";
1190     s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)).unwrap())";
1191     s << "}";
1192   }
1193 
1194   if (HasChildEnums()) {
1195     s << " pub fn specialize(&self) -> " << name_ << "Child {";
1196     s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
1197     for (const auto& child : children_) {
1198       s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
1199         << child->name_ << "Packet::new(self." << root_accessor << ".clone()).unwrap()),";
1200     }
1201     if (fields_.HasPayload()) {
1202       s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),";
1203     }
1204     s << name_ << "DataChild::None => " << name_ << "Child::None,";
1205     s << "}}";
1206   }
1207   auto lineage = GetAncestors();
1208   lineage.push_back(this);
1209   const ParentDef* prev = nullptr;
1210 
1211   s << " fn new(root: Arc<" << root->name_ << "Data>) -> std::result::Result<Self, &'static str> {";
1212   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1213     auto def = *it;
1214     auto accessor_name = util::CamelCaseToUnderScore(def->name_);
1215     if (prev == nullptr) {
1216       s << "let " << accessor_name << " = root;";
1217     } else {
1218       s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {";
1219       s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),";
1220       s << "_ => return Err(\"inconsistent state - child was not " << def->name_ << "\"),";
1221       s << "};";
1222     }
1223     prev = def;
1224   }
1225   s << "Ok(Self {";
1226   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1227     auto def = *it;
1228     s << util::CamelCaseToUnderScore(def->name_) << ",";
1229   }
1230   s << "})}";
1231 
1232   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1233     auto def = *it;
1234     auto fields = def->fields_.GetFieldsWithoutTypes({
1235         BodyField::kFieldType,
1236         CountField::kFieldType,
1237         PaddingField::kFieldType,
1238         ReservedField::kFieldType,
1239         SizeField::kFieldType,
1240         PayloadField::kFieldType,
1241         FixedScalarField::kFieldType,
1242     });
1243 
1244     for (auto const& field : fields) {
1245       if (field->GetterIsByRef()) {
1246         s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{";
1247         s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1248         s << "}\n";
1249       } else {
1250         s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{";
1251         s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1252         s << "}\n";
1253       }
1254     }
1255   }
1256 
1257   s << "}\n";
1258 
1259   lineage = GetAncestors();
1260   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1261     auto def = *it;
1262     s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {";
1263     s << " fn into(self) -> " << def->name_ << "Packet {";
1264     s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")"
1265       << ".unwrap()";
1266     s << " }";
1267     s << "}\n";
1268   }
1269 }
1270 
GenRustBuilderStructImpls(std::ostream & s) const1271 void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
1272   if (complement_ != nullptr) {
1273     auto complement_root = complement_->GetRootDef();
1274     auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1275     s << "impl CommandExpectations for " << name_ << "Builder {";
1276     s << " type ResponseType = " << complement_->name_ << "Packet;";
1277     s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1278     s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1279       << ".unwrap()";
1280     s << " }";
1281     s << "}";
1282   }
1283 
1284   s << "impl " << name_ << "Builder {";
1285   s << "pub fn build(self) -> " << name_ << "Packet {";
1286   auto lineage = GetAncestors();
1287   lineage.push_back(this);
1288   std::reverse(lineage.begin(), lineage.end());
1289 
1290   auto all_constraints = GetAllConstraints();
1291 
1292   const ParentDef* prev = nullptr;
1293   for (auto ancestor : lineage) {
1294     auto fields = ancestor->fields_.GetFieldsWithoutTypes({
1295         BodyField::kFieldType,
1296         CountField::kFieldType,
1297         PaddingField::kFieldType,
1298         ReservedField::kFieldType,
1299         SizeField::kFieldType,
1300         PayloadField::kFieldType,
1301         FixedScalarField::kFieldType,
1302     });
1303 
1304     auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_);
1305     s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {";
1306     for (auto field : fields) {
1307       auto constraint = all_constraints.find(field->GetName());
1308       s << field->GetName() << ": ";
1309       if (constraint != all_constraints.end()) {
1310         if (field->GetFieldType() == ScalarField::kFieldType) {
1311           s << std::get<int64_t>(constraint->second);
1312         } else if (field->GetFieldType() == EnumField::kFieldType) {
1313           auto value = std::get<std::string>(constraint->second);
1314           auto constant = value.substr(value.find("::") + 2, std::string::npos);
1315           s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant);
1316           ;
1317         } else {
1318           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
1319         }
1320       } else {
1321         s << "self." << field->GetName();
1322       }
1323       s << ", ";
1324     }
1325     if (ancestor->HasChildEnums()) {
1326       if (prev == nullptr) {
1327         if (ancestor->fields_.HasPayload()) {
1328           s << "child: match self.payload { ";
1329           s << "None => " << name_ << "DataChild::None,";
1330           s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),";
1331           s << "},";
1332         } else {
1333           s << "child: " << name_ << "DataChild::None,";
1334         }
1335       } else {
1336         s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "("
1337           << util::CamelCaseToUnderScore(prev->name_) << "),";
1338       }
1339     }
1340     s << "});";
1341     prev = ancestor;
1342   }
1343 
1344   s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ").unwrap()";
1345   s << "}\n";
1346 
1347   s << "}\n";
1348   for (const auto ancestor : GetAncestors()) {
1349     s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {";
1350     s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }";
1351     s << "}\n";
1352   }
1353 }
1354 
GenRustBuilderTest(std::ostream & s) const1355 void PacketDef::GenRustBuilderTest(std::ostream& s) const {
1356   auto lineage = GetAncestors();
1357   lineage.push_back(this);
1358   if (!lineage.empty() && !test_cases_.empty()) {
1359     s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { ";
1360     s << "($($name:ident: $byte_string:expr,)*) => {";
1361     s << "$(";
1362     s << "\n#[test]\n";
1363     s << "pub fn $name() { ";
1364     s << "let raw_bytes = $byte_string;";
1365     for (size_t i = 0; i < lineage.size(); i++) {
1366       s << "/* (" << i << ") */\n";
1367       if (i == 0) {
1368         s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {";
1369         s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1370         s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1371       } else if (i != lineage.size() - 1) {
1372         s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(";
1373         s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1374         s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1375       } else {
1376         s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {";
1377         s << "let rebuilder = " << lineage[i]->name_ << "Builder {";
1378         FieldList params = GetParamList();
1379         if (params.HasBody()) {
1380           ERROR() << "Packets with body fields can't be auto-tested.  Test a child.";
1381         }
1382         for (const auto param : params) {
1383           s << param->GetName() << " : packet.";
1384           if (param->GetFieldType() == VectorField::kFieldType) {
1385             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1386           } else if (param->GetFieldType() == ArrayField::kFieldType) {
1387             const auto array_param = static_cast<const ArrayField*>(param);
1388             const auto element_field = array_param->GetElementField();
1389             if (element_field->GetFieldType() == StructField::kFieldType) {
1390               s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1391             } else {
1392               s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1393             }
1394           } else if (param->GetFieldType() == StructField::kFieldType) {
1395             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1396           } else {
1397             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),";
1398           }
1399         }
1400         s << "};";
1401         s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();";
1402         s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();";
1403         s << "assert_eq!(rebuilder_bytes, raw_bytes);";
1404         s << "}";
1405       }
1406     }
1407     for (size_t i = 1; i < lineage.size(); i++) {
1408       s << "_ => {";
1409       s << "println!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_);
1410       s << "{:02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); ";
1411       s << "}}}";
1412     }
1413 
1414     s << ",";
1415     s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),";
1416     s << "}";
1417     s << "}";
1418     s << ")*";
1419     s << "}";
1420     s << "}";
1421 
1422     s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { ";
1423     int number = 0;
1424     for (const auto& test_case : test_cases_) {
1425       s << util::CamelCaseToUnderScore(name_) << "_builder_test_";
1426       s << std::setfill('0') << std::setw(2) << number++ << ": ";
1427       s << "b\"" << test_case << "\",";
1428     }
1429     s << "}";
1430     s << "\n";
1431   }
1432 }
1433 
GenRustDef(std::ostream & s) const1434 void PacketDef::GenRustDef(std::ostream& s) const {
1435   GenRustChildEnums(s);
1436   GenRustStructDeclarations(s);
1437   GenRustStructImpls(s);
1438   GenRustAccessStructImpls(s);
1439   GenRustBuilderStructImpls(s);
1440   GenRustBuilderTest(s);
1441 }
1442