• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: anuraag@google.com (Anuraag Agrawal)
9 // Author: tibell@google.com (Johan Tibell)
10 
11 #include "google/protobuf/pyext/message.h"
12 
13 #include <Python.h>
14 #include <structmember.h>  // A Python header file.
15 
16 #include <cstdint>
17 #include <map>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/log/absl_check.h"
24 #include "absl/strings/match.h"
25 
26 #ifndef PyVarObject_HEAD_INIT
27 #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
28 #endif
29 #ifndef Py_TYPE
30 #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
31 #endif
32 #include "google/protobuf/stubs/common.h"
33 #include "google/protobuf/descriptor.pb.h"
34 #include "absl/strings/escaping.h"
35 #include "absl/strings/string_view.h"
36 #include "google/protobuf/descriptor.h"
37 #include "google/protobuf/io/coded_stream.h"
38 #include "google/protobuf/io/strtod.h"
39 #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
40 #include "google/protobuf/map_field.h"
41 #include "google/protobuf/message.h"
42 #include "google/protobuf/text_format.h"
43 #include "google/protobuf/unknown_field_set.h"
44 #include "google/protobuf/util/message_differencer.h"
45 #include "google/protobuf/pyext/descriptor.h"
46 #include "google/protobuf/pyext/descriptor_pool.h"
47 #include "google/protobuf/pyext/extension_dict.h"
48 #include "google/protobuf/pyext/field.h"
49 #include "google/protobuf/pyext/map_container.h"
50 #include "google/protobuf/pyext/message_factory.h"
51 #include "google/protobuf/pyext/repeated_composite_container.h"
52 #include "google/protobuf/pyext/repeated_scalar_container.h"
53 #include "google/protobuf/pyext/safe_numerics.h"
54 #include "google/protobuf/pyext/scoped_pyobject_ptr.h"
55 #include "google/protobuf/pyext/unknown_field_set.h"
56 
57 // clang-format off
58 #include "google/protobuf/port_def.inc"
59 // clang-format on
60 
61 #define PyString_AsString(ob) \
62   (PyUnicode_Check(ob) ? PyUnicode_AsUTF8(ob) : PyBytes_AsString(ob))
63 #define PyString_AsStringAndSize(ob, charpp, sizep)              \
64   (PyUnicode_Check(ob)                                           \
65        ? ((*(charpp) = const_cast<char*>(                        \
66                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
67               ? -1                                               \
68               : 0)                                               \
69        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
70 
71 #define PROTOBUF_PYTHON_PUBLIC "google.protobuf"
72 #define PROTOBUF_PYTHON_INTERNAL "google.protobuf.internal"
73 
74 namespace google {
75 namespace protobuf {
76 namespace python {
77 
78 class MessageReflectionFriend {
79  public:
UnsafeShallowSwapFields(Message * lhs,Message * rhs,const std::vector<const FieldDescriptor * > & fields)80   static void UnsafeShallowSwapFields(
81       Message* lhs, Message* rhs,
82       const std::vector<const FieldDescriptor*>& fields) {
83     lhs->GetReflection()->UnsafeShallowSwapFields(lhs, rhs, fields);
84   }
IsLazyField(const Reflection * reflection,const Message & message,const FieldDescriptor * field)85   static bool IsLazyField(const Reflection* reflection, const Message& message,
86                           const FieldDescriptor* field) {
87     return reflection->IsLazyField(field) ||
88            reflection->IsLazyExtension(message, field);
89   }
ContainsMapKey(const Reflection * reflection,const Message & message,const FieldDescriptor * field,const MapKey & map_key)90   static bool ContainsMapKey(const Reflection* reflection,
91                              const Message& message,
92                              const FieldDescriptor* field,
93                              const MapKey& map_key) {
94     return reflection->ContainsMapKey(message, field, map_key);
95   }
96 };
97 
98 static PyObject* kDESCRIPTOR;
99 static PyObject* kMessageFactory;
100 PyObject* EnumTypeWrapper_class;
101 static PyObject* PythonMessage_class;
102 static PyObject* kEmptyWeakref;
103 static PyObject* WKT_classes = nullptr;
104 
105 namespace message_meta {
106 
107 namespace {
108 // Copied over from internal 'google/protobuf/stubs/strutil.h'.
LowerString(std::string * s)109 inline void LowerString(std::string* s) {
110   std::string::iterator end = s->end();
111   for (std::string::iterator i = s->begin(); i != end; ++i) {
112     // tolower() changes based on locale.  We don't want this!
113     if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A';
114   }
115 }
116 }  // namespace
117 
118 // Finalize the creation of the Message class.
AddDescriptors(PyObject * cls,const Descriptor * descriptor)119 static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
120   // For each field set: cls.<field>_FIELD_NUMBER = <number>
121   for (int i = 0; i < descriptor->field_count(); ++i) {
122     const FieldDescriptor* field_descriptor = descriptor->field(i);
123     ScopedPyObjectPtr property(NewFieldProperty(field_descriptor));
124     if (property == nullptr) {
125       return -1;
126     }
127     if (PyObject_SetAttrString(cls,
128                                std::string(field_descriptor->name()).c_str(),
129                                property.get()) < 0) {
130       return -1;
131     }
132   }
133 
134   // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
135   for (int i = 0; i < descriptor->enum_type_count(); ++i) {
136     const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
137     ScopedPyObjectPtr enum_type(
138         PyEnumDescriptor_FromDescriptor(enum_descriptor));
139     if (enum_type == nullptr) {
140       return -1;
141     }
142     // Add wrapped enum type to message class.
143     ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
144         EnumTypeWrapper_class, enum_type.get(), nullptr));
145     if (wrapped == nullptr) {
146       return -1;
147     }
148     if (PyObject_SetAttrString(cls,
149                                std::string(enum_descriptor->name()).c_str(),
150                                wrapped.get()) == -1) {
151       return -1;
152     }
153 
154     // For each enum value add cls.<name> = <number>
155     for (int j = 0; j < enum_descriptor->value_count(); ++j) {
156       const EnumValueDescriptor* enum_value_descriptor =
157           enum_descriptor->value(j);
158       ScopedPyObjectPtr value_number(
159           PyLong_FromLong(enum_value_descriptor->number()));
160       if (value_number == nullptr) {
161         return -1;
162       }
163       if (PyObject_SetAttrString(
164               cls, std::string(enum_value_descriptor->name()).c_str(),
165               value_number.get()) == -1) {
166         return -1;
167       }
168     }
169   }
170 
171   // For each extension set cls.<extension name> = <extension descriptor>.
172   //
173   // Extension descriptors come from
174   // <message descriptor>.extensions_by_name[name]
175   // which was defined previously.
176   for (int i = 0; i < descriptor->extension_count(); ++i) {
177     const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
178     ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
179     if (extension_field == nullptr) {
180       return -1;
181     }
182 
183     // Add the extension field to the message class.
184     if (PyObject_SetAttrString(cls, std::string(field->name()).c_str(),
185                                extension_field.get()) == -1) {
186       return -1;
187     }
188   }
189 
190   return 0;
191 }
192 
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)193 static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
194   static const char* kwlist[] = {"name", "bases", "dict", nullptr};
195   PyObject *bases, *dict;
196   const char* name;
197 
198   // Check arguments: (name, bases, dict)
199   if (!PyArg_ParseTupleAndKeywords(
200           args, kwargs, "sO!O!:type", const_cast<char**>(kwlist), &name,
201           &PyTuple_Type, &bases, &PyDict_Type, &dict)) {
202     return nullptr;
203   }
204 
205   // Check bases: only (), or (message.Message,) are allowed
206   if (!(PyTuple_GET_SIZE(bases) == 0 ||
207         (PyTuple_GET_SIZE(bases) == 1 &&
208          PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) {
209     PyErr_SetString(PyExc_TypeError,
210                     "A Message class can only inherit from Message");
211     return nullptr;
212   }
213 
214   // Check dict['DESCRIPTOR']
215   PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR);
216   if (py_descriptor == nullptr) {
217     PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR");
218     return nullptr;
219   }
220   if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) {
221     PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s",
222                  py_descriptor->ob_type->tp_name);
223     return nullptr;
224   }
225   const Descriptor* message_descriptor =
226       PyMessageDescriptor_AsDescriptor(py_descriptor);
227   if (message_descriptor == nullptr) {
228     return nullptr;
229   }
230 
231   // Messages have no __dict__
232   ScopedPyObjectPtr slots(PyTuple_New(0));
233   if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
234     return nullptr;
235   }
236 
237   // Build the arguments to the base metaclass.
238   // We change the __bases__ classes.
239   ScopedPyObjectPtr new_args;
240 
241   if (WKT_classes == nullptr) {
242     ScopedPyObjectPtr well_known_types(
243         PyImport_ImportModule(PROTOBUF_PYTHON_INTERNAL ".well_known_types"));
244     ABSL_DCHECK(well_known_types != nullptr);
245 
246     WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
247     ABSL_DCHECK(WKT_classes != nullptr);
248   }
249 
250   PyObject* well_known_class = PyDict_GetItemString(
251       WKT_classes, std::string(message_descriptor->full_name()).c_str());
252   if (well_known_class == nullptr) {
253     new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type,
254                                  PythonMessage_class, dict));
255   } else {
256     new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type,
257                                  PythonMessage_class, well_known_class, dict));
258   }
259 
260   if (new_args == nullptr) {
261     return nullptr;
262   }
263   // Call the base metaclass.
264   ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), nullptr));
265   if (result == nullptr) {
266     return nullptr;
267   }
268   CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
269 
270   // Cache the descriptor, both as Python object and as C++ pointer.
271   const Descriptor* descriptor =
272       PyMessageDescriptor_AsDescriptor(py_descriptor);
273   if (descriptor == nullptr) {
274     return nullptr;
275   }
276   Py_INCREF(py_descriptor);
277   newtype->py_message_descriptor = py_descriptor;
278   newtype->message_descriptor = descriptor;
279   // TODO: Don't always use the canonical pool of the descriptor,
280   // use the MessageFactory optionally passed in the class dict.
281   PyDescriptorPool* py_descriptor_pool =
282       GetDescriptorPool_FromPool(descriptor->file()->pool());
283   if (py_descriptor_pool == nullptr) {
284     return nullptr;
285   }
286 
287   PyObject* py_message_factory_obj = PyDict_GetItem(dict, kMessageFactory);
288   PyMessageFactory* py_message_factory = nullptr;
289   if (py_message_factory_obj == nullptr) {
290     py_message_factory = py_descriptor_pool->py_message_factory;
291   } else {
292     py_message_factory =
293         reinterpret_cast<PyMessageFactory*>(py_message_factory_obj);
294   }
295 
296   newtype->py_message_factory = py_message_factory;
297   Py_INCREF(newtype->py_message_factory);
298 
299   // Register the message in the MessageFactory.
300   // TODO: Move this call to MessageFactory.GetPrototype() when the
301   // MessageFactory is fully implemented in C++.
302   if (message_factory::RegisterMessageClass(newtype->py_message_factory,
303                                             descriptor, newtype) < 0) {
304     return nullptr;
305   }
306 
307   // Continue with type initialization: add other descriptors, enum values...
308   if (AddDescriptors(result.get(), descriptor) < 0) {
309     return nullptr;
310   }
311   return result.release();
312 }
313 
Dealloc(PyObject * pself)314 static void Dealloc(PyObject* pself) {
315   CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
316   Py_XDECREF(self->py_message_descriptor);
317   Py_XDECREF(self->py_message_factory);
318   return PyType_Type.tp_dealloc(pself);
319 }
320 
GcTraverse(PyObject * pself,visitproc visit,void * arg)321 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
322   CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
323   Py_VISIT(self->py_message_descriptor);
324   Py_VISIT(self->py_message_factory);
325   return PyType_Type.tp_traverse(pself, visit, arg);
326 }
327 
GcClear(PyObject * pself)328 static int GcClear(PyObject* pself) {
329   // It's important to keep the descriptor and factory alive, until the
330   // C++ message is fully destructed.
331   return PyType_Type.tp_clear(pself);
332 }
333 
334 static PyGetSetDef Getters[] = {
335     {nullptr},
336 };
337 
338 // Compute some class attributes on the fly:
339 // - All the _FIELD_NUMBER attributes, for all fields and nested extensions.
340 // Returns a new reference, or NULL with an exception set.
GetClassAttribute(CMessageClass * self,PyObject * name)341 static PyObject* GetClassAttribute(CMessageClass* self, PyObject* name) {
342   char* attr;
343   Py_ssize_t attr_size;
344   static const char kSuffix[] = "_FIELD_NUMBER";
345   if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 &&
346       absl::EndsWith(absl::string_view(attr, attr_size), kSuffix)) {
347     std::string field_name(attr, attr_size - sizeof(kSuffix) + 1);
348     LowerString(&field_name);
349 
350     // Try to find a field with the given name, without the suffix.
351     const FieldDescriptor* field =
352         self->message_descriptor->FindFieldByLowercaseName(field_name);
353     if (!field) {
354       // Search nested extensions as well.
355       field =
356           self->message_descriptor->FindExtensionByLowercaseName(field_name);
357     }
358     if (field) {
359       return PyLong_FromLong(field->number());
360     }
361   }
362   PyErr_SetObject(PyExc_AttributeError, name);
363   return nullptr;
364 }
365 
GetAttr(CMessageClass * self,PyObject * name)366 static PyObject* GetAttr(CMessageClass* self, PyObject* name) {
367   PyObject* result = CMessageClass_Type->tp_base->tp_getattro(
368       reinterpret_cast<PyObject*>(self), name);
369   if (result != nullptr) {
370     return result;
371   }
372   if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
373     return nullptr;
374   }
375 
376   PyErr_Clear();
377   return GetClassAttribute(self, name);
378 }
379 
380 }  // namespace message_meta
381 
382 // Protobuf has a 64MB limit built in, this variable will override this. Please
383 // do not enable this unless you fully understand the implications: protobufs
384 // must all be kept in memory at the same time, so if they grow too big you may
385 // get OOM errors. The protobuf APIs do not provide any tools for processing
386 // protobufs in chunks.  If you have protos this big you should break them up if
387 // it is at all convenient to do so.
388 #ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
389 static bool allow_oversize_protos = true;
390 #else
391 static bool allow_oversize_protos = false;
392 #endif
393 
394 static PyTypeObject _CMessageClass_Type = {
395     PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
396     ".MessageMeta",         // tp_name
397     sizeof(CMessageClass),  // tp_basicsize
398     0,                      // tp_itemsize
399     message_meta::Dealloc,  // tp_dealloc
400 #if PY_VERSION_HEX < 0x03080000
401     nullptr, /* tp_print */
402 #else
403     0, /* tp_vectorcall_offset */
404 #endif
405     nullptr,                              // tp_getattr
406     nullptr,                              // tp_setattr
407     nullptr,                              // tp_compare
408     nullptr,                              // tp_repr
409     nullptr,                              // tp_as_number
410     nullptr,                              // tp_as_sequence
411     nullptr,                              // tp_as_mapping
412     nullptr,                              // tp_hash
413     nullptr,                              // tp_call
414     nullptr,                              // tp_str
415     (getattrofunc)message_meta::GetAttr,  // tp_getattro
416     nullptr,                              // tp_setattro
417     nullptr,                              // tp_as_buffer
418     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,  // tp_flags
419     "The metaclass of ProtocolMessages",                            // tp_doc
420     message_meta::GcTraverse,  // tp_traverse
421     message_meta::GcClear,     // tp_clear
422     nullptr,                   // tp_richcompare
423     0,                         // tp_weaklistoffset
424     nullptr,                   // tp_iter
425     nullptr,                   // tp_iternext
426     nullptr,                   // tp_methods
427     nullptr,                   // tp_members
428     message_meta::Getters,     // tp_getset
429     nullptr,                   // tp_base
430     nullptr,                   // tp_dict
431     nullptr,                   // tp_descr_get
432     nullptr,                   // tp_descr_set
433     0,                         // tp_dictoffset
434     nullptr,                   // tp_init
435     nullptr,                   // tp_alloc
436     message_meta::New,         // tp_new
437 };
438 PyTypeObject* CMessageClass_Type = &_CMessageClass_Type;
439 
CheckMessageClass(PyTypeObject * cls)440 static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
441   if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
442     PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
443     return nullptr;
444   }
445   return reinterpret_cast<CMessageClass*>(cls);
446 }
447 
GetMessageDescriptor(PyTypeObject * cls)448 static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
449   CMessageClass* type = CheckMessageClass(cls);
450   if (type == nullptr) {
451     return nullptr;
452   }
453   return type->message_descriptor;
454 }
455 
456 // Forward declarations
457 namespace cmessage {
458 int InternalReleaseFieldByDescriptor(CMessage* self,
459                                      const FieldDescriptor* field_descriptor);
460 }  // namespace cmessage
461 
462 // ---------------------------------------------------------------------
463 
464 PyObject* EncodeError_class;
465 PyObject* DecodeError_class;
466 PyObject* PickleError_class;
467 
468 // Format an error message for unexpected types.
469 // Always return with an exception set.
FormatTypeError(PyObject * arg,const char * expected_types)470 void FormatTypeError(PyObject* arg, const char* expected_types) {
471   // This function is often called with an exception set.
472   // Clear it to call PyObject_Repr() in good conditions.
473   PyErr_Clear();
474   PyObject* repr = PyObject_Repr(arg);
475   if (repr) {
476     PyErr_Format(
477         PyExc_TypeError, "%.100s has type %.100s, but expected one of: %s",
478         PyString_AsString(repr), Py_TYPE(arg)->tp_name, expected_types);
479     Py_DECREF(repr);
480   }
481 }
482 
OutOfRangeError(PyObject * arg)483 void OutOfRangeError(PyObject* arg) {
484   PyObject* s = PyObject_Str(arg);
485   if (s) {
486     PyErr_Format(PyExc_ValueError, "Value out of range: %s",
487                  PyString_AsString(s));
488     Py_DECREF(s);
489   }
490 }
491 
492 template <class RangeType, class ValueType>
VerifyIntegerCastAndRange(PyObject * arg,ValueType value)493 bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
494   if (PROTOBUF_PREDICT_FALSE(value == -1 && PyErr_Occurred())) {
495     if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
496       // Replace it with the same ValueError as pure python protos instead of
497       // the default one.
498       PyErr_Clear();
499       OutOfRangeError(arg);
500     }  // Otherwise propagate existing error.
501     return false;
502   }
503   if (PROTOBUF_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) {
504     OutOfRangeError(arg);
505     return false;
506   }
507   return true;
508 }
509 
510 template <class T>
CheckAndGetInteger(PyObject * arg,T * value)511 bool CheckAndGetInteger(PyObject* arg, T* value) {
512   // This effectively defines an integer as "an object that can be cast as
513   // an integer and can be used as an ordinal number".
514   // This definition includes everything with a valid __index__() implementation
515   // and shouldn't cast the net too wide.
516   if (!strcmp(Py_TYPE(arg)->tp_name, "numpy.ndarray") ||
517       PROTOBUF_PREDICT_FALSE(!PyIndex_Check(arg))) {
518     FormatTypeError(arg, "int");
519     return false;
520   }
521 
522   PyObject* arg_py_int = PyNumber_Index(arg);
523   if (PyErr_Occurred()) {
524     // Propagate existing error.
525     return false;
526   }
527 
528   if (std::numeric_limits<T>::min() == 0) {
529     // Unsigned case.
530     unsigned PY_LONG_LONG ulong_result = PyLong_AsUnsignedLongLong(arg_py_int);
531     Py_DECREF(arg_py_int);
532     if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
533                                                             ulong_result)) {
534       *value = static_cast<T>(ulong_result);
535     } else {
536       return false;
537     }
538   } else {
539     // Signed case.
540     Py_DECREF(arg_py_int);
541     PY_LONG_LONG long_result = PyLong_AsLongLong(arg);
542     if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
543       *value = static_cast<T>(long_result);
544     } else {
545       return false;
546     }
547   }
548   return true;
549 }
550 
551 // These are referenced by repeated_scalar_container, and must
552 // be explicitly instantiated.
553 template bool CheckAndGetInteger<int32_t>(PyObject*, int32_t*);
554 template bool CheckAndGetInteger<int64_t>(PyObject*, int64_t*);
555 template bool CheckAndGetInteger<uint32_t>(PyObject*, uint32_t*);
556 template bool CheckAndGetInteger<uint64_t>(PyObject*, uint64_t*);
557 
CheckAndGetDouble(PyObject * arg,double * value)558 bool CheckAndGetDouble(PyObject* arg, double* value) {
559   *value = PyFloat_AsDouble(arg);
560   if (!strcmp(Py_TYPE(arg)->tp_name, "numpy.ndarray") ||
561       PROTOBUF_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) {
562     FormatTypeError(arg, "int, float");
563     return false;
564   }
565   return true;
566 }
567 
CheckAndGetFloat(PyObject * arg,float * value)568 bool CheckAndGetFloat(PyObject* arg, float* value) {
569   double double_value;
570   if (!CheckAndGetDouble(arg, &double_value)) {
571     return false;
572   }
573   *value = io::SafeDoubleToFloat(double_value);
574   return true;
575 }
576 
CheckAndGetBool(PyObject * arg,bool * value)577 bool CheckAndGetBool(PyObject* arg, bool* value) {
578   long long_value = PyLong_AsLong(arg);  // NOLINT
579   if (!strcmp(Py_TYPE(arg)->tp_name, "numpy.ndarray") ||
580       (long_value == -1 && PyErr_Occurred())) {
581     FormatTypeError(arg, "int, bool");
582     return false;
583   }
584   *value = static_cast<bool>(long_value);
585 
586   return true;
587 }
588 
589 // Checks whether the given object (which must be "bytes" or "unicode") contains
590 // valid UTF-8.
IsValidUTF8(PyObject * obj)591 bool IsValidUTF8(PyObject* obj) {
592   if (PyBytes_Check(obj)) {
593     PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", nullptr);
594 
595     // Clear the error indicator; we report our own error when desired.
596     PyErr_Clear();
597 
598     if (unicode) {
599       Py_DECREF(unicode);
600       return true;
601     } else {
602       return false;
603     }
604   } else {
605     // Unicode object, known to be valid UTF-8.
606     return true;
607   }
608 }
609 
AllowInvalidUTF8(const FieldDescriptor * field)610 bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; }
611 
CheckString(PyObject * arg,const FieldDescriptor * descriptor)612 PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
613   ABSL_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING ||
614               descriptor->type() == FieldDescriptor::TYPE_BYTES);
615   if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
616     if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
617       FormatTypeError(arg, "bytes, unicode");
618       return nullptr;
619     }
620 
621     if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) {
622       PyObject* repr = PyObject_Repr(arg);
623       PyErr_Format(PyExc_ValueError,
624                    "%s has type str, but isn't valid UTF-8 "
625                    "encoding. Non-UTF-8 strings must be converted to "
626                    "unicode objects before being added.",
627                    PyString_AsString(repr));
628       Py_DECREF(repr);
629       return nullptr;
630     }
631   } else if (!PyBytes_Check(arg)) {
632     FormatTypeError(arg, "bytes");
633     return nullptr;
634   }
635 
636   PyObject* encoded_string = nullptr;
637   if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
638     if (PyBytes_Check(arg)) {
639       // The bytes were already validated as correctly encoded UTF-8 above.
640       encoded_string = arg;  // Already encoded.
641       Py_INCREF(encoded_string);
642     } else {
643       encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", nullptr);
644     }
645   } else {
646     // In this case field type is "bytes".
647     encoded_string = arg;
648     Py_INCREF(encoded_string);
649   }
650 
651   return encoded_string;
652 }
653 
CheckAndSetString(PyObject * arg,Message * message,const FieldDescriptor * descriptor,const Reflection * reflection,bool append,int index)654 bool CheckAndSetString(PyObject* arg, Message* message,
655                        const FieldDescriptor* descriptor,
656                        const Reflection* reflection, bool append, int index) {
657   ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor));
658 
659   if (encoded_string.get() == nullptr) {
660     return false;
661   }
662 
663   char* value;
664   Py_ssize_t value_len;
665   if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) {
666     return false;
667   }
668 
669   std::string value_string(value, value_len);
670   if (append) {
671     reflection->AddString(message, descriptor, std::move(value_string));
672   } else if (index < 0) {
673     reflection->SetString(message, descriptor, std::move(value_string));
674   } else {
675     reflection->SetRepeatedString(message, descriptor, index,
676                                   std::move(value_string));
677   }
678   return true;
679 }
680 
ToStringObject(const FieldDescriptor * descriptor,const absl::string_view value)681 PyObject* ToStringObject(const FieldDescriptor* descriptor,
682                          const absl::string_view value) {
683   if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
684     return PyBytes_FromStringAndSize(value.data(), value.length());
685   }
686 
687   PyObject* result =
688       PyUnicode_DecodeUTF8(value.data(), value.length(), nullptr);
689   // If the string can't be decoded in UTF-8, just return a string object that
690   // contains the raw bytes. This can't happen if the value was assigned using
691   // the members of the Python message object, but can happen if the values were
692   // parsed from the wire (binary).
693   if (result == nullptr) {
694     PyErr_Clear();
695     result = PyBytes_FromStringAndSize(value.data(), value.length());
696   }
697   return result;
698 }
699 
CheckFieldBelongsToMessage(const FieldDescriptor * field_descriptor,const Message * message)700 bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
701                                 const Message* message) {
702   if (message->GetDescriptor() == field_descriptor->containing_type()) {
703     return true;
704   }
705   PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'",
706                std::string(field_descriptor->full_name()).c_str(),
707                std::string(message->GetDescriptor()->full_name()).c_str());
708   return false;
709 }
710 
711 namespace cmessage {
712 
GetFactoryForMessage(CMessage * message)713 PyMessageFactory* GetFactoryForMessage(CMessage* message) {
714   ABSL_DCHECK(PyObject_TypeCheck(message, CMessage_Type));
715   return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
716 }
717 
MaybeReleaseOverlappingOneofField(CMessage * cmessage,const FieldDescriptor * field)718 static int MaybeReleaseOverlappingOneofField(CMessage* cmessage,
719                                              const FieldDescriptor* field) {
720   Message* message = cmessage->message;
721   const Reflection* reflection = message->GetReflection();
722   if (!field->containing_oneof() ||
723       !reflection->HasOneof(*message, field->containing_oneof()) ||
724       reflection->HasField(*message, field)) {
725     // No other field in this oneof, no need to release.
726     return 0;
727   }
728 
729   const OneofDescriptor* oneof = field->containing_oneof();
730   const FieldDescriptor* existing_field =
731       reflection->GetOneofFieldDescriptor(*message, oneof);
732   if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
733     // Non-message fields don't need to be released.
734     return 0;
735   }
736   if (InternalReleaseFieldByDescriptor(cmessage, existing_field) < 0) {
737     return -1;
738   }
739   return 0;
740 }
741 
742 // After a Merge, visit every sub-message that was read-only, and
743 // eventually update their pointer if the Merge operation modified them.
FixupMessageAfterMerge(CMessage * self)744 int FixupMessageAfterMerge(CMessage* self) {
745   if (!self->composite_fields) {
746     return 0;
747   }
748   PyMessageFactory* factory = GetFactoryForMessage(self);
749   for (const auto& item : *self->composite_fields) {
750     const FieldDescriptor* descriptor = item.first;
751     if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
752         !descriptor->is_repeated()) {
753       CMessage* cmsg = reinterpret_cast<CMessage*>(item.second);
754       if (cmsg->read_only == false) {
755         return 0;
756       }
757       Message* message = self->message;
758       const Reflection* reflection = message->GetReflection();
759       if (reflection->HasField(*message, descriptor)) {
760         // Message used to be read_only, but is no longer. Get the new pointer
761         // and record it.
762         Message* mutable_message = reflection->MutableMessage(
763             message, descriptor, factory->message_factory);
764         cmsg->message = mutable_message;
765         cmsg->read_only = false;
766         if (FixupMessageAfterMerge(cmsg) < 0) {
767           return -1;
768         }
769       }
770     }
771   }
772 
773   return 0;
774 }
775 
776 // ---------------------------------------------------------------------
777 // Making a message writable
778 
AssureWritable(CMessage * self)779 int AssureWritable(CMessage* self) {
780   if (self == nullptr || !self->read_only) {
781     return 0;
782   }
783 
784   // Toplevel messages are always mutable.
785   ABSL_DCHECK(self->parent);
786 
787   if (AssureWritable(self->parent) == -1) {
788     return -1;
789   }
790   // If this message is part of a oneof, there might be a field to release in
791   // the parent.
792   if (MaybeReleaseOverlappingOneofField(self->parent,
793                                         self->parent_field_descriptor) < 0) {
794     return -1;
795   }
796 
797   // Make self->message writable.
798   Message* parent_message = self->parent->message;
799   const Reflection* reflection = parent_message->GetReflection();
800   Message* mutable_message = reflection->MutableMessage(
801       parent_message, self->parent_field_descriptor,
802       GetFactoryForMessage(self->parent)->message_factory);
803   if (mutable_message == nullptr) {
804     return -1;
805   }
806   self->message = mutable_message;
807   self->read_only = false;
808 
809   return 0;
810 }
811 
812 // --- Globals:
813 
814 // Retrieve a C++ FieldDescriptor for an extension handle.
GetExtensionDescriptor(PyObject * extension)815 const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
816   ScopedPyObjectPtr cdescriptor;
817   if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) {
818     // Most callers consider extensions as a plain dictionary.  We should
819     // allow input which is not a field descriptor, and simply pretend it does
820     // not exist.
821     PyErr_SetObject(PyExc_KeyError, extension);
822     return nullptr;
823   }
824   return PyFieldDescriptor_AsDescriptor(extension);
825 }
826 
827 // If value is a string, convert it into an enum value based on the labels in
828 // descriptor, otherwise simply return value.  Always returns a new reference.
GetIntegerEnumValue(const FieldDescriptor & descriptor,PyObject * value)829 static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
830                                      PyObject* value) {
831   if (PyUnicode_Check(value)) {
832     const EnumDescriptor* enum_descriptor = descriptor.enum_type();
833     if (enum_descriptor == nullptr) {
834       PyErr_SetString(PyExc_TypeError, "not an enum field");
835       return nullptr;
836     }
837     char* enum_label;
838     Py_ssize_t size;
839     if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) {
840       return nullptr;
841     }
842     const EnumValueDescriptor* enum_value_descriptor =
843         enum_descriptor->FindValueByName(absl::string_view(enum_label, size));
844     if (enum_value_descriptor == nullptr) {
845       PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label);
846       return nullptr;
847     }
848     return PyLong_FromLong(enum_value_descriptor->number());
849   }
850   Py_INCREF(value);
851   return value;
852 }
853 
854 // Delete a slice from a repeated field.
855 // The only way to remove items in C++ protos is to delete the last one,
856 // so we swap items to move the deleted ones at the end, and then strip the
857 // sequence.
DeleteRepeatedField(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * slice)858 int DeleteRepeatedField(CMessage* self, const FieldDescriptor* field_descriptor,
859                         PyObject* slice) {
860   Py_ssize_t length, from, to, step, slice_length;
861   Message* message = self->message;
862   const Reflection* reflection = message->GetReflection();
863   int min, max;
864   length = reflection->FieldSize(*message, field_descriptor);
865 
866   if (PySlice_Check(slice)) {
867     from = to = step = slice_length = 0;
868     PySlice_GetIndicesEx(slice, length, &from, &to, &step, &slice_length);
869     if (from < to) {
870       min = from;
871       max = to - 1;
872     } else {
873       min = to + 1;
874       max = from;
875     }
876   } else {
877     from = to = PyLong_AsLong(slice);
878     if (from == -1 && PyErr_Occurred()) {
879       PyErr_SetString(PyExc_TypeError, "list indices must be integers");
880       return -1;
881     }
882 
883     if (from < 0) {
884       from = to = length + from;
885     }
886     step = 1;
887     min = max = from;
888 
889     // Range check.
890     if (from < 0 || from >= length) {
891       PyErr_Format(PyExc_IndexError, "list assignment index out of range");
892       return -1;
893     }
894   }
895 
896   Py_ssize_t i = from;
897   std::vector<bool> to_delete(length, false);
898   while (i >= min && i <= max) {
899     to_delete[i] = true;
900     i += step;
901   }
902 
903   // Swap elements so that items to delete are at the end.
904   to = 0;
905   for (i = 0; i < length; ++i) {
906     if (!to_delete[i]) {
907       if (i != to) {
908         reflection->SwapElements(message, field_descriptor, i, to);
909       }
910       ++to;
911     }
912   }
913 
914   Arena* arena = message->GetArena();
915   ABSL_DCHECK_EQ(arena, nullptr)
916       << "python protobuf is expected to be allocated from heap";
917   // Remove items, starting from the end.
918   for (; length > to; length--) {
919     if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
920       reflection->RemoveLast(message, field_descriptor);
921       continue;
922     }
923     // It seems that RemoveLast() is less efficient for sub-messages, and
924     // the memory is not completely released. Prefer ReleaseLast().
925     //
926     // To work around a debug hardening (PROTOBUF_FORCE_COPY_IN_RELEASE),
927     // explicitly use UnsafeArenaReleaseLast. To not break rare use cases where
928     // arena is used, we fallback to ReleaseLast (but ABSL_DCHECK to find/fix
929     // it).
930     //
931     // Note that arena is likely null and ABSL_DCHECK and ReleaseLast might be
932     // redundant. The current approach takes extra cautious path not to disrupt
933     // production.
934     Message* sub_message =
935         (arena == nullptr)
936             ? reflection->UnsafeArenaReleaseLast(message, field_descriptor)
937             : reflection->ReleaseLast(message, field_descriptor);
938     // If there is a live weak reference to an item being removed, we "Release"
939     // it, and it takes ownership of the message.
940     if (CMessage* released = self->MaybeReleaseSubMessage(sub_message)) {
941       released->message = sub_message;
942     } else {
943       // sub_message was not transferred, delete it.
944       delete sub_message;
945     }
946   }
947 
948   return 0;
949 }
950 
951 // Initializes fields of a message. Used in constructors.
InitAttributes(CMessage * self,PyObject * args,PyObject * kwargs)952 int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
953   if (args != nullptr && PyTuple_Size(args) != 0) {
954     PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
955     return -1;
956   }
957 
958   if (kwargs == nullptr) {
959     return 0;
960   }
961 
962   Py_ssize_t pos = 0;
963   PyObject* name;
964   PyObject* value;
965   while (PyDict_Next(kwargs, &pos, &name, &value)) {
966     if (!(PyUnicode_Check(name))) {
967       PyErr_SetString(PyExc_ValueError, "Field name must be a string");
968       return -1;
969     }
970     ScopedPyObjectPtr property(
971         PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name));
972     if (property == nullptr ||
973         !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) {
974       PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
975                    std::string(self->message->GetDescriptor()->name()).c_str(),
976                    PyString_AsString(name));
977       return -1;
978     }
979     const FieldDescriptor* descriptor =
980         reinterpret_cast<PyMessageFieldProperty*>(property.get())
981             ->field_descriptor;
982     if (value == Py_None) {
983       // field=None is the same as no field at all.
984       continue;
985     }
986     if (descriptor->is_map()) {
987       ScopedPyObjectPtr map(GetFieldValue(self, descriptor));
988       const FieldDescriptor* value_descriptor =
989           descriptor->message_type()->map_value();
990       if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
991         ScopedPyObjectPtr iter(PyObject_GetIter(value));
992         if (iter == nullptr) {
993           PyErr_Format(PyExc_TypeError, "Argument %s is not iterable",
994                        PyString_AsString(name));
995           return -1;
996         }
997         ScopedPyObjectPtr next;
998         while ((next.reset(PyIter_Next(iter.get()))) != nullptr) {
999           ScopedPyObjectPtr source_value(PyObject_GetItem(value, next.get()));
1000           ScopedPyObjectPtr dest_value(PyObject_GetItem(map.get(), next.get()));
1001           if (source_value.get() == nullptr || dest_value.get() == nullptr) {
1002             return -1;
1003           }
1004           ScopedPyObjectPtr ok(PyObject_CallMethod(
1005               dest_value.get(), "MergeFrom", "O", source_value.get()));
1006           if (ok.get() == nullptr) {
1007             return -1;
1008           }
1009         }
1010       } else {
1011         ScopedPyObjectPtr function_return;
1012         function_return.reset(
1013             PyObject_CallMethod(map.get(), "update", "O", value));
1014         if (function_return.get() == nullptr) {
1015           return -1;
1016         }
1017       }
1018     } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1019       ScopedPyObjectPtr container(GetFieldValue(self, descriptor));
1020       if (container == nullptr) {
1021         return -1;
1022       }
1023       if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1024         RepeatedCompositeContainer* rc_container =
1025             reinterpret_cast<RepeatedCompositeContainer*>(container.get());
1026         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1027         if (iter == nullptr) {
1028           PyErr_Format(PyExc_TypeError, "Value of field '%s' must be iterable",
1029                        std::string(descriptor->name()).c_str());
1030           return -1;
1031         }
1032         ScopedPyObjectPtr next;
1033         while ((next.reset(PyIter_Next(iter.get()))) != nullptr) {
1034           PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : nullptr);
1035           ScopedPyObjectPtr new_msg(
1036               repeated_composite_container::Add(rc_container, nullptr, kwargs));
1037           if (new_msg == nullptr) {
1038             return -1;
1039           }
1040           if (kwargs == nullptr) {
1041             // next was not a dict, it's a message we need to merge
1042             ScopedPyObjectPtr merged(MergeFrom(
1043                 reinterpret_cast<CMessage*>(new_msg.get()), next.get()));
1044             if (merged.get() == nullptr) {
1045               return -1;
1046             }
1047           }
1048         }
1049         if (PyErr_Occurred()) {
1050           // Check to see how PyIter_Next() exited.
1051           return -1;
1052         }
1053       } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1054         RepeatedScalarContainer* rs_container =
1055             reinterpret_cast<RepeatedScalarContainer*>(container.get());
1056         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1057         if (iter == nullptr) {
1058           PyErr_Format(PyExc_TypeError, "Value of field '%s' must be iterable",
1059                        std::string(descriptor->name()).c_str());
1060           return -1;
1061         }
1062         ScopedPyObjectPtr next;
1063         while ((next.reset(PyIter_Next(iter.get()))) != nullptr) {
1064           ScopedPyObjectPtr enum_value(
1065               GetIntegerEnumValue(*descriptor, next.get()));
1066           if (enum_value == nullptr) {
1067             return -1;
1068           }
1069           ScopedPyObjectPtr new_msg(repeated_scalar_container::Append(
1070               rs_container, enum_value.get()));
1071           if (new_msg == nullptr) {
1072             return -1;
1073           }
1074         }
1075         if (PyErr_Occurred()) {
1076           // Check to see how PyIter_Next() exited.
1077           return -1;
1078         }
1079       } else {
1080         if (ScopedPyObjectPtr(repeated_scalar_container::Extend(
1081                 reinterpret_cast<RepeatedScalarContainer*>(container.get()),
1082                 value)) == nullptr) {
1083           return -1;
1084         }
1085       }
1086     } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1087       ScopedPyObjectPtr message(GetFieldValue(self, descriptor));
1088       if (message == nullptr) {
1089         return -1;
1090       }
1091       CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
1092       if (PyDict_Check(value)) {
1093         // Make the message exist even if the dict is empty.
1094         AssureWritable(cmessage);
1095         if (descriptor->message_type()->well_known_type() ==
1096             Descriptor::WELLKNOWNTYPE_STRUCT) {
1097           ScopedPyObjectPtr ok(PyObject_CallMethod(
1098               reinterpret_cast<PyObject*>(cmessage), "update", "O", value));
1099           if (ok.get() == nullptr && PyDict_Size(value) == 1 &&
1100               PyDict_Contains(value, PyUnicode_FromString("fields"))) {
1101             // Fallback to init as normal message field.
1102             PyErr_Clear();
1103             PyObject* tmp = Clear(cmessage);
1104             Py_DECREF(tmp);
1105             if (InitAttributes(cmessage, nullptr, value) < 0) {
1106               return -1;
1107             }
1108           }
1109         } else {
1110           if (InitAttributes(cmessage, nullptr, value) < 0) {
1111             return -1;
1112           }
1113         }
1114       } else {
1115         if (PyObject_TypeCheck(value, CMessage_Type)) {
1116           ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
1117           if (merged == nullptr) {
1118             return -1;
1119           }
1120         } else {
1121           if (descriptor->message_type()->well_known_type() !=
1122                   Descriptor::WELLKNOWNTYPE_UNSPECIFIED &&
1123               PyObject_HasAttrString(reinterpret_cast<PyObject*>(cmessage),
1124                                      "_internal_assign")) {
1125             AssureWritable(cmessage);
1126             ScopedPyObjectPtr ok(
1127                 PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
1128                                     "_internal_assign", "O", value));
1129             if (ok.get() == nullptr) {
1130               return -1;
1131             }
1132           } else {
1133             PyErr_Format(PyExc_TypeError,
1134                          "Parameter to initialize message field must be "
1135                          "dict or instance of same class: expected %s got %s.",
1136                          std::string(descriptor->full_name()).c_str(),
1137                          Py_TYPE(value)->tp_name);
1138             return -1;
1139           }
1140         }
1141       }
1142     } else {
1143       ScopedPyObjectPtr new_val;
1144       if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1145         new_val.reset(GetIntegerEnumValue(*descriptor, value));
1146         if (new_val == nullptr) {
1147           return -1;
1148         }
1149         value = new_val.get();
1150       }
1151       if (SetFieldValue(self, descriptor, value) < 0) {
1152         return -1;
1153       }
1154     }
1155   }
1156   return 0;
1157 }
1158 
1159 // Allocates an incomplete Python Message: the caller must fill self->message
1160 // and eventually self->parent.
NewEmptyMessage(CMessageClass * type)1161 CMessage* NewEmptyMessage(CMessageClass* type) {
1162   CMessage* self =
1163       reinterpret_cast<CMessage*>(PyType_GenericAlloc(&type->super.ht_type, 0));
1164   if (self == nullptr) {
1165     return nullptr;
1166   }
1167 
1168   self->message = nullptr;
1169   self->parent = nullptr;
1170   self->parent_field_descriptor = nullptr;
1171   self->read_only = false;
1172 
1173   self->composite_fields = nullptr;
1174   self->child_submessages = nullptr;
1175 
1176   return self;
1177 }
1178 
1179 // The __new__ method of Message classes.
1180 // Creates a new C++ message and takes ownership.
NewCMessage(CMessageClass * type)1181 static CMessage* NewCMessage(CMessageClass* type) {
1182   // Retrieve the message descriptor and the default instance (=prototype).
1183   const Descriptor* message_descriptor = type->message_descriptor;
1184   if (message_descriptor == nullptr) {
1185     // This would be very unexpected since the CMessageClass has already
1186     // been checked.
1187     PyErr_Format(PyExc_TypeError,
1188                  "CMessageClass object '%s' has no descriptor.",
1189                  Py_TYPE(type)->tp_name);
1190     return nullptr;
1191   }
1192   const Message* prototype =
1193       type->py_message_factory->message_factory->GetPrototype(
1194           message_descriptor);
1195   if (prototype == nullptr) {
1196     PyErr_SetString(PyExc_TypeError,
1197                     std::string(message_descriptor->full_name()).c_str());
1198     return nullptr;
1199   }
1200 
1201   CMessage* self = NewEmptyMessage(type);
1202   if (self == nullptr) {
1203     return nullptr;
1204   }
1205   self->message = prototype->New(nullptr);  // Ensures no arena is used.
1206   self->parent = nullptr;                   // This message owns its data.
1207   return self;
1208 }
1209 
New(PyTypeObject * cls,PyObject * unused_args,PyObject * unused_kwargs)1210 static PyObject* New(PyTypeObject* cls, PyObject* unused_args,
1211                      PyObject* unused_kwargs) {
1212   CMessageClass* type = CheckMessageClass(cls);
1213   if (type == nullptr) {
1214     return nullptr;
1215   }
1216   return reinterpret_cast<PyObject*>(NewCMessage(type));
1217 }
1218 
1219 // The __init__ method of Message classes.
1220 // It initializes fields from keywords passed to the constructor.
Init(CMessage * self,PyObject * args,PyObject * kwargs)1221 static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
1222   return InitAttributes(self, args, kwargs);
1223 }
1224 
1225 // ---------------------------------------------------------------------
1226 // Deallocating a CMessage
1227 
Dealloc(CMessage * self)1228 static void Dealloc(CMessage* self) {
1229   if (self->weakreflist) {
1230     PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
1231   }
1232   // At this point all dependent objects have been removed.
1233   ABSL_DCHECK(!self->child_submessages || self->child_submessages->empty());
1234   ABSL_DCHECK(!self->composite_fields || self->composite_fields->empty());
1235   delete self->child_submessages;
1236   delete self->composite_fields;
1237 
1238   CMessage* parent = self->parent;
1239   if (!parent) {
1240     // No parent, we own the message.
1241     delete self->message;
1242   } else if (parent->AsPyObject() == Py_None) {
1243     // Message owned externally: Nothing to dealloc
1244     Py_CLEAR(self->parent);
1245   } else {
1246     // Clear this message from its parent's map.
1247     if (self->parent_field_descriptor->is_repeated()) {
1248       if (parent->child_submessages)
1249         parent->child_submessages->erase(self->message);
1250     } else {
1251       if (parent->composite_fields)
1252         parent->composite_fields->erase(self->parent_field_descriptor);
1253     }
1254     Py_CLEAR(self->parent);
1255   }
1256   Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
1257 }
1258 
1259 // ---------------------------------------------------------------------
1260 
1261 
IsInitialized(CMessage * self,PyObject * args)1262 PyObject* IsInitialized(CMessage* self, PyObject* args) {
1263   PyObject* errors = nullptr;
1264   if (!PyArg_ParseTuple(args, "|O", &errors)) {
1265     return nullptr;
1266   }
1267   if (self->message->IsInitialized()) {
1268     Py_RETURN_TRUE;
1269   }
1270   if (errors != nullptr) {
1271     ScopedPyObjectPtr initialization_errors(FindInitializationErrors(self));
1272     if (initialization_errors == nullptr) {
1273       return nullptr;
1274     }
1275     ScopedPyObjectPtr extend_name(PyUnicode_FromString("extend"));
1276     if (extend_name == nullptr) {
1277       return nullptr;
1278     }
1279     ScopedPyObjectPtr result(PyObject_CallMethodObjArgs(
1280         errors, extend_name.get(), initialization_errors.get(), nullptr));
1281     if (result == nullptr) {
1282       return nullptr;
1283     }
1284   }
1285   Py_RETURN_FALSE;
1286 }
1287 
HasFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1288 int HasFieldByDescriptor(CMessage* self,
1289                          const FieldDescriptor* field_descriptor) {
1290   Message* message = self->message;
1291   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
1292     return -1;
1293   }
1294   if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1295     PyErr_SetString(PyExc_KeyError,
1296                     "Field is repeated. A singular method is required.");
1297     return -1;
1298   }
1299   return message->GetReflection()->HasField(*message, field_descriptor);
1300 }
1301 
FindFieldWithOneofs(const Message * message,absl::string_view field_name,bool * in_oneof)1302 const FieldDescriptor* FindFieldWithOneofs(const Message* message,
1303                                            absl::string_view field_name,
1304                                            bool* in_oneof) {
1305   *in_oneof = false;
1306   const Descriptor* descriptor = message->GetDescriptor();
1307   const FieldDescriptor* field_descriptor =
1308       descriptor->FindFieldByName(field_name);
1309   if (field_descriptor != nullptr) {
1310     return field_descriptor;
1311   }
1312   const OneofDescriptor* oneof_desc = descriptor->FindOneofByName(field_name);
1313   if (oneof_desc != nullptr) {
1314     *in_oneof = true;
1315     return message->GetReflection()->GetOneofFieldDescriptor(*message,
1316                                                              oneof_desc);
1317   }
1318   return nullptr;
1319 }
1320 
CheckHasPresence(const FieldDescriptor * field_descriptor,bool in_oneof)1321 bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) {
1322   auto message_name = std::string(field_descriptor->containing_type()->name());
1323   if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1324     PyErr_Format(
1325         PyExc_ValueError, "Protocol message %s has no singular \"%s\" field.",
1326         message_name.c_str(), std::string(field_descriptor->name()).c_str());
1327     return false;
1328   }
1329 
1330   if (!field_descriptor->has_presence()) {
1331     PyErr_Format(PyExc_ValueError,
1332                  "Can't test non-optional, non-submessage field \"%s.%s\" for "
1333                  "presence in proto3.",
1334                  message_name.c_str(),
1335                  std::string(field_descriptor->name()).c_str());
1336     return false;
1337   }
1338 
1339   return true;
1340 }
1341 
HasField(CMessage * self,PyObject * arg)1342 PyObject* HasField(CMessage* self, PyObject* arg) {
1343   char* field_name;
1344   Py_ssize_t size;
1345   field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
1346   Message* message = self->message;
1347 
1348   if (!field_name) {
1349     PyErr_Format(PyExc_ValueError,
1350                  "The field name passed to message %s"
1351                  " is not a str.",
1352                  std::string(message->GetDescriptor()->name()).c_str());
1353     return nullptr;
1354   }
1355 
1356   bool is_in_oneof;
1357   const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
1358       message, absl::string_view(field_name, size), &is_in_oneof);
1359   if (field_descriptor == nullptr) {
1360     if (!is_in_oneof) {
1361       PyErr_Format(PyExc_ValueError, "Protocol message %s has no field %s.",
1362                    std::string(message->GetDescriptor()->name()).c_str(),
1363                    field_name);
1364       return nullptr;
1365     } else {
1366       Py_RETURN_FALSE;
1367     }
1368   }
1369 
1370   if (!CheckHasPresence(field_descriptor, is_in_oneof)) {
1371     return nullptr;
1372   }
1373 
1374   if (message->GetReflection()->HasField(*message, field_descriptor)) {
1375     Py_RETURN_TRUE;
1376   }
1377 
1378   Py_RETURN_FALSE;
1379 }
1380 
ClearExtension(CMessage * self,PyObject * extension)1381 PyObject* ClearExtension(CMessage* self, PyObject* extension) {
1382   const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1383   if (descriptor == nullptr) {
1384     return nullptr;
1385   }
1386   if (ClearFieldByDescriptor(self, descriptor) < 0) {
1387     return nullptr;
1388   }
1389   Py_RETURN_NONE;
1390 }
1391 
HasExtension(CMessage * self,PyObject * extension)1392 PyObject* HasExtension(CMessage* self, PyObject* extension) {
1393   const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1394   if (descriptor == nullptr) {
1395     return nullptr;
1396   }
1397   int has_field = HasFieldByDescriptor(self, descriptor);
1398   if (has_field < 0) {
1399     return nullptr;
1400   } else {
1401     return PyBool_FromLong(has_field);
1402   }
1403 }
1404 
1405 // ---------------------------------------------------------------------
1406 // Releasing messages
1407 //
1408 // The Python API's ClearField() and Clear() methods behave
1409 // differently than their C++ counterparts.  While the C++ versions
1410 // clears the children, the Python versions detaches the children,
1411 // without touching their content.  This impedance mismatch causes
1412 // some complexity in the implementation, which is captured in this
1413 // section.
1414 //
1415 // When one or multiple fields are cleared we need to:
1416 //
1417 // * Gather all child objects that need to be detached from the message.
1418 //   In composite_fields and child_submessages.
1419 //
1420 // * Create a new Python message of the same kind. Use SwapFields() to move
1421 //   data from the original message.
1422 //
1423 // * Change the parent of all child objects: update their strong reference
1424 //   to their parent, and move their presence in composite_fields and
1425 //   child_submessages.
1426 
1427 // ---------------------------------------------------------------------
1428 // Release a composite child of a CMessage
1429 
InternalReparentFields(CMessage * self,const std::vector<CMessage * > & messages_to_release,const std::vector<ContainerBase * > & containers_to_release)1430 static int InternalReparentFields(
1431     CMessage* self, const std::vector<CMessage*>& messages_to_release,
1432     const std::vector<ContainerBase*>& containers_to_release) {
1433   if (messages_to_release.empty() && containers_to_release.empty()) {
1434     return 0;
1435   }
1436 
1437   // Move all the passed sub_messages to another message.
1438   CMessage* new_message = cmessage::NewEmptyMessage(self->GetMessageClass());
1439   if (new_message == nullptr) {
1440     return -1;
1441   }
1442   new_message->message = self->message->New(nullptr);
1443   ScopedPyObjectPtr holder(reinterpret_cast<PyObject*>(new_message));
1444   new_message->child_submessages = new CMessage::SubMessagesMap();
1445   new_message->composite_fields = new CMessage::CompositeFieldsMap();
1446   std::set<const FieldDescriptor*> fields_to_swap;
1447 
1448   // In case this the removed fields are the last reference to a message, keep
1449   // a reference.
1450   Py_INCREF(self);
1451 
1452   for (const auto& to_release : messages_to_release) {
1453     fields_to_swap.insert(to_release->parent_field_descriptor);
1454     // Reparent
1455     Py_INCREF(new_message);
1456     Py_DECREF(to_release->parent);
1457     to_release->parent = new_message;
1458     self->child_submessages->erase(to_release->message);
1459     new_message->child_submessages->emplace(to_release->message, to_release);
1460   }
1461 
1462   for (const auto& to_release : containers_to_release) {
1463     fields_to_swap.insert(to_release->parent_field_descriptor);
1464     Py_INCREF(new_message);
1465     Py_DECREF(to_release->parent);
1466     to_release->parent = new_message;
1467     self->composite_fields->erase(to_release->parent_field_descriptor);
1468     new_message->composite_fields->emplace(to_release->parent_field_descriptor,
1469                                            to_release);
1470   }
1471 
1472   if (self->message->GetArena() == new_message->message->GetArena()) {
1473     MessageReflectionFriend::UnsafeShallowSwapFields(
1474         self->message, new_message->message,
1475         std::vector<const FieldDescriptor*>(fields_to_swap.begin(),
1476                                             fields_to_swap.end()));
1477   } else {
1478     self->message->GetReflection()->SwapFields(
1479         self->message, new_message->message,
1480         std::vector<const FieldDescriptor*>(fields_to_swap.begin(),
1481                                             fields_to_swap.end()));
1482   }
1483 
1484   // This might delete the Python message completely if all children were moved.
1485   Py_DECREF(self);
1486 
1487   return 0;
1488 }
1489 
InternalReleaseFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1490 int InternalReleaseFieldByDescriptor(CMessage* self,
1491                                      const FieldDescriptor* field_descriptor) {
1492   if (!field_descriptor->is_repeated() &&
1493       field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1494     // Single scalars are not in any cache.
1495     return 0;
1496   }
1497   std::vector<CMessage*> messages_to_release;
1498   std::vector<ContainerBase*> containers_to_release;
1499   if (self->child_submessages && field_descriptor->is_repeated() &&
1500       field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1501     for (const auto& child_item : *self->child_submessages) {
1502       if (child_item.second->parent_field_descriptor == field_descriptor) {
1503         messages_to_release.push_back(child_item.second);
1504       }
1505     }
1506   }
1507   if (self->composite_fields) {
1508     CMessage::CompositeFieldsMap::iterator it =
1509         self->composite_fields->find(field_descriptor);
1510     if (it != self->composite_fields->end()) {
1511       containers_to_release.push_back(it->second);
1512     }
1513   }
1514 
1515   return InternalReparentFields(self, messages_to_release,
1516                                 containers_to_release);
1517 }
1518 
ClearFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1519 int ClearFieldByDescriptor(CMessage* self,
1520                            const FieldDescriptor* field_descriptor) {
1521   if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
1522     return -1;
1523   }
1524   if (InternalReleaseFieldByDescriptor(self, field_descriptor) < 0) {
1525     return -1;
1526   }
1527   AssureWritable(self);
1528   Message* message = self->message;
1529   message->GetReflection()->ClearField(message, field_descriptor);
1530   return 0;
1531 }
1532 
ClearField(CMessage * self,PyObject * arg)1533 PyObject* ClearField(CMessage* self, PyObject* arg) {
1534   char* field_name;
1535   Py_ssize_t field_size;
1536   if (PyString_AsStringAndSize(arg, &field_name, &field_size) < 0) {
1537     return nullptr;
1538   }
1539   AssureWritable(self);
1540   bool is_in_oneof;
1541   const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
1542       self->message, absl::string_view(field_name, field_size), &is_in_oneof);
1543   if (field_descriptor == nullptr) {
1544     if (is_in_oneof) {
1545       // We gave the name of a oneof, and none of its fields are set.
1546       Py_RETURN_NONE;
1547     } else {
1548       PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.",
1549                    field_name);
1550       return nullptr;
1551     }
1552   }
1553 
1554   if (ClearFieldByDescriptor(self, field_descriptor) < 0) {
1555     return nullptr;
1556   }
1557   Py_RETURN_NONE;
1558 }
1559 
Clear(CMessage * self)1560 PyObject* Clear(CMessage* self) {
1561   AssureWritable(self);
1562   // Detach all current fields of this message
1563   std::vector<CMessage*> messages_to_release;
1564   std::vector<ContainerBase*> containers_to_release;
1565   if (self->child_submessages) {
1566     for (const auto& item : *self->child_submessages) {
1567       messages_to_release.push_back(item.second);
1568     }
1569   }
1570   if (self->composite_fields) {
1571     for (const auto& item : *self->composite_fields) {
1572       containers_to_release.push_back(item.second);
1573     }
1574   }
1575   if (InternalReparentFields(self, messages_to_release, containers_to_release) <
1576       0) {
1577     return nullptr;
1578   }
1579   self->message->Clear();
1580   Py_RETURN_NONE;
1581 }
1582 
1583 // ---------------------------------------------------------------------
1584 
GetMessageName(CMessage * self)1585 static std::string GetMessageName(CMessage* self) {
1586   if (self->parent_field_descriptor != nullptr) {
1587     return std::string(self->parent_field_descriptor->full_name());
1588   } else {
1589     return std::string(self->message->GetDescriptor()->full_name());
1590   }
1591 }
1592 
InternalSerializeToString(CMessage * self,PyObject * args,PyObject * kwargs,bool require_initialized)1593 static PyObject* InternalSerializeToString(CMessage* self, PyObject* args,
1594                                            PyObject* kwargs,
1595                                            bool require_initialized) {
1596   // Parse the "deterministic" kwarg; defaults to False.
1597   static const char* kwlist[] = {"deterministic", nullptr};
1598   PyObject* deterministic_obj = Py_None;
1599   if (!PyArg_ParseTupleAndKeywords(
1600           args, kwargs, "|O", const_cast<char**>(kwlist), &deterministic_obj)) {
1601     return nullptr;
1602   }
1603   // Preemptively convert to a bool first, so we don't need to back out of
1604   // allocating memory if this raises an exception.
1605   // NOTE: This is unused later if deterministic == Py_None, but that's fine.
1606   int deterministic = PyObject_IsTrue(deterministic_obj);
1607   if (deterministic < 0) {
1608     return nullptr;
1609   }
1610 
1611   if (require_initialized && !self->message->IsInitialized()) {
1612     ScopedPyObjectPtr errors(FindInitializationErrors(self));
1613     if (errors == nullptr) {
1614       return nullptr;
1615     }
1616     ScopedPyObjectPtr comma(PyUnicode_FromString(","));
1617     if (comma == nullptr) {
1618       return nullptr;
1619     }
1620     ScopedPyObjectPtr joined(
1621         PyObject_CallMethod(comma.get(), "join", "O", errors.get()));
1622     if (joined == nullptr) {
1623       return nullptr;
1624     }
1625 
1626     // TODO: this is a (hopefully temporary) hack.  The unit testing
1627     // infrastructure reloads all pure-Python modules for every test, but not
1628     // C++ modules (because that's generally impossible:
1629     // http://bugs.python.org/issue1144263).  But if we cache EncodeError, we'll
1630     // return the EncodeError from a previous load of the module, which won't
1631     // match a user's attempt to catch EncodeError.  So we have to look it up
1632     // again every time.
1633     ScopedPyObjectPtr message_module(
1634         PyImport_ImportModule("google.protobuf.message"));
1635     if (message_module.get() == nullptr) {
1636       return nullptr;
1637     }
1638 
1639     ScopedPyObjectPtr encode_error(
1640         PyObject_GetAttrString(message_module.get(), "EncodeError"));
1641     if (encode_error.get() == nullptr) {
1642       return nullptr;
1643     }
1644     PyErr_Format(encode_error.get(),
1645                  "Message %s is missing required fields: %s",
1646                  GetMessageName(self).c_str(), PyString_AsString(joined.get()));
1647     return nullptr;
1648   }
1649 
1650   // Ok, arguments parsed and errors checked, now encode to a string
1651   const size_t size = self->message->ByteSizeLong();
1652   if (size == 0) {
1653     return PyBytes_FromString("");
1654   }
1655 
1656   if (size > INT_MAX) {
1657     PyErr_Format(PyExc_ValueError,
1658                  "Message %s exceeds maximum protobuf "
1659                  "size of 2GB: %zu",
1660                  GetMessageName(self).c_str(), size);
1661     return nullptr;
1662   }
1663 
1664   PyObject* result = PyBytes_FromStringAndSize(nullptr, size);
1665   if (result == nullptr) {
1666     return nullptr;
1667   }
1668   io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
1669   io::CodedOutputStream coded_out(&out);
1670   if (deterministic_obj != Py_None) {
1671     coded_out.SetSerializationDeterministic(deterministic);
1672   }
1673   self->message->SerializeWithCachedSizes(&coded_out);
1674   ABSL_CHECK(!coded_out.HadError());
1675   return result;
1676 }
1677 
SerializeToString(CMessage * self,PyObject * args,PyObject * kwargs)1678 static PyObject* SerializeToString(CMessage* self, PyObject* args,
1679                                    PyObject* kwargs) {
1680   return InternalSerializeToString(self, args, kwargs,
1681                                    /*require_initialized=*/true);
1682 }
1683 
SerializePartialToString(CMessage * self,PyObject * args,PyObject * kwargs)1684 static PyObject* SerializePartialToString(CMessage* self, PyObject* args,
1685                                           PyObject* kwargs) {
1686   return InternalSerializeToString(self, args, kwargs,
1687                                    /*require_initialized=*/false);
1688 }
1689 
1690 // Formats proto fields for ascii dumps using python formatting functions where
1691 // appropriate.
1692 class PythonFieldValuePrinter : public TextFormat::FastFieldValuePrinter {
1693  public:
1694   // Python has some differences from C++ when printing floating point numbers.
1695   //
1696   // 1) Trailing .0 is always printed.
1697   // 2) (Python2) Output is rounded to 12 digits.
1698   // 3) (Python3) The full precision of the double is preserved (and Python uses
1699   //    David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some
1700   //    differences, but they rarely happen)
1701   //
1702   // We override floating point printing with the C-API function for printing
1703   // Python floats to ensure consistency.
PrintFloat(float val,TextFormat::BaseTextGenerator * generator) const1704   void PrintFloat(float val,
1705                   TextFormat::BaseTextGenerator* generator) const override {
1706     PrintDouble(val, generator);
1707   }
PrintDouble(double val,TextFormat::BaseTextGenerator * generator) const1708   void PrintDouble(double val,
1709                    TextFormat::BaseTextGenerator* generator) const override {
1710     // This implementation is not highly optimized (it allocates two temporary
1711     // Python objects) but it is simple and portable.  If this is shown to be a
1712     // performance bottleneck, we can optimize it, but the results will likely
1713     // be more complicated to accommodate the differing behavior of double
1714     // formatting between Python 2 and Python 3.
1715     //
1716     // (Though a valid question is: do we really want to make out output
1717     // dependent on the Python version?)
1718     ScopedPyObjectPtr py_value(PyFloat_FromDouble(val));
1719     if (!py_value.get()) {
1720       return;
1721     }
1722 
1723     ScopedPyObjectPtr py_str(PyObject_Str(py_value.get()));
1724     if (!py_str.get()) {
1725       return;
1726     }
1727 
1728     generator->PrintString(PyString_AsString(py_str.get()));
1729   }
PrintString(const std::string & val,TextFormat::BaseTextGenerator * generator) const1730   void PrintString(const std::string& val,
1731                    TextFormat::BaseTextGenerator* generator) const override {
1732     TextFormat::Printer::HardenedPrintString(val, generator);
1733   }
PrintBytes(const std::string & val,TextFormat::BaseTextGenerator * generator) const1734   void PrintBytes(const std::string& val,
1735                   TextFormat::BaseTextGenerator* generator) const override {
1736     generator->PrintLiteral("\"");
1737     if (!val.empty()) {
1738       generator->PrintString(absl::CEscape(val));
1739     }
1740     generator->PrintLiteral("\"");
1741   }
1742 };
1743 
ToStr(CMessage * self)1744 static PyObject* ToStr(CMessage* self) {
1745   TextFormat::Printer printer;
1746   // Passes ownership
1747   printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter());
1748   printer.SetHideUnknownFields(true);
1749   std::string output;
1750   if (!printer.PrintToString(*self->message, &output)) {
1751     PyErr_SetString(PyExc_ValueError, "Unable to convert message to str");
1752     return nullptr;
1753   }
1754   return PyUnicode_FromString(output.c_str());
1755 }
1756 
MergeFrom(CMessage * self,PyObject * arg)1757 PyObject* MergeFrom(CMessage* self, PyObject* arg) {
1758   CMessage* other_message;
1759   if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1760     PyErr_Format(
1761         PyExc_TypeError,
1762         "Parameter to MergeFrom() must be instance of same class: "
1763         "expected %s got %s.",
1764         std::string(self->message->GetDescriptor()->full_name()).c_str(),
1765         Py_TYPE(arg)->tp_name);
1766     return nullptr;
1767   }
1768 
1769   other_message = reinterpret_cast<CMessage*>(arg);
1770   if (other_message->message->GetDescriptor() !=
1771       self->message->GetDescriptor()) {
1772     PyErr_Format(
1773         PyExc_TypeError,
1774         "Parameter to MergeFrom() must be instance of same class: "
1775         "expected %s got %s.",
1776         std::string(self->message->GetDescriptor()->full_name()).c_str(),
1777         std::string(other_message->message->GetDescriptor()->full_name())
1778             .c_str());
1779     return nullptr;
1780   }
1781   AssureWritable(self);
1782 
1783   self->message->MergeFrom(*other_message->message);
1784   // Child message might be lazily created before MergeFrom. Make sure they
1785   // are mutable at this point if child messages are really created.
1786   if (FixupMessageAfterMerge(self) < 0) {
1787     return nullptr;
1788   }
1789 
1790   Py_RETURN_NONE;
1791 }
1792 
CopyFrom(CMessage * self,PyObject * arg)1793 static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
1794   CMessage* other_message;
1795   if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1796     PyErr_Format(
1797         PyExc_TypeError,
1798         "Parameter to CopyFrom() must be instance of same class: "
1799         "expected %s got %s.",
1800         std::string(self->message->GetDescriptor()->full_name()).c_str(),
1801         Py_TYPE(arg)->tp_name);
1802     return nullptr;
1803   }
1804 
1805   other_message = reinterpret_cast<CMessage*>(arg);
1806 
1807   if (self == other_message) {
1808     Py_RETURN_NONE;
1809   }
1810 
1811   if (other_message->message->GetDescriptor() !=
1812       self->message->GetDescriptor()) {
1813     PyErr_Format(
1814         PyExc_TypeError,
1815         "Parameter to CopyFrom() must be instance of same class: "
1816         "expected %s got %s.",
1817         std::string(self->message->GetDescriptor()->full_name()).c_str(),
1818         std::string(other_message->message->GetDescriptor()->full_name())
1819             .c_str());
1820     return nullptr;
1821   }
1822 
1823   AssureWritable(self);
1824 
1825   // CopyFrom on the message will not clean up self->composite_fields,
1826   // which can leave us in an inconsistent state, so clear it out here.
1827   (void)ScopedPyObjectPtr(Clear(self));
1828 
1829   self->message->CopyFrom(*other_message->message);
1830 
1831   Py_RETURN_NONE;
1832 }
1833 
1834 // Provide a method in the module to set allow_oversize_protos to a boolean
1835 // value. This method returns the newly value of allow_oversize_protos.
SetAllowOversizeProtos(PyObject * m,PyObject * arg)1836 PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
1837   if (!arg || !PyBool_Check(arg)) {
1838     PyErr_SetString(PyExc_TypeError,
1839                     "Argument to SetAllowOversizeProtos must be boolean");
1840     return nullptr;
1841   }
1842   allow_oversize_protos = PyObject_IsTrue(arg);
1843   if (allow_oversize_protos) {
1844     Py_RETURN_TRUE;
1845   } else {
1846     Py_RETURN_FALSE;
1847   }
1848 }
1849 
MergeFromString(CMessage * self,PyObject * arg)1850 static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
1851   Py_buffer data;
1852   if (PyObject_GetBuffer(arg, &data, PyBUF_SIMPLE) < 0) {
1853     return nullptr;
1854   }
1855 
1856   AssureWritable(self);
1857 
1858   PyMessageFactory* factory = GetFactoryForMessage(self);
1859   int depth = allow_oversize_protos
1860                   ? INT_MAX
1861                   : io::CodedInputStream::GetDefaultRecursionLimit();
1862   const char* ptr;
1863   internal::ParseContext ctx(
1864       depth, false, &ptr,
1865       absl::string_view(static_cast<const char*>(data.buf), data.len));
1866   PyBuffer_Release(&data);
1867   ctx.data().pool = factory->pool->pool;
1868   ctx.data().factory = factory->message_factory;
1869 
1870   ptr = self->message->_InternalParse(ptr, &ctx);
1871 
1872   // Child message might be lazily created before MergeFrom. Make sure they
1873   // are mutable at this point if child messages are really created.
1874   if (FixupMessageAfterMerge(self) < 0) {
1875     return nullptr;
1876   }
1877 
1878   // Python makes distinction in error message, between a general parse failure
1879   // and in-correct ending on a terminating tag. Hence we need to be a bit more
1880   // explicit in our correctness checks.
1881   if (ptr == nullptr) {
1882     // Parse error.
1883     PyErr_Format(
1884         DecodeError_class, "Error parsing message with type '%s'",
1885         std::string(self->GetMessageClass()->message_descriptor->full_name())
1886             .c_str());
1887     return nullptr;
1888   }
1889   if (ctx.BytesUntilLimit(ptr) < 0) {
1890     // The parser overshot the limit.
1891     PyErr_Format(
1892         DecodeError_class,
1893         "Error parsing message as the message exceeded the protobuf limit "
1894         "with type '%s'",
1895         std::string(self->GetMessageClass()->message_descriptor->full_name())
1896             .c_str());
1897     return nullptr;
1898   }
1899 
1900   // ctx has an explicit limit set (length of string_view), so we have to
1901   // check we ended at that limit.
1902   if (!ctx.EndedAtLimit()) {
1903     // TODO: Raise error and return NULL instead.
1904     // b/27494216
1905     PyErr_Warn(nullptr, "Unexpected end-group tag: Not all data was converted");
1906     return PyLong_FromLong(data.len - ctx.BytesUntilLimit(ptr));
1907   }
1908   return PyLong_FromLong(data.len);
1909 }
1910 
ParseFromString(CMessage * self,PyObject * arg)1911 static PyObject* ParseFromString(CMessage* self, PyObject* arg) {
1912   if (ScopedPyObjectPtr(Clear(self)) == nullptr) {
1913     return nullptr;
1914   }
1915   return MergeFromString(self, arg);
1916 }
1917 
ByteSize(CMessage * self,PyObject * args)1918 static PyObject* ByteSize(CMessage* self, PyObject* args) {
1919   return PyLong_FromLong(self->message->ByteSizeLong());
1920 }
1921 
SetInParent(CMessage * self,PyObject * args)1922 static PyObject* SetInParent(CMessage* self, PyObject* args) {
1923   AssureWritable(self);
1924   Py_RETURN_NONE;
1925 }
1926 
WhichOneof(CMessage * self,PyObject * arg)1927 static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
1928   Py_ssize_t name_size;
1929   char* name_data;
1930   if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0) return nullptr;
1931   const OneofDescriptor* oneof_desc =
1932       self->message->GetDescriptor()->FindOneofByName(
1933           absl::string_view(name_data, name_size));
1934   if (oneof_desc == nullptr) {
1935     PyErr_Format(PyExc_ValueError,
1936                  "Protocol message has no oneof \"%s\" field.", name_data);
1937     return nullptr;
1938   }
1939   const FieldDescriptor* field_in_oneof =
1940       self->message->GetReflection()->GetOneofFieldDescriptor(*self->message,
1941                                                               oneof_desc);
1942   if (field_in_oneof == nullptr) {
1943     Py_RETURN_NONE;
1944   } else {
1945     const absl::string_view name = field_in_oneof->name();
1946     return PyUnicode_FromStringAndSize(name.data(), name.size());
1947   }
1948 }
1949 
1950 static PyObject* GetExtensionDict(CMessage* self, void* closure);
1951 
ListFields(CMessage * self)1952 static PyObject* ListFields(CMessage* self) {
1953   std::vector<const FieldDescriptor*> fields;
1954   self->message->GetReflection()->ListFields(*self->message, &fields);
1955 
1956   // Normally, the list will be exactly the size of the fields.
1957   ScopedPyObjectPtr all_fields(PyList_New(fields.size()));
1958   if (all_fields == nullptr) {
1959     return nullptr;
1960   }
1961 
1962   // When there are unknown extensions, the py list will *not* contain
1963   // the field information.  Thus the actual size of the py list will be
1964   // smaller than the size of fields.  Set the actual size at the end.
1965   Py_ssize_t actual_size = 0;
1966   for (size_t i = 0; i < fields.size(); ++i) {
1967     ScopedPyObjectPtr t(PyTuple_New(2));
1968     if (t == nullptr) {
1969       return nullptr;
1970     }
1971 
1972     if (fields[i]->is_extension()) {
1973       ScopedPyObjectPtr extension_field(
1974           PyFieldDescriptor_FromDescriptor(fields[i]));
1975       if (extension_field == nullptr) {
1976         return nullptr;
1977       }
1978       // With C++ descriptors, the field can always be retrieved, but for
1979       // unknown extensions which have not been imported in Python code, there
1980       // is no message class and we cannot retrieve the value.
1981       // TODO: consider building the class on the fly!
1982       if (fields[i]->message_type() != nullptr &&
1983           message_factory::GetMessageClass(GetFactoryForMessage(self),
1984                                            fields[i]->message_type()) ==
1985               nullptr) {
1986         PyErr_Clear();
1987         continue;
1988       }
1989       ScopedPyObjectPtr extensions(GetExtensionDict(self, nullptr));
1990       if (extensions == nullptr) {
1991         return nullptr;
1992       }
1993       // 'extension' reference later stolen by PyTuple_SET_ITEM.
1994       PyObject* extension =
1995           PyObject_GetItem(extensions.get(), extension_field.get());
1996       if (extension == nullptr) {
1997         return nullptr;
1998       }
1999       PyTuple_SET_ITEM(t.get(), 0, extension_field.release());
2000       // Steals reference to 'extension'
2001       PyTuple_SET_ITEM(t.get(), 1, extension);
2002     } else {
2003       // Normal field
2004       ScopedPyObjectPtr field_descriptor(
2005           PyFieldDescriptor_FromDescriptor(fields[i]));
2006       if (field_descriptor == nullptr) {
2007         return nullptr;
2008       }
2009 
2010       PyObject* field_value = GetFieldValue(self, fields[i]);
2011       if (field_value == nullptr) {
2012         PyErr_SetString(PyExc_ValueError,
2013                         std::string(fields[i]->name()).c_str());
2014         return nullptr;
2015       }
2016       PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release());
2017       PyTuple_SET_ITEM(t.get(), 1, field_value);
2018     }
2019     PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
2020     ++actual_size;
2021   }
2022   if (static_cast<size_t>(actual_size) != fields.size() &&
2023       (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), nullptr) <
2024        0)) {
2025     return nullptr;
2026   }
2027   return all_fields.release();
2028 }
2029 
DiscardUnknownFields(CMessage * self)2030 static PyObject* DiscardUnknownFields(CMessage* self) {
2031   AssureWritable(self);
2032   self->message->DiscardUnknownFields();
2033   Py_RETURN_NONE;
2034 }
2035 
FindInitializationErrors(CMessage * self)2036 PyObject* FindInitializationErrors(CMessage* self) {
2037   Message* message = self->message;
2038   std::vector<std::string> errors;
2039   message->FindInitializationErrors(&errors);
2040 
2041   PyObject* error_list = PyList_New(errors.size());
2042   if (error_list == nullptr) {
2043     return nullptr;
2044   }
2045   for (size_t i = 0; i < errors.size(); ++i) {
2046     const std::string& error = errors[i];
2047     PyObject* error_string =
2048         PyUnicode_FromStringAndSize(error.c_str(), error.length());
2049     if (error_string == nullptr) {
2050       Py_DECREF(error_list);
2051       return nullptr;
2052     }
2053     PyList_SET_ITEM(error_list, i, error_string);
2054   }
2055   return error_list;
2056 }
2057 
RichCompare(CMessage * self,PyObject * other,int opid)2058 static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
2059   // Only equality comparisons are implemented.
2060   if (opid != Py_EQ && opid != Py_NE) {
2061     Py_INCREF(Py_NotImplemented);
2062     return Py_NotImplemented;
2063   }
2064 
2065   const Descriptor* self_descriptor = self->message->GetDescriptor();
2066   Descriptor::WellKnownType wkt = self_descriptor->well_known_type();
2067   if ((wkt == Descriptor::WELLKNOWNTYPE_LISTVALUE && PyList_Check(other)) ||
2068       (wkt == Descriptor::WELLKNOWNTYPE_STRUCT && PyDict_Check(other))) {
2069     return PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
2070                                "_internal_compare", "O", other);
2071   }
2072 
2073   // If other is not a message, this implementation doesn't know how to perform
2074   // comparisons.
2075   if (!PyObject_TypeCheck(other, CMessage_Type)) {
2076     Py_INCREF(Py_NotImplemented);
2077     return Py_NotImplemented;
2078   }
2079   // Otherwise, we have a CMessage whose message we can inspect.
2080   bool equals = true;
2081   const google::protobuf::Message* other_message =
2082       reinterpret_cast<CMessage*>(other)->message;
2083   // If messages don't have the same descriptors, they are not equal.
2084   if (equals && self_descriptor != other_message->GetDescriptor()) {
2085     equals = false;
2086   }
2087   // Check the message contents.
2088   if (equals &&
2089       !google::protobuf::util::MessageDifferencer::Equals(
2090           *self->message, *reinterpret_cast<CMessage*>(other)->message)) {
2091     equals = false;
2092   }
2093 
2094   if (equals ^ (opid == Py_EQ)) {
2095     Py_RETURN_FALSE;
2096   } else {
2097     Py_RETURN_TRUE;
2098   }
2099 }
2100 
InternalGetScalar(const Message * message,const FieldDescriptor * field_descriptor)2101 PyObject* InternalGetScalar(const Message* message,
2102                             const FieldDescriptor* field_descriptor) {
2103   const Reflection* reflection = message->GetReflection();
2104 
2105   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2106     return nullptr;
2107   }
2108 
2109   PyObject* result = nullptr;
2110   switch (field_descriptor->cpp_type()) {
2111     case FieldDescriptor::CPPTYPE_INT32: {
2112       int32_t value = reflection->GetInt32(*message, field_descriptor);
2113       result = PyLong_FromLong(value);
2114       break;
2115     }
2116     case FieldDescriptor::CPPTYPE_INT64: {
2117       int64_t value = reflection->GetInt64(*message, field_descriptor);
2118       result = PyLong_FromLongLong(value);
2119       break;
2120     }
2121     case FieldDescriptor::CPPTYPE_UINT32: {
2122       uint32_t value = reflection->GetUInt32(*message, field_descriptor);
2123       result = PyLong_FromSsize_t(value);
2124       break;
2125     }
2126     case FieldDescriptor::CPPTYPE_UINT64: {
2127       uint64_t value = reflection->GetUInt64(*message, field_descriptor);
2128       result = PyLong_FromUnsignedLongLong(value);
2129       break;
2130     }
2131     case FieldDescriptor::CPPTYPE_FLOAT: {
2132       float value = reflection->GetFloat(*message, field_descriptor);
2133       result = PyFloat_FromDouble(value);
2134       break;
2135     }
2136     case FieldDescriptor::CPPTYPE_DOUBLE: {
2137       double value = reflection->GetDouble(*message, field_descriptor);
2138       result = PyFloat_FromDouble(value);
2139       break;
2140     }
2141     case FieldDescriptor::CPPTYPE_BOOL: {
2142       bool value = reflection->GetBool(*message, field_descriptor);
2143       result = PyBool_FromLong(value);
2144       break;
2145     }
2146     case FieldDescriptor::CPPTYPE_STRING: {
2147       std::string scratch;
2148       const std::string& value =
2149           reflection->GetStringReference(*message, field_descriptor, &scratch);
2150       result = ToStringObject(field_descriptor, value);
2151       break;
2152     }
2153     case FieldDescriptor::CPPTYPE_ENUM: {
2154       const EnumValueDescriptor* enum_value =
2155           message->GetReflection()->GetEnum(*message, field_descriptor);
2156       result = PyLong_FromLong(enum_value->number());
2157       break;
2158     }
2159     default:
2160       PyErr_Format(PyExc_SystemError,
2161                    "Getting a value from a field of unknown type %d",
2162                    field_descriptor->cpp_type());
2163   }
2164 
2165   return result;
2166 }
2167 
InternalGetSubMessage(CMessage * self,const FieldDescriptor * field_descriptor)2168 CMessage* InternalGetSubMessage(CMessage* self,
2169                                 const FieldDescriptor* field_descriptor) {
2170   const Reflection* reflection = self->message->GetReflection();
2171   PyMessageFactory* factory = GetFactoryForMessage(self);
2172 
2173   CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2174       factory, field_descriptor->message_type());
2175   ScopedPyObjectPtr message_class_owner(
2176       reinterpret_cast<PyObject*>(message_class));
2177   if (message_class == nullptr) {
2178     return nullptr;
2179   }
2180 
2181   CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
2182   if (cmsg == nullptr) {
2183     return nullptr;
2184   }
2185 
2186   Py_INCREF(self);
2187   cmsg->parent = self;
2188   cmsg->parent_field_descriptor = field_descriptor;
2189   if (reflection->HasField(*self->message, field_descriptor)) {
2190     // Force triggering MutableMessage to set the lazy message 'Dirty'
2191     if (MessageReflectionFriend::IsLazyField(reflection, *self->message,
2192                                              field_descriptor)) {
2193       Message* sub_message = reflection->MutableMessage(
2194           self->message, field_descriptor, factory->message_factory);
2195       cmsg->read_only = false;
2196       cmsg->message = sub_message;
2197       return cmsg;
2198     }
2199   } else {
2200     cmsg->read_only = true;
2201   }
2202   const Message& sub_message = reflection->GetMessage(
2203       *self->message, field_descriptor, factory->message_factory);
2204   cmsg->message = const_cast<Message*>(&sub_message);
2205   return cmsg;
2206 }
2207 
InternalSetNonOneofScalar(Message * message,const FieldDescriptor * field_descriptor,PyObject * arg)2208 int InternalSetNonOneofScalar(Message* message,
2209                               const FieldDescriptor* field_descriptor,
2210                               PyObject* arg) {
2211   const Reflection* reflection = message->GetReflection();
2212 
2213   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2214     return -1;
2215   }
2216 
2217   switch (field_descriptor->cpp_type()) {
2218     case FieldDescriptor::CPPTYPE_INT32: {
2219       PROTOBUF_CHECK_GET_INT32(arg, value, -1);
2220       reflection->SetInt32(message, field_descriptor, value);
2221       break;
2222     }
2223     case FieldDescriptor::CPPTYPE_INT64: {
2224       PROTOBUF_CHECK_GET_INT64(arg, value, -1);
2225       reflection->SetInt64(message, field_descriptor, value);
2226       break;
2227     }
2228     case FieldDescriptor::CPPTYPE_UINT32: {
2229       PROTOBUF_CHECK_GET_UINT32(arg, value, -1);
2230       reflection->SetUInt32(message, field_descriptor, value);
2231       break;
2232     }
2233     case FieldDescriptor::CPPTYPE_UINT64: {
2234       PROTOBUF_CHECK_GET_UINT64(arg, value, -1);
2235       reflection->SetUInt64(message, field_descriptor, value);
2236       break;
2237     }
2238     case FieldDescriptor::CPPTYPE_FLOAT: {
2239       PROTOBUF_CHECK_GET_FLOAT(arg, value, -1);
2240       reflection->SetFloat(message, field_descriptor, value);
2241       break;
2242     }
2243     case FieldDescriptor::CPPTYPE_DOUBLE: {
2244       PROTOBUF_CHECK_GET_DOUBLE(arg, value, -1);
2245       reflection->SetDouble(message, field_descriptor, value);
2246       break;
2247     }
2248     case FieldDescriptor::CPPTYPE_BOOL: {
2249       PROTOBUF_CHECK_GET_BOOL(arg, value, -1);
2250       reflection->SetBool(message, field_descriptor, value);
2251       break;
2252     }
2253     case FieldDescriptor::CPPTYPE_STRING: {
2254       if (!CheckAndSetString(arg, message, field_descriptor, reflection, false,
2255                              -1)) {
2256         return -1;
2257       }
2258       break;
2259     }
2260     case FieldDescriptor::CPPTYPE_ENUM: {
2261       PROTOBUF_CHECK_GET_INT32(arg, value, -1);
2262       if (!field_descriptor->legacy_enum_field_treated_as_closed()) {
2263         reflection->SetEnumValue(message, field_descriptor, value);
2264       } else {
2265         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
2266         const EnumValueDescriptor* enum_value =
2267             enum_descriptor->FindValueByNumber(value);
2268         if (enum_value != nullptr) {
2269           reflection->SetEnum(message, field_descriptor, enum_value);
2270         } else {
2271           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
2272           return -1;
2273         }
2274       }
2275       break;
2276     }
2277     default:
2278       PyErr_Format(PyExc_SystemError,
2279                    "Setting value to a field of unknown type %d",
2280                    field_descriptor->cpp_type());
2281       return -1;
2282   }
2283 
2284   return 0;
2285 }
2286 
InternalSetScalar(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * arg)2287 int InternalSetScalar(CMessage* self, const FieldDescriptor* field_descriptor,
2288                       PyObject* arg) {
2289   if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
2290     return -1;
2291   }
2292 
2293   if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
2294     return -1;
2295   }
2296 
2297   return InternalSetNonOneofScalar(self->message, field_descriptor, arg);
2298 }
2299 
FromString(PyTypeObject * cls,PyObject * serialized)2300 PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
2301   PyObject* py_cmsg =
2302       PyObject_CallObject(reinterpret_cast<PyObject*>(cls), nullptr);
2303   if (py_cmsg == nullptr) {
2304     return nullptr;
2305   }
2306   CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
2307 
2308   ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized));
2309   if (py_length == nullptr) {
2310     Py_DECREF(py_cmsg);
2311     return nullptr;
2312   }
2313 
2314   return py_cmsg;
2315 }
2316 
DeepCopy(CMessage * self,PyObject * arg)2317 PyObject* DeepCopy(CMessage* self, PyObject* arg) {
2318   PyObject* clone =
2319       PyObject_CallObject(reinterpret_cast<PyObject*>(Py_TYPE(self)), nullptr);
2320   if (clone == nullptr) {
2321     return nullptr;
2322   }
2323   if (!PyObject_TypeCheck(clone, CMessage_Type)) {
2324     Py_DECREF(clone);
2325     return nullptr;
2326   }
2327   if (ScopedPyObjectPtr(MergeFrom(reinterpret_cast<CMessage*>(clone),
2328                                   reinterpret_cast<PyObject*>(self))) ==
2329       nullptr) {
2330     Py_DECREF(clone);
2331     return nullptr;
2332   }
2333   return clone;
2334 }
2335 
ToUnicode(CMessage * self)2336 PyObject* ToUnicode(CMessage* self) {
2337   // Lazy import to prevent circular dependencies
2338   ScopedPyObjectPtr text_format(
2339       PyImport_ImportModule(PROTOBUF_PYTHON_PUBLIC ".text_format"));
2340   if (text_format == nullptr) {
2341     return nullptr;
2342   }
2343   ScopedPyObjectPtr method_name(PyUnicode_FromString("MessageToString"));
2344   if (method_name == nullptr) {
2345     return nullptr;
2346   }
2347   Py_INCREF(Py_True);
2348   ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(
2349       text_format.get(), method_name.get(), self, Py_True, nullptr));
2350   Py_DECREF(Py_True);
2351   if (encoded == nullptr) {
2352     return nullptr;
2353   }
2354   PyObject* decoded =
2355       PyUnicode_FromEncodedObject(encoded.get(), "utf-8", nullptr);
2356   if (decoded == nullptr) {
2357     return nullptr;
2358   }
2359   return decoded;
2360 }
2361 
Contains(CMessage * self,PyObject * arg)2362 PyObject* Contains(CMessage* self, PyObject* arg) {
2363   Message* message = self->message;
2364   const Descriptor* descriptor = message->GetDescriptor();
2365   switch (descriptor->well_known_type()) {
2366     case Descriptor::WELLKNOWNTYPE_STRUCT: {
2367       // For WKT Struct, check if the key is in the fields.
2368       const Reflection* reflection = message->GetReflection();
2369       const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
2370       const FieldDescriptor* key_field = map_field->message_type()->map_key();
2371       PyObject* py_string = CheckString(arg, key_field);
2372       if (!py_string) {
2373         PyErr_SetString(PyExc_TypeError,
2374                         "The key passed to Struct message must be a str.");
2375         return nullptr;
2376       }
2377       char* value;
2378       Py_ssize_t value_len;
2379       if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
2380         Py_DECREF(py_string);
2381         Py_RETURN_FALSE;
2382       }
2383       std::string key_str;
2384       key_str.assign(value, value_len);
2385       Py_DECREF(py_string);
2386 
2387       MapKey map_key;
2388       map_key.SetStringValue(key_str);
2389       return PyBool_FromLong(MessageReflectionFriend::ContainsMapKey(
2390           reflection, *message, map_field, map_key));
2391     }
2392     case Descriptor::WELLKNOWNTYPE_LISTVALUE: {
2393       // For WKT ListValue, check if the key is in the items.
2394       PyObject* items = PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
2395                                             "items", nullptr);
2396       return PyBool_FromLong(PySequence_Contains(items, arg));
2397     }
2398     default:
2399       // For other messages, check with HasField.
2400       return HasField(self, arg);
2401   }
2402 }
2403 
2404 // CMessage static methods:
_CheckCalledFromGeneratedFile(PyObject * unused,PyObject * unused_arg)2405 PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
2406                                         PyObject* unused_arg) {
2407   if (!_CalledFromGeneratedFile(1)) {
2408     PyErr_SetString(PyExc_TypeError,
2409                     "Descriptors should not be created directly, "
2410                     "but only retrieved from their parent.");
2411     return nullptr;
2412   }
2413   Py_RETURN_NONE;
2414 }
2415 
GetExtensionDict(CMessage * self,void * closure)2416 static PyObject* GetExtensionDict(CMessage* self, void* closure) {
2417   // If there are extension_ranges, the message is "extendable". Allocate a
2418   // dictionary to store the extension fields.
2419   const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
2420   if (!descriptor->extension_range_count()) {
2421     PyErr_SetNone(PyExc_AttributeError);
2422     return nullptr;
2423   }
2424   if (!self->composite_fields) {
2425     self->composite_fields = new CMessage::CompositeFieldsMap();
2426   }
2427   if (!self->composite_fields) {
2428     return nullptr;
2429   }
2430   ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
2431   return reinterpret_cast<PyObject*>(extension_dict);
2432 }
2433 
GetUnknownFields(CMessage * self)2434 static PyObject* GetUnknownFields(CMessage* self) {
2435   PyErr_Format(PyExc_NotImplementedError,
2436                "Please use the add-on feature "
2437                "unknown_fields.UnknownFieldSet(message) in "
2438                "unknown_fields.py instead.");
2439   return nullptr;
2440 }
2441 
2442 static PyGetSetDef Getters[] = {
2443     {"Extensions", (getter)GetExtensionDict, nullptr, "Extension dict"},
2444     {nullptr},
2445 };
2446 
2447 static PyMethodDef Methods[] = {
2448     {"__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
2449      "Makes a deep copy of the class."},
2450     {"__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
2451      "Outputs a unicode representation of the message."},
2452     {"__contains__", (PyCFunction)Contains, METH_O,
2453      "Checks if a message field is set."},
2454     {"ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
2455      "Returns the size of the message in bytes."},
2456     {"Clear", (PyCFunction)Clear, METH_NOARGS, "Clears the message."},
2457     {"ClearExtension", (PyCFunction)ClearExtension, METH_O,
2458      "Clears a message field."},
2459     {"ClearField", (PyCFunction)ClearField, METH_O, "Clears a message field."},
2460     {"CopyFrom", (PyCFunction)CopyFrom, METH_O,
2461      "Copies a protocol message into the current message."},
2462     {"DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
2463      "Discards the unknown fields."},
2464     {"FindInitializationErrors", (PyCFunction)FindInitializationErrors,
2465      METH_NOARGS, "Finds unset required fields."},
2466     {"FromString", (PyCFunction)FromString, METH_O | METH_CLASS,
2467      "Creates new method instance from given serialized data."},
2468     {"HasExtension", (PyCFunction)HasExtension, METH_O,
2469      "Checks if a message field is set."},
2470     {"HasField", (PyCFunction)HasField, METH_O,
2471      "Checks if a message field is set."},
2472     {"IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS,
2473      "Checks if all required fields of a protocol message are set."},
2474     {"ListFields", (PyCFunction)ListFields, METH_NOARGS,
2475      "Lists all set fields of a message."},
2476     {"MergeFrom", (PyCFunction)MergeFrom, METH_O,
2477      "Merges a protocol message into the current message."},
2478     {"MergeFromString", (PyCFunction)MergeFromString, METH_O,
2479      "Merges a serialized message into the current message."},
2480     {"ParseFromString", (PyCFunction)ParseFromString, METH_O,
2481      "Parses a serialized message into the current message."},
2482     {"SerializePartialToString", (PyCFunction)SerializePartialToString,
2483      METH_VARARGS | METH_KEYWORDS,
2484      "Serializes the message to a string, even if it isn't initialized."},
2485     {"SerializeToString", (PyCFunction)SerializeToString,
2486      METH_VARARGS | METH_KEYWORDS,
2487      "Serializes the message to a string, only for initialized messages."},
2488     {"SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
2489      "Sets the has bit of the given field in its parent message."},
2490     {"UnknownFields", (PyCFunction)GetUnknownFields, METH_NOARGS,
2491      "Parse unknown field set"},
2492     {"WhichOneof", (PyCFunction)WhichOneof, METH_O,
2493      "Returns the name of the field set inside a oneof, "
2494      "or None if no field is set."},
2495 
2496     // Static Methods.
2497     {"_CheckCalledFromGeneratedFile",
2498      (PyCFunction)_CheckCalledFromGeneratedFile, METH_NOARGS | METH_STATIC,
2499      "Raises TypeError if the caller is not in a _pb2.py file."},
2500     {nullptr, nullptr}};
2501 
SetCompositeField(CMessage * self,const FieldDescriptor * field,ContainerBase * value)2502 bool SetCompositeField(CMessage* self, const FieldDescriptor* field,
2503                        ContainerBase* value) {
2504   if (self->composite_fields == nullptr) {
2505     self->composite_fields = new CMessage::CompositeFieldsMap();
2506   }
2507   (*self->composite_fields)[field] = value;
2508   return true;
2509 }
2510 
SetSubmessage(CMessage * self,CMessage * submessage)2511 bool SetSubmessage(CMessage* self, CMessage* submessage) {
2512   if (self->child_submessages == nullptr) {
2513     self->child_submessages = new CMessage::SubMessagesMap();
2514   }
2515   (*self->child_submessages)[submessage->message] = submessage;
2516   return true;
2517 }
2518 
GetAttr(PyObject * pself,PyObject * name)2519 PyObject* GetAttr(PyObject* pself, PyObject* name) {
2520   CMessage* self = reinterpret_cast<CMessage*>(pself);
2521   PyObject* result =
2522       PyObject_GenericGetAttr(reinterpret_cast<PyObject*>(self), name);
2523   if (result != nullptr) {
2524     return result;
2525   }
2526   if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
2527     return nullptr;
2528   }
2529 
2530   PyErr_Clear();
2531   return message_meta::GetClassAttribute(CheckMessageClass(Py_TYPE(self)),
2532                                          name);
2533 }
2534 
GetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor)2535 PyObject* GetFieldValue(CMessage* self,
2536                         const FieldDescriptor* field_descriptor) {
2537   if (self->composite_fields) {
2538     CMessage::CompositeFieldsMap::iterator it =
2539         self->composite_fields->find(field_descriptor);
2540     if (it != self->composite_fields->end()) {
2541       ContainerBase* value = it->second;
2542       Py_INCREF(value);
2543       return value->AsPyObject();
2544     }
2545   }
2546 
2547   if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2548     PyErr_Format(PyExc_TypeError,
2549                  "descriptor to field '%s' doesn't apply to '%s' object",
2550                  std::string(field_descriptor->full_name()).c_str(),
2551                  Py_TYPE(self)->tp_name);
2552     return nullptr;
2553   }
2554 
2555   if (!field_descriptor->is_repeated() &&
2556       field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
2557     return InternalGetScalar(self->message, field_descriptor);
2558   }
2559 
2560   ContainerBase* py_container = nullptr;
2561   if (field_descriptor->is_map()) {
2562     const Descriptor* entry_type = field_descriptor->message_type();
2563     const FieldDescriptor* value_type = entry_type->map_value();
2564     if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2565       CMessageClass* value_class = message_factory::GetMessageClass(
2566           GetFactoryForMessage(self), value_type->message_type());
2567       if (value_class == nullptr) {
2568         return nullptr;
2569       }
2570       py_container =
2571           NewMessageMapContainer(self, field_descriptor, value_class);
2572     } else {
2573       py_container = NewScalarMapContainer(self, field_descriptor);
2574     }
2575   } else if (field_descriptor->is_repeated()) {
2576     if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2577       CMessageClass* message_class = message_factory::GetMessageClass(
2578           GetFactoryForMessage(self), field_descriptor->message_type());
2579       if (message_class == nullptr) {
2580         return nullptr;
2581       }
2582       py_container = repeated_composite_container::NewContainer(
2583           self, field_descriptor, message_class);
2584     } else {
2585       py_container =
2586           repeated_scalar_container::NewContainer(self, field_descriptor);
2587     }
2588   } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2589     py_container = InternalGetSubMessage(self, field_descriptor);
2590   } else {
2591     PyErr_SetString(PyExc_SystemError, "Should never happen");
2592   }
2593 
2594   if (py_container == nullptr) {
2595     return nullptr;
2596   }
2597   if (!SetCompositeField(self, field_descriptor, py_container)) {
2598     Py_DECREF(py_container);
2599     return nullptr;
2600   }
2601   return py_container->AsPyObject();
2602 }
2603 
SetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * value)2604 int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
2605                   PyObject* value) {
2606   if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2607     PyErr_Format(PyExc_TypeError,
2608                  "descriptor to field '%s' doesn't apply to '%s' object",
2609                  std::string(field_descriptor->full_name()).c_str(),
2610                  Py_TYPE(self)->tp_name);
2611     return -1;
2612   } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
2613     PyErr_Format(PyExc_AttributeError,
2614                  "Assignment not allowed to repeated "
2615                  "field \"%s\" in protocol message object.",
2616                  std::string(field_descriptor->name()).c_str());
2617     return -1;
2618   } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2619     if (field_descriptor->message_type()->well_known_type() !=
2620         Descriptor::WELLKNOWNTYPE_UNSPECIFIED) {
2621       PyObject* sub_message = GetFieldValue(self, field_descriptor);
2622       if (PyObject_HasAttrString(sub_message, "_internal_assign")) {
2623         AssureWritable(self);
2624         ScopedPyObjectPtr ok(
2625             PyObject_CallMethod(sub_message, "_internal_assign", "O", value));
2626         if (ok.get() == nullptr) {
2627           return -1;
2628         }
2629         return 0;
2630       }
2631     }
2632     PyErr_Format(PyExc_AttributeError,
2633                  "Assignment not allowed to "
2634                  "field \"%s\" in protocol message object.",
2635                  std::string(field_descriptor->name()).c_str());
2636     return -1;
2637   } else {
2638     AssureWritable(self);
2639     return InternalSetScalar(self, field_descriptor, value);
2640   }
2641 }
2642 
2643 }  // namespace cmessage
2644 
2645 // All containers which are not messages:
2646 // - Make a new parent message
2647 // - Copy the field
2648 // - return the field.
DeepCopy()2649 PyObject* ContainerBase::DeepCopy() {
2650   CMessage* new_parent =
2651       cmessage::NewEmptyMessage(this->parent->GetMessageClass());
2652   new_parent->message = this->parent->message->New(nullptr);
2653 
2654   // There is no API to copy a single field. The closest operation we have is
2655   // SwapFields.
2656   // So, we copy the source into a disposable message and then swap the one
2657   // field we care about from it.
2658   // If the performance of this operation matters we can do a copy of the single
2659   // field, but that would require huge switches for each type+cardinality to
2660   // call the right read/write field functions.
2661   std::unique_ptr<Message> tmp(this->parent->message->New(nullptr));
2662   tmp->MergeFrom(*this->parent->message);
2663   tmp->GetReflection()->SwapFields(tmp.get(), new_parent->message,
2664                                    {this->parent_field_descriptor});
2665 
2666   PyObject* result =
2667       cmessage::GetFieldValue(new_parent, this->parent_field_descriptor);
2668   Py_DECREF(new_parent);
2669   return result;
2670 }
2671 
RemoveFromParentCache()2672 void ContainerBase::RemoveFromParentCache() {
2673   CMessage* parent = this->parent;
2674   if (parent) {
2675     if (parent->composite_fields)
2676       parent->composite_fields->erase(this->parent_field_descriptor);
2677     Py_CLEAR(parent);
2678   }
2679 }
2680 
BuildSubMessageFromPointer(const FieldDescriptor * field_descriptor,Message * sub_message,CMessageClass * message_class)2681 CMessage* CMessage::BuildSubMessageFromPointer(
2682     const FieldDescriptor* field_descriptor, Message* sub_message,
2683     CMessageClass* message_class) {
2684   if (!this->child_submessages) {
2685     this->child_submessages = new CMessage::SubMessagesMap();
2686   }
2687   auto it = this->child_submessages->find(sub_message);
2688   if (it != this->child_submessages->end()) {
2689     Py_INCREF(it->second);
2690     return it->second;
2691   }
2692 
2693   CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
2694 
2695   if (cmsg == nullptr) {
2696     return nullptr;
2697   }
2698   cmsg->message = sub_message;
2699   Py_INCREF(this);
2700   cmsg->parent = this;
2701   cmsg->parent_field_descriptor = field_descriptor;
2702   cmessage::SetSubmessage(this, cmsg);
2703   return cmsg;
2704 }
2705 
MaybeReleaseSubMessage(Message * sub_message)2706 CMessage* CMessage::MaybeReleaseSubMessage(Message* sub_message) {
2707   if (!this->child_submessages) {
2708     return nullptr;
2709   }
2710   auto it = this->child_submessages->find(sub_message);
2711   if (it == this->child_submessages->end()) return nullptr;
2712   CMessage* released = it->second;
2713 
2714   // The target message will now own its content.
2715   Py_CLEAR(released->parent);
2716   released->parent_field_descriptor = nullptr;
2717   released->read_only = false;
2718   // Delete it from the cache.
2719   this->child_submessages->erase(sub_message);
2720   return released;
2721 }
2722 
2723 static CMessageClass _CMessage_Type = {{{
2724     PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0) FULL_MODULE_NAME
2725     ".CMessage",                    // tp_name
2726     sizeof(CMessage),               // tp_basicsize
2727     0,                              //  tp_itemsize
2728     (destructor)cmessage::Dealloc,  //  tp_dealloc
2729 #if PY_VERSION_HEX < 0x03080000
2730     nullptr,  // tp_print
2731 #else
2732     0,  // tp_vectorcall_offset
2733 #endif
2734     nullptr,                      //  tp_getattr
2735     nullptr,                      //  tp_setattr
2736     nullptr,                      //  tp_compare
2737     (reprfunc)cmessage::ToStr,    //  tp_repr
2738     nullptr,                      //  tp_as_number
2739     nullptr,                      //  tp_as_sequence
2740     nullptr,                      //  tp_as_mapping
2741     PyObject_HashNotImplemented,  //  tp_hash
2742     nullptr,                      //  tp_call
2743     (reprfunc)cmessage::ToStr,    //  tp_str
2744     cmessage::GetAttr,            //  tp_getattro
2745     nullptr,                      //  tp_setattro
2746     nullptr,                      //  tp_as_buffer
2747     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
2748         Py_TPFLAGS_HAVE_VERSION_TAG,     //  tp_flags
2749     "A ProtocolMessage",                 //  tp_doc
2750     nullptr,                             //  tp_traverse
2751     nullptr,                             //  tp_clear
2752     (richcmpfunc)cmessage::RichCompare,  //  tp_richcompare
2753     offsetof(CMessage, weakreflist),     //  tp_weaklistoffset
2754     nullptr,                             //  tp_iter
2755     nullptr,                             //  tp_iternext
2756     cmessage::Methods,                   //  tp_methods
2757     nullptr,                             //  tp_members
2758     cmessage::Getters,                   //  tp_getset
2759     nullptr,                             //  tp_base
2760     nullptr,                             //  tp_dict
2761     nullptr,                             //  tp_descr_get
2762     nullptr,                             //  tp_descr_set
2763     0,                                   //  tp_dictoffset
2764     (initproc)cmessage::Init,            //  tp_init
2765     nullptr,                             //  tp_alloc
2766     cmessage::New,                       //  tp_new
2767 }}};
2768 PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type;
2769 
2770 // --- Exposing the C proto living inside Python proto to C code:
2771 
2772 const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
2773 Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
2774 
GetCProtoInsidePyProtoImpl(PyObject * msg)2775 static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
2776   const Message* message = PyMessage_GetMessagePointer(msg);
2777   if (message == nullptr) {
2778     PyErr_Clear();
2779     return nullptr;
2780   }
2781   return message;
2782 }
2783 
MutableCProtoInsidePyProtoImpl(PyObject * msg)2784 static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
2785   Message* message = PyMessage_GetMutableMessagePointer(msg);
2786   if (message == nullptr) {
2787     PyErr_Clear();
2788     return nullptr;
2789   }
2790   return message;
2791 }
2792 
PyMessage_GetMessagePointer(PyObject * msg)2793 const Message* PyMessage_GetMessagePointer(PyObject* msg) {
2794   if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2795     PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2796     return nullptr;
2797   }
2798   CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2799   return cmsg->message;
2800 }
2801 
PyMessage_GetMutableMessagePointer(PyObject * msg)2802 Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
2803   if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2804     PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2805     return nullptr;
2806   }
2807   CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2808 
2809 
2810   if ((cmsg->composite_fields && !cmsg->composite_fields->empty()) ||
2811       (cmsg->child_submessages && !cmsg->child_submessages->empty())) {
2812     // There is currently no way of accurately syncing arbitrary changes to
2813     // the underlying C++ message back to the CMessage (e.g. removed repeated
2814     // composite containers). We only allow direct mutation of the underlying
2815     // C++ message if there is no child data in the CMessage.
2816     PyErr_SetString(PyExc_ValueError,
2817                     "Cannot reliably get a mutable pointer "
2818                     "to a message with extra references");
2819     return nullptr;
2820   }
2821   cmessage::AssureWritable(cmsg);
2822   return cmsg->message;
2823 }
2824 
2825 // Returns a new reference to the MessageClass to use for message creation.
2826 // - if a PyMessageFactory is passed, use it.
2827 // - Otherwise, if a PyDescriptorPool was created, use its factory.
GetMessageClassFromDescriptor(const Descriptor * descriptor,PyObject * py_message_factory)2828 static CMessageClass* GetMessageClassFromDescriptor(
2829     const Descriptor* descriptor, PyObject* py_message_factory) {
2830   PyMessageFactory* factory = nullptr;
2831   if (py_message_factory == nullptr) {
2832     PyDescriptorPool* pool =
2833         GetDescriptorPool_FromPool(descriptor->file()->pool());
2834     if (pool == nullptr) {
2835       PyErr_SetString(PyExc_TypeError,
2836                       "Unknown descriptor pool; C++ users should call "
2837                       "DescriptorPool_FromPool and keep it alive");
2838       return nullptr;
2839     }
2840     factory = pool->py_message_factory;
2841   } else if (PyObject_TypeCheck(py_message_factory, &PyMessageFactory_Type)) {
2842     factory = reinterpret_cast<PyMessageFactory*>(py_message_factory);
2843   } else {
2844     PyErr_SetString(PyExc_TypeError, "Expected a MessageFactory");
2845     return nullptr;
2846   }
2847 
2848   return message_factory::GetOrCreateMessageClass(factory, descriptor);
2849 }
2850 
PyMessage_New(const Descriptor * descriptor,PyObject * py_message_factory)2851 PyObject* PyMessage_New(const Descriptor* descriptor,
2852                         PyObject* py_message_factory) {
2853   CMessageClass* message_class =
2854       GetMessageClassFromDescriptor(descriptor, py_message_factory);
2855   if (message_class == nullptr) {
2856     return nullptr;
2857   }
2858 
2859   CMessage* self = cmessage::NewCMessage(message_class);
2860   Py_DECREF(message_class);
2861   if (self == nullptr) {
2862     return nullptr;
2863   }
2864   return self->AsPyObject();
2865 }
2866 
PyMessage_NewMessageOwnedExternally(Message * message,PyObject * py_message_factory)2867 PyObject* PyMessage_NewMessageOwnedExternally(Message* message,
2868                                               PyObject* py_message_factory) {
2869   CMessageClass* message_class = GetMessageClassFromDescriptor(
2870       message->GetDescriptor(), py_message_factory);
2871   if (message_class == nullptr) {
2872     return nullptr;
2873   }
2874 
2875   CMessage* self = cmessage::NewEmptyMessage(message_class);
2876   Py_DECREF(message_class);
2877   if (self == nullptr) {
2878     return nullptr;
2879   }
2880   self->message = message;
2881   Py_INCREF(Py_None);
2882   self->parent = reinterpret_cast<CMessage*>(Py_None);
2883   return self->AsPyObject();
2884 }
2885 
InitGlobals()2886 void InitGlobals() {
2887   // TODO: Check all return values in this function for NULL and propagate
2888   // the error (MemoryError) on up to result in an import failure.  These should
2889   // also be freed and reset to NULL during finalization.
2890   kDESCRIPTOR = PyUnicode_FromString("DESCRIPTOR");
2891   kMessageFactory = PyUnicode_FromString("message_factory");
2892 
2893   PyObject* dummy_obj = PySet_New(nullptr);
2894   kEmptyWeakref = PyWeakref_NewRef(dummy_obj, nullptr);
2895   Py_DECREF(dummy_obj);
2896 }
2897 
InitProto2MessageModule(PyObject * m)2898 bool InitProto2MessageModule(PyObject* m) {
2899   // Initialize types and globals in descriptor.cc
2900   if (!InitDescriptor()) {
2901     return false;
2902   }
2903 
2904   // Initialize types and globals in descriptor_pool.cc
2905   if (!InitDescriptorPool()) {
2906     return false;
2907   }
2908 
2909   // Initialize types and globals in message_factory.cc
2910   if (!InitMessageFactory()) {
2911     return false;
2912   }
2913 
2914   // Initialize constants defined in this file.
2915   InitGlobals();
2916 
2917   CMessageClass_Type->tp_base = &PyType_Type;
2918   if (PyType_Ready(CMessageClass_Type) < 0) {
2919     return false;
2920   }
2921   PyModule_AddObject(m, "MessageMeta",
2922                      reinterpret_cast<PyObject*>(CMessageClass_Type));
2923 
2924   if (PyType_Ready(CMessage_Type) < 0) {
2925     return false;
2926   }
2927   if (PyType_Ready(CFieldProperty_Type) < 0) {
2928     return false;
2929   }
2930 
2931   // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
2932   // it here as well to document that subclasses need to set it.
2933   PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None);
2934   // Invalidate any cached data for the CMessage type.
2935   // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG,
2936   // after we have modified CMessage_Type.tp_dict.
2937   PyType_Modified(CMessage_Type);
2938 
2939   PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type));
2940 
2941   // Initialize Repeated container types.
2942   {
2943     if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) {
2944       return false;
2945     }
2946 
2947     PyModule_AddObject(
2948         m, "RepeatedScalarContainer",
2949         reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type));
2950 
2951     if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) {
2952       return false;
2953     }
2954 
2955     PyModule_AddObject(
2956         m, "RepeatedCompositeContainer",
2957         reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type));
2958 
2959     // Register them as MutableSequence.
2960     ScopedPyObjectPtr collections(PyImport_ImportModule("collections.abc"));
2961     if (collections == nullptr) {
2962       return false;
2963     }
2964     ScopedPyObjectPtr mutable_sequence(
2965         PyObject_GetAttrString(collections.get(), "MutableSequence"));
2966     if (mutable_sequence == nullptr) {
2967       return false;
2968     }
2969     if (ScopedPyObjectPtr(
2970             PyObject_CallMethod(mutable_sequence.get(), "register", "O",
2971                                 &RepeatedScalarContainer_Type)) == nullptr) {
2972       return false;
2973     }
2974     if (ScopedPyObjectPtr(
2975             PyObject_CallMethod(mutable_sequence.get(), "register", "O",
2976                                 &RepeatedCompositeContainer_Type)) == nullptr) {
2977       return false;
2978     }
2979   }
2980 
2981   if (PyType_Ready(&PyUnknownFieldSet_Type) < 0) {
2982     return false;
2983   }
2984 
2985   PyModule_AddObject(m, "UnknownFieldSet",
2986                      reinterpret_cast<PyObject*>(&PyUnknownFieldSet_Type));
2987 
2988   if (PyType_Ready(&PyUnknownField_Type) < 0) {
2989     return false;
2990   }
2991 
2992   // Initialize Map container types.
2993   if (!InitMapContainers()) {
2994     return false;
2995   }
2996   PyModule_AddObject(m, "ScalarMapContainer",
2997                      reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
2998   PyModule_AddObject(m, "MessageMapContainer",
2999                      reinterpret_cast<PyObject*>(MessageMapContainer_Type));
3000   PyModule_AddObject(m, "MapIterator",
3001                      reinterpret_cast<PyObject*>(&MapIterator_Type));
3002 
3003   if (PyType_Ready(&ExtensionDict_Type) < 0) {
3004     return false;
3005   }
3006   PyModule_AddObject(m, "ExtensionDict",
3007                      reinterpret_cast<PyObject*>(&ExtensionDict_Type));
3008   if (PyType_Ready(&ExtensionIterator_Type) < 0) {
3009     return false;
3010   }
3011   PyModule_AddObject(m, "ExtensionIterator",
3012                      reinterpret_cast<PyObject*>(&ExtensionIterator_Type));
3013 
3014   // Expose the DescriptorPool used to hold all descriptors added from generated
3015   // pb2.py files.
3016   // PyModule_AddObject steals a reference.
3017   Py_INCREF(GetDefaultDescriptorPool());
3018   PyModule_AddObject(m, "default_pool",
3019                      reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
3020 
3021   PyModule_AddObject(m, "DescriptorPool",
3022                      reinterpret_cast<PyObject*>(&PyDescriptorPool_Type));
3023   PyModule_AddObject(m, "Descriptor",
3024                      reinterpret_cast<PyObject*>(&PyMessageDescriptor_Type));
3025   PyModule_AddObject(m, "FieldDescriptor",
3026                      reinterpret_cast<PyObject*>(&PyFieldDescriptor_Type));
3027   PyModule_AddObject(m, "EnumDescriptor",
3028                      reinterpret_cast<PyObject*>(&PyEnumDescriptor_Type));
3029   PyModule_AddObject(m, "EnumValueDescriptor",
3030                      reinterpret_cast<PyObject*>(&PyEnumValueDescriptor_Type));
3031   PyModule_AddObject(m, "FileDescriptor",
3032                      reinterpret_cast<PyObject*>(&PyFileDescriptor_Type));
3033   PyModule_AddObject(m, "OneofDescriptor",
3034                      reinterpret_cast<PyObject*>(&PyOneofDescriptor_Type));
3035   PyModule_AddObject(m, "ServiceDescriptor",
3036                      reinterpret_cast<PyObject*>(&PyServiceDescriptor_Type));
3037   PyModule_AddObject(m, "MethodDescriptor",
3038                      reinterpret_cast<PyObject*>(&PyMethodDescriptor_Type));
3039 
3040   PyObject* enum_type_wrapper =
3041       PyImport_ImportModule(PROTOBUF_PYTHON_INTERNAL ".enum_type_wrapper");
3042   if (enum_type_wrapper == nullptr) {
3043     return false;
3044   }
3045   EnumTypeWrapper_class =
3046       PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
3047   Py_DECREF(enum_type_wrapper);
3048 
3049   PyObject* message_module =
3050       PyImport_ImportModule(PROTOBUF_PYTHON_PUBLIC ".message");
3051   if (message_module == nullptr) {
3052     return false;
3053   }
3054   EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError");
3055   DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError");
3056   PythonMessage_class = PyObject_GetAttrString(message_module, "Message");
3057   Py_DECREF(message_module);
3058 
3059   PyObject* pickle_module = PyImport_ImportModule("pickle");
3060   if (pickle_module == nullptr) {
3061     return false;
3062   }
3063   PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError");
3064   Py_DECREF(pickle_module);
3065 
3066   // Override {Get,Mutable}CProtoInsidePyProto.
3067   GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl;
3068   MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl;
3069 
3070   return true;
3071 }
3072 
3073 }  // namespace python
3074 }  // namespace protobuf
3075 }  // namespace google
3076 
3077 #include "google/protobuf/port_undef.inc"
3078