1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "packet_def.h"
18
19 #include <iomanip>
20 #include <list>
21 #include <set>
22
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
29
GetNewField(const std::string &,ParseLocation) const30 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
31 return nullptr; // Packets can't be fields
32 }
33
GenParserDefinition(std::ostream & s) const34 void PacketDef::GenParserDefinition(std::ostream& s) const {
35 s << "class " << name_ << "View";
36 if (parent_ != nullptr) {
37 s << " : public " << parent_->name_ << "View {";
38 } else {
39 s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
40 }
41 s << " public:";
42
43 // Specialize function
44 if (parent_ != nullptr) {
45 s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
46 s << "{ return " << name_ << "View(std::move(parent)); }";
47 } else {
48 s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
49 s << "{ return " << name_ << "View(std::move(packet)); }";
50 }
51
52 GenTestingParserFromBytes(s);
53
54 std::set<std::string> fixed_types = {
55 FixedScalarField::kFieldType,
56 FixedEnumField::kFieldType,
57 };
58
59 // Print all of the public fields which are all the fields minus the fixed fields.
60 const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
61 bool has_fixed_fields = public_fields.size() != fields_.size();
62 for (const auto& field : public_fields) {
63 GenParserFieldGetter(s, field);
64 s << "\n";
65 }
66 GenValidator(s);
67 s << "\n";
68
69 s << " public:";
70 GenParserToString(s);
71 s << "\n";
72
73 s << " protected:\n";
74 // Constructor from a View
75 if (parent_ != nullptr) {
76 s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
77 s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
78 } else {
79 s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
80 s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
81 }
82
83 // Print the private fields which are the fixed fields.
84 if (has_fixed_fields) {
85 const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
86 s << " private:\n";
87 for (const auto& field : private_fields) {
88 GenParserFieldGetter(s, field);
89 s << "\n";
90 }
91 }
92 s << "};\n";
93 }
94
GenTestingParserFromBytes(std::ostream & s) const95 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
96 s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
97
98 s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
99 s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
100 s << "return " << name_ << "View::Create(";
101 auto ancestor_ptr = parent_;
102 size_t parent_parens = 0;
103 while (ancestor_ptr != nullptr) {
104 s << ancestor_ptr->name_ << "View::Create(";
105 parent_parens++;
106 ancestor_ptr = ancestor_ptr->parent_;
107 }
108 s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
109 for (size_t i = 0; i < parent_parens; i++) {
110 s << ")";
111 }
112 s << ");";
113 s << "}";
114
115 s << "\n#endif\n";
116 }
117
GenParserDefinitionPybind11(std::ostream & s) const118 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
119 s << "py::class_<" << name_ << "View";
120 if (parent_ != nullptr) {
121 s << ", " << parent_->name_ << "View";
122 } else {
123 s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
124 }
125 s << ">(m, \"" << name_ << "View\")";
126 if (parent_ != nullptr) {
127 s << ".def(py::init([](" << parent_->name_ << "View parent) {";
128 } else {
129 s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
130 }
131 s << "auto view =" << name_ << "View::Create(std::move(parent));";
132 s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
133 s << "return view; }))";
134
135 s << ".def(py::init(&" << name_ << "View::Create))";
136 std::set<std::string> protected_field_types = {
137 FixedScalarField::kFieldType,
138 FixedEnumField::kFieldType,
139 SizeField::kFieldType,
140 CountField::kFieldType,
141 };
142 const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
143 for (const auto& field : public_fields) {
144 auto getter_func_name = field->GetGetterFunctionName();
145 if (getter_func_name.empty()) {
146 continue;
147 }
148 s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
149 }
150 s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
151 s << ";\n";
152 }
153
GenParserFieldGetter(std::ostream & s,const PacketField * field) const154 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
155 // Start field offset
156 auto start_field_offset = GetOffsetForField(field->GetName(), false);
157 auto end_field_offset = GetOffsetForField(field->GetName(), true);
158
159 if (start_field_offset.empty() && end_field_offset.empty()) {
160 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
161 << "no method exists to determine field location from begin() or end().\n";
162 }
163
164 field->GenGetter(s, start_field_offset, end_field_offset);
165 }
166
GetDefinitionType() const167 TypeDef::Type PacketDef::GetDefinitionType() const {
168 return TypeDef::Type::PACKET;
169 }
170
GenValidator(std::ostream & s) const171 void PacketDef::GenValidator(std::ostream& s) const {
172 // Get the static offset for all of our fields.
173 int bits_size = 0;
174 for (const auto& field : fields_) {
175 if (field->GetFieldType() != PaddingField::kFieldType) {
176 bits_size += field->GetSize().bits();
177 }
178 }
179
180 // Write the function declaration.
181 s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {";
182 s << "if (was_validated_) { return true; } ";
183 s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }";
184 s << "}";
185
186 s << "protected:";
187 s << "virtual bool IsValid_() const {";
188
189 if (parent_ != nullptr) {
190 s << "if (!" << parent_->name_ << "View::IsValid_()) { return false; } ";
191 }
192
193 // Offset by the parents known size. We know that any dynamic fields can
194 // already be called since the parent must have already been validated by
195 // this point.
196 auto parent_size = Size(0);
197 if (parent_ != nullptr) {
198 parent_size = parent_->GetSize(true);
199 }
200
201 s << "auto it = begin() + (" << parent_size << ") / 8;";
202
203 // Check if you can extract the static fields.
204 // At this point you know you can use the size getters without crashing
205 // as long as they follow the instruction that size fields cant come before
206 // their corrisponding variable length field.
207 s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
208 s << "if (it > end()) return false;";
209
210 // For any variable length fields, use their size check.
211 for (const auto& field : fields_) {
212 if (field->GetFieldType() == ChecksumStartField::kFieldType) {
213 auto offset = GetOffsetForField(field->GetName(), false);
214 if (!offset.empty()) {
215 s << "size_t sum_index = (" << offset << ") / 8;";
216 } else {
217 offset = GetOffsetForField(field->GetName(), true);
218 if (offset.empty()) {
219 ERROR(field) << "Checksum Start Field offset can not be determined.";
220 }
221 s << "size_t sum_index = size() - (" << offset << ") / 8;";
222 }
223
224 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
225 const auto& started_field = fields_.GetField(field_name);
226 if (started_field == nullptr) {
227 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
228 << ")";
229 }
230 auto end_offset = GetOffsetForField(started_field->GetName(), false);
231 if (!end_offset.empty()) {
232 s << "size_t end_sum_index = (" << end_offset << ") / 8;";
233 } else {
234 end_offset = GetOffsetForField(started_field->GetName(), true);
235 if (end_offset.empty()) {
236 ERROR(started_field) << "Checksum Field end_offset can not be determined.";
237 }
238 s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
239 }
240 if (is_little_endian_) {
241 s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
242 } else {
243 s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
244 }
245 s << started_field->GetDataType() << " checksum;";
246 s << "checksum.Initialize();";
247 s << "for (uint8_t byte : checksum_view) { ";
248 s << "checksum.AddByte(byte);}";
249 s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
250 << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
251
252 continue;
253 }
254
255 auto field_size = field->GetSize();
256 // Fixed size fields have already been handled.
257 if (!field_size.has_dynamic()) {
258 continue;
259 }
260
261 // Custom fields with dynamic size must have the offset for the field passed in as well
262 // as the end iterator so that they may ensure that they don't try to read past the end.
263 // Custom fields with fixed sizes will be handled in the static offset checking.
264 if (field->GetFieldType() == CustomField::kFieldType) {
265 // Check if we can determine offset from begin(), otherwise error because by this point,
266 // the size of the custom field is unknown and can't be subtracted from end() to get the
267 // offset.
268 auto offset = GetOffsetForField(field->GetName(), false);
269 if (offset.empty()) {
270 ERROR(field) << "Custom Field offset can not be determined from begin().";
271 }
272
273 if (offset.bits() % 8 != 0) {
274 ERROR(field) << "Custom fields must be byte aligned.";
275 }
276
277 // Custom fields are special as their size field takes an argument.
278 const auto& custom_size_var = field->GetName() + "_size";
279 s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
280 s << "(begin() + (" << offset << ") / 8);";
281
282 s << "if (!" << custom_size_var << ".has_value()) { return false; }";
283 s << "it += *" << custom_size_var << ";";
284 s << "if (it > end()) return false;";
285 continue;
286 } else {
287 s << "it += (" << field_size.dynamic_string() << ") / 8;";
288 s << "if (it > end()) return false;";
289 }
290 }
291
292 // Validate constraints after validating the size
293 if (parent_constraints_.size() > 0 && parent_ == nullptr) {
294 ERROR() << "Can't have a constraint on a NULL parent";
295 }
296
297 for (const auto& constraint : parent_constraints_) {
298 s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
299 const auto& field = parent_->GetParamList().GetField(constraint.first);
300 if (field->GetFieldType() == ScalarField::kFieldType) {
301 s << std::get<int64_t>(constraint.second);
302 } else {
303 s << std::get<std::string>(constraint.second);
304 }
305 s << ") return false;";
306 }
307
308 // Validate the packets fields last
309 for (const auto& field : fields_) {
310 field->GenValidator(s);
311 s << "\n";
312 }
313
314 s << "return true;";
315 s << "}\n";
316 if (parent_ == nullptr) {
317 s << "bool was_validated_{false};\n";
318 }
319 }
320
GenParserToString(std::ostream & s) const321 void PacketDef::GenParserToString(std::ostream& s) const {
322 s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
323 s << "std::stringstream ss;";
324 s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
325
326 if (fields_.size() > 0) {
327 s << "ss << \"\" ";
328 bool firstfield = true;
329 for (const auto& field : fields_) {
330 if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
331 field->GetFieldType() == ChecksumStartField::kFieldType)
332 continue;
333
334 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
335
336 field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
337
338 if (firstfield) {
339 firstfield = false;
340 }
341 }
342 s << ";";
343 }
344
345 s << "ss << \" }\";";
346 s << "return ss.str();";
347 s << "}\n";
348 }
349
GenBuilderDefinition(std::ostream & s) const350 void PacketDef::GenBuilderDefinition(std::ostream& s) const {
351 s << "class " << name_ << "Builder";
352 if (parent_ != nullptr) {
353 s << " : public " << parent_->name_ << "Builder";
354 } else {
355 if (is_little_endian_) {
356 s << " : public PacketBuilder<kLittleEndian>";
357 } else {
358 s << " : public PacketBuilder<!kLittleEndian>";
359 }
360 }
361 s << " {";
362 s << " public:";
363 s << " virtual ~" << name_ << "Builder() = default;";
364
365 if (!fields_.HasBody()) {
366 GenBuilderCreate(s);
367 s << "\n";
368
369 GenTestingFromView(s);
370 s << "\n";
371 }
372
373 GenSerialize(s);
374 s << "\n";
375
376 GenSize(s);
377 s << "\n";
378
379 s << " protected:\n";
380 GenBuilderConstructor(s);
381 s << "\n";
382
383 GenBuilderParameterChecker(s);
384 s << "\n";
385
386 GenMembers(s);
387 s << "};\n";
388
389 GenTestDefine(s);
390 s << "\n";
391
392 GenFuzzTestDefine(s);
393 s << "\n";
394 }
395
GenTestingFromView(std::ostream & s) const396 void PacketDef::GenTestingFromView(std::ostream& s) const {
397 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
398
399 s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
400 s << "return " << name_ << "Builder::Create(";
401 FieldList params = GetParamList().GetFieldsWithoutTypes({
402 BodyField::kFieldType,
403 });
404 for (std::size_t i = 0; i < params.size(); i++) {
405 params[i]->GenBuilderParameterFromView(s);
406 if (i != params.size() - 1) {
407 s << ", ";
408 }
409 }
410 s << ");";
411 s << "}";
412
413 s << "\n#endif\n";
414 }
415
GenBuilderDefinitionPybind11(std::ostream & s) const416 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
417 s << "py::class_<" << name_ << "Builder";
418 if (parent_ != nullptr) {
419 s << ", " << parent_->name_ << "Builder";
420 } else {
421 if (is_little_endian_) {
422 s << ", PacketBuilder<kLittleEndian>";
423 } else {
424 s << ", PacketBuilder<!kLittleEndian>";
425 }
426 }
427 s << ", std::shared_ptr<" << name_ << "Builder>";
428 s << ">(m, \"" << name_ << "Builder\")";
429 if (!fields_.HasBody()) {
430 GenBuilderCreatePybind11(s);
431 }
432 s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
433 s << "std::vector<uint8_t> bytes;";
434 s << "BitInserter bi(bytes);";
435 s << "builder.Serialize(bi);";
436 s << "return bytes;})";
437 s << ";\n";
438 }
439
GenTestDefine(std::ostream & s) const440 void PacketDef::GenTestDefine(std::ostream& s) const {
441 s << "#ifdef PACKET_TESTING\n";
442 s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
443 s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
444 s << "public: ";
445 s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
446 s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
447 s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
448 s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
449 s << "ASSERT_TRUE(view.IsValid());";
450 s << "auto packet = " << name_ << "Builder::FromView(view);";
451 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
452 s << "packet_bytes->reserve(packet->size());";
453 s << "BitInserter it(*packet_bytes);";
454 s << "packet->Serialize(it);";
455 s << "ASSERT_EQ(*packet_bytes, captured_packet);";
456 s << "}";
457 s << "};";
458 s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
459 s << "CompareBytes(GetParam());";
460 s << "}";
461 s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
462 s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
463 int i = 0;
464 for (const auto& bytes : test_cases_) {
465 s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
466 s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
467 s << name_ << "_test_bytes_" << i << ",";
468 s << name_ << "_test_bytes_" << i << " + sizeof(";
469 s << name_ << "_test_bytes_" << i << ") - 1);";
470 i++;
471 }
472 if (!test_cases_.empty()) {
473 i = 0;
474 s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
475 for (auto bytes : test_cases_) {
476 if (i > 0) {
477 s << ",";
478 }
479 s << name_ << "_test_vec_" << i++;
480 }
481 s << ");";
482 }
483 s << "\n#endif";
484 }
485
GenFuzzTestDefine(std::ostream & s) const486 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
487 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
488 s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
489 s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
490 s << "auto vec = std::vector<uint8_t>(data, data + size);";
491 s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
492 s << "if (!view.IsValid()) { return; }";
493 s << "auto packet = " << name_ << "Builder::FromView(view);";
494 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
495 s << "packet_bytes->reserve(packet->size());";
496 s << "BitInserter it(*packet_bytes);";
497 s << "packet->Serialize(it);";
498 s << "}";
499 s << "\n#endif\n";
500 s << "#ifdef PACKET_FUZZ_TESTING\n";
501 s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
502 s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
503 s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
504 s << "public: ";
505 s << "explicit " << name_
506 << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
507 s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
508 s << "}}; ";
509 s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
510 s << "\n#endif";
511 }
512
GetParametersToValidate() const513 FieldList PacketDef::GetParametersToValidate() const {
514 FieldList params_to_validate;
515 for (const auto& field : GetParamList()) {
516 if (field->HasParameterValidator()) {
517 params_to_validate.AppendField(field);
518 }
519 }
520 return params_to_validate;
521 }
522
GenBuilderCreate(std::ostream & s) const523 void PacketDef::GenBuilderCreate(std::ostream& s) const {
524 s << "static std::unique_ptr<" << name_ << "Builder> Create(";
525
526 auto params = GetParamList();
527 for (std::size_t i = 0; i < params.size(); i++) {
528 params[i]->GenBuilderParameter(s);
529 if (i != params.size() - 1) {
530 s << ", ";
531 }
532 }
533 s << ") {";
534
535 // Call the constructor
536 s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
537
538 params = params.GetFieldsWithoutTypes({
539 PayloadField::kFieldType,
540 BodyField::kFieldType,
541 });
542 // Add the parameters.
543 for (std::size_t i = 0; i < params.size(); i++) {
544 if (params[i]->BuilderParameterMustBeMoved()) {
545 s << "std::move(" << params[i]->GetName() << ")";
546 } else {
547 s << params[i]->GetName();
548 }
549 if (i != params.size() - 1) {
550 s << ", ";
551 }
552 }
553
554 s << "));";
555 if (fields_.HasPayload()) {
556 s << "builder->payload_ = std::move(payload);";
557 }
558 s << "return builder;";
559 s << "}\n";
560 }
561
GenBuilderCreatePybind11(std::ostream & s) const562 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
563 s << ".def(py::init([](";
564 auto params = GetParamList();
565 std::vector<std::string> constructor_args;
566 int i = 1;
567 for (const auto& param : params) {
568 i++;
569 std::stringstream ss;
570 auto param_type = param->GetBuilderParameterType();
571 if (param_type.empty()) {
572 continue;
573 }
574 // Use shared_ptr instead of unique_ptr for the Python interface
575 if (param->BuilderParameterMustBeMoved()) {
576 param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
577 }
578 ss << param_type << " " << param->GetName();
579 constructor_args.push_back(ss.str());
580 }
581 s << util::StringJoin(",", constructor_args) << "){";
582
583 // Deal with move only args
584 for (const auto& param : params) {
585 std::stringstream ss;
586 auto param_type = param->GetBuilderParameterType();
587 if (param_type.empty()) {
588 continue;
589 }
590 if (!param->BuilderParameterMustBeMoved()) {
591 continue;
592 }
593 auto move_only_param_name = param->GetName() + "_move_only";
594 s << param_type << " " << move_only_param_name << ";";
595 if (param->IsContainerField()) {
596 // Assume single layer container and copy it
597 auto struct_type = param->GetElementField()->GetDataType();
598 struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
599 struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
600 s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
601 // Serialize each struct
602 s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
603 s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
604 s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
605 s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
606 // Parse it again
607 s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
608 s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
609 s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
610 // Push it into a new container
611 if (param->GetFieldType() == VectorField::kFieldType) {
612 s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
613 } else if (param->GetFieldType() == ArrayField::kFieldType) {
614 s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
615 } else {
616 ERROR() << param << " is not supported by Pybind11";
617 }
618 s << "}";
619 } else {
620 // Serialize the parameter and pass the bytes in a RawBuilder
621 s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
622 s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
623 s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
624 s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
625 s << move_only_param_name << " = ";
626 s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
627 }
628 }
629 s << "return " << name_ << "Builder::Create(";
630 std::vector<std::string> builder_vars;
631 for (const auto& param : params) {
632 std::stringstream ss;
633 auto param_type = param->GetBuilderParameterType();
634 if (param_type.empty()) {
635 continue;
636 }
637 auto param_name = param->GetName();
638 if (param->BuilderParameterMustBeMoved()) {
639 ss << "std::move(" << param_name << "_move_only)";
640 } else {
641 ss << param_name;
642 }
643 builder_vars.push_back(ss.str());
644 }
645 s << util::StringJoin(",", builder_vars) << ");}";
646 s << "))";
647 }
648
GenBuilderParameterChecker(std::ostream & s) const649 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
650 FieldList params_to_validate = GetParametersToValidate();
651
652 // Skip writing this function if there is nothing to validate.
653 if (params_to_validate.size() == 0) {
654 return;
655 }
656
657 // Generate function arguments.
658 s << "void CheckParameterValues(";
659 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
660 params_to_validate[i]->GenBuilderParameter(s);
661 if (i != params_to_validate.size() - 1) {
662 s << ", ";
663 }
664 }
665 s << ") {";
666
667 // Check the parameters.
668 for (const auto& field : params_to_validate) {
669 field->GenParameterValidator(s);
670 }
671 s << "}\n";
672 }
673
GenBuilderConstructor(std::ostream & s) const674 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
675 s << "explicit " << name_ << "Builder(";
676
677 // Generate the constructor parameters.
678 auto params = GetParamList().GetFieldsWithoutTypes({
679 PayloadField::kFieldType,
680 BodyField::kFieldType,
681 });
682 for (std::size_t i = 0; i < params.size(); i++) {
683 params[i]->GenBuilderParameter(s);
684 if (i != params.size() - 1) {
685 s << ", ";
686 }
687 }
688 if (params.size() > 0 || parent_constraints_.size() > 0) {
689 s << ") :";
690 } else {
691 s << ")";
692 }
693
694 // Get the list of parent params to call the parent constructor with.
695 FieldList parent_params;
696 if (parent_ != nullptr) {
697 // Pass parameters to the parent constructor
698 s << parent_->name_ << "Builder(";
699 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
700 PayloadField::kFieldType,
701 BodyField::kFieldType,
702 });
703
704 // Go through all the fields and replace constrained fields with fixed values
705 // when calling the parent constructor.
706 for (std::size_t i = 0; i < parent_params.size(); i++) {
707 const auto& field = parent_params[i];
708 const auto& constraint = parent_constraints_.find(field->GetName());
709 if (constraint != parent_constraints_.end()) {
710 if (field->GetFieldType() == ScalarField::kFieldType) {
711 s << std::get<int64_t>(constraint->second);
712 } else if (field->GetFieldType() == EnumField::kFieldType) {
713 s << std::get<std::string>(constraint->second);
714 } else {
715 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
716 }
717
718 s << "/* " << field->GetName() << "_ */";
719 } else {
720 s << field->GetName();
721 }
722
723 if (i != parent_params.size() - 1) {
724 s << ", ";
725 }
726 }
727 s << ") ";
728 }
729
730 // Build a list of parameters that excludes all parent parameters.
731 FieldList saved_params;
732 for (const auto& field : params) {
733 if (parent_params.GetField(field->GetName()) == nullptr) {
734 saved_params.AppendField(field);
735 }
736 }
737 if (parent_ != nullptr && saved_params.size() > 0) {
738 s << ",";
739 }
740 for (std::size_t i = 0; i < saved_params.size(); i++) {
741 const auto& saved_param_name = saved_params[i]->GetName();
742 if (saved_params[i]->BuilderParameterMustBeMoved()) {
743 s << saved_param_name << "_(std::move(" << saved_param_name << "))";
744 } else {
745 s << saved_param_name << "_(" << saved_param_name << ")";
746 }
747 if (i != saved_params.size() - 1) {
748 s << ",";
749 }
750 }
751 s << " {";
752
753 FieldList params_to_validate = GetParametersToValidate();
754
755 if (params_to_validate.size() > 0) {
756 s << "CheckParameterValues(";
757 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
758 s << params_to_validate[i]->GetName() << "_";
759 if (i != params_to_validate.size() - 1) {
760 s << ", ";
761 }
762 }
763 s << ");";
764 }
765
766 s << "}\n";
767 }
768
GenRustChildEnums(std::ostream & s) const769 void PacketDef::GenRustChildEnums(std::ostream& s) const {
770 if (HasChildEnums()) {
771 bool payload = fields_.HasPayload();
772 s << "#[derive(Debug)] ";
773 s << "enum " << name_ << "DataChild {";
774 for (const auto& child : children_) {
775 s << child->name_ << "(Arc<" << child->name_ << "Data>),";
776 }
777 if (payload) {
778 s << "Payload(Bytes),";
779 }
780 s << "None,";
781 s << "}\n";
782
783 s << "impl " << name_ << "DataChild {";
784 s << "fn get_total_size(&self) -> usize {";
785 s << "match self {";
786 for (const auto& child : children_) {
787 s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),";
788 }
789 if (payload) {
790 s << name_ << "DataChild::Payload(p) => p.len(),";
791 }
792 s << name_ << "DataChild::None => 0,";
793 s << "}\n";
794 s << "}\n";
795 s << "}\n";
796
797 s << "#[derive(Debug)] ";
798 s << "pub enum " << name_ << "Child {";
799 for (const auto& child : children_) {
800 s << child->name_ << "(" << child->name_ << "Packet),";
801 }
802 if (payload) {
803 s << "Payload(Bytes),";
804 }
805 s << "None,";
806 s << "}\n";
807 }
808 }
809
GenRustStructDeclarations(std::ostream & s) const810 void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
811 s << "#[derive(Debug)] ";
812 s << "struct " << name_ << "Data {";
813
814 // Generate struct fields
815 GenRustStructFieldNameAndType(s);
816 if (HasChildEnums()) {
817 s << "child: " << name_ << "DataChild,";
818 }
819 s << "}\n";
820
821 // Generate accessor struct
822 s << "#[derive(Debug, Clone)] ";
823 s << "pub struct " << name_ << "Packet {";
824 auto lineage = GetAncestors();
825 lineage.push_back(this);
826 for (auto it = lineage.begin(); it != lineage.end(); it++) {
827 auto def = *it;
828 s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,";
829 }
830 s << "}\n";
831
832 // Generate builder struct
833 s << "#[derive(Debug)] ";
834 s << "pub struct " << name_ << "Builder {";
835 auto params = GetParamList().GetFieldsWithoutTypes({
836 PayloadField::kFieldType,
837 BodyField::kFieldType,
838 });
839 for (auto param : params) {
840 s << "pub ";
841 param->GenRustNameAndType(s);
842 s << ", ";
843 }
844 if (fields_.HasPayload()) {
845 s << "pub payload: Option<Bytes>,";
846 }
847 s << "}\n";
848 }
849
GenRustStructFieldNameAndType(std::ostream & s) const850 bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
851 auto fields = fields_.GetFieldsWithoutTypes({
852 BodyField::kFieldType,
853 CountField::kFieldType,
854 PaddingField::kFieldType,
855 ReservedField::kFieldType,
856 SizeField::kFieldType,
857 PayloadField::kFieldType,
858 FixedScalarField::kFieldType,
859 });
860 if (fields.size() == 0) {
861 return false;
862 }
863 for (const auto& field : fields) {
864 field->GenRustNameAndType(s);
865 s << ", ";
866 }
867 return true;
868 }
869
GenRustStructFieldNames(std::ostream & s) const870 void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
871 auto fields = fields_.GetFieldsWithoutTypes({
872 BodyField::kFieldType,
873 CountField::kFieldType,
874 PaddingField::kFieldType,
875 ReservedField::kFieldType,
876 SizeField::kFieldType,
877 PayloadField::kFieldType,
878 FixedScalarField::kFieldType,
879 });
880 for (const auto field : fields) {
881 s << field->GetName();
882 s << ", ";
883 }
884 }
885
GenRustStructImpls(std::ostream & s) const886 void PacketDef::GenRustStructImpls(std::ostream& s) const {
887 auto packet_dep = PacketDependency(GetRootDef());
888
889 s << "impl " << name_ << "Data {";
890 // conforms function
891 s << "fn conforms(bytes: &[u8]) -> bool {";
892 GenRustConformanceCheck(s);
893
894 auto fields = fields_.GetFieldsWithTypes({
895 StructField::kFieldType,
896 });
897
898 for (auto const& field : fields) {
899 auto start_offset = GetOffsetForField(field->GetName(), false);
900 auto end_offset = GetOffsetForField(field->GetName(), true);
901
902 s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes();
903 s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }";
904 }
905
906 s << " true";
907 s << "}";
908
909 auto parse_params = packet_dep.GetDependencies(name_);
910 s << "fn parse(bytes: &[u8]";
911 for (auto field_name : parse_params) {
912 auto constraint_field = GetParamList().GetField(field_name);
913 auto constraint_type = constraint_field->GetRustDataType();
914 s << ", " << field_name << ": " << constraint_type;
915 }
916 s << ") -> Result<Self> {";
917
918 fields = fields_.GetFieldsWithoutTypes({
919 BodyField::kFieldType,
920 });
921
922 for (auto const& field : fields) {
923 auto start_field_offset = GetOffsetForField(field->GetName(), false);
924 auto end_field_offset = GetOffsetForField(field->GetName(), true);
925
926 if (start_field_offset.empty() && end_field_offset.empty()) {
927 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
928 << "no method exists to determine field location from begin() or end().\n";
929 }
930
931 field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
932 field->GenRustGetter(s, start_field_offset, end_field_offset, name_);
933 }
934
935 auto payload_field = fields_.GetFieldsWithTypes({
936 PayloadField::kFieldType,
937 });
938
939 Size payload_offset;
940
941 if (payload_field.HasPayload()) {
942 payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
943 }
944
945 if (children_.size() > 1) {
946 auto match_on_variables = packet_dep.GetChildrenDependencies(name_);
947 // If match_on_variables is empty, this means there are multiple abstract packets which will
948 // specialize to a child down the packet tree.
949 // In this case match variables will be the union of parent fields and parse params of children.
950 if (match_on_variables.empty()) {
951 for (auto& field : fields_) {
952 if (std::any_of(children_.begin(), children_.end(), [&](auto child) {
953 auto pass_me = packet_dep.GetDependencies(child->name_);
954 return std::find(pass_me.begin(), pass_me.end(), field->GetName()) != pass_me.end();
955 })) {
956 match_on_variables.push_back(field->GetName());
957 }
958 }
959 }
960
961 s << "let child = match (";
962
963 for (auto var : match_on_variables) {
964 if (var == match_on_variables[match_on_variables.size() - 1]) {
965 s << var;
966 } else {
967 s << var << ", ";
968 }
969 }
970 s << ") {";
971
972 auto get_match_val = [&](
973 std::string& match_var,
974 std::variant<int64_t,
975 std::string> constraint) -> std::string {
976 auto constraint_field = GetParamList().GetField(match_var);
977 auto constraint_type = constraint_field->GetFieldType();
978
979 if (constraint_type == EnumField::kFieldType) {
980 auto type = std::get<std::string>(constraint);
981 auto variant_name = type.substr(type.find("::") + 2, type.length());
982 auto enum_type = type.substr(0, type.find("::"));
983 return enum_type + "::" + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
984 }
985 if (constraint_type == ScalarField::kFieldType) {
986 return std::to_string(std::get<int64_t>(constraint));
987 }
988 return "_";
989 };
990
991 for (auto& child : children_) {
992 s << "(";
993 for (auto var : match_on_variables) {
994 std::string match_val = "_";
995
996 if (child->parent_constraints_.find(var) != child->parent_constraints_.end()) {
997 match_val = get_match_val(var, child->parent_constraints_[var]);
998 } else {
999 auto dcs = child->FindDescendantsWithConstraint(var);
1000 std::vector<std::string> all_match_vals;
1001 for (auto& desc : dcs) {
1002 all_match_vals.push_back(get_match_val(var, desc.second));
1003 }
1004 match_val = "";
1005 for (std::size_t i = 0; i < all_match_vals.size(); ++i) {
1006 match_val += all_match_vals[i];
1007 if (i != all_match_vals.size() - 1) {
1008 match_val += " | ";
1009 }
1010 }
1011 match_val = (match_val == "") ? "_" : match_val;
1012 }
1013
1014 if (var == match_on_variables[match_on_variables.size() - 1]) {
1015 s << match_val << ")";
1016 } else {
1017 s << match_val << ", ";
1018 }
1019 }
1020 s << " if " << child->name_ << "Data::conforms(&bytes[..])";
1021 s << " => {";
1022 s << name_ << "DataChild::";
1023 s << child->name_ << "(Arc::new(";
1024
1025 auto child_parse_params = packet_dep.GetDependencies(child->name_);
1026 if (child_parse_params.size() == 0) {
1027 s << child->name_ << "Data::parse(&bytes[..]";
1028 } else {
1029 s << child->name_ << "Data::parse(&bytes[..], ";
1030 }
1031
1032 for (auto var : child_parse_params) {
1033 if (var == child_parse_params[child_parse_params.size() - 1]) {
1034 s << var;
1035 } else {
1036 s << var << ", ";
1037 }
1038 }
1039 s << ")?))";
1040 s << "}\n";
1041 }
1042
1043 s << "(";
1044 for (int i = 1; i <= match_on_variables.size(); i++) {
1045 if (i == match_on_variables.size()) {
1046 s << "_";
1047 } else {
1048 s << "_, ";
1049 }
1050 }
1051 s << ")";
1052 s << " => return Err(Error::InvalidPacketError),";
1053 s << "};\n";
1054 } else if (children_.size() == 1) {
1055 auto child = children_.at(0);
1056 auto params = packet_dep.GetDependencies(child->name_);
1057 s << "let child = match " << child->name_ << "Data::parse(&bytes[..]";
1058 for (auto field_name : params) {
1059 s << ", " << field_name;
1060 }
1061 s << ") {";
1062 s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
1063 s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
1064 s << " },";
1065 s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,";
1066 s << " _ => return Err(Error::InvalidPacketError),";
1067 s << "};";
1068 } else if (fields_.HasPayload()) {
1069 s << "let child = if payload.len() > 0 {";
1070 s << name_ << "DataChild::Payload(Bytes::from(payload))";
1071 s << "} else {";
1072 s << name_ << "DataChild::None";
1073 s << "};";
1074 }
1075
1076 s << "Ok(Self {";
1077 fields = fields_.GetFieldsWithoutTypes({
1078 BodyField::kFieldType,
1079 CountField::kFieldType,
1080 PaddingField::kFieldType,
1081 ReservedField::kFieldType,
1082 SizeField::kFieldType,
1083 PayloadField::kFieldType,
1084 FixedScalarField::kFieldType,
1085 });
1086
1087 if (fields.size() > 0) {
1088 for (const auto& field : fields) {
1089 auto field_type = field->GetFieldType();
1090 s << field->GetName();
1091 s << ", ";
1092 }
1093 }
1094
1095 if (HasChildEnums()) {
1096 s << "child,";
1097 }
1098 s << "})\n";
1099 s << "}\n";
1100
1101 // write_to function
1102 s << "fn write_to(&self, buffer: &mut BytesMut) {";
1103 GenRustWriteToFields(s);
1104
1105 if (HasChildEnums()) {
1106 s << "match &self.child {";
1107 for (const auto& child : children_) {
1108 s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
1109 }
1110 if (fields_.HasPayload()) {
1111 auto offset = GetOffsetForField("payload");
1112 s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),";
1113 }
1114 s << name_ << "DataChild::None => {}";
1115 s << "}";
1116 }
1117
1118 s << "}\n";
1119
1120 s << "fn get_total_size(&self) -> usize {";
1121 if (HasChildEnums()) {
1122 s << "self.get_size() + self.child.get_total_size()";
1123 } else {
1124 s << "self.get_size()";
1125 }
1126 s << "}\n";
1127
1128 s << "fn get_size(&self) -> usize {";
1129 GenSizeRetVal(s);
1130 s << "}\n";
1131 s << "}\n";
1132 }
1133
GenRustAccessStructImpls(std::ostream & s) const1134 void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
1135 if (complement_ != nullptr) {
1136 auto complement_root = complement_->GetRootDef();
1137 auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1138 s << "impl CommandExpectations for " << name_ << "Packet {";
1139 s << " type ResponseType = " << complement_->name_ << "Packet;";
1140 s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1141 s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1142 << ".unwrap()";
1143 s << " }";
1144 s << "}";
1145 }
1146
1147 s << "impl Packet for " << name_ << "Packet {";
1148 auto root = GetRootDef();
1149 auto root_accessor = util::CamelCaseToUnderScore(root->name_);
1150
1151 s << "fn to_bytes(self) -> Bytes {";
1152 s << " let mut buffer = BytesMut::new();";
1153 s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);";
1154 s << " self." << root_accessor << ".write_to(&mut buffer);";
1155 s << " buffer.freeze()";
1156 s << "}\n";
1157
1158 s << "fn to_vec(self) -> Vec<u8> { self.to_bytes().to_vec() }\n";
1159 s << "}";
1160
1161 s << "impl From<" << name_ << "Packet"
1162 << "> for Bytes {\n";
1163 s << "fn from(packet: " << name_ << "Packet"
1164 << ") -> Self {\n";
1165 s << "packet.to_bytes()\n";
1166 s << "}\n";
1167 s << "}\n";
1168
1169 s << "impl From<" << name_ << "Packet"
1170 << "> for Vec<u8> {\n";
1171 s << "fn from(packet: " << name_ << "Packet"
1172 << ") -> Self {\n";
1173 s << "packet.to_vec()\n";
1174 s << "}\n";
1175 s << "}\n";
1176
1177 if (root != this) {
1178 s << "impl TryFrom<" << root->name_ << "Packet"
1179 << "> for " << name_ << "Packet {\n";
1180 s << "type Error = TryFromError;\n";
1181 s << "fn try_from(value: " << root->name_ << "Packet)"
1182 << " -> std::result::Result<Self, Self::Error> {\n";
1183 s << "Self::new(value." << root_accessor << ").map_err(TryFromError)\n", s << "}\n";
1184 s << "}\n";
1185 }
1186
1187 s << "impl " << name_ << "Packet {";
1188 if (parent_ == nullptr) {
1189 s << "pub fn parse(bytes: &[u8]) -> Result<Self> { ";
1190 s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)).unwrap())";
1191 s << "}";
1192 }
1193
1194 if (HasChildEnums()) {
1195 s << " pub fn specialize(&self) -> " << name_ << "Child {";
1196 s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
1197 for (const auto& child : children_) {
1198 s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
1199 << child->name_ << "Packet::new(self." << root_accessor << ".clone()).unwrap()),";
1200 }
1201 if (fields_.HasPayload()) {
1202 s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),";
1203 }
1204 s << name_ << "DataChild::None => " << name_ << "Child::None,";
1205 s << "}}";
1206 }
1207 auto lineage = GetAncestors();
1208 lineage.push_back(this);
1209 const ParentDef* prev = nullptr;
1210
1211 s << " fn new(root: Arc<" << root->name_ << "Data>) -> std::result::Result<Self, &'static str> {";
1212 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1213 auto def = *it;
1214 auto accessor_name = util::CamelCaseToUnderScore(def->name_);
1215 if (prev == nullptr) {
1216 s << "let " << accessor_name << " = root;";
1217 } else {
1218 s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {";
1219 s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),";
1220 s << "_ => return Err(\"inconsistent state - child was not " << def->name_ << "\"),";
1221 s << "};";
1222 }
1223 prev = def;
1224 }
1225 s << "Ok(Self {";
1226 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1227 auto def = *it;
1228 s << util::CamelCaseToUnderScore(def->name_) << ",";
1229 }
1230 s << "})}";
1231
1232 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1233 auto def = *it;
1234 auto fields = def->fields_.GetFieldsWithoutTypes({
1235 BodyField::kFieldType,
1236 CountField::kFieldType,
1237 PaddingField::kFieldType,
1238 ReservedField::kFieldType,
1239 SizeField::kFieldType,
1240 PayloadField::kFieldType,
1241 FixedScalarField::kFieldType,
1242 });
1243
1244 for (auto const& field : fields) {
1245 if (field->GetterIsByRef()) {
1246 s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{";
1247 s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1248 s << "}\n";
1249 } else {
1250 s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{";
1251 s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1252 s << "}\n";
1253 }
1254 }
1255 }
1256
1257 s << "}\n";
1258
1259 lineage = GetAncestors();
1260 for (auto it = lineage.begin(); it != lineage.end(); it++) {
1261 auto def = *it;
1262 s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {";
1263 s << " fn into(self) -> " << def->name_ << "Packet {";
1264 s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")"
1265 << ".unwrap()";
1266 s << " }";
1267 s << "}\n";
1268 }
1269 }
1270
GenRustBuilderStructImpls(std::ostream & s) const1271 void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
1272 if (complement_ != nullptr) {
1273 auto complement_root = complement_->GetRootDef();
1274 auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1275 s << "impl CommandExpectations for " << name_ << "Builder {";
1276 s << " type ResponseType = " << complement_->name_ << "Packet;";
1277 s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1278 s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"
1279 << ".unwrap()";
1280 s << " }";
1281 s << "}";
1282 }
1283
1284 s << "impl " << name_ << "Builder {";
1285 s << "pub fn build(self) -> " << name_ << "Packet {";
1286 auto lineage = GetAncestors();
1287 lineage.push_back(this);
1288 std::reverse(lineage.begin(), lineage.end());
1289
1290 auto all_constraints = GetAllConstraints();
1291
1292 const ParentDef* prev = nullptr;
1293 for (auto ancestor : lineage) {
1294 auto fields = ancestor->fields_.GetFieldsWithoutTypes({
1295 BodyField::kFieldType,
1296 CountField::kFieldType,
1297 PaddingField::kFieldType,
1298 ReservedField::kFieldType,
1299 SizeField::kFieldType,
1300 PayloadField::kFieldType,
1301 FixedScalarField::kFieldType,
1302 });
1303
1304 auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_);
1305 s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {";
1306 for (auto field : fields) {
1307 auto constraint = all_constraints.find(field->GetName());
1308 s << field->GetName() << ": ";
1309 if (constraint != all_constraints.end()) {
1310 if (field->GetFieldType() == ScalarField::kFieldType) {
1311 s << std::get<int64_t>(constraint->second);
1312 } else if (field->GetFieldType() == EnumField::kFieldType) {
1313 auto value = std::get<std::string>(constraint->second);
1314 auto constant = value.substr(value.find("::") + 2, std::string::npos);
1315 s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant);
1316 ;
1317 } else {
1318 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
1319 }
1320 } else {
1321 s << "self." << field->GetName();
1322 }
1323 s << ", ";
1324 }
1325 if (ancestor->HasChildEnums()) {
1326 if (prev == nullptr) {
1327 if (ancestor->fields_.HasPayload()) {
1328 s << "child: match self.payload { ";
1329 s << "None => " << name_ << "DataChild::None,";
1330 s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),";
1331 s << "},";
1332 } else {
1333 s << "child: " << name_ << "DataChild::None,";
1334 }
1335 } else {
1336 s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "("
1337 << util::CamelCaseToUnderScore(prev->name_) << "),";
1338 }
1339 }
1340 s << "});";
1341 prev = ancestor;
1342 }
1343
1344 s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ").unwrap()";
1345 s << "}\n";
1346
1347 s << "}\n";
1348 for (const auto ancestor : GetAncestors()) {
1349 s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {";
1350 s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }";
1351 s << "}\n";
1352 }
1353 }
1354
GenRustBuilderTest(std::ostream & s) const1355 void PacketDef::GenRustBuilderTest(std::ostream& s) const {
1356 auto lineage = GetAncestors();
1357 lineage.push_back(this);
1358 if (!lineage.empty() && !test_cases_.empty()) {
1359 s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { ";
1360 s << "($($name:ident: $byte_string:expr,)*) => {";
1361 s << "$(";
1362 s << "\n#[test]\n";
1363 s << "pub fn $name() { ";
1364 s << "let raw_bytes = $byte_string;";
1365 for (size_t i = 0; i < lineage.size(); i++) {
1366 s << "/* (" << i << ") */\n";
1367 if (i == 0) {
1368 s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {";
1369 s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1370 s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1371 } else if (i != lineage.size() - 1) {
1372 s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(";
1373 s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1374 s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1375 } else {
1376 s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {";
1377 s << "let rebuilder = " << lineage[i]->name_ << "Builder {";
1378 FieldList params = GetParamList();
1379 if (params.HasBody()) {
1380 ERROR() << "Packets with body fields can't be auto-tested. Test a child.";
1381 }
1382 for (const auto param : params) {
1383 s << param->GetName() << " : packet.";
1384 if (param->GetFieldType() == VectorField::kFieldType) {
1385 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1386 } else if (param->GetFieldType() == ArrayField::kFieldType) {
1387 const auto array_param = static_cast<const ArrayField*>(param);
1388 const auto element_field = array_param->GetElementField();
1389 if (element_field->GetFieldType() == StructField::kFieldType) {
1390 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1391 } else {
1392 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1393 }
1394 } else if (param->GetFieldType() == StructField::kFieldType) {
1395 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1396 } else {
1397 s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),";
1398 }
1399 }
1400 s << "};";
1401 s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();";
1402 s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();";
1403 s << "assert_eq!(rebuilder_bytes, raw_bytes);";
1404 s << "}";
1405 }
1406 }
1407 for (size_t i = 1; i < lineage.size(); i++) {
1408 s << "_ => {";
1409 s << "println!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_);
1410 s << "{:02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); ";
1411 s << "}}}";
1412 }
1413
1414 s << ",";
1415 s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),";
1416 s << "}";
1417 s << "}";
1418 s << ")*";
1419 s << "}";
1420 s << "}";
1421
1422 s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { ";
1423 int number = 0;
1424 for (const auto& test_case : test_cases_) {
1425 s << util::CamelCaseToUnderScore(name_) << "_builder_test_";
1426 s << std::setfill('0') << std::setw(2) << number++ << ": ";
1427 s << "b\"" << test_case << "\",";
1428 }
1429 s << "}";
1430 s << "\n";
1431 }
1432 }
1433
GenRustDef(std::ostream & s) const1434 void PacketDef::GenRustDef(std::ostream& s) const {
1435 GenRustChildEnums(s);
1436 GenRustStructDeclarations(s);
1437 GenRustStructImpls(s);
1438 GenRustAccessStructImpls(s);
1439 GenRustBuilderStructImpls(s);
1440 GenRustBuilderTest(s);
1441 }
1442