• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <limits>
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 
23 #include <google/protobuf/compiler/code_generator.h>
24 #include <google/protobuf/compiler/plugin.h>
25 #include <google/protobuf/descriptor.h>
26 #include <google/protobuf/descriptor.pb.h>
27 #include <google/protobuf/io/printer.h>
28 #include <google/protobuf/io/zero_copy_stream.h>
29 
30 #include "perfetto/ext/base/string_utils.h"
31 
32 namespace protozero {
33 namespace {
34 
35 using google::protobuf::Descriptor;
36 using google::protobuf::EnumDescriptor;
37 using google::protobuf::EnumValueDescriptor;
38 using google::protobuf::FieldDescriptor;
39 using google::protobuf::FileDescriptor;
40 using google::protobuf::compiler::GeneratorContext;
41 using google::protobuf::io::Printer;
42 using google::protobuf::io::ZeroCopyOutputStream;
43 using perfetto::base::SplitString;
44 using perfetto::base::StripChars;
45 using perfetto::base::StripPrefix;
46 using perfetto::base::StripSuffix;
47 using perfetto::base::ToUpper;
48 using perfetto::base::Uppercase;
49 
50 // Keep this value in sync with ProtoDecoder::kMaxDecoderFieldId. If they go out
51 // of sync pbzero.h files will stop compiling, hitting the at() static_assert.
52 // Not worth an extra dependency.
53 constexpr int kMaxDecoderFieldId = 999;
54 
Assert(bool condition)55 void Assert(bool condition) {
56   if (!condition)
57     __builtin_trap();
58 }
59 
60 struct FileDescriptorComp {
operator ()protozero::__anon42b1c0d80111::FileDescriptorComp61   bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
62     int comp = lhs->name().compare(rhs->name());
63     Assert(comp != 0 || lhs == rhs);
64     return comp < 0;
65   }
66 };
67 
68 struct DescriptorComp {
operator ()protozero::__anon42b1c0d80111::DescriptorComp69   bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
70     int comp = lhs->full_name().compare(rhs->full_name());
71     Assert(comp != 0 || lhs == rhs);
72     return comp < 0;
73   }
74 };
75 
76 struct EnumDescriptorComp {
operator ()protozero::__anon42b1c0d80111::EnumDescriptorComp77   bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
78     int comp = lhs->full_name().compare(rhs->full_name());
79     Assert(comp != 0 || lhs == rhs);
80     return comp < 0;
81   }
82 };
83 
ProtoStubName(const FileDescriptor * proto)84 inline std::string ProtoStubName(const FileDescriptor* proto) {
85   return StripSuffix(proto->name(), ".proto") + ".pbzero";
86 }
87 
88 class GeneratorJob {
89  public:
GeneratorJob(const FileDescriptor * file,Printer * stub_h_printer)90   GeneratorJob(const FileDescriptor* file, Printer* stub_h_printer)
91       : source_(file), stub_h_(stub_h_printer) {}
92 
GenerateStubs()93   bool GenerateStubs() {
94     Preprocess();
95     GeneratePrologue();
96     for (const EnumDescriptor* enumeration : enums_)
97       GenerateEnumDescriptor(enumeration);
98     for (const Descriptor* message : messages_)
99       GenerateMessageDescriptor(message);
100     GenerateEpilogue();
101     return error_.empty();
102   }
103 
SetOption(const std::string & name,const std::string & value)104   void SetOption(const std::string& name, const std::string& value) {
105     if (name == "wrapper_namespace") {
106       wrapper_namespace_ = value;
107     } else {
108       Abort(std::string() + "Unknown plugin option '" + name + "'.");
109     }
110   }
111 
112   // If generator fails to produce stubs for a particular proto definitions
113   // it finishes with undefined output and writes the first error occured.
GetFirstError() const114   const std::string& GetFirstError() const { return error_; }
115 
116  private:
117   // Only the first error will be recorded.
Abort(const std::string & reason)118   void Abort(const std::string& reason) {
119     if (error_.empty())
120       error_ = reason;
121   }
122 
123   // Get full name (including outer descriptors) of proto descriptor.
124   template <class T>
GetDescriptorName(const T * descriptor)125   inline std::string GetDescriptorName(const T* descriptor) {
126     if (!package_.empty()) {
127       return StripPrefix(descriptor->full_name(), package_ + ".");
128     } else {
129       return descriptor->full_name();
130     }
131   }
132 
133   // Get C++ class name corresponding to proto descriptor.
134   // Nested names are splitted by underscores. Underscores in type names aren't
135   // prohibited but not recommended in order to avoid name collisions.
136   template <class T>
GetCppClassName(const T * descriptor,bool full=false)137   inline std::string GetCppClassName(const T* descriptor, bool full = false) {
138     std::string name = StripChars(GetDescriptorName(descriptor), ".", '_');
139     if (full)
140       name = full_namespace_prefix_ + name;
141     return name;
142   }
143 
GetFieldNumberConstant(const FieldDescriptor * field)144   inline std::string GetFieldNumberConstant(const FieldDescriptor* field) {
145     std::string name = field->camelcase_name();
146     if (!name.empty()) {
147       name.at(0) = Uppercase(name.at(0));
148       name = "k" + name + "FieldNumber";
149     } else {
150       // Protoc allows fields like 'bool _ = 1'.
151       Abort("Empty field name in camel case notation.");
152     }
153     return name;
154   }
155 
156   // Small enums can be written faster without involving VarInt encoder.
IsTinyEnumField(const FieldDescriptor * field)157   inline bool IsTinyEnumField(const FieldDescriptor* field) {
158     if (field->type() != FieldDescriptor::TYPE_ENUM)
159       return false;
160     const EnumDescriptor* enumeration = field->enum_type();
161 
162     for (int i = 0; i < enumeration->value_count(); ++i) {
163       int32_t value = enumeration->value(i)->number();
164       if (value < 0 || value > 0x7F)
165         return false;
166     }
167     return true;
168   }
169 
170   // Note: intentionally avoiding depending on protozero sources, as well as
171   // protobuf-internal WireFormat/WireFormatLite classes.
FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type)172   const char* FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type) {
173     switch (proto_type) {
174       case FieldDescriptor::TYPE_INT64:
175       case FieldDescriptor::TYPE_UINT64:
176       case FieldDescriptor::TYPE_INT32:
177       case FieldDescriptor::TYPE_BOOL:
178       case FieldDescriptor::TYPE_UINT32:
179       case FieldDescriptor::TYPE_ENUM:
180       case FieldDescriptor::TYPE_SINT32:
181       case FieldDescriptor::TYPE_SINT64:
182         return "::protozero::proto_utils::ProtoWireType::kVarInt";
183 
184       case FieldDescriptor::TYPE_FIXED32:
185       case FieldDescriptor::TYPE_SFIXED32:
186       case FieldDescriptor::TYPE_FLOAT:
187         return "::protozero::proto_utils::ProtoWireType::kFixed32";
188 
189       case FieldDescriptor::TYPE_FIXED64:
190       case FieldDescriptor::TYPE_SFIXED64:
191       case FieldDescriptor::TYPE_DOUBLE:
192         return "::protozero::proto_utils::ProtoWireType::kFixed64";
193 
194       case FieldDescriptor::TYPE_STRING:
195       case FieldDescriptor::TYPE_MESSAGE:
196       case FieldDescriptor::TYPE_BYTES:
197         return "::protozero::proto_utils::ProtoWireType::kLengthDelimited";
198 
199       case FieldDescriptor::TYPE_GROUP:
200         Abort("Groups not supported.");
201     }
202     Abort("Unrecognized FieldDescriptor::Type.");
203     return "";
204   }
205 
FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type)206   const char* FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type) {
207     switch (proto_type) {
208       case FieldDescriptor::TYPE_INT64:
209       case FieldDescriptor::TYPE_UINT64:
210       case FieldDescriptor::TYPE_INT32:
211       case FieldDescriptor::TYPE_BOOL:
212       case FieldDescriptor::TYPE_UINT32:
213       case FieldDescriptor::TYPE_ENUM:
214       case FieldDescriptor::TYPE_SINT32:
215       case FieldDescriptor::TYPE_SINT64:
216         return "::protozero::PackedVarInt";
217 
218       case FieldDescriptor::TYPE_FIXED32:
219         return "::protozero::PackedFixedSizeInt<uint32_t>";
220       case FieldDescriptor::TYPE_SFIXED32:
221         return "::protozero::PackedFixedSizeInt<int32_t>";
222       case FieldDescriptor::TYPE_FLOAT:
223         return "::protozero::PackedFixedSizeInt<float>";
224 
225       case FieldDescriptor::TYPE_FIXED64:
226         return "::protozero::PackedFixedSizeInt<uint64_t>";
227       case FieldDescriptor::TYPE_SFIXED64:
228         return "::protozero::PackedFixedSizeInt<int64_t>";
229       case FieldDescriptor::TYPE_DOUBLE:
230         return "::protozero::PackedFixedSizeInt<double>";
231 
232       case FieldDescriptor::TYPE_STRING:
233       case FieldDescriptor::TYPE_MESSAGE:
234       case FieldDescriptor::TYPE_BYTES:
235       case FieldDescriptor::TYPE_GROUP:
236         Abort("Unexpected FieldDescritor::Type.");
237     }
238     Abort("Unrecognized FieldDescriptor::Type.");
239     return "";
240   }
241 
CollectDescriptors()242   void CollectDescriptors() {
243     // Collect message descriptors in DFS order.
244     std::vector<const Descriptor*> stack;
245     for (int i = 0; i < source_->message_type_count(); ++i)
246       stack.push_back(source_->message_type(i));
247 
248     while (!stack.empty()) {
249       const Descriptor* message = stack.back();
250       stack.pop_back();
251       messages_.push_back(message);
252       for (int i = 0; i < message->nested_type_count(); ++i) {
253         stack.push_back(message->nested_type(i));
254       }
255     }
256 
257     // Collect enums.
258     for (int i = 0; i < source_->enum_type_count(); ++i)
259       enums_.push_back(source_->enum_type(i));
260 
261     for (const Descriptor* message : messages_) {
262       for (int i = 0; i < message->enum_type_count(); ++i) {
263         enums_.push_back(message->enum_type(i));
264       }
265     }
266   }
267 
CollectDependencies()268   void CollectDependencies() {
269     // Public import basically means that callers only need to import this
270     // proto in order to use the stuff publicly imported by this proto.
271     for (int i = 0; i < source_->public_dependency_count(); ++i)
272       public_imports_.insert(source_->public_dependency(i));
273 
274     if (source_->weak_dependency_count() > 0)
275       Abort("Weak imports are not supported.");
276 
277     // Sanity check. Collect public imports (of collected imports) in DFS order.
278     // Visibilty for current proto:
279     // - all imports listed in current proto,
280     // - public imports of everything imported (recursive).
281     std::vector<const FileDescriptor*> stack;
282     for (int i = 0; i < source_->dependency_count(); ++i) {
283       const FileDescriptor* import = source_->dependency(i);
284       stack.push_back(import);
285       if (public_imports_.count(import) == 0) {
286         private_imports_.insert(import);
287       }
288     }
289 
290     while (!stack.empty()) {
291       const FileDescriptor* import = stack.back();
292       stack.pop_back();
293       // Having imports under different packages leads to unnecessary
294       // complexity with namespaces.
295       if (import->package() != package_)
296         Abort("Imported proto must be in the same package.");
297 
298       for (int i = 0; i < import->public_dependency_count(); ++i) {
299         stack.push_back(import->public_dependency(i));
300       }
301     }
302 
303     // Collect descriptors of messages and enums used in current proto.
304     // It will be used to generate necessary forward declarations and performed
305     // sanity check guarantees that everything lays in the same namespace.
306     for (const Descriptor* message : messages_) {
307       for (int i = 0; i < message->field_count(); ++i) {
308         const FieldDescriptor* field = message->field(i);
309 
310         if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
311           if (public_imports_.count(field->message_type()->file()) == 0) {
312             // Avoid multiple forward declarations since
313             // public imports have been already included.
314             referenced_messages_.insert(field->message_type());
315           }
316         } else if (field->type() == FieldDescriptor::TYPE_ENUM) {
317           if (public_imports_.count(field->enum_type()->file()) == 0) {
318             referenced_enums_.insert(field->enum_type());
319           }
320         }
321       }
322     }
323   }
324 
Preprocess()325   void Preprocess() {
326     // Package name maps to a series of namespaces.
327     package_ = source_->package();
328     namespaces_ = SplitString(package_, ".");
329     if (!wrapper_namespace_.empty())
330       namespaces_.push_back(wrapper_namespace_);
331 
332     full_namespace_prefix_ = "::";
333     for (const std::string& ns : namespaces_)
334       full_namespace_prefix_ += ns + "::";
335 
336     CollectDescriptors();
337     CollectDependencies();
338   }
339 
340   // Print top header, namespaces and forward declarations.
GeneratePrologue()341   void GeneratePrologue() {
342     std::string greeting =
343         "// Autogenerated by the ProtoZero compiler plugin. DO NOT EDIT.\n";
344     std::string guard = package_ + "_" + source_->name() + "_H_";
345     guard = ToUpper(guard);
346     guard = StripChars(guard, ".-/\\", '_');
347 
348     stub_h_->Print(
349         "$greeting$\n"
350         "#ifndef $guard$\n"
351         "#define $guard$\n\n"
352         "#include <stddef.h>\n"
353         "#include <stdint.h>\n\n"
354         "#include \"perfetto/protozero/message.h\"\n"
355         "#include \"perfetto/protozero/packed_repeated_fields.h\"\n"
356         "#include \"perfetto/protozero/proto_decoder.h\"\n"
357         "#include \"perfetto/protozero/proto_utils.h\"\n",
358         "greeting", greeting, "guard", guard);
359 
360     // Print includes for public imports.
361     for (const FileDescriptor* dependency : public_imports_) {
362       // Dependency name could contain slashes but importing from upper-level
363       // directories is not possible anyway since build system processes each
364       // proto file individually. Hence proto lookup path is always equal to the
365       // directory where particular proto file is located and protoc does not
366       // allow reference to upper directory (aka ..) in import path.
367       //
368       // Laconically said:
369       // - source_->name() may never have slashes,
370       // - dependency->name() may have slashes but always refers to inner path.
371       stub_h_->Print("#include \"$name$.h\"\n", "name",
372                      ProtoStubName(dependency));
373     }
374     stub_h_->Print("\n");
375 
376     // Print namespaces.
377     for (const std::string& ns : namespaces_) {
378       stub_h_->Print("namespace $ns$ {\n", "ns", ns);
379     }
380     stub_h_->Print("\n");
381 
382     // Print forward declarations.
383     for (const Descriptor* message : referenced_messages_) {
384       stub_h_->Print("class $class$;\n", "class", GetCppClassName(message));
385     }
386     for (const EnumDescriptor* enumeration : referenced_enums_) {
387       stub_h_->Print("enum $class$ : int32_t;\n", "class",
388                      GetCppClassName(enumeration));
389     }
390     stub_h_->Print("\n");
391   }
392 
GenerateEnumDescriptor(const EnumDescriptor * enumeration)393   void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
394     stub_h_->Print("enum $class$ : int32_t {\n", "class",
395                    GetCppClassName(enumeration));
396     stub_h_->Indent();
397 
398     std::string value_name_prefix;
399     if (enumeration->containing_type() != nullptr)
400       value_name_prefix = GetCppClassName(enumeration) + "_";
401 
402     std::string min_name, max_name;
403     int min_val = std::numeric_limits<int>::max();
404     int max_val = -1;
405     for (int i = 0; i < enumeration->value_count(); ++i) {
406       const EnumValueDescriptor* value = enumeration->value(i);
407       stub_h_->Print("$name$ = $number$,\n", "name",
408                      value_name_prefix + value->name(), "number",
409                      std::to_string(value->number()));
410       if (value->number() < min_val) {
411         min_val = value->number();
412         min_name = value_name_prefix + value->name();
413       }
414       if (value->number() > max_val) {
415         max_val = value->number();
416         max_name = value_name_prefix + value->name();
417       }
418     }
419     stub_h_->Outdent();
420     stub_h_->Print("};\n\n");
421     stub_h_->Print("const $class$ $class$_MIN = $min$;\n", "class",
422                    GetCppClassName(enumeration), "min", min_name);
423     stub_h_->Print("const $class$ $class$_MAX = $max$;\n", "class",
424                    GetCppClassName(enumeration), "max", max_name);
425     stub_h_->Print("\n");
426   }
427 
428   // Packed repeated fields are encoded as a length-delimited field on the wire,
429   // where the payload is the concatenation of invidually encoded elements.
GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor * field)430   void GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor* field) {
431     std::map<std::string, std::string> setter;
432     setter["id"] = std::to_string(field->number());
433     setter["name"] = field->lowercase_name();
434     setter["action"] = "set";
435     setter["buffer_type"] = FieldTypeToPackedBufferType(field->type());
436     stub_h_->Print(
437         setter,
438         "void $action$_$name$(const $buffer_type$& packed_buffer) {\n"
439         "  AppendBytes($id$, packed_buffer.data(), packed_buffer.size());\n"
440         "}\n");
441   }
442 
GenerateSimpleFieldDescriptor(const FieldDescriptor * field)443   void GenerateSimpleFieldDescriptor(const FieldDescriptor* field) {
444     std::map<std::string, std::string> setter;
445     setter["id"] = std::to_string(field->number());
446     setter["name"] = field->lowercase_name();
447     setter["action"] = field->is_repeated() ? "add" : "set";
448 
449     std::string appender;
450     std::string cpp_type;
451     const char* code_stub =
452         "void $action$_$name$($cpp_type$ value) {\n"
453         "  $appender$($id$, value);\n"
454         "}\n";
455 
456     switch (field->type()) {
457       case FieldDescriptor::TYPE_BOOL: {
458         appender = "AppendTinyVarInt";
459         cpp_type = "bool";
460         break;
461       }
462       case FieldDescriptor::TYPE_INT32: {
463         appender = "AppendVarInt";
464         cpp_type = "int32_t";
465         break;
466       }
467       case FieldDescriptor::TYPE_INT64: {
468         appender = "AppendVarInt";
469         cpp_type = "int64_t";
470         break;
471       }
472       case FieldDescriptor::TYPE_UINT32: {
473         appender = "AppendVarInt";
474         cpp_type = "uint32_t";
475         break;
476       }
477       case FieldDescriptor::TYPE_UINT64: {
478         appender = "AppendVarInt";
479         cpp_type = "uint64_t";
480         break;
481       }
482       case FieldDescriptor::TYPE_SINT32: {
483         appender = "AppendSignedVarInt";
484         cpp_type = "int32_t";
485         break;
486       }
487       case FieldDescriptor::TYPE_SINT64: {
488         appender = "AppendSignedVarInt";
489         cpp_type = "int64_t";
490         break;
491       }
492       case FieldDescriptor::TYPE_FIXED32: {
493         appender = "AppendFixed";
494         cpp_type = "uint32_t";
495         break;
496       }
497       case FieldDescriptor::TYPE_FIXED64: {
498         appender = "AppendFixed";
499         cpp_type = "uint64_t";
500         break;
501       }
502       case FieldDescriptor::TYPE_SFIXED32: {
503         appender = "AppendFixed";
504         cpp_type = "int32_t";
505         break;
506       }
507       case FieldDescriptor::TYPE_SFIXED64: {
508         appender = "AppendFixed";
509         cpp_type = "int64_t";
510         break;
511       }
512       case FieldDescriptor::TYPE_FLOAT: {
513         appender = "AppendFixed";
514         cpp_type = "float";
515         break;
516       }
517       case FieldDescriptor::TYPE_DOUBLE: {
518         appender = "AppendFixed";
519         cpp_type = "double";
520         break;
521       }
522       case FieldDescriptor::TYPE_ENUM: {
523         appender = IsTinyEnumField(field) ? "AppendTinyVarInt" : "AppendVarInt";
524         cpp_type = GetCppClassName(field->enum_type(), true);
525         break;
526       }
527       case FieldDescriptor::TYPE_STRING:
528       case FieldDescriptor::TYPE_BYTES: {
529         if (field->type() == FieldDescriptor::TYPE_STRING) {
530           cpp_type = "const char*";
531         } else {
532           cpp_type = "const uint8_t*";
533         }
534         code_stub =
535             "void $action$_$name$(const std::string& value) {\n"
536             "  AppendBytes($id$, value.data(), value.size());\n"
537             "}\n"
538             "void $action$_$name$($cpp_type$ data, size_t size) {\n"
539             "  AppendBytes($id$, data, size);\n"
540             "}\n";
541         break;
542       }
543       case FieldDescriptor::TYPE_GROUP:
544       case FieldDescriptor::TYPE_MESSAGE: {
545         Abort("Unsupported field type.");
546         return;
547       }
548     }
549     setter["appender"] = appender;
550     setter["cpp_type"] = cpp_type;
551     stub_h_->Print(setter, code_stub);
552   }
553 
GenerateNestedMessageFieldDescriptor(const FieldDescriptor * field)554   void GenerateNestedMessageFieldDescriptor(const FieldDescriptor* field) {
555     std::string action = field->is_repeated() ? "add" : "set";
556     std::string inner_class = GetCppClassName(field->message_type());
557     stub_h_->Print(
558         "template <typename T = $inner_class$> T* $action$_$name$() {\n"
559         "  return BeginNestedMessage<T>($id$);\n"
560         "}\n\n",
561         "id", std::to_string(field->number()), "name", field->lowercase_name(),
562         "action", action, "inner_class", inner_class);
563     if (field->options().lazy()) {
564       stub_h_->Print(
565           "void $action$_$name$_raw(const std::string& raw) {\n"
566           "  return AppendBytes($id$, raw.data(), raw.size());\n"
567           "}\n\n",
568           "id", std::to_string(field->number()), "name",
569           field->lowercase_name(), "action", action, "inner_class",
570           inner_class);
571     }
572   }
573 
GenerateDecoder(const Descriptor * message)574   void GenerateDecoder(const Descriptor* message) {
575     int max_field_id = 0;
576     bool has_nonpacked_repeated_fields = false;
577     for (int i = 0; i < message->field_count(); ++i) {
578       const FieldDescriptor* field = message->field(i);
579       if (field->number() > kMaxDecoderFieldId)
580         continue;
581       max_field_id = std::max(max_field_id, field->number());
582       if (field->is_repeated() && !field->is_packed())
583         has_nonpacked_repeated_fields = true;
584     }
585 
586     std::string class_name = GetCppClassName(message) + "_Decoder";
587     stub_h_->Print(
588         "class $name$ : public "
589         "::protozero::TypedProtoDecoder</*MAX_FIELD_ID=*/$max$, "
590         "/*HAS_NONPACKED_REPEATED_FIELDS=*/$rep$> {\n",
591         "name", class_name, "max", std::to_string(max_field_id), "rep",
592         has_nonpacked_repeated_fields ? "true" : "false");
593     stub_h_->Print(" public:\n");
594     stub_h_->Indent();
595     stub_h_->Print(
596         "$name$(const uint8_t* data, size_t len) "
597         ": TypedProtoDecoder(data, len) {}\n",
598         "name", class_name);
599     stub_h_->Print(
600         "explicit $name$(const std::string& raw) : "
601         "TypedProtoDecoder(reinterpret_cast<const uint8_t*>(raw.data()), "
602         "raw.size()) {}\n",
603         "name", class_name);
604     stub_h_->Print(
605         "explicit $name$(const ::protozero::ConstBytes& raw) : "
606         "TypedProtoDecoder(raw.data, raw.size) {}\n",
607         "name", class_name);
608 
609     for (int i = 0; i < message->field_count(); ++i) {
610       const FieldDescriptor* field = message->field(i);
611       if (field->number() > max_field_id) {
612         stub_h_->Print("// field $name$ omitted because its id is too high\n",
613                        "name", field->name());
614         continue;
615       }
616       std::string getter;
617       std::string cpp_type;
618       switch (field->type()) {
619         case FieldDescriptor::TYPE_BOOL:
620           getter = "as_bool";
621           cpp_type = "bool";
622           break;
623         case FieldDescriptor::TYPE_SFIXED32:
624         case FieldDescriptor::TYPE_SINT32:
625         case FieldDescriptor::TYPE_INT32:
626           getter = "as_int32";
627           cpp_type = "int32_t";
628           break;
629         case FieldDescriptor::TYPE_SFIXED64:
630         case FieldDescriptor::TYPE_SINT64:
631         case FieldDescriptor::TYPE_INT64:
632           getter = "as_int64";
633           cpp_type = "int64_t";
634           break;
635         case FieldDescriptor::TYPE_FIXED32:
636         case FieldDescriptor::TYPE_UINT32:
637           getter = "as_uint32";
638           cpp_type = "uint32_t";
639           break;
640         case FieldDescriptor::TYPE_FIXED64:
641         case FieldDescriptor::TYPE_UINT64:
642           getter = "as_uint64";
643           cpp_type = "uint64_t";
644           break;
645         case FieldDescriptor::TYPE_FLOAT:
646           getter = "as_float";
647           cpp_type = "float";
648           break;
649         case FieldDescriptor::TYPE_DOUBLE:
650           getter = "as_double";
651           cpp_type = "double";
652           break;
653         case FieldDescriptor::TYPE_ENUM:
654           getter = "as_int32";
655           cpp_type = "int32_t";
656           break;
657         case FieldDescriptor::TYPE_STRING:
658           getter = "as_string";
659           cpp_type = "::protozero::ConstChars";
660           break;
661         case FieldDescriptor::TYPE_MESSAGE:
662         case FieldDescriptor::TYPE_BYTES:
663           getter = "as_bytes";
664           cpp_type = "::protozero::ConstBytes";
665           break;
666         case FieldDescriptor::TYPE_GROUP:
667           continue;
668       }
669 
670       stub_h_->Print("bool has_$name$() const { return at<$id$>().valid(); }\n",
671                      "name", field->lowercase_name(), "id",
672                      std::to_string(field->number()));
673 
674       if (field->is_packed()) {
675         const char* protozero_wire_type =
676             FieldTypeToProtozeroWireType(field->type());
677         stub_h_->Print(
678             "::protozero::PackedRepeatedFieldIterator<$wire_type$, $cpp_type$> "
679             "$name$(bool* parse_error_ptr) const { return "
680             "GetPackedRepeated<$wire_type$, $cpp_type$>($id$, "
681             "parse_error_ptr); }\n",
682             "wire_type", protozero_wire_type, "cpp_type", cpp_type, "name",
683             field->lowercase_name(), "id", std::to_string(field->number()));
684       } else if (field->is_repeated()) {
685         stub_h_->Print(
686             "::protozero::RepeatedFieldIterator<$cpp_type$> $name$() const { "
687             "return "
688             "GetRepeated<$cpp_type$>($id$); }\n",
689             "name", field->lowercase_name(), "cpp_type", cpp_type, "id",
690             std::to_string(field->number()));
691       } else {
692         stub_h_->Print(
693             "$cpp_type$ $name$() const { return at<$id$>().$getter$(); }\n",
694             "name", field->lowercase_name(), "id",
695             std::to_string(field->number()), "cpp_type", cpp_type, "getter",
696             getter);
697       }
698     }
699     stub_h_->Outdent();
700     stub_h_->Print("};\n\n");
701   }
702 
GenerateConstantsForMessageFields(const Descriptor * message)703   void GenerateConstantsForMessageFields(const Descriptor* message) {
704     const bool has_fields = (message->field_count() > 0);
705 
706     // Field number constants.
707     if (has_fields) {
708       stub_h_->Print("enum : int32_t {\n");
709       stub_h_->Indent();
710 
711       for (int i = 0; i < message->field_count(); ++i) {
712         const FieldDescriptor* field = message->field(i);
713         stub_h_->Print("$name$ = $id$,\n", "name",
714                        GetFieldNumberConstant(field), "id",
715                        std::to_string(field->number()));
716       }
717       stub_h_->Outdent();
718       stub_h_->Print("};\n");
719     }
720   }
721 
GenerateMessageDescriptor(const Descriptor * message)722   void GenerateMessageDescriptor(const Descriptor* message) {
723     GenerateDecoder(message);
724 
725     stub_h_->Print(
726         "class $name$ : public ::protozero::Message {\n"
727         " public:\n",
728         "name", GetCppClassName(message));
729     stub_h_->Indent();
730 
731     stub_h_->Print("using Decoder = $name$_Decoder;\n", "name",
732                    GetCppClassName(message));
733 
734     GenerateConstantsForMessageFields(message);
735 
736     // Using statements for nested messages.
737     for (int i = 0; i < message->nested_type_count(); ++i) {
738       const Descriptor* nested_message = message->nested_type(i);
739       stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
740                      nested_message->name(), "global_name",
741                      GetCppClassName(nested_message, true));
742     }
743 
744     // Using statements for nested enums.
745     for (int i = 0; i < message->enum_type_count(); ++i) {
746       const EnumDescriptor* nested_enum = message->enum_type(i);
747       stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
748                      nested_enum->name(), "global_name",
749                      GetCppClassName(nested_enum, true));
750     }
751 
752     // Values of nested enums.
753     for (int i = 0; i < message->enum_type_count(); ++i) {
754       const EnumDescriptor* nested_enum = message->enum_type(i);
755       std::string value_name_prefix = GetCppClassName(nested_enum) + "_";
756 
757       for (int j = 0; j < nested_enum->value_count(); ++j) {
758         const EnumValueDescriptor* value = nested_enum->value(j);
759         stub_h_->Print("static const $class$ $name$ = $full_name$;\n", "class",
760                        nested_enum->name(), "name", value->name(), "full_name",
761                        value_name_prefix + value->name());
762       }
763     }
764 
765     // Field descriptors.
766     for (int i = 0; i < message->field_count(); ++i) {
767       const FieldDescriptor* field = message->field(i);
768       if (field->is_packed()) {
769         GeneratePackedRepeatedFieldDescriptor(field);
770       } else if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
771         GenerateSimpleFieldDescriptor(field);
772       } else {
773         GenerateNestedMessageFieldDescriptor(field);
774       }
775     }
776 
777     stub_h_->Outdent();
778     stub_h_->Print("};\n\n");
779   }
780 
GenerateEpilogue()781   void GenerateEpilogue() {
782     for (unsigned i = 0; i < namespaces_.size(); ++i) {
783       stub_h_->Print("} // Namespace.\n");
784     }
785     stub_h_->Print("#endif  // Include guard.\n");
786   }
787 
788   const FileDescriptor* const source_;
789   Printer* const stub_h_;
790   std::string error_;
791 
792   std::string package_;
793   std::string wrapper_namespace_;
794   std::vector<std::string> namespaces_;
795   std::string full_namespace_prefix_;
796   std::vector<const Descriptor*> messages_;
797   std::vector<const EnumDescriptor*> enums_;
798 
799   // The custom *Comp comparators are to ensure determinism of the generator.
800   std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
801   std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
802   std::set<const Descriptor*, DescriptorComp> referenced_messages_;
803   std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
804 };
805 
806 class ProtoZeroGenerator : public ::google::protobuf::compiler::CodeGenerator {
807  public:
808   explicit ProtoZeroGenerator();
809   ~ProtoZeroGenerator() override;
810 
811   // CodeGenerator implementation
812   bool Generate(const google::protobuf::FileDescriptor* file,
813                 const std::string& options,
814                 GeneratorContext* context,
815                 std::string* error) const override;
816 };
817 
ProtoZeroGenerator()818 ProtoZeroGenerator::ProtoZeroGenerator() {}
819 
~ProtoZeroGenerator()820 ProtoZeroGenerator::~ProtoZeroGenerator() {}
821 
Generate(const FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const822 bool ProtoZeroGenerator::Generate(const FileDescriptor* file,
823                                   const std::string& options,
824                                   GeneratorContext* context,
825                                   std::string* error) const {
826   const std::unique_ptr<ZeroCopyOutputStream> stub_h_file_stream(
827       context->Open(ProtoStubName(file) + ".h"));
828   const std::unique_ptr<ZeroCopyOutputStream> stub_cc_file_stream(
829       context->Open(ProtoStubName(file) + ".cc"));
830 
831   // Variables are delimited by $.
832   Printer stub_h_printer(stub_h_file_stream.get(), '$');
833   GeneratorJob job(file, &stub_h_printer);
834 
835   Printer stub_cc_printer(stub_cc_file_stream.get(), '$');
836   stub_cc_printer.Print("// Intentionally empty (crbug.com/998165)\n");
837 
838   // Parse additional options.
839   for (const std::string& option : SplitString(options, ",")) {
840     std::vector<std::string> option_pair = SplitString(option, "=");
841     job.SetOption(option_pair[0], option_pair[1]);
842   }
843 
844   if (!job.GenerateStubs()) {
845     *error = job.GetFirstError();
846     return false;
847   }
848   return true;
849 }
850 
851 }  // namespace
852 }  // namespace protozero
853 
main(int argc,char * argv[])854 int main(int argc, char* argv[]) {
855   ::protozero::ProtoZeroGenerator generator;
856   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
857 }
858