1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "packet_def.h"
18
19 #include <list>
20 #include <set>
21
22 #include "fields/all_fields.h"
23 #include "util.h"
24
PacketDef(std::string name,FieldList fields)25 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)26 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
27
GetNewField(const std::string &,ParseLocation) const28 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
29 return nullptr; // Packets can't be fields
30 }
31
GenParserDefinition(std::ostream & s) const32 void PacketDef::GenParserDefinition(std::ostream& s) const {
33 s << "class " << name_ << "View";
34 if (parent_ != nullptr) {
35 s << " : public " << parent_->name_ << "View {";
36 } else {
37 s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
38 }
39 s << " public:";
40
41 // Specialize function
42 if (parent_ != nullptr) {
43 s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
44 s << "{ return " << name_ << "View(parent); }";
45 } else {
46 s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
47 s << "{ return " << name_ << "View(packet); }";
48 }
49
50 std::set<std::string> fixed_types = {
51 FixedScalarField::kFieldType,
52 FixedEnumField::kFieldType,
53 };
54
55 // Print all of the public fields which are all the fields minus the fixed fields.
56 const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
57 bool has_fixed_fields = public_fields.size() != fields_.size();
58 for (const auto& field : public_fields) {
59 GenParserFieldGetter(s, field);
60 s << "\n";
61 }
62 GenValidator(s);
63 s << "\n";
64
65 s << " protected:\n";
66 // Constructor from a View
67 if (parent_ != nullptr) {
68 s << name_ << "View(" << parent_->name_ << "View parent)";
69 s << " : " << parent_->name_ << "View(parent) { was_validated_ = false; }";
70 } else {
71 s << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
72 s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
73 }
74
75 // Print the private fields which are the fixed fields.
76 if (has_fixed_fields) {
77 const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
78 s << " private:\n";
79 for (const auto& field : private_fields) {
80 GenParserFieldGetter(s, field);
81 s << "\n";
82 }
83 }
84 s << "};\n";
85 }
86
GenParserDefinitionPybind11(std::ostream & s) const87 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
88 s << "py::class_<" << name_ << "View";
89 if (parent_ != nullptr) {
90 s << ", " << parent_->name_ << "View";
91 } else {
92 s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
93 }
94 s << ">(m, \"" << name_ << "View\")";
95 if (parent_ != nullptr) {
96 s << ".def(py::init([](" << parent_->name_ << "View parent) {";
97 } else {
98 s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
99 }
100 s << "auto view =" << name_ << "View::Create(std::move(parent));";
101 s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
102 s << "return view; }))";
103
104 s << ".def(py::init(&" << name_ << "View::Create))";
105 std::set<std::string> protected_field_types = {
106 FixedScalarField::kFieldType,
107 FixedEnumField::kFieldType,
108 SizeField::kFieldType,
109 CountField::kFieldType,
110 };
111 const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
112 for (const auto& field : public_fields) {
113 auto getter_func_name = field->GetGetterFunctionName();
114 if (getter_func_name.empty()) {
115 continue;
116 }
117 s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
118 }
119 s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
120 s << ";\n";
121 }
122
GenParserFieldGetter(std::ostream & s,const PacketField * field) const123 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
124 // Start field offset
125 auto start_field_offset = GetOffsetForField(field->GetName(), false);
126 auto end_field_offset = GetOffsetForField(field->GetName(), true);
127
128 if (start_field_offset.empty() && end_field_offset.empty()) {
129 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
130 << "no method exists to determine field location from begin() or end().\n";
131 }
132
133 field->GenGetter(s, start_field_offset, end_field_offset);
134 }
135
GetDefinitionType() const136 TypeDef::Type PacketDef::GetDefinitionType() const {
137 return TypeDef::Type::PACKET;
138 }
139
GenValidator(std::ostream & s) const140 void PacketDef::GenValidator(std::ostream& s) const {
141 // Get the static offset for all of our fields.
142 int bits_size = 0;
143 for (const auto& field : fields_) {
144 if (field->GetFieldType() != PaddingField::kFieldType) {
145 bits_size += field->GetSize().bits();
146 }
147 }
148
149 // Write the function declaration.
150 s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {";
151 s << "if (was_validated_) { return true; } ";
152 s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }";
153 s << "}";
154
155 s << "protected:";
156 s << "virtual bool IsValid_() const {";
157
158 // Offset by the parents known size. We know that any dynamic fields can
159 // already be called since the parent must have already been validated by
160 // this point.
161 auto parent_size = Size(0);
162 if (parent_ != nullptr) {
163 parent_size = parent_->GetSize(true);
164 }
165
166 s << "auto it = begin() + (" << parent_size << ") / 8;";
167
168 // Check if you can extract the static fields.
169 // At this point you know you can use the size getters without crashing
170 // as long as they follow the instruction that size fields cant come before
171 // their corrisponding variable length field.
172 s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
173 s << "if (it > end()) return false;";
174
175 // For any variable length fields, use their size check.
176 for (const auto& field : fields_) {
177 if (field->GetFieldType() == ChecksumStartField::kFieldType) {
178 auto offset = GetOffsetForField(field->GetName(), false);
179 if (!offset.empty()) {
180 s << "size_t sum_index = (" << offset << ") / 8;";
181 } else {
182 offset = GetOffsetForField(field->GetName(), true);
183 if (offset.empty()) {
184 ERROR(field) << "Checksum Start Field offset can not be determined.";
185 }
186 s << "size_t sum_index = size() - (" << offset << ") / 8;";
187 }
188
189 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
190 const auto& started_field = fields_.GetField(field_name);
191 if (started_field == nullptr) {
192 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
193 << ")";
194 }
195 auto end_offset = GetOffsetForField(started_field->GetName(), false);
196 if (!end_offset.empty()) {
197 s << "size_t end_sum_index = (" << end_offset << ") / 8;";
198 } else {
199 end_offset = GetOffsetForField(started_field->GetName(), true);
200 if (end_offset.empty()) {
201 ERROR(started_field) << "Checksum Field end_offset can not be determined.";
202 }
203 s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
204 }
205 if (is_little_endian_) {
206 s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
207 } else {
208 s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
209 }
210 s << started_field->GetDataType() << " checksum;";
211 s << "checksum.Initialize();";
212 s << "for (uint8_t byte : checksum_view) { ";
213 s << "checksum.AddByte(byte);}";
214 s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
215 << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
216
217 continue;
218 }
219
220 auto field_size = field->GetSize();
221 // Fixed size fields have already been handled.
222 if (!field_size.has_dynamic()) {
223 continue;
224 }
225
226 // Custom fields with dynamic size must have the offset for the field passed in as well
227 // as the end iterator so that they may ensure that they don't try to read past the end.
228 // Custom fields with fixed sizes will be handled in the static offset checking.
229 if (field->GetFieldType() == CustomField::kFieldType) {
230 // Check if we can determine offset from begin(), otherwise error because by this point,
231 // the size of the custom field is unknown and can't be subtracted from end() to get the
232 // offset.
233 auto offset = GetOffsetForField(field->GetName(), false);
234 if (offset.empty()) {
235 ERROR(field) << "Custom Field offset can not be determined from begin().";
236 }
237
238 if (offset.bits() % 8 != 0) {
239 ERROR(field) << "Custom fields must be byte aligned.";
240 }
241
242 // Custom fields are special as their size field takes an argument.
243 const auto& custom_size_var = field->GetName() + "_size";
244 s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
245 s << "(begin() + (" << offset << ") / 8);";
246
247 s << "if (!" << custom_size_var << ".has_value()) { return false; }";
248 s << "it += *" << custom_size_var << ";";
249 s << "if (it > end()) return false;";
250 continue;
251 } else {
252 s << "it += (" << field_size.dynamic_string() << ") / 8;";
253 s << "if (it > end()) return false;";
254 }
255 }
256
257 // Validate constraints after validating the size
258 if (parent_constraints_.size() > 0 && parent_ == nullptr) {
259 ERROR() << "Can't have a constraint on a NULL parent";
260 }
261
262 for (const auto& constraint : parent_constraints_) {
263 s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
264 const auto& field = parent_->GetParamList().GetField(constraint.first);
265 if (field->GetFieldType() == ScalarField::kFieldType) {
266 s << std::get<int64_t>(constraint.second);
267 } else {
268 s << std::get<std::string>(constraint.second);
269 }
270 s << ") return false;";
271 }
272
273 // Validate the packets fields last
274 for (const auto& field : fields_) {
275 field->GenValidator(s);
276 s << "\n";
277 }
278
279 s << "return true;";
280 s << "}\n";
281 if (parent_ == nullptr) {
282 s << "bool was_validated_{false};\n";
283 }
284 }
285
GenBuilderDefinition(std::ostream & s) const286 void PacketDef::GenBuilderDefinition(std::ostream& s) const {
287 s << "class " << name_ << "Builder";
288 if (parent_ != nullptr) {
289 s << " : public " << parent_->name_ << "Builder";
290 } else {
291 if (is_little_endian_) {
292 s << " : public PacketBuilder<kLittleEndian>";
293 } else {
294 s << " : public PacketBuilder<!kLittleEndian>";
295 }
296 }
297 s << " {";
298 s << " public:";
299 s << " virtual ~" << name_ << "Builder()" << (parent_ != nullptr ? " override" : "") << " = default;";
300
301 if (!fields_.HasBody()) {
302 GenBuilderCreate(s);
303 s << "\n";
304 }
305
306 GenSerialize(s);
307 s << "\n";
308
309 GenSize(s);
310 s << "\n";
311
312 s << " protected:\n";
313 GenBuilderConstructor(s);
314 s << "\n";
315
316 GenBuilderParameterChecker(s);
317 s << "\n";
318
319 GenMembers(s);
320 s << "};\n";
321
322 GenTestDefine(s);
323 s << "\n";
324
325 GenFuzzTestDefine(s);
326 s << "\n";
327 }
328
GenBuilderDefinitionPybind11(std::ostream & s) const329 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
330 s << "py::class_<" << name_ << "Builder";
331 if (parent_ != nullptr) {
332 s << ", " << parent_->name_ << "Builder";
333 } else {
334 if (is_little_endian_) {
335 s << ", PacketBuilder<kLittleEndian>";
336 } else {
337 s << ", PacketBuilder<!kLittleEndian>";
338 }
339 }
340 s << ", std::shared_ptr<" << name_ << "Builder>";
341 s << ">(m, \"" << name_ << "Builder\")";
342 if (!fields_.HasBody()) {
343 GenBuilderCreatePybind11(s);
344 }
345 s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
346 s << "std::vector<uint8_t> bytes;";
347 s << "BitInserter bi(bytes);";
348 s << "builder.Serialize(bi);";
349 s << "return bytes;})";
350 s << ";\n";
351 }
352
GenTestDefine(std::ostream & s) const353 void PacketDef::GenTestDefine(std::ostream& s) const {
354 s << "#ifdef PACKET_TESTING\n";
355 s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
356 s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
357 s << "public: ";
358 s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
359 s << "auto vec = std::make_shared<std::vector<uint8_t>>(captured_packet.begin(), captured_packet.end());";
360 s << name_ << "View view = " << name_ << "View::Create(";
361 auto ancestor_ptr = parent_;
362 size_t parent_parens = 0;
363 while (ancestor_ptr != nullptr) {
364 s << ancestor_ptr->name_ << "View::Create(";
365 parent_parens++;
366 ancestor_ptr = ancestor_ptr->parent_;
367 }
368 s << "vec";
369 for (size_t i = 0; i < parent_parens; i++) {
370 s << ")";
371 }
372 s << ");";
373 s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
374 s << "for (size_t i = 0; i < view.size(); i++) { LOG_DEBUG(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
375 s << "ASSERT_TRUE(view.IsValid());";
376 s << "auto packet = " << name_ << "Builder::Create(";
377 FieldList params = GetParamList().GetFieldsWithoutTypes({
378 BodyField::kFieldType,
379 });
380 for (int i = 0; i < params.size(); i++) {
381 params[i]->GenBuilderParameterFromView(s);
382 if (i != params.size() - 1) {
383 s << ", ";
384 }
385 }
386 s << ");";
387 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
388 s << "packet_bytes->reserve(packet->size());";
389 s << "BitInserter it(*packet_bytes);";
390 s << "packet->Serialize(it);";
391 s << "ASSERT_EQ(*packet_bytes, *vec);";
392 s << "}";
393 s << "};";
394 s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
395 s << "CompareBytes(GetParam());";
396 s << "}";
397 s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
398 s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
399 s << "\n#endif";
400 }
401
GenFuzzTestDefine(std::ostream & s) const402 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
403 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
404 s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
405 s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
406 s << "auto vec = std::make_shared<std::vector<uint8_t>>(data, data + size);";
407 s << name_ << "View view = " << name_ << "View::Create(";
408 auto ancestor_ptr = parent_;
409 size_t parent_parens = 0;
410 while (ancestor_ptr != nullptr) {
411 s << ancestor_ptr->name_ << "View::Create(";
412 parent_parens++;
413 ancestor_ptr = ancestor_ptr->parent_;
414 }
415 s << "vec";
416 for (size_t i = 0; i < parent_parens; i++) {
417 s << ")";
418 }
419 s << ");";
420 s << "if (!view.IsValid()) { return; }";
421 s << "auto packet = " << name_ << "Builder::Create(";
422 FieldList params = GetParamList().GetFieldsWithoutTypes({
423 BodyField::kFieldType,
424 });
425 for (int i = 0; i < params.size(); i++) {
426 params[i]->GenBuilderParameterFromView(s);
427 if (i != params.size() - 1) {
428 s << ", ";
429 }
430 }
431 s << ");";
432 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
433 s << "packet_bytes->reserve(packet->size());";
434 s << "BitInserter it(*packet_bytes);";
435 s << "packet->Serialize(it);";
436 s << "}";
437 s << "\n#endif\n";
438 s << "#ifdef PACKET_FUZZ_TESTING\n";
439 s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
440 s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
441 s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
442 s << "public: ";
443 s << "explicit " << name_
444 << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
445 s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
446 s << "}}; ";
447 s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
448 s << "\n#endif";
449 }
450
GetParametersToValidate() const451 FieldList PacketDef::GetParametersToValidate() const {
452 FieldList params_to_validate;
453 for (const auto& field : GetParamList()) {
454 if (field->HasParameterValidator()) {
455 params_to_validate.AppendField(field);
456 }
457 }
458 return params_to_validate;
459 }
460
GenBuilderCreate(std::ostream & s) const461 void PacketDef::GenBuilderCreate(std::ostream& s) const {
462 s << "static std::unique_ptr<" << name_ << "Builder> Create(";
463
464 auto params = GetParamList();
465 for (int i = 0; i < params.size(); i++) {
466 params[i]->GenBuilderParameter(s);
467 if (i != params.size() - 1) {
468 s << ", ";
469 }
470 }
471 s << ") {";
472
473 // Call the constructor
474 s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
475
476 params = params.GetFieldsWithoutTypes({
477 PayloadField::kFieldType,
478 BodyField::kFieldType,
479 });
480 // Add the parameters.
481 for (int i = 0; i < params.size(); i++) {
482 if (params[i]->BuilderParameterMustBeMoved()) {
483 s << "std::move(" << params[i]->GetName() << ")";
484 } else {
485 s << params[i]->GetName();
486 }
487 if (i != params.size() - 1) {
488 s << ", ";
489 }
490 }
491
492 s << "));";
493 if (fields_.HasPayload()) {
494 s << "builder->payload_ = std::move(payload);";
495 }
496 s << "return builder;";
497 s << "}\n";
498 }
499
GenBuilderCreatePybind11(std::ostream & s) const500 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
501 s << ".def(py::init([](";
502 auto params = GetParamList();
503 std::vector<std::string> constructor_args;
504 int i = 1;
505 for (const auto& param : params) {
506 i++;
507 std::stringstream ss;
508 auto param_type = param->GetBuilderParameterType();
509 if (param_type.empty()) {
510 continue;
511 }
512 // Use shared_ptr instead of unique_ptr for the Python interface
513 if (param->BuilderParameterMustBeMoved()) {
514 param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
515 }
516 ss << param_type << " " << param->GetName();
517 constructor_args.push_back(ss.str());
518 }
519 s << util::StringJoin(",", constructor_args) << "){";
520
521 // Deal with move only args
522 for (const auto& param : params) {
523 std::stringstream ss;
524 auto param_type = param->GetBuilderParameterType();
525 if (param_type.empty()) {
526 continue;
527 }
528 if (!param->BuilderParameterMustBeMoved()) {
529 continue;
530 }
531 auto move_only_param_name = param->GetName() + "_move_only";
532 s << param_type << " " << move_only_param_name << ";";
533 if (param->IsContainerField()) {
534 // Assume single layer container and copy it
535 auto struct_type = param->GetElementField()->GetDataType();
536 struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
537 struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
538 s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
539 // Serialize each struct
540 s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
541 s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
542 s << "auto " << param->GetName() + "_reparsed = std::make_unique<" << struct_type << ">();";
543 s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
544 s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
545 // Parse it again
546 s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
547 s << "auto result = Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
548 // Push it into a new container
549 if (param->GetFieldType() == VectorField::kFieldType) {
550 s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
551 } else if (param->GetFieldType() == ArrayField::kFieldType) {
552 s << move_only_param_name << "[i] = " << param->GetName() << "_reparsed;";
553 } else {
554 ERROR() << param << " is not supported by Pybind11";
555 }
556 s << "}";
557 } else {
558 // Serialize the parameter and pass the bytes in a RawBuilder
559 s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
560 s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
561 s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
562 s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
563 s << move_only_param_name << " = ";
564 s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
565 }
566 }
567 s << "return " << name_ << "Builder::Create(";
568 std::vector<std::string> builder_vars;
569 for (const auto& param : params) {
570 std::stringstream ss;
571 auto param_type = param->GetBuilderParameterType();
572 if (param_type.empty()) {
573 continue;
574 }
575 auto param_name = param->GetName();
576 if (param->BuilderParameterMustBeMoved()) {
577 ss << "std::move(" << param_name << "_move_only)";
578 } else {
579 ss << param_name;
580 }
581 builder_vars.push_back(ss.str());
582 }
583 s << util::StringJoin(",", builder_vars) << ");}";
584 s << "))";
585 }
586
GenBuilderParameterChecker(std::ostream & s) const587 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
588 FieldList params_to_validate = GetParametersToValidate();
589
590 // Skip writing this function if there is nothing to validate.
591 if (params_to_validate.size() == 0) {
592 return;
593 }
594
595 // Generate function arguments.
596 s << "void CheckParameterValues(";
597 for (int i = 0; i < params_to_validate.size(); i++) {
598 params_to_validate[i]->GenBuilderParameter(s);
599 if (i != params_to_validate.size() - 1) {
600 s << ", ";
601 }
602 }
603 s << ") {";
604
605 // Check the parameters.
606 for (const auto& field : params_to_validate) {
607 field->GenParameterValidator(s);
608 }
609 s << "}\n";
610 }
611
GenBuilderConstructor(std::ostream & s) const612 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
613 s << name_ << "Builder(";
614
615 // Generate the constructor parameters.
616 auto params = GetParamList().GetFieldsWithoutTypes({
617 PayloadField::kFieldType,
618 BodyField::kFieldType,
619 });
620 for (int i = 0; i < params.size(); i++) {
621 params[i]->GenBuilderParameter(s);
622 if (i != params.size() - 1) {
623 s << ", ";
624 }
625 }
626 if (params.size() > 0 || parent_constraints_.size() > 0) {
627 s << ") :";
628 } else {
629 s << ")";
630 }
631
632 // Get the list of parent params to call the parent constructor with.
633 FieldList parent_params;
634 if (parent_ != nullptr) {
635 // Pass parameters to the parent constructor
636 s << parent_->name_ << "Builder(";
637 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
638 PayloadField::kFieldType,
639 BodyField::kFieldType,
640 });
641
642 // Go through all the fields and replace constrained fields with fixed values
643 // when calling the parent constructor.
644 for (int i = 0; i < parent_params.size(); i++) {
645 const auto& field = parent_params[i];
646 const auto& constraint = parent_constraints_.find(field->GetName());
647 if (constraint != parent_constraints_.end()) {
648 if (field->GetFieldType() == ScalarField::kFieldType) {
649 s << std::get<int64_t>(constraint->second);
650 } else if (field->GetFieldType() == EnumField::kFieldType) {
651 s << std::get<std::string>(constraint->second);
652 } else {
653 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
654 }
655
656 s << "/* " << field->GetName() << "_ */";
657 } else {
658 s << field->GetName();
659 }
660
661 if (i != parent_params.size() - 1) {
662 s << ", ";
663 }
664 }
665 s << ") ";
666 }
667
668 // Build a list of parameters that excludes all parent parameters.
669 FieldList saved_params;
670 for (const auto& field : params) {
671 if (parent_params.GetField(field->GetName()) == nullptr) {
672 saved_params.AppendField(field);
673 }
674 }
675 if (parent_ != nullptr && saved_params.size() > 0) {
676 s << ",";
677 }
678 for (int i = 0; i < saved_params.size(); i++) {
679 const auto& saved_param_name = saved_params[i]->GetName();
680 if (saved_params[i]->BuilderParameterMustBeMoved()) {
681 s << saved_param_name << "_(std::move(" << saved_param_name << "))";
682 } else {
683 s << saved_param_name << "_(" << saved_param_name << ")";
684 }
685 if (i != saved_params.size() - 1) {
686 s << ",";
687 }
688 }
689 s << " {";
690
691 FieldList params_to_validate = GetParametersToValidate();
692
693 if (params_to_validate.size() > 0) {
694 s << "CheckParameterValues(";
695 for (int i = 0; i < params_to_validate.size(); i++) {
696 s << params_to_validate[i]->GetName() << "_";
697 if (i != params_to_validate.size() - 1) {
698 s << ", ";
699 }
700 }
701 s << ");";
702 }
703
704 s << "}\n";
705 }
706