1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <set>
24 #include <stack>
25 #include <vector>
26 
27 #include <google/protobuf/compiler/code_generator.h>
28 #include <google/protobuf/compiler/importer.h>
29 #include <google/protobuf/compiler/plugin.h>
30 #include <google/protobuf/dynamic_message.h>
31 #include <google/protobuf/io/printer.h>
32 #include <google/protobuf/io/zero_copy_stream_impl.h>
33 #include <google/protobuf/util/field_comparator.h>
34 #include <google/protobuf/util/message_differencer.h>
35 
36 #include "perfetto/ext/base/string_utils.h"
37 
38 namespace protozero {
39 namespace {
40 
41 using namespace google::protobuf;
42 using namespace google::protobuf::compiler;
43 using namespace google::protobuf::io;
44 using perfetto::base::SplitString;
45 using perfetto::base::StripChars;
46 using perfetto::base::StripSuffix;
47 using perfetto::base::ToUpper;
48 
49 static constexpr auto TYPE_STRING = FieldDescriptor::TYPE_STRING;
50 static constexpr auto TYPE_MESSAGE = FieldDescriptor::TYPE_MESSAGE;
51 static constexpr auto TYPE_SINT32 = FieldDescriptor::TYPE_SINT32;
52 static constexpr auto TYPE_SINT64 = FieldDescriptor::TYPE_SINT64;
53 
54 static const char kHeader[] =
55     "// DO NOT EDIT. Autogenerated by Perfetto cppgen_plugin\n";
56 
57 class CppObjGenerator : public ::google::protobuf::compiler::CodeGenerator {
58  public:
59   CppObjGenerator();
60   ~CppObjGenerator() override;
61 
62   // CodeGenerator implementation
63   bool Generate(const google::protobuf::FileDescriptor* file,
64                 const std::string& options,
65                 GeneratorContext* context,
66                 std::string* error) const override;
67 
68  private:
69   std::string GetCppType(const FieldDescriptor* field, bool constref) const;
70   std::string GetProtozeroSetter(const FieldDescriptor* field) const;
71   std::string GetPackedBuffer(const FieldDescriptor* field) const;
72   std::string GetPackedWireType(const FieldDescriptor* field) const;
73 
74   void GenEnum(const EnumDescriptor*, Printer*) const;
75   void GenEnumAliases(const EnumDescriptor*, Printer*) const;
76   void GenClassDecl(const Descriptor*, Printer*) const;
77   void GenClassDef(const Descriptor*, Printer*) const;
78 
GetNamespaces(const FileDescriptor * file) const79   std::vector<std::string> GetNamespaces(const FileDescriptor* file) const {
80     std::string pkg = file->package() + wrapper_namespace_;
81     return SplitString(pkg, ".");
82   }
83 
84   template <typename T = Descriptor>
GetFullName(const T * msg,bool with_namespace=false) const85   std::string GetFullName(const T* msg, bool with_namespace = false) const {
86     std::string full_type;
87     full_type.append(msg->name());
88     for (const Descriptor* par = msg->containing_type(); par;
89          par = par->containing_type()) {
90       full_type.insert(0, par->name() + "_");
91     }
92     if (with_namespace) {
93       std::string prefix;
94       for (const std::string& ns : GetNamespaces(msg->file())) {
95         prefix += ns + "::";
96       }
97       full_type = prefix + full_type;
98     }
99     return full_type;
100   }
101 
102   mutable std::string wrapper_namespace_;
103 };
104 
105 CppObjGenerator::CppObjGenerator() = default;
106 CppObjGenerator::~CppObjGenerator() = default;
107 
Generate(const google::protobuf::FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const108 bool CppObjGenerator::Generate(const google::protobuf::FileDescriptor* file,
109                                const std::string& options,
110                                GeneratorContext* context,
111                                std::string* error) const {
112   for (const std::string& option : SplitString(options, ",")) {
113     std::vector<std::string> option_pair = SplitString(option, "=");
114     if (option_pair[0] == "wrapper_namespace") {
115       wrapper_namespace_ =
116           option_pair.size() == 2 ? "." + option_pair[1] : std::string();
117     } else {
118       *error = "Unknown plugin option: " + option_pair[0];
119       return false;
120     }
121   }
122 
123   auto get_file_name = [](const FileDescriptor* proto) {
124     return StripSuffix(proto->name(), ".proto") + ".gen";
125   };
126 
127   const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
128       context->Open(get_file_name(file) + ".h"));
129   const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
130       context->Open(get_file_name(file) + ".cc"));
131 
132   // Variables are delimited by $.
133   Printer h_printer(h_fstream.get(), '$');
134   Printer cc_printer(cc_fstream.get(), '$');
135 
136   std::string include_guard = file->package() + "_" + file->name() + "_CPP_H_";
137   include_guard = ToUpper(include_guard);
138   include_guard = StripChars(include_guard, ".-/\\", '_');
139 
140   h_printer.Print(kHeader);
141   h_printer.Print("#ifndef $g$\n#define $g$\n\n", "g", include_guard);
142   h_printer.Print("#include <stdint.h>\n");
143   h_printer.Print("#include <bitset>\n");
144   h_printer.Print("#include <vector>\n");
145   h_printer.Print("#include <string>\n");
146   h_printer.Print("#include <type_traits>\n\n");
147   h_printer.Print("#include \"perfetto/protozero/cpp_message_obj.h\"\n");
148   h_printer.Print("#include \"perfetto/protozero/copyable_ptr.h\"\n");
149   h_printer.Print("#include \"perfetto/base/export.h\"\n\n");
150 
151   cc_printer.Print("#include \"perfetto/protozero/gen_field_helpers.h\"\n");
152   cc_printer.Print("#include \"perfetto/protozero/message.h\"\n");
153   cc_printer.Print(
154       "#include \"perfetto/protozero/packed_repeated_fields.h\"\n");
155   cc_printer.Print("#include \"perfetto/protozero/proto_decoder.h\"\n");
156   cc_printer.Print("#include \"perfetto/protozero/scattered_heap_buffer.h\"\n");
157   cc_printer.Print(kHeader);
158   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
159   cc_printer.Print("#pragma GCC diagnostic push\n");
160   cc_printer.Print("#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n");
161   cc_printer.Print("#endif\n");
162 
163   // Generate includes for translated types of dependencies.
164 
165   // Figure out the subset of imports that are used only for lazy fields. We
166   // won't emit a C++ #include for them. This code is overly aggressive at
167   // removing imports: it rules them out as soon as it sees one lazy field
168   // whose type is defined in that import. A 100% correct solution would require
169   // to check that *all* dependent types for a given import are lazy before
170   // excluding that. In practice we don't need that because we don't use imports
171   // for both lazy and non-lazy fields.
172   std::set<std::string> lazy_imports;
173   for (int m = 0; m < file->message_type_count(); m++) {
174     const Descriptor* msg = file->message_type(m);
175     for (int i = 0; i < msg->field_count(); i++) {
176       const FieldDescriptor* field = msg->field(i);
177       if (field->options().lazy()) {
178         lazy_imports.insert(field->message_type()->file()->name());
179       }
180     }
181   }
182 
183   // Recursively traverse all imports and turn them into #include(s).
184   std::vector<const FileDescriptor*> imports_to_visit;
185   std::set<const FileDescriptor*> imports_visited;
186   imports_to_visit.push_back(file);
187 
188   while (!imports_to_visit.empty()) {
189     const FileDescriptor* cur = imports_to_visit.back();
190     imports_to_visit.pop_back();
191     imports_visited.insert(cur);
192     std::string base_name = StripSuffix(cur->name(), ".proto");
193     cc_printer.Print("#include \"$f$.gen.h\"\n", "f", base_name);
194     for (int i = 0; i < cur->dependency_count(); i++) {
195       const FileDescriptor* dep = cur->dependency(i);
196       if (imports_visited.count(dep) || lazy_imports.count(dep->name()))
197         continue;
198       imports_to_visit.push_back(dep);
199     }
200   }
201 
202   // Compute all nested types to generate forward declarations later.
203 
204   std::set<const Descriptor*> all_types_seen;  // All deps
205   std::set<const EnumDescriptor*> all_enums_seen;
206 
207   // We track the types additionally in vectors to guarantee a stable order in
208   // the generated output.
209   std::vector<const Descriptor*> local_types;  // Cur .proto file only.
210   std::vector<const Descriptor*> all_types;    // All deps
211   std::vector<const EnumDescriptor*> local_enums;
212   std::vector<const EnumDescriptor*> all_enums;
213 
214   auto add_enum = [&local_enums, &all_enums, &all_enums_seen,
215                    &file](const EnumDescriptor* enum_desc) {
216     if (all_enums_seen.count(enum_desc))
217       return;
218     all_enums_seen.insert(enum_desc);
219     all_enums.push_back(enum_desc);
220     if (enum_desc->file() == file)
221       local_enums.push_back(enum_desc);
222   };
223 
224   for (int i = 0; i < file->enum_type_count(); i++)
225     add_enum(file->enum_type(i));
226 
227   std::stack<const Descriptor*> recursion_stack;
228   for (int i = 0; i < file->message_type_count(); i++)
229     recursion_stack.push(file->message_type(i));
230 
231   while (!recursion_stack.empty()) {
232     const Descriptor* msg = recursion_stack.top();
233     recursion_stack.pop();
234     if (all_types_seen.count(msg))
235       continue;
236     all_types_seen.insert(msg);
237     all_types.push_back(msg);
238     if (msg->file() == file)
239       local_types.push_back(msg);
240 
241     for (int i = 0; i < msg->nested_type_count(); i++)
242       recursion_stack.push(msg->nested_type(i));
243 
244     for (int i = 0; i < msg->enum_type_count(); i++)
245       add_enum(msg->enum_type(i));
246 
247     for (int i = 0; i < msg->field_count(); i++) {
248       const FieldDescriptor* field = msg->field(i);
249       if (field->has_default_value()) {
250         *error = "field " + field->name() +
251                  ": Explicitly declared default values are not supported";
252         return false;
253       }
254       if (field->options().lazy() &&
255           (field->is_repeated() || field->type() != TYPE_MESSAGE)) {
256         *error = "[lazy=true] is supported only on non-repeated fields\n";
257         return false;
258       }
259 
260       if (field->type() == TYPE_MESSAGE && !field->options().lazy())
261         recursion_stack.push(field->message_type());
262 
263       if (field->type() == FieldDescriptor::TYPE_ENUM)
264         add_enum(field->enum_type());
265     }
266   }  //  while (!recursion_stack.empty())
267 
268   // Generate forward declarations in the header for proto types.
269   // Note: do NOT add #includes to other generated headers (either .gen.h or
270   // .pbzero.h). Doing so is extremely hard to handle at the build-system level
271   // and requires propagating public_deps everywhere.
272   cc_printer.Print("\n");
273 
274   // -- Begin of fwd declarations.
275 
276   // Build up the map of forward declarations.
277   std::multimap<std::string /*namespace*/, std::string /*decl*/> fwd_decls;
278   enum FwdType { kClass, kEnum };
279   auto add_fwd_decl = [&fwd_decls](FwdType cpp_type,
280                                    const std::string& full_name) {
281     auto dot = full_name.rfind("::");
282     PERFETTO_CHECK(dot != std::string::npos);
283     auto package = full_name.substr(0, dot);
284     auto name = full_name.substr(dot + 2);
285     if (cpp_type == kClass) {
286       fwd_decls.emplace(package, "class " + name + ";");
287     } else {
288       PERFETTO_CHECK(cpp_type == kEnum);
289       fwd_decls.emplace(package, "enum " + name + " : int;");
290     }
291   };
292 
293   add_fwd_decl(kClass, "protozero::Message");
294   for (const Descriptor* msg : all_types) {
295     add_fwd_decl(kClass, GetFullName(msg, true));
296   }
297   for (const EnumDescriptor* enm : all_enums) {
298     add_fwd_decl(kEnum, GetFullName(enm, true));
299   }
300 
301   // Emit forward declarations grouping by package.
302   std::string last_package;
303   auto close_last_package = [&last_package, &h_printer] {
304     if (!last_package.empty()) {
305       for (const std::string& ns : SplitString(last_package, "::"))
306         h_printer.Print("}  // namespace $ns$\n", "ns", ns);
307       h_printer.Print("\n");
308     }
309   };
310   for (const auto& kv : fwd_decls) {
311     const std::string& package = kv.first;
312     if (package != last_package) {
313       close_last_package();
314       last_package = package;
315       for (const std::string& ns : SplitString(package, "::"))
316         h_printer.Print("namespace $ns$ {\n", "ns", ns);
317     }
318     h_printer.Print("$decl$\n", "decl", kv.second);
319   }
320   close_last_package();
321 
322   // -- End of fwd declarations.
323 
324   for (const std::string& ns : GetNamespaces(file)) {
325     h_printer.Print("namespace $n$ {\n", "n", ns);
326     cc_printer.Print("namespace $n$ {\n", "n", ns);
327   }
328 
329   // Generate declarations and definitions.
330   for (const EnumDescriptor* enm : local_enums)
331     GenEnum(enm, &h_printer);
332 
333   for (const Descriptor* msg : local_types) {
334     GenClassDecl(msg, &h_printer);
335     GenClassDef(msg, &cc_printer);
336   }
337 
338   for (const std::string& ns : GetNamespaces(file)) {
339     h_printer.Print("}  // namespace $n$\n", "n", ns);
340     cc_printer.Print("}  // namespace $n$\n", "n", ns);
341   }
342   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
343   cc_printer.Print("#pragma GCC diagnostic pop\n");
344   cc_printer.Print("#endif\n");
345 
346   h_printer.Print("\n#endif  // $g$\n", "g", include_guard);
347 
348   return true;
349 }
350 
GetCppType(const FieldDescriptor * field,bool constref) const351 std::string CppObjGenerator::GetCppType(const FieldDescriptor* field,
352                                         bool constref) const {
353   switch (field->type()) {
354     case FieldDescriptor::TYPE_DOUBLE:
355       return "double";
356     case FieldDescriptor::TYPE_FLOAT:
357       return "float";
358     case FieldDescriptor::TYPE_FIXED32:
359     case FieldDescriptor::TYPE_UINT32:
360       return "uint32_t";
361     case FieldDescriptor::TYPE_SFIXED32:
362     case FieldDescriptor::TYPE_INT32:
363     case FieldDescriptor::TYPE_SINT32:
364       return "int32_t";
365     case FieldDescriptor::TYPE_FIXED64:
366     case FieldDescriptor::TYPE_UINT64:
367       return "uint64_t";
368     case FieldDescriptor::TYPE_SFIXED64:
369     case FieldDescriptor::TYPE_SINT64:
370     case FieldDescriptor::TYPE_INT64:
371       return "int64_t";
372     case FieldDescriptor::TYPE_BOOL:
373       return "bool";
374     case FieldDescriptor::TYPE_STRING:
375     case FieldDescriptor::TYPE_BYTES:
376       return constref ? "const std::string&" : "std::string";
377     case FieldDescriptor::TYPE_MESSAGE:
378       assert(!field->options().lazy());
379       return constref ? "const " + GetFullName(field->message_type()) + "&"
380                       : GetFullName(field->message_type());
381     case FieldDescriptor::TYPE_ENUM:
382       return GetFullName(field->enum_type());
383     case FieldDescriptor::TYPE_GROUP:
384       abort();
385   }
386   abort();  // for gcc
387 }
388 
GetProtozeroSetter(const FieldDescriptor * field) const389 std::string CppObjGenerator::GetProtozeroSetter(
390     const FieldDescriptor* field) const {
391   switch (field->type()) {
392     case FieldDescriptor::TYPE_BOOL:
393       return "::protozero::internal::gen_helpers::SerializeTinyVarInt";
394     case FieldDescriptor::TYPE_INT32:
395     case FieldDescriptor::TYPE_INT64:
396     case FieldDescriptor::TYPE_UINT32:
397     case FieldDescriptor::TYPE_UINT64:
398     case FieldDescriptor::TYPE_ENUM:
399       return "::protozero::internal::gen_helpers::SerializeVarInt";
400     case FieldDescriptor::TYPE_SINT32:
401     case FieldDescriptor::TYPE_SINT64:
402       return "::protozero::internal::gen_helpers::SerializeSignedVarInt";
403     case FieldDescriptor::TYPE_FIXED32:
404     case FieldDescriptor::TYPE_FIXED64:
405     case FieldDescriptor::TYPE_SFIXED32:
406     case FieldDescriptor::TYPE_SFIXED64:
407     case FieldDescriptor::TYPE_FLOAT:
408     case FieldDescriptor::TYPE_DOUBLE:
409       return "::protozero::internal::gen_helpers::SerializeFixed";
410     case FieldDescriptor::TYPE_STRING:
411     case FieldDescriptor::TYPE_BYTES:
412       return "::protozero::internal::gen_helpers::SerializeString";
413     case FieldDescriptor::TYPE_GROUP:
414     case FieldDescriptor::TYPE_MESSAGE:
415       abort();
416   }
417   abort();
418 }
419 
GetPackedBuffer(const FieldDescriptor * field) const420 std::string CppObjGenerator::GetPackedBuffer(
421     const FieldDescriptor* field) const {
422   switch (field->type()) {
423     case FieldDescriptor::TYPE_FIXED32:
424       return "::protozero::PackedFixedSizeInt<uint32_t>";
425     case FieldDescriptor::TYPE_SFIXED32:
426       return "::protozero::PackedFixedSizeInt<int32_t>";
427     case FieldDescriptor::TYPE_FIXED64:
428       return "::protozero::PackedFixedSizeInt<uint64_t>";
429     case FieldDescriptor::TYPE_SFIXED64:
430       return "::protozero::PackedFixedSizeInt<int64_t>";
431     case FieldDescriptor::TYPE_DOUBLE:
432       return "::protozero::PackedFixedSizeInt<double>";
433     case FieldDescriptor::TYPE_FLOAT:
434       return "::protozero::PackedFixedSizeInt<float>";
435     case FieldDescriptor::TYPE_INT32:
436     case FieldDescriptor::TYPE_SINT32:
437     case FieldDescriptor::TYPE_UINT32:
438     case FieldDescriptor::TYPE_INT64:
439     case FieldDescriptor::TYPE_UINT64:
440     case FieldDescriptor::TYPE_SINT64:
441     case FieldDescriptor::TYPE_BOOL:
442     case FieldDescriptor::TYPE_ENUM:
443       return "::protozero::PackedVarInt";
444     case FieldDescriptor::TYPE_STRING:
445     case FieldDescriptor::TYPE_BYTES:
446     case FieldDescriptor::TYPE_MESSAGE:
447     case FieldDescriptor::TYPE_GROUP:
448       break;  // Will abort()
449   }
450   abort();
451 }
452 
GetPackedWireType(const FieldDescriptor * field) const453 std::string CppObjGenerator::GetPackedWireType(
454     const FieldDescriptor* field) const {
455   switch (field->type()) {
456     case FieldDescriptor::TYPE_FIXED32:
457     case FieldDescriptor::TYPE_SFIXED32:
458     case FieldDescriptor::TYPE_FLOAT:
459       return "::protozero::proto_utils::ProtoWireType::kFixed32";
460     case FieldDescriptor::TYPE_FIXED64:
461     case FieldDescriptor::TYPE_SFIXED64:
462     case FieldDescriptor::TYPE_DOUBLE:
463       return "::protozero::proto_utils::ProtoWireType::kFixed64";
464     case FieldDescriptor::TYPE_INT32:
465     case FieldDescriptor::TYPE_SINT32:
466     case FieldDescriptor::TYPE_UINT32:
467     case FieldDescriptor::TYPE_INT64:
468     case FieldDescriptor::TYPE_UINT64:
469     case FieldDescriptor::TYPE_SINT64:
470     case FieldDescriptor::TYPE_BOOL:
471     case FieldDescriptor::TYPE_ENUM:
472       return "::protozero::proto_utils::ProtoWireType::kVarInt";
473     case FieldDescriptor::TYPE_STRING:
474     case FieldDescriptor::TYPE_BYTES:
475     case FieldDescriptor::TYPE_MESSAGE:
476     case FieldDescriptor::TYPE_GROUP:
477       break;  // Will abort()
478   }
479   abort();
480 }
481 
GenEnum(const EnumDescriptor * enum_desc,Printer * p) const482 void CppObjGenerator::GenEnum(const EnumDescriptor* enum_desc,
483                               Printer* p) const {
484   std::string full_name = GetFullName(enum_desc);
485 
486   // When generating enums, there are two cases:
487   // 1. Enums nested in a message (most frequent case), e.g.:
488   //    message MyMsg { enum MyEnum { FOO=1; BAR=2; } }
489   // 2. Enum defined at the package level, outside of any message.
490   //
491   // In the case 1, the C++ code generated by the official protobuf library is:
492   // enum MyEnum {  MyMsg_MyEnum_FOO=1, MyMsg_MyEnum_BAR=2 }
493   // class MyMsg { static const auto FOO = MyMsg_MyEnum_FOO; ... same for BAR }
494   //
495   // In the case 2, the C++ code is simply:
496   // enum MyEnum { FOO=1, BAR=2 }
497   // Hence this |prefix| logic.
498   std::string prefix = enum_desc->containing_type() ? full_name + "_" : "";
499   p->Print("enum $f$ : int {\n", "f", full_name);
500   for (int e = 0; e < enum_desc->value_count(); e++) {
501     const EnumValueDescriptor* value = enum_desc->value(e);
502     p->Print("  $p$$n$ = $v$,\n", "p", prefix, "n", value->name(), "v",
503              std::to_string(value->number()));
504   }
505   p->Print("};\n");
506 }
507 
GenEnumAliases(const EnumDescriptor * enum_desc,Printer * p) const508 void CppObjGenerator::GenEnumAliases(const EnumDescriptor* enum_desc,
509                                      Printer* p) const {
510   int min_value = std::numeric_limits<int>::max();
511   int max_value = std::numeric_limits<int>::min();
512   std::string min_name;
513   std::string max_name;
514   std::string full_name = GetFullName(enum_desc);
515   for (int e = 0; e < enum_desc->value_count(); e++) {
516     const EnumValueDescriptor* value = enum_desc->value(e);
517     p->Print("static constexpr auto $n$ = $f$_$n$;\n", "f", full_name, "n",
518              value->name());
519     if (value->number() < min_value) {
520       min_value = value->number();
521       min_name = full_name + "_" + value->name();
522     }
523     if (value->number() > max_value) {
524       max_value = value->number();
525       max_name = full_name + "_" + value->name();
526     }
527   }
528   p->Print("static constexpr auto $n$_MIN = $m$;\n", "n", enum_desc->name(),
529            "m", min_name);
530   p->Print("static constexpr auto $n$_MAX = $m$;\n", "n", enum_desc->name(),
531            "m", max_name);
532 }
533 
GenClassDecl(const Descriptor * msg,Printer * p) const534 void CppObjGenerator::GenClassDecl(const Descriptor* msg, Printer* p) const {
535   std::string full_name = GetFullName(msg);
536   p->Print(
537       "\nclass PERFETTO_EXPORT_COMPONENT $n$ : public "
538       "::protozero::CppMessageObj {\n",
539       "n", full_name);
540   p->Print(" public:\n");
541   p->Indent();
542 
543   // Do a first pass to generate aliases for nested types.
544   // e.g., using Foo = Parent_Foo;
545   for (int i = 0; i < msg->nested_type_count(); i++) {
546     const Descriptor* nested_msg = msg->nested_type(i);
547     p->Print("using $n$ = $f$;\n", "n", nested_msg->name(), "f",
548              GetFullName(nested_msg));
549   }
550   for (int i = 0; i < msg->enum_type_count(); i++) {
551     const EnumDescriptor* nested_enum = msg->enum_type(i);
552     p->Print("using $n$ = $f$;\n", "n", nested_enum->name(), "f",
553              GetFullName(nested_enum));
554     GenEnumAliases(nested_enum, p);
555   }
556 
557   // Generate constants with field numbers.
558   p->Print("enum FieldNumbers {\n");
559   for (int i = 0; i < msg->field_count(); i++) {
560     const FieldDescriptor* field = msg->field(i);
561     std::string name = field->camelcase_name();
562     name[0] = perfetto::base::Uppercase(name[0]);
563     p->Print("  k$n$FieldNumber = $num$,\n", "n", name, "num",
564              std::to_string(field->number()));
565   }
566   p->Print("};\n\n");
567 
568   p->Print("$n$();\n", "n", full_name);
569   p->Print("~$n$() override;\n", "n", full_name);
570   p->Print("$n$($n$&&) noexcept;\n", "n", full_name);
571   p->Print("$n$& operator=($n$&&);\n", "n", full_name);
572   p->Print("$n$(const $n$&);\n", "n", full_name);
573   p->Print("$n$& operator=(const $n$&);\n", "n", full_name);
574   p->Print("bool operator==(const $n$&) const;\n", "n", full_name);
575   p->Print(
576       "bool operator!=(const $n$& other) const { return !(*this == other); }\n",
577       "n", full_name);
578   p->Print("\n");
579 
580   std::string proto_type = GetFullName(msg, true);
581   p->Print("bool ParseFromArray(const void*, size_t) override;\n");
582   p->Print("std::string SerializeAsString() const override;\n");
583   p->Print("std::vector<uint8_t> SerializeAsArray() const override;\n");
584   p->Print("void Serialize(::protozero::Message*) const;\n");
585 
586   // Generate accessors.
587   for (int i = 0; i < msg->field_count(); i++) {
588     const FieldDescriptor* field = msg->field(i);
589     auto set_bit = "_has_field_.set(" + std::to_string(field->number()) + ")";
590     p->Print("\n");
591     if (field->options().lazy()) {
592       p->Print("const std::string& $n$_raw() const { return $n$_; }\n", "n",
593                field->lowercase_name());
594       p->Print(
595           "void set_$n$_raw(const std::string& raw) { $n$_ = raw; $s$; }\n",
596           "n", field->lowercase_name(), "s", set_bit);
597     } else if (!field->is_repeated()) {
598       p->Print("bool has_$n$() const { return _has_field_[$bit$]; }\n", "n",
599                field->lowercase_name(), "bit", std::to_string(field->number()));
600       if (field->type() == TYPE_MESSAGE) {
601         p->Print("$t$ $n$() const { return *$n$_; }\n", "t",
602                  GetCppType(field, true), "n", field->lowercase_name());
603         p->Print("$t$* mutable_$n$() { $s$; return $n$_.get(); }\n", "t",
604                  GetCppType(field, false), "n", field->lowercase_name(), "s",
605                  set_bit);
606       } else {
607         p->Print("$t$ $n$() const { return $n$_; }\n", "t",
608                  GetCppType(field, true), "n", field->lowercase_name());
609         p->Print("void set_$n$($t$ value) { $n$_ = value; $s$; }\n", "t",
610                  GetCppType(field, true), "n", field->lowercase_name(), "s",
611                  set_bit);
612         if (field->type() == FieldDescriptor::TYPE_BYTES) {
613           p->Print(
614               "void set_$n$(const void* p, size_t s) { "
615               "$n$_.assign(reinterpret_cast<const char*>(p), s); $s$; }\n",
616               "n", field->lowercase_name(), "s", set_bit);
617         }
618       }
619     } else {  // is_repeated()
620       p->Print("const std::vector<$t$>& $n$() const { return $n$_; }\n", "t",
621                GetCppType(field, false), "n", field->lowercase_name());
622       p->Print("std::vector<$t$>* mutable_$n$() { return &$n$_; }\n", "t",
623                GetCppType(field, false), "n", field->lowercase_name());
624 
625       // Generate accessors for repeated message types in the .cc file so that
626       // the header doesn't depend on the full definition of all nested types.
627       if (field->type() == TYPE_MESSAGE) {
628         p->Print("int $n$_size() const;\n", "t", GetCppType(field, false), "n",
629                  field->lowercase_name());
630         p->Print("void clear_$n$();\n", "n", field->lowercase_name());
631         p->Print("$t$* add_$n$();\n", "t", GetCppType(field, false), "n",
632                  field->lowercase_name());
633       } else {  // Primitive type.
634         p->Print(
635             "int $n$_size() const { return static_cast<int>($n$_.size()); }\n",
636             "t", GetCppType(field, false), "n", field->lowercase_name());
637         p->Print("void clear_$n$() { $n$_.clear(); }\n", "n",
638                  field->lowercase_name());
639         p->Print("void add_$n$($t$ value) { $n$_.emplace_back(value); }\n", "t",
640                  GetCppType(field, false), "n", field->lowercase_name());
641         // TODO(primiano): this should be done only for TYPE_MESSAGE.
642         // Unfortuntely we didn't realize before and now we have a bunch of code
643         // that does: *msg->add_int_value() = 42 instead of
644         // msg->add_int_value(42).
645         p->Print(
646             "$t$* add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
647             "t", GetCppType(field, false), "n", field->lowercase_name());
648       }
649     }
650   }
651   p->Outdent();
652   p->Print("\n private:\n");
653   p->Indent();
654 
655   // Generate fields.
656   int max_field_id = 1;
657   for (int i = 0; i < msg->field_count(); i++) {
658     const FieldDescriptor* field = msg->field(i);
659     max_field_id = std::max(max_field_id, field->number());
660     if (field->options().lazy()) {
661       p->Print("std::string $n$_;  // [lazy=true]\n", "n",
662                field->lowercase_name());
663     } else if (!field->is_repeated()) {
664       std::string type = GetCppType(field, false);
665       if (field->type() == TYPE_MESSAGE) {
666         type = "::protozero::CopyablePtr<" + type + ">";
667         p->Print("$t$ $n$_;\n", "t", type, "n", field->lowercase_name());
668       } else {
669         p->Print("$t$ $n$_{};\n", "t", type, "n", field->lowercase_name());
670       }
671     } else {  // is_repeated()
672       p->Print("std::vector<$t$> $n$_;\n", "t", GetCppType(field, false), "n",
673                field->lowercase_name());
674     }
675   }
676   p->Print("\n");
677   p->Print("// Allows to preserve unknown protobuf fields for compatibility\n");
678   p->Print("// with future versions of .proto files.\n");
679   p->Print("std::string unknown_fields_;\n");
680 
681   p->Print("\nstd::bitset<$id$> _has_field_{};\n", "id",
682            std::to_string(max_field_id + 1));
683 
684   p->Outdent();
685   p->Print("};\n\n");
686 }
687 
GenClassDef(const Descriptor * msg,Printer * p) const688 void CppObjGenerator::GenClassDef(const Descriptor* msg, Printer* p) const {
689   p->Print("\n");
690   std::string full_name = GetFullName(msg);
691 
692   p->Print("$n$::$n$() = default;\n", "n", full_name);
693   p->Print("$n$::~$n$() = default;\n", "n", full_name);
694   p->Print("$n$::$n$(const $n$&) = default;\n", "n", full_name);
695   p->Print("$n$& $n$::operator=(const $n$&) = default;\n", "n", full_name);
696   p->Print("$n$::$n$($n$&&) noexcept = default;\n", "n", full_name);
697   p->Print("$n$& $n$::operator=($n$&&) = default;\n", "n", full_name);
698 
699   p->Print("\n");
700 
701   // Comparison operator
702   p->Print("bool $n$::operator==(const $n$& other) const {\n", "n", full_name);
703   p->Indent();
704 
705   p->Print("return unknown_fields_ == other.unknown_fields_");
706   for (int i = 0; i < msg->field_count(); i++)
707     p->Print("\n && $n$_ == other.$n$_", "n", msg->field(i)->lowercase_name());
708   p->Print(";");
709   p->Outdent();
710   p->Print("\n}\n\n");
711 
712   // Accessors for repeated message fields.
713   for (int i = 0; i < msg->field_count(); i++) {
714     const FieldDescriptor* field = msg->field(i);
715     if (field->options().lazy() || !field->is_repeated() ||
716         field->type() != TYPE_MESSAGE) {
717       continue;
718     }
719     p->Print(
720         "int $c$::$n$_size() const { return static_cast<int>($n$_.size()); }\n",
721         "c", full_name, "t", GetCppType(field, false), "n",
722         field->lowercase_name());
723     p->Print("void $c$::clear_$n$() { $n$_.clear(); }\n", "c", full_name, "n",
724              field->lowercase_name());
725     p->Print(
726         "$t$* $c$::add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
727         "c", full_name, "t", GetCppType(field, false), "n",
728         field->lowercase_name());
729   }
730 
731   std::string proto_type = GetFullName(msg, true);
732 
733   // Generate the ParseFromArray() method definition.
734   p->Print("bool $f$::ParseFromArray(const void* raw, size_t size) {\n", "f",
735            full_name);
736   p->Indent();
737   for (int i = 0; i < msg->field_count(); i++) {
738     const FieldDescriptor* field = msg->field(i);
739     if (field->is_repeated())
740       p->Print("$n$_.clear();\n", "n", field->lowercase_name());
741   }
742   p->Print("unknown_fields_.clear();\n");
743   p->Print("bool packed_error = false;\n");
744   p->Print("\n");
745   p->Print("::protozero::ProtoDecoder dec(raw, size);\n");
746   p->Print("for (auto field = dec.ReadField(); field.valid(); ");
747   p->Print("field = dec.ReadField()) {\n");
748   p->Indent();
749   p->Print("if (field.id() < _has_field_.size()) {\n");
750   p->Print("  _has_field_.set(field.id());\n");
751   p->Print("}\n");
752   p->Print("switch (field.id()) {\n");
753   p->Indent();
754   for (int i = 0; i < msg->field_count(); i++) {
755     const FieldDescriptor* field = msg->field(i);
756     p->Print("case $id$ /* $n$ */:\n", "id", std::to_string(field->number()),
757              "n", field->lowercase_name());
758     p->Indent();
759     if (field->options().lazy()) {
760       p->Print(
761           "::protozero::internal::gen_helpers::DeserializeString(field, "
762           "&$n$_);\n",
763           "n", field->lowercase_name());
764     } else {
765       std::string statement;
766       if (field->type() == TYPE_MESSAGE) {
767         statement = "$rval$.ParseFromArray(field.data(), field.size());\n";
768       } else {
769         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
770           // sint32/64 fields are special and need to be zig-zag-decoded.
771           statement = "field.get_signed(&$rval$);\n";
772         } else if (field->type() == TYPE_STRING) {
773           statement =
774               "::protozero::internal::gen_helpers::DeserializeString(field, "
775               "&$rval$);\n";
776         } else {
777           statement = "field.get(&$rval$);\n";
778         }
779       }
780       if (field->is_packed()) {
781         PERFETTO_CHECK(field->is_repeated());
782         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
783           PERFETTO_FATAL("packed signed (zigzag) fields are not supported");
784         }
785         p->Print(
786             "if "
787             "(!::protozero::internal::gen_helpers::DeserializePackedRepeated"
788             "<$w$, $c$>(field, &$n$_)) {\n",
789             "w", GetPackedWireType(field), "c", GetCppType(field, false), "n",
790             field->lowercase_name());
791         p->Print("  packed_error = true;");
792         p->Print("}\n");
793       } else if (field->is_repeated()) {
794         p->Print("$n$_.emplace_back();\n", "n", field->lowercase_name());
795         p->Print(statement.c_str(), "rval",
796                  field->lowercase_name() + "_.back()");
797       } else if (field->type() == TYPE_MESSAGE) {
798         p->Print(statement.c_str(), "rval",
799                  "(*" + field->lowercase_name() + "_)");
800       } else {
801         p->Print(statement.c_str(), "rval", field->lowercase_name() + "_");
802       }
803     }
804     p->Print("break;\n");
805     p->Outdent();
806   }  // for (field)
807   p->Print("default:\n");
808   p->Print("  field.SerializeAndAppendTo(&unknown_fields_);\n");
809   p->Print("  break;\n");
810   p->Outdent();
811   p->Print("}\n");  // switch(field.id)
812   p->Outdent();
813   p->Print("}\n");                                           // for(field)
814   p->Print("return !packed_error && !dec.bytes_left();\n");  // for(field)
815   p->Outdent();
816   p->Print("}\n\n");
817 
818   // Generate the SerializeAsString() method definition.
819   p->Print("std::string $f$::SerializeAsString() const {\n", "f", full_name);
820   p->Indent();
821   p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
822   p->Print("Serialize(msg.get());\n");
823   p->Print("return msg.SerializeAsString();\n");
824   p->Outdent();
825   p->Print("}\n\n");
826 
827   // Generate the SerializeAsArray() method definition.
828   p->Print("std::vector<uint8_t> $f$::SerializeAsArray() const {\n", "f",
829            full_name);
830   p->Indent();
831   p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
832   p->Print("Serialize(msg.get());\n");
833   p->Print("return msg.SerializeAsArray();\n");
834   p->Outdent();
835   p->Print("}\n\n");
836 
837   // Generate the Serialize() method that writes the fields into the passed
838   // protozero |msg| write-only interface |msg|.
839   p->Print("void $f$::Serialize(::protozero::Message* msg) const {\n", "f",
840            full_name);
841   p->Indent();
842   for (int i = 0; i < msg->field_count(); i++) {
843     const FieldDescriptor* field = msg->field(i);
844     std::map<std::string, std::string> args;
845     args["id"] = std::to_string(field->number());
846     args["n"] = field->lowercase_name();
847     p->Print(args, "// Field $id$: $n$\n");
848     if (field->is_packed()) {
849       PERFETTO_CHECK(field->is_repeated());
850       p->Print("{\n");
851       p->Indent();
852       p->Print("$p$ pack;\n", "p", GetPackedBuffer(field));
853       p->Print(args, "for (auto& it : $n$_)\n");
854       p->Print(args, "  pack.Append(it);\n");
855       p->Print(args, "msg->AppendBytes($id$, pack.data(), pack.size());\n");
856       p->Outdent();
857       p->Print("}\n");
858     } else {
859       if (field->is_repeated()) {
860         p->Print(args, "for (auto& it : $n$_) {\n");
861         args["lvalue"] = "it";
862         args["rvalue"] = "it";
863       } else {
864         p->Print(args, "if (_has_field_[$id$]) {\n");
865         args["lvalue"] = "(*" + field->lowercase_name() + "_)";
866         args["rvalue"] = field->lowercase_name() + "_";
867       }
868       p->Indent();
869       if (field->options().lazy()) {
870         p->Print(args, "msg->AppendString($id$, $rvalue$);\n");
871       } else if (field->type() == TYPE_MESSAGE) {
872         p->Print(args,
873                  "$lvalue$.Serialize("
874                  "msg->BeginNestedMessage<::protozero::Message>($id$));\n");
875       } else {
876         args["setter"] = GetProtozeroSetter(field);
877         p->Print(args, "$setter$($id$, $rvalue$, msg);\n");
878       }
879       p->Outdent();
880       p->Print("}\n");
881     }
882 
883     p->Print("\n");
884   }  // for (field)
885   p->Print(
886       "protozero::internal::gen_helpers::SerializeUnknownFields(unknown_fields_"
887       ", msg);\n");
888   p->Outdent();
889   p->Print("}\n\n");
890 }
891 
892 }  // namespace
893 }  // namespace protozero
894 
main(int argc,char ** argv)895 int main(int argc, char** argv) {
896   ::protozero::CppObjGenerator generator;
897   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
898 }
899