1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <stdio.h>
18 #include <stdlib.h>
19
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <set>
24 #include <stack>
25 #include <vector>
26
27 #include <google/protobuf/compiler/code_generator.h>
28 #include <google/protobuf/compiler/importer.h>
29 #include <google/protobuf/compiler/plugin.h>
30 #include <google/protobuf/dynamic_message.h>
31 #include <google/protobuf/io/printer.h>
32 #include <google/protobuf/io/zero_copy_stream_impl.h>
33 #include <google/protobuf/util/field_comparator.h>
34 #include <google/protobuf/util/message_differencer.h>
35
36 #include "perfetto/ext/base/string_utils.h"
37
38 namespace protozero {
39 namespace {
40
41 using namespace google::protobuf;
42 using namespace google::protobuf::compiler;
43 using namespace google::protobuf::io;
44 using perfetto::base::SplitString;
45 using perfetto::base::StripChars;
46 using perfetto::base::StripSuffix;
47 using perfetto::base::ToUpper;
48
49 static constexpr auto TYPE_STRING = FieldDescriptor::TYPE_STRING;
50 static constexpr auto TYPE_MESSAGE = FieldDescriptor::TYPE_MESSAGE;
51 static constexpr auto TYPE_SINT32 = FieldDescriptor::TYPE_SINT32;
52 static constexpr auto TYPE_SINT64 = FieldDescriptor::TYPE_SINT64;
53
54 static const char kHeader[] =
55 "// DO NOT EDIT. Autogenerated by Perfetto cppgen_plugin\n";
56
57 class CppObjGenerator : public ::google::protobuf::compiler::CodeGenerator {
58 public:
59 CppObjGenerator();
60 ~CppObjGenerator() override;
61
62 // CodeGenerator implementation
63 bool Generate(const google::protobuf::FileDescriptor* file,
64 const std::string& options,
65 GeneratorContext* context,
66 std::string* error) const override;
67
68 private:
69 std::string GetCppType(const FieldDescriptor* field, bool constref) const;
70 std::string GetProtozeroSetter(const FieldDescriptor* field) const;
71 std::string GetPackedBuffer(const FieldDescriptor* field) const;
72 std::string GetPackedWireType(const FieldDescriptor* field) const;
73
74 void GenEnum(const EnumDescriptor*, Printer*) const;
75 void GenEnumAliases(const EnumDescriptor*, Printer*) const;
76 void GenClassDecl(const Descriptor*, Printer*) const;
77 void GenClassDef(const Descriptor*, Printer*) const;
78
GetNamespaces(const FileDescriptor * file) const79 std::vector<std::string> GetNamespaces(const FileDescriptor* file) const {
80 std::string pkg = file->package() + wrapper_namespace_;
81 return SplitString(pkg, ".");
82 }
83
84 template <typename T = Descriptor>
GetFullName(const T * msg,bool with_namespace=false) const85 std::string GetFullName(const T* msg, bool with_namespace = false) const {
86 std::string full_type;
87 full_type.append(msg->name());
88 for (const Descriptor* par = msg->containing_type(); par;
89 par = par->containing_type()) {
90 full_type.insert(0, par->name() + "_");
91 }
92 if (with_namespace) {
93 std::string prefix;
94 for (const std::string& ns : GetNamespaces(msg->file())) {
95 prefix += ns + "::";
96 }
97 full_type = prefix + full_type;
98 }
99 return full_type;
100 }
101
102 mutable std::string wrapper_namespace_;
103 };
104
105 CppObjGenerator::CppObjGenerator() = default;
106 CppObjGenerator::~CppObjGenerator() = default;
107
Generate(const google::protobuf::FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const108 bool CppObjGenerator::Generate(const google::protobuf::FileDescriptor* file,
109 const std::string& options,
110 GeneratorContext* context,
111 std::string* error) const {
112 for (const std::string& option : SplitString(options, ",")) {
113 std::vector<std::string> option_pair = SplitString(option, "=");
114 if (option_pair[0] == "wrapper_namespace") {
115 wrapper_namespace_ =
116 option_pair.size() == 2 ? "." + option_pair[1] : std::string();
117 } else {
118 *error = "Unknown plugin option: " + option_pair[0];
119 return false;
120 }
121 }
122
123 auto get_file_name = [](const FileDescriptor* proto) {
124 return StripSuffix(proto->name(), ".proto") + ".gen";
125 };
126
127 const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
128 context->Open(get_file_name(file) + ".h"));
129 const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
130 context->Open(get_file_name(file) + ".cc"));
131
132 // Variables are delimited by $.
133 Printer h_printer(h_fstream.get(), '$');
134 Printer cc_printer(cc_fstream.get(), '$');
135
136 std::string include_guard = file->package() + "_" + file->name() + "_CPP_H_";
137 include_guard = ToUpper(include_guard);
138 include_guard = StripChars(include_guard, ".-/\\", '_');
139
140 h_printer.Print(kHeader);
141 h_printer.Print("#ifndef $g$\n#define $g$\n\n", "g", include_guard);
142 h_printer.Print("#include <stdint.h>\n");
143 h_printer.Print("#include <bitset>\n");
144 h_printer.Print("#include <vector>\n");
145 h_printer.Print("#include <string>\n");
146 h_printer.Print("#include <type_traits>\n\n");
147 h_printer.Print("#include \"perfetto/protozero/cpp_message_obj.h\"\n");
148 h_printer.Print("#include \"perfetto/protozero/copyable_ptr.h\"\n");
149 h_printer.Print("#include \"perfetto/base/export.h\"\n\n");
150
151 cc_printer.Print("#include \"perfetto/protozero/gen_field_helpers.h\"\n");
152 cc_printer.Print("#include \"perfetto/protozero/message.h\"\n");
153 cc_printer.Print(
154 "#include \"perfetto/protozero/packed_repeated_fields.h\"\n");
155 cc_printer.Print("#include \"perfetto/protozero/proto_decoder.h\"\n");
156 cc_printer.Print("#include \"perfetto/protozero/scattered_heap_buffer.h\"\n");
157 cc_printer.Print(kHeader);
158 cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
159 cc_printer.Print("#pragma GCC diagnostic push\n");
160 cc_printer.Print("#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n");
161 cc_printer.Print("#endif\n");
162
163 // Generate includes for translated types of dependencies.
164
165 // Figure out the subset of imports that are used only for lazy fields. We
166 // won't emit a C++ #include for them. This code is overly aggressive at
167 // removing imports: it rules them out as soon as it sees one lazy field
168 // whose type is defined in that import. A 100% correct solution would require
169 // to check that *all* dependent types for a given import are lazy before
170 // excluding that. In practice we don't need that because we don't use imports
171 // for both lazy and non-lazy fields.
172 std::set<std::string> lazy_imports;
173 for (int m = 0; m < file->message_type_count(); m++) {
174 const Descriptor* msg = file->message_type(m);
175 for (int i = 0; i < msg->field_count(); i++) {
176 const FieldDescriptor* field = msg->field(i);
177 if (field->options().lazy()) {
178 lazy_imports.insert(field->message_type()->file()->name());
179 }
180 }
181 }
182
183 // Recursively traverse all imports and turn them into #include(s).
184 std::vector<const FileDescriptor*> imports_to_visit;
185 std::set<const FileDescriptor*> imports_visited;
186 imports_to_visit.push_back(file);
187
188 while (!imports_to_visit.empty()) {
189 const FileDescriptor* cur = imports_to_visit.back();
190 imports_to_visit.pop_back();
191 imports_visited.insert(cur);
192 std::string base_name = StripSuffix(cur->name(), ".proto");
193 cc_printer.Print("#include \"$f$.gen.h\"\n", "f", base_name);
194 for (int i = 0; i < cur->dependency_count(); i++) {
195 const FileDescriptor* dep = cur->dependency(i);
196 if (imports_visited.count(dep) || lazy_imports.count(dep->name()))
197 continue;
198 imports_to_visit.push_back(dep);
199 }
200 }
201
202 // Compute all nested types to generate forward declarations later.
203
204 std::set<const Descriptor*> all_types_seen; // All deps
205 std::set<const EnumDescriptor*> all_enums_seen;
206
207 // We track the types additionally in vectors to guarantee a stable order in
208 // the generated output.
209 std::vector<const Descriptor*> local_types; // Cur .proto file only.
210 std::vector<const Descriptor*> all_types; // All deps
211 std::vector<const EnumDescriptor*> local_enums;
212 std::vector<const EnumDescriptor*> all_enums;
213
214 auto add_enum = [&local_enums, &all_enums, &all_enums_seen,
215 &file](const EnumDescriptor* enum_desc) {
216 if (all_enums_seen.count(enum_desc))
217 return;
218 all_enums_seen.insert(enum_desc);
219 all_enums.push_back(enum_desc);
220 if (enum_desc->file() == file)
221 local_enums.push_back(enum_desc);
222 };
223
224 for (int i = 0; i < file->enum_type_count(); i++)
225 add_enum(file->enum_type(i));
226
227 std::stack<const Descriptor*> recursion_stack;
228 for (int i = 0; i < file->message_type_count(); i++)
229 recursion_stack.push(file->message_type(i));
230
231 while (!recursion_stack.empty()) {
232 const Descriptor* msg = recursion_stack.top();
233 recursion_stack.pop();
234 if (all_types_seen.count(msg))
235 continue;
236 all_types_seen.insert(msg);
237 all_types.push_back(msg);
238 if (msg->file() == file)
239 local_types.push_back(msg);
240
241 for (int i = 0; i < msg->nested_type_count(); i++)
242 recursion_stack.push(msg->nested_type(i));
243
244 for (int i = 0; i < msg->enum_type_count(); i++)
245 add_enum(msg->enum_type(i));
246
247 for (int i = 0; i < msg->field_count(); i++) {
248 const FieldDescriptor* field = msg->field(i);
249 if (field->has_default_value()) {
250 *error = "field " + field->name() +
251 ": Explicitly declared default values are not supported";
252 return false;
253 }
254 if (field->options().lazy() &&
255 (field->is_repeated() || field->type() != TYPE_MESSAGE)) {
256 *error = "[lazy=true] is supported only on non-repeated fields\n";
257 return false;
258 }
259
260 if (field->type() == TYPE_MESSAGE && !field->options().lazy())
261 recursion_stack.push(field->message_type());
262
263 if (field->type() == FieldDescriptor::TYPE_ENUM)
264 add_enum(field->enum_type());
265 }
266 } // while (!recursion_stack.empty())
267
268 // Generate forward declarations in the header for proto types.
269 // Note: do NOT add #includes to other generated headers (either .gen.h or
270 // .pbzero.h). Doing so is extremely hard to handle at the build-system level
271 // and requires propagating public_deps everywhere.
272 cc_printer.Print("\n");
273
274 // -- Begin of fwd declarations.
275
276 // Build up the map of forward declarations.
277 std::multimap<std::string /*namespace*/, std::string /*decl*/> fwd_decls;
278 enum FwdType { kClass, kEnum };
279 auto add_fwd_decl = [&fwd_decls](FwdType cpp_type,
280 const std::string& full_name) {
281 auto dot = full_name.rfind("::");
282 PERFETTO_CHECK(dot != std::string::npos);
283 auto package = full_name.substr(0, dot);
284 auto name = full_name.substr(dot + 2);
285 if (cpp_type == kClass) {
286 fwd_decls.emplace(package, "class " + name + ";");
287 } else {
288 PERFETTO_CHECK(cpp_type == kEnum);
289 fwd_decls.emplace(package, "enum " + name + " : int;");
290 }
291 };
292
293 add_fwd_decl(kClass, "protozero::Message");
294 for (const Descriptor* msg : all_types) {
295 add_fwd_decl(kClass, GetFullName(msg, true));
296 }
297 for (const EnumDescriptor* enm : all_enums) {
298 add_fwd_decl(kEnum, GetFullName(enm, true));
299 }
300
301 // Emit forward declarations grouping by package.
302 std::string last_package;
303 auto close_last_package = [&last_package, &h_printer] {
304 if (!last_package.empty()) {
305 for (const std::string& ns : SplitString(last_package, "::"))
306 h_printer.Print("} // namespace $ns$\n", "ns", ns);
307 h_printer.Print("\n");
308 }
309 };
310 for (const auto& kv : fwd_decls) {
311 const std::string& package = kv.first;
312 if (package != last_package) {
313 close_last_package();
314 last_package = package;
315 for (const std::string& ns : SplitString(package, "::"))
316 h_printer.Print("namespace $ns$ {\n", "ns", ns);
317 }
318 h_printer.Print("$decl$\n", "decl", kv.second);
319 }
320 close_last_package();
321
322 // -- End of fwd declarations.
323
324 for (const std::string& ns : GetNamespaces(file)) {
325 h_printer.Print("namespace $n$ {\n", "n", ns);
326 cc_printer.Print("namespace $n$ {\n", "n", ns);
327 }
328
329 // Generate declarations and definitions.
330 for (const EnumDescriptor* enm : local_enums)
331 GenEnum(enm, &h_printer);
332
333 for (const Descriptor* msg : local_types) {
334 GenClassDecl(msg, &h_printer);
335 GenClassDef(msg, &cc_printer);
336 }
337
338 for (const std::string& ns : GetNamespaces(file)) {
339 h_printer.Print("} // namespace $n$\n", "n", ns);
340 cc_printer.Print("} // namespace $n$\n", "n", ns);
341 }
342 cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
343 cc_printer.Print("#pragma GCC diagnostic pop\n");
344 cc_printer.Print("#endif\n");
345
346 h_printer.Print("\n#endif // $g$\n", "g", include_guard);
347
348 return true;
349 }
350
GetCppType(const FieldDescriptor * field,bool constref) const351 std::string CppObjGenerator::GetCppType(const FieldDescriptor* field,
352 bool constref) const {
353 switch (field->type()) {
354 case FieldDescriptor::TYPE_DOUBLE:
355 return "double";
356 case FieldDescriptor::TYPE_FLOAT:
357 return "float";
358 case FieldDescriptor::TYPE_FIXED32:
359 case FieldDescriptor::TYPE_UINT32:
360 return "uint32_t";
361 case FieldDescriptor::TYPE_SFIXED32:
362 case FieldDescriptor::TYPE_INT32:
363 case FieldDescriptor::TYPE_SINT32:
364 return "int32_t";
365 case FieldDescriptor::TYPE_FIXED64:
366 case FieldDescriptor::TYPE_UINT64:
367 return "uint64_t";
368 case FieldDescriptor::TYPE_SFIXED64:
369 case FieldDescriptor::TYPE_SINT64:
370 case FieldDescriptor::TYPE_INT64:
371 return "int64_t";
372 case FieldDescriptor::TYPE_BOOL:
373 return "bool";
374 case FieldDescriptor::TYPE_STRING:
375 case FieldDescriptor::TYPE_BYTES:
376 return constref ? "const std::string&" : "std::string";
377 case FieldDescriptor::TYPE_MESSAGE:
378 assert(!field->options().lazy());
379 return constref ? "const " + GetFullName(field->message_type()) + "&"
380 : GetFullName(field->message_type());
381 case FieldDescriptor::TYPE_ENUM:
382 return GetFullName(field->enum_type());
383 case FieldDescriptor::TYPE_GROUP:
384 abort();
385 }
386 abort(); // for gcc
387 }
388
GetProtozeroSetter(const FieldDescriptor * field) const389 std::string CppObjGenerator::GetProtozeroSetter(
390 const FieldDescriptor* field) const {
391 switch (field->type()) {
392 case FieldDescriptor::TYPE_BOOL:
393 return "::protozero::internal::gen_helpers::SerializeTinyVarInt";
394 case FieldDescriptor::TYPE_INT32:
395 case FieldDescriptor::TYPE_INT64:
396 case FieldDescriptor::TYPE_UINT32:
397 case FieldDescriptor::TYPE_UINT64:
398 case FieldDescriptor::TYPE_ENUM:
399 return "::protozero::internal::gen_helpers::SerializeVarInt";
400 case FieldDescriptor::TYPE_SINT32:
401 case FieldDescriptor::TYPE_SINT64:
402 return "::protozero::internal::gen_helpers::SerializeSignedVarInt";
403 case FieldDescriptor::TYPE_FIXED32:
404 case FieldDescriptor::TYPE_FIXED64:
405 case FieldDescriptor::TYPE_SFIXED32:
406 case FieldDescriptor::TYPE_SFIXED64:
407 case FieldDescriptor::TYPE_FLOAT:
408 case FieldDescriptor::TYPE_DOUBLE:
409 return "::protozero::internal::gen_helpers::SerializeFixed";
410 case FieldDescriptor::TYPE_STRING:
411 case FieldDescriptor::TYPE_BYTES:
412 return "::protozero::internal::gen_helpers::SerializeString";
413 case FieldDescriptor::TYPE_GROUP:
414 case FieldDescriptor::TYPE_MESSAGE:
415 abort();
416 }
417 abort();
418 }
419
GetPackedBuffer(const FieldDescriptor * field) const420 std::string CppObjGenerator::GetPackedBuffer(
421 const FieldDescriptor* field) const {
422 switch (field->type()) {
423 case FieldDescriptor::TYPE_FIXED32:
424 return "::protozero::PackedFixedSizeInt<uint32_t>";
425 case FieldDescriptor::TYPE_SFIXED32:
426 return "::protozero::PackedFixedSizeInt<int32_t>";
427 case FieldDescriptor::TYPE_FIXED64:
428 return "::protozero::PackedFixedSizeInt<uint64_t>";
429 case FieldDescriptor::TYPE_SFIXED64:
430 return "::protozero::PackedFixedSizeInt<int64_t>";
431 case FieldDescriptor::TYPE_DOUBLE:
432 return "::protozero::PackedFixedSizeInt<double>";
433 case FieldDescriptor::TYPE_FLOAT:
434 return "::protozero::PackedFixedSizeInt<float>";
435 case FieldDescriptor::TYPE_INT32:
436 case FieldDescriptor::TYPE_SINT32:
437 case FieldDescriptor::TYPE_UINT32:
438 case FieldDescriptor::TYPE_INT64:
439 case FieldDescriptor::TYPE_UINT64:
440 case FieldDescriptor::TYPE_SINT64:
441 case FieldDescriptor::TYPE_BOOL:
442 case FieldDescriptor::TYPE_ENUM:
443 return "::protozero::PackedVarInt";
444 case FieldDescriptor::TYPE_STRING:
445 case FieldDescriptor::TYPE_BYTES:
446 case FieldDescriptor::TYPE_MESSAGE:
447 case FieldDescriptor::TYPE_GROUP:
448 break; // Will abort()
449 }
450 abort();
451 }
452
GetPackedWireType(const FieldDescriptor * field) const453 std::string CppObjGenerator::GetPackedWireType(
454 const FieldDescriptor* field) const {
455 switch (field->type()) {
456 case FieldDescriptor::TYPE_FIXED32:
457 case FieldDescriptor::TYPE_SFIXED32:
458 case FieldDescriptor::TYPE_FLOAT:
459 return "::protozero::proto_utils::ProtoWireType::kFixed32";
460 case FieldDescriptor::TYPE_FIXED64:
461 case FieldDescriptor::TYPE_SFIXED64:
462 case FieldDescriptor::TYPE_DOUBLE:
463 return "::protozero::proto_utils::ProtoWireType::kFixed64";
464 case FieldDescriptor::TYPE_INT32:
465 case FieldDescriptor::TYPE_SINT32:
466 case FieldDescriptor::TYPE_UINT32:
467 case FieldDescriptor::TYPE_INT64:
468 case FieldDescriptor::TYPE_UINT64:
469 case FieldDescriptor::TYPE_SINT64:
470 case FieldDescriptor::TYPE_BOOL:
471 case FieldDescriptor::TYPE_ENUM:
472 return "::protozero::proto_utils::ProtoWireType::kVarInt";
473 case FieldDescriptor::TYPE_STRING:
474 case FieldDescriptor::TYPE_BYTES:
475 case FieldDescriptor::TYPE_MESSAGE:
476 case FieldDescriptor::TYPE_GROUP:
477 break; // Will abort()
478 }
479 abort();
480 }
481
GenEnum(const EnumDescriptor * enum_desc,Printer * p) const482 void CppObjGenerator::GenEnum(const EnumDescriptor* enum_desc,
483 Printer* p) const {
484 std::string full_name = GetFullName(enum_desc);
485
486 // When generating enums, there are two cases:
487 // 1. Enums nested in a message (most frequent case), e.g.:
488 // message MyMsg { enum MyEnum { FOO=1; BAR=2; } }
489 // 2. Enum defined at the package level, outside of any message.
490 //
491 // In the case 1, the C++ code generated by the official protobuf library is:
492 // enum MyEnum { MyMsg_MyEnum_FOO=1, MyMsg_MyEnum_BAR=2 }
493 // class MyMsg { static const auto FOO = MyMsg_MyEnum_FOO; ... same for BAR }
494 //
495 // In the case 2, the C++ code is simply:
496 // enum MyEnum { FOO=1, BAR=2 }
497 // Hence this |prefix| logic.
498 std::string prefix = enum_desc->containing_type() ? full_name + "_" : "";
499 p->Print("enum $f$ : int {\n", "f", full_name);
500 for (int e = 0; e < enum_desc->value_count(); e++) {
501 const EnumValueDescriptor* value = enum_desc->value(e);
502 p->Print(" $p$$n$ = $v$,\n", "p", prefix, "n", value->name(), "v",
503 std::to_string(value->number()));
504 }
505 p->Print("};\n");
506 }
507
GenEnumAliases(const EnumDescriptor * enum_desc,Printer * p) const508 void CppObjGenerator::GenEnumAliases(const EnumDescriptor* enum_desc,
509 Printer* p) const {
510 int min_value = std::numeric_limits<int>::max();
511 int max_value = std::numeric_limits<int>::min();
512 std::string min_name;
513 std::string max_name;
514 std::string full_name = GetFullName(enum_desc);
515 for (int e = 0; e < enum_desc->value_count(); e++) {
516 const EnumValueDescriptor* value = enum_desc->value(e);
517 p->Print("static constexpr auto $n$ = $f$_$n$;\n", "f", full_name, "n",
518 value->name());
519 if (value->number() < min_value) {
520 min_value = value->number();
521 min_name = full_name + "_" + value->name();
522 }
523 if (value->number() > max_value) {
524 max_value = value->number();
525 max_name = full_name + "_" + value->name();
526 }
527 }
528 p->Print("static constexpr auto $n$_MIN = $m$;\n", "n", enum_desc->name(),
529 "m", min_name);
530 p->Print("static constexpr auto $n$_MAX = $m$;\n", "n", enum_desc->name(),
531 "m", max_name);
532 }
533
GenClassDecl(const Descriptor * msg,Printer * p) const534 void CppObjGenerator::GenClassDecl(const Descriptor* msg, Printer* p) const {
535 std::string full_name = GetFullName(msg);
536 p->Print(
537 "\nclass PERFETTO_EXPORT_COMPONENT $n$ : public "
538 "::protozero::CppMessageObj {\n",
539 "n", full_name);
540 p->Print(" public:\n");
541 p->Indent();
542
543 // Do a first pass to generate aliases for nested types.
544 // e.g., using Foo = Parent_Foo;
545 for (int i = 0; i < msg->nested_type_count(); i++) {
546 const Descriptor* nested_msg = msg->nested_type(i);
547 p->Print("using $n$ = $f$;\n", "n", nested_msg->name(), "f",
548 GetFullName(nested_msg));
549 }
550 for (int i = 0; i < msg->enum_type_count(); i++) {
551 const EnumDescriptor* nested_enum = msg->enum_type(i);
552 p->Print("using $n$ = $f$;\n", "n", nested_enum->name(), "f",
553 GetFullName(nested_enum));
554 GenEnumAliases(nested_enum, p);
555 }
556
557 // Generate constants with field numbers.
558 p->Print("enum FieldNumbers {\n");
559 for (int i = 0; i < msg->field_count(); i++) {
560 const FieldDescriptor* field = msg->field(i);
561 std::string name = field->camelcase_name();
562 name[0] = perfetto::base::Uppercase(name[0]);
563 p->Print(" k$n$FieldNumber = $num$,\n", "n", name, "num",
564 std::to_string(field->number()));
565 }
566 p->Print("};\n\n");
567
568 p->Print("$n$();\n", "n", full_name);
569 p->Print("~$n$() override;\n", "n", full_name);
570 p->Print("$n$($n$&&) noexcept;\n", "n", full_name);
571 p->Print("$n$& operator=($n$&&);\n", "n", full_name);
572 p->Print("$n$(const $n$&);\n", "n", full_name);
573 p->Print("$n$& operator=(const $n$&);\n", "n", full_name);
574 p->Print("bool operator==(const $n$&) const;\n", "n", full_name);
575 p->Print(
576 "bool operator!=(const $n$& other) const { return !(*this == other); }\n",
577 "n", full_name);
578 p->Print("\n");
579
580 std::string proto_type = GetFullName(msg, true);
581 p->Print("bool ParseFromArray(const void*, size_t) override;\n");
582 p->Print("std::string SerializeAsString() const override;\n");
583 p->Print("std::vector<uint8_t> SerializeAsArray() const override;\n");
584 p->Print("void Serialize(::protozero::Message*) const;\n");
585
586 // Generate accessors.
587 for (int i = 0; i < msg->field_count(); i++) {
588 const FieldDescriptor* field = msg->field(i);
589 auto set_bit = "_has_field_.set(" + std::to_string(field->number()) + ")";
590 p->Print("\n");
591 if (field->options().lazy()) {
592 p->Print("const std::string& $n$_raw() const { return $n$_; }\n", "n",
593 field->lowercase_name());
594 p->Print(
595 "void set_$n$_raw(const std::string& raw) { $n$_ = raw; $s$; }\n",
596 "n", field->lowercase_name(), "s", set_bit);
597 } else if (!field->is_repeated()) {
598 p->Print("bool has_$n$() const { return _has_field_[$bit$]; }\n", "n",
599 field->lowercase_name(), "bit", std::to_string(field->number()));
600 if (field->type() == TYPE_MESSAGE) {
601 p->Print("$t$ $n$() const { return *$n$_; }\n", "t",
602 GetCppType(field, true), "n", field->lowercase_name());
603 p->Print("$t$* mutable_$n$() { $s$; return $n$_.get(); }\n", "t",
604 GetCppType(field, false), "n", field->lowercase_name(), "s",
605 set_bit);
606 } else {
607 p->Print("$t$ $n$() const { return $n$_; }\n", "t",
608 GetCppType(field, true), "n", field->lowercase_name());
609 p->Print("void set_$n$($t$ value) { $n$_ = value; $s$; }\n", "t",
610 GetCppType(field, true), "n", field->lowercase_name(), "s",
611 set_bit);
612 if (field->type() == FieldDescriptor::TYPE_BYTES) {
613 p->Print(
614 "void set_$n$(const void* p, size_t s) { "
615 "$n$_.assign(reinterpret_cast<const char*>(p), s); $s$; }\n",
616 "n", field->lowercase_name(), "s", set_bit);
617 }
618 }
619 } else { // is_repeated()
620 p->Print("const std::vector<$t$>& $n$() const { return $n$_; }\n", "t",
621 GetCppType(field, false), "n", field->lowercase_name());
622 p->Print("std::vector<$t$>* mutable_$n$() { return &$n$_; }\n", "t",
623 GetCppType(field, false), "n", field->lowercase_name());
624
625 // Generate accessors for repeated message types in the .cc file so that
626 // the header doesn't depend on the full definition of all nested types.
627 if (field->type() == TYPE_MESSAGE) {
628 p->Print("int $n$_size() const;\n", "t", GetCppType(field, false), "n",
629 field->lowercase_name());
630 p->Print("void clear_$n$();\n", "n", field->lowercase_name());
631 p->Print("$t$* add_$n$();\n", "t", GetCppType(field, false), "n",
632 field->lowercase_name());
633 } else { // Primitive type.
634 p->Print(
635 "int $n$_size() const { return static_cast<int>($n$_.size()); }\n",
636 "t", GetCppType(field, false), "n", field->lowercase_name());
637 p->Print("void clear_$n$() { $n$_.clear(); }\n", "n",
638 field->lowercase_name());
639 p->Print("void add_$n$($t$ value) { $n$_.emplace_back(value); }\n", "t",
640 GetCppType(field, false), "n", field->lowercase_name());
641 // TODO(primiano): this should be done only for TYPE_MESSAGE.
642 // Unfortuntely we didn't realize before and now we have a bunch of code
643 // that does: *msg->add_int_value() = 42 instead of
644 // msg->add_int_value(42).
645 p->Print(
646 "$t$* add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
647 "t", GetCppType(field, false), "n", field->lowercase_name());
648 }
649 }
650 }
651 p->Outdent();
652 p->Print("\n private:\n");
653 p->Indent();
654
655 // Generate fields.
656 int max_field_id = 1;
657 for (int i = 0; i < msg->field_count(); i++) {
658 const FieldDescriptor* field = msg->field(i);
659 max_field_id = std::max(max_field_id, field->number());
660 if (field->options().lazy()) {
661 p->Print("std::string $n$_; // [lazy=true]\n", "n",
662 field->lowercase_name());
663 } else if (!field->is_repeated()) {
664 std::string type = GetCppType(field, false);
665 if (field->type() == TYPE_MESSAGE) {
666 type = "::protozero::CopyablePtr<" + type + ">";
667 p->Print("$t$ $n$_;\n", "t", type, "n", field->lowercase_name());
668 } else {
669 p->Print("$t$ $n$_{};\n", "t", type, "n", field->lowercase_name());
670 }
671 } else { // is_repeated()
672 p->Print("std::vector<$t$> $n$_;\n", "t", GetCppType(field, false), "n",
673 field->lowercase_name());
674 }
675 }
676 p->Print("\n");
677 p->Print("// Allows to preserve unknown protobuf fields for compatibility\n");
678 p->Print("// with future versions of .proto files.\n");
679 p->Print("std::string unknown_fields_;\n");
680
681 p->Print("\nstd::bitset<$id$> _has_field_{};\n", "id",
682 std::to_string(max_field_id + 1));
683
684 p->Outdent();
685 p->Print("};\n\n");
686 }
687
GenClassDef(const Descriptor * msg,Printer * p) const688 void CppObjGenerator::GenClassDef(const Descriptor* msg, Printer* p) const {
689 p->Print("\n");
690 std::string full_name = GetFullName(msg);
691
692 p->Print("$n$::$n$() = default;\n", "n", full_name);
693 p->Print("$n$::~$n$() = default;\n", "n", full_name);
694 p->Print("$n$::$n$(const $n$&) = default;\n", "n", full_name);
695 p->Print("$n$& $n$::operator=(const $n$&) = default;\n", "n", full_name);
696 p->Print("$n$::$n$($n$&&) noexcept = default;\n", "n", full_name);
697 p->Print("$n$& $n$::operator=($n$&&) = default;\n", "n", full_name);
698
699 p->Print("\n");
700
701 // Comparison operator
702 p->Print("bool $n$::operator==(const $n$& other) const {\n", "n", full_name);
703 p->Indent();
704
705 p->Print("return unknown_fields_ == other.unknown_fields_");
706 for (int i = 0; i < msg->field_count(); i++)
707 p->Print("\n && $n$_ == other.$n$_", "n", msg->field(i)->lowercase_name());
708 p->Print(";");
709 p->Outdent();
710 p->Print("\n}\n\n");
711
712 // Accessors for repeated message fields.
713 for (int i = 0; i < msg->field_count(); i++) {
714 const FieldDescriptor* field = msg->field(i);
715 if (field->options().lazy() || !field->is_repeated() ||
716 field->type() != TYPE_MESSAGE) {
717 continue;
718 }
719 p->Print(
720 "int $c$::$n$_size() const { return static_cast<int>($n$_.size()); }\n",
721 "c", full_name, "t", GetCppType(field, false), "n",
722 field->lowercase_name());
723 p->Print("void $c$::clear_$n$() { $n$_.clear(); }\n", "c", full_name, "n",
724 field->lowercase_name());
725 p->Print(
726 "$t$* $c$::add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
727 "c", full_name, "t", GetCppType(field, false), "n",
728 field->lowercase_name());
729 }
730
731 std::string proto_type = GetFullName(msg, true);
732
733 // Generate the ParseFromArray() method definition.
734 p->Print("bool $f$::ParseFromArray(const void* raw, size_t size) {\n", "f",
735 full_name);
736 p->Indent();
737 for (int i = 0; i < msg->field_count(); i++) {
738 const FieldDescriptor* field = msg->field(i);
739 if (field->is_repeated())
740 p->Print("$n$_.clear();\n", "n", field->lowercase_name());
741 }
742 p->Print("unknown_fields_.clear();\n");
743 p->Print("bool packed_error = false;\n");
744 p->Print("\n");
745 p->Print("::protozero::ProtoDecoder dec(raw, size);\n");
746 p->Print("for (auto field = dec.ReadField(); field.valid(); ");
747 p->Print("field = dec.ReadField()) {\n");
748 p->Indent();
749 p->Print("if (field.id() < _has_field_.size()) {\n");
750 p->Print(" _has_field_.set(field.id());\n");
751 p->Print("}\n");
752 p->Print("switch (field.id()) {\n");
753 p->Indent();
754 for (int i = 0; i < msg->field_count(); i++) {
755 const FieldDescriptor* field = msg->field(i);
756 p->Print("case $id$ /* $n$ */:\n", "id", std::to_string(field->number()),
757 "n", field->lowercase_name());
758 p->Indent();
759 if (field->options().lazy()) {
760 p->Print(
761 "::protozero::internal::gen_helpers::DeserializeString(field, "
762 "&$n$_);\n",
763 "n", field->lowercase_name());
764 } else {
765 std::string statement;
766 if (field->type() == TYPE_MESSAGE) {
767 statement = "$rval$.ParseFromArray(field.data(), field.size());\n";
768 } else {
769 if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
770 // sint32/64 fields are special and need to be zig-zag-decoded.
771 statement = "field.get_signed(&$rval$);\n";
772 } else if (field->type() == TYPE_STRING) {
773 statement =
774 "::protozero::internal::gen_helpers::DeserializeString(field, "
775 "&$rval$);\n";
776 } else {
777 statement = "field.get(&$rval$);\n";
778 }
779 }
780 if (field->is_packed()) {
781 PERFETTO_CHECK(field->is_repeated());
782 if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
783 PERFETTO_FATAL("packed signed (zigzag) fields are not supported");
784 }
785 p->Print(
786 "if "
787 "(!::protozero::internal::gen_helpers::DeserializePackedRepeated"
788 "<$w$, $c$>(field, &$n$_)) {\n",
789 "w", GetPackedWireType(field), "c", GetCppType(field, false), "n",
790 field->lowercase_name());
791 p->Print(" packed_error = true;");
792 p->Print("}\n");
793 } else if (field->is_repeated()) {
794 p->Print("$n$_.emplace_back();\n", "n", field->lowercase_name());
795 p->Print(statement.c_str(), "rval",
796 field->lowercase_name() + "_.back()");
797 } else if (field->type() == TYPE_MESSAGE) {
798 p->Print(statement.c_str(), "rval",
799 "(*" + field->lowercase_name() + "_)");
800 } else {
801 p->Print(statement.c_str(), "rval", field->lowercase_name() + "_");
802 }
803 }
804 p->Print("break;\n");
805 p->Outdent();
806 } // for (field)
807 p->Print("default:\n");
808 p->Print(" field.SerializeAndAppendTo(&unknown_fields_);\n");
809 p->Print(" break;\n");
810 p->Outdent();
811 p->Print("}\n"); // switch(field.id)
812 p->Outdent();
813 p->Print("}\n"); // for(field)
814 p->Print("return !packed_error && !dec.bytes_left();\n"); // for(field)
815 p->Outdent();
816 p->Print("}\n\n");
817
818 // Generate the SerializeAsString() method definition.
819 p->Print("std::string $f$::SerializeAsString() const {\n", "f", full_name);
820 p->Indent();
821 p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
822 p->Print("Serialize(msg.get());\n");
823 p->Print("return msg.SerializeAsString();\n");
824 p->Outdent();
825 p->Print("}\n\n");
826
827 // Generate the SerializeAsArray() method definition.
828 p->Print("std::vector<uint8_t> $f$::SerializeAsArray() const {\n", "f",
829 full_name);
830 p->Indent();
831 p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
832 p->Print("Serialize(msg.get());\n");
833 p->Print("return msg.SerializeAsArray();\n");
834 p->Outdent();
835 p->Print("}\n\n");
836
837 // Generate the Serialize() method that writes the fields into the passed
838 // protozero |msg| write-only interface |msg|.
839 p->Print("void $f$::Serialize(::protozero::Message* msg) const {\n", "f",
840 full_name);
841 p->Indent();
842 for (int i = 0; i < msg->field_count(); i++) {
843 const FieldDescriptor* field = msg->field(i);
844 std::map<std::string, std::string> args;
845 args["id"] = std::to_string(field->number());
846 args["n"] = field->lowercase_name();
847 p->Print(args, "// Field $id$: $n$\n");
848 if (field->is_packed()) {
849 PERFETTO_CHECK(field->is_repeated());
850 p->Print("{\n");
851 p->Indent();
852 p->Print("$p$ pack;\n", "p", GetPackedBuffer(field));
853 p->Print(args, "for (auto& it : $n$_)\n");
854 p->Print(args, " pack.Append(it);\n");
855 p->Print(args, "msg->AppendBytes($id$, pack.data(), pack.size());\n");
856 p->Outdent();
857 p->Print("}\n");
858 } else {
859 if (field->is_repeated()) {
860 p->Print(args, "for (auto& it : $n$_) {\n");
861 args["lvalue"] = "it";
862 args["rvalue"] = "it";
863 } else {
864 p->Print(args, "if (_has_field_[$id$]) {\n");
865 args["lvalue"] = "(*" + field->lowercase_name() + "_)";
866 args["rvalue"] = field->lowercase_name() + "_";
867 }
868 p->Indent();
869 if (field->options().lazy()) {
870 p->Print(args, "msg->AppendString($id$, $rvalue$);\n");
871 } else if (field->type() == TYPE_MESSAGE) {
872 p->Print(args,
873 "$lvalue$.Serialize("
874 "msg->BeginNestedMessage<::protozero::Message>($id$));\n");
875 } else {
876 args["setter"] = GetProtozeroSetter(field);
877 p->Print(args, "$setter$($id$, $rvalue$, msg);\n");
878 }
879 p->Outdent();
880 p->Print("}\n");
881 }
882
883 p->Print("\n");
884 } // for (field)
885 p->Print(
886 "protozero::internal::gen_helpers::SerializeUnknownFields(unknown_fields_"
887 ", msg);\n");
888 p->Outdent();
889 p->Print("}\n\n");
890 }
891
892 } // namespace
893 } // namespace protozero
894
main(int argc,char ** argv)895 int main(int argc, char** argv) {
896 ::protozero::CppObjGenerator generator;
897 return google::protobuf::compiler::PluginMain(argc, argv, &generator);
898 }
899