• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33 
34 #include <google/protobuf/pyext/message.h>
35 
36 #include <structmember.h>  // A Python header file.
37 
38 #include <map>
39 #include <memory>
40 #include <string>
41 #include <vector>
42 
43 #include <google/protobuf/stubs/strutil.h>
44 
45 #ifndef PyVarObject_HEAD_INIT
46 #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
47 #endif
48 #ifndef Py_TYPE
49 #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
50 #endif
51 #include <google/protobuf/stubs/common.h>
52 #include <google/protobuf/stubs/logging.h>
53 #include <google/protobuf/io/coded_stream.h>
54 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
55 #include <google/protobuf/descriptor.pb.h>
56 #include <google/protobuf/descriptor.h>
57 #include <google/protobuf/message.h>
58 #include <google/protobuf/text_format.h>
59 #include <google/protobuf/unknown_field_set.h>
60 #include <google/protobuf/pyext/descriptor.h>
61 #include <google/protobuf/pyext/descriptor_pool.h>
62 #include <google/protobuf/pyext/extension_dict.h>
63 #include <google/protobuf/pyext/field.h>
64 #include <google/protobuf/pyext/map_container.h>
65 #include <google/protobuf/pyext/message_factory.h>
66 #include <google/protobuf/pyext/repeated_composite_container.h>
67 #include <google/protobuf/pyext/repeated_scalar_container.h>
68 #include <google/protobuf/pyext/safe_numerics.h>
69 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
70 #include <google/protobuf/pyext/unknown_fields.h>
71 #include <google/protobuf/util/message_differencer.h>
72 #include <google/protobuf/io/strtod.h>
73 #include <google/protobuf/stubs/map_util.h>
74 
75 // clang-format off
76 #include <google/protobuf/port_def.inc>
77 // clang-format on
78 
79 #if PY_MAJOR_VERSION >= 3
80   #define PyInt_AsLong PyLong_AsLong
81   #define PyInt_FromLong PyLong_FromLong
82   #define PyInt_FromSize_t PyLong_FromSize_t
83   #define PyString_Check PyUnicode_Check
84   #define PyString_FromString PyUnicode_FromString
85   #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
86   #define PyString_FromFormat PyUnicode_FromFormat
87   #if PY_VERSION_HEX < 0x03030000
88     #error "Python 3.0 - 3.2 are not supported."
89   #else
90   #define PyString_AsString(ob) \
91     (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob))
92 #define PyString_AsStringAndSize(ob, charpp, sizep)                           \
93   (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>(                     \
94                                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
95                               ? -1                                            \
96                               : 0)                                            \
97                        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
98 #endif
99 #endif
100 
101 namespace google {
102 namespace protobuf {
103 namespace python {
104 
105 static PyObject* kDESCRIPTOR;
106 PyObject* EnumTypeWrapper_class;
107 static PyObject* PythonMessage_class;
108 static PyObject* kEmptyWeakref;
109 static PyObject* WKT_classes = NULL;
110 
111 namespace message_meta {
112 
113 static int InsertEmptyWeakref(PyTypeObject* base);
114 
115 namespace {
116 // Copied over from internal 'google/protobuf/stubs/strutil.h'.
LowerString(std::string * s)117 inline void LowerString(std::string* s) {
118   std::string::iterator end = s->end();
119   for (std::string::iterator i = s->begin(); i != end; ++i) {
120     // tolower() changes based on locale.  We don't want this!
121     if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A';
122   }
123 }
124 }
125 
126 // Finalize the creation of the Message class.
AddDescriptors(PyObject * cls,const Descriptor * descriptor)127 static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
128   // For each field set: cls.<field>_FIELD_NUMBER = <number>
129   for (int i = 0; i < descriptor->field_count(); ++i) {
130     const FieldDescriptor* field_descriptor = descriptor->field(i);
131     ScopedPyObjectPtr property(NewFieldProperty(field_descriptor));
132     if (property == NULL) {
133       return -1;
134     }
135     if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(),
136                                property.get()) < 0) {
137       return -1;
138     }
139   }
140 
141   // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
142   for (int i = 0; i < descriptor->enum_type_count(); ++i) {
143     const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
144     ScopedPyObjectPtr enum_type(
145         PyEnumDescriptor_FromDescriptor(enum_descriptor));
146     if (enum_type == NULL) {
147       return -1;
148      }
149     // Add wrapped enum type to message class.
150     ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
151         EnumTypeWrapper_class, enum_type.get(), NULL));
152     if (wrapped == NULL) {
153       return -1;
154     }
155     if (PyObject_SetAttrString(
156             cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) {
157       return -1;
158     }
159 
160     // For each enum value add cls.<name> = <number>
161     for (int j = 0; j < enum_descriptor->value_count(); ++j) {
162       const EnumValueDescriptor* enum_value_descriptor =
163           enum_descriptor->value(j);
164       ScopedPyObjectPtr value_number(PyInt_FromLong(
165           enum_value_descriptor->number()));
166       if (value_number == NULL) {
167         return -1;
168       }
169       if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(),
170                                  value_number.get()) == -1) {
171         return -1;
172       }
173     }
174   }
175 
176   // For each extension set cls.<extension name> = <extension descriptor>.
177   //
178   // Extension descriptors come from
179   // <message descriptor>.extensions_by_name[name]
180   // which was defined previously.
181   for (int i = 0; i < descriptor->extension_count(); ++i) {
182     const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
183     ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
184     if (extension_field == NULL) {
185       return -1;
186     }
187 
188     // Add the extension field to the message class.
189     if (PyObject_SetAttrString(
190             cls, field->name().c_str(), extension_field.get()) == -1) {
191       return -1;
192     }
193   }
194 
195   return 0;
196 }
197 
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)198 static PyObject* New(PyTypeObject* type,
199                      PyObject* args, PyObject* kwargs) {
200   static char *kwlist[] = {"name", "bases", "dict", 0};
201   PyObject *bases, *dict;
202   const char* name;
203 
204   // Check arguments: (name, bases, dict)
205   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist,
206                                    &name,
207                                    &PyTuple_Type, &bases,
208                                    &PyDict_Type, &dict)) {
209     return NULL;
210   }
211 
212   // Check bases: only (), or (message.Message,) are allowed
213   if (!(PyTuple_GET_SIZE(bases) == 0 ||
214         (PyTuple_GET_SIZE(bases) == 1 &&
215          PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) {
216     PyErr_SetString(PyExc_TypeError,
217                     "A Message class can only inherit from Message");
218     return NULL;
219   }
220 
221   // Check dict['DESCRIPTOR']
222   PyObject* descriptor_or_name = PyDict_GetItem(dict, kDESCRIPTOR);
223   if (descriptor_or_name == nullptr) {
224     PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR");
225     return NULL;
226   }
227 
228   Py_ssize_t name_size;
229   char* full_name;
230   const Descriptor* message_descriptor;
231   PyObject* py_descriptor;
232 
233   if (PyObject_TypeCheck(descriptor_or_name, &PyMessageDescriptor_Type)) {
234     py_descriptor = descriptor_or_name;
235     message_descriptor = PyMessageDescriptor_AsDescriptor(py_descriptor);
236     if (message_descriptor == nullptr) {
237       return nullptr;
238     }
239   } else {
240     if (PyString_AsStringAndSize(descriptor_or_name, &full_name, &name_size) <
241         0) {
242       return nullptr;
243     }
244     message_descriptor =
245         GetDefaultDescriptorPool()->pool->FindMessageTypeByName(
246             StringParam(full_name, name_size));
247     if (message_descriptor == nullptr) {
248       PyErr_Format(PyExc_KeyError,
249                    "Can not find message descriptor %s "
250                    "from pool",
251                    full_name);
252       return nullptr;
253     }
254     py_descriptor = PyMessageDescriptor_FromDescriptor(message_descriptor);
255     // reset the dict['DESCRIPTOR'] to py_descriptor.
256     PyDict_SetItem(dict, kDESCRIPTOR, py_descriptor);
257   }
258 
259   // Messages have no __dict__
260   ScopedPyObjectPtr slots(PyTuple_New(0));
261   if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
262     return NULL;
263   }
264 
265   // Build the arguments to the base metaclass.
266   // We change the __bases__ classes.
267   ScopedPyObjectPtr new_args;
268 
269   if (WKT_classes == NULL) {
270     ScopedPyObjectPtr well_known_types(PyImport_ImportModule(
271         "google.protobuf.internal.well_known_types"));
272     GOOGLE_DCHECK(well_known_types != NULL);
273 
274     WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
275     GOOGLE_DCHECK(WKT_classes != NULL);
276   }
277 
278   PyObject* well_known_class = PyDict_GetItemString(
279       WKT_classes, message_descriptor->full_name().c_str());
280   if (well_known_class == NULL) {
281     new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type,
282                                  PythonMessage_class, dict));
283   } else {
284     new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type,
285                                  PythonMessage_class, well_known_class, dict));
286   }
287 
288   if (new_args == NULL) {
289     return NULL;
290   }
291   // Call the base metaclass.
292   ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL));
293   if (result == NULL) {
294     return NULL;
295   }
296   CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
297 
298   // Insert the empty weakref into the base classes.
299   if (InsertEmptyWeakref(
300           reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 ||
301       InsertEmptyWeakref(CMessage_Type) < 0) {
302     return NULL;
303   }
304 
305   // Cache the descriptor, both as Python object and as C++ pointer.
306   const Descriptor* descriptor =
307       PyMessageDescriptor_AsDescriptor(py_descriptor);
308   if (descriptor == NULL) {
309     return NULL;
310   }
311   Py_INCREF(py_descriptor);
312   newtype->py_message_descriptor = py_descriptor;
313   newtype->message_descriptor = descriptor;
314   // TODO(amauryfa): Don't always use the canonical pool of the descriptor,
315   // use the MessageFactory optionally passed in the class dict.
316   PyDescriptorPool* py_descriptor_pool =
317       GetDescriptorPool_FromPool(descriptor->file()->pool());
318   if (py_descriptor_pool == NULL) {
319     return NULL;
320   }
321   newtype->py_message_factory = py_descriptor_pool->py_message_factory;
322   Py_INCREF(newtype->py_message_factory);
323 
324   // Register the message in the MessageFactory.
325   // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the
326   // MessageFactory is fully implemented in C++.
327   if (message_factory::RegisterMessageClass(newtype->py_message_factory,
328                                             descriptor, newtype) < 0) {
329     return NULL;
330   }
331 
332   // Continue with type initialization: add other descriptors, enum values...
333   if (AddDescriptors(result.get(), descriptor) < 0) {
334     return NULL;
335   }
336   return result.release();
337 }
338 
Dealloc(PyObject * pself)339 static void Dealloc(PyObject* pself) {
340   CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
341   Py_XDECREF(self->py_message_descriptor);
342   Py_XDECREF(self->py_message_factory);
343   return PyType_Type.tp_dealloc(pself);
344 }
345 
GcTraverse(PyObject * pself,visitproc visit,void * arg)346 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
347   CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
348   Py_VISIT(self->py_message_descriptor);
349   Py_VISIT(self->py_message_factory);
350   return PyType_Type.tp_traverse(pself, visit, arg);
351 }
352 
GcClear(PyObject * pself)353 static int GcClear(PyObject* pself) {
354   // It's important to keep the descriptor and factory alive, until the
355   // C++ message is fully destructed.
356   return PyType_Type.tp_clear(pself);
357 }
358 
359 // This function inserts and empty weakref at the end of the list of
360 // subclasses for the main protocol buffer Message class.
361 //
362 // This eliminates a O(n^2) behaviour in the internal add_subclass
363 // routine.
InsertEmptyWeakref(PyTypeObject * base_type)364 static int InsertEmptyWeakref(PyTypeObject *base_type) {
365 #if PY_MAJOR_VERSION >= 3
366   // Python 3.4 has already included the fix for the issue that this
367   // hack addresses. For further background and the fix please see
368   // https://bugs.python.org/issue17936.
369   return 0;
370 #else
371 #ifdef Py_DEBUG
372   // The code below causes all new subclasses to append an entry, which is never
373   // cleared. This is a small memory leak, which we disable in Py_DEBUG mode
374   // to have stable refcounting checks.
375 #else
376   PyObject *subclasses = base_type->tp_subclasses;
377   if (subclasses && PyList_CheckExact(subclasses)) {
378     return PyList_Append(subclasses, kEmptyWeakref);
379   }
380 #endif  // !Py_DEBUG
381   return 0;
382 #endif  // PY_MAJOR_VERSION >= 3
383 }
384 
385 // The _extensions_by_name dictionary is built on every access.
386 // TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
GetExtensionsByName(CMessageClass * self,void * closure)387 static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
388   if (self->message_descriptor == NULL) {
389     // This is the base Message object, simply raise AttributeError.
390     PyErr_SetString(PyExc_AttributeError,
391                     "Base Message class has no DESCRIPTOR");
392     return NULL;
393   }
394 
395   const PyDescriptorPool* pool = self->py_message_factory->pool;
396 
397   std::vector<const FieldDescriptor*> extensions;
398   pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
399 
400   ScopedPyObjectPtr result(PyDict_New());
401   for (int i = 0; i < extensions.size(); i++) {
402     ScopedPyObjectPtr extension(
403         PyFieldDescriptor_FromDescriptor(extensions[i]));
404     if (extension == NULL) {
405       return NULL;
406     }
407     if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
408                              extension.get()) < 0) {
409       return NULL;
410     }
411   }
412   return result.release();
413 }
414 
415 // The _extensions_by_number dictionary is built on every access.
416 // TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
GetExtensionsByNumber(CMessageClass * self,void * closure)417 static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
418   if (self->message_descriptor == NULL) {
419     // This is the base Message object, simply raise AttributeError.
420     PyErr_SetString(PyExc_AttributeError,
421                     "Base Message class has no DESCRIPTOR");
422     return NULL;
423   }
424 
425   const PyDescriptorPool* pool = self->py_message_factory->pool;
426 
427   std::vector<const FieldDescriptor*> extensions;
428   pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
429 
430   ScopedPyObjectPtr result(PyDict_New());
431   for (int i = 0; i < extensions.size(); i++) {
432     ScopedPyObjectPtr extension(
433         PyFieldDescriptor_FromDescriptor(extensions[i]));
434     if (extension == NULL) {
435       return NULL;
436     }
437     ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
438     if (number == NULL) {
439       return NULL;
440     }
441     if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
442       return NULL;
443     }
444   }
445   return result.release();
446 }
447 
448 static PyGetSetDef Getters[] = {
449   {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
450   {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
451   {NULL}
452 };
453 
454 // Compute some class attributes on the fly:
455 // - All the _FIELD_NUMBER attributes, for all fields and nested extensions.
456 // Returns a new reference, or NULL with an exception set.
GetClassAttribute(CMessageClass * self,PyObject * name)457 static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) {
458   char* attr;
459   Py_ssize_t attr_size;
460   static const char kSuffix[] = "_FIELD_NUMBER";
461   if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 &&
462       HasSuffixString(StringPiece(attr, attr_size), kSuffix)) {
463     std::string field_name(attr, attr_size - sizeof(kSuffix) + 1);
464     LowerString(&field_name);
465 
466     // Try to find a field with the given name, without the suffix.
467     const FieldDescriptor* field =
468         self->message_descriptor->FindFieldByLowercaseName(field_name);
469     if (!field) {
470       // Search nested extensions as well.
471       field =
472           self->message_descriptor->FindExtensionByLowercaseName(field_name);
473     }
474     if (field) {
475       return PyInt_FromLong(field->number());
476     }
477   }
478   PyErr_SetObject(PyExc_AttributeError, name);
479   return NULL;
480 }
481 
GetAttr(CMessageClass * self,PyObject * name)482 static PyObject* GetAttr(CMessageClass* self, PyObject* name) {
483   PyObject* result = CMessageClass_Type->tp_base->tp_getattro(
484       reinterpret_cast<PyObject*>(self), name);
485   if (result != NULL) {
486     return result;
487   }
488   if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
489     return NULL;
490   }
491 
492   PyErr_Clear();
493   return GetClassAttribute(self, name);
494 }
495 
496 }  // namespace message_meta
497 
498 static PyTypeObject _CMessageClass_Type = {
499     PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
500     ".MessageMeta",                       // tp_name
501     sizeof(CMessageClass),                // tp_basicsize
502     0,                                    // tp_itemsize
503     message_meta::Dealloc,                // tp_dealloc
504     0,                                    // tp_print
505     0,                                    // tp_getattr
506     0,                                    // tp_setattr
507     0,                                    // tp_compare
508     0,                                    // tp_repr
509     0,                                    // tp_as_number
510     0,                                    // tp_as_sequence
511     0,                                    // tp_as_mapping
512     0,                                    // tp_hash
513     0,                                    // tp_call
514     0,                                    // tp_str
515     (getattrofunc)message_meta::GetAttr,  // tp_getattro
516     0,                                    // tp_setattro
517     0,                                    // tp_as_buffer
518     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,  // tp_flags
519     "The metaclass of ProtocolMessages",                            // tp_doc
520     message_meta::GcTraverse,  // tp_traverse
521     message_meta::GcClear,     // tp_clear
522     0,                         // tp_richcompare
523     0,                         // tp_weaklistoffset
524     0,                         // tp_iter
525     0,                         // tp_iternext
526     0,                         // tp_methods
527     0,                         // tp_members
528     message_meta::Getters,     // tp_getset
529     0,                         // tp_base
530     0,                         // tp_dict
531     0,                         // tp_descr_get
532     0,                         // tp_descr_set
533     0,                         // tp_dictoffset
534     0,                         // tp_init
535     0,                         // tp_alloc
536     message_meta::New,         // tp_new
537 };
538 PyTypeObject* CMessageClass_Type = &_CMessageClass_Type;
539 
CheckMessageClass(PyTypeObject * cls)540 static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
541   if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
542     PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
543     return NULL;
544   }
545   return reinterpret_cast<CMessageClass*>(cls);
546 }
547 
GetMessageDescriptor(PyTypeObject * cls)548 static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
549   CMessageClass* type = CheckMessageClass(cls);
550   if (type == NULL) {
551     return NULL;
552   }
553   return type->message_descriptor;
554 }
555 
556 // Forward declarations
557 namespace cmessage {
558 int InternalReleaseFieldByDescriptor(
559     CMessage* self,
560     const FieldDescriptor* field_descriptor);
561 }  // namespace cmessage
562 
563 // ---------------------------------------------------------------------
564 
565 PyObject* EncodeError_class;
566 PyObject* DecodeError_class;
567 PyObject* PickleError_class;
568 
569 // Format an error message for unexpected types.
570 // Always return with an exception set.
FormatTypeError(PyObject * arg,char * expected_types)571 void FormatTypeError(PyObject* arg, char* expected_types) {
572   // This function is often called with an exception set.
573   // Clear it to call PyObject_Repr() in good conditions.
574   PyErr_Clear();
575   PyObject* repr = PyObject_Repr(arg);
576   if (repr) {
577     PyErr_Format(PyExc_TypeError,
578                  "%.100s has type %.100s, but expected one of: %s",
579                  PyString_AsString(repr),
580                  Py_TYPE(arg)->tp_name,
581                  expected_types);
582     Py_DECREF(repr);
583   }
584 }
585 
OutOfRangeError(PyObject * arg)586 void OutOfRangeError(PyObject* arg) {
587   PyObject *s = PyObject_Str(arg);
588   if (s) {
589     PyErr_Format(PyExc_ValueError,
590                  "Value out of range: %s",
591                  PyString_AsString(s));
592     Py_DECREF(s);
593   }
594 }
595 
596 template<class RangeType, class ValueType>
VerifyIntegerCastAndRange(PyObject * arg,ValueType value)597 bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
598   if (PROTOBUF_PREDICT_FALSE(value == -1 && PyErr_Occurred())) {
599     if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
600       // Replace it with the same ValueError as pure python protos instead of
601       // the default one.
602       PyErr_Clear();
603       OutOfRangeError(arg);
604     }  // Otherwise propagate existing error.
605     return false;
606     }
607     if (PROTOBUF_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) {
608       OutOfRangeError(arg);
609       return false;
610     }
611   return true;
612 }
613 
614 template<class T>
CheckAndGetInteger(PyObject * arg,T * value)615 bool CheckAndGetInteger(PyObject* arg, T* value) {
616   // The fast path.
617 #if PY_MAJOR_VERSION < 3
618   // For the typical case, offer a fast path.
619   if (PROTOBUF_PREDICT_TRUE(PyInt_Check(arg))) {
620     long int_result = PyInt_AsLong(arg);
621     if (PROTOBUF_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) {
622       *value = static_cast<T>(int_result);
623       return true;
624     } else {
625       OutOfRangeError(arg);
626       return false;
627     }
628   }
629 #endif
630   // This effectively defines an integer as "an object that can be cast as
631   // an integer and can be used as an ordinal number".
632   // This definition includes everything that implements numbers.Integral
633   // and shouldn't cast the net too wide.
634     if (PROTOBUF_PREDICT_FALSE(!PyIndex_Check(arg))) {
635       FormatTypeError(arg, "int, long");
636       return false;
637     }
638 
639   // Now we have an integral number so we can safely use PyLong_ functions.
640   // We need to treat the signed and unsigned cases differently in case arg is
641   // holding a value above the maximum for signed longs.
642   if (std::numeric_limits<T>::min() == 0) {
643     // Unsigned case.
644     unsigned PY_LONG_LONG ulong_result;
645     if (PyLong_Check(arg)) {
646       ulong_result = PyLong_AsUnsignedLongLong(arg);
647     } else {
648       // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
649       // picky about the exact type.
650       PyObject* casted = PyNumber_Long(arg);
651       if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
652         // Propagate existing error.
653         return false;
654         }
655       ulong_result = PyLong_AsUnsignedLongLong(casted);
656       Py_DECREF(casted);
657     }
658     if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
659                                                             ulong_result)) {
660       *value = static_cast<T>(ulong_result);
661     } else {
662       return false;
663     }
664   } else {
665     // Signed case.
666     PY_LONG_LONG long_result;
667     PyNumberMethods *nb;
668     if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
669       // PyLong_AsLongLong requires it to be a long or to have an __int__()
670       // method.
671       long_result = PyLong_AsLongLong(arg);
672     } else {
673       // Valid subclasses of numbers.Integral should have a __long__() method
674       // so fall back to that.
675       PyObject* casted = PyNumber_Long(arg);
676       if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
677         // Propagate existing error.
678         return false;
679         }
680       long_result = PyLong_AsLongLong(casted);
681       Py_DECREF(casted);
682     }
683     if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
684       *value = static_cast<T>(long_result);
685     } else {
686       return false;
687     }
688   }
689 
690   return true;
691 }
692 
693 // These are referenced by repeated_scalar_container, and must
694 // be explicitly instantiated.
695 template bool CheckAndGetInteger<int32>(PyObject*, int32*);
696 template bool CheckAndGetInteger<int64>(PyObject*, int64*);
697 template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
698 template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
699 
CheckAndGetDouble(PyObject * arg,double * value)700 bool CheckAndGetDouble(PyObject* arg, double* value) {
701   *value = PyFloat_AsDouble(arg);
702   if (PROTOBUF_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) {
703     FormatTypeError(arg, "int, long, float");
704     return false;
705     }
706   return true;
707 }
708 
CheckAndGetFloat(PyObject * arg,float * value)709 bool CheckAndGetFloat(PyObject* arg, float* value) {
710   double double_value;
711   if (!CheckAndGetDouble(arg, &double_value)) {
712     return false;
713   }
714   *value = io::SafeDoubleToFloat(double_value);
715   return true;
716 }
717 
CheckAndGetBool(PyObject * arg,bool * value)718 bool CheckAndGetBool(PyObject* arg, bool* value) {
719   long long_value = PyInt_AsLong(arg);
720   if (long_value == -1 && PyErr_Occurred()) {
721     FormatTypeError(arg, "int, long, bool");
722     return false;
723   }
724   *value = static_cast<bool>(long_value);
725 
726   return true;
727 }
728 
729 // Checks whether the given object (which must be "bytes" or "unicode") contains
730 // valid UTF-8.
IsValidUTF8(PyObject * obj)731 bool IsValidUTF8(PyObject* obj) {
732   if (PyBytes_Check(obj)) {
733     PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL);
734 
735     // Clear the error indicator; we report our own error when desired.
736     PyErr_Clear();
737 
738     if (unicode) {
739       Py_DECREF(unicode);
740       return true;
741     } else {
742       return false;
743     }
744   } else {
745     // Unicode object, known to be valid UTF-8.
746     return true;
747   }
748 }
749 
AllowInvalidUTF8(const FieldDescriptor * field)750 bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; }
751 
CheckString(PyObject * arg,const FieldDescriptor * descriptor)752 PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
753   GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING ||
754          descriptor->type() == FieldDescriptor::TYPE_BYTES);
755   if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
756     if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
757       FormatTypeError(arg, "bytes, unicode");
758       return NULL;
759     }
760 
761     if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) {
762       PyObject* repr = PyObject_Repr(arg);
763       PyErr_Format(PyExc_ValueError,
764                    "%s has type str, but isn't valid UTF-8 "
765                    "encoding. Non-UTF-8 strings must be converted to "
766                    "unicode objects before being added.",
767                    PyString_AsString(repr));
768       Py_DECREF(repr);
769       return NULL;
770     }
771   } else if (!PyBytes_Check(arg)) {
772     FormatTypeError(arg, "bytes");
773     return NULL;
774   }
775 
776   PyObject* encoded_string = NULL;
777   if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
778     if (PyBytes_Check(arg)) {
779       // The bytes were already validated as correctly encoded UTF-8 above.
780       encoded_string = arg;  // Already encoded.
781       Py_INCREF(encoded_string);
782     } else {
783       encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
784     }
785   } else {
786     // In this case field type is "bytes".
787     encoded_string = arg;
788     Py_INCREF(encoded_string);
789   }
790 
791   return encoded_string;
792 }
793 
CheckAndSetString(PyObject * arg,Message * message,const FieldDescriptor * descriptor,const Reflection * reflection,bool append,int index)794 bool CheckAndSetString(
795     PyObject* arg, Message* message,
796     const FieldDescriptor* descriptor,
797     const Reflection* reflection,
798     bool append,
799     int index) {
800   ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor));
801 
802   if (encoded_string.get() == NULL) {
803     return false;
804   }
805 
806   char* value;
807   Py_ssize_t value_len;
808   if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) {
809     return false;
810   }
811 
812   string value_string(value, value_len);
813   if (append) {
814     reflection->AddString(message, descriptor, std::move(value_string));
815   } else if (index < 0) {
816     reflection->SetString(message, descriptor, std::move(value_string));
817   } else {
818     reflection->SetRepeatedString(message, descriptor, index,
819                                   std::move(value_string));
820   }
821   return true;
822 }
823 
ToStringObject(const FieldDescriptor * descriptor,const std::string & value)824 PyObject* ToStringObject(const FieldDescriptor* descriptor,
825                          const std::string& value) {
826   if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
827     return PyBytes_FromStringAndSize(value.c_str(), value.length());
828   }
829 
830   PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL);
831   // If the string can't be decoded in UTF-8, just return a string object that
832   // contains the raw bytes. This can't happen if the value was assigned using
833   // the members of the Python message object, but can happen if the values were
834   // parsed from the wire (binary).
835   if (result == NULL) {
836     PyErr_Clear();
837     result = PyBytes_FromStringAndSize(value.c_str(), value.length());
838   }
839   return result;
840 }
841 
CheckFieldBelongsToMessage(const FieldDescriptor * field_descriptor,const Message * message)842 bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
843                                 const Message* message) {
844   if (message->GetDescriptor() == field_descriptor->containing_type()) {
845     return true;
846   }
847   PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'",
848                field_descriptor->full_name().c_str(),
849                message->GetDescriptor()->full_name().c_str());
850   return false;
851 }
852 
853 namespace cmessage {
854 
GetFactoryForMessage(CMessage * message)855 PyMessageFactory* GetFactoryForMessage(CMessage* message) {
856   GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type));
857   return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
858 }
859 
MaybeReleaseOverlappingOneofField(CMessage * cmessage,const FieldDescriptor * field)860 static int MaybeReleaseOverlappingOneofField(
861     CMessage* cmessage,
862     const FieldDescriptor* field) {
863 #ifdef GOOGLE_PROTOBUF_HAS_ONEOF
864   Message* message = cmessage->message;
865   const Reflection* reflection = message->GetReflection();
866   if (!field->containing_oneof() ||
867       !reflection->HasOneof(*message, field->containing_oneof()) ||
868       reflection->HasField(*message, field)) {
869     // No other field in this oneof, no need to release.
870     return 0;
871   }
872 
873   const OneofDescriptor* oneof = field->containing_oneof();
874   const FieldDescriptor* existing_field =
875       reflection->GetOneofFieldDescriptor(*message, oneof);
876   if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
877     // Non-message fields don't need to be released.
878     return 0;
879   }
880   if (InternalReleaseFieldByDescriptor(cmessage, existing_field) < 0) {
881     return -1;
882   }
883 #endif
884   return 0;
885 }
886 
887 // After a Merge, visit every sub-message that was read-only, and
888 // eventually update their pointer if the Merge operation modified them.
FixupMessageAfterMerge(CMessage * self)889 int FixupMessageAfterMerge(CMessage* self) {
890   if (!self->composite_fields) {
891     return 0;
892   }
893   for (const auto& item : *self->composite_fields) {
894     const FieldDescriptor* descriptor = item.first;
895     if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
896         !descriptor->is_repeated()) {
897       CMessage* cmsg = reinterpret_cast<CMessage*>(item.second);
898       if (cmsg->read_only == false) {
899         return 0;
900       }
901       Message* message = self->message;
902       const Reflection* reflection = message->GetReflection();
903       if (reflection->HasField(*message, descriptor)) {
904         // Message used to be read_only, but is no longer. Get the new pointer
905         // and record it.
906         Message* mutable_message =
907             reflection->MutableMessage(message, descriptor, nullptr);
908         cmsg->message = mutable_message;
909         cmsg->read_only = false;
910         if (FixupMessageAfterMerge(cmsg) < 0) {
911           return -1;
912         }
913       }
914     }
915   }
916 
917   return 0;
918 }
919 
920 // ---------------------------------------------------------------------
921 // Making a message writable
922 
AssureWritable(CMessage * self)923 int AssureWritable(CMessage* self) {
924   if (self == NULL || !self->read_only) {
925     return 0;
926   }
927 
928   // Toplevel messages are always mutable.
929   GOOGLE_DCHECK(self->parent);
930 
931   if (AssureWritable(self->parent) == -1)
932     return -1;
933 
934   // If this message is part of a oneof, there might be a field to release in
935   // the parent.
936   if (MaybeReleaseOverlappingOneofField(self->parent,
937                                         self->parent_field_descriptor) < 0) {
938     return -1;
939   }
940 
941   // Make self->message writable.
942   Message* parent_message = self->parent->message;
943   const Reflection* reflection = parent_message->GetReflection();
944   Message* mutable_message = reflection->MutableMessage(
945       parent_message, self->parent_field_descriptor,
946       GetFactoryForMessage(self->parent)->message_factory);
947   if (mutable_message == NULL) {
948     return -1;
949   }
950   self->message = mutable_message;
951   self->read_only = false;
952 
953   return 0;
954 }
955 
956 // --- Globals:
957 
958 // Retrieve a C++ FieldDescriptor for an extension handle.
GetExtensionDescriptor(PyObject * extension)959 const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
960   ScopedPyObjectPtr cdescriptor;
961   if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) {
962     // Most callers consider extensions as a plain dictionary.  We should
963     // allow input which is not a field descriptor, and simply pretend it does
964     // not exist.
965     PyErr_SetObject(PyExc_KeyError, extension);
966     return NULL;
967   }
968   return PyFieldDescriptor_AsDescriptor(extension);
969 }
970 
971 // If value is a string, convert it into an enum value based on the labels in
972 // descriptor, otherwise simply return value.  Always returns a new reference.
GetIntegerEnumValue(const FieldDescriptor & descriptor,PyObject * value)973 static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
974                                      PyObject* value) {
975   if (PyString_Check(value) || PyUnicode_Check(value)) {
976     const EnumDescriptor* enum_descriptor = descriptor.enum_type();
977     if (enum_descriptor == NULL) {
978       PyErr_SetString(PyExc_TypeError, "not an enum field");
979       return NULL;
980     }
981     char* enum_label;
982     Py_ssize_t size;
983     if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) {
984       return NULL;
985     }
986     const EnumValueDescriptor* enum_value_descriptor =
987         enum_descriptor->FindValueByName(StringParam(enum_label, size));
988     if (enum_value_descriptor == NULL) {
989       PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label);
990       return NULL;
991     }
992     return PyInt_FromLong(enum_value_descriptor->number());
993   }
994   Py_INCREF(value);
995   return value;
996 }
997 
998 // Delete a slice from a repeated field.
999 // The only way to remove items in C++ protos is to delete the last one,
1000 // so we swap items to move the deleted ones at the end, and then strip the
1001 // sequence.
DeleteRepeatedField(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * slice)1002 int DeleteRepeatedField(
1003     CMessage* self,
1004     const FieldDescriptor* field_descriptor,
1005     PyObject* slice) {
1006   Py_ssize_t length, from, to, step, slice_length;
1007   Message* message = self->message;
1008   const Reflection* reflection = message->GetReflection();
1009   int min, max;
1010   length = reflection->FieldSize(*message, field_descriptor);
1011 
1012   if (PySlice_Check(slice)) {
1013     from = to = step = slice_length = 0;
1014 #if PY_MAJOR_VERSION < 3
1015     PySlice_GetIndicesEx(
1016         reinterpret_cast<PySliceObject*>(slice),
1017         length, &from, &to, &step, &slice_length);
1018 #else
1019     PySlice_GetIndicesEx(
1020         slice,
1021         length, &from, &to, &step, &slice_length);
1022 #endif
1023     if (from < to) {
1024       min = from;
1025       max = to - 1;
1026     } else {
1027       min = to + 1;
1028       max = from;
1029     }
1030   } else {
1031     from = to = PyLong_AsLong(slice);
1032     if (from == -1 && PyErr_Occurred()) {
1033       PyErr_SetString(PyExc_TypeError, "list indices must be integers");
1034       return -1;
1035     }
1036 
1037     if (from < 0) {
1038       from = to = length + from;
1039     }
1040     step = 1;
1041     min = max = from;
1042 
1043     // Range check.
1044     if (from < 0 || from >= length) {
1045       PyErr_Format(PyExc_IndexError, "list assignment index out of range");
1046       return -1;
1047     }
1048   }
1049 
1050   Py_ssize_t i = from;
1051   std::vector<bool> to_delete(length, false);
1052   while (i >= min && i <= max) {
1053     to_delete[i] = true;
1054     i += step;
1055   }
1056 
1057   // Swap elements so that items to delete are at the end.
1058   to = 0;
1059   for (i = 0; i < length; ++i) {
1060     if (!to_delete[i]) {
1061       if (i != to) {
1062         reflection->SwapElements(message, field_descriptor, i, to);
1063       }
1064       ++to;
1065     }
1066   }
1067 
1068   // Remove items, starting from the end.
1069   for (; length > to; length--) {
1070     if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1071       reflection->RemoveLast(message, field_descriptor);
1072       continue;
1073     }
1074     // It seems that RemoveLast() is less efficient for sub-messages, and
1075     // the memory is not completely released. Prefer ReleaseLast().
1076     Message* sub_message = reflection->ReleaseLast(message, field_descriptor);
1077     // If there is a live weak reference to an item being removed, we "Release"
1078     // it, and it takes ownership of the message.
1079     if (CMessage* released = self->MaybeReleaseSubMessage(sub_message)) {
1080       released->message = sub_message;
1081     } else {
1082       // sub_message was not transferred, delete it.
1083       delete sub_message;
1084     }
1085   }
1086 
1087   return 0;
1088 }
1089 
1090 // Initializes fields of a message. Used in constructors.
InitAttributes(CMessage * self,PyObject * args,PyObject * kwargs)1091 int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
1092   if (args != NULL && PyTuple_Size(args) != 0) {
1093     PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
1094     return -1;
1095   }
1096 
1097   if (kwargs == NULL) {
1098     return 0;
1099   }
1100 
1101   Py_ssize_t pos = 0;
1102   PyObject* name;
1103   PyObject* value;
1104   while (PyDict_Next(kwargs, &pos, &name, &value)) {
1105     if (!(PyString_Check(name) || PyUnicode_Check(name))) {
1106       PyErr_SetString(PyExc_ValueError, "Field name must be a string");
1107       return -1;
1108     }
1109     ScopedPyObjectPtr property(
1110         PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name));
1111     if (property == NULL ||
1112         !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) {
1113       PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
1114                    self->message->GetDescriptor()->name().c_str(),
1115                    PyString_AsString(name));
1116       return -1;
1117     }
1118     const FieldDescriptor* descriptor =
1119         reinterpret_cast<PyMessageFieldProperty*>(property.get())
1120             ->field_descriptor;
1121     if (value == Py_None) {
1122       // field=None is the same as no field at all.
1123       continue;
1124     }
1125     if (descriptor->is_map()) {
1126       ScopedPyObjectPtr map(GetFieldValue(self, descriptor));
1127       const FieldDescriptor* value_descriptor =
1128           descriptor->message_type()->FindFieldByName("value");
1129       if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1130         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1131         if (iter == NULL) {
1132           PyErr_Format(PyExc_TypeError, "Argument %s is not iterable", PyString_AsString(name));
1133           return -1;
1134         }
1135         ScopedPyObjectPtr next;
1136         while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1137           ScopedPyObjectPtr source_value(PyObject_GetItem(value, next.get()));
1138           ScopedPyObjectPtr dest_value(PyObject_GetItem(map.get(), next.get()));
1139           if (source_value.get() == NULL || dest_value.get() == NULL) {
1140             return -1;
1141           }
1142           ScopedPyObjectPtr ok(PyObject_CallMethod(
1143               dest_value.get(), "MergeFrom", "O", source_value.get()));
1144           if (ok.get() == NULL) {
1145             return -1;
1146           }
1147         }
1148       } else {
1149         ScopedPyObjectPtr function_return;
1150         function_return.reset(
1151             PyObject_CallMethod(map.get(), "update", "O", value));
1152         if (function_return.get() == NULL) {
1153           return -1;
1154         }
1155       }
1156     } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1157       ScopedPyObjectPtr container(GetFieldValue(self, descriptor));
1158       if (container == NULL) {
1159         return -1;
1160       }
1161       if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1162         RepeatedCompositeContainer* rc_container =
1163             reinterpret_cast<RepeatedCompositeContainer*>(container.get());
1164         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1165         if (iter == NULL) {
1166           PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1167           return -1;
1168         }
1169         ScopedPyObjectPtr next;
1170         while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1171           PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL);
1172           ScopedPyObjectPtr new_msg(
1173               repeated_composite_container::Add(rc_container, NULL, kwargs));
1174           if (new_msg == NULL) {
1175             return -1;
1176           }
1177           if (kwargs == NULL) {
1178             // next was not a dict, it's a message we need to merge
1179             ScopedPyObjectPtr merged(MergeFrom(
1180                 reinterpret_cast<CMessage*>(new_msg.get()), next.get()));
1181             if (merged.get() == NULL) {
1182               return -1;
1183             }
1184           }
1185         }
1186         if (PyErr_Occurred()) {
1187           // Check to see how PyIter_Next() exited.
1188           return -1;
1189         }
1190       } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1191         RepeatedScalarContainer* rs_container =
1192             reinterpret_cast<RepeatedScalarContainer*>(container.get());
1193         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1194         if (iter == NULL) {
1195           PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1196           return -1;
1197         }
1198         ScopedPyObjectPtr next;
1199         while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1200           ScopedPyObjectPtr enum_value(
1201               GetIntegerEnumValue(*descriptor, next.get()));
1202           if (enum_value == NULL) {
1203             return -1;
1204           }
1205           ScopedPyObjectPtr new_msg(repeated_scalar_container::Append(
1206               rs_container, enum_value.get()));
1207           if (new_msg == NULL) {
1208             return -1;
1209           }
1210         }
1211         if (PyErr_Occurred()) {
1212           // Check to see how PyIter_Next() exited.
1213           return -1;
1214         }
1215       } else {
1216         if (ScopedPyObjectPtr(repeated_scalar_container::Extend(
1217                 reinterpret_cast<RepeatedScalarContainer*>(container.get()),
1218                 value)) ==
1219             NULL) {
1220           return -1;
1221         }
1222       }
1223     } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1224       ScopedPyObjectPtr message(GetFieldValue(self, descriptor));
1225       if (message == NULL) {
1226         return -1;
1227       }
1228       CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
1229       if (PyDict_Check(value)) {
1230         // Make the message exist even if the dict is empty.
1231         AssureWritable(cmessage);
1232         if (InitAttributes(cmessage, NULL, value) < 0) {
1233           return -1;
1234         }
1235       } else {
1236         ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
1237         if (merged == NULL) {
1238           return -1;
1239         }
1240       }
1241     } else {
1242       ScopedPyObjectPtr new_val;
1243       if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1244         new_val.reset(GetIntegerEnumValue(*descriptor, value));
1245         if (new_val == NULL) {
1246           return -1;
1247         }
1248         value = new_val.get();
1249       }
1250       if (SetFieldValue(self, descriptor, value) < 0) {
1251         return -1;
1252       }
1253     }
1254   }
1255   return 0;
1256 }
1257 
1258 // Allocates an incomplete Python Message: the caller must fill self->message
1259 // and eventually self->parent.
NewEmptyMessage(CMessageClass * type)1260 CMessage* NewEmptyMessage(CMessageClass* type) {
1261   CMessage* self = reinterpret_cast<CMessage*>(
1262       PyType_GenericAlloc(&type->super.ht_type, 0));
1263   if (self == NULL) {
1264     return NULL;
1265   }
1266 
1267   self->message = NULL;
1268   self->parent = NULL;
1269   self->parent_field_descriptor = NULL;
1270   self->read_only = false;
1271 
1272   self->composite_fields = NULL;
1273   self->child_submessages = NULL;
1274 
1275   self->unknown_field_set = NULL;
1276 
1277   return self;
1278 }
1279 
1280 // The __new__ method of Message classes.
1281 // Creates a new C++ message and takes ownership.
New(PyTypeObject * cls,PyObject * unused_args,PyObject * unused_kwargs)1282 static PyObject* New(PyTypeObject* cls,
1283                      PyObject* unused_args, PyObject* unused_kwargs) {
1284   CMessageClass* type = CheckMessageClass(cls);
1285   if (type == NULL) {
1286     return NULL;
1287   }
1288   // Retrieve the message descriptor and the default instance (=prototype).
1289   const Descriptor* message_descriptor = type->message_descriptor;
1290   if (message_descriptor == NULL) {
1291     return NULL;
1292   }
1293   const Message* prototype =
1294       type->py_message_factory->message_factory->GetPrototype(
1295           message_descriptor);
1296   if (prototype == NULL) {
1297     PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
1298     return NULL;
1299   }
1300 
1301   CMessage* self = NewEmptyMessage(type);
1302   if (self == NULL) {
1303     return NULL;
1304   }
1305   self->message = prototype->New();
1306   self->parent = nullptr;  // This message owns its data.
1307   return reinterpret_cast<PyObject*>(self);
1308 }
1309 
1310 // The __init__ method of Message classes.
1311 // It initializes fields from keywords passed to the constructor.
Init(CMessage * self,PyObject * args,PyObject * kwargs)1312 static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
1313   return InitAttributes(self, args, kwargs);
1314 }
1315 
1316 // ---------------------------------------------------------------------
1317 // Deallocating a CMessage
1318 
Dealloc(CMessage * self)1319 static void Dealloc(CMessage* self) {
1320   if (self->weakreflist) {
1321     PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
1322   }
1323   // At this point all dependent objects have been removed.
1324   GOOGLE_DCHECK(!self->child_submessages || self->child_submessages->empty());
1325   GOOGLE_DCHECK(!self->composite_fields || self->composite_fields->empty());
1326   delete self->child_submessages;
1327   delete self->composite_fields;
1328   if (self->unknown_field_set) {
1329     unknown_fields::Clear(
1330         reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1331   }
1332 
1333   CMessage* parent = self->parent;
1334   if (!parent) {
1335     // No parent, we own the message.
1336     delete self->message;
1337   } else if (parent->AsPyObject() == Py_None) {
1338     // Message owned externally: Nothing to dealloc
1339     Py_CLEAR(self->parent);
1340   } else {
1341     // Clear this message from its parent's map.
1342     if (self->parent_field_descriptor->is_repeated()) {
1343       if (parent->child_submessages)
1344         parent->child_submessages->erase(self->message);
1345     } else {
1346       if (parent->composite_fields)
1347         parent->composite_fields->erase(self->parent_field_descriptor);
1348     }
1349     Py_CLEAR(self->parent);
1350   }
1351   Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
1352 }
1353 
1354 // ---------------------------------------------------------------------
1355 
1356 
IsInitialized(CMessage * self,PyObject * args)1357 PyObject* IsInitialized(CMessage* self, PyObject* args) {
1358   PyObject* errors = NULL;
1359   if (!PyArg_ParseTuple(args, "|O", &errors)) {
1360     return NULL;
1361   }
1362   if (self->message->IsInitialized()) {
1363     Py_RETURN_TRUE;
1364   }
1365   if (errors != NULL) {
1366     ScopedPyObjectPtr initialization_errors(
1367         FindInitializationErrors(self));
1368     if (initialization_errors == NULL) {
1369       return NULL;
1370     }
1371     ScopedPyObjectPtr extend_name(PyString_FromString("extend"));
1372     if (extend_name == NULL) {
1373       return NULL;
1374     }
1375     ScopedPyObjectPtr result(PyObject_CallMethodObjArgs(
1376         errors,
1377         extend_name.get(),
1378         initialization_errors.get(),
1379         NULL));
1380     if (result == NULL) {
1381       return NULL;
1382     }
1383   }
1384   Py_RETURN_FALSE;
1385 }
1386 
HasFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1387 int HasFieldByDescriptor(CMessage* self,
1388                          const FieldDescriptor* field_descriptor) {
1389   Message* message = self->message;
1390   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
1391     return -1;
1392   }
1393   if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1394     PyErr_SetString(PyExc_KeyError,
1395                     "Field is repeated. A singular method is required.");
1396     return -1;
1397   }
1398   return message->GetReflection()->HasField(*message, field_descriptor);
1399 }
1400 
FindFieldWithOneofs(const Message * message,ConstStringParam field_name,bool * in_oneof)1401 const FieldDescriptor* FindFieldWithOneofs(const Message* message,
1402                                            ConstStringParam field_name,
1403                                            bool* in_oneof) {
1404   *in_oneof = false;
1405   const Descriptor* descriptor = message->GetDescriptor();
1406   const FieldDescriptor* field_descriptor =
1407       descriptor->FindFieldByName(field_name);
1408   if (field_descriptor != NULL) {
1409     return field_descriptor;
1410   }
1411   const OneofDescriptor* oneof_desc =
1412       descriptor->FindOneofByName(field_name);
1413   if (oneof_desc != NULL) {
1414     *in_oneof = true;
1415     return message->GetReflection()->GetOneofFieldDescriptor(*message,
1416                                                              oneof_desc);
1417   }
1418   return NULL;
1419 }
1420 
CheckHasPresence(const FieldDescriptor * field_descriptor,bool in_oneof)1421 bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) {
1422   auto message_name = field_descriptor->containing_type()->name();
1423   if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1424     PyErr_Format(PyExc_ValueError,
1425                  "Protocol message %s has no singular \"%s\" field.",
1426                  message_name.c_str(), field_descriptor->name().c_str());
1427     return false;
1428   }
1429 
1430   if (!field_descriptor->has_presence()) {
1431     PyErr_Format(PyExc_ValueError,
1432                  "Can't test non-optional, non-submessage field \"%s.%s\" for "
1433                  "presence in proto3.",
1434                  message_name.c_str(), field_descriptor->name().c_str());
1435     return false;
1436   }
1437 
1438   return true;
1439 }
1440 
HasField(CMessage * self,PyObject * arg)1441 PyObject* HasField(CMessage* self, PyObject* arg) {
1442   char* field_name;
1443   Py_ssize_t size;
1444 #if PY_MAJOR_VERSION < 3
1445   if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) {
1446     return NULL;
1447   }
1448 #else
1449   field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
1450   if (!field_name) {
1451     return NULL;
1452   }
1453 #endif
1454 
1455   Message* message = self->message;
1456   bool is_in_oneof;
1457   const FieldDescriptor* field_descriptor =
1458       FindFieldWithOneofs(message, StringParam(field_name, size), &is_in_oneof);
1459   if (field_descriptor == NULL) {
1460     if (!is_in_oneof) {
1461       PyErr_Format(PyExc_ValueError, "Protocol message %s has no field %s.",
1462                    message->GetDescriptor()->name().c_str(), field_name);
1463       return NULL;
1464     } else {
1465       Py_RETURN_FALSE;
1466     }
1467   }
1468 
1469   if (!CheckHasPresence(field_descriptor, is_in_oneof)) {
1470     return NULL;
1471   }
1472 
1473   if (message->GetReflection()->HasField(*message, field_descriptor)) {
1474     Py_RETURN_TRUE;
1475   }
1476 
1477   Py_RETURN_FALSE;
1478 }
1479 
ClearExtension(CMessage * self,PyObject * extension)1480 PyObject* ClearExtension(CMessage* self, PyObject* extension) {
1481   const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1482   if (descriptor == NULL) {
1483     return NULL;
1484   }
1485   if (ClearFieldByDescriptor(self, descriptor) < 0) {
1486     return nullptr;
1487   }
1488   Py_RETURN_NONE;
1489 }
1490 
HasExtension(CMessage * self,PyObject * extension)1491 PyObject* HasExtension(CMessage* self, PyObject* extension) {
1492   const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1493   if (descriptor == NULL) {
1494     return NULL;
1495   }
1496   int has_field = HasFieldByDescriptor(self, descriptor);
1497   if (has_field < 0) {
1498     return nullptr;
1499   } else {
1500     return PyBool_FromLong(has_field);
1501   }
1502 }
1503 
1504 // ---------------------------------------------------------------------
1505 // Releasing messages
1506 //
1507 // The Python API's ClearField() and Clear() methods behave
1508 // differently than their C++ counterparts.  While the C++ versions
1509 // clears the children, the Python versions detaches the children,
1510 // without touching their content.  This impedance mismatch causes
1511 // some complexity in the implementation, which is captured in this
1512 // section.
1513 //
1514 // When one or multiple fields are cleared we need to:
1515 //
1516 // * Gather all child objects that need to be detached from the message.
1517 //   In composite_fields and child_submessages.
1518 //
1519 // * Create a new Python message of the same kind. Use SwapFields() to move
1520 //   data from the original message.
1521 //
1522 // * Change the parent of all child objects: update their strong reference
1523 //   to their parent, and move their presence in composite_fields and
1524 //   child_submessages.
1525 
1526 // ---------------------------------------------------------------------
1527 // Release a composite child of a CMessage
1528 
InternalReparentFields(CMessage * self,const std::vector<CMessage * > & messages_to_release,const std::vector<ContainerBase * > & containers_to_release)1529 static int InternalReparentFields(
1530     CMessage* self, const std::vector<CMessage*>& messages_to_release,
1531     const std::vector<ContainerBase*>& containers_to_release) {
1532   if (messages_to_release.empty() && containers_to_release.empty()) {
1533     return 0;
1534   }
1535 
1536   // Move all the passed sub_messages to another message.
1537   CMessage* new_message = cmessage::NewEmptyMessage(self->GetMessageClass());
1538   if (new_message == nullptr) {
1539     return -1;
1540   }
1541   new_message->message = self->message->New();
1542   ScopedPyObjectPtr holder(reinterpret_cast<PyObject*>(new_message));
1543   new_message->child_submessages = new CMessage::SubMessagesMap();
1544   new_message->composite_fields = new CMessage::CompositeFieldsMap();
1545   std::set<const FieldDescriptor*> fields_to_swap;
1546 
1547   // In case this the removed fields are the last reference to a message, keep
1548   // a reference.
1549   Py_INCREF(self);
1550 
1551   for (const auto& to_release : messages_to_release) {
1552     fields_to_swap.insert(to_release->parent_field_descriptor);
1553     // Reparent
1554     Py_INCREF(new_message);
1555     Py_DECREF(to_release->parent);
1556     to_release->parent = new_message;
1557     self->child_submessages->erase(to_release->message);
1558     new_message->child_submessages->emplace(to_release->message, to_release);
1559   }
1560 
1561   for (const auto& to_release : containers_to_release) {
1562     fields_to_swap.insert(to_release->parent_field_descriptor);
1563     Py_INCREF(new_message);
1564     Py_DECREF(to_release->parent);
1565     to_release->parent = new_message;
1566     self->composite_fields->erase(to_release->parent_field_descriptor);
1567     new_message->composite_fields->emplace(to_release->parent_field_descriptor,
1568                                            to_release);
1569   }
1570 
1571   self->message->GetReflection()->SwapFields(
1572       self->message, new_message->message,
1573       std::vector<const FieldDescriptor*>(fields_to_swap.begin(),
1574                                           fields_to_swap.end()));
1575 
1576   // This might delete the Python message completely if all children were moved.
1577   Py_DECREF(self);
1578 
1579   return 0;
1580 }
1581 
InternalReleaseFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1582 int InternalReleaseFieldByDescriptor(
1583     CMessage* self,
1584     const FieldDescriptor* field_descriptor) {
1585   if (!field_descriptor->is_repeated() &&
1586       field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1587     // Single scalars are not in any cache.
1588     return 0;
1589   }
1590   std::vector<CMessage*> messages_to_release;
1591   std::vector<ContainerBase*> containers_to_release;
1592   if (self->child_submessages && field_descriptor->is_repeated() &&
1593       field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1594     for (const auto& child_item : *self->child_submessages) {
1595       if (child_item.second->parent_field_descriptor == field_descriptor) {
1596         messages_to_release.push_back(child_item.second);
1597       }
1598     }
1599   }
1600   if (self->composite_fields) {
1601     CMessage::CompositeFieldsMap::iterator it =
1602         self->composite_fields->find(field_descriptor);
1603     if (it != self->composite_fields->end()) {
1604       containers_to_release.push_back(it->second);
1605     }
1606   }
1607 
1608   return InternalReparentFields(self, messages_to_release,
1609                                 containers_to_release);
1610 }
1611 
ClearFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1612 int ClearFieldByDescriptor(CMessage* self,
1613                            const FieldDescriptor* field_descriptor) {
1614   if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
1615     return -1;
1616   }
1617   if (InternalReleaseFieldByDescriptor(self, field_descriptor) < 0) {
1618     return -1;
1619   }
1620   AssureWritable(self);
1621   Message* message = self->message;
1622   message->GetReflection()->ClearField(message, field_descriptor);
1623   return 0;
1624 }
1625 
ClearField(CMessage * self,PyObject * arg)1626 PyObject* ClearField(CMessage* self, PyObject* arg) {
1627   char* field_name;
1628   Py_ssize_t field_size;
1629   if (PyString_AsStringAndSize(arg, &field_name, &field_size) < 0) {
1630     return NULL;
1631   }
1632   AssureWritable(self);
1633   bool is_in_oneof;
1634   const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
1635       self->message, StringParam(field_name, field_size), &is_in_oneof);
1636   if (field_descriptor == NULL) {
1637     if (is_in_oneof) {
1638       // We gave the name of a oneof, and none of its fields are set.
1639       Py_RETURN_NONE;
1640     } else {
1641       PyErr_Format(PyExc_ValueError,
1642                    "Protocol message has no \"%s\" field.", field_name);
1643       return NULL;
1644     }
1645   }
1646 
1647   if (ClearFieldByDescriptor(self, field_descriptor) < 0) {
1648     return nullptr;
1649   }
1650   Py_RETURN_NONE;
1651 }
1652 
Clear(CMessage * self)1653 PyObject* Clear(CMessage* self) {
1654   AssureWritable(self);
1655   // Detach all current fields of this message
1656   std::vector<CMessage*> messages_to_release;
1657   std::vector<ContainerBase*> containers_to_release;
1658   if (self->child_submessages) {
1659     for (const auto& item : *self->child_submessages) {
1660       messages_to_release.push_back(item.second);
1661     }
1662   }
1663   if (self->composite_fields) {
1664     for (const auto& item : *self->composite_fields) {
1665       containers_to_release.push_back(item.second);
1666     }
1667   }
1668   if (InternalReparentFields(self, messages_to_release, containers_to_release) <
1669       0) {
1670     return NULL;
1671   }
1672   if (self->unknown_field_set) {
1673     unknown_fields::Clear(
1674         reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1675     self->unknown_field_set = nullptr;
1676   }
1677   self->message->Clear();
1678   Py_RETURN_NONE;
1679 }
1680 
1681 // ---------------------------------------------------------------------
1682 
GetMessageName(CMessage * self)1683 static std::string GetMessageName(CMessage* self) {
1684   if (self->parent_field_descriptor != NULL) {
1685     return self->parent_field_descriptor->full_name();
1686   } else {
1687     return self->message->GetDescriptor()->full_name();
1688   }
1689 }
1690 
InternalSerializeToString(CMessage * self,PyObject * args,PyObject * kwargs,bool require_initialized)1691 static PyObject* InternalSerializeToString(
1692     CMessage* self, PyObject* args, PyObject* kwargs,
1693     bool require_initialized) {
1694   // Parse the "deterministic" kwarg; defaults to False.
1695   static char* kwlist[] = { "deterministic", 0 };
1696   PyObject* deterministic_obj = Py_None;
1697   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
1698                                    &deterministic_obj)) {
1699     return NULL;
1700   }
1701   // Preemptively convert to a bool first, so we don't need to back out of
1702   // allocating memory if this raises an exception.
1703   // NOTE: This is unused later if deterministic == Py_None, but that's fine.
1704   int deterministic = PyObject_IsTrue(deterministic_obj);
1705   if (deterministic < 0) {
1706     return NULL;
1707   }
1708 
1709   if (require_initialized && !self->message->IsInitialized()) {
1710     ScopedPyObjectPtr errors(FindInitializationErrors(self));
1711     if (errors == NULL) {
1712       return NULL;
1713     }
1714     ScopedPyObjectPtr comma(PyString_FromString(","));
1715     if (comma == NULL) {
1716       return NULL;
1717     }
1718     ScopedPyObjectPtr joined(
1719         PyObject_CallMethod(comma.get(), "join", "O", errors.get()));
1720     if (joined == NULL) {
1721       return NULL;
1722     }
1723 
1724     // TODO(haberman): this is a (hopefully temporary) hack.  The unit testing
1725     // infrastructure reloads all pure-Python modules for every test, but not
1726     // C++ modules (because that's generally impossible:
1727     // http://bugs.python.org/issue1144263).  But if we cache EncodeError, we'll
1728     // return the EncodeError from a previous load of the module, which won't
1729     // match a user's attempt to catch EncodeError.  So we have to look it up
1730     // again every time.
1731     ScopedPyObjectPtr message_module(PyImport_ImportModule(
1732         "google.protobuf.message"));
1733     if (message_module.get() == NULL) {
1734       return NULL;
1735     }
1736 
1737     ScopedPyObjectPtr encode_error(
1738         PyObject_GetAttrString(message_module.get(), "EncodeError"));
1739     if (encode_error.get() == NULL) {
1740       return NULL;
1741     }
1742     PyErr_Format(encode_error.get(),
1743                  "Message %s is missing required fields: %s",
1744                  GetMessageName(self).c_str(), PyString_AsString(joined.get()));
1745     return NULL;
1746   }
1747 
1748   // Ok, arguments parsed and errors checked, now encode to a string
1749   const size_t size = self->message->ByteSizeLong();
1750   if (size == 0) {
1751     return PyBytes_FromString("");
1752   }
1753 
1754   if (size > INT_MAX) {
1755     PyErr_Format(PyExc_ValueError,
1756                  "Message %s exceeds maximum protobuf "
1757                  "size of 2GB: %zu",
1758                  GetMessageName(self).c_str(), size);
1759     return nullptr;
1760   }
1761 
1762   PyObject* result = PyBytes_FromStringAndSize(NULL, size);
1763   if (result == NULL) {
1764     return NULL;
1765   }
1766   io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
1767   io::CodedOutputStream coded_out(&out);
1768   if (deterministic_obj != Py_None) {
1769     coded_out.SetSerializationDeterministic(deterministic);
1770   }
1771   self->message->SerializeWithCachedSizes(&coded_out);
1772   GOOGLE_CHECK(!coded_out.HadError());
1773   return result;
1774 }
1775 
SerializeToString(CMessage * self,PyObject * args,PyObject * kwargs)1776 static PyObject* SerializeToString(
1777     CMessage* self, PyObject* args, PyObject* kwargs) {
1778   return InternalSerializeToString(self, args, kwargs,
1779                                    /*require_initialized=*/true);
1780 }
1781 
SerializePartialToString(CMessage * self,PyObject * args,PyObject * kwargs)1782 static PyObject* SerializePartialToString(
1783     CMessage* self, PyObject* args, PyObject* kwargs) {
1784   return InternalSerializeToString(self, args, kwargs,
1785                                    /*require_initialized=*/false);
1786 }
1787 
1788 // Formats proto fields for ascii dumps using python formatting functions where
1789 // appropriate.
1790 class PythonFieldValuePrinter : public TextFormat::FastFieldValuePrinter {
1791  public:
1792   // Python has some differences from C++ when printing floating point numbers.
1793   //
1794   // 1) Trailing .0 is always printed.
1795   // 2) (Python2) Output is rounded to 12 digits.
1796   // 3) (Python3) The full precision of the double is preserved (and Python uses
1797   //    David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some
1798   //    differences, but they rarely happen)
1799   //
1800   // We override floating point printing with the C-API function for printing
1801   // Python floats to ensure consistency.
PrintFloat(float val,TextFormat::BaseTextGenerator * generator) const1802   void PrintFloat(float val,
1803                   TextFormat::BaseTextGenerator* generator) const override {
1804     PrintDouble(val, generator);
1805   }
PrintDouble(double val,TextFormat::BaseTextGenerator * generator) const1806   void PrintDouble(double val,
1807                    TextFormat::BaseTextGenerator* generator) const override {
1808     // This implementation is not highly optimized (it allocates two temporary
1809     // Python objects) but it is simple and portable.  If this is shown to be a
1810     // performance bottleneck, we can optimize it, but the results will likely
1811     // be more complicated to accommodate the differing behavior of double
1812     // formatting between Python 2 and Python 3.
1813     //
1814     // (Though a valid question is: do we really want to make out output
1815     // dependent on the Python version?)
1816     ScopedPyObjectPtr py_value(PyFloat_FromDouble(val));
1817     if (!py_value.get()) {
1818       return;
1819     }
1820 
1821     ScopedPyObjectPtr py_str(PyObject_Str(py_value.get()));
1822     if (!py_str.get()) {
1823       return;
1824     }
1825 
1826     generator->PrintString(PyString_AsString(py_str.get()));
1827   }
1828 };
1829 
ToStr(CMessage * self)1830 static PyObject* ToStr(CMessage* self) {
1831   TextFormat::Printer printer;
1832   // Passes ownership
1833   printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter());
1834   printer.SetHideUnknownFields(true);
1835   std::string output;
1836   if (!printer.PrintToString(*self->message, &output)) {
1837     PyErr_SetString(PyExc_ValueError, "Unable to convert message to str");
1838     return NULL;
1839   }
1840   return PyString_FromString(output.c_str());
1841 }
1842 
MergeFrom(CMessage * self,PyObject * arg)1843 PyObject* MergeFrom(CMessage* self, PyObject* arg) {
1844   CMessage* other_message;
1845   if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1846     PyErr_Format(PyExc_TypeError,
1847                  "Parameter to MergeFrom() must be instance of same class: "
1848                  "expected %s got %s.",
1849                  self->message->GetDescriptor()->full_name().c_str(),
1850                  Py_TYPE(arg)->tp_name);
1851     return NULL;
1852   }
1853 
1854   other_message = reinterpret_cast<CMessage*>(arg);
1855   if (other_message->message->GetDescriptor() !=
1856       self->message->GetDescriptor()) {
1857     PyErr_Format(PyExc_TypeError,
1858                  "Parameter to MergeFrom() must be instance of same class: "
1859                  "expected %s got %s.",
1860                  self->message->GetDescriptor()->full_name().c_str(),
1861                  other_message->message->GetDescriptor()->full_name().c_str());
1862     return NULL;
1863   }
1864   AssureWritable(self);
1865 
1866   self->message->MergeFrom(*other_message->message);
1867   // Child message might be lazily created before MergeFrom. Make sure they
1868   // are mutable at this point if child messages are really created.
1869   if (FixupMessageAfterMerge(self) < 0) {
1870     return NULL;
1871   }
1872 
1873   Py_RETURN_NONE;
1874 }
1875 
CopyFrom(CMessage * self,PyObject * arg)1876 static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
1877   CMessage* other_message;
1878   if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1879     PyErr_Format(PyExc_TypeError,
1880                  "Parameter to CopyFrom() must be instance of same class: "
1881                  "expected %s got %s.",
1882                  self->message->GetDescriptor()->full_name().c_str(),
1883                  Py_TYPE(arg)->tp_name);
1884     return NULL;
1885   }
1886 
1887   other_message = reinterpret_cast<CMessage*>(arg);
1888 
1889   if (self == other_message) {
1890     Py_RETURN_NONE;
1891   }
1892 
1893   if (other_message->message->GetDescriptor() !=
1894       self->message->GetDescriptor()) {
1895     PyErr_Format(PyExc_TypeError,
1896                  "Parameter to CopyFrom() must be instance of same class: "
1897                  "expected %s got %s.",
1898                  self->message->GetDescriptor()->full_name().c_str(),
1899                  other_message->message->GetDescriptor()->full_name().c_str());
1900     return NULL;
1901   }
1902 
1903   AssureWritable(self);
1904 
1905   // CopyFrom on the message will not clean up self->composite_fields,
1906   // which can leave us in an inconsistent state, so clear it out here.
1907   (void)ScopedPyObjectPtr(Clear(self));
1908 
1909   self->message->CopyFrom(*other_message->message);
1910 
1911   Py_RETURN_NONE;
1912 }
1913 
1914 // Protobuf has a 64MB limit built in, this variable will override this. Please
1915 // do not enable this unless you fully understand the implications: protobufs
1916 // must all be kept in memory at the same time, so if they grow too big you may
1917 // get OOM errors. The protobuf APIs do not provide any tools for processing
1918 // protobufs in chunks.  If you have protos this big you should break them up if
1919 // it is at all convenient to do so.
1920 #ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
1921 static bool allow_oversize_protos = true;
1922 #else
1923 static bool allow_oversize_protos = false;
1924 #endif
1925 
1926 // Provide a method in the module to set allow_oversize_protos to a boolean
1927 // value. This method returns the newly value of allow_oversize_protos.
SetAllowOversizeProtos(PyObject * m,PyObject * arg)1928 PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
1929   if (!arg || !PyBool_Check(arg)) {
1930     PyErr_SetString(PyExc_TypeError,
1931                     "Argument to SetAllowOversizeProtos must be boolean");
1932     return NULL;
1933   }
1934   allow_oversize_protos = PyObject_IsTrue(arg);
1935   if (allow_oversize_protos) {
1936     Py_RETURN_TRUE;
1937   } else {
1938     Py_RETURN_FALSE;
1939   }
1940 }
1941 
MergeFromString(CMessage * self,PyObject * arg)1942 static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
1943   const void* data;
1944   Py_ssize_t data_length;
1945   if (PyObject_AsReadBuffer(arg, &data, &data_length) < 0) {
1946     return NULL;
1947   }
1948 
1949   AssureWritable(self);
1950 
1951   PyMessageFactory* factory = GetFactoryForMessage(self);
1952   int depth = allow_oversize_protos
1953                   ? INT_MAX
1954                   : io::CodedInputStream::GetDefaultRecursionLimit();
1955   const char* ptr;
1956   internal::ParseContext ctx(
1957       depth, false, &ptr,
1958       StringPiece(static_cast<const char*>(data), data_length));
1959   ctx.data().pool = factory->pool->pool;
1960   ctx.data().factory = factory->message_factory;
1961 
1962   ptr = self->message->_InternalParse(ptr, &ctx);
1963 
1964   // Child message might be lazily created before MergeFrom. Make sure they
1965   // are mutable at this point if child messages are really created.
1966   if (FixupMessageAfterMerge(self) < 0) {
1967     return NULL;
1968   }
1969 
1970   // Python makes distinction in error message, between a general parse failure
1971   // and in-correct ending on a terminating tag. Hence we need to be a bit more
1972   // explicit in our correctness checks.
1973   if (ptr == nullptr || ctx.BytesUntilLimit(ptr) < 0) {
1974     // Parse error or the parser overshoot the limit.
1975     PyErr_Format(DecodeError_class, "Error parsing message");
1976     return NULL;
1977   }
1978   // ctx has an explicit limit set (length of string_view), so we have to
1979   // check we ended at that limit.
1980   if (!ctx.EndedAtLimit()) {
1981     // TODO(jieluo): Raise error and return NULL instead.
1982     // b/27494216
1983     PyErr_Warn(nullptr, "Unexpected end-group tag: Not all data was converted");
1984     return PyInt_FromLong(data_length - ctx.BytesUntilLimit(ptr));
1985   }
1986   return PyInt_FromLong(data_length);
1987 }
1988 
ParseFromString(CMessage * self,PyObject * arg)1989 static PyObject* ParseFromString(CMessage* self, PyObject* arg) {
1990   if (ScopedPyObjectPtr(Clear(self)) == NULL) {
1991     return NULL;
1992   }
1993   return MergeFromString(self, arg);
1994 }
1995 
ByteSize(CMessage * self,PyObject * args)1996 static PyObject* ByteSize(CMessage* self, PyObject* args) {
1997   return PyLong_FromLong(self->message->ByteSizeLong());
1998 }
1999 
RegisterExtension(PyObject * cls,PyObject * extension_handle)2000 PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
2001   const FieldDescriptor* descriptor =
2002       GetExtensionDescriptor(extension_handle);
2003   if (descriptor == NULL) {
2004     return NULL;
2005   }
2006   if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
2007     PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
2008                  cls->ob_type->tp_name);
2009     return NULL;
2010   }
2011   CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
2012   if (message_class == NULL) {
2013     return NULL;
2014   }
2015   // If the extension was already registered, check that it is the same.
2016   const FieldDescriptor* existing_extension =
2017       message_class->py_message_factory->pool->pool->FindExtensionByNumber(
2018           descriptor->containing_type(), descriptor->number());
2019   if (existing_extension != NULL && existing_extension != descriptor) {
2020     PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
2021     return NULL;
2022   }
2023   Py_RETURN_NONE;
2024 }
2025 
SetInParent(CMessage * self,PyObject * args)2026 static PyObject* SetInParent(CMessage* self, PyObject* args) {
2027   AssureWritable(self);
2028   Py_RETURN_NONE;
2029 }
2030 
WhichOneof(CMessage * self,PyObject * arg)2031 static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
2032   Py_ssize_t name_size;
2033   char *name_data;
2034   if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0)
2035     return NULL;
2036   const OneofDescriptor* oneof_desc =
2037       self->message->GetDescriptor()->FindOneofByName(
2038           StringParam(name_data, name_size));
2039   if (oneof_desc == NULL) {
2040     PyErr_Format(PyExc_ValueError,
2041                  "Protocol message has no oneof \"%s\" field.", name_data);
2042     return NULL;
2043   }
2044   const FieldDescriptor* field_in_oneof =
2045       self->message->GetReflection()->GetOneofFieldDescriptor(
2046           *self->message, oneof_desc);
2047   if (field_in_oneof == NULL) {
2048     Py_RETURN_NONE;
2049   } else {
2050     const std::string& name = field_in_oneof->name();
2051     return PyString_FromStringAndSize(name.c_str(), name.size());
2052   }
2053 }
2054 
2055 static PyObject* GetExtensionDict(CMessage* self, void *closure);
2056 
ListFields(CMessage * self)2057 static PyObject* ListFields(CMessage* self) {
2058   std::vector<const FieldDescriptor*> fields;
2059   self->message->GetReflection()->ListFields(*self->message, &fields);
2060 
2061   // Normally, the list will be exactly the size of the fields.
2062   ScopedPyObjectPtr all_fields(PyList_New(fields.size()));
2063   if (all_fields == NULL) {
2064     return NULL;
2065   }
2066 
2067   // When there are unknown extensions, the py list will *not* contain
2068   // the field information.  Thus the actual size of the py list will be
2069   // smaller than the size of fields.  Set the actual size at the end.
2070   Py_ssize_t actual_size = 0;
2071   for (size_t i = 0; i < fields.size(); ++i) {
2072     ScopedPyObjectPtr t(PyTuple_New(2));
2073     if (t == NULL) {
2074       return NULL;
2075     }
2076 
2077     if (fields[i]->is_extension()) {
2078       ScopedPyObjectPtr extension_field(
2079           PyFieldDescriptor_FromDescriptor(fields[i]));
2080       if (extension_field == NULL) {
2081         return NULL;
2082       }
2083       // With C++ descriptors, the field can always be retrieved, but for
2084       // unknown extensions which have not been imported in Python code, there
2085       // is no message class and we cannot retrieve the value.
2086       // TODO(amauryfa): consider building the class on the fly!
2087       if (fields[i]->message_type() != NULL &&
2088           message_factory::GetMessageClass(
2089               GetFactoryForMessage(self),
2090               fields[i]->message_type()) == NULL) {
2091         PyErr_Clear();
2092         continue;
2093       }
2094       ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL));
2095       if (extensions == NULL) {
2096         return NULL;
2097       }
2098       // 'extension' reference later stolen by PyTuple_SET_ITEM.
2099       PyObject* extension = PyObject_GetItem(
2100           extensions.get(), extension_field.get());
2101       if (extension == NULL) {
2102         return NULL;
2103       }
2104       PyTuple_SET_ITEM(t.get(), 0, extension_field.release());
2105       // Steals reference to 'extension'
2106       PyTuple_SET_ITEM(t.get(), 1, extension);
2107     } else {
2108       // Normal field
2109       ScopedPyObjectPtr field_descriptor(
2110           PyFieldDescriptor_FromDescriptor(fields[i]));
2111       if (field_descriptor == NULL) {
2112         return NULL;
2113       }
2114 
2115       PyObject* field_value = GetFieldValue(self, fields[i]);
2116       if (field_value == NULL) {
2117         PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str());
2118         return NULL;
2119       }
2120       PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release());
2121       PyTuple_SET_ITEM(t.get(), 1, field_value);
2122     }
2123     PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
2124     ++actual_size;
2125   }
2126   if (static_cast<size_t>(actual_size) != fields.size() &&
2127       (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) <
2128        0)) {
2129     return NULL;
2130   }
2131   return all_fields.release();
2132 }
2133 
DiscardUnknownFields(CMessage * self)2134 static PyObject* DiscardUnknownFields(CMessage* self) {
2135   AssureWritable(self);
2136   self->message->DiscardUnknownFields();
2137   Py_RETURN_NONE;
2138 }
2139 
FindInitializationErrors(CMessage * self)2140 PyObject* FindInitializationErrors(CMessage* self) {
2141   Message* message = self->message;
2142   std::vector<std::string> errors;
2143   message->FindInitializationErrors(&errors);
2144 
2145   PyObject* error_list = PyList_New(errors.size());
2146   if (error_list == NULL) {
2147     return NULL;
2148   }
2149   for (size_t i = 0; i < errors.size(); ++i) {
2150     const std::string& error = errors[i];
2151     PyObject* error_string = PyString_FromStringAndSize(
2152         error.c_str(), error.length());
2153     if (error_string == NULL) {
2154       Py_DECREF(error_list);
2155       return NULL;
2156     }
2157     PyList_SET_ITEM(error_list, i, error_string);
2158   }
2159   return error_list;
2160 }
2161 
RichCompare(CMessage * self,PyObject * other,int opid)2162 static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
2163   // Only equality comparisons are implemented.
2164   if (opid != Py_EQ && opid != Py_NE) {
2165     Py_INCREF(Py_NotImplemented);
2166     return Py_NotImplemented;
2167   }
2168   bool equals = true;
2169   // If other is not a message, it cannot be equal.
2170   if (!PyObject_TypeCheck(other, CMessage_Type)) {
2171     equals = false;
2172   } else {
2173     // Otherwise, we have a CMessage whose message we can inspect.
2174     const google::protobuf::Message* other_message =
2175         reinterpret_cast<CMessage*>(other)->message;
2176     // If messages don't have the same descriptors, they are not equal.
2177     if (equals &&
2178         self->message->GetDescriptor() != other_message->GetDescriptor()) {
2179       equals = false;
2180     }
2181     // Check the message contents.
2182     if (equals &&
2183         !google::protobuf::util::MessageDifferencer::Equals(
2184             *self->message, *reinterpret_cast<CMessage*>(other)->message)) {
2185       equals = false;
2186     }
2187   }
2188 
2189   if (equals ^ (opid == Py_EQ)) {
2190     Py_RETURN_FALSE;
2191   } else {
2192     Py_RETURN_TRUE;
2193   }
2194 }
2195 
InternalGetScalar(const Message * message,const FieldDescriptor * field_descriptor)2196 PyObject* InternalGetScalar(const Message* message,
2197                             const FieldDescriptor* field_descriptor) {
2198   const Reflection* reflection = message->GetReflection();
2199 
2200   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2201     return NULL;
2202   }
2203 
2204   PyObject* result = NULL;
2205   switch (field_descriptor->cpp_type()) {
2206     case FieldDescriptor::CPPTYPE_INT32: {
2207       int32 value = reflection->GetInt32(*message, field_descriptor);
2208       result = PyInt_FromLong(value);
2209       break;
2210     }
2211     case FieldDescriptor::CPPTYPE_INT64: {
2212       int64 value = reflection->GetInt64(*message, field_descriptor);
2213       result = PyLong_FromLongLong(value);
2214       break;
2215     }
2216     case FieldDescriptor::CPPTYPE_UINT32: {
2217       uint32 value = reflection->GetUInt32(*message, field_descriptor);
2218       result = PyInt_FromSize_t(value);
2219       break;
2220     }
2221     case FieldDescriptor::CPPTYPE_UINT64: {
2222       uint64 value = reflection->GetUInt64(*message, field_descriptor);
2223       result = PyLong_FromUnsignedLongLong(value);
2224       break;
2225     }
2226     case FieldDescriptor::CPPTYPE_FLOAT: {
2227       float value = reflection->GetFloat(*message, field_descriptor);
2228       result = PyFloat_FromDouble(value);
2229       break;
2230     }
2231     case FieldDescriptor::CPPTYPE_DOUBLE: {
2232       double value = reflection->GetDouble(*message, field_descriptor);
2233       result = PyFloat_FromDouble(value);
2234       break;
2235     }
2236     case FieldDescriptor::CPPTYPE_BOOL: {
2237       bool value = reflection->GetBool(*message, field_descriptor);
2238       result = PyBool_FromLong(value);
2239       break;
2240     }
2241     case FieldDescriptor::CPPTYPE_STRING: {
2242       std::string scratch;
2243       const std::string& value =
2244           reflection->GetStringReference(*message, field_descriptor, &scratch);
2245       result = ToStringObject(field_descriptor, value);
2246       break;
2247     }
2248     case FieldDescriptor::CPPTYPE_ENUM: {
2249       const EnumValueDescriptor* enum_value =
2250           message->GetReflection()->GetEnum(*message, field_descriptor);
2251       result = PyInt_FromLong(enum_value->number());
2252       break;
2253     }
2254     default:
2255       PyErr_Format(
2256           PyExc_SystemError, "Getting a value from a field of unknown type %d",
2257           field_descriptor->cpp_type());
2258   }
2259 
2260   return result;
2261 }
2262 
InternalGetSubMessage(CMessage * self,const FieldDescriptor * field_descriptor)2263 CMessage* InternalGetSubMessage(
2264     CMessage* self, const FieldDescriptor* field_descriptor) {
2265   const Reflection* reflection = self->message->GetReflection();
2266   PyMessageFactory* factory = GetFactoryForMessage(self);
2267   const Message& sub_message = reflection->GetMessage(
2268       *self->message, field_descriptor, factory->message_factory);
2269 
2270   CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2271       factory, field_descriptor->message_type());
2272   ScopedPyObjectPtr message_class_owner(
2273       reinterpret_cast<PyObject*>(message_class));
2274   if (message_class == NULL) {
2275     return NULL;
2276   }
2277 
2278   CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
2279   if (cmsg == NULL) {
2280     return NULL;
2281   }
2282 
2283   Py_INCREF(self);
2284   cmsg->parent = self;
2285   cmsg->parent_field_descriptor = field_descriptor;
2286   cmsg->read_only = !reflection->HasField(*self->message, field_descriptor);
2287   cmsg->message = const_cast<Message*>(&sub_message);
2288   return cmsg;
2289 }
2290 
InternalSetNonOneofScalar(Message * message,const FieldDescriptor * field_descriptor,PyObject * arg)2291 int InternalSetNonOneofScalar(
2292     Message* message,
2293     const FieldDescriptor* field_descriptor,
2294     PyObject* arg) {
2295   const Reflection* reflection = message->GetReflection();
2296 
2297   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2298     return -1;
2299   }
2300 
2301   switch (field_descriptor->cpp_type()) {
2302     case FieldDescriptor::CPPTYPE_INT32: {
2303       GOOGLE_CHECK_GET_INT32(arg, value, -1);
2304       reflection->SetInt32(message, field_descriptor, value);
2305       break;
2306     }
2307     case FieldDescriptor::CPPTYPE_INT64: {
2308       GOOGLE_CHECK_GET_INT64(arg, value, -1);
2309       reflection->SetInt64(message, field_descriptor, value);
2310       break;
2311     }
2312     case FieldDescriptor::CPPTYPE_UINT32: {
2313       GOOGLE_CHECK_GET_UINT32(arg, value, -1);
2314       reflection->SetUInt32(message, field_descriptor, value);
2315       break;
2316     }
2317     case FieldDescriptor::CPPTYPE_UINT64: {
2318       GOOGLE_CHECK_GET_UINT64(arg, value, -1);
2319       reflection->SetUInt64(message, field_descriptor, value);
2320       break;
2321     }
2322     case FieldDescriptor::CPPTYPE_FLOAT: {
2323       GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
2324       reflection->SetFloat(message, field_descriptor, value);
2325       break;
2326     }
2327     case FieldDescriptor::CPPTYPE_DOUBLE: {
2328       GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
2329       reflection->SetDouble(message, field_descriptor, value);
2330       break;
2331     }
2332     case FieldDescriptor::CPPTYPE_BOOL: {
2333       GOOGLE_CHECK_GET_BOOL(arg, value, -1);
2334       reflection->SetBool(message, field_descriptor, value);
2335       break;
2336     }
2337     case FieldDescriptor::CPPTYPE_STRING: {
2338       if (!CheckAndSetString(
2339           arg, message, field_descriptor, reflection, false, -1)) {
2340         return -1;
2341       }
2342       break;
2343     }
2344     case FieldDescriptor::CPPTYPE_ENUM: {
2345       GOOGLE_CHECK_GET_INT32(arg, value, -1);
2346       if (reflection->SupportsUnknownEnumValues()) {
2347         reflection->SetEnumValue(message, field_descriptor, value);
2348       } else {
2349         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
2350         const EnumValueDescriptor* enum_value =
2351             enum_descriptor->FindValueByNumber(value);
2352         if (enum_value != NULL) {
2353           reflection->SetEnum(message, field_descriptor, enum_value);
2354         } else {
2355           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
2356           return -1;
2357         }
2358       }
2359       break;
2360     }
2361     default:
2362       PyErr_Format(
2363           PyExc_SystemError, "Setting value to a field of unknown type %d",
2364           field_descriptor->cpp_type());
2365       return -1;
2366   }
2367 
2368   return 0;
2369 }
2370 
InternalSetScalar(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * arg)2371 int InternalSetScalar(
2372     CMessage* self,
2373     const FieldDescriptor* field_descriptor,
2374     PyObject* arg) {
2375   if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
2376     return -1;
2377   }
2378 
2379   if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
2380     return -1;
2381   }
2382 
2383   return InternalSetNonOneofScalar(self->message, field_descriptor, arg);
2384 }
2385 
FromString(PyTypeObject * cls,PyObject * serialized)2386 PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
2387   PyObject* py_cmsg = PyObject_CallObject(
2388       reinterpret_cast<PyObject*>(cls), NULL);
2389   if (py_cmsg == NULL) {
2390     return NULL;
2391   }
2392   CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
2393 
2394   ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized));
2395   if (py_length == NULL) {
2396     Py_DECREF(py_cmsg);
2397     return NULL;
2398   }
2399 
2400   return py_cmsg;
2401 }
2402 
DeepCopy(CMessage * self,PyObject * arg)2403 PyObject* DeepCopy(CMessage* self, PyObject* arg) {
2404   PyObject* clone = PyObject_CallObject(
2405       reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL);
2406   if (clone == NULL) {
2407     return NULL;
2408   }
2409   if (!PyObject_TypeCheck(clone, CMessage_Type)) {
2410     Py_DECREF(clone);
2411     return NULL;
2412   }
2413   if (ScopedPyObjectPtr(MergeFrom(
2414           reinterpret_cast<CMessage*>(clone),
2415           reinterpret_cast<PyObject*>(self))) == NULL) {
2416     Py_DECREF(clone);
2417     return NULL;
2418   }
2419   return clone;
2420 }
2421 
ToUnicode(CMessage * self)2422 PyObject* ToUnicode(CMessage* self) {
2423   // Lazy import to prevent circular dependencies
2424   ScopedPyObjectPtr text_format(
2425       PyImport_ImportModule("google.protobuf.text_format"));
2426   if (text_format == NULL) {
2427     return NULL;
2428   }
2429   ScopedPyObjectPtr method_name(PyString_FromString("MessageToString"));
2430   if (method_name == NULL) {
2431     return NULL;
2432   }
2433   Py_INCREF(Py_True);
2434   ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(
2435       text_format.get(), method_name.get(), self, Py_True, NULL));
2436   Py_DECREF(Py_True);
2437   if (encoded == NULL) {
2438     return NULL;
2439   }
2440 #if PY_MAJOR_VERSION < 3
2441   PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL);
2442 #else
2443   PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL);
2444 #endif
2445   if (decoded == NULL) {
2446     return NULL;
2447   }
2448   return decoded;
2449 }
2450 
2451 // CMessage static methods:
_CheckCalledFromGeneratedFile(PyObject * unused,PyObject * unused_arg)2452 PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
2453                                         PyObject* unused_arg) {
2454   if (!_CalledFromGeneratedFile(1)) {
2455     PyErr_SetString(PyExc_TypeError,
2456                     "Descriptors should not be created directly, "
2457                     "but only retrieved from their parent.");
2458     return NULL;
2459   }
2460   Py_RETURN_NONE;
2461 }
2462 
GetExtensionDict(CMessage * self,void * closure)2463 static PyObject* GetExtensionDict(CMessage* self, void *closure) {
2464   // If there are extension_ranges, the message is "extendable". Allocate a
2465   // dictionary to store the extension fields.
2466   const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
2467   if (!descriptor->extension_range_count()) {
2468     PyErr_SetNone(PyExc_AttributeError);
2469     return NULL;
2470   }
2471   if (!self->composite_fields) {
2472     self->composite_fields = new CMessage::CompositeFieldsMap();
2473   }
2474   if (!self->composite_fields) {
2475     return NULL;
2476   }
2477   ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
2478   return reinterpret_cast<PyObject*>(extension_dict);
2479 }
2480 
UnknownFieldSet(CMessage * self)2481 static PyObject* UnknownFieldSet(CMessage* self) {
2482   if (self->unknown_field_set == NULL) {
2483     self->unknown_field_set = unknown_fields::NewPyUnknownFields(self);
2484   } else {
2485     Py_INCREF(self->unknown_field_set);
2486   }
2487   return self->unknown_field_set;
2488 }
2489 
GetExtensionsByName(CMessage * self,void * closure)2490 static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
2491   return message_meta::GetExtensionsByName(
2492       reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2493 }
2494 
GetExtensionsByNumber(CMessage * self,void * closure)2495 static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
2496   return message_meta::GetExtensionsByNumber(
2497       reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2498 }
2499 
2500 static PyGetSetDef Getters[] = {
2501   {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
2502   {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
2503   {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
2504   {NULL}
2505 };
2506 
2507 
2508 static PyMethodDef Methods[] = {
2509   { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
2510     "Makes a deep copy of the class." },
2511   { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
2512     "Outputs a unicode representation of the message." },
2513   { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
2514     "Returns the size of the message in bytes." },
2515   { "Clear", (PyCFunction)Clear, METH_NOARGS,
2516     "Clears the message." },
2517   { "ClearExtension", (PyCFunction)ClearExtension, METH_O,
2518     "Clears a message field." },
2519   { "ClearField", (PyCFunction)ClearField, METH_O,
2520     "Clears a message field." },
2521   { "CopyFrom", (PyCFunction)CopyFrom, METH_O,
2522     "Copies a protocol message into the current message." },
2523   { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
2524     "Discards the unknown fields." },
2525   { "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
2526     METH_NOARGS,
2527     "Finds unset required fields." },
2528   { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS,
2529     "Creates new method instance from given serialized data." },
2530   { "HasExtension", (PyCFunction)HasExtension, METH_O,
2531     "Checks if a message field is set." },
2532   { "HasField", (PyCFunction)HasField, METH_O,
2533     "Checks if a message field is set." },
2534   { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS,
2535     "Checks if all required fields of a protocol message are set." },
2536   { "ListFields", (PyCFunction)ListFields, METH_NOARGS,
2537     "Lists all set fields of a message." },
2538   { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
2539     "Merges a protocol message into the current message." },
2540   { "MergeFromString", (PyCFunction)MergeFromString, METH_O,
2541     "Merges a serialized message into the current message." },
2542   { "ParseFromString", (PyCFunction)ParseFromString, METH_O,
2543     "Parses a serialized message into the current message." },
2544   { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
2545     "Registers an extension with the current message." },
2546   { "SerializePartialToString", (PyCFunction)SerializePartialToString,
2547     METH_VARARGS | METH_KEYWORDS,
2548     "Serializes the message to a string, even if it isn't initialized." },
2549   { "SerializeToString", (PyCFunction)SerializeToString,
2550     METH_VARARGS | METH_KEYWORDS,
2551     "Serializes the message to a string, only for initialized messages." },
2552   { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
2553     "Sets the has bit of the given field in its parent message." },
2554   { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS,
2555     "Parse unknown field set"},
2556   { "WhichOneof", (PyCFunction)WhichOneof, METH_O,
2557     "Returns the name of the field set inside a oneof, "
2558     "or None if no field is set." },
2559 
2560   // Static Methods.
2561   { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile,
2562     METH_NOARGS | METH_STATIC,
2563     "Raises TypeError if the caller is not in a _pb2.py file."},
2564   { NULL, NULL}
2565 };
2566 
SetCompositeField(CMessage * self,const FieldDescriptor * field,ContainerBase * value)2567 bool SetCompositeField(CMessage* self, const FieldDescriptor* field,
2568                        ContainerBase* value) {
2569   if (self->composite_fields == NULL) {
2570     self->composite_fields = new CMessage::CompositeFieldsMap();
2571   }
2572   (*self->composite_fields)[field] = value;
2573   return true;
2574 }
2575 
SetSubmessage(CMessage * self,CMessage * submessage)2576 bool SetSubmessage(CMessage* self, CMessage* submessage) {
2577   if (self->child_submessages == NULL) {
2578     self->child_submessages = new CMessage::SubMessagesMap();
2579   }
2580   (*self->child_submessages)[submessage->message] = submessage;
2581   return true;
2582 }
2583 
GetAttr(PyObject * pself,PyObject * name)2584 PyObject* GetAttr(PyObject* pself, PyObject* name) {
2585   CMessage* self = reinterpret_cast<CMessage*>(pself);
2586   PyObject* result = PyObject_GenericGetAttr(
2587       reinterpret_cast<PyObject*>(self), name);
2588   if (result != NULL) {
2589     return result;
2590   }
2591   if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
2592     return NULL;
2593   }
2594 
2595   PyErr_Clear();
2596   return message_meta::GetClassAttribute(
2597       CheckMessageClass(Py_TYPE(self)), name);
2598 }
2599 
GetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor)2600 PyObject* GetFieldValue(CMessage* self,
2601                         const FieldDescriptor* field_descriptor) {
2602   if (self->composite_fields) {
2603     CMessage::CompositeFieldsMap::iterator it =
2604         self->composite_fields->find(field_descriptor);
2605     if (it != self->composite_fields->end()) {
2606       ContainerBase* value = it->second;
2607       Py_INCREF(value);
2608       return value->AsPyObject();
2609     }
2610   }
2611 
2612   if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2613     PyErr_Format(PyExc_TypeError,
2614                  "descriptor to field '%s' doesn't apply to '%s' object",
2615                  field_descriptor->full_name().c_str(),
2616                  Py_TYPE(self)->tp_name);
2617     return NULL;
2618   }
2619 
2620   if (!field_descriptor->is_repeated() &&
2621       field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
2622     return InternalGetScalar(self->message, field_descriptor);
2623   }
2624 
2625   ContainerBase* py_container = nullptr;
2626   if (field_descriptor->is_map()) {
2627     const Descriptor* entry_type = field_descriptor->message_type();
2628     const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
2629     if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2630       CMessageClass* value_class = message_factory::GetMessageClass(
2631           GetFactoryForMessage(self), value_type->message_type());
2632       if (value_class == NULL) {
2633         return NULL;
2634       }
2635       py_container =
2636           NewMessageMapContainer(self, field_descriptor, value_class);
2637     } else {
2638       py_container = NewScalarMapContainer(self, field_descriptor);
2639     }
2640   } else if (field_descriptor->is_repeated()) {
2641     if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2642       CMessageClass* message_class = message_factory::GetMessageClass(
2643           GetFactoryForMessage(self), field_descriptor->message_type());
2644       if (message_class == NULL) {
2645         return NULL;
2646       }
2647       py_container = repeated_composite_container::NewContainer(
2648           self, field_descriptor, message_class);
2649     } else {
2650       py_container =
2651           repeated_scalar_container::NewContainer(self, field_descriptor);
2652     }
2653   } else if (field_descriptor->cpp_type() ==
2654              FieldDescriptor::CPPTYPE_MESSAGE) {
2655     py_container = InternalGetSubMessage(self, field_descriptor);
2656   } else {
2657     PyErr_SetString(PyExc_SystemError, "Should never happen");
2658   }
2659 
2660   if (py_container == NULL) {
2661     return NULL;
2662   }
2663   if (!SetCompositeField(self, field_descriptor, py_container)) {
2664     Py_DECREF(py_container);
2665     return NULL;
2666   }
2667   return py_container->AsPyObject();
2668 }
2669 
SetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * value)2670 int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
2671                   PyObject* value) {
2672   if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2673     PyErr_Format(PyExc_TypeError,
2674                  "descriptor to field '%s' doesn't apply to '%s' object",
2675                  field_descriptor->full_name().c_str(),
2676                  Py_TYPE(self)->tp_name);
2677     return -1;
2678   } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
2679     PyErr_Format(PyExc_AttributeError,
2680                  "Assignment not allowed to repeated "
2681                  "field \"%s\" in protocol message object.",
2682                  field_descriptor->name().c_str());
2683     return -1;
2684   } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2685     PyErr_Format(PyExc_AttributeError,
2686                  "Assignment not allowed to "
2687                  "field \"%s\" in protocol message object.",
2688                  field_descriptor->name().c_str());
2689     return -1;
2690   } else {
2691     AssureWritable(self);
2692     return InternalSetScalar(self, field_descriptor, value);
2693   }
2694 }
2695 
2696 }  // namespace cmessage
2697 
2698 // All containers which are not messages:
2699 // - Make a new parent message
2700 // - Copy the field
2701 // - return the field.
DeepCopy()2702 PyObject* ContainerBase::DeepCopy() {
2703   CMessage* new_parent =
2704       cmessage::NewEmptyMessage(this->parent->GetMessageClass());
2705   new_parent->message = this->parent->message->New();
2706 
2707   // Copy the map field into the new message.
2708   this->parent->message->GetReflection()->SwapFields(
2709       this->parent->message, new_parent->message,
2710       {this->parent_field_descriptor});
2711   this->parent->message->MergeFrom(*new_parent->message);
2712 
2713   PyObject* result =
2714       cmessage::GetFieldValue(new_parent, this->parent_field_descriptor);
2715   Py_DECREF(new_parent);
2716   return result;
2717 }
2718 
RemoveFromParentCache()2719 void ContainerBase::RemoveFromParentCache() {
2720   CMessage* parent = this->parent;
2721   if (parent) {
2722     if (parent->composite_fields)
2723       parent->composite_fields->erase(this->parent_field_descriptor);
2724     Py_CLEAR(parent);
2725   }
2726 }
2727 
BuildSubMessageFromPointer(const FieldDescriptor * field_descriptor,Message * sub_message,CMessageClass * message_class)2728 CMessage* CMessage::BuildSubMessageFromPointer(
2729     const FieldDescriptor* field_descriptor, Message* sub_message,
2730     CMessageClass* message_class) {
2731   if (!this->child_submessages) {
2732     this->child_submessages = new CMessage::SubMessagesMap();
2733   }
2734   CMessage* cmsg = FindPtrOrNull(
2735       *this->child_submessages, sub_message);
2736   if (cmsg) {
2737     Py_INCREF(cmsg);
2738   } else {
2739     cmsg = cmessage::NewEmptyMessage(message_class);
2740 
2741     if (cmsg == NULL) {
2742       return NULL;
2743     }
2744     cmsg->message = sub_message;
2745     Py_INCREF(this);
2746     cmsg->parent = this;
2747     cmsg->parent_field_descriptor = field_descriptor;
2748     cmessage::SetSubmessage(this, cmsg);
2749   }
2750   return cmsg;
2751 }
2752 
MaybeReleaseSubMessage(Message * sub_message)2753 CMessage* CMessage::MaybeReleaseSubMessage(Message* sub_message) {
2754   if (!this->child_submessages) {
2755     return nullptr;
2756   }
2757   CMessage* released = FindPtrOrNull(
2758       *this->child_submessages, sub_message);
2759   if (!released) {
2760     return nullptr;
2761   }
2762   // The target message will now own its content.
2763   Py_CLEAR(released->parent);
2764   released->parent_field_descriptor = nullptr;
2765   released->read_only = false;
2766   // Delete it from the cache.
2767   this->child_submessages->erase(sub_message);
2768   return released;
2769 }
2770 
2771 static CMessageClass _CMessage_Type = { { {
2772   PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0)
2773   FULL_MODULE_NAME ".CMessage",        // tp_name
2774   sizeof(CMessage),                    // tp_basicsize
2775   0,                                   //  tp_itemsize
2776   (destructor)cmessage::Dealloc,       //  tp_dealloc
2777   0,                                   //  tp_print
2778   0,                                   //  tp_getattr
2779   0,                                   //  tp_setattr
2780   0,                                   //  tp_compare
2781   (reprfunc)cmessage::ToStr,           //  tp_repr
2782   0,                                   //  tp_as_number
2783   0,                                   //  tp_as_sequence
2784   0,                                   //  tp_as_mapping
2785   PyObject_HashNotImplemented,         //  tp_hash
2786   0,                                   //  tp_call
2787   (reprfunc)cmessage::ToStr,           //  tp_str
2788   cmessage::GetAttr,                   //  tp_getattro
2789   0,                                   //  tp_setattro
2790   0,                                   //  tp_as_buffer
2791   Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE
2792       | Py_TPFLAGS_HAVE_VERSION_TAG,   //  tp_flags
2793   "A ProtocolMessage",                 //  tp_doc
2794   0,                                   //  tp_traverse
2795   0,                                   //  tp_clear
2796   (richcmpfunc)cmessage::RichCompare,  //  tp_richcompare
2797   offsetof(CMessage, weakreflist),     //  tp_weaklistoffset
2798   0,                                   //  tp_iter
2799   0,                                   //  tp_iternext
2800   cmessage::Methods,                   //  tp_methods
2801   0,                                   //  tp_members
2802   cmessage::Getters,                   //  tp_getset
2803   0,                                   //  tp_base
2804   0,                                   //  tp_dict
2805   0,                                   //  tp_descr_get
2806   0,                                   //  tp_descr_set
2807   0,                                   //  tp_dictoffset
2808   (initproc)cmessage::Init,            //  tp_init
2809   0,                                   //  tp_alloc
2810   cmessage::New,                       //  tp_new
2811 } } };
2812 PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type;
2813 
2814 // --- Exposing the C proto living inside Python proto to C code:
2815 
2816 const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
2817 Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
2818 
GetCProtoInsidePyProtoImpl(PyObject * msg)2819 static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
2820   const Message* message = PyMessage_GetMessagePointer(msg);
2821   if (message == NULL) {
2822     PyErr_Clear();
2823     return NULL;
2824   }
2825   return message;
2826 }
2827 
MutableCProtoInsidePyProtoImpl(PyObject * msg)2828 static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
2829   Message* message = PyMessage_GetMutableMessagePointer(msg);
2830   if (message == NULL) {
2831     PyErr_Clear();
2832     return NULL;
2833   }
2834   return message;
2835 }
2836 
PyMessage_GetMessagePointer(PyObject * msg)2837 const Message* PyMessage_GetMessagePointer(PyObject* msg) {
2838   if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2839     PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2840     return NULL;
2841   }
2842   CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2843   return cmsg->message;
2844 }
2845 
PyMessage_GetMutableMessagePointer(PyObject * msg)2846 Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
2847   if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2848     PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2849     return NULL;
2850   }
2851   CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2852 
2853 
2854   if ((cmsg->composite_fields && !cmsg->composite_fields->empty()) ||
2855       (cmsg->child_submessages && !cmsg->child_submessages->empty())) {
2856     // There is currently no way of accurately syncing arbitrary changes to
2857     // the underlying C++ message back to the CMessage (e.g. removed repeated
2858     // composite containers). We only allow direct mutation of the underlying
2859     // C++ message if there is no child data in the CMessage.
2860     PyErr_SetString(PyExc_ValueError,
2861                     "Cannot reliably get a mutable pointer "
2862                     "to a message with extra references");
2863     return NULL;
2864   }
2865   cmessage::AssureWritable(cmsg);
2866   return cmsg->message;
2867 }
2868 
PyMessage_NewMessageOwnedExternally(Message * message,PyObject * message_factory)2869 PyObject* PyMessage_NewMessageOwnedExternally(Message* message,
2870                                               PyObject* message_factory) {
2871   if (message_factory) {
2872     PyErr_SetString(PyExc_NotImplementedError,
2873                     "Default message_factory=NULL is the only supported value");
2874     return NULL;
2875   }
2876   if (message->GetReflection()->GetMessageFactory() !=
2877       MessageFactory::generated_factory()) {
2878     PyErr_SetString(PyExc_TypeError,
2879                     "Message pointer was not created from the default factory");
2880     return NULL;
2881   }
2882 
2883   CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2884       GetDefaultDescriptorPool()->py_message_factory, message->GetDescriptor());
2885 
2886   CMessage* self = cmessage::NewEmptyMessage(message_class);
2887   if (self == NULL) {
2888     return NULL;
2889   }
2890   Py_DECREF(message_class);
2891   self->message = message;
2892   Py_INCREF(Py_None);
2893   self->parent = reinterpret_cast<CMessage*>(Py_None);
2894   return self->AsPyObject();
2895 }
2896 
InitGlobals()2897 void InitGlobals() {
2898   // TODO(gps): Check all return values in this function for NULL and propagate
2899   // the error (MemoryError) on up to result in an import failure.  These should
2900   // also be freed and reset to NULL during finalization.
2901   kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
2902 
2903   PyObject *dummy_obj = PySet_New(NULL);
2904   kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
2905   Py_DECREF(dummy_obj);
2906 }
2907 
InitProto2MessageModule(PyObject * m)2908 bool InitProto2MessageModule(PyObject *m) {
2909   // Initialize types and globals in descriptor.cc
2910   if (!InitDescriptor()) {
2911     return false;
2912   }
2913 
2914   // Initialize types and globals in descriptor_pool.cc
2915   if (!InitDescriptorPool()) {
2916     return false;
2917   }
2918 
2919   // Initialize types and globals in message_factory.cc
2920   if (!InitMessageFactory()) {
2921     return false;
2922   }
2923 
2924   // Initialize constants defined in this file.
2925   InitGlobals();
2926 
2927   CMessageClass_Type->tp_base = &PyType_Type;
2928   if (PyType_Ready(CMessageClass_Type) < 0) {
2929     return false;
2930   }
2931   PyModule_AddObject(m, "MessageMeta",
2932                      reinterpret_cast<PyObject*>(CMessageClass_Type));
2933 
2934   if (PyType_Ready(CMessage_Type) < 0) {
2935     return false;
2936   }
2937   if (PyType_Ready(CFieldProperty_Type) < 0) {
2938     return false;
2939   }
2940 
2941   // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
2942   // it here as well to document that subclasses need to set it.
2943   PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None);
2944   // Invalidate any cached data for the CMessage type.
2945   // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG,
2946   // after we have modified CMessage_Type.tp_dict.
2947   PyType_Modified(CMessage_Type);
2948 
2949   PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type));
2950 
2951   // Initialize Repeated container types.
2952   {
2953     if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) {
2954       return false;
2955     }
2956 
2957     PyModule_AddObject(m, "RepeatedScalarContainer",
2958                        reinterpret_cast<PyObject*>(
2959                            &RepeatedScalarContainer_Type));
2960 
2961     if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) {
2962       return false;
2963     }
2964 
2965     PyModule_AddObject(
2966         m, "RepeatedCompositeContainer",
2967         reinterpret_cast<PyObject*>(
2968             &RepeatedCompositeContainer_Type));
2969 
2970     // Register them as MutableSequence.
2971 #if PY_MAJOR_VERSION >= 3
2972     ScopedPyObjectPtr collections(PyImport_ImportModule("collections.abc"));
2973 #else
2974     ScopedPyObjectPtr collections(PyImport_ImportModule("collections"));
2975 #endif
2976     if (collections == NULL) {
2977       return false;
2978     }
2979     ScopedPyObjectPtr mutable_sequence(
2980         PyObject_GetAttrString(collections.get(), "MutableSequence"));
2981     if (mutable_sequence == NULL) {
2982       return false;
2983     }
2984     if (ScopedPyObjectPtr(
2985             PyObject_CallMethod(mutable_sequence.get(), "register", "O",
2986                                 &RepeatedScalarContainer_Type)) == NULL) {
2987       return false;
2988     }
2989     if (ScopedPyObjectPtr(
2990             PyObject_CallMethod(mutable_sequence.get(), "register", "O",
2991                                 &RepeatedCompositeContainer_Type)) == NULL) {
2992       return false;
2993     }
2994   }
2995 
2996   if (PyType_Ready(&PyUnknownFields_Type) < 0) {
2997     return false;
2998   }
2999 
3000   PyModule_AddObject(m, "UnknownFieldSet",
3001                      reinterpret_cast<PyObject*>(
3002                          &PyUnknownFields_Type));
3003 
3004   if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) {
3005     return false;
3006   }
3007 
3008   PyModule_AddObject(m, "UnknownField",
3009                      reinterpret_cast<PyObject*>(
3010                          &PyUnknownFieldRef_Type));
3011 
3012   // Initialize Map container types.
3013   if (!InitMapContainers()) {
3014     return false;
3015   }
3016   PyModule_AddObject(m, "ScalarMapContainer",
3017                      reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
3018   PyModule_AddObject(m, "MessageMapContainer",
3019                      reinterpret_cast<PyObject*>(MessageMapContainer_Type));
3020   PyModule_AddObject(m, "MapIterator",
3021                      reinterpret_cast<PyObject*>(&MapIterator_Type));
3022 
3023   if (PyType_Ready(&ExtensionDict_Type) < 0) {
3024     return false;
3025   }
3026   PyModule_AddObject(
3027       m, "ExtensionDict",
3028       reinterpret_cast<PyObject*>(&ExtensionDict_Type));
3029   if (PyType_Ready(&ExtensionIterator_Type) < 0) {
3030     return false;
3031   }
3032   PyModule_AddObject(m, "ExtensionIterator",
3033                      reinterpret_cast<PyObject*>(&ExtensionIterator_Type));
3034 
3035   // Expose the DescriptorPool used to hold all descriptors added from generated
3036   // pb2.py files.
3037   // PyModule_AddObject steals a reference.
3038   Py_INCREF(GetDefaultDescriptorPool());
3039   PyModule_AddObject(m, "default_pool",
3040                      reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
3041 
3042   PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>(
3043       &PyDescriptorPool_Type));
3044 
3045   PyModule_AddObject(m, "Descriptor", reinterpret_cast<PyObject*>(
3046       &PyMessageDescriptor_Type));
3047   PyModule_AddObject(m, "FieldDescriptor", reinterpret_cast<PyObject*>(
3048       &PyFieldDescriptor_Type));
3049   PyModule_AddObject(m, "EnumDescriptor", reinterpret_cast<PyObject*>(
3050       &PyEnumDescriptor_Type));
3051   PyModule_AddObject(m, "EnumValueDescriptor", reinterpret_cast<PyObject*>(
3052       &PyEnumValueDescriptor_Type));
3053   PyModule_AddObject(m, "FileDescriptor", reinterpret_cast<PyObject*>(
3054       &PyFileDescriptor_Type));
3055   PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
3056       &PyOneofDescriptor_Type));
3057   PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>(
3058       &PyServiceDescriptor_Type));
3059   PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>(
3060       &PyMethodDescriptor_Type));
3061 
3062   PyObject* enum_type_wrapper = PyImport_ImportModule(
3063       "google.protobuf.internal.enum_type_wrapper");
3064   if (enum_type_wrapper == NULL) {
3065     return false;
3066   }
3067   EnumTypeWrapper_class =
3068       PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
3069   Py_DECREF(enum_type_wrapper);
3070 
3071   PyObject* message_module = PyImport_ImportModule(
3072       "google.protobuf.message");
3073   if (message_module == NULL) {
3074     return false;
3075   }
3076   EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError");
3077   DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError");
3078   PythonMessage_class = PyObject_GetAttrString(message_module, "Message");
3079   Py_DECREF(message_module);
3080 
3081   PyObject* pickle_module = PyImport_ImportModule("pickle");
3082   if (pickle_module == NULL) {
3083     return false;
3084   }
3085   PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError");
3086   Py_DECREF(pickle_module);
3087 
3088   // Override {Get,Mutable}CProtoInsidePyProto.
3089   GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl;
3090   MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl;
3091 
3092   return true;
3093 }
3094 
3095 }  // namespace python
3096 }  // namespace protobuf
3097 }  // namespace google
3098