• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33 
34 #include <google/protobuf/pyext/message.h>
35 
36 #include <map>
37 #include <memory>
38 #include <string>
39 #include <vector>
40 #include <structmember.h>  // A Python header file.
41 
42 #ifndef PyVarObject_HEAD_INIT
43 #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
44 #endif
45 #ifndef Py_TYPE
46 #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
47 #endif
48 #include <google/protobuf/stubs/common.h>
49 #include <google/protobuf/stubs/logging.h>
50 #include <google/protobuf/io/strtod.h>
51 #include <google/protobuf/io/coded_stream.h>
52 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
53 #include <google/protobuf/descriptor.pb.h>
54 #include <google/protobuf/descriptor.h>
55 #include <google/protobuf/message.h>
56 #include <google/protobuf/text_format.h>
57 #include <google/protobuf/unknown_field_set.h>
58 #include <google/protobuf/pyext/descriptor.h>
59 #include <google/protobuf/pyext/descriptor_pool.h>
60 #include <google/protobuf/pyext/extension_dict.h>
61 #include <google/protobuf/pyext/field.h>
62 #include <google/protobuf/pyext/map_container.h>
63 #include <google/protobuf/pyext/message_factory.h>
64 #include <google/protobuf/pyext/repeated_composite_container.h>
65 #include <google/protobuf/pyext/repeated_scalar_container.h>
66 #include <google/protobuf/pyext/unknown_fields.h>
67 #include <google/protobuf/pyext/safe_numerics.h>
68 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
69 #include <google/protobuf/util/message_differencer.h>
70 #include <google/protobuf/stubs/strutil.h>
71 #include <google/protobuf/stubs/map_util.h>
72 
73 #include <google/protobuf/port_def.inc>
74 
75 #if PY_MAJOR_VERSION >= 3
76   #define PyInt_AsLong PyLong_AsLong
77   #define PyInt_FromLong PyLong_FromLong
78   #define PyInt_FromSize_t PyLong_FromSize_t
79   #define PyString_Check PyUnicode_Check
80   #define PyString_FromString PyUnicode_FromString
81   #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
82   #define PyString_FromFormat PyUnicode_FromFormat
83   #if PY_VERSION_HEX < 0x03030000
84     #error "Python 3.0 - 3.2 are not supported."
85   #else
86   #define PyString_AsString(ob) \
87     (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob))
88 #define PyString_AsStringAndSize(ob, charpp, sizep)                           \
89   (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>(                     \
90                                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
91                               ? -1                                            \
92                               : 0)                                            \
93                        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
94 #endif
95 #endif
96 
97 namespace google {
98 namespace protobuf {
99 namespace python {
100 
101 static PyObject* kDESCRIPTOR;
102 PyObject* EnumTypeWrapper_class;
103 static PyObject* PythonMessage_class;
104 static PyObject* kEmptyWeakref;
105 static PyObject* WKT_classes = NULL;
106 
107 namespace message_meta {
108 
109 static int InsertEmptyWeakref(PyTypeObject* base);
110 
111 namespace {
112 // Copied over from internal 'google/protobuf/stubs/strutil.h'.
LowerString(string * s)113 inline void LowerString(string * s) {
114   string::iterator end = s->end();
115   for (string::iterator i = s->begin(); i != end; ++i) {
116     // tolower() changes based on locale.  We don't want this!
117     if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A';
118   }
119 }
120 }
121 
122 // Finalize the creation of the Message class.
AddDescriptors(PyObject * cls,const Descriptor * descriptor)123 static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
124   // For each field set: cls.<field>_FIELD_NUMBER = <number>
125   for (int i = 0; i < descriptor->field_count(); ++i) {
126     const FieldDescriptor* field_descriptor = descriptor->field(i);
127     ScopedPyObjectPtr property(NewFieldProperty(field_descriptor));
128     if (property == NULL) {
129       return -1;
130     }
131     if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(),
132                                property.get()) < 0) {
133       return -1;
134     }
135   }
136 
137   // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
138   for (int i = 0; i < descriptor->enum_type_count(); ++i) {
139     const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
140     ScopedPyObjectPtr enum_type(
141         PyEnumDescriptor_FromDescriptor(enum_descriptor));
142     if (enum_type == NULL) {
143       return -1;
144      }
145     // Add wrapped enum type to message class.
146     ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
147         EnumTypeWrapper_class, enum_type.get(), NULL));
148     if (wrapped == NULL) {
149       return -1;
150     }
151     if (PyObject_SetAttrString(
152             cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) {
153       return -1;
154     }
155 
156     // For each enum value add cls.<name> = <number>
157     for (int j = 0; j < enum_descriptor->value_count(); ++j) {
158       const EnumValueDescriptor* enum_value_descriptor =
159           enum_descriptor->value(j);
160       ScopedPyObjectPtr value_number(PyInt_FromLong(
161           enum_value_descriptor->number()));
162       if (value_number == NULL) {
163         return -1;
164       }
165       if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(),
166                                  value_number.get()) == -1) {
167         return -1;
168       }
169     }
170   }
171 
172   // For each extension set cls.<extension name> = <extension descriptor>.
173   //
174   // Extension descriptors come from
175   // <message descriptor>.extensions_by_name[name]
176   // which was defined previously.
177   for (int i = 0; i < descriptor->extension_count(); ++i) {
178     const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
179     ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
180     if (extension_field == NULL) {
181       return -1;
182     }
183 
184     // Add the extension field to the message class.
185     if (PyObject_SetAttrString(
186             cls, field->name().c_str(), extension_field.get()) == -1) {
187       return -1;
188     }
189   }
190 
191   return 0;
192 }
193 
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)194 static PyObject* New(PyTypeObject* type,
195                      PyObject* args, PyObject* kwargs) {
196   static char *kwlist[] = {"name", "bases", "dict", 0};
197   PyObject *bases, *dict;
198   const char* name;
199 
200   // Check arguments: (name, bases, dict)
201   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist,
202                                    &name,
203                                    &PyTuple_Type, &bases,
204                                    &PyDict_Type, &dict)) {
205     return NULL;
206   }
207 
208   // Check bases: only (), or (message.Message,) are allowed
209   if (!(PyTuple_GET_SIZE(bases) == 0 ||
210         (PyTuple_GET_SIZE(bases) == 1 &&
211          PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) {
212     PyErr_SetString(PyExc_TypeError,
213                     "A Message class can only inherit from Message");
214     return NULL;
215   }
216 
217   // Check dict['DESCRIPTOR']
218   PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR);
219   if (py_descriptor == NULL) {
220     PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR");
221     return NULL;
222   }
223   if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) {
224     PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s",
225                  py_descriptor->ob_type->tp_name);
226     return NULL;
227   }
228 
229   // Messages have no __dict__
230   ScopedPyObjectPtr slots(PyTuple_New(0));
231   if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
232     return NULL;
233   }
234 
235   // Build the arguments to the base metaclass.
236   // We change the __bases__ classes.
237   ScopedPyObjectPtr new_args;
238   const Descriptor* message_descriptor =
239       PyMessageDescriptor_AsDescriptor(py_descriptor);
240   if (message_descriptor == NULL) {
241     return NULL;
242   }
243 
244   if (WKT_classes == NULL) {
245     ScopedPyObjectPtr well_known_types(PyImport_ImportModule(
246         "google.protobuf.internal.well_known_types"));
247     GOOGLE_DCHECK(well_known_types != NULL);
248 
249     WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
250     GOOGLE_DCHECK(WKT_classes != NULL);
251   }
252 
253   PyObject* well_known_class = PyDict_GetItemString(
254       WKT_classes, message_descriptor->full_name().c_str());
255   if (well_known_class == NULL) {
256     new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type,
257                                  PythonMessage_class, dict));
258   } else {
259     new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type,
260                                  PythonMessage_class, well_known_class, dict));
261   }
262 
263   if (new_args == NULL) {
264     return NULL;
265   }
266   // Call the base metaclass.
267   ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL));
268   if (result == NULL) {
269     return NULL;
270   }
271   CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
272 
273   // Insert the empty weakref into the base classes.
274   if (InsertEmptyWeakref(
275           reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 ||
276       InsertEmptyWeakref(CMessage_Type) < 0) {
277     return NULL;
278   }
279 
280   // Cache the descriptor, both as Python object and as C++ pointer.
281   const Descriptor* descriptor =
282       PyMessageDescriptor_AsDescriptor(py_descriptor);
283   if (descriptor == NULL) {
284     return NULL;
285   }
286   Py_INCREF(py_descriptor);
287   newtype->py_message_descriptor = py_descriptor;
288   newtype->message_descriptor = descriptor;
289   // TODO(amauryfa): Don't always use the canonical pool of the descriptor,
290   // use the MessageFactory optionally passed in the class dict.
291   PyDescriptorPool* py_descriptor_pool =
292       GetDescriptorPool_FromPool(descriptor->file()->pool());
293   if (py_descriptor_pool == NULL) {
294     return NULL;
295   }
296   newtype->py_message_factory = py_descriptor_pool->py_message_factory;
297   Py_INCREF(newtype->py_message_factory);
298 
299   // Register the message in the MessageFactory.
300   // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the
301   // MessageFactory is fully implemented in C++.
302   if (message_factory::RegisterMessageClass(newtype->py_message_factory,
303                                             descriptor, newtype) < 0) {
304     return NULL;
305   }
306 
307   // Continue with type initialization: add other descriptors, enum values...
308   if (AddDescriptors(result.get(), descriptor) < 0) {
309     return NULL;
310   }
311   return result.release();
312 }
313 
Dealloc(PyObject * pself)314 static void Dealloc(PyObject* pself) {
315   CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
316   Py_XDECREF(self->py_message_descriptor);
317   Py_XDECREF(self->py_message_factory);
318   return PyType_Type.tp_dealloc(pself);
319 }
320 
GcTraverse(PyObject * pself,visitproc visit,void * arg)321 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
322   CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
323   Py_VISIT(self->py_message_descriptor);
324   Py_VISIT(self->py_message_factory);
325   return PyType_Type.tp_traverse(pself, visit, arg);
326 }
327 
GcClear(PyObject * pself)328 static int GcClear(PyObject* pself) {
329   // It's important to keep the descriptor and factory alive, until the
330   // C++ message is fully destructed.
331   return PyType_Type.tp_clear(pself);
332 }
333 
334 // This function inserts and empty weakref at the end of the list of
335 // subclasses for the main protocol buffer Message class.
336 //
337 // This eliminates a O(n^2) behaviour in the internal add_subclass
338 // routine.
InsertEmptyWeakref(PyTypeObject * base_type)339 static int InsertEmptyWeakref(PyTypeObject *base_type) {
340 #if PY_MAJOR_VERSION >= 3
341   // Python 3.4 has already included the fix for the issue that this
342   // hack addresses. For further background and the fix please see
343   // https://bugs.python.org/issue17936.
344   return 0;
345 #else
346 #ifdef Py_DEBUG
347   // The code below causes all new subclasses to append an entry, which is never
348   // cleared. This is a small memory leak, which we disable in Py_DEBUG mode
349   // to have stable refcounting checks.
350 #else
351   PyObject *subclasses = base_type->tp_subclasses;
352   if (subclasses && PyList_CheckExact(subclasses)) {
353     return PyList_Append(subclasses, kEmptyWeakref);
354   }
355 #endif  // !Py_DEBUG
356   return 0;
357 #endif  // PY_MAJOR_VERSION >= 3
358 }
359 
360 // The _extensions_by_name dictionary is built on every access.
361 // TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
GetExtensionsByName(CMessageClass * self,void * closure)362 static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
363   if (self->message_descriptor == NULL) {
364     // This is the base Message object, simply raise AttributeError.
365     PyErr_SetString(PyExc_AttributeError,
366                     "Base Message class has no DESCRIPTOR");
367     return NULL;
368   }
369 
370   const PyDescriptorPool* pool = self->py_message_factory->pool;
371 
372   std::vector<const FieldDescriptor*> extensions;
373   pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
374 
375   ScopedPyObjectPtr result(PyDict_New());
376   for (int i = 0; i < extensions.size(); i++) {
377     ScopedPyObjectPtr extension(
378         PyFieldDescriptor_FromDescriptor(extensions[i]));
379     if (extension == NULL) {
380       return NULL;
381     }
382     if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
383                              extension.get()) < 0) {
384       return NULL;
385     }
386   }
387   return result.release();
388 }
389 
390 // The _extensions_by_number dictionary is built on every access.
391 // TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
GetExtensionsByNumber(CMessageClass * self,void * closure)392 static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
393   if (self->message_descriptor == NULL) {
394     // This is the base Message object, simply raise AttributeError.
395     PyErr_SetString(PyExc_AttributeError,
396                     "Base Message class has no DESCRIPTOR");
397     return NULL;
398   }
399 
400   const PyDescriptorPool* pool = self->py_message_factory->pool;
401 
402   std::vector<const FieldDescriptor*> extensions;
403   pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
404 
405   ScopedPyObjectPtr result(PyDict_New());
406   for (int i = 0; i < extensions.size(); i++) {
407     ScopedPyObjectPtr extension(
408         PyFieldDescriptor_FromDescriptor(extensions[i]));
409     if (extension == NULL) {
410       return NULL;
411     }
412     ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
413     if (number == NULL) {
414       return NULL;
415     }
416     if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
417       return NULL;
418     }
419   }
420   return result.release();
421 }
422 
423 static PyGetSetDef Getters[] = {
424   {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
425   {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
426   {NULL}
427 };
428 
429 // Compute some class attributes on the fly:
430 // - All the _FIELD_NUMBER attributes, for all fields and nested extensions.
431 // Returns a new reference, or NULL with an exception set.
GetClassAttribute(CMessageClass * self,PyObject * name)432 static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) {
433   char* attr;
434   Py_ssize_t attr_size;
435   static const char kSuffix[] = "_FIELD_NUMBER";
436   if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 &&
437       strings::EndsWith(StringPiece(attr, attr_size), kSuffix)) {
438     string field_name(attr, attr_size - sizeof(kSuffix) + 1);
439     LowerString(&field_name);
440 
441     // Try to find a field with the given name, without the suffix.
442     const FieldDescriptor* field =
443         self->message_descriptor->FindFieldByLowercaseName(field_name);
444     if (!field) {
445       // Search nested extensions as well.
446       field =
447           self->message_descriptor->FindExtensionByLowercaseName(field_name);
448     }
449     if (field) {
450       return PyInt_FromLong(field->number());
451     }
452   }
453   PyErr_SetObject(PyExc_AttributeError, name);
454   return NULL;
455 }
456 
GetAttr(CMessageClass * self,PyObject * name)457 static PyObject* GetAttr(CMessageClass* self, PyObject* name) {
458   PyObject* result = CMessageClass_Type->tp_base->tp_getattro(
459       reinterpret_cast<PyObject*>(self), name);
460   if (result != NULL) {
461     return result;
462   }
463   if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
464     return NULL;
465   }
466 
467   PyErr_Clear();
468   return GetClassAttribute(self, name);
469 }
470 
471 }  // namespace message_meta
472 
473 static PyTypeObject _CMessageClass_Type = {
474     PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
475     ".MessageMeta",                       // tp_name
476     sizeof(CMessageClass),                // tp_basicsize
477     0,                                    // tp_itemsize
478     message_meta::Dealloc,                // tp_dealloc
479     0,                                    // tp_print
480     0,                                    // tp_getattr
481     0,                                    // tp_setattr
482     0,                                    // tp_compare
483     0,                                    // tp_repr
484     0,                                    // tp_as_number
485     0,                                    // tp_as_sequence
486     0,                                    // tp_as_mapping
487     0,                                    // tp_hash
488     0,                                    // tp_call
489     0,                                    // tp_str
490     (getattrofunc)message_meta::GetAttr,  // tp_getattro
491     0,                                    // tp_setattro
492     0,                                    // tp_as_buffer
493     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,  // tp_flags
494     "The metaclass of ProtocolMessages",                            // tp_doc
495     message_meta::GcTraverse,  // tp_traverse
496     message_meta::GcClear,     // tp_clear
497     0,                         // tp_richcompare
498     0,                         // tp_weaklistoffset
499     0,                         // tp_iter
500     0,                         // tp_iternext
501     0,                         // tp_methods
502     0,                         // tp_members
503     message_meta::Getters,     // tp_getset
504     0,                         // tp_base
505     0,                         // tp_dict
506     0,                         // tp_descr_get
507     0,                         // tp_descr_set
508     0,                         // tp_dictoffset
509     0,                         // tp_init
510     0,                         // tp_alloc
511     message_meta::New,         // tp_new
512 };
513 PyTypeObject* CMessageClass_Type = &_CMessageClass_Type;
514 
CheckMessageClass(PyTypeObject * cls)515 static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
516   if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
517     PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
518     return NULL;
519   }
520   return reinterpret_cast<CMessageClass*>(cls);
521 }
522 
GetMessageDescriptor(PyTypeObject * cls)523 static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
524   CMessageClass* type = CheckMessageClass(cls);
525   if (type == NULL) {
526     return NULL;
527   }
528   return type->message_descriptor;
529 }
530 
531 // Forward declarations
532 namespace cmessage {
533 int InternalReleaseFieldByDescriptor(
534     CMessage* self,
535     const FieldDescriptor* field_descriptor);
536 }  // namespace cmessage
537 
538 // ---------------------------------------------------------------------
539 
540 PyObject* EncodeError_class;
541 PyObject* DecodeError_class;
542 PyObject* PickleError_class;
543 
544 // Format an error message for unexpected types.
545 // Always return with an exception set.
FormatTypeError(PyObject * arg,char * expected_types)546 void FormatTypeError(PyObject* arg, char* expected_types) {
547   // This function is often called with an exception set.
548   // Clear it to call PyObject_Repr() in good conditions.
549   PyErr_Clear();
550   PyObject* repr = PyObject_Repr(arg);
551   if (repr) {
552     PyErr_Format(PyExc_TypeError,
553                  "%.100s has type %.100s, but expected one of: %s",
554                  PyString_AsString(repr),
555                  Py_TYPE(arg)->tp_name,
556                  expected_types);
557     Py_DECREF(repr);
558   }
559 }
560 
OutOfRangeError(PyObject * arg)561 void OutOfRangeError(PyObject* arg) {
562   PyObject *s = PyObject_Str(arg);
563   if (s) {
564     PyErr_Format(PyExc_ValueError,
565                  "Value out of range: %s",
566                  PyString_AsString(s));
567     Py_DECREF(s);
568   }
569 }
570 
571 template<class RangeType, class ValueType>
VerifyIntegerCastAndRange(PyObject * arg,ValueType value)572 bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
573   if (PROTOBUF_PREDICT_FALSE(value == -1 && PyErr_Occurred())) {
574     if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
575       // Replace it with the same ValueError as pure python protos instead of
576       // the default one.
577       PyErr_Clear();
578       OutOfRangeError(arg);
579     }  // Otherwise propagate existing error.
580     return false;
581     }
582     if (PROTOBUF_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) {
583       OutOfRangeError(arg);
584       return false;
585     }
586   return true;
587 }
588 
589 template<class T>
CheckAndGetInteger(PyObject * arg,T * value)590 bool CheckAndGetInteger(PyObject* arg, T* value) {
591   // The fast path.
592 #if PY_MAJOR_VERSION < 3
593   // For the typical case, offer a fast path.
594   if (PROTOBUF_PREDICT_TRUE(PyInt_Check(arg))) {
595     long int_result = PyInt_AsLong(arg);
596     if (PROTOBUF_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) {
597       *value = static_cast<T>(int_result);
598       return true;
599     } else {
600       OutOfRangeError(arg);
601       return false;
602     }
603   }
604 #endif
605   // This effectively defines an integer as "an object that can be cast as
606   // an integer and can be used as an ordinal number".
607   // This definition includes everything that implements numbers.Integral
608   // and shouldn't cast the net too wide.
609     if (PROTOBUF_PREDICT_FALSE(!PyIndex_Check(arg))) {
610       FormatTypeError(arg, "int, long");
611       return false;
612     }
613 
614   // Now we have an integral number so we can safely use PyLong_ functions.
615   // We need to treat the signed and unsigned cases differently in case arg is
616   // holding a value above the maximum for signed longs.
617   if (std::numeric_limits<T>::min() == 0) {
618     // Unsigned case.
619     unsigned PY_LONG_LONG ulong_result;
620     if (PyLong_Check(arg)) {
621       ulong_result = PyLong_AsUnsignedLongLong(arg);
622     } else {
623       // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
624       // picky about the exact type.
625       PyObject* casted = PyNumber_Long(arg);
626       if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
627         // Propagate existing error.
628         return false;
629         }
630       ulong_result = PyLong_AsUnsignedLongLong(casted);
631       Py_DECREF(casted);
632     }
633     if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
634                                                             ulong_result)) {
635       *value = static_cast<T>(ulong_result);
636     } else {
637       return false;
638     }
639   } else {
640     // Signed case.
641     PY_LONG_LONG long_result;
642     PyNumberMethods *nb;
643     if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
644       // PyLong_AsLongLong requires it to be a long or to have an __int__()
645       // method.
646       long_result = PyLong_AsLongLong(arg);
647     } else {
648       // Valid subclasses of numbers.Integral should have a __long__() method
649       // so fall back to that.
650       PyObject* casted = PyNumber_Long(arg);
651       if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
652         // Propagate existing error.
653         return false;
654         }
655       long_result = PyLong_AsLongLong(casted);
656       Py_DECREF(casted);
657     }
658     if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
659       *value = static_cast<T>(long_result);
660     } else {
661       return false;
662     }
663   }
664 
665   return true;
666 }
667 
668 // These are referenced by repeated_scalar_container, and must
669 // be explicitly instantiated.
670 template bool CheckAndGetInteger<int32>(PyObject*, int32*);
671 template bool CheckAndGetInteger<int64>(PyObject*, int64*);
672 template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
673 template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
674 
CheckAndGetDouble(PyObject * arg,double * value)675 bool CheckAndGetDouble(PyObject* arg, double* value) {
676   *value = PyFloat_AsDouble(arg);
677   if (PROTOBUF_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) {
678     FormatTypeError(arg, "int, long, float");
679     return false;
680     }
681   return true;
682 }
683 
CheckAndGetFloat(PyObject * arg,float * value)684 bool CheckAndGetFloat(PyObject* arg, float* value) {
685   double double_value;
686   if (!CheckAndGetDouble(arg, &double_value)) {
687     return false;
688   }
689   *value = io::SafeDoubleToFloat(double_value);
690   return true;
691 }
692 
CheckAndGetBool(PyObject * arg,bool * value)693 bool CheckAndGetBool(PyObject* arg, bool* value) {
694   long long_value = PyInt_AsLong(arg);
695   if (long_value == -1 && PyErr_Occurred()) {
696     FormatTypeError(arg, "int, long, bool");
697     return false;
698   }
699   *value = static_cast<bool>(long_value);
700 
701   return true;
702 }
703 
704 // Checks whether the given object (which must be "bytes" or "unicode") contains
705 // valid UTF-8.
IsValidUTF8(PyObject * obj)706 bool IsValidUTF8(PyObject* obj) {
707   if (PyBytes_Check(obj)) {
708     PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL);
709 
710     // Clear the error indicator; we report our own error when desired.
711     PyErr_Clear();
712 
713     if (unicode) {
714       Py_DECREF(unicode);
715       return true;
716     } else {
717       return false;
718     }
719   } else {
720     // Unicode object, known to be valid UTF-8.
721     return true;
722   }
723 }
724 
AllowInvalidUTF8(const FieldDescriptor * field)725 bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; }
726 
CheckString(PyObject * arg,const FieldDescriptor * descriptor)727 PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
728   GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING ||
729          descriptor->type() == FieldDescriptor::TYPE_BYTES);
730   if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
731     if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
732       FormatTypeError(arg, "bytes, unicode");
733       return NULL;
734     }
735 
736     if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) {
737       PyObject* repr = PyObject_Repr(arg);
738       PyErr_Format(PyExc_ValueError,
739                    "%s has type str, but isn't valid UTF-8 "
740                    "encoding. Non-UTF-8 strings must be converted to "
741                    "unicode objects before being added.",
742                    PyString_AsString(repr));
743       Py_DECREF(repr);
744       return NULL;
745     }
746   } else if (!PyBytes_Check(arg)) {
747     FormatTypeError(arg, "bytes");
748     return NULL;
749   }
750 
751   PyObject* encoded_string = NULL;
752   if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
753     if (PyBytes_Check(arg)) {
754       // The bytes were already validated as correctly encoded UTF-8 above.
755       encoded_string = arg;  // Already encoded.
756       Py_INCREF(encoded_string);
757     } else {
758       encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
759     }
760   } else {
761     // In this case field type is "bytes".
762     encoded_string = arg;
763     Py_INCREF(encoded_string);
764   }
765 
766   return encoded_string;
767 }
768 
CheckAndSetString(PyObject * arg,Message * message,const FieldDescriptor * descriptor,const Reflection * reflection,bool append,int index)769 bool CheckAndSetString(
770     PyObject* arg, Message* message,
771     const FieldDescriptor* descriptor,
772     const Reflection* reflection,
773     bool append,
774     int index) {
775   ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor));
776 
777   if (encoded_string.get() == NULL) {
778     return false;
779   }
780 
781   char* value;
782   Py_ssize_t value_len;
783   if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) {
784     return false;
785   }
786 
787   string value_string(value, value_len);
788   if (append) {
789     reflection->AddString(message, descriptor, value_string);
790   } else if (index < 0) {
791     reflection->SetString(message, descriptor, value_string);
792   } else {
793     reflection->SetRepeatedString(message, descriptor, index, value_string);
794   }
795   return true;
796 }
797 
ToStringObject(const FieldDescriptor * descriptor,const string & value)798 PyObject* ToStringObject(const FieldDescriptor* descriptor,
799                          const string& value) {
800   if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
801     return PyBytes_FromStringAndSize(value.c_str(), value.length());
802   }
803 
804   PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL);
805   // If the string can't be decoded in UTF-8, just return a string object that
806   // contains the raw bytes. This can't happen if the value was assigned using
807   // the members of the Python message object, but can happen if the values were
808   // parsed from the wire (binary).
809   if (result == NULL) {
810     PyErr_Clear();
811     result = PyBytes_FromStringAndSize(value.c_str(), value.length());
812   }
813   return result;
814 }
815 
CheckFieldBelongsToMessage(const FieldDescriptor * field_descriptor,const Message * message)816 bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
817                                 const Message* message) {
818   if (message->GetDescriptor() == field_descriptor->containing_type()) {
819     return true;
820   }
821   PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'",
822                field_descriptor->full_name().c_str(),
823                message->GetDescriptor()->full_name().c_str());
824   return false;
825 }
826 
827 namespace cmessage {
828 
GetFactoryForMessage(CMessage * message)829 PyMessageFactory* GetFactoryForMessage(CMessage* message) {
830   GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type));
831   return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
832 }
833 
MaybeReleaseOverlappingOneofField(CMessage * cmessage,const FieldDescriptor * field)834 static int MaybeReleaseOverlappingOneofField(
835     CMessage* cmessage,
836     const FieldDescriptor* field) {
837 #ifdef GOOGLE_PROTOBUF_HAS_ONEOF
838   Message* message = cmessage->message;
839   const Reflection* reflection = message->GetReflection();
840   if (!field->containing_oneof() ||
841       !reflection->HasOneof(*message, field->containing_oneof()) ||
842       reflection->HasField(*message, field)) {
843     // No other field in this oneof, no need to release.
844     return 0;
845   }
846 
847   const OneofDescriptor* oneof = field->containing_oneof();
848   const FieldDescriptor* existing_field =
849       reflection->GetOneofFieldDescriptor(*message, oneof);
850   if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
851     // Non-message fields don't need to be released.
852     return 0;
853   }
854   if (InternalReleaseFieldByDescriptor(cmessage, existing_field) < 0) {
855     return -1;
856   }
857 #endif
858   return 0;
859 }
860 
861 // After a Merge, visit every sub-message that was read-only, and
862 // eventually update their pointer if the Merge operation modified them.
FixupMessageAfterMerge(CMessage * self)863 int FixupMessageAfterMerge(CMessage* self) {
864   if (!self->composite_fields) {
865     return 0;
866   }
867   for (const auto& item : *self->composite_fields) {
868     const FieldDescriptor* descriptor = item.first;
869     if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
870         !descriptor->is_repeated()) {
871       CMessage* cmsg = reinterpret_cast<CMessage*>(item.second);
872       if (cmsg->read_only == false) {
873         return 0;
874       }
875       Message* message = self->message;
876       const Reflection* reflection = message->GetReflection();
877       if (reflection->HasField(*message, descriptor)) {
878         // Message used to be read_only, but is no longer. Get the new pointer
879         // and record it.
880         Message* mutable_message =
881             reflection->MutableMessage(message, descriptor, nullptr);
882         cmsg->message = mutable_message;
883         cmsg->read_only = false;
884         if (FixupMessageAfterMerge(cmsg) < 0) {
885           return -1;
886         }
887       }
888     }
889   }
890 
891   return 0;
892 }
893 
894 // ---------------------------------------------------------------------
895 // Making a message writable
896 
AssureWritable(CMessage * self)897 int AssureWritable(CMessage* self) {
898   if (self == NULL || !self->read_only) {
899     return 0;
900   }
901 
902   // Toplevel messages are always mutable.
903   GOOGLE_DCHECK(self->parent);
904 
905   if (AssureWritable(self->parent) == -1)
906     return -1;
907 
908   // If this message is part of a oneof, there might be a field to release in
909   // the parent.
910   if (MaybeReleaseOverlappingOneofField(self->parent,
911                                         self->parent_field_descriptor) < 0) {
912     return -1;
913   }
914 
915   // Make self->message writable.
916   Message* parent_message = self->parent->message;
917   const Reflection* reflection = parent_message->GetReflection();
918   Message* mutable_message = reflection->MutableMessage(
919       parent_message, self->parent_field_descriptor,
920       GetFactoryForMessage(self->parent)->message_factory);
921   if (mutable_message == NULL) {
922     return -1;
923   }
924   self->message = mutable_message;
925   self->read_only = false;
926 
927   return 0;
928 }
929 
930 // --- Globals:
931 
932 // Retrieve a C++ FieldDescriptor for an extension handle.
GetExtensionDescriptor(PyObject * extension)933 const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
934   ScopedPyObjectPtr cdescriptor;
935   if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) {
936     // Most callers consider extensions as a plain dictionary.  We should
937     // allow input which is not a field descriptor, and simply pretend it does
938     // not exist.
939     PyErr_SetObject(PyExc_KeyError, extension);
940     return NULL;
941   }
942   return PyFieldDescriptor_AsDescriptor(extension);
943 }
944 
945 // If value is a string, convert it into an enum value based on the labels in
946 // descriptor, otherwise simply return value.  Always returns a new reference.
GetIntegerEnumValue(const FieldDescriptor & descriptor,PyObject * value)947 static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
948                                      PyObject* value) {
949   if (PyString_Check(value) || PyUnicode_Check(value)) {
950     const EnumDescriptor* enum_descriptor = descriptor.enum_type();
951     if (enum_descriptor == NULL) {
952       PyErr_SetString(PyExc_TypeError, "not an enum field");
953       return NULL;
954     }
955     char* enum_label;
956     Py_ssize_t size;
957     if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) {
958       return NULL;
959     }
960     const EnumValueDescriptor* enum_value_descriptor =
961         enum_descriptor->FindValueByName(string(enum_label, size));
962     if (enum_value_descriptor == NULL) {
963       PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label);
964       return NULL;
965     }
966     return PyInt_FromLong(enum_value_descriptor->number());
967   }
968   Py_INCREF(value);
969   return value;
970 }
971 
972 // Delete a slice from a repeated field.
973 // The only way to remove items in C++ protos is to delete the last one,
974 // so we swap items to move the deleted ones at the end, and then strip the
975 // sequence.
DeleteRepeatedField(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * slice)976 int DeleteRepeatedField(
977     CMessage* self,
978     const FieldDescriptor* field_descriptor,
979     PyObject* slice) {
980   Py_ssize_t length, from, to, step, slice_length;
981   Message* message = self->message;
982   const Reflection* reflection = message->GetReflection();
983   int min, max;
984   length = reflection->FieldSize(*message, field_descriptor);
985 
986   if (PySlice_Check(slice)) {
987     from = to = step = slice_length = 0;
988 #if PY_MAJOR_VERSION < 3
989     PySlice_GetIndicesEx(
990         reinterpret_cast<PySliceObject*>(slice),
991         length, &from, &to, &step, &slice_length);
992 #else
993     PySlice_GetIndicesEx(
994         slice,
995         length, &from, &to, &step, &slice_length);
996 #endif
997     if (from < to) {
998       min = from;
999       max = to - 1;
1000     } else {
1001       min = to + 1;
1002       max = from;
1003     }
1004   } else {
1005     from = to = PyLong_AsLong(slice);
1006     if (from == -1 && PyErr_Occurred()) {
1007       PyErr_SetString(PyExc_TypeError, "list indices must be integers");
1008       return -1;
1009     }
1010 
1011     if (from < 0) {
1012       from = to = length + from;
1013     }
1014     step = 1;
1015     min = max = from;
1016 
1017     // Range check.
1018     if (from < 0 || from >= length) {
1019       PyErr_Format(PyExc_IndexError, "list assignment index out of range");
1020       return -1;
1021     }
1022   }
1023 
1024   Py_ssize_t i = from;
1025   std::vector<bool> to_delete(length, false);
1026   while (i >= min && i <= max) {
1027     to_delete[i] = true;
1028     i += step;
1029   }
1030 
1031   // Swap elements so that items to delete are at the end.
1032   to = 0;
1033   for (i = 0; i < length; ++i) {
1034     if (!to_delete[i]) {
1035       if (i != to) {
1036         reflection->SwapElements(message, field_descriptor, i, to);
1037       }
1038       ++to;
1039     }
1040   }
1041 
1042   // Remove items, starting from the end.
1043   for (; length > to; length--) {
1044     if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1045       reflection->RemoveLast(message, field_descriptor);
1046       continue;
1047     }
1048     // It seems that RemoveLast() is less efficient for sub-messages, and
1049     // the memory is not completely released. Prefer ReleaseLast().
1050     Message* sub_message = reflection->ReleaseLast(message, field_descriptor);
1051     // If there is a live weak reference to an item being removed, we "Release"
1052     // it, and it takes ownership of the message.
1053     if (CMessage* released = self->MaybeReleaseSubMessage(sub_message)) {
1054       released->message = sub_message;
1055     } else {
1056       // sub_message was not transferred, delete it.
1057       delete sub_message;
1058     }
1059   }
1060 
1061   return 0;
1062 }
1063 
1064 // Initializes fields of a message. Used in constructors.
InitAttributes(CMessage * self,PyObject * args,PyObject * kwargs)1065 int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
1066   if (args != NULL && PyTuple_Size(args) != 0) {
1067     PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
1068     return -1;
1069   }
1070 
1071   if (kwargs == NULL) {
1072     return 0;
1073   }
1074 
1075   Py_ssize_t pos = 0;
1076   PyObject* name;
1077   PyObject* value;
1078   while (PyDict_Next(kwargs, &pos, &name, &value)) {
1079     if (!(PyString_Check(name) || PyUnicode_Check(name))) {
1080       PyErr_SetString(PyExc_ValueError, "Field name must be a string");
1081       return -1;
1082     }
1083     ScopedPyObjectPtr property(
1084         PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name));
1085     if (property == NULL ||
1086         !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) {
1087       PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
1088                    self->message->GetDescriptor()->name().c_str(),
1089                    PyString_AsString(name));
1090       return -1;
1091     }
1092     const FieldDescriptor* descriptor =
1093         reinterpret_cast<PyMessageFieldProperty*>(property.get())
1094             ->field_descriptor;
1095     if (value == Py_None) {
1096       // field=None is the same as no field at all.
1097       continue;
1098     }
1099     if (descriptor->is_map()) {
1100       ScopedPyObjectPtr map(GetFieldValue(self, descriptor));
1101       const FieldDescriptor* value_descriptor =
1102           descriptor->message_type()->FindFieldByName("value");
1103       if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1104         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1105         if (iter == NULL) {
1106           PyErr_Format(PyExc_TypeError, "Argument %s is not iterable", PyString_AsString(name));
1107           return -1;
1108         }
1109         ScopedPyObjectPtr next;
1110         while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1111           ScopedPyObjectPtr source_value(PyObject_GetItem(value, next.get()));
1112           ScopedPyObjectPtr dest_value(PyObject_GetItem(map.get(), next.get()));
1113           if (source_value.get() == NULL || dest_value.get() == NULL) {
1114             return -1;
1115           }
1116           ScopedPyObjectPtr ok(PyObject_CallMethod(
1117               dest_value.get(), "MergeFrom", "O", source_value.get()));
1118           if (ok.get() == NULL) {
1119             return -1;
1120           }
1121         }
1122       } else {
1123         ScopedPyObjectPtr function_return;
1124         function_return.reset(
1125             PyObject_CallMethod(map.get(), "update", "O", value));
1126         if (function_return.get() == NULL) {
1127           return -1;
1128         }
1129       }
1130     } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1131       ScopedPyObjectPtr container(GetFieldValue(self, descriptor));
1132       if (container == NULL) {
1133         return -1;
1134       }
1135       if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1136         RepeatedCompositeContainer* rc_container =
1137             reinterpret_cast<RepeatedCompositeContainer*>(container.get());
1138         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1139         if (iter == NULL) {
1140           PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1141           return -1;
1142         }
1143         ScopedPyObjectPtr next;
1144         while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1145           PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL);
1146           ScopedPyObjectPtr new_msg(
1147               repeated_composite_container::Add(rc_container, NULL, kwargs));
1148           if (new_msg == NULL) {
1149             return -1;
1150           }
1151           if (kwargs == NULL) {
1152             // next was not a dict, it's a message we need to merge
1153             ScopedPyObjectPtr merged(MergeFrom(
1154                 reinterpret_cast<CMessage*>(new_msg.get()), next.get()));
1155             if (merged.get() == NULL) {
1156               return -1;
1157             }
1158           }
1159         }
1160         if (PyErr_Occurred()) {
1161           // Check to see how PyIter_Next() exited.
1162           return -1;
1163         }
1164       } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1165         RepeatedScalarContainer* rs_container =
1166             reinterpret_cast<RepeatedScalarContainer*>(container.get());
1167         ScopedPyObjectPtr iter(PyObject_GetIter(value));
1168         if (iter == NULL) {
1169           PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1170           return -1;
1171         }
1172         ScopedPyObjectPtr next;
1173         while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1174           ScopedPyObjectPtr enum_value(
1175               GetIntegerEnumValue(*descriptor, next.get()));
1176           if (enum_value == NULL) {
1177             return -1;
1178           }
1179           ScopedPyObjectPtr new_msg(repeated_scalar_container::Append(
1180               rs_container, enum_value.get()));
1181           if (new_msg == NULL) {
1182             return -1;
1183           }
1184         }
1185         if (PyErr_Occurred()) {
1186           // Check to see how PyIter_Next() exited.
1187           return -1;
1188         }
1189       } else {
1190         if (ScopedPyObjectPtr(repeated_scalar_container::Extend(
1191                 reinterpret_cast<RepeatedScalarContainer*>(container.get()),
1192                 value)) ==
1193             NULL) {
1194           return -1;
1195         }
1196       }
1197     } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1198       ScopedPyObjectPtr message(GetFieldValue(self, descriptor));
1199       if (message == NULL) {
1200         return -1;
1201       }
1202       CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
1203       if (PyDict_Check(value)) {
1204         // Make the message exist even if the dict is empty.
1205         AssureWritable(cmessage);
1206         if (InitAttributes(cmessage, NULL, value) < 0) {
1207           return -1;
1208         }
1209       } else {
1210         ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
1211         if (merged == NULL) {
1212           return -1;
1213         }
1214       }
1215     } else {
1216       ScopedPyObjectPtr new_val;
1217       if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1218         new_val.reset(GetIntegerEnumValue(*descriptor, value));
1219         if (new_val == NULL) {
1220           return -1;
1221         }
1222         value = new_val.get();
1223       }
1224       if (SetFieldValue(self, descriptor, value) < 0) {
1225         return -1;
1226       }
1227     }
1228   }
1229   return 0;
1230 }
1231 
1232 // Allocates an incomplete Python Message: the caller must fill self->message
1233 // and eventually self->parent.
NewEmptyMessage(CMessageClass * type)1234 CMessage* NewEmptyMessage(CMessageClass* type) {
1235   CMessage* self = reinterpret_cast<CMessage*>(
1236       PyType_GenericAlloc(&type->super.ht_type, 0));
1237   if (self == NULL) {
1238     return NULL;
1239   }
1240 
1241   self->message = NULL;
1242   self->parent = NULL;
1243   self->parent_field_descriptor = NULL;
1244   self->read_only = false;
1245 
1246   self->composite_fields = NULL;
1247   self->child_submessages = NULL;
1248 
1249   self->unknown_field_set = NULL;
1250 
1251   return self;
1252 }
1253 
1254 // The __new__ method of Message classes.
1255 // Creates a new C++ message and takes ownership.
New(PyTypeObject * cls,PyObject * unused_args,PyObject * unused_kwargs)1256 static PyObject* New(PyTypeObject* cls,
1257                      PyObject* unused_args, PyObject* unused_kwargs) {
1258   CMessageClass* type = CheckMessageClass(cls);
1259   if (type == NULL) {
1260     return NULL;
1261   }
1262   // Retrieve the message descriptor and the default instance (=prototype).
1263   const Descriptor* message_descriptor = type->message_descriptor;
1264   if (message_descriptor == NULL) {
1265     return NULL;
1266   }
1267   const Message* prototype =
1268       type->py_message_factory->message_factory->GetPrototype(
1269           message_descriptor);
1270   if (prototype == NULL) {
1271     PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
1272     return NULL;
1273   }
1274 
1275   CMessage* self = NewEmptyMessage(type);
1276   if (self == NULL) {
1277     return NULL;
1278   }
1279   self->message = prototype->New();
1280   self->parent = nullptr;  // This message owns its data.
1281   return reinterpret_cast<PyObject*>(self);
1282 }
1283 
1284 // The __init__ method of Message classes.
1285 // It initializes fields from keywords passed to the constructor.
Init(CMessage * self,PyObject * args,PyObject * kwargs)1286 static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
1287   return InitAttributes(self, args, kwargs);
1288 }
1289 
1290 // ---------------------------------------------------------------------
1291 // Deallocating a CMessage
1292 
Dealloc(CMessage * self)1293 static void Dealloc(CMessage* self) {
1294   if (self->weakreflist) {
1295     PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
1296   }
1297   // At this point all dependent objects have been removed.
1298   GOOGLE_DCHECK(!self->child_submessages || self->child_submessages->empty());
1299   GOOGLE_DCHECK(!self->composite_fields || self->composite_fields->empty());
1300   delete self->child_submessages;
1301   delete self->composite_fields;
1302   if (self->unknown_field_set) {
1303     unknown_fields::Clear(
1304         reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1305   }
1306 
1307   CMessage* parent = self->parent;
1308   if (!parent) {
1309     // No parent, we own the message.
1310     delete self->message;
1311   } else if (parent->AsPyObject() == Py_None) {
1312     // Message owned externally: Nothing to dealloc
1313     Py_CLEAR(self->parent);
1314   } else {
1315     // Clear this message from its parent's map.
1316     if (self->parent_field_descriptor->is_repeated()) {
1317       if (parent->child_submessages)
1318         parent->child_submessages->erase(self->message);
1319     } else {
1320       if (parent->composite_fields)
1321         parent->composite_fields->erase(self->parent_field_descriptor);
1322     }
1323     Py_CLEAR(self->parent);
1324   }
1325   Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
1326 }
1327 
1328 // ---------------------------------------------------------------------
1329 
1330 
IsInitialized(CMessage * self,PyObject * args)1331 PyObject* IsInitialized(CMessage* self, PyObject* args) {
1332   PyObject* errors = NULL;
1333   if (PyArg_ParseTuple(args, "|O", &errors) < 0) {
1334     return NULL;
1335   }
1336   if (self->message->IsInitialized()) {
1337     Py_RETURN_TRUE;
1338   }
1339   if (errors != NULL) {
1340     ScopedPyObjectPtr initialization_errors(
1341         FindInitializationErrors(self));
1342     if (initialization_errors == NULL) {
1343       return NULL;
1344     }
1345     ScopedPyObjectPtr extend_name(PyString_FromString("extend"));
1346     if (extend_name == NULL) {
1347       return NULL;
1348     }
1349     ScopedPyObjectPtr result(PyObject_CallMethodObjArgs(
1350         errors,
1351         extend_name.get(),
1352         initialization_errors.get(),
1353         NULL));
1354     if (result == NULL) {
1355       return NULL;
1356     }
1357   }
1358   Py_RETURN_FALSE;
1359 }
1360 
HasFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1361 PyObject* HasFieldByDescriptor(
1362     CMessage* self, const FieldDescriptor* field_descriptor) {
1363   Message* message = self->message;
1364   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
1365     return NULL;
1366   }
1367   if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1368     PyErr_SetString(PyExc_KeyError,
1369                     "Field is repeated. A singular method is required.");
1370     return NULL;
1371   }
1372   bool has_field =
1373       message->GetReflection()->HasField(*message, field_descriptor);
1374   return PyBool_FromLong(has_field ? 1 : 0);
1375 }
1376 
FindFieldWithOneofs(const Message * message,const string & field_name,bool * in_oneof)1377 const FieldDescriptor* FindFieldWithOneofs(
1378     const Message* message, const string& field_name, bool* in_oneof) {
1379   *in_oneof = false;
1380   const Descriptor* descriptor = message->GetDescriptor();
1381   const FieldDescriptor* field_descriptor =
1382       descriptor->FindFieldByName(field_name);
1383   if (field_descriptor != NULL) {
1384     return field_descriptor;
1385   }
1386   const OneofDescriptor* oneof_desc =
1387       descriptor->FindOneofByName(field_name);
1388   if (oneof_desc != NULL) {
1389     *in_oneof = true;
1390     return message->GetReflection()->GetOneofFieldDescriptor(*message,
1391                                                              oneof_desc);
1392   }
1393   return NULL;
1394 }
1395 
CheckHasPresence(const FieldDescriptor * field_descriptor,bool in_oneof)1396 bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) {
1397   auto message_name = field_descriptor->containing_type()->name();
1398   if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1399     PyErr_Format(PyExc_ValueError,
1400                  "Protocol message %s has no singular \"%s\" field.",
1401                  message_name.c_str(), field_descriptor->name().c_str());
1402     return false;
1403   }
1404 
1405   if (field_descriptor->file()->syntax() == FileDescriptor::SYNTAX_PROTO3) {
1406     // HasField() for a oneof *itself* isn't supported.
1407     if (in_oneof) {
1408       PyErr_Format(PyExc_ValueError,
1409                    "Can't test oneof field \"%s.%s\" for presence in proto3, "
1410                    "use WhichOneof instead.", message_name.c_str(),
1411                    field_descriptor->containing_oneof()->name().c_str());
1412       return false;
1413     }
1414 
1415     // ...but HasField() for fields *in* a oneof is supported.
1416     if (field_descriptor->containing_oneof() != NULL) {
1417       return true;
1418     }
1419 
1420     if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1421       PyErr_Format(
1422           PyExc_ValueError,
1423           "Can't test non-submessage field \"%s.%s\" for presence in proto3.",
1424           message_name.c_str(), field_descriptor->name().c_str());
1425       return false;
1426     }
1427   }
1428 
1429   return true;
1430 }
1431 
HasField(CMessage * self,PyObject * arg)1432 PyObject* HasField(CMessage* self, PyObject* arg) {
1433   char* field_name;
1434   Py_ssize_t size;
1435 #if PY_MAJOR_VERSION < 3
1436   if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) {
1437     return NULL;
1438   }
1439 #else
1440   field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
1441   if (!field_name) {
1442     return NULL;
1443   }
1444 #endif
1445 
1446   Message* message = self->message;
1447   bool is_in_oneof;
1448   const FieldDescriptor* field_descriptor =
1449       FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof);
1450   if (field_descriptor == NULL) {
1451     if (!is_in_oneof) {
1452       PyErr_Format(PyExc_ValueError, "Protocol message %s has no field %s.",
1453                    message->GetDescriptor()->name().c_str(), field_name);
1454       return NULL;
1455     } else {
1456       Py_RETURN_FALSE;
1457     }
1458   }
1459 
1460   if (!CheckHasPresence(field_descriptor, is_in_oneof)) {
1461     return NULL;
1462   }
1463 
1464   if (message->GetReflection()->HasField(*message, field_descriptor)) {
1465     Py_RETURN_TRUE;
1466   }
1467 
1468   Py_RETURN_FALSE;
1469 }
1470 
ClearExtension(CMessage * self,PyObject * extension)1471 PyObject* ClearExtension(CMessage* self, PyObject* extension) {
1472   const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1473   if (descriptor == NULL) {
1474     return NULL;
1475   }
1476   if (InternalReleaseFieldByDescriptor(self, descriptor) < 0) {
1477     return NULL;
1478   }
1479   return ClearFieldByDescriptor(self, descriptor);
1480 }
1481 
HasExtension(CMessage * self,PyObject * extension)1482 PyObject* HasExtension(CMessage* self, PyObject* extension) {
1483   const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1484   if (descriptor == NULL) {
1485     return NULL;
1486   }
1487   return HasFieldByDescriptor(self, descriptor);
1488 }
1489 
1490 // ---------------------------------------------------------------------
1491 // Releasing messages
1492 //
1493 // The Python API's ClearField() and Clear() methods behave
1494 // differently than their C++ counterparts.  While the C++ versions
1495 // clears the children, the Python versions detaches the children,
1496 // without touching their content.  This impedance mismatch causes
1497 // some complexity in the implementation, which is captured in this
1498 // section.
1499 //
1500 // When one or multiple fields are cleared we need to:
1501 //
1502 // * Gather all child objects that need to be detached from the message.
1503 //   In composite_fields and child_submessages.
1504 //
1505 // * Create a new Python message of the same kind. Use SwapFields() to move
1506 //   data from the original message.
1507 //
1508 // * Change the parent of all child objects: update their strong reference
1509 //   to their parent, and move their presence in composite_fields and
1510 //   child_submessages.
1511 
1512 // ---------------------------------------------------------------------
1513 // Release a composite child of a CMessage
1514 
InternalReparentFields(CMessage * self,const std::vector<CMessage * > & messages_to_release,const std::vector<ContainerBase * > & containers_to_release)1515 static int InternalReparentFields(
1516     CMessage* self, const std::vector<CMessage*>& messages_to_release,
1517     const std::vector<ContainerBase*>& containers_to_release) {
1518   if (messages_to_release.empty() && containers_to_release.empty()) {
1519     return 0;
1520   }
1521 
1522   // Move all the passed sub_messages to another message.
1523   CMessage* new_message = cmessage::NewEmptyMessage(self->GetMessageClass());
1524   if (new_message == nullptr) {
1525     return -1;
1526   }
1527   new_message->message = self->message->New();
1528   ScopedPyObjectPtr holder(reinterpret_cast<PyObject*>(new_message));
1529   new_message->child_submessages = new CMessage::SubMessagesMap();
1530   new_message->composite_fields = new CMessage::CompositeFieldsMap();
1531   std::set<const FieldDescriptor*> fields_to_swap;
1532 
1533   // In case this the removed fields are the last reference to a message, keep
1534   // a reference.
1535   Py_INCREF(self);
1536 
1537   for (const auto& to_release : messages_to_release) {
1538     fields_to_swap.insert(to_release->parent_field_descriptor);
1539     // Reparent
1540     Py_INCREF(new_message);
1541     Py_DECREF(to_release->parent);
1542     to_release->parent = new_message;
1543     self->child_submessages->erase(to_release->message);
1544     new_message->child_submessages->emplace(to_release->message, to_release);
1545   }
1546 
1547   for (const auto& to_release : containers_to_release) {
1548     fields_to_swap.insert(to_release->parent_field_descriptor);
1549     Py_INCREF(new_message);
1550     Py_DECREF(to_release->parent);
1551     to_release->parent = new_message;
1552     self->composite_fields->erase(to_release->parent_field_descriptor);
1553     new_message->composite_fields->emplace(to_release->parent_field_descriptor,
1554                                            to_release);
1555   }
1556 
1557   self->message->GetReflection()->SwapFields(
1558       self->message, new_message->message,
1559       std::vector<const FieldDescriptor*>(fields_to_swap.begin(),
1560                                           fields_to_swap.end()));
1561 
1562   // This might delete the Python message completely if all children were moved.
1563   Py_DECREF(self);
1564 
1565   return 0;
1566 }
1567 
InternalReleaseFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1568 int InternalReleaseFieldByDescriptor(
1569     CMessage* self,
1570     const FieldDescriptor* field_descriptor) {
1571   if (!field_descriptor->is_repeated() &&
1572       field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1573     // Single scalars are not in any cache.
1574     return 0;
1575   }
1576   std::vector<CMessage*> messages_to_release;
1577   std::vector<ContainerBase*> containers_to_release;
1578   if (self->child_submessages && field_descriptor->is_repeated() &&
1579       field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1580     for (const auto& child_item : *self->child_submessages) {
1581       if (child_item.second->parent_field_descriptor == field_descriptor) {
1582         messages_to_release.push_back(child_item.second);
1583       }
1584     }
1585   }
1586   if (self->composite_fields) {
1587     CMessage::CompositeFieldsMap::iterator it =
1588         self->composite_fields->find(field_descriptor);
1589     if (it != self->composite_fields->end()) {
1590       containers_to_release.push_back(it->second);
1591     }
1592   }
1593 
1594   return InternalReparentFields(self, messages_to_release,
1595                                 containers_to_release);
1596 }
1597 
ClearFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1598 PyObject* ClearFieldByDescriptor(
1599     CMessage* self,
1600     const FieldDescriptor* field_descriptor) {
1601   if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
1602     return NULL;
1603   }
1604   AssureWritable(self);
1605   Message* message = self->message;
1606   message->GetReflection()->ClearField(message, field_descriptor);
1607   Py_RETURN_NONE;
1608 }
1609 
ClearField(CMessage * self,PyObject * arg)1610 PyObject* ClearField(CMessage* self, PyObject* arg) {
1611   if (!(PyString_Check(arg) || PyUnicode_Check(arg))) {
1612     PyErr_SetString(PyExc_TypeError, "field name must be a string");
1613     return NULL;
1614   }
1615 #if PY_MAJOR_VERSION < 3
1616   char* field_name;
1617   Py_ssize_t size;
1618   if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) {
1619     return NULL;
1620   }
1621 #else
1622   Py_ssize_t size;
1623   const char* field_name = PyUnicode_AsUTF8AndSize(arg, &size);
1624 #endif
1625   AssureWritable(self);
1626   Message* message = self->message;
1627   ScopedPyObjectPtr arg_in_oneof;
1628   bool is_in_oneof;
1629   const FieldDescriptor* field_descriptor =
1630       FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof);
1631   if (field_descriptor == NULL) {
1632     if (!is_in_oneof) {
1633       PyErr_Format(PyExc_ValueError,
1634                    "Protocol message has no \"%s\" field.", field_name);
1635       return NULL;
1636     } else {
1637       Py_RETURN_NONE;
1638     }
1639   } else if (is_in_oneof) {
1640     const string& name = field_descriptor->name();
1641     arg_in_oneof.reset(PyString_FromStringAndSize(name.c_str(), name.size()));
1642     arg = arg_in_oneof.get();
1643   }
1644 
1645   if (InternalReleaseFieldByDescriptor(self, field_descriptor) < 0) {
1646     return NULL;
1647   }
1648   return ClearFieldByDescriptor(self, field_descriptor);
1649 }
1650 
Clear(CMessage * self)1651 PyObject* Clear(CMessage* self) {
1652   AssureWritable(self);
1653   // Detach all current fields of this message
1654   std::vector<CMessage*> messages_to_release;
1655   std::vector<ContainerBase*> containers_to_release;
1656   if (self->child_submessages) {
1657     for (const auto& item : *self->child_submessages) {
1658       messages_to_release.push_back(item.second);
1659     }
1660   }
1661   if (self->composite_fields) {
1662     for (const auto& item : *self->composite_fields) {
1663       containers_to_release.push_back(item.second);
1664     }
1665   }
1666   if (InternalReparentFields(self, messages_to_release, containers_to_release) <
1667       0) {
1668     return NULL;
1669   }
1670   if (self->unknown_field_set) {
1671     unknown_fields::Clear(
1672         reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1673     self->unknown_field_set = nullptr;
1674   }
1675   self->message->Clear();
1676   Py_RETURN_NONE;
1677 }
1678 
1679 // ---------------------------------------------------------------------
1680 
GetMessageName(CMessage * self)1681 static string GetMessageName(CMessage* self) {
1682   if (self->parent_field_descriptor != NULL) {
1683     return self->parent_field_descriptor->full_name();
1684   } else {
1685     return self->message->GetDescriptor()->full_name();
1686   }
1687 }
1688 
InternalSerializeToString(CMessage * self,PyObject * args,PyObject * kwargs,bool require_initialized)1689 static PyObject* InternalSerializeToString(
1690     CMessage* self, PyObject* args, PyObject* kwargs,
1691     bool require_initialized) {
1692   // Parse the "deterministic" kwarg; defaults to False.
1693   static char* kwlist[] = { "deterministic", 0 };
1694   PyObject* deterministic_obj = Py_None;
1695   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
1696                                    &deterministic_obj)) {
1697     return NULL;
1698   }
1699   // Preemptively convert to a bool first, so we don't need to back out of
1700   // allocating memory if this raises an exception.
1701   // NOTE: This is unused later if deterministic == Py_None, but that's fine.
1702   int deterministic = PyObject_IsTrue(deterministic_obj);
1703   if (deterministic < 0) {
1704     return NULL;
1705   }
1706 
1707   if (require_initialized && !self->message->IsInitialized()) {
1708     ScopedPyObjectPtr errors(FindInitializationErrors(self));
1709     if (errors == NULL) {
1710       return NULL;
1711     }
1712     ScopedPyObjectPtr comma(PyString_FromString(","));
1713     if (comma == NULL) {
1714       return NULL;
1715     }
1716     ScopedPyObjectPtr joined(
1717         PyObject_CallMethod(comma.get(), "join", "O", errors.get()));
1718     if (joined == NULL) {
1719       return NULL;
1720     }
1721 
1722     // TODO(haberman): this is a (hopefully temporary) hack.  The unit testing
1723     // infrastructure reloads all pure-Python modules for every test, but not
1724     // C++ modules (because that's generally impossible:
1725     // http://bugs.python.org/issue1144263).  But if we cache EncodeError, we'll
1726     // return the EncodeError from a previous load of the module, which won't
1727     // match a user's attempt to catch EncodeError.  So we have to look it up
1728     // again every time.
1729     ScopedPyObjectPtr message_module(PyImport_ImportModule(
1730         "google.protobuf.message"));
1731     if (message_module.get() == NULL) {
1732       return NULL;
1733     }
1734 
1735     ScopedPyObjectPtr encode_error(
1736         PyObject_GetAttrString(message_module.get(), "EncodeError"));
1737     if (encode_error.get() == NULL) {
1738       return NULL;
1739     }
1740     PyErr_Format(encode_error.get(),
1741                  "Message %s is missing required fields: %s",
1742                  GetMessageName(self).c_str(), PyString_AsString(joined.get()));
1743     return NULL;
1744   }
1745 
1746   // Ok, arguments parsed and errors checked, now encode to a string
1747   const size_t size = self->message->ByteSizeLong();
1748   if (size == 0) {
1749     return PyBytes_FromString("");
1750   }
1751   PyObject* result = PyBytes_FromStringAndSize(NULL, size);
1752   if (result == NULL) {
1753     return NULL;
1754   }
1755   io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
1756   io::CodedOutputStream coded_out(&out);
1757   if (deterministic_obj != Py_None) {
1758     coded_out.SetSerializationDeterministic(deterministic);
1759   }
1760   self->message->SerializeWithCachedSizes(&coded_out);
1761   GOOGLE_CHECK(!coded_out.HadError());
1762   return result;
1763 }
1764 
SerializeToString(CMessage * self,PyObject * args,PyObject * kwargs)1765 static PyObject* SerializeToString(
1766     CMessage* self, PyObject* args, PyObject* kwargs) {
1767   return InternalSerializeToString(self, args, kwargs,
1768                                    /*require_initialized=*/true);
1769 }
1770 
SerializePartialToString(CMessage * self,PyObject * args,PyObject * kwargs)1771 static PyObject* SerializePartialToString(
1772     CMessage* self, PyObject* args, PyObject* kwargs) {
1773   return InternalSerializeToString(self, args, kwargs,
1774                                    /*require_initialized=*/false);
1775 }
1776 
1777 // Formats proto fields for ascii dumps using python formatting functions where
1778 // appropriate.
1779 class PythonFieldValuePrinter : public TextFormat::FieldValuePrinter {
1780  public:
1781   // Python has some differences from C++ when printing floating point numbers.
1782   //
1783   // 1) Trailing .0 is always printed.
1784   // 2) (Python2) Output is rounded to 12 digits.
1785   // 3) (Python3) The full precision of the double is preserved (and Python uses
1786   //    David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some
1787   //    differences, but they rarely happen)
1788   //
1789   // We override floating point printing with the C-API function for printing
1790   // Python floats to ensure consistency.
PrintFloat(float value) const1791   string PrintFloat(float value) const { return PrintDouble(value); }
PrintDouble(double value) const1792   string PrintDouble(double value) const {
1793     // This implementation is not highly optimized (it allocates two temporary
1794     // Python objects) but it is simple and portable.  If this is shown to be a
1795     // performance bottleneck, we can optimize it, but the results will likely
1796     // be more complicated to accommodate the differing behavior of double
1797     // formatting between Python 2 and Python 3.
1798     //
1799     // (Though a valid question is: do we really want to make out output
1800     // dependent on the Python version?)
1801     ScopedPyObjectPtr py_value(PyFloat_FromDouble(value));
1802     if (!py_value.get()) {
1803       return string();
1804     }
1805 
1806     ScopedPyObjectPtr py_str(PyObject_Str(py_value.get()));
1807     if (!py_str.get()) {
1808       return string();
1809     }
1810 
1811     return string(PyString_AsString(py_str.get()));
1812   }
1813 };
1814 
ToStr(CMessage * self)1815 static PyObject* ToStr(CMessage* self) {
1816   TextFormat::Printer printer;
1817   // Passes ownership
1818   printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter());
1819   printer.SetHideUnknownFields(true);
1820   string output;
1821   if (!printer.PrintToString(*self->message, &output)) {
1822     PyErr_SetString(PyExc_ValueError, "Unable to convert message to str");
1823     return NULL;
1824   }
1825   return PyString_FromString(output.c_str());
1826 }
1827 
MergeFrom(CMessage * self,PyObject * arg)1828 PyObject* MergeFrom(CMessage* self, PyObject* arg) {
1829   CMessage* other_message;
1830   if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1831     PyErr_Format(PyExc_TypeError,
1832                  "Parameter to MergeFrom() must be instance of same class: "
1833                  "expected %s got %s.",
1834                  self->message->GetDescriptor()->full_name().c_str(),
1835                  Py_TYPE(arg)->tp_name);
1836     return NULL;
1837   }
1838 
1839   other_message = reinterpret_cast<CMessage*>(arg);
1840   if (other_message->message->GetDescriptor() !=
1841       self->message->GetDescriptor()) {
1842     PyErr_Format(PyExc_TypeError,
1843                  "Parameter to MergeFrom() must be instance of same class: "
1844                  "expected %s got %s.",
1845                  self->message->GetDescriptor()->full_name().c_str(),
1846                  other_message->message->GetDescriptor()->full_name().c_str());
1847     return NULL;
1848   }
1849   AssureWritable(self);
1850 
1851   self->message->MergeFrom(*other_message->message);
1852   // Child message might be lazily created before MergeFrom. Make sure they
1853   // are mutable at this point if child messages are really created.
1854   if (FixupMessageAfterMerge(self) < 0) {
1855     return NULL;
1856   }
1857 
1858   Py_RETURN_NONE;
1859 }
1860 
CopyFrom(CMessage * self,PyObject * arg)1861 static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
1862   CMessage* other_message;
1863   if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1864     PyErr_Format(PyExc_TypeError,
1865                  "Parameter to CopyFrom() must be instance of same class: "
1866                  "expected %s got %s.",
1867                  self->message->GetDescriptor()->full_name().c_str(),
1868                  Py_TYPE(arg)->tp_name);
1869     return NULL;
1870   }
1871 
1872   other_message = reinterpret_cast<CMessage*>(arg);
1873 
1874   if (self == other_message) {
1875     Py_RETURN_NONE;
1876   }
1877 
1878   if (other_message->message->GetDescriptor() !=
1879       self->message->GetDescriptor()) {
1880     PyErr_Format(PyExc_TypeError,
1881                  "Parameter to CopyFrom() must be instance of same class: "
1882                  "expected %s got %s.",
1883                  self->message->GetDescriptor()->full_name().c_str(),
1884                  other_message->message->GetDescriptor()->full_name().c_str());
1885     return NULL;
1886   }
1887 
1888   AssureWritable(self);
1889 
1890   // CopyFrom on the message will not clean up self->composite_fields,
1891   // which can leave us in an inconsistent state, so clear it out here.
1892   (void)ScopedPyObjectPtr(Clear(self));
1893 
1894   self->message->CopyFrom(*other_message->message);
1895 
1896   Py_RETURN_NONE;
1897 }
1898 
1899 // Protobuf has a 64MB limit built in, this variable will override this. Please
1900 // do not enable this unless you fully understand the implications: protobufs
1901 // must all be kept in memory at the same time, so if they grow too big you may
1902 // get OOM errors. The protobuf APIs do not provide any tools for processing
1903 // protobufs in chunks.  If you have protos this big you should break them up if
1904 // it is at all convenient to do so.
1905 #ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
1906 static bool allow_oversize_protos = true;
1907 #else
1908 static bool allow_oversize_protos = false;
1909 #endif
1910 
1911 // Provide a method in the module to set allow_oversize_protos to a boolean
1912 // value. This method returns the newly value of allow_oversize_protos.
SetAllowOversizeProtos(PyObject * m,PyObject * arg)1913 PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
1914   if (!arg || !PyBool_Check(arg)) {
1915     PyErr_SetString(PyExc_TypeError,
1916                     "Argument to SetAllowOversizeProtos must be boolean");
1917     return NULL;
1918   }
1919   allow_oversize_protos = PyObject_IsTrue(arg);
1920   if (allow_oversize_protos) {
1921     Py_RETURN_TRUE;
1922   } else {
1923     Py_RETURN_FALSE;
1924   }
1925 }
1926 
MergeFromString(CMessage * self,PyObject * arg)1927 static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
1928   const void* data;
1929   Py_ssize_t data_length;
1930   if (PyObject_AsReadBuffer(arg, &data, &data_length) < 0) {
1931     return NULL;
1932   }
1933 
1934   AssureWritable(self);
1935 
1936   io::CodedInputStream input(
1937       reinterpret_cast<const uint8*>(data), data_length);
1938   if (allow_oversize_protos) {
1939     input.SetTotalBytesLimit(INT_MAX, INT_MAX);
1940     input.SetRecursionLimit(INT_MAX);
1941   }
1942   PyMessageFactory* factory = GetFactoryForMessage(self);
1943   input.SetExtensionRegistry(factory->pool->pool, factory->message_factory);
1944   bool success = self->message->MergePartialFromCodedStream(&input);
1945   // Child message might be lazily created before MergeFrom. Make sure they
1946   // are mutable at this point if child messages are really created.
1947   if (FixupMessageAfterMerge(self) < 0) {
1948     return NULL;
1949   }
1950 
1951   if (success) {
1952     if (!input.ConsumedEntireMessage()) {
1953       // TODO(jieluo): Raise error and return NULL instead.
1954       // b/27494216
1955       PyErr_Warn(NULL, "Unexpected end-group tag: Not all data was converted");
1956     }
1957     return PyInt_FromLong(input.CurrentPosition());
1958   } else {
1959     PyErr_Format(DecodeError_class, "Error parsing message");
1960     return NULL;
1961   }
1962 }
1963 
ParseFromString(CMessage * self,PyObject * arg)1964 static PyObject* ParseFromString(CMessage* self, PyObject* arg) {
1965   if (ScopedPyObjectPtr(Clear(self)) == NULL) {
1966     return NULL;
1967   }
1968   return MergeFromString(self, arg);
1969 }
1970 
ByteSize(CMessage * self,PyObject * args)1971 static PyObject* ByteSize(CMessage* self, PyObject* args) {
1972   return PyLong_FromLong(self->message->ByteSize());
1973 }
1974 
RegisterExtension(PyObject * cls,PyObject * extension_handle)1975 PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
1976   const FieldDescriptor* descriptor =
1977       GetExtensionDescriptor(extension_handle);
1978   if (descriptor == NULL) {
1979     return NULL;
1980   }
1981   if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
1982     PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
1983                  cls->ob_type->tp_name);
1984     return NULL;
1985   }
1986   CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
1987   if (message_class == NULL) {
1988     return NULL;
1989   }
1990   // If the extension was already registered, check that it is the same.
1991   const FieldDescriptor* existing_extension =
1992       message_class->py_message_factory->pool->pool->FindExtensionByNumber(
1993           descriptor->containing_type(), descriptor->number());
1994   if (existing_extension != NULL && existing_extension != descriptor) {
1995     PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
1996     return NULL;
1997   }
1998   Py_RETURN_NONE;
1999 }
2000 
SetInParent(CMessage * self,PyObject * args)2001 static PyObject* SetInParent(CMessage* self, PyObject* args) {
2002   AssureWritable(self);
2003   Py_RETURN_NONE;
2004 }
2005 
WhichOneof(CMessage * self,PyObject * arg)2006 static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
2007   Py_ssize_t name_size;
2008   char *name_data;
2009   if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0)
2010     return NULL;
2011   string oneof_name = string(name_data, name_size);
2012   const OneofDescriptor* oneof_desc =
2013       self->message->GetDescriptor()->FindOneofByName(oneof_name);
2014   if (oneof_desc == NULL) {
2015     PyErr_Format(PyExc_ValueError,
2016                  "Protocol message has no oneof \"%s\" field.",
2017                  oneof_name.c_str());
2018     return NULL;
2019   }
2020   const FieldDescriptor* field_in_oneof =
2021       self->message->GetReflection()->GetOneofFieldDescriptor(
2022           *self->message, oneof_desc);
2023   if (field_in_oneof == NULL) {
2024     Py_RETURN_NONE;
2025   } else {
2026     const string& name = field_in_oneof->name();
2027     return PyString_FromStringAndSize(name.c_str(), name.size());
2028   }
2029 }
2030 
2031 static PyObject* GetExtensionDict(CMessage* self, void *closure);
2032 
ListFields(CMessage * self)2033 static PyObject* ListFields(CMessage* self) {
2034   std::vector<const FieldDescriptor*> fields;
2035   self->message->GetReflection()->ListFields(*self->message, &fields);
2036 
2037   // Normally, the list will be exactly the size of the fields.
2038   ScopedPyObjectPtr all_fields(PyList_New(fields.size()));
2039   if (all_fields == NULL) {
2040     return NULL;
2041   }
2042 
2043   // When there are unknown extensions, the py list will *not* contain
2044   // the field information.  Thus the actual size of the py list will be
2045   // smaller than the size of fields.  Set the actual size at the end.
2046   Py_ssize_t actual_size = 0;
2047   for (size_t i = 0; i < fields.size(); ++i) {
2048     ScopedPyObjectPtr t(PyTuple_New(2));
2049     if (t == NULL) {
2050       return NULL;
2051     }
2052 
2053     if (fields[i]->is_extension()) {
2054       ScopedPyObjectPtr extension_field(
2055           PyFieldDescriptor_FromDescriptor(fields[i]));
2056       if (extension_field == NULL) {
2057         return NULL;
2058       }
2059       // With C++ descriptors, the field can always be retrieved, but for
2060       // unknown extensions which have not been imported in Python code, there
2061       // is no message class and we cannot retrieve the value.
2062       // TODO(amauryfa): consider building the class on the fly!
2063       if (fields[i]->message_type() != NULL &&
2064           message_factory::GetMessageClass(
2065               GetFactoryForMessage(self),
2066               fields[i]->message_type()) == NULL) {
2067         PyErr_Clear();
2068         continue;
2069       }
2070       ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL));
2071       if (extensions == NULL) {
2072         return NULL;
2073       }
2074       // 'extension' reference later stolen by PyTuple_SET_ITEM.
2075       PyObject* extension = PyObject_GetItem(
2076           extensions.get(), extension_field.get());
2077       if (extension == NULL) {
2078         return NULL;
2079       }
2080       PyTuple_SET_ITEM(t.get(), 0, extension_field.release());
2081       // Steals reference to 'extension'
2082       PyTuple_SET_ITEM(t.get(), 1, extension);
2083     } else {
2084       // Normal field
2085       ScopedPyObjectPtr field_descriptor(
2086           PyFieldDescriptor_FromDescriptor(fields[i]));
2087       if (field_descriptor == NULL) {
2088         return NULL;
2089       }
2090 
2091       PyObject* field_value = GetFieldValue(self, fields[i]);
2092       if (field_value == NULL) {
2093         PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str());
2094         return NULL;
2095       }
2096       PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release());
2097       PyTuple_SET_ITEM(t.get(), 1, field_value);
2098     }
2099     PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
2100     ++actual_size;
2101   }
2102   if (static_cast<size_t>(actual_size) != fields.size() &&
2103       (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) <
2104        0)) {
2105     return NULL;
2106   }
2107   return all_fields.release();
2108 }
2109 
DiscardUnknownFields(CMessage * self)2110 static PyObject* DiscardUnknownFields(CMessage* self) {
2111   AssureWritable(self);
2112   self->message->DiscardUnknownFields();
2113   Py_RETURN_NONE;
2114 }
2115 
FindInitializationErrors(CMessage * self)2116 PyObject* FindInitializationErrors(CMessage* self) {
2117   Message* message = self->message;
2118   std::vector<string> errors;
2119   message->FindInitializationErrors(&errors);
2120 
2121   PyObject* error_list = PyList_New(errors.size());
2122   if (error_list == NULL) {
2123     return NULL;
2124   }
2125   for (size_t i = 0; i < errors.size(); ++i) {
2126     const string& error = errors[i];
2127     PyObject* error_string = PyString_FromStringAndSize(
2128         error.c_str(), error.length());
2129     if (error_string == NULL) {
2130       Py_DECREF(error_list);
2131       return NULL;
2132     }
2133     PyList_SET_ITEM(error_list, i, error_string);
2134   }
2135   return error_list;
2136 }
2137 
RichCompare(CMessage * self,PyObject * other,int opid)2138 static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
2139   // Only equality comparisons are implemented.
2140   if (opid != Py_EQ && opid != Py_NE) {
2141     Py_INCREF(Py_NotImplemented);
2142     return Py_NotImplemented;
2143   }
2144   bool equals = true;
2145   // If other is not a message, it cannot be equal.
2146   if (!PyObject_TypeCheck(other, CMessage_Type)) {
2147     equals = false;
2148   }
2149   const google::protobuf::Message* other_message =
2150       reinterpret_cast<CMessage*>(other)->message;
2151   // If messages don't have the same descriptors, they are not equal.
2152   if (equals &&
2153       self->message->GetDescriptor() != other_message->GetDescriptor()) {
2154     equals = false;
2155   }
2156   // Check the message contents.
2157   if (equals && !google::protobuf::util::MessageDifferencer::Equals(
2158           *self->message,
2159           *reinterpret_cast<CMessage*>(other)->message)) {
2160     equals = false;
2161   }
2162 
2163   if (equals ^ (opid == Py_EQ)) {
2164     Py_RETURN_FALSE;
2165   } else {
2166     Py_RETURN_TRUE;
2167   }
2168 }
2169 
InternalGetScalar(const Message * message,const FieldDescriptor * field_descriptor)2170 PyObject* InternalGetScalar(const Message* message,
2171                             const FieldDescriptor* field_descriptor) {
2172   const Reflection* reflection = message->GetReflection();
2173 
2174   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2175     return NULL;
2176   }
2177 
2178   PyObject* result = NULL;
2179   switch (field_descriptor->cpp_type()) {
2180     case FieldDescriptor::CPPTYPE_INT32: {
2181       int32 value = reflection->GetInt32(*message, field_descriptor);
2182       result = PyInt_FromLong(value);
2183       break;
2184     }
2185     case FieldDescriptor::CPPTYPE_INT64: {
2186       int64 value = reflection->GetInt64(*message, field_descriptor);
2187       result = PyLong_FromLongLong(value);
2188       break;
2189     }
2190     case FieldDescriptor::CPPTYPE_UINT32: {
2191       uint32 value = reflection->GetUInt32(*message, field_descriptor);
2192       result = PyInt_FromSize_t(value);
2193       break;
2194     }
2195     case FieldDescriptor::CPPTYPE_UINT64: {
2196       uint64 value = reflection->GetUInt64(*message, field_descriptor);
2197       result = PyLong_FromUnsignedLongLong(value);
2198       break;
2199     }
2200     case FieldDescriptor::CPPTYPE_FLOAT: {
2201       float value = reflection->GetFloat(*message, field_descriptor);
2202       result = PyFloat_FromDouble(value);
2203       break;
2204     }
2205     case FieldDescriptor::CPPTYPE_DOUBLE: {
2206       double value = reflection->GetDouble(*message, field_descriptor);
2207       result = PyFloat_FromDouble(value);
2208       break;
2209     }
2210     case FieldDescriptor::CPPTYPE_BOOL: {
2211       bool value = reflection->GetBool(*message, field_descriptor);
2212       result = PyBool_FromLong(value);
2213       break;
2214     }
2215     case FieldDescriptor::CPPTYPE_STRING: {
2216       string scratch;
2217       const string& value =
2218           reflection->GetStringReference(*message, field_descriptor, &scratch);
2219       result = ToStringObject(field_descriptor, value);
2220       break;
2221     }
2222     case FieldDescriptor::CPPTYPE_ENUM: {
2223       const EnumValueDescriptor* enum_value =
2224           message->GetReflection()->GetEnum(*message, field_descriptor);
2225       result = PyInt_FromLong(enum_value->number());
2226       break;
2227     }
2228     default:
2229       PyErr_Format(
2230           PyExc_SystemError, "Getting a value from a field of unknown type %d",
2231           field_descriptor->cpp_type());
2232   }
2233 
2234   return result;
2235 }
2236 
InternalGetSubMessage(CMessage * self,const FieldDescriptor * field_descriptor)2237 CMessage* InternalGetSubMessage(
2238     CMessage* self, const FieldDescriptor* field_descriptor) {
2239   const Reflection* reflection = self->message->GetReflection();
2240   PyMessageFactory* factory = GetFactoryForMessage(self);
2241   const Message& sub_message = reflection->GetMessage(
2242       *self->message, field_descriptor, factory->message_factory);
2243 
2244   CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2245       factory, field_descriptor->message_type());
2246   ScopedPyObjectPtr message_class_owner(
2247       reinterpret_cast<PyObject*>(message_class));
2248   if (message_class == NULL) {
2249     return NULL;
2250   }
2251 
2252   CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
2253   if (cmsg == NULL) {
2254     return NULL;
2255   }
2256 
2257   Py_INCREF(self);
2258   cmsg->parent = self;
2259   cmsg->parent_field_descriptor = field_descriptor;
2260   cmsg->read_only = !reflection->HasField(*self->message, field_descriptor);
2261   cmsg->message = const_cast<Message*>(&sub_message);
2262   return cmsg;
2263 }
2264 
InternalSetNonOneofScalar(Message * message,const FieldDescriptor * field_descriptor,PyObject * arg)2265 int InternalSetNonOneofScalar(
2266     Message* message,
2267     const FieldDescriptor* field_descriptor,
2268     PyObject* arg) {
2269   const Reflection* reflection = message->GetReflection();
2270 
2271   if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2272     return -1;
2273   }
2274 
2275   switch (field_descriptor->cpp_type()) {
2276     case FieldDescriptor::CPPTYPE_INT32: {
2277       GOOGLE_CHECK_GET_INT32(arg, value, -1);
2278       reflection->SetInt32(message, field_descriptor, value);
2279       break;
2280     }
2281     case FieldDescriptor::CPPTYPE_INT64: {
2282       GOOGLE_CHECK_GET_INT64(arg, value, -1);
2283       reflection->SetInt64(message, field_descriptor, value);
2284       break;
2285     }
2286     case FieldDescriptor::CPPTYPE_UINT32: {
2287       GOOGLE_CHECK_GET_UINT32(arg, value, -1);
2288       reflection->SetUInt32(message, field_descriptor, value);
2289       break;
2290     }
2291     case FieldDescriptor::CPPTYPE_UINT64: {
2292       GOOGLE_CHECK_GET_UINT64(arg, value, -1);
2293       reflection->SetUInt64(message, field_descriptor, value);
2294       break;
2295     }
2296     case FieldDescriptor::CPPTYPE_FLOAT: {
2297       GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
2298       reflection->SetFloat(message, field_descriptor, value);
2299       break;
2300     }
2301     case FieldDescriptor::CPPTYPE_DOUBLE: {
2302       GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
2303       reflection->SetDouble(message, field_descriptor, value);
2304       break;
2305     }
2306     case FieldDescriptor::CPPTYPE_BOOL: {
2307       GOOGLE_CHECK_GET_BOOL(arg, value, -1);
2308       reflection->SetBool(message, field_descriptor, value);
2309       break;
2310     }
2311     case FieldDescriptor::CPPTYPE_STRING: {
2312       if (!CheckAndSetString(
2313           arg, message, field_descriptor, reflection, false, -1)) {
2314         return -1;
2315       }
2316       break;
2317     }
2318     case FieldDescriptor::CPPTYPE_ENUM: {
2319       GOOGLE_CHECK_GET_INT32(arg, value, -1);
2320       if (reflection->SupportsUnknownEnumValues()) {
2321         reflection->SetEnumValue(message, field_descriptor, value);
2322       } else {
2323         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
2324         const EnumValueDescriptor* enum_value =
2325             enum_descriptor->FindValueByNumber(value);
2326         if (enum_value != NULL) {
2327           reflection->SetEnum(message, field_descriptor, enum_value);
2328         } else {
2329           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
2330           return -1;
2331         }
2332       }
2333       break;
2334     }
2335     default:
2336       PyErr_Format(
2337           PyExc_SystemError, "Setting value to a field of unknown type %d",
2338           field_descriptor->cpp_type());
2339       return -1;
2340   }
2341 
2342   return 0;
2343 }
2344 
InternalSetScalar(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * arg)2345 int InternalSetScalar(
2346     CMessage* self,
2347     const FieldDescriptor* field_descriptor,
2348     PyObject* arg) {
2349   if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
2350     return -1;
2351   }
2352 
2353   if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
2354     return -1;
2355   }
2356 
2357   return InternalSetNonOneofScalar(self->message, field_descriptor, arg);
2358 }
2359 
FromString(PyTypeObject * cls,PyObject * serialized)2360 PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
2361   PyObject* py_cmsg = PyObject_CallObject(
2362       reinterpret_cast<PyObject*>(cls), NULL);
2363   if (py_cmsg == NULL) {
2364     return NULL;
2365   }
2366   CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
2367 
2368   ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized));
2369   if (py_length == NULL) {
2370     Py_DECREF(py_cmsg);
2371     return NULL;
2372   }
2373 
2374   return py_cmsg;
2375 }
2376 
DeepCopy(CMessage * self,PyObject * arg)2377 PyObject* DeepCopy(CMessage* self, PyObject* arg) {
2378   PyObject* clone = PyObject_CallObject(
2379       reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL);
2380   if (clone == NULL) {
2381     return NULL;
2382   }
2383   if (!PyObject_TypeCheck(clone, CMessage_Type)) {
2384     Py_DECREF(clone);
2385     return NULL;
2386   }
2387   if (ScopedPyObjectPtr(MergeFrom(
2388           reinterpret_cast<CMessage*>(clone),
2389           reinterpret_cast<PyObject*>(self))) == NULL) {
2390     Py_DECREF(clone);
2391     return NULL;
2392   }
2393   return clone;
2394 }
2395 
ToUnicode(CMessage * self)2396 PyObject* ToUnicode(CMessage* self) {
2397   // Lazy import to prevent circular dependencies
2398   ScopedPyObjectPtr text_format(
2399       PyImport_ImportModule("google.protobuf.text_format"));
2400   if (text_format == NULL) {
2401     return NULL;
2402   }
2403   ScopedPyObjectPtr method_name(PyString_FromString("MessageToString"));
2404   if (method_name == NULL) {
2405     return NULL;
2406   }
2407   Py_INCREF(Py_True);
2408   ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(
2409       text_format.get(), method_name.get(), self, Py_True, NULL));
2410   Py_DECREF(Py_True);
2411   if (encoded == NULL) {
2412     return NULL;
2413   }
2414 #if PY_MAJOR_VERSION < 3
2415   PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL);
2416 #else
2417   PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL);
2418 #endif
2419   if (decoded == NULL) {
2420     return NULL;
2421   }
2422   return decoded;
2423 }
2424 
Reduce(CMessage * self)2425 PyObject* Reduce(CMessage* self) {
2426   ScopedPyObjectPtr constructor(reinterpret_cast<PyObject*>(Py_TYPE(self)));
2427   constructor.inc();
2428   ScopedPyObjectPtr args(PyTuple_New(0));
2429   if (args == NULL) {
2430     return NULL;
2431   }
2432   ScopedPyObjectPtr state(PyDict_New());
2433   if (state == NULL) {
2434     return  NULL;
2435   }
2436   string contents;
2437   self->message->SerializePartialToString(&contents);
2438   ScopedPyObjectPtr serialized(
2439       PyBytes_FromStringAndSize(contents.c_str(), contents.size()));
2440   if (serialized == NULL) {
2441     return NULL;
2442   }
2443   if (PyDict_SetItemString(state.get(), "serialized", serialized.get()) < 0) {
2444     return NULL;
2445   }
2446   return Py_BuildValue("OOO", constructor.get(), args.get(), state.get());
2447 }
2448 
SetState(CMessage * self,PyObject * state)2449 PyObject* SetState(CMessage* self, PyObject* state) {
2450   if (!PyDict_Check(state)) {
2451     PyErr_SetString(PyExc_TypeError, "state not a dict");
2452     return NULL;
2453   }
2454   PyObject* serialized = PyDict_GetItemString(state, "serialized");
2455   if (serialized == NULL) {
2456     return NULL;
2457   }
2458 #if PY_MAJOR_VERSION >= 3
2459   // On Python 3, using encoding='latin1' is required for unpickling
2460   // protos pickled by Python 2.
2461   if (!PyBytes_Check(serialized)) {
2462     serialized = PyUnicode_AsEncodedString(serialized, "latin1", NULL);
2463   }
2464 #endif
2465   if (ScopedPyObjectPtr(ParseFromString(self, serialized)) == NULL) {
2466     return NULL;
2467   }
2468   Py_RETURN_NONE;
2469 }
2470 
2471 // CMessage static methods:
_CheckCalledFromGeneratedFile(PyObject * unused,PyObject * unused_arg)2472 PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
2473                                         PyObject* unused_arg) {
2474   if (!_CalledFromGeneratedFile(1)) {
2475     PyErr_SetString(PyExc_TypeError,
2476                     "Descriptors should not be created directly, "
2477                     "but only retrieved from their parent.");
2478     return NULL;
2479   }
2480   Py_RETURN_NONE;
2481 }
2482 
GetExtensionDict(CMessage * self,void * closure)2483 static PyObject* GetExtensionDict(CMessage* self, void *closure) {
2484   // If there are extension_ranges, the message is "extendable". Allocate a
2485   // dictionary to store the extension fields.
2486   const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
2487   if (!descriptor->extension_range_count()) {
2488     PyErr_SetNone(PyExc_AttributeError);
2489     return NULL;
2490   }
2491   if (!self->composite_fields) {
2492     self->composite_fields = new CMessage::CompositeFieldsMap();
2493   }
2494   if (!self->composite_fields) {
2495     return NULL;
2496   }
2497   ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
2498   return reinterpret_cast<PyObject*>(extension_dict);
2499 }
2500 
UnknownFieldSet(CMessage * self)2501 static PyObject* UnknownFieldSet(CMessage* self) {
2502   if (self->unknown_field_set == NULL) {
2503     self->unknown_field_set = unknown_fields::NewPyUnknownFields(self);
2504   } else {
2505     Py_INCREF(self->unknown_field_set);
2506   }
2507   return self->unknown_field_set;
2508 }
2509 
GetExtensionsByName(CMessage * self,void * closure)2510 static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
2511   return message_meta::GetExtensionsByName(
2512       reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2513 }
2514 
GetExtensionsByNumber(CMessage * self,void * closure)2515 static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
2516   return message_meta::GetExtensionsByNumber(
2517       reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2518 }
2519 
2520 static PyGetSetDef Getters[] = {
2521   {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
2522   {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
2523   {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
2524   {NULL}
2525 };
2526 
2527 
2528 static PyMethodDef Methods[] = {
2529   { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
2530     "Makes a deep copy of the class." },
2531   { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
2532     "Outputs picklable representation of the message." },
2533   { "__setstate__", (PyCFunction)SetState, METH_O,
2534     "Inputs picklable representation of the message." },
2535   { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
2536     "Outputs a unicode representation of the message." },
2537   { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
2538     "Returns the size of the message in bytes." },
2539   { "Clear", (PyCFunction)Clear, METH_NOARGS,
2540     "Clears the message." },
2541   { "ClearExtension", (PyCFunction)ClearExtension, METH_O,
2542     "Clears a message field." },
2543   { "ClearField", (PyCFunction)ClearField, METH_O,
2544     "Clears a message field." },
2545   { "CopyFrom", (PyCFunction)CopyFrom, METH_O,
2546     "Copies a protocol message into the current message." },
2547   { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
2548     "Discards the unknown fields." },
2549   { "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
2550     METH_NOARGS,
2551     "Finds unset required fields." },
2552   { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS,
2553     "Creates new method instance from given serialized data." },
2554   { "HasExtension", (PyCFunction)HasExtension, METH_O,
2555     "Checks if a message field is set." },
2556   { "HasField", (PyCFunction)HasField, METH_O,
2557     "Checks if a message field is set." },
2558   { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS,
2559     "Checks if all required fields of a protocol message are set." },
2560   { "ListFields", (PyCFunction)ListFields, METH_NOARGS,
2561     "Lists all set fields of a message." },
2562   { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
2563     "Merges a protocol message into the current message." },
2564   { "MergeFromString", (PyCFunction)MergeFromString, METH_O,
2565     "Merges a serialized message into the current message." },
2566   { "ParseFromString", (PyCFunction)ParseFromString, METH_O,
2567     "Parses a serialized message into the current message." },
2568   { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
2569     "Registers an extension with the current message." },
2570   { "SerializePartialToString", (PyCFunction)SerializePartialToString,
2571     METH_VARARGS | METH_KEYWORDS,
2572     "Serializes the message to a string, even if it isn't initialized." },
2573   { "SerializeToString", (PyCFunction)SerializeToString,
2574     METH_VARARGS | METH_KEYWORDS,
2575     "Serializes the message to a string, only for initialized messages." },
2576   { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
2577     "Sets the has bit of the given field in its parent message." },
2578   { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS,
2579     "Parse unknown field set"},
2580   { "WhichOneof", (PyCFunction)WhichOneof, METH_O,
2581     "Returns the name of the field set inside a oneof, "
2582     "or None if no field is set." },
2583 
2584   // Static Methods.
2585   { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile,
2586     METH_NOARGS | METH_STATIC,
2587     "Raises TypeError if the caller is not in a _pb2.py file."},
2588   { NULL, NULL}
2589 };
2590 
SetCompositeField(CMessage * self,const FieldDescriptor * field,ContainerBase * value)2591 bool SetCompositeField(CMessage* self, const FieldDescriptor* field,
2592                        ContainerBase* value) {
2593   if (self->composite_fields == NULL) {
2594     self->composite_fields = new CMessage::CompositeFieldsMap();
2595   }
2596   (*self->composite_fields)[field] = value;
2597   return true;
2598 }
2599 
SetSubmessage(CMessage * self,CMessage * submessage)2600 bool SetSubmessage(CMessage* self, CMessage* submessage) {
2601   if (self->child_submessages == NULL) {
2602     self->child_submessages = new CMessage::SubMessagesMap();
2603   }
2604   (*self->child_submessages)[submessage->message] = submessage;
2605   return true;
2606 }
2607 
GetAttr(PyObject * pself,PyObject * name)2608 PyObject* GetAttr(PyObject* pself, PyObject* name) {
2609   CMessage* self = reinterpret_cast<CMessage*>(pself);
2610   PyObject* result = PyObject_GenericGetAttr(
2611       reinterpret_cast<PyObject*>(self), name);
2612   if (result != NULL) {
2613     return result;
2614   }
2615   if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
2616     return NULL;
2617   }
2618 
2619   PyErr_Clear();
2620   return message_meta::GetClassAttribute(
2621       CheckMessageClass(Py_TYPE(self)), name);
2622 }
2623 
GetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor)2624 PyObject* GetFieldValue(CMessage* self,
2625                         const FieldDescriptor* field_descriptor) {
2626   if (self->composite_fields) {
2627     CMessage::CompositeFieldsMap::iterator it =
2628         self->composite_fields->find(field_descriptor);
2629     if (it != self->composite_fields->end()) {
2630       ContainerBase* value = it->second;
2631       Py_INCREF(value);
2632       return value->AsPyObject();
2633     }
2634   }
2635 
2636   if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2637     PyErr_Format(PyExc_TypeError,
2638                  "descriptor to field '%s' doesn't apply to '%s' object",
2639                  field_descriptor->full_name().c_str(),
2640                  Py_TYPE(self)->tp_name);
2641     return NULL;
2642   }
2643 
2644   if (!field_descriptor->is_repeated() &&
2645       field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
2646     return InternalGetScalar(self->message, field_descriptor);
2647   }
2648 
2649   ContainerBase* py_container = nullptr;
2650   if (field_descriptor->is_map()) {
2651     const Descriptor* entry_type = field_descriptor->message_type();
2652     const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
2653     if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2654       CMessageClass* value_class = message_factory::GetMessageClass(
2655           GetFactoryForMessage(self), value_type->message_type());
2656       if (value_class == NULL) {
2657         return NULL;
2658       }
2659       py_container =
2660           NewMessageMapContainer(self, field_descriptor, value_class);
2661     } else {
2662       py_container = NewScalarMapContainer(self, field_descriptor);
2663     }
2664   } else if (field_descriptor->is_repeated()) {
2665     if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2666       CMessageClass* message_class = message_factory::GetMessageClass(
2667           GetFactoryForMessage(self), field_descriptor->message_type());
2668       if (message_class == NULL) {
2669         return NULL;
2670       }
2671       py_container = repeated_composite_container::NewContainer(
2672           self, field_descriptor, message_class);
2673     } else {
2674       py_container =
2675           repeated_scalar_container::NewContainer(self, field_descriptor);
2676     }
2677   } else if (field_descriptor->cpp_type() ==
2678              FieldDescriptor::CPPTYPE_MESSAGE) {
2679     py_container = InternalGetSubMessage(self, field_descriptor);
2680   } else {
2681     PyErr_SetString(PyExc_SystemError, "Should never happen");
2682   }
2683 
2684   if (py_container == NULL) {
2685     return NULL;
2686   }
2687   if (!SetCompositeField(self, field_descriptor, py_container)) {
2688     Py_DECREF(py_container);
2689     return NULL;
2690   }
2691   return py_container->AsPyObject();
2692 }
2693 
SetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * value)2694 int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
2695                   PyObject* value) {
2696   if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2697     PyErr_Format(PyExc_TypeError,
2698                  "descriptor to field '%s' doesn't apply to '%s' object",
2699                  field_descriptor->full_name().c_str(),
2700                  Py_TYPE(self)->tp_name);
2701     return -1;
2702   } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
2703     PyErr_Format(PyExc_AttributeError,
2704                  "Assignment not allowed to repeated "
2705                  "field \"%s\" in protocol message object.",
2706                  field_descriptor->name().c_str());
2707     return -1;
2708   } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2709     PyErr_Format(PyExc_AttributeError,
2710                  "Assignment not allowed to "
2711                  "field \"%s\" in protocol message object.",
2712                  field_descriptor->name().c_str());
2713     return -1;
2714   } else {
2715     AssureWritable(self);
2716     return InternalSetScalar(self, field_descriptor, value);
2717   }
2718 }
2719 
2720 }  // namespace cmessage
2721 
2722 // All containers which are not messages:
2723 // - Make a new parent message
2724 // - Copy the field
2725 // - return the field.
DeepCopy()2726 PyObject* ContainerBase::DeepCopy() {
2727   CMessage* new_parent =
2728       cmessage::NewEmptyMessage(this->parent->GetMessageClass());
2729   new_parent->message = this->parent->message->New();
2730 
2731   // Copy the map field into the new message.
2732   this->parent->message->GetReflection()->SwapFields(
2733       this->parent->message, new_parent->message,
2734       {this->parent_field_descriptor});
2735   this->parent->message->MergeFrom(*new_parent->message);
2736 
2737   PyObject* result =
2738       cmessage::GetFieldValue(new_parent, this->parent_field_descriptor);
2739   Py_DECREF(new_parent);
2740   return result;
2741 }
2742 
RemoveFromParentCache()2743 void ContainerBase::RemoveFromParentCache() {
2744   CMessage* parent = this->parent;
2745   if (parent) {
2746     if (parent->composite_fields)
2747       parent->composite_fields->erase(this->parent_field_descriptor);
2748     Py_CLEAR(parent);
2749   }
2750 }
2751 
BuildSubMessageFromPointer(const FieldDescriptor * field_descriptor,Message * sub_message,CMessageClass * message_class)2752 CMessage* CMessage::BuildSubMessageFromPointer(
2753     const FieldDescriptor* field_descriptor, Message* sub_message,
2754     CMessageClass* message_class) {
2755   if (!this->child_submessages) {
2756     this->child_submessages = new CMessage::SubMessagesMap();
2757   }
2758   CMessage* cmsg = FindPtrOrNull(
2759       *this->child_submessages, sub_message);
2760   if (cmsg) {
2761     Py_INCREF(cmsg);
2762   } else {
2763     cmsg = cmessage::NewEmptyMessage(message_class);
2764 
2765     if (cmsg == NULL) {
2766       return NULL;
2767     }
2768     cmsg->message = sub_message;
2769     Py_INCREF(this);
2770     cmsg->parent = this;
2771     cmsg->parent_field_descriptor = field_descriptor;
2772     cmessage::SetSubmessage(this, cmsg);
2773   }
2774   return cmsg;
2775 }
2776 
MaybeReleaseSubMessage(Message * sub_message)2777 CMessage* CMessage::MaybeReleaseSubMessage(Message* sub_message) {
2778   if (!this->child_submessages) {
2779     return nullptr;
2780   }
2781   CMessage* released = FindPtrOrNull(
2782       *this->child_submessages, sub_message);
2783   if (!released) {
2784     return nullptr;
2785   }
2786   // The target message will now own its content.
2787   Py_CLEAR(released->parent);
2788   released->parent_field_descriptor = nullptr;
2789   released->read_only = false;
2790   // Delete it from the cache.
2791   this->child_submessages->erase(sub_message);
2792   return released;
2793 }
2794 
2795 static CMessageClass _CMessage_Type = { { {
2796   PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0)
2797   FULL_MODULE_NAME ".CMessage",        // tp_name
2798   sizeof(CMessage),                    // tp_basicsize
2799   0,                                   //  tp_itemsize
2800   (destructor)cmessage::Dealloc,       //  tp_dealloc
2801   0,                                   //  tp_print
2802   0,                                   //  tp_getattr
2803   0,                                   //  tp_setattr
2804   0,                                   //  tp_compare
2805   (reprfunc)cmessage::ToStr,           //  tp_repr
2806   0,                                   //  tp_as_number
2807   0,                                   //  tp_as_sequence
2808   0,                                   //  tp_as_mapping
2809   PyObject_HashNotImplemented,         //  tp_hash
2810   0,                                   //  tp_call
2811   (reprfunc)cmessage::ToStr,           //  tp_str
2812   cmessage::GetAttr,                   //  tp_getattro
2813   0,                                   //  tp_setattro
2814   0,                                   //  tp_as_buffer
2815   Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE
2816       | Py_TPFLAGS_HAVE_VERSION_TAG,   //  tp_flags
2817   "A ProtocolMessage",                 //  tp_doc
2818   0,                                   //  tp_traverse
2819   0,                                   //  tp_clear
2820   (richcmpfunc)cmessage::RichCompare,  //  tp_richcompare
2821   offsetof(CMessage, weakreflist),     //  tp_weaklistoffset
2822   0,                                   //  tp_iter
2823   0,                                   //  tp_iternext
2824   cmessage::Methods,                   //  tp_methods
2825   0,                                   //  tp_members
2826   cmessage::Getters,                   //  tp_getset
2827   0,                                   //  tp_base
2828   0,                                   //  tp_dict
2829   0,                                   //  tp_descr_get
2830   0,                                   //  tp_descr_set
2831   0,                                   //  tp_dictoffset
2832   (initproc)cmessage::Init,            //  tp_init
2833   0,                                   //  tp_alloc
2834   cmessage::New,                       //  tp_new
2835 } } };
2836 PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type;
2837 
2838 // --- Exposing the C proto living inside Python proto to C code:
2839 
2840 const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
2841 Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
2842 
GetCProtoInsidePyProtoImpl(PyObject * msg)2843 static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
2844   const Message* message = PyMessage_GetMessagePointer(msg);
2845   if (message == NULL) {
2846     PyErr_Clear();
2847     return NULL;
2848   }
2849   return message;
2850 }
2851 
MutableCProtoInsidePyProtoImpl(PyObject * msg)2852 static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
2853   Message* message = PyMessage_GetMutableMessagePointer(msg);
2854   if (message == NULL) {
2855     PyErr_Clear();
2856     return NULL;
2857   }
2858   return message;
2859 }
2860 
PyMessage_GetMessagePointer(PyObject * msg)2861 const Message* PyMessage_GetMessagePointer(PyObject* msg) {
2862   if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2863     PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2864     return NULL;
2865   }
2866   CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2867   return cmsg->message;
2868 }
2869 
PyMessage_GetMutableMessagePointer(PyObject * msg)2870 Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
2871   if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2872     PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2873     return NULL;
2874   }
2875   CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2876 
2877 
2878   if ((cmsg->composite_fields && !cmsg->composite_fields->empty()) ||
2879       (cmsg->child_submessages && !cmsg->child_submessages->empty())) {
2880     // There is currently no way of accurately syncing arbitrary changes to
2881     // the underlying C++ message back to the CMessage (e.g. removed repeated
2882     // composite containers). We only allow direct mutation of the underlying
2883     // C++ message if there is no child data in the CMessage.
2884     PyErr_SetString(PyExc_ValueError,
2885                     "Cannot reliably get a mutable pointer "
2886                     "to a message with extra references");
2887     return NULL;
2888   }
2889   cmessage::AssureWritable(cmsg);
2890   return cmsg->message;
2891 }
2892 
PyMessage_NewMessageOwnedExternally(Message * message,PyObject * message_factory)2893 PyObject* PyMessage_NewMessageOwnedExternally(Message* message,
2894                                               PyObject* message_factory) {
2895   if (message_factory) {
2896     PyErr_SetString(PyExc_NotImplementedError,
2897                     "Default message_factory=NULL is the only supported value");
2898     return NULL;
2899   }
2900   if (message->GetReflection()->GetMessageFactory() !=
2901       MessageFactory::generated_factory()) {
2902     PyErr_SetString(PyExc_TypeError,
2903                     "Message pointer was not created from the default factory");
2904     return NULL;
2905   }
2906 
2907   CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2908       GetDefaultDescriptorPool()->py_message_factory, message->GetDescriptor());
2909 
2910   CMessage* self = cmessage::NewEmptyMessage(message_class);
2911   if (self == NULL) {
2912     return NULL;
2913   }
2914   Py_DECREF(message_class);
2915   self->message = message;
2916   Py_INCREF(Py_None);
2917   self->parent = reinterpret_cast<CMessage*>(Py_None);
2918   return self->AsPyObject();
2919 }
2920 
InitGlobals()2921 void InitGlobals() {
2922   // TODO(gps): Check all return values in this function for NULL and propagate
2923   // the error (MemoryError) on up to result in an import failure.  These should
2924   // also be freed and reset to NULL during finalization.
2925   kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
2926 
2927   PyObject *dummy_obj = PySet_New(NULL);
2928   kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
2929   Py_DECREF(dummy_obj);
2930 }
2931 
InitProto2MessageModule(PyObject * m)2932 bool InitProto2MessageModule(PyObject *m) {
2933   // Initialize types and globals in descriptor.cc
2934   if (!InitDescriptor()) {
2935     return false;
2936   }
2937 
2938   // Initialize types and globals in descriptor_pool.cc
2939   if (!InitDescriptorPool()) {
2940     return false;
2941   }
2942 
2943   // Initialize types and globals in message_factory.cc
2944   if (!InitMessageFactory()) {
2945     return false;
2946   }
2947 
2948   // Initialize constants defined in this file.
2949   InitGlobals();
2950 
2951   CMessageClass_Type->tp_base = &PyType_Type;
2952   if (PyType_Ready(CMessageClass_Type) < 0) {
2953     return false;
2954   }
2955   PyModule_AddObject(m, "MessageMeta",
2956                      reinterpret_cast<PyObject*>(CMessageClass_Type));
2957 
2958   if (PyType_Ready(CMessage_Type) < 0) {
2959     return false;
2960   }
2961   if (PyType_Ready(CFieldProperty_Type) < 0) {
2962     return false;
2963   }
2964 
2965   // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
2966   // it here as well to document that subclasses need to set it.
2967   PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None);
2968   // Invalidate any cached data for the CMessage type.
2969   // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG,
2970   // after we have modified CMessage_Type.tp_dict.
2971   PyType_Modified(CMessage_Type);
2972 
2973   PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type));
2974 
2975   // Initialize Repeated container types.
2976   {
2977     if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) {
2978       return false;
2979     }
2980 
2981     PyModule_AddObject(m, "RepeatedScalarContainer",
2982                        reinterpret_cast<PyObject*>(
2983                            &RepeatedScalarContainer_Type));
2984 
2985     if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) {
2986       return false;
2987     }
2988 
2989     PyModule_AddObject(
2990         m, "RepeatedCompositeContainer",
2991         reinterpret_cast<PyObject*>(
2992             &RepeatedCompositeContainer_Type));
2993 
2994     // Register them as collections.Sequence
2995     ScopedPyObjectPtr collections(PyImport_ImportModule("collections"));
2996     if (collections == NULL) {
2997       return false;
2998     }
2999     ScopedPyObjectPtr mutable_sequence(
3000         PyObject_GetAttrString(collections.get(), "MutableSequence"));
3001     if (mutable_sequence == NULL) {
3002       return false;
3003     }
3004     if (ScopedPyObjectPtr(
3005             PyObject_CallMethod(mutable_sequence.get(), "register", "O",
3006                                 &RepeatedScalarContainer_Type)) == NULL) {
3007       return false;
3008     }
3009     if (ScopedPyObjectPtr(
3010             PyObject_CallMethod(mutable_sequence.get(), "register", "O",
3011                                 &RepeatedCompositeContainer_Type)) == NULL) {
3012       return false;
3013     }
3014   }
3015 
3016   if (PyType_Ready(&PyUnknownFields_Type) < 0) {
3017     return false;
3018   }
3019 
3020   PyModule_AddObject(m, "UnknownFieldSet",
3021                      reinterpret_cast<PyObject*>(
3022                          &PyUnknownFields_Type));
3023 
3024   if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) {
3025     return false;
3026   }
3027 
3028   PyModule_AddObject(m, "UnknownField",
3029                      reinterpret_cast<PyObject*>(
3030                          &PyUnknownFieldRef_Type));
3031 
3032   // Initialize Map container types.
3033   if (!InitMapContainers()) {
3034     return false;
3035   }
3036   PyModule_AddObject(m, "ScalarMapContainer",
3037                      reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
3038   PyModule_AddObject(m, "MessageMapContainer",
3039                      reinterpret_cast<PyObject*>(MessageMapContainer_Type));
3040   PyModule_AddObject(m, "MapIterator",
3041                      reinterpret_cast<PyObject*>(&MapIterator_Type));
3042 
3043   if (PyType_Ready(&ExtensionDict_Type) < 0) {
3044     return false;
3045   }
3046   PyModule_AddObject(
3047       m, "ExtensionDict",
3048       reinterpret_cast<PyObject*>(&ExtensionDict_Type));
3049   if (PyType_Ready(&ExtensionIterator_Type) < 0) {
3050     return false;
3051   }
3052   PyModule_AddObject(m, "ExtensionIterator",
3053                      reinterpret_cast<PyObject*>(&ExtensionIterator_Type));
3054 
3055   // Expose the DescriptorPool used to hold all descriptors added from generated
3056   // pb2.py files.
3057   // PyModule_AddObject steals a reference.
3058   Py_INCREF(GetDefaultDescriptorPool());
3059   PyModule_AddObject(m, "default_pool",
3060                      reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
3061 
3062   PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>(
3063       &PyDescriptorPool_Type));
3064 
3065   // This implementation provides full Descriptor types, we advertise it so that
3066   // descriptor.py can use them in replacement of the Python classes.
3067   PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1);
3068 
3069   PyModule_AddObject(m, "Descriptor", reinterpret_cast<PyObject*>(
3070       &PyMessageDescriptor_Type));
3071   PyModule_AddObject(m, "FieldDescriptor", reinterpret_cast<PyObject*>(
3072       &PyFieldDescriptor_Type));
3073   PyModule_AddObject(m, "EnumDescriptor", reinterpret_cast<PyObject*>(
3074       &PyEnumDescriptor_Type));
3075   PyModule_AddObject(m, "EnumValueDescriptor", reinterpret_cast<PyObject*>(
3076       &PyEnumValueDescriptor_Type));
3077   PyModule_AddObject(m, "FileDescriptor", reinterpret_cast<PyObject*>(
3078       &PyFileDescriptor_Type));
3079   PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
3080       &PyOneofDescriptor_Type));
3081   PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>(
3082       &PyServiceDescriptor_Type));
3083   PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>(
3084       &PyMethodDescriptor_Type));
3085 
3086   PyObject* enum_type_wrapper = PyImport_ImportModule(
3087       "google.protobuf.internal.enum_type_wrapper");
3088   if (enum_type_wrapper == NULL) {
3089     return false;
3090   }
3091   EnumTypeWrapper_class =
3092       PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
3093   Py_DECREF(enum_type_wrapper);
3094 
3095   PyObject* message_module = PyImport_ImportModule(
3096       "google.protobuf.message");
3097   if (message_module == NULL) {
3098     return false;
3099   }
3100   EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError");
3101   DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError");
3102   PythonMessage_class = PyObject_GetAttrString(message_module, "Message");
3103   Py_DECREF(message_module);
3104 
3105   PyObject* pickle_module = PyImport_ImportModule("pickle");
3106   if (pickle_module == NULL) {
3107     return false;
3108   }
3109   PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError");
3110   Py_DECREF(pickle_module);
3111 
3112   // Override {Get,Mutable}CProtoInsidePyProto.
3113   GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl;
3114   MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl;
3115 
3116   return true;
3117 }
3118 
3119 }  // namespace python
3120 }  // namespace protobuf
3121 }  // namespace google
3122