1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 // Author: kenton@google.com (Kenton Varda)
9 // Based on original Protocol Buffers design by
10 // Sanjay Ghemawat, Jeff Dean, and others.
11 //
12 // Contains methods defined in extension_set.h which cannot be part of the
13 // lite library because they use descriptors or reflection.
14
15 #include <cstddef>
16 #include <cstdint>
17 #include <initializer_list>
18 #include <vector>
19
20 #include "absl/base/attributes.h"
21 #include "absl/log/absl_check.h"
22 #include "google/protobuf/arena.h"
23 #include "google/protobuf/descriptor.h"
24 #include "google/protobuf/descriptor.pb.h"
25 #include "google/protobuf/extension_set.h"
26 #include "google/protobuf/extension_set_inl.h"
27 #include "google/protobuf/generated_message_reflection.h"
28 #include "google/protobuf/generated_message_tctable_impl.h"
29 #include "google/protobuf/io/coded_stream.h"
30 #include "google/protobuf/message.h"
31 #include "google/protobuf/message_lite.h"
32 #include "google/protobuf/parse_context.h"
33 #include "google/protobuf/port.h"
34 #include "google/protobuf/repeated_field.h"
35 #include "google/protobuf/unknown_field_set.h"
36 #include "google/protobuf/wire_format_lite.h"
37
38
39 // Must be included last.
40 #include "google/protobuf/port_def.inc"
41
42 namespace google {
43 namespace protobuf {
44 namespace internal {
45
46 // Implementation of ExtensionFinder which finds extensions in a given
47 // DescriptorPool, using the given MessageFactory to construct sub-objects.
48 // This class is implemented in extension_set_heavy.cc.
49 class DescriptorPoolExtensionFinder {
50 public:
DescriptorPoolExtensionFinder(const DescriptorPool * pool,MessageFactory * factory,const Descriptor * extendee)51 DescriptorPoolExtensionFinder(const DescriptorPool* pool,
52 MessageFactory* factory,
53 const Descriptor* extendee)
54 : pool_(pool), factory_(factory), containing_type_(extendee) {}
55
56 bool Find(int number, ExtensionInfo* output);
57
58 private:
59 const DescriptorPool* pool_;
60 MessageFactory* factory_;
61 const Descriptor* containing_type_;
62 };
63
AppendToList(const Descriptor * extendee,const DescriptorPool * pool,std::vector<const FieldDescriptor * > * output) const64 void ExtensionSet::AppendToList(
65 const Descriptor* extendee, const DescriptorPool* pool,
66 std::vector<const FieldDescriptor*>* output) const {
67 ForEach(
68 [extendee, pool, &output](int number, const Extension& ext) {
69 bool has = false;
70 if (ext.is_repeated) {
71 has = ext.GetSize() > 0;
72 } else {
73 has = !ext.is_cleared;
74 }
75
76 if (has) {
77 // TODO: Looking up each field by number is somewhat
78 // unfortunate.
79 // Is there a better way? The problem is that descriptors are
80 // lazily-initialized, so they might not even be constructed until
81 // AppendToList() is called.
82
83 if (ext.descriptor == nullptr) {
84 output->push_back(pool->FindExtensionByNumber(extendee, number));
85 } else {
86 output->push_back(ext.descriptor);
87 }
88 }
89 },
90 Prefetch{});
91 }
92
real_type(FieldType type)93 inline FieldDescriptor::Type real_type(FieldType type) {
94 ABSL_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE);
95 return static_cast<FieldDescriptor::Type>(type);
96 }
97
cpp_type(FieldType type)98 inline FieldDescriptor::CppType cpp_type(FieldType type) {
99 return FieldDescriptor::TypeToCppType(
100 static_cast<FieldDescriptor::Type>(type));
101 }
102
field_type(FieldType type)103 inline WireFormatLite::FieldType field_type(FieldType type) {
104 ABSL_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
105 return static_cast<WireFormatLite::FieldType>(type);
106 }
107
108 #define ABSL_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE) \
109 ABSL_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED \
110 : FieldDescriptor::LABEL_OPTIONAL, \
111 FieldDescriptor::LABEL_##LABEL); \
112 ABSL_DCHECK_EQ(cpp_type((EXTENSION).type), FieldDescriptor::CPPTYPE_##CPPTYPE)
113
GetMessage(int number,const Descriptor * message_type,MessageFactory * factory) const114 const MessageLite& ExtensionSet::GetMessage(int number,
115 const Descriptor* message_type,
116 MessageFactory* factory) const {
117 const Extension* extension = FindOrNull(number);
118 if (extension == nullptr || extension->is_cleared) {
119 // Not present. Return the default value.
120 return *factory->GetPrototype(message_type);
121 } else {
122 ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
123 if (extension->is_lazy) {
124 return extension->ptr.lazymessage_value->GetMessage(
125 *factory->GetPrototype(message_type), arena_);
126 } else {
127 return *extension->ptr.message_value;
128 }
129 }
130 }
131
MutableMessage(const FieldDescriptor * descriptor,MessageFactory * factory)132 MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
133 MessageFactory* factory) {
134 Extension* extension;
135 if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
136 extension->type = descriptor->type();
137 ABSL_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
138 extension->is_repeated = false;
139 extension->is_pointer = true;
140 extension->is_packed = false;
141 const MessageLite* prototype =
142 factory->GetPrototype(descriptor->message_type());
143 extension->is_lazy = false;
144 extension->ptr.message_value = prototype->New(arena_);
145 extension->is_cleared = false;
146 return extension->ptr.message_value;
147 } else {
148 ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
149 extension->is_cleared = false;
150 if (extension->is_lazy) {
151 return extension->ptr.lazymessage_value->MutableMessage(
152 *factory->GetPrototype(descriptor->message_type()), arena_);
153 } else {
154 return extension->ptr.message_value;
155 }
156 }
157 }
158
ReleaseMessage(const FieldDescriptor * descriptor,MessageFactory * factory)159 MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
160 MessageFactory* factory) {
161 Extension* extension = FindOrNull(descriptor->number());
162 if (extension == nullptr) {
163 // Not present. Return nullptr.
164 return nullptr;
165 } else {
166 ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
167 MessageLite* ret = nullptr;
168 if (extension->is_lazy) {
169 ret = extension->ptr.lazymessage_value->ReleaseMessage(
170 *factory->GetPrototype(descriptor->message_type()), arena_);
171 if (arena_ == nullptr) {
172 delete extension->ptr.lazymessage_value;
173 }
174 } else {
175 if (arena_ != nullptr) {
176 ret = extension->ptr.message_value->New();
177 ret->CheckTypeAndMergeFrom(*extension->ptr.message_value);
178 } else {
179 ret = extension->ptr.message_value;
180 }
181 }
182 Erase(descriptor->number());
183 return ret;
184 }
185 }
186
UnsafeArenaReleaseMessage(const FieldDescriptor * descriptor,MessageFactory * factory)187 MessageLite* ExtensionSet::UnsafeArenaReleaseMessage(
188 const FieldDescriptor* descriptor, MessageFactory* factory) {
189 Extension* extension = FindOrNull(descriptor->number());
190 if (extension == nullptr) {
191 // Not present. Return nullptr.
192 return nullptr;
193 } else {
194 ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
195 MessageLite* ret = nullptr;
196 if (extension->is_lazy) {
197 ret = extension->ptr.lazymessage_value->UnsafeArenaReleaseMessage(
198 *factory->GetPrototype(descriptor->message_type()), arena_);
199 if (arena_ == nullptr) {
200 delete extension->ptr.lazymessage_value;
201 }
202 } else {
203 ret = extension->ptr.message_value;
204 }
205 Erase(descriptor->number());
206 return ret;
207 }
208 }
209
MaybeNewRepeatedExtension(const FieldDescriptor * descriptor)210 ExtensionSet::Extension* ExtensionSet::MaybeNewRepeatedExtension(
211 const FieldDescriptor* descriptor) {
212 Extension* extension;
213 if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
214 extension->type = descriptor->type();
215 ABSL_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
216 extension->is_repeated = true;
217 extension->is_pointer = true;
218 extension->ptr.repeated_message_value =
219 Arena::Create<RepeatedPtrField<MessageLite> >(arena_);
220 } else {
221 ABSL_DCHECK_TYPE(*extension, REPEATED, MESSAGE);
222 }
223 return extension;
224 }
225
AddMessage(const FieldDescriptor * descriptor,MessageFactory * factory)226 MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
227 MessageFactory* factory) {
228 Extension* extension = MaybeNewRepeatedExtension(descriptor);
229
230 // RepeatedPtrField<Message> does not know how to Add() since it cannot
231 // allocate an abstract object, so we have to be tricky.
232 MessageLite* result =
233 reinterpret_cast<internal::RepeatedPtrFieldBase*>(
234 extension->ptr.repeated_message_value)
235 ->AddFromCleared<GenericTypeHandler<MessageLite> >();
236 if (result == nullptr) {
237 const MessageLite* prototype;
238 if (extension->ptr.repeated_message_value->empty()) {
239 prototype = factory->GetPrototype(descriptor->message_type());
240 ABSL_CHECK(prototype != nullptr);
241 } else {
242 prototype = &extension->ptr.repeated_message_value->Get(0);
243 }
244 result = prototype->New(arena_);
245 extension->ptr.repeated_message_value->AddAllocated(result);
246 }
247 return result;
248 }
249
AddAllocatedMessage(const FieldDescriptor * descriptor,MessageLite * new_entry)250 void ExtensionSet::AddAllocatedMessage(const FieldDescriptor* descriptor,
251 MessageLite* new_entry) {
252 Extension* extension = MaybeNewRepeatedExtension(descriptor);
253
254 extension->ptr.repeated_message_value->AddAllocated(new_entry);
255 }
256
UnsafeArenaAddAllocatedMessage(const FieldDescriptor * descriptor,MessageLite * new_entry)257 void ExtensionSet::UnsafeArenaAddAllocatedMessage(
258 const FieldDescriptor* descriptor, MessageLite* new_entry) {
259 Extension* extension = MaybeNewRepeatedExtension(descriptor);
260
261 extension->ptr.repeated_message_value->UnsafeArenaAddAllocated(new_entry);
262 }
263
ValidateEnumUsingDescriptor(const void * arg,int number)264 static bool ValidateEnumUsingDescriptor(const void* arg, int number) {
265 return reinterpret_cast<const EnumDescriptor*>(arg)->FindValueByNumber(
266 number) != nullptr;
267 }
268
Find(int number,ExtensionInfo * output)269 bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) {
270 const FieldDescriptor* extension =
271 pool_->FindExtensionByNumber(containing_type_, number);
272 if (extension == nullptr) {
273 return false;
274 } else {
275 output->type = extension->type();
276 output->is_repeated = extension->is_repeated();
277 output->is_packed = extension->is_packed();
278 output->descriptor = extension;
279 if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
280 output->message_info.prototype =
281 factory_->GetPrototype(extension->message_type());
282 output->message_info.tc_table =
283 output->message_info.prototype->GetTcParseTable();
284 ABSL_CHECK(output->message_info.prototype != nullptr)
285 << "Extension factory's GetPrototype() returned nullptr; extension: "
286 << extension->full_name();
287
288 if (extension->options().has_lazy()) {
289 output->is_lazy = extension->options().lazy() ? LazyAnnotation::kLazy
290 : LazyAnnotation::kEager;
291 }
292 } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
293 output->enum_validity_check.func = ValidateEnumUsingDescriptor;
294 output->enum_validity_check.arg = extension->enum_type();
295 }
296
297 return true;
298 }
299 }
300
301
FindExtension(int wire_type,uint32_t field,const Message * extendee,const internal::ParseContext * ctx,ExtensionInfo * extension,bool * was_packed_on_wire)302 bool ExtensionSet::FindExtension(int wire_type, uint32_t field,
303 const Message* extendee,
304 const internal::ParseContext* ctx,
305 ExtensionInfo* extension,
306 bool* was_packed_on_wire) {
307 if (ctx->data().pool == nullptr) {
308 GeneratedExtensionFinder finder(extendee);
309 if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension,
310 was_packed_on_wire)) {
311 return false;
312 }
313 } else {
314 DescriptorPoolExtensionFinder finder(ctx->data().pool, ctx->data().factory,
315 extendee->GetDescriptor());
316 if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension,
317 was_packed_on_wire)) {
318 return false;
319 }
320 }
321 return true;
322 }
323
324
ParseField(uint64_t tag,const char * ptr,const Message * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)325 const char* ExtensionSet::ParseField(uint64_t tag, const char* ptr,
326 const Message* extendee,
327 internal::InternalMetadata* metadata,
328 internal::ParseContext* ctx) {
329 int number = tag >> 3;
330 bool was_packed_on_wire;
331 ExtensionInfo extension;
332 if (!FindExtension(tag & 7, number, extendee, ctx, &extension,
333 &was_packed_on_wire)) {
334 return UnknownFieldParse(
335 tag, metadata->mutable_unknown_fields<UnknownFieldSet>(), ptr, ctx);
336 }
337 return ParseFieldWithExtensionInfo<UnknownFieldSet>(
338 number, was_packed_on_wire, extension, metadata, ptr, ctx);
339 }
340
ParseFieldMaybeLazily(uint64_t tag,const char * ptr,const Message * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)341 const char* ExtensionSet::ParseFieldMaybeLazily(
342 uint64_t tag, const char* ptr, const Message* extendee,
343 internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
344 return ParseField(tag, ptr, extendee, metadata, ctx);
345 }
346
ParseMessageSetItem(const char * ptr,const Message * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)347 const char* ExtensionSet::ParseMessageSetItem(
348 const char* ptr, const Message* extendee,
349 internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
350 return ParseMessageSetItemTmpl<Message, UnknownFieldSet>(ptr, extendee,
351 metadata, ctx);
352 }
353
SpaceUsedExcludingSelf() const354 int ExtensionSet::SpaceUsedExcludingSelf() const {
355 return internal::FromIntSize(SpaceUsedExcludingSelfLong());
356 }
357
SpaceUsedExcludingSelfLong() const358 size_t ExtensionSet::SpaceUsedExcludingSelfLong() const {
359 size_t total_size =
360 (is_large() ? map_.large->size() : flat_capacity_) * sizeof(KeyValue);
361 ForEach(
362 [&total_size](int /* number */, const Extension& ext) {
363 total_size += ext.SpaceUsedExcludingSelfLong();
364 },
365 Prefetch{});
366 return total_size;
367 }
368
RepeatedMessage_SpaceUsedExcludingSelfLong(RepeatedPtrFieldBase * field)369 inline size_t ExtensionSet::RepeatedMessage_SpaceUsedExcludingSelfLong(
370 RepeatedPtrFieldBase* field) {
371 return field->SpaceUsedExcludingSelfLong<GenericTypeHandler<Message> >();
372 }
373
SpaceUsedExcludingSelfLong() const374 size_t ExtensionSet::Extension::SpaceUsedExcludingSelfLong() const {
375 size_t total_size = 0;
376 if (is_repeated) {
377 switch (cpp_type(type)) {
378 #define HANDLE_TYPE(UPPERCASE, LOWERCASE) \
379 case FieldDescriptor::CPPTYPE_##UPPERCASE: \
380 total_size += \
381 sizeof(*ptr.repeated_##LOWERCASE##_value) + \
382 ptr.repeated_##LOWERCASE##_value->SpaceUsedExcludingSelfLong(); \
383 break
384
385 HANDLE_TYPE(INT32, int32_t);
386 HANDLE_TYPE(INT64, int64_t);
387 HANDLE_TYPE(UINT32, uint32_t);
388 HANDLE_TYPE(UINT64, uint64_t);
389 HANDLE_TYPE(FLOAT, float);
390 HANDLE_TYPE(DOUBLE, double);
391 HANDLE_TYPE(BOOL, bool);
392 HANDLE_TYPE(ENUM, enum);
393 HANDLE_TYPE(STRING, string);
394 #undef HANDLE_TYPE
395
396 case FieldDescriptor::CPPTYPE_MESSAGE:
397 // repeated_message_value is actually a RepeatedPtrField<MessageLite>,
398 // but MessageLite has no SpaceUsedLong(), so we must directly call
399 // RepeatedPtrFieldBase::SpaceUsedExcludingSelfLong() with a different
400 // type handler.
401 total_size += sizeof(*ptr.repeated_message_value) +
402 RepeatedMessage_SpaceUsedExcludingSelfLong(
403 reinterpret_cast<internal::RepeatedPtrFieldBase*>(
404 ptr.repeated_message_value));
405 break;
406 }
407 } else {
408 switch (cpp_type(type)) {
409 case FieldDescriptor::CPPTYPE_STRING:
410 total_size += sizeof(*ptr.string_value) +
411 StringSpaceUsedExcludingSelfLong(*ptr.string_value);
412 break;
413 case FieldDescriptor::CPPTYPE_MESSAGE:
414 if (is_lazy) {
415 total_size += ptr.lazymessage_value->SpaceUsedLong();
416 } else {
417 total_size +=
418 DownCastMessage<Message>(ptr.message_value)->SpaceUsedLong();
419 }
420 break;
421 default:
422 // No extra storage costs for primitive types.
423 break;
424 }
425 }
426 return total_size;
427 }
428
SerializeMessageSetWithCachedSizesToArray(const MessageLite * extendee,uint8_t * target) const429 uint8_t* ExtensionSet::SerializeMessageSetWithCachedSizesToArray(
430 const MessageLite* extendee, uint8_t* target) const {
431 io::EpsCopyOutputStream stream(
432 target, MessageSetByteSize(),
433 io::CodedOutputStream::IsDefaultSerializationDeterministic());
434 return InternalSerializeMessageSetWithCachedSizesToArray(extendee, target,
435 &stream);
436 }
437
438 #if defined(PROTOBUF_DESCRIPTOR_WEAK_MESSAGES_ALLOWED)
ShouldRegisterAtThisTime(std::initializer_list<WeakPrototypeRef> messages,bool is_preregistration)439 bool ExtensionSet::ShouldRegisterAtThisTime(
440 std::initializer_list<WeakPrototypeRef> messages, bool is_preregistration) {
441 bool has_all = true;
442 for (auto ref : messages) {
443 has_all = has_all && GetPrototypeForWeakDescriptor(ref.table, ref.index,
444 false) != nullptr;
445 }
446 return has_all == is_preregistration;
447 }
448 #endif // PROTOBUF_DESCRIPTOR_WEAK_MESSAGES_ALLOWED
449
450 } // namespace internal
451 } // namespace protobuf
452 } // namespace google
453
454 #include "google/protobuf/port_undef.inc"
455