• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef GOOGLE_PROTOBUF_REFLECTION_VISIT_FIELDS_H__
2 #define GOOGLE_PROTOBUF_REFLECTION_VISIT_FIELDS_H__
3 
4 #include <cstdint>
5 #include <string>
6 #include <utility>
7 
8 #include "absl/base/attributes.h"
9 #include "absl/log/absl_check.h"
10 #include "absl/strings/cord.h"
11 #include "google/protobuf/descriptor.h"
12 #include "google/protobuf/descriptor.pb.h"
13 #include "google/protobuf/descriptor_lite.h"
14 #include "google/protobuf/extension_set.h"
15 #include "google/protobuf/generated_message_reflection.h"
16 #include "google/protobuf/message.h"
17 #include "google/protobuf/port.h"
18 #include "google/protobuf/reflection.h"
19 #include "google/protobuf/reflection_visit_field_info.h"
20 #include "google/protobuf/repeated_field.h"
21 #include "google/protobuf/repeated_ptr_field.h"
22 
23 
24 // Must be the last include.
25 #include "google/protobuf/port_def.inc"  // NOLINT
26 
27 namespace google {
28 namespace protobuf {
29 namespace internal {
30 
31 enum class FieldMask : uint32_t {
32   kInt32 = 1 << FieldDescriptor::CPPTYPE_INT32,
33   kInt64 = 1 << FieldDescriptor::CPPTYPE_INT64,
34   kUInt32 = 1 << FieldDescriptor::CPPTYPE_UINT32,
35   kUInt64 = 1 << FieldDescriptor::CPPTYPE_UINT64,
36   kDouble = 1 << FieldDescriptor::CPPTYPE_DOUBLE,
37   kFloat = 1 << FieldDescriptor::CPPTYPE_FLOAT,
38   kBool = 1 << FieldDescriptor::CPPTYPE_BOOL,
39   kEnum = 1 << FieldDescriptor::CPPTYPE_ENUM,
40   kString = 1 << FieldDescriptor::CPPTYPE_STRING,
41   kMessage = 1 << FieldDescriptor::CPPTYPE_MESSAGE,
42   kPrimitive =
43       kInt32 | kInt64 | kUInt32 | kUInt64 | kDouble | kFloat | kBool | kEnum,
44   kAll = 0xFFFFFFFFu,
45 };
46 
47 inline FieldMask operator|(FieldMask lhs, FieldMask rhs) {
48   return static_cast<FieldMask>(static_cast<uint32_t>(lhs) |
49                                 static_cast<uint32_t>(rhs));
50 }
51 
52 #ifdef __cpp_if_constexpr
53 
54 template <typename MessageT, typename CallbackFn>
55 void VisitFields(MessageT& message, CallbackFn&& func,
56                  FieldMask mask = FieldMask::kAll);
57 
58 class ReflectionVisit final {
59  public:
60   template <typename MessageT, typename CallbackFn>
61   static void VisitFields(MessageT& message, CallbackFn&& func, FieldMask mask);
62 
63   template <typename CallbackFn>
64   static void VisitMessageFields(const Message& message, CallbackFn&& func);
65 
66   template <typename CallbackFn>
67   static void VisitMessageFields(Message& message, CallbackFn&& func);
68 
69  private:
GetSchema(const Reflection * reflection)70   static const internal::ReflectionSchema& GetSchema(
71       const Reflection* reflection) {
72     return reflection->schema_;
73   }
GetDescriptor(const Reflection * reflection)74   static const Descriptor* GetDescriptor(const Reflection* reflection) {
75     return reflection->descriptor_;
76   }
ExtensionSet(const Reflection * reflection,const Message & message)77   static const internal::ExtensionSet& ExtensionSet(
78       const Reflection* reflection, const Message& message) {
79     return reflection->GetExtensionSet(message);
80   }
ExtensionSet(const Reflection * reflection,Message & message)81   static internal::ExtensionSet& ExtensionSet(const Reflection* reflection,
82                                               Message& message) {
83     return *reflection->MutableExtensionSet(&message);
84   }
85 };
86 
ShouldVisit(FieldMask mask,FieldDescriptor::CppType cpptype)87 inline bool ShouldVisit(FieldMask mask, FieldDescriptor::CppType cpptype) {
88   if (PROTOBUF_PREDICT_TRUE(mask == FieldMask::kAll)) return true;
89   return (static_cast<uint32_t>(mask) & (1 << cpptype)) != 0;
90 }
91 
92 template <typename MessageT, typename CallbackFn>
VisitFields(MessageT & message,CallbackFn && func,FieldMask mask)93 void ReflectionVisit::VisitFields(MessageT& message, CallbackFn&& func,
94                                   FieldMask mask) {
95   const Reflection* reflection = message.GetReflection();
96   const auto& schema = GetSchema(reflection);
97 
98   ABSL_CHECK(!schema.HasWeakFields()) << "weak fields are not supported";
99 
100   // See Reflection::ListFields for the optimization.
101   const uint32_t* const has_bits =
102       schema.HasHasbits() ? reflection->GetHasBits(message) : nullptr;
103   const uint32_t* const has_bits_indices = schema.has_bit_indices_;
104   const Descriptor* descriptor = GetDescriptor(reflection);
105   const int field_count = descriptor->field_count();
106 
107   for (int i = 0; i < field_count; i++) {
108     const FieldDescriptor* field = descriptor->field(i);
109     ABSL_DCHECK(!field->options().weak()) << "weak fields are not supported";
110 
111     if (!ShouldVisit(mask, field->cpp_type())) continue;
112 
113     if (field->is_repeated()) {
114       switch (field->type()) {
115 #define PROTOBUF_HANDLE_REPEATED_CASE(TYPE, CPPTYPE, NAME)                  \
116   case FieldDescriptor::TYPE_##TYPE: {                                      \
117     ABSL_DCHECK(!field->is_map());                                          \
118     const auto& rep =                                                       \
119         reflection->GetRawNonOneof<RepeatedField<CPPTYPE>>(message, field); \
120     if (rep.size() == 0) continue;                                          \
121     func(internal::Repeated##NAME##DynamicFieldInfo<MessageT>{              \
122         reflection, message, field, rep});                                  \
123     break;                                                                  \
124   }
125 
126         PROTOBUF_HANDLE_REPEATED_CASE(DOUBLE, double, Double);
127         PROTOBUF_HANDLE_REPEATED_CASE(FLOAT, float, Float);
128         PROTOBUF_HANDLE_REPEATED_CASE(INT64, int64_t, Int64);
129         PROTOBUF_HANDLE_REPEATED_CASE(UINT64, uint64_t, UInt64);
130         PROTOBUF_HANDLE_REPEATED_CASE(INT32, int32_t, Int32);
131         PROTOBUF_HANDLE_REPEATED_CASE(FIXED64, uint64_t, Fixed64);
132         PROTOBUF_HANDLE_REPEATED_CASE(FIXED32, uint32_t, Fixed32);
133         PROTOBUF_HANDLE_REPEATED_CASE(BOOL, bool, Bool);
134         PROTOBUF_HANDLE_REPEATED_CASE(UINT32, uint32_t, UInt32);
135         PROTOBUF_HANDLE_REPEATED_CASE(ENUM, int, Enum);
136         PROTOBUF_HANDLE_REPEATED_CASE(SFIXED32, int32_t, SFixed32);
137         PROTOBUF_HANDLE_REPEATED_CASE(SFIXED64, int64_t, SFixed64);
138         PROTOBUF_HANDLE_REPEATED_CASE(SINT32, int32_t, SInt32);
139         PROTOBUF_HANDLE_REPEATED_CASE(SINT64, int64_t, SInt64);
140 
141 #define PROTOBUF_HANDLE_REPEATED_PTR_CASE(TYPE, CPPTYPE, NAME)                 \
142   case FieldDescriptor::TYPE_##TYPE: {                                         \
143     if (PROTOBUF_PREDICT_TRUE(!field->is_map())) {                             \
144       /* Handle repeated fields. */                                            \
145       const auto& rep = reflection->GetRawNonOneof<RepeatedPtrField<CPPTYPE>>( \
146           message, field);                                                     \
147       if (rep.size() == 0) continue;                                           \
148       func(internal::Repeated##NAME##DynamicFieldInfo<MessageT>{               \
149           reflection, message, field, rep});                                   \
150     } else {                                                                   \
151       /* Handle map fields. */                                                 \
152       const auto& map =                                                        \
153           reflection->GetRawNonOneof<MapFieldBase>(message, field);            \
154       if (map.size() == 0) continue; /* NOLINT */                              \
155       const Descriptor* desc = field->message_type();                          \
156       func(internal::MapDynamicFieldInfo<MessageT>{reflection, message, field, \
157                                                    desc->map_key(),            \
158                                                    desc->map_value(), map});   \
159     }                                                                          \
160     break;                                                                     \
161   }
162 
163         PROTOBUF_HANDLE_REPEATED_PTR_CASE(MESSAGE, Message, Message);
164         PROTOBUF_HANDLE_REPEATED_PTR_CASE(GROUP, Message, Group);
165 
166         case FieldDescriptor::TYPE_BYTES:
167         case FieldDescriptor::TYPE_STRING:
168 #define PROTOBUF_IMPL_STRING_CASE(CPPTYPE, NAME)                               \
169   {                                                                            \
170     const auto& rep =                                                          \
171         reflection->GetRawNonOneof<RepeatedPtrField<CPPTYPE>>(message, field); \
172     if (rep.size() == 0) continue;                                             \
173     func(internal::Repeated##NAME##DynamicFieldInfo<MessageT>{                 \
174         reflection, message, field, rep});                                     \
175   }
176 
177           switch (cpp::EffectiveStringCType(field)) {
178             default:
179             case FieldOptions::STRING:
180               PROTOBUF_IMPL_STRING_CASE(std::string, String);
181               break;
182           }
183           break;
184         default:
185           internal::Unreachable();
186           break;
187       }
188 #undef PROTOBUF_HANDLE_REPEATED_CASE
189 #undef PROTOBUF_HANDLE_REPEATED_PTR_CASE
190 #undef PROTOBUF_IMPL_STRING_CASE
191     } else if (schema.InRealOneof(field)) {
192       const OneofDescriptor* containing_oneof = field->containing_oneof();
193       const uint32_t* const oneof_case_array =
194           internal::GetConstPointerAtOffset<uint32_t>(
195               &message, schema.oneof_case_offset_);
196       // Equivalent to: !HasOneofField(message, field)
197       if (static_cast<int64_t>(oneof_case_array[containing_oneof->index()]) !=
198           field->number()) {
199         continue;
200       }
201       switch (field->type()) {
202 #define PROTOBUF_HANDLE_CASE(TYPE, NAME)                                       \
203   case FieldDescriptor::TYPE_##TYPE:                                           \
204     func(internal::NAME##DynamicFieldInfo<MessageT, true>{reflection, message, \
205                                                           field});             \
206     break;
207         PROTOBUF_HANDLE_CASE(DOUBLE, Double);
208         PROTOBUF_HANDLE_CASE(FLOAT, Float);
209         PROTOBUF_HANDLE_CASE(INT64, Int64);
210         PROTOBUF_HANDLE_CASE(UINT64, UInt64);
211         PROTOBUF_HANDLE_CASE(INT32, Int32);
212         PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
213         PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
214         PROTOBUF_HANDLE_CASE(BOOL, Bool);
215         PROTOBUF_HANDLE_CASE(UINT32, UInt32);
216         PROTOBUF_HANDLE_CASE(ENUM, Enum);
217         PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
218         PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
219         PROTOBUF_HANDLE_CASE(SINT32, SInt32);
220         PROTOBUF_HANDLE_CASE(SINT64, SInt64);
221 
222         case FieldDescriptor::TYPE_MESSAGE:
223         case FieldDescriptor::TYPE_GROUP:
224           func(internal::MessageDynamicFieldInfo<MessageT, true>{
225               reflection, message, field});
226           break;
227 
228         case FieldDescriptor::TYPE_BYTES:
229         case FieldDescriptor::TYPE_STRING: {
230           auto ctype = cpp::EffectiveStringCType(field);
231           if (ctype == FieldOptions::CORD) {
232             func(CordDynamicFieldInfo<MessageT, true>{reflection, message,
233                                                       field});
234           } else {
235             func(StringDynamicFieldInfo<MessageT, true>{reflection, message,
236                                                         field});
237           }
238           break;
239         }
240         default:
241           internal::Unreachable();
242           break;
243 #undef PROTOBUF_HANDLE_CASE
244       }
245     } else {
246       auto index = has_bits_indices[i];
247       bool check_hasbits = has_bits && index != static_cast<uint32_t>(-1);
248       if (PROTOBUF_PREDICT_TRUE(check_hasbits)) {
249         if ((has_bits[index / 32] & (1u << (index % 32))) == 0) continue;
250       } else {
251         // Skip if it has default values.
252         if (!reflection->HasFieldSingular(message, field)) continue;
253       }
254       switch (field->type()) {
255 #define PROTOBUF_HANDLE_CASE(TYPE, NAME)                                     \
256   case FieldDescriptor::TYPE_##TYPE:                                         \
257     func(internal::NAME##DynamicFieldInfo<MessageT, false>{reflection,       \
258                                                            message, field}); \
259     break;
260         PROTOBUF_HANDLE_CASE(DOUBLE, Double);
261         PROTOBUF_HANDLE_CASE(FLOAT, Float);
262         PROTOBUF_HANDLE_CASE(INT64, Int64);
263         PROTOBUF_HANDLE_CASE(UINT64, UInt64);
264         PROTOBUF_HANDLE_CASE(INT32, Int32);
265         PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
266         PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
267         PROTOBUF_HANDLE_CASE(BOOL, Bool);
268         PROTOBUF_HANDLE_CASE(UINT32, UInt32);
269         PROTOBUF_HANDLE_CASE(ENUM, Enum);
270         PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
271         PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
272         PROTOBUF_HANDLE_CASE(SINT32, SInt32);
273         PROTOBUF_HANDLE_CASE(SINT64, SInt64);
274 
275         case FieldDescriptor::TYPE_MESSAGE:
276         case FieldDescriptor::TYPE_GROUP:
277           func(internal::MessageDynamicFieldInfo<MessageT, false>{
278               reflection, message, field});
279           break;
280         case FieldDescriptor::TYPE_BYTES:
281         case FieldDescriptor::TYPE_STRING: {
282           auto ctype = cpp::EffectiveStringCType(field);
283           if (ctype == FieldOptions::CORD) {
284             func(CordDynamicFieldInfo<MessageT, false>{reflection, message,
285                                                        field});
286           } else {
287             func(StringDynamicFieldInfo<MessageT, false>{reflection, message,
288                                                          field});
289           }
290           break;
291         }
292         default:
293           internal::Unreachable();
294           break;
295 #undef PROTOBUF_HANDLE_CASE
296       }
297     }
298   }
299 
300   if (!schema.HasExtensionSet()) return;
301 
302   auto& set = ExtensionSet(reflection, message);
303   auto* extendee = reflection->descriptor_;
304   auto* pool = reflection->descriptor_pool_;
305 
306   set.ForEach(
307       [&](int number, auto& ext) {
308         ABSL_DCHECK_GT(ext.type, 0);
309         ABSL_DCHECK_LE(ext.type, FieldDescriptor::MAX_TYPE);
310 
311         if (!ShouldVisit(mask,
312                          FieldDescriptor::TypeToCppType(
313                              static_cast<FieldDescriptor::Type>(ext.type)))) {
314           return;
315         }
316 
317         if (ext.is_repeated) {
318           if (ext.GetSize() == 0) return;
319 
320           switch (ext.type) {
321 #define PROTOBUF_HANDLE_CASE(TYPE, NAME)                                \
322   case FieldDescriptor::TYPE_##TYPE:                                    \
323     func(internal::Repeated##NAME##DynamicExtensionInfo<decltype(ext)>{ \
324         ext, number});                                                  \
325     break;
326             PROTOBUF_HANDLE_CASE(DOUBLE, Double);
327             PROTOBUF_HANDLE_CASE(FLOAT, Float);
328             PROTOBUF_HANDLE_CASE(INT64, Int64);
329             PROTOBUF_HANDLE_CASE(UINT64, UInt64);
330             PROTOBUF_HANDLE_CASE(INT32, Int32);
331             PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
332             PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
333             PROTOBUF_HANDLE_CASE(BOOL, Bool);
334             PROTOBUF_HANDLE_CASE(UINT32, UInt32);
335             PROTOBUF_HANDLE_CASE(ENUM, Enum);
336             PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
337             PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
338             PROTOBUF_HANDLE_CASE(SINT32, SInt32);
339             PROTOBUF_HANDLE_CASE(SINT64, SInt64);
340 
341             PROTOBUF_HANDLE_CASE(MESSAGE, Message);
342             PROTOBUF_HANDLE_CASE(GROUP, Group);
343 
344             case FieldDescriptor::TYPE_BYTES:
345             case FieldDescriptor::TYPE_STRING:
346               func(internal::RepeatedStringDynamicExtensionInfo<decltype(ext)>{
347                   ext, number});
348               break;
349             default:
350               internal::Unreachable();
351               break;
352 #undef PROTOBUF_HANDLE_CASE
353           }
354         } else {
355           if (ext.is_cleared) return;
356 
357           switch (ext.type) {
358 #define PROTOBUF_HANDLE_CASE(TYPE, NAME)                                    \
359   case FieldDescriptor::TYPE_##TYPE:                                        \
360     func(internal::NAME##DynamicExtensionInfo<decltype(ext)>{ext, number}); \
361     break;
362             PROTOBUF_HANDLE_CASE(DOUBLE, Double);
363             PROTOBUF_HANDLE_CASE(FLOAT, Float);
364             PROTOBUF_HANDLE_CASE(INT64, Int64);
365             PROTOBUF_HANDLE_CASE(UINT64, UInt64);
366             PROTOBUF_HANDLE_CASE(INT32, Int32);
367             PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
368             PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
369             PROTOBUF_HANDLE_CASE(BOOL, Bool);
370             PROTOBUF_HANDLE_CASE(UINT32, UInt32);
371             PROTOBUF_HANDLE_CASE(ENUM, Enum);
372             PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
373             PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
374             PROTOBUF_HANDLE_CASE(SINT32, SInt32);
375             PROTOBUF_HANDLE_CASE(SINT64, SInt64);
376 
377             PROTOBUF_HANDLE_CASE(GROUP, Group);
378             case FieldDescriptor::TYPE_MESSAGE: {
379               const FieldDescriptor* field =
380                   ext.descriptor != nullptr
381                       ? ext.descriptor
382                       : pool->FindExtensionByNumber(extendee, number);
383               ABSL_DCHECK_EQ(field->number(), number);
384               bool is_mset =
385                   field->containing_type()->options().message_set_wire_format();
386               func(internal::MessageDynamicExtensionInfo<decltype(ext)>{
387                   ext, number, is_mset});
388               break;
389             }
390 
391             case FieldDescriptor::TYPE_BYTES:
392             case FieldDescriptor::TYPE_STRING:
393               func(internal::StringDynamicExtensionInfo<decltype(ext)>{ext,
394                                                                        number});
395               break;
396 
397             default:
398               internal::Unreachable();
399               break;
400 #undef PROTOBUF_HANDLE_CASE
401           }
402         }
403       },
404       ExtensionSet::Prefetch{});
405 }
406 
407 template <typename CallbackFn>
VisitMessageFields(const Message & message,CallbackFn && func)408 void ReflectionVisit::VisitMessageFields(const Message& message,
409                                          CallbackFn&& func) {
410   ReflectionVisit::VisitFields(
411       message,
412       [&](auto info) {
413         if constexpr (info.is_map) {
414           auto value_type = info.value_type();
415           if (value_type != FieldDescriptor::TYPE_MESSAGE &&
416               value_type != FieldDescriptor::TYPE_GROUP) {
417             return;
418           }
419           info.VisitElements([&](auto key, auto val) {
420             if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
421               func(val.Get());
422             }
423           });
424         } else if constexpr (info.cpp_type ==
425                              FieldDescriptor::CPPTYPE_MESSAGE) {
426           if constexpr (info.is_repeated) {
427             for (const auto& it : info.Get()) {
428               func(DownCastMessage<Message>(it));
429             }
430           } else {
431             func(info.Get());
432           }
433         }
434       },
435       FieldMask::kMessage);
436 }
437 
438 template <typename CallbackFn>
VisitMessageFields(Message & message,CallbackFn && func)439 void ReflectionVisit::VisitMessageFields(Message& message, CallbackFn&& func) {
440   ReflectionVisit::VisitFields(
441       message,
442       [&](auto info) {
443         if constexpr (info.is_map) {
444           auto value_type = info.value_type();
445           if (value_type != FieldDescriptor::TYPE_MESSAGE &&
446               value_type != FieldDescriptor::TYPE_GROUP) {
447             return;
448           }
449           info.VisitElements([&](auto key, auto val) {
450             if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
451               func(*val.Mutable());
452             }
453           });
454         } else if constexpr (info.cpp_type ==
455                              FieldDescriptor::CPPTYPE_MESSAGE) {
456           if constexpr (info.is_repeated) {
457             for (auto& it : info.Mutable()) {
458               func(DownCastMessage<Message>(it));
459             }
460           } else {
461             func(info.Mutable());
462           }
463         }
464       },
465       FieldMask::kMessage);
466 }
467 
468 // Visits present fields of "message" and calls the callback function "func".
469 // Skips fields whose ctypes are missing in "mask".
470 template <typename MessageT, typename CallbackFn>
VisitFields(MessageT & message,CallbackFn && func,FieldMask mask)471 void VisitFields(MessageT& message, CallbackFn&& func, FieldMask mask) {
472   ReflectionVisit::VisitFields(message, std::forward<CallbackFn>(func), mask);
473 }
474 
475 // Visits message fields of "message" and calls "func". Expects "func" to
476 // accept const Message&. Note the following divergence from VisitFields.
477 //
478 // --Each of N elements of a repeated message field is visited (total N).
479 // --Each of M elements of a map field whose value type is message are visited
480 //   (total M).
481 // --A map field whose value type is not message is ignored.
482 //
483 // This is a helper API built on top of VisitFields to hide specifics about
484 // extensions, repeated fields, etc.
485 template <typename CallbackFn>
VisitMessageFields(const Message & message,CallbackFn && func)486 void VisitMessageFields(const Message& message, CallbackFn&& func) {
487   ReflectionVisit::VisitMessageFields(message, std::forward<CallbackFn>(func));
488 }
489 
490 // Same as VisitMessageFields above but expects "func" to accept Message&. This
491 // is useful when mutable access is required. As mutable access can be
492 // expensive, use it only if it's necessary.
493 template <typename CallbackFn>
VisitMutableMessageFields(Message & message,CallbackFn && func)494 void VisitMutableMessageFields(Message& message, CallbackFn&& func) {
495   ReflectionVisit::VisitMessageFields(message, std::forward<CallbackFn>(func));
496 }
497 
498 #endif  // __cpp_if_constexpr
499 
500 }  // namespace internal
501 }  // namespace protobuf
502 }  // namespace google
503 
504 #include "google/protobuf/port_undef.inc"
505 
506 #endif  // GOOGLE_PROTOBUF_REFLECTION_VISIT_FIELDS_H__
507