1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <stdio.h>
18 #include <stdlib.h>
19
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <set>
24 #include <stack>
25 #include <vector>
26
27 #include <google/protobuf/compiler/code_generator.h>
28 #include <google/protobuf/compiler/importer.h>
29 #include <google/protobuf/compiler/plugin.h>
30 #include <google/protobuf/dynamic_message.h>
31 #include <google/protobuf/io/printer.h>
32 #include <google/protobuf/io/zero_copy_stream_impl.h>
33 #include <google/protobuf/util/field_comparator.h>
34 #include <google/protobuf/util/message_differencer.h>
35
36 #include "perfetto/ext/base/string_utils.h"
37
38 namespace protozero {
39 namespace {
40
41 using namespace google::protobuf;
42 using namespace google::protobuf::compiler;
43 using namespace google::protobuf::io;
44 using perfetto::base::SplitString;
45 using perfetto::base::StripChars;
46 using perfetto::base::StripSuffix;
47 using perfetto::base::ToUpper;
48
49 static constexpr auto TYPE_MESSAGE = FieldDescriptor::TYPE_MESSAGE;
50 static constexpr auto TYPE_SINT32 = FieldDescriptor::TYPE_SINT32;
51 static constexpr auto TYPE_SINT64 = FieldDescriptor::TYPE_SINT64;
52
53 static const char kHeader[] =
54 "// DO NOT EDIT. Autogenerated by Perfetto cppgen_plugin\n";
55
56 class CppObjGenerator : public ::google::protobuf::compiler::CodeGenerator {
57 public:
58 CppObjGenerator();
59 ~CppObjGenerator() override;
60
61 // CodeGenerator implementation
62 bool Generate(const google::protobuf::FileDescriptor* file,
63 const std::string& options,
64 GeneratorContext* context,
65 std::string* error) const override;
66
67 private:
68 std::string GetCppType(const FieldDescriptor* field, bool constref) const;
69 std::string GetProtozeroSetter(const FieldDescriptor* field) const;
70 std::string GetPackedBuffer(const FieldDescriptor* field) const;
71 std::string GetPackedWireType(const FieldDescriptor* field) const;
72
73 void GenEnum(const EnumDescriptor*, Printer*) const;
74 void GenEnumAliases(const EnumDescriptor*, Printer*) const;
75 void GenClassDecl(const Descriptor*, Printer*) const;
76 void GenClassDef(const Descriptor*, Printer*) const;
77
GetNamespaces(const FileDescriptor * file) const78 std::vector<std::string> GetNamespaces(const FileDescriptor* file) const {
79 std::string pkg = file->package() + wrapper_namespace_;
80 return SplitString(pkg, ".");
81 }
82
83 template <typename T = Descriptor>
GetFullName(const T * msg,bool with_namespace=false) const84 std::string GetFullName(const T* msg, bool with_namespace = false) const {
85 std::string full_type;
86 full_type.append(msg->name());
87 for (const Descriptor* par = msg->containing_type(); par;
88 par = par->containing_type()) {
89 full_type.insert(0, par->name() + "_");
90 }
91 if (with_namespace) {
92 std::string prefix;
93 for (const std::string& ns : GetNamespaces(msg->file())) {
94 prefix += ns + "::";
95 }
96 full_type = prefix + full_type;
97 }
98 return full_type;
99 }
100
101 mutable std::string wrapper_namespace_;
102 };
103
104 CppObjGenerator::CppObjGenerator() = default;
105 CppObjGenerator::~CppObjGenerator() = default;
106
Generate(const google::protobuf::FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const107 bool CppObjGenerator::Generate(const google::protobuf::FileDescriptor* file,
108 const std::string& options,
109 GeneratorContext* context,
110 std::string* error) const {
111 for (const std::string& option : SplitString(options, ",")) {
112 std::vector<std::string> option_pair = SplitString(option, "=");
113 if (option_pair[0] == "wrapper_namespace") {
114 wrapper_namespace_ =
115 option_pair.size() == 2 ? "." + option_pair[1] : std::string();
116 } else {
117 *error = "Unknown plugin option: " + option_pair[0];
118 return false;
119 }
120 }
121
122 auto get_file_name = [](const FileDescriptor* proto) {
123 return StripSuffix(proto->name(), ".proto") + ".gen";
124 };
125
126 const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
127 context->Open(get_file_name(file) + ".h"));
128 const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
129 context->Open(get_file_name(file) + ".cc"));
130
131 // Variables are delimited by $.
132 Printer h_printer(h_fstream.get(), '$');
133 Printer cc_printer(cc_fstream.get(), '$');
134
135 std::string include_guard = file->package() + "_" + file->name() + "_CPP_H_";
136 include_guard = ToUpper(include_guard);
137 include_guard = StripChars(include_guard, ".-/\\", '_');
138
139 h_printer.Print(kHeader);
140 h_printer.Print("#ifndef $g$\n#define $g$\n\n", "g", include_guard);
141 h_printer.Print("#include <stdint.h>\n");
142 h_printer.Print("#include <bitset>\n");
143 h_printer.Print("#include <vector>\n");
144 h_printer.Print("#include <string>\n");
145 h_printer.Print("#include <type_traits>\n\n");
146 h_printer.Print("#include \"perfetto/protozero/cpp_message_obj.h\"\n");
147 h_printer.Print("#include \"perfetto/protozero/copyable_ptr.h\"\n");
148 h_printer.Print("#include \"perfetto/base/export.h\"\n\n");
149
150 cc_printer.Print("#include \"perfetto/protozero/message.h\"\n");
151 cc_printer.Print(
152 "#include \"perfetto/protozero/packed_repeated_fields.h\"\n");
153 cc_printer.Print("#include \"perfetto/protozero/proto_decoder.h\"\n");
154 cc_printer.Print("#include \"perfetto/protozero/scattered_heap_buffer.h\"\n");
155 cc_printer.Print(kHeader);
156 cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
157 cc_printer.Print("#pragma GCC diagnostic push\n");
158 cc_printer.Print("#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n");
159 cc_printer.Print("#endif\n");
160
161 // Generate includes for translated types of dependencies.
162
163 // Figure out the subset of imports that are used only for lazy fields. We
164 // won't emit a C++ #include for them. This code is overly aggressive at
165 // removing imports: it rules them out as soon as it sees one lazy field
166 // whose type is defined in that import. A 100% correct solution would require
167 // to check that *all* dependent types for a given import are lazy before
168 // excluding that. In practice we don't need that because we don't use imports
169 // for both lazy and non-lazy fields.
170 std::set<std::string> lazy_imports;
171 for (int m = 0; m < file->message_type_count(); m++) {
172 const Descriptor* msg = file->message_type(m);
173 for (int i = 0; i < msg->field_count(); i++) {
174 const FieldDescriptor* field = msg->field(i);
175 if (field->options().lazy()) {
176 lazy_imports.insert(field->message_type()->file()->name());
177 }
178 }
179 }
180
181 // Recursively traverse all imports and turn them into #include(s).
182 std::vector<const FileDescriptor*> imports_to_visit;
183 std::set<const FileDescriptor*> imports_visited;
184 imports_to_visit.push_back(file);
185
186 while (!imports_to_visit.empty()) {
187 const FileDescriptor* cur = imports_to_visit.back();
188 imports_to_visit.pop_back();
189 imports_visited.insert(cur);
190 std::string base_name = StripSuffix(cur->name(), ".proto");
191 cc_printer.Print("#include \"$f$.gen.h\"\n", "f", base_name);
192 for (int i = 0; i < cur->dependency_count(); i++) {
193 const FileDescriptor* dep = cur->dependency(i);
194 if (imports_visited.count(dep) || lazy_imports.count(dep->name()))
195 continue;
196 imports_to_visit.push_back(dep);
197 }
198 }
199
200 // Compute all nested types to generate forward declarations later.
201
202 std::set<const Descriptor*> all_types_seen; // All deps
203 std::set<const EnumDescriptor*> all_enums_seen;
204
205 // We track the types additionally in vectors to guarantee a stable order in
206 // the generated output.
207 std::vector<const Descriptor*> local_types; // Cur .proto file only.
208 std::vector<const Descriptor*> all_types; // All deps
209 std::vector<const EnumDescriptor*> local_enums;
210 std::vector<const EnumDescriptor*> all_enums;
211
212 auto add_enum = [&local_enums, &all_enums, &all_enums_seen,
213 &file](const EnumDescriptor* enum_desc) {
214 if (all_enums_seen.count(enum_desc))
215 return;
216 all_enums_seen.insert(enum_desc);
217 all_enums.push_back(enum_desc);
218 if (enum_desc->file() == file)
219 local_enums.push_back(enum_desc);
220 };
221
222 for (int i = 0; i < file->enum_type_count(); i++)
223 add_enum(file->enum_type(i));
224
225 std::stack<const Descriptor*> recursion_stack;
226 for (int i = 0; i < file->message_type_count(); i++)
227 recursion_stack.push(file->message_type(i));
228
229 while (!recursion_stack.empty()) {
230 const Descriptor* msg = recursion_stack.top();
231 recursion_stack.pop();
232 if (all_types_seen.count(msg))
233 continue;
234 all_types_seen.insert(msg);
235 all_types.push_back(msg);
236 if (msg->file() == file)
237 local_types.push_back(msg);
238
239 for (int i = 0; i < msg->nested_type_count(); i++)
240 recursion_stack.push(msg->nested_type(i));
241
242 for (int i = 0; i < msg->enum_type_count(); i++)
243 add_enum(msg->enum_type(i));
244
245 for (int i = 0; i < msg->field_count(); i++) {
246 const FieldDescriptor* field = msg->field(i);
247 if (field->has_default_value()) {
248 *error = "field " + field->name() +
249 ": Explicitly declared default values are not supported";
250 return false;
251 }
252 if (field->options().lazy() &&
253 (field->is_repeated() || field->type() != TYPE_MESSAGE)) {
254 *error = "[lazy=true] is supported only on non-repeated fields\n";
255 return false;
256 }
257
258 if (field->type() == TYPE_MESSAGE && !field->options().lazy())
259 recursion_stack.push(field->message_type());
260
261 if (field->type() == FieldDescriptor::TYPE_ENUM)
262 add_enum(field->enum_type());
263 }
264 } // while (!recursion_stack.empty())
265
266 // Generate forward declarations in the header for proto types.
267 // Note: do NOT add #includes to other generated headers (either .gen.h or
268 // .pbzero.h). Doing so is extremely hard to handle at the build-system level
269 // and requires propagating public_deps everywhere.
270 cc_printer.Print("\n");
271
272 // -- Begin of fwd declarations.
273
274 // Build up the map of forward declarations.
275 std::multimap<std::string /*namespace*/, std::string /*decl*/> fwd_decls;
276 enum FwdType { kClass, kEnum };
277 auto add_fwd_decl = [&fwd_decls](FwdType cpp_type,
278 const std::string& full_name) {
279 auto dot = full_name.rfind("::");
280 PERFETTO_CHECK(dot != std::string::npos);
281 auto package = full_name.substr(0, dot);
282 auto name = full_name.substr(dot + 2);
283 if (cpp_type == kClass) {
284 fwd_decls.emplace(package, "class " + name + ";");
285 } else {
286 PERFETTO_CHECK(cpp_type == kEnum);
287 fwd_decls.emplace(package, "enum " + name + " : int;");
288 }
289 };
290
291 add_fwd_decl(kClass, "protozero::Message");
292 for (const Descriptor* msg : all_types) {
293 add_fwd_decl(kClass, GetFullName(msg, true));
294 }
295 for (const EnumDescriptor* enm : all_enums) {
296 add_fwd_decl(kEnum, GetFullName(enm, true));
297 }
298
299 // Emit forward declarations grouping by package.
300 std::string last_package;
301 auto close_last_package = [&last_package, &h_printer] {
302 if (!last_package.empty()) {
303 for (const std::string& ns : SplitString(last_package, "::"))
304 h_printer.Print("} // namespace $ns$\n", "ns", ns);
305 h_printer.Print("\n");
306 }
307 };
308 for (const auto& kv : fwd_decls) {
309 const std::string& package = kv.first;
310 if (package != last_package) {
311 close_last_package();
312 last_package = package;
313 for (const std::string& ns : SplitString(package, "::"))
314 h_printer.Print("namespace $ns$ {\n", "ns", ns);
315 }
316 h_printer.Print("$decl$\n", "decl", kv.second);
317 }
318 close_last_package();
319
320 // -- End of fwd declarations.
321
322 for (const std::string& ns : GetNamespaces(file)) {
323 h_printer.Print("namespace $n$ {\n", "n", ns);
324 cc_printer.Print("namespace $n$ {\n", "n", ns);
325 }
326
327 // Generate declarations and definitions.
328 for (const EnumDescriptor* enm : local_enums)
329 GenEnum(enm, &h_printer);
330
331 for (const Descriptor* msg : local_types) {
332 GenClassDecl(msg, &h_printer);
333 GenClassDef(msg, &cc_printer);
334 }
335
336 for (const std::string& ns : GetNamespaces(file)) {
337 h_printer.Print("} // namespace $n$\n", "n", ns);
338 cc_printer.Print("} // namespace $n$\n", "n", ns);
339 }
340 cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
341 cc_printer.Print("#pragma GCC diagnostic pop\n");
342 cc_printer.Print("#endif\n");
343
344 h_printer.Print("\n#endif // $g$\n", "g", include_guard);
345
346 return true;
347 }
348
GetCppType(const FieldDescriptor * field,bool constref) const349 std::string CppObjGenerator::GetCppType(const FieldDescriptor* field,
350 bool constref) const {
351 switch (field->type()) {
352 case FieldDescriptor::TYPE_DOUBLE:
353 return "double";
354 case FieldDescriptor::TYPE_FLOAT:
355 return "float";
356 case FieldDescriptor::TYPE_FIXED32:
357 case FieldDescriptor::TYPE_UINT32:
358 return "uint32_t";
359 case FieldDescriptor::TYPE_SFIXED32:
360 case FieldDescriptor::TYPE_INT32:
361 case FieldDescriptor::TYPE_SINT32:
362 return "int32_t";
363 case FieldDescriptor::TYPE_FIXED64:
364 case FieldDescriptor::TYPE_UINT64:
365 return "uint64_t";
366 case FieldDescriptor::TYPE_SFIXED64:
367 case FieldDescriptor::TYPE_SINT64:
368 case FieldDescriptor::TYPE_INT64:
369 return "int64_t";
370 case FieldDescriptor::TYPE_BOOL:
371 return "bool";
372 case FieldDescriptor::TYPE_STRING:
373 case FieldDescriptor::TYPE_BYTES:
374 return constref ? "const std::string&" : "std::string";
375 case FieldDescriptor::TYPE_MESSAGE:
376 assert(!field->options().lazy());
377 return constref ? "const " + GetFullName(field->message_type()) + "&"
378 : GetFullName(field->message_type());
379 case FieldDescriptor::TYPE_ENUM:
380 return GetFullName(field->enum_type());
381 case FieldDescriptor::TYPE_GROUP:
382 abort();
383 }
384 abort(); // for gcc
385 }
386
GetProtozeroSetter(const FieldDescriptor * field) const387 std::string CppObjGenerator::GetProtozeroSetter(
388 const FieldDescriptor* field) const {
389 switch (field->type()) {
390 case FieldDescriptor::TYPE_BOOL:
391 return "AppendTinyVarInt";
392 case FieldDescriptor::TYPE_INT32:
393 case FieldDescriptor::TYPE_INT64:
394 case FieldDescriptor::TYPE_UINT32:
395 case FieldDescriptor::TYPE_UINT64:
396 case FieldDescriptor::TYPE_ENUM:
397 return "AppendVarInt";
398 case FieldDescriptor::TYPE_SINT32:
399 case FieldDescriptor::TYPE_SINT64:
400 return "AppendSignedVarInt";
401 case FieldDescriptor::TYPE_FIXED32:
402 case FieldDescriptor::TYPE_FIXED64:
403 case FieldDescriptor::TYPE_SFIXED32:
404 case FieldDescriptor::TYPE_SFIXED64:
405 case FieldDescriptor::TYPE_FLOAT:
406 case FieldDescriptor::TYPE_DOUBLE:
407 return "AppendFixed";
408 case FieldDescriptor::TYPE_STRING:
409 case FieldDescriptor::TYPE_BYTES:
410 return "AppendString";
411 case FieldDescriptor::TYPE_GROUP:
412 case FieldDescriptor::TYPE_MESSAGE:
413 abort();
414 }
415 abort();
416 }
417
GetPackedBuffer(const FieldDescriptor * field) const418 std::string CppObjGenerator::GetPackedBuffer(
419 const FieldDescriptor* field) const {
420 switch (field->type()) {
421 case FieldDescriptor::TYPE_FIXED32:
422 return "::protozero::PackedFixedSizeInt<uint32_t>";
423 case FieldDescriptor::TYPE_SFIXED32:
424 return "::protozero::PackedFixedSizeInt<int32_t>";
425 case FieldDescriptor::TYPE_FIXED64:
426 return "::protozero::PackedFixedSizeInt<uint64_t>";
427 case FieldDescriptor::TYPE_SFIXED64:
428 return "::protozero::PackedFixedSizeInt<int64_t>";
429 case FieldDescriptor::TYPE_DOUBLE:
430 return "::protozero::PackedFixedSizeInt<double>";
431 case FieldDescriptor::TYPE_FLOAT:
432 return "::protozero::PackedFixedSizeInt<float>";
433 case FieldDescriptor::TYPE_INT32:
434 case FieldDescriptor::TYPE_SINT32:
435 case FieldDescriptor::TYPE_UINT32:
436 case FieldDescriptor::TYPE_INT64:
437 case FieldDescriptor::TYPE_UINT64:
438 case FieldDescriptor::TYPE_SINT64:
439 case FieldDescriptor::TYPE_BOOL:
440 return "::protozero::PackedVarInt";
441 case FieldDescriptor::TYPE_STRING:
442 case FieldDescriptor::TYPE_BYTES:
443 case FieldDescriptor::TYPE_MESSAGE:
444 case FieldDescriptor::TYPE_ENUM:
445 case FieldDescriptor::TYPE_GROUP:
446 break; // Will abort()
447 }
448 abort();
449 }
450
GetPackedWireType(const FieldDescriptor * field) const451 std::string CppObjGenerator::GetPackedWireType(
452 const FieldDescriptor* field) const {
453 switch (field->type()) {
454 case FieldDescriptor::TYPE_FIXED32:
455 case FieldDescriptor::TYPE_SFIXED32:
456 case FieldDescriptor::TYPE_FLOAT:
457 return "::protozero::proto_utils::ProtoWireType::kFixed32";
458 case FieldDescriptor::TYPE_FIXED64:
459 case FieldDescriptor::TYPE_SFIXED64:
460 case FieldDescriptor::TYPE_DOUBLE:
461 return "::protozero::proto_utils::ProtoWireType::kFixed64";
462 case FieldDescriptor::TYPE_INT32:
463 case FieldDescriptor::TYPE_SINT32:
464 case FieldDescriptor::TYPE_UINT32:
465 case FieldDescriptor::TYPE_INT64:
466 case FieldDescriptor::TYPE_UINT64:
467 case FieldDescriptor::TYPE_SINT64:
468 case FieldDescriptor::TYPE_BOOL:
469 return "::protozero::proto_utils::ProtoWireType::kVarInt";
470 case FieldDescriptor::TYPE_STRING:
471 case FieldDescriptor::TYPE_BYTES:
472 case FieldDescriptor::TYPE_MESSAGE:
473 case FieldDescriptor::TYPE_ENUM:
474 case FieldDescriptor::TYPE_GROUP:
475 break; // Will abort()
476 }
477 abort();
478 }
479
GenEnum(const EnumDescriptor * enum_desc,Printer * p) const480 void CppObjGenerator::GenEnum(const EnumDescriptor* enum_desc,
481 Printer* p) const {
482 std::string full_name = GetFullName(enum_desc);
483
484 // When generating enums, there are two cases:
485 // 1. Enums nested in a message (most frequent case), e.g.:
486 // message MyMsg { enum MyEnum { FOO=1; BAR=2; } }
487 // 2. Enum defined at the package level, outside of any message.
488 //
489 // In the case 1, the C++ code generated by the official protobuf library is:
490 // enum MyEnum { MyMsg_MyEnum_FOO=1, MyMsg_MyEnum_BAR=2 }
491 // class MyMsg { static const auto FOO = MyMsg_MyEnum_FOO; ... same for BAR }
492 //
493 // In the case 2, the C++ code is simply:
494 // enum MyEnum { FOO=1, BAR=2 }
495 // Hence this |prefix| logic.
496 std::string prefix = enum_desc->containing_type() ? full_name + "_" : "";
497 p->Print("enum $f$ : int {\n", "f", full_name);
498 for (int e = 0; e < enum_desc->value_count(); e++) {
499 const EnumValueDescriptor* value = enum_desc->value(e);
500 p->Print(" $p$$n$ = $v$,\n", "p", prefix, "n", value->name(), "v",
501 std::to_string(value->number()));
502 }
503 p->Print("};\n");
504 }
505
GenEnumAliases(const EnumDescriptor * enum_desc,Printer * p) const506 void CppObjGenerator::GenEnumAliases(const EnumDescriptor* enum_desc,
507 Printer* p) const {
508 int min_value = std::numeric_limits<int>::max();
509 int max_value = std::numeric_limits<int>::min();
510 std::string min_name;
511 std::string max_name;
512 std::string full_name = GetFullName(enum_desc);
513 for (int e = 0; e < enum_desc->value_count(); e++) {
514 const EnumValueDescriptor* value = enum_desc->value(e);
515 p->Print("static constexpr auto $n$ = $f$_$n$;\n", "f", full_name, "n",
516 value->name());
517 if (value->number() < min_value) {
518 min_value = value->number();
519 min_name = full_name + "_" + value->name();
520 }
521 if (value->number() > max_value) {
522 max_value = value->number();
523 max_name = full_name + "_" + value->name();
524 }
525 }
526 p->Print("static constexpr auto $n$_MIN = $m$;\n", "n", enum_desc->name(),
527 "m", min_name);
528 p->Print("static constexpr auto $n$_MAX = $m$;\n", "n", enum_desc->name(),
529 "m", max_name);
530 }
531
GenClassDecl(const Descriptor * msg,Printer * p) const532 void CppObjGenerator::GenClassDecl(const Descriptor* msg, Printer* p) const {
533 std::string full_name = GetFullName(msg);
534 p->Print(
535 "\nclass PERFETTO_EXPORT $n$ : public ::protozero::CppMessageObj {\n",
536 "n", full_name);
537 p->Print(" public:\n");
538 p->Indent();
539
540 // Do a first pass to generate aliases for nested types.
541 // e.g., using Foo = Parent_Foo;
542 for (int i = 0; i < msg->nested_type_count(); i++) {
543 const Descriptor* nested_msg = msg->nested_type(i);
544 p->Print("using $n$ = $f$;\n", "n", nested_msg->name(), "f",
545 GetFullName(nested_msg));
546 }
547 for (int i = 0; i < msg->enum_type_count(); i++) {
548 const EnumDescriptor* nested_enum = msg->enum_type(i);
549 p->Print("using $n$ = $f$;\n", "n", nested_enum->name(), "f",
550 GetFullName(nested_enum));
551 GenEnumAliases(nested_enum, p);
552 }
553
554 // Generate constants with field numbers.
555 p->Print("enum FieldNumbers {\n");
556 for (int i = 0; i < msg->field_count(); i++) {
557 const FieldDescriptor* field = msg->field(i);
558 std::string name = field->camelcase_name();
559 name[0] = perfetto::base::Uppercase(name[0]);
560 p->Print(" k$n$FieldNumber = $num$,\n", "n", name, "num",
561 std::to_string(field->number()));
562 }
563 p->Print("};\n\n");
564
565 p->Print("$n$();\n", "n", full_name);
566 p->Print("~$n$() override;\n", "n", full_name);
567 p->Print("$n$($n$&&) noexcept;\n", "n", full_name);
568 p->Print("$n$& operator=($n$&&);\n", "n", full_name);
569 p->Print("$n$(const $n$&);\n", "n", full_name);
570 p->Print("$n$& operator=(const $n$&);\n", "n", full_name);
571 p->Print("bool operator==(const $n$&) const;\n", "n", full_name);
572 p->Print(
573 "bool operator!=(const $n$& other) const { return !(*this == other); }\n",
574 "n", full_name);
575 p->Print("\n");
576
577 std::string proto_type = GetFullName(msg, true);
578 p->Print("bool ParseFromArray(const void*, size_t) override;\n");
579 p->Print("std::string SerializeAsString() const override;\n");
580 p->Print("std::vector<uint8_t> SerializeAsArray() const override;\n");
581 p->Print("void Serialize(::protozero::Message*) const;\n");
582
583 // Generate accessors.
584 for (int i = 0; i < msg->field_count(); i++) {
585 const FieldDescriptor* field = msg->field(i);
586 auto set_bit = "_has_field_.set(" + std::to_string(field->number()) + ")";
587 p->Print("\n");
588 if (field->options().lazy()) {
589 p->Print("const std::string& $n$_raw() const { return $n$_; }\n", "n",
590 field->lowercase_name());
591 p->Print(
592 "void set_$n$_raw(const std::string& raw) { $n$_ = raw; $s$; }\n",
593 "n", field->lowercase_name(), "s", set_bit);
594 } else if (!field->is_repeated()) {
595 p->Print("bool has_$n$() const { return _has_field_[$bit$]; }\n", "n",
596 field->lowercase_name(), "bit", std::to_string(field->number()));
597 if (field->type() == TYPE_MESSAGE) {
598 p->Print("$t$ $n$() const { return *$n$_; }\n", "t",
599 GetCppType(field, true), "n", field->lowercase_name());
600 p->Print("$t$* mutable_$n$() { $s$; return $n$_.get(); }\n", "t",
601 GetCppType(field, false), "n", field->lowercase_name(), "s",
602 set_bit);
603 } else {
604 p->Print("$t$ $n$() const { return $n$_; }\n", "t",
605 GetCppType(field, true), "n", field->lowercase_name());
606 p->Print("void set_$n$($t$ value) { $n$_ = value; $s$; }\n", "t",
607 GetCppType(field, true), "n", field->lowercase_name(), "s",
608 set_bit);
609 if (field->type() == FieldDescriptor::TYPE_BYTES) {
610 p->Print(
611 "void set_$n$(const void* p, size_t s) { "
612 "$n$_.assign(reinterpret_cast<const char*>(p), s); $s$; }\n",
613 "n", field->lowercase_name(), "s", set_bit);
614 }
615 }
616 } else { // is_repeated()
617 p->Print("const std::vector<$t$>& $n$() const { return $n$_; }\n", "t",
618 GetCppType(field, false), "n", field->lowercase_name());
619 p->Print("std::vector<$t$>* mutable_$n$() { return &$n$_; }\n", "t",
620 GetCppType(field, false), "n", field->lowercase_name());
621
622 // Generate accessors for repeated message types in the .cc file so that
623 // the header doesn't depend on the full definition of all nested types.
624 if (field->type() == TYPE_MESSAGE) {
625 p->Print("int $n$_size() const;\n", "t", GetCppType(field, false), "n",
626 field->lowercase_name());
627 p->Print("void clear_$n$();\n", "n", field->lowercase_name());
628 p->Print("$t$* add_$n$();\n", "t", GetCppType(field, false), "n",
629 field->lowercase_name());
630 } else { // Primitive type.
631 p->Print(
632 "int $n$_size() const { return static_cast<int>($n$_.size()); }\n",
633 "t", GetCppType(field, false), "n", field->lowercase_name());
634 p->Print("void clear_$n$() { $n$_.clear(); }\n", "n",
635 field->lowercase_name());
636 p->Print("void add_$n$($t$ value) { $n$_.emplace_back(value); }\n", "t",
637 GetCppType(field, false), "n", field->lowercase_name());
638 // TODO(primiano): this should be done only for TYPE_MESSAGE.
639 // Unfortuntely we didn't realize before and now we have a bunch of code
640 // that does: *msg->add_int_value() = 42 instead of
641 // msg->add_int_value(42).
642 p->Print(
643 "$t$* add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
644 "t", GetCppType(field, false), "n", field->lowercase_name());
645 }
646 }
647 }
648 p->Outdent();
649 p->Print("\n private:\n");
650 p->Indent();
651
652 // Generate fields.
653 int max_field_id = 1;
654 for (int i = 0; i < msg->field_count(); i++) {
655 const FieldDescriptor* field = msg->field(i);
656 max_field_id = std::max(max_field_id, field->number());
657 if (field->options().lazy()) {
658 p->Print("std::string $n$_; // [lazy=true]\n", "n",
659 field->lowercase_name());
660 } else if (!field->is_repeated()) {
661 std::string type = GetCppType(field, false);
662 if (field->type() == TYPE_MESSAGE) {
663 type = "::protozero::CopyablePtr<" + type + ">";
664 p->Print("$t$ $n$_;\n", "t", type, "n", field->lowercase_name());
665 } else {
666 p->Print("$t$ $n$_{};\n", "t", type, "n", field->lowercase_name());
667 }
668 } else { // is_repeated()
669 p->Print("std::vector<$t$> $n$_;\n", "t", GetCppType(field, false), "n",
670 field->lowercase_name());
671 }
672 }
673 p->Print("\n");
674 p->Print("// Allows to preserve unknown protobuf fields for compatibility\n");
675 p->Print("// with future versions of .proto files.\n");
676 p->Print("std::string unknown_fields_;\n");
677
678 p->Print("\nstd::bitset<$id$> _has_field_{};\n", "id",
679 std::to_string(max_field_id + 1));
680
681 p->Outdent();
682 p->Print("};\n\n");
683 }
684
GenClassDef(const Descriptor * msg,Printer * p) const685 void CppObjGenerator::GenClassDef(const Descriptor* msg, Printer* p) const {
686 p->Print("\n");
687 std::string full_name = GetFullName(msg);
688
689 p->Print("$n$::$n$() = default;\n", "n", full_name);
690 p->Print("$n$::~$n$() = default;\n", "n", full_name);
691 p->Print("$n$::$n$(const $n$&) = default;\n", "n", full_name);
692 p->Print("$n$& $n$::operator=(const $n$&) = default;\n", "n", full_name);
693 p->Print("$n$::$n$($n$&&) noexcept = default;\n", "n", full_name);
694 p->Print("$n$& $n$::operator=($n$&&) = default;\n", "n", full_name);
695
696 p->Print("\n");
697
698 // Comparison operator
699 p->Print("bool $n$::operator==(const $n$& other) const {\n", "n", full_name);
700 p->Indent();
701
702 p->Print("return unknown_fields_ == other.unknown_fields_");
703 for (int i = 0; i < msg->field_count(); i++)
704 p->Print("\n && $n$_ == other.$n$_", "n", msg->field(i)->lowercase_name());
705 p->Print(";");
706 p->Outdent();
707 p->Print("\n}\n\n");
708
709 // Accessors for repeated message fields.
710 for (int i = 0; i < msg->field_count(); i++) {
711 const FieldDescriptor* field = msg->field(i);
712 if (field->options().lazy() || !field->is_repeated() ||
713 field->type() != TYPE_MESSAGE) {
714 continue;
715 }
716 p->Print(
717 "int $c$::$n$_size() const { return static_cast<int>($n$_.size()); }\n",
718 "c", full_name, "t", GetCppType(field, false), "n",
719 field->lowercase_name());
720 p->Print("void $c$::clear_$n$() { $n$_.clear(); }\n", "c", full_name, "n",
721 field->lowercase_name());
722 p->Print(
723 "$t$* $c$::add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
724 "c", full_name, "t", GetCppType(field, false), "n",
725 field->lowercase_name());
726 }
727
728 std::string proto_type = GetFullName(msg, true);
729
730 // Generate the ParseFromArray() method definition.
731 p->Print("bool $f$::ParseFromArray(const void* raw, size_t size) {\n", "f",
732 full_name);
733 p->Indent();
734 for (int i = 0; i < msg->field_count(); i++) {
735 const FieldDescriptor* field = msg->field(i);
736 if (field->is_repeated())
737 p->Print("$n$_.clear();\n", "n", field->lowercase_name());
738 }
739 p->Print("unknown_fields_.clear();\n");
740 p->Print("bool packed_error = false;\n");
741 p->Print("\n");
742 p->Print("::protozero::ProtoDecoder dec(raw, size);\n");
743 p->Print("for (auto field = dec.ReadField(); field.valid(); ");
744 p->Print("field = dec.ReadField()) {\n");
745 p->Indent();
746 p->Print("if (field.id() < _has_field_.size()) {\n");
747 p->Print(" _has_field_.set(field.id());\n");
748 p->Print("}\n");
749 p->Print("switch (field.id()) {\n");
750 p->Indent();
751 for (int i = 0; i < msg->field_count(); i++) {
752 const FieldDescriptor* field = msg->field(i);
753 p->Print("case $id$ /* $n$ */:\n", "id", std::to_string(field->number()),
754 "n", field->lowercase_name());
755 p->Indent();
756 if (field->options().lazy()) {
757 p->Print("$n$_ = field.as_std_string();\n", "n", field->lowercase_name());
758 } else {
759 std::string statement;
760 if (field->type() == TYPE_MESSAGE) {
761 statement = "$rval$.ParseFromArray(field.data(), field.size());\n";
762 } else {
763 if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
764 // sint32/64 fields are special and need to be zig-zag-decoded.
765 statement = "field.get_signed(&$rval$);\n";
766 } else {
767 statement = "field.get(&$rval$);\n";
768 }
769 }
770 if (field->is_packed()) {
771 PERFETTO_CHECK(field->is_repeated());
772 if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
773 PERFETTO_FATAL("packed signed (zigzag) fields are not supported");
774 }
775 p->Print(
776 "for (::protozero::PackedRepeatedFieldIterator<$w$, $c$> "
777 "rep(field.data(), field.size(), &packed_error); rep; ++rep) {\n",
778 "w", GetPackedWireType(field), "c", GetCppType(field, false));
779 p->Print(" $n$_.emplace_back(*rep);\n", "n", field->lowercase_name());
780 p->Print("}\n");
781 } else if (field->is_repeated()) {
782 p->Print("$n$_.emplace_back();\n", "n", field->lowercase_name());
783 p->Print(statement.c_str(), "rval",
784 field->lowercase_name() + "_.back()");
785 } else if (field->type() == TYPE_MESSAGE) {
786 p->Print(statement.c_str(), "rval",
787 "(*" + field->lowercase_name() + "_)");
788 } else {
789 p->Print(statement.c_str(), "rval", field->lowercase_name() + "_");
790 }
791 }
792 p->Print("break;\n");
793 p->Outdent();
794 } // for (field)
795 p->Print("default:\n");
796 p->Print(" field.SerializeAndAppendTo(&unknown_fields_);\n");
797 p->Print(" break;\n");
798 p->Outdent();
799 p->Print("}\n"); // switch(field.id)
800 p->Outdent();
801 p->Print("}\n"); // for(field)
802 p->Print("return !packed_error && !dec.bytes_left();\n"); // for(field)
803 p->Outdent();
804 p->Print("}\n\n");
805
806 // Generate the SerializeAsString() method definition.
807 p->Print("std::string $f$::SerializeAsString() const {\n", "f", full_name);
808 p->Indent();
809 p->Print("::protozero::HeapBuffered<::protozero::Message> msg;\n");
810 p->Print("Serialize(msg.get());\n");
811 p->Print("return msg.SerializeAsString();\n");
812 p->Outdent();
813 p->Print("}\n\n");
814
815 // Generate the SerializeAsArray() method definition.
816 p->Print("std::vector<uint8_t> $f$::SerializeAsArray() const {\n", "f",
817 full_name);
818 p->Indent();
819 p->Print("::protozero::HeapBuffered<::protozero::Message> msg;\n");
820 p->Print("Serialize(msg.get());\n");
821 p->Print("return msg.SerializeAsArray();\n");
822 p->Outdent();
823 p->Print("}\n\n");
824
825 // Generate the Serialize() method that writes the fields into the passed
826 // protozero |msg| write-only interface |msg|.
827 p->Print("void $f$::Serialize(::protozero::Message* msg) const {\n", "f",
828 full_name);
829 p->Indent();
830 for (int i = 0; i < msg->field_count(); i++) {
831 const FieldDescriptor* field = msg->field(i);
832 std::map<std::string, std::string> args;
833 args["id"] = std::to_string(field->number());
834 args["n"] = field->lowercase_name();
835 p->Print(args, "// Field $id$: $n$\n");
836 if (field->is_packed()) {
837 PERFETTO_CHECK(field->is_repeated());
838 p->Print("{\n");
839 p->Indent();
840 p->Print("$p$ pack;\n", "p", GetPackedBuffer(field));
841 p->Print(args, "for (auto& it : $n$_)\n");
842 p->Print(args, " pack.Append(it);\n");
843 p->Print(args, "msg->AppendBytes($id$, pack.data(), pack.size());\n");
844 p->Outdent();
845 p->Print("}\n");
846 } else {
847 if (field->is_repeated()) {
848 p->Print(args, "for (auto& it : $n$_) {\n");
849 args["lvalue"] = "it";
850 args["rvalue"] = "it";
851 } else {
852 p->Print(args, "if (_has_field_[$id$]) {\n");
853 args["lvalue"] = "(*" + field->lowercase_name() + "_)";
854 args["rvalue"] = field->lowercase_name() + "_";
855 }
856 p->Indent();
857 if (field->options().lazy()) {
858 p->Print(args, "msg->AppendString($id$, $rvalue$);\n");
859 } else if (field->type() == TYPE_MESSAGE) {
860 p->Print(args,
861 "$lvalue$.Serialize("
862 "msg->BeginNestedMessage<::protozero::Message>($id$));\n");
863 } else {
864 args["setter"] = GetProtozeroSetter(field);
865 p->Print(args, "msg->$setter$($id$, $rvalue$);\n");
866 }
867 p->Outdent();
868 p->Print("}\n");
869 }
870
871 p->Print("\n");
872 } // for (field)
873 p->Print(
874 "msg->AppendRawProtoBytes(unknown_fields_.data(), "
875 "unknown_fields_.size());\n");
876 p->Outdent();
877 p->Print("}\n\n");
878 }
879
880 } // namespace
881 } // namespace protozero
882
main(int argc,char ** argv)883 int main(int argc, char** argv) {
884 ::protozero::CppObjGenerator generator;
885 return google::protobuf::compiler::PluginMain(argc, argv, &generator);
886 }
887