• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: kenton@google.com (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 
35 #include <google/protobuf/extension_set.h>
36 
37 #include <tuple>
38 #include <unordered_map>
39 #include <utility>
40 #include <google/protobuf/stubs/common.h>
41 #include <google/protobuf/extension_set_inl.h>
42 #include <google/protobuf/parse_context.h>
43 #include <google/protobuf/io/coded_stream.h>
44 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
45 #include <google/protobuf/message_lite.h>
46 #include <google/protobuf/metadata_lite.h>
47 #include <google/protobuf/repeated_field.h>
48 #include <google/protobuf/stubs/map_util.h>
49 #include <google/protobuf/stubs/hash.h>
50 
51 #include <google/protobuf/port_def.inc>
52 
53 namespace google {
54 namespace protobuf {
55 namespace internal {
56 
57 namespace {
58 
real_type(FieldType type)59 inline WireFormatLite::FieldType real_type(FieldType type) {
60   GOOGLE_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
61   return static_cast<WireFormatLite::FieldType>(type);
62 }
63 
cpp_type(FieldType type)64 inline WireFormatLite::CppType cpp_type(FieldType type) {
65   return WireFormatLite::FieldTypeToCppType(real_type(type));
66 }
67 
is_packable(WireFormatLite::WireType type)68 inline bool is_packable(WireFormatLite::WireType type) {
69   switch (type) {
70     case WireFormatLite::WIRETYPE_VARINT:
71     case WireFormatLite::WIRETYPE_FIXED64:
72     case WireFormatLite::WIRETYPE_FIXED32:
73       return true;
74     case WireFormatLite::WIRETYPE_LENGTH_DELIMITED:
75     case WireFormatLite::WIRETYPE_START_GROUP:
76     case WireFormatLite::WIRETYPE_END_GROUP:
77       return false;
78 
79       // Do not add a default statement. Let the compiler complain when someone
80       // adds a new wire type.
81   }
82   GOOGLE_LOG(FATAL) << "can't reach here.";
83   return false;
84 }
85 
86 // Registry stuff.
87 struct ExtensionHasher {
operator ()google::protobuf::internal::__anon812ce0970111::ExtensionHasher88   std::size_t operator()(const std::pair<const MessageLite*, int>& p) const {
89     return std::hash<const MessageLite*>{}(p.first) ^
90            std::hash<int>{}(p.second);
91   }
92 };
93 
94 typedef std::unordered_map<std::pair<const MessageLite*, int>, ExtensionInfo,
95                            ExtensionHasher>
96     ExtensionRegistry;
97 
98 static const ExtensionRegistry* global_registry = nullptr;
99 
100 // This function is only called at startup, so there is no need for thread-
101 // safety.
Register(const MessageLite * containing_type,int number,ExtensionInfo info)102 void Register(const MessageLite* containing_type, int number,
103               ExtensionInfo info) {
104   static auto local_static_registry = OnShutdownDelete(new ExtensionRegistry);
105   global_registry = local_static_registry;
106   if (!InsertIfNotPresent(local_static_registry,
107                                std::make_pair(containing_type, number), info)) {
108     GOOGLE_LOG(FATAL) << "Multiple extension registrations for type \""
109                << containing_type->GetTypeName() << "\", field number "
110                << number << ".";
111   }
112 }
113 
FindRegisteredExtension(const MessageLite * containing_type,int number)114 const ExtensionInfo* FindRegisteredExtension(const MessageLite* containing_type,
115                                              int number) {
116   return global_registry == nullptr
117              ? nullptr
118              : FindOrNull(*global_registry,
119                                std::make_pair(containing_type, number));
120 }
121 
122 }  // namespace
123 
~ExtensionFinder()124 ExtensionFinder::~ExtensionFinder() {}
125 
Find(int number,ExtensionInfo * output)126 bool GeneratedExtensionFinder::Find(int number, ExtensionInfo* output) {
127   const ExtensionInfo* extension =
128       FindRegisteredExtension(containing_type_, number);
129   if (extension == NULL) {
130     return false;
131   } else {
132     *output = *extension;
133     return true;
134   }
135 }
136 
RegisterExtension(const MessageLite * containing_type,int number,FieldType type,bool is_repeated,bool is_packed)137 void ExtensionSet::RegisterExtension(const MessageLite* containing_type,
138                                      int number, FieldType type,
139                                      bool is_repeated, bool is_packed) {
140   GOOGLE_CHECK_NE(type, WireFormatLite::TYPE_ENUM);
141   GOOGLE_CHECK_NE(type, WireFormatLite::TYPE_MESSAGE);
142   GOOGLE_CHECK_NE(type, WireFormatLite::TYPE_GROUP);
143   ExtensionInfo info(type, is_repeated, is_packed);
144   Register(containing_type, number, info);
145 }
146 
CallNoArgValidityFunc(const void * arg,int number)147 static bool CallNoArgValidityFunc(const void* arg, int number) {
148   // Note:  Must use C-style cast here rather than reinterpret_cast because
149   //   the C++ standard at one point did not allow casts between function and
150   //   data pointers and some compilers enforce this for C++-style casts.  No
151   //   compiler enforces it for C-style casts since lots of C-style code has
152   //   relied on these kinds of casts for a long time, despite being
153   //   technically undefined.  See:
154   //     http://www.open-std.org/jtc1/sc22/wg21/docs/cwg_defects.html#195
155   // Also note:  Some compilers do not allow function pointers to be "const".
156   //   Which makes sense, I suppose, because it's meaningless.
157   return ((EnumValidityFunc*)arg)(number);
158 }
159 
RegisterEnumExtension(const MessageLite * containing_type,int number,FieldType type,bool is_repeated,bool is_packed,EnumValidityFunc * is_valid)160 void ExtensionSet::RegisterEnumExtension(const MessageLite* containing_type,
161                                          int number, FieldType type,
162                                          bool is_repeated, bool is_packed,
163                                          EnumValidityFunc* is_valid) {
164   GOOGLE_CHECK_EQ(type, WireFormatLite::TYPE_ENUM);
165   ExtensionInfo info(type, is_repeated, is_packed);
166   info.enum_validity_check.func = CallNoArgValidityFunc;
167   // See comment in CallNoArgValidityFunc() about why we use a c-style cast.
168   info.enum_validity_check.arg = (void*)is_valid;
169   Register(containing_type, number, info);
170 }
171 
RegisterMessageExtension(const MessageLite * containing_type,int number,FieldType type,bool is_repeated,bool is_packed,const MessageLite * prototype)172 void ExtensionSet::RegisterMessageExtension(const MessageLite* containing_type,
173                                             int number, FieldType type,
174                                             bool is_repeated, bool is_packed,
175                                             const MessageLite* prototype) {
176   GOOGLE_CHECK(type == WireFormatLite::TYPE_MESSAGE ||
177         type == WireFormatLite::TYPE_GROUP);
178   ExtensionInfo info(type, is_repeated, is_packed);
179   info.message_info = {prototype};
180   Register(containing_type, number, info);
181 }
182 
183 
184 // ===================================================================
185 // Constructors and basic methods.
186 
ExtensionSet(Arena * arena)187 ExtensionSet::ExtensionSet(Arena* arena)
188     : arena_(arena),
189       flat_capacity_(0),
190       flat_size_(0),
191       map_{flat_capacity_ == 0
192                ? NULL
193                : Arena::CreateArray<KeyValue>(arena_, flat_capacity_)} {}
194 
ExtensionSet()195 ExtensionSet::ExtensionSet()
196     : arena_(NULL),
197       flat_capacity_(0),
198       flat_size_(0),
199       map_{flat_capacity_ == 0
200                ? NULL
201                : Arena::CreateArray<KeyValue>(arena_, flat_capacity_)} {}
202 
~ExtensionSet()203 ExtensionSet::~ExtensionSet() {
204   // Deletes all allocated extensions.
205   if (arena_ == NULL) {
206     ForEach([](int /* number */, Extension& ext) { ext.Free(); });
207     if (PROTOBUF_PREDICT_FALSE(is_large())) {
208       delete map_.large;
209     } else {
210       DeleteFlatMap(map_.flat, flat_capacity_);
211     }
212   }
213 }
214 
DeleteFlatMap(const ExtensionSet::KeyValue * flat,uint16 flat_capacity)215 void ExtensionSet::DeleteFlatMap(const ExtensionSet::KeyValue* flat,
216                                  uint16 flat_capacity) {
217 #ifdef __cpp_sized_deallocation
218   // Arena::CreateArray already requires a trivially destructible type, but
219   // ensure this constraint is not violated in the future.
220   static_assert(std::is_trivially_destructible<KeyValue>::value,
221                 "CreateArray requires a trivially destructible type");
222   // A const-cast is needed, but this is safe as we are about to deallocate the
223   // array.
224   ::operator delete[](const_cast<ExtensionSet::KeyValue*>(flat),
225                       sizeof(*flat) * flat_capacity);
226 #else   // !__cpp_sized_deallocation
227   delete[] flat;
228 #endif  // !__cpp_sized_deallocation
229 }
230 
231 // Defined in extension_set_heavy.cc.
232 // void ExtensionSet::AppendToList(const Descriptor* containing_type,
233 //                                 const DescriptorPool* pool,
234 //                                 vector<const FieldDescriptor*>* output) const
235 
Has(int number) const236 bool ExtensionSet::Has(int number) const {
237   const Extension* ext = FindOrNull(number);
238   if (ext == NULL) return false;
239   GOOGLE_DCHECK(!ext->is_repeated);
240   return !ext->is_cleared;
241 }
242 
NumExtensions() const243 int ExtensionSet::NumExtensions() const {
244   int result = 0;
245   ForEach([&result](int /* number */, const Extension& ext) {
246     if (!ext.is_cleared) {
247       ++result;
248     }
249   });
250   return result;
251 }
252 
ExtensionSize(int number) const253 int ExtensionSet::ExtensionSize(int number) const {
254   const Extension* ext = FindOrNull(number);
255   return ext == NULL ? 0 : ext->GetSize();
256 }
257 
ExtensionType(int number) const258 FieldType ExtensionSet::ExtensionType(int number) const {
259   const Extension* ext = FindOrNull(number);
260   if (ext == NULL) {
261     GOOGLE_LOG(DFATAL) << "Don't lookup extension types if they aren't present (1). ";
262     return 0;
263   }
264   if (ext->is_cleared) {
265     GOOGLE_LOG(DFATAL) << "Don't lookup extension types if they aren't present (2). ";
266   }
267   return ext->type;
268 }
269 
ClearExtension(int number)270 void ExtensionSet::ClearExtension(int number) {
271   Extension* ext = FindOrNull(number);
272   if (ext == NULL) return;
273   ext->Clear();
274 }
275 
276 // ===================================================================
277 // Field accessors
278 
279 namespace {
280 
281 enum { REPEATED_FIELD, OPTIONAL_FIELD };
282 
283 }  // namespace
284 
285 #define GOOGLE_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE)                                 \
286   GOOGLE_DCHECK_EQ((EXTENSION).is_repeated ? REPEATED_FIELD : OPTIONAL_FIELD, LABEL); \
287   GOOGLE_DCHECK_EQ(cpp_type((EXTENSION).type), WireFormatLite::CPPTYPE_##CPPTYPE)
288 
289 // -------------------------------------------------------------------
290 // Primitives
291 
292 #define PRIMITIVE_ACCESSORS(UPPERCASE, LOWERCASE, CAMELCASE)                  \
293                                                                               \
294   LOWERCASE ExtensionSet::Get##CAMELCASE(int number, LOWERCASE default_value) \
295       const {                                                                 \
296     const Extension* extension = FindOrNull(number);                          \
297     if (extension == NULL || extension->is_cleared) {                         \
298       return default_value;                                                   \
299     } else {                                                                  \
300       GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE);                     \
301       return extension->LOWERCASE##_value;                                    \
302     }                                                                         \
303   }                                                                           \
304                                                                               \
305   void ExtensionSet::Set##CAMELCASE(int number, FieldType type,               \
306                                     LOWERCASE value,                          \
307                                     const FieldDescriptor* descriptor) {      \
308     Extension* extension;                                                     \
309     if (MaybeNewExtension(number, descriptor, &extension)) {                  \
310       extension->type = type;                                                 \
311       GOOGLE_DCHECK_EQ(cpp_type(extension->type),                                    \
312                 WireFormatLite::CPPTYPE_##UPPERCASE);                         \
313       extension->is_repeated = false;                                         \
314     } else {                                                                  \
315       GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE);                     \
316     }                                                                         \
317     extension->is_cleared = false;                                            \
318     extension->LOWERCASE##_value = value;                                     \
319   }                                                                           \
320                                                                               \
321   LOWERCASE ExtensionSet::GetRepeated##CAMELCASE(int number, int index)       \
322       const {                                                                 \
323     const Extension* extension = FindOrNull(number);                          \
324     GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";      \
325     GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE);                       \
326     return extension->repeated_##LOWERCASE##_value->Get(index);               \
327   }                                                                           \
328                                                                               \
329   void ExtensionSet::SetRepeated##CAMELCASE(int number, int index,            \
330                                             LOWERCASE value) {                \
331     Extension* extension = FindOrNull(number);                                \
332     GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";      \
333     GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE);                       \
334     extension->repeated_##LOWERCASE##_value->Set(index, value);               \
335   }                                                                           \
336                                                                               \
337   void ExtensionSet::Add##CAMELCASE(int number, FieldType type, bool packed,  \
338                                     LOWERCASE value,                          \
339                                     const FieldDescriptor* descriptor) {      \
340     Extension* extension;                                                     \
341     if (MaybeNewExtension(number, descriptor, &extension)) {                  \
342       extension->type = type;                                                 \
343       GOOGLE_DCHECK_EQ(cpp_type(extension->type),                                    \
344                 WireFormatLite::CPPTYPE_##UPPERCASE);                         \
345       extension->is_repeated = true;                                          \
346       extension->is_packed = packed;                                          \
347       extension->repeated_##LOWERCASE##_value =                               \
348           Arena::CreateMessage<RepeatedField<LOWERCASE>>(arena_);             \
349     } else {                                                                  \
350       GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE);                     \
351       GOOGLE_DCHECK_EQ(extension->is_packed, packed);                                \
352     }                                                                         \
353     extension->repeated_##LOWERCASE##_value->Add(value);                      \
354   }
355 
PRIMITIVE_ACCESSORS(INT32,int32,Int32)356 PRIMITIVE_ACCESSORS(INT32, int32, Int32)
357 PRIMITIVE_ACCESSORS(INT64, int64, Int64)
358 PRIMITIVE_ACCESSORS(UINT32, uint32, UInt32)
359 PRIMITIVE_ACCESSORS(UINT64, uint64, UInt64)
360 PRIMITIVE_ACCESSORS(FLOAT, float, Float)
361 PRIMITIVE_ACCESSORS(DOUBLE, double, Double)
362 PRIMITIVE_ACCESSORS(BOOL, bool, Bool)
363 
364 #undef PRIMITIVE_ACCESSORS
365 
366 const void* ExtensionSet::GetRawRepeatedField(int number,
367                                               const void* default_value) const {
368   const Extension* extension = FindOrNull(number);
369   if (extension == NULL) {
370     return default_value;
371   }
372   // We assume that all the RepeatedField<>* pointers have the same
373   // size and alignment within the anonymous union in Extension.
374   return extension->repeated_int32_value;
375 }
376 
MutableRawRepeatedField(int number,FieldType field_type,bool packed,const FieldDescriptor * desc)377 void* ExtensionSet::MutableRawRepeatedField(int number, FieldType field_type,
378                                             bool packed,
379                                             const FieldDescriptor* desc) {
380   Extension* extension;
381 
382   // We instantiate an empty Repeated{,Ptr}Field if one doesn't exist for this
383   // extension.
384   if (MaybeNewExtension(number, desc, &extension)) {
385     extension->is_repeated = true;
386     extension->type = field_type;
387     extension->is_packed = packed;
388 
389     switch (WireFormatLite::FieldTypeToCppType(
390         static_cast<WireFormatLite::FieldType>(field_type))) {
391       case WireFormatLite::CPPTYPE_INT32:
392         extension->repeated_int32_value =
393             Arena::CreateMessage<RepeatedField<int32>>(arena_);
394         break;
395       case WireFormatLite::CPPTYPE_INT64:
396         extension->repeated_int64_value =
397             Arena::CreateMessage<RepeatedField<int64>>(arena_);
398         break;
399       case WireFormatLite::CPPTYPE_UINT32:
400         extension->repeated_uint32_value =
401             Arena::CreateMessage<RepeatedField<uint32>>(arena_);
402         break;
403       case WireFormatLite::CPPTYPE_UINT64:
404         extension->repeated_uint64_value =
405             Arena::CreateMessage<RepeatedField<uint64>>(arena_);
406         break;
407       case WireFormatLite::CPPTYPE_DOUBLE:
408         extension->repeated_double_value =
409             Arena::CreateMessage<RepeatedField<double>>(arena_);
410         break;
411       case WireFormatLite::CPPTYPE_FLOAT:
412         extension->repeated_float_value =
413             Arena::CreateMessage<RepeatedField<float>>(arena_);
414         break;
415       case WireFormatLite::CPPTYPE_BOOL:
416         extension->repeated_bool_value =
417             Arena::CreateMessage<RepeatedField<bool>>(arena_);
418         break;
419       case WireFormatLite::CPPTYPE_ENUM:
420         extension->repeated_enum_value =
421             Arena::CreateMessage<RepeatedField<int>>(arena_);
422         break;
423       case WireFormatLite::CPPTYPE_STRING:
424         extension->repeated_string_value =
425             Arena::CreateMessage<RepeatedPtrField<std::string>>(arena_);
426         break;
427       case WireFormatLite::CPPTYPE_MESSAGE:
428         extension->repeated_message_value =
429             Arena::CreateMessage<RepeatedPtrField<MessageLite>>(arena_);
430         break;
431     }
432   }
433 
434   // We assume that all the RepeatedField<>* pointers have the same
435   // size and alignment within the anonymous union in Extension.
436   return extension->repeated_int32_value;
437 }
438 
439 // Compatible version using old call signature. Does not create extensions when
440 // the don't already exist; instead, just GOOGLE_CHECK-fails.
MutableRawRepeatedField(int number)441 void* ExtensionSet::MutableRawRepeatedField(int number) {
442   Extension* extension = FindOrNull(number);
443   GOOGLE_CHECK(extension != NULL) << "Extension not found.";
444   // We assume that all the RepeatedField<>* pointers have the same
445   // size and alignment within the anonymous union in Extension.
446   return extension->repeated_int32_value;
447 }
448 
449 // -------------------------------------------------------------------
450 // Enums
451 
GetEnum(int number,int default_value) const452 int ExtensionSet::GetEnum(int number, int default_value) const {
453   const Extension* extension = FindOrNull(number);
454   if (extension == NULL || extension->is_cleared) {
455     // Not present.  Return the default value.
456     return default_value;
457   } else {
458     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
459     return extension->enum_value;
460   }
461 }
462 
SetEnum(int number,FieldType type,int value,const FieldDescriptor * descriptor)463 void ExtensionSet::SetEnum(int number, FieldType type, int value,
464                            const FieldDescriptor* descriptor) {
465   Extension* extension;
466   if (MaybeNewExtension(number, descriptor, &extension)) {
467     extension->type = type;
468     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
469     extension->is_repeated = false;
470   } else {
471     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
472   }
473   extension->is_cleared = false;
474   extension->enum_value = value;
475 }
476 
GetRepeatedEnum(int number,int index) const477 int ExtensionSet::GetRepeatedEnum(int number, int index) const {
478   const Extension* extension = FindOrNull(number);
479   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
480   GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
481   return extension->repeated_enum_value->Get(index);
482 }
483 
SetRepeatedEnum(int number,int index,int value)484 void ExtensionSet::SetRepeatedEnum(int number, int index, int value) {
485   Extension* extension = FindOrNull(number);
486   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
487   GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
488   extension->repeated_enum_value->Set(index, value);
489 }
490 
AddEnum(int number,FieldType type,bool packed,int value,const FieldDescriptor * descriptor)491 void ExtensionSet::AddEnum(int number, FieldType type, bool packed, int value,
492                            const FieldDescriptor* descriptor) {
493   Extension* extension;
494   if (MaybeNewExtension(number, descriptor, &extension)) {
495     extension->type = type;
496     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
497     extension->is_repeated = true;
498     extension->is_packed = packed;
499     extension->repeated_enum_value =
500         Arena::CreateMessage<RepeatedField<int>>(arena_);
501   } else {
502     GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
503     GOOGLE_DCHECK_EQ(extension->is_packed, packed);
504   }
505   extension->repeated_enum_value->Add(value);
506 }
507 
508 // -------------------------------------------------------------------
509 // Strings
510 
GetString(int number,const std::string & default_value) const511 const std::string& ExtensionSet::GetString(
512     int number, const std::string& default_value) const {
513   const Extension* extension = FindOrNull(number);
514   if (extension == NULL || extension->is_cleared) {
515     // Not present.  Return the default value.
516     return default_value;
517   } else {
518     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
519     return *extension->string_value;
520   }
521 }
522 
MutableString(int number,FieldType type,const FieldDescriptor * descriptor)523 std::string* ExtensionSet::MutableString(int number, FieldType type,
524                                          const FieldDescriptor* descriptor) {
525   Extension* extension;
526   if (MaybeNewExtension(number, descriptor, &extension)) {
527     extension->type = type;
528     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
529     extension->is_repeated = false;
530     extension->string_value = Arena::Create<std::string>(arena_);
531   } else {
532     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
533   }
534   extension->is_cleared = false;
535   return extension->string_value;
536 }
537 
GetRepeatedString(int number,int index) const538 const std::string& ExtensionSet::GetRepeatedString(int number,
539                                                    int index) const {
540   const Extension* extension = FindOrNull(number);
541   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
542   GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
543   return extension->repeated_string_value->Get(index);
544 }
545 
MutableRepeatedString(int number,int index)546 std::string* ExtensionSet::MutableRepeatedString(int number, int index) {
547   Extension* extension = FindOrNull(number);
548   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
549   GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
550   return extension->repeated_string_value->Mutable(index);
551 }
552 
AddString(int number,FieldType type,const FieldDescriptor * descriptor)553 std::string* ExtensionSet::AddString(int number, FieldType type,
554                                      const FieldDescriptor* descriptor) {
555   Extension* extension;
556   if (MaybeNewExtension(number, descriptor, &extension)) {
557     extension->type = type;
558     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
559     extension->is_repeated = true;
560     extension->is_packed = false;
561     extension->repeated_string_value =
562         Arena::CreateMessage<RepeatedPtrField<std::string>>(arena_);
563   } else {
564     GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
565   }
566   return extension->repeated_string_value->Add();
567 }
568 
569 // -------------------------------------------------------------------
570 // Messages
571 
GetMessage(int number,const MessageLite & default_value) const572 const MessageLite& ExtensionSet::GetMessage(
573     int number, const MessageLite& default_value) const {
574   const Extension* extension = FindOrNull(number);
575   if (extension == NULL) {
576     // Not present.  Return the default value.
577     return default_value;
578   } else {
579     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
580     if (extension->is_lazy) {
581       return extension->lazymessage_value->GetMessage(default_value);
582     } else {
583       return *extension->message_value;
584     }
585   }
586 }
587 
588 // Defined in extension_set_heavy.cc.
589 // const MessageLite& ExtensionSet::GetMessage(int number,
590 //                                             const Descriptor* message_type,
591 //                                             MessageFactory* factory) const
592 
MutableMessage(int number,FieldType type,const MessageLite & prototype,const FieldDescriptor * descriptor)593 MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
594                                           const MessageLite& prototype,
595                                           const FieldDescriptor* descriptor) {
596   Extension* extension;
597   if (MaybeNewExtension(number, descriptor, &extension)) {
598     extension->type = type;
599     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
600     extension->is_repeated = false;
601     extension->is_lazy = false;
602     extension->message_value = prototype.New(arena_);
603     extension->is_cleared = false;
604     return extension->message_value;
605   } else {
606     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
607     extension->is_cleared = false;
608     if (extension->is_lazy) {
609       return extension->lazymessage_value->MutableMessage(prototype);
610     } else {
611       return extension->message_value;
612     }
613   }
614 }
615 
616 // Defined in extension_set_heavy.cc.
617 // MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
618 //                                           const Descriptor* message_type,
619 //                                           MessageFactory* factory)
620 
SetAllocatedMessage(int number,FieldType type,const FieldDescriptor * descriptor,MessageLite * message)621 void ExtensionSet::SetAllocatedMessage(int number, FieldType type,
622                                        const FieldDescriptor* descriptor,
623                                        MessageLite* message) {
624   if (message == NULL) {
625     ClearExtension(number);
626     return;
627   }
628   Arena* message_arena = message->GetArena();
629   Extension* extension;
630   if (MaybeNewExtension(number, descriptor, &extension)) {
631     extension->type = type;
632     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
633     extension->is_repeated = false;
634     extension->is_lazy = false;
635     if (message_arena == arena_) {
636       extension->message_value = message;
637     } else if (message_arena == NULL) {
638       extension->message_value = message;
639       arena_->Own(message);  // not NULL because not equal to message_arena
640     } else {
641       extension->message_value = message->New(arena_);
642       extension->message_value->CheckTypeAndMergeFrom(*message);
643     }
644   } else {
645     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
646     if (extension->is_lazy) {
647       extension->lazymessage_value->SetAllocatedMessage(message);
648     } else {
649       if (arena_ == NULL) {
650         delete extension->message_value;
651       }
652       if (message_arena == arena_) {
653         extension->message_value = message;
654       } else if (message_arena == NULL) {
655         extension->message_value = message;
656         arena_->Own(message);  // not NULL because not equal to message_arena
657       } else {
658         extension->message_value = message->New(arena_);
659         extension->message_value->CheckTypeAndMergeFrom(*message);
660       }
661     }
662   }
663   extension->is_cleared = false;
664 }
665 
UnsafeArenaSetAllocatedMessage(int number,FieldType type,const FieldDescriptor * descriptor,MessageLite * message)666 void ExtensionSet::UnsafeArenaSetAllocatedMessage(
667     int number, FieldType type, const FieldDescriptor* descriptor,
668     MessageLite* message) {
669   if (message == NULL) {
670     ClearExtension(number);
671     return;
672   }
673   Extension* extension;
674   if (MaybeNewExtension(number, descriptor, &extension)) {
675     extension->type = type;
676     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
677     extension->is_repeated = false;
678     extension->is_lazy = false;
679     extension->message_value = message;
680   } else {
681     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
682     if (extension->is_lazy) {
683       extension->lazymessage_value->UnsafeArenaSetAllocatedMessage(message);
684     } else {
685       if (arena_ == NULL) {
686         delete extension->message_value;
687       }
688       extension->message_value = message;
689     }
690   }
691   extension->is_cleared = false;
692 }
693 
ReleaseMessage(int number,const MessageLite & prototype)694 MessageLite* ExtensionSet::ReleaseMessage(int number,
695                                           const MessageLite& prototype) {
696   Extension* extension = FindOrNull(number);
697   if (extension == NULL) {
698     // Not present.  Return NULL.
699     return NULL;
700   } else {
701     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
702     MessageLite* ret = NULL;
703     if (extension->is_lazy) {
704       ret = extension->lazymessage_value->ReleaseMessage(prototype);
705       if (arena_ == NULL) {
706         delete extension->lazymessage_value;
707       }
708     } else {
709       if (arena_ == NULL) {
710         ret = extension->message_value;
711       } else {
712         // ReleaseMessage() always returns a heap-allocated message, and we are
713         // on an arena, so we need to make a copy of this message to return.
714         ret = extension->message_value->New();
715         ret->CheckTypeAndMergeFrom(*extension->message_value);
716       }
717     }
718     Erase(number);
719     return ret;
720   }
721 }
722 
UnsafeArenaReleaseMessage(int number,const MessageLite & prototype)723 MessageLite* ExtensionSet::UnsafeArenaReleaseMessage(
724     int number, const MessageLite& prototype) {
725   Extension* extension = FindOrNull(number);
726   if (extension == NULL) {
727     // Not present.  Return NULL.
728     return NULL;
729   } else {
730     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
731     MessageLite* ret = NULL;
732     if (extension->is_lazy) {
733       ret = extension->lazymessage_value->UnsafeArenaReleaseMessage(prototype);
734       if (arena_ == NULL) {
735         delete extension->lazymessage_value;
736       }
737     } else {
738       ret = extension->message_value;
739     }
740     Erase(number);
741     return ret;
742   }
743 }
744 
745 // Defined in extension_set_heavy.cc.
746 // MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
747 //                                           MessageFactory* factory);
748 
GetRepeatedMessage(int number,int index) const749 const MessageLite& ExtensionSet::GetRepeatedMessage(int number,
750                                                     int index) const {
751   const Extension* extension = FindOrNull(number);
752   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
753   GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
754   return extension->repeated_message_value->Get(index);
755 }
756 
MutableRepeatedMessage(int number,int index)757 MessageLite* ExtensionSet::MutableRepeatedMessage(int number, int index) {
758   Extension* extension = FindOrNull(number);
759   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
760   GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
761   return extension->repeated_message_value->Mutable(index);
762 }
763 
AddMessage(int number,FieldType type,const MessageLite & prototype,const FieldDescriptor * descriptor)764 MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
765                                       const MessageLite& prototype,
766                                       const FieldDescriptor* descriptor) {
767   Extension* extension;
768   if (MaybeNewExtension(number, descriptor, &extension)) {
769     extension->type = type;
770     GOOGLE_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
771     extension->is_repeated = true;
772     extension->repeated_message_value =
773         Arena::CreateMessage<RepeatedPtrField<MessageLite>>(arena_);
774   } else {
775     GOOGLE_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
776   }
777 
778   // RepeatedPtrField<MessageLite> does not know how to Add() since it cannot
779   // allocate an abstract object, so we have to be tricky.
780   MessageLite* result = reinterpret_cast<internal::RepeatedPtrFieldBase*>(
781                             extension->repeated_message_value)
782                             ->AddFromCleared<GenericTypeHandler<MessageLite>>();
783   if (result == NULL) {
784     result = prototype.New(arena_);
785     extension->repeated_message_value->AddAllocated(result);
786   }
787   return result;
788 }
789 
790 // Defined in extension_set_heavy.cc.
791 // MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
792 //                                       const Descriptor* message_type,
793 //                                       MessageFactory* factory)
794 
795 #undef GOOGLE_DCHECK_TYPE
796 
RemoveLast(int number)797 void ExtensionSet::RemoveLast(int number) {
798   Extension* extension = FindOrNull(number);
799   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
800   GOOGLE_DCHECK(extension->is_repeated);
801 
802   switch (cpp_type(extension->type)) {
803     case WireFormatLite::CPPTYPE_INT32:
804       extension->repeated_int32_value->RemoveLast();
805       break;
806     case WireFormatLite::CPPTYPE_INT64:
807       extension->repeated_int64_value->RemoveLast();
808       break;
809     case WireFormatLite::CPPTYPE_UINT32:
810       extension->repeated_uint32_value->RemoveLast();
811       break;
812     case WireFormatLite::CPPTYPE_UINT64:
813       extension->repeated_uint64_value->RemoveLast();
814       break;
815     case WireFormatLite::CPPTYPE_FLOAT:
816       extension->repeated_float_value->RemoveLast();
817       break;
818     case WireFormatLite::CPPTYPE_DOUBLE:
819       extension->repeated_double_value->RemoveLast();
820       break;
821     case WireFormatLite::CPPTYPE_BOOL:
822       extension->repeated_bool_value->RemoveLast();
823       break;
824     case WireFormatLite::CPPTYPE_ENUM:
825       extension->repeated_enum_value->RemoveLast();
826       break;
827     case WireFormatLite::CPPTYPE_STRING:
828       extension->repeated_string_value->RemoveLast();
829       break;
830     case WireFormatLite::CPPTYPE_MESSAGE:
831       extension->repeated_message_value->RemoveLast();
832       break;
833   }
834 }
835 
ReleaseLast(int number)836 MessageLite* ExtensionSet::ReleaseLast(int number) {
837   Extension* extension = FindOrNull(number);
838   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
839   GOOGLE_DCHECK(extension->is_repeated);
840   GOOGLE_DCHECK(cpp_type(extension->type) == WireFormatLite::CPPTYPE_MESSAGE);
841   return extension->repeated_message_value->ReleaseLast();
842 }
843 
SwapElements(int number,int index1,int index2)844 void ExtensionSet::SwapElements(int number, int index1, int index2) {
845   Extension* extension = FindOrNull(number);
846   GOOGLE_CHECK(extension != NULL) << "Index out-of-bounds (field is empty).";
847   GOOGLE_DCHECK(extension->is_repeated);
848 
849   switch (cpp_type(extension->type)) {
850     case WireFormatLite::CPPTYPE_INT32:
851       extension->repeated_int32_value->SwapElements(index1, index2);
852       break;
853     case WireFormatLite::CPPTYPE_INT64:
854       extension->repeated_int64_value->SwapElements(index1, index2);
855       break;
856     case WireFormatLite::CPPTYPE_UINT32:
857       extension->repeated_uint32_value->SwapElements(index1, index2);
858       break;
859     case WireFormatLite::CPPTYPE_UINT64:
860       extension->repeated_uint64_value->SwapElements(index1, index2);
861       break;
862     case WireFormatLite::CPPTYPE_FLOAT:
863       extension->repeated_float_value->SwapElements(index1, index2);
864       break;
865     case WireFormatLite::CPPTYPE_DOUBLE:
866       extension->repeated_double_value->SwapElements(index1, index2);
867       break;
868     case WireFormatLite::CPPTYPE_BOOL:
869       extension->repeated_bool_value->SwapElements(index1, index2);
870       break;
871     case WireFormatLite::CPPTYPE_ENUM:
872       extension->repeated_enum_value->SwapElements(index1, index2);
873       break;
874     case WireFormatLite::CPPTYPE_STRING:
875       extension->repeated_string_value->SwapElements(index1, index2);
876       break;
877     case WireFormatLite::CPPTYPE_MESSAGE:
878       extension->repeated_message_value->SwapElements(index1, index2);
879       break;
880   }
881 }
882 
883 // ===================================================================
884 
Clear()885 void ExtensionSet::Clear() {
886   ForEach([](int /* number */, Extension& ext) { ext.Clear(); });
887 }
888 
889 namespace {
890 // Computes the size of a std::set_union without constructing the union.
891 template <typename ItX, typename ItY>
SizeOfUnion(ItX it_xs,ItX end_xs,ItY it_ys,ItY end_ys)892 size_t SizeOfUnion(ItX it_xs, ItX end_xs, ItY it_ys, ItY end_ys) {
893   size_t result = 0;
894   while (it_xs != end_xs && it_ys != end_ys) {
895     ++result;
896     if (it_xs->first < it_ys->first) {
897       ++it_xs;
898     } else if (it_xs->first == it_ys->first) {
899       ++it_xs;
900       ++it_ys;
901     } else {
902       ++it_ys;
903     }
904   }
905   result += std::distance(it_xs, end_xs);
906   result += std::distance(it_ys, end_ys);
907   return result;
908 }
909 }  // namespace
910 
MergeFrom(const ExtensionSet & other)911 void ExtensionSet::MergeFrom(const ExtensionSet& other) {
912   if (PROTOBUF_PREDICT_TRUE(!is_large())) {
913     if (PROTOBUF_PREDICT_TRUE(!other.is_large())) {
914       GrowCapacity(SizeOfUnion(flat_begin(), flat_end(), other.flat_begin(),
915                                other.flat_end()));
916     } else {
917       GrowCapacity(SizeOfUnion(flat_begin(), flat_end(),
918                                other.map_.large->begin(),
919                                other.map_.large->end()));
920     }
921   }
922   other.ForEach([this](int number, const Extension& ext) {
923     this->InternalExtensionMergeFrom(number, ext);
924   });
925 }
926 
InternalExtensionMergeFrom(int number,const Extension & other_extension)927 void ExtensionSet::InternalExtensionMergeFrom(
928     int number, const Extension& other_extension) {
929   if (other_extension.is_repeated) {
930     Extension* extension;
931     bool is_new =
932         MaybeNewExtension(number, other_extension.descriptor, &extension);
933     if (is_new) {
934       // Extension did not already exist in set.
935       extension->type = other_extension.type;
936       extension->is_packed = other_extension.is_packed;
937       extension->is_repeated = true;
938     } else {
939       GOOGLE_DCHECK_EQ(extension->type, other_extension.type);
940       GOOGLE_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
941       GOOGLE_DCHECK(extension->is_repeated);
942     }
943 
944     switch (cpp_type(other_extension.type)) {
945 #define HANDLE_TYPE(UPPERCASE, LOWERCASE, REPEATED_TYPE) \
946   case WireFormatLite::CPPTYPE_##UPPERCASE:              \
947     if (is_new) {                                        \
948       extension->repeated_##LOWERCASE##_value =          \
949           Arena::CreateMessage<REPEATED_TYPE>(arena_);   \
950     }                                                    \
951     extension->repeated_##LOWERCASE##_value->MergeFrom(  \
952         *other_extension.repeated_##LOWERCASE##_value);  \
953     break;
954 
955       HANDLE_TYPE(INT32, int32, RepeatedField<int32>);
956       HANDLE_TYPE(INT64, int64, RepeatedField<int64>);
957       HANDLE_TYPE(UINT32, uint32, RepeatedField<uint32>);
958       HANDLE_TYPE(UINT64, uint64, RepeatedField<uint64>);
959       HANDLE_TYPE(FLOAT, float, RepeatedField<float>);
960       HANDLE_TYPE(DOUBLE, double, RepeatedField<double>);
961       HANDLE_TYPE(BOOL, bool, RepeatedField<bool>);
962       HANDLE_TYPE(ENUM, enum, RepeatedField<int>);
963       HANDLE_TYPE(STRING, string, RepeatedPtrField<std::string>);
964 #undef HANDLE_TYPE
965 
966       case WireFormatLite::CPPTYPE_MESSAGE:
967         if (is_new) {
968           extension->repeated_message_value =
969               Arena::CreateMessage<RepeatedPtrField<MessageLite>>(arena_);
970         }
971         // We can't call RepeatedPtrField<MessageLite>::MergeFrom() because
972         // it would attempt to allocate new objects.
973         RepeatedPtrField<MessageLite>* other_repeated_message =
974             other_extension.repeated_message_value;
975         for (int i = 0; i < other_repeated_message->size(); i++) {
976           const MessageLite& other_message = other_repeated_message->Get(i);
977           MessageLite* target =
978               reinterpret_cast<internal::RepeatedPtrFieldBase*>(
979                   extension->repeated_message_value)
980                   ->AddFromCleared<GenericTypeHandler<MessageLite>>();
981           if (target == NULL) {
982             target = other_message.New(arena_);
983             extension->repeated_message_value->AddAllocated(target);
984           }
985           target->CheckTypeAndMergeFrom(other_message);
986         }
987         break;
988     }
989   } else {
990     if (!other_extension.is_cleared) {
991       switch (cpp_type(other_extension.type)) {
992 #define HANDLE_TYPE(UPPERCASE, LOWERCASE, CAMELCASE)  \
993   case WireFormatLite::CPPTYPE_##UPPERCASE:           \
994     Set##CAMELCASE(number, other_extension.type,      \
995                    other_extension.LOWERCASE##_value, \
996                    other_extension.descriptor);       \
997     break;
998 
999         HANDLE_TYPE(INT32, int32, Int32);
1000         HANDLE_TYPE(INT64, int64, Int64);
1001         HANDLE_TYPE(UINT32, uint32, UInt32);
1002         HANDLE_TYPE(UINT64, uint64, UInt64);
1003         HANDLE_TYPE(FLOAT, float, Float);
1004         HANDLE_TYPE(DOUBLE, double, Double);
1005         HANDLE_TYPE(BOOL, bool, Bool);
1006         HANDLE_TYPE(ENUM, enum, Enum);
1007 #undef HANDLE_TYPE
1008         case WireFormatLite::CPPTYPE_STRING:
1009           SetString(number, other_extension.type, *other_extension.string_value,
1010                     other_extension.descriptor);
1011           break;
1012         case WireFormatLite::CPPTYPE_MESSAGE: {
1013           Extension* extension;
1014           bool is_new =
1015               MaybeNewExtension(number, other_extension.descriptor, &extension);
1016           if (is_new) {
1017             extension->type = other_extension.type;
1018             extension->is_packed = other_extension.is_packed;
1019             extension->is_repeated = false;
1020             if (other_extension.is_lazy) {
1021               extension->is_lazy = true;
1022               extension->lazymessage_value =
1023                   other_extension.lazymessage_value->New(arena_);
1024               extension->lazymessage_value->MergeFrom(
1025                   *other_extension.lazymessage_value);
1026             } else {
1027               extension->is_lazy = false;
1028               extension->message_value =
1029                   other_extension.message_value->New(arena_);
1030               extension->message_value->CheckTypeAndMergeFrom(
1031                   *other_extension.message_value);
1032             }
1033           } else {
1034             GOOGLE_DCHECK_EQ(extension->type, other_extension.type);
1035             GOOGLE_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
1036             GOOGLE_DCHECK(!extension->is_repeated);
1037             if (other_extension.is_lazy) {
1038               if (extension->is_lazy) {
1039                 extension->lazymessage_value->MergeFrom(
1040                     *other_extension.lazymessage_value);
1041               } else {
1042                 extension->message_value->CheckTypeAndMergeFrom(
1043                     other_extension.lazymessage_value->GetMessage(
1044                         *extension->message_value));
1045               }
1046             } else {
1047               if (extension->is_lazy) {
1048                 extension->lazymessage_value
1049                     ->MutableMessage(*other_extension.message_value)
1050                     ->CheckTypeAndMergeFrom(*other_extension.message_value);
1051               } else {
1052                 extension->message_value->CheckTypeAndMergeFrom(
1053                     *other_extension.message_value);
1054               }
1055             }
1056           }
1057           extension->is_cleared = false;
1058           break;
1059         }
1060       }
1061     }
1062   }
1063 }
1064 
Swap(ExtensionSet * x)1065 void ExtensionSet::Swap(ExtensionSet* x) {
1066   if (GetArenaNoVirtual() == x->GetArenaNoVirtual()) {
1067     using std::swap;
1068     swap(flat_capacity_, x->flat_capacity_);
1069     swap(flat_size_, x->flat_size_);
1070     swap(map_, x->map_);
1071   } else {
1072     // TODO(cfallin, rohananil): We maybe able to optimize a case where we are
1073     // swapping from heap to arena-allocated extension set, by just Own()'ing
1074     // the extensions.
1075     ExtensionSet extension_set;
1076     extension_set.MergeFrom(*x);
1077     x->Clear();
1078     x->MergeFrom(*this);
1079     Clear();
1080     MergeFrom(extension_set);
1081   }
1082 }
1083 
SwapExtension(ExtensionSet * other,int number)1084 void ExtensionSet::SwapExtension(ExtensionSet* other, int number) {
1085   if (this == other) return;
1086   Extension* this_ext = FindOrNull(number);
1087   Extension* other_ext = other->FindOrNull(number);
1088 
1089   if (this_ext == NULL && other_ext == NULL) {
1090     return;
1091   }
1092 
1093   if (this_ext != NULL && other_ext != NULL) {
1094     if (GetArenaNoVirtual() == other->GetArenaNoVirtual()) {
1095       using std::swap;
1096       swap(*this_ext, *other_ext);
1097     } else {
1098       // TODO(cfallin, rohananil): We could further optimize these cases,
1099       // especially avoid creation of ExtensionSet, and move MergeFrom logic
1100       // into Extensions itself (which takes arena as an argument).
1101       // We do it this way to reuse the copy-across-arenas logic already
1102       // implemented in ExtensionSet's MergeFrom.
1103       ExtensionSet temp;
1104       temp.InternalExtensionMergeFrom(number, *other_ext);
1105       Extension* temp_ext = temp.FindOrNull(number);
1106       other_ext->Clear();
1107       other->InternalExtensionMergeFrom(number, *this_ext);
1108       this_ext->Clear();
1109       InternalExtensionMergeFrom(number, *temp_ext);
1110     }
1111     return;
1112   }
1113 
1114   if (this_ext == NULL) {
1115     if (GetArenaNoVirtual() == other->GetArenaNoVirtual()) {
1116       *Insert(number).first = *other_ext;
1117     } else {
1118       InternalExtensionMergeFrom(number, *other_ext);
1119     }
1120     other->Erase(number);
1121     return;
1122   }
1123 
1124   if (other_ext == NULL) {
1125     if (GetArenaNoVirtual() == other->GetArenaNoVirtual()) {
1126       *other->Insert(number).first = *this_ext;
1127     } else {
1128       other->InternalExtensionMergeFrom(number, *this_ext);
1129     }
1130     Erase(number);
1131     return;
1132   }
1133 }
1134 
IsInitialized() const1135 bool ExtensionSet::IsInitialized() const {
1136   // Extensions are never required.  However, we need to check that all
1137   // embedded messages are initialized.
1138   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1139     for (const auto& kv : *map_.large) {
1140       if (!kv.second.IsInitialized()) return false;
1141     }
1142     return true;
1143   }
1144   for (const KeyValue* it = flat_begin(); it != flat_end(); ++it) {
1145     if (!it->second.IsInitialized()) return false;
1146   }
1147   return true;
1148 }
1149 
FindExtensionInfoFromTag(uint32 tag,ExtensionFinder * extension_finder,int * field_number,ExtensionInfo * extension,bool * was_packed_on_wire)1150 bool ExtensionSet::FindExtensionInfoFromTag(uint32 tag,
1151                                             ExtensionFinder* extension_finder,
1152                                             int* field_number,
1153                                             ExtensionInfo* extension,
1154                                             bool* was_packed_on_wire) {
1155   *field_number = WireFormatLite::GetTagFieldNumber(tag);
1156   WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
1157   return FindExtensionInfoFromFieldNumber(wire_type, *field_number,
1158                                           extension_finder, extension,
1159                                           was_packed_on_wire);
1160 }
1161 
FindExtensionInfoFromFieldNumber(int wire_type,int field_number,ExtensionFinder * extension_finder,ExtensionInfo * extension,bool * was_packed_on_wire)1162 bool ExtensionSet::FindExtensionInfoFromFieldNumber(
1163     int wire_type, int field_number, ExtensionFinder* extension_finder,
1164     ExtensionInfo* extension, bool* was_packed_on_wire) {
1165   if (!extension_finder->Find(field_number, extension)) {
1166     return false;
1167   }
1168 
1169   WireFormatLite::WireType expected_wire_type =
1170       WireFormatLite::WireTypeForFieldType(real_type(extension->type));
1171 
1172   // Check if this is a packed field.
1173   *was_packed_on_wire = false;
1174   if (extension->is_repeated &&
1175       wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
1176       is_packable(expected_wire_type)) {
1177     *was_packed_on_wire = true;
1178     return true;
1179   }
1180   // Otherwise the wire type must match.
1181   return expected_wire_type == wire_type;
1182 }
1183 
ParseField(uint32 tag,io::CodedInputStream * input,ExtensionFinder * extension_finder,FieldSkipper * field_skipper)1184 bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
1185                               ExtensionFinder* extension_finder,
1186                               FieldSkipper* field_skipper) {
1187   int number;
1188   bool was_packed_on_wire;
1189   ExtensionInfo extension;
1190   if (!FindExtensionInfoFromTag(tag, extension_finder, &number, &extension,
1191                                 &was_packed_on_wire)) {
1192     return field_skipper->SkipField(input, tag);
1193   } else {
1194     return ParseFieldWithExtensionInfo(number, was_packed_on_wire, extension,
1195                                        input, field_skipper);
1196   }
1197 }
1198 
1199 #if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER
ParseField(uint64 tag,const char * ptr,const MessageLite * containing_type,internal::InternalMetadataWithArenaLite * metadata,internal::ParseContext * ctx)1200 const char* ExtensionSet::ParseField(
1201     uint64 tag, const char* ptr, const MessageLite* containing_type,
1202     internal::InternalMetadataWithArenaLite* metadata,
1203     internal::ParseContext* ctx) {
1204   GeneratedExtensionFinder finder(containing_type);
1205   int number = tag >> 3;
1206   bool was_packed_on_wire;
1207   ExtensionInfo extension;
1208   if (!FindExtensionInfoFromFieldNumber(tag & 7, number, &finder, &extension,
1209                                         &was_packed_on_wire)) {
1210     return UnknownFieldParse(tag, metadata->mutable_unknown_fields(), ptr, ctx);
1211   }
1212   return ParseFieldWithExtensionInfo(number, was_packed_on_wire, extension,
1213                                      metadata, ptr, ctx);
1214 }
1215 
ParseMessageSetItem(const char * ptr,const MessageLite * containing_type,internal::InternalMetadataWithArenaLite * metadata,internal::ParseContext * ctx)1216 const char* ExtensionSet::ParseMessageSetItem(
1217     const char* ptr, const MessageLite* containing_type,
1218     internal::InternalMetadataWithArenaLite* metadata,
1219     internal::ParseContext* ctx) {
1220   return ParseMessageSetItemTmpl(ptr, containing_type, metadata, ctx);
1221 }
1222 
1223 #endif
1224 
ParseFieldWithExtensionInfo(int number,bool was_packed_on_wire,const ExtensionInfo & extension,io::CodedInputStream * input,FieldSkipper * field_skipper)1225 bool ExtensionSet::ParseFieldWithExtensionInfo(int number,
1226                                                bool was_packed_on_wire,
1227                                                const ExtensionInfo& extension,
1228                                                io::CodedInputStream* input,
1229                                                FieldSkipper* field_skipper) {
1230   // Explicitly not read extension.is_packed, instead check whether the field
1231   // was encoded in packed form on the wire.
1232   if (was_packed_on_wire) {
1233     uint32 size;
1234     if (!input->ReadVarint32(&size)) return false;
1235     io::CodedInputStream::Limit limit = input->PushLimit(size);
1236 
1237     switch (extension.type) {
1238 #define HANDLE_TYPE(UPPERCASE, CPP_CAMELCASE, CPP_LOWERCASE)                \
1239   case WireFormatLite::TYPE_##UPPERCASE:                                    \
1240     while (input->BytesUntilLimit() > 0) {                                  \
1241       CPP_LOWERCASE value;                                                  \
1242       if (!WireFormatLite::ReadPrimitive<CPP_LOWERCASE,                     \
1243                                          WireFormatLite::TYPE_##UPPERCASE>( \
1244               input, &value))                                               \
1245         return false;                                                       \
1246       Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE,          \
1247                          extension.is_packed, value, extension.descriptor); \
1248     }                                                                       \
1249     break
1250 
1251       HANDLE_TYPE(INT32, Int32, int32);
1252       HANDLE_TYPE(INT64, Int64, int64);
1253       HANDLE_TYPE(UINT32, UInt32, uint32);
1254       HANDLE_TYPE(UINT64, UInt64, uint64);
1255       HANDLE_TYPE(SINT32, Int32, int32);
1256       HANDLE_TYPE(SINT64, Int64, int64);
1257       HANDLE_TYPE(FIXED32, UInt32, uint32);
1258       HANDLE_TYPE(FIXED64, UInt64, uint64);
1259       HANDLE_TYPE(SFIXED32, Int32, int32);
1260       HANDLE_TYPE(SFIXED64, Int64, int64);
1261       HANDLE_TYPE(FLOAT, Float, float);
1262       HANDLE_TYPE(DOUBLE, Double, double);
1263       HANDLE_TYPE(BOOL, Bool, bool);
1264 #undef HANDLE_TYPE
1265 
1266       case WireFormatLite::TYPE_ENUM:
1267         while (input->BytesUntilLimit() > 0) {
1268           int value;
1269           if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
1270                   input, &value))
1271             return false;
1272           if (extension.enum_validity_check.func(
1273                   extension.enum_validity_check.arg, value)) {
1274             AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed,
1275                     value, extension.descriptor);
1276           } else {
1277             // Invalid value.  Treat as unknown.
1278             field_skipper->SkipUnknownEnum(number, value);
1279           }
1280         }
1281         break;
1282 
1283       case WireFormatLite::TYPE_STRING:
1284       case WireFormatLite::TYPE_BYTES:
1285       case WireFormatLite::TYPE_GROUP:
1286       case WireFormatLite::TYPE_MESSAGE:
1287         GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
1288         break;
1289     }
1290 
1291     input->PopLimit(limit);
1292   } else {
1293     switch (extension.type) {
1294 #define HANDLE_TYPE(UPPERCASE, CPP_CAMELCASE, CPP_LOWERCASE)                \
1295   case WireFormatLite::TYPE_##UPPERCASE: {                                  \
1296     CPP_LOWERCASE value;                                                    \
1297     if (!WireFormatLite::ReadPrimitive<CPP_LOWERCASE,                       \
1298                                        WireFormatLite::TYPE_##UPPERCASE>(   \
1299             input, &value))                                                 \
1300       return false;                                                         \
1301     if (extension.is_repeated) {                                            \
1302       Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE,          \
1303                          extension.is_packed, value, extension.descriptor); \
1304     } else {                                                                \
1305       Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value,   \
1306                          extension.descriptor);                             \
1307     }                                                                       \
1308   } break
1309 
1310       HANDLE_TYPE(INT32, Int32, int32);
1311       HANDLE_TYPE(INT64, Int64, int64);
1312       HANDLE_TYPE(UINT32, UInt32, uint32);
1313       HANDLE_TYPE(UINT64, UInt64, uint64);
1314       HANDLE_TYPE(SINT32, Int32, int32);
1315       HANDLE_TYPE(SINT64, Int64, int64);
1316       HANDLE_TYPE(FIXED32, UInt32, uint32);
1317       HANDLE_TYPE(FIXED64, UInt64, uint64);
1318       HANDLE_TYPE(SFIXED32, Int32, int32);
1319       HANDLE_TYPE(SFIXED64, Int64, int64);
1320       HANDLE_TYPE(FLOAT, Float, float);
1321       HANDLE_TYPE(DOUBLE, Double, double);
1322       HANDLE_TYPE(BOOL, Bool, bool);
1323 #undef HANDLE_TYPE
1324 
1325       case WireFormatLite::TYPE_ENUM: {
1326         int value;
1327         if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
1328                 input, &value))
1329           return false;
1330 
1331         if (!extension.enum_validity_check.func(
1332                 extension.enum_validity_check.arg, value)) {
1333           // Invalid value.  Treat as unknown.
1334           field_skipper->SkipUnknownEnum(number, value);
1335         } else if (extension.is_repeated) {
1336           AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value,
1337                   extension.descriptor);
1338         } else {
1339           SetEnum(number, WireFormatLite::TYPE_ENUM, value,
1340                   extension.descriptor);
1341         }
1342         break;
1343       }
1344 
1345       case WireFormatLite::TYPE_STRING: {
1346         std::string* value =
1347             extension.is_repeated
1348                 ? AddString(number, WireFormatLite::TYPE_STRING,
1349                             extension.descriptor)
1350                 : MutableString(number, WireFormatLite::TYPE_STRING,
1351                                 extension.descriptor);
1352         if (!WireFormatLite::ReadString(input, value)) return false;
1353         break;
1354       }
1355 
1356       case WireFormatLite::TYPE_BYTES: {
1357         std::string* value =
1358             extension.is_repeated
1359                 ? AddString(number, WireFormatLite::TYPE_BYTES,
1360                             extension.descriptor)
1361                 : MutableString(number, WireFormatLite::TYPE_BYTES,
1362                                 extension.descriptor);
1363         if (!WireFormatLite::ReadBytes(input, value)) return false;
1364         break;
1365       }
1366 
1367       case WireFormatLite::TYPE_GROUP: {
1368         MessageLite* value =
1369             extension.is_repeated
1370                 ? AddMessage(number, WireFormatLite::TYPE_GROUP,
1371                              *extension.message_info.prototype,
1372                              extension.descriptor)
1373                 : MutableMessage(number, WireFormatLite::TYPE_GROUP,
1374                                  *extension.message_info.prototype,
1375                                  extension.descriptor);
1376         if (!WireFormatLite::ReadGroup(number, input, value)) return false;
1377         break;
1378       }
1379 
1380       case WireFormatLite::TYPE_MESSAGE: {
1381         MessageLite* value =
1382             extension.is_repeated
1383                 ? AddMessage(number, WireFormatLite::TYPE_MESSAGE,
1384                              *extension.message_info.prototype,
1385                              extension.descriptor)
1386                 : MutableMessage(number, WireFormatLite::TYPE_MESSAGE,
1387                                  *extension.message_info.prototype,
1388                                  extension.descriptor);
1389         if (!WireFormatLite::ReadMessage(input, value)) return false;
1390         break;
1391       }
1392     }
1393   }
1394 
1395   return true;
1396 }
1397 
ParseField(uint32 tag,io::CodedInputStream * input,const MessageLite * containing_type)1398 bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
1399                               const MessageLite* containing_type) {
1400   FieldSkipper skipper;
1401   GeneratedExtensionFinder finder(containing_type);
1402   return ParseField(tag, input, &finder, &skipper);
1403 }
1404 
ParseField(uint32 tag,io::CodedInputStream * input,const MessageLite * containing_type,io::CodedOutputStream * unknown_fields)1405 bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
1406                               const MessageLite* containing_type,
1407                               io::CodedOutputStream* unknown_fields) {
1408   CodedOutputStreamFieldSkipper skipper(unknown_fields);
1409   GeneratedExtensionFinder finder(containing_type);
1410   return ParseField(tag, input, &finder, &skipper);
1411 }
1412 
ParseMessageSetLite(io::CodedInputStream * input,ExtensionFinder * extension_finder,FieldSkipper * field_skipper)1413 bool ExtensionSet::ParseMessageSetLite(io::CodedInputStream* input,
1414                                        ExtensionFinder* extension_finder,
1415                                        FieldSkipper* field_skipper) {
1416   while (true) {
1417     const uint32 tag = input->ReadTag();
1418     switch (tag) {
1419       case 0:
1420         return true;
1421       case WireFormatLite::kMessageSetItemStartTag:
1422         if (!ParseMessageSetItemLite(input, extension_finder, field_skipper)) {
1423           return false;
1424         }
1425         break;
1426       default:
1427         if (!ParseField(tag, input, extension_finder, field_skipper)) {
1428           return false;
1429         }
1430         break;
1431     }
1432   }
1433 }
1434 
ParseMessageSetItemLite(io::CodedInputStream * input,ExtensionFinder * extension_finder,FieldSkipper * field_skipper)1435 bool ExtensionSet::ParseMessageSetItemLite(io::CodedInputStream* input,
1436                                            ExtensionFinder* extension_finder,
1437                                            FieldSkipper* field_skipper) {
1438   struct MSLite {
1439     bool ParseField(int type_id, io::CodedInputStream* input) {
1440       return me->ParseField(
1441           WireFormatLite::WIRETYPE_LENGTH_DELIMITED + 8 * type_id, input,
1442           extension_finder, field_skipper);
1443     }
1444 
1445     bool SkipField(uint32 tag, io::CodedInputStream* input) {
1446       return field_skipper->SkipField(input, tag);
1447     }
1448 
1449     ExtensionSet* me;
1450     ExtensionFinder* extension_finder;
1451     FieldSkipper* field_skipper;
1452   };
1453 
1454   return ParseMessageSetItemImpl(input,
1455                                  MSLite{this, extension_finder, field_skipper});
1456 }
1457 
ParseMessageSet(io::CodedInputStream * input,const MessageLite * containing_type,std::string * unknown_fields)1458 bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
1459                                    const MessageLite* containing_type,
1460                                    std::string* unknown_fields) {
1461   io::StringOutputStream zcis(unknown_fields);
1462   io::CodedOutputStream output(&zcis);
1463   CodedOutputStreamFieldSkipper skipper(&output);
1464   GeneratedExtensionFinder finder(containing_type);
1465   return ParseMessageSetLite(input, &finder, &skipper);
1466 }
1467 
SerializeWithCachedSizes(int start_field_number,int end_field_number,io::CodedOutputStream * output) const1468 void ExtensionSet::SerializeWithCachedSizes(
1469     int start_field_number, int end_field_number,
1470     io::CodedOutputStream* output) const {
1471   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1472     const auto& end = map_.large->end();
1473     for (auto it = map_.large->lower_bound(start_field_number);
1474          it != end && it->first < end_field_number; ++it) {
1475       it->second.SerializeFieldWithCachedSizes(it->first, output);
1476     }
1477     return;
1478   }
1479   const KeyValue* end = flat_end();
1480   for (const KeyValue* it = std::lower_bound(
1481            flat_begin(), end, start_field_number, KeyValue::FirstComparator());
1482        it != end && it->first < end_field_number; ++it) {
1483     it->second.SerializeFieldWithCachedSizes(it->first, output);
1484   }
1485 }
1486 
ByteSize() const1487 size_t ExtensionSet::ByteSize() const {
1488   size_t total_size = 0;
1489   ForEach([&total_size](int number, const Extension& ext) {
1490     total_size += ext.ByteSize(number);
1491   });
1492   return total_size;
1493 }
1494 
1495 // Defined in extension_set_heavy.cc.
1496 // int ExtensionSet::SpaceUsedExcludingSelf() const
1497 
MaybeNewExtension(int number,const FieldDescriptor * descriptor,Extension ** result)1498 bool ExtensionSet::MaybeNewExtension(int number,
1499                                      const FieldDescriptor* descriptor,
1500                                      Extension** result) {
1501   bool extension_is_new = false;
1502   std::tie(*result, extension_is_new) = Insert(number);
1503   (*result)->descriptor = descriptor;
1504   return extension_is_new;
1505 }
1506 
1507 // ===================================================================
1508 // Methods of ExtensionSet::Extension
1509 
Clear()1510 void ExtensionSet::Extension::Clear() {
1511   if (is_repeated) {
1512     switch (cpp_type(type)) {
1513 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)   \
1514   case WireFormatLite::CPPTYPE_##UPPERCASE: \
1515     repeated_##LOWERCASE##_value->Clear();  \
1516     break
1517 
1518       HANDLE_TYPE(INT32, int32);
1519       HANDLE_TYPE(INT64, int64);
1520       HANDLE_TYPE(UINT32, uint32);
1521       HANDLE_TYPE(UINT64, uint64);
1522       HANDLE_TYPE(FLOAT, float);
1523       HANDLE_TYPE(DOUBLE, double);
1524       HANDLE_TYPE(BOOL, bool);
1525       HANDLE_TYPE(ENUM, enum);
1526       HANDLE_TYPE(STRING, string);
1527       HANDLE_TYPE(MESSAGE, message);
1528 #undef HANDLE_TYPE
1529     }
1530   } else {
1531     if (!is_cleared) {
1532       switch (cpp_type(type)) {
1533         case WireFormatLite::CPPTYPE_STRING:
1534           string_value->clear();
1535           break;
1536         case WireFormatLite::CPPTYPE_MESSAGE:
1537           if (is_lazy) {
1538             lazymessage_value->Clear();
1539           } else {
1540             message_value->Clear();
1541           }
1542           break;
1543         default:
1544           // No need to do anything.  Get*() will return the default value
1545           // as long as is_cleared is true and Set*() will overwrite the
1546           // previous value.
1547           break;
1548       }
1549 
1550       is_cleared = true;
1551     }
1552   }
1553 }
1554 
SerializeFieldWithCachedSizes(int number,io::CodedOutputStream * output) const1555 void ExtensionSet::Extension::SerializeFieldWithCachedSizes(
1556     int number, io::CodedOutputStream* output) const {
1557   if (is_repeated) {
1558     if (is_packed) {
1559       if (cached_size == 0) return;
1560 
1561       WireFormatLite::WriteTag(
1562           number, WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
1563       output->WriteVarint32(cached_size);
1564 
1565       switch (real_type(type)) {
1566 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                 \
1567   case WireFormatLite::TYPE_##UPPERCASE:                             \
1568     for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) { \
1569       WireFormatLite::Write##CAMELCASE##NoTag(                       \
1570           repeated_##LOWERCASE##_value->Get(i), output);             \
1571     }                                                                \
1572     break
1573 
1574         HANDLE_TYPE(INT32, Int32, int32);
1575         HANDLE_TYPE(INT64, Int64, int64);
1576         HANDLE_TYPE(UINT32, UInt32, uint32);
1577         HANDLE_TYPE(UINT64, UInt64, uint64);
1578         HANDLE_TYPE(SINT32, SInt32, int32);
1579         HANDLE_TYPE(SINT64, SInt64, int64);
1580         HANDLE_TYPE(FIXED32, Fixed32, uint32);
1581         HANDLE_TYPE(FIXED64, Fixed64, uint64);
1582         HANDLE_TYPE(SFIXED32, SFixed32, int32);
1583         HANDLE_TYPE(SFIXED64, SFixed64, int64);
1584         HANDLE_TYPE(FLOAT, Float, float);
1585         HANDLE_TYPE(DOUBLE, Double, double);
1586         HANDLE_TYPE(BOOL, Bool, bool);
1587         HANDLE_TYPE(ENUM, Enum, enum);
1588 #undef HANDLE_TYPE
1589 
1590         case WireFormatLite::TYPE_STRING:
1591         case WireFormatLite::TYPE_BYTES:
1592         case WireFormatLite::TYPE_GROUP:
1593         case WireFormatLite::TYPE_MESSAGE:
1594           GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
1595           break;
1596       }
1597     } else {
1598       switch (real_type(type)) {
1599 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                 \
1600   case WireFormatLite::TYPE_##UPPERCASE:                             \
1601     for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) { \
1602       WireFormatLite::Write##CAMELCASE(                              \
1603           number, repeated_##LOWERCASE##_value->Get(i), output);     \
1604     }                                                                \
1605     break
1606 
1607         HANDLE_TYPE(INT32, Int32, int32);
1608         HANDLE_TYPE(INT64, Int64, int64);
1609         HANDLE_TYPE(UINT32, UInt32, uint32);
1610         HANDLE_TYPE(UINT64, UInt64, uint64);
1611         HANDLE_TYPE(SINT32, SInt32, int32);
1612         HANDLE_TYPE(SINT64, SInt64, int64);
1613         HANDLE_TYPE(FIXED32, Fixed32, uint32);
1614         HANDLE_TYPE(FIXED64, Fixed64, uint64);
1615         HANDLE_TYPE(SFIXED32, SFixed32, int32);
1616         HANDLE_TYPE(SFIXED64, SFixed64, int64);
1617         HANDLE_TYPE(FLOAT, Float, float);
1618         HANDLE_TYPE(DOUBLE, Double, double);
1619         HANDLE_TYPE(BOOL, Bool, bool);
1620         HANDLE_TYPE(STRING, String, string);
1621         HANDLE_TYPE(BYTES, Bytes, string);
1622         HANDLE_TYPE(ENUM, Enum, enum);
1623         HANDLE_TYPE(GROUP, Group, message);
1624         HANDLE_TYPE(MESSAGE, Message, message);
1625 #undef HANDLE_TYPE
1626       }
1627     }
1628   } else if (!is_cleared) {
1629     switch (real_type(type)) {
1630 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE)             \
1631   case WireFormatLite::TYPE_##UPPERCASE:                     \
1632     WireFormatLite::Write##CAMELCASE(number, VALUE, output); \
1633     break
1634 
1635       HANDLE_TYPE(INT32, Int32, int32_value);
1636       HANDLE_TYPE(INT64, Int64, int64_value);
1637       HANDLE_TYPE(UINT32, UInt32, uint32_value);
1638       HANDLE_TYPE(UINT64, UInt64, uint64_value);
1639       HANDLE_TYPE(SINT32, SInt32, int32_value);
1640       HANDLE_TYPE(SINT64, SInt64, int64_value);
1641       HANDLE_TYPE(FIXED32, Fixed32, uint32_value);
1642       HANDLE_TYPE(FIXED64, Fixed64, uint64_value);
1643       HANDLE_TYPE(SFIXED32, SFixed32, int32_value);
1644       HANDLE_TYPE(SFIXED64, SFixed64, int64_value);
1645       HANDLE_TYPE(FLOAT, Float, float_value);
1646       HANDLE_TYPE(DOUBLE, Double, double_value);
1647       HANDLE_TYPE(BOOL, Bool, bool_value);
1648       HANDLE_TYPE(STRING, String, *string_value);
1649       HANDLE_TYPE(BYTES, Bytes, *string_value);
1650       HANDLE_TYPE(ENUM, Enum, enum_value);
1651       HANDLE_TYPE(GROUP, Group, *message_value);
1652 #undef HANDLE_TYPE
1653       case WireFormatLite::TYPE_MESSAGE:
1654         if (is_lazy) {
1655           lazymessage_value->WriteMessage(number, output);
1656         } else {
1657           WireFormatLite::WriteMessage(number, *message_value, output);
1658         }
1659         break;
1660     }
1661   }
1662 }
1663 
ByteSize(int number) const1664 size_t ExtensionSet::Extension::ByteSize(int number) const {
1665   size_t result = 0;
1666 
1667   if (is_repeated) {
1668     if (is_packed) {
1669       switch (real_type(type)) {
1670 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                 \
1671   case WireFormatLite::TYPE_##UPPERCASE:                             \
1672     for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) { \
1673       result += WireFormatLite::CAMELCASE##Size(                     \
1674           repeated_##LOWERCASE##_value->Get(i));                     \
1675     }                                                                \
1676     break
1677 
1678         HANDLE_TYPE(INT32, Int32, int32);
1679         HANDLE_TYPE(INT64, Int64, int64);
1680         HANDLE_TYPE(UINT32, UInt32, uint32);
1681         HANDLE_TYPE(UINT64, UInt64, uint64);
1682         HANDLE_TYPE(SINT32, SInt32, int32);
1683         HANDLE_TYPE(SINT64, SInt64, int64);
1684         HANDLE_TYPE(ENUM, Enum, enum);
1685 #undef HANDLE_TYPE
1686 
1687         // Stuff with fixed size.
1688 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)             \
1689   case WireFormatLite::TYPE_##UPPERCASE:                         \
1690     result += WireFormatLite::k##CAMELCASE##Size *               \
1691               FromIntSize(repeated_##LOWERCASE##_value->size()); \
1692     break
1693         HANDLE_TYPE(FIXED32, Fixed32, uint32);
1694         HANDLE_TYPE(FIXED64, Fixed64, uint64);
1695         HANDLE_TYPE(SFIXED32, SFixed32, int32);
1696         HANDLE_TYPE(SFIXED64, SFixed64, int64);
1697         HANDLE_TYPE(FLOAT, Float, float);
1698         HANDLE_TYPE(DOUBLE, Double, double);
1699         HANDLE_TYPE(BOOL, Bool, bool);
1700 #undef HANDLE_TYPE
1701 
1702         case WireFormatLite::TYPE_STRING:
1703         case WireFormatLite::TYPE_BYTES:
1704         case WireFormatLite::TYPE_GROUP:
1705         case WireFormatLite::TYPE_MESSAGE:
1706           GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
1707           break;
1708       }
1709 
1710       cached_size = ToCachedSize(result);
1711       if (result > 0) {
1712         result += io::CodedOutputStream::VarintSize32(result);
1713         result += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
1714             number, WireFormatLite::WIRETYPE_LENGTH_DELIMITED));
1715       }
1716     } else {
1717       size_t tag_size = WireFormatLite::TagSize(number, real_type(type));
1718 
1719       switch (real_type(type)) {
1720 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
1721   case WireFormatLite::TYPE_##UPPERCASE:                                    \
1722     result += tag_size * FromIntSize(repeated_##LOWERCASE##_value->size()); \
1723     for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {        \
1724       result += WireFormatLite::CAMELCASE##Size(                            \
1725           repeated_##LOWERCASE##_value->Get(i));                            \
1726     }                                                                       \
1727     break
1728 
1729         HANDLE_TYPE(INT32, Int32, int32);
1730         HANDLE_TYPE(INT64, Int64, int64);
1731         HANDLE_TYPE(UINT32, UInt32, uint32);
1732         HANDLE_TYPE(UINT64, UInt64, uint64);
1733         HANDLE_TYPE(SINT32, SInt32, int32);
1734         HANDLE_TYPE(SINT64, SInt64, int64);
1735         HANDLE_TYPE(STRING, String, string);
1736         HANDLE_TYPE(BYTES, Bytes, string);
1737         HANDLE_TYPE(ENUM, Enum, enum);
1738         HANDLE_TYPE(GROUP, Group, message);
1739         HANDLE_TYPE(MESSAGE, Message, message);
1740 #undef HANDLE_TYPE
1741 
1742         // Stuff with fixed size.
1743 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)             \
1744   case WireFormatLite::TYPE_##UPPERCASE:                         \
1745     result += (tag_size + WireFormatLite::k##CAMELCASE##Size) *  \
1746               FromIntSize(repeated_##LOWERCASE##_value->size()); \
1747     break
1748         HANDLE_TYPE(FIXED32, Fixed32, uint32);
1749         HANDLE_TYPE(FIXED64, Fixed64, uint64);
1750         HANDLE_TYPE(SFIXED32, SFixed32, int32);
1751         HANDLE_TYPE(SFIXED64, SFixed64, int64);
1752         HANDLE_TYPE(FLOAT, Float, float);
1753         HANDLE_TYPE(DOUBLE, Double, double);
1754         HANDLE_TYPE(BOOL, Bool, bool);
1755 #undef HANDLE_TYPE
1756       }
1757     }
1758   } else if (!is_cleared) {
1759     result += WireFormatLite::TagSize(number, real_type(type));
1760     switch (real_type(type)) {
1761 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)      \
1762   case WireFormatLite::TYPE_##UPPERCASE:                  \
1763     result += WireFormatLite::CAMELCASE##Size(LOWERCASE); \
1764     break
1765 
1766       HANDLE_TYPE(INT32, Int32, int32_value);
1767       HANDLE_TYPE(INT64, Int64, int64_value);
1768       HANDLE_TYPE(UINT32, UInt32, uint32_value);
1769       HANDLE_TYPE(UINT64, UInt64, uint64_value);
1770       HANDLE_TYPE(SINT32, SInt32, int32_value);
1771       HANDLE_TYPE(SINT64, SInt64, int64_value);
1772       HANDLE_TYPE(STRING, String, *string_value);
1773       HANDLE_TYPE(BYTES, Bytes, *string_value);
1774       HANDLE_TYPE(ENUM, Enum, enum_value);
1775       HANDLE_TYPE(GROUP, Group, *message_value);
1776 #undef HANDLE_TYPE
1777       case WireFormatLite::TYPE_MESSAGE: {
1778         if (is_lazy) {
1779           size_t size = lazymessage_value->ByteSize();
1780           result += io::CodedOutputStream::VarintSize32(size) + size;
1781         } else {
1782           result += WireFormatLite::MessageSize(*message_value);
1783         }
1784         break;
1785       }
1786 
1787       // Stuff with fixed size.
1788 #define HANDLE_TYPE(UPPERCASE, CAMELCASE)         \
1789   case WireFormatLite::TYPE_##UPPERCASE:          \
1790     result += WireFormatLite::k##CAMELCASE##Size; \
1791     break
1792         HANDLE_TYPE(FIXED32, Fixed32);
1793         HANDLE_TYPE(FIXED64, Fixed64);
1794         HANDLE_TYPE(SFIXED32, SFixed32);
1795         HANDLE_TYPE(SFIXED64, SFixed64);
1796         HANDLE_TYPE(FLOAT, Float);
1797         HANDLE_TYPE(DOUBLE, Double);
1798         HANDLE_TYPE(BOOL, Bool);
1799 #undef HANDLE_TYPE
1800     }
1801   }
1802 
1803   return result;
1804 }
1805 
GetSize() const1806 int ExtensionSet::Extension::GetSize() const {
1807   GOOGLE_DCHECK(is_repeated);
1808   switch (cpp_type(type)) {
1809 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)   \
1810   case WireFormatLite::CPPTYPE_##UPPERCASE: \
1811     return repeated_##LOWERCASE##_value->size()
1812 
1813     HANDLE_TYPE(INT32, int32);
1814     HANDLE_TYPE(INT64, int64);
1815     HANDLE_TYPE(UINT32, uint32);
1816     HANDLE_TYPE(UINT64, uint64);
1817     HANDLE_TYPE(FLOAT, float);
1818     HANDLE_TYPE(DOUBLE, double);
1819     HANDLE_TYPE(BOOL, bool);
1820     HANDLE_TYPE(ENUM, enum);
1821     HANDLE_TYPE(STRING, string);
1822     HANDLE_TYPE(MESSAGE, message);
1823 #undef HANDLE_TYPE
1824   }
1825 
1826   GOOGLE_LOG(FATAL) << "Can't get here.";
1827   return 0;
1828 }
1829 
1830 // This function deletes all allocated objects. This function should be only
1831 // called if the Extension was created with an arena.
Free()1832 void ExtensionSet::Extension::Free() {
1833   if (is_repeated) {
1834     switch (cpp_type(type)) {
1835 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)   \
1836   case WireFormatLite::CPPTYPE_##UPPERCASE: \
1837     delete repeated_##LOWERCASE##_value;    \
1838     break
1839 
1840       HANDLE_TYPE(INT32, int32);
1841       HANDLE_TYPE(INT64, int64);
1842       HANDLE_TYPE(UINT32, uint32);
1843       HANDLE_TYPE(UINT64, uint64);
1844       HANDLE_TYPE(FLOAT, float);
1845       HANDLE_TYPE(DOUBLE, double);
1846       HANDLE_TYPE(BOOL, bool);
1847       HANDLE_TYPE(ENUM, enum);
1848       HANDLE_TYPE(STRING, string);
1849       HANDLE_TYPE(MESSAGE, message);
1850 #undef HANDLE_TYPE
1851     }
1852   } else {
1853     switch (cpp_type(type)) {
1854       case WireFormatLite::CPPTYPE_STRING:
1855         delete string_value;
1856         break;
1857       case WireFormatLite::CPPTYPE_MESSAGE:
1858         if (is_lazy) {
1859           delete lazymessage_value;
1860         } else {
1861           delete message_value;
1862         }
1863         break;
1864       default:
1865         break;
1866     }
1867   }
1868 }
1869 
1870 // Defined in extension_set_heavy.cc.
1871 // int ExtensionSet::Extension::SpaceUsedExcludingSelf() const
1872 
IsInitialized() const1873 bool ExtensionSet::Extension::IsInitialized() const {
1874   if (cpp_type(type) == WireFormatLite::CPPTYPE_MESSAGE) {
1875     if (is_repeated) {
1876       for (int i = 0; i < repeated_message_value->size(); i++) {
1877         if (!repeated_message_value->Get(i).IsInitialized()) {
1878           return false;
1879         }
1880       }
1881     } else {
1882       if (!is_cleared) {
1883         if (is_lazy) {
1884           if (!lazymessage_value->IsInitialized()) return false;
1885         } else {
1886           if (!message_value->IsInitialized()) return false;
1887         }
1888       }
1889     }
1890   }
1891   return true;
1892 }
1893 
1894 // Dummy key method to avoid weak vtable.
UnusedKeyMethod()1895 void ExtensionSet::LazyMessageExtension::UnusedKeyMethod() {}
1896 
FindOrNull(int key) const1897 const ExtensionSet::Extension* ExtensionSet::FindOrNull(int key) const {
1898   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1899     return FindOrNullInLargeMap(key);
1900   }
1901   const KeyValue* end = flat_end();
1902   const KeyValue* it =
1903       std::lower_bound(flat_begin(), end, key, KeyValue::FirstComparator());
1904   if (it != end && it->first == key) {
1905     return &it->second;
1906   }
1907   return NULL;
1908 }
1909 
FindOrNullInLargeMap(int key) const1910 const ExtensionSet::Extension* ExtensionSet::FindOrNullInLargeMap(
1911     int key) const {
1912   assert(is_large());
1913   LargeMap::const_iterator it = map_.large->find(key);
1914   if (it != map_.large->end()) {
1915     return &it->second;
1916   }
1917   return NULL;
1918 }
1919 
FindOrNull(int key)1920 ExtensionSet::Extension* ExtensionSet::FindOrNull(int key) {
1921   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1922     return FindOrNullInLargeMap(key);
1923   }
1924   KeyValue* end = flat_end();
1925   KeyValue* it =
1926       std::lower_bound(flat_begin(), end, key, KeyValue::FirstComparator());
1927   if (it != end && it->first == key) {
1928     return &it->second;
1929   }
1930   return NULL;
1931 }
1932 
FindOrNullInLargeMap(int key)1933 ExtensionSet::Extension* ExtensionSet::FindOrNullInLargeMap(int key) {
1934   assert(is_large());
1935   LargeMap::iterator it = map_.large->find(key);
1936   if (it != map_.large->end()) {
1937     return &it->second;
1938   }
1939   return NULL;
1940 }
1941 
Insert(int key)1942 std::pair<ExtensionSet::Extension*, bool> ExtensionSet::Insert(int key) {
1943   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1944     auto maybe = map_.large->insert({key, Extension()});
1945     return {&maybe.first->second, maybe.second};
1946   }
1947   KeyValue* end = flat_end();
1948   KeyValue* it =
1949       std::lower_bound(flat_begin(), end, key, KeyValue::FirstComparator());
1950   if (it != end && it->first == key) {
1951     return {&it->second, false};
1952   }
1953   if (flat_size_ < flat_capacity_) {
1954     std::copy_backward(it, end, end + 1);
1955     ++flat_size_;
1956     it->first = key;
1957     it->second = Extension();
1958     return {&it->second, true};
1959   }
1960   GrowCapacity(flat_size_ + 1);
1961   return Insert(key);
1962 }
1963 
GrowCapacity(size_t minimum_new_capacity)1964 void ExtensionSet::GrowCapacity(size_t minimum_new_capacity) {
1965   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1966     return;  // LargeMap does not have a "reserve" method.
1967   }
1968   if (flat_capacity_ >= minimum_new_capacity) {
1969     return;
1970   }
1971 
1972   const auto old_flat_capacity = flat_capacity_;
1973 
1974   do {
1975     flat_capacity_ = flat_capacity_ == 0 ? 1 : flat_capacity_ * 4;
1976   } while (flat_capacity_ < minimum_new_capacity);
1977 
1978   const KeyValue* begin = flat_begin();
1979   const KeyValue* end = flat_end();
1980   if (flat_capacity_ > kMaximumFlatCapacity) {
1981     // Switch to LargeMap
1982     map_.large = Arena::Create<LargeMap>(arena_);
1983     LargeMap::iterator hint = map_.large->begin();
1984     for (const KeyValue* it = begin; it != end; ++it) {
1985       hint = map_.large->insert(hint, {it->first, it->second});
1986     }
1987     flat_size_ = 0;
1988   } else {
1989     map_.flat = Arena::CreateArray<KeyValue>(arena_, flat_capacity_);
1990     std::copy(begin, end, map_.flat);
1991   }
1992   if (arena_ == nullptr) {
1993     DeleteFlatMap(begin, old_flat_capacity);
1994   }
1995 }
1996 
1997 // static
1998 constexpr uint16 ExtensionSet::kMaximumFlatCapacity;
1999 
Erase(int key)2000 void ExtensionSet::Erase(int key) {
2001   if (PROTOBUF_PREDICT_FALSE(is_large())) {
2002     map_.large->erase(key);
2003     return;
2004   }
2005   KeyValue* end = flat_end();
2006   KeyValue* it =
2007       std::lower_bound(flat_begin(), end, key, KeyValue::FirstComparator());
2008   if (it != end && it->first == key) {
2009     std::copy(it + 1, end, it);
2010     --flat_size_;
2011   }
2012 }
2013 
2014 // ==================================================================
2015 // Default repeated field instances for iterator-compatible accessors
2016 
default_instance()2017 const RepeatedPrimitiveDefaults* RepeatedPrimitiveDefaults::default_instance() {
2018   static auto instance = OnShutdownDelete(new RepeatedPrimitiveDefaults);
2019   return instance;
2020 }
2021 
2022 const RepeatedStringTypeTraits::RepeatedFieldType*
GetDefaultRepeatedField()2023 RepeatedStringTypeTraits::GetDefaultRepeatedField() {
2024   static auto instance = OnShutdownDelete(new RepeatedFieldType);
2025   return instance;
2026 }
2027 
SerializeMessageSetItemWithCachedSizes(int number,io::CodedOutputStream * output) const2028 void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes(
2029     int number, io::CodedOutputStream* output) const {
2030   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
2031     // Not a valid MessageSet extension, but serialize it the normal way.
2032     SerializeFieldWithCachedSizes(number, output);
2033     return;
2034   }
2035 
2036   if (is_cleared) return;
2037 
2038   // Start group.
2039   output->WriteTag(WireFormatLite::kMessageSetItemStartTag);
2040 
2041   // Write type ID.
2042   WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, number,
2043                               output);
2044   // Write message.
2045   if (is_lazy) {
2046     lazymessage_value->WriteMessage(WireFormatLite::kMessageSetMessageNumber,
2047                                     output);
2048   } else {
2049     WireFormatLite::WriteMessageMaybeToArray(
2050         WireFormatLite::kMessageSetMessageNumber, *message_value, output);
2051   }
2052 
2053   // End group.
2054   output->WriteTag(WireFormatLite::kMessageSetItemEndTag);
2055 }
2056 
MessageSetItemByteSize(int number) const2057 size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
2058   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
2059     // Not a valid MessageSet extension, but compute the byte size for it the
2060     // normal way.
2061     return ByteSize(number);
2062   }
2063 
2064   if (is_cleared) return 0;
2065 
2066   size_t our_size = WireFormatLite::kMessageSetItemTagsSize;
2067 
2068   // type_id
2069   our_size += io::CodedOutputStream::VarintSize32(number);
2070 
2071   // message
2072   size_t message_size = 0;
2073   if (is_lazy) {
2074     message_size = lazymessage_value->ByteSizeLong();
2075   } else {
2076     message_size = message_value->ByteSizeLong();
2077   }
2078 
2079   our_size += io::CodedOutputStream::VarintSize32(message_size);
2080   our_size += message_size;
2081 
2082   return our_size;
2083 }
2084 
SerializeMessageSetWithCachedSizes(io::CodedOutputStream * output) const2085 void ExtensionSet::SerializeMessageSetWithCachedSizes(
2086     io::CodedOutputStream* output) const {
2087   ForEach([output](int number, const Extension& ext) {
2088     ext.SerializeMessageSetItemWithCachedSizes(number, output);
2089   });
2090 }
2091 
MessageSetByteSize() const2092 size_t ExtensionSet::MessageSetByteSize() const {
2093   size_t total_size = 0;
2094   ForEach([&total_size](int number, const Extension& ext) {
2095     total_size += ext.MessageSetItemByteSize(number);
2096   });
2097   return total_size;
2098 }
2099 
2100 }  // namespace internal
2101 }  // namespace protobuf
2102 }  // namespace google
2103