/* * Copyright 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "packet_def.h" #include #include #include #include "fields/all_fields.h" #include "util.h" PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {} PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {} PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const { return nullptr; // Packets can't be fields } void PacketDef::GenParserDefinition(std::ostream& s) const { s << "class " << name_ << "View"; if (parent_ != nullptr) { s << " : public " << parent_->name_ << "View {"; } else { s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {"; } s << " public:"; // Specialize function if (parent_ != nullptr) { s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)"; s << "{ return " << name_ << "View(std::move(parent)); }"; } else { s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) "; s << "{ return " << name_ << "View(std::move(packet)); }"; } GenTestingParserFromBytes(s); std::set fixed_types = { FixedScalarField::kFieldType, FixedEnumField::kFieldType, }; // Print all of the public fields which are all the fields minus the fixed fields. const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types); bool has_fixed_fields = public_fields.size() != fields_.size(); for (const auto& field : public_fields) { GenParserFieldGetter(s, field); s << "\n"; } GenValidator(s); s << "\n"; s << " public:"; GenParserToString(s); s << "\n"; s << " protected:\n"; // Constructor from a View if (parent_ != nullptr) { s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)"; s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }"; } else { s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) "; s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}"; } // Print the private fields which are the fixed fields. if (has_fixed_fields) { const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types); s << " private:\n"; for (const auto& field : private_fields) { GenParserFieldGetter(s, field); s << "\n"; } } s << "};\n"; } void PacketDef::GenTestingParserFromBytes(std::ostream& s) const { s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n"; s << "static " << name_ << "View FromBytes(std::vector bytes) {"; s << "auto vec = std::make_shared>(bytes);"; s << "return " << name_ << "View::Create("; auto ancestor_ptr = parent_; size_t parent_parens = 0; while (ancestor_ptr != nullptr) { s << ancestor_ptr->name_ << "View::Create("; parent_parens++; ancestor_ptr = ancestor_ptr->parent_; } s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)"; for (size_t i = 0; i < parent_parens; i++) { s << ")"; } s << ");"; s << "}"; s << "\n#endif\n"; } void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const { s << "py::class_<" << name_ << "View"; if (parent_ != nullptr) { s << ", " << parent_->name_ << "View"; } else { s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>"; } s << ">(m, \"" << name_ << "View\")"; if (parent_ != nullptr) { s << ".def(py::init([](" << parent_->name_ << "View parent) {"; } else { s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {"; } s << "auto view =" << name_ << "View::Create(std::move(parent));"; s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }"; s << "return view; }))"; s << ".def(py::init(&" << name_ << "View::Create))"; std::set protected_field_types = { FixedScalarField::kFieldType, FixedEnumField::kFieldType, SizeField::kFieldType, CountField::kFieldType, }; const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types); for (const auto& field : public_fields) { auto getter_func_name = field->GetGetterFunctionName(); if (getter_func_name.empty()) { continue; } s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")"; } s << ".def(\"IsValid\", &" << name_ << "View::IsValid)"; s << ";\n"; } void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const { // Start field offset auto start_field_offset = GetOffsetForField(field->GetName(), false); auto end_field_offset = GetOffsetForField(field->GetName(), true); if (start_field_offset.empty() && end_field_offset.empty()) { ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, " << "no method exists to determine field location from begin() or end().\n"; } field->GenGetter(s, start_field_offset, end_field_offset); } TypeDef::Type PacketDef::GetDefinitionType() const { return TypeDef::Type::PACKET; } void PacketDef::GenValidator(std::ostream& s) const { // Get the static offset for all of our fields. int bits_size = 0; for (const auto& field : fields_) { if (field->GetFieldType() != PaddingField::kFieldType) { bits_size += field->GetSize().bits(); } } // Write the function declaration. s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {"; s << "if (was_validated_) { return true; } "; s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }"; s << "}"; s << "protected:"; s << "virtual bool IsValid_() const {"; if (parent_ != nullptr) { s << "if (!" << parent_->name_ << "View::IsValid_()) { return false; } "; } // Offset by the parents known size. We know that any dynamic fields can // already be called since the parent must have already been validated by // this point. auto parent_size = Size(0); if (parent_ != nullptr) { parent_size = parent_->GetSize(true); } s << "auto it = begin() + (" << parent_size << ") / 8;"; // Check if you can extract the static fields. // At this point you know you can use the size getters without crashing // as long as they follow the instruction that size fields cant come before // their corrisponding variable length field. s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;"; s << "if (it > end()) return false;"; // For any variable length fields, use their size check. for (const auto& field : fields_) { if (field->GetFieldType() == ChecksumStartField::kFieldType) { auto offset = GetOffsetForField(field->GetName(), false); if (!offset.empty()) { s << "size_t sum_index = (" << offset << ") / 8;"; } else { offset = GetOffsetForField(field->GetName(), true); if (offset.empty()) { ERROR(field) << "Checksum Start Field offset can not be determined."; } s << "size_t sum_index = size() - (" << offset << ") / 8;"; } const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName(); const auto& started_field = fields_.GetField(field_name); if (started_field == nullptr) { ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName() << ")"; } auto end_offset = GetOffsetForField(started_field->GetName(), false); if (!end_offset.empty()) { s << "size_t end_sum_index = (" << end_offset << ") / 8;"; } else { end_offset = GetOffsetForField(started_field->GetName(), true); if (end_offset.empty()) { ERROR(started_field) << "Checksum Field end_offset can not be determined."; } s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;"; } if (is_little_endian_) { s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);"; } else { s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);"; } s << started_field->GetDataType() << " checksum;"; s << "checksum.Initialize();"; s << "for (uint8_t byte : checksum_view) { "; s << "checksum.AddByte(byte);}"; s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<" << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }"; continue; } auto field_size = field->GetSize(); // Fixed size fields have already been handled. if (!field_size.has_dynamic()) { continue; } // Custom fields with dynamic size must have the offset for the field passed in as well // as the end iterator so that they may ensure that they don't try to read past the end. // Custom fields with fixed sizes will be handled in the static offset checking. if (field->GetFieldType() == CustomField::kFieldType) { // Check if we can determine offset from begin(), otherwise error because by this point, // the size of the custom field is unknown and can't be subtracted from end() to get the // offset. auto offset = GetOffsetForField(field->GetName(), false); if (offset.empty()) { ERROR(field) << "Custom Field offset can not be determined from begin()."; } if (offset.bits() % 8 != 0) { ERROR(field) << "Custom fields must be byte aligned."; } // Custom fields are special as their size field takes an argument. const auto& custom_size_var = field->GetName() + "_size"; s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string(); s << "(begin() + (" << offset << ") / 8);"; s << "if (!" << custom_size_var << ".has_value()) { return false; }"; s << "it += *" << custom_size_var << ";"; s << "if (it > end()) return false;"; continue; } else { s << "it += (" << field_size.dynamic_string() << ") / 8;"; s << "if (it > end()) return false;"; } } // Validate constraints after validating the size if (parent_constraints_.size() > 0 && parent_ == nullptr) { ERROR() << "Can't have a constraint on a NULL parent"; } for (const auto& constraint : parent_constraints_) { s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != "; const auto& field = parent_->GetParamList().GetField(constraint.first); if (field->GetFieldType() == ScalarField::kFieldType) { s << std::get(constraint.second); } else { s << std::get(constraint.second); } s << ") return false;"; } // Validate the packets fields last for (const auto& field : fields_) { field->GenValidator(s); s << "\n"; } s << "return true;"; s << "}\n"; if (parent_ == nullptr) { s << "bool was_validated_{false};\n"; } } void PacketDef::GenParserToString(std::ostream& s) const { s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {"; s << "std::stringstream ss;"; s << "ss << std::showbase << std::hex << \"" << name_ << " { \";"; if (fields_.size() > 0) { s << "ss << \"\" "; bool firstfield = true; for (const auto& field : fields_) { if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType || field->GetFieldType() == ChecksumStartField::kFieldType) continue; s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << "; field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()"); if (firstfield) { firstfield = false; } } s << ";"; } s << "ss << \" }\";"; s << "return ss.str();"; s << "}\n"; } void PacketDef::GenBuilderDefinition(std::ostream& s) const { s << "class " << name_ << "Builder"; if (parent_ != nullptr) { s << " : public " << parent_->name_ << "Builder"; } else { if (is_little_endian_) { s << " : public PacketBuilder"; } else { s << " : public PacketBuilder"; } } s << " {"; s << " public:"; s << " virtual ~" << name_ << "Builder() = default;"; if (!fields_.HasBody()) { GenBuilderCreate(s); s << "\n"; GenTestingFromView(s); s << "\n"; } GenSerialize(s); s << "\n"; GenSize(s); s << "\n"; s << " protected:\n"; GenBuilderConstructor(s); s << "\n"; GenBuilderParameterChecker(s); s << "\n"; GenMembers(s); s << "};\n"; GenTestDefine(s); s << "\n"; GenFuzzTestDefine(s); s << "\n"; } void PacketDef::GenTestingFromView(std::ostream& s) const { s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n"; s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {"; s << "return " << name_ << "Builder::Create("; FieldList params = GetParamList().GetFieldsWithoutTypes({ BodyField::kFieldType, }); for (std::size_t i = 0; i < params.size(); i++) { params[i]->GenBuilderParameterFromView(s); if (i != params.size() - 1) { s << ", "; } } s << ");"; s << "}"; s << "\n#endif\n"; } void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const { s << "py::class_<" << name_ << "Builder"; if (parent_ != nullptr) { s << ", " << parent_->name_ << "Builder"; } else { if (is_little_endian_) { s << ", PacketBuilder"; } else { s << ", PacketBuilder"; } } s << ", std::shared_ptr<" << name_ << "Builder>"; s << ">(m, \"" << name_ << "Builder\")"; if (!fields_.HasBody()) { GenBuilderCreatePybind11(s); } s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){"; s << "std::vector bytes;"; s << "BitInserter bi(bytes);"; s << "builder.Serialize(bi);"; s << "return bytes;})"; s << ";\n"; } void PacketDef::GenTestDefine(std::ostream& s) const { s << "#ifdef PACKET_TESTING\n"; s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)"; s << "class " << name_ << "ReflectionTest : public testing::TestWithParam> { "; s << "public: "; s << "void CompareBytes(std::vector captured_packet) {"; s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);"; s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());"; s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}"; s << "ASSERT_TRUE(view.IsValid());"; s << "auto packet = " << name_ << "Builder::FromView(view);"; s << "std::shared_ptr> packet_bytes = std::make_shared>();"; s << "packet_bytes->reserve(packet->size());"; s << "BitInserter it(*packet_bytes);"; s << "packet->Serialize(it);"; s << "ASSERT_EQ(*packet_bytes, captured_packet);"; s << "}"; s << "};"; s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {"; s << "CompareBytes(GetParam());"; s << "}"; s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, "; s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))"; int i = 0; for (const auto& bytes : test_cases_) { s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";"; s << "std::vector " << name_ << "_test_vec_" << i << "("; s << name_ << "_test_bytes_" << i << ","; s << name_ << "_test_bytes_" << i << " + sizeof("; s << name_ << "_test_bytes_" << i << ") - 1);"; i++; } if (!test_cases_.empty()) { i = 0; s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest("; for (auto bytes : test_cases_) { if (i > 0) { s << ","; } s << name_ << "_test_vec_" << i++; } s << ");"; } s << "\n#endif"; } void PacketDef::GenFuzzTestDefine(std::ostream& s) const { s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n"; s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() "; s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {"; s << "auto vec = std::vector(data, data + size);"; s << name_ << "View view = " << name_ << "View::FromBytes(vec);"; s << "if (!view.IsValid()) { return; }"; s << "auto packet = " << name_ << "Builder::FromView(view);"; s << "std::shared_ptr> packet_bytes = std::make_shared>();"; s << "packet_bytes->reserve(packet->size());"; s << "BitInserter it(*packet_bytes);"; s << "packet->Serialize(it);"; s << "}"; s << "\n#endif\n"; s << "#ifdef PACKET_FUZZ_TESTING\n"; s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) "; s << "DEFINE_" << name_ << "ReflectionFuzzTest();"; s << " class " << name_ << "ReflectionFuzzTestRegistrant {"; s << "public: "; s << "explicit " << name_ << "ReflectionFuzzTestRegistrant(std::vector& fuzz_test_registry) {"; s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);"; s << "}}; "; s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);"; s << "\n#endif"; } FieldList PacketDef::GetParametersToValidate() const { FieldList params_to_validate; for (const auto& field : GetParamList()) { if (field->HasParameterValidator()) { params_to_validate.AppendField(field); } } return params_to_validate; } void PacketDef::GenBuilderCreate(std::ostream& s) const { s << "static std::unique_ptr<" << name_ << "Builder> Create("; auto params = GetParamList(); for (std::size_t i = 0; i < params.size(); i++) { params[i]->GenBuilderParameter(s); if (i != params.size() - 1) { s << ", "; } } s << ") {"; // Call the constructor s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder("; params = params.GetFieldsWithoutTypes({ PayloadField::kFieldType, BodyField::kFieldType, }); // Add the parameters. for (std::size_t i = 0; i < params.size(); i++) { if (params[i]->BuilderParameterMustBeMoved()) { s << "std::move(" << params[i]->GetName() << ")"; } else { s << params[i]->GetName(); } if (i != params.size() - 1) { s << ", "; } } s << "));"; if (fields_.HasPayload()) { s << "builder->payload_ = std::move(payload);"; } s << "return builder;"; s << "}\n"; } void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const { s << ".def(py::init([]("; auto params = GetParamList(); std::vector constructor_args; int i = 1; for (const auto& param : params) { i++; std::stringstream ss; auto param_type = param->GetBuilderParameterType(); if (param_type.empty()) { continue; } // Use shared_ptr instead of unique_ptr for the Python interface if (param->BuilderParameterMustBeMoved()) { param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr"); } ss << param_type << " " << param->GetName(); constructor_args.push_back(ss.str()); } s << util::StringJoin(",", constructor_args) << "){"; // Deal with move only args for (const auto& param : params) { std::stringstream ss; auto param_type = param->GetBuilderParameterType(); if (param_type.empty()) { continue; } if (!param->BuilderParameterMustBeMoved()) { continue; } auto move_only_param_name = param->GetName() + "_move_only"; s << param_type << " " << move_only_param_name << ";"; if (param->IsContainerField()) { // Assume single layer container and copy it auto struct_type = param->GetElementField()->GetDataType(); struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", ""); struct_type = util::StringFindAndReplaceAll(struct_type, ">", ""); s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {"; // Serialize each struct s << "auto " << param->GetName() + "_bytes = std::make_shared>();"; s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());"; s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);"; s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);"; // Parse it again s << "auto " << param->GetName() << "_view = PacketView(" << param->GetName() << "_bytes);"; s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = "; s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());"; // Push it into a new container if (param->GetFieldType() == VectorField::kFieldType) { s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));"; } else if (param->GetFieldType() == ArrayField::kFieldType) { s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);"; } else { ERROR() << param << " is not supported by Pybind11"; } s << "}"; } else { // Serialize the parameter and pass the bytes in a RawBuilder s << "std::vector " << param->GetName() + "_bytes;"; s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());"; s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);"; s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);"; s << move_only_param_name << " = "; s << "std::make_unique(" << param->GetName() << "_bytes);"; } } s << "return " << name_ << "Builder::Create("; std::vector builder_vars; for (const auto& param : params) { std::stringstream ss; auto param_type = param->GetBuilderParameterType(); if (param_type.empty()) { continue; } auto param_name = param->GetName(); if (param->BuilderParameterMustBeMoved()) { ss << "std::move(" << param_name << "_move_only)"; } else { ss << param_name; } builder_vars.push_back(ss.str()); } s << util::StringJoin(",", builder_vars) << ");}"; s << "))"; } void PacketDef::GenBuilderParameterChecker(std::ostream& s) const { FieldList params_to_validate = GetParametersToValidate(); // Skip writing this function if there is nothing to validate. if (params_to_validate.size() == 0) { return; } // Generate function arguments. s << "void CheckParameterValues("; for (std::size_t i = 0; i < params_to_validate.size(); i++) { params_to_validate[i]->GenBuilderParameter(s); if (i != params_to_validate.size() - 1) { s << ", "; } } s << ") {"; // Check the parameters. for (const auto& field : params_to_validate) { field->GenParameterValidator(s); } s << "}\n"; } void PacketDef::GenBuilderConstructor(std::ostream& s) const { s << "explicit " << name_ << "Builder("; // Generate the constructor parameters. auto params = GetParamList().GetFieldsWithoutTypes({ PayloadField::kFieldType, BodyField::kFieldType, }); for (std::size_t i = 0; i < params.size(); i++) { params[i]->GenBuilderParameter(s); if (i != params.size() - 1) { s << ", "; } } if (params.size() > 0 || parent_constraints_.size() > 0) { s << ") :"; } else { s << ")"; } // Get the list of parent params to call the parent constructor with. FieldList parent_params; if (parent_ != nullptr) { // Pass parameters to the parent constructor s << parent_->name_ << "Builder("; parent_params = parent_->GetParamList().GetFieldsWithoutTypes({ PayloadField::kFieldType, BodyField::kFieldType, }); // Go through all the fields and replace constrained fields with fixed values // when calling the parent constructor. for (std::size_t i = 0; i < parent_params.size(); i++) { const auto& field = parent_params[i]; const auto& constraint = parent_constraints_.find(field->GetName()); if (constraint != parent_constraints_.end()) { if (field->GetFieldType() == ScalarField::kFieldType) { s << std::get(constraint->second); } else if (field->GetFieldType() == EnumField::kFieldType) { s << std::get(constraint->second); } else { ERROR(field) << "Constraints on non enum/scalar fields should be impossible."; } s << "/* " << field->GetName() << "_ */"; } else { s << field->GetName(); } if (i != parent_params.size() - 1) { s << ", "; } } s << ") "; } // Build a list of parameters that excludes all parent parameters. FieldList saved_params; for (const auto& field : params) { if (parent_params.GetField(field->GetName()) == nullptr) { saved_params.AppendField(field); } } if (parent_ != nullptr && saved_params.size() > 0) { s << ","; } for (std::size_t i = 0; i < saved_params.size(); i++) { const auto& saved_param_name = saved_params[i]->GetName(); if (saved_params[i]->BuilderParameterMustBeMoved()) { s << saved_param_name << "_(std::move(" << saved_param_name << "))"; } else { s << saved_param_name << "_(" << saved_param_name << ")"; } if (i != saved_params.size() - 1) { s << ","; } } s << " {"; FieldList params_to_validate = GetParametersToValidate(); if (params_to_validate.size() > 0) { s << "CheckParameterValues("; for (std::size_t i = 0; i < params_to_validate.size(); i++) { s << params_to_validate[i]->GetName() << "_"; if (i != params_to_validate.size() - 1) { s << ", "; } } s << ");"; } s << "}\n"; } void PacketDef::GenRustChildEnums(std::ostream& s) const { if (HasChildEnums()) { bool payload = fields_.HasPayload(); s << "#[derive(Debug)] "; s << "enum " << name_ << "DataChild {"; for (const auto& child : children_) { s << child->name_ << "(Arc<" << child->name_ << "Data>),"; } if (payload) { s << "Payload(Bytes),"; } s << "None,"; s << "}\n"; s << "impl " << name_ << "DataChild {"; s << "fn get_total_size(&self) -> usize {"; s << "match self {"; for (const auto& child : children_) { s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),"; } if (payload) { s << name_ << "DataChild::Payload(p) => p.len(),"; } s << name_ << "DataChild::None => 0,"; s << "}\n"; s << "}\n"; s << "}\n"; s << "#[derive(Debug)] "; s << "pub enum " << name_ << "Child {"; for (const auto& child : children_) { s << child->name_ << "(" << child->name_ << "Packet),"; } if (payload) { s << "Payload(Bytes),"; } s << "None,"; s << "}\n"; } } void PacketDef::GenRustStructDeclarations(std::ostream& s) const { s << "#[derive(Debug)] "; s << "struct " << name_ << "Data {"; // Generate struct fields GenRustStructFieldNameAndType(s); if (HasChildEnums()) { s << "child: " << name_ << "DataChild,"; } s << "}\n"; // Generate accessor struct s << "#[derive(Debug, Clone)] "; s << "pub struct " << name_ << "Packet {"; auto lineage = GetAncestors(); lineage.push_back(this); for (auto it = lineage.begin(); it != lineage.end(); it++) { auto def = *it; s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,"; } s << "}\n"; // Generate builder struct s << "#[derive(Debug)] "; s << "pub struct " << name_ << "Builder {"; auto params = GetParamList().GetFieldsWithoutTypes({ PayloadField::kFieldType, BodyField::kFieldType, }); for (auto param : params) { s << "pub "; param->GenRustNameAndType(s); s << ", "; } if (fields_.HasPayload()) { s << "pub payload: Option,"; } s << "}\n"; } bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const { auto fields = fields_.GetFieldsWithoutTypes({ BodyField::kFieldType, CountField::kFieldType, PaddingField::kFieldType, ReservedField::kFieldType, SizeField::kFieldType, PayloadField::kFieldType, FixedScalarField::kFieldType, }); if (fields.size() == 0) { return false; } for (const auto& field : fields) { field->GenRustNameAndType(s); s << ", "; } return true; } void PacketDef::GenRustStructFieldNames(std::ostream& s) const { auto fields = fields_.GetFieldsWithoutTypes({ BodyField::kFieldType, CountField::kFieldType, PaddingField::kFieldType, ReservedField::kFieldType, SizeField::kFieldType, PayloadField::kFieldType, FixedScalarField::kFieldType, }); for (const auto field : fields) { s << field->GetName(); s << ", "; } } void PacketDef::GenRustStructImpls(std::ostream& s) const { s << "impl " << name_ << "Data {"; // conforms function s << "fn conforms(bytes: &[u8]) -> bool {"; GenRustConformanceCheck(s); auto fields = fields_.GetFieldsWithTypes({ StructField::kFieldType, }); for (auto const& field : fields) { auto start_offset = GetOffsetForField(field->GetName(), false); auto end_offset = GetOffsetForField(field->GetName(), true); s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes(); s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }"; } s << " true"; s << "}"; // parse function if (parent_constraints_.empty() && children_.size() > 1 && parent_ != nullptr) { auto constraint = FindConstraintField(); auto constraint_field = GetParamList().GetField(constraint); auto constraint_type = constraint_field->GetRustDataType(); s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type << ") -> Result {"; } else { s << "fn parse(bytes: &[u8]) -> Result {"; } fields = fields_.GetFieldsWithoutTypes({ BodyField::kFieldType, }); for (auto const& field : fields) { auto start_field_offset = GetOffsetForField(field->GetName(), false); auto end_field_offset = GetOffsetForField(field->GetName(), true); if (start_field_offset.empty() && end_field_offset.empty()) { ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, " << "no method exists to determine field location from begin() or end().\n"; } field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_); field->GenRustGetter(s, start_field_offset, end_field_offset); } auto payload_field = fields_.GetFieldsWithTypes({ PayloadField::kFieldType, }); Size payload_offset; if (payload_field.HasPayload()) { payload_offset = GetOffsetForField(payload_field[0]->GetName(), false); } auto constraint_name = FindConstraintField(); auto constrained_descendants = FindDescendantsWithConstraint(constraint_name); if (children_.size() > 1) { s << "let child = match " << constraint_name << " {"; for (const auto& desc : constrained_descendants) { auto desc_path = FindPathToDescendant(desc.first->name_); std::reverse(desc_path.begin(), desc_path.end()); auto constraint_field = GetParamList().GetField(constraint_name); auto constraint_type = constraint_field->GetFieldType(); if (constraint_type == EnumField::kFieldType) { auto type = std::get(desc.second); auto variant_name = type.substr(type.find("::") + 2, type.length()); auto enum_type = type.substr(0, type.find("::")); auto enum_variant = enum_type + "::" + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name)); s << enum_variant; s << " if " << desc_path[0]->name_ << "Data::conforms(&bytes[..])"; s << " => {"; s << name_ << "DataChild::"; s << desc_path[0]->name_ << "(Arc::new("; if (desc_path[0]->parent_constraints_.empty()) { s << desc_path[0]->name_ << "Data::parse(&bytes[..]"; s << ", " << enum_variant << ")?))"; } else { s << desc_path[0]->name_ << "Data::parse(&bytes[..])?))"; } } else if (constraint_type == ScalarField::kFieldType) { s << std::get(desc.second) << " => {"; s << "unimplemented!();"; } s << "}\n"; } if (!constrained_descendants.empty()) { s << "v => return Err(Error::ConstraintOutOfBounds{field: \"" << constraint_name << "\".to_string(), value: v as u64}),"; } s << "};\n"; } else if (children_.size() == 1) { auto child = children_.at(0); s << "let child = match " << child->name_ << "Data::parse(&bytes[..]) {"; s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {"; s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))"; s << " },"; s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,"; s << " _ => return Err(Error::InvalidPacketError),"; s << "};"; } else if (fields_.HasPayload()) { s << "let child = if payload.len() > 0 {"; s << name_ << "DataChild::Payload(Bytes::from(payload))"; s << "} else {"; s << name_ << "DataChild::None"; s << "};"; } s << "Ok(Self {"; fields = fields_.GetFieldsWithoutTypes({ BodyField::kFieldType, CountField::kFieldType, PaddingField::kFieldType, ReservedField::kFieldType, SizeField::kFieldType, PayloadField::kFieldType, FixedScalarField::kFieldType, }); if (fields.size() > 0) { for (const auto& field : fields) { auto field_type = field->GetFieldType(); s << field->GetName(); s << ", "; } } if (HasChildEnums()) { s << "child,"; } s << "})\n"; s << "}\n"; // write_to function s << "fn write_to(&self, buffer: &mut BytesMut) {"; GenRustWriteToFields(s); if (HasChildEnums()) { s << "match &self.child {"; for (const auto& child : children_) { s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),"; } if (fields_.HasPayload()) { auto offset = GetOffsetForField("payload"); s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),"; } s << name_ << "DataChild::None => {}"; s << "}"; } s << "}\n"; s << "fn get_total_size(&self) -> usize {"; if (HasChildEnums()) { s << "self.get_size() + self.child.get_total_size()"; } else { s << "self.get_size()"; } s << "}\n"; s << "fn get_size(&self) -> usize {"; GenSizeRetVal(s); s << "}\n"; s << "}\n"; } void PacketDef::GenRustAccessStructImpls(std::ostream& s) const { if (complement_ != nullptr) { auto complement_root = complement_->GetRootDef(); auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_); s << "impl CommandExpectations for " << name_ << "Packet {"; s << " type ResponseType = " << complement_->name_ << "Packet;"; s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { "; s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"; s << " }"; s << "}"; } s << "impl Packet for " << name_ << "Packet {"; auto root = GetRootDef(); auto root_accessor = util::CamelCaseToUnderScore(root->name_); s << "fn to_bytes(self) -> Bytes {"; s << " let mut buffer = BytesMut::new();"; s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);"; s << " self." << root_accessor << ".write_to(&mut buffer);"; s << " buffer.freeze()"; s << "}\n"; s << "fn to_vec(self) -> Vec { self.to_bytes().to_vec() }\n"; s << "}"; s << "impl " << name_ << "Packet {"; if (parent_ == nullptr) { s << "pub fn parse(bytes: &[u8]) -> Result { "; s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)))"; s << "}"; } if (HasChildEnums()) { s << " pub fn specialize(&self) -> " << name_ << "Child {"; s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {"; for (const auto& child : children_) { s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "(" << child->name_ << "Packet::new(self." << root_accessor << ".clone())),"; } if (fields_.HasPayload()) { s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),"; } s << name_ << "DataChild::None => " << name_ << "Child::None,"; s << "}}"; } auto lineage = GetAncestors(); lineage.push_back(this); const ParentDef* prev = nullptr; s << " fn new(root: Arc<" << root->name_ << "Data>) -> Self {"; for (auto it = lineage.begin(); it != lineage.end(); it++) { auto def = *it; auto accessor_name = util::CamelCaseToUnderScore(def->name_); if (prev == nullptr) { s << "let " << accessor_name << " = root;"; } else { s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {"; s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),"; s << "_ => panic!(\"inconsistent state - child was not " << def->name_ << "\"),"; s << "};"; } prev = def; } s << "Self {"; for (auto it = lineage.begin(); it != lineage.end(); it++) { auto def = *it; s << util::CamelCaseToUnderScore(def->name_) << ","; } s << "}}"; for (auto it = lineage.begin(); it != lineage.end(); it++) { auto def = *it; auto fields = def->fields_.GetFieldsWithoutTypes({ BodyField::kFieldType, CountField::kFieldType, PaddingField::kFieldType, ReservedField::kFieldType, SizeField::kFieldType, PayloadField::kFieldType, FixedScalarField::kFieldType, }); for (auto const& field : fields) { if (field->GetterIsByRef()) { s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{"; s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName(); s << "}\n"; } else { s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{"; s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName(); s << "}\n"; } } } s << "}\n"; lineage = GetAncestors(); for (auto it = lineage.begin(); it != lineage.end(); it++) { auto def = *it; s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {"; s << " fn into(self) -> " << def->name_ << "Packet {"; s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")"; s << " }"; s << "}\n"; } } void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const { if (complement_ != nullptr) { auto complement_root = complement_->GetRootDef(); auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_); s << "impl CommandExpectations for " << name_ << "Builder {"; s << " type ResponseType = " << complement_->name_ << "Packet;"; s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { "; s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"; s << " }"; s << "}"; } s << "impl " << name_ << "Builder {"; s << "pub fn build(self) -> " << name_ << "Packet {"; auto lineage = GetAncestors(); lineage.push_back(this); std::reverse(lineage.begin(), lineage.end()); auto all_constraints = GetAllConstraints(); const ParentDef* prev = nullptr; for (auto ancestor : lineage) { auto fields = ancestor->fields_.GetFieldsWithoutTypes({ BodyField::kFieldType, CountField::kFieldType, PaddingField::kFieldType, ReservedField::kFieldType, SizeField::kFieldType, PayloadField::kFieldType, FixedScalarField::kFieldType, }); auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_); s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {"; for (auto field : fields) { auto constraint = all_constraints.find(field->GetName()); s << field->GetName() << ": "; if (constraint != all_constraints.end()) { if (field->GetFieldType() == ScalarField::kFieldType) { s << std::get(constraint->second); } else if (field->GetFieldType() == EnumField::kFieldType) { auto value = std::get(constraint->second); auto constant = value.substr(value.find("::") + 2, std::string::npos); s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant); ; } else { ERROR(field) << "Constraints on non enum/scalar fields should be impossible."; } } else { s << "self." << field->GetName(); } s << ", "; } if (ancestor->HasChildEnums()) { if (prev == nullptr) { if (ancestor->fields_.HasPayload()) { s << "child: match self.payload { "; s << "None => " << name_ << "DataChild::None,"; s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),"; s << "},"; } else { s << "child: " << name_ << "DataChild::None,"; } } else { s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "(" << util::CamelCaseToUnderScore(prev->name_) << "),"; } } s << "});"; prev = ancestor; } s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ")"; s << "}\n"; s << "}\n"; for (const auto ancestor : GetAncestors()) { s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {"; s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }"; s << "}\n"; } } void PacketDef::GenRustBuilderTest(std::ostream& s) const { auto lineage = GetAncestors(); lineage.push_back(this); if (!lineage.empty() && !test_cases_.empty()) { s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { "; s << "($($name:ident: $byte_string:expr,)*) => {"; s << "$("; s << "\n#[test]\n"; s << "pub fn $name() { "; s << "let raw_bytes = $byte_string;"; for (size_t i = 0; i < lineage.size(); i++) { s << "/* (" << i << ") */\n"; if (i == 0) { s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {"; s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {"; s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {"; } else if (i != lineage.size() - 1) { s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "("; s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {"; s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {"; } else { s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {"; s << "let rebuilder = " << lineage[i]->name_ << "Builder {"; FieldList params = GetParamList(); if (params.HasBody()) { ERROR() << "Packets with body fields can't be auto-tested. Test a child."; } for (const auto param : params) { s << param->GetName() << " : packet."; if (param->GetFieldType() == VectorField::kFieldType) { s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),"; } else if (param->GetFieldType() == ArrayField::kFieldType) { const auto array_param = static_cast(param); const auto element_field = array_param->GetElementField(); if (element_field->GetFieldType() == StructField::kFieldType) { s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),"; } else { s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),"; } } else if (param->GetFieldType() == StructField::kFieldType) { s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),"; } else { s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),"; } } s << "};"; s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();"; s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();"; s << "assert_eq!(rebuilder_bytes, raw_bytes);"; s << "}"; } } for (size_t i = 1; i < lineage.size(); i++) { s << "_ => {"; s << "println!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_); s << "{:02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); "; s << "}}}"; } s << ","; s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),"; s << "}"; s << "}"; s << ")*"; s << "}"; s << "}"; s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { "; int number = 0; for (const auto& test_case : test_cases_) { s << util::CamelCaseToUnderScore(name_) << "_builder_test_"; s << std::setfill('0') << std::setw(2) << number++ << ": "; s << "b\"" << test_case << "\","; } s << "}"; s << "\n"; } } void PacketDef::GenRustDef(std::ostream& s) const { GenRustChildEnums(s); GenRustStructDeclarations(s); GenRustStructImpls(s); GenRustAccessStructImpls(s); GenRustBuilderStructImpls(s); GenRustBuilderTest(s); }