• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: kenton@google.com (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 
35 #include <google/protobuf/compiler/cpp/cpp_message.h>
36 
37 #include <algorithm>
38 #include <functional>
39 #include <map>
40 #include <memory>
41 #include <unordered_map>
42 #include <utility>
43 #include <vector>
44 
45 #include <google/protobuf/compiler/cpp/cpp_enum.h>
46 #include <google/protobuf/compiler/cpp/cpp_extension.h>
47 #include <google/protobuf/compiler/cpp/cpp_field.h>
48 #include <google/protobuf/compiler/cpp/cpp_helpers.h>
49 #include <google/protobuf/compiler/cpp/cpp_padding_optimizer.h>
50 #include <google/protobuf/descriptor.pb.h>
51 #include <google/protobuf/io/coded_stream.h>
52 #include <google/protobuf/io/printer.h>
53 #include <google/protobuf/generated_message_table_driven.h>
54 #include <google/protobuf/generated_message_util.h>
55 #include <google/protobuf/map_entry_lite.h>
56 #include <google/protobuf/wire_format.h>
57 #include <google/protobuf/stubs/strutil.h>
58 #include <google/protobuf/stubs/substitute.h>
59 #include <google/protobuf/stubs/hash.h>
60 
61 
62 namespace google {
63 namespace protobuf {
64 namespace compiler {
65 namespace cpp {
66 
67 using internal::WireFormat;
68 using internal::WireFormatLite;
69 
70 namespace {
71 
72 static constexpr int kNoHasbit = -1;
73 
74 // Create an expression that evaluates to
75 //  "for all i, (_has_bits_[i] & masks[i]) == masks[i]"
76 // masks is allowed to be shorter than _has_bits_, but at least one element of
77 // masks must be non-zero.
ConditionalToCheckBitmasks(const std::vector<uint32> & masks,bool return_success=true,StringPiece has_bits_var="_has_bits_")78 std::string ConditionalToCheckBitmasks(
79     const std::vector<uint32>& masks, bool return_success = true,
80     StringPiece has_bits_var = "_has_bits_") {
81   std::vector<std::string> parts;
82   for (int i = 0; i < masks.size(); i++) {
83     if (masks[i] == 0) continue;
84     std::string m = StrCat("0x", strings::Hex(masks[i], strings::ZERO_PAD_8));
85     // Each xor evaluates to 0 if the expected bits are present.
86     parts.push_back(
87         StrCat("((", has_bits_var, "[", i, "] & ", m, ") ^ ", m, ")"));
88   }
89   GOOGLE_CHECK(!parts.empty());
90   // If we have multiple parts, each expected to be 0, then bitwise-or them.
91   std::string result =
92       parts.size() == 1
93           ? parts[0]
94           : StrCat("(", Join(parts, "\n       | "), ")");
95   return result + (return_success ? " == 0" : " != 0");
96 }
97 
PrintPresenceCheck(const Formatter & format,const FieldDescriptor * field,const std::vector<int> & has_bit_indices,io::Printer * printer,int * cached_has_word_index)98 void PrintPresenceCheck(const Formatter& format, const FieldDescriptor* field,
99                         const std::vector<int>& has_bit_indices,
100                         io::Printer* printer, int* cached_has_word_index) {
101   if (!field->options().weak()) {
102     int has_bit_index = has_bit_indices[field->index()];
103     if (*cached_has_word_index != (has_bit_index / 32)) {
104       *cached_has_word_index = (has_bit_index / 32);
105       format("cached_has_bits = _has_bits_[$1$];\n", *cached_has_word_index);
106     }
107     const std::string mask =
108         StrCat(strings::Hex(1u << (has_bit_index % 32), strings::ZERO_PAD_8));
109     format("if (cached_has_bits & 0x$1$u) {\n", mask);
110   } else {
111     format("if (has_$1$()) {\n", FieldName(field));
112   }
113   format.Indent();
114 }
115 
116 struct FieldOrderingByNumber {
operator ()google::protobuf::compiler::cpp::__anon9167795c0111::FieldOrderingByNumber117   inline bool operator()(const FieldDescriptor* a,
118                          const FieldDescriptor* b) const {
119     return a->number() < b->number();
120   }
121 };
122 
123 // Sort the fields of the given Descriptor by number into a new[]'d array
124 // and return it.
SortFieldsByNumber(const Descriptor * descriptor)125 std::vector<const FieldDescriptor*> SortFieldsByNumber(
126     const Descriptor* descriptor) {
127   std::vector<const FieldDescriptor*> fields(descriptor->field_count());
128   for (int i = 0; i < descriptor->field_count(); i++) {
129     fields[i] = descriptor->field(i);
130   }
131   std::sort(fields.begin(), fields.end(), FieldOrderingByNumber());
132   return fields;
133 }
134 
135 // Functor for sorting extension ranges by their "start" field number.
136 struct ExtensionRangeSorter {
operator ()google::protobuf::compiler::cpp::__anon9167795c0111::ExtensionRangeSorter137   bool operator()(const Descriptor::ExtensionRange* left,
138                   const Descriptor::ExtensionRange* right) const {
139     return left->start < right->start;
140   }
141 };
142 
IsPOD(const FieldDescriptor * field)143 bool IsPOD(const FieldDescriptor* field) {
144   if (field->is_repeated() || field->is_extension()) return false;
145   switch (field->cpp_type()) {
146     case FieldDescriptor::CPPTYPE_ENUM:
147     case FieldDescriptor::CPPTYPE_INT32:
148     case FieldDescriptor::CPPTYPE_INT64:
149     case FieldDescriptor::CPPTYPE_UINT32:
150     case FieldDescriptor::CPPTYPE_UINT64:
151     case FieldDescriptor::CPPTYPE_FLOAT:
152     case FieldDescriptor::CPPTYPE_DOUBLE:
153     case FieldDescriptor::CPPTYPE_BOOL:
154       return true;
155     case FieldDescriptor::CPPTYPE_STRING:
156       return false;
157     default:
158       return false;
159   }
160 }
161 
162 // Helper for the code that emits the SharedCtor() and InternalSwap() methods.
163 // Anything that is a POD or a "normal" message (represented by a pointer) can
164 // be manipulated as raw bytes.
CanBeManipulatedAsRawBytes(const FieldDescriptor * field,const Options & options)165 bool CanBeManipulatedAsRawBytes(const FieldDescriptor* field,
166                                 const Options& options) {
167   bool ret = CanInitializeByZeroing(field);
168 
169   // Non-repeated, non-lazy message fields are simply raw pointers, so we can
170   // swap them or use memset to initialize these in SharedCtor. We cannot use
171   // this in Clear, as we need to potentially delete the existing value.
172   ret = ret || (!field->is_repeated() && !IsLazy(field, options) &&
173                 field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE);
174   return ret;
175 }
176 
177 // Finds runs of fields for which `predicate` is true.
178 // RunMap maps from fields that start each run to the number of fields in that
179 // run.  This is optimized for the common case that there are very few runs in
180 // a message and that most of the eligible fields appear together.
181 using RunMap = std::unordered_map<const FieldDescriptor*, size_t>;
FindRuns(const std::vector<const FieldDescriptor * > & fields,const std::function<bool (const FieldDescriptor *)> & predicate)182 RunMap FindRuns(const std::vector<const FieldDescriptor*>& fields,
183                 const std::function<bool(const FieldDescriptor*)>& predicate) {
184   RunMap runs;
185   const FieldDescriptor* last_start = nullptr;
186 
187   for (auto field : fields) {
188     if (predicate(field)) {
189       if (last_start == nullptr) {
190         last_start = field;
191       }
192 
193       runs[last_start]++;
194     } else {
195       last_start = nullptr;
196     }
197   }
198   return runs;
199 }
200 
201 // Emits an if-statement with a condition that evaluates to true if |field| is
202 // considered non-default (will be sent over the wire), for message types
203 // without true field presence. Should only be called if
204 // !HasHasbit(field).
EmitFieldNonDefaultCondition(io::Printer * printer,const std::string & prefix,const FieldDescriptor * field)205 bool EmitFieldNonDefaultCondition(io::Printer* printer,
206                                   const std::string& prefix,
207                                   const FieldDescriptor* field) {
208   GOOGLE_CHECK(!HasHasbit(field));
209   Formatter format(printer);
210   format.Set("prefix", prefix);
211   format.Set("name", FieldName(field));
212   // Merge and serialize semantics: primitive fields are merged/serialized only
213   // if non-zero (numeric) or non-empty (string).
214   if (!field->is_repeated() && !field->containing_oneof()) {
215     if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) {
216       format("if ($prefix$$name$().size() > 0) {\n");
217     } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
218       // Message fields still have has_$name$() methods.
219       format("if ($prefix$has_$name$()) {\n");
220     } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_DOUBLE ||
221                field->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT) {
222       // Handle float comparison to prevent -Wfloat-equal warnings
223       format("if (!($prefix$$name$() <= 0 && $prefix$$name$() >= 0)) {\n");
224     } else {
225       format("if ($prefix$$name$() != 0) {\n");
226     }
227     format.Indent();
228     return true;
229   } else if (field->real_containing_oneof()) {
230     format("if (_internal_has_$name$()) {\n");
231     format.Indent();
232     return true;
233   }
234   return false;
235 }
236 
237 // Does the given field have a has_$name$() method?
HasHasMethod(const FieldDescriptor * field)238 bool HasHasMethod(const FieldDescriptor* field) {
239   if (HasFieldPresence(field->file())) {
240     // In proto1/proto2, every field has a has_$name$() method.
241     return true;
242   }
243   // For message types without true field presence, only fields with a message
244   // type have a has_$name$() method.
245   return field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE ||
246          field->has_optional_keyword();
247 }
248 
249 // Collects map entry message type information.
CollectMapInfo(const Options & options,const Descriptor * descriptor,std::map<std::string,std::string> * variables)250 void CollectMapInfo(const Options& options, const Descriptor* descriptor,
251                     std::map<std::string, std::string>* variables) {
252   GOOGLE_CHECK(IsMapEntryMessage(descriptor));
253   std::map<std::string, std::string>& vars = *variables;
254   const FieldDescriptor* key = descriptor->FindFieldByName("key");
255   const FieldDescriptor* val = descriptor->FindFieldByName("value");
256   vars["key_cpp"] = PrimitiveTypeName(options, key->cpp_type());
257   switch (val->cpp_type()) {
258     case FieldDescriptor::CPPTYPE_MESSAGE:
259       vars["val_cpp"] = FieldMessageTypeName(val, options);
260       break;
261     case FieldDescriptor::CPPTYPE_ENUM:
262       vars["val_cpp"] = ClassName(val->enum_type(), true);
263       break;
264     default:
265       vars["val_cpp"] = PrimitiveTypeName(options, val->cpp_type());
266   }
267   vars["key_wire_type"] =
268       "TYPE_" + ToUpper(DeclaredTypeMethodName(key->type()));
269   vars["val_wire_type"] =
270       "TYPE_" + ToUpper(DeclaredTypeMethodName(val->type()));
271 }
272 
273 // Does the given field have a private (internal helper only) has_$name$()
274 // method?
HasPrivateHasMethod(const FieldDescriptor * field)275 bool HasPrivateHasMethod(const FieldDescriptor* field) {
276   // Only for oneofs in message types with no field presence. has_$name$(),
277   // based on the oneof case, is still useful internally for generated code.
278   return (!HasFieldPresence(field->file()) && field->real_containing_oneof());
279 }
280 
281 // TODO(ckennelly):  Cull these exclusions if/when these protos do not have
282 // their methods overridden by subclasses.
283 
ShouldMarkClassAsFinal(const Descriptor * descriptor,const Options & options)284 bool ShouldMarkClassAsFinal(const Descriptor* descriptor,
285                             const Options& options) {
286   return true;
287 }
288 
ShouldMarkClearAsFinal(const Descriptor * descriptor,const Options & options)289 bool ShouldMarkClearAsFinal(const Descriptor* descriptor,
290                             const Options& options) {
291   static std::set<std::string> exclusions{
292   };
293 
294   const std::string name = ClassName(descriptor, true);
295   return exclusions.find(name) == exclusions.end() ||
296          options.opensource_runtime;
297 }
298 
ShouldMarkIsInitializedAsFinal(const Descriptor * descriptor,const Options & options)299 bool ShouldMarkIsInitializedAsFinal(const Descriptor* descriptor,
300                                     const Options& options) {
301   static std::set<std::string> exclusions{
302   };
303 
304   const std::string name = ClassName(descriptor, true);
305   return exclusions.find(name) == exclusions.end() ||
306          options.opensource_runtime;
307 }
308 
ShouldMarkNewAsFinal(const Descriptor * descriptor,const Options & options)309 bool ShouldMarkNewAsFinal(const Descriptor* descriptor,
310                           const Options& options) {
311   static std::set<std::string> exclusions{
312   };
313 
314   const std::string name = ClassName(descriptor, true);
315   return exclusions.find(name) == exclusions.end() ||
316          options.opensource_runtime;
317 }
318 
319 // Returns true to make the message serialize in order, decided by the following
320 // factors in the order of precedence.
321 // --options().message_set_wire_format() == true
322 // --the message is in the allowlist (true)
323 // --GOOGLE_PROTOBUF_SHUFFLE_SERIALIZE is defined (false)
324 // --a ranage of message names that are allowed to stay in order (true)
ShouldSerializeInOrder(const Descriptor * descriptor,const Options & options)325 bool ShouldSerializeInOrder(const Descriptor* descriptor,
326                             const Options& options) {
327   return true;
328 }
329 
TableDrivenParsingEnabled(const Descriptor * descriptor,const Options & options)330 bool TableDrivenParsingEnabled(const Descriptor* descriptor,
331                                const Options& options) {
332   if (!options.table_driven_parsing) {
333     return false;
334   }
335 
336   // Consider table-driven parsing.  We only do this if:
337   // - We have has_bits for fields.  This avoids a check on every field we set
338   //   when are present (the common case).
339   bool has_hasbit = false;
340   for (int i = 0; i < descriptor->field_count(); i++) {
341     if (HasHasbit(descriptor->field(i))) {
342       has_hasbit = true;
343       break;
344     }
345   }
346 
347   if (!has_hasbit) return false;
348 
349   const double table_sparseness = 0.5;
350   int max_field_number = 0;
351   for (auto field : FieldRange(descriptor)) {
352     if (max_field_number < field->number()) {
353       max_field_number = field->number();
354     }
355 
356     // - There are no weak fields.
357     if (IsWeak(field, options)) {
358       return false;
359     }
360 
361     // - There are no lazy fields (they require the non-lite library).
362     if (IsLazy(field, options)) {
363       return false;
364     }
365   }
366 
367   // - There range of field numbers is "small"
368   if (max_field_number >= (2 << 14)) {
369     return false;
370   }
371 
372   // - Field numbers are relatively dense within the actual number of fields.
373   //   We check for strictly greater than in the case where there are no fields
374   //   (only extensions) so max_field_number == descriptor->field_count() == 0.
375   if (max_field_number * table_sparseness > descriptor->field_count()) {
376     return false;
377   }
378 
379   // - This is not a MapEntryMessage.
380   if (IsMapEntryMessage(descriptor)) {
381     return false;
382   }
383 
384   return true;
385 }
386 
IsCrossFileMapField(const FieldDescriptor * field)387 bool IsCrossFileMapField(const FieldDescriptor* field) {
388   if (!field->is_map()) {
389     return false;
390   }
391 
392   const Descriptor* d = field->message_type();
393   const FieldDescriptor* value = d->FindFieldByNumber(2);
394 
395   return IsCrossFileMessage(value);
396 }
397 
IsCrossFileMaybeMap(const FieldDescriptor * field)398 bool IsCrossFileMaybeMap(const FieldDescriptor* field) {
399   if (IsCrossFileMapField(field)) {
400     return true;
401   }
402 
403   return IsCrossFileMessage(field);
404 }
405 
IsRequired(const std::vector<const FieldDescriptor * > & v)406 bool IsRequired(const std::vector<const FieldDescriptor*>& v) {
407   return v.front()->is_required();
408 }
409 
410 // Collects neighboring fields based on a given criteria (equivalent predicate).
411 template <typename Predicate>
CollectFields(const std::vector<const FieldDescriptor * > & fields,const Predicate & equivalent)412 std::vector<std::vector<const FieldDescriptor*>> CollectFields(
413     const std::vector<const FieldDescriptor*>& fields,
414     const Predicate& equivalent) {
415   std::vector<std::vector<const FieldDescriptor*>> chunks;
416   for (auto field : fields) {
417     if (chunks.empty() || !equivalent(chunks.back().back(), field)) {
418       chunks.emplace_back();
419     }
420     chunks.back().push_back(field);
421   }
422   return chunks;
423 }
424 
425 // Returns a bit mask based on has_bit index of "fields" that are typically on
426 // the same chunk. It is used in a group presence check where _has_bits_ is
427 // masked to tell if any thing in "fields" is present.
GenChunkMask(const std::vector<const FieldDescriptor * > & fields,const std::vector<int> & has_bit_indices)428 uint32 GenChunkMask(const std::vector<const FieldDescriptor*>& fields,
429                     const std::vector<int>& has_bit_indices) {
430   GOOGLE_CHECK(!fields.empty());
431   int first_index_offset = has_bit_indices[fields.front()->index()] / 32;
432   uint32 chunk_mask = 0;
433   for (auto field : fields) {
434     // "index" defines where in the _has_bits_ the field appears.
435     int index = has_bit_indices[field->index()];
436     GOOGLE_CHECK_EQ(first_index_offset, index / 32);
437     chunk_mask |= static_cast<uint32>(1) << (index % 32);
438   }
439   GOOGLE_CHECK_NE(0, chunk_mask);
440   return chunk_mask;
441 }
442 
443 // Return the number of bits set in n, a non-negative integer.
popcnt(uint32 n)444 static int popcnt(uint32 n) {
445   int result = 0;
446   while (n != 0) {
447     result += (n & 1);
448     n = n / 2;
449   }
450   return result;
451 }
452 
453 // For a run of cold chunks, opens and closes an external if statement that
454 // checks multiple has_bits words to skip bulk of cold fields.
455 class ColdChunkSkipper {
456  public:
ColdChunkSkipper(const Options & options,const std::vector<std::vector<const FieldDescriptor * >> & chunks,const std::vector<int> & has_bit_indices,const double cold_threshold)457   ColdChunkSkipper(
458       const Options& options,
459       const std::vector<std::vector<const FieldDescriptor*>>& chunks,
460       const std::vector<int>& has_bit_indices, const double cold_threshold)
461       : chunks_(chunks),
462         has_bit_indices_(has_bit_indices),
463         access_info_map_(options.access_info_map),
464         cold_threshold_(cold_threshold) {
465     SetCommonVars(options, &variables_);
466   }
467 
468   // May open an external if check for a batch of cold fields. "from" is the
469   // prefix to _has_bits_ to allow MergeFrom to use "from._has_bits_".
470   // Otherwise, it should be "".
471   void OnStartChunk(int chunk, int cached_has_word_index,
472                     const std::string& from, io::Printer* printer);
473   bool OnEndChunk(int chunk, io::Printer* printer);
474 
475  private:
476   bool IsColdChunk(int chunk);
477 
HasbitWord(int chunk,int offset)478   int HasbitWord(int chunk, int offset) {
479     return has_bit_indices_[chunks_[chunk][offset]->index()] / 32;
480   }
481 
482   const std::vector<std::vector<const FieldDescriptor*>>& chunks_;
483   const std::vector<int>& has_bit_indices_;
484   const AccessInfoMap* access_info_map_;
485   const double cold_threshold_;
486   std::map<std::string, std::string> variables_;
487   int limit_chunk_ = -1;
488 };
489 
490 // Tuning parameters for ColdChunkSkipper.
491 const double kColdRatio = 0.005;
492 
IsColdChunk(int chunk)493 bool ColdChunkSkipper::IsColdChunk(int chunk) {
494   // Mark this variable as used until it is actually used
495   (void)cold_threshold_;
496   return false;
497 }
498 
499 
OnStartChunk(int chunk,int cached_has_word_index,const std::string & from,io::Printer * printer)500 void ColdChunkSkipper::OnStartChunk(int chunk, int cached_has_word_index,
501                                     const std::string& from,
502                                     io::Printer* printer) {
503   Formatter format(printer, variables_);
504   if (!access_info_map_) {
505     return;
506   } else if (chunk < limit_chunk_) {
507     // We are already inside a run of cold chunks.
508     return;
509   } else if (!IsColdChunk(chunk)) {
510     // We can't start a run of cold chunks.
511     return;
512   }
513 
514   // Find the end of consecutive cold chunks.
515   limit_chunk_ = chunk;
516   while (limit_chunk_ < chunks_.size() && IsColdChunk(limit_chunk_)) {
517     limit_chunk_++;
518   }
519 
520   if (limit_chunk_ <= chunk + 1) {
521     // Require at least two chunks to emit external has_bit checks.
522     limit_chunk_ = -1;
523     return;
524   }
525 
526   // Emit has_bit check for each has_bit_dword index.
527   format("if (PROTOBUF_PREDICT_FALSE(");
528   int first_word = HasbitWord(chunk, 0);
529   while (chunk < limit_chunk_) {
530     uint32 mask = 0;
531     int this_word = HasbitWord(chunk, 0);
532     // Generate mask for chunks on the same word.
533     for (; chunk < limit_chunk_ && HasbitWord(chunk, 0) == this_word; chunk++) {
534       for (auto field : chunks_[chunk]) {
535         int hasbit_index = has_bit_indices_[field->index()];
536         // Fields on a chunk must be in the same word.
537         GOOGLE_CHECK_EQ(this_word, hasbit_index / 32);
538         mask |= 1 << (hasbit_index % 32);
539       }
540     }
541 
542     if (this_word != first_word) {
543       format(" ||\n    ");
544     }
545     format.Set("mask", strings::Hex(mask, strings::ZERO_PAD_8));
546     if (this_word == cached_has_word_index) {
547       format("(cached_has_bits & 0x$mask$u) != 0");
548     } else {
549       format("($1$_has_bits_[$2$] & 0x$mask$u) != 0", from, this_word);
550     }
551   }
552   format(")) {\n");
553   format.Indent();
554 }
555 
OnEndChunk(int chunk,io::Printer * printer)556 bool ColdChunkSkipper::OnEndChunk(int chunk, io::Printer* printer) {
557   Formatter format(printer, variables_);
558   if (chunk != limit_chunk_ - 1) {
559     return false;
560   }
561   format.Outdent();
562   format("}\n");
563   return true;
564 }
565 
566 }  // anonymous namespace
567 
568 // ===================================================================
569 
MessageGenerator(const Descriptor * descriptor,const std::map<std::string,std::string> & vars,int index_in_file_messages,const Options & options,MessageSCCAnalyzer * scc_analyzer)570 MessageGenerator::MessageGenerator(
571     const Descriptor* descriptor,
572     const std::map<std::string, std::string>& vars, int index_in_file_messages,
573     const Options& options, MessageSCCAnalyzer* scc_analyzer)
574     : descriptor_(descriptor),
575       index_in_file_messages_(index_in_file_messages),
576       classname_(ClassName(descriptor, false)),
577       options_(options),
578       field_generators_(descriptor, options, scc_analyzer),
579       max_has_bit_index_(0),
580       num_weak_fields_(0),
581       scc_analyzer_(scc_analyzer),
582       variables_(vars) {
583   if (!message_layout_helper_) {
584     message_layout_helper_.reset(new PaddingOptimizer());
585   }
586 
587   // Variables that apply to this class
588   variables_["classname"] = classname_;
589   variables_["classtype"] = QualifiedClassName(descriptor_, options);
590   variables_["scc_info"] =
591       SccInfoSymbol(scc_analyzer_->GetSCC(descriptor_), options_);
592   variables_["full_name"] = descriptor_->full_name();
593   variables_["superclass"] = SuperClassName(descriptor_, options_);
594 
595   // Compute optimized field order to be used for layout and initialization
596   // purposes.
597   for (auto field : FieldRange(descriptor_)) {
598     if (IsFieldStripped(field, options_)) {
599       continue;
600     }
601 
602     if (IsWeak(field, options_)) {
603       num_weak_fields_++;
604     } else if (!field->real_containing_oneof()) {
605       optimized_order_.push_back(field);
606     }
607   }
608 
609   message_layout_helper_->OptimizeLayout(&optimized_order_, options_);
610 
611   // This message has hasbits iff one or more fields need one.
612   for (auto field : optimized_order_) {
613     if (HasHasbit(field)) {
614       if (has_bit_indices_.empty()) {
615         has_bit_indices_.resize(descriptor_->field_count(), kNoHasbit);
616       }
617       has_bit_indices_[field->index()] = max_has_bit_index_++;
618     }
619   }
620 
621   if (!has_bit_indices_.empty()) {
622     field_generators_.SetHasBitIndices(has_bit_indices_);
623   }
624 
625   num_required_fields_ = 0;
626   for (int i = 0; i < descriptor->field_count(); i++) {
627     if (descriptor->field(i)->is_required()) {
628       ++num_required_fields_;
629     }
630   }
631 
632   table_driven_ = TableDrivenParsingEnabled(descriptor_, options_);
633 }
634 
635 MessageGenerator::~MessageGenerator() = default;
636 
HasBitsSize() const637 size_t MessageGenerator::HasBitsSize() const {
638   return (max_has_bit_index_ + 31) / 32;
639 }
640 
HasBitIndex(const FieldDescriptor * field) const641 int MessageGenerator::HasBitIndex(const FieldDescriptor* field) const {
642   return has_bit_indices_.empty() ? kNoHasbit
643                                   : has_bit_indices_[field->index()];
644 }
645 
HasByteIndex(const FieldDescriptor * field) const646 int MessageGenerator::HasByteIndex(const FieldDescriptor* field) const {
647   int hasbit = HasBitIndex(field);
648   return hasbit == kNoHasbit ? kNoHasbit : hasbit / 8;
649 }
650 
HasWordIndex(const FieldDescriptor * field) const651 int MessageGenerator::HasWordIndex(const FieldDescriptor* field) const {
652   int hasbit = HasBitIndex(field);
653   return hasbit == kNoHasbit ? kNoHasbit : hasbit / 32;
654 }
655 
AddGenerators(std::vector<std::unique_ptr<EnumGenerator>> * enum_generators,std::vector<std::unique_ptr<ExtensionGenerator>> * extension_generators)656 void MessageGenerator::AddGenerators(
657     std::vector<std::unique_ptr<EnumGenerator>>* enum_generators,
658     std::vector<std::unique_ptr<ExtensionGenerator>>* extension_generators) {
659   for (int i = 0; i < descriptor_->enum_type_count(); i++) {
660     enum_generators->emplace_back(
661         new EnumGenerator(descriptor_->enum_type(i), variables_, options_));
662     enum_generators_.push_back(enum_generators->back().get());
663   }
664   for (int i = 0; i < descriptor_->extension_count(); i++) {
665     extension_generators->emplace_back(
666         new ExtensionGenerator(descriptor_->extension(i), options_));
667     extension_generators_.push_back(extension_generators->back().get());
668   }
669 }
670 
GenerateFieldAccessorDeclarations(io::Printer * printer)671 void MessageGenerator::GenerateFieldAccessorDeclarations(io::Printer* printer) {
672   Formatter format(printer, variables_);
673   // optimized_fields_ does not contain fields where
674   //    field->real_containing_oneof()
675   // so we need to iterate over those as well.
676   //
677   // We place the non-oneof fields in optimized_order_, as that controls the
678   // order of the _has_bits_ entries and we want GDB's pretty printers to be
679   // able to infer these indices from the k[FIELDNAME]FieldNumber order.
680   std::vector<const FieldDescriptor*> ordered_fields;
681   ordered_fields.reserve(descriptor_->field_count());
682 
683   ordered_fields.insert(ordered_fields.begin(), optimized_order_.begin(),
684                         optimized_order_.end());
685   for (auto field : FieldRange(descriptor_)) {
686     if (!field->real_containing_oneof() && !field->options().weak() &&
687         !IsFieldStripped(field, options_)) {
688       continue;
689     }
690     ordered_fields.push_back(field);
691   }
692 
693   if (!ordered_fields.empty()) {
694     format("enum : int {\n");
695     for (auto field : ordered_fields) {
696       Formatter::SaveState save(&format);
697 
698       std::map<std::string, std::string> vars;
699       SetCommonFieldVariables(field, &vars, options_);
700       format.AddMap(vars);
701       format("  ${1$$2$$}$ = $number$,\n", field, FieldConstantName(field));
702     }
703     format("};\n");
704   }
705   for (auto field : ordered_fields) {
706     PrintFieldComment(format, field);
707 
708     Formatter::SaveState save(&format);
709 
710     std::map<std::string, std::string> vars;
711     SetCommonFieldVariables(field, &vars, options_);
712     format.AddMap(vars);
713 
714     if (field->is_repeated()) {
715       format("$deprecated_attr$int ${1$$name$_size$}$() const$2$\n", field,
716              !IsFieldStripped(field, options_) ? ";" : " {__builtin_trap();}");
717       if (!IsFieldStripped(field, options_)) {
718         format(
719             "private:\n"
720             "int ${1$_internal_$name$_size$}$() const;\n"
721             "public:\n",
722             field);
723       }
724     } else if (HasHasMethod(field)) {
725       format("$deprecated_attr$bool ${1$has_$name$$}$() const$2$\n", field,
726              !IsFieldStripped(field, options_) ? ";" : " {__builtin_trap();}");
727       if (!IsFieldStripped(field, options_)) {
728         format(
729             "private:\n"
730             "bool _internal_has_$name$() const;\n"
731             "public:\n");
732       }
733     } else if (HasPrivateHasMethod(field)) {
734       if (!IsFieldStripped(field, options_)) {
735         format(
736             "private:\n"
737             "bool ${1$_internal_has_$name$$}$() const;\n"
738             "public:\n",
739             field);
740       }
741     }
742     format("$deprecated_attr$void ${1$clear_$name$$}$()$2$\n", field,
743            !IsFieldStripped(field, options_) ? ";" : "{__builtin_trap();}");
744 
745     // Generate type-specific accessor declarations.
746     field_generators_.get(field).GenerateAccessorDeclarations(printer);
747 
748     format("\n");
749   }
750 
751   if (descriptor_->extension_range_count() > 0) {
752     // Generate accessors for extensions.  We just call a macro located in
753     // extension_set.h since the accessors about 80 lines of static code.
754     format("$GOOGLE_PROTOBUF$_EXTENSION_ACCESSORS($classname$)\n");
755     // Generate MessageSet specific APIs for proto2 MessageSet.
756     // For testing purposes we don't check for bridge.MessageSet, so
757     // we don't use IsProto2MessageSet
758     if (descriptor_->options().message_set_wire_format() &&
759         !options_.opensource_runtime && !options_.lite_implicit_weak_fields) {
760       // Special-case MessageSet
761       format("GOOGLE_PROTOBUF_EXTENSION_MESSAGE_SET_ACCESSORS($classname$)\n");
762     }
763   }
764 
765   for (auto oneof : OneOfRange(descriptor_)) {
766     Formatter::SaveState saver(&format);
767     format.Set("oneof_name", oneof->name());
768     format.Set("camel_oneof_name", UnderscoresToCamelCase(oneof->name(), true));
769     format(
770         "void ${1$clear_$oneof_name$$}$();\n"
771         "$camel_oneof_name$Case $oneof_name$_case() const;\n",
772         oneof);
773   }
774 }
775 
GenerateSingularFieldHasBits(const FieldDescriptor * field,Formatter format)776 void MessageGenerator::GenerateSingularFieldHasBits(
777     const FieldDescriptor* field, Formatter format) {
778   if (IsFieldStripped(field, options_)) {
779     format(
780         "inline bool $classname$::has_$name$() const { "
781         "__builtin_trap(); }\n");
782     return;
783   }
784   if (field->options().weak()) {
785     format(
786         "inline bool $classname$::has_$name$() const {\n"
787         "$annotate_accessor$"
788         "  return _weak_field_map_.Has($number$);\n"
789         "}\n");
790     return;
791   }
792   if (HasHasbit(field)) {
793     int has_bit_index = HasBitIndex(field);
794     GOOGLE_CHECK_NE(has_bit_index, kNoHasbit);
795 
796     format.Set("has_array_index", has_bit_index / 32);
797     format.Set("has_mask",
798                strings::Hex(1u << (has_bit_index % 32), strings::ZERO_PAD_8));
799     format(
800         "inline bool $classname$::_internal_has_$name$() const {\n"
801         "  bool value = "
802         "(_has_bits_[$has_array_index$] & 0x$has_mask$u) != 0;\n");
803 
804     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
805         !IsLazy(field, options_)) {
806       // We maintain the invariant that for a submessage x, has_x() returning
807       // true implies that x_ is not null. By giving this information to the
808       // compiler, we allow it to eliminate unnecessary null checks later on.
809       format("  PROTOBUF_ASSUME(!value || $name$_ != nullptr);\n");
810     }
811 
812     format(
813         "  return value;\n"
814         "}\n"
815         "inline bool $classname$::has_$name$() const {\n"
816         "$annotate_accessor$"
817         "  return _internal_has_$name$();\n"
818         "}\n");
819   } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
820     // Message fields have a has_$name$() method.
821     if (IsLazy(field, options_)) {
822       format(
823           "inline bool $classname$::_internal_has_$name$() const {\n"
824           "  return !$name$_.IsCleared();\n"
825           "}\n");
826     } else {
827       format(
828           "inline bool $classname$::_internal_has_$name$() const {\n"
829           "  return this != internal_default_instance() "
830           "&& $name$_ != nullptr;\n"
831           "}\n");
832     }
833     format(
834         "inline bool $classname$::has_$name$() const {\n"
835         "$annotate_accessor$"
836         "  return _internal_has_$name$();\n"
837         "}\n");
838   }
839 }
840 
GenerateOneofHasBits(io::Printer * printer)841 void MessageGenerator::GenerateOneofHasBits(io::Printer* printer) {
842   Formatter format(printer, variables_);
843   for (auto oneof : OneOfRange(descriptor_)) {
844     format.Set("oneof_name", oneof->name());
845     format.Set("oneof_index", oneof->index());
846     format.Set("cap_oneof_name", ToUpper(oneof->name()));
847     format(
848         "inline bool $classname$::has_$oneof_name$() const {\n"
849         "  return $oneof_name$_case() != $cap_oneof_name$_NOT_SET;\n"
850         "}\n"
851         "inline void $classname$::clear_has_$oneof_name$() {\n"
852         "  _oneof_case_[$oneof_index$] = $cap_oneof_name$_NOT_SET;\n"
853         "}\n");
854   }
855 }
856 
GenerateOneofMemberHasBits(const FieldDescriptor * field,const Formatter & format)857 void MessageGenerator::GenerateOneofMemberHasBits(const FieldDescriptor* field,
858                                                   const Formatter& format) {
859   if (IsFieldStripped(field, options_)) {
860     if (HasHasMethod(field)) {
861       format(
862           "inline bool $classname$::has_$name$() const { "
863           "__builtin_trap(); }\n");
864     }
865     format(
866         "inline void $classname$::set_has_$name$() { __builtin_trap(); "
867         "}\n");
868     return;
869   }
870   // Singular field in a oneof
871   // N.B.: Without field presence, we do not use has-bits or generate
872   // has_$name$() methods, but oneofs still have set_has_$name$().
873   // Oneofs also have has_$name$() but only as a private helper
874   // method, so that generated code is slightly cleaner (vs.  comparing
875   // _oneof_case_[index] against a constant everywhere).
876   //
877   // If has_$name$() is private, there is no need to add an internal accessor.
878   // Only annotate public accessors.
879   if (HasHasMethod(field)) {
880     format(
881         "inline bool $classname$::_internal_has_$name$() const {\n"
882         "  return $oneof_name$_case() == k$field_name$;\n"
883         "}\n"
884         "inline bool $classname$::has_$name$() const {\n"
885         "$annotate_accessor$"
886         "  return _internal_has_$name$();\n"
887         "}\n");
888   } else if (HasPrivateHasMethod(field)) {
889     format(
890         "inline bool $classname$::_internal_has_$name$() const {\n"
891         "  return $oneof_name$_case() == k$field_name$;\n"
892         "}\n");
893   }
894   // set_has_$name$() for oneof fields is always private; hence should not be
895   // annotated.
896   format(
897       "inline void $classname$::set_has_$name$() {\n"
898       "  _oneof_case_[$oneof_index$] = k$field_name$;\n"
899       "}\n");
900 }
901 
GenerateFieldClear(const FieldDescriptor * field,bool is_inline,Formatter format)902 void MessageGenerator::GenerateFieldClear(const FieldDescriptor* field,
903                                           bool is_inline, Formatter format) {
904   if (IsFieldStripped(field, options_)) {
905     format("void $classname$::clear_$name$() { __builtin_trap(); }\n");
906     return;
907   }
908 
909   // Generate clear_$name$().
910   if (is_inline) {
911     format("inline ");
912   }
913   format(
914       "void $classname$::clear_$name$() {\n"
915       "$annotate_accessor$");
916 
917   format.Indent();
918 
919   if (field->real_containing_oneof()) {
920     // Clear this field only if it is the active field in this oneof,
921     // otherwise ignore
922     format("if (_internal_has_$name$()) {\n");
923     format.Indent();
924     field_generators_.get(field).GenerateClearingCode(format.printer());
925     format("clear_has_$oneof_name$();\n");
926     format.Outdent();
927     format("}\n");
928   } else {
929     field_generators_.get(field).GenerateClearingCode(format.printer());
930     if (HasHasbit(field)) {
931       int has_bit_index = HasBitIndex(field);
932       format.Set("has_array_index", has_bit_index / 32);
933       format.Set("has_mask",
934                  strings::Hex(1u << (has_bit_index % 32), strings::ZERO_PAD_8));
935       format("_has_bits_[$has_array_index$] &= ~0x$has_mask$u;\n");
936     }
937   }
938 
939   format.Outdent();
940   format("}\n");
941 }
942 
GenerateFieldAccessorDefinitions(io::Printer * printer)943 void MessageGenerator::GenerateFieldAccessorDefinitions(io::Printer* printer) {
944   Formatter format(printer, variables_);
945   format("// $classname$\n\n");
946 
947   for (auto field : FieldRange(descriptor_)) {
948     PrintFieldComment(format, field);
949 
950     if (IsFieldStripped(field, options_)) {
951       continue;
952     }
953 
954     std::map<std::string, std::string> vars;
955     SetCommonFieldVariables(field, &vars, options_);
956 
957     Formatter::SaveState saver(&format);
958     format.AddMap(vars);
959 
960     // Generate has_$name$() or $name$_size().
961     if (field->is_repeated()) {
962       if (IsFieldStripped(field, options_)) {
963         format(
964             "inline int $classname$::$name$_size() const { "
965             "__builtin_trap(); }\n");
966       } else {
967         format(
968             "inline int $classname$::_internal_$name$_size() const {\n"
969             "  return $name$_$1$.size();\n"
970             "}\n"
971             "inline int $classname$::$name$_size() const {\n"
972             "$annotate_accessor$"
973             "  return _internal_$name$_size();\n"
974             "}\n",
975             IsImplicitWeakField(field, options_, scc_analyzer_) &&
976                     field->message_type()
977                 ? ".weak"
978                 : "");
979       }
980     } else if (field->real_containing_oneof()) {
981       format.Set("field_name", UnderscoresToCamelCase(field->name(), true));
982       format.Set("oneof_name", field->containing_oneof()->name());
983       format.Set("oneof_index",
984                  StrCat(field->containing_oneof()->index()));
985       GenerateOneofMemberHasBits(field, format);
986     } else {
987       // Singular field.
988       GenerateSingularFieldHasBits(field, format);
989     }
990 
991     if (!IsCrossFileMaybeMap(field)) {
992       GenerateFieldClear(field, true, format);
993     }
994 
995     // Generate type-specific accessors.
996     if (!IsFieldStripped(field, options_)) {
997       field_generators_.get(field).GenerateInlineAccessorDefinitions(printer);
998     }
999 
1000     format("\n");
1001   }
1002 
1003   // Generate has_$name$() and clear_has_$name$() functions for oneofs.
1004   GenerateOneofHasBits(printer);
1005 }
1006 
GenerateClassDefinition(io::Printer * printer)1007 void MessageGenerator::GenerateClassDefinition(io::Printer* printer) {
1008   Formatter format(printer, variables_);
1009   format.Set("class_final", ShouldMarkClassAsFinal(descriptor_, options_)
1010                                 ? "PROTOBUF_FINAL"
1011                                 : "");
1012 
1013   if (IsMapEntryMessage(descriptor_)) {
1014     std::map<std::string, std::string> vars;
1015     CollectMapInfo(options_, descriptor_, &vars);
1016     vars["lite"] =
1017         HasDescriptorMethods(descriptor_->file(), options_) ? "" : "Lite";
1018     format.AddMap(vars);
1019     format(
1020         "class $classname$ : public "
1021         "::$proto_ns$::internal::MapEntry$lite$<$classname$, \n"
1022         "    $key_cpp$, $val_cpp$,\n"
1023         "    ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
1024         "    ::$proto_ns$::internal::WireFormatLite::$val_wire_type$> {\n"
1025         "public:\n"
1026         "  typedef ::$proto_ns$::internal::MapEntry$lite$<$classname$, \n"
1027         "    $key_cpp$, $val_cpp$,\n"
1028         "    ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
1029         "    ::$proto_ns$::internal::WireFormatLite::$val_wire_type$> "
1030         "SuperType;\n"
1031         "  $classname$();\n"
1032         "  explicit $classname$(::$proto_ns$::Arena* arena);\n"
1033         "  void MergeFrom(const $classname$& other);\n"
1034         "  static const $classname$* internal_default_instance() { return "
1035         "reinterpret_cast<const "
1036         "$classname$*>(&_$classname$_default_instance_); }\n");
1037     auto utf8_check = GetUtf8CheckMode(descriptor_->field(0), options_);
1038     if (descriptor_->field(0)->type() == FieldDescriptor::TYPE_STRING &&
1039         utf8_check != NONE) {
1040       if (utf8_check == STRICT) {
1041         format(
1042             "  static bool ValidateKey(std::string* s) {\n"
1043             "    return ::$proto_ns$::internal::WireFormatLite::"
1044             "VerifyUtf8String(s->data(), static_cast<int>(s->size()), "
1045             "::$proto_ns$::internal::WireFormatLite::PARSE, \"$1$\");\n"
1046             " }\n",
1047             descriptor_->field(0)->full_name());
1048       } else {
1049         GOOGLE_CHECK(utf8_check == VERIFY);
1050         format(
1051             "  static bool ValidateKey(std::string* s) {\n"
1052             "#ifndef NDEBUG\n"
1053             "    ::$proto_ns$::internal::WireFormatLite::VerifyUtf8String(\n"
1054             "       s->data(), static_cast<int>(s->size()), "
1055             "::$proto_ns$::internal::"
1056             "WireFormatLite::PARSE, \"$1$\");\n"
1057             "#else\n"
1058             "    (void) s;\n"
1059             "#endif\n"
1060             "    return true;\n"
1061             " }\n",
1062             descriptor_->field(0)->full_name());
1063       }
1064     } else {
1065       format("  static bool ValidateKey(void*) { return true; }\n");
1066     }
1067     if (descriptor_->field(1)->type() == FieldDescriptor::TYPE_STRING &&
1068         utf8_check != NONE) {
1069       if (utf8_check == STRICT) {
1070         format(
1071             "  static bool ValidateValue(std::string* s) {\n"
1072             "    return ::$proto_ns$::internal::WireFormatLite::"
1073             "VerifyUtf8String(s->data(), static_cast<int>(s->size()), "
1074             "::$proto_ns$::internal::WireFormatLite::PARSE, \"$1$\");\n"
1075             " }\n",
1076             descriptor_->field(1)->full_name());
1077       } else {
1078         GOOGLE_CHECK(utf8_check = VERIFY);
1079         format(
1080             "  static bool ValidateValue(std::string* s) {\n"
1081             "#ifndef NDEBUG\n"
1082             "    ::$proto_ns$::internal::WireFormatLite::VerifyUtf8String(\n"
1083             "       s->data(), static_cast<int>(s->size()), "
1084             "::$proto_ns$::internal::"
1085             "WireFormatLite::PARSE, \"$1$\");\n"
1086             "#else\n"
1087             "    (void) s;\n"
1088             "#endif\n"
1089             "    return true;\n"
1090             " }\n",
1091             descriptor_->field(1)->full_name());
1092       }
1093     } else {
1094       format("  static bool ValidateValue(void*) { return true; }\n");
1095     }
1096     if (HasDescriptorMethods(descriptor_->file(), options_)) {
1097       format(
1098           "  void MergeFrom(const ::$proto_ns$::Message& other) final;\n"
1099           "  ::$proto_ns$::Metadata GetMetadata() const final;\n"
1100           "  private:\n"
1101           "  static ::$proto_ns$::Metadata GetMetadataStatic() {\n"
1102           "    ::$proto_ns$::internal::AssignDescriptors(&::$desc_table$);\n"
1103           "    return ::$desc_table$.file_level_metadata[$1$];\n"
1104           "  }\n"
1105           "\n"
1106           "  public:\n"
1107           "};\n",
1108           index_in_file_messages_);
1109     } else {
1110       format("};\n");
1111     }
1112     return;
1113   }
1114 
1115   format(
1116       "class $dllexport_decl $${1$$classname$$}$$ class_final$ :\n"
1117       "    public $superclass$ /* @@protoc_insertion_point("
1118       "class_definition:$full_name$) */ {\n",
1119       descriptor_);
1120   format(" public:\n");
1121   format.Indent();
1122 
1123   format(
1124       "inline $classname$() : $classname$(nullptr) {}\n"
1125       "virtual ~$classname$();\n"
1126       "\n"
1127       "$classname$(const $classname$& from);\n"
1128       "$classname$($classname$&& from) noexcept\n"
1129       "  : $classname$() {\n"
1130       "  *this = ::std::move(from);\n"
1131       "}\n"
1132       "\n"
1133       "inline $classname$& operator=(const $classname$& from) {\n"
1134       "  CopyFrom(from);\n"
1135       "  return *this;\n"
1136       "}\n"
1137       "inline $classname$& operator=($classname$&& from) noexcept {\n"
1138       "  if (GetArena() == from.GetArena()) {\n"
1139       "    if (this != &from) InternalSwap(&from);\n"
1140       "  } else {\n"
1141       "    CopyFrom(from);\n"
1142       "  }\n"
1143       "  return *this;\n"
1144       "}\n"
1145       "\n");
1146 
1147   if (options_.table_driven_serialization) {
1148     format(
1149         "private:\n"
1150         "const void* InternalGetTable() const;\n"
1151         "public:\n"
1152         "\n");
1153   }
1154 
1155   std::map<std::string, std::string> vars;
1156   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
1157   format.AddMap(vars);
1158   if (PublicUnknownFieldsAccessors(descriptor_)) {
1159     format(
1160         "inline const $unknown_fields_type$& unknown_fields() const {\n"
1161         "  return $unknown_fields$;\n"
1162         "}\n"
1163         "inline $unknown_fields_type$* mutable_unknown_fields() {\n"
1164         "  return $mutable_unknown_fields$;\n"
1165         "}\n"
1166         "\n");
1167   }
1168 
1169   // Only generate this member if it's not disabled.
1170   if (HasDescriptorMethods(descriptor_->file(), options_) &&
1171       !descriptor_->options().no_standard_descriptor_accessor()) {
1172     format(
1173         "static const ::$proto_ns$::Descriptor* descriptor() {\n"
1174         "  return GetDescriptor();\n"
1175         "}\n");
1176   }
1177 
1178   if (HasDescriptorMethods(descriptor_->file(), options_)) {
1179     // These shadow non-static methods of the same names in Message.  We
1180     // redefine them here because calls directly on the generated class can be
1181     // statically analyzed -- we know what descriptor types are being requested.
1182     // It also avoids a vtable dispatch.
1183     //
1184     // We would eventually like to eliminate the methods in Message, and having
1185     // this separate also lets us track calls to the base class methods
1186     // separately.
1187     format(
1188         "static const ::$proto_ns$::Descriptor* GetDescriptor() {\n"
1189         "  return GetMetadataStatic().descriptor;\n"
1190         "}\n"
1191         "static const ::$proto_ns$::Reflection* GetReflection() {\n"
1192         "  return GetMetadataStatic().reflection;\n"
1193         "}\n");
1194   }
1195 
1196   format(
1197       "static const $classname$& default_instance();\n"
1198       "\n");
1199 
1200   // Generate enum values for every field in oneofs. One list is generated for
1201   // each oneof with an additional *_NOT_SET value.
1202   for (auto oneof : OneOfRange(descriptor_)) {
1203     format("enum $1$Case {\n", UnderscoresToCamelCase(oneof->name(), true));
1204     format.Indent();
1205     for (auto field : FieldRange(oneof)) {
1206       std::string oneof_enum_case_field_name =
1207           UnderscoresToCamelCase(field->name(), true);
1208       format("k$1$ = $2$,\n", oneof_enum_case_field_name,  // 1
1209              field->number());                             // 2
1210     }
1211     format("$1$_NOT_SET = 0,\n", ToUpper(oneof->name()));
1212     format.Outdent();
1213     format(
1214         "};\n"
1215         "\n");
1216   }
1217 
1218   // TODO(gerbens) make this private, while still granting other protos access.
1219   format(
1220       "static inline const $classname$* internal_default_instance() {\n"
1221       "  return reinterpret_cast<const $classname$*>(\n"
1222       "             &_$classname$_default_instance_);\n"
1223       "}\n"
1224       "static constexpr int kIndexInFileMessages =\n"
1225       "  $1$;\n"
1226       "\n",
1227       index_in_file_messages_);
1228 
1229   if (IsAnyMessage(descriptor_, options_)) {
1230     format(
1231         "// implements Any -----------------------------------------------\n"
1232         "\n");
1233     if (HasDescriptorMethods(descriptor_->file(), options_)) {
1234       format(
1235           "void PackFrom(const ::$proto_ns$::Message& message) {\n"
1236           "  _any_metadata_.PackFrom(message);\n"
1237           "}\n"
1238           "void PackFrom(const ::$proto_ns$::Message& message,\n"
1239           "              ::PROTOBUF_NAMESPACE_ID::ConstStringParam "
1240           "type_url_prefix) {\n"
1241           "  _any_metadata_.PackFrom(message, type_url_prefix);\n"
1242           "}\n"
1243           "bool UnpackTo(::$proto_ns$::Message* message) const {\n"
1244           "  return _any_metadata_.UnpackTo(message);\n"
1245           "}\n"
1246           "static bool GetAnyFieldDescriptors(\n"
1247           "    const ::$proto_ns$::Message& message,\n"
1248           "    const ::$proto_ns$::FieldDescriptor** type_url_field,\n"
1249           "    const ::$proto_ns$::FieldDescriptor** value_field);\n"
1250           "template <typename T, class = typename std::enable_if<"
1251           "!std::is_convertible<T, const ::$proto_ns$::Message&>"
1252           "::value>::type>\n"
1253           "void PackFrom(const T& message) {\n"
1254           "  _any_metadata_.PackFrom<T>(message);\n"
1255           "}\n"
1256           "template <typename T, class = typename std::enable_if<"
1257           "!std::is_convertible<T, const ::$proto_ns$::Message&>"
1258           "::value>::type>\n"
1259           "void PackFrom(const T& message,\n"
1260           "              ::PROTOBUF_NAMESPACE_ID::ConstStringParam "
1261           "type_url_prefix) {\n"
1262           "  _any_metadata_.PackFrom<T>(message, type_url_prefix);"
1263           "}\n"
1264           "template <typename T, class = typename std::enable_if<"
1265           "!std::is_convertible<T, const ::$proto_ns$::Message&>"
1266           "::value>::type>\n"
1267           "bool UnpackTo(T* message) const {\n"
1268           "  return _any_metadata_.UnpackTo<T>(message);\n"
1269           "}\n");
1270     } else {
1271       format(
1272           "template <typename T>\n"
1273           "void PackFrom(const T& message) {\n"
1274           "  _any_metadata_.PackFrom(message);\n"
1275           "}\n"
1276           "template <typename T>\n"
1277           "void PackFrom(const T& message,\n"
1278           "              ::PROTOBUF_NAMESPACE_ID::ConstStringParam "
1279           "type_url_prefix) {\n"
1280           "  _any_metadata_.PackFrom(message, type_url_prefix);\n"
1281           "}\n"
1282           "template <typename T>\n"
1283           "bool UnpackTo(T* message) const {\n"
1284           "  return _any_metadata_.UnpackTo(message);\n"
1285           "}\n");
1286     }
1287     format(
1288         "template<typename T> bool Is() const {\n"
1289         "  return _any_metadata_.Is<T>();\n"
1290         "}\n"
1291         "static bool ParseAnyTypeUrl(::PROTOBUF_NAMESPACE_ID::ConstStringParam "
1292         "type_url,\n"
1293         "                            std::string* full_type_name);\n");
1294   }
1295 
1296   format.Set("new_final",
1297              ShouldMarkNewAsFinal(descriptor_, options_) ? "final" : "");
1298 
1299   format(
1300       "friend void swap($classname$& a, $classname$& b) {\n"
1301       "  a.Swap(&b);\n"
1302       "}\n"
1303       "inline void Swap($classname$* other) {\n"
1304       "  if (other == this) return;\n"
1305       "  if (GetArena() == other->GetArena()) {\n"
1306       "    InternalSwap(other);\n"
1307       "  } else {\n"
1308       "    ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other);\n"
1309       "  }\n"
1310       "}\n"
1311       "void UnsafeArenaSwap($classname$* other) {\n"
1312       "  if (other == this) return;\n"
1313       "  $DCHK$(GetArena() == other->GetArena());\n"
1314       "  InternalSwap(other);\n"
1315       "}\n");
1316 
1317   format(
1318       "\n"
1319       "// implements Message ----------------------------------------------\n"
1320       "\n"
1321       "inline $classname$* New() const$ new_final$ {\n"
1322       "  return CreateMaybeMessage<$classname$>(nullptr);\n"
1323       "}\n"
1324       "\n"
1325       "$classname$* New(::$proto_ns$::Arena* arena) const$ new_final$ {\n"
1326       "  return CreateMaybeMessage<$classname$>(arena);\n"
1327       "}\n");
1328 
1329   // For instances that derive from Message (rather than MessageLite), some
1330   // methods are virtual and should be marked as final.
1331   format.Set("full_final", HasDescriptorMethods(descriptor_->file(), options_)
1332                                ? "final"
1333                                : "");
1334 
1335   if (HasGeneratedMethods(descriptor_->file(), options_)) {
1336     if (HasDescriptorMethods(descriptor_->file(), options_)) {
1337       format(
1338           "void CopyFrom(const ::$proto_ns$::Message& from) final;\n"
1339           "void MergeFrom(const ::$proto_ns$::Message& from) final;\n");
1340     } else {
1341       format(
1342           "void CheckTypeAndMergeFrom(const ::$proto_ns$::MessageLite& from)\n"
1343           "  final;\n");
1344     }
1345 
1346     format.Set("clear_final",
1347                ShouldMarkClearAsFinal(descriptor_, options_) ? "final" : "");
1348     format.Set(
1349         "is_initialized_final",
1350         ShouldMarkIsInitializedAsFinal(descriptor_, options_) ? "final" : "");
1351 
1352     format(
1353         "void CopyFrom(const $classname$& from);\n"
1354         "void MergeFrom(const $classname$& from);\n"
1355         "PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear()$ clear_final$;\n"
1356         "bool IsInitialized() const$ is_initialized_final$;\n"
1357         "\n"
1358         "size_t ByteSizeLong() const final;\n"
1359         "const char* _InternalParse(const char* ptr, "
1360         "::$proto_ns$::internal::ParseContext* ctx) final;\n"
1361         "$uint8$* _InternalSerialize(\n"
1362         "    $uint8$* target, ::$proto_ns$::io::EpsCopyOutputStream* stream) "
1363         "const final;\n");
1364 
1365     // DiscardUnknownFields() is implemented in message.cc using reflections. We
1366     // need to implement this function in generated code for messages.
1367     if (!UseUnknownFieldSet(descriptor_->file(), options_)) {
1368       format("void DiscardUnknownFields()$ full_final$;\n");
1369     }
1370   }
1371 
1372   format(
1373       "int GetCachedSize() const final { return _cached_size_.Get(); }"
1374       "\n\nprivate:\n"
1375       "inline void SharedCtor();\n"
1376       "inline void SharedDtor();\n"
1377       "void SetCachedSize(int size) const$ full_final$;\n"
1378       "void InternalSwap($classname$* other);\n");
1379 
1380   format(
1381       // Friend AnyMetadata so that it can call this FullMessageName() method.
1382       "friend class ::$proto_ns$::internal::AnyMetadata;\n"
1383       "static $1$ FullMessageName() {\n"
1384       "  return \"$full_name$\";\n"
1385       "}\n",
1386       options_.opensource_runtime ? "::PROTOBUF_NAMESPACE_ID::StringPiece"
1387                                   : "::StringPiece");
1388 
1389   format(
1390       // TODO(gerbens) Make this private! Currently people are deriving from
1391       // protos to give access to this constructor, breaking the invariants
1392       // we rely on.
1393       "protected:\n"
1394       "explicit $classname$(::$proto_ns$::Arena* arena);\n"
1395       "private:\n"
1396       "static void ArenaDtor(void* object);\n"
1397       "inline void RegisterArenaDtor(::$proto_ns$::Arena* arena);\n");
1398 
1399   format(
1400       "public:\n"
1401       "\n");
1402 
1403   if (HasDescriptorMethods(descriptor_->file(), options_)) {
1404     format(
1405         "::$proto_ns$::Metadata GetMetadata() const final;\n"
1406         "private:\n"
1407         "static ::$proto_ns$::Metadata GetMetadataStatic() {\n"
1408         "  ::$proto_ns$::internal::AssignDescriptors(&::$desc_table$);\n"
1409         "  return ::$desc_table$.file_level_metadata[kIndexInFileMessages];\n"
1410         "}\n"
1411         "\n"
1412         "public:\n"
1413         "\n");
1414   } else {
1415     format(
1416         "std::string GetTypeName() const final;\n"
1417         "\n");
1418   }
1419 
1420   format(
1421       "// nested types ----------------------------------------------------\n"
1422       "\n");
1423 
1424   // Import all nested message classes into this class's scope with typedefs.
1425   for (int i = 0; i < descriptor_->nested_type_count(); i++) {
1426     const Descriptor* nested_type = descriptor_->nested_type(i);
1427     if (!IsMapEntryMessage(nested_type)) {
1428       format.Set("nested_full_name", ClassName(nested_type, false));
1429       format.Set("nested_name", ResolveKeyword(nested_type->name()));
1430       format("typedef ${1$$nested_full_name$$}$ ${1$$nested_name$$}$;\n",
1431              nested_type);
1432     }
1433   }
1434 
1435   if (descriptor_->nested_type_count() > 0) {
1436     format("\n");
1437   }
1438 
1439   // Import all nested enums and their values into this class's scope with
1440   // typedefs and constants.
1441   for (int i = 0; i < descriptor_->enum_type_count(); i++) {
1442     enum_generators_[i]->GenerateSymbolImports(printer);
1443     format("\n");
1444   }
1445 
1446   format(
1447       "// accessors -------------------------------------------------------\n"
1448       "\n");
1449 
1450   // Generate accessor methods for all fields.
1451   GenerateFieldAccessorDeclarations(printer);
1452 
1453   // Declare extension identifiers.
1454   for (int i = 0; i < descriptor_->extension_count(); i++) {
1455     extension_generators_[i]->GenerateDeclaration(printer);
1456   }
1457 
1458 
1459   format("// @@protoc_insertion_point(class_scope:$full_name$)\n");
1460 
1461   // Generate private members.
1462   format.Outdent();
1463   format(" private:\n");
1464   format.Indent();
1465   // TODO(seongkim): Remove hack to track field access and remove this class.
1466   format("class _Internal;\n");
1467 
1468   for (auto field : FieldRange(descriptor_)) {
1469     // set_has_***() generated in all oneofs.
1470     if (!field->is_repeated() && !field->options().weak() &&
1471         field->real_containing_oneof()) {
1472       format("void set_has_$1$();\n", FieldName(field));
1473     }
1474   }
1475   format("\n");
1476 
1477   // Generate oneof function declarations
1478   for (auto oneof : OneOfRange(descriptor_)) {
1479     format(
1480         "inline bool has_$1$() const;\n"
1481         "inline void clear_has_$1$();\n\n",
1482         oneof->name());
1483   }
1484 
1485   if (HasGeneratedMethods(descriptor_->file(), options_) &&
1486       !descriptor_->options().message_set_wire_format() &&
1487       num_required_fields_ > 1) {
1488     format(
1489         "// helper for ByteSizeLong()\n"
1490         "size_t RequiredFieldsByteSizeFallback() const;\n\n");
1491   }
1492 
1493   // Prepare decls for _cached_size_ and _has_bits_.  Their position in the
1494   // output will be determined later.
1495 
1496   bool need_to_emit_cached_size = true;
1497   const std::string cached_size_decl =
1498       "mutable ::$proto_ns$::internal::CachedSize _cached_size_;\n";
1499 
1500   const size_t sizeof_has_bits = HasBitsSize();
1501   const std::string has_bits_decl =
1502       sizeof_has_bits == 0 ? ""
1503                            : StrCat("::$proto_ns$::internal::HasBits<",
1504                                           sizeof_has_bits, "> _has_bits_;\n");
1505 
1506   // To minimize padding, data members are divided into three sections:
1507   // (1) members assumed to align to 8 bytes
1508   // (2) members corresponding to message fields, re-ordered to optimize
1509   //     alignment.
1510   // (3) members assumed to align to 4 bytes.
1511 
1512   // Members assumed to align to 8 bytes:
1513 
1514   if (descriptor_->extension_range_count() > 0) {
1515     format(
1516         "::$proto_ns$::internal::ExtensionSet _extensions_;\n"
1517         "\n");
1518   }
1519 
1520   format(
1521       "template <typename T> friend class "
1522       "::$proto_ns$::Arena::InternalHelper;\n"
1523       "typedef void InternalArenaConstructable_;\n"
1524       "typedef void DestructorSkippable_;\n");
1525 
1526   if (!has_bit_indices_.empty()) {
1527     // _has_bits_ is frequently accessed, so to reduce code size and improve
1528     // speed, it should be close to the start of the object. Placing
1529     // _cached_size_ together with _has_bits_ improves cache locality despite
1530     // potential alignment padding.
1531     format(has_bits_decl.c_str());
1532     format(cached_size_decl.c_str());
1533     need_to_emit_cached_size = false;
1534   }
1535 
1536   // Field members:
1537 
1538   // Emit some private and static members
1539   for (auto field : optimized_order_) {
1540     const FieldGenerator& generator = field_generators_.get(field);
1541     generator.GenerateStaticMembers(printer);
1542     generator.GeneratePrivateMembers(printer);
1543   }
1544 
1545   // For each oneof generate a union
1546   for (auto oneof : OneOfRange(descriptor_)) {
1547     std::string camel_oneof_name = UnderscoresToCamelCase(oneof->name(), true);
1548     format(
1549         "union $1$Union {\n"
1550         // explicit empty constructor is needed when union contains
1551         // ArenaStringPtr members for string fields.
1552         "  $1$Union() {}\n",
1553         camel_oneof_name);
1554     format.Indent();
1555     for (auto field : FieldRange(oneof)) {
1556       if (!IsFieldStripped(field, options_)) {
1557         field_generators_.get(field).GeneratePrivateMembers(printer);
1558       }
1559     }
1560     format.Outdent();
1561     format("} $1$_;\n", oneof->name());
1562     for (auto field : FieldRange(oneof)) {
1563       if (!IsFieldStripped(field, options_)) {
1564         field_generators_.get(field).GenerateStaticMembers(printer);
1565       }
1566     }
1567   }
1568 
1569   // Members assumed to align to 4 bytes:
1570 
1571   if (need_to_emit_cached_size) {
1572     format(cached_size_decl.c_str());
1573     need_to_emit_cached_size = false;
1574   }
1575 
1576   // Generate _oneof_case_.
1577   if (descriptor_->real_oneof_decl_count() > 0) {
1578     format(
1579         "$uint32$ _oneof_case_[$1$];\n"
1580         "\n",
1581         descriptor_->real_oneof_decl_count());
1582   }
1583 
1584   if (num_weak_fields_) {
1585     format("::$proto_ns$::internal::WeakFieldMap _weak_field_map_;\n");
1586   }
1587   // Generate _any_metadata_ for the Any type.
1588   if (IsAnyMessage(descriptor_, options_)) {
1589     format("::$proto_ns$::internal::AnyMetadata _any_metadata_;\n");
1590   }
1591 
1592   // The TableStruct struct needs access to the private parts, in order to
1593   // construct the offsets of all members.
1594   format("friend struct ::$tablename$;\n");
1595 
1596   format.Outdent();
1597   format("};");
1598   GOOGLE_DCHECK(!need_to_emit_cached_size);
1599 }  // NOLINT(readability/fn_size)
1600 
GenerateInlineMethods(io::Printer * printer)1601 void MessageGenerator::GenerateInlineMethods(io::Printer* printer) {
1602   if (IsMapEntryMessage(descriptor_)) return;
1603   GenerateFieldAccessorDefinitions(printer);
1604 
1605   // Generate oneof_case() functions.
1606   for (auto oneof : OneOfRange(descriptor_)) {
1607     Formatter format(printer, variables_);
1608     format.Set("camel_oneof_name", UnderscoresToCamelCase(oneof->name(), true));
1609     format.Set("oneof_name", oneof->name());
1610     format.Set("oneof_index", oneof->index());
1611     format(
1612         "inline $classname$::$camel_oneof_name$Case $classname$::"
1613         "${1$$oneof_name$_case$}$() const {\n"
1614         "  return $classname$::$camel_oneof_name$Case("
1615         "_oneof_case_[$oneof_index$]);\n"
1616         "}\n",
1617         oneof);
1618   }
1619 }
1620 
GenerateParseTable(io::Printer * printer,size_t offset,size_t aux_offset)1621 bool MessageGenerator::GenerateParseTable(io::Printer* printer, size_t offset,
1622                                           size_t aux_offset) {
1623   Formatter format(printer, variables_);
1624 
1625   if (!table_driven_) {
1626     format("{ nullptr, nullptr, 0, -1, -1, -1, -1, nullptr, false },\n");
1627     return false;
1628   }
1629 
1630   int max_field_number = 0;
1631   for (auto field : FieldRange(descriptor_)) {
1632     if (max_field_number < field->number()) {
1633       max_field_number = field->number();
1634     }
1635   }
1636 
1637   format("{\n");
1638   format.Indent();
1639 
1640   format(
1641       "$tablename$::entries + $1$,\n"
1642       "$tablename$::aux + $2$,\n"
1643       "$3$,\n",
1644       offset, aux_offset, max_field_number);
1645 
1646   if (has_bit_indices_.empty()) {
1647     // If no fields have hasbits, then _has_bits_ does not exist.
1648     format("-1,\n");
1649   } else {
1650     format("PROTOBUF_FIELD_OFFSET($classtype$, _has_bits_),\n");
1651   }
1652 
1653   if (descriptor_->real_oneof_decl_count() > 0) {
1654     format("PROTOBUF_FIELD_OFFSET($classtype$, _oneof_case_),\n");
1655   } else {
1656     format("-1,  // no _oneof_case_\n");
1657   }
1658 
1659   if (descriptor_->extension_range_count() > 0) {
1660     format("PROTOBUF_FIELD_OFFSET($classtype$, _extensions_),\n");
1661   } else {
1662     format("-1,  // no _extensions_\n");
1663   }
1664 
1665   // TODO(ckennelly): Consolidate this with the calculation for
1666   // AuxiliaryParseTableField.
1667   format(
1668       "PROTOBUF_FIELD_OFFSET($classtype$, _internal_metadata_),\n"
1669       "&$package_ns$::_$classname$_default_instance_,\n");
1670 
1671   if (UseUnknownFieldSet(descriptor_->file(), options_)) {
1672     format("true,\n");
1673   } else {
1674     format("false,\n");
1675   }
1676 
1677   format.Outdent();
1678   format("},\n");
1679   return true;
1680 }
1681 
GenerateSchema(io::Printer * printer,int offset,int has_offset)1682 void MessageGenerator::GenerateSchema(io::Printer* printer, int offset,
1683                                       int has_offset) {
1684   Formatter format(printer, variables_);
1685   has_offset = !has_bit_indices_.empty() || IsMapEntryMessage(descriptor_)
1686                    ? offset + has_offset
1687                    : -1;
1688 
1689   format("{ $1$, $2$, sizeof($classtype$)},\n", offset, has_offset);
1690 }
1691 
1692 namespace {
1693 
1694 // We need to calculate for each field what function the table driven code
1695 // should use to serialize it. This returns the index in a lookup table.
CalcFieldNum(const FieldGenerator & generator,const FieldDescriptor * field,const Options & options)1696 uint32 CalcFieldNum(const FieldGenerator& generator,
1697                     const FieldDescriptor* field, const Options& options) {
1698   bool is_a_map = IsMapEntryMessage(field->containing_type());
1699   int type = field->type();
1700   if (type == FieldDescriptor::TYPE_STRING ||
1701       type == FieldDescriptor::TYPE_BYTES) {
1702     // string field
1703     if (IsCord(field, options)) {
1704       type = internal::FieldMetadata::kCordType;
1705     } else if (IsStringPiece(field, options)) {
1706       type = internal::FieldMetadata::kStringPieceType;
1707     }
1708   }
1709 
1710   if (field->real_containing_oneof()) {
1711     return internal::FieldMetadata::CalculateType(
1712         type, internal::FieldMetadata::kOneOf);
1713   } else if (field->is_packed()) {
1714     return internal::FieldMetadata::CalculateType(
1715         type, internal::FieldMetadata::kPacked);
1716   } else if (field->is_repeated()) {
1717     return internal::FieldMetadata::CalculateType(
1718         type, internal::FieldMetadata::kRepeated);
1719   } else if (HasHasbit(field) || field->real_containing_oneof() || is_a_map) {
1720     return internal::FieldMetadata::CalculateType(
1721         type, internal::FieldMetadata::kPresence);
1722   } else {
1723     return internal::FieldMetadata::CalculateType(
1724         type, internal::FieldMetadata::kNoPresence);
1725   }
1726 }
1727 
FindMessageIndexInFile(const Descriptor * descriptor)1728 int FindMessageIndexInFile(const Descriptor* descriptor) {
1729   std::vector<const Descriptor*> flatten =
1730       FlattenMessagesInFile(descriptor->file());
1731   return std::find(flatten.begin(), flatten.end(), descriptor) -
1732          flatten.begin();
1733 }
1734 
1735 }  // namespace
1736 
GenerateFieldMetadata(io::Printer * printer)1737 int MessageGenerator::GenerateFieldMetadata(io::Printer* printer) {
1738   Formatter format(printer, variables_);
1739   if (!options_.table_driven_serialization) {
1740     return 0;
1741   }
1742 
1743   std::vector<const FieldDescriptor*> sorted = SortFieldsByNumber(descriptor_);
1744   if (IsMapEntryMessage(descriptor_)) {
1745     for (int i = 0; i < 2; i++) {
1746       const FieldDescriptor* field = sorted[i];
1747       const FieldGenerator& generator = field_generators_.get(field);
1748 
1749       uint32 tag = internal::WireFormatLite::MakeTag(
1750           field->number(), WireFormat::WireTypeForFieldType(field->type()));
1751 
1752       std::map<std::string, std::string> vars;
1753       vars["classtype"] = QualifiedClassName(descriptor_, options_);
1754       vars["field_name"] = FieldName(field);
1755       vars["tag"] = StrCat(tag);
1756       vars["hasbit"] = StrCat(i);
1757       vars["type"] = StrCat(CalcFieldNum(generator, field, options_));
1758       vars["ptr"] = "nullptr";
1759       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1760         GOOGLE_CHECK(!IsMapEntryMessage(field->message_type()));
1761         vars["ptr"] =
1762             "::" + UniqueName("TableStruct", field->message_type(), options_) +
1763             "::serialization_table + " +
1764             StrCat(FindMessageIndexInFile(field->message_type()));
1765       }
1766       Formatter::SaveState saver(&format);
1767       format.AddMap(vars);
1768       format(
1769           "{PROTOBUF_FIELD_OFFSET("
1770           "::$proto_ns$::internal::MapEntryHelper<$classtype$::"
1771           "SuperType>, $field_name$_), $tag$,"
1772           "PROTOBUF_FIELD_OFFSET("
1773           "::$proto_ns$::internal::MapEntryHelper<$classtype$::"
1774           "SuperType>, _has_bits_) * 8 + $hasbit$, $type$, "
1775           "$ptr$},\n");
1776     }
1777     return 2;
1778   }
1779   format(
1780       "{PROTOBUF_FIELD_OFFSET($classtype$, _cached_size_),"
1781       " 0, 0, 0, nullptr},\n");
1782   std::vector<const Descriptor::ExtensionRange*> sorted_extensions;
1783   sorted_extensions.reserve(descriptor_->extension_range_count());
1784   for (int i = 0; i < descriptor_->extension_range_count(); ++i) {
1785     sorted_extensions.push_back(descriptor_->extension_range(i));
1786   }
1787   std::sort(sorted_extensions.begin(), sorted_extensions.end(),
1788             ExtensionRangeSorter());
1789   for (int i = 0, extension_idx = 0; /* no range */; i++) {
1790     for (; extension_idx < sorted_extensions.size() &&
1791            (i == sorted.size() ||
1792             sorted_extensions[extension_idx]->start < sorted[i]->number());
1793          extension_idx++) {
1794       const Descriptor::ExtensionRange* range =
1795           sorted_extensions[extension_idx];
1796       format(
1797           "{PROTOBUF_FIELD_OFFSET($classtype$, _extensions_), "
1798           "$1$, $2$, ::$proto_ns$::internal::FieldMetadata::kSpecial, "
1799           "reinterpret_cast<const "
1800           "void*>(::$proto_ns$::internal::ExtensionSerializer)},\n",
1801           range->start, range->end);
1802     }
1803     if (i == sorted.size()) break;
1804     const FieldDescriptor* field = sorted[i];
1805 
1806     uint32 tag = internal::WireFormatLite::MakeTag(
1807         field->number(), WireFormat::WireTypeForFieldType(field->type()));
1808     if (field->is_packed()) {
1809       tag = internal::WireFormatLite::MakeTag(
1810           field->number(), WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
1811     }
1812 
1813     std::string classfieldname = FieldName(field);
1814     if (field->real_containing_oneof()) {
1815       classfieldname = field->containing_oneof()->name();
1816     }
1817     format.Set("field_name", classfieldname);
1818     std::string ptr = "nullptr";
1819     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1820       if (IsMapEntryMessage(field->message_type())) {
1821         format(
1822             "{PROTOBUF_FIELD_OFFSET($classtype$, $field_name$_), $1$, $2$, "
1823             "::$proto_ns$::internal::FieldMetadata::kSpecial, "
1824             "reinterpret_cast<const void*>(static_cast< "
1825             "::$proto_ns$::internal::SpecialSerializer>("
1826             "::$proto_ns$::internal::MapFieldSerializer< "
1827             "::$proto_ns$::internal::MapEntryToMapField<"
1828             "$3$>::MapFieldType, "
1829             "$tablename$::serialization_table>))},\n",
1830             tag, FindMessageIndexInFile(field->message_type()),
1831             QualifiedClassName(field->message_type(), options_));
1832         continue;
1833       } else if (!field->message_type()->options().message_set_wire_format()) {
1834         // message_set doesn't have the usual table and we need to
1835         // dispatch to generated serializer, hence ptr stays zero.
1836         ptr =
1837             "::" + UniqueName("TableStruct", field->message_type(), options_) +
1838             "::serialization_table + " +
1839             StrCat(FindMessageIndexInFile(field->message_type()));
1840       }
1841     }
1842 
1843     const FieldGenerator& generator = field_generators_.get(field);
1844     int type = CalcFieldNum(generator, field, options_);
1845 
1846     if (IsLazy(field, options_)) {
1847       type = internal::FieldMetadata::kSpecial;
1848       ptr = "reinterpret_cast<const void*>(::" + variables_["proto_ns"] +
1849             "::internal::LazyFieldSerializer";
1850       if (field->real_containing_oneof()) {
1851         ptr += "OneOf";
1852       } else if (!HasHasbit(field)) {
1853         ptr += "NoPresence";
1854       }
1855       ptr += ")";
1856     }
1857 
1858     if (field->options().weak()) {
1859       // TODO(gerbens) merge weak fields into ranges
1860       format(
1861           "{PROTOBUF_FIELD_OFFSET("
1862           "$classtype$, _weak_field_map_), $1$, $1$, "
1863           "::$proto_ns$::internal::FieldMetadata::kSpecial, "
1864           "reinterpret_cast<const "
1865           "void*>(::$proto_ns$::internal::WeakFieldSerializer)},\n",
1866           tag);
1867     } else if (field->real_containing_oneof()) {
1868       format.Set("oneofoffset",
1869                  sizeof(uint32) * field->containing_oneof()->index());
1870       format(
1871           "{PROTOBUF_FIELD_OFFSET($classtype$, $field_name$_), $1$,"
1872           " PROTOBUF_FIELD_OFFSET($classtype$, _oneof_case_) + "
1873           "$oneofoffset$, $2$, $3$},\n",
1874           tag, type, ptr);
1875     } else if (HasHasbit(field)) {
1876       format.Set("hasbitsoffset", has_bit_indices_[field->index()]);
1877       format(
1878           "{PROTOBUF_FIELD_OFFSET($classtype$, $field_name$_), "
1879           "$1$, PROTOBUF_FIELD_OFFSET($classtype$, _has_bits_) * 8 + "
1880           "$hasbitsoffset$, $2$, $3$},\n",
1881           tag, type, ptr);
1882     } else {
1883       format(
1884           "{PROTOBUF_FIELD_OFFSET($classtype$, $field_name$_), "
1885           "$1$, ~0u, $2$, $3$},\n",
1886           tag, type, ptr);
1887     }
1888   }
1889   int num_field_metadata = 1 + sorted.size() + sorted_extensions.size();
1890   num_field_metadata++;
1891   std::string serializer = UseUnknownFieldSet(descriptor_->file(), options_)
1892                                ? "UnknownFieldSetSerializer"
1893                                : "UnknownFieldSerializerLite";
1894   format(
1895       "{PROTOBUF_FIELD_OFFSET($classtype$, _internal_metadata_), 0, ~0u, "
1896       "::$proto_ns$::internal::FieldMetadata::kSpecial, reinterpret_cast<const "
1897       "void*>(::$proto_ns$::internal::$1$)},\n",
1898       serializer);
1899   return num_field_metadata;
1900 }
1901 
GenerateClassMethods(io::Printer * printer)1902 void MessageGenerator::GenerateClassMethods(io::Printer* printer) {
1903   Formatter format(printer, variables_);
1904   if (IsMapEntryMessage(descriptor_)) {
1905     format(
1906         "$classname$::$classname$() {}\n"
1907         "$classname$::$classname$(::$proto_ns$::Arena* arena)\n"
1908         "    : SuperType(arena) {}\n"
1909         "void $classname$::MergeFrom(const $classname$& other) {\n"
1910         "  MergeFromInternal(other);\n"
1911         "}\n");
1912     if (HasDescriptorMethods(descriptor_->file(), options_)) {
1913       format(
1914           "::$proto_ns$::Metadata $classname$::GetMetadata() const {\n"
1915           "  return GetMetadataStatic();\n"
1916           "}\n");
1917       format(
1918           "void $classname$::MergeFrom(\n"
1919           "    const ::$proto_ns$::Message& other) {\n"
1920           "  ::$proto_ns$::Message::MergeFrom(other);\n"
1921           "}\n"
1922           "\n");
1923     }
1924     return;
1925   }
1926 
1927   if (IsAnyMessage(descriptor_, options_)) {
1928     if (HasDescriptorMethods(descriptor_->file(), options_)) {
1929       format(
1930           "bool $classname$::GetAnyFieldDescriptors(\n"
1931           "    const ::$proto_ns$::Message& message,\n"
1932           "    const ::$proto_ns$::FieldDescriptor** type_url_field,\n"
1933           "    const ::$proto_ns$::FieldDescriptor** value_field) {\n"
1934           "  return ::$proto_ns$::internal::GetAnyFieldDescriptors(\n"
1935           "      message, type_url_field, value_field);\n"
1936           "}\n");
1937     }
1938     format(
1939         "bool $classname$::ParseAnyTypeUrl(\n"
1940         "    ::PROTOBUF_NAMESPACE_ID::ConstStringParam type_url,\n"
1941         "    std::string* full_type_name) {\n"
1942         "  return ::$proto_ns$::internal::ParseAnyTypeUrl(type_url,\n"
1943         "                                             full_type_name);\n"
1944         "}\n"
1945         "\n");
1946   }
1947 
1948   format(
1949       "class $classname$::_Internal {\n"
1950       " public:\n");
1951   format.Indent();
1952   if (!has_bit_indices_.empty()) {
1953     format(
1954         "using HasBits = decltype(std::declval<$classname$>()._has_bits_);\n");
1955   }
1956   for (auto field : FieldRange(descriptor_)) {
1957     field_generators_.get(field).GenerateInternalAccessorDeclarations(printer);
1958     if (IsFieldStripped(field, options_)) {
1959       continue;
1960     }
1961     if (HasHasbit(field)) {
1962       int has_bit_index = HasBitIndex(field);
1963       GOOGLE_CHECK_NE(has_bit_index, kNoHasbit) << field->full_name();
1964       format(
1965           "static void set_has_$1$(HasBits* has_bits) {\n"
1966           "  (*has_bits)[$2$] |= $3$u;\n"
1967           "}\n",
1968           FieldName(field), has_bit_index / 32, (1u << (has_bit_index % 32)));
1969     }
1970   }
1971   if (num_required_fields_ > 0) {
1972     const std::vector<uint32> masks_for_has_bits = RequiredFieldsBitMask();
1973     format(
1974         "static bool MissingRequiredFields(const HasBits& has_bits) "
1975         "{\n"
1976         "  return $1$;\n"
1977         "}\n",
1978         ConditionalToCheckBitmasks(masks_for_has_bits, false, "has_bits"));
1979   }
1980 
1981   format.Outdent();
1982   format("};\n\n");
1983   for (auto field : FieldRange(descriptor_)) {
1984     if (!IsFieldStripped(field, options_)) {
1985       field_generators_.get(field).GenerateInternalAccessorDefinitions(printer);
1986     }
1987   }
1988 
1989   // Generate non-inline field definitions.
1990   for (auto field : FieldRange(descriptor_)) {
1991     if (IsFieldStripped(field, options_)) {
1992       continue;
1993     }
1994     field_generators_.get(field).GenerateNonInlineAccessorDefinitions(printer);
1995     if (IsCrossFileMaybeMap(field)) {
1996       Formatter::SaveState saver(&format);
1997       std::map<std::string, std::string> vars;
1998       SetCommonFieldVariables(field, &vars, options_);
1999       if (field->real_containing_oneof()) {
2000         SetCommonOneofFieldVariables(field, &vars);
2001       }
2002       format.AddMap(vars);
2003       GenerateFieldClear(field, false, format);
2004     }
2005   }
2006 
2007   GenerateStructors(printer);
2008   format("\n");
2009 
2010   if (descriptor_->real_oneof_decl_count() > 0) {
2011     GenerateOneofClear(printer);
2012     format("\n");
2013   }
2014 
2015   if (HasGeneratedMethods(descriptor_->file(), options_)) {
2016     GenerateClear(printer);
2017     format("\n");
2018 
2019     GenerateMergeFromCodedStream(printer);
2020     format("\n");
2021 
2022     GenerateSerializeWithCachedSizesToArray(printer);
2023     format("\n");
2024 
2025     GenerateByteSize(printer);
2026     format("\n");
2027 
2028     GenerateMergeFrom(printer);
2029     format("\n");
2030 
2031     GenerateClassSpecificMergeFrom(printer);
2032     format("\n");
2033 
2034     GenerateCopyFrom(printer);
2035     format("\n");
2036 
2037     GenerateIsInitialized(printer);
2038     format("\n");
2039   }
2040 
2041   GenerateSwap(printer);
2042   format("\n");
2043 
2044   if (options_.table_driven_serialization) {
2045     format(
2046         "const void* $classname$::InternalGetTable() const {\n"
2047         "  return ::$tablename$::serialization_table + $1$;\n"
2048         "}\n"
2049         "\n",
2050         index_in_file_messages_);
2051   }
2052   if (HasDescriptorMethods(descriptor_->file(), options_)) {
2053     format(
2054         "::$proto_ns$::Metadata $classname$::GetMetadata() const {\n"
2055         "  return GetMetadataStatic();\n"
2056         "}\n"
2057         "\n");
2058   } else {
2059     format(
2060         "std::string $classname$::GetTypeName() const {\n"
2061         "  return \"$full_name$\";\n"
2062         "}\n"
2063         "\n");
2064   }
2065 
2066 }
2067 
GenerateParseOffsets(io::Printer * printer)2068 size_t MessageGenerator::GenerateParseOffsets(io::Printer* printer) {
2069   Formatter format(printer, variables_);
2070 
2071   if (!table_driven_) {
2072     return 0;
2073   }
2074 
2075   // Field "0" is special:  We use it in our switch statement of processing
2076   // types to handle the successful end tag case.
2077   format("{0, 0, 0, ::$proto_ns$::internal::kInvalidMask, 0, 0},\n");
2078   int last_field_number = 1;
2079 
2080   std::vector<const FieldDescriptor*> ordered_fields =
2081       SortFieldsByNumber(descriptor_);
2082 
2083   for (auto field : ordered_fields) {
2084     Formatter::SaveState saver(&format);
2085     GOOGLE_CHECK_GE(field->number(), last_field_number);
2086 
2087     for (; last_field_number < field->number(); last_field_number++) {
2088       format(
2089           "{ 0, 0, ::$proto_ns$::internal::kInvalidMask,\n"
2090           "  ::$proto_ns$::internal::kInvalidMask, 0, 0 },\n");
2091     }
2092     last_field_number++;
2093 
2094     unsigned char normal_wiretype, packed_wiretype, processing_type;
2095     normal_wiretype = WireFormat::WireTypeForFieldType(field->type());
2096 
2097     if (field->is_packable()) {
2098       packed_wiretype = WireFormatLite::WIRETYPE_LENGTH_DELIMITED;
2099     } else {
2100       packed_wiretype = internal::kNotPackedMask;
2101     }
2102 
2103     processing_type = static_cast<unsigned>(field->type());
2104     if (field->type() == FieldDescriptor::TYPE_STRING) {
2105       switch (EffectiveStringCType(field, options_)) {
2106         case FieldOptions::STRING:
2107           break;
2108         case FieldOptions::CORD:
2109           processing_type = internal::TYPE_STRING_CORD;
2110           break;
2111         case FieldOptions::STRING_PIECE:
2112           processing_type = internal::TYPE_STRING_STRING_PIECE;
2113           break;
2114       }
2115     } else if (field->type() == FieldDescriptor::TYPE_BYTES) {
2116       switch (EffectiveStringCType(field, options_)) {
2117         case FieldOptions::STRING:
2118           break;
2119         case FieldOptions::CORD:
2120           processing_type = internal::TYPE_BYTES_CORD;
2121           break;
2122         case FieldOptions::STRING_PIECE:
2123           processing_type = internal::TYPE_BYTES_STRING_PIECE;
2124           break;
2125       }
2126     }
2127 
2128     processing_type |= static_cast<unsigned>(
2129         field->is_repeated() ? internal::kRepeatedMask : 0);
2130     processing_type |= static_cast<unsigned>(
2131         field->real_containing_oneof() ? internal::kOneofMask : 0);
2132 
2133     if (field->is_map()) {
2134       processing_type = internal::TYPE_MAP;
2135     }
2136 
2137     const unsigned char tag_size =
2138         WireFormat::TagSize(field->number(), field->type());
2139 
2140     std::map<std::string, std::string> vars;
2141     if (field->real_containing_oneof()) {
2142       vars["name"] = field->containing_oneof()->name();
2143       vars["presence"] = StrCat(field->containing_oneof()->index());
2144     } else {
2145       vars["name"] = FieldName(field);
2146       vars["presence"] = StrCat(has_bit_indices_[field->index()]);
2147     }
2148     vars["nwtype"] = StrCat(normal_wiretype);
2149     vars["pwtype"] = StrCat(packed_wiretype);
2150     vars["ptype"] = StrCat(processing_type);
2151     vars["tag_size"] = StrCat(tag_size);
2152 
2153     format.AddMap(vars);
2154 
2155     format(
2156         "{\n"
2157         "  PROTOBUF_FIELD_OFFSET($classtype$, $name$_),\n"
2158         "  static_cast<$uint32$>($presence$),\n"
2159         "  $nwtype$, $pwtype$, $ptype$, $tag_size$\n"
2160         "},\n");
2161   }
2162 
2163   return last_field_number;
2164 }
2165 
GenerateParseAuxTable(io::Printer * printer)2166 size_t MessageGenerator::GenerateParseAuxTable(io::Printer* printer) {
2167   Formatter format(printer, variables_);
2168 
2169   if (!table_driven_) {
2170     return 0;
2171   }
2172 
2173   std::vector<const FieldDescriptor*> ordered_fields =
2174       SortFieldsByNumber(descriptor_);
2175 
2176   format("::$proto_ns$::internal::AuxiliaryParseTableField(),\n");
2177   int last_field_number = 1;
2178   for (auto field : ordered_fields) {
2179     Formatter::SaveState saver(&format);
2180 
2181     GOOGLE_CHECK_GE(field->number(), last_field_number);
2182     for (; last_field_number < field->number(); last_field_number++) {
2183       format("::$proto_ns$::internal::AuxiliaryParseTableField(),\n");
2184     }
2185 
2186     std::map<std::string, std::string> vars;
2187     SetCommonFieldVariables(field, &vars, options_);
2188     format.AddMap(vars);
2189 
2190     switch (field->cpp_type()) {
2191       case FieldDescriptor::CPPTYPE_ENUM:
2192         if (HasPreservingUnknownEnumSemantics(field)) {
2193           format(
2194               "{::$proto_ns$::internal::AuxiliaryParseTableField::enum_aux{"
2195               "nullptr}},\n");
2196         } else {
2197           format(
2198               "{::$proto_ns$::internal::AuxiliaryParseTableField::enum_aux{"
2199               "$1$_IsValid}},\n",
2200               ClassName(field->enum_type(), true));
2201         }
2202         last_field_number++;
2203         break;
2204       case FieldDescriptor::CPPTYPE_MESSAGE: {
2205         if (field->is_map()) {
2206           format(
2207               "{::$proto_ns$::internal::AuxiliaryParseTableField::map_"
2208               "aux{&::$proto_ns$::internal::ParseMap<$1$>}},\n",
2209               QualifiedClassName(field->message_type(), options_));
2210           last_field_number++;
2211           break;
2212         }
2213         format.Set("field_classname", ClassName(field->message_type(), false));
2214         format.Set("default_instance", QualifiedDefaultInstanceName(
2215                                            field->message_type(), options_));
2216 
2217         format(
2218             "{::$proto_ns$::internal::AuxiliaryParseTableField::message_aux{\n"
2219             "  &$default_instance$}},\n");
2220         last_field_number++;
2221         break;
2222       }
2223       case FieldDescriptor::CPPTYPE_STRING: {
2224         std::string default_val;
2225         switch (EffectiveStringCType(field, options_)) {
2226           case FieldOptions::STRING:
2227             default_val = field->default_value_string().empty()
2228                               ? "&::" + variables_["proto_ns"] +
2229                                     "::internal::fixed_address_empty_string"
2230                               : "&" +
2231                                     QualifiedClassName(descriptor_, options_) +
2232                                     "::" + MakeDefaultName(field);
2233             break;
2234           case FieldOptions::CORD:
2235           case FieldOptions::STRING_PIECE:
2236             default_val =
2237                 "\"" + CEscape(field->default_value_string()) + "\"";
2238             break;
2239         }
2240         format(
2241             "{::$proto_ns$::internal::AuxiliaryParseTableField::string_aux{\n"
2242             "  $1$,\n"
2243             "  \"$2$\"\n"
2244             "}},\n",
2245             default_val, field->full_name());
2246         last_field_number++;
2247         break;
2248       }
2249       default:
2250         break;
2251     }
2252   }
2253 
2254   return last_field_number;
2255 }
2256 
GenerateOffsets(io::Printer * printer)2257 std::pair<size_t, size_t> MessageGenerator::GenerateOffsets(
2258     io::Printer* printer) {
2259   Formatter format(printer, variables_);
2260 
2261   if (!has_bit_indices_.empty() || IsMapEntryMessage(descriptor_)) {
2262     format("PROTOBUF_FIELD_OFFSET($classtype$, _has_bits_),\n");
2263   } else {
2264     format("~0u,  // no _has_bits_\n");
2265   }
2266   format("PROTOBUF_FIELD_OFFSET($classtype$, _internal_metadata_),\n");
2267   if (descriptor_->extension_range_count() > 0) {
2268     format("PROTOBUF_FIELD_OFFSET($classtype$, _extensions_),\n");
2269   } else {
2270     format("~0u,  // no _extensions_\n");
2271   }
2272   if (descriptor_->real_oneof_decl_count() > 0) {
2273     format("PROTOBUF_FIELD_OFFSET($classtype$, _oneof_case_[0]),\n");
2274   } else {
2275     format("~0u,  // no _oneof_case_\n");
2276   }
2277   if (num_weak_fields_ > 0) {
2278     format("PROTOBUF_FIELD_OFFSET($classtype$, _weak_field_map_),\n");
2279   } else {
2280     format("~0u,  // no _weak_field_map_\n");
2281   }
2282   const int kNumGenericOffsets = 5;  // the number of fixed offsets above
2283   const size_t offsets = kNumGenericOffsets + descriptor_->field_count() +
2284                          descriptor_->real_oneof_decl_count();
2285   size_t entries = offsets;
2286   for (auto field : FieldRange(descriptor_)) {
2287     if (IsFieldStripped(field, options_)) {
2288       format("~0u,  // stripped\n");
2289       continue;
2290     }
2291     // TODO(sbenza): We should not have an entry in the offset table for fields
2292     // that do not use them.
2293     if (field->options().weak() || field->real_containing_oneof()) {
2294       // Mark the field to prevent unintentional access through reflection.
2295       // Don't use the top bit because that is for unused fields.
2296       format("::$proto_ns$::internal::kInvalidFieldOffsetTag");
2297     } else {
2298       format("PROTOBUF_FIELD_OFFSET($classtype$, $1$_)", FieldName(field));
2299     }
2300 
2301     if (!IsFieldUsed(field, options_)) {
2302       format(" | 0x80000000u, // unused\n");
2303     } else {
2304       format(",\n");
2305     }
2306   }
2307 
2308   int count = 0;
2309   for (auto oneof : OneOfRange(descriptor_)) {
2310     format("PROTOBUF_FIELD_OFFSET($classtype$, $1$_),\n", oneof->name());
2311     count++;
2312   }
2313   GOOGLE_CHECK_EQ(count, descriptor_->real_oneof_decl_count());
2314 
2315   if (IsMapEntryMessage(descriptor_)) {
2316     entries += 2;
2317     format(
2318         "0,\n"
2319         "1,\n");
2320   } else if (!has_bit_indices_.empty()) {
2321     entries += has_bit_indices_.size();
2322     for (int i = 0; i < has_bit_indices_.size(); i++) {
2323       const std::string index =
2324           has_bit_indices_[i] >= 0 ? StrCat(has_bit_indices_[i]) : "~0u";
2325       format("$1$,\n", index);
2326     }
2327   }
2328 
2329   return std::make_pair(entries, offsets);
2330 }
2331 
GenerateSharedConstructorCode(io::Printer * printer)2332 void MessageGenerator::GenerateSharedConstructorCode(io::Printer* printer) {
2333   Formatter format(printer, variables_);
2334 
2335   format("void $classname$::SharedCtor() {\n");
2336   if (scc_analyzer_->GetSCCAnalysis(scc_analyzer_->GetSCC(descriptor_))
2337           .constructor_requires_initialization) {
2338     format("  ::$proto_ns$::internal::InitSCC(&$scc_info$.base);\n");
2339   }
2340 
2341   format.Indent();
2342 
2343   std::vector<bool> processed(optimized_order_.size(), false);
2344   GenerateConstructorBody(printer, processed, false);
2345 
2346   for (auto oneof : OneOfRange(descriptor_)) {
2347     format("clear_has_$1$();\n", oneof->name());
2348   }
2349 
2350   format.Outdent();
2351   format("}\n\n");
2352 }
2353 
GenerateSharedDestructorCode(io::Printer * printer)2354 void MessageGenerator::GenerateSharedDestructorCode(io::Printer* printer) {
2355   Formatter format(printer, variables_);
2356 
2357   format("void $classname$::SharedDtor() {\n");
2358   format.Indent();
2359   format("$DCHK$(GetArena() == nullptr);\n");
2360   // Write the destructors for each field except oneof members.
2361   // optimized_order_ does not contain oneof fields.
2362   for (auto field : optimized_order_) {
2363     field_generators_.get(field).GenerateDestructorCode(printer);
2364   }
2365 
2366   // Generate code to destruct oneofs. Clearing should do the work.
2367   for (auto oneof : OneOfRange(descriptor_)) {
2368     format(
2369         "if (has_$1$()) {\n"
2370         "  clear_$1$();\n"
2371         "}\n",
2372         oneof->name());
2373   }
2374 
2375   if (num_weak_fields_) {
2376     format("_weak_field_map_.ClearAll();\n");
2377   }
2378   format.Outdent();
2379   format(
2380       "}\n"
2381       "\n");
2382 }
2383 
GenerateArenaDestructorCode(io::Printer * printer)2384 void MessageGenerator::GenerateArenaDestructorCode(io::Printer* printer) {
2385   Formatter format(printer, variables_);
2386 
2387   // Generate the ArenaDtor() method. Track whether any fields actually produced
2388   // code that needs to be called.
2389   format("void $classname$::ArenaDtor(void* object) {\n");
2390   format.Indent();
2391 
2392   // This code is placed inside a static method, rather than an ordinary one,
2393   // since that simplifies Arena's destructor list (ordinary function pointers
2394   // rather than member function pointers). _this is the object being
2395   // destructed.
2396   format(
2397       "$classname$* _this = reinterpret_cast< $classname$* >(object);\n"
2398       // avoid an "unused variable" warning in case no fields have dtor code.
2399       "(void)_this;\n");
2400 
2401   bool need_registration = false;
2402   // Process non-oneof fields first.
2403   for (auto field : optimized_order_) {
2404     if (field_generators_.get(field).GenerateArenaDestructorCode(printer)) {
2405       need_registration = true;
2406     }
2407   }
2408 
2409   // Process oneof fields.
2410   //
2411   // Note:  As of 10/5/2016, GenerateArenaDestructorCode does not emit anything
2412   // and returns false for oneof fields.
2413   for (auto oneof : OneOfRange(descriptor_)) {
2414     for (auto field : FieldRange(oneof)) {
2415       if (!IsFieldStripped(field, options_) &&
2416           field_generators_.get(field).GenerateArenaDestructorCode(printer)) {
2417         need_registration = true;
2418       }
2419     }
2420   }
2421   if (num_weak_fields_) {
2422     // _this is the object being destructed (we are inside a static method
2423     // here).
2424     format("_this->_weak_field_map_.ClearAll();\n");
2425     need_registration = true;
2426   }
2427 
2428   format.Outdent();
2429   format("}\n");
2430 
2431   if (need_registration) {
2432     format(
2433         "inline void $classname$::RegisterArenaDtor(::$proto_ns$::Arena* "
2434         "arena) {\n"
2435         "  if (arena != nullptr) {\n"
2436         "    arena->OwnCustomDestructor(this, &$classname$::ArenaDtor);\n"
2437         "  }\n"
2438         "}\n");
2439   } else {
2440     format(
2441         "void $classname$::RegisterArenaDtor(::$proto_ns$::Arena*) {\n"
2442         "}\n");
2443   }
2444 }
2445 
GenerateConstructorBody(io::Printer * printer,std::vector<bool> processed,bool copy_constructor) const2446 void MessageGenerator::GenerateConstructorBody(io::Printer* printer,
2447                                                std::vector<bool> processed,
2448                                                bool copy_constructor) const {
2449   Formatter format(printer, variables_);
2450 
2451   const RunMap runs = FindRuns(
2452       optimized_order_, [copy_constructor, this](const FieldDescriptor* field) {
2453         return (copy_constructor && IsPOD(field)) ||
2454                (!copy_constructor &&
2455                 CanBeManipulatedAsRawBytes(field, options_));
2456       });
2457 
2458   std::string pod_template;
2459   if (copy_constructor) {
2460     pod_template =
2461         "::memcpy(&$first$_, &from.$first$_,\n"
2462         "  static_cast<size_t>(reinterpret_cast<char*>(&$last$_) -\n"
2463         "  reinterpret_cast<char*>(&$first$_)) + sizeof($last$_));\n";
2464   } else {
2465     pod_template =
2466         "::memset(reinterpret_cast<char*>(this) + static_cast<size_t>(\n"
2467         "    reinterpret_cast<char*>(&$first$_) - reinterpret_cast<char*>(this)),\n"
2468         "    0, static_cast<size_t>(reinterpret_cast<char*>(&$last$_) -\n"
2469         "    reinterpret_cast<char*>(&$first$_)) + sizeof($last$_));\n";
2470   }
2471 
2472   for (int i = 0; i < optimized_order_.size(); ++i) {
2473     if (processed[i]) {
2474       continue;
2475     }
2476 
2477     const FieldDescriptor* field = optimized_order_[i];
2478     const auto it = runs.find(field);
2479 
2480     // We only apply the memset technique to runs of more than one field, as
2481     // assignment is better than memset for generated code clarity.
2482     if (it != runs.end() && it->second > 1) {
2483       // Use a memset, then skip run_length fields.
2484       const size_t run_length = it->second;
2485       const std::string first_field_name = FieldName(field);
2486       const std::string last_field_name =
2487           FieldName(optimized_order_[i + run_length - 1]);
2488 
2489       format.Set("first", first_field_name);
2490       format.Set("last", last_field_name);
2491 
2492       format(pod_template.c_str());
2493 
2494       i += run_length - 1;
2495       // ++i at the top of the loop.
2496     } else {
2497       if (copy_constructor) {
2498         field_generators_.get(field).GenerateCopyConstructorCode(printer);
2499       } else {
2500         field_generators_.get(field).GenerateConstructorCode(printer);
2501       }
2502     }
2503   }
2504 }
2505 
GenerateStructors(io::Printer * printer)2506 void MessageGenerator::GenerateStructors(io::Printer* printer) {
2507   Formatter format(printer, variables_);
2508 
2509   std::string superclass;
2510   superclass = SuperClassName(descriptor_, options_);
2511   std::string initializer_with_arena = superclass + "(arena)";
2512 
2513   if (descriptor_->extension_range_count() > 0) {
2514     initializer_with_arena += ",\n  _extensions_(arena)";
2515   }
2516 
2517   // Initialize member variables with arena constructor.
2518   for (auto field : optimized_order_) {
2519     GOOGLE_DCHECK(!IsFieldStripped(field, options_));
2520     bool has_arena_constructor = field->is_repeated();
2521     if (!field->real_containing_oneof() &&
2522         (IsLazy(field, options_) || IsStringPiece(field, options_))) {
2523       has_arena_constructor = true;
2524     }
2525     if (has_arena_constructor) {
2526       initializer_with_arena +=
2527           std::string(",\n  ") + FieldName(field) + std::string("_(arena)");
2528     }
2529   }
2530 
2531   if (IsAnyMessage(descriptor_, options_)) {
2532     initializer_with_arena += ",\n  _any_metadata_(&type_url_, &value_)";
2533   }
2534   if (num_weak_fields_ > 0) {
2535     initializer_with_arena += ", _weak_field_map_(arena)";
2536   }
2537 
2538   std::string initializer_null = superclass + "()";
2539   if (IsAnyMessage(descriptor_, options_)) {
2540     initializer_null += ", _any_metadata_(&type_url_, &value_)";
2541   }
2542   if (num_weak_fields_ > 0) {
2543     initializer_null += ", _weak_field_map_(nullptr)";
2544   }
2545 
2546   format(
2547       "$classname$::$classname$(::$proto_ns$::Arena* arena)\n"
2548       "  : $1$ {\n"
2549       "  SharedCtor();\n"
2550       "  RegisterArenaDtor(arena);\n"
2551       "  // @@protoc_insertion_point(arena_constructor:$full_name$)\n"
2552       "}\n",
2553       initializer_with_arena);
2554 
2555   std::map<std::string, std::string> vars;
2556   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
2557   format.AddMap(vars);
2558 
2559   // Generate the copy constructor.
2560   if (UsingImplicitWeakFields(descriptor_->file(), options_)) {
2561     // If we are in lite mode and using implicit weak fields, we generate a
2562     // one-liner copy constructor that delegates to MergeFrom. This saves some
2563     // code size and also cuts down on the complexity of implicit weak fields.
2564     // We might eventually want to do this for all lite protos.
2565     format(
2566         "$classname$::$classname$(const $classname$& from)\n"
2567         "  : $classname$() {\n"
2568         "  MergeFrom(from);\n"
2569         "}\n");
2570   } else {
2571     format(
2572         "$classname$::$classname$(const $classname$& from)\n"
2573         "  : $superclass$()");
2574     format.Indent();
2575     format.Indent();
2576     format.Indent();
2577 
2578     if (!has_bit_indices_.empty()) {
2579       format(",\n_has_bits_(from._has_bits_)");
2580     }
2581 
2582     std::vector<bool> processed(optimized_order_.size(), false);
2583     for (int i = 0; i < optimized_order_.size(); i++) {
2584       auto field = optimized_order_[i];
2585       if (!(field->is_repeated() && !(field->is_map())) &&
2586           !IsCord(field, options_)) {
2587         continue;
2588       }
2589 
2590       processed[i] = true;
2591       format(",\n$1$_(from.$1$_)", FieldName(field));
2592     }
2593 
2594     if (IsAnyMessage(descriptor_, options_)) {
2595       format(",\n_any_metadata_(&type_url_, &value_)");
2596     }
2597     if (num_weak_fields_ > 0) {
2598       format(",\n_weak_field_map_(from._weak_field_map_)");
2599     }
2600 
2601     format.Outdent();
2602     format.Outdent();
2603     format(" {\n");
2604 
2605     format(
2606         "_internal_metadata_.MergeFrom<$unknown_fields_type$>(from._internal_"
2607         "metadata_);\n");
2608 
2609     if (descriptor_->extension_range_count() > 0) {
2610       format("_extensions_.MergeFrom(from._extensions_);\n");
2611     }
2612 
2613     GenerateConstructorBody(printer, processed, true);
2614 
2615     // Copy oneof fields. Oneof field requires oneof case check.
2616     for (auto oneof : OneOfRange(descriptor_)) {
2617       format(
2618           "clear_has_$1$();\n"
2619           "switch (from.$1$_case()) {\n",
2620           oneof->name());
2621       format.Indent();
2622       for (auto field : FieldRange(oneof)) {
2623         format("case k$1$: {\n", UnderscoresToCamelCase(field->name(), true));
2624         format.Indent();
2625         if (!IsFieldStripped(field, options_)) {
2626           field_generators_.get(field).GenerateMergingCode(printer);
2627         }
2628         format("break;\n");
2629         format.Outdent();
2630         format("}\n");
2631       }
2632       format(
2633           "case $1$_NOT_SET: {\n"
2634           "  break;\n"
2635           "}\n",
2636           ToUpper(oneof->name()));
2637       format.Outdent();
2638       format("}\n");
2639     }
2640 
2641     format.Outdent();
2642     format(
2643         "  // @@protoc_insertion_point(copy_constructor:$full_name$)\n"
2644         "}\n"
2645         "\n");
2646   }
2647 
2648   // Generate the shared constructor code.
2649   GenerateSharedConstructorCode(printer);
2650 
2651   // Generate the destructor.
2652   format(
2653       "$classname$::~$classname$() {\n"
2654       "  // @@protoc_insertion_point(destructor:$full_name$)\n"
2655       "  SharedDtor();\n"
2656       "  _internal_metadata_.Delete<$unknown_fields_type$>();\n"
2657       "}\n"
2658       "\n");
2659 
2660   // Generate the shared destructor code.
2661   GenerateSharedDestructorCode(printer);
2662 
2663   // Generate the arena-specific destructor code.
2664   GenerateArenaDestructorCode(printer);
2665 
2666   // Generate SetCachedSize.
2667   format(
2668       "void $classname$::SetCachedSize(int size) const {\n"
2669       "  _cached_size_.Set(size);\n"
2670       "}\n");
2671 
2672   format(
2673       "const $classname$& $classname$::default_instance() {\n"
2674       "  "
2675       "::$proto_ns$::internal::InitSCC(&::$scc_info$.base)"
2676       ";\n"
2677       "  return *internal_default_instance();\n"
2678       "}\n\n");
2679 }
2680 
GenerateSourceInProto2Namespace(io::Printer * printer)2681 void MessageGenerator::GenerateSourceInProto2Namespace(io::Printer* printer) {
2682   Formatter format(printer, variables_);
2683   format(
2684       "template<> "
2685       "PROTOBUF_NOINLINE "
2686       "$classtype$* Arena::CreateMaybeMessage< $classtype$ >(Arena* arena) {\n"
2687       "  return Arena::CreateMessageInternal< $classtype$ >(arena);\n"
2688       "}\n");
2689 }
2690 
GenerateClear(io::Printer * printer)2691 void MessageGenerator::GenerateClear(io::Printer* printer) {
2692   Formatter format(printer, variables_);
2693 
2694   // The maximum number of bytes we will memset to zero without checking their
2695   // hasbit to see if a zero-init is necessary.
2696   const int kMaxUnconditionalPrimitiveBytesClear = 4;
2697 
2698   format(
2699       "void $classname$::Clear() {\n"
2700       "// @@protoc_insertion_point(message_clear_start:$full_name$)\n");
2701   format.Indent();
2702 
2703   format(
2704       // TODO(jwb): It would be better to avoid emitting this if it is not used,
2705       // rather than emitting a workaround for the resulting warning.
2706       "$uint32$ cached_has_bits = 0;\n"
2707       "// Prevent compiler warnings about cached_has_bits being unused\n"
2708       "(void) cached_has_bits;\n\n");
2709 
2710   if (descriptor_->extension_range_count() > 0) {
2711     format("_extensions_.Clear();\n");
2712   }
2713 
2714   // Collect fields into chunks. Each chunk may have an if() condition that
2715   // checks all hasbits in the chunk and skips it if none are set.
2716   int zero_init_bytes = 0;
2717   for (const auto& field : optimized_order_) {
2718     if (CanInitializeByZeroing(field)) {
2719       zero_init_bytes += EstimateAlignmentSize(field);
2720     }
2721   }
2722   bool merge_zero_init = zero_init_bytes > kMaxUnconditionalPrimitiveBytesClear;
2723   int chunk_count = 0;
2724 
2725   std::vector<std::vector<const FieldDescriptor*>> chunks = CollectFields(
2726       optimized_order_,
2727       [&](const FieldDescriptor* a, const FieldDescriptor* b) -> bool {
2728         chunk_count++;
2729         // This predicate guarantees that there is only a single zero-init
2730         // (memset) per chunk, and if present it will be at the beginning.
2731         bool same = HasByteIndex(a) == HasByteIndex(b) &&
2732                     a->is_repeated() == b->is_repeated() &&
2733                     (CanInitializeByZeroing(a) == CanInitializeByZeroing(b) ||
2734                      (CanInitializeByZeroing(a) &&
2735                       (chunk_count == 1 || merge_zero_init)));
2736         if (!same) chunk_count = 0;
2737         return same;
2738       });
2739 
2740   ColdChunkSkipper cold_skipper(options_, chunks, has_bit_indices_, kColdRatio);
2741   int cached_has_word_index = -1;
2742 
2743   for (int chunk_index = 0; chunk_index < chunks.size(); chunk_index++) {
2744     std::vector<const FieldDescriptor*>& chunk = chunks[chunk_index];
2745     cold_skipper.OnStartChunk(chunk_index, cached_has_word_index, "", printer);
2746 
2747     const FieldDescriptor* memset_start = nullptr;
2748     const FieldDescriptor* memset_end = nullptr;
2749     bool saw_non_zero_init = false;
2750 
2751     for (const auto& field : chunk) {
2752       if (CanInitializeByZeroing(field)) {
2753         GOOGLE_CHECK(!saw_non_zero_init);
2754         if (!memset_start) memset_start = field;
2755         memset_end = field;
2756       } else {
2757         saw_non_zero_init = true;
2758       }
2759     }
2760 
2761     // Whether we wrap this chunk in:
2762     //   if (cached_has_bits & <chunk hasbits) { /* chunk. */ }
2763     // We can omit the if() for chunk size 1, or if our fields do not have
2764     // hasbits. I don't understand the rationale for the last part of the
2765     // condition, but it matches the old logic.
2766     const bool have_outer_if = HasBitIndex(chunk.front()) != kNoHasbit &&
2767                                chunk.size() > 1 &&
2768                                (memset_end != chunk.back() || merge_zero_init);
2769 
2770     if (have_outer_if) {
2771       // Emit an if() that will let us skip the whole chunk if none are set.
2772       uint32 chunk_mask = GenChunkMask(chunk, has_bit_indices_);
2773       std::string chunk_mask_str =
2774           StrCat(strings::Hex(chunk_mask, strings::ZERO_PAD_8));
2775 
2776       // Check (up to) 8 has_bits at a time if we have more than one field in
2777       // this chunk.  Due to field layout ordering, we may check
2778       // _has_bits_[last_chunk * 8 / 32] multiple times.
2779       GOOGLE_DCHECK_LE(2, popcnt(chunk_mask));
2780       GOOGLE_DCHECK_GE(8, popcnt(chunk_mask));
2781 
2782       if (cached_has_word_index != HasWordIndex(chunk.front())) {
2783         cached_has_word_index = HasWordIndex(chunk.front());
2784         format("cached_has_bits = _has_bits_[$1$];\n", cached_has_word_index);
2785       }
2786       format("if (cached_has_bits & 0x$1$u) {\n", chunk_mask_str);
2787       format.Indent();
2788     }
2789 
2790     if (memset_start) {
2791       if (memset_start == memset_end) {
2792         // For clarity, do not memset a single field.
2793         field_generators_.get(memset_start)
2794             .GenerateMessageClearingCode(printer);
2795       } else {
2796         format(
2797             "::memset(&$1$_, 0, static_cast<size_t>(\n"
2798             "    reinterpret_cast<char*>(&$2$_) -\n"
2799             "    reinterpret_cast<char*>(&$1$_)) + sizeof($2$_));\n",
2800             FieldName(memset_start), FieldName(memset_end));
2801       }
2802     }
2803 
2804     // Clear all non-zero-initializable fields in the chunk.
2805     for (const auto& field : chunk) {
2806       if (CanInitializeByZeroing(field)) continue;
2807       // It's faster to just overwrite primitive types, but we should only
2808       // clear strings and messages if they were set.
2809       //
2810       // TODO(kenton):  Let the CppFieldGenerator decide this somehow.
2811       bool have_enclosing_if =
2812           HasBitIndex(field) != kNoHasbit &&
2813           (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE ||
2814            field->cpp_type() == FieldDescriptor::CPPTYPE_STRING);
2815 
2816       if (have_enclosing_if) {
2817         PrintPresenceCheck(format, field, has_bit_indices_, printer,
2818                            &cached_has_word_index);
2819       }
2820 
2821       field_generators_.get(field).GenerateMessageClearingCode(printer);
2822 
2823       if (have_enclosing_if) {
2824         format.Outdent();
2825         format("}\n");
2826       }
2827     }
2828 
2829     if (have_outer_if) {
2830       format.Outdent();
2831       format("}\n");
2832     }
2833 
2834     if (cold_skipper.OnEndChunk(chunk_index, printer)) {
2835       // Reset here as it may have been updated in just closed if statement.
2836       cached_has_word_index = -1;
2837     }
2838   }
2839 
2840   // Step 4: Unions.
2841   for (auto oneof : OneOfRange(descriptor_)) {
2842     format("clear_$1$();\n", oneof->name());
2843   }
2844 
2845   if (num_weak_fields_) {
2846     format("_weak_field_map_.ClearAll();\n");
2847   }
2848 
2849   if (!has_bit_indices_.empty()) {
2850     // Step 5: Everything else.
2851     format("_has_bits_.Clear();\n");
2852   }
2853 
2854   std::map<std::string, std::string> vars;
2855   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
2856   format.AddMap(vars);
2857   format("_internal_metadata_.Clear<$unknown_fields_type$>();\n");
2858 
2859   format.Outdent();
2860   format("}\n");
2861 }
2862 
GenerateOneofClear(io::Printer * printer)2863 void MessageGenerator::GenerateOneofClear(io::Printer* printer) {
2864   // Generated function clears the active field and union case (e.g. foo_case_).
2865   int i = 0;
2866   for (auto oneof : OneOfRange(descriptor_)) {
2867     Formatter format(printer, variables_);
2868     format.Set("oneofname", oneof->name());
2869 
2870     format(
2871         "void $classname$::clear_$oneofname$() {\n"
2872         "// @@protoc_insertion_point(one_of_clear_start:$full_name$)\n");
2873     format.Indent();
2874     format("switch ($oneofname$_case()) {\n");
2875     format.Indent();
2876     for (auto field : FieldRange(oneof)) {
2877       format("case k$1$: {\n", UnderscoresToCamelCase(field->name(), true));
2878       format.Indent();
2879       // We clear only allocated objects in oneofs
2880       if (!IsStringOrMessage(field) || IsFieldStripped(field, options_)) {
2881         format("// No need to clear\n");
2882       } else {
2883         field_generators_.get(field).GenerateClearingCode(printer);
2884       }
2885       format("break;\n");
2886       format.Outdent();
2887       format("}\n");
2888     }
2889     format(
2890         "case $1$_NOT_SET: {\n"
2891         "  break;\n"
2892         "}\n",
2893         ToUpper(oneof->name()));
2894     format.Outdent();
2895     format(
2896         "}\n"
2897         "_oneof_case_[$1$] = $2$_NOT_SET;\n",
2898         i, ToUpper(oneof->name()));
2899     format.Outdent();
2900     format(
2901         "}\n"
2902         "\n");
2903     i++;
2904   }
2905 }
2906 
GenerateSwap(io::Printer * printer)2907 void MessageGenerator::GenerateSwap(io::Printer* printer) {
2908   Formatter format(printer, variables_);
2909 
2910   format("void $classname$::InternalSwap($classname$* other) {\n");
2911   format.Indent();
2912   format("using std::swap;\n");
2913 
2914   if (HasGeneratedMethods(descriptor_->file(), options_)) {
2915     if (descriptor_->extension_range_count() > 0) {
2916       format("_extensions_.Swap(&other->_extensions_);\n");
2917     }
2918 
2919     std::map<std::string, std::string> vars;
2920     SetUnknkownFieldsVariable(descriptor_, options_, &vars);
2921     format.AddMap(vars);
2922     format(
2923         "_internal_metadata_.Swap<$unknown_fields_type$>(&other->_internal_"
2924         "metadata_);\n");
2925 
2926     if (!has_bit_indices_.empty()) {
2927       for (int i = 0; i < HasBitsSize(); ++i) {
2928         format("swap(_has_bits_[$1$], other->_has_bits_[$1$]);\n", i);
2929       }
2930     }
2931 
2932     // If possible, we swap several fields at once, including padding.
2933     const RunMap runs =
2934         FindRuns(optimized_order_, [this](const FieldDescriptor* field) {
2935           return CanBeManipulatedAsRawBytes(field, options_);
2936         });
2937 
2938     for (int i = 0; i < optimized_order_.size(); ++i) {
2939       const FieldDescriptor* field = optimized_order_[i];
2940       const auto it = runs.find(field);
2941 
2942       // We only apply the memswap technique to runs of more than one field, as
2943       // `swap(field_, other.field_)` is better than
2944       // `memswap<...>(&field_, &other.field_)` for generated code readability.
2945       if (it != runs.end() && it->second > 1) {
2946         // Use a memswap, then skip run_length fields.
2947         const size_t run_length = it->second;
2948         const std::string first_field_name = FieldName(field);
2949         const std::string last_field_name =
2950             FieldName(optimized_order_[i + run_length - 1]);
2951 
2952         format.Set("first", first_field_name);
2953         format.Set("last", last_field_name);
2954 
2955         format(
2956             "::PROTOBUF_NAMESPACE_ID::internal::memswap<\n"
2957             "    PROTOBUF_FIELD_OFFSET($classname$, $last$_)\n"
2958             "    + sizeof($classname$::$last$_)\n"
2959             "    - PROTOBUF_FIELD_OFFSET($classname$, $first$_)>(\n"
2960             "        reinterpret_cast<char*>(&$first$_),\n"
2961             "        reinterpret_cast<char*>(&other->$first$_));\n");
2962 
2963         i += run_length - 1;
2964         // ++i at the top of the loop.
2965       } else {
2966         field_generators_.get(field).GenerateSwappingCode(printer);
2967       }
2968     }
2969 
2970     for (auto oneof : OneOfRange(descriptor_)) {
2971       format("swap($1$_, other->$1$_);\n", oneof->name());
2972     }
2973 
2974     for (int i = 0; i < descriptor_->real_oneof_decl_count(); i++) {
2975       format("swap(_oneof_case_[$1$], other->_oneof_case_[$1$]);\n", i);
2976     }
2977 
2978     if (num_weak_fields_) {
2979       format("_weak_field_map_.UnsafeArenaSwap(&other->_weak_field_map_);\n");
2980     }
2981   } else {
2982     format("GetReflection()->Swap(this, other);");
2983   }
2984 
2985   format.Outdent();
2986   format("}\n");
2987 }
2988 
GenerateMergeFrom(io::Printer * printer)2989 void MessageGenerator::GenerateMergeFrom(io::Printer* printer) {
2990   Formatter format(printer, variables_);
2991   if (HasDescriptorMethods(descriptor_->file(), options_)) {
2992     // Generate the generalized MergeFrom (aka that which takes in the Message
2993     // base class as a parameter).
2994     format(
2995         "void $classname$::MergeFrom(const ::$proto_ns$::Message& from) {\n"
2996         "// @@protoc_insertion_point(generalized_merge_from_start:"
2997         "$full_name$)\n"
2998         "  $DCHK$_NE(&from, this);\n");
2999     format.Indent();
3000 
3001     // Cast the message to the proper type. If we find that the message is
3002     // *not* of the proper type, we can still call Merge via the reflection
3003     // system, as the GOOGLE_CHECK above ensured that we have the same descriptor
3004     // for each message.
3005     format(
3006         "const $classname$* source =\n"
3007         "    ::$proto_ns$::DynamicCastToGenerated<$classname$>(\n"
3008         "        &from);\n"
3009         "if (source == nullptr) {\n"
3010         "// @@protoc_insertion_point(generalized_merge_from_cast_fail:"
3011         "$full_name$)\n"
3012         "  ::$proto_ns$::internal::ReflectionOps::Merge(from, this);\n"
3013         "} else {\n"
3014         "// @@protoc_insertion_point(generalized_merge_from_cast_success:"
3015         "$full_name$)\n"
3016         "  MergeFrom(*source);\n"
3017         "}\n");
3018 
3019     format.Outdent();
3020     format("}\n");
3021   } else {
3022     // Generate CheckTypeAndMergeFrom().
3023     format(
3024         "void $classname$::CheckTypeAndMergeFrom(\n"
3025         "    const ::$proto_ns$::MessageLite& from) {\n"
3026         "  MergeFrom(*::$proto_ns$::internal::DownCast<const $classname$*>(\n"
3027         "      &from));\n"
3028         "}\n");
3029   }
3030 }
3031 
GenerateClassSpecificMergeFrom(io::Printer * printer)3032 void MessageGenerator::GenerateClassSpecificMergeFrom(io::Printer* printer) {
3033   // Generate the class-specific MergeFrom, which avoids the GOOGLE_CHECK and cast.
3034   Formatter format(printer, variables_);
3035   format(
3036       "void $classname$::MergeFrom(const $classname$& from) {\n"
3037       "// @@protoc_insertion_point(class_specific_merge_from_start:"
3038       "$full_name$)\n"
3039       "  $DCHK$_NE(&from, this);\n");
3040   format.Indent();
3041 
3042   if (descriptor_->extension_range_count() > 0) {
3043     format("_extensions_.MergeFrom(from._extensions_);\n");
3044   }
3045   std::map<std::string, std::string> vars;
3046   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
3047   format.AddMap(vars);
3048   format(
3049       "_internal_metadata_.MergeFrom<$unknown_fields_type$>(from._internal_"
3050       "metadata_);\n"
3051       "$uint32$ cached_has_bits = 0;\n"
3052       "(void) cached_has_bits;\n\n");
3053 
3054   std::vector<std::vector<const FieldDescriptor*>> chunks = CollectFields(
3055       optimized_order_,
3056       [&](const FieldDescriptor* a, const FieldDescriptor* b) -> bool {
3057         return HasByteIndex(a) == HasByteIndex(b);
3058       });
3059 
3060   ColdChunkSkipper cold_skipper(options_, chunks, has_bit_indices_, kColdRatio);
3061 
3062   // cached_has_word_index maintains that:
3063   //   cached_has_bits = from._has_bits_[cached_has_word_index]
3064   // for cached_has_word_index >= 0
3065   int cached_has_word_index = -1;
3066 
3067   for (int chunk_index = 0; chunk_index < chunks.size(); chunk_index++) {
3068     const std::vector<const FieldDescriptor*>& chunk = chunks[chunk_index];
3069     bool have_outer_if =
3070         chunk.size() > 1 && HasByteIndex(chunk.front()) != kNoHasbit;
3071     cold_skipper.OnStartChunk(chunk_index, cached_has_word_index, "from.",
3072                               printer);
3073 
3074     if (have_outer_if) {
3075       // Emit an if() that will let us skip the whole chunk if none are set.
3076       uint32 chunk_mask = GenChunkMask(chunk, has_bit_indices_);
3077       std::string chunk_mask_str =
3078           StrCat(strings::Hex(chunk_mask, strings::ZERO_PAD_8));
3079 
3080       // Check (up to) 8 has_bits at a time if we have more than one field in
3081       // this chunk.  Due to field layout ordering, we may check
3082       // _has_bits_[last_chunk * 8 / 32] multiple times.
3083       GOOGLE_DCHECK_LE(2, popcnt(chunk_mask));
3084       GOOGLE_DCHECK_GE(8, popcnt(chunk_mask));
3085 
3086       if (cached_has_word_index != HasWordIndex(chunk.front())) {
3087         cached_has_word_index = HasWordIndex(chunk.front());
3088         format("cached_has_bits = from._has_bits_[$1$];\n",
3089                cached_has_word_index);
3090       }
3091 
3092       format("if (cached_has_bits & 0x$1$u) {\n", chunk_mask_str);
3093       format.Indent();
3094     }
3095 
3096     // Go back and emit merging code for each of the fields we processed.
3097     bool deferred_has_bit_changes = false;
3098     for (const auto field : chunk) {
3099       const FieldGenerator& generator = field_generators_.get(field);
3100 
3101       if (field->is_repeated()) {
3102         generator.GenerateMergingCode(printer);
3103       } else if (field->is_optional() && !HasHasbit(field)) {
3104         // Merge semantics without true field presence: primitive fields are
3105         // merged only if non-zero (numeric) or non-empty (string).
3106         bool have_enclosing_if =
3107             EmitFieldNonDefaultCondition(printer, "from.", field);
3108         generator.GenerateMergingCode(printer);
3109         if (have_enclosing_if) {
3110           format.Outdent();
3111           format("}\n");
3112         }
3113       } else if (field->options().weak() ||
3114                  cached_has_word_index != HasWordIndex(field)) {
3115         // Check hasbit, not using cached bits.
3116         GOOGLE_CHECK(HasHasbit(field));
3117         format("if (from._internal_has_$1$()) {\n", FieldName(field));
3118         format.Indent();
3119         generator.GenerateMergingCode(printer);
3120         format.Outdent();
3121         format("}\n");
3122       } else {
3123         // Check hasbit, using cached bits.
3124         GOOGLE_CHECK(HasHasbit(field));
3125         int has_bit_index = has_bit_indices_[field->index()];
3126         const std::string mask = StrCat(
3127             strings::Hex(1u << (has_bit_index % 32), strings::ZERO_PAD_8));
3128         format("if (cached_has_bits & 0x$1$u) {\n", mask);
3129         format.Indent();
3130 
3131         if (have_outer_if && IsPOD(field)) {
3132           // Defer hasbit modification until the end of chunk.
3133           // This can reduce the number of loads/stores by up to 7 per 8 fields.
3134           deferred_has_bit_changes = true;
3135           generator.GenerateCopyConstructorCode(printer);
3136         } else {
3137           generator.GenerateMergingCode(printer);
3138         }
3139 
3140         format.Outdent();
3141         format("}\n");
3142       }
3143     }
3144 
3145     if (have_outer_if) {
3146       if (deferred_has_bit_changes) {
3147         // Flush the has bits for the primitives we deferred.
3148         GOOGLE_CHECK_LE(0, cached_has_word_index);
3149         format("_has_bits_[$1$] |= cached_has_bits;\n", cached_has_word_index);
3150       }
3151 
3152       format.Outdent();
3153       format("}\n");
3154     }
3155 
3156     if (cold_skipper.OnEndChunk(chunk_index, printer)) {
3157       // Reset here as it may have been updated in just closed if statement.
3158       cached_has_word_index = -1;
3159     }
3160   }
3161 
3162   // Merge oneof fields. Oneof field requires oneof case check.
3163   for (auto oneof : OneOfRange(descriptor_)) {
3164     format("switch (from.$1$_case()) {\n", oneof->name());
3165     format.Indent();
3166     for (auto field : FieldRange(oneof)) {
3167       format("case k$1$: {\n", UnderscoresToCamelCase(field->name(), true));
3168       format.Indent();
3169       if (!IsFieldStripped(field, options_)) {
3170         field_generators_.get(field).GenerateMergingCode(printer);
3171       }
3172       format("break;\n");
3173       format.Outdent();
3174       format("}\n");
3175     }
3176     format(
3177         "case $1$_NOT_SET: {\n"
3178         "  break;\n"
3179         "}\n",
3180         ToUpper(oneof->name()));
3181     format.Outdent();
3182     format("}\n");
3183   }
3184   if (num_weak_fields_) {
3185     format("_weak_field_map_.MergeFrom(from._weak_field_map_);\n");
3186   }
3187 
3188   format.Outdent();
3189   format("}\n");
3190 }
3191 
GenerateCopyFrom(io::Printer * printer)3192 void MessageGenerator::GenerateCopyFrom(io::Printer* printer) {
3193   Formatter format(printer, variables_);
3194   if (HasDescriptorMethods(descriptor_->file(), options_)) {
3195     // Generate the generalized CopyFrom (aka that which takes in the Message
3196     // base class as a parameter).
3197     format(
3198         "void $classname$::CopyFrom(const ::$proto_ns$::Message& from) {\n"
3199         "// @@protoc_insertion_point(generalized_copy_from_start:"
3200         "$full_name$)\n");
3201     format.Indent();
3202 
3203     format("if (&from == this) return;\n");
3204 
3205     if (!options_.opensource_runtime) {
3206       // This check is disabled in the opensource release because we're
3207       // concerned that many users do not define NDEBUG in their release
3208       // builds.
3209       format(
3210           "#ifndef NDEBUG\n"
3211           "size_t from_size = from.ByteSizeLong();\n"
3212           "#endif\n"
3213           "Clear();\n"
3214           "#ifndef NDEBUG\n"
3215           "$CHK$_EQ(from_size, from.ByteSizeLong())\n"
3216           "  << \"Source of CopyFrom changed when clearing target.  Either \"\n"
3217           "  << \"source is a nested message in target (not allowed), or \"\n"
3218           "  << \"another thread is modifying the source.\";\n"
3219           "#endif\n");
3220     } else {
3221       format("Clear();\n");
3222     }
3223     format("MergeFrom(from);\n");
3224 
3225     format.Outdent();
3226     format("}\n\n");
3227   }
3228 
3229   // Generate the class-specific CopyFrom.
3230   format(
3231       "void $classname$::CopyFrom(const $classname$& from) {\n"
3232       "// @@protoc_insertion_point(class_specific_copy_from_start:"
3233       "$full_name$)\n");
3234   format.Indent();
3235 
3236   format("if (&from == this) return;\n");
3237 
3238   if (!options_.opensource_runtime) {
3239     // This check is disabled in the opensource release because we're
3240     // concerned that many users do not define NDEBUG in their release builds.
3241     format(
3242         "#ifndef NDEBUG\n"
3243         "size_t from_size = from.ByteSizeLong();\n"
3244         "#endif\n"
3245         "Clear();\n"
3246         "#ifndef NDEBUG\n"
3247         "$CHK$_EQ(from_size, from.ByteSizeLong())\n"
3248         "  << \"Source of CopyFrom changed when clearing target.  Either \"\n"
3249         "  << \"source is a nested message in target (not allowed), or \"\n"
3250         "  << \"another thread is modifying the source.\";\n"
3251         "#endif\n");
3252   } else {
3253     format("Clear();\n");
3254   }
3255   format("MergeFrom(from);\n");
3256 
3257   format.Outdent();
3258   format("}\n");
3259 }
3260 
GenerateMergeFromCodedStream(io::Printer * printer)3261 void MessageGenerator::GenerateMergeFromCodedStream(io::Printer* printer) {
3262   std::map<std::string, std::string> vars = variables_;
3263   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
3264   Formatter format(printer, vars);
3265   if (descriptor_->options().message_set_wire_format()) {
3266     // Special-case MessageSet.
3267     format(
3268         "const char* $classname$::_InternalParse(const char* ptr,\n"
3269         "                  ::$proto_ns$::internal::ParseContext* ctx) {\n"
3270         "  return _extensions_.ParseMessageSet(ptr, \n"
3271         "      internal_default_instance(), &_internal_metadata_, ctx);\n"
3272         "}\n");
3273     return;
3274   }
3275   GenerateParserLoop(descriptor_, max_has_bit_index_, options_, scc_analyzer_,
3276                      printer);
3277 }
3278 
GenerateSerializeOneofFields(io::Printer * printer,const std::vector<const FieldDescriptor * > & fields)3279 void MessageGenerator::GenerateSerializeOneofFields(
3280     io::Printer* printer, const std::vector<const FieldDescriptor*>& fields) {
3281   Formatter format(printer, variables_);
3282   GOOGLE_CHECK(!fields.empty());
3283   if (fields.size() == 1) {
3284     GenerateSerializeOneField(printer, fields[0], -1);
3285     return;
3286   }
3287   // We have multiple mutually exclusive choices.  Emit a switch statement.
3288   const OneofDescriptor* oneof = fields[0]->containing_oneof();
3289   format("switch ($1$_case()) {\n", oneof->name());
3290   format.Indent();
3291   for (auto field : fields) {
3292     format("case k$1$: {\n", UnderscoresToCamelCase(field->name(), true));
3293     format.Indent();
3294     field_generators_.get(field).GenerateSerializeWithCachedSizesToArray(
3295         printer);
3296     format("break;\n");
3297     format.Outdent();
3298     format("}\n");
3299   }
3300   format.Outdent();
3301   // Doing nothing is an option.
3302   format(
3303       "  default: ;\n"
3304       "}\n");
3305 }
3306 
GenerateSerializeOneField(io::Printer * printer,const FieldDescriptor * field,int cached_has_bits_index)3307 void MessageGenerator::GenerateSerializeOneField(io::Printer* printer,
3308                                                  const FieldDescriptor* field,
3309                                                  int cached_has_bits_index) {
3310   Formatter format(printer, variables_);
3311   if (!field->options().weak()) {
3312     // For weakfields, PrintFieldComment is called during iteration.
3313     PrintFieldComment(format, field);
3314   }
3315 
3316   bool have_enclosing_if = false;
3317   if (field->options().weak()) {
3318   } else if (HasHasbit(field)) {
3319     // Attempt to use the state of cached_has_bits, if possible.
3320     int has_bit_index = HasBitIndex(field);
3321     if (cached_has_bits_index == has_bit_index / 32) {
3322       const std::string mask =
3323           StrCat(strings::Hex(1u << (has_bit_index % 32), strings::ZERO_PAD_8));
3324 
3325       format("if (cached_has_bits & 0x$1$u) {\n", mask);
3326     } else {
3327       format("if (_internal_has_$1$()) {\n", FieldName(field));
3328     }
3329 
3330     format.Indent();
3331     have_enclosing_if = true;
3332   } else if (field->is_optional() && !HasHasbit(field)) {
3333     have_enclosing_if = EmitFieldNonDefaultCondition(printer, "this->", field);
3334   }
3335 
3336   field_generators_.get(field).GenerateSerializeWithCachedSizesToArray(printer);
3337 
3338   if (have_enclosing_if) {
3339     format.Outdent();
3340     format("}\n");
3341   }
3342   format("\n");
3343 }
3344 
GenerateSerializeOneExtensionRange(io::Printer * printer,const Descriptor::ExtensionRange * range)3345 void MessageGenerator::GenerateSerializeOneExtensionRange(
3346     io::Printer* printer, const Descriptor::ExtensionRange* range) {
3347   std::map<std::string, std::string> vars = variables_;
3348   vars["start"] = StrCat(range->start);
3349   vars["end"] = StrCat(range->end);
3350   Formatter format(printer, vars);
3351   format("// Extension range [$start$, $end$)\n");
3352   format(
3353       "target = _extensions_._InternalSerialize(\n"
3354       "    $start$, $end$, target, stream);\n\n");
3355 }
3356 
GenerateSerializeWithCachedSizesToArray(io::Printer * printer)3357 void MessageGenerator::GenerateSerializeWithCachedSizesToArray(
3358     io::Printer* printer) {
3359   Formatter format(printer, variables_);
3360   if (descriptor_->options().message_set_wire_format()) {
3361     // Special-case MessageSet.
3362     format(
3363         "$uint8$* $classname$::_InternalSerialize(\n"
3364         "    $uint8$* target, ::$proto_ns$::io::EpsCopyOutputStream* stream) "
3365         "const {\n"
3366         "  target = _extensions_."
3367         "InternalSerializeMessageSetWithCachedSizesToArray(target, stream);\n");
3368     std::map<std::string, std::string> vars;
3369     SetUnknkownFieldsVariable(descriptor_, options_, &vars);
3370     format.AddMap(vars);
3371     format(
3372         "  target = ::$proto_ns$::internal::"
3373         "InternalSerializeUnknownMessageSetItemsToArray(\n"
3374         "               $unknown_fields$, target, stream);\n");
3375     format(
3376         "  return target;\n"
3377         "}\n");
3378     return;
3379   }
3380 
3381   format(
3382       "$uint8$* $classname$::_InternalSerialize(\n"
3383       "    $uint8$* target, ::$proto_ns$::io::EpsCopyOutputStream* stream) "
3384       "const {\n");
3385   format.Indent();
3386 
3387   format("// @@protoc_insertion_point(serialize_to_array_start:$full_name$)\n");
3388 
3389   if (!ShouldSerializeInOrder(descriptor_, options_)) {
3390     format.Outdent();
3391     format("#ifdef NDEBUG\n");
3392     format.Indent();
3393   }
3394 
3395   GenerateSerializeWithCachedSizesBody(printer);
3396 
3397   if (!ShouldSerializeInOrder(descriptor_, options_)) {
3398     format.Outdent();
3399     format("#else  // NDEBUG\n");
3400     format.Indent();
3401 
3402     GenerateSerializeWithCachedSizesBodyShuffled(printer);
3403 
3404     format.Outdent();
3405     format("#endif  // !NDEBUG\n");
3406     format.Indent();
3407   }
3408 
3409   format("// @@protoc_insertion_point(serialize_to_array_end:$full_name$)\n");
3410 
3411   format.Outdent();
3412   format(
3413       "  return target;\n"
3414       "}\n");
3415 }
3416 
GenerateSerializeWithCachedSizesBody(io::Printer * printer)3417 void MessageGenerator::GenerateSerializeWithCachedSizesBody(
3418     io::Printer* printer) {
3419   Formatter format(printer, variables_);
3420   // If there are multiple fields in a row from the same oneof then we
3421   // coalesce them and emit a switch statement.  This is more efficient
3422   // because it lets the C++ compiler know this is a "at most one can happen"
3423   // situation. If we emitted "if (has_x()) ...; if (has_y()) ..." the C++
3424   // compiler's emitted code might check has_y() even when has_x() is true.
3425   class LazySerializerEmitter {
3426    public:
3427     LazySerializerEmitter(MessageGenerator* mg, io::Printer* printer)
3428         : mg_(mg),
3429           format_(printer),
3430           eager_(!HasFieldPresence(mg->descriptor_->file())),
3431           cached_has_bit_index_(kNoHasbit) {}
3432 
3433     ~LazySerializerEmitter() { Flush(); }
3434 
3435     // If conditions allow, try to accumulate a run of fields from the same
3436     // oneof, and handle them at the next Flush().
3437     void Emit(const FieldDescriptor* field) {
3438       if (eager_ || MustFlush(field)) {
3439         Flush();
3440       }
3441       if (!field->real_containing_oneof()) {
3442         // TODO(ckennelly): Defer non-oneof fields similarly to oneof fields.
3443 
3444         if (!field->options().weak() && !field->is_repeated() && !eager_) {
3445           // We speculatively load the entire _has_bits_[index] contents, even
3446           // if it is for only one field.  Deferring non-oneof emitting would
3447           // allow us to determine whether this is going to be useful.
3448           int has_bit_index = mg_->has_bit_indices_[field->index()];
3449           if (cached_has_bit_index_ != has_bit_index / 32) {
3450             // Reload.
3451             int new_index = has_bit_index / 32;
3452 
3453             format_("cached_has_bits = _has_bits_[$1$];\n", new_index);
3454 
3455             cached_has_bit_index_ = new_index;
3456           }
3457         }
3458 
3459         mg_->GenerateSerializeOneField(format_.printer(), field,
3460                                        cached_has_bit_index_);
3461       } else {
3462         v_.push_back(field);
3463       }
3464     }
3465 
3466     void Flush() {
3467       if (!v_.empty()) {
3468         mg_->GenerateSerializeOneofFields(format_.printer(), v_);
3469         v_.clear();
3470       }
3471     }
3472 
3473    private:
3474     // If we have multiple fields in v_ then they all must be from the same
3475     // oneof.  Would adding field to v_ break that invariant?
3476     bool MustFlush(const FieldDescriptor* field) {
3477       return !v_.empty() &&
3478              v_[0]->containing_oneof() != field->containing_oneof();
3479     }
3480 
3481     MessageGenerator* mg_;
3482     Formatter format_;
3483     const bool eager_;
3484     std::vector<const FieldDescriptor*> v_;
3485 
3486     // cached_has_bit_index_ maintains that:
3487     //   cached_has_bits = from._has_bits_[cached_has_bit_index_]
3488     // for cached_has_bit_index_ >= 0
3489     int cached_has_bit_index_;
3490   };
3491 
3492   std::vector<const FieldDescriptor*> ordered_fields =
3493       SortFieldsByNumber(descriptor_);
3494 
3495   std::vector<const Descriptor::ExtensionRange*> sorted_extensions;
3496   sorted_extensions.reserve(descriptor_->extension_range_count());
3497   for (int i = 0; i < descriptor_->extension_range_count(); ++i) {
3498     sorted_extensions.push_back(descriptor_->extension_range(i));
3499   }
3500   std::sort(sorted_extensions.begin(), sorted_extensions.end(),
3501             ExtensionRangeSorter());
3502   if (num_weak_fields_) {
3503     format(
3504         "::$proto_ns$::internal::WeakFieldMap::FieldWriter field_writer("
3505         "_weak_field_map_);\n");
3506   }
3507 
3508   format(
3509       "$uint32$ cached_has_bits = 0;\n"
3510       "(void) cached_has_bits;\n\n");
3511 
3512   // Merge the fields and the extension ranges, both sorted by field number.
3513   {
3514     LazySerializerEmitter e(this, printer);
3515     const FieldDescriptor* last_weak_field = nullptr;
3516     int i, j;
3517     for (i = 0, j = 0;
3518          i < ordered_fields.size() || j < sorted_extensions.size();) {
3519       if ((j == sorted_extensions.size()) ||
3520           (i < descriptor_->field_count() &&
3521            ordered_fields[i]->number() < sorted_extensions[j]->start)) {
3522         const FieldDescriptor* field = ordered_fields[i++];
3523         if (IsFieldStripped(field, options_)) {
3524           continue;
3525         }
3526         if (field->options().weak()) {
3527           if (last_weak_field == nullptr ||
3528               last_weak_field->number() < field->number()) {
3529             last_weak_field = field;
3530           }
3531           PrintFieldComment(format, field);
3532         } else {
3533           if (last_weak_field != nullptr) {
3534             e.Emit(last_weak_field);
3535             last_weak_field = nullptr;
3536           }
3537           e.Emit(field);
3538         }
3539       } else {
3540         if (last_weak_field != nullptr) {
3541           e.Emit(last_weak_field);
3542           last_weak_field = nullptr;
3543         }
3544         e.Flush();
3545         GenerateSerializeOneExtensionRange(printer, sorted_extensions[j++]);
3546       }
3547     }
3548     if (last_weak_field != nullptr) {
3549       e.Emit(last_weak_field);
3550     }
3551   }
3552 
3553   std::map<std::string, std::string> vars;
3554   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
3555   format.AddMap(vars);
3556   format("if (PROTOBUF_PREDICT_FALSE($have_unknown_fields$)) {\n");
3557   format.Indent();
3558   if (UseUnknownFieldSet(descriptor_->file(), options_)) {
3559     format(
3560         "target = "
3561         "::$proto_ns$::internal::WireFormat::"
3562         "InternalSerializeUnknownFieldsToArray(\n"
3563         "    $unknown_fields$, target, stream);\n");
3564   } else {
3565     format(
3566         "target = stream->WriteRaw($unknown_fields$.data(),\n"
3567         "    static_cast<int>($unknown_fields$.size()), target);\n");
3568   }
3569   format.Outdent();
3570   format("}\n");
3571 }
3572 
GenerateSerializeWithCachedSizesBodyShuffled(io::Printer * printer)3573 void MessageGenerator::GenerateSerializeWithCachedSizesBodyShuffled(
3574     io::Printer* printer) {
3575   Formatter format(printer, variables_);
3576 
3577   std::vector<const FieldDescriptor*> ordered_fields =
3578       SortFieldsByNumber(descriptor_);
3579   ordered_fields.erase(
3580       std::remove_if(ordered_fields.begin(), ordered_fields.end(),
3581                      [this](const FieldDescriptor* f) {
3582                        return !IsFieldUsed(f, options_);
3583                      }),
3584       ordered_fields.end());
3585 
3586   std::vector<const Descriptor::ExtensionRange*> sorted_extensions;
3587   sorted_extensions.reserve(descriptor_->extension_range_count());
3588   for (int i = 0; i < descriptor_->extension_range_count(); ++i) {
3589     sorted_extensions.push_back(descriptor_->extension_range(i));
3590   }
3591   std::sort(sorted_extensions.begin(), sorted_extensions.end(),
3592             ExtensionRangeSorter());
3593 
3594   int num_fields = ordered_fields.size() + sorted_extensions.size();
3595   constexpr int kLargePrime = 1000003;
3596   GOOGLE_CHECK_LT(num_fields, kLargePrime)
3597       << "Prime offset must be greater than the number of fields to ensure "
3598          "those are coprime.";
3599 
3600   if (num_weak_fields_) {
3601     format(
3602         "::$proto_ns$::internal::WeakFieldMap::FieldWriter field_writer("
3603         "_weak_field_map_);\n");
3604   }
3605 
3606   format(
3607       "static const int kStart = GetInvariantPerBuild($1$UL) % $2$;\n"
3608       "bool first_pass = true;\n"
3609       "for (int i = kStart; i != kStart || first_pass; i = ((i + $3$) % $2$)) "
3610       "{\n",
3611       0,
3612       num_fields, kLargePrime);
3613 
3614   format.Indent();
3615   format("switch(i) {\n");
3616   format.Indent();
3617 
3618   bool first_pass_set = false;
3619   int index = 0;
3620   for (const auto* f : ordered_fields) {
3621     format("case $1$: {\n", index++);
3622     format.Indent();
3623 
3624     if (!first_pass_set) {
3625       first_pass_set = true;
3626       format("first_pass = false;\n");
3627     }
3628 
3629     GenerateSerializeOneField(printer, f, -1);
3630 
3631     format("break;\n");
3632     format.Outdent();
3633     format("}\n");
3634   }
3635 
3636   for (const auto* r : sorted_extensions) {
3637     format("case $1$: {\n", index++);
3638     format.Indent();
3639 
3640     if (!first_pass_set) {
3641       first_pass_set = true;
3642       format("first_pass = false;\n");
3643     }
3644 
3645     GenerateSerializeOneExtensionRange(printer, r);
3646 
3647     format("break;\n");
3648     format.Outdent();
3649     format("}\n");
3650   }
3651 
3652   format(
3653       "default: {\n"
3654       "  $DCHK$(false) << \"Unexpected index: \" << i;\n"
3655       "}\n");
3656   format.Outdent();
3657   format("}\n");
3658 
3659   format.Outdent();
3660   format("}\n");
3661 
3662   std::map<std::string, std::string> vars;
3663   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
3664   format.AddMap(vars);
3665   format("if (PROTOBUF_PREDICT_FALSE($have_unknown_fields$)) {\n");
3666   format.Indent();
3667   if (UseUnknownFieldSet(descriptor_->file(), options_)) {
3668     format(
3669         "target = "
3670         "::$proto_ns$::internal::WireFormat::"
3671         "InternalSerializeUnknownFieldsToArray(\n"
3672         "    $unknown_fields$, target, stream);\n");
3673   } else {
3674     format(
3675         "target = stream->WriteRaw($unknown_fields$.data(),\n"
3676         "    static_cast<int>($unknown_fields$.size()), target);\n");
3677   }
3678   format.Outdent();
3679   format("}\n");
3680 }
3681 
RequiredFieldsBitMask() const3682 std::vector<uint32> MessageGenerator::RequiredFieldsBitMask() const {
3683   const int array_size = HasBitsSize();
3684   std::vector<uint32> masks(array_size, 0);
3685 
3686   for (auto field : FieldRange(descriptor_)) {
3687     if (!field->is_required()) {
3688       continue;
3689     }
3690 
3691     const int has_bit_index = has_bit_indices_[field->index()];
3692     masks[has_bit_index / 32] |= static_cast<uint32>(1) << (has_bit_index % 32);
3693   }
3694   return masks;
3695 }
3696 
GenerateByteSize(io::Printer * printer)3697 void MessageGenerator::GenerateByteSize(io::Printer* printer) {
3698   Formatter format(printer, variables_);
3699 
3700   if (descriptor_->options().message_set_wire_format()) {
3701     // Special-case MessageSet.
3702     std::map<std::string, std::string> vars;
3703     SetUnknkownFieldsVariable(descriptor_, options_, &vars);
3704     format.AddMap(vars);
3705     format(
3706         "size_t $classname$::ByteSizeLong() const {\n"
3707         "// @@protoc_insertion_point(message_set_byte_size_start:$full_name$)\n"
3708         "  size_t total_size = _extensions_.MessageSetByteSize();\n"
3709         "  if ($have_unknown_fields$) {\n"
3710         "    total_size += ::$proto_ns$::internal::\n"
3711         "        ComputeUnknownMessageSetItemsSize($unknown_fields$);\n"
3712         "  }\n"
3713         "  int cached_size = "
3714         "::$proto_ns$::internal::ToCachedSize(total_size);\n"
3715         "  SetCachedSize(cached_size);\n"
3716         "  return total_size;\n"
3717         "}\n");
3718     return;
3719   }
3720 
3721   if (num_required_fields_ > 1) {
3722     // Emit a function (rarely used, we hope) that handles the required fields
3723     // by checking for each one individually.
3724     format(
3725         "size_t $classname$::RequiredFieldsByteSizeFallback() const {\n"
3726         "// @@protoc_insertion_point(required_fields_byte_size_fallback_start:"
3727         "$full_name$)\n");
3728     format.Indent();
3729     format("size_t total_size = 0;\n");
3730     for (auto field : optimized_order_) {
3731       if (field->is_required()) {
3732         format(
3733             "\n"
3734             "if (_internal_has_$1$()) {\n",
3735             FieldName(field));
3736         format.Indent();
3737         PrintFieldComment(format, field);
3738         field_generators_.get(field).GenerateByteSize(printer);
3739         format.Outdent();
3740         format("}\n");
3741       }
3742     }
3743     format(
3744         "\n"
3745         "return total_size;\n");
3746     format.Outdent();
3747     format("}\n");
3748   }
3749 
3750   format(
3751       "size_t $classname$::ByteSizeLong() const {\n"
3752       "// @@protoc_insertion_point(message_byte_size_start:$full_name$)\n");
3753   format.Indent();
3754   format(
3755       "size_t total_size = 0;\n"
3756       "\n");
3757 
3758   if (descriptor_->extension_range_count() > 0) {
3759     format(
3760         "total_size += _extensions_.ByteSize();\n"
3761         "\n");
3762   }
3763 
3764   std::map<std::string, std::string> vars;
3765   SetUnknkownFieldsVariable(descriptor_, options_, &vars);
3766   format.AddMap(vars);
3767 
3768   // Handle required fields (if any).  We expect all of them to be
3769   // present, so emit one conditional that checks for that.  If they are all
3770   // present then the fast path executes; otherwise the slow path executes.
3771   if (num_required_fields_ > 1) {
3772     // The fast path works if all required fields are present.
3773     const std::vector<uint32> masks_for_has_bits = RequiredFieldsBitMask();
3774     format("if ($1$) {  // All required fields are present.\n",
3775            ConditionalToCheckBitmasks(masks_for_has_bits));
3776     format.Indent();
3777     // Oneof fields cannot be required, so optimized_order_ contains all of the
3778     // fields that we need to potentially emit.
3779     for (auto field : optimized_order_) {
3780       if (!field->is_required()) continue;
3781       PrintFieldComment(format, field);
3782       field_generators_.get(field).GenerateByteSize(printer);
3783       format("\n");
3784     }
3785     format.Outdent();
3786     format(
3787         "} else {\n"  // the slow path
3788         "  total_size += RequiredFieldsByteSizeFallback();\n"
3789         "}\n");
3790   } else {
3791     // num_required_fields_ <= 1: no need to be tricky
3792     for (auto field : optimized_order_) {
3793       if (!field->is_required()) continue;
3794       PrintFieldComment(format, field);
3795       format("if (_internal_has_$1$()) {\n", FieldName(field));
3796       format.Indent();
3797       field_generators_.get(field).GenerateByteSize(printer);
3798       format.Outdent();
3799       format("}\n");
3800     }
3801   }
3802 
3803   std::vector<std::vector<const FieldDescriptor*>> chunks = CollectFields(
3804       optimized_order_,
3805       [&](const FieldDescriptor* a, const FieldDescriptor* b) -> bool {
3806         return a->label() == b->label() && HasByteIndex(a) == HasByteIndex(b);
3807       });
3808 
3809   // Remove chunks with required fields.
3810   chunks.erase(std::remove_if(chunks.begin(), chunks.end(), IsRequired),
3811                chunks.end());
3812 
3813   ColdChunkSkipper cold_skipper(options_, chunks, has_bit_indices_, kColdRatio);
3814   int cached_has_word_index = -1;
3815 
3816   format(
3817       "$uint32$ cached_has_bits = 0;\n"
3818       "// Prevent compiler warnings about cached_has_bits being unused\n"
3819       "(void) cached_has_bits;\n\n");
3820 
3821   for (int chunk_index = 0; chunk_index < chunks.size(); chunk_index++) {
3822     const std::vector<const FieldDescriptor*>& chunk = chunks[chunk_index];
3823     const bool have_outer_if =
3824         chunk.size() > 1 && HasWordIndex(chunk[0]) != kNoHasbit;
3825     cold_skipper.OnStartChunk(chunk_index, cached_has_word_index, "", printer);
3826 
3827     if (have_outer_if) {
3828       // Emit an if() that will let us skip the whole chunk if none are set.
3829       uint32 chunk_mask = GenChunkMask(chunk, has_bit_indices_);
3830       std::string chunk_mask_str =
3831           StrCat(strings::Hex(chunk_mask, strings::ZERO_PAD_8));
3832 
3833       // Check (up to) 8 has_bits at a time if we have more than one field in
3834       // this chunk.  Due to field layout ordering, we may check
3835       // _has_bits_[last_chunk * 8 / 32] multiple times.
3836       GOOGLE_DCHECK_LE(2, popcnt(chunk_mask));
3837       GOOGLE_DCHECK_GE(8, popcnt(chunk_mask));
3838 
3839       if (cached_has_word_index != HasWordIndex(chunk.front())) {
3840         cached_has_word_index = HasWordIndex(chunk.front());
3841         format("cached_has_bits = _has_bits_[$1$];\n", cached_has_word_index);
3842       }
3843       format("if (cached_has_bits & 0x$1$u) {\n", chunk_mask_str);
3844       format.Indent();
3845     }
3846 
3847     // Go back and emit checks for each of the fields we processed.
3848     for (int j = 0; j < chunk.size(); j++) {
3849       const FieldDescriptor* field = chunk[j];
3850       const FieldGenerator& generator = field_generators_.get(field);
3851       bool have_enclosing_if = false;
3852       bool need_extra_newline = false;
3853 
3854       PrintFieldComment(format, field);
3855 
3856       if (field->is_repeated()) {
3857         // No presence check is required.
3858         need_extra_newline = true;
3859       } else if (HasHasbit(field)) {
3860         PrintPresenceCheck(format, field, has_bit_indices_, printer,
3861                            &cached_has_word_index);
3862         have_enclosing_if = true;
3863       } else {
3864         // Without field presence: field is serialized only if it has a
3865         // non-default value.
3866         have_enclosing_if =
3867             EmitFieldNonDefaultCondition(printer, "this->", field);
3868       }
3869 
3870       generator.GenerateByteSize(printer);
3871 
3872       if (have_enclosing_if) {
3873         format.Outdent();
3874         format(
3875             "}\n"
3876             "\n");
3877       }
3878       if (need_extra_newline) {
3879         format("\n");
3880       }
3881     }
3882 
3883     if (have_outer_if) {
3884       format.Outdent();
3885       format("}\n");
3886     }
3887 
3888     if (cold_skipper.OnEndChunk(chunk_index, printer)) {
3889       // Reset here as it may have been updated in just closed if statement.
3890       cached_has_word_index = -1;
3891     }
3892   }
3893 
3894   // Fields inside a oneof don't use _has_bits_ so we count them in a separate
3895   // pass.
3896   for (auto oneof : OneOfRange(descriptor_)) {
3897     format("switch ($1$_case()) {\n", oneof->name());
3898     format.Indent();
3899     for (auto field : FieldRange(oneof)) {
3900       PrintFieldComment(format, field);
3901       format("case k$1$: {\n", UnderscoresToCamelCase(field->name(), true));
3902       format.Indent();
3903       if (!IsFieldStripped(field, options_)) {
3904         field_generators_.get(field).GenerateByteSize(printer);
3905       }
3906       format("break;\n");
3907       format.Outdent();
3908       format("}\n");
3909     }
3910     format(
3911         "case $1$_NOT_SET: {\n"
3912         "  break;\n"
3913         "}\n",
3914         ToUpper(oneof->name()));
3915     format.Outdent();
3916     format("}\n");
3917   }
3918 
3919   if (num_weak_fields_) {
3920     // TagSize + MessageSize
3921     format("total_size += _weak_field_map_.ByteSizeLong();\n");
3922   }
3923 
3924   format("if (PROTOBUF_PREDICT_FALSE($have_unknown_fields$)) {\n");
3925   if (UseUnknownFieldSet(descriptor_->file(), options_)) {
3926     // We go out of our way to put the computation of the uncommon path of
3927     // unknown fields in tail position. This allows for better code generation
3928     // of this function for simple protos.
3929     format(
3930         "  return ::$proto_ns$::internal::ComputeUnknownFieldsSize(\n"
3931         "      _internal_metadata_, total_size, &_cached_size_);\n");
3932   } else {
3933     format("  total_size += $unknown_fields$.size();\n");
3934   }
3935   format("}\n");
3936 
3937   // We update _cached_size_ even though this is a const method.  Because
3938   // const methods might be called concurrently this needs to be atomic
3939   // operations or the program is undefined.  In practice, since any concurrent
3940   // writes will be writing the exact same value, normal writes will work on
3941   // all common processors. We use a dedicated wrapper class to abstract away
3942   // the underlying atomic. This makes it easier on platforms where even relaxed
3943   // memory order might have perf impact to replace it with ordinary loads and
3944   // stores.
3945   format(
3946       "int cached_size = ::$proto_ns$::internal::ToCachedSize(total_size);\n"
3947       "SetCachedSize(cached_size);\n"
3948       "return total_size;\n");
3949 
3950   format.Outdent();
3951   format("}\n");
3952 }
3953 
GenerateIsInitialized(io::Printer * printer)3954 void MessageGenerator::GenerateIsInitialized(io::Printer* printer) {
3955   Formatter format(printer, variables_);
3956   format("bool $classname$::IsInitialized() const {\n");
3957   format.Indent();
3958 
3959   if (descriptor_->extension_range_count() > 0) {
3960     format(
3961         "if (!_extensions_.IsInitialized()) {\n"
3962         "  return false;\n"
3963         "}\n\n");
3964   }
3965 
3966   if (num_required_fields_ > 0) {
3967     format(
3968         "if (_Internal::MissingRequiredFields(_has_bits_))"
3969         " return false;\n");
3970   }
3971 
3972   // Now check that all non-oneof embedded messages are initialized.
3973   for (auto field : optimized_order_) {
3974     // TODO(ckennelly): Push this down into a generator?
3975     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
3976         !ShouldIgnoreRequiredFieldCheck(field, options_) &&
3977         scc_analyzer_->HasRequiredFields(field->message_type())) {
3978       if (field->is_repeated()) {
3979         if (IsImplicitWeakField(field, options_, scc_analyzer_)) {
3980           format(
3981               "if "
3982               "(!::$proto_ns$::internal::AllAreInitializedWeak($1$_.weak)"
3983               ")"
3984               " return false;\n",
3985               FieldName(field));
3986         } else {
3987           format(
3988               "if (!::$proto_ns$::internal::AllAreInitialized($1$_))"
3989               " return false;\n",
3990               FieldName(field));
3991         }
3992       } else if (field->options().weak()) {
3993         continue;
3994       } else {
3995         GOOGLE_CHECK(!field->real_containing_oneof());
3996         format(
3997             "if (_internal_has_$1$()) {\n"
3998             "  if (!$1$_->IsInitialized()) return false;\n"
3999             "}\n",
4000             FieldName(field));
4001       }
4002     }
4003   }
4004   if (num_weak_fields_) {
4005     // For Weak fields.
4006     format("if (!_weak_field_map_.IsInitialized()) return false;\n");
4007   }
4008   // Go through the oneof fields, emitting a switch if any might have required
4009   // fields.
4010   for (auto oneof : OneOfRange(descriptor_)) {
4011     bool has_required_fields = false;
4012     for (auto field : FieldRange(oneof)) {
4013       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
4014           !ShouldIgnoreRequiredFieldCheck(field, options_) &&
4015           scc_analyzer_->HasRequiredFields(field->message_type())) {
4016         has_required_fields = true;
4017         break;
4018       }
4019     }
4020 
4021     if (!has_required_fields) {
4022       continue;
4023     }
4024 
4025     format("switch ($1$_case()) {\n", oneof->name());
4026     format.Indent();
4027     for (auto field : FieldRange(oneof)) {
4028       format("case k$1$: {\n", UnderscoresToCamelCase(field->name(), true));
4029       format.Indent();
4030 
4031       if (!IsFieldStripped(field, options_) &&
4032           field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
4033           !ShouldIgnoreRequiredFieldCheck(field, options_) &&
4034           scc_analyzer_->HasRequiredFields(field->message_type())) {
4035         GOOGLE_CHECK(!(field->options().weak() || !field->real_containing_oneof()));
4036         if (field->options().weak()) {
4037           // Just skip.
4038         } else {
4039           format(
4040               "if (has_$1$()) {\n"
4041               "  if (!this->$1$().IsInitialized()) return false;\n"
4042               "}\n",
4043               FieldName(field));
4044         }
4045       }
4046 
4047       format("break;\n");
4048       format.Outdent();
4049       format("}\n");
4050     }
4051     format(
4052         "case $1$_NOT_SET: {\n"
4053         "  break;\n"
4054         "}\n",
4055         ToUpper(oneof->name()));
4056     format.Outdent();
4057     format("}\n");
4058   }
4059 
4060   format.Outdent();
4061   format(
4062       "  return true;\n"
4063       "}\n");
4064 }
4065 
4066 }  // namespace cpp
4067 }  // namespace compiler
4068 }  // namespace protobuf
4069 }  // namespace google
4070