• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: kenton@google.com (Kenton Varda)
9 //  Based on original Protocol Buffers design by
10 //  Sanjay Ghemawat, Jeff Dean, and others.
11 
12 #include "google/protobuf/wire_format.h"
13 
14 #include <algorithm>
15 #include <cstddef>
16 #include <cstdint>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/base/attributes.h"
22 #include "absl/log/absl_check.h"
23 #include "absl/log/absl_log.h"
24 #include "absl/strings/cord.h"
25 #include "google/protobuf/descriptor.h"
26 #include "google/protobuf/descriptor.pb.h"
27 #include "google/protobuf/dynamic_message.h"
28 #include "google/protobuf/io/coded_stream.h"
29 #include "google/protobuf/map_field.h"
30 #include "google/protobuf/message.h"
31 #include "google/protobuf/message_lite.h"
32 #include "google/protobuf/parse_context.h"
33 #include "google/protobuf/unknown_field_set.h"
34 #include "google/protobuf/wire_format_lite.h"
35 
36 
37 // Must be included last.
38 #include "google/protobuf/port_def.inc"
39 
40 const size_t kMapEntryTagByteSize = 2;
41 
42 namespace google {
43 namespace protobuf {
44 namespace internal {
45 
46 // Forward declare static functions
47 static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field,
48                                           const MapValueConstRef& value);
49 
50 // ===================================================================
51 
SkipField(io::CodedInputStream * input,uint32_t tag)52 bool UnknownFieldSetFieldSkipper::SkipField(io::CodedInputStream* input,
53                                             uint32_t tag) {
54   return WireFormat::SkipField(input, tag, unknown_fields_);
55 }
56 
SkipMessage(io::CodedInputStream * input)57 bool UnknownFieldSetFieldSkipper::SkipMessage(io::CodedInputStream* input) {
58   return WireFormat::SkipMessage(input, unknown_fields_);
59 }
60 
SkipUnknownEnum(int field_number,int value)61 void UnknownFieldSetFieldSkipper::SkipUnknownEnum(int field_number, int value) {
62   unknown_fields_->AddVarint(field_number, value);
63 }
64 
SkipField(io::CodedInputStream * input,uint32_t tag,UnknownFieldSet * unknown_fields)65 bool WireFormat::SkipField(io::CodedInputStream* input, uint32_t tag,
66                            UnknownFieldSet* unknown_fields) {
67   int number = WireFormatLite::GetTagFieldNumber(tag);
68   // Field number 0 is illegal.
69   if (number == 0) return false;
70 
71   switch (WireFormatLite::GetTagWireType(tag)) {
72     case WireFormatLite::WIRETYPE_VARINT: {
73       uint64_t value;
74       if (!input->ReadVarint64(&value)) return false;
75       if (unknown_fields != nullptr) unknown_fields->AddVarint(number, value);
76       return true;
77     }
78     case WireFormatLite::WIRETYPE_FIXED64: {
79       uint64_t value;
80       if (!input->ReadLittleEndian64(&value)) return false;
81       if (unknown_fields != nullptr) unknown_fields->AddFixed64(number, value);
82       return true;
83     }
84     case WireFormatLite::WIRETYPE_LENGTH_DELIMITED: {
85       uint32_t length;
86       if (!input->ReadVarint32(&length)) return false;
87       if (unknown_fields == nullptr) {
88         if (!input->Skip(length)) return false;
89       } else {
90         if (!input->ReadString(unknown_fields->AddLengthDelimited(number),
91                                length)) {
92           return false;
93         }
94       }
95       return true;
96     }
97     case WireFormatLite::WIRETYPE_START_GROUP: {
98       if (!input->IncrementRecursionDepth()) return false;
99       if (!SkipMessage(input, (unknown_fields == nullptr)
100                                   ? nullptr
101                                   : unknown_fields->AddGroup(number))) {
102         return false;
103       }
104       input->DecrementRecursionDepth();
105       // Check that the ending tag matched the starting tag.
106       if (!input->LastTagWas(
107               WireFormatLite::MakeTag(WireFormatLite::GetTagFieldNumber(tag),
108                                       WireFormatLite::WIRETYPE_END_GROUP))) {
109         return false;
110       }
111       return true;
112     }
113     case WireFormatLite::WIRETYPE_END_GROUP: {
114       return false;
115     }
116     case WireFormatLite::WIRETYPE_FIXED32: {
117       uint32_t value;
118       if (!input->ReadLittleEndian32(&value)) return false;
119       if (unknown_fields != nullptr) unknown_fields->AddFixed32(number, value);
120       return true;
121     }
122     default: {
123       return false;
124     }
125   }
126 }
127 
SkipMessage(io::CodedInputStream * input,UnknownFieldSet * unknown_fields)128 bool WireFormat::SkipMessage(io::CodedInputStream* input,
129                              UnknownFieldSet* unknown_fields) {
130   while (true) {
131     uint32_t tag = input->ReadTag();
132     if (tag == 0) {
133       // End of input.  This is a valid place to end, so return true.
134       return true;
135     }
136 
137     WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
138 
139     if (wire_type == WireFormatLite::WIRETYPE_END_GROUP) {
140       // Must be the end of the message.
141       return true;
142     }
143 
144     if (!SkipField(input, tag, unknown_fields)) return false;
145   }
146 }
147 
ReadPackedEnumPreserveUnknowns(io::CodedInputStream * input,uint32_t field_number,bool (* is_valid)(int),UnknownFieldSet * unknown_fields,RepeatedField<int> * values)148 bool WireFormat::ReadPackedEnumPreserveUnknowns(io::CodedInputStream* input,
149                                                 uint32_t field_number,
150                                                 bool (*is_valid)(int),
151                                                 UnknownFieldSet* unknown_fields,
152                                                 RepeatedField<int>* values) {
153   uint32_t length;
154   if (!input->ReadVarint32(&length)) return false;
155   io::CodedInputStream::Limit limit = input->PushLimit(length);
156   while (input->BytesUntilLimit() > 0) {
157     int value;
158     if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
159             input, &value)) {
160       return false;
161     }
162     if (is_valid == nullptr || is_valid(value)) {
163       values->Add(value);
164     } else {
165       unknown_fields->AddVarint(field_number, value);
166     }
167   }
168   input->PopLimit(limit);
169   return true;
170 }
171 
InternalSerializeUnknownFieldsToArray(const UnknownFieldSet & unknown_fields,uint8_t * target,io::EpsCopyOutputStream * stream)172 uint8_t* WireFormat::InternalSerializeUnknownFieldsToArray(
173     const UnknownFieldSet& unknown_fields, uint8_t* target,
174     io::EpsCopyOutputStream* stream) {
175   for (int i = 0; i < unknown_fields.field_count(); i++) {
176     const UnknownField& field = unknown_fields.field(i);
177 
178     target = stream->EnsureSpace(target);
179     switch (field.type()) {
180       case UnknownField::TYPE_VARINT:
181         target = WireFormatLite::WriteUInt64ToArray(field.number(),
182                                                     field.varint(), target);
183         break;
184       case UnknownField::TYPE_FIXED32:
185         target = WireFormatLite::WriteFixed32ToArray(field.number(),
186                                                      field.fixed32(), target);
187         break;
188       case UnknownField::TYPE_FIXED64:
189         target = WireFormatLite::WriteFixed64ToArray(field.number(),
190                                                      field.fixed64(), target);
191         break;
192       case UnknownField::TYPE_LENGTH_DELIMITED:
193         target = stream->WriteString(field.number(), field.length_delimited(),
194                                      target);
195         break;
196       case UnknownField::TYPE_GROUP:
197         target = WireFormatLite::WriteTagToArray(
198             field.number(), WireFormatLite::WIRETYPE_START_GROUP, target);
199         target = InternalSerializeUnknownFieldsToArray(field.group(), target,
200                                                        stream);
201         target = stream->EnsureSpace(target);
202         target = WireFormatLite::WriteTagToArray(
203             field.number(), WireFormatLite::WIRETYPE_END_GROUP, target);
204         break;
205     }
206   }
207   return target;
208 }
209 
InternalSerializeUnknownMessageSetItemsToArray(const UnknownFieldSet & unknown_fields,uint8_t * target,io::EpsCopyOutputStream * stream)210 uint8_t* WireFormat::InternalSerializeUnknownMessageSetItemsToArray(
211     const UnknownFieldSet& unknown_fields, uint8_t* target,
212     io::EpsCopyOutputStream* stream) {
213   for (int i = 0; i < unknown_fields.field_count(); i++) {
214     const UnknownField& field = unknown_fields.field(i);
215 
216     // The only unknown fields that are allowed to exist in a MessageSet are
217     // messages, which are length-delimited.
218     if (field.type() == UnknownField::TYPE_LENGTH_DELIMITED) {
219       target = stream->EnsureSpace(target);
220       // Start group.
221       target = io::CodedOutputStream::WriteTagToArray(
222           WireFormatLite::kMessageSetItemStartTag, target);
223 
224       // Write type ID.
225       target = io::CodedOutputStream::WriteTagToArray(
226           WireFormatLite::kMessageSetTypeIdTag, target);
227       target =
228           io::CodedOutputStream::WriteVarint32ToArray(field.number(), target);
229 
230       // Write message.
231       target = io::CodedOutputStream::WriteTagToArray(
232           WireFormatLite::kMessageSetMessageTag, target);
233 
234       target = field.InternalSerializeLengthDelimitedNoTag(target, stream);
235 
236       target = stream->EnsureSpace(target);
237       // End group.
238       target = io::CodedOutputStream::WriteTagToArray(
239           WireFormatLite::kMessageSetItemEndTag, target);
240     }
241   }
242 
243   return target;
244 }
245 
ComputeUnknownFieldsSize(const UnknownFieldSet & unknown_fields)246 size_t WireFormat::ComputeUnknownFieldsSize(
247     const UnknownFieldSet& unknown_fields) {
248   size_t size = 0;
249   for (int i = 0; i < unknown_fields.field_count(); i++) {
250     const UnknownField& field = unknown_fields.field(i);
251 
252     switch (field.type()) {
253       case UnknownField::TYPE_VARINT:
254         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
255             field.number(), WireFormatLite::WIRETYPE_VARINT));
256         size += io::CodedOutputStream::VarintSize64(field.varint());
257         break;
258       case UnknownField::TYPE_FIXED32:
259         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
260             field.number(), WireFormatLite::WIRETYPE_FIXED32));
261         size += sizeof(int32_t);
262         break;
263       case UnknownField::TYPE_FIXED64:
264         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
265             field.number(), WireFormatLite::WIRETYPE_FIXED64));
266         size += sizeof(int64_t);
267         break;
268       case UnknownField::TYPE_LENGTH_DELIMITED:
269         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
270             field.number(), WireFormatLite::WIRETYPE_LENGTH_DELIMITED));
271         size += io::CodedOutputStream::VarintSize32(
272             field.length_delimited().size());
273         size += field.length_delimited().size();
274         break;
275       case UnknownField::TYPE_GROUP:
276         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
277             field.number(), WireFormatLite::WIRETYPE_START_GROUP));
278         size += ComputeUnknownFieldsSize(field.group());
279         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
280             field.number(), WireFormatLite::WIRETYPE_END_GROUP));
281         break;
282     }
283   }
284 
285   return size;
286 }
287 
ComputeUnknownMessageSetItemsSize(const UnknownFieldSet & unknown_fields)288 size_t WireFormat::ComputeUnknownMessageSetItemsSize(
289     const UnknownFieldSet& unknown_fields) {
290   size_t size = 0;
291   for (int i = 0; i < unknown_fields.field_count(); i++) {
292     const UnknownField& field = unknown_fields.field(i);
293 
294     // The only unknown fields that are allowed to exist in a MessageSet are
295     // messages, which are length-delimited.
296     if (field.type() == UnknownField::TYPE_LENGTH_DELIMITED) {
297       size += WireFormatLite::kMessageSetItemTagsSize;
298       size += io::CodedOutputStream::VarintSize32(field.number());
299 
300       int field_size = field.GetLengthDelimitedSize();
301       size += io::CodedOutputStream::VarintSize32(field_size);
302       size += field_size;
303     }
304   }
305 
306   return size;
307 }
308 
309 // ===================================================================
310 
ParseAndMergePartial(io::CodedInputStream * input,Message * message)311 bool WireFormat::ParseAndMergePartial(io::CodedInputStream* input,
312                                       Message* message) {
313   const Descriptor* descriptor = message->GetDescriptor();
314   const Reflection* message_reflection = message->GetReflection();
315 
316   while (true) {
317     uint32_t tag = input->ReadTag();
318     if (tag == 0) {
319       // End of input.  This is a valid place to end, so return true.
320       return true;
321     }
322 
323     if (WireFormatLite::GetTagWireType(tag) ==
324         WireFormatLite::WIRETYPE_END_GROUP) {
325       // Must be the end of the message.
326       return true;
327     }
328 
329     const FieldDescriptor* field = nullptr;
330 
331     if (descriptor != nullptr) {
332       int field_number = WireFormatLite::GetTagFieldNumber(tag);
333       field = descriptor->FindFieldByNumber(field_number);
334 
335       // If that failed, check if the field is an extension.
336       if (field == nullptr && descriptor->IsExtensionNumber(field_number)) {
337         if (input->GetExtensionPool() == nullptr) {
338           field = message_reflection->FindKnownExtensionByNumber(field_number);
339         } else {
340           field = input->GetExtensionPool()->FindExtensionByNumber(
341               descriptor, field_number);
342         }
343       }
344 
345       // If that failed, but we're a MessageSet, and this is the tag for a
346       // MessageSet item, then parse that.
347       if (field == nullptr && descriptor->options().message_set_wire_format() &&
348           tag == WireFormatLite::kMessageSetItemStartTag) {
349         if (!ParseAndMergeMessageSetItem(input, message)) {
350           return false;
351         }
352         continue;  // Skip ParseAndMergeField(); already taken care of.
353       }
354     }
355 
356     if (!ParseAndMergeField(tag, field, message, input)) {
357       return false;
358     }
359   }
360 }
361 
SkipMessageSetField(io::CodedInputStream * input,uint32_t field_number,UnknownFieldSet * unknown_fields)362 bool WireFormat::SkipMessageSetField(io::CodedInputStream* input,
363                                      uint32_t field_number,
364                                      UnknownFieldSet* unknown_fields) {
365   uint32_t length;
366   if (!input->ReadVarint32(&length)) return false;
367   return input->ReadString(unknown_fields->AddLengthDelimited(field_number),
368                            length);
369 }
370 
ParseAndMergeMessageSetField(uint32_t field_number,const FieldDescriptor * field,Message * message,io::CodedInputStream * input)371 bool WireFormat::ParseAndMergeMessageSetField(uint32_t field_number,
372                                               const FieldDescriptor* field,
373                                               Message* message,
374                                               io::CodedInputStream* input) {
375   const Reflection* message_reflection = message->GetReflection();
376   if (field == nullptr) {
377     // We store unknown MessageSet extensions as groups.
378     return SkipMessageSetField(
379         input, field_number, message_reflection->MutableUnknownFields(message));
380   } else if (field->is_repeated() ||
381              field->type() != FieldDescriptor::TYPE_MESSAGE) {
382     // This shouldn't happen as we only allow optional message extensions to
383     // MessageSet.
384     ABSL_LOG(ERROR) << "Extensions of MessageSets must be optional messages.";
385     return false;
386   } else {
387     Message* sub_message = message_reflection->MutableMessage(
388         message, field, input->GetExtensionFactory());
389     return WireFormatLite::ReadMessage(input, sub_message);
390   }
391 }
392 
ParseAndMergeField(uint32_t tag,const FieldDescriptor * field,Message * message,io::CodedInputStream * input)393 bool WireFormat::ParseAndMergeField(
394     uint32_t tag,
395     const FieldDescriptor* field,  // May be nullptr for unknown
396     Message* message, io::CodedInputStream* input) {
397   const Reflection* message_reflection = message->GetReflection();
398 
399   enum { UNKNOWN, NORMAL_FORMAT, PACKED_FORMAT } value_format;
400 
401   if (field == nullptr) {
402     value_format = UNKNOWN;
403   } else if (WireFormatLite::GetTagWireType(tag) ==
404              WireTypeForFieldType(field->type())) {
405     value_format = NORMAL_FORMAT;
406   } else if (field->is_packable() &&
407              WireFormatLite::GetTagWireType(tag) ==
408                  WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
409     value_format = PACKED_FORMAT;
410   } else {
411     // We don't recognize this field. Either the field number is unknown
412     // or the wire type doesn't match. Put it in our unknown field set.
413     value_format = UNKNOWN;
414   }
415 
416   if (value_format == UNKNOWN) {
417     return SkipField(input, tag,
418                      message_reflection->MutableUnknownFields(message));
419   } else if (value_format == PACKED_FORMAT) {
420     uint32_t length;
421     if (!input->ReadVarint32(&length)) return false;
422     io::CodedInputStream::Limit limit = input->PushLimit(length);
423 
424     switch (field->type()) {
425 #define HANDLE_PACKED_TYPE(TYPE, CPPTYPE, CPPTYPE_METHOD)                      \
426   case FieldDescriptor::TYPE_##TYPE: {                                         \
427     while (input->BytesUntilLimit() > 0) {                                     \
428       CPPTYPE value;                                                           \
429       if (!WireFormatLite::ReadPrimitive<CPPTYPE,                              \
430                                          WireFormatLite::TYPE_##TYPE>(input,   \
431                                                                       &value)) \
432         return false;                                                          \
433       message_reflection->Add##CPPTYPE_METHOD(message, field, value);          \
434     }                                                                          \
435     break;                                                                     \
436   }
437 
438       HANDLE_PACKED_TYPE(INT32, int32_t, Int32)
439       HANDLE_PACKED_TYPE(INT64, int64_t, Int64)
440       HANDLE_PACKED_TYPE(SINT32, int32_t, Int32)
441       HANDLE_PACKED_TYPE(SINT64, int64_t, Int64)
442       HANDLE_PACKED_TYPE(UINT32, uint32_t, UInt32)
443       HANDLE_PACKED_TYPE(UINT64, uint64_t, UInt64)
444 
445       HANDLE_PACKED_TYPE(FIXED32, uint32_t, UInt32)
446       HANDLE_PACKED_TYPE(FIXED64, uint64_t, UInt64)
447       HANDLE_PACKED_TYPE(SFIXED32, int32_t, Int32)
448       HANDLE_PACKED_TYPE(SFIXED64, int64_t, Int64)
449 
450       HANDLE_PACKED_TYPE(FLOAT, float, Float)
451       HANDLE_PACKED_TYPE(DOUBLE, double, Double)
452 
453       HANDLE_PACKED_TYPE(BOOL, bool, Bool)
454 #undef HANDLE_PACKED_TYPE
455 
456       case FieldDescriptor::TYPE_ENUM: {
457         while (input->BytesUntilLimit() > 0) {
458           int value;
459           if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
460                   input, &value))
461             return false;
462           if (!field->legacy_enum_field_treated_as_closed()) {
463             message_reflection->AddEnumValue(message, field, value);
464           } else {
465             const EnumValueDescriptor* enum_value =
466                 field->enum_type()->FindValueByNumber(value);
467             if (enum_value != nullptr) {
468               message_reflection->AddEnum(message, field, enum_value);
469             } else {
470               // The enum value is not one of the known values.  Add it to the
471               // UnknownFieldSet.
472               int64_t sign_extended_value = static_cast<int64_t>(value);
473               message_reflection->MutableUnknownFields(message)->AddVarint(
474                   WireFormatLite::GetTagFieldNumber(tag), sign_extended_value);
475             }
476           }
477         }
478 
479         break;
480       }
481 
482       case FieldDescriptor::TYPE_STRING:
483       case FieldDescriptor::TYPE_GROUP:
484       case FieldDescriptor::TYPE_MESSAGE:
485       case FieldDescriptor::TYPE_BYTES:
486         // Can't have packed fields of these types: these should be caught by
487         // the protocol compiler.
488         return false;
489         break;
490     }
491 
492     input->PopLimit(limit);
493   } else {
494     // Non-packed value (value_format == NORMAL_FORMAT)
495     switch (field->type()) {
496 #define HANDLE_TYPE(TYPE, CPPTYPE, CPPTYPE_METHOD)                            \
497   case FieldDescriptor::TYPE_##TYPE: {                                        \
498     CPPTYPE value;                                                            \
499     if (!WireFormatLite::ReadPrimitive<CPPTYPE, WireFormatLite::TYPE_##TYPE>( \
500             input, &value))                                                   \
501       return false;                                                           \
502     if (field->is_repeated()) {                                               \
503       message_reflection->Add##CPPTYPE_METHOD(message, field, value);         \
504     } else {                                                                  \
505       message_reflection->Set##CPPTYPE_METHOD(message, field, value);         \
506     }                                                                         \
507     break;                                                                    \
508   }
509 
510       HANDLE_TYPE(INT32, int32_t, Int32)
511       HANDLE_TYPE(INT64, int64_t, Int64)
512       HANDLE_TYPE(SINT32, int32_t, Int32)
513       HANDLE_TYPE(SINT64, int64_t, Int64)
514       HANDLE_TYPE(UINT32, uint32_t, UInt32)
515       HANDLE_TYPE(UINT64, uint64_t, UInt64)
516 
517       HANDLE_TYPE(FIXED32, uint32_t, UInt32)
518       HANDLE_TYPE(FIXED64, uint64_t, UInt64)
519       HANDLE_TYPE(SFIXED32, int32_t, Int32)
520       HANDLE_TYPE(SFIXED64, int64_t, Int64)
521 
522       HANDLE_TYPE(FLOAT, float, Float)
523       HANDLE_TYPE(DOUBLE, double, Double)
524 
525       HANDLE_TYPE(BOOL, bool, Bool)
526 #undef HANDLE_TYPE
527 
528       case FieldDescriptor::TYPE_ENUM: {
529         int value;
530         if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
531                 input, &value))
532           return false;
533         if (field->is_repeated()) {
534           message_reflection->AddEnumValue(message, field, value);
535         } else {
536           message_reflection->SetEnumValue(message, field, value);
537         }
538         break;
539       }
540 
541       // Handle strings separately so that we can optimize the ctype=CORD case.
542       case FieldDescriptor::TYPE_STRING: {
543         bool strict_utf8_check = field->requires_utf8_validation();
544         std::string value;
545         if (!WireFormatLite::ReadString(input, &value)) return false;
546         if (strict_utf8_check) {
547           if (!WireFormatLite::VerifyUtf8String(value.data(), value.length(),
548                                                 WireFormatLite::PARSE,
549                                                 field->full_name())) {
550             return false;
551           }
552         } else {
553           VerifyUTF8StringNamedField(value.data(), value.length(), PARSE,
554                                      field->full_name());
555         }
556         if (field->is_repeated()) {
557           message_reflection->AddString(message, field, value);
558         } else {
559           message_reflection->SetString(message, field, value);
560         }
561         break;
562       }
563 
564       case FieldDescriptor::TYPE_BYTES: {
565         if (field->cpp_string_type() == FieldDescriptor::CppStringType::kCord) {
566           absl::Cord value;
567           if (!WireFormatLite::ReadBytes(input, &value)) return false;
568           message_reflection->SetString(message, field, value);
569           break;
570         }
571         std::string value;
572         if (!WireFormatLite::ReadBytes(input, &value)) return false;
573         if (field->is_repeated()) {
574           message_reflection->AddString(message, field, value);
575         } else {
576           message_reflection->SetString(message, field, value);
577         }
578         break;
579       }
580 
581       case FieldDescriptor::TYPE_GROUP: {
582         Message* sub_message;
583         if (field->is_repeated()) {
584           sub_message = message_reflection->AddMessage(
585               message, field, input->GetExtensionFactory());
586         } else {
587           sub_message = message_reflection->MutableMessage(
588               message, field, input->GetExtensionFactory());
589         }
590 
591         if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
592                                        input, sub_message))
593           return false;
594         break;
595       }
596 
597       case FieldDescriptor::TYPE_MESSAGE: {
598         Message* sub_message;
599         if (field->is_repeated()) {
600           sub_message = message_reflection->AddMessage(
601               message, field, input->GetExtensionFactory());
602         } else {
603           sub_message = message_reflection->MutableMessage(
604               message, field, input->GetExtensionFactory());
605         }
606 
607         if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
608         break;
609       }
610     }
611   }
612 
613   return true;
614 }
615 
ParseAndMergeMessageSetItem(io::CodedInputStream * input,Message * message)616 bool WireFormat::ParseAndMergeMessageSetItem(io::CodedInputStream* input,
617                                              Message* message) {
618   struct MSReflective {
619     bool ParseField(int type_id, io::CodedInputStream* input) {
620       const FieldDescriptor* field =
621           message_reflection->FindKnownExtensionByNumber(type_id);
622       return ParseAndMergeMessageSetField(type_id, field, message, input);
623     }
624 
625     bool SkipField(uint32_t tag, io::CodedInputStream* input) {
626       return WireFormat::SkipField(input, tag, nullptr);
627     }
628 
629     const Reflection* message_reflection;
630     Message* message;
631   };
632 
633   return ParseMessageSetItemImpl(
634       input, MSReflective{message->GetReflection(), message});
635 }
636 
637 struct WireFormat::MessageSetParser {
ParseElementgoogle::protobuf::internal::WireFormat::MessageSetParser638   const char* ParseElement(const char* ptr, internal::ParseContext* ctx) {
639     // Parse a MessageSetItem
640     auto metadata = reflection->MutableInternalMetadata(msg);
641     enum class State { kNoTag, kHasType, kHasPayload, kDone };
642     State state = State::kNoTag;
643 
644     std::string payload;
645     uint32_t type_id = 0;
646     while (!ctx->Done(&ptr)) {
647       // We use 64 bit tags in order to allow typeid's that span the whole
648       // range of 32 bit numbers.
649       uint32_t tag = static_cast<uint8_t>(*ptr++);
650       if (tag == WireFormatLite::kMessageSetTypeIdTag) {
651         uint64_t tmp;
652         ptr = ParseBigVarint(ptr, &tmp);
653         // We should fail parsing if type id is 0 after cast to uint32.
654         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr &&
655                                        static_cast<uint32_t>(tmp) != 0);
656         if (state == State::kNoTag) {
657           type_id = static_cast<uint32_t>(tmp);
658           state = State::kHasType;
659         } else if (state == State::kHasPayload) {
660           type_id = static_cast<uint32_t>(tmp);
661           const FieldDescriptor* field;
662           if (ctx->data().pool == nullptr) {
663             field = reflection->FindKnownExtensionByNumber(type_id);
664           } else {
665             field =
666                 ctx->data().pool->FindExtensionByNumber(descriptor, type_id);
667           }
668           if (field == nullptr || field->message_type() == nullptr) {
669             WriteLengthDelimited(
670                 type_id, payload,
671                 metadata->mutable_unknown_fields<UnknownFieldSet>());
672           } else {
673             Message* value =
674                 field->is_repeated()
675                     ? reflection->AddMessage(msg, field, ctx->data().factory)
676                     : reflection->MutableMessage(msg, field,
677                                                  ctx->data().factory);
678             const char* p;
679             // We can't use regular parse from string as we have to track
680             // proper recursion depth and descriptor pools. Spawn a new
681             // ParseContext inheriting those attributes.
682             ParseContext tmp_ctx(ParseContext::kSpawn, *ctx, &p, payload);
683             GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
684                                            tmp_ctx.EndedAtLimit());
685           }
686           state = State::kDone;
687         }
688         continue;
689       } else if (tag == WireFormatLite::kMessageSetMessageTag) {
690         if (state == State::kNoTag) {
691           int32_t size = ReadSize(&ptr);
692           GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
693           ptr = ctx->ReadString(ptr, size, &payload);
694           GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
695           state = State::kHasPayload;
696         } else if (state == State::kHasType) {
697           // We're now parsing the payload
698           const FieldDescriptor* field = nullptr;
699           if (descriptor->IsExtensionNumber(type_id)) {
700             if (ctx->data().pool == nullptr) {
701               field = reflection->FindKnownExtensionByNumber(type_id);
702             } else {
703               field =
704                   ctx->data().pool->FindExtensionByNumber(descriptor, type_id);
705             }
706           }
707           ptr = WireFormat::_InternalParseAndMergeField(
708               msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
709               field);
710           state = State::kDone;
711         } else {
712           int32_t size = ReadSize(&ptr);
713           GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
714           ptr = ctx->Skip(ptr, size);
715           GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
716         }
717       } else {
718         // An unknown field in MessageSetItem.
719         ptr = ReadTag(ptr - 1, &tag);
720         if (tag == 0 || (tag & 7) == WireFormatLite::WIRETYPE_END_GROUP) {
721           ctx->SetLastTag(tag);
722           return ptr;
723         }
724         // Skip field.
725         ptr = internal::UnknownFieldParse(
726             tag, static_cast<std::string*>(nullptr), ptr, ctx);
727       }
728       GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
729     }
730     return ptr;
731   }
732 
ParseMessageSetgoogle::protobuf::internal::WireFormat::MessageSetParser733   const char* ParseMessageSet(const char* ptr, internal::ParseContext* ctx) {
734     while (!ctx->Done(&ptr)) {
735       uint32_t tag;
736       ptr = ReadTag(ptr, &tag);
737       if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) return nullptr;
738       if (tag == 0 || (tag & 7) == WireFormatLite::WIRETYPE_END_GROUP) {
739         ctx->SetLastTag(tag);
740         break;
741       }
742       if (tag == WireFormatLite::kMessageSetItemStartTag) {
743         // A message set item starts
744         ptr = ctx->ParseGroupInlined(
745             ptr, tag, [&](const char* ptr) { return ParseElement(ptr, ctx); });
746       } else {
747         // Parse other fields as normal extensions.
748         int field_number = WireFormatLite::GetTagFieldNumber(tag);
749         const FieldDescriptor* field = nullptr;
750         if (descriptor->IsExtensionNumber(field_number)) {
751           if (ctx->data().pool == nullptr) {
752             field = reflection->FindKnownExtensionByNumber(field_number);
753           } else {
754             field = ctx->data().pool->FindExtensionByNumber(descriptor,
755                                                             field_number);
756           }
757         }
758         ptr = WireFormat::_InternalParseAndMergeField(msg, ptr, ctx, tag,
759                                                       reflection, field);
760       }
761       if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) return nullptr;
762     }
763     return ptr;
764   }
765 
766   Message* msg;
767   const Descriptor* descriptor;
768   const Reflection* reflection;
769 };
770 
HandleMessage(Message * msg,const char * ptr,internal::ParseContext * ctx,uint64_t tag,const Reflection * reflection,const FieldDescriptor * field)771 static const char* HandleMessage(Message* msg, const char* ptr,
772                                  internal::ParseContext* ctx, uint64_t tag,
773                                  const Reflection* reflection,
774                                  const FieldDescriptor* field) {
775   Message* sub_message;
776   if (field->is_repeated()) {
777     sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
778   } else {
779     sub_message = reflection->MutableMessage(msg, field, ctx->data().factory);
780   }
781 
782   if (WireFormatLite::GetTagWireType(tag) ==
783       WireFormatLite::WIRETYPE_START_GROUP) {
784     return ctx->ParseGroup(sub_message, ptr, tag);
785   } else {
786     ABSL_DCHECK(WireFormatLite::GetTagWireType(tag) ==
787                 WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
788   }
789 
790   ptr = ctx->ParseMessage(sub_message, ptr);
791 
792   // For map entries, if the value is an unknown enum we have to push it
793   // into the unknown field set and remove it from the list.
794   if (ptr != nullptr && field->is_map()) {
795     auto* value_field = field->message_type()->map_value();
796     auto* enum_type = value_field->enum_type();
797     if (enum_type != nullptr &&
798         !internal::cpp::HasPreservingUnknownEnumSemantics(value_field) &&
799         enum_type->FindValueByNumber(sub_message->GetReflection()->GetEnumValue(
800             *sub_message, value_field)) == nullptr) {
801       reflection->MutableUnknownFields(msg)->AddLengthDelimited(
802           field->number(), sub_message->SerializeAsString());
803       reflection->RemoveLast(msg, field);
804     }
805   }
806   return ptr;
807 }
808 
_InternalParse(Message * msg,const char * ptr,internal::ParseContext * ctx)809 const char* WireFormat::_InternalParse(Message* msg, const char* ptr,
810                                        internal::ParseContext* ctx) {
811   const Descriptor* descriptor = msg->GetDescriptor();
812   const Reflection* reflection = msg->GetReflection();
813   ABSL_DCHECK(descriptor);
814   ABSL_DCHECK(reflection);
815   if (descriptor->options().message_set_wire_format()) {
816     MessageSetParser message_set{msg, descriptor, reflection};
817     return message_set.ParseMessageSet(ptr, ctx);
818   }
819   while (!ctx->Done(&ptr)) {
820     uint32_t tag;
821     ptr = ReadTag(ptr, &tag);
822     if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) return nullptr;
823     if (tag == 0 || (tag & 7) == WireFormatLite::WIRETYPE_END_GROUP) {
824       ctx->SetLastTag(tag);
825       break;
826     }
827     const FieldDescriptor* field = nullptr;
828 
829     int field_number = WireFormatLite::GetTagFieldNumber(tag);
830     field = descriptor->FindFieldByNumber(field_number);
831 
832     // If that failed, check if the field is an extension.
833     if (field == nullptr && descriptor->IsExtensionNumber(field_number)) {
834       if (ctx->data().pool == nullptr) {
835         field = reflection->FindKnownExtensionByNumber(field_number);
836       } else {
837         field =
838             ctx->data().pool->FindExtensionByNumber(descriptor, field_number);
839       }
840     }
841 
842     ptr = _InternalParseAndMergeField(msg, ptr, ctx, tag, reflection, field);
843     if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) return nullptr;
844   }
845   return ptr;
846 }
847 
_InternalParseAndMergeField(Message * msg,const char * ptr,internal::ParseContext * ctx,uint64_t tag,const Reflection * reflection,const FieldDescriptor * field)848 const char* WireFormat::_InternalParseAndMergeField(
849     Message* msg, const char* ptr, internal::ParseContext* ctx, uint64_t tag,
850     const Reflection* reflection, const FieldDescriptor* field) {
851   if (field == nullptr) {
852     // unknown field set parser takes 64bit tags, because message set type ids
853     // span the full 32 bit range making the tag span [0, 2^35) range.
854     return internal::UnknownFieldParse(
855         tag, reflection->MutableUnknownFields(msg), ptr, ctx);
856   }
857   if (WireFormatLite::GetTagWireType(tag) !=
858       WireTypeForFieldType(field->type())) {
859     if (field->is_packable() && WireFormatLite::GetTagWireType(tag) ==
860                                     WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
861       switch (field->type()) {
862 #define HANDLE_PACKED_TYPE(TYPE, CPPTYPE, CPPTYPE_METHOD)                   \
863   case FieldDescriptor::TYPE_##TYPE: {                                      \
864     ptr = internal::Packed##CPPTYPE_METHOD##Parser(                         \
865         reflection->MutableRepeatedFieldInternal<CPPTYPE>(msg, field), ptr, \
866         ctx);                                                               \
867     return ptr;                                                             \
868   }
869 
870         HANDLE_PACKED_TYPE(INT32, int32_t, Int32)
871         HANDLE_PACKED_TYPE(INT64, int64_t, Int64)
872         HANDLE_PACKED_TYPE(SINT32, int32_t, SInt32)
873         HANDLE_PACKED_TYPE(SINT64, int64_t, SInt64)
874         HANDLE_PACKED_TYPE(UINT32, uint32_t, UInt32)
875         HANDLE_PACKED_TYPE(UINT64, uint64_t, UInt64)
876 
877         HANDLE_PACKED_TYPE(FIXED32, uint32_t, Fixed32)
878         HANDLE_PACKED_TYPE(FIXED64, uint64_t, Fixed64)
879         HANDLE_PACKED_TYPE(SFIXED32, int32_t, SFixed32)
880         HANDLE_PACKED_TYPE(SFIXED64, int64_t, SFixed64)
881 
882         HANDLE_PACKED_TYPE(FLOAT, float, Float)
883         HANDLE_PACKED_TYPE(DOUBLE, double, Double)
884 
885         HANDLE_PACKED_TYPE(BOOL, bool, Bool)
886 #undef HANDLE_PACKED_TYPE
887 
888         case FieldDescriptor::TYPE_ENUM: {
889           auto rep_enum =
890               reflection->MutableRepeatedFieldInternal<int>(msg, field);
891           if (!field->legacy_enum_field_treated_as_closed()) {
892             ptr = internal::PackedEnumParser(rep_enum, ptr, ctx);
893           } else {
894             return ctx->ReadPackedVarint(
895                 ptr, [rep_enum, field, reflection, msg](int32_t val) {
896                   if (field->enum_type()->FindValueByNumber(val) != nullptr) {
897                     rep_enum->Add(val);
898                   } else {
899                     WriteVarint(field->number(), val,
900                                 reflection->MutableUnknownFields(msg));
901                   }
902                 });
903           }
904           return ptr;
905         }
906 
907         case FieldDescriptor::TYPE_STRING:
908         case FieldDescriptor::TYPE_GROUP:
909         case FieldDescriptor::TYPE_MESSAGE:
910         case FieldDescriptor::TYPE_BYTES:
911           ABSL_LOG(FATAL) << "Can't reach";
912           return nullptr;
913       }
914     } else {
915       // mismatched wiretype;
916       return internal::UnknownFieldParse(
917           tag, reflection->MutableUnknownFields(msg), ptr, ctx);
918     }
919   }
920 
921   // Non-packed value
922   bool utf8_check = false;
923   bool strict_utf8_check = false;
924   switch (field->type()) {
925 #define HANDLE_TYPE(TYPE, CPPTYPE, CPPTYPE_METHOD)        \
926   case FieldDescriptor::TYPE_##TYPE: {                    \
927     CPPTYPE value;                                        \
928     ptr = VarintParse(ptr, &value);                       \
929     if (ptr == nullptr) return nullptr;                   \
930     if (field->is_repeated()) {                           \
931       reflection->Add##CPPTYPE_METHOD(msg, field, value); \
932     } else {                                              \
933       reflection->Set##CPPTYPE_METHOD(msg, field, value); \
934     }                                                     \
935     return ptr;                                           \
936   }
937 
938     HANDLE_TYPE(BOOL, uint64_t, Bool)
939     HANDLE_TYPE(INT32, uint32_t, Int32)
940     HANDLE_TYPE(INT64, uint64_t, Int64)
941     HANDLE_TYPE(UINT32, uint32_t, UInt32)
942     HANDLE_TYPE(UINT64, uint64_t, UInt64)
943 
944     case FieldDescriptor::TYPE_SINT32: {
945       int32_t value = ReadVarintZigZag32(&ptr);
946       if (ptr == nullptr) return nullptr;
947       if (field->is_repeated()) {
948         reflection->AddInt32(msg, field, value);
949       } else {
950         reflection->SetInt32(msg, field, value);
951       }
952       return ptr;
953     }
954     case FieldDescriptor::TYPE_SINT64: {
955       int64_t value = ReadVarintZigZag64(&ptr);
956       if (ptr == nullptr) return nullptr;
957       if (field->is_repeated()) {
958         reflection->AddInt64(msg, field, value);
959       } else {
960         reflection->SetInt64(msg, field, value);
961       }
962       return ptr;
963     }
964 #undef HANDLE_TYPE
965 #define HANDLE_TYPE(TYPE, CPPTYPE, CPPTYPE_METHOD)        \
966   case FieldDescriptor::TYPE_##TYPE: {                    \
967     CPPTYPE value;                                        \
968     value = UnalignedLoad<CPPTYPE>(ptr);                  \
969     ptr += sizeof(CPPTYPE);                               \
970     if (field->is_repeated()) {                           \
971       reflection->Add##CPPTYPE_METHOD(msg, field, value); \
972     } else {                                              \
973       reflection->Set##CPPTYPE_METHOD(msg, field, value); \
974     }                                                     \
975     return ptr;                                           \
976   }
977 
978       HANDLE_TYPE(FIXED32, uint32_t, UInt32)
979       HANDLE_TYPE(FIXED64, uint64_t, UInt64)
980       HANDLE_TYPE(SFIXED32, int32_t, Int32)
981       HANDLE_TYPE(SFIXED64, int64_t, Int64)
982 
983       HANDLE_TYPE(FLOAT, float, Float)
984       HANDLE_TYPE(DOUBLE, double, Double)
985 
986 #undef HANDLE_TYPE
987 
988     case FieldDescriptor::TYPE_ENUM: {
989       uint32_t value;
990       ptr = VarintParse(ptr, &value);
991       if (ptr == nullptr) return nullptr;
992       if (field->is_repeated()) {
993         reflection->AddEnumValue(msg, field, value);
994       } else {
995         reflection->SetEnumValue(msg, field, value);
996       }
997       return ptr;
998     }
999 
1000     // Handle strings separately so that we can optimize the ctype=CORD case.
1001     case FieldDescriptor::TYPE_STRING:
1002       utf8_check = true;
1003       strict_utf8_check = field->requires_utf8_validation();
1004       ABSL_FALLTHROUGH_INTENDED;
1005     case FieldDescriptor::TYPE_BYTES: {
1006       int size = ReadSize(&ptr);
1007       if (ptr == nullptr) return nullptr;
1008       if (field->cpp_string_type() == FieldDescriptor::CppStringType::kCord) {
1009         absl::Cord value;
1010         ptr = ctx->ReadCord(ptr, size, &value);
1011         if (ptr == nullptr) return nullptr;
1012         reflection->SetString(msg, field, value);
1013         return ptr;
1014       }
1015       std::string value;
1016       ptr = ctx->ReadString(ptr, size, &value);
1017       if (ptr == nullptr) return nullptr;
1018       if (utf8_check) {
1019         if (strict_utf8_check) {
1020           if (!WireFormatLite::VerifyUtf8String(value.data(), value.length(),
1021                                                 WireFormatLite::PARSE,
1022                                                 field->full_name())) {
1023             return nullptr;
1024           }
1025         } else {
1026           VerifyUTF8StringNamedField(value.data(), value.length(), PARSE,
1027                                      field->full_name());
1028         }
1029       }
1030       if (field->is_repeated()) {
1031         reflection->AddString(msg, field, std::move(value));
1032       } else {
1033         reflection->SetString(msg, field, std::move(value));
1034       }
1035       return ptr;
1036     }
1037 
1038     case FieldDescriptor::TYPE_MESSAGE:
1039     case FieldDescriptor::TYPE_GROUP:
1040       return HandleMessage(msg, ptr, ctx, tag, reflection, field);
1041   }
1042 
1043   // GCC 8 complains about control reaching end of non-void function here.
1044   // Let's keep it happy by returning a nullptr.
1045   return nullptr;
1046 }
1047 
1048 // ===================================================================
1049 
_InternalSerialize(const Message & message,uint8_t * target,io::EpsCopyOutputStream * stream)1050 uint8_t* WireFormat::_InternalSerialize(const Message& message, uint8_t* target,
1051                                         io::EpsCopyOutputStream* stream) {
1052   const Descriptor* descriptor = message.GetDescriptor();
1053   const Reflection* message_reflection = message.GetReflection();
1054 
1055   std::vector<const FieldDescriptor*> fields;
1056 
1057   // Fields of map entry should always be serialized.
1058   if (descriptor->options().map_entry()) {
1059     for (int i = 0; i < descriptor->field_count(); i++) {
1060       fields.push_back(descriptor->field(i));
1061     }
1062   } else {
1063     message_reflection->ListFields(message, &fields);
1064   }
1065 
1066   for (auto field : fields) {
1067     target = InternalSerializeField(field, message, target, stream);
1068   }
1069 
1070   if (descriptor->options().message_set_wire_format()) {
1071     return InternalSerializeUnknownMessageSetItemsToArray(
1072         message_reflection->GetUnknownFields(message), target, stream);
1073   } else {
1074     return InternalSerializeUnknownFieldsToArray(
1075         message_reflection->GetUnknownFields(message), target, stream);
1076   }
1077 }
1078 
SerializeMapKeyWithCachedSizes(const FieldDescriptor * field,const MapKey & value,uint8_t * target,io::EpsCopyOutputStream * stream)1079 uint8_t* SerializeMapKeyWithCachedSizes(const FieldDescriptor* field,
1080                                         const MapKey& value, uint8_t* target,
1081                                         io::EpsCopyOutputStream* stream) {
1082   target = stream->EnsureSpace(target);
1083   switch (field->type()) {
1084     case FieldDescriptor::TYPE_DOUBLE:
1085     case FieldDescriptor::TYPE_FLOAT:
1086     case FieldDescriptor::TYPE_GROUP:
1087     case FieldDescriptor::TYPE_MESSAGE:
1088     case FieldDescriptor::TYPE_BYTES:
1089     case FieldDescriptor::TYPE_ENUM:
1090       ABSL_LOG(FATAL) << "Unsupported";
1091       break;
1092 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType)   \
1093   case FieldDescriptor::TYPE_##FieldType:                    \
1094     target = WireFormatLite::Write##CamelFieldType##ToArray( \
1095         1, value.Get##CamelCppType##Value(), target);        \
1096     break;
1097       CASE_TYPE(INT64, Int64, Int64)
1098       CASE_TYPE(UINT64, UInt64, UInt64)
1099       CASE_TYPE(INT32, Int32, Int32)
1100       CASE_TYPE(FIXED64, Fixed64, UInt64)
1101       CASE_TYPE(FIXED32, Fixed32, UInt32)
1102       CASE_TYPE(BOOL, Bool, Bool)
1103       CASE_TYPE(UINT32, UInt32, UInt32)
1104       CASE_TYPE(SFIXED32, SFixed32, Int32)
1105       CASE_TYPE(SFIXED64, SFixed64, Int64)
1106       CASE_TYPE(SINT32, SInt32, Int32)
1107       CASE_TYPE(SINT64, SInt64, Int64)
1108 #undef CASE_TYPE
1109     case FieldDescriptor::TYPE_STRING:
1110       target = stream->WriteString(1, value.GetStringValue(), target);
1111       break;
1112   }
1113   return target;
1114 }
1115 
SerializeMapValueRefWithCachedSizes(const FieldDescriptor * field,const MapValueConstRef & value,uint8_t * target,io::EpsCopyOutputStream * stream)1116 static uint8_t* SerializeMapValueRefWithCachedSizes(
1117     const FieldDescriptor* field, const MapValueConstRef& value,
1118     uint8_t* target, io::EpsCopyOutputStream* stream) {
1119   target = stream->EnsureSpace(target);
1120   switch (field->type()) {
1121 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType)   \
1122   case FieldDescriptor::TYPE_##FieldType:                    \
1123     target = WireFormatLite::Write##CamelFieldType##ToArray( \
1124         2, value.Get##CamelCppType##Value(), target);        \
1125     break;
1126     CASE_TYPE(INT64, Int64, Int64)
1127     CASE_TYPE(UINT64, UInt64, UInt64)
1128     CASE_TYPE(INT32, Int32, Int32)
1129     CASE_TYPE(FIXED64, Fixed64, UInt64)
1130     CASE_TYPE(FIXED32, Fixed32, UInt32)
1131     CASE_TYPE(BOOL, Bool, Bool)
1132     CASE_TYPE(UINT32, UInt32, UInt32)
1133     CASE_TYPE(SFIXED32, SFixed32, Int32)
1134     CASE_TYPE(SFIXED64, SFixed64, Int64)
1135     CASE_TYPE(SINT32, SInt32, Int32)
1136     CASE_TYPE(SINT64, SInt64, Int64)
1137     CASE_TYPE(ENUM, Enum, Enum)
1138     CASE_TYPE(DOUBLE, Double, Double)
1139     CASE_TYPE(FLOAT, Float, Float)
1140 #undef CASE_TYPE
1141     case FieldDescriptor::TYPE_STRING:
1142     case FieldDescriptor::TYPE_BYTES:
1143       target = stream->WriteString(2, value.GetStringValue(), target);
1144       break;
1145     case FieldDescriptor::TYPE_MESSAGE: {
1146       auto& msg = value.GetMessageValue();
1147       target = WireFormatLite::InternalWriteMessage(2, msg, msg.GetCachedSize(),
1148                                                     target, stream);
1149     } break;
1150     case FieldDescriptor::TYPE_GROUP:
1151       target = WireFormatLite::InternalWriteGroup(2, value.GetMessageValue(),
1152                                                   target, stream);
1153       break;
1154   }
1155   return target;
1156 }
1157 
1158 class MapKeySorter {
1159  public:
SortKey(const Message & message,const Reflection * reflection,const FieldDescriptor * field)1160   static std::vector<MapKey> SortKey(const Message& message,
1161                                      const Reflection* reflection,
1162                                      const FieldDescriptor* field) {
1163     std::vector<MapKey> sorted_key_list;
1164     for (MapIterator it =
1165              reflection->MapBegin(const_cast<Message*>(&message), field);
1166          it != reflection->MapEnd(const_cast<Message*>(&message), field);
1167          ++it) {
1168       sorted_key_list.push_back(it.GetKey());
1169     }
1170     MapKeyComparator comparator;
1171     std::sort(sorted_key_list.begin(), sorted_key_list.end(), comparator);
1172     return sorted_key_list;
1173   }
1174 
1175  private:
1176   class MapKeyComparator {
1177    public:
operator ()(const MapKey & a,const MapKey & b) const1178     bool operator()(const MapKey& a, const MapKey& b) const {
1179       ABSL_DCHECK(a.type() == b.type());
1180       switch (a.type()) {
1181 #define CASE_TYPE(CppType, CamelCppType)                                \
1182   case FieldDescriptor::CPPTYPE_##CppType: {                            \
1183     return a.Get##CamelCppType##Value() < b.Get##CamelCppType##Value(); \
1184   }
1185         CASE_TYPE(STRING, String)
1186         CASE_TYPE(INT64, Int64)
1187         CASE_TYPE(INT32, Int32)
1188         CASE_TYPE(UINT64, UInt64)
1189         CASE_TYPE(UINT32, UInt32)
1190         CASE_TYPE(BOOL, Bool)
1191 #undef CASE_TYPE
1192 
1193         default:
1194           ABSL_DLOG(FATAL) << "Invalid key for map field.";
1195           return true;
1196       }
1197     }
1198   };
1199 };
1200 
InternalSerializeMapEntry(const FieldDescriptor * field,const MapKey & key,const MapValueConstRef & value,uint8_t * target,io::EpsCopyOutputStream * stream)1201 static uint8_t* InternalSerializeMapEntry(const FieldDescriptor* field,
1202                                           const MapKey& key,
1203                                           const MapValueConstRef& value,
1204                                           uint8_t* target,
1205                                           io::EpsCopyOutputStream* stream) {
1206   const FieldDescriptor* key_field = field->message_type()->field(0);
1207   const FieldDescriptor* value_field = field->message_type()->field(1);
1208 
1209   size_t size = kMapEntryTagByteSize;
1210   size += MapKeyDataOnlyByteSize(key_field, key);
1211   size += MapValueRefDataOnlyByteSize(value_field, value);
1212   target = stream->EnsureSpace(target);
1213   target = WireFormatLite::WriteTagToArray(
1214       field->number(), WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
1215   target = io::CodedOutputStream::WriteVarint32ToArray(size, target);
1216   target = SerializeMapKeyWithCachedSizes(key_field, key, target, stream);
1217   target =
1218       SerializeMapValueRefWithCachedSizes(value_field, value, target, stream);
1219   return target;
1220 }
1221 
InternalSerializeField(const FieldDescriptor * field,const Message & message,uint8_t * target,io::EpsCopyOutputStream * stream)1222 uint8_t* WireFormat::InternalSerializeField(const FieldDescriptor* field,
1223                                             const Message& message,
1224                                             uint8_t* target,
1225                                             io::EpsCopyOutputStream* stream) {
1226   const Reflection* message_reflection = message.GetReflection();
1227 
1228   if (field->is_extension() &&
1229       field->containing_type()->options().message_set_wire_format() &&
1230       field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
1231       !field->is_repeated()) {
1232     return InternalSerializeMessageSetItem(field, message, target, stream);
1233   }
1234 
1235 
1236   // For map fields, we can use either repeated field reflection or map
1237   // reflection.  Our choice has some subtle effects.  If we use repeated field
1238   // reflection here, then the repeated field representation becomes
1239   // authoritative for this field: any existing references that came from map
1240   // reflection remain valid for reading, but mutations to them are lost and
1241   // will be overwritten next time we call map reflection!
1242   //
1243   // So far this mainly affects Python, which keeps long-term references to map
1244   // values around, and always uses map reflection.  See: b/35918691
1245   //
1246   // Here we choose to use map reflection API as long as the internal
1247   // map is valid. In this way, the serialization doesn't change map field's
1248   // internal state and existing references that came from map reflection remain
1249   // valid for both reading and writing.
1250   if (field->is_map()) {
1251     const MapFieldBase* map_field =
1252         message_reflection->GetMapData(message, field);
1253     if (map_field->IsMapValid()) {
1254       if (stream->IsSerializationDeterministic()) {
1255         std::vector<MapKey> sorted_key_list =
1256             MapKeySorter::SortKey(message, message_reflection, field);
1257         for (std::vector<MapKey>::iterator it = sorted_key_list.begin();
1258              it != sorted_key_list.end(); ++it) {
1259           MapValueConstRef map_value;
1260           message_reflection->LookupMapValue(message, field, *it, &map_value);
1261           target =
1262               InternalSerializeMapEntry(field, *it, map_value, target, stream);
1263         }
1264       } else {
1265         for (MapIterator it = message_reflection->MapBegin(
1266                  const_cast<Message*>(&message), field);
1267              it !=
1268              message_reflection->MapEnd(const_cast<Message*>(&message), field);
1269              ++it) {
1270           target = InternalSerializeMapEntry(field, it.GetKey(),
1271                                              it.GetValueRef(), target, stream);
1272         }
1273       }
1274 
1275       return target;
1276     }
1277   }
1278   int count = 0;
1279 
1280   if (field->is_repeated()) {
1281     count = message_reflection->FieldSize(message, field);
1282   } else if (field->containing_type()->options().map_entry()) {
1283     // Map entry fields always need to be serialized.
1284     count = 1;
1285   } else if (message_reflection->HasField(message, field)) {
1286     count = 1;
1287   }
1288 
1289   // map_entries is for maps that'll be deterministically serialized.
1290   std::vector<const Message*> map_entries;
1291   if (count > 1 && field->is_map() && stream->IsSerializationDeterministic()) {
1292     map_entries =
1293         DynamicMapSorter::Sort(message, count, message_reflection, field);
1294   }
1295 
1296   if (field->is_packed()) {
1297     if (count == 0) return target;
1298     target = stream->EnsureSpace(target);
1299     switch (field->type()) {
1300 #define HANDLE_PRIMITIVE_TYPE(TYPE, CPPTYPE, TYPE_METHOD, CPPTYPE_METHOD)      \
1301   case FieldDescriptor::TYPE_##TYPE: {                                         \
1302     auto r =                                                                   \
1303         message_reflection->GetRepeatedFieldInternal<CPPTYPE>(message, field); \
1304     target = stream->Write##TYPE_METHOD##Packed(                               \
1305         field->number(), r, FieldDataOnlyByteSize(field, message), target);    \
1306     break;                                                                     \
1307   }
1308 
1309       HANDLE_PRIMITIVE_TYPE(INT32, int32_t, Int32, Int32)
1310       HANDLE_PRIMITIVE_TYPE(INT64, int64_t, Int64, Int64)
1311       HANDLE_PRIMITIVE_TYPE(SINT32, int32_t, SInt32, Int32)
1312       HANDLE_PRIMITIVE_TYPE(SINT64, int64_t, SInt64, Int64)
1313       HANDLE_PRIMITIVE_TYPE(UINT32, uint32_t, UInt32, UInt32)
1314       HANDLE_PRIMITIVE_TYPE(UINT64, uint64_t, UInt64, UInt64)
1315       HANDLE_PRIMITIVE_TYPE(ENUM, int, Enum, Enum)
1316 
1317 #undef HANDLE_PRIMITIVE_TYPE
1318 #define HANDLE_PRIMITIVE_TYPE(TYPE, CPPTYPE, TYPE_METHOD, CPPTYPE_METHOD)      \
1319   case FieldDescriptor::TYPE_##TYPE: {                                         \
1320     auto r =                                                                   \
1321         message_reflection->GetRepeatedFieldInternal<CPPTYPE>(message, field); \
1322     target = stream->WriteFixedPacked(field->number(), r, target);             \
1323     break;                                                                     \
1324   }
1325 
1326       HANDLE_PRIMITIVE_TYPE(FIXED32, uint32_t, Fixed32, UInt32)
1327       HANDLE_PRIMITIVE_TYPE(FIXED64, uint64_t, Fixed64, UInt64)
1328       HANDLE_PRIMITIVE_TYPE(SFIXED32, int32_t, SFixed32, Int32)
1329       HANDLE_PRIMITIVE_TYPE(SFIXED64, int64_t, SFixed64, Int64)
1330 
1331       HANDLE_PRIMITIVE_TYPE(FLOAT, float, Float, Float)
1332       HANDLE_PRIMITIVE_TYPE(DOUBLE, double, Double, Double)
1333 
1334       HANDLE_PRIMITIVE_TYPE(BOOL, bool, Bool, Bool)
1335 #undef HANDLE_PRIMITIVE_TYPE
1336       default:
1337         ABSL_LOG(FATAL) << "Invalid descriptor";
1338     }
1339     return target;
1340   }
1341 
1342   auto get_message_from_field = [&message, &map_entries, message_reflection](
1343                                     const FieldDescriptor* field, int j) {
1344     if (!field->is_repeated()) {
1345       return &message_reflection->GetMessage(message, field);
1346     }
1347     if (!map_entries.empty()) {
1348       return map_entries[j];
1349     }
1350     return &message_reflection->GetRepeatedMessage(message, field, j);
1351   };
1352   for (int j = 0; j < count; j++) {
1353     target = stream->EnsureSpace(target);
1354     switch (field->type()) {
1355 #define HANDLE_PRIMITIVE_TYPE(TYPE, CPPTYPE, TYPE_METHOD, CPPTYPE_METHOD)     \
1356   case FieldDescriptor::TYPE_##TYPE: {                                        \
1357     const CPPTYPE value =                                                     \
1358         field->is_repeated()                                                  \
1359             ? message_reflection->GetRepeated##CPPTYPE_METHOD(message, field, \
1360                                                               j)              \
1361             : message_reflection->Get##CPPTYPE_METHOD(message, field);        \
1362     target = WireFormatLite::Write##TYPE_METHOD##ToArray(field->number(),     \
1363                                                          value, target);      \
1364     break;                                                                    \
1365   }
1366 
1367       HANDLE_PRIMITIVE_TYPE(INT32, int32_t, Int32, Int32)
1368       HANDLE_PRIMITIVE_TYPE(INT64, int64_t, Int64, Int64)
1369       HANDLE_PRIMITIVE_TYPE(SINT32, int32_t, SInt32, Int32)
1370       HANDLE_PRIMITIVE_TYPE(SINT64, int64_t, SInt64, Int64)
1371       HANDLE_PRIMITIVE_TYPE(UINT32, uint32_t, UInt32, UInt32)
1372       HANDLE_PRIMITIVE_TYPE(UINT64, uint64_t, UInt64, UInt64)
1373 
1374       HANDLE_PRIMITIVE_TYPE(FIXED32, uint32_t, Fixed32, UInt32)
1375       HANDLE_PRIMITIVE_TYPE(FIXED64, uint64_t, Fixed64, UInt64)
1376       HANDLE_PRIMITIVE_TYPE(SFIXED32, int32_t, SFixed32, Int32)
1377       HANDLE_PRIMITIVE_TYPE(SFIXED64, int64_t, SFixed64, Int64)
1378 
1379       HANDLE_PRIMITIVE_TYPE(FLOAT, float, Float, Float)
1380       HANDLE_PRIMITIVE_TYPE(DOUBLE, double, Double, Double)
1381 
1382       HANDLE_PRIMITIVE_TYPE(BOOL, bool, Bool, Bool)
1383 #undef HANDLE_PRIMITIVE_TYPE
1384 
1385       case FieldDescriptor::TYPE_GROUP: {
1386         auto* msg = get_message_from_field(field, j);
1387         target = WireFormatLite::InternalWriteGroup(field->number(), *msg,
1388                                                     target, stream);
1389       } break;
1390 
1391       case FieldDescriptor::TYPE_MESSAGE: {
1392         auto* msg = get_message_from_field(field, j);
1393         target = WireFormatLite::InternalWriteMessage(
1394             field->number(), *msg, msg->GetCachedSize(), target, stream);
1395       } break;
1396 
1397       case FieldDescriptor::TYPE_ENUM: {
1398         const EnumValueDescriptor* value =
1399             field->is_repeated()
1400                 ? message_reflection->GetRepeatedEnum(message, field, j)
1401                 : message_reflection->GetEnum(message, field);
1402         target = WireFormatLite::WriteEnumToArray(field->number(),
1403                                                   value->number(), target);
1404         break;
1405       }
1406 
1407       // Handle strings separately so that we can get string references
1408       // instead of copying.
1409       case FieldDescriptor::TYPE_STRING: {
1410         bool strict_utf8_check = field->requires_utf8_validation();
1411         std::string scratch;
1412         const std::string& value =
1413             field->is_repeated()
1414                 ? message_reflection->GetRepeatedStringReference(message, field,
1415                                                                  j, &scratch)
1416                 : message_reflection->GetStringReference(message, field,
1417                                                          &scratch);
1418         if (strict_utf8_check) {
1419           WireFormatLite::VerifyUtf8String(value.data(), value.length(),
1420                                            WireFormatLite::SERIALIZE,
1421                                            field->full_name());
1422         } else {
1423           VerifyUTF8StringNamedField(value.data(), value.length(), SERIALIZE,
1424                                      field->full_name());
1425         }
1426         target = stream->WriteString(field->number(), value, target);
1427         break;
1428       }
1429 
1430       case FieldDescriptor::TYPE_BYTES: {
1431         if (field->cpp_string_type() == FieldDescriptor::CppStringType::kCord) {
1432           absl::Cord value = message_reflection->GetCord(message, field);
1433           target = stream->WriteString(field->number(), value, target);
1434           break;
1435         }
1436         std::string scratch;
1437         const std::string& value =
1438             field->is_repeated()
1439                 ? message_reflection->GetRepeatedStringReference(message, field,
1440                                                                  j, &scratch)
1441                 : message_reflection->GetStringReference(message, field,
1442                                                          &scratch);
1443         target = stream->WriteString(field->number(), value, target);
1444         break;
1445       }
1446     }
1447   }
1448   return target;
1449 }
1450 
InternalSerializeMessageSetItem(const FieldDescriptor * field,const Message & message,uint8_t * target,io::EpsCopyOutputStream * stream)1451 uint8_t* WireFormat::InternalSerializeMessageSetItem(
1452     const FieldDescriptor* field, const Message& message, uint8_t* target,
1453     io::EpsCopyOutputStream* stream) {
1454   const Reflection* message_reflection = message.GetReflection();
1455 
1456   target = stream->EnsureSpace(target);
1457   // Start group.
1458   target = io::CodedOutputStream::WriteTagToArray(
1459       WireFormatLite::kMessageSetItemStartTag, target);
1460   // Write type ID.
1461   target = WireFormatLite::WriteUInt32ToArray(
1462       WireFormatLite::kMessageSetTypeIdNumber, field->number(), target);
1463     // Write message.
1464     auto& msg = message_reflection->GetMessage(message, field);
1465     target = WireFormatLite::InternalWriteMessage(
1466         WireFormatLite::kMessageSetMessageNumber, msg, msg.GetCachedSize(),
1467         target, stream);
1468   // End group.
1469   target = stream->EnsureSpace(target);
1470   target = io::CodedOutputStream::WriteTagToArray(
1471       WireFormatLite::kMessageSetItemEndTag, target);
1472   return target;
1473 }
1474 
1475 // ===================================================================
1476 
ByteSize(const Message & message)1477 size_t WireFormat::ByteSize(const Message& message) {
1478   const Descriptor* descriptor = message.GetDescriptor();
1479   const Reflection* message_reflection = message.GetReflection();
1480 
1481   size_t our_size = 0;
1482 
1483   std::vector<const FieldDescriptor*> fields;
1484 
1485   // Fields of map entry should always be serialized.
1486   if (descriptor->options().map_entry()) {
1487     for (int i = 0; i < descriptor->field_count(); i++) {
1488       fields.push_back(descriptor->field(i));
1489     }
1490   } else {
1491     message_reflection->ListFields(message, &fields);
1492   }
1493 
1494   for (const FieldDescriptor* field : fields) {
1495     our_size += FieldByteSize(field, message);
1496   }
1497 
1498   if (descriptor->options().message_set_wire_format()) {
1499     our_size += ComputeUnknownMessageSetItemsSize(
1500         message_reflection->GetUnknownFields(message));
1501   } else {
1502     our_size +=
1503         ComputeUnknownFieldsSize(message_reflection->GetUnknownFields(message));
1504   }
1505 
1506   return our_size;
1507 }
1508 
FieldByteSize(const FieldDescriptor * field,const Message & message)1509 size_t WireFormat::FieldByteSize(const FieldDescriptor* field,
1510                                  const Message& message) {
1511   const Reflection* message_reflection = message.GetReflection();
1512 
1513   if (field->is_extension() &&
1514       field->containing_type()->options().message_set_wire_format() &&
1515       field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
1516       !field->is_repeated()) {
1517     return MessageSetItemByteSize(field, message);
1518   }
1519 
1520   size_t count = 0;
1521   if (field->is_repeated()) {
1522     if (field->is_map()) {
1523       const MapFieldBase* map_field =
1524           message_reflection->GetMapData(message, field);
1525       if (map_field->IsMapValid()) {
1526         count = FromIntSize(map_field->size());
1527       } else {
1528         count = FromIntSize(message_reflection->FieldSize(message, field));
1529       }
1530     } else {
1531       count = FromIntSize(message_reflection->FieldSize(message, field));
1532     }
1533   } else if (field->containing_type()->options().map_entry()) {
1534     // Map entry fields always need to be serialized.
1535     count = 1;
1536   } else if (message_reflection->HasField(message, field)) {
1537     count = 1;
1538   }
1539 
1540   const size_t data_size = FieldDataOnlyByteSize(field, message);
1541   size_t our_size = data_size;
1542   if (field->is_packed()) {
1543     if (data_size > 0) {
1544       // Packed fields get serialized like a string, not their native type.
1545       // Technically this doesn't really matter; the size only changes if it's
1546       // a GROUP
1547       our_size += TagSize(field->number(), FieldDescriptor::TYPE_STRING);
1548       our_size += io::CodedOutputStream::VarintSize32(data_size);
1549     }
1550   } else {
1551     our_size += count * TagSize(field->number(), field->type());
1552   }
1553   return our_size;
1554 }
1555 
MapKeyDataOnlyByteSize(const FieldDescriptor * field,const MapKey & value)1556 size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field,
1557                               const MapKey& value) {
1558   ABSL_DCHECK_EQ(FieldDescriptor::TypeToCppType(field->type()), value.type());
1559   switch (field->type()) {
1560     case FieldDescriptor::TYPE_DOUBLE:
1561     case FieldDescriptor::TYPE_FLOAT:
1562     case FieldDescriptor::TYPE_GROUP:
1563     case FieldDescriptor::TYPE_MESSAGE:
1564     case FieldDescriptor::TYPE_BYTES:
1565     case FieldDescriptor::TYPE_ENUM:
1566       ABSL_LOG(FATAL) << "Unsupported";
1567       return 0;
1568 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \
1569   case FieldDescriptor::TYPE_##FieldType:                  \
1570     return WireFormatLite::CamelFieldType##Size(           \
1571         value.Get##CamelCppType##Value());
1572 
1573 #define FIXED_CASE_TYPE(FieldType, CamelFieldType) \
1574   case FieldDescriptor::TYPE_##FieldType:          \
1575     return WireFormatLite::k##CamelFieldType##Size;
1576 
1577       CASE_TYPE(INT32, Int32, Int32);
1578       CASE_TYPE(INT64, Int64, Int64);
1579       CASE_TYPE(UINT32, UInt32, UInt32);
1580       CASE_TYPE(UINT64, UInt64, UInt64);
1581       CASE_TYPE(SINT32, SInt32, Int32);
1582       CASE_TYPE(SINT64, SInt64, Int64);
1583       CASE_TYPE(STRING, String, String);
1584       FIXED_CASE_TYPE(FIXED32, Fixed32);
1585       FIXED_CASE_TYPE(FIXED64, Fixed64);
1586       FIXED_CASE_TYPE(SFIXED32, SFixed32);
1587       FIXED_CASE_TYPE(SFIXED64, SFixed64);
1588       FIXED_CASE_TYPE(BOOL, Bool);
1589 
1590 #undef CASE_TYPE
1591 #undef FIXED_CASE_TYPE
1592   }
1593   ABSL_LOG(FATAL) << "Cannot get here";
1594   return 0;
1595 }
1596 
MapValueRefDataOnlyByteSize(const FieldDescriptor * field,const MapValueConstRef & value)1597 static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field,
1598                                           const MapValueConstRef& value) {
1599   switch (field->type()) {
1600     case FieldDescriptor::TYPE_GROUP:
1601       ABSL_LOG(FATAL) << "Unsupported";
1602       return 0;
1603 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \
1604   case FieldDescriptor::TYPE_##FieldType:                  \
1605     return WireFormatLite::CamelFieldType##Size(           \
1606         value.Get##CamelCppType##Value());
1607 
1608 #define FIXED_CASE_TYPE(FieldType, CamelFieldType) \
1609   case FieldDescriptor::TYPE_##FieldType:          \
1610     return WireFormatLite::k##CamelFieldType##Size;
1611 
1612       CASE_TYPE(INT32, Int32, Int32);
1613       CASE_TYPE(INT64, Int64, Int64);
1614       CASE_TYPE(UINT32, UInt32, UInt32);
1615       CASE_TYPE(UINT64, UInt64, UInt64);
1616       CASE_TYPE(SINT32, SInt32, Int32);
1617       CASE_TYPE(SINT64, SInt64, Int64);
1618       CASE_TYPE(STRING, String, String);
1619       CASE_TYPE(BYTES, Bytes, String);
1620       CASE_TYPE(ENUM, Enum, Enum);
1621       CASE_TYPE(MESSAGE, Message, Message);
1622       FIXED_CASE_TYPE(FIXED32, Fixed32);
1623       FIXED_CASE_TYPE(FIXED64, Fixed64);
1624       FIXED_CASE_TYPE(SFIXED32, SFixed32);
1625       FIXED_CASE_TYPE(SFIXED64, SFixed64);
1626       FIXED_CASE_TYPE(DOUBLE, Double);
1627       FIXED_CASE_TYPE(FLOAT, Float);
1628       FIXED_CASE_TYPE(BOOL, Bool);
1629 
1630 #undef CASE_TYPE
1631 #undef FIXED_CASE_TYPE
1632   }
1633   ABSL_LOG(FATAL) << "Cannot get here";
1634   return 0;
1635 }
1636 
FieldDataOnlyByteSize(const FieldDescriptor * field,const Message & message)1637 size_t WireFormat::FieldDataOnlyByteSize(const FieldDescriptor* field,
1638                                          const Message& message) {
1639   const Reflection* message_reflection = message.GetReflection();
1640 
1641   size_t data_size = 0;
1642 
1643   if (field->is_map()) {
1644     const MapFieldBase* map_field =
1645         message_reflection->GetMapData(message, field);
1646     if (map_field->IsMapValid()) {
1647       MapIterator iter(const_cast<Message*>(&message), field);
1648       MapIterator end(const_cast<Message*>(&message), field);
1649       const FieldDescriptor* key_field = field->message_type()->field(0);
1650       const FieldDescriptor* value_field = field->message_type()->field(1);
1651       for (map_field->MapBegin(&iter), map_field->MapEnd(&end); iter != end;
1652            ++iter) {
1653         size_t size = kMapEntryTagByteSize;
1654         size += MapKeyDataOnlyByteSize(key_field, iter.GetKey());
1655         size += MapValueRefDataOnlyByteSize(value_field, iter.GetValueRef());
1656         data_size += WireFormatLite::LengthDelimitedSize(size);
1657       }
1658       return data_size;
1659     }
1660   }
1661 
1662   size_t count = 0;
1663   if (field->is_repeated()) {
1664     count =
1665         internal::FromIntSize(message_reflection->FieldSize(message, field));
1666   } else if (field->containing_type()->options().map_entry()) {
1667     // Map entry fields always need to be serialized.
1668     count = 1;
1669   } else if (message_reflection->HasField(message, field)) {
1670     count = 1;
1671   }
1672 
1673   switch (field->type()) {
1674 #define HANDLE_TYPE(TYPE, TYPE_METHOD, CPPTYPE_METHOD)                      \
1675   case FieldDescriptor::TYPE_##TYPE:                                        \
1676     if (field->is_repeated()) {                                             \
1677       for (size_t j = 0; j < count; j++) {                                  \
1678         data_size += WireFormatLite::TYPE_METHOD##Size(                     \
1679             message_reflection->GetRepeated##CPPTYPE_METHOD(message, field, \
1680                                                             j));            \
1681       }                                                                     \
1682     } else {                                                                \
1683       data_size += WireFormatLite::TYPE_METHOD##Size(                       \
1684           message_reflection->Get##CPPTYPE_METHOD(message, field));         \
1685     }                                                                       \
1686     break;
1687 
1688 #define HANDLE_FIXED_TYPE(TYPE, TYPE_METHOD)                   \
1689   case FieldDescriptor::TYPE_##TYPE:                           \
1690     data_size += count * WireFormatLite::k##TYPE_METHOD##Size; \
1691     break;
1692 
1693     HANDLE_TYPE(INT32, Int32, Int32)
1694     HANDLE_TYPE(INT64, Int64, Int64)
1695     HANDLE_TYPE(SINT32, SInt32, Int32)
1696     HANDLE_TYPE(SINT64, SInt64, Int64)
1697     HANDLE_TYPE(UINT32, UInt32, UInt32)
1698     HANDLE_TYPE(UINT64, UInt64, UInt64)
1699 
1700     HANDLE_FIXED_TYPE(FIXED32, Fixed32)
1701     HANDLE_FIXED_TYPE(FIXED64, Fixed64)
1702     HANDLE_FIXED_TYPE(SFIXED32, SFixed32)
1703     HANDLE_FIXED_TYPE(SFIXED64, SFixed64)
1704 
1705     HANDLE_FIXED_TYPE(FLOAT, Float)
1706     HANDLE_FIXED_TYPE(DOUBLE, Double)
1707 
1708     HANDLE_FIXED_TYPE(BOOL, Bool)
1709 
1710     HANDLE_TYPE(GROUP, Group, Message)
1711 
1712     case FieldDescriptor::TYPE_MESSAGE: {
1713       if (field->is_repeated()) {
1714         for (size_t j = 0; j < count; ++j) {
1715           data_size += WireFormatLite::MessageSize(
1716               message_reflection->GetRepeatedMessage(message, field, j));
1717         }
1718         break;
1719       }
1720       if (field->is_extension()) {
1721         data_size += WireFormatLite::LengthDelimitedSize(
1722             message_reflection->GetExtensionSet(message).GetMessageByteSizeLong(
1723                 field->number()));
1724         break;
1725       }
1726       data_size += WireFormatLite::MessageSize(
1727           message_reflection->GetMessage(message, field));
1728       break;
1729     }
1730 
1731 #undef HANDLE_TYPE
1732 #undef HANDLE_FIXED_TYPE
1733 
1734     case FieldDescriptor::TYPE_ENUM: {
1735       if (field->is_repeated()) {
1736         for (size_t j = 0; j < count; j++) {
1737           data_size += WireFormatLite::EnumSize(
1738               message_reflection->GetRepeatedEnum(message, field, j)->number());
1739         }
1740       } else {
1741         data_size += WireFormatLite::EnumSize(
1742             message_reflection->GetEnum(message, field)->number());
1743       }
1744       break;
1745     }
1746 
1747     // Handle strings separately so that we can get string references
1748     // instead of copying.
1749     case FieldDescriptor::TYPE_STRING:
1750     case FieldDescriptor::TYPE_BYTES: {
1751       if (field->cpp_string_type() == FieldDescriptor::CppStringType::kCord) {
1752         for (size_t j = 0; j < count; j++) {
1753           absl::Cord value = message_reflection->GetCord(message, field);
1754           data_size += WireFormatLite::StringSize(value);
1755         }
1756         break;
1757       }
1758       for (size_t j = 0; j < count; j++) {
1759         std::string scratch;
1760         const std::string& value =
1761             field->is_repeated()
1762                 ? message_reflection->GetRepeatedStringReference(message, field,
1763                                                                  j, &scratch)
1764                 : message_reflection->GetStringReference(message, field,
1765                                                          &scratch);
1766         data_size += WireFormatLite::StringSize(value);
1767       }
1768       break;
1769     }
1770   }
1771   return data_size;
1772 }
1773 
MessageSetItemByteSize(const FieldDescriptor * field,const Message & message)1774 size_t WireFormat::MessageSetItemByteSize(const FieldDescriptor* field,
1775                                           const Message& message) {
1776   const Reflection* message_reflection = message.GetReflection();
1777 
1778   size_t our_size = WireFormatLite::kMessageSetItemTagsSize;
1779 
1780   // type_id
1781   our_size += io::CodedOutputStream::VarintSize32(field->number());
1782 
1783   // message
1784   size_t message_size;
1785     const Message& sub_message = message_reflection->GetMessage(message, field);
1786     message_size = sub_message.ByteSizeLong();
1787 
1788   our_size += io::CodedOutputStream::VarintSize32(message_size);
1789   our_size += message_size;
1790 
1791   return our_size;
1792 }
1793 
1794 // Compute the size of the UnknownFieldSet on the wire.
ComputeUnknownFieldsSize(const InternalMetadata & metadata,size_t total_size,CachedSize * cached_size)1795 size_t ComputeUnknownFieldsSize(const InternalMetadata& metadata,
1796                                 size_t total_size, CachedSize* cached_size) {
1797   total_size += WireFormat::ComputeUnknownFieldsSize(
1798       metadata.unknown_fields<UnknownFieldSet>(
1799           UnknownFieldSet::default_instance));
1800   cached_size->Set(ToCachedSize(total_size));
1801   return total_size;
1802 }
1803 
1804 }  // namespace internal
1805 }  // namespace protobuf
1806 }  // namespace google
1807 
1808 #include "google/protobuf/port_undef.inc"
1809