1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 // Author: kenton@google.com (Kenton Varda)
9 // Based on original Protocol Buffers design by
10 // Sanjay Ghemawat, Jeff Dean, and others.
11 #include "google/protobuf/reflection_ops.h"
12
13 #include <string>
14 #include <vector>
15
16 #include "absl/log/absl_check.h"
17 #include "absl/log/absl_log.h"
18 #include "absl/strings/str_cat.h"
19 #include "google/protobuf/descriptor.h"
20 #include "google/protobuf/descriptor.pb.h"
21 #include "google/protobuf/map_field.h"
22 #include "google/protobuf/map_field_inl.h"
23 #include "google/protobuf/unknown_field_set.h"
24
25 // Must be included last.
26 #include "google/protobuf/port_def.inc"
27
28 namespace google {
29 namespace protobuf {
30 namespace internal {
31
GetReflectionOrDie(const Message & m)32 static const Reflection* GetReflectionOrDie(const Message& m) {
33 const Reflection* r = m.GetReflection();
34 if (r == nullptr) {
35 const Descriptor* d = m.GetDescriptor();
36 // RawMessage is one known type for which GetReflection() returns nullptr.
37 ABSL_LOG(FATAL) << "Message does not support reflection (type "
38 << (d ? d->name() : "unknown") << ").";
39 }
40 return r;
41 }
42
Copy(const Message & from,Message * to)43 void ReflectionOps::Copy(const Message& from, Message* to) {
44 if (&from == to) return;
45 Clear(to);
46 Merge(from, to);
47 }
48
Merge(const Message & from,Message * to)49 void ReflectionOps::Merge(const Message& from, Message* to) {
50 ABSL_CHECK_NE(&from, to);
51
52 const Descriptor* descriptor = from.GetDescriptor();
53 ABSL_CHECK_EQ(to->GetDescriptor(), descriptor)
54 << "Tried to merge messages of different types "
55 << "(merge " << descriptor->full_name() << " to "
56 << to->GetDescriptor()->full_name() << ")";
57
58 const Reflection* from_reflection = GetReflectionOrDie(from);
59 const Reflection* to_reflection = GetReflectionOrDie(*to);
60 bool is_from_generated = (from_reflection->GetMessageFactory() ==
61 google::protobuf::MessageFactory::generated_factory());
62 bool is_to_generated = (to_reflection->GetMessageFactory() ==
63 google::protobuf::MessageFactory::generated_factory());
64
65 std::vector<const FieldDescriptor*> fields;
66 from_reflection->ListFields(from, &fields);
67 for (const FieldDescriptor* field : fields) {
68 if (field->is_repeated()) {
69 // Use map reflection if both are in map status and have the
70 // same map type to avoid sync with repeated field.
71 // Note: As from and to messages have the same descriptor, the
72 // map field types are the same if they are both generated
73 // messages or both dynamic messages.
74 if (is_from_generated == is_to_generated && field->is_map()) {
75 const MapFieldBase* from_field =
76 from_reflection->GetMapData(from, field);
77 MapFieldBase* to_field = to_reflection->MutableMapData(to, field);
78 if (to_field->IsMapValid() && from_field->IsMapValid()) {
79 to_field->MergeFrom(*from_field);
80 continue;
81 }
82 }
83 int count = from_reflection->FieldSize(from, field);
84 for (int j = 0; j < count; j++) {
85 switch (field->cpp_type()) {
86 #define HANDLE_TYPE(CPPTYPE, METHOD) \
87 case FieldDescriptor::CPPTYPE_##CPPTYPE: \
88 to_reflection->Add##METHOD( \
89 to, field, from_reflection->GetRepeated##METHOD(from, field, j)); \
90 break;
91
92 HANDLE_TYPE(INT32, Int32);
93 HANDLE_TYPE(INT64, Int64);
94 HANDLE_TYPE(UINT32, UInt32);
95 HANDLE_TYPE(UINT64, UInt64);
96 HANDLE_TYPE(FLOAT, Float);
97 HANDLE_TYPE(DOUBLE, Double);
98 HANDLE_TYPE(BOOL, Bool);
99 HANDLE_TYPE(STRING, String);
100 HANDLE_TYPE(ENUM, Enum);
101 #undef HANDLE_TYPE
102
103 case FieldDescriptor::CPPTYPE_MESSAGE:
104 const Message& from_child =
105 from_reflection->GetRepeatedMessage(from, field, j);
106 if (from_reflection == to_reflection) {
107 to_reflection
108 ->AddMessage(to, field,
109 from_child.GetReflection()->GetMessageFactory())
110 ->MergeFrom(from_child);
111 } else {
112 to_reflection->AddMessage(to, field)->MergeFrom(from_child);
113 }
114 break;
115 }
116 }
117 } else {
118 switch (field->cpp_type()) {
119 #define HANDLE_TYPE(CPPTYPE, METHOD) \
120 case FieldDescriptor::CPPTYPE_##CPPTYPE: \
121 to_reflection->Set##METHOD(to, field, \
122 from_reflection->Get##METHOD(from, field)); \
123 break;
124
125 HANDLE_TYPE(INT32, Int32);
126 HANDLE_TYPE(INT64, Int64);
127 HANDLE_TYPE(UINT32, UInt32);
128 HANDLE_TYPE(UINT64, UInt64);
129 HANDLE_TYPE(FLOAT, Float);
130 HANDLE_TYPE(DOUBLE, Double);
131 HANDLE_TYPE(BOOL, Bool);
132 HANDLE_TYPE(STRING, String);
133 HANDLE_TYPE(ENUM, Enum);
134 #undef HANDLE_TYPE
135
136 case FieldDescriptor::CPPTYPE_MESSAGE:
137 const Message& from_child = from_reflection->GetMessage(from, field);
138 if (from_reflection == to_reflection) {
139 to_reflection
140 ->MutableMessage(
141 to, field, from_child.GetReflection()->GetMessageFactory())
142 ->MergeFrom(from_child);
143 } else {
144 to_reflection->MutableMessage(to, field)->MergeFrom(from_child);
145 }
146 break;
147 }
148 }
149 }
150
151 if (!from_reflection->GetUnknownFields(from).empty()) {
152 to_reflection->MutableUnknownFields(to)->MergeFrom(
153 from_reflection->GetUnknownFields(from));
154 }
155 }
156
Clear(Message * message)157 void ReflectionOps::Clear(Message* message) {
158 const Reflection* reflection = GetReflectionOrDie(*message);
159
160 std::vector<const FieldDescriptor*> fields;
161 reflection->ListFields(*message, &fields);
162 for (const FieldDescriptor* field : fields) {
163 reflection->ClearField(message, field);
164 }
165
166 if (reflection->GetInternalMetadata(*message).have_unknown_fields()) {
167 reflection->MutableUnknownFields(message)->Clear();
168 }
169 }
170
IsInitialized(const Message & message,bool check_fields,bool check_descendants)171 bool ReflectionOps::IsInitialized(const Message& message, bool check_fields,
172 bool check_descendants) {
173 const Descriptor* descriptor = message.GetDescriptor();
174 const Reflection* reflection = GetReflectionOrDie(message);
175 if (const int field_count = descriptor->field_count()) {
176 const FieldDescriptor* begin = descriptor->field(0);
177 const FieldDescriptor* end = begin + field_count;
178 ABSL_DCHECK_EQ(descriptor->field(field_count - 1), end - 1);
179
180 if (check_fields) {
181 // Check required fields of this message.
182 for (const FieldDescriptor* field = begin; field != end; ++field) {
183 if (field->is_required() && !reflection->HasField(message, field)) {
184 return false;
185 }
186 }
187 }
188
189 if (check_descendants) {
190 for (const FieldDescriptor* field = begin; field != end; ++field) {
191 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
192 const Descriptor* message_type = field->message_type();
193 if (PROTOBUF_PREDICT_FALSE(message_type->options().map_entry())) {
194 if (message_type->field(1)->cpp_type() ==
195 FieldDescriptor::CPPTYPE_MESSAGE) {
196 const MapFieldBase* map_field =
197 reflection->GetMapData(message, field);
198 if (map_field->IsMapValid()) {
199 MapIterator it(const_cast<Message*>(&message), field);
200 MapIterator end_map(const_cast<Message*>(&message), field);
201 for (map_field->MapBegin(&it), map_field->MapEnd(&end_map);
202 it != end_map; ++it) {
203 if (!it.GetValueRef().GetMessageValue().IsInitialized()) {
204 return false;
205 }
206 }
207 }
208 }
209 } else if (field->is_repeated()) {
210 const int size = reflection->FieldSize(message, field);
211 for (int j = 0; j < size; j++) {
212 if (!reflection->GetRepeatedMessage(message, field, j)
213 .IsInitialized()) {
214 return false;
215 }
216 }
217 } else if (reflection->HasField(message, field)) {
218 if (!reflection->GetMessage(message, field).IsInitialized()) {
219 return false;
220 }
221 }
222 }
223 }
224 }
225 }
226 if (check_descendants && reflection->HasExtensionSet(message)) {
227 // Note that "extendee" is only referenced if the extension is lazily parsed
228 // (e.g. LazyMessageExtensionImpl), which requires a verification function
229 // to be generated.
230 //
231 // Dynamic messages would get null prototype from the generated message
232 // factory but their verification functions are not generated. Therefore, it
233 // it will always be eagerly parsed and "extendee" here will not be
234 // referenced.
235 const Message* extendee =
236 MessageFactory::generated_factory()->GetPrototype(descriptor);
237 if (!reflection->GetExtensionSet(message).IsInitialized(extendee)) {
238 return false;
239 }
240 }
241 return true;
242 }
243
IsInitialized(const Message & message)244 bool ReflectionOps::IsInitialized(const Message& message) {
245 const Descriptor* descriptor = message.GetDescriptor();
246 const Reflection* reflection = GetReflectionOrDie(message);
247
248 // Check required fields of this message.
249 {
250 const int field_count = descriptor->field_count();
251 for (int i = 0; i < field_count; i++) {
252 if (descriptor->field(i)->is_required()) {
253 if (!reflection->HasField(message, descriptor->field(i))) {
254 return false;
255 }
256 }
257 }
258 }
259
260 // Check that sub-messages are initialized.
261 std::vector<const FieldDescriptor*> fields;
262 // Should be safe to skip stripped fields because required fields are not
263 // stripped.
264 if (descriptor->options().map_entry()) {
265 // MapEntry objects always check the value regardless of has bit.
266 // We don't need to bother with the key.
267 fields = {descriptor->map_value()};
268 } else {
269 reflection->ListFields(message, &fields);
270 }
271 for (const FieldDescriptor* field : fields) {
272 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
273
274 if (field->is_map()) {
275 const FieldDescriptor* value_field = field->message_type()->field(1);
276 if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
277 const MapFieldBase* map_field =
278 reflection->GetMapData(message, field);
279 if (map_field->IsMapValid()) {
280 MapIterator iter(const_cast<Message*>(&message), field);
281 MapIterator end(const_cast<Message*>(&message), field);
282 for (map_field->MapBegin(&iter), map_field->MapEnd(&end);
283 iter != end; ++iter) {
284 if (!iter.GetValueRef().GetMessageValue().IsInitialized()) {
285 return false;
286 }
287 }
288 continue;
289 }
290 } else {
291 continue;
292 }
293 }
294
295 if (field->is_repeated()) {
296 int size = reflection->FieldSize(message, field);
297
298 for (int j = 0; j < size; j++) {
299 if (!reflection->GetRepeatedMessage(message, field, j)
300 .IsInitialized()) {
301 return false;
302 }
303 }
304 } else {
305 if (!reflection->GetMessage(message, field).IsInitialized()) {
306 return false;
307 }
308 }
309 }
310 }
311
312 return true;
313 }
314
IsMapValueMessageTyped(const FieldDescriptor * map_field)315 static bool IsMapValueMessageTyped(const FieldDescriptor* map_field) {
316 return map_field->message_type()->field(1)->cpp_type() ==
317 FieldDescriptor::CPPTYPE_MESSAGE;
318 }
319
DiscardUnknownFields(Message * message)320 void ReflectionOps::DiscardUnknownFields(Message* message) {
321 const Reflection* reflection = GetReflectionOrDie(*message);
322
323 reflection->MutableUnknownFields(message)->Clear();
324
325 // Walk through the fields of this message and DiscardUnknownFields on any
326 // messages present.
327 std::vector<const FieldDescriptor*> fields;
328 reflection->ListFields(*message, &fields);
329 for (const FieldDescriptor* field : fields) {
330 // Skip over non-message fields.
331 if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
332 continue;
333 }
334 // Discard the unknown fields in maps that contain message values.
335 const MapFieldBase* map_field =
336 field->is_map() ? reflection->MutableMapData(message, field) : nullptr;
337 if (map_field != nullptr && map_field->IsMapValid()) {
338 if (IsMapValueMessageTyped(field)) {
339 MapIterator iter(message, field);
340 MapIterator end(message, field);
341 for (map_field->MapBegin(&iter), map_field->MapEnd(&end); iter != end;
342 ++iter) {
343 iter.MutableValueRef()->MutableMessageValue()->DiscardUnknownFields();
344 }
345 }
346 // Discard every unknown field inside messages in a repeated field.
347 } else if (field->is_repeated()) {
348 int size = reflection->FieldSize(*message, field);
349 for (int j = 0; j < size; j++) {
350 reflection->MutableRepeatedMessage(message, field, j)
351 ->DiscardUnknownFields();
352 }
353 // Discard the unknown fields inside an optional message.
354 } else {
355 reflection->MutableMessage(message, field)->DiscardUnknownFields();
356 }
357 }
358 }
359
SubMessagePrefix(const std::string & prefix,const FieldDescriptor * field,int index)360 static std::string SubMessagePrefix(const std::string& prefix,
361 const FieldDescriptor* field, int index) {
362 std::string result(prefix);
363 if (field->is_extension()) {
364 result.append("(");
365 result.append(field->full_name());
366 result.append(")");
367 } else {
368 result.append(field->name());
369 }
370 if (index != -1) {
371 result.append("[");
372 result.append(absl::StrCat(index));
373 result.append("]");
374 }
375 result.append(".");
376 return result;
377 }
378
FindInitializationErrors(const Message & message,const std::string & prefix,std::vector<std::string> * errors)379 void ReflectionOps::FindInitializationErrors(const Message& message,
380 const std::string& prefix,
381 std::vector<std::string>* errors) {
382 const Descriptor* descriptor = message.GetDescriptor();
383 const Reflection* reflection = GetReflectionOrDie(message);
384
385 // Check required fields of this message.
386 {
387 const int field_count = descriptor->field_count();
388 for (int i = 0; i < field_count; i++) {
389 if (descriptor->field(i)->is_required()) {
390 if (!reflection->HasField(message, descriptor->field(i))) {
391 errors->push_back(absl::StrCat(prefix, descriptor->field(i)->name()));
392 }
393 }
394 }
395 }
396
397 // Check sub-messages.
398 std::vector<const FieldDescriptor*> fields;
399 reflection->ListFields(message, &fields);
400 for (const FieldDescriptor* field : fields) {
401 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
402
403 if (field->is_repeated()) {
404 int size = reflection->FieldSize(message, field);
405
406 for (int j = 0; j < size; j++) {
407 const Message& sub_message =
408 reflection->GetRepeatedMessage(message, field, j);
409 FindInitializationErrors(sub_message,
410 SubMessagePrefix(prefix, field, j), errors);
411 }
412 } else {
413 const Message& sub_message = reflection->GetMessage(message, field);
414 FindInitializationErrors(sub_message,
415 SubMessagePrefix(prefix, field, -1), errors);
416 }
417 }
418 }
419 }
420
GenericSwap(Message * lhs,Message * rhs)421 void GenericSwap(Message* lhs, Message* rhs) {
422 if (!internal::DebugHardenForceCopyInSwap()) {
423 ABSL_DCHECK(lhs->GetArena() != rhs->GetArena());
424 ABSL_DCHECK(lhs->GetArena() != nullptr || rhs->GetArena() != nullptr);
425 }
426 // At least one of these must have an arena, so make `rhs` point to it.
427 Arena* arena = rhs->GetArena();
428 if (arena == nullptr) {
429 std::swap(lhs, rhs);
430 arena = rhs->GetArena();
431 }
432
433 // Improve efficiency by placing the temporary on an arena so that messages
434 // are copied twice rather than three times.
435 Message* tmp = rhs->New(arena);
436 tmp->CheckTypeAndMergeFrom(*lhs);
437 lhs->Clear();
438 lhs->CheckTypeAndMergeFrom(*rhs);
439 if (internal::DebugHardenForceCopyInSwap()) {
440 rhs->Clear();
441 rhs->CheckTypeAndMergeFrom(*tmp);
442 if (arena == nullptr) delete tmp;
443 } else {
444 rhs->GetReflection()->Swap(tmp, rhs);
445 }
446 }
447
448 } // namespace internal
449 } // namespace protobuf
450 } // namespace google
451
452 #include "google/protobuf/port_undef.inc"
453