• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <set>
24 #include <stack>
25 #include <string>
26 #include <vector>
27 
28 #include <google/protobuf/compiler/code_generator.h>
29 #include <google/protobuf/compiler/importer.h>
30 #include <google/protobuf/compiler/plugin.h>
31 #include <google/protobuf/io/printer.h>
32 
33 #include "perfetto/ext/base/string_utils.h"
34 
35 namespace protozero {
36 namespace {
37 
38 using namespace google::protobuf;
39 using namespace google::protobuf::compiler;
40 using namespace google::protobuf::io;
41 using perfetto::base::SplitString;
42 using perfetto::base::StripChars;
43 using perfetto::base::StripSuffix;
44 using perfetto::base::ToUpper;
45 
46 static constexpr auto TYPE_STRING = FieldDescriptor::TYPE_STRING;
47 static constexpr auto TYPE_MESSAGE = FieldDescriptor::TYPE_MESSAGE;
48 static constexpr auto TYPE_SINT32 = FieldDescriptor::TYPE_SINT32;
49 static constexpr auto TYPE_SINT64 = FieldDescriptor::TYPE_SINT64;
50 
51 static const char kHeader[] =
52     "// DO NOT EDIT. Autogenerated by Perfetto cppgen_plugin\n";
53 
54 class CppObjGenerator : public ::google::protobuf::compiler::CodeGenerator {
55  public:
56   CppObjGenerator();
57   ~CppObjGenerator() override;
58 
59   // CodeGenerator implementation
60   bool Generate(const google::protobuf::FileDescriptor* file,
61                 const std::string& options,
62                 GeneratorContext* context,
63                 std::string* error) const override;
64 
65  private:
66   std::string GetCppType(const FieldDescriptor* field, bool constref) const;
67   std::string GetProtozeroSetter(const FieldDescriptor* field) const;
68   std::string GetPackedBuffer(const FieldDescriptor* field) const;
69   std::string GetPackedWireType(const FieldDescriptor* field) const;
70 
71   void GenEnum(const EnumDescriptor*, Printer*) const;
72   void GenEnumAliases(const EnumDescriptor*, Printer*) const;
73   void GenClassDecl(const Descriptor*, Printer*) const;
74   void GenClassDef(const Descriptor*, Printer*) const;
75 
GetNamespaces(const FileDescriptor * file) const76   std::vector<std::string> GetNamespaces(const FileDescriptor* file) const {
77     std::string pkg = std::string(file->package()) + wrapper_namespace_;
78     return SplitString(pkg, ".");
79   }
80 
81   template <typename T = Descriptor>
GetFullName(const T * msg,bool with_namespace=false) const82   std::string GetFullName(const T* msg, bool with_namespace = false) const {
83     std::string full_type;
84     full_type.append(msg->name());
85     for (const Descriptor* par = msg->containing_type(); par;
86          par = par->containing_type()) {
87       full_type.insert(0, std::string(par->name()) + "_");
88     }
89     if (with_namespace) {
90       std::string prefix;
91       for (const std::string& ns : GetNamespaces(msg->file())) {
92         prefix += ns + "::";
93       }
94       full_type = prefix + full_type;
95     }
96     return full_type;
97   }
98 
99   template <class T>
HasSamePackage(const T * descriptor) const100   bool HasSamePackage(const T* descriptor) const {
101     return descriptor->file()->package() == package_;
102   }
103 
104   mutable std::string wrapper_namespace_;
105   mutable std::string package_;
106 };
107 
108 CppObjGenerator::CppObjGenerator() = default;
109 CppObjGenerator::~CppObjGenerator() = default;
110 
Generate(const google::protobuf::FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const111 bool CppObjGenerator::Generate(const google::protobuf::FileDescriptor* file,
112                                const std::string& options,
113                                GeneratorContext* context,
114                                std::string* error) const {
115   for (const std::string& option : SplitString(options, ",")) {
116     std::vector<std::string> option_pair = SplitString(option, "=");
117     if (option_pair[0] == "wrapper_namespace") {
118       wrapper_namespace_ =
119           option_pair.size() == 2 ? "." + option_pair[1] : std::string();
120     } else {
121       *error = "Unknown plugin option: " + option_pair[0];
122       return false;
123     }
124   }
125 
126   package_ = file->package();
127 
128   auto get_file_name = [](const FileDescriptor* proto) {
129     return StripSuffix(std::string(proto->name()), ".proto") + ".gen";
130   };
131 
132   const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
133       context->Open(get_file_name(file) + ".h"));
134   const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
135       context->Open(get_file_name(file) + ".cc"));
136 
137   // Variables are delimited by $.
138   Printer h_printer(h_fstream.get(), '$');
139   Printer cc_printer(cc_fstream.get(), '$');
140 
141   std::string include_guard = std::string(file->package()) + "_" +
142                               std::string(file->name()) + "_CPP_H_";
143   include_guard = ToUpper(include_guard);
144   include_guard = StripChars(include_guard, ".-/\\", '_');
145 
146   h_printer.Print(kHeader);
147   h_printer.Print("#ifndef $g$\n#define $g$\n\n", "g", include_guard);
148   h_printer.Print("#include <stdint.h>\n");
149   h_printer.Print("#include <bitset>\n");
150   h_printer.Print("#include <vector>\n");
151   h_printer.Print("#include <string>\n");
152   h_printer.Print("#include <type_traits>\n\n");
153   h_printer.Print("#include \"perfetto/protozero/cpp_message_obj.h\"\n");
154   h_printer.Print("#include \"perfetto/protozero/copyable_ptr.h\"\n");
155   h_printer.Print("#include \"perfetto/base/export.h\"\n\n");
156 
157   cc_printer.Print("#include \"perfetto/protozero/gen_field_helpers.h\"\n");
158   cc_printer.Print("#include \"perfetto/protozero/message.h\"\n");
159   cc_printer.Print(
160       "#include \"perfetto/protozero/packed_repeated_fields.h\"\n");
161   cc_printer.Print("#include \"perfetto/protozero/proto_decoder.h\"\n");
162   cc_printer.Print("#include \"perfetto/protozero/scattered_heap_buffer.h\"\n");
163   cc_printer.Print(kHeader);
164   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
165   cc_printer.Print("#pragma GCC diagnostic push\n");
166   cc_printer.Print("#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n");
167   cc_printer.Print("#endif\n");
168 
169   // Generate includes for translated types of dependencies.
170 
171   // Figure out the subset of imports that are used only for lazy fields. We
172   // won't emit a C++ #include for them. This code is overly aggressive at
173   // removing imports: it rules them out as soon as it sees one lazy field
174   // whose type is defined in that import. A 100% correct solution would require
175   // to check that *all* dependent types for a given import are lazy before
176   // excluding that. In practice we don't need that because we don't use imports
177   // for both lazy and non-lazy fields.
178   std::set<std::string> lazy_imports;
179   for (int m = 0; m < file->message_type_count(); m++) {
180     const Descriptor* msg = file->message_type(m);
181     for (int i = 0; i < msg->field_count(); i++) {
182       const FieldDescriptor* field = msg->field(i);
183       if (field->options().lazy()) {
184         lazy_imports.insert(std::string(field->message_type()->file()->name()));
185       }
186     }
187   }
188 
189   // Recursively traverse all imports and turn them into #include(s).
190   std::vector<const FileDescriptor*> imports_to_visit;
191   std::set<const FileDescriptor*> imports_visited;
192   imports_to_visit.push_back(file);
193 
194   while (!imports_to_visit.empty()) {
195     const FileDescriptor* cur = imports_to_visit.back();
196     imports_to_visit.pop_back();
197     imports_visited.insert(cur);
198     std::string base_name = StripSuffix(std::string(cur->name()), ".proto");
199     cc_printer.Print("#include \"$f$.gen.h\"\n", "f", base_name);
200     for (int i = 0; i < cur->dependency_count(); i++) {
201       const FileDescriptor* dep = cur->dependency(i);
202       if (imports_visited.count(dep) ||
203           lazy_imports.count(std::string(dep->name())))
204         continue;
205       imports_to_visit.push_back(dep);
206     }
207   }
208 
209   // Compute all nested types to generate forward declarations later.
210 
211   std::set<const Descriptor*> all_types_seen;  // All deps
212   std::set<const EnumDescriptor*> all_enums_seen;
213 
214   // We track the types additionally in vectors to guarantee a stable order in
215   // the generated output.
216   std::vector<const Descriptor*> local_types;  // Cur .proto file only.
217   std::vector<const Descriptor*> all_types;    // All deps
218   std::vector<const EnumDescriptor*> local_enums;
219   std::vector<const EnumDescriptor*> all_enums;
220 
221   auto add_enum = [&local_enums, &all_enums, &all_enums_seen,
222                    &file](const EnumDescriptor* enum_desc) {
223     if (all_enums_seen.count(enum_desc))
224       return;
225     all_enums_seen.insert(enum_desc);
226     all_enums.push_back(enum_desc);
227     if (enum_desc->file() == file)
228       local_enums.push_back(enum_desc);
229   };
230 
231   for (int i = 0; i < file->enum_type_count(); i++)
232     add_enum(file->enum_type(i));
233 
234   std::stack<const Descriptor*> recursion_stack;
235   for (int i = 0; i < file->message_type_count(); i++)
236     recursion_stack.push(file->message_type(i));
237 
238   while (!recursion_stack.empty()) {
239     const Descriptor* msg = recursion_stack.top();
240     recursion_stack.pop();
241     if (all_types_seen.count(msg))
242       continue;
243     all_types_seen.insert(msg);
244     all_types.push_back(msg);
245     if (msg->file() == file)
246       local_types.push_back(msg);
247 
248     for (int i = 0; i < msg->nested_type_count(); i++)
249       recursion_stack.push(msg->nested_type(i));
250 
251     for (int i = 0; i < msg->enum_type_count(); i++)
252       add_enum(msg->enum_type(i));
253 
254     for (int i = 0; i < msg->field_count(); i++) {
255       const FieldDescriptor* field = msg->field(i);
256       if (field->has_default_value()) {
257         *error = "field " + std::string(field->name()) +
258                  ": Explicitly declared default values are not supported";
259         return false;
260       }
261       if (field->options().lazy() &&
262           (field->is_repeated() || field->type() != TYPE_MESSAGE)) {
263         *error = "[lazy=true] is supported only on non-repeated fields\n";
264         return false;
265       }
266 
267       if (field->type() == TYPE_MESSAGE && !field->options().lazy())
268         recursion_stack.push(field->message_type());
269 
270       if (field->type() == FieldDescriptor::TYPE_ENUM)
271         add_enum(field->enum_type());
272     }
273   }  //  while (!recursion_stack.empty())
274 
275   // Generate forward declarations in the header for proto types.
276   // Note: do NOT add #includes to other generated headers (either .gen.h or
277   // .pbzero.h). Doing so is extremely hard to handle at the build-system level
278   // and requires propagating public_deps everywhere.
279   cc_printer.Print("\n");
280 
281   // -- Begin of fwd declarations.
282 
283   // Build up the map of forward declarations.
284   std::multimap<std::string /*namespace*/, std::string /*decl*/> fwd_decls;
285   enum FwdType { kClass, kEnum };
286   auto add_fwd_decl = [&fwd_decls](FwdType cpp_type,
287                                    const std::string& full_name) {
288     auto dot = full_name.rfind("::");
289     PERFETTO_CHECK(dot != std::string::npos);
290     auto package = full_name.substr(0, dot);
291     auto name = full_name.substr(dot + 2);
292     if (cpp_type == kClass) {
293       fwd_decls.emplace(package, "class " + name + ";");
294     } else {
295       PERFETTO_CHECK(cpp_type == kEnum);
296       fwd_decls.emplace(package, "enum " + name + " : int;");
297     }
298   };
299 
300   add_fwd_decl(kClass, "protozero::Message");
301   for (const Descriptor* msg : all_types) {
302     add_fwd_decl(kClass, GetFullName(msg, true));
303   }
304   for (const EnumDescriptor* enm : all_enums) {
305     add_fwd_decl(kEnum, GetFullName(enm, true));
306   }
307 
308   // Emit forward declarations grouping by package.
309   std::string last_package;
310   auto close_last_package = [&last_package, &h_printer] {
311     if (!last_package.empty()) {
312       for (const std::string& ns : SplitString(last_package, "::"))
313         h_printer.Print("}  // namespace $ns$\n", "ns", ns);
314       h_printer.Print("\n");
315     }
316   };
317   for (const auto& kv : fwd_decls) {
318     const std::string& package = kv.first;
319     if (package != last_package) {
320       close_last_package();
321       last_package = package;
322       for (const std::string& ns : SplitString(package, "::"))
323         h_printer.Print("namespace $ns$ {\n", "ns", ns);
324     }
325     h_printer.Print("$decl$\n", "decl", kv.second);
326   }
327   close_last_package();
328 
329   // -- End of fwd declarations.
330 
331   for (const std::string& ns : GetNamespaces(file)) {
332     h_printer.Print("namespace $n$ {\n", "n", ns);
333     cc_printer.Print("namespace $n$ {\n", "n", ns);
334   }
335 
336   // Generate declarations and definitions.
337   for (const EnumDescriptor* enm : local_enums)
338     GenEnum(enm, &h_printer);
339 
340   for (const Descriptor* msg : local_types) {
341     GenClassDecl(msg, &h_printer);
342     GenClassDef(msg, &cc_printer);
343   }
344 
345   for (const std::string& ns : GetNamespaces(file)) {
346     h_printer.Print("}  // namespace $n$\n", "n", ns);
347     cc_printer.Print("}  // namespace $n$\n", "n", ns);
348   }
349   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
350   cc_printer.Print("#pragma GCC diagnostic pop\n");
351   cc_printer.Print("#endif\n");
352 
353   h_printer.Print("\n#endif  // $g$\n", "g", include_guard);
354 
355   return true;
356 }
357 
GetCppType(const FieldDescriptor * field,bool constref) const358 std::string CppObjGenerator::GetCppType(const FieldDescriptor* field,
359                                         bool constref) const {
360   switch (field->type()) {
361     case FieldDescriptor::TYPE_DOUBLE:
362       return "double";
363     case FieldDescriptor::TYPE_FLOAT:
364       return "float";
365     case FieldDescriptor::TYPE_FIXED32:
366     case FieldDescriptor::TYPE_UINT32:
367       return "uint32_t";
368     case FieldDescriptor::TYPE_SFIXED32:
369     case FieldDescriptor::TYPE_INT32:
370     case FieldDescriptor::TYPE_SINT32:
371       return "int32_t";
372     case FieldDescriptor::TYPE_FIXED64:
373     case FieldDescriptor::TYPE_UINT64:
374       return "uint64_t";
375     case FieldDescriptor::TYPE_SFIXED64:
376     case FieldDescriptor::TYPE_SINT64:
377     case FieldDescriptor::TYPE_INT64:
378       return "int64_t";
379     case FieldDescriptor::TYPE_BOOL:
380       return "bool";
381     case FieldDescriptor::TYPE_STRING:
382     case FieldDescriptor::TYPE_BYTES:
383       return constref ? "const std::string&" : "std::string";
384     case FieldDescriptor::TYPE_MESSAGE:
385       assert(!field->options().lazy());
386       return constref
387                  ? "const " +
388                        GetFullName(field->message_type(),
389                                    !HasSamePackage(field->message_type())) +
390                        "&"
391                  : GetFullName(field->message_type(),
392                                !HasSamePackage(field->message_type()));
393     case FieldDescriptor::TYPE_ENUM:
394       return GetFullName(field->enum_type(),
395                          !HasSamePackage(field->enum_type()));
396     case FieldDescriptor::TYPE_GROUP:
397       abort();
398   }
399   abort();  // for gcc
400 }
401 
GetProtozeroSetter(const FieldDescriptor * field) const402 std::string CppObjGenerator::GetProtozeroSetter(
403     const FieldDescriptor* field) const {
404   switch (field->type()) {
405     case FieldDescriptor::TYPE_BOOL:
406       return "::protozero::internal::gen_helpers::SerializeTinyVarInt";
407     case FieldDescriptor::TYPE_INT32:
408     case FieldDescriptor::TYPE_INT64:
409     case FieldDescriptor::TYPE_UINT32:
410     case FieldDescriptor::TYPE_UINT64:
411     case FieldDescriptor::TYPE_ENUM:
412       return "::protozero::internal::gen_helpers::SerializeVarInt";
413     case FieldDescriptor::TYPE_SINT32:
414     case FieldDescriptor::TYPE_SINT64:
415       return "::protozero::internal::gen_helpers::SerializeSignedVarInt";
416     case FieldDescriptor::TYPE_FIXED32:
417     case FieldDescriptor::TYPE_FIXED64:
418     case FieldDescriptor::TYPE_SFIXED32:
419     case FieldDescriptor::TYPE_SFIXED64:
420     case FieldDescriptor::TYPE_FLOAT:
421     case FieldDescriptor::TYPE_DOUBLE:
422       return "::protozero::internal::gen_helpers::SerializeFixed";
423     case FieldDescriptor::TYPE_STRING:
424     case FieldDescriptor::TYPE_BYTES:
425       return "::protozero::internal::gen_helpers::SerializeString";
426     case FieldDescriptor::TYPE_GROUP:
427     case FieldDescriptor::TYPE_MESSAGE:
428       abort();
429   }
430   abort();
431 }
432 
GetPackedBuffer(const FieldDescriptor * field) const433 std::string CppObjGenerator::GetPackedBuffer(
434     const FieldDescriptor* field) const {
435   switch (field->type()) {
436     case FieldDescriptor::TYPE_FIXED32:
437       return "::protozero::PackedFixedSizeInt<uint32_t>";
438     case FieldDescriptor::TYPE_SFIXED32:
439       return "::protozero::PackedFixedSizeInt<int32_t>";
440     case FieldDescriptor::TYPE_FIXED64:
441       return "::protozero::PackedFixedSizeInt<uint64_t>";
442     case FieldDescriptor::TYPE_SFIXED64:
443       return "::protozero::PackedFixedSizeInt<int64_t>";
444     case FieldDescriptor::TYPE_DOUBLE:
445       return "::protozero::PackedFixedSizeInt<double>";
446     case FieldDescriptor::TYPE_FLOAT:
447       return "::protozero::PackedFixedSizeInt<float>";
448     case FieldDescriptor::TYPE_INT32:
449     case FieldDescriptor::TYPE_SINT32:
450     case FieldDescriptor::TYPE_UINT32:
451     case FieldDescriptor::TYPE_INT64:
452     case FieldDescriptor::TYPE_UINT64:
453     case FieldDescriptor::TYPE_SINT64:
454     case FieldDescriptor::TYPE_BOOL:
455     case FieldDescriptor::TYPE_ENUM:
456       return "::protozero::PackedVarInt";
457     case FieldDescriptor::TYPE_STRING:
458     case FieldDescriptor::TYPE_BYTES:
459     case FieldDescriptor::TYPE_MESSAGE:
460     case FieldDescriptor::TYPE_GROUP:
461       break;  // Will abort()
462   }
463   abort();
464 }
465 
GetPackedWireType(const FieldDescriptor * field) const466 std::string CppObjGenerator::GetPackedWireType(
467     const FieldDescriptor* field) const {
468   switch (field->type()) {
469     case FieldDescriptor::TYPE_FIXED32:
470     case FieldDescriptor::TYPE_SFIXED32:
471     case FieldDescriptor::TYPE_FLOAT:
472       return "::protozero::proto_utils::ProtoWireType::kFixed32";
473     case FieldDescriptor::TYPE_FIXED64:
474     case FieldDescriptor::TYPE_SFIXED64:
475     case FieldDescriptor::TYPE_DOUBLE:
476       return "::protozero::proto_utils::ProtoWireType::kFixed64";
477     case FieldDescriptor::TYPE_INT32:
478     case FieldDescriptor::TYPE_SINT32:
479     case FieldDescriptor::TYPE_UINT32:
480     case FieldDescriptor::TYPE_INT64:
481     case FieldDescriptor::TYPE_UINT64:
482     case FieldDescriptor::TYPE_SINT64:
483     case FieldDescriptor::TYPE_BOOL:
484     case FieldDescriptor::TYPE_ENUM:
485       return "::protozero::proto_utils::ProtoWireType::kVarInt";
486     case FieldDescriptor::TYPE_STRING:
487     case FieldDescriptor::TYPE_BYTES:
488     case FieldDescriptor::TYPE_MESSAGE:
489     case FieldDescriptor::TYPE_GROUP:
490       break;  // Will abort()
491   }
492   abort();
493 }
494 
IntLiteralString(int number)495 std::string IntLiteralString(int number) {
496   // Special case for -2147483648. If int is 32-bit, the compiler will
497   // misinterpret it.
498   if (number == std::numeric_limits<int32_t>::min()) {
499     return "-2147483647 - 1";
500   }
501   return std::to_string(number);
502 }
503 
GenEnum(const EnumDescriptor * enum_desc,Printer * p) const504 void CppObjGenerator::GenEnum(const EnumDescriptor* enum_desc,
505                               Printer* p) const {
506   std::string full_name = GetFullName(enum_desc);
507 
508   // When generating enums, there are two cases:
509   // 1. Enums nested in a message (most frequent case), e.g.:
510   //    message MyMsg { enum MyEnum { FOO=1; BAR=2; } }
511   // 2. Enum defined at the package level, outside of any message.
512   //
513   // In the case 1, the C++ code generated by the official protobuf library is:
514   // enum MyEnum {  MyMsg_MyEnum_FOO=1, MyMsg_MyEnum_BAR=2 }
515   // class MyMsg { static const auto FOO = MyMsg_MyEnum_FOO; ... same for BAR }
516   //
517   // In the case 2, the C++ code is simply:
518   // enum MyEnum { FOO=1, BAR=2 }
519   // Hence this |prefix| logic.
520   std::string prefix = enum_desc->containing_type() ? full_name + "_" : "";
521   p->Print("enum $f$ : int {\n", "f", full_name);
522   for (int e = 0; e < enum_desc->value_count(); e++) {
523     const EnumValueDescriptor* value = enum_desc->value(e);
524     p->Print("  $p$$n$ = $v$,\n", "p", prefix, "n", std::string(value->name()),
525              "v", IntLiteralString(value->number()));
526   }
527   p->Print("};\n");
528 }
529 
GenEnumAliases(const EnumDescriptor * enum_desc,Printer * p) const530 void CppObjGenerator::GenEnumAliases(const EnumDescriptor* enum_desc,
531                                      Printer* p) const {
532   int min_value = std::numeric_limits<int>::max();
533   int max_value = std::numeric_limits<int>::min();
534   std::string min_name;
535   std::string max_name;
536   std::string full_name = GetFullName(enum_desc);
537   for (int e = 0; e < enum_desc->value_count(); e++) {
538     const EnumValueDescriptor* value = enum_desc->value(e);
539     std::string value_name(value->name());
540     p->Print("static constexpr auto $n$ = $f$_$n$;\n", "f", full_name, "n",
541              value_name);
542     if (value->number() < min_value) {
543       min_value = value->number();
544       min_name = full_name + "_" + value_name;
545     }
546     if (value->number() > max_value) {
547       max_value = value->number();
548       max_name = full_name + "_" + value_name;
549     }
550   }
551   p->Print("static constexpr auto $n$_MIN = $m$;\n", "n", enum_desc->name(),
552            "m", min_name);
553   p->Print("static constexpr auto $n$_MAX = $m$;\n", "n", enum_desc->name(),
554            "m", max_name);
555 }
556 
GenClassDecl(const Descriptor * msg,Printer * p) const557 void CppObjGenerator::GenClassDecl(const Descriptor* msg, Printer* p) const {
558   std::string full_name = GetFullName(msg);
559   p->Print(
560       "\nclass PERFETTO_EXPORT_COMPONENT $n$ : public "
561       "::protozero::CppMessageObj {\n",
562       "n", full_name);
563   p->Print(" public:\n");
564   p->Indent();
565 
566   // Do a first pass to generate aliases for nested types.
567   // e.g., using Foo = Parent_Foo;
568   for (int i = 0; i < msg->nested_type_count(); i++) {
569     const Descriptor* nested_msg = msg->nested_type(i);
570     p->Print("using $n$ = $f$;\n", "n", nested_msg->name(), "f",
571              GetFullName(nested_msg));
572   }
573   for (int i = 0; i < msg->enum_type_count(); i++) {
574     const EnumDescriptor* nested_enum = msg->enum_type(i);
575     p->Print("using $n$ = $f$;\n", "n", nested_enum->name(), "f",
576              GetFullName(nested_enum));
577     GenEnumAliases(nested_enum, p);
578   }
579 
580   // Generate constants with field numbers.
581   p->Print("enum FieldNumbers {\n");
582   for (int i = 0; i < msg->field_count(); i++) {
583     const FieldDescriptor* field = msg->field(i);
584     std::string name(field->camelcase_name());
585     name[0] = perfetto::base::Uppercase(name[0]);
586     p->Print("  k$n$FieldNumber = $num$,\n", "n", name, "num",
587              std::to_string(field->number()));
588   }
589   p->Print("};\n\n");
590 
591   p->Print("$n$();\n", "n", full_name);
592   p->Print("~$n$() override;\n", "n", full_name);
593   p->Print("$n$($n$&&) noexcept;\n", "n", full_name);
594   p->Print("$n$& operator=($n$&&);\n", "n", full_name);
595   p->Print("$n$(const $n$&);\n", "n", full_name);
596   p->Print("$n$& operator=(const $n$&);\n", "n", full_name);
597   p->Print("bool operator==(const $n$&) const;\n", "n", full_name);
598   p->Print(
599       "bool operator!=(const $n$& other) const { return !(*this == other); }\n",
600       "n", full_name);
601   p->Print("\n");
602 
603   std::string proto_type = GetFullName(msg, true);
604   p->Print("bool ParseFromArray(const void*, size_t) override;\n");
605   p->Print("std::string SerializeAsString() const override;\n");
606   p->Print("std::vector<uint8_t> SerializeAsArray() const override;\n");
607   p->Print("void Serialize(::protozero::Message*) const;\n");
608 
609   // Generate accessors.
610   for (int i = 0; i < msg->field_count(); i++) {
611     const FieldDescriptor* field = msg->field(i);
612     auto set_bit = "_has_field_.set(" + std::to_string(field->number()) + ")";
613     p->Print("\n");
614     std::string lowercase_name(field->lowercase_name());
615     if (field->options().lazy()) {
616       p->Print("const std::string& $n$_raw() const { return $n$_; }\n", "n",
617                lowercase_name);
618       p->Print(
619           "void set_$n$_raw(const std::string& raw) { $n$_ = raw; $s$; }\n",
620           "n", lowercase_name, "s", set_bit);
621     } else if (!field->is_repeated()) {
622       p->Print("bool has_$n$() const { return _has_field_[$bit$]; }\n", "n",
623                lowercase_name, "bit", std::to_string(field->number()));
624       if (field->type() == TYPE_MESSAGE) {
625         p->Print("$t$ $n$() const { return *$n$_; }\n", "t",
626                  GetCppType(field, true), "n", lowercase_name);
627         p->Print("$t$* mutable_$n$() { $s$; return $n$_.get(); }\n", "t",
628                  GetCppType(field, false), "n", lowercase_name, "s", set_bit);
629       } else {
630         p->Print("$t$ $n$() const { return $n$_; }\n", "t",
631                  GetCppType(field, true), "n", lowercase_name);
632         p->Print("void set_$n$($t$ value) { $n$_ = value; $s$; }\n", "t",
633                  GetCppType(field, true), "n", lowercase_name, "s", set_bit);
634         if (field->type() == FieldDescriptor::TYPE_BYTES) {
635           p->Print(
636               "void set_$n$(const void* p, size_t s) { "
637               "$n$_.assign(reinterpret_cast<const char*>(p), s); $s$; }\n",
638               "n", lowercase_name, "s", set_bit);
639         }
640       }
641     } else {  // is_repeated()
642       p->Print("const std::vector<$t$>& $n$() const { return $n$_; }\n", "t",
643                GetCppType(field, false), "n", lowercase_name);
644       p->Print("std::vector<$t$>* mutable_$n$() { return &$n$_; }\n", "t",
645                GetCppType(field, false), "n", lowercase_name);
646 
647       // Generate accessors for repeated message types in the .cc file so that
648       // the header doesn't depend on the full definition of all nested types.
649       if (field->type() == TYPE_MESSAGE) {
650         p->Print("int $n$_size() const;\n", "t", GetCppType(field, false), "n",
651                  lowercase_name);
652         p->Print("void clear_$n$();\n", "n", lowercase_name);
653         p->Print("$t$* add_$n$();\n", "t", GetCppType(field, false), "n",
654                  lowercase_name);
655       } else {  // Primitive type.
656         p->Print(
657             "int $n$_size() const { return static_cast<int>($n$_.size()); }\n",
658             "t", GetCppType(field, false), "n", lowercase_name);
659         p->Print("void clear_$n$() { $n$_.clear(); }\n", "n", lowercase_name);
660         p->Print("void add_$n$($t$ value) { $n$_.emplace_back(value); }\n", "t",
661                  GetCppType(field, false), "n", lowercase_name);
662         // TODO(primiano): this should be done only for TYPE_MESSAGE.
663         // Unfortuntely we didn't realize before and now we have a bunch of code
664         // that does: *msg->add_int_value() = 42 instead of
665         // msg->add_int_value(42).
666         p->Print(
667             "$t$* add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
668             "t", GetCppType(field, false), "n", lowercase_name);
669       }
670     }
671   }
672   p->Outdent();
673   p->Print("\n private:\n");
674   p->Indent();
675 
676   // Generate fields.
677   int max_field_id = 1;
678   for (int i = 0; i < msg->field_count(); i++) {
679     const FieldDescriptor* field = msg->field(i);
680     std::string lowercase_name(field->lowercase_name());
681     max_field_id = std::max(max_field_id, field->number());
682     if (field->options().lazy()) {
683       p->Print("std::string $n$_;  // [lazy=true]\n", "n", lowercase_name);
684     } else if (!field->is_repeated()) {
685       std::string type = GetCppType(field, false);
686       if (field->type() == TYPE_MESSAGE) {
687         type = "::protozero::CopyablePtr<" + type + ">";
688         p->Print("$t$ $n$_;\n", "t", type, "n", lowercase_name);
689       } else {
690         p->Print("$t$ $n$_{};\n", "t", type, "n", lowercase_name);
691       }
692     } else {  // is_repeated()
693       p->Print("std::vector<$t$> $n$_;\n", "t", GetCppType(field, false), "n",
694                lowercase_name);
695     }
696   }
697   p->Print("\n");
698   p->Print("// Allows to preserve unknown protobuf fields for compatibility\n");
699   p->Print("// with future versions of .proto files.\n");
700   p->Print("std::string unknown_fields_;\n");
701 
702   p->Print("\nstd::bitset<$id$> _has_field_{};\n", "id",
703            std::to_string(max_field_id + 1));
704 
705   p->Outdent();
706   p->Print("};\n\n");
707 }
708 
GenClassDef(const Descriptor * msg,Printer * p) const709 void CppObjGenerator::GenClassDef(const Descriptor* msg, Printer* p) const {
710   p->Print("\n");
711   std::string full_name = GetFullName(msg);
712 
713   p->Print("$n$::$n$() = default;\n", "n", full_name);
714   p->Print("$n$::~$n$() = default;\n", "n", full_name);
715   p->Print("$n$::$n$(const $n$&) = default;\n", "n", full_name);
716   p->Print("$n$& $n$::operator=(const $n$&) = default;\n", "n", full_name);
717   p->Print("$n$::$n$($n$&&) noexcept = default;\n", "n", full_name);
718   p->Print("$n$& $n$::operator=($n$&&) = default;\n", "n", full_name);
719 
720   p->Print("\n");
721 
722   // Comparison operator
723   p->Print("bool $n$::operator==(const $n$& other) const {\n", "n", full_name);
724   p->Indent();
725 
726   p->Print(
727       "return ::protozero::internal::gen_helpers::EqualsField(unknown_fields_, "
728       "other.unknown_fields_)");
729   for (int i = 0; i < msg->field_count(); i++)
730     p->Print(
731         "\n && ::protozero::internal::gen_helpers::EqualsField($n$_, "
732         "other.$n$_)",
733         "n", std::string(msg->field(i)->lowercase_name()));
734   p->Print(";");
735   p->Outdent();
736   p->Print("\n}\n\n");
737 
738   // Accessors for repeated message fields.
739   for (int i = 0; i < msg->field_count(); i++) {
740     const FieldDescriptor* field = msg->field(i);
741     if (field->options().lazy() || !field->is_repeated() ||
742         field->type() != TYPE_MESSAGE) {
743       continue;
744     }
745     std::string lowercase_name(field->lowercase_name());
746     p->Print(
747         "int $c$::$n$_size() const { return static_cast<int>($n$_.size()); }\n",
748         "c", full_name, "t", GetCppType(field, false), "n", lowercase_name);
749     p->Print("void $c$::clear_$n$() { $n$_.clear(); }\n", "c", full_name, "n",
750              lowercase_name);
751     p->Print(
752         "$t$* $c$::add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
753         "c", full_name, "t", GetCppType(field, false), "n", lowercase_name);
754   }
755 
756   std::string proto_type = GetFullName(msg, true);
757 
758   // Generate the ParseFromArray() method definition.
759   p->Print("bool $f$::ParseFromArray(const void* raw, size_t size) {\n", "f",
760            full_name);
761   p->Indent();
762   for (int i = 0; i < msg->field_count(); i++) {
763     const FieldDescriptor* field = msg->field(i);
764     std::string lowercase_name(field->lowercase_name());
765     if (field->is_repeated())
766       p->Print("$n$_.clear();\n", "n", lowercase_name);
767   }
768   p->Print("unknown_fields_.clear();\n");
769   p->Print("bool packed_error = false;\n");
770   p->Print("\n");
771   p->Print("::protozero::ProtoDecoder dec(raw, size);\n");
772   p->Print("for (auto field = dec.ReadField(); field.valid(); ");
773   p->Print("field = dec.ReadField()) {\n");
774   p->Indent();
775   p->Print("if (field.id() < _has_field_.size()) {\n");
776   p->Print("  _has_field_.set(field.id());\n");
777   p->Print("}\n");
778   p->Print("switch (field.id()) {\n");
779   p->Indent();
780   for (int i = 0; i < msg->field_count(); i++) {
781     const FieldDescriptor* field = msg->field(i);
782     std::string lowercase_name(field->lowercase_name());
783     p->Print("case $id$ /* $n$ */:\n", "id", std::to_string(field->number()),
784              "n", lowercase_name);
785     p->Indent();
786     if (field->options().lazy()) {
787       p->Print(
788           "::protozero::internal::gen_helpers::DeserializeString(field, "
789           "&$n$_);\n",
790           "n", lowercase_name);
791     } else {
792       std::string statement;
793       if (field->type() == TYPE_MESSAGE) {
794         statement = "$rval$.ParseFromArray(field.data(), field.size());\n";
795       } else {
796         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
797           // sint32/64 fields are special and need to be zig-zag-decoded.
798           statement = "field.get_signed(&$rval$);\n";
799         } else if (field->type() == TYPE_STRING) {
800           statement =
801               "::protozero::internal::gen_helpers::DeserializeString(field, "
802               "&$rval$);\n";
803         } else {
804           statement = "field.get(&$rval$);\n";
805         }
806       }
807       if (field->is_packed()) {
808         PERFETTO_CHECK(field->is_repeated());
809         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
810           PERFETTO_FATAL("packed signed (zigzag) fields are not supported");
811         }
812         p->Print(
813             "if "
814             "(!::protozero::internal::gen_helpers::DeserializePackedRepeated"
815             "<$w$, $c$>(field, &$n$_)) {\n",
816             "w", GetPackedWireType(field), "c", GetCppType(field, false), "n",
817             lowercase_name);
818         p->Print("  packed_error = true;");
819         p->Print("}\n");
820       } else if (field->is_repeated()) {
821         p->Print("$n$_.emplace_back();\n", "n", lowercase_name);
822         p->Print(statement.c_str(), "rval", lowercase_name + "_.back()");
823       } else if (field->type() == TYPE_MESSAGE) {
824         p->Print(statement.c_str(), "rval", "(*" + lowercase_name + "_)");
825       } else {
826         p->Print(statement.c_str(), "rval", lowercase_name + "_");
827       }
828     }
829     p->Print("break;\n");
830     p->Outdent();
831   }  // for (field)
832   p->Print("default:\n");
833   p->Print("  field.SerializeAndAppendTo(&unknown_fields_);\n");
834   p->Print("  break;\n");
835   p->Outdent();
836   p->Print("}\n");  // switch(field.id)
837   p->Outdent();
838   p->Print("}\n");                                           // for(field)
839   p->Print("return !packed_error && !dec.bytes_left();\n");  // for(field)
840   p->Outdent();
841   p->Print("}\n\n");
842 
843   // Generate the SerializeAsString() method definition.
844   p->Print("std::string $f$::SerializeAsString() const {\n", "f", full_name);
845   p->Indent();
846   p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
847   p->Print("Serialize(msg.get());\n");
848   p->Print("return msg.SerializeAsString();\n");
849   p->Outdent();
850   p->Print("}\n\n");
851 
852   // Generate the SerializeAsArray() method definition.
853   p->Print("std::vector<uint8_t> $f$::SerializeAsArray() const {\n", "f",
854            full_name);
855   p->Indent();
856   p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
857   p->Print("Serialize(msg.get());\n");
858   p->Print("return msg.SerializeAsArray();\n");
859   p->Outdent();
860   p->Print("}\n\n");
861 
862   // Generate the Serialize() method that writes the fields into the passed
863   // protozero |msg| write-only interface |msg|.
864   p->Print("void $f$::Serialize(::protozero::Message* msg) const {\n", "f",
865            full_name);
866   p->Indent();
867   for (int i = 0; i < msg->field_count(); i++) {
868     const FieldDescriptor* field = msg->field(i);
869     std::string lowercase_name(field->lowercase_name());
870     std::map<std::string, std::string> args;
871     args["id"] = std::to_string(field->number());
872     args["n"] = lowercase_name;
873     p->Print(args, "// Field $id$: $n$\n");
874     if (field->is_packed()) {
875       PERFETTO_CHECK(field->is_repeated());
876       p->Print("{\n");
877       p->Indent();
878       p->Print("$p$ pack;\n", "p", GetPackedBuffer(field));
879       p->Print(args, "for (auto& it : $n$_)\n");
880       p->Print(args, "  pack.Append(it);\n");
881       p->Print(args, "msg->AppendBytes($id$, pack.data(), pack.size());\n");
882       p->Outdent();
883       p->Print("}\n");
884     } else {
885       if (field->is_repeated()) {
886         p->Print(args, "for (auto& it : $n$_) {\n");
887         args["lvalue"] = "it";
888         args["rvalue"] = "it";
889       } else {
890         p->Print(args, "if (_has_field_[$id$]) {\n");
891         args["lvalue"] = "(*" + lowercase_name + "_)";
892         args["rvalue"] = lowercase_name + "_";
893       }
894       p->Indent();
895       if (field->options().lazy()) {
896         p->Print(args, "msg->AppendString($id$, $rvalue$);\n");
897       } else if (field->type() == TYPE_MESSAGE) {
898         p->Print(args,
899                  "$lvalue$.Serialize("
900                  "msg->BeginNestedMessage<::protozero::Message>($id$));\n");
901       } else {
902         args["setter"] = GetProtozeroSetter(field);
903         p->Print(args, "$setter$($id$, $rvalue$, msg);\n");
904       }
905       p->Outdent();
906       p->Print("}\n");
907     }
908 
909     p->Print("\n");
910   }  // for (field)
911   p->Print(
912       "protozero::internal::gen_helpers::SerializeUnknownFields(unknown_fields_"
913       ", msg);\n");
914   p->Outdent();
915   p->Print("}\n\n");
916 }
917 
918 }  // namespace
919 }  // namespace protozero
920 
main(int argc,char ** argv)921 int main(int argc, char** argv) {
922   ::protozero::CppObjGenerator generator;
923   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
924 }
925