• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "parent_def.h"
18 
19 #include "fields/all_fields.h"
20 #include "util.h"
21 
ParentDef(std::string name,FieldList fields)22 ParentDef::ParentDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
ParentDef(std::string name,FieldList fields,ParentDef * parent)23 ParentDef::ParentDef(std::string name, FieldList fields, ParentDef* parent)
24     : TypeDef(name), fields_(fields), parent_(parent) {}
25 
AddParentConstraint(std::string field_name,std::variant<int64_t,std::string> value)26 void ParentDef::AddParentConstraint(std::string field_name, std::variant<int64_t, std::string> value) {
27   // NOTE: This could end up being very slow if there are a lot of constraints.
28   const auto& parent_params = parent_->GetParamList();
29   const auto& constrained_field = parent_params.GetField(field_name);
30   if (constrained_field == nullptr) {
31     ERROR() << "Attempting to constrain field " << field_name << " in parent " << parent_->name_
32             << ", but no such field exists.";
33   }
34 
35   if (constrained_field->GetFieldType() == ScalarField::kFieldType) {
36     if (!std::holds_alternative<int64_t>(value)) {
37       ERROR(constrained_field) << "Attempting to constrain a scalar field to an enum value in " << parent_->name_;
38     }
39   } else if (constrained_field->GetFieldType() == EnumField::kFieldType) {
40     if (!std::holds_alternative<std::string>(value)) {
41       ERROR(constrained_field) << "Attempting to constrain an enum field to a scalar value in " << parent_->name_;
42     }
43     const auto& enum_def = static_cast<EnumField*>(constrained_field)->GetEnumDef();
44     if (!enum_def.HasEntry(std::get<std::string>(value))) {
45       ERROR(constrained_field) << "No matching enumeration \"" << std::get<std::string>(value)
46                                << "\" for constraint on enum in parent " << parent_->name_ << ".";
47     }
48 
49     // For enums, we have to qualify the value using the enum type name.
50     value = enum_def.GetTypeName() + "::" + std::get<std::string>(value);
51   } else {
52     ERROR(constrained_field) << "Field in parent " << parent_->name_ << " is not viable for constraining.";
53   }
54 
55   parent_constraints_.insert(std::pair(field_name, value));
56 }
57 
58 // Assign all size fields to their corresponding variable length fields.
59 // Will crash if
60 //  - there aren't any fields that don't match up to a field.
61 //  - the size field points to a fixed size field.
62 //  - if the size field comes after the variable length field.
AssignSizeFields()63 void ParentDef::AssignSizeFields() {
64   for (const auto& field : fields_) {
65     DEBUG() << "field name: " << field->GetName();
66 
67     if (field->GetFieldType() != SizeField::kFieldType && field->GetFieldType() != CountField::kFieldType) {
68       continue;
69     }
70 
71     const SizeField* size_field = static_cast<SizeField*>(field);
72     // Check to see if a corresponding field can be found.
73     const auto& var_len_field = fields_.GetField(size_field->GetSizedFieldName());
74     if (var_len_field == nullptr) {
75       ERROR(field) << "Could not find corresponding field for size/count field.";
76     }
77 
78     // Do the ordering check to ensure the size field comes before the
79     // variable length field.
80     for (auto it = fields_.begin(); *it != size_field; it++) {
81       DEBUG() << "field name: " << (*it)->GetName();
82       if (*it == var_len_field) {
83         ERROR(var_len_field, size_field) << "Size/count field must come before the variable length field it describes.";
84       }
85     }
86 
87     if (var_len_field->GetFieldType() == PayloadField::kFieldType) {
88       const auto& payload_field = static_cast<PayloadField*>(var_len_field);
89       payload_field->SetSizeField(size_field);
90       continue;
91     }
92 
93     if (var_len_field->GetFieldType() == VectorField::kFieldType) {
94       const auto& vector_field = static_cast<VectorField*>(var_len_field);
95       vector_field->SetSizeField(size_field);
96       continue;
97     }
98 
99     // If we've reached this point then the field wasn't a variable length field.
100     // Check to see if the field is a variable length field
101     ERROR(field, size_field) << "Can not use size/count in reference to a fixed size field.\n";
102   }
103 }
104 
SetEndianness(bool is_little_endian)105 void ParentDef::SetEndianness(bool is_little_endian) {
106   is_little_endian_ = is_little_endian;
107 }
108 
109 // Get the size. You scan specify without_payload in order to exclude payload fields as children will be overriding it.
GetSize(bool without_payload) const110 Size ParentDef::GetSize(bool without_payload) const {
111   auto size = Size(0);
112 
113   for (const auto& field : fields_) {
114     if (without_payload &&
115         (field->GetFieldType() == PayloadField::kFieldType || field->GetFieldType() == BodyField::kFieldType)) {
116       continue;
117     }
118 
119     // The offset to the field must be passed in as an argument for dynamically sized custom fields.
120     if (field->GetFieldType() == CustomField::kFieldType && field->GetSize().has_dynamic()) {
121       std::stringstream custom_field_size;
122 
123       // Custom fields are special as their size field takes an argument.
124       custom_field_size << field->GetSize().dynamic_string() << "(begin()";
125 
126       // Check if we can determine offset from begin(), otherwise error because by this point,
127       // the size of the custom field is unknown and can't be subtracted from end() to get the
128       // offset.
129       auto offset = GetOffsetForField(field->GetName(), false);
130       if (offset.empty()) {
131         ERROR(field) << "Custom Field offset can not be determined from begin().";
132       }
133 
134       if (offset.bits() % 8 != 0) {
135         ERROR(field) << "Custom fields must be byte aligned.";
136       }
137       if (offset.has_bits()) custom_field_size << " + " << offset.bits() / 8;
138       if (offset.has_dynamic()) custom_field_size << " + " << offset.dynamic_string();
139       custom_field_size << ")";
140 
141       size += custom_field_size.str();
142       continue;
143     }
144 
145     size += field->GetSize();
146   }
147 
148   if (parent_ != nullptr) {
149     size += parent_->GetSize(true);
150   }
151 
152   return size;
153 }
154 
155 // Get the offset until the field is reached, if there is no field
156 // returns an empty Size. from_end requests the offset to the field
157 // starting from the end() iterator. If there is a field with an unknown
158 // size along the traversal, then an empty size is returned.
GetOffsetForField(std::string field_name,bool from_end) const159 Size ParentDef::GetOffsetForField(std::string field_name, bool from_end) const {
160   // Check first if the field exists.
161   if (fields_.GetField(field_name) == nullptr) {
162     ERROR() << "Can't find a field offset for nonexistent field named: " << field_name << " in " << name_;
163   }
164 
165   // We have to use a generic lambda to conditionally change iteration direction
166   // due to iterator and reverse_iterator being different types.
167   auto size_lambda = [&](auto from, auto to) -> Size {
168     auto size = Size(0);
169     for (auto it = from; it != to; it++) {
170       // We've reached the field, end the loop.
171       if ((*it)->GetName() == field_name) break;
172       const auto& field = *it;
173       // If there is a field with an unknown size before the field, return an empty Size.
174       if (field->GetSize().empty()) {
175         return Size();
176       }
177       if (field->GetFieldType() != PaddingField::kFieldType || !from_end) {
178         size += field->GetSize();
179       }
180     }
181     return size;
182   };
183 
184   // Change iteration direction based on from_end.
185   auto size = Size();
186   if (from_end)
187     size = size_lambda(fields_.rbegin(), fields_.rend());
188   else
189     size = size_lambda(fields_.begin(), fields_.end());
190   if (size.empty()) return size;
191 
192   // We need the offset until a payload or body field.
193   if (parent_ != nullptr) {
194     if (parent_->fields_.HasPayload()) {
195       auto parent_payload_offset = parent_->GetOffsetForField("payload", from_end);
196       if (parent_payload_offset.empty()) {
197         ERROR() << "Empty offset for payload in " << parent_->name_ << " finding the offset for field: " << field_name;
198       }
199       size += parent_payload_offset;
200     } else {
201       auto parent_body_offset = parent_->GetOffsetForField("body", from_end);
202       if (parent_body_offset.empty()) {
203         ERROR() << "Empty offset for body in " << parent_->name_ << " finding the offset for field: " << field_name;
204       }
205       size += parent_body_offset;
206     }
207   }
208 
209   return size;
210 }
211 
GetParamList() const212 FieldList ParentDef::GetParamList() const {
213   FieldList params;
214 
215   std::set<std::string> param_types = {
216       ScalarField::kFieldType,
217       EnumField::kFieldType,
218       ArrayField::kFieldType,
219       VectorField::kFieldType,
220       CustomField::kFieldType,
221       StructField::kFieldType,
222       VariableLengthStructField::kFieldType,
223       PayloadField::kFieldType,
224   };
225 
226   if (parent_ != nullptr) {
227     auto parent_params = parent_->GetParamList().GetFieldsWithTypes(param_types);
228 
229     // Do not include constrained fields in the params
230     for (const auto& field : parent_params) {
231       if (parent_constraints_.find(field->GetName()) == parent_constraints_.end()) {
232         params.AppendField(field);
233       }
234     }
235   }
236   // Add our parameters.
237   return params.Merge(fields_.GetFieldsWithTypes(param_types));
238 }
239 
GenMembers(std::ostream & s) const240 void ParentDef::GenMembers(std::ostream& s) const {
241   // Add the parameter list.
242   for (int i = 0; i < fields_.size(); i++) {
243     if (fields_[i]->GenBuilderMember(s)) {
244       s << "_;";
245     }
246   }
247 }
248 
GenSize(std::ostream & s) const249 void ParentDef::GenSize(std::ostream& s) const {
250   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
251   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
252 
253   s << "protected:";
254   s << "size_t BitsOfHeader() const {";
255   s << "return 0";
256 
257   if (parent_ != nullptr) {
258     if (parent_->GetDefinitionType() == Type::PACKET) {
259       s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
260     } else {
261       s << " + " << parent_->name_ << "::BitsOfHeader() ";
262     }
263   }
264 
265   for (const auto& field : header_fields) {
266     s << " + " << field->GetBuilderSize();
267   }
268   s << ";";
269 
270   s << "}\n\n";
271 
272   s << "size_t BitsOfFooter() const {";
273   s << "return 0";
274   for (const auto& field : footer_fields) {
275     s << " + " << field->GetBuilderSize();
276   }
277 
278   if (parent_ != nullptr) {
279     if (parent_->GetDefinitionType() == Type::PACKET) {
280       s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
281     } else {
282       s << " + " << parent_->name_ << "::BitsOfFooter() ";
283     }
284   }
285   s << ";";
286   s << "}\n\n";
287 
288   if (fields_.HasPayload()) {
289     s << "size_t GetPayloadSize() const {";
290     s << "if (payload_ != nullptr) {return payload_->size();}";
291     s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
292     s << ";}\n\n";
293   }
294 
295   Size padded_size;
296   for (const auto& field : header_fields) {
297     if (field->GetFieldType() == PaddingField::kFieldType) {
298       if (!padded_size.empty()) {
299         ERROR() << "Only one padding field is allowed.  Second field: " << field->GetName();
300       }
301       padded_size = field->GetSize();
302     }
303   }
304 
305   s << "public:";
306   s << "virtual size_t size() const override {";
307   if (!padded_size.empty()) {
308     s << "return " << padded_size.bytes() << ";}";
309     s << "size_t unpadded_size() const {";
310   }
311   s << "return (BitsOfHeader() / 8)";
312   if (fields_.HasPayload()) {
313     s << "+ payload_->size()";
314   }
315   s << " + (BitsOfFooter() / 8);";
316   s << "}\n";
317 }
318 
GenSerialize(std::ostream & s) const319 void ParentDef::GenSerialize(std::ostream& s) const {
320   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
321   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
322 
323   s << "protected:";
324   s << "void SerializeHeader(BitInserter&";
325   if (parent_ != nullptr || header_fields.size() != 0) {
326     s << " i ";
327   }
328   s << ") const {";
329 
330   if (parent_ != nullptr) {
331     if (parent_->GetDefinitionType() == Type::PACKET) {
332       s << parent_->name_ << "Builder::SerializeHeader(i);";
333     } else {
334       s << parent_->name_ << "::SerializeHeader(i);";
335     }
336   }
337 
338   for (const auto& field : header_fields) {
339     if (field->GetFieldType() == SizeField::kFieldType) {
340       const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
341       const auto& sized_field = fields_.GetField(field_name);
342       if (sized_field == nullptr) {
343         ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
344       }
345       if (sized_field->GetFieldType() == PayloadField::kFieldType) {
346         s << "size_t payload_bytes = GetPayloadSize();";
347         std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
348         if (modifier != "") {
349           s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
350           s << "payload_bytes = payload_bytes + (" << modifier << ") / 8;";
351         }
352         s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
353         s << "insert(static_cast<" << field->GetDataType() << ">(payload_bytes), i," << field->GetSize().bits() << ");";
354       } else {
355         if (sized_field->GetFieldType() != VectorField::kFieldType) {
356           ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
357         }
358         const auto& vector_name = field_name + "_";
359         const VectorField* vector = (VectorField*)sized_field;
360         s << "size_t " << vector_name + "bytes = 0;";
361         if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
362           s << "for (auto elem : " << vector_name << ") {";
363           s << vector_name + "bytes += elem.size(); }";
364         } else {
365           s << vector_name + "bytes = ";
366           s << vector_name << ".size() * ((" << vector->element_size_ << ") / 8);";
367         }
368         std::string modifier = vector->GetSizeModifier();
369         if (modifier != "") {
370           s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
371           s << vector_name << "bytes = ";
372           s << vector_name << "bytes + (" << modifier << ") / 8;";
373         }
374         s << "ASSERT(" << vector_name + "bytes < (1 << " << field->GetSize().bits() << "));";
375         s << "insert(" << vector_name << "bytes, i, ";
376         s << field->GetSize().bits() << ");";
377       }
378     } else if (field->GetFieldType() == ChecksumStartField::kFieldType) {
379       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
380       const auto& started_field = fields_.GetField(field_name);
381       if (started_field == nullptr) {
382         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
383                      << ")";
384       }
385       s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetDataType() << ">();";
386       s << "shared_checksum_ptr->Initialize();";
387       s << "i.RegisterObserver(packet::ByteObserver(";
388       s << "[shared_checksum_ptr](uint8_t byte){ shared_checksum_ptr->AddByte(byte);},";
389       s << "[shared_checksum_ptr](){ return static_cast<uint64_t>(shared_checksum_ptr->GetChecksum());}));";
390     } else if (field->GetFieldType() == PaddingField::kFieldType) {
391       s << "ASSERT(unpadded_size() <= " << field->GetSize().bytes() << ");";
392       s << "size_t padding_bytes = ";
393       s << field->GetSize().bytes() << " - unpadded_size();";
394       s << "for (size_t padding = 0; padding < padding_bytes; padding++) {i.insert_byte(0);}";
395     } else if (field->GetFieldType() == CountField::kFieldType) {
396       const auto& vector_name = ((SizeField*)field)->GetSizedFieldName() + "_";
397       s << "insert(" << vector_name << ".size(), i, " << field->GetSize().bits() << ");";
398     } else {
399       field->GenInserter(s);
400     }
401   }
402   s << "}\n\n";
403 
404   s << "void SerializeFooter(BitInserter&";
405   if (parent_ != nullptr || footer_fields.size() != 0) {
406     s << " i ";
407   }
408   s << ") const {";
409 
410   for (const auto& field : footer_fields) {
411     field->GenInserter(s);
412   }
413   if (parent_ != nullptr) {
414     if (parent_->GetDefinitionType() == Type::PACKET) {
415       s << parent_->name_ << "Builder::SerializeFooter(i);";
416     } else {
417       s << parent_->name_ << "::SerializeFooter(i);";
418     }
419   }
420   s << "}\n\n";
421 
422   s << "public:";
423   s << "virtual void Serialize(BitInserter& i) const override {";
424   s << "SerializeHeader(i);";
425   if (fields_.HasPayload()) {
426     s << "payload_->Serialize(i);";
427   }
428   s << "SerializeFooter(i);";
429 
430   s << "}\n";
431 }
432 
GenInstanceOf(std::ostream & s) const433 void ParentDef::GenInstanceOf(std::ostream& s) const {
434   if (parent_ != nullptr && parent_constraints_.size() > 0) {
435     s << "static bool IsInstance(const " << parent_->name_ << "& parent) {";
436     // Get the list of parent params.
437     FieldList parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
438         PayloadField::kFieldType,
439         BodyField::kFieldType,
440     });
441 
442     // Check if constrained parent fields are set to their correct values.
443     for (int i = 0; i < parent_params.size(); i++) {
444       const auto& field = parent_params[i];
445       const auto& constraint = parent_constraints_.find(field->GetName());
446       if (constraint != parent_constraints_.end()) {
447         s << "if (parent." << field->GetName() << "_ != ";
448         if (field->GetFieldType() == ScalarField::kFieldType) {
449           s << std::get<int64_t>(constraint->second) << ")";
450           s << "{ return false;}";
451         } else if (field->GetFieldType() == EnumField::kFieldType) {
452           s << std::get<std::string>(constraint->second) << ")";
453           s << "{ return false;}";
454         } else {
455           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
456         }
457       }
458     }
459     s << "return true;}";
460   }
461 }
462