• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: kenton@google.com (Kenton Varda)
9 //  Based on original Protocol Buffers design by
10 //  Sanjay Ghemawat, Jeff Dean, and others.
11 
12 #include "google/protobuf/extension_set.h"
13 
14 #include <algorithm>
15 #include <atomic>
16 #include <cstddef>
17 #include <cstdint>
18 #include <string>
19 #include <tuple>
20 #include <type_traits>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/hash/hash.h"
25 #include "absl/log/absl_check.h"
26 #include "absl/log/absl_log.h"
27 #include "google/protobuf/arena.h"
28 #include "google/protobuf/extension_set_inl.h"
29 #include "google/protobuf/io/coded_stream.h"
30 #include "google/protobuf/message_lite.h"
31 #include "google/protobuf/metadata_lite.h"
32 #include "google/protobuf/parse_context.h"
33 #include "google/protobuf/port.h"
34 #include "google/protobuf/repeated_field.h"
35 
36 // must be last.
37 #include "google/protobuf/port_def.inc"
38 
39 namespace google {
40 namespace protobuf {
41 namespace internal {
42 namespace {
43 
real_type(FieldType type)44 inline WireFormatLite::FieldType real_type(FieldType type) {
45   ABSL_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
46   return static_cast<WireFormatLite::FieldType>(type);
47 }
48 
cpp_type(FieldType type)49 inline WireFormatLite::CppType cpp_type(FieldType type) {
50   return WireFormatLite::FieldTypeToCppType(real_type(type));
51 }
52 
53 // Registry stuff.
54 
55 struct ExtensionInfoKey {
56   const MessageLite* message;
57   int number;
58 };
59 
60 struct ExtensionEq {
61   using is_transparent = void;
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionEq62   bool operator()(const ExtensionInfo& lhs, const ExtensionInfo& rhs) const {
63     return lhs.message == rhs.message && lhs.number == rhs.number;
64   }
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionEq65   bool operator()(const ExtensionInfo& lhs, const ExtensionInfoKey& rhs) const {
66     return lhs.message == rhs.message && lhs.number == rhs.number;
67   }
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionEq68   bool operator()(const ExtensionInfoKey& lhs, const ExtensionInfo& rhs) const {
69     return lhs.message == rhs.message && lhs.number == rhs.number;
70   }
71 };
72 
73 struct ExtensionHasher {
74   using is_transparent = void;
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionHasher75   std::size_t operator()(const ExtensionInfo& info) const {
76     return absl::HashOf(info.message, info.number);
77   }
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionHasher78   std::size_t operator()(const ExtensionInfoKey& info) const {
79     return absl::HashOf(info.message, info.number);
80   }
81 };
82 
83 using ExtensionRegistry =
84     absl::flat_hash_set<ExtensionInfo, ExtensionHasher, ExtensionEq>;
85 
86 static const ExtensionRegistry* global_registry = nullptr;
87 
88 // This function is only called at startup, so there is no need for thread-
89 // safety.
Register(const ExtensionInfo & info)90 void Register(const ExtensionInfo& info) {
91   static auto local_static_registry = OnShutdownDelete(new ExtensionRegistry);
92   global_registry = local_static_registry;
93   if (!local_static_registry->insert(info).second) {
94     ABSL_LOG(FATAL) << "Multiple extension registrations for type \""
95                     << info.message->GetTypeName() << "\", field number "
96                     << info.number << ".";
97   }
98 }
99 
FindRegisteredExtension(const MessageLite * extendee,int number)100 const ExtensionInfo* FindRegisteredExtension(const MessageLite* extendee,
101                                              int number) {
102   if (!global_registry) return nullptr;
103 
104   ExtensionInfoKey info;
105   info.message = extendee;
106   info.number = number;
107 
108   auto it = global_registry->find(info);
109   if (it == global_registry->end()) {
110     return nullptr;
111   } else {
112     return &*it;
113   }
114 }
115 
116 }  // namespace
117 
Find(int number,ExtensionInfo * output)118 bool GeneratedExtensionFinder::Find(int number, ExtensionInfo* output) {
119   const ExtensionInfo* extension = FindRegisteredExtension(extendee_, number);
120   if (extension == nullptr) {
121     return false;
122   } else {
123     *output = *extension;
124     return true;
125   }
126 }
127 
RegisterExtension(const MessageLite * extendee,int number,FieldType type,bool is_repeated,bool is_packed)128 void ExtensionSet::RegisterExtension(const MessageLite* extendee, int number,
129                                      FieldType type, bool is_repeated,
130                                      bool is_packed) {
131   ABSL_CHECK_NE(type, WireFormatLite::TYPE_ENUM);
132   ABSL_CHECK_NE(type, WireFormatLite::TYPE_MESSAGE);
133   ABSL_CHECK_NE(type, WireFormatLite::TYPE_GROUP);
134   ExtensionInfo info(extendee, number, type, is_repeated, is_packed);
135   Register(info);
136 }
137 
CallNoArgValidityFunc(const void * arg,int number)138 static bool CallNoArgValidityFunc(const void* arg, int number) {
139   // Note:  Must use C-style cast here rather than reinterpret_cast because
140   //   the C++ standard at one point did not allow casts between function and
141   //   data pointers and some compilers enforce this for C++-style casts.  No
142   //   compiler enforces it for C-style casts since lots of C-style code has
143   //   relied on these kinds of casts for a long time, despite being
144   //   technically undefined.  See:
145   //     http://www.open-std.org/jtc1/sc22/wg21/docs/cwg_defects.html#195
146   // Also note:  Some compilers do not allow function pointers to be "const".
147   //   Which makes sense, I suppose, because it's meaningless.
148   return ((EnumValidityFunc*)arg)(number);
149 }
150 
RegisterEnumExtension(const MessageLite * extendee,int number,FieldType type,bool is_repeated,bool is_packed,EnumValidityFunc * is_valid)151 void ExtensionSet::RegisterEnumExtension(const MessageLite* extendee,
152                                          int number, FieldType type,
153                                          bool is_repeated, bool is_packed,
154                                          EnumValidityFunc* is_valid) {
155   ABSL_CHECK_EQ(type, WireFormatLite::TYPE_ENUM);
156   ExtensionInfo info(extendee, number, type, is_repeated, is_packed);
157   info.enum_validity_check.func = CallNoArgValidityFunc;
158   // See comment in CallNoArgValidityFunc() about why we use a c-style cast.
159   info.enum_validity_check.arg = (void*)is_valid;
160   Register(info);
161 }
162 
RegisterMessageExtension(const MessageLite * extendee,int number,FieldType type,bool is_repeated,bool is_packed,const MessageLite * prototype,LazyEagerVerifyFnType verify_func,LazyAnnotation is_lazy)163 void ExtensionSet::RegisterMessageExtension(const MessageLite* extendee,
164                                             int number, FieldType type,
165                                             bool is_repeated, bool is_packed,
166                                             const MessageLite* prototype,
167                                             LazyEagerVerifyFnType verify_func,
168                                             LazyAnnotation is_lazy) {
169   ABSL_CHECK(type == WireFormatLite::TYPE_MESSAGE ||
170              type == WireFormatLite::TYPE_GROUP);
171   ExtensionInfo info(extendee, number, type, is_repeated, is_packed,
172                      verify_func, is_lazy);
173   info.message_info = {prototype,
174 #if defined(PROTOBUF_CONSTINIT_DEFAULT_INSTANCES)
175                        prototype->GetTcParseTable()
176 #else
177                        nullptr
178 #endif
179   };
180   Register(info);
181 }
182 
183 // ===================================================================
184 // Constructors and basic methods.
185 
~ExtensionSet()186 ExtensionSet::~ExtensionSet() {
187   // Deletes all allocated extensions.
188   if (arena_ == nullptr) {
189     ForEach([](int /* number */, Extension& ext) { ext.Free(); },
190             PrefetchNta{});
191     if (PROTOBUF_PREDICT_FALSE(is_large())) {
192       delete map_.large;
193     } else {
194       DeleteFlatMap(map_.flat, flat_capacity_);
195     }
196   }
197 }
198 
DeleteFlatMap(const ExtensionSet::KeyValue * flat,uint16_t flat_capacity)199 void ExtensionSet::DeleteFlatMap(const ExtensionSet::KeyValue* flat,
200                                  uint16_t flat_capacity) {
201   // Arena::CreateArray already requires a trivially destructible type, but
202   // ensure this constraint is not violated in the future.
203   static_assert(std::is_trivially_destructible<KeyValue>::value,
204                 "CreateArray requires a trivially destructible type");
205   // A const-cast is needed, but this is safe as we are about to deallocate the
206   // array.
207   internal::SizedArrayDelete(const_cast<KeyValue*>(flat),
208                              sizeof(*flat) * flat_capacity);
209 }
210 
211 // Defined in extension_set_heavy.cc.
212 // void ExtensionSet::AppendToList(const Descriptor* extendee,
213 //                                 const DescriptorPool* pool,
214 //                                 vector<const FieldDescriptor*>* output) const
215 
Has(int number) const216 bool ExtensionSet::Has(int number) const {
217   const Extension* ext = FindOrNull(number);
218   if (ext == nullptr) return false;
219   ABSL_DCHECK(!ext->is_repeated);
220   return !ext->is_cleared;
221 }
222 
HasLazy(int number) const223 bool ExtensionSet::HasLazy(int number) const {
224   return Has(number) && FindOrNull(number)->is_lazy;
225 }
226 
NumExtensions() const227 int ExtensionSet::NumExtensions() const {
228   int result = 0;
229   ForEachNoPrefetch([&result](int /* number */, const Extension& ext) {
230     if (!ext.is_cleared) {
231       ++result;
232     }
233   });
234   return result;
235 }
236 
ExtensionSize(int number) const237 int ExtensionSet::ExtensionSize(int number) const {
238   const Extension* ext = FindOrNull(number);
239   return ext == nullptr ? 0 : ext->GetSize();
240 }
241 
ExtensionType(int number) const242 FieldType ExtensionSet::ExtensionType(int number) const {
243   const Extension* ext = FindOrNull(number);
244   if (ext == nullptr) {
245     ABSL_DLOG(FATAL)
246         << "Don't lookup extension types if they aren't present (1). ";
247     return 0;
248   }
249   if (ext->is_cleared) {
250     ABSL_DLOG(FATAL)
251         << "Don't lookup extension types if they aren't present (2). ";
252   }
253   return ext->type;
254 }
255 
ClearExtension(int number)256 void ExtensionSet::ClearExtension(int number) {
257   Extension* ext = FindOrNull(number);
258   if (ext == nullptr) return;
259   ext->Clear();
260 }
261 
262 // ===================================================================
263 // Field accessors
264 
265 namespace {
266 
267 enum { REPEATED_FIELD, OPTIONAL_FIELD };
268 
269 }  // namespace
270 
271 #define ABSL_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE)                         \
272   ABSL_DCHECK_EQ((EXTENSION).is_repeated ? REPEATED_FIELD : OPTIONAL_FIELD, \
273                  LABEL);                                                    \
274   ABSL_DCHECK_EQ(cpp_type((EXTENSION).type), WireFormatLite::CPPTYPE_##CPPTYPE)
275 
276 // -------------------------------------------------------------------
277 // Primitives
278 
279 #define PRIMITIVE_ACCESSORS(UPPERCASE, LOWERCASE, CAMELCASE)                  \
280                                                                               \
281   LOWERCASE ExtensionSet::Get##CAMELCASE(int number, LOWERCASE default_value) \
282       const {                                                                 \
283     const Extension* extension = FindOrNull(number);                          \
284     if (extension == nullptr || extension->is_cleared) {                      \
285       return default_value;                                                   \
286     } else {                                                                  \
287       ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE);                \
288       return extension->LOWERCASE##_value;                                    \
289     }                                                                         \
290   }                                                                           \
291                                                                               \
292   const LOWERCASE& ExtensionSet::GetRef##CAMELCASE(                           \
293       int number, const LOWERCASE& default_value) const {                     \
294     const Extension* extension = FindOrNull(number);                          \
295     if (extension == nullptr || extension->is_cleared) {                      \
296       return default_value;                                                   \
297     } else {                                                                  \
298       ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE);                \
299       return extension->LOWERCASE##_value;                                    \
300     }                                                                         \
301   }                                                                           \
302                                                                               \
303   void ExtensionSet::Set##CAMELCASE(int number, FieldType type,               \
304                                     LOWERCASE value,                          \
305                                     const FieldDescriptor* descriptor) {      \
306     Extension* extension;                                                     \
307     if (MaybeNewExtension(number, descriptor, &extension)) {                  \
308       extension->type = type;                                                 \
309       ABSL_DCHECK_EQ(cpp_type(extension->type),                               \
310                      WireFormatLite::CPPTYPE_##UPPERCASE);                    \
311       extension->is_repeated = false;                                         \
312       extension->is_pointer = false;                                          \
313     } else {                                                                  \
314       ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE);                \
315     }                                                                         \
316     extension->is_cleared = false;                                            \
317     extension->LOWERCASE##_value = value;                                     \
318   }                                                                           \
319                                                                               \
320   LOWERCASE ExtensionSet::GetRepeated##CAMELCASE(int number, int index)       \
321       const {                                                                 \
322     const Extension* extension = FindOrNull(number);                          \
323     ABSL_CHECK(extension != nullptr)                                          \
324         << "Index out-of-bounds (field is empty).";                           \
325     ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE);                  \
326     return extension->ptr.repeated_##LOWERCASE##_value->Get(index);           \
327   }                                                                           \
328                                                                               \
329   const LOWERCASE& ExtensionSet::GetRefRepeated##CAMELCASE(int number,        \
330                                                            int index) const { \
331     const Extension* extension = FindOrNull(number);                          \
332     ABSL_CHECK(extension != nullptr)                                          \
333         << "Index out-of-bounds (field is empty).";                           \
334     ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE);                  \
335     return extension->ptr.repeated_##LOWERCASE##_value->Get(index);           \
336   }                                                                           \
337                                                                               \
338   void ExtensionSet::SetRepeated##CAMELCASE(int number, int index,            \
339                                             LOWERCASE value) {                \
340     Extension* extension = FindOrNull(number);                                \
341     ABSL_CHECK(extension != nullptr)                                          \
342         << "Index out-of-bounds (field is empty).";                           \
343     ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE);                  \
344     extension->ptr.repeated_##LOWERCASE##_value->Set(index, value);           \
345   }                                                                           \
346                                                                               \
347   void ExtensionSet::Add##CAMELCASE(int number, FieldType type, bool packed,  \
348                                     LOWERCASE value,                          \
349                                     const FieldDescriptor* descriptor) {      \
350     Extension* extension;                                                     \
351     if (MaybeNewExtension(number, descriptor, &extension)) {                  \
352       extension->type = type;                                                 \
353       ABSL_DCHECK_EQ(cpp_type(extension->type),                               \
354                      WireFormatLite::CPPTYPE_##UPPERCASE);                    \
355       extension->is_repeated = true;                                          \
356       extension->is_pointer = true;                                           \
357       extension->is_packed = packed;                                          \
358       extension->ptr.repeated_##LOWERCASE##_value =                           \
359           Arena::Create<RepeatedField<LOWERCASE>>(arena_);                    \
360     } else {                                                                  \
361       ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE);                \
362       ABSL_DCHECK_EQ(extension->is_packed, packed);                           \
363     }                                                                         \
364     extension->ptr.repeated_##LOWERCASE##_value->Add(value);                  \
365   }
366 
PRIMITIVE_ACCESSORS(INT32,int32_t,Int32)367 PRIMITIVE_ACCESSORS(INT32, int32_t, Int32)
368 PRIMITIVE_ACCESSORS(INT64, int64_t, Int64)
369 PRIMITIVE_ACCESSORS(UINT32, uint32_t, UInt32)
370 PRIMITIVE_ACCESSORS(UINT64, uint64_t, UInt64)
371 PRIMITIVE_ACCESSORS(FLOAT, float, Float)
372 PRIMITIVE_ACCESSORS(DOUBLE, double, Double)
373 PRIMITIVE_ACCESSORS(BOOL, bool, Bool)
374 
375 #undef PRIMITIVE_ACCESSORS
376 
377 const void* ExtensionSet::GetRawRepeatedField(int number,
378                                               const void* default_value) const {
379   const Extension* extension = FindOrNull(number);
380   if (extension == nullptr) {
381     return default_value;
382   }
383   // We assume that all the RepeatedField<>* pointers have the same
384   // size and alignment within the anonymous union in Extension.
385   return extension->ptr.repeated_int32_t_value;
386 }
387 
MutableRawRepeatedField(int number,FieldType field_type,bool packed,const FieldDescriptor * desc)388 void* ExtensionSet::MutableRawRepeatedField(int number, FieldType field_type,
389                                             bool packed,
390                                             const FieldDescriptor* desc) {
391   Extension* extension;
392 
393   // We instantiate an empty Repeated{,Ptr}Field if one doesn't exist for this
394   // extension.
395   if (MaybeNewExtension(number, desc, &extension)) {
396     extension->is_repeated = true;
397     extension->is_pointer = true;
398     extension->type = field_type;
399     extension->is_packed = packed;
400 
401     switch (WireFormatLite::FieldTypeToCppType(
402         static_cast<WireFormatLite::FieldType>(field_type))) {
403       case WireFormatLite::CPPTYPE_INT32:
404         extension->ptr.repeated_int32_t_value =
405             Arena::Create<RepeatedField<int32_t>>(arena_);
406         break;
407       case WireFormatLite::CPPTYPE_INT64:
408         extension->ptr.repeated_int64_t_value =
409             Arena::Create<RepeatedField<int64_t>>(arena_);
410         break;
411       case WireFormatLite::CPPTYPE_UINT32:
412         extension->ptr.repeated_uint32_t_value =
413             Arena::Create<RepeatedField<uint32_t>>(arena_);
414         break;
415       case WireFormatLite::CPPTYPE_UINT64:
416         extension->ptr.repeated_uint64_t_value =
417             Arena::Create<RepeatedField<uint64_t>>(arena_);
418         break;
419       case WireFormatLite::CPPTYPE_DOUBLE:
420         extension->ptr.repeated_double_value =
421             Arena::Create<RepeatedField<double>>(arena_);
422         break;
423       case WireFormatLite::CPPTYPE_FLOAT:
424         extension->ptr.repeated_float_value =
425             Arena::Create<RepeatedField<float>>(arena_);
426         break;
427       case WireFormatLite::CPPTYPE_BOOL:
428         extension->ptr.repeated_bool_value =
429             Arena::Create<RepeatedField<bool>>(arena_);
430         break;
431       case WireFormatLite::CPPTYPE_ENUM:
432         extension->ptr.repeated_enum_value =
433             Arena::Create<RepeatedField<int>>(arena_);
434         break;
435       case WireFormatLite::CPPTYPE_STRING:
436         extension->ptr.repeated_string_value =
437             Arena::Create<RepeatedPtrField<std::string>>(arena_);
438         break;
439       case WireFormatLite::CPPTYPE_MESSAGE:
440         extension->ptr.repeated_message_value =
441             Arena::Create<RepeatedPtrField<MessageLite>>(arena_);
442         break;
443     }
444   }
445 
446   // We assume that all the RepeatedField<>* pointers have the same
447   // size and alignment within the anonymous union in Extension.
448   return extension->ptr.repeated_int32_t_value;
449 }
450 
451 // Compatible version using old call signature. Does not create extensions when
452 // the don't already exist; instead, just ABSL_CHECK-fails.
MutableRawRepeatedField(int number)453 void* ExtensionSet::MutableRawRepeatedField(int number) {
454   Extension* extension = FindOrNull(number);
455   ABSL_CHECK(extension != nullptr) << "Extension not found.";
456   // We assume that all the RepeatedField<>* pointers have the same
457   // size and alignment within the anonymous union in Extension.
458   return extension->ptr.repeated_int32_t_value;
459 }
460 
461 // -------------------------------------------------------------------
462 // Enums
463 
GetEnum(int number,int default_value) const464 int ExtensionSet::GetEnum(int number, int default_value) const {
465   const Extension* extension = FindOrNull(number);
466   if (extension == nullptr || extension->is_cleared) {
467     // Not present.  Return the default value.
468     return default_value;
469   } else {
470     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
471     return extension->enum_value;
472   }
473 }
474 
GetRefEnum(int number,const int & default_value) const475 const int& ExtensionSet::GetRefEnum(int number,
476                                     const int& default_value) const {
477   const Extension* extension = FindOrNull(number);
478   if (extension == nullptr || extension->is_cleared) {
479     // Not present.  Return the default value.
480     return default_value;
481   } else {
482     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
483     return extension->enum_value;
484   }
485 }
486 
SetEnum(int number,FieldType type,int value,const FieldDescriptor * descriptor)487 void ExtensionSet::SetEnum(int number, FieldType type, int value,
488                            const FieldDescriptor* descriptor) {
489   Extension* extension;
490   if (MaybeNewExtension(number, descriptor, &extension)) {
491     extension->type = type;
492     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
493     extension->is_repeated = false;
494     extension->is_pointer = false;
495   } else {
496     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
497   }
498   extension->is_cleared = false;
499   extension->enum_value = value;
500 }
501 
GetRepeatedEnum(int number,int index) const502 int ExtensionSet::GetRepeatedEnum(int number, int index) const {
503   const Extension* extension = FindOrNull(number);
504   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
505   ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
506   return extension->ptr.repeated_enum_value->Get(index);
507 }
508 
GetRefRepeatedEnum(int number,int index) const509 const int& ExtensionSet::GetRefRepeatedEnum(int number, int index) const {
510   const Extension* extension = FindOrNull(number);
511   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
512   ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
513   return extension->ptr.repeated_enum_value->Get(index);
514 }
515 
GetMessageByteSizeLong(int number) const516 size_t ExtensionSet::GetMessageByteSizeLong(int number) const {
517   const Extension* extension = FindOrNull(number);
518   ABSL_CHECK(extension != nullptr) << "not present";
519   ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
520   return extension->is_lazy ? extension->ptr.lazymessage_value->ByteSizeLong()
521                             : extension->ptr.message_value->ByteSizeLong();
522 }
523 
InternalSerializeMessage(int number,const MessageLite * prototype,uint8_t * target,io::EpsCopyOutputStream * stream) const524 uint8_t* ExtensionSet::InternalSerializeMessage(
525     int number, const MessageLite* prototype, uint8_t* target,
526     io::EpsCopyOutputStream* stream) const {
527   const Extension* extension = FindOrNull(number);
528   ABSL_CHECK(extension != nullptr) << "not present";
529   ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
530 
531   if (extension->is_lazy) {
532     return extension->ptr.lazymessage_value->WriteMessageToArray(
533         prototype, number, target, stream);
534   }
535 
536   const auto* msg = extension->ptr.message_value;
537   return WireFormatLite::InternalWriteMessage(
538       number, *msg, msg->GetCachedSize(), target, stream);
539 }
540 
SetRepeatedEnum(int number,int index,int value)541 void ExtensionSet::SetRepeatedEnum(int number, int index, int value) {
542   Extension* extension = FindOrNull(number);
543   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
544   ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
545   extension->ptr.repeated_enum_value->Set(index, value);
546 }
547 
AddEnum(int number,FieldType type,bool packed,int value,const FieldDescriptor * descriptor)548 void ExtensionSet::AddEnum(int number, FieldType type, bool packed, int value,
549                            const FieldDescriptor* descriptor) {
550   Extension* extension;
551   if (MaybeNewExtension(number, descriptor, &extension)) {
552     extension->type = type;
553     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
554     extension->is_repeated = true;
555     extension->is_pointer = true;
556     extension->is_packed = packed;
557     extension->ptr.repeated_enum_value =
558         Arena::Create<RepeatedField<int>>(arena_);
559   } else {
560     ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
561     ABSL_DCHECK_EQ(extension->is_packed, packed);
562   }
563   extension->ptr.repeated_enum_value->Add(value);
564 }
565 
566 // -------------------------------------------------------------------
567 // Strings
568 
GetString(int number,const std::string & default_value) const569 const std::string& ExtensionSet::GetString(
570     int number, const std::string& default_value) const {
571   const Extension* extension = FindOrNull(number);
572   if (extension == nullptr || extension->is_cleared) {
573     // Not present.  Return the default value.
574     return default_value;
575   } else {
576     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
577     return *extension->ptr.string_value;
578   }
579 }
580 
MutableString(int number,FieldType type,const FieldDescriptor * descriptor)581 std::string* ExtensionSet::MutableString(int number, FieldType type,
582                                          const FieldDescriptor* descriptor) {
583   Extension* extension;
584   if (MaybeNewExtension(number, descriptor, &extension)) {
585     extension->type = type;
586     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
587     extension->is_repeated = false;
588     extension->is_pointer = true;
589     extension->ptr.string_value = Arena::Create<std::string>(arena_);
590   } else {
591     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
592   }
593   extension->is_cleared = false;
594   return extension->ptr.string_value;
595 }
596 
GetRepeatedString(int number,int index) const597 const std::string& ExtensionSet::GetRepeatedString(int number,
598                                                    int index) const {
599   const Extension* extension = FindOrNull(number);
600   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
601   ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
602   return extension->ptr.repeated_string_value->Get(index);
603 }
604 
MutableRepeatedString(int number,int index)605 std::string* ExtensionSet::MutableRepeatedString(int number, int index) {
606   Extension* extension = FindOrNull(number);
607   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
608   ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
609   return extension->ptr.repeated_string_value->Mutable(index);
610 }
611 
AddString(int number,FieldType type,const FieldDescriptor * descriptor)612 std::string* ExtensionSet::AddString(int number, FieldType type,
613                                      const FieldDescriptor* descriptor) {
614   Extension* extension;
615   if (MaybeNewExtension(number, descriptor, &extension)) {
616     extension->type = type;
617     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
618     extension->is_repeated = true;
619     extension->is_pointer = true;
620     extension->is_packed = false;
621     extension->ptr.repeated_string_value =
622         Arena::Create<RepeatedPtrField<std::string>>(arena_);
623   } else {
624     ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
625   }
626   return extension->ptr.repeated_string_value->Add();
627 }
628 
629 // -------------------------------------------------------------------
630 // Messages
631 
GetMessage(int number,const MessageLite & default_value) const632 const MessageLite& ExtensionSet::GetMessage(
633     int number, const MessageLite& default_value) const {
634   const Extension* extension = FindOrNull(number);
635   if (extension == nullptr) {
636     // Not present.  Return the default value.
637     return default_value;
638   } else {
639     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
640     if (extension->is_lazy) {
641       return extension->ptr.lazymessage_value->GetMessage(default_value,
642                                                           arena_);
643     } else {
644       return *extension->ptr.message_value;
645     }
646   }
647 }
648 
649 // Defined in extension_set_heavy.cc.
650 // const MessageLite& ExtensionSet::GetMessage(int number,
651 //                                             const Descriptor* message_type,
652 //                                             MessageFactory* factory) const
653 
MutableMessage(int number,FieldType type,const MessageLite & prototype,const FieldDescriptor * descriptor)654 MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
655                                           const MessageLite& prototype,
656                                           const FieldDescriptor* descriptor) {
657   Extension* extension;
658   if (MaybeNewExtension(number, descriptor, &extension)) {
659     extension->type = type;
660     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
661     extension->is_repeated = false;
662     extension->is_pointer = true;
663     extension->is_lazy = false;
664     extension->ptr.message_value = prototype.New(arena_);
665     extension->is_cleared = false;
666     return extension->ptr.message_value;
667   } else {
668     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
669     extension->is_cleared = false;
670     if (extension->is_lazy) {
671       return extension->ptr.lazymessage_value->MutableMessage(prototype,
672                                                               arena_);
673     } else {
674       return extension->ptr.message_value;
675     }
676   }
677 }
678 
679 // Defined in extension_set_heavy.cc.
680 // MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
681 //                                           const Descriptor* message_type,
682 //                                           MessageFactory* factory)
683 
SetAllocatedMessage(int number,FieldType type,const FieldDescriptor * descriptor,MessageLite * message)684 void ExtensionSet::SetAllocatedMessage(int number, FieldType type,
685                                        const FieldDescriptor* descriptor,
686                                        MessageLite* message) {
687   if (message == nullptr) {
688     ClearExtension(number);
689     return;
690   }
691   Arena* const arena = arena_;
692   Arena* const message_arena = message->GetArena();
693   ABSL_DCHECK(message_arena == nullptr || message_arena == arena);
694 
695   Extension* extension;
696   if (MaybeNewExtension(number, descriptor, &extension)) {
697     extension->type = type;
698     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
699     extension->is_repeated = false;
700     extension->is_pointer = true;
701     extension->is_lazy = false;
702     if (message_arena == arena) {
703       extension->ptr.message_value = message;
704     } else if (message_arena == nullptr) {
705       extension->ptr.message_value = message;
706       arena->Own(message);  // not nullptr because not equal to message_arena
707     } else {
708       extension->ptr.message_value = message->New(arena);
709       extension->ptr.message_value->CheckTypeAndMergeFrom(*message);
710     }
711   } else {
712     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
713     if (extension->is_lazy) {
714       extension->ptr.lazymessage_value->SetAllocatedMessage(message, arena);
715     } else {
716       if (arena == nullptr) {
717         delete extension->ptr.message_value;
718       }
719       if (message_arena == arena) {
720         extension->ptr.message_value = message;
721       } else if (message_arena == nullptr) {
722         extension->ptr.message_value = message;
723         arena->Own(message);  // not nullptr because not equal to message_arena
724       } else {
725         extension->ptr.message_value = message->New(arena);
726         extension->ptr.message_value->CheckTypeAndMergeFrom(*message);
727       }
728     }
729   }
730   extension->is_cleared = false;
731 }
732 
UnsafeArenaSetAllocatedMessage(int number,FieldType type,const FieldDescriptor * descriptor,MessageLite * message)733 void ExtensionSet::UnsafeArenaSetAllocatedMessage(
734     int number, FieldType type, const FieldDescriptor* descriptor,
735     MessageLite* message) {
736   if (message == nullptr) {
737     ClearExtension(number);
738     return;
739   }
740   Extension* extension;
741   if (MaybeNewExtension(number, descriptor, &extension)) {
742     extension->type = type;
743     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
744     extension->is_repeated = false;
745     extension->is_pointer = true;
746     extension->is_lazy = false;
747     extension->ptr.message_value = message;
748   } else {
749     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
750     if (extension->is_lazy) {
751       extension->ptr.lazymessage_value->UnsafeArenaSetAllocatedMessage(message,
752                                                                        arena_);
753     } else {
754       if (arena_ == nullptr) {
755         delete extension->ptr.message_value;
756       }
757       extension->ptr.message_value = message;
758     }
759   }
760   extension->is_cleared = false;
761 }
762 
ReleaseMessage(int number,const MessageLite & prototype)763 MessageLite* ExtensionSet::ReleaseMessage(int number,
764                                           const MessageLite& prototype) {
765   Extension* extension = FindOrNull(number);
766   if (extension == nullptr) {
767     // Not present.  Return nullptr.
768     return nullptr;
769   } else {
770     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
771     MessageLite* ret = nullptr;
772     if (extension->is_lazy) {
773       Arena* const arena = arena_;
774       ret = extension->ptr.lazymessage_value->ReleaseMessage(prototype, arena);
775       if (arena == nullptr) {
776         delete extension->ptr.lazymessage_value;
777       }
778     } else {
779       if (arena_ == nullptr) {
780         ret = extension->ptr.message_value;
781       } else {
782         // ReleaseMessage() always returns a heap-allocated message, and we are
783         // on an arena, so we need to make a copy of this message to return.
784         ret = extension->ptr.message_value->New();
785         ret->CheckTypeAndMergeFrom(*extension->ptr.message_value);
786       }
787     }
788     Erase(number);
789     return ret;
790   }
791 }
792 
UnsafeArenaReleaseMessage(int number,const MessageLite & prototype)793 MessageLite* ExtensionSet::UnsafeArenaReleaseMessage(
794     int number, const MessageLite& prototype) {
795   Extension* extension = FindOrNull(number);
796   if (extension == nullptr) {
797     // Not present.  Return nullptr.
798     return nullptr;
799   } else {
800     ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
801     MessageLite* ret = nullptr;
802     if (extension->is_lazy) {
803       Arena* const arena = arena_;
804       ret = extension->ptr.lazymessage_value->UnsafeArenaReleaseMessage(
805           prototype, arena);
806       if (arena == nullptr) {
807         delete extension->ptr.lazymessage_value;
808       }
809     } else {
810       ret = extension->ptr.message_value;
811     }
812     Erase(number);
813     return ret;
814   }
815 }
816 
817 // Defined in extension_set_heavy.cc.
818 // MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
819 //                                           MessageFactory* factory);
820 
GetRepeatedMessage(int number,int index) const821 const MessageLite& ExtensionSet::GetRepeatedMessage(int number,
822                                                     int index) const {
823   const Extension* extension = FindOrNull(number);
824   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
825   ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
826   return extension->ptr.repeated_message_value->Get(index);
827 }
828 
MutableRepeatedMessage(int number,int index)829 MessageLite* ExtensionSet::MutableRepeatedMessage(int number, int index) {
830   Extension* extension = FindOrNull(number);
831   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
832   ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
833   return extension->ptr.repeated_message_value->Mutable(index);
834 }
835 
AddMessage(int number,FieldType type,const MessageLite & prototype,const FieldDescriptor * descriptor)836 MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
837                                       const MessageLite& prototype,
838                                       const FieldDescriptor* descriptor) {
839   Extension* extension;
840   if (MaybeNewExtension(number, descriptor, &extension)) {
841     extension->type = type;
842     ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
843     extension->is_repeated = true;
844     extension->is_pointer = true;
845     extension->ptr.repeated_message_value =
846         Arena::Create<RepeatedPtrField<MessageLite>>(arena_);
847   } else {
848     ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
849   }
850 
851   return reinterpret_cast<internal::RepeatedPtrFieldBase*>(
852              extension->ptr.repeated_message_value)
853       ->AddMessage(&prototype);
854 }
855 
856 // Defined in extension_set_heavy.cc.
857 // MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
858 //                                       const Descriptor* message_type,
859 //                                       MessageFactory* factory)
860 
861 #undef ABSL_DCHECK_TYPE
862 
RemoveLast(int number)863 void ExtensionSet::RemoveLast(int number) {
864   Extension* extension = FindOrNull(number);
865   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
866   ABSL_DCHECK(extension->is_repeated);
867 
868   switch (cpp_type(extension->type)) {
869     case WireFormatLite::CPPTYPE_INT32:
870       extension->ptr.repeated_int32_t_value->RemoveLast();
871       break;
872     case WireFormatLite::CPPTYPE_INT64:
873       extension->ptr.repeated_int64_t_value->RemoveLast();
874       break;
875     case WireFormatLite::CPPTYPE_UINT32:
876       extension->ptr.repeated_uint32_t_value->RemoveLast();
877       break;
878     case WireFormatLite::CPPTYPE_UINT64:
879       extension->ptr.repeated_uint64_t_value->RemoveLast();
880       break;
881     case WireFormatLite::CPPTYPE_FLOAT:
882       extension->ptr.repeated_float_value->RemoveLast();
883       break;
884     case WireFormatLite::CPPTYPE_DOUBLE:
885       extension->ptr.repeated_double_value->RemoveLast();
886       break;
887     case WireFormatLite::CPPTYPE_BOOL:
888       extension->ptr.repeated_bool_value->RemoveLast();
889       break;
890     case WireFormatLite::CPPTYPE_ENUM:
891       extension->ptr.repeated_enum_value->RemoveLast();
892       break;
893     case WireFormatLite::CPPTYPE_STRING:
894       extension->ptr.repeated_string_value->RemoveLast();
895       break;
896     case WireFormatLite::CPPTYPE_MESSAGE:
897       extension->ptr.repeated_message_value->RemoveLast();
898       break;
899   }
900 }
901 
ReleaseLast(int number)902 MessageLite* ExtensionSet::ReleaseLast(int number) {
903   Extension* extension = FindOrNull(number);
904   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
905   ABSL_DCHECK(extension->is_repeated);
906   ABSL_DCHECK(cpp_type(extension->type) == WireFormatLite::CPPTYPE_MESSAGE);
907   return extension->ptr.repeated_message_value->ReleaseLast();
908 }
909 
UnsafeArenaReleaseLast(int number)910 MessageLite* ExtensionSet::UnsafeArenaReleaseLast(int number) {
911   Extension* extension = FindOrNull(number);
912   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
913   ABSL_DCHECK(extension->is_repeated);
914   ABSL_DCHECK(cpp_type(extension->type) == WireFormatLite::CPPTYPE_MESSAGE);
915   return extension->ptr.repeated_message_value->UnsafeArenaReleaseLast();
916 }
917 
SwapElements(int number,int index1,int index2)918 void ExtensionSet::SwapElements(int number, int index1, int index2) {
919   Extension* extension = FindOrNull(number);
920   ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
921   ABSL_DCHECK(extension->is_repeated);
922 
923   switch (cpp_type(extension->type)) {
924     case WireFormatLite::CPPTYPE_INT32:
925       extension->ptr.repeated_int32_t_value->SwapElements(index1, index2);
926       break;
927     case WireFormatLite::CPPTYPE_INT64:
928       extension->ptr.repeated_int64_t_value->SwapElements(index1, index2);
929       break;
930     case WireFormatLite::CPPTYPE_UINT32:
931       extension->ptr.repeated_uint32_t_value->SwapElements(index1, index2);
932       break;
933     case WireFormatLite::CPPTYPE_UINT64:
934       extension->ptr.repeated_uint64_t_value->SwapElements(index1, index2);
935       break;
936     case WireFormatLite::CPPTYPE_FLOAT:
937       extension->ptr.repeated_float_value->SwapElements(index1, index2);
938       break;
939     case WireFormatLite::CPPTYPE_DOUBLE:
940       extension->ptr.repeated_double_value->SwapElements(index1, index2);
941       break;
942     case WireFormatLite::CPPTYPE_BOOL:
943       extension->ptr.repeated_bool_value->SwapElements(index1, index2);
944       break;
945     case WireFormatLite::CPPTYPE_ENUM:
946       extension->ptr.repeated_enum_value->SwapElements(index1, index2);
947       break;
948     case WireFormatLite::CPPTYPE_STRING:
949       extension->ptr.repeated_string_value->SwapElements(index1, index2);
950       break;
951     case WireFormatLite::CPPTYPE_MESSAGE:
952       extension->ptr.repeated_message_value->SwapElements(index1, index2);
953       break;
954   }
955 }
956 
957 // ===================================================================
958 
Clear()959 void ExtensionSet::Clear() {
960   ForEach([](int /* number */, Extension& ext) { ext.Clear(); }, Prefetch{});
961 }
962 
963 namespace {
964 // Computes the size of an ExtensionSet union without actually constructing the
965 // union. Note that we do not count cleared extensions from the source to be
966 // part of the total, because there is no need to allocate space for those. We
967 // do include cleared extensions in the destination, though, because those are
968 // already allocated and will not be going away.
969 template <typename ItX, typename ItY>
SizeOfUnion(ItX it_dest,ItX end_dest,ItY it_source,ItY end_source)970 size_t SizeOfUnion(ItX it_dest, ItX end_dest, ItY it_source, ItY end_source) {
971   size_t result = 0;
972   while (it_dest != end_dest && it_source != end_source) {
973     if (it_dest->first < it_source->first) {
974       ++result;
975       ++it_dest;
976     } else if (it_dest->first == it_source->first) {
977       ++result;
978       ++it_dest;
979       ++it_source;
980     } else {
981       if (!it_source->second.is_cleared) {
982         ++result;
983       }
984       ++it_source;
985     }
986   }
987   result += std::distance(it_dest, end_dest);
988   for (; it_source != end_source; ++it_source) {
989     if (!it_source->second.is_cleared) {
990       ++result;
991     }
992   }
993   return result;
994 }
995 }  // namespace
996 
MergeFrom(const MessageLite * extendee,const ExtensionSet & other)997 void ExtensionSet::MergeFrom(const MessageLite* extendee,
998                              const ExtensionSet& other) {
999   Prefetch5LinesFrom1Line(&other);
1000   if (PROTOBUF_PREDICT_TRUE(!is_large())) {
1001     if (PROTOBUF_PREDICT_TRUE(!other.is_large())) {
1002       GrowCapacity(SizeOfUnion(flat_begin(), flat_end(), other.flat_begin(),
1003                                other.flat_end()));
1004     } else {
1005       GrowCapacity(SizeOfUnion(flat_begin(), flat_end(),
1006                                other.map_.large->begin(),
1007                                other.map_.large->end()));
1008     }
1009   }
1010   other.ForEach(
1011       [extendee, this, &other](int number, const Extension& ext) {
1012         this->InternalExtensionMergeFrom(extendee, number, ext, other.arena_);
1013       },
1014       Prefetch{});
1015 }
1016 
InternalExtensionMergeFrom(const MessageLite * extendee,int number,const Extension & other_extension,Arena * other_arena)1017 void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
1018                                               int number,
1019                                               const Extension& other_extension,
1020                                               Arena* other_arena) {
1021   if (other_extension.is_repeated) {
1022     Extension* extension;
1023     bool is_new =
1024         MaybeNewExtension(number, other_extension.descriptor, &extension);
1025     if (is_new) {
1026       // Extension did not already exist in set.
1027       extension->type = other_extension.type;
1028       extension->is_packed = other_extension.is_packed;
1029       extension->is_repeated = true;
1030       extension->is_pointer = true;
1031     } else {
1032       ABSL_DCHECK_EQ(extension->type, other_extension.type);
1033       ABSL_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
1034       ABSL_DCHECK(extension->is_repeated);
1035     }
1036 
1037     switch (cpp_type(other_extension.type)) {
1038 #define HANDLE_TYPE(UPPERCASE, LOWERCASE, REPEATED_TYPE)    \
1039   case WireFormatLite::CPPTYPE_##UPPERCASE:                 \
1040     if (is_new) {                                           \
1041       extension->ptr.repeated_##LOWERCASE##_value =         \
1042           Arena::Create<REPEATED_TYPE>(arena_);             \
1043     }                                                       \
1044     extension->ptr.repeated_##LOWERCASE##_value->MergeFrom( \
1045         *other_extension.ptr.repeated_##LOWERCASE##_value); \
1046     break;
1047 
1048       HANDLE_TYPE(INT32, int32_t, RepeatedField<int32_t>);
1049       HANDLE_TYPE(INT64, int64_t, RepeatedField<int64_t>);
1050       HANDLE_TYPE(UINT32, uint32_t, RepeatedField<uint32_t>);
1051       HANDLE_TYPE(UINT64, uint64_t, RepeatedField<uint64_t>);
1052       HANDLE_TYPE(FLOAT, float, RepeatedField<float>);
1053       HANDLE_TYPE(DOUBLE, double, RepeatedField<double>);
1054       HANDLE_TYPE(BOOL, bool, RepeatedField<bool>);
1055       HANDLE_TYPE(ENUM, enum, RepeatedField<int>);
1056       HANDLE_TYPE(STRING, string, RepeatedPtrField<std::string>);
1057       HANDLE_TYPE(MESSAGE, message, RepeatedPtrField<MessageLite>);
1058 #undef HANDLE_TYPE
1059     }
1060   } else {
1061     if (!other_extension.is_cleared) {
1062       switch (cpp_type(other_extension.type)) {
1063 #define HANDLE_TYPE(UPPERCASE, LOWERCASE, CAMELCASE)  \
1064   case WireFormatLite::CPPTYPE_##UPPERCASE:           \
1065     Set##CAMELCASE(number, other_extension.type,      \
1066                    other_extension.LOWERCASE##_value, \
1067                    other_extension.descriptor);       \
1068     break;
1069 
1070         HANDLE_TYPE(INT32, int32_t, Int32);
1071         HANDLE_TYPE(INT64, int64_t, Int64);
1072         HANDLE_TYPE(UINT32, uint32_t, UInt32);
1073         HANDLE_TYPE(UINT64, uint64_t, UInt64);
1074         HANDLE_TYPE(FLOAT, float, Float);
1075         HANDLE_TYPE(DOUBLE, double, Double);
1076         HANDLE_TYPE(BOOL, bool, Bool);
1077         HANDLE_TYPE(ENUM, enum, Enum);
1078 #undef HANDLE_TYPE
1079         case WireFormatLite::CPPTYPE_STRING:
1080           SetString(number, other_extension.type,
1081                     *other_extension.ptr.string_value,
1082                     other_extension.descriptor);
1083           break;
1084         case WireFormatLite::CPPTYPE_MESSAGE: {
1085           Arena* const arena = arena_;
1086           Extension* extension;
1087           bool is_new =
1088               MaybeNewExtension(number, other_extension.descriptor, &extension);
1089           if (is_new) {
1090             extension->type = other_extension.type;
1091             extension->is_packed = other_extension.is_packed;
1092             extension->is_repeated = false;
1093             extension->is_pointer = true;
1094             if (other_extension.is_lazy) {
1095               extension->is_lazy = true;
1096               extension->ptr.lazymessage_value =
1097                   other_extension.ptr.lazymessage_value->New(arena);
1098               extension->ptr.lazymessage_value->MergeFrom(
1099                   GetPrototypeForLazyMessage(extendee, number),
1100                   *other_extension.ptr.lazymessage_value, arena, other_arena);
1101             } else {
1102               extension->is_lazy = false;
1103               extension->ptr.message_value =
1104                   other_extension.ptr.message_value->New(arena);
1105               extension->ptr.message_value->CheckTypeAndMergeFrom(
1106                   *other_extension.ptr.message_value);
1107             }
1108           } else {
1109             ABSL_DCHECK_EQ(extension->type, other_extension.type);
1110             ABSL_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
1111             ABSL_DCHECK(!extension->is_repeated);
1112             if (other_extension.is_lazy) {
1113               if (extension->is_lazy) {
1114                 extension->ptr.lazymessage_value->MergeFrom(
1115                     GetPrototypeForLazyMessage(extendee, number),
1116                     *other_extension.ptr.lazymessage_value, arena, other_arena);
1117               } else {
1118                 extension->ptr.message_value->CheckTypeAndMergeFrom(
1119                     other_extension.ptr.lazymessage_value->GetMessage(
1120                         *extension->ptr.message_value, other_arena));
1121               }
1122             } else {
1123               if (extension->is_lazy) {
1124                 extension->ptr.lazymessage_value
1125                     ->MutableMessage(*other_extension.ptr.message_value, arena)
1126                     ->CheckTypeAndMergeFrom(*other_extension.ptr.message_value);
1127               } else {
1128                 extension->ptr.message_value->CheckTypeAndMergeFrom(
1129                     *other_extension.ptr.message_value);
1130               }
1131             }
1132           }
1133           extension->is_cleared = false;
1134           break;
1135         }
1136       }
1137     }
1138   }
1139 }
1140 
Swap(const MessageLite * extendee,ExtensionSet * other)1141 void ExtensionSet::Swap(const MessageLite* extendee, ExtensionSet* other) {
1142   if (internal::CanUseInternalSwap(arena_, other->arena_)) {
1143     InternalSwap(other);
1144   } else {
1145     // TODO: We maybe able to optimize a case where we are
1146     // swapping from heap to arena-allocated extension set, by just Own()'ing
1147     // the extensions.
1148     ExtensionSet extension_set;
1149     extension_set.MergeFrom(extendee, *other);
1150     other->Clear();
1151     other->MergeFrom(extendee, *this);
1152     Clear();
1153     MergeFrom(extendee, extension_set);
1154   }
1155 }
1156 
InternalSwap(ExtensionSet * other)1157 void ExtensionSet::InternalSwap(ExtensionSet* other) {
1158   using std::swap;
1159   swap(arena_, other->arena_);
1160   swap(flat_capacity_, other->flat_capacity_);
1161   swap(flat_size_, other->flat_size_);
1162   swap(map_, other->map_);
1163 }
1164 
SwapExtension(const MessageLite * extendee,ExtensionSet * other,int number)1165 void ExtensionSet::SwapExtension(const MessageLite* extendee,
1166                                  ExtensionSet* other, int number) {
1167   if (this == other) return;
1168 
1169   Arena* const arena = arena_;
1170   Arena* const other_arena = other->arena_;
1171   if (arena == other_arena) {
1172     UnsafeShallowSwapExtension(other, number);
1173     return;
1174   }
1175 
1176   Extension* this_ext = FindOrNull(number);
1177   Extension* other_ext = other->FindOrNull(number);
1178 
1179   if (this_ext == other_ext) return;
1180 
1181   if (this_ext != nullptr && other_ext != nullptr) {
1182     // TODO: We could further optimize these cases,
1183     // especially avoid creation of ExtensionSet, and move MergeFrom logic
1184     // into Extensions itself (which takes arena as an argument).
1185     // We do it this way to reuse the copy-across-arenas logic already
1186     // implemented in ExtensionSet's MergeFrom.
1187     ExtensionSet temp;
1188     temp.InternalExtensionMergeFrom(extendee, number, *other_ext, other_arena);
1189     Extension* temp_ext = temp.FindOrNull(number);
1190 
1191     other_ext->Clear();
1192     other->InternalExtensionMergeFrom(extendee, number, *this_ext, arena);
1193     this_ext->Clear();
1194     InternalExtensionMergeFrom(extendee, number, *temp_ext, temp.GetArena());
1195   } else if (this_ext == nullptr) {
1196     InternalExtensionMergeFrom(extendee, number, *other_ext, other_arena);
1197     if (other_arena == nullptr) other_ext->Free();
1198     other->Erase(number);
1199   } else {
1200     other->InternalExtensionMergeFrom(extendee, number, *this_ext, arena);
1201     if (arena == nullptr) this_ext->Free();
1202     Erase(number);
1203   }
1204 }
1205 
UnsafeShallowSwapExtension(ExtensionSet * other,int number)1206 void ExtensionSet::UnsafeShallowSwapExtension(ExtensionSet* other, int number) {
1207   if (this == other) return;
1208 
1209   Extension* this_ext = FindOrNull(number);
1210   Extension* other_ext = other->FindOrNull(number);
1211 
1212   if (this_ext == other_ext) return;
1213 
1214   ABSL_DCHECK_EQ(arena_, other->arena_);
1215 
1216   if (this_ext != nullptr && other_ext != nullptr) {
1217     std::swap(*this_ext, *other_ext);
1218   } else if (this_ext == nullptr) {
1219     *Insert(number).first = *other_ext;
1220     other->Erase(number);
1221   } else {
1222     *other->Insert(number).first = *this_ext;
1223     Erase(number);
1224   }
1225 }
1226 
IsInitialized(const MessageLite * extendee) const1227 bool ExtensionSet::IsInitialized(const MessageLite* extendee) const {
1228   // Extensions are never required.  However, we need to check that all
1229   // embedded messages are initialized.
1230   Arena* const arena = arena_;
1231   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1232     for (const auto& kv : *map_.large) {
1233       if (!kv.second.IsInitialized(this, extendee, kv.first, arena)) {
1234         return false;
1235       }
1236     }
1237     return true;
1238   }
1239   for (const KeyValue* it = flat_begin(); it != flat_end(); ++it) {
1240     if (!it->second.IsInitialized(this, extendee, it->first, arena)) {
1241       return false;
1242     }
1243   }
1244   return true;
1245 }
1246 
ParseField(uint64_t tag,const char * ptr,const MessageLite * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)1247 const char* ExtensionSet::ParseField(uint64_t tag, const char* ptr,
1248                                      const MessageLite* extendee,
1249                                      internal::InternalMetadata* metadata,
1250                                      internal::ParseContext* ctx) {
1251   GeneratedExtensionFinder finder(extendee);
1252   int number = tag >> 3;
1253   bool was_packed_on_wire;
1254   ExtensionInfo extension;
1255   if (!FindExtensionInfoFromFieldNumber(tag & 7, number, &finder, &extension,
1256                                         &was_packed_on_wire)) {
1257     return UnknownFieldParse(
1258         tag, metadata->mutable_unknown_fields<std::string>(), ptr, ctx);
1259   }
1260   return ParseFieldWithExtensionInfo<std::string>(
1261       number, was_packed_on_wire, extension, metadata, ptr, ctx);
1262 }
1263 
ParseMessageSetItem(const char * ptr,const MessageLite * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)1264 const char* ExtensionSet::ParseMessageSetItem(
1265     const char* ptr, const MessageLite* extendee,
1266     internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
1267   return ParseMessageSetItemTmpl<MessageLite, std::string>(ptr, extendee,
1268                                                            metadata, ctx);
1269 }
1270 
FieldTypeIsPointer(FieldType type)1271 bool ExtensionSet::FieldTypeIsPointer(FieldType type) {
1272   return type == WireFormatLite::TYPE_STRING ||
1273          type == WireFormatLite::TYPE_BYTES ||
1274          type == WireFormatLite::TYPE_GROUP ||
1275          type == WireFormatLite::TYPE_MESSAGE;
1276 }
1277 
_InternalSerializeImpl(const MessageLite * extendee,int start_field_number,int end_field_number,uint8_t * target,io::EpsCopyOutputStream * stream) const1278 uint8_t* ExtensionSet::_InternalSerializeImpl(
1279     const MessageLite* extendee, int start_field_number, int end_field_number,
1280     uint8_t* target, io::EpsCopyOutputStream* stream) const {
1281   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1282     const auto& end = map_.large->end();
1283     for (auto it = map_.large->lower_bound(start_field_number);
1284          it != end && it->first < end_field_number; ++it) {
1285       target = it->second.InternalSerializeFieldWithCachedSizesToArray(
1286           extendee, this, it->first, target, stream);
1287     }
1288     return target;
1289   }
1290   const KeyValue* end = flat_end();
1291   const KeyValue* it = flat_begin();
1292   while (it != end && it->first < start_field_number) ++it;
1293   for (; it != end && it->first < end_field_number; ++it) {
1294     target = it->second.InternalSerializeFieldWithCachedSizesToArray(
1295         extendee, this, it->first, target, stream);
1296   }
1297   return target;
1298 }
1299 
InternalSerializeMessageSetWithCachedSizesToArray(const MessageLite * extendee,uint8_t * target,io::EpsCopyOutputStream * stream) const1300 uint8_t* ExtensionSet::InternalSerializeMessageSetWithCachedSizesToArray(
1301     const MessageLite* extendee, uint8_t* target,
1302     io::EpsCopyOutputStream* stream) const {
1303   const ExtensionSet* extension_set = this;
1304   ForEach(
1305       [&target, extendee, stream, extension_set](int number,
1306                                                  const Extension& ext) {
1307         target = ext.InternalSerializeMessageSetItemWithCachedSizesToArray(
1308             extendee, extension_set, number, target, stream);
1309       },
1310       Prefetch{});
1311   return target;
1312 }
1313 
ByteSize() const1314 size_t ExtensionSet::ByteSize() const {
1315   size_t total_size = 0;
1316   ForEach(
1317       [&total_size](int number, const Extension& ext) {
1318         total_size += ext.ByteSize(number);
1319       },
1320       Prefetch{});
1321   return total_size;
1322 }
1323 
1324 // Defined in extension_set_heavy.cc.
1325 // int ExtensionSet::SpaceUsedExcludingSelf() const
1326 
MaybeNewExtension(int number,const FieldDescriptor * descriptor,Extension ** result)1327 bool ExtensionSet::MaybeNewExtension(int number,
1328                                      const FieldDescriptor* descriptor,
1329                                      Extension** result) {
1330   bool extension_is_new = false;
1331   std::tie(*result, extension_is_new) = Insert(number);
1332   (*result)->descriptor = descriptor;
1333   return extension_is_new;
1334 }
1335 
1336 // ===================================================================
1337 // Methods of ExtensionSet::Extension
1338 
Clear()1339 void ExtensionSet::Extension::Clear() {
1340   if (is_repeated) {
1341     switch (cpp_type(type)) {
1342 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)      \
1343   case WireFormatLite::CPPTYPE_##UPPERCASE:    \
1344     ptr.repeated_##LOWERCASE##_value->Clear(); \
1345     break
1346 
1347       HANDLE_TYPE(INT32, int32_t);
1348       HANDLE_TYPE(INT64, int64_t);
1349       HANDLE_TYPE(UINT32, uint32_t);
1350       HANDLE_TYPE(UINT64, uint64_t);
1351       HANDLE_TYPE(FLOAT, float);
1352       HANDLE_TYPE(DOUBLE, double);
1353       HANDLE_TYPE(BOOL, bool);
1354       HANDLE_TYPE(ENUM, enum);
1355       HANDLE_TYPE(STRING, string);
1356       HANDLE_TYPE(MESSAGE, message);
1357 #undef HANDLE_TYPE
1358     }
1359   } else {
1360     if (!is_cleared) {
1361       switch (cpp_type(type)) {
1362         case WireFormatLite::CPPTYPE_STRING:
1363           ptr.string_value->clear();
1364           break;
1365         case WireFormatLite::CPPTYPE_MESSAGE:
1366           if (is_lazy) {
1367             ptr.lazymessage_value->Clear();
1368           } else {
1369             ptr.message_value->Clear();
1370           }
1371           break;
1372         default:
1373           // No need to do anything.  Get*() will return the default value
1374           // as long as is_cleared is true and Set*() will overwrite the
1375           // previous value.
1376           break;
1377       }
1378 
1379       is_cleared = true;
1380     }
1381   }
1382 }
1383 
ByteSize(int number) const1384 size_t ExtensionSet::Extension::ByteSize(int number) const {
1385   size_t result = 0;
1386 
1387   if (is_repeated) {
1388     if (is_packed) {
1389       switch (real_type(type)) {
1390 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                     \
1391   case WireFormatLite::TYPE_##UPPERCASE:                                 \
1392     for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1393       result += WireFormatLite::CAMELCASE##Size(                         \
1394           ptr.repeated_##LOWERCASE##_value->Get(i));                     \
1395     }                                                                    \
1396     break
1397 
1398         HANDLE_TYPE(INT32, Int32, int32_t);
1399         HANDLE_TYPE(INT64, Int64, int64_t);
1400         HANDLE_TYPE(UINT32, UInt32, uint32_t);
1401         HANDLE_TYPE(UINT64, UInt64, uint64_t);
1402         HANDLE_TYPE(SINT32, SInt32, int32_t);
1403         HANDLE_TYPE(SINT64, SInt64, int64_t);
1404         HANDLE_TYPE(ENUM, Enum, enum);
1405 #undef HANDLE_TYPE
1406 
1407         // Stuff with fixed size.
1408 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                 \
1409   case WireFormatLite::TYPE_##UPPERCASE:                             \
1410     result += WireFormatLite::k##CAMELCASE##Size *                   \
1411               FromIntSize(ptr.repeated_##LOWERCASE##_value->size()); \
1412     break
1413         HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1414         HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1415         HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1416         HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1417         HANDLE_TYPE(FLOAT, Float, float);
1418         HANDLE_TYPE(DOUBLE, Double, double);
1419         HANDLE_TYPE(BOOL, Bool, bool);
1420 #undef HANDLE_TYPE
1421 
1422         case WireFormatLite::TYPE_STRING:
1423         case WireFormatLite::TYPE_BYTES:
1424         case WireFormatLite::TYPE_GROUP:
1425         case WireFormatLite::TYPE_MESSAGE:
1426           ABSL_LOG(FATAL) << "Non-primitive types can't be packed.";
1427           break;
1428       }
1429 
1430       cached_size.set(ToCachedSize(result));
1431       if (result > 0) {
1432         result += io::CodedOutputStream::VarintSize32(result);
1433         result += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
1434             number, WireFormatLite::WIRETYPE_LENGTH_DELIMITED));
1435       }
1436     } else {
1437       size_t tag_size = WireFormatLite::TagSize(number, real_type(type));
1438 
1439       switch (real_type(type)) {
1440 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                      \
1441   case WireFormatLite::TYPE_##UPPERCASE:                                  \
1442     result +=                                                             \
1443         tag_size * FromIntSize(ptr.repeated_##LOWERCASE##_value->size()); \
1444     for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) {  \
1445       result += WireFormatLite::CAMELCASE##Size(                          \
1446           ptr.repeated_##LOWERCASE##_value->Get(i));                      \
1447     }                                                                     \
1448     break
1449 
1450         HANDLE_TYPE(INT32, Int32, int32_t);
1451         HANDLE_TYPE(INT64, Int64, int64_t);
1452         HANDLE_TYPE(UINT32, UInt32, uint32_t);
1453         HANDLE_TYPE(UINT64, UInt64, uint64_t);
1454         HANDLE_TYPE(SINT32, SInt32, int32_t);
1455         HANDLE_TYPE(SINT64, SInt64, int64_t);
1456         HANDLE_TYPE(STRING, String, string);
1457         HANDLE_TYPE(BYTES, Bytes, string);
1458         HANDLE_TYPE(ENUM, Enum, enum);
1459         HANDLE_TYPE(GROUP, Group, message);
1460         HANDLE_TYPE(MESSAGE, Message, message);
1461 #undef HANDLE_TYPE
1462 
1463         // Stuff with fixed size.
1464 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                 \
1465   case WireFormatLite::TYPE_##UPPERCASE:                             \
1466     result += (tag_size + WireFormatLite::k##CAMELCASE##Size) *      \
1467               FromIntSize(ptr.repeated_##LOWERCASE##_value->size()); \
1468     break
1469         HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1470         HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1471         HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1472         HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1473         HANDLE_TYPE(FLOAT, Float, float);
1474         HANDLE_TYPE(DOUBLE, Double, double);
1475         HANDLE_TYPE(BOOL, Bool, bool);
1476 #undef HANDLE_TYPE
1477       }
1478     }
1479   } else if (!is_cleared) {
1480     result += WireFormatLite::TagSize(number, real_type(type));
1481     switch (real_type(type)) {
1482 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)      \
1483   case WireFormatLite::TYPE_##UPPERCASE:                  \
1484     result += WireFormatLite::CAMELCASE##Size(LOWERCASE); \
1485     break
1486 
1487       HANDLE_TYPE(INT32, Int32, int32_t_value);
1488       HANDLE_TYPE(INT64, Int64, int64_t_value);
1489       HANDLE_TYPE(UINT32, UInt32, uint32_t_value);
1490       HANDLE_TYPE(UINT64, UInt64, uint64_t_value);
1491       HANDLE_TYPE(SINT32, SInt32, int32_t_value);
1492       HANDLE_TYPE(SINT64, SInt64, int64_t_value);
1493       HANDLE_TYPE(STRING, String, *ptr.string_value);
1494       HANDLE_TYPE(BYTES, Bytes, *ptr.string_value);
1495       HANDLE_TYPE(ENUM, Enum, enum_value);
1496       HANDLE_TYPE(GROUP, Group, *ptr.message_value);
1497 #undef HANDLE_TYPE
1498       case WireFormatLite::TYPE_MESSAGE: {
1499         result += WireFormatLite::LengthDelimitedSize(
1500             is_lazy ? ptr.lazymessage_value->ByteSizeLong()
1501                     : ptr.message_value->ByteSizeLong());
1502         break;
1503       }
1504 
1505       // Stuff with fixed size.
1506 #define HANDLE_TYPE(UPPERCASE, CAMELCASE)         \
1507   case WireFormatLite::TYPE_##UPPERCASE:          \
1508     result += WireFormatLite::k##CAMELCASE##Size; \
1509     break
1510         HANDLE_TYPE(FIXED32, Fixed32);
1511         HANDLE_TYPE(FIXED64, Fixed64);
1512         HANDLE_TYPE(SFIXED32, SFixed32);
1513         HANDLE_TYPE(SFIXED64, SFixed64);
1514         HANDLE_TYPE(FLOAT, Float);
1515         HANDLE_TYPE(DOUBLE, Double);
1516         HANDLE_TYPE(BOOL, Bool);
1517 #undef HANDLE_TYPE
1518     }
1519   }
1520 
1521   return result;
1522 }
1523 
GetSize() const1524 int ExtensionSet::Extension::GetSize() const {
1525   ABSL_DCHECK(is_repeated);
1526   switch (cpp_type(type)) {
1527 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)   \
1528   case WireFormatLite::CPPTYPE_##UPPERCASE: \
1529     return ptr.repeated_##LOWERCASE##_value->size()
1530 
1531     HANDLE_TYPE(INT32, int32_t);
1532     HANDLE_TYPE(INT64, int64_t);
1533     HANDLE_TYPE(UINT32, uint32_t);
1534     HANDLE_TYPE(UINT64, uint64_t);
1535     HANDLE_TYPE(FLOAT, float);
1536     HANDLE_TYPE(DOUBLE, double);
1537     HANDLE_TYPE(BOOL, bool);
1538     HANDLE_TYPE(ENUM, enum);
1539     HANDLE_TYPE(STRING, string);
1540     HANDLE_TYPE(MESSAGE, message);
1541 #undef HANDLE_TYPE
1542   }
1543 
1544   ABSL_LOG(FATAL) << "Can't get here.";
1545   return 0;
1546 }
1547 
1548 // This function deletes all allocated objects. This function should be only
1549 // called if the Extension was created without an arena.
Free()1550 void ExtensionSet::Extension::Free() {
1551   if (is_repeated) {
1552     switch (cpp_type(type)) {
1553 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)    \
1554   case WireFormatLite::CPPTYPE_##UPPERCASE:  \
1555     delete ptr.repeated_##LOWERCASE##_value; \
1556     break
1557 
1558       HANDLE_TYPE(INT32, int32_t);
1559       HANDLE_TYPE(INT64, int64_t);
1560       HANDLE_TYPE(UINT32, uint32_t);
1561       HANDLE_TYPE(UINT64, uint64_t);
1562       HANDLE_TYPE(FLOAT, float);
1563       HANDLE_TYPE(DOUBLE, double);
1564       HANDLE_TYPE(BOOL, bool);
1565       HANDLE_TYPE(ENUM, enum);
1566       HANDLE_TYPE(STRING, string);
1567       HANDLE_TYPE(MESSAGE, message);
1568 #undef HANDLE_TYPE
1569     }
1570   } else {
1571     switch (cpp_type(type)) {
1572       case WireFormatLite::CPPTYPE_STRING:
1573         delete ptr.string_value;
1574         break;
1575       case WireFormatLite::CPPTYPE_MESSAGE:
1576         if (is_lazy) {
1577           delete ptr.lazymessage_value;
1578         } else {
1579           delete ptr.message_value;
1580         }
1581         break;
1582       default:
1583         break;
1584     }
1585   }
1586 }
1587 
1588 // Defined in extension_set_heavy.cc.
1589 // int ExtensionSet::Extension::SpaceUsedExcludingSelf() const
1590 
IsInitialized(const ExtensionSet * ext_set,const MessageLite * extendee,int number,Arena * arena) const1591 bool ExtensionSet::Extension::IsInitialized(const ExtensionSet* ext_set,
1592                                             const MessageLite* extendee,
1593                                             int number, Arena* arena) const {
1594   if (cpp_type(type) != WireFormatLite::CPPTYPE_MESSAGE) return true;
1595 
1596   if (is_repeated) {
1597     for (int i = 0; i < ptr.repeated_message_value->size(); i++) {
1598       if (!ptr.repeated_message_value->Get(i).IsInitialized()) {
1599         return false;
1600       }
1601     }
1602     return true;
1603   }
1604 
1605   if (is_cleared) return true;
1606 
1607   if (!is_lazy) return ptr.message_value->IsInitialized();
1608 
1609   const MessageLite* prototype =
1610       ext_set->GetPrototypeForLazyMessage(extendee, number);
1611   ABSL_DCHECK_NE(prototype, nullptr)
1612       << "extendee: " << extendee->GetTypeName() << "; number: " << number;
1613   return ptr.lazymessage_value->IsInitialized(prototype, arena);
1614 }
1615 
1616 // Dummy key method to avoid weak vtable.
UnusedKeyMethod()1617 void ExtensionSet::LazyMessageExtension::UnusedKeyMethod() {}
1618 
FindOrNull(int key) const1619 const ExtensionSet::Extension* ExtensionSet::FindOrNull(int key) const {
1620   if (flat_size_ == 0) {
1621     return nullptr;
1622   } else if (PROTOBUF_PREDICT_TRUE(!is_large())) {
1623     for (auto it = flat_begin(), end = flat_end();
1624          it != end && it->first <= key; ++it) {
1625       if (it->first == key) return &it->second;
1626     }
1627     return nullptr;
1628   } else {
1629     return FindOrNullInLargeMap(key);
1630   }
1631 }
1632 
FindOrNullInLargeMap(int key) const1633 const ExtensionSet::Extension* ExtensionSet::FindOrNullInLargeMap(
1634     int key) const {
1635   assert(is_large());
1636   LargeMap::const_iterator it = map_.large->find(key);
1637   if (it != map_.large->end()) {
1638     return &it->second;
1639   }
1640   return nullptr;
1641 }
1642 
FindOrNull(int key)1643 ExtensionSet::Extension* ExtensionSet::FindOrNull(int key) {
1644   const auto* const_this = this;
1645   return const_cast<ExtensionSet::Extension*>(const_this->FindOrNull(key));
1646 }
1647 
FindOrNullInLargeMap(int key)1648 ExtensionSet::Extension* ExtensionSet::FindOrNullInLargeMap(int key) {
1649   const auto* const_this = this;
1650   return const_cast<ExtensionSet::Extension*>(
1651       const_this->FindOrNullInLargeMap(key));
1652 }
1653 
Insert(int key)1654 std::pair<ExtensionSet::Extension*, bool> ExtensionSet::Insert(int key) {
1655   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1656     auto maybe = map_.large->insert({key, Extension()});
1657     return {&maybe.first->second, maybe.second};
1658   }
1659   KeyValue* end = flat_end();
1660   KeyValue* it = flat_begin();
1661   for (; it != end && it->first <= key; ++it) {
1662     if (it->first == key) return {&it->second, false};
1663   }
1664   if (flat_size_ < flat_capacity_) {
1665     std::copy_backward(it, end, end + 1);
1666     ++flat_size_;
1667     it->first = key;
1668     it->second = Extension();
1669     return {&it->second, true};
1670   }
1671   GrowCapacity(flat_size_ + 1);
1672   return Insert(key);
1673 }
1674 
1675 namespace {
IsPowerOfTwo(size_t n)1676 constexpr bool IsPowerOfTwo(size_t n) { return (n & (n - 1)) == 0; }
1677 }  // namespace
1678 
GrowCapacity(size_t minimum_new_capacity)1679 void ExtensionSet::GrowCapacity(size_t minimum_new_capacity) {
1680   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1681     return;  // LargeMap does not have a "reserve" method.
1682   }
1683   if (flat_capacity_ >= minimum_new_capacity) {
1684     return;
1685   }
1686 
1687   auto new_flat_capacity = flat_capacity_;
1688   do {
1689     new_flat_capacity = new_flat_capacity == 0 ? 1 : new_flat_capacity * 4;
1690   } while (new_flat_capacity < minimum_new_capacity);
1691 
1692   KeyValue* begin = flat_begin();
1693   KeyValue* end = flat_end();
1694   AllocatedData new_map;
1695   Arena* const arena = arena_;
1696   if (new_flat_capacity > kMaximumFlatCapacity) {
1697     new_map.large = Arena::Create<LargeMap>(arena);
1698     LargeMap::iterator hint = new_map.large->begin();
1699     for (const KeyValue* it = begin; it != end; ++it) {
1700       hint = new_map.large->insert(hint, {it->first, it->second});
1701     }
1702     flat_size_ = static_cast<uint16_t>(-1);
1703     ABSL_DCHECK(is_large());
1704   } else {
1705     new_map.flat = Arena::CreateArray<KeyValue>(arena, new_flat_capacity);
1706     std::copy(begin, end, new_map.flat);
1707   }
1708 
1709   // ReturnArrayMemory is more efficient with power-of-2 bytes, and
1710   // sizeof(KeyValue) is a power-of-2 on 64-bit platforms. flat_capacity_ is
1711   // always a power-of-2.
1712   ABSL_DCHECK(IsPowerOfTwo(sizeof(KeyValue)) || sizeof(void*) != 8)
1713       << sizeof(KeyValue) << " " << sizeof(void*);
1714   ABSL_DCHECK(IsPowerOfTwo(flat_capacity_));
1715   if (flat_capacity_ > 0) {
1716     if (arena == nullptr) {
1717       DeleteFlatMap(begin, flat_capacity_);
1718     } else {
1719       arena->ReturnArrayMemory(begin, sizeof(KeyValue) * flat_capacity_);
1720     }
1721   }
1722   flat_capacity_ = new_flat_capacity;
1723   map_ = new_map;
1724 }
1725 
1726 #if (__cplusplus < 201703) && \
1727     (!defined(_MSC_VER) || (_MSC_VER >= 1900 && _MSC_VER < 1912))
1728 // static
1729 constexpr uint16_t ExtensionSet::kMaximumFlatCapacity;
1730 #endif  //  (__cplusplus < 201703) && (!defined(_MSC_VER) || (_MSC_VER >= 1900
1731         //  && _MSC_VER < 1912))
1732 
Erase(int key)1733 void ExtensionSet::Erase(int key) {
1734   if (PROTOBUF_PREDICT_FALSE(is_large())) {
1735     map_.large->erase(key);
1736     return;
1737   }
1738   KeyValue* end = flat_end();
1739   for (KeyValue* it = flat_begin(); it != end && it->first <= key; ++it) {
1740     if (it->first == key) {
1741       std::copy(it + 1, end, it);
1742       --flat_size_;
1743       return;
1744     }
1745   }
1746 }
1747 
1748 // ==================================================================
1749 // Default repeated field instances for iterator-compatible accessors
1750 
default_instance()1751 const RepeatedPrimitiveDefaults* RepeatedPrimitiveDefaults::default_instance() {
1752   static auto instance = OnShutdownDelete(new RepeatedPrimitiveDefaults);
1753   return instance;
1754 }
1755 
1756 const RepeatedStringTypeTraits::RepeatedFieldType*
GetDefaultRepeatedField()1757 RepeatedStringTypeTraits::GetDefaultRepeatedField() {
1758   static auto instance = OnShutdownDelete(new RepeatedFieldType);
1759   return instance;
1760 }
1761 
InternalSerializeFieldWithCachedSizesToArray(const MessageLite * extendee,const ExtensionSet * extension_set,int number,uint8_t * target,io::EpsCopyOutputStream * stream) const1762 uint8_t* ExtensionSet::Extension::InternalSerializeFieldWithCachedSizesToArray(
1763     const MessageLite* extendee, const ExtensionSet* extension_set, int number,
1764     uint8_t* target, io::EpsCopyOutputStream* stream) const {
1765   if (is_repeated) {
1766     if (is_packed) {
1767       if (cached_size() == 0) return target;
1768 
1769       target = stream->EnsureSpace(target);
1770       target = WireFormatLite::WriteTagToArray(
1771           number, WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
1772       target = WireFormatLite::WriteInt32NoTagToArray(cached_size(), target);
1773 
1774       switch (real_type(type)) {
1775 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                     \
1776   case WireFormatLite::TYPE_##UPPERCASE:                                 \
1777     for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1778       target = stream->EnsureSpace(target);                              \
1779       target = WireFormatLite::Write##CAMELCASE##NoTagToArray(           \
1780           ptr.repeated_##LOWERCASE##_value->Get(i), target);             \
1781     }                                                                    \
1782     break
1783 
1784         HANDLE_TYPE(INT32, Int32, int32_t);
1785         HANDLE_TYPE(INT64, Int64, int64_t);
1786         HANDLE_TYPE(UINT32, UInt32, uint32_t);
1787         HANDLE_TYPE(UINT64, UInt64, uint64_t);
1788         HANDLE_TYPE(SINT32, SInt32, int32_t);
1789         HANDLE_TYPE(SINT64, SInt64, int64_t);
1790         HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1791         HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1792         HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1793         HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1794         HANDLE_TYPE(FLOAT, Float, float);
1795         HANDLE_TYPE(DOUBLE, Double, double);
1796         HANDLE_TYPE(BOOL, Bool, bool);
1797         HANDLE_TYPE(ENUM, Enum, enum);
1798 #undef HANDLE_TYPE
1799 
1800         case WireFormatLite::TYPE_STRING:
1801         case WireFormatLite::TYPE_BYTES:
1802         case WireFormatLite::TYPE_GROUP:
1803         case WireFormatLite::TYPE_MESSAGE:
1804           ABSL_LOG(FATAL) << "Non-primitive types can't be packed.";
1805           break;
1806       }
1807     } else {
1808       switch (real_type(type)) {
1809 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                     \
1810   case WireFormatLite::TYPE_##UPPERCASE:                                 \
1811     for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1812       target = stream->EnsureSpace(target);                              \
1813       target = WireFormatLite::Write##CAMELCASE##ToArray(                \
1814           number, ptr.repeated_##LOWERCASE##_value->Get(i), target);     \
1815     }                                                                    \
1816     break
1817 
1818         HANDLE_TYPE(INT32, Int32, int32_t);
1819         HANDLE_TYPE(INT64, Int64, int64_t);
1820         HANDLE_TYPE(UINT32, UInt32, uint32_t);
1821         HANDLE_TYPE(UINT64, UInt64, uint64_t);
1822         HANDLE_TYPE(SINT32, SInt32, int32_t);
1823         HANDLE_TYPE(SINT64, SInt64, int64_t);
1824         HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1825         HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1826         HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1827         HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1828         HANDLE_TYPE(FLOAT, Float, float);
1829         HANDLE_TYPE(DOUBLE, Double, double);
1830         HANDLE_TYPE(BOOL, Bool, bool);
1831         HANDLE_TYPE(ENUM, Enum, enum);
1832 #undef HANDLE_TYPE
1833 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                     \
1834   case WireFormatLite::TYPE_##UPPERCASE:                                 \
1835     for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1836       target = stream->EnsureSpace(target);                              \
1837       target = stream->WriteString(                                      \
1838           number, ptr.repeated_##LOWERCASE##_value->Get(i), target);     \
1839     }                                                                    \
1840     break
1841         HANDLE_TYPE(STRING, String, string);
1842         HANDLE_TYPE(BYTES, Bytes, string);
1843 #undef HANDLE_TYPE
1844         case WireFormatLite::TYPE_GROUP:
1845           for (int i = 0; i < ptr.repeated_message_value->size(); i++) {
1846             target = stream->EnsureSpace(target);
1847             target = WireFormatLite::InternalWriteGroup(
1848                 number, ptr.repeated_message_value->Get(i), target, stream);
1849           }
1850           break;
1851         case WireFormatLite::TYPE_MESSAGE:
1852           for (int i = 0; i < ptr.repeated_message_value->size(); i++) {
1853             auto& msg = ptr.repeated_message_value->Get(i);
1854             target = WireFormatLite::InternalWriteMessage(
1855                 number, msg, msg.GetCachedSize(), target, stream);
1856           }
1857           break;
1858       }
1859     }
1860   } else if (!is_cleared) {
1861     switch (real_type(type)) {
1862 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE)                               \
1863   case WireFormatLite::TYPE_##UPPERCASE:                                       \
1864     target = stream->EnsureSpace(target);                                      \
1865     target = WireFormatLite::Write##CAMELCASE##ToArray(number, VALUE, target); \
1866     break
1867 
1868       HANDLE_TYPE(INT32, Int32, int32_t_value);
1869       HANDLE_TYPE(INT64, Int64, int64_t_value);
1870       HANDLE_TYPE(UINT32, UInt32, uint32_t_value);
1871       HANDLE_TYPE(UINT64, UInt64, uint64_t_value);
1872       HANDLE_TYPE(SINT32, SInt32, int32_t_value);
1873       HANDLE_TYPE(SINT64, SInt64, int64_t_value);
1874       HANDLE_TYPE(FIXED32, Fixed32, uint32_t_value);
1875       HANDLE_TYPE(FIXED64, Fixed64, uint64_t_value);
1876       HANDLE_TYPE(SFIXED32, SFixed32, int32_t_value);
1877       HANDLE_TYPE(SFIXED64, SFixed64, int64_t_value);
1878       HANDLE_TYPE(FLOAT, Float, float_value);
1879       HANDLE_TYPE(DOUBLE, Double, double_value);
1880       HANDLE_TYPE(BOOL, Bool, bool_value);
1881       HANDLE_TYPE(ENUM, Enum, enum_value);
1882 #undef HANDLE_TYPE
1883 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE)         \
1884   case WireFormatLite::TYPE_##UPPERCASE:                 \
1885     target = stream->EnsureSpace(target);                \
1886     target = stream->WriteString(number, VALUE, target); \
1887     break
1888       HANDLE_TYPE(STRING, String, *ptr.string_value);
1889       HANDLE_TYPE(BYTES, Bytes, *ptr.string_value);
1890 #undef HANDLE_TYPE
1891       case WireFormatLite::TYPE_GROUP:
1892         target = stream->EnsureSpace(target);
1893         target = WireFormatLite::InternalWriteGroup(number, *ptr.message_value,
1894                                                     target, stream);
1895         break;
1896       case WireFormatLite::TYPE_MESSAGE:
1897         if (is_lazy) {
1898           const auto* prototype =
1899               extension_set->GetPrototypeForLazyMessage(extendee, number);
1900           target = ptr.lazymessage_value->WriteMessageToArray(prototype, number,
1901                                                               target, stream);
1902         } else {
1903           target = WireFormatLite::InternalWriteMessage(
1904               number, *ptr.message_value, ptr.message_value->GetCachedSize(),
1905               target, stream);
1906         }
1907         break;
1908     }
1909   }
1910   return target;
1911 }
1912 
GetPrototypeForLazyMessage(const MessageLite * extendee,int number) const1913 const MessageLite* ExtensionSet::GetPrototypeForLazyMessage(
1914     const MessageLite* extendee, int number) const {
1915   GeneratedExtensionFinder finder(extendee);
1916   bool was_packed_on_wire = false;
1917   ExtensionInfo extension_info;
1918   if (!FindExtensionInfoFromFieldNumber(
1919           WireFormatLite::WireType::WIRETYPE_LENGTH_DELIMITED, number, &finder,
1920           &extension_info, &was_packed_on_wire)) {
1921     return nullptr;
1922   }
1923   return extension_info.message_info.prototype;
1924 }
1925 
1926 uint8_t*
InternalSerializeMessageSetItemWithCachedSizesToArray(const MessageLite * extendee,const ExtensionSet * extension_set,int number,uint8_t * target,io::EpsCopyOutputStream * stream) const1927 ExtensionSet::Extension::InternalSerializeMessageSetItemWithCachedSizesToArray(
1928     const MessageLite* extendee, const ExtensionSet* extension_set, int number,
1929     uint8_t* target, io::EpsCopyOutputStream* stream) const {
1930   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
1931     // Not a valid MessageSet extension, but serialize it the normal way.
1932     ABSL_LOG(WARNING) << "Invalid message set extension.";
1933     return InternalSerializeFieldWithCachedSizesToArray(extendee, extension_set,
1934                                                         number, target, stream);
1935   }
1936 
1937   if (is_cleared) return target;
1938 
1939   target = stream->EnsureSpace(target);
1940   // Start group.
1941   target = io::CodedOutputStream::WriteTagToArray(
1942       WireFormatLite::kMessageSetItemStartTag, target);
1943   // Write type ID.
1944   target = WireFormatLite::WriteUInt32ToArray(
1945       WireFormatLite::kMessageSetTypeIdNumber, number, target);
1946   // Write message.
1947   if (is_lazy) {
1948     const auto* prototype =
1949         extension_set->GetPrototypeForLazyMessage(extendee, number);
1950     target = ptr.lazymessage_value->WriteMessageToArray(
1951         prototype, WireFormatLite::kMessageSetMessageNumber, target, stream);
1952   } else {
1953     target = WireFormatLite::InternalWriteMessage(
1954         WireFormatLite::kMessageSetMessageNumber, *ptr.message_value,
1955         ptr.message_value->GetCachedSize(), target, stream);
1956   }
1957   // End group.
1958   target = stream->EnsureSpace(target);
1959   target = io::CodedOutputStream::WriteTagToArray(
1960       WireFormatLite::kMessageSetItemEndTag, target);
1961   return target;
1962 }
1963 
MessageSetItemByteSize(int number) const1964 size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
1965   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
1966     // Not a valid MessageSet extension, but compute the byte size for it the
1967     // normal way.
1968     return ByteSize(number);
1969   }
1970 
1971   if (is_cleared) return 0;
1972 
1973   size_t our_size = WireFormatLite::kMessageSetItemTagsSize;
1974 
1975   // type_id
1976   our_size += io::CodedOutputStream::VarintSize32(number);
1977 
1978   // message
1979   our_size += WireFormatLite::LengthDelimitedSize(
1980       is_lazy ? ptr.lazymessage_value->ByteSizeLong()
1981               : ptr.message_value->ByteSizeLong());
1982 
1983   return our_size;
1984 }
1985 
MessageSetByteSize() const1986 size_t ExtensionSet::MessageSetByteSize() const {
1987   size_t total_size = 0;
1988   ForEach(
1989       [&total_size](int number, const Extension& ext) {
1990         total_size += ext.MessageSetItemByteSize(number);
1991       },
1992       Prefetch{});
1993   return total_size;
1994 }
1995 
FindExtensionLazyEagerVerifyFn(const MessageLite * extendee,int number)1996 LazyEagerVerifyFnType FindExtensionLazyEagerVerifyFn(
1997     const MessageLite* extendee, int number) {
1998   const ExtensionInfo* registered = FindRegisteredExtension(extendee, number);
1999   if (registered != nullptr) {
2000     return registered->lazy_eager_verify_func;
2001   }
2002   return nullptr;
2003 }
2004 
2005 std::atomic<ExtensionSet::LazyMessageExtension* (*)(Arena* arena)>
2006     ExtensionSet::maybe_create_lazy_extension_;
2007 
2008 }  // namespace internal
2009 }  // namespace protobuf
2010 }  // namespace google
2011 
2012 #include "google/protobuf/port_undef.inc"
2013