• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: kenton@google.com (Kenton Varda)
9 //  Based on original Protocol Buffers design by
10 //  Sanjay Ghemawat, Jeff Dean, and others.
11 //
12 // Contains methods defined in extension_set.h which cannot be part of the
13 // lite library because they use descriptors or reflection.
14 
15 #include <cstddef>
16 #include <cstdint>
17 #include <initializer_list>
18 #include <vector>
19 
20 #include "absl/base/attributes.h"
21 #include "absl/log/absl_check.h"
22 #include "google/protobuf/arena.h"
23 #include "google/protobuf/descriptor.h"
24 #include "google/protobuf/descriptor.pb.h"
25 #include "google/protobuf/extension_set.h"
26 #include "google/protobuf/extension_set_inl.h"
27 #include "google/protobuf/generated_message_reflection.h"
28 #include "google/protobuf/generated_message_tctable_impl.h"
29 #include "google/protobuf/io/coded_stream.h"
30 #include "google/protobuf/message.h"
31 #include "google/protobuf/message_lite.h"
32 #include "google/protobuf/parse_context.h"
33 #include "google/protobuf/port.h"
34 #include "google/protobuf/repeated_field.h"
35 #include "google/protobuf/unknown_field_set.h"
36 #include "google/protobuf/wire_format_lite.h"
37 
38 
39 // Must be included last.
40 #include "google/protobuf/port_def.inc"
41 
42 namespace google {
43 namespace protobuf {
44 namespace internal {
45 
46 // Implementation of ExtensionFinder which finds extensions in a given
47 // DescriptorPool, using the given MessageFactory to construct sub-objects.
48 // This class is implemented in extension_set_heavy.cc.
49 class DescriptorPoolExtensionFinder {
50  public:
DescriptorPoolExtensionFinder(const DescriptorPool * pool,MessageFactory * factory,const Descriptor * extendee)51   DescriptorPoolExtensionFinder(const DescriptorPool* pool,
52                                 MessageFactory* factory,
53                                 const Descriptor* extendee)
54       : pool_(pool), factory_(factory), containing_type_(extendee) {}
55 
56   bool Find(int number, ExtensionInfo* output);
57 
58  private:
59   const DescriptorPool* pool_;
60   MessageFactory* factory_;
61   const Descriptor* containing_type_;
62 };
63 
AppendToList(const Descriptor * extendee,const DescriptorPool * pool,std::vector<const FieldDescriptor * > * output) const64 void ExtensionSet::AppendToList(
65     const Descriptor* extendee, const DescriptorPool* pool,
66     std::vector<const FieldDescriptor*>* output) const {
67   ForEach(
68       [extendee, pool, &output](int number, const Extension& ext) {
69         bool has = false;
70         if (ext.is_repeated) {
71           has = ext.GetSize() > 0;
72         } else {
73           has = !ext.is_cleared;
74         }
75 
76         if (has) {
77           // TODO: Looking up each field by number is somewhat
78           // unfortunate.
79           //   Is there a better way?  The problem is that descriptors are
80           //   lazily-initialized, so they might not even be constructed until
81           //   AppendToList() is called.
82 
83           if (ext.descriptor == nullptr) {
84             output->push_back(pool->FindExtensionByNumber(extendee, number));
85           } else {
86             output->push_back(ext.descriptor);
87           }
88         }
89       },
90       Prefetch{});
91 }
92 
real_type(FieldType type)93 inline FieldDescriptor::Type real_type(FieldType type) {
94   ABSL_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE);
95   return static_cast<FieldDescriptor::Type>(type);
96 }
97 
cpp_type(FieldType type)98 inline FieldDescriptor::CppType cpp_type(FieldType type) {
99   return FieldDescriptor::TypeToCppType(
100       static_cast<FieldDescriptor::Type>(type));
101 }
102 
field_type(FieldType type)103 inline WireFormatLite::FieldType field_type(FieldType type) {
104   ABSL_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
105   return static_cast<WireFormatLite::FieldType>(type);
106 }
107 
108 #define ABSL_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE)                         \
109   ABSL_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED  \
110                                          : FieldDescriptor::LABEL_OPTIONAL, \
111                  FieldDescriptor::LABEL_##LABEL);                           \
112   ABSL_DCHECK_EQ(cpp_type((EXTENSION).type), FieldDescriptor::CPPTYPE_##CPPTYPE)
113 
GetMessage(int number,const Descriptor * message_type,MessageFactory * factory) const114 const MessageLite& ExtensionSet::GetMessage(int number,
115                                             const Descriptor* message_type,
116                                             MessageFactory* factory) const {
117   const Extension* extension = FindOrNull(number);
118   if (extension == nullptr || extension->is_cleared) {
119     // Not present.  Return the default value.
120     return *factory->GetPrototype(message_type);
121   } else {
122     ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
123     if (extension->is_lazy) {
124       return extension->ptr.lazymessage_value->GetMessage(
125           *factory->GetPrototype(message_type), arena_);
126     } else {
127       return *extension->ptr.message_value;
128     }
129   }
130 }
131 
MutableMessage(const FieldDescriptor * descriptor,MessageFactory * factory)132 MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
133                                           MessageFactory* factory) {
134   Extension* extension;
135   if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
136     extension->type = descriptor->type();
137     ABSL_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
138     extension->is_repeated = false;
139     extension->is_pointer = true;
140     extension->is_packed = false;
141     const MessageLite* prototype =
142         factory->GetPrototype(descriptor->message_type());
143     extension->is_lazy = false;
144     extension->ptr.message_value = prototype->New(arena_);
145     extension->is_cleared = false;
146     return extension->ptr.message_value;
147   } else {
148     ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
149     extension->is_cleared = false;
150     if (extension->is_lazy) {
151       return extension->ptr.lazymessage_value->MutableMessage(
152           *factory->GetPrototype(descriptor->message_type()), arena_);
153     } else {
154       return extension->ptr.message_value;
155     }
156   }
157 }
158 
ReleaseMessage(const FieldDescriptor * descriptor,MessageFactory * factory)159 MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
160                                           MessageFactory* factory) {
161   Extension* extension = FindOrNull(descriptor->number());
162   if (extension == nullptr) {
163     // Not present.  Return nullptr.
164     return nullptr;
165   } else {
166     ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
167     MessageLite* ret = nullptr;
168     if (extension->is_lazy) {
169       ret = extension->ptr.lazymessage_value->ReleaseMessage(
170           *factory->GetPrototype(descriptor->message_type()), arena_);
171       if (arena_ == nullptr) {
172         delete extension->ptr.lazymessage_value;
173       }
174     } else {
175       if (arena_ != nullptr) {
176         ret = extension->ptr.message_value->New();
177         ret->CheckTypeAndMergeFrom(*extension->ptr.message_value);
178       } else {
179         ret = extension->ptr.message_value;
180       }
181     }
182     Erase(descriptor->number());
183     return ret;
184   }
185 }
186 
UnsafeArenaReleaseMessage(const FieldDescriptor * descriptor,MessageFactory * factory)187 MessageLite* ExtensionSet::UnsafeArenaReleaseMessage(
188     const FieldDescriptor* descriptor, MessageFactory* factory) {
189   Extension* extension = FindOrNull(descriptor->number());
190   if (extension == nullptr) {
191     // Not present.  Return nullptr.
192     return nullptr;
193   } else {
194     ABSL_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
195     MessageLite* ret = nullptr;
196     if (extension->is_lazy) {
197       ret = extension->ptr.lazymessage_value->UnsafeArenaReleaseMessage(
198           *factory->GetPrototype(descriptor->message_type()), arena_);
199       if (arena_ == nullptr) {
200         delete extension->ptr.lazymessage_value;
201       }
202     } else {
203       ret = extension->ptr.message_value;
204     }
205     Erase(descriptor->number());
206     return ret;
207   }
208 }
209 
MaybeNewRepeatedExtension(const FieldDescriptor * descriptor)210 ExtensionSet::Extension* ExtensionSet::MaybeNewRepeatedExtension(
211     const FieldDescriptor* descriptor) {
212   Extension* extension;
213   if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
214     extension->type = descriptor->type();
215     ABSL_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
216     extension->is_repeated = true;
217     extension->is_pointer = true;
218     extension->ptr.repeated_message_value =
219         Arena::Create<RepeatedPtrField<MessageLite> >(arena_);
220   } else {
221     ABSL_DCHECK_TYPE(*extension, REPEATED, MESSAGE);
222   }
223   return extension;
224 }
225 
AddMessage(const FieldDescriptor * descriptor,MessageFactory * factory)226 MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
227                                       MessageFactory* factory) {
228   Extension* extension = MaybeNewRepeatedExtension(descriptor);
229 
230   // RepeatedPtrField<Message> does not know how to Add() since it cannot
231   // allocate an abstract object, so we have to be tricky.
232   MessageLite* result =
233       reinterpret_cast<internal::RepeatedPtrFieldBase*>(
234           extension->ptr.repeated_message_value)
235           ->AddFromCleared<GenericTypeHandler<MessageLite> >();
236   if (result == nullptr) {
237     const MessageLite* prototype;
238     if (extension->ptr.repeated_message_value->empty()) {
239       prototype = factory->GetPrototype(descriptor->message_type());
240       ABSL_CHECK(prototype != nullptr);
241     } else {
242       prototype = &extension->ptr.repeated_message_value->Get(0);
243     }
244     result = prototype->New(arena_);
245     extension->ptr.repeated_message_value->AddAllocated(result);
246   }
247   return result;
248 }
249 
AddAllocatedMessage(const FieldDescriptor * descriptor,MessageLite * new_entry)250 void ExtensionSet::AddAllocatedMessage(const FieldDescriptor* descriptor,
251                                        MessageLite* new_entry) {
252   Extension* extension = MaybeNewRepeatedExtension(descriptor);
253 
254   extension->ptr.repeated_message_value->AddAllocated(new_entry);
255 }
256 
UnsafeArenaAddAllocatedMessage(const FieldDescriptor * descriptor,MessageLite * new_entry)257 void ExtensionSet::UnsafeArenaAddAllocatedMessage(
258     const FieldDescriptor* descriptor, MessageLite* new_entry) {
259   Extension* extension = MaybeNewRepeatedExtension(descriptor);
260 
261   extension->ptr.repeated_message_value->UnsafeArenaAddAllocated(new_entry);
262 }
263 
ValidateEnumUsingDescriptor(const void * arg,int number)264 static bool ValidateEnumUsingDescriptor(const void* arg, int number) {
265   return reinterpret_cast<const EnumDescriptor*>(arg)->FindValueByNumber(
266              number) != nullptr;
267 }
268 
Find(int number,ExtensionInfo * output)269 bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) {
270   const FieldDescriptor* extension =
271       pool_->FindExtensionByNumber(containing_type_, number);
272   if (extension == nullptr) {
273     return false;
274   } else {
275     output->type = extension->type();
276     output->is_repeated = extension->is_repeated();
277     output->is_packed = extension->is_packed();
278     output->descriptor = extension;
279     if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
280       output->message_info.prototype =
281           factory_->GetPrototype(extension->message_type());
282       output->message_info.tc_table =
283           output->message_info.prototype->GetTcParseTable();
284       ABSL_CHECK(output->message_info.prototype != nullptr)
285           << "Extension factory's GetPrototype() returned nullptr; extension: "
286           << extension->full_name();
287 
288       if (extension->options().has_lazy()) {
289         output->is_lazy = extension->options().lazy() ? LazyAnnotation::kLazy
290                                                       : LazyAnnotation::kEager;
291       }
292     } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
293       output->enum_validity_check.func = ValidateEnumUsingDescriptor;
294       output->enum_validity_check.arg = extension->enum_type();
295     }
296 
297     return true;
298   }
299 }
300 
301 
FindExtension(int wire_type,uint32_t field,const Message * extendee,const internal::ParseContext * ctx,ExtensionInfo * extension,bool * was_packed_on_wire)302 bool ExtensionSet::FindExtension(int wire_type, uint32_t field,
303                                  const Message* extendee,
304                                  const internal::ParseContext* ctx,
305                                  ExtensionInfo* extension,
306                                  bool* was_packed_on_wire) {
307   if (ctx->data().pool == nullptr) {
308     GeneratedExtensionFinder finder(extendee);
309     if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension,
310                                           was_packed_on_wire)) {
311       return false;
312     }
313   } else {
314     DescriptorPoolExtensionFinder finder(ctx->data().pool, ctx->data().factory,
315                                          extendee->GetDescriptor());
316     if (!FindExtensionInfoFromFieldNumber(wire_type, field, &finder, extension,
317                                           was_packed_on_wire)) {
318       return false;
319     }
320   }
321   return true;
322 }
323 
324 
ParseField(uint64_t tag,const char * ptr,const Message * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)325 const char* ExtensionSet::ParseField(uint64_t tag, const char* ptr,
326                                      const Message* extendee,
327                                      internal::InternalMetadata* metadata,
328                                      internal::ParseContext* ctx) {
329   int number = tag >> 3;
330   bool was_packed_on_wire;
331   ExtensionInfo extension;
332   if (!FindExtension(tag & 7, number, extendee, ctx, &extension,
333                      &was_packed_on_wire)) {
334     return UnknownFieldParse(
335         tag, metadata->mutable_unknown_fields<UnknownFieldSet>(), ptr, ctx);
336   }
337   return ParseFieldWithExtensionInfo<UnknownFieldSet>(
338       number, was_packed_on_wire, extension, metadata, ptr, ctx);
339 }
340 
ParseFieldMaybeLazily(uint64_t tag,const char * ptr,const Message * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)341 const char* ExtensionSet::ParseFieldMaybeLazily(
342     uint64_t tag, const char* ptr, const Message* extendee,
343     internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
344   return ParseField(tag, ptr, extendee, metadata, ctx);
345 }
346 
ParseMessageSetItem(const char * ptr,const Message * extendee,internal::InternalMetadata * metadata,internal::ParseContext * ctx)347 const char* ExtensionSet::ParseMessageSetItem(
348     const char* ptr, const Message* extendee,
349     internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
350   return ParseMessageSetItemTmpl<Message, UnknownFieldSet>(ptr, extendee,
351                                                            metadata, ctx);
352 }
353 
SpaceUsedExcludingSelf() const354 int ExtensionSet::SpaceUsedExcludingSelf() const {
355   return internal::FromIntSize(SpaceUsedExcludingSelfLong());
356 }
357 
SpaceUsedExcludingSelfLong() const358 size_t ExtensionSet::SpaceUsedExcludingSelfLong() const {
359   size_t total_size =
360       (is_large() ? map_.large->size() : flat_capacity_) * sizeof(KeyValue);
361   ForEach(
362       [&total_size](int /* number */, const Extension& ext) {
363         total_size += ext.SpaceUsedExcludingSelfLong();
364       },
365       Prefetch{});
366   return total_size;
367 }
368 
RepeatedMessage_SpaceUsedExcludingSelfLong(RepeatedPtrFieldBase * field)369 inline size_t ExtensionSet::RepeatedMessage_SpaceUsedExcludingSelfLong(
370     RepeatedPtrFieldBase* field) {
371   return field->SpaceUsedExcludingSelfLong<GenericTypeHandler<Message> >();
372 }
373 
SpaceUsedExcludingSelfLong() const374 size_t ExtensionSet::Extension::SpaceUsedExcludingSelfLong() const {
375   size_t total_size = 0;
376   if (is_repeated) {
377     switch (cpp_type(type)) {
378 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)                               \
379   case FieldDescriptor::CPPTYPE_##UPPERCASE:                            \
380     total_size +=                                                       \
381         sizeof(*ptr.repeated_##LOWERCASE##_value) +                     \
382         ptr.repeated_##LOWERCASE##_value->SpaceUsedExcludingSelfLong(); \
383     break
384 
385       HANDLE_TYPE(INT32, int32_t);
386       HANDLE_TYPE(INT64, int64_t);
387       HANDLE_TYPE(UINT32, uint32_t);
388       HANDLE_TYPE(UINT64, uint64_t);
389       HANDLE_TYPE(FLOAT, float);
390       HANDLE_TYPE(DOUBLE, double);
391       HANDLE_TYPE(BOOL, bool);
392       HANDLE_TYPE(ENUM, enum);
393       HANDLE_TYPE(STRING, string);
394 #undef HANDLE_TYPE
395 
396       case FieldDescriptor::CPPTYPE_MESSAGE:
397         // repeated_message_value is actually a RepeatedPtrField<MessageLite>,
398         // but MessageLite has no SpaceUsedLong(), so we must directly call
399         // RepeatedPtrFieldBase::SpaceUsedExcludingSelfLong() with a different
400         // type handler.
401         total_size += sizeof(*ptr.repeated_message_value) +
402                       RepeatedMessage_SpaceUsedExcludingSelfLong(
403                           reinterpret_cast<internal::RepeatedPtrFieldBase*>(
404                               ptr.repeated_message_value));
405         break;
406     }
407   } else {
408     switch (cpp_type(type)) {
409       case FieldDescriptor::CPPTYPE_STRING:
410         total_size += sizeof(*ptr.string_value) +
411                       StringSpaceUsedExcludingSelfLong(*ptr.string_value);
412         break;
413       case FieldDescriptor::CPPTYPE_MESSAGE:
414         if (is_lazy) {
415           total_size += ptr.lazymessage_value->SpaceUsedLong();
416         } else {
417           total_size +=
418               DownCastMessage<Message>(ptr.message_value)->SpaceUsedLong();
419         }
420         break;
421       default:
422         // No extra storage costs for primitive types.
423         break;
424     }
425   }
426   return total_size;
427 }
428 
SerializeMessageSetWithCachedSizesToArray(const MessageLite * extendee,uint8_t * target) const429 uint8_t* ExtensionSet::SerializeMessageSetWithCachedSizesToArray(
430     const MessageLite* extendee, uint8_t* target) const {
431   io::EpsCopyOutputStream stream(
432       target, MessageSetByteSize(),
433       io::CodedOutputStream::IsDefaultSerializationDeterministic());
434   return InternalSerializeMessageSetWithCachedSizesToArray(extendee, target,
435                                                            &stream);
436 }
437 
438 #if defined(PROTOBUF_DESCRIPTOR_WEAK_MESSAGES_ALLOWED)
ShouldRegisterAtThisTime(std::initializer_list<WeakPrototypeRef> messages,bool is_preregistration)439 bool ExtensionSet::ShouldRegisterAtThisTime(
440     std::initializer_list<WeakPrototypeRef> messages, bool is_preregistration) {
441   bool has_all = true;
442   for (auto ref : messages) {
443     has_all = has_all && GetPrototypeForWeakDescriptor(ref.table, ref.index,
444                                                        false) != nullptr;
445   }
446   return has_all == is_preregistration;
447 }
448 #endif  // PROTOBUF_DESCRIPTOR_WEAK_MESSAGES_ALLOWED
449 
450 }  // namespace internal
451 }  // namespace protobuf
452 }  // namespace google
453 
454 #include "google/protobuf/port_undef.inc"
455