1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <stdlib.h>
18
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24
25 #include <google/protobuf/compiler/code_generator.h>
26 #include <google/protobuf/compiler/plugin.h>
27 #include <google/protobuf/descriptor.h>
28 #include <google/protobuf/descriptor.pb.h>
29 #include <google/protobuf/io/printer.h>
30 #include <google/protobuf/io/zero_copy_stream.h>
31
32 #include "perfetto/ext/base/string_utils.h"
33
34 namespace protozero {
35 namespace {
36
37 using google::protobuf::Descriptor;
38 using google::protobuf::EnumDescriptor;
39 using google::protobuf::EnumValueDescriptor;
40 using google::protobuf::FieldDescriptor;
41 using google::protobuf::FileDescriptor;
42 using google::protobuf::compiler::GeneratorContext;
43 using google::protobuf::io::Printer;
44 using google::protobuf::io::ZeroCopyOutputStream;
45 using perfetto::base::SplitString;
46 using perfetto::base::StripChars;
47 using perfetto::base::StripPrefix;
48 using perfetto::base::StripSuffix;
49 using perfetto::base::ToUpper;
50 using perfetto::base::Uppercase;
51
52 // Keep this value in sync with ProtoDecoder::kMaxDecoderFieldId. If they go out
53 // of sync pbzero.h files will stop compiling, hitting the at() static_assert.
54 // Not worth an extra dependency.
55 constexpr int kMaxDecoderFieldId = 999;
56
Assert(bool condition)57 void Assert(bool condition) {
58 if (!condition)
59 abort();
60 }
61
62 struct FileDescriptorComp {
operator ()protozero::__anon02737a120111::FileDescriptorComp63 bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
64 int comp = lhs->name().compare(rhs->name());
65 Assert(comp != 0 || lhs == rhs);
66 return comp < 0;
67 }
68 };
69
70 struct DescriptorComp {
operator ()protozero::__anon02737a120111::DescriptorComp71 bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
72 int comp = lhs->full_name().compare(rhs->full_name());
73 Assert(comp != 0 || lhs == rhs);
74 return comp < 0;
75 }
76 };
77
78 struct EnumDescriptorComp {
operator ()protozero::__anon02737a120111::EnumDescriptorComp79 bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
80 int comp = lhs->full_name().compare(rhs->full_name());
81 Assert(comp != 0 || lhs == rhs);
82 return comp < 0;
83 }
84 };
85
ProtoStubName(const FileDescriptor * proto)86 inline std::string ProtoStubName(const FileDescriptor* proto) {
87 return StripSuffix(proto->name(), ".proto") + ".pbzero";
88 }
89
90 class GeneratorJob {
91 public:
GeneratorJob(const FileDescriptor * file,Printer * stub_h_printer)92 GeneratorJob(const FileDescriptor* file, Printer* stub_h_printer)
93 : source_(file), stub_h_(stub_h_printer) {}
94
GenerateStubs()95 bool GenerateStubs() {
96 Preprocess();
97 GeneratePrologue();
98 for (const EnumDescriptor* enumeration : enums_)
99 GenerateEnumDescriptor(enumeration);
100 for (const Descriptor* message : messages_)
101 GenerateMessageDescriptor(message);
102 for (const auto& key_value : extensions_)
103 GenerateExtension(key_value.first, key_value.second);
104 GenerateEpilogue();
105 return error_.empty();
106 }
107
SetOption(const std::string & name,const std::string & value)108 void SetOption(const std::string& name, const std::string& value) {
109 if (name == "wrapper_namespace") {
110 wrapper_namespace_ = value;
111 } else {
112 Abort(std::string() + "Unknown plugin option '" + name + "'.");
113 }
114 }
115
116 // If generator fails to produce stubs for a particular proto definitions
117 // it finishes with undefined output and writes the first error occured.
GetFirstError() const118 const std::string& GetFirstError() const { return error_; }
119
120 private:
121 // Only the first error will be recorded.
Abort(const std::string & reason)122 void Abort(const std::string& reason) {
123 if (error_.empty())
124 error_ = reason;
125 }
126
127 // Get full name (including outer descriptors) of proto descriptor.
128 template <class T>
GetDescriptorName(const T * descriptor)129 inline std::string GetDescriptorName(const T* descriptor) {
130 if (!package_.empty()) {
131 return StripPrefix(descriptor->full_name(), package_ + ".");
132 } else {
133 return descriptor->full_name();
134 }
135 }
136
137 // Get C++ class name corresponding to proto descriptor.
138 // Nested names are splitted by underscores. Underscores in type names aren't
139 // prohibited but not recommended in order to avoid name collisions.
140 template <class T>
GetCppClassName(const T * descriptor,bool full=false)141 inline std::string GetCppClassName(const T* descriptor, bool full = false) {
142 std::string name = StripChars(GetDescriptorName(descriptor), ".", '_');
143 if (full)
144 name = full_namespace_prefix_ + name;
145 return name;
146 }
147
GetFieldNumberConstant(const FieldDescriptor * field)148 inline std::string GetFieldNumberConstant(const FieldDescriptor* field) {
149 std::string name = field->camelcase_name();
150 if (!name.empty()) {
151 name.at(0) = Uppercase(name.at(0));
152 name = "k" + name + "FieldNumber";
153 } else {
154 // Protoc allows fields like 'bool _ = 1'.
155 Abort("Empty field name in camel case notation.");
156 }
157 return name;
158 }
159
160 // Note: intentionally avoiding depending on protozero sources, as well as
161 // protobuf-internal WireFormat/WireFormatLite classes.
FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type)162 const char* FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type) {
163 switch (proto_type) {
164 case FieldDescriptor::TYPE_INT64:
165 case FieldDescriptor::TYPE_UINT64:
166 case FieldDescriptor::TYPE_INT32:
167 case FieldDescriptor::TYPE_BOOL:
168 case FieldDescriptor::TYPE_UINT32:
169 case FieldDescriptor::TYPE_ENUM:
170 case FieldDescriptor::TYPE_SINT32:
171 case FieldDescriptor::TYPE_SINT64:
172 return "::protozero::proto_utils::ProtoWireType::kVarInt";
173
174 case FieldDescriptor::TYPE_FIXED32:
175 case FieldDescriptor::TYPE_SFIXED32:
176 case FieldDescriptor::TYPE_FLOAT:
177 return "::protozero::proto_utils::ProtoWireType::kFixed32";
178
179 case FieldDescriptor::TYPE_FIXED64:
180 case FieldDescriptor::TYPE_SFIXED64:
181 case FieldDescriptor::TYPE_DOUBLE:
182 return "::protozero::proto_utils::ProtoWireType::kFixed64";
183
184 case FieldDescriptor::TYPE_STRING:
185 case FieldDescriptor::TYPE_MESSAGE:
186 case FieldDescriptor::TYPE_BYTES:
187 return "::protozero::proto_utils::ProtoWireType::kLengthDelimited";
188
189 case FieldDescriptor::TYPE_GROUP:
190 Abort("Groups not supported.");
191 }
192 Abort("Unrecognized FieldDescriptor::Type.");
193 return "";
194 }
195
FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type)196 const char* FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type) {
197 switch (proto_type) {
198 case FieldDescriptor::TYPE_INT64:
199 case FieldDescriptor::TYPE_UINT64:
200 case FieldDescriptor::TYPE_INT32:
201 case FieldDescriptor::TYPE_BOOL:
202 case FieldDescriptor::TYPE_UINT32:
203 case FieldDescriptor::TYPE_ENUM:
204 case FieldDescriptor::TYPE_SINT32:
205 case FieldDescriptor::TYPE_SINT64:
206 return "::protozero::PackedVarInt";
207
208 case FieldDescriptor::TYPE_FIXED32:
209 return "::protozero::PackedFixedSizeInt<uint32_t>";
210 case FieldDescriptor::TYPE_SFIXED32:
211 return "::protozero::PackedFixedSizeInt<int32_t>";
212 case FieldDescriptor::TYPE_FLOAT:
213 return "::protozero::PackedFixedSizeInt<float>";
214
215 case FieldDescriptor::TYPE_FIXED64:
216 return "::protozero::PackedFixedSizeInt<uint64_t>";
217 case FieldDescriptor::TYPE_SFIXED64:
218 return "::protozero::PackedFixedSizeInt<int64_t>";
219 case FieldDescriptor::TYPE_DOUBLE:
220 return "::protozero::PackedFixedSizeInt<double>";
221
222 case FieldDescriptor::TYPE_STRING:
223 case FieldDescriptor::TYPE_MESSAGE:
224 case FieldDescriptor::TYPE_BYTES:
225 case FieldDescriptor::TYPE_GROUP:
226 Abort("Unexpected FieldDescritor::Type.");
227 }
228 Abort("Unrecognized FieldDescriptor::Type.");
229 return "";
230 }
231
FieldToProtoSchemaType(const FieldDescriptor * field)232 const char* FieldToProtoSchemaType(const FieldDescriptor* field) {
233 switch (field->type()) {
234 case FieldDescriptor::TYPE_BOOL:
235 return "kBool";
236 case FieldDescriptor::TYPE_INT32:
237 return "kInt32";
238 case FieldDescriptor::TYPE_INT64:
239 return "kInt64";
240 case FieldDescriptor::TYPE_UINT32:
241 return "kUint32";
242 case FieldDescriptor::TYPE_UINT64:
243 return "kUint64";
244 case FieldDescriptor::TYPE_SINT32:
245 return "kSint32";
246 case FieldDescriptor::TYPE_SINT64:
247 return "kSint64";
248 case FieldDescriptor::TYPE_FIXED32:
249 return "kFixed32";
250 case FieldDescriptor::TYPE_FIXED64:
251 return "kFixed64";
252 case FieldDescriptor::TYPE_SFIXED32:
253 return "kSfixed32";
254 case FieldDescriptor::TYPE_SFIXED64:
255 return "kSfixed64";
256 case FieldDescriptor::TYPE_FLOAT:
257 return "kFloat";
258 case FieldDescriptor::TYPE_DOUBLE:
259 return "kDouble";
260 case FieldDescriptor::TYPE_ENUM:
261 return "kEnum";
262 case FieldDescriptor::TYPE_STRING:
263 return "kString";
264 case FieldDescriptor::TYPE_MESSAGE:
265 return "kMessage";
266 case FieldDescriptor::TYPE_BYTES:
267 return "kBytes";
268
269 case FieldDescriptor::TYPE_GROUP:
270 Abort("Groups not supported.");
271 return "";
272 }
273 Abort("Unrecognized FieldDescriptor::Type.");
274 return "";
275 }
276
FieldToCppTypeName(const FieldDescriptor * field)277 std::string FieldToCppTypeName(const FieldDescriptor* field) {
278 switch (field->type()) {
279 case FieldDescriptor::TYPE_BOOL:
280 return "bool";
281 case FieldDescriptor::TYPE_INT32:
282 return "int32_t";
283 case FieldDescriptor::TYPE_INT64:
284 return "int64_t";
285 case FieldDescriptor::TYPE_UINT32:
286 return "uint32_t";
287 case FieldDescriptor::TYPE_UINT64:
288 return "uint64_t";
289 case FieldDescriptor::TYPE_SINT32:
290 return "int32_t";
291 case FieldDescriptor::TYPE_SINT64:
292 return "int64_t";
293 case FieldDescriptor::TYPE_FIXED32:
294 return "uint32_t";
295 case FieldDescriptor::TYPE_FIXED64:
296 return "uint64_t";
297 case FieldDescriptor::TYPE_SFIXED32:
298 return "int32_t";
299 case FieldDescriptor::TYPE_SFIXED64:
300 return "int64_t";
301 case FieldDescriptor::TYPE_FLOAT:
302 return "float";
303 case FieldDescriptor::TYPE_DOUBLE:
304 return "double";
305 case FieldDescriptor::TYPE_ENUM:
306 return GetCppClassName(field->enum_type(), true);
307 case FieldDescriptor::TYPE_STRING:
308 case FieldDescriptor::TYPE_BYTES:
309 return "std::string";
310 case FieldDescriptor::TYPE_MESSAGE:
311 return GetCppClassName(field->message_type());
312 case FieldDescriptor::TYPE_GROUP:
313 Abort("Groups not supported.");
314 return "";
315 }
316 Abort("Unrecognized FieldDescriptor::Type.");
317 return "";
318 }
319
FieldToRepetitionType(const FieldDescriptor * field)320 const char* FieldToRepetitionType(const FieldDescriptor* field) {
321 if (!field->is_repeated())
322 return "kNotRepeated";
323 if (field->is_packed())
324 return "kRepeatedPacked";
325 return "kRepeatedNotPacked";
326 }
327
CollectDescriptors()328 void CollectDescriptors() {
329 // Collect message descriptors in DFS order.
330 std::vector<const Descriptor*> stack;
331 stack.reserve(static_cast<size_t>(source_->message_type_count()));
332 for (int i = 0; i < source_->message_type_count(); ++i)
333 stack.push_back(source_->message_type(i));
334
335 while (!stack.empty()) {
336 const Descriptor* message = stack.back();
337 stack.pop_back();
338
339 if (message->extension_count() > 0) {
340 if (message->field_count() > 0 || message->nested_type_count() > 0 ||
341 message->enum_type_count() > 0) {
342 Abort("message with extend blocks shouldn't contain anything else");
343 }
344
345 // Iterate over all fields in "extend" blocks.
346 for (int i = 0; i < message->extension_count(); ++i) {
347 const FieldDescriptor* extension = message->extension(i);
348
349 // Protoc plugin API does not group fields in "extend" blocks.
350 // As the support for extensions in protozero is limited, the code
351 // assumes that extend blocks are located inside a wrapper message and
352 // name of this message is used to group them.
353 std::string extension_name = extension->extension_scope()->name();
354 extensions_[extension_name].push_back(extension);
355 }
356 } else {
357 messages_.push_back(message);
358 for (int i = 0; i < message->nested_type_count(); ++i) {
359 stack.push_back(message->nested_type(i));
360 // Emit a forward declaration of nested message types, as the outer
361 // class will refer to them when creating type aliases.
362 referenced_messages_.insert(message->nested_type(i));
363 }
364 }
365 }
366
367 // Collect enums.
368 for (int i = 0; i < source_->enum_type_count(); ++i)
369 enums_.push_back(source_->enum_type(i));
370
371 if (source_->extension_count() > 0)
372 Abort("top-level extension blocks are not supported");
373
374 for (const Descriptor* message : messages_) {
375 for (int i = 0; i < message->enum_type_count(); ++i) {
376 enums_.push_back(message->enum_type(i));
377 }
378 }
379 }
380
CollectDependencies()381 void CollectDependencies() {
382 // Public import basically means that callers only need to import this
383 // proto in order to use the stuff publicly imported by this proto.
384 for (int i = 0; i < source_->public_dependency_count(); ++i)
385 public_imports_.insert(source_->public_dependency(i));
386
387 if (source_->weak_dependency_count() > 0)
388 Abort("Weak imports are not supported.");
389
390 // Validations. Collect public imports (of collected imports) in DFS order.
391 // Visibilty for current proto:
392 // - all imports listed in current proto,
393 // - public imports of everything imported (recursive).
394 std::vector<const FileDescriptor*> stack;
395 for (int i = 0; i < source_->dependency_count(); ++i) {
396 const FileDescriptor* import = source_->dependency(i);
397 stack.push_back(import);
398 if (public_imports_.count(import) == 0) {
399 private_imports_.insert(import);
400 }
401 }
402
403 while (!stack.empty()) {
404 const FileDescriptor* import = stack.back();
405 stack.pop_back();
406 // Having imports under different packages leads to unnecessary
407 // complexity with namespaces.
408 if (import->package() != package_)
409 Abort("Imported proto must be in the same package.");
410
411 for (int i = 0; i < import->public_dependency_count(); ++i) {
412 stack.push_back(import->public_dependency(i));
413 }
414 }
415
416 // Collect descriptors of messages and enums used in current proto.
417 // It will be used to generate necessary forward declarations and
418 // check that everything lays in the same namespace.
419 for (const Descriptor* message : messages_) {
420 for (int i = 0; i < message->field_count(); ++i) {
421 const FieldDescriptor* field = message->field(i);
422
423 if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
424 if (public_imports_.count(field->message_type()->file()) == 0) {
425 // Avoid multiple forward declarations since
426 // public imports have been already included.
427 referenced_messages_.insert(field->message_type());
428 }
429 } else if (field->type() == FieldDescriptor::TYPE_ENUM) {
430 if (public_imports_.count(field->enum_type()->file()) == 0) {
431 referenced_enums_.insert(field->enum_type());
432 }
433 }
434 }
435 }
436 }
437
Preprocess()438 void Preprocess() {
439 // Package name maps to a series of namespaces.
440 package_ = source_->package();
441 namespaces_ = SplitString(package_, ".");
442 if (!wrapper_namespace_.empty())
443 namespaces_.push_back(wrapper_namespace_);
444
445 full_namespace_prefix_ = "::";
446 for (const std::string& ns : namespaces_)
447 full_namespace_prefix_ += ns + "::";
448
449 CollectDescriptors();
450 CollectDependencies();
451 }
452
GetNamespaceNameForInnerEnum(const EnumDescriptor * enumeration)453 std::string GetNamespaceNameForInnerEnum(const EnumDescriptor* enumeration) {
454 return "perfetto_pbzero_enum_" +
455 GetCppClassName(enumeration->containing_type());
456 }
457
458 // Print top header, namespaces and forward declarations.
GeneratePrologue()459 void GeneratePrologue() {
460 std::string greeting =
461 "// Autogenerated by the ProtoZero compiler plugin. DO NOT EDIT.\n";
462 std::string guard = package_ + "_" + source_->name() + "_H_";
463 guard = ToUpper(guard);
464 guard = StripChars(guard, ".-/\\", '_');
465
466 stub_h_->Print(
467 "$greeting$\n"
468 "#ifndef $guard$\n"
469 "#define $guard$\n\n"
470 "#include <stddef.h>\n"
471 "#include <stdint.h>\n\n"
472 "#include \"perfetto/protozero/field_writer.h\"\n"
473 "#include \"perfetto/protozero/message.h\"\n"
474 "#include \"perfetto/protozero/packed_repeated_fields.h\"\n"
475 "#include \"perfetto/protozero/proto_decoder.h\"\n"
476 "#include \"perfetto/protozero/proto_utils.h\"\n",
477 "greeting", greeting, "guard", guard);
478
479 // Print includes for public imports.
480 for (const FileDescriptor* dependency : public_imports_) {
481 // Dependency name could contain slashes but importing from upper-level
482 // directories is not possible anyway since build system processes each
483 // proto file individually. Hence proto lookup path is always equal to the
484 // directory where particular proto file is located and protoc does not
485 // allow reference to upper directory (aka ..) in import path.
486 //
487 // Laconically said:
488 // - source_->name() may never have slashes,
489 // - dependency->name() may have slashes but always refers to inner path.
490 stub_h_->Print("#include \"$name$.h\"\n", "name",
491 ProtoStubName(dependency));
492 }
493 stub_h_->Print("\n");
494
495 // Print namespaces.
496 for (const std::string& ns : namespaces_) {
497 stub_h_->Print("namespace $ns$ {\n", "ns", ns);
498 }
499 stub_h_->Print("\n");
500
501 // Print forward declarations.
502 for (const Descriptor* message : referenced_messages_) {
503 stub_h_->Print("class $class$;\n", "class", GetCppClassName(message));
504 }
505 for (const EnumDescriptor* enumeration : referenced_enums_) {
506 if (enumeration->containing_type()) {
507 stub_h_->Print("namespace $namespace_name$ {\n", "namespace_name",
508 GetNamespaceNameForInnerEnum(enumeration));
509 }
510 stub_h_->Print("enum $class$ : int32_t;\n", "class", enumeration->name());
511
512 if (enumeration->containing_type()) {
513 stub_h_->Print("} // namespace $namespace_name$\n", "namespace_name",
514 GetNamespaceNameForInnerEnum(enumeration));
515 stub_h_->Print("using $alias$ = $namespace_name$::$short_name$;\n",
516 "alias", GetCppClassName(enumeration), "namespace_name",
517 GetNamespaceNameForInnerEnum(enumeration), "short_name",
518 enumeration->name());
519 }
520 }
521 stub_h_->Print("\n");
522 }
523
GenerateEnumDescriptor(const EnumDescriptor * enumeration)524 void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
525 bool is_inner_enum = !!enumeration->containing_type();
526 if (is_inner_enum) {
527 stub_h_->Print("namespace $namespace_name$ {\n", "namespace_name",
528 GetNamespaceNameForInnerEnum(enumeration));
529 }
530
531 stub_h_->Print("enum $class$ : int32_t {\n", "class", enumeration->name());
532 stub_h_->Indent();
533
534 std::string min_name, max_name;
535 int min_val = std::numeric_limits<int>::max();
536 int max_val = -1;
537 for (int i = 0; i < enumeration->value_count(); ++i) {
538 const EnumValueDescriptor* value = enumeration->value(i);
539 const std::string value_name = value->name();
540 stub_h_->Print("$name$ = $number$,\n", "name", value_name, "number",
541 std::to_string(value->number()));
542 if (value->number() < min_val) {
543 min_val = value->number();
544 min_name = value_name;
545 }
546 if (value->number() > max_val) {
547 max_val = value->number();
548 max_name = value_name;
549 }
550 }
551 stub_h_->Outdent();
552 stub_h_->Print("};\n");
553 if (is_inner_enum) {
554 const std::string namespace_name =
555 GetNamespaceNameForInnerEnum(enumeration);
556 stub_h_->Print("} // namespace $namespace_name$\n", "namespace_name",
557 namespace_name);
558 stub_h_->Print(
559 "using $full_enum_name$ = $namespace_name$::$enum_name$;\n\n",
560 "full_enum_name", GetCppClassName(enumeration), "enum_name",
561 enumeration->name(), "namespace_name", namespace_name);
562 }
563 stub_h_->Print("\n");
564 stub_h_->Print("constexpr $class$ $class$_MIN = $class$::$min$;\n", "class",
565 GetCppClassName(enumeration), "min", min_name);
566 stub_h_->Print("constexpr $class$ $class$_MAX = $class$::$max$;\n", "class",
567 GetCppClassName(enumeration), "max", max_name);
568 stub_h_->Print("\n");
569
570 GenerateEnumToStringConversion(enumeration);
571 }
572
GenerateEnumToStringConversion(const EnumDescriptor * enumeration)573 void GenerateEnumToStringConversion(const EnumDescriptor* enumeration) {
574 std::string fullClassName =
575 full_namespace_prefix_ + GetCppClassName(enumeration);
576 const char* function_header_stub = R"(
577 PERFETTO_PROTOZERO_CONSTEXPR14_OR_INLINE
578 const char* $class_name$_Name($full_class$ value) {
579 )";
580 stub_h_->Print(function_header_stub, "full_class", fullClassName,
581 "class_name", GetCppClassName(enumeration));
582 stub_h_->Indent();
583 stub_h_->Print("switch (value) {");
584 for (int index = 0; index < enumeration->value_count(); ++index) {
585 const EnumValueDescriptor* value = enumeration->value(index);
586 const char* switch_stub = R"(
587 case $full_class$::$value_name$:
588 return "$value_name$";
589 )";
590 stub_h_->Print(switch_stub, "full_class", fullClassName, "value_name",
591 value->name());
592 }
593 stub_h_->Print("}\n");
594 stub_h_->Print(R"(return "PBZERO_UNKNOWN_ENUM_VALUE";)");
595 stub_h_->Print("\n");
596 stub_h_->Outdent();
597 stub_h_->Print("}\n\n");
598 }
599
600 // Packed repeated fields are encoded as a length-delimited field on the wire,
601 // where the payload is the concatenation of invidually encoded elements.
GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor * field)602 void GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor* field) {
603 std::map<std::string, std::string> setter;
604 setter["name"] = field->lowercase_name();
605 setter["field_metadata"] = GetFieldMetadataTypeName(field);
606 setter["action"] = "set";
607 setter["buffer_type"] = FieldTypeToPackedBufferType(field->type());
608 stub_h_->Print(
609 setter,
610 "void $action$_$name$(const $buffer_type$& packed_buffer) {\n"
611 " AppendBytes($field_metadata$::kFieldId, packed_buffer.data(),\n"
612 " packed_buffer.size());\n"
613 "}\n");
614 }
615
GenerateSimpleFieldDescriptor(const FieldDescriptor * field)616 void GenerateSimpleFieldDescriptor(const FieldDescriptor* field) {
617 std::map<std::string, std::string> setter;
618 setter["id"] = std::to_string(field->number());
619 setter["name"] = field->lowercase_name();
620 setter["field_metadata"] = GetFieldMetadataTypeName(field);
621 setter["action"] = field->is_repeated() ? "add" : "set";
622 setter["cpp_type"] = FieldToCppTypeName(field);
623 setter["proto_field_type"] = FieldToProtoSchemaType(field);
624
625 const char* code_stub =
626 "void $action$_$name$($cpp_type$ value) {\n"
627 " static constexpr uint32_t field_id = $field_metadata$::kFieldId;\n"
628 " // Call the appropriate protozero::Message::Append(field_id, ...)\n"
629 " // method based on the type of the field.\n"
630 " ::protozero::internal::FieldWriter<\n"
631 " ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$>\n"
632 " ::Append(*this, field_id, value);\n"
633 "}\n";
634
635 if (field->type() == FieldDescriptor::TYPE_STRING) {
636 // Strings and bytes should have an additional accessor which specifies
637 // the length explicitly.
638 const char* additional_method =
639 "void $action$_$name$(const char* data, size_t size) {\n"
640 " AppendBytes($field_metadata$::kFieldId, data, size);\n"
641 "}\n"
642 "void $action$_$name$(::protozero::ConstChars chars) {\n"
643 " AppendBytes($field_metadata$::kFieldId, chars.data, chars.size);\n"
644 "}\n";
645 stub_h_->Print(setter, additional_method);
646 } else if (field->type() == FieldDescriptor::TYPE_BYTES) {
647 const char* additional_method =
648 "void $action$_$name$(const uint8_t* data, size_t size) {\n"
649 " AppendBytes($field_metadata$::kFieldId, data, size);\n"
650 "}\n"
651 "void $action$_$name$(::protozero::ConstBytes bytes) {\n"
652 " AppendBytes($field_metadata$::kFieldId, bytes.data, bytes.size);\n"
653 "}\n";
654 stub_h_->Print(setter, additional_method);
655 } else if (field->type() == FieldDescriptor::TYPE_GROUP ||
656 field->type() == FieldDescriptor::TYPE_MESSAGE) {
657 Abort("Unsupported field type.");
658 return;
659 }
660
661 stub_h_->Print(setter, code_stub);
662 }
663
GenerateNestedMessageFieldDescriptor(const FieldDescriptor * field)664 void GenerateNestedMessageFieldDescriptor(const FieldDescriptor* field) {
665 std::string action = field->is_repeated() ? "add" : "set";
666 std::string inner_class = GetCppClassName(field->message_type());
667 stub_h_->Print(
668 "template <typename T = $inner_class$> T* $action$_$name$() {\n"
669 " return BeginNestedMessage<T>($id$);\n"
670 "}\n\n",
671 "id", std::to_string(field->number()), "name", field->lowercase_name(),
672 "action", action, "inner_class", inner_class);
673 if (field->options().lazy()) {
674 stub_h_->Print(
675 "void $action$_$name$_raw(const std::string& raw) {\n"
676 " return AppendBytes($id$, raw.data(), raw.size());\n"
677 "}\n\n",
678 "id", std::to_string(field->number()), "name",
679 field->lowercase_name(), "action", action, "inner_class",
680 inner_class);
681 }
682 }
683
GenerateDecoder(const Descriptor * message)684 void GenerateDecoder(const Descriptor* message) {
685 int max_field_id = 0;
686 bool has_nonpacked_repeated_fields = false;
687 for (int i = 0; i < message->field_count(); ++i) {
688 const FieldDescriptor* field = message->field(i);
689 if (field->number() > kMaxDecoderFieldId)
690 continue;
691 max_field_id = std::max(max_field_id, field->number());
692 if (field->is_repeated() && !field->is_packed())
693 has_nonpacked_repeated_fields = true;
694 }
695
696 std::string class_name = GetCppClassName(message) + "_Decoder";
697 stub_h_->Print(
698 "class $name$ : public "
699 "::protozero::TypedProtoDecoder</*MAX_FIELD_ID=*/$max$, "
700 "/*HAS_NONPACKED_REPEATED_FIELDS=*/$rep$> {\n",
701 "name", class_name, "max", std::to_string(max_field_id), "rep",
702 has_nonpacked_repeated_fields ? "true" : "false");
703 stub_h_->Print(" public:\n");
704 stub_h_->Indent();
705 stub_h_->Print(
706 "$name$(const uint8_t* data, size_t len) "
707 ": TypedProtoDecoder(data, len) {}\n",
708 "name", class_name);
709 stub_h_->Print(
710 "explicit $name$(const std::string& raw) : "
711 "TypedProtoDecoder(reinterpret_cast<const uint8_t*>(raw.data()), "
712 "raw.size()) {}\n",
713 "name", class_name);
714 stub_h_->Print(
715 "explicit $name$(const ::protozero::ConstBytes& raw) : "
716 "TypedProtoDecoder(raw.data, raw.size) {}\n",
717 "name", class_name);
718
719 for (int i = 0; i < message->field_count(); ++i) {
720 const FieldDescriptor* field = message->field(i);
721 if (field->number() > max_field_id) {
722 stub_h_->Print("// field $name$ omitted because its id is too high\n",
723 "name", field->name());
724 continue;
725 }
726 std::string getter;
727 std::string cpp_type;
728 switch (field->type()) {
729 case FieldDescriptor::TYPE_BOOL:
730 getter = "as_bool";
731 cpp_type = "bool";
732 break;
733 case FieldDescriptor::TYPE_SFIXED32:
734 case FieldDescriptor::TYPE_SINT32:
735 case FieldDescriptor::TYPE_INT32:
736 getter = "as_int32";
737 cpp_type = "int32_t";
738 break;
739 case FieldDescriptor::TYPE_SFIXED64:
740 case FieldDescriptor::TYPE_SINT64:
741 case FieldDescriptor::TYPE_INT64:
742 getter = "as_int64";
743 cpp_type = "int64_t";
744 break;
745 case FieldDescriptor::TYPE_FIXED32:
746 case FieldDescriptor::TYPE_UINT32:
747 getter = "as_uint32";
748 cpp_type = "uint32_t";
749 break;
750 case FieldDescriptor::TYPE_FIXED64:
751 case FieldDescriptor::TYPE_UINT64:
752 getter = "as_uint64";
753 cpp_type = "uint64_t";
754 break;
755 case FieldDescriptor::TYPE_FLOAT:
756 getter = "as_float";
757 cpp_type = "float";
758 break;
759 case FieldDescriptor::TYPE_DOUBLE:
760 getter = "as_double";
761 cpp_type = "double";
762 break;
763 case FieldDescriptor::TYPE_ENUM:
764 getter = "as_int32";
765 cpp_type = "int32_t";
766 break;
767 case FieldDescriptor::TYPE_STRING:
768 getter = "as_string";
769 cpp_type = "::protozero::ConstChars";
770 break;
771 case FieldDescriptor::TYPE_MESSAGE:
772 case FieldDescriptor::TYPE_BYTES:
773 getter = "as_bytes";
774 cpp_type = "::protozero::ConstBytes";
775 break;
776 case FieldDescriptor::TYPE_GROUP:
777 continue;
778 }
779
780 stub_h_->Print("bool has_$name$() const { return at<$id$>().valid(); }\n",
781 "name", field->lowercase_name(), "id",
782 std::to_string(field->number()));
783
784 if (field->is_packed()) {
785 const char* protozero_wire_type =
786 FieldTypeToProtozeroWireType(field->type());
787 stub_h_->Print(
788 "::protozero::PackedRepeatedFieldIterator<$wire_type$, $cpp_type$> "
789 "$name$(bool* parse_error_ptr) const { return "
790 "GetPackedRepeated<$wire_type$, $cpp_type$>($id$, "
791 "parse_error_ptr); }\n",
792 "wire_type", protozero_wire_type, "cpp_type", cpp_type, "name",
793 field->lowercase_name(), "id", std::to_string(field->number()));
794 } else if (field->is_repeated()) {
795 stub_h_->Print(
796 "::protozero::RepeatedFieldIterator<$cpp_type$> $name$() const { "
797 "return "
798 "GetRepeated<$cpp_type$>($id$); }\n",
799 "name", field->lowercase_name(), "cpp_type", cpp_type, "id",
800 std::to_string(field->number()));
801 } else {
802 stub_h_->Print(
803 "$cpp_type$ $name$() const { return at<$id$>().$getter$(); }\n",
804 "name", field->lowercase_name(), "id",
805 std::to_string(field->number()), "cpp_type", cpp_type, "getter",
806 getter);
807 }
808 }
809 stub_h_->Outdent();
810 stub_h_->Print("};\n\n");
811 }
812
GenerateConstantsForMessageFields(const Descriptor * message)813 void GenerateConstantsForMessageFields(const Descriptor* message) {
814 const bool has_fields = (message->field_count() > 0);
815
816 // Field number constants.
817 if (has_fields) {
818 stub_h_->Print("enum : int32_t {\n");
819 stub_h_->Indent();
820
821 for (int i = 0; i < message->field_count(); ++i) {
822 const FieldDescriptor* field = message->field(i);
823 stub_h_->Print("$name$ = $id$,\n", "name",
824 GetFieldNumberConstant(field), "id",
825 std::to_string(field->number()));
826 }
827 stub_h_->Outdent();
828 stub_h_->Print("};\n");
829 }
830 }
831
GenerateMessageDescriptor(const Descriptor * message)832 void GenerateMessageDescriptor(const Descriptor* message) {
833 GenerateDecoder(message);
834
835 stub_h_->Print(
836 "class $name$ : public ::protozero::Message {\n"
837 " public:\n",
838 "name", GetCppClassName(message));
839 stub_h_->Indent();
840
841 stub_h_->Print("using Decoder = $name$_Decoder;\n", "name",
842 GetCppClassName(message));
843
844 GenerateConstantsForMessageFields(message);
845
846 stub_h_->Print(
847 "static constexpr const char* GetName() { return \".$name$\"; }\n\n",
848 "name", message->full_name());
849
850 // Using statements for nested messages.
851 for (int i = 0; i < message->nested_type_count(); ++i) {
852 const Descriptor* nested_message = message->nested_type(i);
853 stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
854 nested_message->name(), "global_name",
855 GetCppClassName(nested_message, true));
856 }
857
858 // Using statements for nested enums.
859 for (int i = 0; i < message->enum_type_count(); ++i) {
860 const EnumDescriptor* nested_enum = message->enum_type(i);
861 const char* stub = R"(
862 using $local_name$ = $global_name$;
863 static inline const char* $local_name$_Name($local_name$ value) {
864 return $global_name$_Name(value);
865 }
866 )";
867 stub_h_->Print(stub, "local_name", nested_enum->name(), "global_name",
868 GetCppClassName(nested_enum, true));
869 }
870
871 // Values of nested enums.
872 for (int i = 0; i < message->enum_type_count(); ++i) {
873 const EnumDescriptor* nested_enum = message->enum_type(i);
874
875 for (int j = 0; j < nested_enum->value_count(); ++j) {
876 const EnumValueDescriptor* value = nested_enum->value(j);
877 stub_h_->Print("static const $class$ $name$ = $class$::$name$;\n",
878 "class", nested_enum->name(), "name", value->name());
879 }
880 }
881
882 // Field descriptors.
883 for (int i = 0; i < message->field_count(); ++i) {
884 GenerateFieldDescriptor(GetCppClassName(message), message->field(i));
885 }
886
887 stub_h_->Outdent();
888 stub_h_->Print("};\n\n");
889 }
890
GetFieldMetadataTypeName(const FieldDescriptor * field)891 std::string GetFieldMetadataTypeName(const FieldDescriptor* field) {
892 std::string name = field->camelcase_name();
893 if (isalpha(name[0]))
894 name[0] = static_cast<char>(toupper(name[0]));
895 return "FieldMetadata_" + name;
896 }
897
GetFieldMetadataVariableName(const FieldDescriptor * field)898 std::string GetFieldMetadataVariableName(const FieldDescriptor* field) {
899 std::string name = field->camelcase_name();
900 if (isalpha(name[0]))
901 name[0] = static_cast<char>(toupper(name[0]));
902 return "k" + name;
903 }
904
GenerateFieldMetadata(const std::string & message_cpp_type,const FieldDescriptor * field)905 void GenerateFieldMetadata(const std::string& message_cpp_type,
906 const FieldDescriptor* field) {
907 const char* code_stub = R"(
908 using $field_metadata_type$ =
909 ::protozero::proto_utils::FieldMetadata<
910 $field_id$,
911 ::protozero::proto_utils::RepetitionType::$repetition_type$,
912 ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$,
913 $cpp_type$,
914 $message_cpp_type$>;
915
916 static constexpr $field_metadata_type$ $field_metadata_var${};
917 )";
918
919 stub_h_->Print(code_stub, "field_id", std::to_string(field->number()),
920 "repetition_type", FieldToRepetitionType(field),
921 "proto_field_type", FieldToProtoSchemaType(field),
922 "cpp_type", FieldToCppTypeName(field), "message_cpp_type",
923 message_cpp_type, "field_metadata_type",
924 GetFieldMetadataTypeName(field), "field_metadata_var",
925 GetFieldMetadataVariableName(field));
926 }
927
GenerateFieldDescriptor(const std::string & message_cpp_type,const FieldDescriptor * field)928 void GenerateFieldDescriptor(const std::string& message_cpp_type,
929 const FieldDescriptor* field) {
930 GenerateFieldMetadata(message_cpp_type, field);
931 if (field->is_packed()) {
932 GeneratePackedRepeatedFieldDescriptor(field);
933 } else if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
934 GenerateSimpleFieldDescriptor(field);
935 } else {
936 GenerateNestedMessageFieldDescriptor(field);
937 }
938 }
939
940 // Generate extension class for a group of FieldDescriptor instances
941 // representing one "extend" block in proto definition. For example:
942 //
943 // message SpecificExtension {
944 // extend GeneralThing {
945 // optional Fizz fizz = 101;
946 // optional Buzz buzz = 102;
947 // }
948 // }
949 //
950 // This is going to be passed as a vector of two elements, "fizz" and
951 // "buzz". Wrapping message is used to provide a name for generated
952 // extension class.
953 //
954 // In the example above, generated code is going to look like:
955 //
956 // class SpecificExtension : public GeneralThing {
957 // Fizz* set_fizz();
958 // Buzz* set_buzz();
959 // }
GenerateExtension(const std::string & extension_name,const std::vector<const FieldDescriptor * > & descriptors)960 void GenerateExtension(
961 const std::string& extension_name,
962 const std::vector<const FieldDescriptor*>& descriptors) {
963 // Use an arbitrary descriptor in order to get generic information not
964 // specific to any of them.
965 const FieldDescriptor* descriptor = descriptors[0];
966 const Descriptor* base_message = descriptor->containing_type();
967
968 // TODO(ddrone): ensure that this code works when containing_type located in
969 // other file or namespace.
970 stub_h_->Print("class $name$ : public $extendee$ {\n", "name",
971 extension_name, "extendee",
972 GetCppClassName(base_message, /*full=*/true));
973 stub_h_->Print(" public:\n");
974 stub_h_->Indent();
975 for (const FieldDescriptor* field : descriptors) {
976 if (field->containing_type() != base_message) {
977 Abort("one wrapper should extend only one message");
978 return;
979 }
980 GenerateFieldDescriptor(extension_name, field);
981 }
982 stub_h_->Outdent();
983 stub_h_->Print("};\n");
984 }
985
GenerateEpilogue()986 void GenerateEpilogue() {
987 for (unsigned i = 0; i < namespaces_.size(); ++i) {
988 stub_h_->Print("} // Namespace.\n");
989 }
990 stub_h_->Print("#endif // Include guard.\n");
991 }
992
993 const FileDescriptor* const source_;
994 Printer* const stub_h_;
995 std::string error_;
996
997 std::string package_;
998 std::string wrapper_namespace_;
999 std::vector<std::string> namespaces_;
1000 std::string full_namespace_prefix_;
1001 std::vector<const Descriptor*> messages_;
1002 std::vector<const EnumDescriptor*> enums_;
1003 std::map<std::string, std::vector<const FieldDescriptor*>> extensions_;
1004
1005 // The custom *Comp comparators are to ensure determinism of the generator.
1006 std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
1007 std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
1008 std::set<const Descriptor*, DescriptorComp> referenced_messages_;
1009 std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
1010 };
1011
1012 class ProtoZeroGenerator : public ::google::protobuf::compiler::CodeGenerator {
1013 public:
1014 explicit ProtoZeroGenerator();
1015 ~ProtoZeroGenerator() override;
1016
1017 // CodeGenerator implementation
1018 bool Generate(const google::protobuf::FileDescriptor* file,
1019 const std::string& options,
1020 GeneratorContext* context,
1021 std::string* error) const override;
1022 };
1023
ProtoZeroGenerator()1024 ProtoZeroGenerator::ProtoZeroGenerator() {}
1025
~ProtoZeroGenerator()1026 ProtoZeroGenerator::~ProtoZeroGenerator() {}
1027
Generate(const FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const1028 bool ProtoZeroGenerator::Generate(const FileDescriptor* file,
1029 const std::string& options,
1030 GeneratorContext* context,
1031 std::string* error) const {
1032 const std::unique_ptr<ZeroCopyOutputStream> stub_h_file_stream(
1033 context->Open(ProtoStubName(file) + ".h"));
1034 const std::unique_ptr<ZeroCopyOutputStream> stub_cc_file_stream(
1035 context->Open(ProtoStubName(file) + ".cc"));
1036
1037 // Variables are delimited by $.
1038 Printer stub_h_printer(stub_h_file_stream.get(), '$');
1039 GeneratorJob job(file, &stub_h_printer);
1040
1041 Printer stub_cc_printer(stub_cc_file_stream.get(), '$');
1042 stub_cc_printer.Print("// Intentionally empty (crbug.com/998165)\n");
1043
1044 // Parse additional options.
1045 for (const std::string& option : SplitString(options, ",")) {
1046 std::vector<std::string> option_pair = SplitString(option, "=");
1047 job.SetOption(option_pair[0], option_pair[1]);
1048 }
1049
1050 if (!job.GenerateStubs()) {
1051 *error = job.GetFirstError();
1052 return false;
1053 }
1054 return true;
1055 }
1056
1057 } // namespace
1058 } // namespace protozero
1059
main(int argc,char * argv[])1060 int main(int argc, char* argv[]) {
1061 ::protozero::ProtoZeroGenerator generator;
1062 return google::protobuf::compiler::PluginMain(argc, argv, &generator);
1063 }
1064