1 #ifndef GOOGLE_PROTOBUF_REFLECTION_VISIT_FIELDS_H__
2 #define GOOGLE_PROTOBUF_REFLECTION_VISIT_FIELDS_H__
3
4 #include <cstdint>
5 #include <string>
6 #include <utility>
7
8 #include "absl/base/attributes.h"
9 #include "absl/log/absl_check.h"
10 #include "absl/strings/cord.h"
11 #include "google/protobuf/descriptor.h"
12 #include "google/protobuf/descriptor.pb.h"
13 #include "google/protobuf/descriptor_lite.h"
14 #include "google/protobuf/extension_set.h"
15 #include "google/protobuf/generated_message_reflection.h"
16 #include "google/protobuf/message.h"
17 #include "google/protobuf/port.h"
18 #include "google/protobuf/reflection.h"
19 #include "google/protobuf/reflection_visit_field_info.h"
20 #include "google/protobuf/repeated_field.h"
21 #include "google/protobuf/repeated_ptr_field.h"
22
23
24 // Must be the last include.
25 #include "google/protobuf/port_def.inc" // NOLINT
26
27 namespace google {
28 namespace protobuf {
29 namespace internal {
30
31 enum class FieldMask : uint32_t {
32 kInt32 = 1 << FieldDescriptor::CPPTYPE_INT32,
33 kInt64 = 1 << FieldDescriptor::CPPTYPE_INT64,
34 kUInt32 = 1 << FieldDescriptor::CPPTYPE_UINT32,
35 kUInt64 = 1 << FieldDescriptor::CPPTYPE_UINT64,
36 kDouble = 1 << FieldDescriptor::CPPTYPE_DOUBLE,
37 kFloat = 1 << FieldDescriptor::CPPTYPE_FLOAT,
38 kBool = 1 << FieldDescriptor::CPPTYPE_BOOL,
39 kEnum = 1 << FieldDescriptor::CPPTYPE_ENUM,
40 kString = 1 << FieldDescriptor::CPPTYPE_STRING,
41 kMessage = 1 << FieldDescriptor::CPPTYPE_MESSAGE,
42 kPrimitive =
43 kInt32 | kInt64 | kUInt32 | kUInt64 | kDouble | kFloat | kBool | kEnum,
44 kAll = 0xFFFFFFFFu,
45 };
46
47 inline FieldMask operator|(FieldMask lhs, FieldMask rhs) {
48 return static_cast<FieldMask>(static_cast<uint32_t>(lhs) |
49 static_cast<uint32_t>(rhs));
50 }
51
52 #ifdef __cpp_if_constexpr
53
54 template <typename MessageT, typename CallbackFn>
55 void VisitFields(MessageT& message, CallbackFn&& func,
56 FieldMask mask = FieldMask::kAll);
57
58 class ReflectionVisit final {
59 public:
60 template <typename MessageT, typename CallbackFn>
61 static void VisitFields(MessageT& message, CallbackFn&& func, FieldMask mask);
62
63 template <typename CallbackFn>
64 static void VisitMessageFields(const Message& message, CallbackFn&& func);
65
66 template <typename CallbackFn>
67 static void VisitMessageFields(Message& message, CallbackFn&& func);
68
69 private:
GetSchema(const Reflection * reflection)70 static const internal::ReflectionSchema& GetSchema(
71 const Reflection* reflection) {
72 return reflection->schema_;
73 }
GetDescriptor(const Reflection * reflection)74 static const Descriptor* GetDescriptor(const Reflection* reflection) {
75 return reflection->descriptor_;
76 }
ExtensionSet(const Reflection * reflection,const Message & message)77 static const internal::ExtensionSet& ExtensionSet(
78 const Reflection* reflection, const Message& message) {
79 return reflection->GetExtensionSet(message);
80 }
ExtensionSet(const Reflection * reflection,Message & message)81 static internal::ExtensionSet& ExtensionSet(const Reflection* reflection,
82 Message& message) {
83 return *reflection->MutableExtensionSet(&message);
84 }
85 };
86
ShouldVisit(FieldMask mask,FieldDescriptor::CppType cpptype)87 inline bool ShouldVisit(FieldMask mask, FieldDescriptor::CppType cpptype) {
88 if (PROTOBUF_PREDICT_TRUE(mask == FieldMask::kAll)) return true;
89 return (static_cast<uint32_t>(mask) & (1 << cpptype)) != 0;
90 }
91
92 template <typename MessageT, typename CallbackFn>
VisitFields(MessageT & message,CallbackFn && func,FieldMask mask)93 void ReflectionVisit::VisitFields(MessageT& message, CallbackFn&& func,
94 FieldMask mask) {
95 const Reflection* reflection = message.GetReflection();
96 const auto& schema = GetSchema(reflection);
97
98 ABSL_CHECK(!schema.HasWeakFields()) << "weak fields are not supported";
99
100 // See Reflection::ListFields for the optimization.
101 const uint32_t* const has_bits =
102 schema.HasHasbits() ? reflection->GetHasBits(message) : nullptr;
103 const uint32_t* const has_bits_indices = schema.has_bit_indices_;
104 const Descriptor* descriptor = GetDescriptor(reflection);
105 const int field_count = descriptor->field_count();
106
107 for (int i = 0; i < field_count; i++) {
108 const FieldDescriptor* field = descriptor->field(i);
109 ABSL_DCHECK(!field->options().weak()) << "weak fields are not supported";
110
111 if (!ShouldVisit(mask, field->cpp_type())) continue;
112
113 if (field->is_repeated()) {
114 switch (field->type()) {
115 #define PROTOBUF_HANDLE_REPEATED_CASE(TYPE, CPPTYPE, NAME) \
116 case FieldDescriptor::TYPE_##TYPE: { \
117 ABSL_DCHECK(!field->is_map()); \
118 const auto& rep = \
119 reflection->GetRawNonOneof<RepeatedField<CPPTYPE>>(message, field); \
120 if (rep.size() == 0) continue; \
121 func(internal::Repeated##NAME##DynamicFieldInfo<MessageT>{ \
122 reflection, message, field, rep}); \
123 break; \
124 }
125
126 PROTOBUF_HANDLE_REPEATED_CASE(DOUBLE, double, Double);
127 PROTOBUF_HANDLE_REPEATED_CASE(FLOAT, float, Float);
128 PROTOBUF_HANDLE_REPEATED_CASE(INT64, int64_t, Int64);
129 PROTOBUF_HANDLE_REPEATED_CASE(UINT64, uint64_t, UInt64);
130 PROTOBUF_HANDLE_REPEATED_CASE(INT32, int32_t, Int32);
131 PROTOBUF_HANDLE_REPEATED_CASE(FIXED64, uint64_t, Fixed64);
132 PROTOBUF_HANDLE_REPEATED_CASE(FIXED32, uint32_t, Fixed32);
133 PROTOBUF_HANDLE_REPEATED_CASE(BOOL, bool, Bool);
134 PROTOBUF_HANDLE_REPEATED_CASE(UINT32, uint32_t, UInt32);
135 PROTOBUF_HANDLE_REPEATED_CASE(ENUM, int, Enum);
136 PROTOBUF_HANDLE_REPEATED_CASE(SFIXED32, int32_t, SFixed32);
137 PROTOBUF_HANDLE_REPEATED_CASE(SFIXED64, int64_t, SFixed64);
138 PROTOBUF_HANDLE_REPEATED_CASE(SINT32, int32_t, SInt32);
139 PROTOBUF_HANDLE_REPEATED_CASE(SINT64, int64_t, SInt64);
140
141 #define PROTOBUF_HANDLE_REPEATED_PTR_CASE(TYPE, CPPTYPE, NAME) \
142 case FieldDescriptor::TYPE_##TYPE: { \
143 if (PROTOBUF_PREDICT_TRUE(!field->is_map())) { \
144 /* Handle repeated fields. */ \
145 const auto& rep = reflection->GetRawNonOneof<RepeatedPtrField<CPPTYPE>>( \
146 message, field); \
147 if (rep.size() == 0) continue; \
148 func(internal::Repeated##NAME##DynamicFieldInfo<MessageT>{ \
149 reflection, message, field, rep}); \
150 } else { \
151 /* Handle map fields. */ \
152 const auto& map = \
153 reflection->GetRawNonOneof<MapFieldBase>(message, field); \
154 if (map.size() == 0) continue; /* NOLINT */ \
155 const Descriptor* desc = field->message_type(); \
156 func(internal::MapDynamicFieldInfo<MessageT>{reflection, message, field, \
157 desc->map_key(), \
158 desc->map_value(), map}); \
159 } \
160 break; \
161 }
162
163 PROTOBUF_HANDLE_REPEATED_PTR_CASE(MESSAGE, Message, Message);
164 PROTOBUF_HANDLE_REPEATED_PTR_CASE(GROUP, Message, Group);
165
166 case FieldDescriptor::TYPE_BYTES:
167 case FieldDescriptor::TYPE_STRING:
168 #define PROTOBUF_IMPL_STRING_CASE(CPPTYPE, NAME) \
169 { \
170 const auto& rep = \
171 reflection->GetRawNonOneof<RepeatedPtrField<CPPTYPE>>(message, field); \
172 if (rep.size() == 0) continue; \
173 func(internal::Repeated##NAME##DynamicFieldInfo<MessageT>{ \
174 reflection, message, field, rep}); \
175 }
176
177 switch (cpp::EffectiveStringCType(field)) {
178 default:
179 case FieldOptions::STRING:
180 PROTOBUF_IMPL_STRING_CASE(std::string, String);
181 break;
182 }
183 break;
184 default:
185 internal::Unreachable();
186 break;
187 }
188 #undef PROTOBUF_HANDLE_REPEATED_CASE
189 #undef PROTOBUF_HANDLE_REPEATED_PTR_CASE
190 #undef PROTOBUF_IMPL_STRING_CASE
191 } else if (schema.InRealOneof(field)) {
192 const OneofDescriptor* containing_oneof = field->containing_oneof();
193 const uint32_t* const oneof_case_array =
194 internal::GetConstPointerAtOffset<uint32_t>(
195 &message, schema.oneof_case_offset_);
196 // Equivalent to: !HasOneofField(message, field)
197 if (static_cast<int64_t>(oneof_case_array[containing_oneof->index()]) !=
198 field->number()) {
199 continue;
200 }
201 switch (field->type()) {
202 #define PROTOBUF_HANDLE_CASE(TYPE, NAME) \
203 case FieldDescriptor::TYPE_##TYPE: \
204 func(internal::NAME##DynamicFieldInfo<MessageT, true>{reflection, message, \
205 field}); \
206 break;
207 PROTOBUF_HANDLE_CASE(DOUBLE, Double);
208 PROTOBUF_HANDLE_CASE(FLOAT, Float);
209 PROTOBUF_HANDLE_CASE(INT64, Int64);
210 PROTOBUF_HANDLE_CASE(UINT64, UInt64);
211 PROTOBUF_HANDLE_CASE(INT32, Int32);
212 PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
213 PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
214 PROTOBUF_HANDLE_CASE(BOOL, Bool);
215 PROTOBUF_HANDLE_CASE(UINT32, UInt32);
216 PROTOBUF_HANDLE_CASE(ENUM, Enum);
217 PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
218 PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
219 PROTOBUF_HANDLE_CASE(SINT32, SInt32);
220 PROTOBUF_HANDLE_CASE(SINT64, SInt64);
221
222 case FieldDescriptor::TYPE_MESSAGE:
223 case FieldDescriptor::TYPE_GROUP:
224 func(internal::MessageDynamicFieldInfo<MessageT, true>{
225 reflection, message, field});
226 break;
227
228 case FieldDescriptor::TYPE_BYTES:
229 case FieldDescriptor::TYPE_STRING: {
230 auto ctype = cpp::EffectiveStringCType(field);
231 if (ctype == FieldOptions::CORD) {
232 func(CordDynamicFieldInfo<MessageT, true>{reflection, message,
233 field});
234 } else {
235 func(StringDynamicFieldInfo<MessageT, true>{reflection, message,
236 field});
237 }
238 break;
239 }
240 default:
241 internal::Unreachable();
242 break;
243 #undef PROTOBUF_HANDLE_CASE
244 }
245 } else {
246 auto index = has_bits_indices[i];
247 bool check_hasbits = has_bits && index != static_cast<uint32_t>(-1);
248 if (PROTOBUF_PREDICT_TRUE(check_hasbits)) {
249 if ((has_bits[index / 32] & (1u << (index % 32))) == 0) continue;
250 } else {
251 // Skip if it has default values.
252 if (!reflection->HasFieldSingular(message, field)) continue;
253 }
254 switch (field->type()) {
255 #define PROTOBUF_HANDLE_CASE(TYPE, NAME) \
256 case FieldDescriptor::TYPE_##TYPE: \
257 func(internal::NAME##DynamicFieldInfo<MessageT, false>{reflection, \
258 message, field}); \
259 break;
260 PROTOBUF_HANDLE_CASE(DOUBLE, Double);
261 PROTOBUF_HANDLE_CASE(FLOAT, Float);
262 PROTOBUF_HANDLE_CASE(INT64, Int64);
263 PROTOBUF_HANDLE_CASE(UINT64, UInt64);
264 PROTOBUF_HANDLE_CASE(INT32, Int32);
265 PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
266 PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
267 PROTOBUF_HANDLE_CASE(BOOL, Bool);
268 PROTOBUF_HANDLE_CASE(UINT32, UInt32);
269 PROTOBUF_HANDLE_CASE(ENUM, Enum);
270 PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
271 PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
272 PROTOBUF_HANDLE_CASE(SINT32, SInt32);
273 PROTOBUF_HANDLE_CASE(SINT64, SInt64);
274
275 case FieldDescriptor::TYPE_MESSAGE:
276 case FieldDescriptor::TYPE_GROUP:
277 func(internal::MessageDynamicFieldInfo<MessageT, false>{
278 reflection, message, field});
279 break;
280 case FieldDescriptor::TYPE_BYTES:
281 case FieldDescriptor::TYPE_STRING: {
282 auto ctype = cpp::EffectiveStringCType(field);
283 if (ctype == FieldOptions::CORD) {
284 func(CordDynamicFieldInfo<MessageT, false>{reflection, message,
285 field});
286 } else {
287 func(StringDynamicFieldInfo<MessageT, false>{reflection, message,
288 field});
289 }
290 break;
291 }
292 default:
293 internal::Unreachable();
294 break;
295 #undef PROTOBUF_HANDLE_CASE
296 }
297 }
298 }
299
300 if (!schema.HasExtensionSet()) return;
301
302 auto& set = ExtensionSet(reflection, message);
303 auto* extendee = reflection->descriptor_;
304 auto* pool = reflection->descriptor_pool_;
305
306 set.ForEach(
307 [&](int number, auto& ext) {
308 ABSL_DCHECK_GT(ext.type, 0);
309 ABSL_DCHECK_LE(ext.type, FieldDescriptor::MAX_TYPE);
310
311 if (!ShouldVisit(mask,
312 FieldDescriptor::TypeToCppType(
313 static_cast<FieldDescriptor::Type>(ext.type)))) {
314 return;
315 }
316
317 if (ext.is_repeated) {
318 if (ext.GetSize() == 0) return;
319
320 switch (ext.type) {
321 #define PROTOBUF_HANDLE_CASE(TYPE, NAME) \
322 case FieldDescriptor::TYPE_##TYPE: \
323 func(internal::Repeated##NAME##DynamicExtensionInfo<decltype(ext)>{ \
324 ext, number}); \
325 break;
326 PROTOBUF_HANDLE_CASE(DOUBLE, Double);
327 PROTOBUF_HANDLE_CASE(FLOAT, Float);
328 PROTOBUF_HANDLE_CASE(INT64, Int64);
329 PROTOBUF_HANDLE_CASE(UINT64, UInt64);
330 PROTOBUF_HANDLE_CASE(INT32, Int32);
331 PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
332 PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
333 PROTOBUF_HANDLE_CASE(BOOL, Bool);
334 PROTOBUF_HANDLE_CASE(UINT32, UInt32);
335 PROTOBUF_HANDLE_CASE(ENUM, Enum);
336 PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
337 PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
338 PROTOBUF_HANDLE_CASE(SINT32, SInt32);
339 PROTOBUF_HANDLE_CASE(SINT64, SInt64);
340
341 PROTOBUF_HANDLE_CASE(MESSAGE, Message);
342 PROTOBUF_HANDLE_CASE(GROUP, Group);
343
344 case FieldDescriptor::TYPE_BYTES:
345 case FieldDescriptor::TYPE_STRING:
346 func(internal::RepeatedStringDynamicExtensionInfo<decltype(ext)>{
347 ext, number});
348 break;
349 default:
350 internal::Unreachable();
351 break;
352 #undef PROTOBUF_HANDLE_CASE
353 }
354 } else {
355 if (ext.is_cleared) return;
356
357 switch (ext.type) {
358 #define PROTOBUF_HANDLE_CASE(TYPE, NAME) \
359 case FieldDescriptor::TYPE_##TYPE: \
360 func(internal::NAME##DynamicExtensionInfo<decltype(ext)>{ext, number}); \
361 break;
362 PROTOBUF_HANDLE_CASE(DOUBLE, Double);
363 PROTOBUF_HANDLE_CASE(FLOAT, Float);
364 PROTOBUF_HANDLE_CASE(INT64, Int64);
365 PROTOBUF_HANDLE_CASE(UINT64, UInt64);
366 PROTOBUF_HANDLE_CASE(INT32, Int32);
367 PROTOBUF_HANDLE_CASE(FIXED64, Fixed64);
368 PROTOBUF_HANDLE_CASE(FIXED32, Fixed32);
369 PROTOBUF_HANDLE_CASE(BOOL, Bool);
370 PROTOBUF_HANDLE_CASE(UINT32, UInt32);
371 PROTOBUF_HANDLE_CASE(ENUM, Enum);
372 PROTOBUF_HANDLE_CASE(SFIXED32, SFixed32);
373 PROTOBUF_HANDLE_CASE(SFIXED64, SFixed64);
374 PROTOBUF_HANDLE_CASE(SINT32, SInt32);
375 PROTOBUF_HANDLE_CASE(SINT64, SInt64);
376
377 PROTOBUF_HANDLE_CASE(GROUP, Group);
378 case FieldDescriptor::TYPE_MESSAGE: {
379 const FieldDescriptor* field =
380 ext.descriptor != nullptr
381 ? ext.descriptor
382 : pool->FindExtensionByNumber(extendee, number);
383 ABSL_DCHECK_EQ(field->number(), number);
384 bool is_mset =
385 field->containing_type()->options().message_set_wire_format();
386 func(internal::MessageDynamicExtensionInfo<decltype(ext)>{
387 ext, number, is_mset});
388 break;
389 }
390
391 case FieldDescriptor::TYPE_BYTES:
392 case FieldDescriptor::TYPE_STRING:
393 func(internal::StringDynamicExtensionInfo<decltype(ext)>{ext,
394 number});
395 break;
396
397 default:
398 internal::Unreachable();
399 break;
400 #undef PROTOBUF_HANDLE_CASE
401 }
402 }
403 },
404 ExtensionSet::Prefetch{});
405 }
406
407 template <typename CallbackFn>
VisitMessageFields(const Message & message,CallbackFn && func)408 void ReflectionVisit::VisitMessageFields(const Message& message,
409 CallbackFn&& func) {
410 ReflectionVisit::VisitFields(
411 message,
412 [&](auto info) {
413 if constexpr (info.is_map) {
414 auto value_type = info.value_type();
415 if (value_type != FieldDescriptor::TYPE_MESSAGE &&
416 value_type != FieldDescriptor::TYPE_GROUP) {
417 return;
418 }
419 info.VisitElements([&](auto key, auto val) {
420 if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
421 func(val.Get());
422 }
423 });
424 } else if constexpr (info.cpp_type ==
425 FieldDescriptor::CPPTYPE_MESSAGE) {
426 if constexpr (info.is_repeated) {
427 for (const auto& it : info.Get()) {
428 func(DownCastMessage<Message>(it));
429 }
430 } else {
431 func(info.Get());
432 }
433 }
434 },
435 FieldMask::kMessage);
436 }
437
438 template <typename CallbackFn>
VisitMessageFields(Message & message,CallbackFn && func)439 void ReflectionVisit::VisitMessageFields(Message& message, CallbackFn&& func) {
440 ReflectionVisit::VisitFields(
441 message,
442 [&](auto info) {
443 if constexpr (info.is_map) {
444 auto value_type = info.value_type();
445 if (value_type != FieldDescriptor::TYPE_MESSAGE &&
446 value_type != FieldDescriptor::TYPE_GROUP) {
447 return;
448 }
449 info.VisitElements([&](auto key, auto val) {
450 if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
451 func(*val.Mutable());
452 }
453 });
454 } else if constexpr (info.cpp_type ==
455 FieldDescriptor::CPPTYPE_MESSAGE) {
456 if constexpr (info.is_repeated) {
457 for (auto& it : info.Mutable()) {
458 func(DownCastMessage<Message>(it));
459 }
460 } else {
461 func(info.Mutable());
462 }
463 }
464 },
465 FieldMask::kMessage);
466 }
467
468 // Visits present fields of "message" and calls the callback function "func".
469 // Skips fields whose ctypes are missing in "mask".
470 template <typename MessageT, typename CallbackFn>
VisitFields(MessageT & message,CallbackFn && func,FieldMask mask)471 void VisitFields(MessageT& message, CallbackFn&& func, FieldMask mask) {
472 ReflectionVisit::VisitFields(message, std::forward<CallbackFn>(func), mask);
473 }
474
475 // Visits message fields of "message" and calls "func". Expects "func" to
476 // accept const Message&. Note the following divergence from VisitFields.
477 //
478 // --Each of N elements of a repeated message field is visited (total N).
479 // --Each of M elements of a map field whose value type is message are visited
480 // (total M).
481 // --A map field whose value type is not message is ignored.
482 //
483 // This is a helper API built on top of VisitFields to hide specifics about
484 // extensions, repeated fields, etc.
485 template <typename CallbackFn>
VisitMessageFields(const Message & message,CallbackFn && func)486 void VisitMessageFields(const Message& message, CallbackFn&& func) {
487 ReflectionVisit::VisitMessageFields(message, std::forward<CallbackFn>(func));
488 }
489
490 // Same as VisitMessageFields above but expects "func" to accept Message&. This
491 // is useful when mutable access is required. As mutable access can be
492 // expensive, use it only if it's necessary.
493 template <typename CallbackFn>
VisitMutableMessageFields(Message & message,CallbackFn && func)494 void VisitMutableMessageFields(Message& message, CallbackFn&& func) {
495 ReflectionVisit::VisitMessageFields(message, std::forward<CallbackFn>(func));
496 }
497
498 #endif // __cpp_if_constexpr
499
500 } // namespace internal
501 } // namespace protobuf
502 } // namespace google
503
504 #include "google/protobuf/port_undef.inc"
505
506 #endif // GOOGLE_PROTOBUF_REFLECTION_VISIT_FIELDS_H__
507