• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: kenton@google.com (Kenton Varda)
9 //  Based on original Protocol Buffers design by
10 //  Sanjay Ghemawat, Jeff Dean, and others.
11 #include "google/protobuf/reflection_ops.h"
12 
13 #include <string>
14 #include <vector>
15 
16 #include "absl/log/absl_check.h"
17 #include "absl/log/absl_log.h"
18 #include "absl/strings/str_cat.h"
19 #include "google/protobuf/descriptor.h"
20 #include "google/protobuf/descriptor.pb.h"
21 #include "google/protobuf/map_field.h"
22 #include "google/protobuf/map_field_inl.h"
23 #include "google/protobuf/unknown_field_set.h"
24 
25 // Must be included last.
26 #include "google/protobuf/port_def.inc"
27 
28 namespace google {
29 namespace protobuf {
30 namespace internal {
31 
GetReflectionOrDie(const Message & m)32 static const Reflection* GetReflectionOrDie(const Message& m) {
33   const Reflection* r = m.GetReflection();
34   if (r == nullptr) {
35     const Descriptor* d = m.GetDescriptor();
36     // RawMessage is one known type for which GetReflection() returns nullptr.
37     ABSL_LOG(FATAL) << "Message does not support reflection (type "
38                     << (d ? d->name() : "unknown") << ").";
39   }
40   return r;
41 }
42 
Copy(const Message & from,Message * to)43 void ReflectionOps::Copy(const Message& from, Message* to) {
44   if (&from == to) return;
45   Clear(to);
46   Merge(from, to);
47 }
48 
Merge(const Message & from,Message * to)49 void ReflectionOps::Merge(const Message& from, Message* to) {
50   ABSL_CHECK_NE(&from, to);
51 
52   const Descriptor* descriptor = from.GetDescriptor();
53   ABSL_CHECK_EQ(to->GetDescriptor(), descriptor)
54       << "Tried to merge messages of different types "
55       << "(merge " << descriptor->full_name() << " to "
56       << to->GetDescriptor()->full_name() << ")";
57 
58   const Reflection* from_reflection = GetReflectionOrDie(from);
59   const Reflection* to_reflection = GetReflectionOrDie(*to);
60   bool is_from_generated = (from_reflection->GetMessageFactory() ==
61                             google::protobuf::MessageFactory::generated_factory());
62   bool is_to_generated = (to_reflection->GetMessageFactory() ==
63                           google::protobuf::MessageFactory::generated_factory());
64 
65   std::vector<const FieldDescriptor*> fields;
66   from_reflection->ListFields(from, &fields);
67   for (const FieldDescriptor* field : fields) {
68     if (field->is_repeated()) {
69       // Use map reflection if both are in map status and have the
70       // same map type to avoid sync with repeated field.
71       // Note: As from and to messages have the same descriptor, the
72       // map field types are the same if they are both generated
73       // messages or both dynamic messages.
74       if (is_from_generated == is_to_generated && field->is_map()) {
75         const MapFieldBase* from_field =
76             from_reflection->GetMapData(from, field);
77         MapFieldBase* to_field = to_reflection->MutableMapData(to, field);
78         if (to_field->IsMapValid() && from_field->IsMapValid()) {
79           to_field->MergeFrom(*from_field);
80           continue;
81         }
82       }
83       int count = from_reflection->FieldSize(from, field);
84       for (int j = 0; j < count; j++) {
85         switch (field->cpp_type()) {
86 #define HANDLE_TYPE(CPPTYPE, METHOD)                                      \
87   case FieldDescriptor::CPPTYPE_##CPPTYPE:                                \
88     to_reflection->Add##METHOD(                                           \
89         to, field, from_reflection->GetRepeated##METHOD(from, field, j)); \
90     break;
91 
92           HANDLE_TYPE(INT32, Int32);
93           HANDLE_TYPE(INT64, Int64);
94           HANDLE_TYPE(UINT32, UInt32);
95           HANDLE_TYPE(UINT64, UInt64);
96           HANDLE_TYPE(FLOAT, Float);
97           HANDLE_TYPE(DOUBLE, Double);
98           HANDLE_TYPE(BOOL, Bool);
99           HANDLE_TYPE(STRING, String);
100           HANDLE_TYPE(ENUM, Enum);
101 #undef HANDLE_TYPE
102 
103           case FieldDescriptor::CPPTYPE_MESSAGE:
104             const Message& from_child =
105                 from_reflection->GetRepeatedMessage(from, field, j);
106             if (from_reflection == to_reflection) {
107               to_reflection
108                   ->AddMessage(to, field,
109                                from_child.GetReflection()->GetMessageFactory())
110                   ->MergeFrom(from_child);
111             } else {
112               to_reflection->AddMessage(to, field)->MergeFrom(from_child);
113             }
114             break;
115         }
116       }
117     } else {
118       switch (field->cpp_type()) {
119 #define HANDLE_TYPE(CPPTYPE, METHOD)                                       \
120   case FieldDescriptor::CPPTYPE_##CPPTYPE:                                 \
121     to_reflection->Set##METHOD(to, field,                                  \
122                                from_reflection->Get##METHOD(from, field)); \
123     break;
124 
125         HANDLE_TYPE(INT32, Int32);
126         HANDLE_TYPE(INT64, Int64);
127         HANDLE_TYPE(UINT32, UInt32);
128         HANDLE_TYPE(UINT64, UInt64);
129         HANDLE_TYPE(FLOAT, Float);
130         HANDLE_TYPE(DOUBLE, Double);
131         HANDLE_TYPE(BOOL, Bool);
132         HANDLE_TYPE(STRING, String);
133         HANDLE_TYPE(ENUM, Enum);
134 #undef HANDLE_TYPE
135 
136         case FieldDescriptor::CPPTYPE_MESSAGE:
137           const Message& from_child = from_reflection->GetMessage(from, field);
138           if (from_reflection == to_reflection) {
139             to_reflection
140                 ->MutableMessage(
141                     to, field, from_child.GetReflection()->GetMessageFactory())
142                 ->MergeFrom(from_child);
143           } else {
144             to_reflection->MutableMessage(to, field)->MergeFrom(from_child);
145           }
146           break;
147       }
148     }
149   }
150 
151   if (!from_reflection->GetUnknownFields(from).empty()) {
152     to_reflection->MutableUnknownFields(to)->MergeFrom(
153         from_reflection->GetUnknownFields(from));
154   }
155 }
156 
Clear(Message * message)157 void ReflectionOps::Clear(Message* message) {
158   const Reflection* reflection = GetReflectionOrDie(*message);
159 
160   std::vector<const FieldDescriptor*> fields;
161   reflection->ListFields(*message, &fields);
162   for (const FieldDescriptor* field : fields) {
163     reflection->ClearField(message, field);
164   }
165 
166   if (reflection->GetInternalMetadata(*message).have_unknown_fields()) {
167     reflection->MutableUnknownFields(message)->Clear();
168   }
169 }
170 
IsInitialized(const Message & message,bool check_fields,bool check_descendants)171 bool ReflectionOps::IsInitialized(const Message& message, bool check_fields,
172                                   bool check_descendants) {
173   const Descriptor* descriptor = message.GetDescriptor();
174   const Reflection* reflection = GetReflectionOrDie(message);
175   if (const int field_count = descriptor->field_count()) {
176     const FieldDescriptor* begin = descriptor->field(0);
177     const FieldDescriptor* end = begin + field_count;
178     ABSL_DCHECK_EQ(descriptor->field(field_count - 1), end - 1);
179 
180     if (check_fields) {
181       // Check required fields of this message.
182       for (const FieldDescriptor* field = begin; field != end; ++field) {
183         if (field->is_required() && !reflection->HasField(message, field)) {
184           return false;
185         }
186       }
187     }
188 
189     if (check_descendants) {
190       for (const FieldDescriptor* field = begin; field != end; ++field) {
191         if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
192           const Descriptor* message_type = field->message_type();
193           if (PROTOBUF_PREDICT_FALSE(message_type->options().map_entry())) {
194             if (message_type->field(1)->cpp_type() ==
195                 FieldDescriptor::CPPTYPE_MESSAGE) {
196               const MapFieldBase* map_field =
197                   reflection->GetMapData(message, field);
198               if (map_field->IsMapValid()) {
199                 MapIterator it(const_cast<Message*>(&message), field);
200                 MapIterator end_map(const_cast<Message*>(&message), field);
201                 for (map_field->MapBegin(&it), map_field->MapEnd(&end_map);
202                      it != end_map; ++it) {
203                   if (!it.GetValueRef().GetMessageValue().IsInitialized()) {
204                     return false;
205                   }
206                 }
207               }
208             }
209           } else if (field->is_repeated()) {
210             const int size = reflection->FieldSize(message, field);
211             for (int j = 0; j < size; j++) {
212               if (!reflection->GetRepeatedMessage(message, field, j)
213                        .IsInitialized()) {
214                 return false;
215               }
216             }
217           } else if (reflection->HasField(message, field)) {
218             if (!reflection->GetMessage(message, field).IsInitialized()) {
219               return false;
220             }
221           }
222         }
223       }
224     }
225   }
226   if (check_descendants && reflection->HasExtensionSet(message)) {
227     // Note that "extendee" is only referenced if the extension is lazily parsed
228     // (e.g. LazyMessageExtensionImpl), which requires a verification function
229     // to be generated.
230     //
231     // Dynamic messages would get null prototype from the generated message
232     // factory but their verification functions are not generated. Therefore, it
233     // it will always be eagerly parsed and "extendee" here will not be
234     // referenced.
235     const Message* extendee =
236         MessageFactory::generated_factory()->GetPrototype(descriptor);
237     if (!reflection->GetExtensionSet(message).IsInitialized(extendee)) {
238       return false;
239     }
240   }
241   return true;
242 }
243 
IsInitialized(const Message & message)244 bool ReflectionOps::IsInitialized(const Message& message) {
245   const Descriptor* descriptor = message.GetDescriptor();
246   const Reflection* reflection = GetReflectionOrDie(message);
247 
248   // Check required fields of this message.
249   {
250     const int field_count = descriptor->field_count();
251     for (int i = 0; i < field_count; i++) {
252       if (descriptor->field(i)->is_required()) {
253         if (!reflection->HasField(message, descriptor->field(i))) {
254           return false;
255         }
256       }
257     }
258   }
259 
260   // Check that sub-messages are initialized.
261   std::vector<const FieldDescriptor*> fields;
262   // Should be safe to skip stripped fields because required fields are not
263   // stripped.
264   if (descriptor->options().map_entry()) {
265     // MapEntry objects always check the value regardless of has bit.
266     // We don't need to bother with the key.
267     fields = {descriptor->map_value()};
268   } else {
269     reflection->ListFields(message, &fields);
270   }
271   for (const FieldDescriptor* field : fields) {
272     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
273 
274       if (field->is_map()) {
275         const FieldDescriptor* value_field = field->message_type()->field(1);
276         if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
277           const MapFieldBase* map_field =
278               reflection->GetMapData(message, field);
279           if (map_field->IsMapValid()) {
280             MapIterator iter(const_cast<Message*>(&message), field);
281             MapIterator end(const_cast<Message*>(&message), field);
282             for (map_field->MapBegin(&iter), map_field->MapEnd(&end);
283                  iter != end; ++iter) {
284               if (!iter.GetValueRef().GetMessageValue().IsInitialized()) {
285                 return false;
286               }
287             }
288             continue;
289           }
290         } else {
291           continue;
292         }
293       }
294 
295       if (field->is_repeated()) {
296         int size = reflection->FieldSize(message, field);
297 
298         for (int j = 0; j < size; j++) {
299           if (!reflection->GetRepeatedMessage(message, field, j)
300                    .IsInitialized()) {
301             return false;
302           }
303         }
304       } else {
305         if (!reflection->GetMessage(message, field).IsInitialized()) {
306           return false;
307         }
308       }
309     }
310   }
311 
312   return true;
313 }
314 
IsMapValueMessageTyped(const FieldDescriptor * map_field)315 static bool IsMapValueMessageTyped(const FieldDescriptor* map_field) {
316   return map_field->message_type()->field(1)->cpp_type() ==
317          FieldDescriptor::CPPTYPE_MESSAGE;
318 }
319 
DiscardUnknownFields(Message * message)320 void ReflectionOps::DiscardUnknownFields(Message* message) {
321   const Reflection* reflection = GetReflectionOrDie(*message);
322 
323   reflection->MutableUnknownFields(message)->Clear();
324 
325   // Walk through the fields of this message and DiscardUnknownFields on any
326   // messages present.
327   std::vector<const FieldDescriptor*> fields;
328   reflection->ListFields(*message, &fields);
329   for (const FieldDescriptor* field : fields) {
330     // Skip over non-message fields.
331     if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
332       continue;
333     }
334     // Discard the unknown fields in maps that contain message values.
335     const MapFieldBase* map_field =
336         field->is_map() ? reflection->MutableMapData(message, field) : nullptr;
337     if (map_field != nullptr && map_field->IsMapValid()) {
338       if (IsMapValueMessageTyped(field)) {
339         MapIterator iter(message, field);
340         MapIterator end(message, field);
341         for (map_field->MapBegin(&iter), map_field->MapEnd(&end); iter != end;
342              ++iter) {
343           iter.MutableValueRef()->MutableMessageValue()->DiscardUnknownFields();
344         }
345       }
346       // Discard every unknown field inside messages in a repeated field.
347     } else if (field->is_repeated()) {
348       int size = reflection->FieldSize(*message, field);
349       for (int j = 0; j < size; j++) {
350         reflection->MutableRepeatedMessage(message, field, j)
351             ->DiscardUnknownFields();
352       }
353       // Discard the unknown fields inside an optional message.
354     } else {
355       reflection->MutableMessage(message, field)->DiscardUnknownFields();
356     }
357   }
358 }
359 
SubMessagePrefix(const std::string & prefix,const FieldDescriptor * field,int index)360 static std::string SubMessagePrefix(const std::string& prefix,
361                                     const FieldDescriptor* field, int index) {
362   std::string result(prefix);
363   if (field->is_extension()) {
364     result.append("(");
365     result.append(field->full_name());
366     result.append(")");
367   } else {
368     result.append(field->name());
369   }
370   if (index != -1) {
371     result.append("[");
372     result.append(absl::StrCat(index));
373     result.append("]");
374   }
375   result.append(".");
376   return result;
377 }
378 
FindInitializationErrors(const Message & message,const std::string & prefix,std::vector<std::string> * errors)379 void ReflectionOps::FindInitializationErrors(const Message& message,
380                                              const std::string& prefix,
381                                              std::vector<std::string>* errors) {
382   const Descriptor* descriptor = message.GetDescriptor();
383   const Reflection* reflection = GetReflectionOrDie(message);
384 
385   // Check required fields of this message.
386   {
387     const int field_count = descriptor->field_count();
388     for (int i = 0; i < field_count; i++) {
389       if (descriptor->field(i)->is_required()) {
390         if (!reflection->HasField(message, descriptor->field(i))) {
391           errors->push_back(absl::StrCat(prefix, descriptor->field(i)->name()));
392         }
393       }
394     }
395   }
396 
397   // Check sub-messages.
398   std::vector<const FieldDescriptor*> fields;
399   reflection->ListFields(message, &fields);
400   for (const FieldDescriptor* field : fields) {
401     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
402 
403       if (field->is_repeated()) {
404         int size = reflection->FieldSize(message, field);
405 
406         for (int j = 0; j < size; j++) {
407           const Message& sub_message =
408               reflection->GetRepeatedMessage(message, field, j);
409           FindInitializationErrors(sub_message,
410                                    SubMessagePrefix(prefix, field, j), errors);
411         }
412       } else {
413         const Message& sub_message = reflection->GetMessage(message, field);
414         FindInitializationErrors(sub_message,
415                                  SubMessagePrefix(prefix, field, -1), errors);
416       }
417     }
418   }
419 }
420 
GenericSwap(Message * lhs,Message * rhs)421 void GenericSwap(Message* lhs, Message* rhs) {
422   if (!internal::DebugHardenForceCopyInSwap()) {
423     ABSL_DCHECK(lhs->GetArena() != rhs->GetArena());
424     ABSL_DCHECK(lhs->GetArena() != nullptr || rhs->GetArena() != nullptr);
425   }
426   // At least one of these must have an arena, so make `rhs` point to it.
427   Arena* arena = rhs->GetArena();
428   if (arena == nullptr) {
429     std::swap(lhs, rhs);
430     arena = rhs->GetArena();
431   }
432 
433   // Improve efficiency by placing the temporary on an arena so that messages
434   // are copied twice rather than three times.
435   Message* tmp = rhs->New(arena);
436   tmp->CheckTypeAndMergeFrom(*lhs);
437   lhs->Clear();
438   lhs->CheckTypeAndMergeFrom(*rhs);
439   if (internal::DebugHardenForceCopyInSwap()) {
440     rhs->Clear();
441     rhs->CheckTypeAndMergeFrom(*tmp);
442     if (arena == nullptr) delete tmp;
443   } else {
444     rhs->GetReflection()->Swap(tmp, rhs);
445   }
446 }
447 
448 }  // namespace internal
449 }  // namespace protobuf
450 }  // namespace google
451 
452 #include "google/protobuf/port_undef.inc"
453