1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "packet_def.h"
18
19 #include <iomanip>
20 #include <list>
21 #include <set>
22
23 #include "fields/all_fields.h"
24 #include "util.h"
25
PacketDef(std::string name,FieldList fields)26 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)27 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
28
GetNewField(const std::string &,ParseLocation) const29 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
30 return nullptr; // Packets can't be fields
31 }
32
GenParserDefinition(std::ostream & s) const33 void PacketDef::GenParserDefinition(std::ostream& s) const {
34 s << "class " << name_ << "View";
35 if (parent_ != nullptr) {
36 s << " : public " << parent_->name_ << "View {";
37 } else {
38 s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
39 }
40 s << " public:";
41
42 // Specialize function
43 if (parent_ != nullptr) {
44 s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
45 s << "{ return " << name_ << "View(std::move(parent)); }";
46 } else {
47 s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
48 s << "{ return " << name_ << "View(std::move(packet)); }";
49 }
50
51 GenTestingParserFromBytes(s);
52
53 std::set<std::string> fixed_types = {
54 FixedScalarField::kFieldType,
55 FixedEnumField::kFieldType,
56 };
57
58 // Print all of the public fields which are all the fields minus the fixed fields.
59 const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
60 bool has_fixed_fields = public_fields.size() != fields_.size();
61 for (const auto& field : public_fields) {
62 GenParserFieldGetter(s, field);
63 s << "\n";
64 }
65 GenValidator(s);
66 s << "\n";
67
68 s << " public:";
69 GenParserToString(s);
70 s << "\n";
71
72 s << " protected:\n";
73 // Constructor from a View
74 if (parent_ != nullptr) {
75 s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
76 s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
77 } else {
78 s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
79 s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
80 }
81
82 // Print the private fields which are the fixed fields.
83 if (has_fixed_fields) {
84 const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
85 s << " private:\n";
86 for (const auto& field : private_fields) {
87 GenParserFieldGetter(s, field);
88 s << "\n";
89 }
90 }
91 s << "};\n";
92 }
93
GenTestingParserFromBytes(std::ostream & s) const94 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
95 s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
96
97 s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
98 s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
99 s << "return " << name_ << "View::Create(";
100 auto ancestor_ptr = parent_;
101 size_t parent_parens = 0;
102 while (ancestor_ptr != nullptr) {
103 s << ancestor_ptr->name_ << "View::Create(";
104 parent_parens++;
105 ancestor_ptr = ancestor_ptr->parent_;
106 }
107 s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
108 for (size_t i = 0; i < parent_parens; i++) {
109 s << ")";
110 }
111 s << ");";
112 s << "}";
113
114 s << "\n#endif\n";
115 }
116
GenParserDefinitionPybind11(std::ostream & s) const117 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
118 s << "py::class_<" << name_ << "View";
119 if (parent_ != nullptr) {
120 s << ", " << parent_->name_ << "View";
121 } else {
122 s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
123 }
124 s << ">(m, \"" << name_ << "View\")";
125 if (parent_ != nullptr) {
126 s << ".def(py::init([](" << parent_->name_ << "View parent) {";
127 } else {
128 s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
129 }
130 s << "auto view =" << name_ << "View::Create(std::move(parent));";
131 s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
132 s << "return view; }))";
133
134 s << ".def(py::init(&" << name_ << "View::Create))";
135 std::set<std::string> protected_field_types = {
136 FixedScalarField::kFieldType,
137 FixedEnumField::kFieldType,
138 SizeField::kFieldType,
139 CountField::kFieldType,
140 };
141 const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
142 for (const auto& field : public_fields) {
143 auto getter_func_name = field->GetGetterFunctionName();
144 if (getter_func_name.empty()) {
145 continue;
146 }
147 s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
148 }
149 s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
150 s << ";\n";
151 }
152
GenParserFieldGetter(std::ostream & s,const PacketField * field) const153 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
154 // Start field offset
155 auto start_field_offset = GetOffsetForField(field->GetName(), false);
156 auto end_field_offset = GetOffsetForField(field->GetName(), true);
157
158 if (start_field_offset.empty() && end_field_offset.empty()) {
159 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
160 << "no method exists to determine field location from begin() or end().\n";
161 }
162
163 field->GenGetter(s, start_field_offset, end_field_offset);
164 }
165
GetDefinitionType() const166 TypeDef::Type PacketDef::GetDefinitionType() const {
167 return TypeDef::Type::PACKET;
168 }
169
GenValidator(std::ostream & s) const170 void PacketDef::GenValidator(std::ostream& s) const {
171 // Get the static offset for all of our fields.
172 int bits_size = 0;
173 for (const auto& field : fields_) {
174 if (field->GetFieldType() != PaddingField::kFieldType) {
175 bits_size += field->GetSize().bits();
176 }
177 }
178
179 // Write the function declaration.
180 s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {";
181 s << "if (was_validated_) { return true; } ";
182 s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }";
183 s << "}";
184
185 s << "protected:";
186 s << "virtual bool IsValid_() const {";
187
188 if (parent_ != nullptr) {
189 s << "if (!" << parent_->name_ << "View::IsValid_()) { return false; } ";
190 }
191
192 // Offset by the parents known size. We know that any dynamic fields can
193 // already be called since the parent must have already been validated by
194 // this point.
195 auto parent_size = Size(0);
196 if (parent_ != nullptr) {
197 parent_size = parent_->GetSize(true);
198 }
199
200 s << "auto it = begin() + (" << parent_size << ") / 8;";
201
202 // Check if you can extract the static fields.
203 // At this point you know you can use the size getters without crashing
204 // as long as they follow the instruction that size fields cant come before
205 // their corrisponding variable length field.
206 s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
207 s << "if (it > end()) return false;";
208
209 // For any variable length fields, use their size check.
210 for (const auto& field : fields_) {
211 if (field->GetFieldType() == ChecksumStartField::kFieldType) {
212 auto offset = GetOffsetForField(field->GetName(), false);
213 if (!offset.empty()) {
214 s << "size_t sum_index = (" << offset << ") / 8;";
215 } else {
216 offset = GetOffsetForField(field->GetName(), true);
217 if (offset.empty()) {
218 ERROR(field) << "Checksum Start Field offset can not be determined.";
219 }
220 s << "size_t sum_index = size() - (" << offset << ") / 8;";
221 }
222
223 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
224 const auto& started_field = fields_.GetField(field_name);
225 if (started_field == nullptr) {
226 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
227 << ")";
228 }
229 auto end_offset = GetOffsetForField(started_field->GetName(), false);
230 if (!end_offset.empty()) {
231 s << "size_t end_sum_index = (" << end_offset << ") / 8;";
232 } else {
233 end_offset = GetOffsetForField(started_field->GetName(), true);
234 if (end_offset.empty()) {
235 ERROR(started_field) << "Checksum Field end_offset can not be determined.";
236 }
237 s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
238 }
239 if (is_little_endian_) {
240 s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
241 } else {
242 s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
243 }
244 s << started_field->GetDataType() << " checksum;";
245 s << "checksum.Initialize();";
246 s << "for (uint8_t byte : checksum_view) { ";
247 s << "checksum.AddByte(byte);}";
248 s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
249 << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
250
251 continue;
252 }
253
254 auto field_size = field->GetSize();
255 // Fixed size fields have already been handled.
256 if (!field_size.has_dynamic()) {
257 continue;
258 }
259
260 // Custom fields with dynamic size must have the offset for the field passed in as well
261 // as the end iterator so that they may ensure that they don't try to read past the end.
262 // Custom fields with fixed sizes will be handled in the static offset checking.
263 if (field->GetFieldType() == CustomField::kFieldType) {
264 // Check if we can determine offset from begin(), otherwise error because by this point,
265 // the size of the custom field is unknown and can't be subtracted from end() to get the
266 // offset.
267 auto offset = GetOffsetForField(field->GetName(), false);
268 if (offset.empty()) {
269 ERROR(field) << "Custom Field offset can not be determined from begin().";
270 }
271
272 if (offset.bits() % 8 != 0) {
273 ERROR(field) << "Custom fields must be byte aligned.";
274 }
275
276 // Custom fields are special as their size field takes an argument.
277 const auto& custom_size_var = field->GetName() + "_size";
278 s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
279 s << "(begin() + (" << offset << ") / 8);";
280
281 s << "if (!" << custom_size_var << ".has_value()) { return false; }";
282 s << "it += *" << custom_size_var << ";";
283 s << "if (it > end()) return false;";
284 continue;
285 } else {
286 s << "it += (" << field_size.dynamic_string() << ") / 8;";
287 s << "if (it > end()) return false;";
288 }
289 }
290
291 // Validate constraints after validating the size
292 if (parent_constraints_.size() > 0 && parent_ == nullptr) {
293 ERROR() << "Can't have a constraint on a NULL parent";
294 }
295
296 for (const auto& constraint : parent_constraints_) {
297 s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
298 const auto& field = parent_->GetParamList().GetField(constraint.first);
299 if (field->GetFieldType() == ScalarField::kFieldType) {
300 s << std::get<int64_t>(constraint.second);
301 } else {
302 s << std::get<std::string>(constraint.second);
303 }
304 s << ") return false;";
305 }
306
307 // Validate the packets fields last
308 for (const auto& field : fields_) {
309 field->GenValidator(s);
310 s << "\n";
311 }
312
313 s << "return true;";
314 s << "}\n";
315 if (parent_ == nullptr) {
316 s << "bool was_validated_{false};\n";
317 }
318 }
319
GenParserToString(std::ostream & s) const320 void PacketDef::GenParserToString(std::ostream& s) const {
321 s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
322 s << "std::stringstream ss;";
323 s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
324
325 if (fields_.size() > 0) {
326 s << "ss << \"\" ";
327 bool firstfield = true;
328 for (const auto& field : fields_) {
329 if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
330 field->GetFieldType() == ChecksumStartField::kFieldType)
331 continue;
332
333 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
334
335 field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
336
337 if (firstfield) {
338 firstfield = false;
339 }
340 }
341 s << ";";
342 }
343
344 s << "ss << \" }\";";
345 s << "return ss.str();";
346 s << "}\n";
347 }
348
GenBuilderDefinition(std::ostream & s) const349 void PacketDef::GenBuilderDefinition(std::ostream& s) const {
350 s << "class " << name_ << "Builder";
351 if (parent_ != nullptr) {
352 s << " : public " << parent_->name_ << "Builder";
353 } else {
354 if (is_little_endian_) {
355 s << " : public PacketBuilder<kLittleEndian>";
356 } else {
357 s << " : public PacketBuilder<!kLittleEndian>";
358 }
359 }
360 s << " {";
361 s << " public:";
362 s << " virtual ~" << name_ << "Builder() = default;";
363
364 if (!fields_.HasBody()) {
365 GenBuilderCreate(s);
366 s << "\n";
367
368 GenTestingFromView(s);
369 s << "\n";
370 }
371
372 GenSerialize(s);
373 s << "\n";
374
375 GenSize(s);
376 s << "\n";
377
378 s << " protected:\n";
379 GenBuilderConstructor(s);
380 s << "\n";
381
382 GenBuilderParameterChecker(s);
383 s << "\n";
384
385 GenMembers(s);
386 s << "};\n";
387
388 GenTestDefine(s);
389 s << "\n";
390
391 GenFuzzTestDefine(s);
392 s << "\n";
393 }
394
GenTestingFromView(std::ostream & s) const395 void PacketDef::GenTestingFromView(std::ostream& s) const {
396 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
397
398 s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
399 s << "return " << name_ << "Builder::Create(";
400 FieldList params = GetParamList().GetFieldsWithoutTypes({
401 BodyField::kFieldType,
402 });
403 for (std::size_t i = 0; i < params.size(); i++) {
404 params[i]->GenBuilderParameterFromView(s);
405 if (i != params.size() - 1) {
406 s << ", ";
407 }
408 }
409 s << ");";
410 s << "}";
411
412 s << "\n#endif\n";
413 }
414
GenBuilderDefinitionPybind11(std::ostream & s) const415 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
416 s << "py::class_<" << name_ << "Builder";
417 if (parent_ != nullptr) {
418 s << ", " << parent_->name_ << "Builder";
419 } else {
420 if (is_little_endian_) {
421 s << ", PacketBuilder<kLittleEndian>";
422 } else {
423 s << ", PacketBuilder<!kLittleEndian>";
424 }
425 }
426 s << ", std::shared_ptr<" << name_ << "Builder>";
427 s << ">(m, \"" << name_ << "Builder\")";
428 if (!fields_.HasBody()) {
429 GenBuilderCreatePybind11(s);
430 }
431 s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
432 s << "std::vector<uint8_t> bytes;";
433 s << "BitInserter bi(bytes);";
434 s << "builder.Serialize(bi);";
435 s << "return bytes;})";
436 s << ";\n";
437 }
438
GenTestDefine(std::ostream & s) const439 void PacketDef::GenTestDefine(std::ostream& s) const {
440 s << "#ifdef PACKET_TESTING\n";
441 s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
442 s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
443 s << "public: ";
444 s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
445 s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
446 s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
447 s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
448 s << "ASSERT_TRUE(view.IsValid());";
449 s << "auto packet = " << name_ << "Builder::FromView(view);";
450 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
451 s << "packet_bytes->reserve(packet->size());";
452 s << "BitInserter it(*packet_bytes);";
453 s << "packet->Serialize(it);";
454 s << "ASSERT_EQ(*packet_bytes, captured_packet);";
455 s << "}";
456 s << "};";
457 s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
458 s << "CompareBytes(GetParam());";
459 s << "}";
460 s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
461 s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
462 int i = 0;
463 for (const auto& bytes : test_cases_) {
464 s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
465 s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
466 s << name_ << "_test_bytes_" << i << ",";
467 s << name_ << "_test_bytes_" << i << " + sizeof(";
468 s << name_ << "_test_bytes_" << i << ") - 1);";
469 i++;
470 }
471 if (!test_cases_.empty()) {
472 i = 0;
473 s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
474 for (auto bytes : test_cases_) {
475 if (i > 0) {
476 s << ",";
477 }
478 s << name_ << "_test_vec_" << i++;
479 }
480 s << ");";
481 }
482 s << "\n#endif";
483 }
484
GenFuzzTestDefine(std::ostream & s) const485 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
486 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
487 s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
488 s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
489 s << "auto vec = std::vector<uint8_t>(data, data + size);";
490 s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
491 s << "if (!view.IsValid()) { return; }";
492 s << "auto packet = " << name_ << "Builder::FromView(view);";
493 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
494 s << "packet_bytes->reserve(packet->size());";
495 s << "BitInserter it(*packet_bytes);";
496 s << "packet->Serialize(it);";
497 s << "}";
498 s << "\n#endif\n";
499 s << "#ifdef PACKET_FUZZ_TESTING\n";
500 s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
501 s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
502 s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
503 s << "public: ";
504 s << "explicit " << name_
505 << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
506 s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
507 s << "}}; ";
508 s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
509 s << "\n#endif";
510 }
511
GetParametersToValidate() const512 FieldList PacketDef::GetParametersToValidate() const {
513 FieldList params_to_validate;
514 for (const auto& field : GetParamList()) {
515 if (field->HasParameterValidator()) {
516 params_to_validate.AppendField(field);
517 }
518 }
519 return params_to_validate;
520 }
521
GenBuilderCreate(std::ostream & s) const522 void PacketDef::GenBuilderCreate(std::ostream& s) const {
523 s << "static std::unique_ptr<" << name_ << "Builder> Create(";
524
525 auto params = GetParamList();
526 for (std::size_t i = 0; i < params.size(); i++) {
527 params[i]->GenBuilderParameter(s);
528 if (i != params.size() - 1) {
529 s << ", ";
530 }
531 }
532 s << ") {";
533
534 // Call the constructor
535 s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
536
537 params = params.GetFieldsWithoutTypes({
538 PayloadField::kFieldType,
539 BodyField::kFieldType,
540 });
541 // Add the parameters.
542 for (std::size_t i = 0; i < params.size(); i++) {
543 if (params[i]->BuilderParameterMustBeMoved()) {
544 s << "std::move(" << params[i]->GetName() << ")";
545 } else {
546 s << params[i]->GetName();
547 }
548 if (i != params.size() - 1) {
549 s << ", ";
550 }
551 }
552
553 s << "));";
554 if (fields_.HasPayload()) {
555 s << "builder->payload_ = std::move(payload);";
556 }
557 s << "return builder;";
558 s << "}\n";
559 }
560
GenBuilderCreatePybind11(std::ostream & s) const561 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
562 s << ".def(py::init([](";
563 auto params = GetParamList();
564 std::vector<std::string> constructor_args;
565 int i = 1;
566 for (const auto& param : params) {
567 i++;
568 std::stringstream ss;
569 auto param_type = param->GetBuilderParameterType();
570 if (param_type.empty()) {
571 continue;
572 }
573 // Use shared_ptr instead of unique_ptr for the Python interface
574 if (param->BuilderParameterMustBeMoved()) {
575 param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
576 }
577 ss << param_type << " " << param->GetName();
578 constructor_args.push_back(ss.str());
579 }
580 s << util::StringJoin(",", constructor_args) << "){";
581
582 // Deal with move only args
583 for (const auto& param : params) {
584 std::stringstream ss;
585 auto param_type = param->GetBuilderParameterType();
586 if (param_type.empty()) {
587 continue;
588 }
589 if (!param->BuilderParameterMustBeMoved()) {
590 continue;
591 }
592 auto move_only_param_name = param->GetName() + "_move_only";
593 s << param_type << " " << move_only_param_name << ";";
594 if (param->IsContainerField()) {
595 // Assume single layer container and copy it
596 auto struct_type = param->GetElementField()->GetDataType();
597 struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
598 struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
599 s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
600 // Serialize each struct
601 s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
602 s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
603 s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
604 s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
605 // Parse it again
606 s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
607 s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
608 s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
609 // Push it into a new container
610 if (param->GetFieldType() == VectorField::kFieldType) {
611 s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
612 } else if (param->GetFieldType() == ArrayField::kFieldType) {
613 s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
614 } else {
615 ERROR() << param << " is not supported by Pybind11";
616 }
617 s << "}";
618 } else {
619 // Serialize the parameter and pass the bytes in a RawBuilder
620 s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
621 s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
622 s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
623 s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
624 s << move_only_param_name << " = ";
625 s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
626 }
627 }
628 s << "return " << name_ << "Builder::Create(";
629 std::vector<std::string> builder_vars;
630 for (const auto& param : params) {
631 std::stringstream ss;
632 auto param_type = param->GetBuilderParameterType();
633 if (param_type.empty()) {
634 continue;
635 }
636 auto param_name = param->GetName();
637 if (param->BuilderParameterMustBeMoved()) {
638 ss << "std::move(" << param_name << "_move_only)";
639 } else {
640 ss << param_name;
641 }
642 builder_vars.push_back(ss.str());
643 }
644 s << util::StringJoin(",", builder_vars) << ");}";
645 s << "))";
646 }
647
GenBuilderParameterChecker(std::ostream & s) const648 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
649 FieldList params_to_validate = GetParametersToValidate();
650
651 // Skip writing this function if there is nothing to validate.
652 if (params_to_validate.size() == 0) {
653 return;
654 }
655
656 // Generate function arguments.
657 s << "void CheckParameterValues(";
658 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
659 params_to_validate[i]->GenBuilderParameter(s);
660 if (i != params_to_validate.size() - 1) {
661 s << ", ";
662 }
663 }
664 s << ") {";
665
666 // Check the parameters.
667 for (const auto& field : params_to_validate) {
668 field->GenParameterValidator(s);
669 }
670 s << "}\n";
671 }
672
GenBuilderConstructor(std::ostream & s) const673 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
674 s << "explicit " << name_ << "Builder(";
675
676 // Generate the constructor parameters.
677 auto params = GetParamList().GetFieldsWithoutTypes({
678 PayloadField::kFieldType,
679 BodyField::kFieldType,
680 });
681 for (std::size_t i = 0; i < params.size(); i++) {
682 params[i]->GenBuilderParameter(s);
683 if (i != params.size() - 1) {
684 s << ", ";
685 }
686 }
687 if (params.size() > 0 || parent_constraints_.size() > 0) {
688 s << ") :";
689 } else {
690 s << ")";
691 }
692
693 // Get the list of parent params to call the parent constructor with.
694 FieldList parent_params;
695 if (parent_ != nullptr) {
696 // Pass parameters to the parent constructor
697 s << parent_->name_ << "Builder(";
698 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
699 PayloadField::kFieldType,
700 BodyField::kFieldType,
701 });
702
703 // Go through all the fields and replace constrained fields with fixed values
704 // when calling the parent constructor.
705 for (std::size_t i = 0; i < parent_params.size(); i++) {
706 const auto& field = parent_params[i];
707 const auto& constraint = parent_constraints_.find(field->GetName());
708 if (constraint != parent_constraints_.end()) {
709 if (field->GetFieldType() == ScalarField::kFieldType) {
710 s << std::get<int64_t>(constraint->second);
711 } else if (field->GetFieldType() == EnumField::kFieldType) {
712 s << std::get<std::string>(constraint->second);
713 } else {
714 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
715 }
716
717 s << "/* " << field->GetName() << "_ */";
718 } else {
719 s << field->GetName();
720 }
721
722 if (i != parent_params.size() - 1) {
723 s << ", ";
724 }
725 }
726 s << ") ";
727 }
728
729 // Build a list of parameters that excludes all parent parameters.
730 FieldList saved_params;
731 for (const auto& field : params) {
732 if (parent_params.GetField(field->GetName()) == nullptr) {
733 saved_params.AppendField(field);
734 }
735 }
736 if (parent_ != nullptr && saved_params.size() > 0) {
737 s << ",";
738 }
739 for (std::size_t i = 0; i < saved_params.size(); i++) {
740 const auto& saved_param_name = saved_params[i]->GetName();
741 if (saved_params[i]->BuilderParameterMustBeMoved()) {
742 s << saved_param_name << "_(std::move(" << saved_param_name << "))";
743 } else {
744 s << saved_param_name << "_(" << saved_param_name << ")";
745 }
746 if (i != saved_params.size() - 1) {
747 s << ",";
748 }
749 }
750 s << " {";
751
752 FieldList params_to_validate = GetParametersToValidate();
753
754 if (params_to_validate.size() > 0) {
755 s << "CheckParameterValues(";
756 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
757 s << params_to_validate[i]->GetName() << "_";
758 if (i != params_to_validate.size() - 1) {
759 s << ", ";
760 }
761 }
762 s << ");";
763 }
764
765 s << "}\n";
766 }
767
GenRustChildEnums(std::ostream & s) const768 void PacketDef::GenRustChildEnums(std::ostream& s) const {
769 if (HasChildEnums()) {
770 bool payload = fields_.HasPayload();
771 s << "#[derive(Debug)] ";
772 s << "enum " << name_ << "DataChild {";
773 for (const auto& child : children_) {
774 s << child->name_ << "(Arc<" << child->name_ << "Data>),";
775 }
776 if (payload) {
777 s << "Payload(Bytes),";
778 }
779 s << "None,";
780 s << "}\n";
781
782 s << "impl " << name_ << "DataChild {";
783 s << "fn get_total_size(&self) -> usize {";
784 s << "match self {";
785 for (const auto& child : children_) {
786 s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),";
787 }
788 if (payload) {
789 s << name_ << "DataChild::Payload(p) => p.len(),";
790 }
791 s << name_ << "DataChild::None => 0,";
792 s << "}\n";
793 s << "}\n";
794 s << "}\n";
795
796 s << "#[derive(Debug)] ";
797 s << "pub enum " << name_ << "Child {";
798 for (const auto& child : children_) {
799 s << child->name_ << "(" << child->name_ << "Packet),";
800 }
801 if (payload) {
802 s << "Payload(Bytes),";
803 }
804 s << "None,";
805 s << "}\n";
806 }
807 }
808
GenRustStructDeclarations(std::ostream & s) const809 void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
810 s << "#[derive(Debug)] ";
811 s << "struct " << name_ << "Data {";
812
813 // Generate struct fields
814 GenRustStructFieldNameAndType(s);
815 if (HasChildEnums()) {
816 s << "child: " << name_ << "DataChild,";
817 }
818 s << "}\n";
819
820 // Generate accessor struct
821 s << "#[derive(Debug, Clone)] ";
822 s << "pub struct " << name_ << "Packet {";
823 auto lineage = GetAncestors();
824 lineage.push_back(this);
825 for (auto it = lineage.begin(); it != lineage.end(); it++) {
826 auto def = *it;
827 s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,";
828 }
829 s << "}\n";
830
831 // Generate builder struct
832 s << "#[derive(Debug)] ";
833 s << "pub struct " << name_ << "Builder {";
834 auto params = GetParamList().GetFieldsWithoutTypes({
835 PayloadField::kFieldType,
836 BodyField::kFieldType,
837 });
838 for (auto param : params) {
839 s << "pub ";
840 param->GenRustNameAndType(s);
841 s << ", ";
842 }
843 if (fields_.HasPayload()) {
844 s << "pub payload: Option<Bytes>,";
845 }
846 s << "}\n";
847 }
848
GenRustStructFieldNameAndType(std::ostream & s) const849 bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
850 auto fields = fields_.GetFieldsWithoutTypes({
851 BodyField::kFieldType,
852 CountField::kFieldType,
853 PaddingField::kFieldType,
854 ReservedField::kFieldType,
855 SizeField::kFieldType,
856 PayloadField::kFieldType,
857 FixedScalarField::kFieldType,
858 });
859 if (fields.size() == 0) {
860 return false;
861 }
862 for (const auto& field : fields) {
863 field->GenRustNameAndType(s);
864 s << ", ";
865 }
866 return true;
867 }
868
GenRustStructFieldNames(std::ostream & s) const869 void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
870 auto fields = fields_.GetFieldsWithoutTypes({
871 BodyField::kFieldType,
872 CountField::kFieldType,
873 PaddingField::kFieldType,
874 ReservedField::kFieldType,
875 SizeField::kFieldType,
876 PayloadField::kFieldType,
877 FixedScalarField::kFieldType,
878 });
879 for (const auto field : fields) {
880 s << field->GetName();
881 s << ", ";
882 }
883 }
884
GenRustStructImpls(std::ostream & s) const885 void PacketDef::GenRustStructImpls(std::ostream& s) const {
886 s << "impl " << name_ << "Data {";
887
888 // conforms function
889 s << "fn conforms(bytes: &[u8]) -> bool {";
890 GenRustConformanceCheck(s);
891
892 auto fields = fields_.GetFieldsWithTypes({
893 StructField::kFieldType,
894 });
895
896 for (auto const& field : fields) {
897 auto start_offset = GetOffsetForField(field->GetName(), false);
898 auto end_offset = GetOffsetForField(field->GetName(), true);
899
900 s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes();
901 s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }";
902 }
903
904 s << " true";
905 s << "}";
906
907 // parse function
908 if (parent_constraints_.empty() && children_.size() > 1 && parent_ != nullptr) {
909 auto constraint = FindConstraintField();
910 auto constraint_field = GetParamList().GetField(constraint);
911 auto constraint_type = constraint_field->GetRustDataType();
912 s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type << ") -> Result<Self> {";
913 } else {
914 s << "fn parse(bytes: &[u8]) -> Result<Self> {";
915 }
916 fields = fields_.GetFieldsWithoutTypes({
917 BodyField::kFieldType,
918 });
919
920 for (auto const& field : fields) {
921 auto start_field_offset = GetOffsetForField(field->GetName(), false);
922 auto end_field_offset = GetOffsetForField(field->GetName(), true);
923
924 if (start_field_offset.empty() && end_field_offset.empty()) {
925 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
926 << "no method exists to determine field location from begin() or end().\n";
927 }
928
929 field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
930 field->GenRustGetter(s, start_field_offset, end_field_offset);
931 }
932
933 auto payload_field = fields_.GetFieldsWithTypes({
934 PayloadField::kFieldType,
935 });
936
937 Size payload_offset;
938
939 if (payload_field.HasPayload()) {
940 payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
941 }
942
943 auto constraint_name = FindConstraintField();
944 auto constrained_descendants = FindDescendantsWithConstraint(constraint_name);
945
946 if (children_.size() > 1) {
947 s << "let child = match " << constraint_name << " {";
948
949 for (const auto& desc : constrained_descendants) {
950 auto desc_path = FindPathToDescendant(desc.first->name_);
951 std::reverse(desc_path.begin(), desc_path.end());
952 auto constraint_field = GetParamList().GetField(constraint_name);
953 auto constraint_type = constraint_field->GetFieldType();
954
955 if (constraint_type == EnumField::kFieldType) {
956 auto type = std::get<std::string>(desc.second);
957 auto variant_name = type.substr(type.find("::") + 2, type.length());
958 auto enum_type = type.substr(0, type.find("::"));
959 auto enum_variant = enum_type + "::"
960 + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
961 s << enum_variant;
962 s << " if " << desc_path[0]->name_ << "Data::conforms(&bytes[..])";
963 s << " => {";
964 s << name_ << "DataChild::";
965 s << desc_path[0]->name_ << "(Arc::new(";
966 if (desc_path[0]->parent_constraints_.empty()) {
967 s << desc_path[0]->name_ << "Data::parse(&bytes[..]";
968 s << ", " << enum_variant << ")?))";
969 } else {
970 s << desc_path[0]->name_ << "Data::parse(&bytes[..])?))";
971 }
972 } else if (constraint_type == ScalarField::kFieldType) {
973 s << std::get<int64_t>(desc.second) << " => {";
974 s << "unimplemented!();";
975 }
976 s << "}\n";
977 }
978
979 if (!constrained_descendants.empty()) {
980 s << "v => return Err(Error::ConstraintOutOfBounds{field: \"" << constraint_name
981 << "\".to_string(), value: v as u64}),";
982 }
983
984 s << "};\n";
985 } else if (children_.size() == 1) {
986 auto child = children_.at(0);
987 s << "let child = match " << child->name_ << "Data::parse(&bytes[..]) {";
988 s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
989 s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
990 s << " },";
991 s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,";
992 s << " _ => return Err(Error::InvalidPacketError),";
993 s << "};";
994 } else if (fields_.HasPayload()) {
995 s << "let child = if payload.len() > 0 {";
996 s << name_ << "DataChild::Payload(Bytes::from(payload))";
997 s << "} else {";
998 s << name_ << "DataChild::None";
999 s << "};";
1000 }
1001
1002 s << "Ok(Self {";
1003 fields = fields_.GetFieldsWithoutTypes({
1004 BodyField::kFieldType,
1005 CountField::kFieldType,
1006 PaddingField::kFieldType,
1007 ReservedField::kFieldType,
1008 SizeField::kFieldType,
1009 PayloadField::kFieldType,
1010 FixedScalarField::kFieldType,
1011 });
1012
1013 if (fields.size() > 0) {
1014 for (const auto& field : fields) {
1015 auto field_type = field->GetFieldType();
1016 s << field->GetName();
1017 s << ", ";
1018 }
1019 }
1020
1021 if (HasChildEnums()) {
1022 s << "child,";
1023 }
1024 s << "})\n";
1025 s << "}\n";
1026
1027 // write_to function
1028 s << "fn write_to(&self, buffer: &mut BytesMut) {";
1029 GenRustWriteToFields(s);
1030
1031 if (HasChildEnums()) {
1032 s << "match &self.child {";
1033 for (const auto& child : children_) {
1034 s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
1035 }
1036 if (fields_.HasPayload()) {
1037 auto offset = GetOffsetForField("payload");
1038 s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),";
1039 }
1040 s << name_ << "DataChild::None => {}";
1041 s << "}";
1042 }
1043
1044 s << "}\n";
1045
1046 s << "fn get_total_size(&self) -> usize {";
1047 if (HasChildEnums()) {
1048 s << "self.get_size() + self.child.get_total_size()";
1049 } else {
1050 s << "self.get_size()";
1051 }
1052 s << "}\n";
1053
1054 s << "fn get_size(&self) -> usize {";
1055 GenSizeRetVal(s);
1056 s << "}\n";
1057 s << "}\n";
1058 }
1059
GenRustAccessStructImpls(std::ostream & s) const1060 void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
1061 if (complement_ != nullptr) {
1062 auto complement_root = complement_->GetRootDef();
1063 auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1064 s << "impl CommandExpectations for " << name_ << "Packet {";
1065 s << " type ResponseType = " << complement_->name_ << "Packet;";
1066 s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1067 s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())";
1068 s << " }";
1069 s << "}";
1070 }
1071
1072 s << "impl Packet for " << name_ << "Packet {";
1073 auto root = GetRootDef();
1074 auto root_accessor = util::CamelCaseToUnderScore(root->name_);
1075
1076 s << "fn to_bytes(self) -> Bytes {";
1077 s << " let mut buffer = BytesMut::new();";
1078 s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);";
1079 s << " self." << root_accessor << ".write_to(&mut buffer);";
1080 s << " buffer.freeze()";
1081 s << "}\n";
1082
1083 s << "fn to_vec(self) -> Vec<u8> { self.to_bytes().to_vec() }\n";
1084 s << "}";
1085
1086 s << "impl " << name_ << "Packet {";
1087 if (parent_ == nullptr) {
1088 s << "pub fn parse(bytes: &[u8]) -> Result<Self> { ";
1089 s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)))";
1090 s << "}";
1091 }
1092
1093 if (HasChildEnums()) {
1094 s << " pub fn specialize(&self) -> " << name_ << "Child {";
1095 s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
1096 for (const auto& child : children_) {
1097 s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
1098 << child->name_ << "Packet::new(self." << root_accessor << ".clone())),";
1099 }
1100 if (fields_.HasPayload()) {
1101 s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),";
1102 }
1103 s << name_ << "DataChild::None => " << name_ << "Child::None,";
1104 s << "}}";
1105 }
1106 auto lineage = GetAncestors();
1107 lineage.push_back(this);
1108 const ParentDef* prev = nullptr;
1109
1110 s << " fn new(root: Arc<" << root->name_ << "Data>) -> Self {";
1111 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1112 auto def = *it;
1113 auto accessor_name = util::CamelCaseToUnderScore(def->name_);
1114 if (prev == nullptr) {
1115 s << "let " << accessor_name << " = root;";
1116 } else {
1117 s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {";
1118 s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),";
1119 s << "_ => panic!(\"inconsistent state - child was not " << def->name_ << "\"),";
1120 s << "};";
1121 }
1122 prev = def;
1123 }
1124 s << "Self {";
1125 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1126 auto def = *it;
1127 s << util::CamelCaseToUnderScore(def->name_) << ",";
1128 }
1129 s << "}}";
1130
1131 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1132 auto def = *it;
1133 auto fields = def->fields_.GetFieldsWithoutTypes({
1134 BodyField::kFieldType,
1135 CountField::kFieldType,
1136 PaddingField::kFieldType,
1137 ReservedField::kFieldType,
1138 SizeField::kFieldType,
1139 PayloadField::kFieldType,
1140 FixedScalarField::kFieldType,
1141 });
1142
1143 for (auto const& field : fields) {
1144 if (field->GetterIsByRef()) {
1145 s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{";
1146 s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1147 s << "}\n";
1148 } else {
1149 s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{";
1150 s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1151 s << "}\n";
1152 }
1153 }
1154 }
1155
1156 s << "}\n";
1157
1158 lineage = GetAncestors();
1159 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1160 auto def = *it;
1161 s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {";
1162 s << " fn into(self) -> " << def->name_ << "Packet {";
1163 s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")";
1164 s << " }";
1165 s << "}\n";
1166 }
1167 }
1168
GenRustBuilderStructImpls(std::ostream & s) const1169 void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
1170 if (complement_ != nullptr) {
1171 auto complement_root = complement_->GetRootDef();
1172 auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1173 s << "impl CommandExpectations for " << name_ << "Builder {";
1174 s << " type ResponseType = " << complement_->name_ << "Packet;";
1175 s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1176 s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())";
1177 s << " }";
1178 s << "}";
1179 }
1180
1181 s << "impl " << name_ << "Builder {";
1182 s << "pub fn build(self) -> " << name_ << "Packet {";
1183 auto lineage = GetAncestors();
1184 lineage.push_back(this);
1185 std::reverse(lineage.begin(), lineage.end());
1186
1187 auto all_constraints = GetAllConstraints();
1188
1189 const ParentDef* prev = nullptr;
1190 for (auto ancestor : lineage) {
1191 auto fields = ancestor->fields_.GetFieldsWithoutTypes({
1192 BodyField::kFieldType,
1193 CountField::kFieldType,
1194 PaddingField::kFieldType,
1195 ReservedField::kFieldType,
1196 SizeField::kFieldType,
1197 PayloadField::kFieldType,
1198 FixedScalarField::kFieldType,
1199 });
1200
1201 auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_);
1202 s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {";
1203 for (auto field : fields) {
1204 auto constraint = all_constraints.find(field->GetName());
1205 s << field->GetName() << ": ";
1206 if (constraint != all_constraints.end()) {
1207 if (field->GetFieldType() == ScalarField::kFieldType) {
1208 s << std::get<int64_t>(constraint->second);
1209 } else if (field->GetFieldType() == EnumField::kFieldType) {
1210 auto value = std::get<std::string>(constraint->second);
1211 auto constant = value.substr(value.find("::") + 2, std::string::npos);
1212 s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant);
1213 ;
1214 } else {
1215 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
1216 }
1217 } else {
1218 s << "self." << field->GetName();
1219 }
1220 s << ", ";
1221 }
1222 if (ancestor->HasChildEnums()) {
1223 if (prev == nullptr) {
1224 if (ancestor->fields_.HasPayload()) {
1225 s << "child: match self.payload { ";
1226 s << "None => " << name_ << "DataChild::None,";
1227 s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),";
1228 s << "},";
1229 } else {
1230 s << "child: " << name_ << "DataChild::None,";
1231 }
1232 } else {
1233 s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "("
1234 << util::CamelCaseToUnderScore(prev->name_) << "),";
1235 }
1236 }
1237 s << "});";
1238 prev = ancestor;
1239 }
1240
1241 s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ")";
1242 s << "}\n";
1243
1244 s << "}\n";
1245 for (const auto ancestor : GetAncestors()) {
1246 s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {";
1247 s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }";
1248 s << "}\n";
1249 }
1250 }
1251
GenRustBuilderTest(std::ostream & s) const1252 void PacketDef::GenRustBuilderTest(std::ostream& s) const {
1253 auto lineage = GetAncestors();
1254 lineage.push_back(this);
1255 if (!lineage.empty() && !test_cases_.empty()) {
1256 s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { ";
1257 s << "($($name:ident: $byte_string:expr,)*) => {";
1258 s << "$(";
1259 s << "\n#[test]\n";
1260 s << "pub fn $name() { ";
1261 s << "let raw_bytes = $byte_string;";
1262 for (size_t i = 0; i < lineage.size(); i++) {
1263 s << "/* (" << i << ") */\n";
1264 if (i == 0) {
1265 s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {";
1266 s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1267 s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1268 } else if (i != lineage.size() - 1) {
1269 s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(";
1270 s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1271 s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1272 } else {
1273 s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {";
1274 s << "let rebuilder = " << lineage[i]->name_ << "Builder {";
1275 FieldList params = GetParamList();
1276 if (params.HasBody()) {
1277 ERROR() << "Packets with body fields can't be auto-tested. Test a child.";
1278 }
1279 for (const auto param : params) {
1280 s << param->GetName() << " : packet.";
1281 if (param->GetFieldType() == VectorField::kFieldType) {
1282 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1283 } else if (param->GetFieldType() == ArrayField::kFieldType) {
1284 const auto array_param = static_cast<const ArrayField*>(param);
1285 const auto element_field = array_param->GetElementField();
1286 if (element_field->GetFieldType() == StructField::kFieldType) {
1287 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1288 } else {
1289 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1290 }
1291 } else if (param->GetFieldType() == StructField::kFieldType) {
1292 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1293 } else {
1294 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),";
1295 }
1296 }
1297 s << "};";
1298 s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();";
1299 s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();";
1300 s << "assert_eq!(rebuilder_bytes, raw_bytes);";
1301 s << "}";
1302 }
1303 }
1304 for (size_t i = 1; i < lineage.size(); i++) {
1305 s << "_ => {";
1306 s << "println!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_);
1307 s << "{:02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); ";
1308 s << "}}}";
1309 }
1310
1311 s << ",";
1312 s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),";
1313 s << "}";
1314 s << "}";
1315 s << ")*";
1316 s << "}";
1317 s << "}";
1318
1319 s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { ";
1320 int number = 0;
1321 for (const auto& test_case : test_cases_) {
1322 s << util::CamelCaseToUnderScore(name_) << "_builder_test_";
1323 s << std::setfill('0') << std::setw(2) << number++ << ": ";
1324 s << "b\"" << test_case << "\",";
1325 }
1326 s << "}";
1327 s << "\n";
1328 }
1329 }
1330
GenRustDef(std::ostream & s) const1331 void PacketDef::GenRustDef(std::ostream& s) const {
1332 GenRustChildEnums(s);
1333 GenRustStructDeclarations(s);
1334 GenRustStructImpls(s);
1335 GenRustAccessStructImpls(s);
1336 GenRustBuilderStructImpls(s);
1337 GenRustBuilderTest(s);
1338 }
1339