1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 // Author: kenton@google.com (Kenton Varda)
9 // Based on original Protocol Buffers design by
10 // Sanjay Ghemawat, Jeff Dean, and others.
11
12 #include "google/protobuf/extension_set.h"
13
14 #include <algorithm>
15 #include <atomic>
16 #include <cstddef>
17 #include <cstdint>
18 #include <string>
19 #include <tuple>
20 #include <type_traits>
21 #include <utility>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/hash/hash.h"
25 #include "absl/log/absl_check.h"
26 #include "absl/log/absl_log.h"
27 #include "google/protobuf/arena.h"
28 #include "google/protobuf/extension_set_inl.h"
29 #include "google/protobuf/io/coded_stream.h"
30 #include "google/protobuf/message_lite.h"
31 #include "google/protobuf/metadata_lite.h"
32 #include "google/protobuf/parse_context.h"
33 #include "google/protobuf/port.h"
34 #include "google/protobuf/repeated_field.h"
35
36 // must be last.
37 #include "google/protobuf/port_def.inc"
38
39 namespace google {
40 namespace protobuf {
41 namespace internal {
42 namespace {
43
real_type(FieldType type)44 inline WireFormatLite::FieldType real_type(FieldType type) {
45 ABSL_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
46 return static_cast<WireFormatLite::FieldType>(type);
47 }
48
cpp_type(FieldType type)49 inline WireFormatLite::CppType cpp_type(FieldType type) {
50 return WireFormatLite::FieldTypeToCppType(real_type(type));
51 }
52
53 // Registry stuff.
54
55 struct ExtensionInfoKey {
56 const MessageLite* message;
57 int number;
58 };
59
60 struct ExtensionEq {
61 using is_transparent = void;
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionEq62 bool operator()(const ExtensionInfo& lhs, const ExtensionInfo& rhs) const {
63 return lhs.message == rhs.message && lhs.number == rhs.number;
64 }
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionEq65 bool operator()(const ExtensionInfo& lhs, const ExtensionInfoKey& rhs) const {
66 return lhs.message == rhs.message && lhs.number == rhs.number;
67 }
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionEq68 bool operator()(const ExtensionInfoKey& lhs, const ExtensionInfo& rhs) const {
69 return lhs.message == rhs.message && lhs.number == rhs.number;
70 }
71 };
72
73 struct ExtensionHasher {
74 using is_transparent = void;
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionHasher75 std::size_t operator()(const ExtensionInfo& info) const {
76 return absl::HashOf(info.message, info.number);
77 }
operator ()google::protobuf::internal::__anon936a36e30111::ExtensionHasher78 std::size_t operator()(const ExtensionInfoKey& info) const {
79 return absl::HashOf(info.message, info.number);
80 }
81 };
82
83 using ExtensionRegistry =
84 absl::flat_hash_set<ExtensionInfo, ExtensionHasher, ExtensionEq>;
85
86 static const ExtensionRegistry* global_registry = nullptr;
87
88 // This function is only called at startup, so there is no need for thread-
89 // safety.
Register(const ExtensionInfo & info)90 void Register(const ExtensionInfo& info) {
91 static auto local_static_registry = OnShutdownDelete(new ExtensionRegistry);
92 global_registry = local_static_registry;
93 if (!local_static_registry->insert(info).second) {
94 ABSL_LOG(FATAL) << "Multiple extension registrations for type \""
95 << info.message->GetTypeName() << "\", field number "
96 << info.number << ".";
97 }
98 }
99
FindRegisteredExtension(const MessageLite * extendee,int number)100 const ExtensionInfo* FindRegisteredExtension(const MessageLite* extendee,
101 int number) {
102 if (!global_registry) return nullptr;
103
104 ExtensionInfoKey info;
105 info.message = extendee;
106 info.number = number;
107
108 auto it = global_registry->find(info);
109 if (it == global_registry->end()) {
110 return nullptr;
111 } else {
112 return &*it;
113 }
114 }
115
116 } // namespace
117
Find(int number,ExtensionInfo * output)118 bool GeneratedExtensionFinder::Find(int number, ExtensionInfo* output) {
119 const ExtensionInfo* extension = FindRegisteredExtension(extendee_, number);
120 if (extension == nullptr) {
121 return false;
122 } else {
123 *output = *extension;
124 return true;
125 }
126 }
127
RegisterExtension(const MessageLite * extendee,int number,FieldType type,bool is_repeated,bool is_packed)128 void ExtensionSet::RegisterExtension(const MessageLite* extendee, int number,
129 FieldType type, bool is_repeated,
130 bool is_packed) {
131 ABSL_CHECK_NE(type, WireFormatLite::TYPE_ENUM);
132 ABSL_CHECK_NE(type, WireFormatLite::TYPE_MESSAGE);
133 ABSL_CHECK_NE(type, WireFormatLite::TYPE_GROUP);
134 ExtensionInfo info(extendee, number, type, is_repeated, is_packed);
135 Register(info);
136 }
137
CallNoArgValidityFunc(const void * arg,int number)138 static bool CallNoArgValidityFunc(const void* arg, int number) {
139 // Note: Must use C-style cast here rather than reinterpret_cast because
140 // the C++ standard at one point did not allow casts between function and
141 // data pointers and some compilers enforce this for C++-style casts. No
142 // compiler enforces it for C-style casts since lots of C-style code has
143 // relied on these kinds of casts for a long time, despite being
144 // technically undefined. See:
145 // http://www.open-std.org/jtc1/sc22/wg21/docs/cwg_defects.html#195
146 // Also note: Some compilers do not allow function pointers to be "const".
147 // Which makes sense, I suppose, because it's meaningless.
148 return ((EnumValidityFunc*)arg)(number);
149 }
150
RegisterEnumExtension(const MessageLite * extendee,int number,FieldType type,bool is_repeated,bool is_packed,EnumValidityFunc * is_valid)151 void ExtensionSet::RegisterEnumExtension(const MessageLite* extendee,
152 int number, FieldType type,
153 bool is_repeated, bool is_packed,
154 EnumValidityFunc* is_valid) {
155 ABSL_CHECK_EQ(type, WireFormatLite::TYPE_ENUM);
156 ExtensionInfo info(extendee, number, type, is_repeated, is_packed);
157 info.enum_validity_check.func = CallNoArgValidityFunc;
158 // See comment in CallNoArgValidityFunc() about why we use a c-style cast.
159 info.enum_validity_check.arg = (void*)is_valid;
160 Register(info);
161 }
162
RegisterMessageExtension(const MessageLite * extendee,int number,FieldType type,bool is_repeated,bool is_packed,const MessageLite * prototype,LazyEagerVerifyFnType verify_func,LazyAnnotation is_lazy)163 void ExtensionSet::RegisterMessageExtension(const MessageLite* extendee,
164 int number, FieldType type,
165 bool is_repeated, bool is_packed,
166 const MessageLite* prototype,
167 LazyEagerVerifyFnType verify_func,
168 LazyAnnotation is_lazy) {
169 ABSL_CHECK(type == WireFormatLite::TYPE_MESSAGE ||
170 type == WireFormatLite::TYPE_GROUP);
171 ExtensionInfo info(extendee, number, type, is_repeated, is_packed,
172 verify_func, is_lazy);
173 info.message_info = {prototype,
174 #if defined(PROTOBUF_CONSTINIT_DEFAULT_INSTANCES)
175 prototype->GetTcParseTable()
176 #else
177 nullptr
178 #endif
179 };
180 Register(info);
181 }
182
183 // ===================================================================
184 // Constructors and basic methods.
185
~ExtensionSet()186 ExtensionSet::~ExtensionSet() {
187 // Deletes all allocated extensions.
188 if (arena_ == nullptr) {
189 ForEach([](int /* number */, Extension& ext) { ext.Free(); },
190 PrefetchNta{});
191 if (PROTOBUF_PREDICT_FALSE(is_large())) {
192 delete map_.large;
193 } else {
194 DeleteFlatMap(map_.flat, flat_capacity_);
195 }
196 }
197 }
198
DeleteFlatMap(const ExtensionSet::KeyValue * flat,uint16_t flat_capacity)199 void ExtensionSet::DeleteFlatMap(const ExtensionSet::KeyValue* flat,
200 uint16_t flat_capacity) {
201 // Arena::CreateArray already requires a trivially destructible type, but
202 // ensure this constraint is not violated in the future.
203 static_assert(std::is_trivially_destructible<KeyValue>::value,
204 "CreateArray requires a trivially destructible type");
205 // A const-cast is needed, but this is safe as we are about to deallocate the
206 // array.
207 internal::SizedArrayDelete(const_cast<KeyValue*>(flat),
208 sizeof(*flat) * flat_capacity);
209 }
210
211 // Defined in extension_set_heavy.cc.
212 // void ExtensionSet::AppendToList(const Descriptor* extendee,
213 // const DescriptorPool* pool,
214 // vector<const FieldDescriptor*>* output) const
215
Has(int number) const216 bool ExtensionSet::Has(int number) const {
217 const Extension* ext = FindOrNull(number);
218 if (ext == nullptr) return false;
219 ABSL_DCHECK(!ext->is_repeated);
220 return !ext->is_cleared;
221 }
222
HasLazy(int number) const223 bool ExtensionSet::HasLazy(int number) const {
224 return Has(number) && FindOrNull(number)->is_lazy;
225 }
226
NumExtensions() const227 int ExtensionSet::NumExtensions() const {
228 int result = 0;
229 ForEachNoPrefetch([&result](int /* number */, const Extension& ext) {
230 if (!ext.is_cleared) {
231 ++result;
232 }
233 });
234 return result;
235 }
236
ExtensionSize(int number) const237 int ExtensionSet::ExtensionSize(int number) const {
238 const Extension* ext = FindOrNull(number);
239 return ext == nullptr ? 0 : ext->GetSize();
240 }
241
ExtensionType(int number) const242 FieldType ExtensionSet::ExtensionType(int number) const {
243 const Extension* ext = FindOrNull(number);
244 if (ext == nullptr) {
245 ABSL_DLOG(FATAL)
246 << "Don't lookup extension types if they aren't present (1). ";
247 return 0;
248 }
249 if (ext->is_cleared) {
250 ABSL_DLOG(FATAL)
251 << "Don't lookup extension types if they aren't present (2). ";
252 }
253 return ext->type;
254 }
255
ClearExtension(int number)256 void ExtensionSet::ClearExtension(int number) {
257 Extension* ext = FindOrNull(number);
258 if (ext == nullptr) return;
259 ext->Clear();
260 }
261
262 // ===================================================================
263 // Field accessors
264
265 namespace {
266
267 enum { REPEATED_FIELD, OPTIONAL_FIELD };
268
269 } // namespace
270
271 #define ABSL_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE) \
272 ABSL_DCHECK_EQ((EXTENSION).is_repeated ? REPEATED_FIELD : OPTIONAL_FIELD, \
273 LABEL); \
274 ABSL_DCHECK_EQ(cpp_type((EXTENSION).type), WireFormatLite::CPPTYPE_##CPPTYPE)
275
276 // -------------------------------------------------------------------
277 // Primitives
278
279 #define PRIMITIVE_ACCESSORS(UPPERCASE, LOWERCASE, CAMELCASE) \
280 \
281 LOWERCASE ExtensionSet::Get##CAMELCASE(int number, LOWERCASE default_value) \
282 const { \
283 const Extension* extension = FindOrNull(number); \
284 if (extension == nullptr || extension->is_cleared) { \
285 return default_value; \
286 } else { \
287 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE); \
288 return extension->LOWERCASE##_value; \
289 } \
290 } \
291 \
292 const LOWERCASE& ExtensionSet::GetRef##CAMELCASE( \
293 int number, const LOWERCASE& default_value) const { \
294 const Extension* extension = FindOrNull(number); \
295 if (extension == nullptr || extension->is_cleared) { \
296 return default_value; \
297 } else { \
298 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE); \
299 return extension->LOWERCASE##_value; \
300 } \
301 } \
302 \
303 void ExtensionSet::Set##CAMELCASE(int number, FieldType type, \
304 LOWERCASE value, \
305 const FieldDescriptor* descriptor) { \
306 Extension* extension; \
307 if (MaybeNewExtension(number, descriptor, &extension)) { \
308 extension->type = type; \
309 ABSL_DCHECK_EQ(cpp_type(extension->type), \
310 WireFormatLite::CPPTYPE_##UPPERCASE); \
311 extension->is_repeated = false; \
312 extension->is_pointer = false; \
313 } else { \
314 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE); \
315 } \
316 extension->is_cleared = false; \
317 extension->LOWERCASE##_value = value; \
318 } \
319 \
320 LOWERCASE ExtensionSet::GetRepeated##CAMELCASE(int number, int index) \
321 const { \
322 const Extension* extension = FindOrNull(number); \
323 ABSL_CHECK(extension != nullptr) \
324 << "Index out-of-bounds (field is empty)."; \
325 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE); \
326 return extension->ptr.repeated_##LOWERCASE##_value->Get(index); \
327 } \
328 \
329 const LOWERCASE& ExtensionSet::GetRefRepeated##CAMELCASE(int number, \
330 int index) const { \
331 const Extension* extension = FindOrNull(number); \
332 ABSL_CHECK(extension != nullptr) \
333 << "Index out-of-bounds (field is empty)."; \
334 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE); \
335 return extension->ptr.repeated_##LOWERCASE##_value->Get(index); \
336 } \
337 \
338 void ExtensionSet::SetRepeated##CAMELCASE(int number, int index, \
339 LOWERCASE value) { \
340 Extension* extension = FindOrNull(number); \
341 ABSL_CHECK(extension != nullptr) \
342 << "Index out-of-bounds (field is empty)."; \
343 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE); \
344 extension->ptr.repeated_##LOWERCASE##_value->Set(index, value); \
345 } \
346 \
347 void ExtensionSet::Add##CAMELCASE(int number, FieldType type, bool packed, \
348 LOWERCASE value, \
349 const FieldDescriptor* descriptor) { \
350 Extension* extension; \
351 if (MaybeNewExtension(number, descriptor, &extension)) { \
352 extension->type = type; \
353 ABSL_DCHECK_EQ(cpp_type(extension->type), \
354 WireFormatLite::CPPTYPE_##UPPERCASE); \
355 extension->is_repeated = true; \
356 extension->is_pointer = true; \
357 extension->is_packed = packed; \
358 extension->ptr.repeated_##LOWERCASE##_value = \
359 Arena::Create<RepeatedField<LOWERCASE>>(arena_); \
360 } else { \
361 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, UPPERCASE); \
362 ABSL_DCHECK_EQ(extension->is_packed, packed); \
363 } \
364 extension->ptr.repeated_##LOWERCASE##_value->Add(value); \
365 }
366
PRIMITIVE_ACCESSORS(INT32,int32_t,Int32)367 PRIMITIVE_ACCESSORS(INT32, int32_t, Int32)
368 PRIMITIVE_ACCESSORS(INT64, int64_t, Int64)
369 PRIMITIVE_ACCESSORS(UINT32, uint32_t, UInt32)
370 PRIMITIVE_ACCESSORS(UINT64, uint64_t, UInt64)
371 PRIMITIVE_ACCESSORS(FLOAT, float, Float)
372 PRIMITIVE_ACCESSORS(DOUBLE, double, Double)
373 PRIMITIVE_ACCESSORS(BOOL, bool, Bool)
374
375 #undef PRIMITIVE_ACCESSORS
376
377 const void* ExtensionSet::GetRawRepeatedField(int number,
378 const void* default_value) const {
379 const Extension* extension = FindOrNull(number);
380 if (extension == nullptr) {
381 return default_value;
382 }
383 // We assume that all the RepeatedField<>* pointers have the same
384 // size and alignment within the anonymous union in Extension.
385 return extension->ptr.repeated_int32_t_value;
386 }
387
MutableRawRepeatedField(int number,FieldType field_type,bool packed,const FieldDescriptor * desc)388 void* ExtensionSet::MutableRawRepeatedField(int number, FieldType field_type,
389 bool packed,
390 const FieldDescriptor* desc) {
391 Extension* extension;
392
393 // We instantiate an empty Repeated{,Ptr}Field if one doesn't exist for this
394 // extension.
395 if (MaybeNewExtension(number, desc, &extension)) {
396 extension->is_repeated = true;
397 extension->is_pointer = true;
398 extension->type = field_type;
399 extension->is_packed = packed;
400
401 switch (WireFormatLite::FieldTypeToCppType(
402 static_cast<WireFormatLite::FieldType>(field_type))) {
403 case WireFormatLite::CPPTYPE_INT32:
404 extension->ptr.repeated_int32_t_value =
405 Arena::Create<RepeatedField<int32_t>>(arena_);
406 break;
407 case WireFormatLite::CPPTYPE_INT64:
408 extension->ptr.repeated_int64_t_value =
409 Arena::Create<RepeatedField<int64_t>>(arena_);
410 break;
411 case WireFormatLite::CPPTYPE_UINT32:
412 extension->ptr.repeated_uint32_t_value =
413 Arena::Create<RepeatedField<uint32_t>>(arena_);
414 break;
415 case WireFormatLite::CPPTYPE_UINT64:
416 extension->ptr.repeated_uint64_t_value =
417 Arena::Create<RepeatedField<uint64_t>>(arena_);
418 break;
419 case WireFormatLite::CPPTYPE_DOUBLE:
420 extension->ptr.repeated_double_value =
421 Arena::Create<RepeatedField<double>>(arena_);
422 break;
423 case WireFormatLite::CPPTYPE_FLOAT:
424 extension->ptr.repeated_float_value =
425 Arena::Create<RepeatedField<float>>(arena_);
426 break;
427 case WireFormatLite::CPPTYPE_BOOL:
428 extension->ptr.repeated_bool_value =
429 Arena::Create<RepeatedField<bool>>(arena_);
430 break;
431 case WireFormatLite::CPPTYPE_ENUM:
432 extension->ptr.repeated_enum_value =
433 Arena::Create<RepeatedField<int>>(arena_);
434 break;
435 case WireFormatLite::CPPTYPE_STRING:
436 extension->ptr.repeated_string_value =
437 Arena::Create<RepeatedPtrField<std::string>>(arena_);
438 break;
439 case WireFormatLite::CPPTYPE_MESSAGE:
440 extension->ptr.repeated_message_value =
441 Arena::Create<RepeatedPtrField<MessageLite>>(arena_);
442 break;
443 }
444 }
445
446 // We assume that all the RepeatedField<>* pointers have the same
447 // size and alignment within the anonymous union in Extension.
448 return extension->ptr.repeated_int32_t_value;
449 }
450
451 // Compatible version using old call signature. Does not create extensions when
452 // the don't already exist; instead, just ABSL_CHECK-fails.
MutableRawRepeatedField(int number)453 void* ExtensionSet::MutableRawRepeatedField(int number) {
454 Extension* extension = FindOrNull(number);
455 ABSL_CHECK(extension != nullptr) << "Extension not found.";
456 // We assume that all the RepeatedField<>* pointers have the same
457 // size and alignment within the anonymous union in Extension.
458 return extension->ptr.repeated_int32_t_value;
459 }
460
461 // -------------------------------------------------------------------
462 // Enums
463
GetEnum(int number,int default_value) const464 int ExtensionSet::GetEnum(int number, int default_value) const {
465 const Extension* extension = FindOrNull(number);
466 if (extension == nullptr || extension->is_cleared) {
467 // Not present. Return the default value.
468 return default_value;
469 } else {
470 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
471 return extension->enum_value;
472 }
473 }
474
GetRefEnum(int number,const int & default_value) const475 const int& ExtensionSet::GetRefEnum(int number,
476 const int& default_value) const {
477 const Extension* extension = FindOrNull(number);
478 if (extension == nullptr || extension->is_cleared) {
479 // Not present. Return the default value.
480 return default_value;
481 } else {
482 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
483 return extension->enum_value;
484 }
485 }
486
SetEnum(int number,FieldType type,int value,const FieldDescriptor * descriptor)487 void ExtensionSet::SetEnum(int number, FieldType type, int value,
488 const FieldDescriptor* descriptor) {
489 Extension* extension;
490 if (MaybeNewExtension(number, descriptor, &extension)) {
491 extension->type = type;
492 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
493 extension->is_repeated = false;
494 extension->is_pointer = false;
495 } else {
496 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
497 }
498 extension->is_cleared = false;
499 extension->enum_value = value;
500 }
501
GetRepeatedEnum(int number,int index) const502 int ExtensionSet::GetRepeatedEnum(int number, int index) const {
503 const Extension* extension = FindOrNull(number);
504 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
505 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
506 return extension->ptr.repeated_enum_value->Get(index);
507 }
508
GetRefRepeatedEnum(int number,int index) const509 const int& ExtensionSet::GetRefRepeatedEnum(int number, int index) const {
510 const Extension* extension = FindOrNull(number);
511 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
512 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
513 return extension->ptr.repeated_enum_value->Get(index);
514 }
515
GetMessageByteSizeLong(int number) const516 size_t ExtensionSet::GetMessageByteSizeLong(int number) const {
517 const Extension* extension = FindOrNull(number);
518 ABSL_CHECK(extension != nullptr) << "not present";
519 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
520 return extension->is_lazy ? extension->ptr.lazymessage_value->ByteSizeLong()
521 : extension->ptr.message_value->ByteSizeLong();
522 }
523
InternalSerializeMessage(int number,const MessageLite * prototype,uint8_t * target,io::EpsCopyOutputStream * stream) const524 uint8_t* ExtensionSet::InternalSerializeMessage(
525 int number, const MessageLite* prototype, uint8_t* target,
526 io::EpsCopyOutputStream* stream) const {
527 const Extension* extension = FindOrNull(number);
528 ABSL_CHECK(extension != nullptr) << "not present";
529 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
530
531 if (extension->is_lazy) {
532 return extension->ptr.lazymessage_value->WriteMessageToArray(
533 prototype, number, target, stream);
534 }
535
536 const auto* msg = extension->ptr.message_value;
537 return WireFormatLite::InternalWriteMessage(
538 number, *msg, msg->GetCachedSize(), target, stream);
539 }
540
SetRepeatedEnum(int number,int index,int value)541 void ExtensionSet::SetRepeatedEnum(int number, int index, int value) {
542 Extension* extension = FindOrNull(number);
543 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
544 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
545 extension->ptr.repeated_enum_value->Set(index, value);
546 }
547
AddEnum(int number,FieldType type,bool packed,int value,const FieldDescriptor * descriptor)548 void ExtensionSet::AddEnum(int number, FieldType type, bool packed, int value,
549 const FieldDescriptor* descriptor) {
550 Extension* extension;
551 if (MaybeNewExtension(number, descriptor, &extension)) {
552 extension->type = type;
553 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
554 extension->is_repeated = true;
555 extension->is_pointer = true;
556 extension->is_packed = packed;
557 extension->ptr.repeated_enum_value =
558 Arena::Create<RepeatedField<int>>(arena_);
559 } else {
560 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, ENUM);
561 ABSL_DCHECK_EQ(extension->is_packed, packed);
562 }
563 extension->ptr.repeated_enum_value->Add(value);
564 }
565
566 // -------------------------------------------------------------------
567 // Strings
568
GetString(int number,const std::string & default_value) const569 const std::string& ExtensionSet::GetString(
570 int number, const std::string& default_value) const {
571 const Extension* extension = FindOrNull(number);
572 if (extension == nullptr || extension->is_cleared) {
573 // Not present. Return the default value.
574 return default_value;
575 } else {
576 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
577 return *extension->ptr.string_value;
578 }
579 }
580
MutableString(int number,FieldType type,const FieldDescriptor * descriptor)581 std::string* ExtensionSet::MutableString(int number, FieldType type,
582 const FieldDescriptor* descriptor) {
583 Extension* extension;
584 if (MaybeNewExtension(number, descriptor, &extension)) {
585 extension->type = type;
586 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
587 extension->is_repeated = false;
588 extension->is_pointer = true;
589 extension->ptr.string_value = Arena::Create<std::string>(arena_);
590 } else {
591 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
592 }
593 extension->is_cleared = false;
594 return extension->ptr.string_value;
595 }
596
GetRepeatedString(int number,int index) const597 const std::string& ExtensionSet::GetRepeatedString(int number,
598 int index) const {
599 const Extension* extension = FindOrNull(number);
600 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
601 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
602 return extension->ptr.repeated_string_value->Get(index);
603 }
604
MutableRepeatedString(int number,int index)605 std::string* ExtensionSet::MutableRepeatedString(int number, int index) {
606 Extension* extension = FindOrNull(number);
607 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
608 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
609 return extension->ptr.repeated_string_value->Mutable(index);
610 }
611
AddString(int number,FieldType type,const FieldDescriptor * descriptor)612 std::string* ExtensionSet::AddString(int number, FieldType type,
613 const FieldDescriptor* descriptor) {
614 Extension* extension;
615 if (MaybeNewExtension(number, descriptor, &extension)) {
616 extension->type = type;
617 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
618 extension->is_repeated = true;
619 extension->is_pointer = true;
620 extension->is_packed = false;
621 extension->ptr.repeated_string_value =
622 Arena::Create<RepeatedPtrField<std::string>>(arena_);
623 } else {
624 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, STRING);
625 }
626 return extension->ptr.repeated_string_value->Add();
627 }
628
629 // -------------------------------------------------------------------
630 // Messages
631
GetMessage(int number,const MessageLite & default_value) const632 const MessageLite& ExtensionSet::GetMessage(
633 int number, const MessageLite& default_value) const {
634 const Extension* extension = FindOrNull(number);
635 if (extension == nullptr) {
636 // Not present. Return the default value.
637 return default_value;
638 } else {
639 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
640 if (extension->is_lazy) {
641 return extension->ptr.lazymessage_value->GetMessage(default_value,
642 arena_);
643 } else {
644 return *extension->ptr.message_value;
645 }
646 }
647 }
648
649 // Defined in extension_set_heavy.cc.
650 // const MessageLite& ExtensionSet::GetMessage(int number,
651 // const Descriptor* message_type,
652 // MessageFactory* factory) const
653
MutableMessage(int number,FieldType type,const MessageLite & prototype,const FieldDescriptor * descriptor)654 MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
655 const MessageLite& prototype,
656 const FieldDescriptor* descriptor) {
657 Extension* extension;
658 if (MaybeNewExtension(number, descriptor, &extension)) {
659 extension->type = type;
660 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
661 extension->is_repeated = false;
662 extension->is_pointer = true;
663 extension->is_lazy = false;
664 extension->ptr.message_value = prototype.New(arena_);
665 extension->is_cleared = false;
666 return extension->ptr.message_value;
667 } else {
668 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
669 extension->is_cleared = false;
670 if (extension->is_lazy) {
671 return extension->ptr.lazymessage_value->MutableMessage(prototype,
672 arena_);
673 } else {
674 return extension->ptr.message_value;
675 }
676 }
677 }
678
679 // Defined in extension_set_heavy.cc.
680 // MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
681 // const Descriptor* message_type,
682 // MessageFactory* factory)
683
SetAllocatedMessage(int number,FieldType type,const FieldDescriptor * descriptor,MessageLite * message)684 void ExtensionSet::SetAllocatedMessage(int number, FieldType type,
685 const FieldDescriptor* descriptor,
686 MessageLite* message) {
687 if (message == nullptr) {
688 ClearExtension(number);
689 return;
690 }
691 Arena* const arena = arena_;
692 Arena* const message_arena = message->GetArena();
693 ABSL_DCHECK(message_arena == nullptr || message_arena == arena);
694
695 Extension* extension;
696 if (MaybeNewExtension(number, descriptor, &extension)) {
697 extension->type = type;
698 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
699 extension->is_repeated = false;
700 extension->is_pointer = true;
701 extension->is_lazy = false;
702 if (message_arena == arena) {
703 extension->ptr.message_value = message;
704 } else if (message_arena == nullptr) {
705 extension->ptr.message_value = message;
706 arena->Own(message); // not nullptr because not equal to message_arena
707 } else {
708 extension->ptr.message_value = message->New(arena);
709 extension->ptr.message_value->CheckTypeAndMergeFrom(*message);
710 }
711 } else {
712 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
713 if (extension->is_lazy) {
714 extension->ptr.lazymessage_value->SetAllocatedMessage(message, arena);
715 } else {
716 if (arena == nullptr) {
717 delete extension->ptr.message_value;
718 }
719 if (message_arena == arena) {
720 extension->ptr.message_value = message;
721 } else if (message_arena == nullptr) {
722 extension->ptr.message_value = message;
723 arena->Own(message); // not nullptr because not equal to message_arena
724 } else {
725 extension->ptr.message_value = message->New(arena);
726 extension->ptr.message_value->CheckTypeAndMergeFrom(*message);
727 }
728 }
729 }
730 extension->is_cleared = false;
731 }
732
UnsafeArenaSetAllocatedMessage(int number,FieldType type,const FieldDescriptor * descriptor,MessageLite * message)733 void ExtensionSet::UnsafeArenaSetAllocatedMessage(
734 int number, FieldType type, const FieldDescriptor* descriptor,
735 MessageLite* message) {
736 if (message == nullptr) {
737 ClearExtension(number);
738 return;
739 }
740 Extension* extension;
741 if (MaybeNewExtension(number, descriptor, &extension)) {
742 extension->type = type;
743 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
744 extension->is_repeated = false;
745 extension->is_pointer = true;
746 extension->is_lazy = false;
747 extension->ptr.message_value = message;
748 } else {
749 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
750 if (extension->is_lazy) {
751 extension->ptr.lazymessage_value->UnsafeArenaSetAllocatedMessage(message,
752 arena_);
753 } else {
754 if (arena_ == nullptr) {
755 delete extension->ptr.message_value;
756 }
757 extension->ptr.message_value = message;
758 }
759 }
760 extension->is_cleared = false;
761 }
762
ReleaseMessage(int number,const MessageLite & prototype)763 MessageLite* ExtensionSet::ReleaseMessage(int number,
764 const MessageLite& prototype) {
765 Extension* extension = FindOrNull(number);
766 if (extension == nullptr) {
767 // Not present. Return nullptr.
768 return nullptr;
769 } else {
770 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
771 MessageLite* ret = nullptr;
772 if (extension->is_lazy) {
773 Arena* const arena = arena_;
774 ret = extension->ptr.lazymessage_value->ReleaseMessage(prototype, arena);
775 if (arena == nullptr) {
776 delete extension->ptr.lazymessage_value;
777 }
778 } else {
779 if (arena_ == nullptr) {
780 ret = extension->ptr.message_value;
781 } else {
782 // ReleaseMessage() always returns a heap-allocated message, and we are
783 // on an arena, so we need to make a copy of this message to return.
784 ret = extension->ptr.message_value->New();
785 ret->CheckTypeAndMergeFrom(*extension->ptr.message_value);
786 }
787 }
788 Erase(number);
789 return ret;
790 }
791 }
792
UnsafeArenaReleaseMessage(int number,const MessageLite & prototype)793 MessageLite* ExtensionSet::UnsafeArenaReleaseMessage(
794 int number, const MessageLite& prototype) {
795 Extension* extension = FindOrNull(number);
796 if (extension == nullptr) {
797 // Not present. Return nullptr.
798 return nullptr;
799 } else {
800 ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
801 MessageLite* ret = nullptr;
802 if (extension->is_lazy) {
803 Arena* const arena = arena_;
804 ret = extension->ptr.lazymessage_value->UnsafeArenaReleaseMessage(
805 prototype, arena);
806 if (arena == nullptr) {
807 delete extension->ptr.lazymessage_value;
808 }
809 } else {
810 ret = extension->ptr.message_value;
811 }
812 Erase(number);
813 return ret;
814 }
815 }
816
817 // Defined in extension_set_heavy.cc.
818 // MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
819 // MessageFactory* factory);
820
GetRepeatedMessage(int number,int index) const821 const MessageLite& ExtensionSet::GetRepeatedMessage(int number,
822 int index) const {
823 const Extension* extension = FindOrNull(number);
824 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
825 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
826 return extension->ptr.repeated_message_value->Get(index);
827 }
828
MutableRepeatedMessage(int number,int index)829 MessageLite* ExtensionSet::MutableRepeatedMessage(int number, int index) {
830 Extension* extension = FindOrNull(number);
831 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
832 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
833 return extension->ptr.repeated_message_value->Mutable(index);
834 }
835
AddMessage(int number,FieldType type,const MessageLite & prototype,const FieldDescriptor * descriptor)836 MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
837 const MessageLite& prototype,
838 const FieldDescriptor* descriptor) {
839 Extension* extension;
840 if (MaybeNewExtension(number, descriptor, &extension)) {
841 extension->type = type;
842 ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
843 extension->is_repeated = true;
844 extension->is_pointer = true;
845 extension->ptr.repeated_message_value =
846 Arena::Create<RepeatedPtrField<MessageLite>>(arena_);
847 } else {
848 ABSL_DCHECK_TYPE(*extension, REPEATED_FIELD, MESSAGE);
849 }
850
851 return reinterpret_cast<internal::RepeatedPtrFieldBase*>(
852 extension->ptr.repeated_message_value)
853 ->AddMessage(&prototype);
854 }
855
856 // Defined in extension_set_heavy.cc.
857 // MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
858 // const Descriptor* message_type,
859 // MessageFactory* factory)
860
861 #undef ABSL_DCHECK_TYPE
862
RemoveLast(int number)863 void ExtensionSet::RemoveLast(int number) {
864 Extension* extension = FindOrNull(number);
865 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
866 ABSL_DCHECK(extension->is_repeated);
867
868 switch (cpp_type(extension->type)) {
869 case WireFormatLite::CPPTYPE_INT32:
870 extension->ptr.repeated_int32_t_value->RemoveLast();
871 break;
872 case WireFormatLite::CPPTYPE_INT64:
873 extension->ptr.repeated_int64_t_value->RemoveLast();
874 break;
875 case WireFormatLite::CPPTYPE_UINT32:
876 extension->ptr.repeated_uint32_t_value->RemoveLast();
877 break;
878 case WireFormatLite::CPPTYPE_UINT64:
879 extension->ptr.repeated_uint64_t_value->RemoveLast();
880 break;
881 case WireFormatLite::CPPTYPE_FLOAT:
882 extension->ptr.repeated_float_value->RemoveLast();
883 break;
884 case WireFormatLite::CPPTYPE_DOUBLE:
885 extension->ptr.repeated_double_value->RemoveLast();
886 break;
887 case WireFormatLite::CPPTYPE_BOOL:
888 extension->ptr.repeated_bool_value->RemoveLast();
889 break;
890 case WireFormatLite::CPPTYPE_ENUM:
891 extension->ptr.repeated_enum_value->RemoveLast();
892 break;
893 case WireFormatLite::CPPTYPE_STRING:
894 extension->ptr.repeated_string_value->RemoveLast();
895 break;
896 case WireFormatLite::CPPTYPE_MESSAGE:
897 extension->ptr.repeated_message_value->RemoveLast();
898 break;
899 }
900 }
901
ReleaseLast(int number)902 MessageLite* ExtensionSet::ReleaseLast(int number) {
903 Extension* extension = FindOrNull(number);
904 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
905 ABSL_DCHECK(extension->is_repeated);
906 ABSL_DCHECK(cpp_type(extension->type) == WireFormatLite::CPPTYPE_MESSAGE);
907 return extension->ptr.repeated_message_value->ReleaseLast();
908 }
909
UnsafeArenaReleaseLast(int number)910 MessageLite* ExtensionSet::UnsafeArenaReleaseLast(int number) {
911 Extension* extension = FindOrNull(number);
912 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
913 ABSL_DCHECK(extension->is_repeated);
914 ABSL_DCHECK(cpp_type(extension->type) == WireFormatLite::CPPTYPE_MESSAGE);
915 return extension->ptr.repeated_message_value->UnsafeArenaReleaseLast();
916 }
917
SwapElements(int number,int index1,int index2)918 void ExtensionSet::SwapElements(int number, int index1, int index2) {
919 Extension* extension = FindOrNull(number);
920 ABSL_CHECK(extension != nullptr) << "Index out-of-bounds (field is empty).";
921 ABSL_DCHECK(extension->is_repeated);
922
923 switch (cpp_type(extension->type)) {
924 case WireFormatLite::CPPTYPE_INT32:
925 extension->ptr.repeated_int32_t_value->SwapElements(index1, index2);
926 break;
927 case WireFormatLite::CPPTYPE_INT64:
928 extension->ptr.repeated_int64_t_value->SwapElements(index1, index2);
929 break;
930 case WireFormatLite::CPPTYPE_UINT32:
931 extension->ptr.repeated_uint32_t_value->SwapElements(index1, index2);
932 break;
933 case WireFormatLite::CPPTYPE_UINT64:
934 extension->ptr.repeated_uint64_t_value->SwapElements(index1, index2);
935 break;
936 case WireFormatLite::CPPTYPE_FLOAT:
937 extension->ptr.repeated_float_value->SwapElements(index1, index2);
938 break;
939 case WireFormatLite::CPPTYPE_DOUBLE:
940 extension->ptr.repeated_double_value->SwapElements(index1, index2);
941 break;
942 case WireFormatLite::CPPTYPE_BOOL:
943 extension->ptr.repeated_bool_value->SwapElements(index1, index2);
944 break;
945 case WireFormatLite::CPPTYPE_ENUM:
946 extension->ptr.repeated_enum_value->SwapElements(index1, index2);
947 break;
948 case WireFormatLite::CPPTYPE_STRING:
949 extension->ptr.repeated_string_value->SwapElements(index1, index2);
950 break;
951 case WireFormatLite::CPPTYPE_MESSAGE:
952 extension->ptr.repeated_message_value->SwapElements(index1, index2);
953 break;
954 }
955 }
956
957 // ===================================================================
958
Clear()959 void ExtensionSet::Clear() {
960 ForEach([](int /* number */, Extension& ext) { ext.Clear(); }, Prefetch{});
961 }
962
963 namespace {
964 // Computes the size of an ExtensionSet union without actually constructing the
965 // union. Note that we do not count cleared extensions from the source to be
966 // part of the total, because there is no need to allocate space for those. We
967 // do include cleared extensions in the destination, though, because those are
968 // already allocated and will not be going away.
969 template <typename ItX, typename ItY>
SizeOfUnion(ItX it_dest,ItX end_dest,ItY it_source,ItY end_source)970 size_t SizeOfUnion(ItX it_dest, ItX end_dest, ItY it_source, ItY end_source) {
971 size_t result = 0;
972 while (it_dest != end_dest && it_source != end_source) {
973 if (it_dest->first < it_source->first) {
974 ++result;
975 ++it_dest;
976 } else if (it_dest->first == it_source->first) {
977 ++result;
978 ++it_dest;
979 ++it_source;
980 } else {
981 if (!it_source->second.is_cleared) {
982 ++result;
983 }
984 ++it_source;
985 }
986 }
987 result += std::distance(it_dest, end_dest);
988 for (; it_source != end_source; ++it_source) {
989 if (!it_source->second.is_cleared) {
990 ++result;
991 }
992 }
993 return result;
994 }
995 } // namespace
996
MergeFrom(const MessageLite * extendee,const ExtensionSet & other)997 void ExtensionSet::MergeFrom(const MessageLite* extendee,
998 const ExtensionSet& other) {
999 Prefetch5LinesFrom1Line(&other);
1000 if (PROTOBUF_PREDICT_TRUE(!is_large())) {
1001 if (PROTOBUF_PREDICT_TRUE(!other.is_large())) {
1002 GrowCapacity(SizeOfUnion(flat_begin(), flat_end(), other.flat_begin(),
1003 other.flat_end()));
1004 } else {
1005 GrowCapacity(SizeOfUnion(flat_begin(), flat_end(),
1006 other.map_.large->begin(),
1007 other.map_.large->end()));
1008 }
1009 }
1010 other.ForEach(
1011 [extendee, this, &other](int number, const Extension& ext) {
1012 this->InternalExtensionMergeFrom(extendee, number, ext, other.arena_);
1013 },
1014 Prefetch{});
1015 }
1016
InternalExtensionMergeFrom(const MessageLite * extendee,int number,const Extension & other_extension,Arena * other_arena)1017 void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
1018 int number,
1019 const Extension& other_extension,
1020 Arena* other_arena) {
1021 if (other_extension.is_repeated) {
1022 Extension* extension;
1023 bool is_new =
1024 MaybeNewExtension(number, other_extension.descriptor, &extension);
1025 if (is_new) {
1026 // Extension did not already exist in set.
1027 extension->type = other_extension.type;
1028 extension->is_packed = other_extension.is_packed;
1029 extension->is_repeated = true;
1030 extension->is_pointer = true;
1031 } else {
1032 ABSL_DCHECK_EQ(extension->type, other_extension.type);
1033 ABSL_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
1034 ABSL_DCHECK(extension->is_repeated);
1035 }
1036
1037 switch (cpp_type(other_extension.type)) {
1038 #define HANDLE_TYPE(UPPERCASE, LOWERCASE, REPEATED_TYPE) \
1039 case WireFormatLite::CPPTYPE_##UPPERCASE: \
1040 if (is_new) { \
1041 extension->ptr.repeated_##LOWERCASE##_value = \
1042 Arena::Create<REPEATED_TYPE>(arena_); \
1043 } \
1044 extension->ptr.repeated_##LOWERCASE##_value->MergeFrom( \
1045 *other_extension.ptr.repeated_##LOWERCASE##_value); \
1046 break;
1047
1048 HANDLE_TYPE(INT32, int32_t, RepeatedField<int32_t>);
1049 HANDLE_TYPE(INT64, int64_t, RepeatedField<int64_t>);
1050 HANDLE_TYPE(UINT32, uint32_t, RepeatedField<uint32_t>);
1051 HANDLE_TYPE(UINT64, uint64_t, RepeatedField<uint64_t>);
1052 HANDLE_TYPE(FLOAT, float, RepeatedField<float>);
1053 HANDLE_TYPE(DOUBLE, double, RepeatedField<double>);
1054 HANDLE_TYPE(BOOL, bool, RepeatedField<bool>);
1055 HANDLE_TYPE(ENUM, enum, RepeatedField<int>);
1056 HANDLE_TYPE(STRING, string, RepeatedPtrField<std::string>);
1057 HANDLE_TYPE(MESSAGE, message, RepeatedPtrField<MessageLite>);
1058 #undef HANDLE_TYPE
1059 }
1060 } else {
1061 if (!other_extension.is_cleared) {
1062 switch (cpp_type(other_extension.type)) {
1063 #define HANDLE_TYPE(UPPERCASE, LOWERCASE, CAMELCASE) \
1064 case WireFormatLite::CPPTYPE_##UPPERCASE: \
1065 Set##CAMELCASE(number, other_extension.type, \
1066 other_extension.LOWERCASE##_value, \
1067 other_extension.descriptor); \
1068 break;
1069
1070 HANDLE_TYPE(INT32, int32_t, Int32);
1071 HANDLE_TYPE(INT64, int64_t, Int64);
1072 HANDLE_TYPE(UINT32, uint32_t, UInt32);
1073 HANDLE_TYPE(UINT64, uint64_t, UInt64);
1074 HANDLE_TYPE(FLOAT, float, Float);
1075 HANDLE_TYPE(DOUBLE, double, Double);
1076 HANDLE_TYPE(BOOL, bool, Bool);
1077 HANDLE_TYPE(ENUM, enum, Enum);
1078 #undef HANDLE_TYPE
1079 case WireFormatLite::CPPTYPE_STRING:
1080 SetString(number, other_extension.type,
1081 *other_extension.ptr.string_value,
1082 other_extension.descriptor);
1083 break;
1084 case WireFormatLite::CPPTYPE_MESSAGE: {
1085 Arena* const arena = arena_;
1086 Extension* extension;
1087 bool is_new =
1088 MaybeNewExtension(number, other_extension.descriptor, &extension);
1089 if (is_new) {
1090 extension->type = other_extension.type;
1091 extension->is_packed = other_extension.is_packed;
1092 extension->is_repeated = false;
1093 extension->is_pointer = true;
1094 if (other_extension.is_lazy) {
1095 extension->is_lazy = true;
1096 extension->ptr.lazymessage_value =
1097 other_extension.ptr.lazymessage_value->New(arena);
1098 extension->ptr.lazymessage_value->MergeFrom(
1099 GetPrototypeForLazyMessage(extendee, number),
1100 *other_extension.ptr.lazymessage_value, arena, other_arena);
1101 } else {
1102 extension->is_lazy = false;
1103 extension->ptr.message_value =
1104 other_extension.ptr.message_value->New(arena);
1105 extension->ptr.message_value->CheckTypeAndMergeFrom(
1106 *other_extension.ptr.message_value);
1107 }
1108 } else {
1109 ABSL_DCHECK_EQ(extension->type, other_extension.type);
1110 ABSL_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
1111 ABSL_DCHECK(!extension->is_repeated);
1112 if (other_extension.is_lazy) {
1113 if (extension->is_lazy) {
1114 extension->ptr.lazymessage_value->MergeFrom(
1115 GetPrototypeForLazyMessage(extendee, number),
1116 *other_extension.ptr.lazymessage_value, arena, other_arena);
1117 } else {
1118 extension->ptr.message_value->CheckTypeAndMergeFrom(
1119 other_extension.ptr.lazymessage_value->GetMessage(
1120 *extension->ptr.message_value, other_arena));
1121 }
1122 } else {
1123 if (extension->is_lazy) {
1124 extension->ptr.lazymessage_value
1125 ->MutableMessage(*other_extension.ptr.message_value, arena)
1126 ->CheckTypeAndMergeFrom(*other_extension.ptr.message_value);
1127 } else {
1128 extension->ptr.message_value->CheckTypeAndMergeFrom(
1129 *other_extension.ptr.message_value);
1130 }
1131 }
1132 }
1133 extension->is_cleared = false;
1134 break;
1135 }
1136 }
1137 }
1138 }
1139 }
1140
Swap(const MessageLite * extendee,ExtensionSet * other)1141 void ExtensionSet::Swap(const MessageLite* extendee, ExtensionSet* other) {
1142 if (internal::CanUseInternalSwap(arena_, other->arena_)) {
1143 InternalSwap(other);
1144 } else {
1145 // TODO: We maybe able to optimize a case where we are
1146 // swapping from heap to arena-allocated extension set, by just Own()'ing
1147 // the extensions.
1148 ExtensionSet extension_set;
1149 extension_set.MergeFrom(extendee, *other);
1150 other->Clear();
1151 other->MergeFrom(extendee, *this);
1152 Clear();
1153 MergeFrom(extendee, extension_set);
1154 }
1155 }
1156
InternalSwap(ExtensionSet * other)1157 void ExtensionSet::InternalSwap(ExtensionSet* other) {
1158 using std::swap;
1159 swap(arena_, other->arena_);
1160 swap(flat_capacity_, other->flat_capacity_);
1161 swap(flat_size_, other->flat_size_);
1162 swap(map_, other->map_);
1163 }
1164
SwapExtension(const MessageLite * extendee,ExtensionSet * other,int number)1165 void ExtensionSet::SwapExtension(const MessageLite* extendee,
1166 ExtensionSet* other, int number) {
1167 if (this == other) return;
1168
1169 Arena* const arena = arena_;
1170 Arena* const other_arena = other->arena_;
1171 if (arena == other_arena) {
1172 UnsafeShallowSwapExtension(other, number);
1173 return;
1174 }
1175
1176 Extension* this_ext = FindOrNull(number);
1177 Extension* other_ext = other->FindOrNull(number);
1178
1179 if (this_ext == other_ext) return;
1180
1181 if (this_ext != nullptr && other_ext != nullptr) {
1182 // TODO: We could further optimize these cases,
1183 // especially avoid creation of ExtensionSet, and move MergeFrom logic
1184 // into Extensions itself (which takes arena as an argument).
1185 // We do it this way to reuse the copy-across-arenas logic already
1186 // implemented in ExtensionSet's MergeFrom.
1187 ExtensionSet temp;
1188 temp.InternalExtensionMergeFrom(extendee, number, *other_ext, other_arena);
1189 Extension* temp_ext = temp.FindOrNull(number);
1190
1191 other_ext->Clear();
1192 other->InternalExtensionMergeFrom(extendee, number, *this_ext, arena);
1193 this_ext->Clear();
1194 InternalExtensionMergeFrom(extendee, number, *temp_ext, temp.GetArena());
1195 } else if (this_ext == nullptr) {
1196 InternalExtensionMergeFrom(extendee, number, *other_ext, other_arena);
1197 if (other_arena == nullptr) other_ext->Free();
1198 other->Erase(number);
1199 } else {
1200 other->InternalExtensionMergeFrom(extendee, number, *this_ext, arena);
1201 if (arena == nullptr) this_ext->Free();
1202 Erase(number);
1203 }
1204 }
1205
UnsafeShallowSwapExtension(ExtensionSet * other,int number)1206 void ExtensionSet::UnsafeShallowSwapExtension(ExtensionSet* other, int number) {
1207 if (this == other) return;
1208
1209 Extension* this_ext = FindOrNull(number);
1210 Extension* other_ext = other->FindOrNull(number);
1211
1212 if (this_ext == other_ext) return;
1213
1214 ABSL_DCHECK_EQ(arena_, other->arena_);
1215
1216 if (this_ext != nullptr && other_ext != nullptr) {
1217 std::swap(*this_ext, *other_ext);
1218 } else if (this_ext == nullptr) {
1219 *Insert(number).first = *other_ext;
1220 other->Erase(number);
1221 } else {
1222 *other->Insert(number).first = *this_ext;
1223 Erase(number);
1224 }
1225 }
1226
IsInitialized(const MessageLite * extendee) const1227 bool ExtensionSet::IsInitialized(const MessageLite* extendee) const {
1228 // Extensions are never required. However, we need to check that all
1229 // embedded messages are initialized.
1230 Arena* const arena = arena_;
1231 if (PROTOBUF_PREDICT_FALSE(is_large())) {
1232 for (const auto& kv : *map_.large) {
1233 if (!kv.second.IsInitialized(this, extendee, kv.first, arena)) {
1234 return false;
1235 }
1236 }
1237 return true;
1238 }
1239 for (const KeyValue* it = flat_begin(); it != flat_end(); ++it) {
1240 if (!it->second.IsInitialized(this, extendee, it->first, arena)) {
1241 return false;
1242 }
1243 }
1244 return true;
1245 }
1246
ParseField(uint64_t tag,const char * ptr,const MessageLite * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)1247 const char* ExtensionSet::ParseField(uint64_t tag, const char* ptr,
1248 const MessageLite* extendee,
1249 internal::InternalMetadata* metadata,
1250 internal::ParseContext* ctx) {
1251 GeneratedExtensionFinder finder(extendee);
1252 int number = tag >> 3;
1253 bool was_packed_on_wire;
1254 ExtensionInfo extension;
1255 if (!FindExtensionInfoFromFieldNumber(tag & 7, number, &finder, &extension,
1256 &was_packed_on_wire)) {
1257 return UnknownFieldParse(
1258 tag, metadata->mutable_unknown_fields<std::string>(), ptr, ctx);
1259 }
1260 return ParseFieldWithExtensionInfo<std::string>(
1261 number, was_packed_on_wire, extension, metadata, ptr, ctx);
1262 }
1263
ParseMessageSetItem(const char * ptr,const MessageLite * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)1264 const char* ExtensionSet::ParseMessageSetItem(
1265 const char* ptr, const MessageLite* extendee,
1266 internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
1267 return ParseMessageSetItemTmpl<MessageLite, std::string>(ptr, extendee,
1268 metadata, ctx);
1269 }
1270
FieldTypeIsPointer(FieldType type)1271 bool ExtensionSet::FieldTypeIsPointer(FieldType type) {
1272 return type == WireFormatLite::TYPE_STRING ||
1273 type == WireFormatLite::TYPE_BYTES ||
1274 type == WireFormatLite::TYPE_GROUP ||
1275 type == WireFormatLite::TYPE_MESSAGE;
1276 }
1277
_InternalSerializeImpl(const MessageLite * extendee,int start_field_number,int end_field_number,uint8_t * target,io::EpsCopyOutputStream * stream) const1278 uint8_t* ExtensionSet::_InternalSerializeImpl(
1279 const MessageLite* extendee, int start_field_number, int end_field_number,
1280 uint8_t* target, io::EpsCopyOutputStream* stream) const {
1281 if (PROTOBUF_PREDICT_FALSE(is_large())) {
1282 const auto& end = map_.large->end();
1283 for (auto it = map_.large->lower_bound(start_field_number);
1284 it != end && it->first < end_field_number; ++it) {
1285 target = it->second.InternalSerializeFieldWithCachedSizesToArray(
1286 extendee, this, it->first, target, stream);
1287 }
1288 return target;
1289 }
1290 const KeyValue* end = flat_end();
1291 const KeyValue* it = flat_begin();
1292 while (it != end && it->first < start_field_number) ++it;
1293 for (; it != end && it->first < end_field_number; ++it) {
1294 target = it->second.InternalSerializeFieldWithCachedSizesToArray(
1295 extendee, this, it->first, target, stream);
1296 }
1297 return target;
1298 }
1299
InternalSerializeMessageSetWithCachedSizesToArray(const MessageLite * extendee,uint8_t * target,io::EpsCopyOutputStream * stream) const1300 uint8_t* ExtensionSet::InternalSerializeMessageSetWithCachedSizesToArray(
1301 const MessageLite* extendee, uint8_t* target,
1302 io::EpsCopyOutputStream* stream) const {
1303 const ExtensionSet* extension_set = this;
1304 ForEach(
1305 [&target, extendee, stream, extension_set](int number,
1306 const Extension& ext) {
1307 target = ext.InternalSerializeMessageSetItemWithCachedSizesToArray(
1308 extendee, extension_set, number, target, stream);
1309 },
1310 Prefetch{});
1311 return target;
1312 }
1313
ByteSize() const1314 size_t ExtensionSet::ByteSize() const {
1315 size_t total_size = 0;
1316 ForEach(
1317 [&total_size](int number, const Extension& ext) {
1318 total_size += ext.ByteSize(number);
1319 },
1320 Prefetch{});
1321 return total_size;
1322 }
1323
1324 // Defined in extension_set_heavy.cc.
1325 // int ExtensionSet::SpaceUsedExcludingSelf() const
1326
MaybeNewExtension(int number,const FieldDescriptor * descriptor,Extension ** result)1327 bool ExtensionSet::MaybeNewExtension(int number,
1328 const FieldDescriptor* descriptor,
1329 Extension** result) {
1330 bool extension_is_new = false;
1331 std::tie(*result, extension_is_new) = Insert(number);
1332 (*result)->descriptor = descriptor;
1333 return extension_is_new;
1334 }
1335
1336 // ===================================================================
1337 // Methods of ExtensionSet::Extension
1338
Clear()1339 void ExtensionSet::Extension::Clear() {
1340 if (is_repeated) {
1341 switch (cpp_type(type)) {
1342 #define HANDLE_TYPE(UPPERCASE, LOWERCASE) \
1343 case WireFormatLite::CPPTYPE_##UPPERCASE: \
1344 ptr.repeated_##LOWERCASE##_value->Clear(); \
1345 break
1346
1347 HANDLE_TYPE(INT32, int32_t);
1348 HANDLE_TYPE(INT64, int64_t);
1349 HANDLE_TYPE(UINT32, uint32_t);
1350 HANDLE_TYPE(UINT64, uint64_t);
1351 HANDLE_TYPE(FLOAT, float);
1352 HANDLE_TYPE(DOUBLE, double);
1353 HANDLE_TYPE(BOOL, bool);
1354 HANDLE_TYPE(ENUM, enum);
1355 HANDLE_TYPE(STRING, string);
1356 HANDLE_TYPE(MESSAGE, message);
1357 #undef HANDLE_TYPE
1358 }
1359 } else {
1360 if (!is_cleared) {
1361 switch (cpp_type(type)) {
1362 case WireFormatLite::CPPTYPE_STRING:
1363 ptr.string_value->clear();
1364 break;
1365 case WireFormatLite::CPPTYPE_MESSAGE:
1366 if (is_lazy) {
1367 ptr.lazymessage_value->Clear();
1368 } else {
1369 ptr.message_value->Clear();
1370 }
1371 break;
1372 default:
1373 // No need to do anything. Get*() will return the default value
1374 // as long as is_cleared is true and Set*() will overwrite the
1375 // previous value.
1376 break;
1377 }
1378
1379 is_cleared = true;
1380 }
1381 }
1382 }
1383
ByteSize(int number) const1384 size_t ExtensionSet::Extension::ByteSize(int number) const {
1385 size_t result = 0;
1386
1387 if (is_repeated) {
1388 if (is_packed) {
1389 switch (real_type(type)) {
1390 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1391 case WireFormatLite::TYPE_##UPPERCASE: \
1392 for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1393 result += WireFormatLite::CAMELCASE##Size( \
1394 ptr.repeated_##LOWERCASE##_value->Get(i)); \
1395 } \
1396 break
1397
1398 HANDLE_TYPE(INT32, Int32, int32_t);
1399 HANDLE_TYPE(INT64, Int64, int64_t);
1400 HANDLE_TYPE(UINT32, UInt32, uint32_t);
1401 HANDLE_TYPE(UINT64, UInt64, uint64_t);
1402 HANDLE_TYPE(SINT32, SInt32, int32_t);
1403 HANDLE_TYPE(SINT64, SInt64, int64_t);
1404 HANDLE_TYPE(ENUM, Enum, enum);
1405 #undef HANDLE_TYPE
1406
1407 // Stuff with fixed size.
1408 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1409 case WireFormatLite::TYPE_##UPPERCASE: \
1410 result += WireFormatLite::k##CAMELCASE##Size * \
1411 FromIntSize(ptr.repeated_##LOWERCASE##_value->size()); \
1412 break
1413 HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1414 HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1415 HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1416 HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1417 HANDLE_TYPE(FLOAT, Float, float);
1418 HANDLE_TYPE(DOUBLE, Double, double);
1419 HANDLE_TYPE(BOOL, Bool, bool);
1420 #undef HANDLE_TYPE
1421
1422 case WireFormatLite::TYPE_STRING:
1423 case WireFormatLite::TYPE_BYTES:
1424 case WireFormatLite::TYPE_GROUP:
1425 case WireFormatLite::TYPE_MESSAGE:
1426 ABSL_LOG(FATAL) << "Non-primitive types can't be packed.";
1427 break;
1428 }
1429
1430 cached_size.set(ToCachedSize(result));
1431 if (result > 0) {
1432 result += io::CodedOutputStream::VarintSize32(result);
1433 result += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
1434 number, WireFormatLite::WIRETYPE_LENGTH_DELIMITED));
1435 }
1436 } else {
1437 size_t tag_size = WireFormatLite::TagSize(number, real_type(type));
1438
1439 switch (real_type(type)) {
1440 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1441 case WireFormatLite::TYPE_##UPPERCASE: \
1442 result += \
1443 tag_size * FromIntSize(ptr.repeated_##LOWERCASE##_value->size()); \
1444 for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1445 result += WireFormatLite::CAMELCASE##Size( \
1446 ptr.repeated_##LOWERCASE##_value->Get(i)); \
1447 } \
1448 break
1449
1450 HANDLE_TYPE(INT32, Int32, int32_t);
1451 HANDLE_TYPE(INT64, Int64, int64_t);
1452 HANDLE_TYPE(UINT32, UInt32, uint32_t);
1453 HANDLE_TYPE(UINT64, UInt64, uint64_t);
1454 HANDLE_TYPE(SINT32, SInt32, int32_t);
1455 HANDLE_TYPE(SINT64, SInt64, int64_t);
1456 HANDLE_TYPE(STRING, String, string);
1457 HANDLE_TYPE(BYTES, Bytes, string);
1458 HANDLE_TYPE(ENUM, Enum, enum);
1459 HANDLE_TYPE(GROUP, Group, message);
1460 HANDLE_TYPE(MESSAGE, Message, message);
1461 #undef HANDLE_TYPE
1462
1463 // Stuff with fixed size.
1464 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1465 case WireFormatLite::TYPE_##UPPERCASE: \
1466 result += (tag_size + WireFormatLite::k##CAMELCASE##Size) * \
1467 FromIntSize(ptr.repeated_##LOWERCASE##_value->size()); \
1468 break
1469 HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1470 HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1471 HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1472 HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1473 HANDLE_TYPE(FLOAT, Float, float);
1474 HANDLE_TYPE(DOUBLE, Double, double);
1475 HANDLE_TYPE(BOOL, Bool, bool);
1476 #undef HANDLE_TYPE
1477 }
1478 }
1479 } else if (!is_cleared) {
1480 result += WireFormatLite::TagSize(number, real_type(type));
1481 switch (real_type(type)) {
1482 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1483 case WireFormatLite::TYPE_##UPPERCASE: \
1484 result += WireFormatLite::CAMELCASE##Size(LOWERCASE); \
1485 break
1486
1487 HANDLE_TYPE(INT32, Int32, int32_t_value);
1488 HANDLE_TYPE(INT64, Int64, int64_t_value);
1489 HANDLE_TYPE(UINT32, UInt32, uint32_t_value);
1490 HANDLE_TYPE(UINT64, UInt64, uint64_t_value);
1491 HANDLE_TYPE(SINT32, SInt32, int32_t_value);
1492 HANDLE_TYPE(SINT64, SInt64, int64_t_value);
1493 HANDLE_TYPE(STRING, String, *ptr.string_value);
1494 HANDLE_TYPE(BYTES, Bytes, *ptr.string_value);
1495 HANDLE_TYPE(ENUM, Enum, enum_value);
1496 HANDLE_TYPE(GROUP, Group, *ptr.message_value);
1497 #undef HANDLE_TYPE
1498 case WireFormatLite::TYPE_MESSAGE: {
1499 result += WireFormatLite::LengthDelimitedSize(
1500 is_lazy ? ptr.lazymessage_value->ByteSizeLong()
1501 : ptr.message_value->ByteSizeLong());
1502 break;
1503 }
1504
1505 // Stuff with fixed size.
1506 #define HANDLE_TYPE(UPPERCASE, CAMELCASE) \
1507 case WireFormatLite::TYPE_##UPPERCASE: \
1508 result += WireFormatLite::k##CAMELCASE##Size; \
1509 break
1510 HANDLE_TYPE(FIXED32, Fixed32);
1511 HANDLE_TYPE(FIXED64, Fixed64);
1512 HANDLE_TYPE(SFIXED32, SFixed32);
1513 HANDLE_TYPE(SFIXED64, SFixed64);
1514 HANDLE_TYPE(FLOAT, Float);
1515 HANDLE_TYPE(DOUBLE, Double);
1516 HANDLE_TYPE(BOOL, Bool);
1517 #undef HANDLE_TYPE
1518 }
1519 }
1520
1521 return result;
1522 }
1523
GetSize() const1524 int ExtensionSet::Extension::GetSize() const {
1525 ABSL_DCHECK(is_repeated);
1526 switch (cpp_type(type)) {
1527 #define HANDLE_TYPE(UPPERCASE, LOWERCASE) \
1528 case WireFormatLite::CPPTYPE_##UPPERCASE: \
1529 return ptr.repeated_##LOWERCASE##_value->size()
1530
1531 HANDLE_TYPE(INT32, int32_t);
1532 HANDLE_TYPE(INT64, int64_t);
1533 HANDLE_TYPE(UINT32, uint32_t);
1534 HANDLE_TYPE(UINT64, uint64_t);
1535 HANDLE_TYPE(FLOAT, float);
1536 HANDLE_TYPE(DOUBLE, double);
1537 HANDLE_TYPE(BOOL, bool);
1538 HANDLE_TYPE(ENUM, enum);
1539 HANDLE_TYPE(STRING, string);
1540 HANDLE_TYPE(MESSAGE, message);
1541 #undef HANDLE_TYPE
1542 }
1543
1544 ABSL_LOG(FATAL) << "Can't get here.";
1545 return 0;
1546 }
1547
1548 // This function deletes all allocated objects. This function should be only
1549 // called if the Extension was created without an arena.
Free()1550 void ExtensionSet::Extension::Free() {
1551 if (is_repeated) {
1552 switch (cpp_type(type)) {
1553 #define HANDLE_TYPE(UPPERCASE, LOWERCASE) \
1554 case WireFormatLite::CPPTYPE_##UPPERCASE: \
1555 delete ptr.repeated_##LOWERCASE##_value; \
1556 break
1557
1558 HANDLE_TYPE(INT32, int32_t);
1559 HANDLE_TYPE(INT64, int64_t);
1560 HANDLE_TYPE(UINT32, uint32_t);
1561 HANDLE_TYPE(UINT64, uint64_t);
1562 HANDLE_TYPE(FLOAT, float);
1563 HANDLE_TYPE(DOUBLE, double);
1564 HANDLE_TYPE(BOOL, bool);
1565 HANDLE_TYPE(ENUM, enum);
1566 HANDLE_TYPE(STRING, string);
1567 HANDLE_TYPE(MESSAGE, message);
1568 #undef HANDLE_TYPE
1569 }
1570 } else {
1571 switch (cpp_type(type)) {
1572 case WireFormatLite::CPPTYPE_STRING:
1573 delete ptr.string_value;
1574 break;
1575 case WireFormatLite::CPPTYPE_MESSAGE:
1576 if (is_lazy) {
1577 delete ptr.lazymessage_value;
1578 } else {
1579 delete ptr.message_value;
1580 }
1581 break;
1582 default:
1583 break;
1584 }
1585 }
1586 }
1587
1588 // Defined in extension_set_heavy.cc.
1589 // int ExtensionSet::Extension::SpaceUsedExcludingSelf() const
1590
IsInitialized(const ExtensionSet * ext_set,const MessageLite * extendee,int number,Arena * arena) const1591 bool ExtensionSet::Extension::IsInitialized(const ExtensionSet* ext_set,
1592 const MessageLite* extendee,
1593 int number, Arena* arena) const {
1594 if (cpp_type(type) != WireFormatLite::CPPTYPE_MESSAGE) return true;
1595
1596 if (is_repeated) {
1597 for (int i = 0; i < ptr.repeated_message_value->size(); i++) {
1598 if (!ptr.repeated_message_value->Get(i).IsInitialized()) {
1599 return false;
1600 }
1601 }
1602 return true;
1603 }
1604
1605 if (is_cleared) return true;
1606
1607 if (!is_lazy) return ptr.message_value->IsInitialized();
1608
1609 const MessageLite* prototype =
1610 ext_set->GetPrototypeForLazyMessage(extendee, number);
1611 ABSL_DCHECK_NE(prototype, nullptr)
1612 << "extendee: " << extendee->GetTypeName() << "; number: " << number;
1613 return ptr.lazymessage_value->IsInitialized(prototype, arena);
1614 }
1615
1616 // Dummy key method to avoid weak vtable.
UnusedKeyMethod()1617 void ExtensionSet::LazyMessageExtension::UnusedKeyMethod() {}
1618
FindOrNull(int key) const1619 const ExtensionSet::Extension* ExtensionSet::FindOrNull(int key) const {
1620 if (flat_size_ == 0) {
1621 return nullptr;
1622 } else if (PROTOBUF_PREDICT_TRUE(!is_large())) {
1623 for (auto it = flat_begin(), end = flat_end();
1624 it != end && it->first <= key; ++it) {
1625 if (it->first == key) return &it->second;
1626 }
1627 return nullptr;
1628 } else {
1629 return FindOrNullInLargeMap(key);
1630 }
1631 }
1632
FindOrNullInLargeMap(int key) const1633 const ExtensionSet::Extension* ExtensionSet::FindOrNullInLargeMap(
1634 int key) const {
1635 assert(is_large());
1636 LargeMap::const_iterator it = map_.large->find(key);
1637 if (it != map_.large->end()) {
1638 return &it->second;
1639 }
1640 return nullptr;
1641 }
1642
FindOrNull(int key)1643 ExtensionSet::Extension* ExtensionSet::FindOrNull(int key) {
1644 const auto* const_this = this;
1645 return const_cast<ExtensionSet::Extension*>(const_this->FindOrNull(key));
1646 }
1647
FindOrNullInLargeMap(int key)1648 ExtensionSet::Extension* ExtensionSet::FindOrNullInLargeMap(int key) {
1649 const auto* const_this = this;
1650 return const_cast<ExtensionSet::Extension*>(
1651 const_this->FindOrNullInLargeMap(key));
1652 }
1653
Insert(int key)1654 std::pair<ExtensionSet::Extension*, bool> ExtensionSet::Insert(int key) {
1655 if (PROTOBUF_PREDICT_FALSE(is_large())) {
1656 auto maybe = map_.large->insert({key, Extension()});
1657 return {&maybe.first->second, maybe.second};
1658 }
1659 KeyValue* end = flat_end();
1660 KeyValue* it = flat_begin();
1661 for (; it != end && it->first <= key; ++it) {
1662 if (it->first == key) return {&it->second, false};
1663 }
1664 if (flat_size_ < flat_capacity_) {
1665 std::copy_backward(it, end, end + 1);
1666 ++flat_size_;
1667 it->first = key;
1668 it->second = Extension();
1669 return {&it->second, true};
1670 }
1671 GrowCapacity(flat_size_ + 1);
1672 return Insert(key);
1673 }
1674
1675 namespace {
IsPowerOfTwo(size_t n)1676 constexpr bool IsPowerOfTwo(size_t n) { return (n & (n - 1)) == 0; }
1677 } // namespace
1678
GrowCapacity(size_t minimum_new_capacity)1679 void ExtensionSet::GrowCapacity(size_t minimum_new_capacity) {
1680 if (PROTOBUF_PREDICT_FALSE(is_large())) {
1681 return; // LargeMap does not have a "reserve" method.
1682 }
1683 if (flat_capacity_ >= minimum_new_capacity) {
1684 return;
1685 }
1686
1687 auto new_flat_capacity = flat_capacity_;
1688 do {
1689 new_flat_capacity = new_flat_capacity == 0 ? 1 : new_flat_capacity * 4;
1690 } while (new_flat_capacity < minimum_new_capacity);
1691
1692 KeyValue* begin = flat_begin();
1693 KeyValue* end = flat_end();
1694 AllocatedData new_map;
1695 Arena* const arena = arena_;
1696 if (new_flat_capacity > kMaximumFlatCapacity) {
1697 new_map.large = Arena::Create<LargeMap>(arena);
1698 LargeMap::iterator hint = new_map.large->begin();
1699 for (const KeyValue* it = begin; it != end; ++it) {
1700 hint = new_map.large->insert(hint, {it->first, it->second});
1701 }
1702 flat_size_ = static_cast<uint16_t>(-1);
1703 ABSL_DCHECK(is_large());
1704 } else {
1705 new_map.flat = Arena::CreateArray<KeyValue>(arena, new_flat_capacity);
1706 std::copy(begin, end, new_map.flat);
1707 }
1708
1709 // ReturnArrayMemory is more efficient with power-of-2 bytes, and
1710 // sizeof(KeyValue) is a power-of-2 on 64-bit platforms. flat_capacity_ is
1711 // always a power-of-2.
1712 ABSL_DCHECK(IsPowerOfTwo(sizeof(KeyValue)) || sizeof(void*) != 8)
1713 << sizeof(KeyValue) << " " << sizeof(void*);
1714 ABSL_DCHECK(IsPowerOfTwo(flat_capacity_));
1715 if (flat_capacity_ > 0) {
1716 if (arena == nullptr) {
1717 DeleteFlatMap(begin, flat_capacity_);
1718 } else {
1719 arena->ReturnArrayMemory(begin, sizeof(KeyValue) * flat_capacity_);
1720 }
1721 }
1722 flat_capacity_ = new_flat_capacity;
1723 map_ = new_map;
1724 }
1725
1726 #if (__cplusplus < 201703) && \
1727 (!defined(_MSC_VER) || (_MSC_VER >= 1900 && _MSC_VER < 1912))
1728 // static
1729 constexpr uint16_t ExtensionSet::kMaximumFlatCapacity;
1730 #endif // (__cplusplus < 201703) && (!defined(_MSC_VER) || (_MSC_VER >= 1900
1731 // && _MSC_VER < 1912))
1732
Erase(int key)1733 void ExtensionSet::Erase(int key) {
1734 if (PROTOBUF_PREDICT_FALSE(is_large())) {
1735 map_.large->erase(key);
1736 return;
1737 }
1738 KeyValue* end = flat_end();
1739 for (KeyValue* it = flat_begin(); it != end && it->first <= key; ++it) {
1740 if (it->first == key) {
1741 std::copy(it + 1, end, it);
1742 --flat_size_;
1743 return;
1744 }
1745 }
1746 }
1747
1748 // ==================================================================
1749 // Default repeated field instances for iterator-compatible accessors
1750
default_instance()1751 const RepeatedPrimitiveDefaults* RepeatedPrimitiveDefaults::default_instance() {
1752 static auto instance = OnShutdownDelete(new RepeatedPrimitiveDefaults);
1753 return instance;
1754 }
1755
1756 const RepeatedStringTypeTraits::RepeatedFieldType*
GetDefaultRepeatedField()1757 RepeatedStringTypeTraits::GetDefaultRepeatedField() {
1758 static auto instance = OnShutdownDelete(new RepeatedFieldType);
1759 return instance;
1760 }
1761
InternalSerializeFieldWithCachedSizesToArray(const MessageLite * extendee,const ExtensionSet * extension_set,int number,uint8_t * target,io::EpsCopyOutputStream * stream) const1762 uint8_t* ExtensionSet::Extension::InternalSerializeFieldWithCachedSizesToArray(
1763 const MessageLite* extendee, const ExtensionSet* extension_set, int number,
1764 uint8_t* target, io::EpsCopyOutputStream* stream) const {
1765 if (is_repeated) {
1766 if (is_packed) {
1767 if (cached_size() == 0) return target;
1768
1769 target = stream->EnsureSpace(target);
1770 target = WireFormatLite::WriteTagToArray(
1771 number, WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
1772 target = WireFormatLite::WriteInt32NoTagToArray(cached_size(), target);
1773
1774 switch (real_type(type)) {
1775 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1776 case WireFormatLite::TYPE_##UPPERCASE: \
1777 for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1778 target = stream->EnsureSpace(target); \
1779 target = WireFormatLite::Write##CAMELCASE##NoTagToArray( \
1780 ptr.repeated_##LOWERCASE##_value->Get(i), target); \
1781 } \
1782 break
1783
1784 HANDLE_TYPE(INT32, Int32, int32_t);
1785 HANDLE_TYPE(INT64, Int64, int64_t);
1786 HANDLE_TYPE(UINT32, UInt32, uint32_t);
1787 HANDLE_TYPE(UINT64, UInt64, uint64_t);
1788 HANDLE_TYPE(SINT32, SInt32, int32_t);
1789 HANDLE_TYPE(SINT64, SInt64, int64_t);
1790 HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1791 HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1792 HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1793 HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1794 HANDLE_TYPE(FLOAT, Float, float);
1795 HANDLE_TYPE(DOUBLE, Double, double);
1796 HANDLE_TYPE(BOOL, Bool, bool);
1797 HANDLE_TYPE(ENUM, Enum, enum);
1798 #undef HANDLE_TYPE
1799
1800 case WireFormatLite::TYPE_STRING:
1801 case WireFormatLite::TYPE_BYTES:
1802 case WireFormatLite::TYPE_GROUP:
1803 case WireFormatLite::TYPE_MESSAGE:
1804 ABSL_LOG(FATAL) << "Non-primitive types can't be packed.";
1805 break;
1806 }
1807 } else {
1808 switch (real_type(type)) {
1809 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1810 case WireFormatLite::TYPE_##UPPERCASE: \
1811 for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1812 target = stream->EnsureSpace(target); \
1813 target = WireFormatLite::Write##CAMELCASE##ToArray( \
1814 number, ptr.repeated_##LOWERCASE##_value->Get(i), target); \
1815 } \
1816 break
1817
1818 HANDLE_TYPE(INT32, Int32, int32_t);
1819 HANDLE_TYPE(INT64, Int64, int64_t);
1820 HANDLE_TYPE(UINT32, UInt32, uint32_t);
1821 HANDLE_TYPE(UINT64, UInt64, uint64_t);
1822 HANDLE_TYPE(SINT32, SInt32, int32_t);
1823 HANDLE_TYPE(SINT64, SInt64, int64_t);
1824 HANDLE_TYPE(FIXED32, Fixed32, uint32_t);
1825 HANDLE_TYPE(FIXED64, Fixed64, uint64_t);
1826 HANDLE_TYPE(SFIXED32, SFixed32, int32_t);
1827 HANDLE_TYPE(SFIXED64, SFixed64, int64_t);
1828 HANDLE_TYPE(FLOAT, Float, float);
1829 HANDLE_TYPE(DOUBLE, Double, double);
1830 HANDLE_TYPE(BOOL, Bool, bool);
1831 HANDLE_TYPE(ENUM, Enum, enum);
1832 #undef HANDLE_TYPE
1833 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
1834 case WireFormatLite::TYPE_##UPPERCASE: \
1835 for (int i = 0; i < ptr.repeated_##LOWERCASE##_value->size(); i++) { \
1836 target = stream->EnsureSpace(target); \
1837 target = stream->WriteString( \
1838 number, ptr.repeated_##LOWERCASE##_value->Get(i), target); \
1839 } \
1840 break
1841 HANDLE_TYPE(STRING, String, string);
1842 HANDLE_TYPE(BYTES, Bytes, string);
1843 #undef HANDLE_TYPE
1844 case WireFormatLite::TYPE_GROUP:
1845 for (int i = 0; i < ptr.repeated_message_value->size(); i++) {
1846 target = stream->EnsureSpace(target);
1847 target = WireFormatLite::InternalWriteGroup(
1848 number, ptr.repeated_message_value->Get(i), target, stream);
1849 }
1850 break;
1851 case WireFormatLite::TYPE_MESSAGE:
1852 for (int i = 0; i < ptr.repeated_message_value->size(); i++) {
1853 auto& msg = ptr.repeated_message_value->Get(i);
1854 target = WireFormatLite::InternalWriteMessage(
1855 number, msg, msg.GetCachedSize(), target, stream);
1856 }
1857 break;
1858 }
1859 }
1860 } else if (!is_cleared) {
1861 switch (real_type(type)) {
1862 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE) \
1863 case WireFormatLite::TYPE_##UPPERCASE: \
1864 target = stream->EnsureSpace(target); \
1865 target = WireFormatLite::Write##CAMELCASE##ToArray(number, VALUE, target); \
1866 break
1867
1868 HANDLE_TYPE(INT32, Int32, int32_t_value);
1869 HANDLE_TYPE(INT64, Int64, int64_t_value);
1870 HANDLE_TYPE(UINT32, UInt32, uint32_t_value);
1871 HANDLE_TYPE(UINT64, UInt64, uint64_t_value);
1872 HANDLE_TYPE(SINT32, SInt32, int32_t_value);
1873 HANDLE_TYPE(SINT64, SInt64, int64_t_value);
1874 HANDLE_TYPE(FIXED32, Fixed32, uint32_t_value);
1875 HANDLE_TYPE(FIXED64, Fixed64, uint64_t_value);
1876 HANDLE_TYPE(SFIXED32, SFixed32, int32_t_value);
1877 HANDLE_TYPE(SFIXED64, SFixed64, int64_t_value);
1878 HANDLE_TYPE(FLOAT, Float, float_value);
1879 HANDLE_TYPE(DOUBLE, Double, double_value);
1880 HANDLE_TYPE(BOOL, Bool, bool_value);
1881 HANDLE_TYPE(ENUM, Enum, enum_value);
1882 #undef HANDLE_TYPE
1883 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE) \
1884 case WireFormatLite::TYPE_##UPPERCASE: \
1885 target = stream->EnsureSpace(target); \
1886 target = stream->WriteString(number, VALUE, target); \
1887 break
1888 HANDLE_TYPE(STRING, String, *ptr.string_value);
1889 HANDLE_TYPE(BYTES, Bytes, *ptr.string_value);
1890 #undef HANDLE_TYPE
1891 case WireFormatLite::TYPE_GROUP:
1892 target = stream->EnsureSpace(target);
1893 target = WireFormatLite::InternalWriteGroup(number, *ptr.message_value,
1894 target, stream);
1895 break;
1896 case WireFormatLite::TYPE_MESSAGE:
1897 if (is_lazy) {
1898 const auto* prototype =
1899 extension_set->GetPrototypeForLazyMessage(extendee, number);
1900 target = ptr.lazymessage_value->WriteMessageToArray(prototype, number,
1901 target, stream);
1902 } else {
1903 target = WireFormatLite::InternalWriteMessage(
1904 number, *ptr.message_value, ptr.message_value->GetCachedSize(),
1905 target, stream);
1906 }
1907 break;
1908 }
1909 }
1910 return target;
1911 }
1912
GetPrototypeForLazyMessage(const MessageLite * extendee,int number) const1913 const MessageLite* ExtensionSet::GetPrototypeForLazyMessage(
1914 const MessageLite* extendee, int number) const {
1915 GeneratedExtensionFinder finder(extendee);
1916 bool was_packed_on_wire = false;
1917 ExtensionInfo extension_info;
1918 if (!FindExtensionInfoFromFieldNumber(
1919 WireFormatLite::WireType::WIRETYPE_LENGTH_DELIMITED, number, &finder,
1920 &extension_info, &was_packed_on_wire)) {
1921 return nullptr;
1922 }
1923 return extension_info.message_info.prototype;
1924 }
1925
1926 uint8_t*
InternalSerializeMessageSetItemWithCachedSizesToArray(const MessageLite * extendee,const ExtensionSet * extension_set,int number,uint8_t * target,io::EpsCopyOutputStream * stream) const1927 ExtensionSet::Extension::InternalSerializeMessageSetItemWithCachedSizesToArray(
1928 const MessageLite* extendee, const ExtensionSet* extension_set, int number,
1929 uint8_t* target, io::EpsCopyOutputStream* stream) const {
1930 if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
1931 // Not a valid MessageSet extension, but serialize it the normal way.
1932 ABSL_LOG(WARNING) << "Invalid message set extension.";
1933 return InternalSerializeFieldWithCachedSizesToArray(extendee, extension_set,
1934 number, target, stream);
1935 }
1936
1937 if (is_cleared) return target;
1938
1939 target = stream->EnsureSpace(target);
1940 // Start group.
1941 target = io::CodedOutputStream::WriteTagToArray(
1942 WireFormatLite::kMessageSetItemStartTag, target);
1943 // Write type ID.
1944 target = WireFormatLite::WriteUInt32ToArray(
1945 WireFormatLite::kMessageSetTypeIdNumber, number, target);
1946 // Write message.
1947 if (is_lazy) {
1948 const auto* prototype =
1949 extension_set->GetPrototypeForLazyMessage(extendee, number);
1950 target = ptr.lazymessage_value->WriteMessageToArray(
1951 prototype, WireFormatLite::kMessageSetMessageNumber, target, stream);
1952 } else {
1953 target = WireFormatLite::InternalWriteMessage(
1954 WireFormatLite::kMessageSetMessageNumber, *ptr.message_value,
1955 ptr.message_value->GetCachedSize(), target, stream);
1956 }
1957 // End group.
1958 target = stream->EnsureSpace(target);
1959 target = io::CodedOutputStream::WriteTagToArray(
1960 WireFormatLite::kMessageSetItemEndTag, target);
1961 return target;
1962 }
1963
MessageSetItemByteSize(int number) const1964 size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
1965 if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
1966 // Not a valid MessageSet extension, but compute the byte size for it the
1967 // normal way.
1968 return ByteSize(number);
1969 }
1970
1971 if (is_cleared) return 0;
1972
1973 size_t our_size = WireFormatLite::kMessageSetItemTagsSize;
1974
1975 // type_id
1976 our_size += io::CodedOutputStream::VarintSize32(number);
1977
1978 // message
1979 our_size += WireFormatLite::LengthDelimitedSize(
1980 is_lazy ? ptr.lazymessage_value->ByteSizeLong()
1981 : ptr.message_value->ByteSizeLong());
1982
1983 return our_size;
1984 }
1985
MessageSetByteSize() const1986 size_t ExtensionSet::MessageSetByteSize() const {
1987 size_t total_size = 0;
1988 ForEach(
1989 [&total_size](int number, const Extension& ext) {
1990 total_size += ext.MessageSetItemByteSize(number);
1991 },
1992 Prefetch{});
1993 return total_size;
1994 }
1995
FindExtensionLazyEagerVerifyFn(const MessageLite * extendee,int number)1996 LazyEagerVerifyFnType FindExtensionLazyEagerVerifyFn(
1997 const MessageLite* extendee, int number) {
1998 const ExtensionInfo* registered = FindRegisteredExtension(extendee, number);
1999 if (registered != nullptr) {
2000 return registered->lazy_eager_verify_func;
2001 }
2002 return nullptr;
2003 }
2004
2005 std::atomic<ExtensionSet::LazyMessageExtension* (*)(Arena* arena)>
2006 ExtensionSet::maybe_create_lazy_extension_;
2007
2008 } // namespace internal
2009 } // namespace protobuf
2010 } // namespace google
2011
2012 #include "google/protobuf/port_undef.inc"
2013