• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdlib.h>
18 
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 
25 #include <google/protobuf/compiler/code_generator.h>
26 #include <google/protobuf/compiler/plugin.h>
27 #include <google/protobuf/descriptor.h>
28 #include <google/protobuf/descriptor.pb.h>
29 #include <google/protobuf/io/printer.h>
30 #include <google/protobuf/io/zero_copy_stream.h>
31 
32 #include "perfetto/ext/base/string_utils.h"
33 
34 namespace protozero {
35 namespace {
36 
37 using google::protobuf::Descriptor;
38 using google::protobuf::EnumDescriptor;
39 using google::protobuf::EnumValueDescriptor;
40 using google::protobuf::FieldDescriptor;
41 using google::protobuf::FileDescriptor;
42 using google::protobuf::compiler::GeneratorContext;
43 using google::protobuf::io::Printer;
44 using google::protobuf::io::ZeroCopyOutputStream;
45 using perfetto::base::SplitString;
46 using perfetto::base::StripChars;
47 using perfetto::base::StripPrefix;
48 using perfetto::base::StripSuffix;
49 using perfetto::base::ToUpper;
50 using perfetto::base::Uppercase;
51 
52 // Keep this value in sync with ProtoDecoder::kMaxDecoderFieldId. If they go out
53 // of sync pbzero.h files will stop compiling, hitting the at() static_assert.
54 // Not worth an extra dependency.
55 constexpr int kMaxDecoderFieldId = 999;
56 
Assert(bool condition)57 void Assert(bool condition) {
58   if (!condition)
59     abort();
60 }
61 
62 struct FileDescriptorComp {
operator ()protozero::__anon02737a120111::FileDescriptorComp63   bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
64     int comp = lhs->name().compare(rhs->name());
65     Assert(comp != 0 || lhs == rhs);
66     return comp < 0;
67   }
68 };
69 
70 struct DescriptorComp {
operator ()protozero::__anon02737a120111::DescriptorComp71   bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
72     int comp = lhs->full_name().compare(rhs->full_name());
73     Assert(comp != 0 || lhs == rhs);
74     return comp < 0;
75   }
76 };
77 
78 struct EnumDescriptorComp {
operator ()protozero::__anon02737a120111::EnumDescriptorComp79   bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
80     int comp = lhs->full_name().compare(rhs->full_name());
81     Assert(comp != 0 || lhs == rhs);
82     return comp < 0;
83   }
84 };
85 
ProtoStubName(const FileDescriptor * proto)86 inline std::string ProtoStubName(const FileDescriptor* proto) {
87   return StripSuffix(proto->name(), ".proto") + ".pbzero";
88 }
89 
90 class GeneratorJob {
91  public:
GeneratorJob(const FileDescriptor * file,Printer * stub_h_printer)92   GeneratorJob(const FileDescriptor* file, Printer* stub_h_printer)
93       : source_(file), stub_h_(stub_h_printer) {}
94 
GenerateStubs()95   bool GenerateStubs() {
96     Preprocess();
97     GeneratePrologue();
98     for (const EnumDescriptor* enumeration : enums_)
99       GenerateEnumDescriptor(enumeration);
100     for (const Descriptor* message : messages_)
101       GenerateMessageDescriptor(message);
102     for (const auto& key_value : extensions_)
103       GenerateExtension(key_value.first, key_value.second);
104     GenerateEpilogue();
105     return error_.empty();
106   }
107 
SetOption(const std::string & name,const std::string & value)108   void SetOption(const std::string& name, const std::string& value) {
109     if (name == "wrapper_namespace") {
110       wrapper_namespace_ = value;
111     } else {
112       Abort(std::string() + "Unknown plugin option '" + name + "'.");
113     }
114   }
115 
116   // If generator fails to produce stubs for a particular proto definitions
117   // it finishes with undefined output and writes the first error occured.
GetFirstError() const118   const std::string& GetFirstError() const { return error_; }
119 
120  private:
121   // Only the first error will be recorded.
Abort(const std::string & reason)122   void Abort(const std::string& reason) {
123     if (error_.empty())
124       error_ = reason;
125   }
126 
127   // Get full name (including outer descriptors) of proto descriptor.
128   template <class T>
GetDescriptorName(const T * descriptor)129   inline std::string GetDescriptorName(const T* descriptor) {
130     if (!package_.empty()) {
131       return StripPrefix(descriptor->full_name(), package_ + ".");
132     } else {
133       return descriptor->full_name();
134     }
135   }
136 
137   // Get C++ class name corresponding to proto descriptor.
138   // Nested names are splitted by underscores. Underscores in type names aren't
139   // prohibited but not recommended in order to avoid name collisions.
140   template <class T>
GetCppClassName(const T * descriptor,bool full=false)141   inline std::string GetCppClassName(const T* descriptor, bool full = false) {
142     std::string name = StripChars(GetDescriptorName(descriptor), ".", '_');
143     if (full)
144       name = full_namespace_prefix_ + name;
145     return name;
146   }
147 
GetFieldNumberConstant(const FieldDescriptor * field)148   inline std::string GetFieldNumberConstant(const FieldDescriptor* field) {
149     std::string name = field->camelcase_name();
150     if (!name.empty()) {
151       name.at(0) = Uppercase(name.at(0));
152       name = "k" + name + "FieldNumber";
153     } else {
154       // Protoc allows fields like 'bool _ = 1'.
155       Abort("Empty field name in camel case notation.");
156     }
157     return name;
158   }
159 
160   // Note: intentionally avoiding depending on protozero sources, as well as
161   // protobuf-internal WireFormat/WireFormatLite classes.
FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type)162   const char* FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type) {
163     switch (proto_type) {
164       case FieldDescriptor::TYPE_INT64:
165       case FieldDescriptor::TYPE_UINT64:
166       case FieldDescriptor::TYPE_INT32:
167       case FieldDescriptor::TYPE_BOOL:
168       case FieldDescriptor::TYPE_UINT32:
169       case FieldDescriptor::TYPE_ENUM:
170       case FieldDescriptor::TYPE_SINT32:
171       case FieldDescriptor::TYPE_SINT64:
172         return "::protozero::proto_utils::ProtoWireType::kVarInt";
173 
174       case FieldDescriptor::TYPE_FIXED32:
175       case FieldDescriptor::TYPE_SFIXED32:
176       case FieldDescriptor::TYPE_FLOAT:
177         return "::protozero::proto_utils::ProtoWireType::kFixed32";
178 
179       case FieldDescriptor::TYPE_FIXED64:
180       case FieldDescriptor::TYPE_SFIXED64:
181       case FieldDescriptor::TYPE_DOUBLE:
182         return "::protozero::proto_utils::ProtoWireType::kFixed64";
183 
184       case FieldDescriptor::TYPE_STRING:
185       case FieldDescriptor::TYPE_MESSAGE:
186       case FieldDescriptor::TYPE_BYTES:
187         return "::protozero::proto_utils::ProtoWireType::kLengthDelimited";
188 
189       case FieldDescriptor::TYPE_GROUP:
190         Abort("Groups not supported.");
191     }
192     Abort("Unrecognized FieldDescriptor::Type.");
193     return "";
194   }
195 
FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type)196   const char* FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type) {
197     switch (proto_type) {
198       case FieldDescriptor::TYPE_INT64:
199       case FieldDescriptor::TYPE_UINT64:
200       case FieldDescriptor::TYPE_INT32:
201       case FieldDescriptor::TYPE_BOOL:
202       case FieldDescriptor::TYPE_UINT32:
203       case FieldDescriptor::TYPE_ENUM:
204       case FieldDescriptor::TYPE_SINT32:
205       case FieldDescriptor::TYPE_SINT64:
206         return "::protozero::PackedVarInt";
207 
208       case FieldDescriptor::TYPE_FIXED32:
209         return "::protozero::PackedFixedSizeInt<uint32_t>";
210       case FieldDescriptor::TYPE_SFIXED32:
211         return "::protozero::PackedFixedSizeInt<int32_t>";
212       case FieldDescriptor::TYPE_FLOAT:
213         return "::protozero::PackedFixedSizeInt<float>";
214 
215       case FieldDescriptor::TYPE_FIXED64:
216         return "::protozero::PackedFixedSizeInt<uint64_t>";
217       case FieldDescriptor::TYPE_SFIXED64:
218         return "::protozero::PackedFixedSizeInt<int64_t>";
219       case FieldDescriptor::TYPE_DOUBLE:
220         return "::protozero::PackedFixedSizeInt<double>";
221 
222       case FieldDescriptor::TYPE_STRING:
223       case FieldDescriptor::TYPE_MESSAGE:
224       case FieldDescriptor::TYPE_BYTES:
225       case FieldDescriptor::TYPE_GROUP:
226         Abort("Unexpected FieldDescritor::Type.");
227     }
228     Abort("Unrecognized FieldDescriptor::Type.");
229     return "";
230   }
231 
FieldToProtoSchemaType(const FieldDescriptor * field)232   const char* FieldToProtoSchemaType(const FieldDescriptor* field) {
233     switch (field->type()) {
234       case FieldDescriptor::TYPE_BOOL:
235         return "kBool";
236       case FieldDescriptor::TYPE_INT32:
237         return "kInt32";
238       case FieldDescriptor::TYPE_INT64:
239         return "kInt64";
240       case FieldDescriptor::TYPE_UINT32:
241         return "kUint32";
242       case FieldDescriptor::TYPE_UINT64:
243         return "kUint64";
244       case FieldDescriptor::TYPE_SINT32:
245         return "kSint32";
246       case FieldDescriptor::TYPE_SINT64:
247         return "kSint64";
248       case FieldDescriptor::TYPE_FIXED32:
249         return "kFixed32";
250       case FieldDescriptor::TYPE_FIXED64:
251         return "kFixed64";
252       case FieldDescriptor::TYPE_SFIXED32:
253         return "kSfixed32";
254       case FieldDescriptor::TYPE_SFIXED64:
255         return "kSfixed64";
256       case FieldDescriptor::TYPE_FLOAT:
257         return "kFloat";
258       case FieldDescriptor::TYPE_DOUBLE:
259         return "kDouble";
260       case FieldDescriptor::TYPE_ENUM:
261         return "kEnum";
262       case FieldDescriptor::TYPE_STRING:
263         return "kString";
264       case FieldDescriptor::TYPE_MESSAGE:
265         return "kMessage";
266       case FieldDescriptor::TYPE_BYTES:
267         return "kBytes";
268 
269       case FieldDescriptor::TYPE_GROUP:
270         Abort("Groups not supported.");
271         return "";
272     }
273     Abort("Unrecognized FieldDescriptor::Type.");
274     return "";
275   }
276 
FieldToCppTypeName(const FieldDescriptor * field)277   std::string FieldToCppTypeName(const FieldDescriptor* field) {
278     switch (field->type()) {
279       case FieldDescriptor::TYPE_BOOL:
280         return "bool";
281       case FieldDescriptor::TYPE_INT32:
282         return "int32_t";
283       case FieldDescriptor::TYPE_INT64:
284         return "int64_t";
285       case FieldDescriptor::TYPE_UINT32:
286         return "uint32_t";
287       case FieldDescriptor::TYPE_UINT64:
288         return "uint64_t";
289       case FieldDescriptor::TYPE_SINT32:
290         return "int32_t";
291       case FieldDescriptor::TYPE_SINT64:
292         return "int64_t";
293       case FieldDescriptor::TYPE_FIXED32:
294         return "uint32_t";
295       case FieldDescriptor::TYPE_FIXED64:
296         return "uint64_t";
297       case FieldDescriptor::TYPE_SFIXED32:
298         return "int32_t";
299       case FieldDescriptor::TYPE_SFIXED64:
300         return "int64_t";
301       case FieldDescriptor::TYPE_FLOAT:
302         return "float";
303       case FieldDescriptor::TYPE_DOUBLE:
304         return "double";
305       case FieldDescriptor::TYPE_ENUM:
306         return GetCppClassName(field->enum_type(), true);
307       case FieldDescriptor::TYPE_STRING:
308       case FieldDescriptor::TYPE_BYTES:
309         return "std::string";
310       case FieldDescriptor::TYPE_MESSAGE:
311         return GetCppClassName(field->message_type());
312       case FieldDescriptor::TYPE_GROUP:
313         Abort("Groups not supported.");
314         return "";
315     }
316     Abort("Unrecognized FieldDescriptor::Type.");
317     return "";
318   }
319 
FieldToRepetitionType(const FieldDescriptor * field)320   const char* FieldToRepetitionType(const FieldDescriptor* field) {
321     if (!field->is_repeated())
322       return "kNotRepeated";
323     if (field->is_packed())
324       return "kRepeatedPacked";
325     return "kRepeatedNotPacked";
326   }
327 
CollectDescriptors()328   void CollectDescriptors() {
329     // Collect message descriptors in DFS order.
330     std::vector<const Descriptor*> stack;
331     stack.reserve(static_cast<size_t>(source_->message_type_count()));
332     for (int i = 0; i < source_->message_type_count(); ++i)
333       stack.push_back(source_->message_type(i));
334 
335     while (!stack.empty()) {
336       const Descriptor* message = stack.back();
337       stack.pop_back();
338 
339       if (message->extension_count() > 0) {
340         if (message->field_count() > 0 || message->nested_type_count() > 0 ||
341             message->enum_type_count() > 0) {
342           Abort("message with extend blocks shouldn't contain anything else");
343         }
344 
345         // Iterate over all fields in "extend" blocks.
346         for (int i = 0; i < message->extension_count(); ++i) {
347           const FieldDescriptor* extension = message->extension(i);
348 
349           // Protoc plugin API does not group fields in "extend" blocks.
350           // As the support for extensions in protozero is limited, the code
351           // assumes that extend blocks are located inside a wrapper message and
352           // name of this message is used to group them.
353           std::string extension_name = extension->extension_scope()->name();
354           extensions_[extension_name].push_back(extension);
355         }
356       } else {
357         messages_.push_back(message);
358         for (int i = 0; i < message->nested_type_count(); ++i) {
359           stack.push_back(message->nested_type(i));
360           // Emit a forward declaration of nested message types, as the outer
361           // class will refer to them when creating type aliases.
362           referenced_messages_.insert(message->nested_type(i));
363         }
364       }
365     }
366 
367     // Collect enums.
368     for (int i = 0; i < source_->enum_type_count(); ++i)
369       enums_.push_back(source_->enum_type(i));
370 
371     if (source_->extension_count() > 0)
372       Abort("top-level extension blocks are not supported");
373 
374     for (const Descriptor* message : messages_) {
375       for (int i = 0; i < message->enum_type_count(); ++i) {
376         enums_.push_back(message->enum_type(i));
377       }
378     }
379   }
380 
CollectDependencies()381   void CollectDependencies() {
382     // Public import basically means that callers only need to import this
383     // proto in order to use the stuff publicly imported by this proto.
384     for (int i = 0; i < source_->public_dependency_count(); ++i)
385       public_imports_.insert(source_->public_dependency(i));
386 
387     if (source_->weak_dependency_count() > 0)
388       Abort("Weak imports are not supported.");
389 
390     // Validations. Collect public imports (of collected imports) in DFS order.
391     // Visibilty for current proto:
392     // - all imports listed in current proto,
393     // - public imports of everything imported (recursive).
394     std::vector<const FileDescriptor*> stack;
395     for (int i = 0; i < source_->dependency_count(); ++i) {
396       const FileDescriptor* import = source_->dependency(i);
397       stack.push_back(import);
398       if (public_imports_.count(import) == 0) {
399         private_imports_.insert(import);
400       }
401     }
402 
403     while (!stack.empty()) {
404       const FileDescriptor* import = stack.back();
405       stack.pop_back();
406       // Having imports under different packages leads to unnecessary
407       // complexity with namespaces.
408       if (import->package() != package_)
409         Abort("Imported proto must be in the same package.");
410 
411       for (int i = 0; i < import->public_dependency_count(); ++i) {
412         stack.push_back(import->public_dependency(i));
413       }
414     }
415 
416     // Collect descriptors of messages and enums used in current proto.
417     // It will be used to generate necessary forward declarations and
418     // check that everything lays in the same namespace.
419     for (const Descriptor* message : messages_) {
420       for (int i = 0; i < message->field_count(); ++i) {
421         const FieldDescriptor* field = message->field(i);
422 
423         if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
424           if (public_imports_.count(field->message_type()->file()) == 0) {
425             // Avoid multiple forward declarations since
426             // public imports have been already included.
427             referenced_messages_.insert(field->message_type());
428           }
429         } else if (field->type() == FieldDescriptor::TYPE_ENUM) {
430           if (public_imports_.count(field->enum_type()->file()) == 0) {
431             referenced_enums_.insert(field->enum_type());
432           }
433         }
434       }
435     }
436   }
437 
Preprocess()438   void Preprocess() {
439     // Package name maps to a series of namespaces.
440     package_ = source_->package();
441     namespaces_ = SplitString(package_, ".");
442     if (!wrapper_namespace_.empty())
443       namespaces_.push_back(wrapper_namespace_);
444 
445     full_namespace_prefix_ = "::";
446     for (const std::string& ns : namespaces_)
447       full_namespace_prefix_ += ns + "::";
448 
449     CollectDescriptors();
450     CollectDependencies();
451   }
452 
GetNamespaceNameForInnerEnum(const EnumDescriptor * enumeration)453   std::string GetNamespaceNameForInnerEnum(const EnumDescriptor* enumeration) {
454     return "perfetto_pbzero_enum_" +
455            GetCppClassName(enumeration->containing_type());
456   }
457 
458   // Print top header, namespaces and forward declarations.
GeneratePrologue()459   void GeneratePrologue() {
460     std::string greeting =
461         "// Autogenerated by the ProtoZero compiler plugin. DO NOT EDIT.\n";
462     std::string guard = package_ + "_" + source_->name() + "_H_";
463     guard = ToUpper(guard);
464     guard = StripChars(guard, ".-/\\", '_');
465 
466     stub_h_->Print(
467         "$greeting$\n"
468         "#ifndef $guard$\n"
469         "#define $guard$\n\n"
470         "#include <stddef.h>\n"
471         "#include <stdint.h>\n\n"
472         "#include \"perfetto/protozero/field_writer.h\"\n"
473         "#include \"perfetto/protozero/message.h\"\n"
474         "#include \"perfetto/protozero/packed_repeated_fields.h\"\n"
475         "#include \"perfetto/protozero/proto_decoder.h\"\n"
476         "#include \"perfetto/protozero/proto_utils.h\"\n",
477         "greeting", greeting, "guard", guard);
478 
479     // Print includes for public imports.
480     for (const FileDescriptor* dependency : public_imports_) {
481       // Dependency name could contain slashes but importing from upper-level
482       // directories is not possible anyway since build system processes each
483       // proto file individually. Hence proto lookup path is always equal to the
484       // directory where particular proto file is located and protoc does not
485       // allow reference to upper directory (aka ..) in import path.
486       //
487       // Laconically said:
488       // - source_->name() may never have slashes,
489       // - dependency->name() may have slashes but always refers to inner path.
490       stub_h_->Print("#include \"$name$.h\"\n", "name",
491                      ProtoStubName(dependency));
492     }
493     stub_h_->Print("\n");
494 
495     // Print namespaces.
496     for (const std::string& ns : namespaces_) {
497       stub_h_->Print("namespace $ns$ {\n", "ns", ns);
498     }
499     stub_h_->Print("\n");
500 
501     // Print forward declarations.
502     for (const Descriptor* message : referenced_messages_) {
503       stub_h_->Print("class $class$;\n", "class", GetCppClassName(message));
504     }
505     for (const EnumDescriptor* enumeration : referenced_enums_) {
506       if (enumeration->containing_type()) {
507         stub_h_->Print("namespace $namespace_name$ {\n", "namespace_name",
508                        GetNamespaceNameForInnerEnum(enumeration));
509       }
510       stub_h_->Print("enum $class$ : int32_t;\n", "class", enumeration->name());
511 
512       if (enumeration->containing_type()) {
513         stub_h_->Print("}  // namespace $namespace_name$\n", "namespace_name",
514                        GetNamespaceNameForInnerEnum(enumeration));
515         stub_h_->Print("using $alias$ = $namespace_name$::$short_name$;\n",
516                        "alias", GetCppClassName(enumeration), "namespace_name",
517                        GetNamespaceNameForInnerEnum(enumeration), "short_name",
518                        enumeration->name());
519       }
520     }
521     stub_h_->Print("\n");
522   }
523 
GenerateEnumDescriptor(const EnumDescriptor * enumeration)524   void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
525     bool is_inner_enum = !!enumeration->containing_type();
526     if (is_inner_enum) {
527       stub_h_->Print("namespace $namespace_name$ {\n", "namespace_name",
528                      GetNamespaceNameForInnerEnum(enumeration));
529     }
530 
531     stub_h_->Print("enum $class$ : int32_t {\n", "class", enumeration->name());
532     stub_h_->Indent();
533 
534     std::string min_name, max_name;
535     int min_val = std::numeric_limits<int>::max();
536     int max_val = -1;
537     for (int i = 0; i < enumeration->value_count(); ++i) {
538       const EnumValueDescriptor* value = enumeration->value(i);
539       const std::string value_name = value->name();
540       stub_h_->Print("$name$ = $number$,\n", "name", value_name, "number",
541                      std::to_string(value->number()));
542       if (value->number() < min_val) {
543         min_val = value->number();
544         min_name = value_name;
545       }
546       if (value->number() > max_val) {
547         max_val = value->number();
548         max_name = value_name;
549       }
550     }
551     stub_h_->Outdent();
552     stub_h_->Print("};\n");
553     if (is_inner_enum) {
554       const std::string namespace_name =
555           GetNamespaceNameForInnerEnum(enumeration);
556       stub_h_->Print("} // namespace $namespace_name$\n", "namespace_name",
557                      namespace_name);
558       stub_h_->Print(
559           "using $full_enum_name$ = $namespace_name$::$enum_name$;\n\n",
560           "full_enum_name", GetCppClassName(enumeration), "enum_name",
561           enumeration->name(), "namespace_name", namespace_name);
562     }
563     stub_h_->Print("\n");
564     stub_h_->Print("constexpr $class$ $class$_MIN = $class$::$min$;\n", "class",
565                    GetCppClassName(enumeration), "min", min_name);
566     stub_h_->Print("constexpr $class$ $class$_MAX = $class$::$max$;\n", "class",
567                    GetCppClassName(enumeration), "max", max_name);
568     stub_h_->Print("\n");
569 
570     GenerateEnumToStringConversion(enumeration);
571   }
572 
GenerateEnumToStringConversion(const EnumDescriptor * enumeration)573   void GenerateEnumToStringConversion(const EnumDescriptor* enumeration) {
574     std::string fullClassName =
575         full_namespace_prefix_ + GetCppClassName(enumeration);
576     const char* function_header_stub = R"(
577 PERFETTO_PROTOZERO_CONSTEXPR14_OR_INLINE
578 const char* $class_name$_Name($full_class$ value) {
579 )";
580     stub_h_->Print(function_header_stub, "full_class", fullClassName,
581                    "class_name", GetCppClassName(enumeration));
582     stub_h_->Indent();
583     stub_h_->Print("switch (value) {");
584     for (int index = 0; index < enumeration->value_count(); ++index) {
585       const EnumValueDescriptor* value = enumeration->value(index);
586       const char* switch_stub = R"(
587 case $full_class$::$value_name$:
588   return "$value_name$";
589 )";
590       stub_h_->Print(switch_stub, "full_class", fullClassName, "value_name",
591                      value->name());
592     }
593     stub_h_->Print("}\n");
594     stub_h_->Print(R"(return "PBZERO_UNKNOWN_ENUM_VALUE";)");
595     stub_h_->Print("\n");
596     stub_h_->Outdent();
597     stub_h_->Print("}\n\n");
598   }
599 
600   // Packed repeated fields are encoded as a length-delimited field on the wire,
601   // where the payload is the concatenation of invidually encoded elements.
GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor * field)602   void GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor* field) {
603     std::map<std::string, std::string> setter;
604     setter["name"] = field->lowercase_name();
605     setter["field_metadata"] = GetFieldMetadataTypeName(field);
606     setter["action"] = "set";
607     setter["buffer_type"] = FieldTypeToPackedBufferType(field->type());
608     stub_h_->Print(
609         setter,
610         "void $action$_$name$(const $buffer_type$& packed_buffer) {\n"
611         "  AppendBytes($field_metadata$::kFieldId, packed_buffer.data(),\n"
612         "              packed_buffer.size());\n"
613         "}\n");
614   }
615 
GenerateSimpleFieldDescriptor(const FieldDescriptor * field)616   void GenerateSimpleFieldDescriptor(const FieldDescriptor* field) {
617     std::map<std::string, std::string> setter;
618     setter["id"] = std::to_string(field->number());
619     setter["name"] = field->lowercase_name();
620     setter["field_metadata"] = GetFieldMetadataTypeName(field);
621     setter["action"] = field->is_repeated() ? "add" : "set";
622     setter["cpp_type"] = FieldToCppTypeName(field);
623     setter["proto_field_type"] = FieldToProtoSchemaType(field);
624 
625     const char* code_stub =
626         "void $action$_$name$($cpp_type$ value) {\n"
627         "  static constexpr uint32_t field_id = $field_metadata$::kFieldId;\n"
628         "  // Call the appropriate protozero::Message::Append(field_id, ...)\n"
629         "  // method based on the type of the field.\n"
630         "  ::protozero::internal::FieldWriter<\n"
631         "    ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$>\n"
632         "      ::Append(*this, field_id, value);\n"
633         "}\n";
634 
635     if (field->type() == FieldDescriptor::TYPE_STRING) {
636       // Strings and bytes should have an additional accessor which specifies
637       // the length explicitly.
638       const char* additional_method =
639           "void $action$_$name$(const char* data, size_t size) {\n"
640           "  AppendBytes($field_metadata$::kFieldId, data, size);\n"
641           "}\n"
642           "void $action$_$name$(::protozero::ConstChars chars) {\n"
643           "  AppendBytes($field_metadata$::kFieldId, chars.data, chars.size);\n"
644           "}\n";
645       stub_h_->Print(setter, additional_method);
646     } else if (field->type() == FieldDescriptor::TYPE_BYTES) {
647       const char* additional_method =
648           "void $action$_$name$(const uint8_t* data, size_t size) {\n"
649           "  AppendBytes($field_metadata$::kFieldId, data, size);\n"
650           "}\n"
651           "void $action$_$name$(::protozero::ConstBytes bytes) {\n"
652           "  AppendBytes($field_metadata$::kFieldId, bytes.data, bytes.size);\n"
653           "}\n";
654       stub_h_->Print(setter, additional_method);
655     } else if (field->type() == FieldDescriptor::TYPE_GROUP ||
656                field->type() == FieldDescriptor::TYPE_MESSAGE) {
657       Abort("Unsupported field type.");
658       return;
659     }
660 
661     stub_h_->Print(setter, code_stub);
662   }
663 
GenerateNestedMessageFieldDescriptor(const FieldDescriptor * field)664   void GenerateNestedMessageFieldDescriptor(const FieldDescriptor* field) {
665     std::string action = field->is_repeated() ? "add" : "set";
666     std::string inner_class = GetCppClassName(field->message_type());
667     stub_h_->Print(
668         "template <typename T = $inner_class$> T* $action$_$name$() {\n"
669         "  return BeginNestedMessage<T>($id$);\n"
670         "}\n\n",
671         "id", std::to_string(field->number()), "name", field->lowercase_name(),
672         "action", action, "inner_class", inner_class);
673     if (field->options().lazy()) {
674       stub_h_->Print(
675           "void $action$_$name$_raw(const std::string& raw) {\n"
676           "  return AppendBytes($id$, raw.data(), raw.size());\n"
677           "}\n\n",
678           "id", std::to_string(field->number()), "name",
679           field->lowercase_name(), "action", action, "inner_class",
680           inner_class);
681     }
682   }
683 
GenerateDecoder(const Descriptor * message)684   void GenerateDecoder(const Descriptor* message) {
685     int max_field_id = 0;
686     bool has_nonpacked_repeated_fields = false;
687     for (int i = 0; i < message->field_count(); ++i) {
688       const FieldDescriptor* field = message->field(i);
689       if (field->number() > kMaxDecoderFieldId)
690         continue;
691       max_field_id = std::max(max_field_id, field->number());
692       if (field->is_repeated() && !field->is_packed())
693         has_nonpacked_repeated_fields = true;
694     }
695 
696     std::string class_name = GetCppClassName(message) + "_Decoder";
697     stub_h_->Print(
698         "class $name$ : public "
699         "::protozero::TypedProtoDecoder</*MAX_FIELD_ID=*/$max$, "
700         "/*HAS_NONPACKED_REPEATED_FIELDS=*/$rep$> {\n",
701         "name", class_name, "max", std::to_string(max_field_id), "rep",
702         has_nonpacked_repeated_fields ? "true" : "false");
703     stub_h_->Print(" public:\n");
704     stub_h_->Indent();
705     stub_h_->Print(
706         "$name$(const uint8_t* data, size_t len) "
707         ": TypedProtoDecoder(data, len) {}\n",
708         "name", class_name);
709     stub_h_->Print(
710         "explicit $name$(const std::string& raw) : "
711         "TypedProtoDecoder(reinterpret_cast<const uint8_t*>(raw.data()), "
712         "raw.size()) {}\n",
713         "name", class_name);
714     stub_h_->Print(
715         "explicit $name$(const ::protozero::ConstBytes& raw) : "
716         "TypedProtoDecoder(raw.data, raw.size) {}\n",
717         "name", class_name);
718 
719     for (int i = 0; i < message->field_count(); ++i) {
720       const FieldDescriptor* field = message->field(i);
721       if (field->number() > max_field_id) {
722         stub_h_->Print("// field $name$ omitted because its id is too high\n",
723                        "name", field->name());
724         continue;
725       }
726       std::string getter;
727       std::string cpp_type;
728       switch (field->type()) {
729         case FieldDescriptor::TYPE_BOOL:
730           getter = "as_bool";
731           cpp_type = "bool";
732           break;
733         case FieldDescriptor::TYPE_SFIXED32:
734         case FieldDescriptor::TYPE_SINT32:
735         case FieldDescriptor::TYPE_INT32:
736           getter = "as_int32";
737           cpp_type = "int32_t";
738           break;
739         case FieldDescriptor::TYPE_SFIXED64:
740         case FieldDescriptor::TYPE_SINT64:
741         case FieldDescriptor::TYPE_INT64:
742           getter = "as_int64";
743           cpp_type = "int64_t";
744           break;
745         case FieldDescriptor::TYPE_FIXED32:
746         case FieldDescriptor::TYPE_UINT32:
747           getter = "as_uint32";
748           cpp_type = "uint32_t";
749           break;
750         case FieldDescriptor::TYPE_FIXED64:
751         case FieldDescriptor::TYPE_UINT64:
752           getter = "as_uint64";
753           cpp_type = "uint64_t";
754           break;
755         case FieldDescriptor::TYPE_FLOAT:
756           getter = "as_float";
757           cpp_type = "float";
758           break;
759         case FieldDescriptor::TYPE_DOUBLE:
760           getter = "as_double";
761           cpp_type = "double";
762           break;
763         case FieldDescriptor::TYPE_ENUM:
764           getter = "as_int32";
765           cpp_type = "int32_t";
766           break;
767         case FieldDescriptor::TYPE_STRING:
768           getter = "as_string";
769           cpp_type = "::protozero::ConstChars";
770           break;
771         case FieldDescriptor::TYPE_MESSAGE:
772         case FieldDescriptor::TYPE_BYTES:
773           getter = "as_bytes";
774           cpp_type = "::protozero::ConstBytes";
775           break;
776         case FieldDescriptor::TYPE_GROUP:
777           continue;
778       }
779 
780       stub_h_->Print("bool has_$name$() const { return at<$id$>().valid(); }\n",
781                      "name", field->lowercase_name(), "id",
782                      std::to_string(field->number()));
783 
784       if (field->is_packed()) {
785         const char* protozero_wire_type =
786             FieldTypeToProtozeroWireType(field->type());
787         stub_h_->Print(
788             "::protozero::PackedRepeatedFieldIterator<$wire_type$, $cpp_type$> "
789             "$name$(bool* parse_error_ptr) const { return "
790             "GetPackedRepeated<$wire_type$, $cpp_type$>($id$, "
791             "parse_error_ptr); }\n",
792             "wire_type", protozero_wire_type, "cpp_type", cpp_type, "name",
793             field->lowercase_name(), "id", std::to_string(field->number()));
794       } else if (field->is_repeated()) {
795         stub_h_->Print(
796             "::protozero::RepeatedFieldIterator<$cpp_type$> $name$() const { "
797             "return "
798             "GetRepeated<$cpp_type$>($id$); }\n",
799             "name", field->lowercase_name(), "cpp_type", cpp_type, "id",
800             std::to_string(field->number()));
801       } else {
802         stub_h_->Print(
803             "$cpp_type$ $name$() const { return at<$id$>().$getter$(); }\n",
804             "name", field->lowercase_name(), "id",
805             std::to_string(field->number()), "cpp_type", cpp_type, "getter",
806             getter);
807       }
808     }
809     stub_h_->Outdent();
810     stub_h_->Print("};\n\n");
811   }
812 
GenerateConstantsForMessageFields(const Descriptor * message)813   void GenerateConstantsForMessageFields(const Descriptor* message) {
814     const bool has_fields = (message->field_count() > 0);
815 
816     // Field number constants.
817     if (has_fields) {
818       stub_h_->Print("enum : int32_t {\n");
819       stub_h_->Indent();
820 
821       for (int i = 0; i < message->field_count(); ++i) {
822         const FieldDescriptor* field = message->field(i);
823         stub_h_->Print("$name$ = $id$,\n", "name",
824                        GetFieldNumberConstant(field), "id",
825                        std::to_string(field->number()));
826       }
827       stub_h_->Outdent();
828       stub_h_->Print("};\n");
829     }
830   }
831 
GenerateMessageDescriptor(const Descriptor * message)832   void GenerateMessageDescriptor(const Descriptor* message) {
833     GenerateDecoder(message);
834 
835     stub_h_->Print(
836         "class $name$ : public ::protozero::Message {\n"
837         " public:\n",
838         "name", GetCppClassName(message));
839     stub_h_->Indent();
840 
841     stub_h_->Print("using Decoder = $name$_Decoder;\n", "name",
842                    GetCppClassName(message));
843 
844     GenerateConstantsForMessageFields(message);
845 
846     stub_h_->Print(
847         "static constexpr const char* GetName() { return \".$name$\"; }\n\n",
848         "name", message->full_name());
849 
850     // Using statements for nested messages.
851     for (int i = 0; i < message->nested_type_count(); ++i) {
852       const Descriptor* nested_message = message->nested_type(i);
853       stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
854                      nested_message->name(), "global_name",
855                      GetCppClassName(nested_message, true));
856     }
857 
858     // Using statements for nested enums.
859     for (int i = 0; i < message->enum_type_count(); ++i) {
860       const EnumDescriptor* nested_enum = message->enum_type(i);
861       const char* stub = R"(
862 using $local_name$ = $global_name$;
863 static inline const char* $local_name$_Name($local_name$ value) {
864   return $global_name$_Name(value);
865 }
866 )";
867       stub_h_->Print(stub, "local_name", nested_enum->name(), "global_name",
868                      GetCppClassName(nested_enum, true));
869     }
870 
871     // Values of nested enums.
872     for (int i = 0; i < message->enum_type_count(); ++i) {
873       const EnumDescriptor* nested_enum = message->enum_type(i);
874 
875       for (int j = 0; j < nested_enum->value_count(); ++j) {
876         const EnumValueDescriptor* value = nested_enum->value(j);
877         stub_h_->Print("static const $class$ $name$ = $class$::$name$;\n",
878                        "class", nested_enum->name(), "name", value->name());
879       }
880     }
881 
882     // Field descriptors.
883     for (int i = 0; i < message->field_count(); ++i) {
884       GenerateFieldDescriptor(GetCppClassName(message), message->field(i));
885     }
886 
887     stub_h_->Outdent();
888     stub_h_->Print("};\n\n");
889   }
890 
GetFieldMetadataTypeName(const FieldDescriptor * field)891   std::string GetFieldMetadataTypeName(const FieldDescriptor* field) {
892     std::string name = field->camelcase_name();
893     if (isalpha(name[0]))
894       name[0] = static_cast<char>(toupper(name[0]));
895     return "FieldMetadata_" + name;
896   }
897 
GetFieldMetadataVariableName(const FieldDescriptor * field)898   std::string GetFieldMetadataVariableName(const FieldDescriptor* field) {
899     std::string name = field->camelcase_name();
900     if (isalpha(name[0]))
901       name[0] = static_cast<char>(toupper(name[0]));
902     return "k" + name;
903   }
904 
GenerateFieldMetadata(const std::string & message_cpp_type,const FieldDescriptor * field)905   void GenerateFieldMetadata(const std::string& message_cpp_type,
906                              const FieldDescriptor* field) {
907     const char* code_stub = R"(
908 using $field_metadata_type$ =
909   ::protozero::proto_utils::FieldMetadata<
910     $field_id$,
911     ::protozero::proto_utils::RepetitionType::$repetition_type$,
912     ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$,
913     $cpp_type$,
914     $message_cpp_type$>;
915 
916 static constexpr $field_metadata_type$ $field_metadata_var${};
917 )";
918 
919     stub_h_->Print(code_stub, "field_id", std::to_string(field->number()),
920                    "repetition_type", FieldToRepetitionType(field),
921                    "proto_field_type", FieldToProtoSchemaType(field),
922                    "cpp_type", FieldToCppTypeName(field), "message_cpp_type",
923                    message_cpp_type, "field_metadata_type",
924                    GetFieldMetadataTypeName(field), "field_metadata_var",
925                    GetFieldMetadataVariableName(field));
926   }
927 
GenerateFieldDescriptor(const std::string & message_cpp_type,const FieldDescriptor * field)928   void GenerateFieldDescriptor(const std::string& message_cpp_type,
929                                const FieldDescriptor* field) {
930     GenerateFieldMetadata(message_cpp_type, field);
931     if (field->is_packed()) {
932       GeneratePackedRepeatedFieldDescriptor(field);
933     } else if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
934       GenerateSimpleFieldDescriptor(field);
935     } else {
936       GenerateNestedMessageFieldDescriptor(field);
937     }
938   }
939 
940   // Generate extension class for a group of FieldDescriptor instances
941   // representing one "extend" block in proto definition. For example:
942   //
943   //   message SpecificExtension {
944   //     extend GeneralThing {
945   //       optional Fizz fizz = 101;
946   //       optional Buzz buzz = 102;
947   //     }
948   //   }
949   //
950   // This is going to be passed as a vector of two elements, "fizz" and
951   // "buzz". Wrapping message is used to provide a name for generated
952   // extension class.
953   //
954   // In the example above, generated code is going to look like:
955   //
956   //   class SpecificExtension : public GeneralThing {
957   //     Fizz* set_fizz();
958   //     Buzz* set_buzz();
959   //   }
GenerateExtension(const std::string & extension_name,const std::vector<const FieldDescriptor * > & descriptors)960   void GenerateExtension(
961       const std::string& extension_name,
962       const std::vector<const FieldDescriptor*>& descriptors) {
963     // Use an arbitrary descriptor in order to get generic information not
964     // specific to any of them.
965     const FieldDescriptor* descriptor = descriptors[0];
966     const Descriptor* base_message = descriptor->containing_type();
967 
968     // TODO(ddrone): ensure that this code works when containing_type located in
969     // other file or namespace.
970     stub_h_->Print("class $name$ : public $extendee$ {\n", "name",
971                    extension_name, "extendee",
972                    GetCppClassName(base_message, /*full=*/true));
973     stub_h_->Print(" public:\n");
974     stub_h_->Indent();
975     for (const FieldDescriptor* field : descriptors) {
976       if (field->containing_type() != base_message) {
977         Abort("one wrapper should extend only one message");
978         return;
979       }
980       GenerateFieldDescriptor(extension_name, field);
981     }
982     stub_h_->Outdent();
983     stub_h_->Print("};\n");
984   }
985 
GenerateEpilogue()986   void GenerateEpilogue() {
987     for (unsigned i = 0; i < namespaces_.size(); ++i) {
988       stub_h_->Print("} // Namespace.\n");
989     }
990     stub_h_->Print("#endif  // Include guard.\n");
991   }
992 
993   const FileDescriptor* const source_;
994   Printer* const stub_h_;
995   std::string error_;
996 
997   std::string package_;
998   std::string wrapper_namespace_;
999   std::vector<std::string> namespaces_;
1000   std::string full_namespace_prefix_;
1001   std::vector<const Descriptor*> messages_;
1002   std::vector<const EnumDescriptor*> enums_;
1003   std::map<std::string, std::vector<const FieldDescriptor*>> extensions_;
1004 
1005   // The custom *Comp comparators are to ensure determinism of the generator.
1006   std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
1007   std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
1008   std::set<const Descriptor*, DescriptorComp> referenced_messages_;
1009   std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
1010 };
1011 
1012 class ProtoZeroGenerator : public ::google::protobuf::compiler::CodeGenerator {
1013  public:
1014   explicit ProtoZeroGenerator();
1015   ~ProtoZeroGenerator() override;
1016 
1017   // CodeGenerator implementation
1018   bool Generate(const google::protobuf::FileDescriptor* file,
1019                 const std::string& options,
1020                 GeneratorContext* context,
1021                 std::string* error) const override;
1022 };
1023 
ProtoZeroGenerator()1024 ProtoZeroGenerator::ProtoZeroGenerator() {}
1025 
~ProtoZeroGenerator()1026 ProtoZeroGenerator::~ProtoZeroGenerator() {}
1027 
Generate(const FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const1028 bool ProtoZeroGenerator::Generate(const FileDescriptor* file,
1029                                   const std::string& options,
1030                                   GeneratorContext* context,
1031                                   std::string* error) const {
1032   const std::unique_ptr<ZeroCopyOutputStream> stub_h_file_stream(
1033       context->Open(ProtoStubName(file) + ".h"));
1034   const std::unique_ptr<ZeroCopyOutputStream> stub_cc_file_stream(
1035       context->Open(ProtoStubName(file) + ".cc"));
1036 
1037   // Variables are delimited by $.
1038   Printer stub_h_printer(stub_h_file_stream.get(), '$');
1039   GeneratorJob job(file, &stub_h_printer);
1040 
1041   Printer stub_cc_printer(stub_cc_file_stream.get(), '$');
1042   stub_cc_printer.Print("// Intentionally empty (crbug.com/998165)\n");
1043 
1044   // Parse additional options.
1045   for (const std::string& option : SplitString(options, ",")) {
1046     std::vector<std::string> option_pair = SplitString(option, "=");
1047     job.SetOption(option_pair[0], option_pair[1]);
1048   }
1049 
1050   if (!job.GenerateStubs()) {
1051     *error = job.GetFirstError();
1052     return false;
1053   }
1054   return true;
1055 }
1056 
1057 }  // namespace
1058 }  // namespace protozero
1059 
main(int argc,char * argv[])1060 int main(int argc, char* argv[]) {
1061   ::protozero::ProtoZeroGenerator generator;
1062   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
1063 }
1064