1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33
34 #include <google/protobuf/pyext/message.h>
35
36 #include <map>
37 #include <memory>
38 #include <string>
39 #include <vector>
40 #include <structmember.h> // A Python header file.
41
42 #ifndef PyVarObject_HEAD_INIT
43 #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
44 #endif
45 #ifndef Py_TYPE
46 #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
47 #endif
48 #include <google/protobuf/stubs/common.h>
49 #include <google/protobuf/stubs/logging.h>
50 #include <google/protobuf/io/strtod.h>
51 #include <google/protobuf/io/coded_stream.h>
52 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
53 #include <google/protobuf/descriptor.pb.h>
54 #include <google/protobuf/descriptor.h>
55 #include <google/protobuf/message.h>
56 #include <google/protobuf/text_format.h>
57 #include <google/protobuf/unknown_field_set.h>
58 #include <google/protobuf/pyext/descriptor.h>
59 #include <google/protobuf/pyext/descriptor_pool.h>
60 #include <google/protobuf/pyext/extension_dict.h>
61 #include <google/protobuf/pyext/field.h>
62 #include <google/protobuf/pyext/map_container.h>
63 #include <google/protobuf/pyext/message_factory.h>
64 #include <google/protobuf/pyext/repeated_composite_container.h>
65 #include <google/protobuf/pyext/repeated_scalar_container.h>
66 #include <google/protobuf/pyext/unknown_fields.h>
67 #include <google/protobuf/pyext/safe_numerics.h>
68 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
69 #include <google/protobuf/util/message_differencer.h>
70 #include <google/protobuf/stubs/strutil.h>
71 #include <google/protobuf/stubs/map_util.h>
72
73 #include <google/protobuf/port_def.inc>
74
75 #if PY_MAJOR_VERSION >= 3
76 #define PyInt_AsLong PyLong_AsLong
77 #define PyInt_FromLong PyLong_FromLong
78 #define PyInt_FromSize_t PyLong_FromSize_t
79 #define PyString_Check PyUnicode_Check
80 #define PyString_FromString PyUnicode_FromString
81 #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
82 #define PyString_FromFormat PyUnicode_FromFormat
83 #if PY_VERSION_HEX < 0x03030000
84 #error "Python 3.0 - 3.2 are not supported."
85 #else
86 #define PyString_AsString(ob) \
87 (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob))
88 #define PyString_AsStringAndSize(ob, charpp, sizep) \
89 (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
90 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
91 ? -1 \
92 : 0) \
93 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
94 #endif
95 #endif
96
97 namespace google {
98 namespace protobuf {
99 namespace python {
100
101 static PyObject* kDESCRIPTOR;
102 PyObject* EnumTypeWrapper_class;
103 static PyObject* PythonMessage_class;
104 static PyObject* kEmptyWeakref;
105 static PyObject* WKT_classes = NULL;
106
107 namespace message_meta {
108
109 static int InsertEmptyWeakref(PyTypeObject* base);
110
111 namespace {
112 // Copied over from internal 'google/protobuf/stubs/strutil.h'.
LowerString(string * s)113 inline void LowerString(string * s) {
114 string::iterator end = s->end();
115 for (string::iterator i = s->begin(); i != end; ++i) {
116 // tolower() changes based on locale. We don't want this!
117 if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A';
118 }
119 }
120 }
121
122 // Finalize the creation of the Message class.
AddDescriptors(PyObject * cls,const Descriptor * descriptor)123 static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
124 // For each field set: cls.<field>_FIELD_NUMBER = <number>
125 for (int i = 0; i < descriptor->field_count(); ++i) {
126 const FieldDescriptor* field_descriptor = descriptor->field(i);
127 ScopedPyObjectPtr property(NewFieldProperty(field_descriptor));
128 if (property == NULL) {
129 return -1;
130 }
131 if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(),
132 property.get()) < 0) {
133 return -1;
134 }
135 }
136
137 // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
138 for (int i = 0; i < descriptor->enum_type_count(); ++i) {
139 const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
140 ScopedPyObjectPtr enum_type(
141 PyEnumDescriptor_FromDescriptor(enum_descriptor));
142 if (enum_type == NULL) {
143 return -1;
144 }
145 // Add wrapped enum type to message class.
146 ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
147 EnumTypeWrapper_class, enum_type.get(), NULL));
148 if (wrapped == NULL) {
149 return -1;
150 }
151 if (PyObject_SetAttrString(
152 cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) {
153 return -1;
154 }
155
156 // For each enum value add cls.<name> = <number>
157 for (int j = 0; j < enum_descriptor->value_count(); ++j) {
158 const EnumValueDescriptor* enum_value_descriptor =
159 enum_descriptor->value(j);
160 ScopedPyObjectPtr value_number(PyInt_FromLong(
161 enum_value_descriptor->number()));
162 if (value_number == NULL) {
163 return -1;
164 }
165 if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(),
166 value_number.get()) == -1) {
167 return -1;
168 }
169 }
170 }
171
172 // For each extension set cls.<extension name> = <extension descriptor>.
173 //
174 // Extension descriptors come from
175 // <message descriptor>.extensions_by_name[name]
176 // which was defined previously.
177 for (int i = 0; i < descriptor->extension_count(); ++i) {
178 const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
179 ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
180 if (extension_field == NULL) {
181 return -1;
182 }
183
184 // Add the extension field to the message class.
185 if (PyObject_SetAttrString(
186 cls, field->name().c_str(), extension_field.get()) == -1) {
187 return -1;
188 }
189 }
190
191 return 0;
192 }
193
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)194 static PyObject* New(PyTypeObject* type,
195 PyObject* args, PyObject* kwargs) {
196 static char *kwlist[] = {"name", "bases", "dict", 0};
197 PyObject *bases, *dict;
198 const char* name;
199
200 // Check arguments: (name, bases, dict)
201 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist,
202 &name,
203 &PyTuple_Type, &bases,
204 &PyDict_Type, &dict)) {
205 return NULL;
206 }
207
208 // Check bases: only (), or (message.Message,) are allowed
209 if (!(PyTuple_GET_SIZE(bases) == 0 ||
210 (PyTuple_GET_SIZE(bases) == 1 &&
211 PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) {
212 PyErr_SetString(PyExc_TypeError,
213 "A Message class can only inherit from Message");
214 return NULL;
215 }
216
217 // Check dict['DESCRIPTOR']
218 PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR);
219 if (py_descriptor == NULL) {
220 PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR");
221 return NULL;
222 }
223 if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) {
224 PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s",
225 py_descriptor->ob_type->tp_name);
226 return NULL;
227 }
228
229 // Messages have no __dict__
230 ScopedPyObjectPtr slots(PyTuple_New(0));
231 if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
232 return NULL;
233 }
234
235 // Build the arguments to the base metaclass.
236 // We change the __bases__ classes.
237 ScopedPyObjectPtr new_args;
238 const Descriptor* message_descriptor =
239 PyMessageDescriptor_AsDescriptor(py_descriptor);
240 if (message_descriptor == NULL) {
241 return NULL;
242 }
243
244 if (WKT_classes == NULL) {
245 ScopedPyObjectPtr well_known_types(PyImport_ImportModule(
246 "google.protobuf.internal.well_known_types"));
247 GOOGLE_DCHECK(well_known_types != NULL);
248
249 WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
250 GOOGLE_DCHECK(WKT_classes != NULL);
251 }
252
253 PyObject* well_known_class = PyDict_GetItemString(
254 WKT_classes, message_descriptor->full_name().c_str());
255 if (well_known_class == NULL) {
256 new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type,
257 PythonMessage_class, dict));
258 } else {
259 new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type,
260 PythonMessage_class, well_known_class, dict));
261 }
262
263 if (new_args == NULL) {
264 return NULL;
265 }
266 // Call the base metaclass.
267 ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL));
268 if (result == NULL) {
269 return NULL;
270 }
271 CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
272
273 // Insert the empty weakref into the base classes.
274 if (InsertEmptyWeakref(
275 reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 ||
276 InsertEmptyWeakref(CMessage_Type) < 0) {
277 return NULL;
278 }
279
280 // Cache the descriptor, both as Python object and as C++ pointer.
281 const Descriptor* descriptor =
282 PyMessageDescriptor_AsDescriptor(py_descriptor);
283 if (descriptor == NULL) {
284 return NULL;
285 }
286 Py_INCREF(py_descriptor);
287 newtype->py_message_descriptor = py_descriptor;
288 newtype->message_descriptor = descriptor;
289 // TODO(amauryfa): Don't always use the canonical pool of the descriptor,
290 // use the MessageFactory optionally passed in the class dict.
291 PyDescriptorPool* py_descriptor_pool =
292 GetDescriptorPool_FromPool(descriptor->file()->pool());
293 if (py_descriptor_pool == NULL) {
294 return NULL;
295 }
296 newtype->py_message_factory = py_descriptor_pool->py_message_factory;
297 Py_INCREF(newtype->py_message_factory);
298
299 // Register the message in the MessageFactory.
300 // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the
301 // MessageFactory is fully implemented in C++.
302 if (message_factory::RegisterMessageClass(newtype->py_message_factory,
303 descriptor, newtype) < 0) {
304 return NULL;
305 }
306
307 // Continue with type initialization: add other descriptors, enum values...
308 if (AddDescriptors(result.get(), descriptor) < 0) {
309 return NULL;
310 }
311 return result.release();
312 }
313
Dealloc(PyObject * pself)314 static void Dealloc(PyObject* pself) {
315 CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
316 Py_XDECREF(self->py_message_descriptor);
317 Py_XDECREF(self->py_message_factory);
318 return PyType_Type.tp_dealloc(pself);
319 }
320
GcTraverse(PyObject * pself,visitproc visit,void * arg)321 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
322 CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
323 Py_VISIT(self->py_message_descriptor);
324 Py_VISIT(self->py_message_factory);
325 return PyType_Type.tp_traverse(pself, visit, arg);
326 }
327
GcClear(PyObject * pself)328 static int GcClear(PyObject* pself) {
329 // It's important to keep the descriptor and factory alive, until the
330 // C++ message is fully destructed.
331 return PyType_Type.tp_clear(pself);
332 }
333
334 // This function inserts and empty weakref at the end of the list of
335 // subclasses for the main protocol buffer Message class.
336 //
337 // This eliminates a O(n^2) behaviour in the internal add_subclass
338 // routine.
InsertEmptyWeakref(PyTypeObject * base_type)339 static int InsertEmptyWeakref(PyTypeObject *base_type) {
340 #if PY_MAJOR_VERSION >= 3
341 // Python 3.4 has already included the fix for the issue that this
342 // hack addresses. For further background and the fix please see
343 // https://bugs.python.org/issue17936.
344 return 0;
345 #else
346 #ifdef Py_DEBUG
347 // The code below causes all new subclasses to append an entry, which is never
348 // cleared. This is a small memory leak, which we disable in Py_DEBUG mode
349 // to have stable refcounting checks.
350 #else
351 PyObject *subclasses = base_type->tp_subclasses;
352 if (subclasses && PyList_CheckExact(subclasses)) {
353 return PyList_Append(subclasses, kEmptyWeakref);
354 }
355 #endif // !Py_DEBUG
356 return 0;
357 #endif // PY_MAJOR_VERSION >= 3
358 }
359
360 // The _extensions_by_name dictionary is built on every access.
361 // TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
GetExtensionsByName(CMessageClass * self,void * closure)362 static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
363 if (self->message_descriptor == NULL) {
364 // This is the base Message object, simply raise AttributeError.
365 PyErr_SetString(PyExc_AttributeError,
366 "Base Message class has no DESCRIPTOR");
367 return NULL;
368 }
369
370 const PyDescriptorPool* pool = self->py_message_factory->pool;
371
372 std::vector<const FieldDescriptor*> extensions;
373 pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
374
375 ScopedPyObjectPtr result(PyDict_New());
376 for (int i = 0; i < extensions.size(); i++) {
377 ScopedPyObjectPtr extension(
378 PyFieldDescriptor_FromDescriptor(extensions[i]));
379 if (extension == NULL) {
380 return NULL;
381 }
382 if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
383 extension.get()) < 0) {
384 return NULL;
385 }
386 }
387 return result.release();
388 }
389
390 // The _extensions_by_number dictionary is built on every access.
391 // TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
GetExtensionsByNumber(CMessageClass * self,void * closure)392 static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
393 if (self->message_descriptor == NULL) {
394 // This is the base Message object, simply raise AttributeError.
395 PyErr_SetString(PyExc_AttributeError,
396 "Base Message class has no DESCRIPTOR");
397 return NULL;
398 }
399
400 const PyDescriptorPool* pool = self->py_message_factory->pool;
401
402 std::vector<const FieldDescriptor*> extensions;
403 pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
404
405 ScopedPyObjectPtr result(PyDict_New());
406 for (int i = 0; i < extensions.size(); i++) {
407 ScopedPyObjectPtr extension(
408 PyFieldDescriptor_FromDescriptor(extensions[i]));
409 if (extension == NULL) {
410 return NULL;
411 }
412 ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
413 if (number == NULL) {
414 return NULL;
415 }
416 if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
417 return NULL;
418 }
419 }
420 return result.release();
421 }
422
423 static PyGetSetDef Getters[] = {
424 {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
425 {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
426 {NULL}
427 };
428
429 // Compute some class attributes on the fly:
430 // - All the _FIELD_NUMBER attributes, for all fields and nested extensions.
431 // Returns a new reference, or NULL with an exception set.
GetClassAttribute(CMessageClass * self,PyObject * name)432 static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) {
433 char* attr;
434 Py_ssize_t attr_size;
435 static const char kSuffix[] = "_FIELD_NUMBER";
436 if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 &&
437 strings::EndsWith(StringPiece(attr, attr_size), kSuffix)) {
438 string field_name(attr, attr_size - sizeof(kSuffix) + 1);
439 LowerString(&field_name);
440
441 // Try to find a field with the given name, without the suffix.
442 const FieldDescriptor* field =
443 self->message_descriptor->FindFieldByLowercaseName(field_name);
444 if (!field) {
445 // Search nested extensions as well.
446 field =
447 self->message_descriptor->FindExtensionByLowercaseName(field_name);
448 }
449 if (field) {
450 return PyInt_FromLong(field->number());
451 }
452 }
453 PyErr_SetObject(PyExc_AttributeError, name);
454 return NULL;
455 }
456
GetAttr(CMessageClass * self,PyObject * name)457 static PyObject* GetAttr(CMessageClass* self, PyObject* name) {
458 PyObject* result = CMessageClass_Type->tp_base->tp_getattro(
459 reinterpret_cast<PyObject*>(self), name);
460 if (result != NULL) {
461 return result;
462 }
463 if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
464 return NULL;
465 }
466
467 PyErr_Clear();
468 return GetClassAttribute(self, name);
469 }
470
471 } // namespace message_meta
472
473 static PyTypeObject _CMessageClass_Type = {
474 PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
475 ".MessageMeta", // tp_name
476 sizeof(CMessageClass), // tp_basicsize
477 0, // tp_itemsize
478 message_meta::Dealloc, // tp_dealloc
479 0, // tp_print
480 0, // tp_getattr
481 0, // tp_setattr
482 0, // tp_compare
483 0, // tp_repr
484 0, // tp_as_number
485 0, // tp_as_sequence
486 0, // tp_as_mapping
487 0, // tp_hash
488 0, // tp_call
489 0, // tp_str
490 (getattrofunc)message_meta::GetAttr, // tp_getattro
491 0, // tp_setattro
492 0, // tp_as_buffer
493 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags
494 "The metaclass of ProtocolMessages", // tp_doc
495 message_meta::GcTraverse, // tp_traverse
496 message_meta::GcClear, // tp_clear
497 0, // tp_richcompare
498 0, // tp_weaklistoffset
499 0, // tp_iter
500 0, // tp_iternext
501 0, // tp_methods
502 0, // tp_members
503 message_meta::Getters, // tp_getset
504 0, // tp_base
505 0, // tp_dict
506 0, // tp_descr_get
507 0, // tp_descr_set
508 0, // tp_dictoffset
509 0, // tp_init
510 0, // tp_alloc
511 message_meta::New, // tp_new
512 };
513 PyTypeObject* CMessageClass_Type = &_CMessageClass_Type;
514
CheckMessageClass(PyTypeObject * cls)515 static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
516 if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
517 PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
518 return NULL;
519 }
520 return reinterpret_cast<CMessageClass*>(cls);
521 }
522
GetMessageDescriptor(PyTypeObject * cls)523 static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
524 CMessageClass* type = CheckMessageClass(cls);
525 if (type == NULL) {
526 return NULL;
527 }
528 return type->message_descriptor;
529 }
530
531 // Forward declarations
532 namespace cmessage {
533 int InternalReleaseFieldByDescriptor(
534 CMessage* self,
535 const FieldDescriptor* field_descriptor);
536 } // namespace cmessage
537
538 // ---------------------------------------------------------------------
539
540 PyObject* EncodeError_class;
541 PyObject* DecodeError_class;
542 PyObject* PickleError_class;
543
544 // Format an error message for unexpected types.
545 // Always return with an exception set.
FormatTypeError(PyObject * arg,char * expected_types)546 void FormatTypeError(PyObject* arg, char* expected_types) {
547 // This function is often called with an exception set.
548 // Clear it to call PyObject_Repr() in good conditions.
549 PyErr_Clear();
550 PyObject* repr = PyObject_Repr(arg);
551 if (repr) {
552 PyErr_Format(PyExc_TypeError,
553 "%.100s has type %.100s, but expected one of: %s",
554 PyString_AsString(repr),
555 Py_TYPE(arg)->tp_name,
556 expected_types);
557 Py_DECREF(repr);
558 }
559 }
560
OutOfRangeError(PyObject * arg)561 void OutOfRangeError(PyObject* arg) {
562 PyObject *s = PyObject_Str(arg);
563 if (s) {
564 PyErr_Format(PyExc_ValueError,
565 "Value out of range: %s",
566 PyString_AsString(s));
567 Py_DECREF(s);
568 }
569 }
570
571 template<class RangeType, class ValueType>
VerifyIntegerCastAndRange(PyObject * arg,ValueType value)572 bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
573 if (PROTOBUF_PREDICT_FALSE(value == -1 && PyErr_Occurred())) {
574 if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
575 // Replace it with the same ValueError as pure python protos instead of
576 // the default one.
577 PyErr_Clear();
578 OutOfRangeError(arg);
579 } // Otherwise propagate existing error.
580 return false;
581 }
582 if (PROTOBUF_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) {
583 OutOfRangeError(arg);
584 return false;
585 }
586 return true;
587 }
588
589 template<class T>
CheckAndGetInteger(PyObject * arg,T * value)590 bool CheckAndGetInteger(PyObject* arg, T* value) {
591 // The fast path.
592 #if PY_MAJOR_VERSION < 3
593 // For the typical case, offer a fast path.
594 if (PROTOBUF_PREDICT_TRUE(PyInt_Check(arg))) {
595 long int_result = PyInt_AsLong(arg);
596 if (PROTOBUF_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) {
597 *value = static_cast<T>(int_result);
598 return true;
599 } else {
600 OutOfRangeError(arg);
601 return false;
602 }
603 }
604 #endif
605 // This effectively defines an integer as "an object that can be cast as
606 // an integer and can be used as an ordinal number".
607 // This definition includes everything that implements numbers.Integral
608 // and shouldn't cast the net too wide.
609 if (PROTOBUF_PREDICT_FALSE(!PyIndex_Check(arg))) {
610 FormatTypeError(arg, "int, long");
611 return false;
612 }
613
614 // Now we have an integral number so we can safely use PyLong_ functions.
615 // We need to treat the signed and unsigned cases differently in case arg is
616 // holding a value above the maximum for signed longs.
617 if (std::numeric_limits<T>::min() == 0) {
618 // Unsigned case.
619 unsigned PY_LONG_LONG ulong_result;
620 if (PyLong_Check(arg)) {
621 ulong_result = PyLong_AsUnsignedLongLong(arg);
622 } else {
623 // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
624 // picky about the exact type.
625 PyObject* casted = PyNumber_Long(arg);
626 if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
627 // Propagate existing error.
628 return false;
629 }
630 ulong_result = PyLong_AsUnsignedLongLong(casted);
631 Py_DECREF(casted);
632 }
633 if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
634 ulong_result)) {
635 *value = static_cast<T>(ulong_result);
636 } else {
637 return false;
638 }
639 } else {
640 // Signed case.
641 PY_LONG_LONG long_result;
642 PyNumberMethods *nb;
643 if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
644 // PyLong_AsLongLong requires it to be a long or to have an __int__()
645 // method.
646 long_result = PyLong_AsLongLong(arg);
647 } else {
648 // Valid subclasses of numbers.Integral should have a __long__() method
649 // so fall back to that.
650 PyObject* casted = PyNumber_Long(arg);
651 if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
652 // Propagate existing error.
653 return false;
654 }
655 long_result = PyLong_AsLongLong(casted);
656 Py_DECREF(casted);
657 }
658 if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
659 *value = static_cast<T>(long_result);
660 } else {
661 return false;
662 }
663 }
664
665 return true;
666 }
667
668 // These are referenced by repeated_scalar_container, and must
669 // be explicitly instantiated.
670 template bool CheckAndGetInteger<int32>(PyObject*, int32*);
671 template bool CheckAndGetInteger<int64>(PyObject*, int64*);
672 template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
673 template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
674
CheckAndGetDouble(PyObject * arg,double * value)675 bool CheckAndGetDouble(PyObject* arg, double* value) {
676 *value = PyFloat_AsDouble(arg);
677 if (PROTOBUF_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) {
678 FormatTypeError(arg, "int, long, float");
679 return false;
680 }
681 return true;
682 }
683
CheckAndGetFloat(PyObject * arg,float * value)684 bool CheckAndGetFloat(PyObject* arg, float* value) {
685 double double_value;
686 if (!CheckAndGetDouble(arg, &double_value)) {
687 return false;
688 }
689 *value = io::SafeDoubleToFloat(double_value);
690 return true;
691 }
692
CheckAndGetBool(PyObject * arg,bool * value)693 bool CheckAndGetBool(PyObject* arg, bool* value) {
694 long long_value = PyInt_AsLong(arg);
695 if (long_value == -1 && PyErr_Occurred()) {
696 FormatTypeError(arg, "int, long, bool");
697 return false;
698 }
699 *value = static_cast<bool>(long_value);
700
701 return true;
702 }
703
704 // Checks whether the given object (which must be "bytes" or "unicode") contains
705 // valid UTF-8.
IsValidUTF8(PyObject * obj)706 bool IsValidUTF8(PyObject* obj) {
707 if (PyBytes_Check(obj)) {
708 PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL);
709
710 // Clear the error indicator; we report our own error when desired.
711 PyErr_Clear();
712
713 if (unicode) {
714 Py_DECREF(unicode);
715 return true;
716 } else {
717 return false;
718 }
719 } else {
720 // Unicode object, known to be valid UTF-8.
721 return true;
722 }
723 }
724
AllowInvalidUTF8(const FieldDescriptor * field)725 bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; }
726
CheckString(PyObject * arg,const FieldDescriptor * descriptor)727 PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
728 GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING ||
729 descriptor->type() == FieldDescriptor::TYPE_BYTES);
730 if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
731 if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
732 FormatTypeError(arg, "bytes, unicode");
733 return NULL;
734 }
735
736 if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) {
737 PyObject* repr = PyObject_Repr(arg);
738 PyErr_Format(PyExc_ValueError,
739 "%s has type str, but isn't valid UTF-8 "
740 "encoding. Non-UTF-8 strings must be converted to "
741 "unicode objects before being added.",
742 PyString_AsString(repr));
743 Py_DECREF(repr);
744 return NULL;
745 }
746 } else if (!PyBytes_Check(arg)) {
747 FormatTypeError(arg, "bytes");
748 return NULL;
749 }
750
751 PyObject* encoded_string = NULL;
752 if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
753 if (PyBytes_Check(arg)) {
754 // The bytes were already validated as correctly encoded UTF-8 above.
755 encoded_string = arg; // Already encoded.
756 Py_INCREF(encoded_string);
757 } else {
758 encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
759 }
760 } else {
761 // In this case field type is "bytes".
762 encoded_string = arg;
763 Py_INCREF(encoded_string);
764 }
765
766 return encoded_string;
767 }
768
CheckAndSetString(PyObject * arg,Message * message,const FieldDescriptor * descriptor,const Reflection * reflection,bool append,int index)769 bool CheckAndSetString(
770 PyObject* arg, Message* message,
771 const FieldDescriptor* descriptor,
772 const Reflection* reflection,
773 bool append,
774 int index) {
775 ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor));
776
777 if (encoded_string.get() == NULL) {
778 return false;
779 }
780
781 char* value;
782 Py_ssize_t value_len;
783 if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) {
784 return false;
785 }
786
787 string value_string(value, value_len);
788 if (append) {
789 reflection->AddString(message, descriptor, value_string);
790 } else if (index < 0) {
791 reflection->SetString(message, descriptor, value_string);
792 } else {
793 reflection->SetRepeatedString(message, descriptor, index, value_string);
794 }
795 return true;
796 }
797
ToStringObject(const FieldDescriptor * descriptor,const string & value)798 PyObject* ToStringObject(const FieldDescriptor* descriptor,
799 const string& value) {
800 if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
801 return PyBytes_FromStringAndSize(value.c_str(), value.length());
802 }
803
804 PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL);
805 // If the string can't be decoded in UTF-8, just return a string object that
806 // contains the raw bytes. This can't happen if the value was assigned using
807 // the members of the Python message object, but can happen if the values were
808 // parsed from the wire (binary).
809 if (result == NULL) {
810 PyErr_Clear();
811 result = PyBytes_FromStringAndSize(value.c_str(), value.length());
812 }
813 return result;
814 }
815
CheckFieldBelongsToMessage(const FieldDescriptor * field_descriptor,const Message * message)816 bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
817 const Message* message) {
818 if (message->GetDescriptor() == field_descriptor->containing_type()) {
819 return true;
820 }
821 PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'",
822 field_descriptor->full_name().c_str(),
823 message->GetDescriptor()->full_name().c_str());
824 return false;
825 }
826
827 namespace cmessage {
828
GetFactoryForMessage(CMessage * message)829 PyMessageFactory* GetFactoryForMessage(CMessage* message) {
830 GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type));
831 return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
832 }
833
MaybeReleaseOverlappingOneofField(CMessage * cmessage,const FieldDescriptor * field)834 static int MaybeReleaseOverlappingOneofField(
835 CMessage* cmessage,
836 const FieldDescriptor* field) {
837 #ifdef GOOGLE_PROTOBUF_HAS_ONEOF
838 Message* message = cmessage->message;
839 const Reflection* reflection = message->GetReflection();
840 if (!field->containing_oneof() ||
841 !reflection->HasOneof(*message, field->containing_oneof()) ||
842 reflection->HasField(*message, field)) {
843 // No other field in this oneof, no need to release.
844 return 0;
845 }
846
847 const OneofDescriptor* oneof = field->containing_oneof();
848 const FieldDescriptor* existing_field =
849 reflection->GetOneofFieldDescriptor(*message, oneof);
850 if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
851 // Non-message fields don't need to be released.
852 return 0;
853 }
854 if (InternalReleaseFieldByDescriptor(cmessage, existing_field) < 0) {
855 return -1;
856 }
857 #endif
858 return 0;
859 }
860
861 // After a Merge, visit every sub-message that was read-only, and
862 // eventually update their pointer if the Merge operation modified them.
FixupMessageAfterMerge(CMessage * self)863 int FixupMessageAfterMerge(CMessage* self) {
864 if (!self->composite_fields) {
865 return 0;
866 }
867 for (const auto& item : *self->composite_fields) {
868 const FieldDescriptor* descriptor = item.first;
869 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
870 !descriptor->is_repeated()) {
871 CMessage* cmsg = reinterpret_cast<CMessage*>(item.second);
872 if (cmsg->read_only == false) {
873 return 0;
874 }
875 Message* message = self->message;
876 const Reflection* reflection = message->GetReflection();
877 if (reflection->HasField(*message, descriptor)) {
878 // Message used to be read_only, but is no longer. Get the new pointer
879 // and record it.
880 Message* mutable_message =
881 reflection->MutableMessage(message, descriptor, nullptr);
882 cmsg->message = mutable_message;
883 cmsg->read_only = false;
884 if (FixupMessageAfterMerge(cmsg) < 0) {
885 return -1;
886 }
887 }
888 }
889 }
890
891 return 0;
892 }
893
894 // ---------------------------------------------------------------------
895 // Making a message writable
896
AssureWritable(CMessage * self)897 int AssureWritable(CMessage* self) {
898 if (self == NULL || !self->read_only) {
899 return 0;
900 }
901
902 // Toplevel messages are always mutable.
903 GOOGLE_DCHECK(self->parent);
904
905 if (AssureWritable(self->parent) == -1)
906 return -1;
907
908 // If this message is part of a oneof, there might be a field to release in
909 // the parent.
910 if (MaybeReleaseOverlappingOneofField(self->parent,
911 self->parent_field_descriptor) < 0) {
912 return -1;
913 }
914
915 // Make self->message writable.
916 Message* parent_message = self->parent->message;
917 const Reflection* reflection = parent_message->GetReflection();
918 Message* mutable_message = reflection->MutableMessage(
919 parent_message, self->parent_field_descriptor,
920 GetFactoryForMessage(self->parent)->message_factory);
921 if (mutable_message == NULL) {
922 return -1;
923 }
924 self->message = mutable_message;
925 self->read_only = false;
926
927 return 0;
928 }
929
930 // --- Globals:
931
932 // Retrieve a C++ FieldDescriptor for an extension handle.
GetExtensionDescriptor(PyObject * extension)933 const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
934 ScopedPyObjectPtr cdescriptor;
935 if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) {
936 // Most callers consider extensions as a plain dictionary. We should
937 // allow input which is not a field descriptor, and simply pretend it does
938 // not exist.
939 PyErr_SetObject(PyExc_KeyError, extension);
940 return NULL;
941 }
942 return PyFieldDescriptor_AsDescriptor(extension);
943 }
944
945 // If value is a string, convert it into an enum value based on the labels in
946 // descriptor, otherwise simply return value. Always returns a new reference.
GetIntegerEnumValue(const FieldDescriptor & descriptor,PyObject * value)947 static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
948 PyObject* value) {
949 if (PyString_Check(value) || PyUnicode_Check(value)) {
950 const EnumDescriptor* enum_descriptor = descriptor.enum_type();
951 if (enum_descriptor == NULL) {
952 PyErr_SetString(PyExc_TypeError, "not an enum field");
953 return NULL;
954 }
955 char* enum_label;
956 Py_ssize_t size;
957 if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) {
958 return NULL;
959 }
960 const EnumValueDescriptor* enum_value_descriptor =
961 enum_descriptor->FindValueByName(string(enum_label, size));
962 if (enum_value_descriptor == NULL) {
963 PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label);
964 return NULL;
965 }
966 return PyInt_FromLong(enum_value_descriptor->number());
967 }
968 Py_INCREF(value);
969 return value;
970 }
971
972 // Delete a slice from a repeated field.
973 // The only way to remove items in C++ protos is to delete the last one,
974 // so we swap items to move the deleted ones at the end, and then strip the
975 // sequence.
DeleteRepeatedField(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * slice)976 int DeleteRepeatedField(
977 CMessage* self,
978 const FieldDescriptor* field_descriptor,
979 PyObject* slice) {
980 Py_ssize_t length, from, to, step, slice_length;
981 Message* message = self->message;
982 const Reflection* reflection = message->GetReflection();
983 int min, max;
984 length = reflection->FieldSize(*message, field_descriptor);
985
986 if (PySlice_Check(slice)) {
987 from = to = step = slice_length = 0;
988 #if PY_MAJOR_VERSION < 3
989 PySlice_GetIndicesEx(
990 reinterpret_cast<PySliceObject*>(slice),
991 length, &from, &to, &step, &slice_length);
992 #else
993 PySlice_GetIndicesEx(
994 slice,
995 length, &from, &to, &step, &slice_length);
996 #endif
997 if (from < to) {
998 min = from;
999 max = to - 1;
1000 } else {
1001 min = to + 1;
1002 max = from;
1003 }
1004 } else {
1005 from = to = PyLong_AsLong(slice);
1006 if (from == -1 && PyErr_Occurred()) {
1007 PyErr_SetString(PyExc_TypeError, "list indices must be integers");
1008 return -1;
1009 }
1010
1011 if (from < 0) {
1012 from = to = length + from;
1013 }
1014 step = 1;
1015 min = max = from;
1016
1017 // Range check.
1018 if (from < 0 || from >= length) {
1019 PyErr_Format(PyExc_IndexError, "list assignment index out of range");
1020 return -1;
1021 }
1022 }
1023
1024 Py_ssize_t i = from;
1025 std::vector<bool> to_delete(length, false);
1026 while (i >= min && i <= max) {
1027 to_delete[i] = true;
1028 i += step;
1029 }
1030
1031 // Swap elements so that items to delete are at the end.
1032 to = 0;
1033 for (i = 0; i < length; ++i) {
1034 if (!to_delete[i]) {
1035 if (i != to) {
1036 reflection->SwapElements(message, field_descriptor, i, to);
1037 }
1038 ++to;
1039 }
1040 }
1041
1042 // Remove items, starting from the end.
1043 for (; length > to; length--) {
1044 if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1045 reflection->RemoveLast(message, field_descriptor);
1046 continue;
1047 }
1048 // It seems that RemoveLast() is less efficient for sub-messages, and
1049 // the memory is not completely released. Prefer ReleaseLast().
1050 Message* sub_message = reflection->ReleaseLast(message, field_descriptor);
1051 // If there is a live weak reference to an item being removed, we "Release"
1052 // it, and it takes ownership of the message.
1053 if (CMessage* released = self->MaybeReleaseSubMessage(sub_message)) {
1054 released->message = sub_message;
1055 } else {
1056 // sub_message was not transferred, delete it.
1057 delete sub_message;
1058 }
1059 }
1060
1061 return 0;
1062 }
1063
1064 // Initializes fields of a message. Used in constructors.
InitAttributes(CMessage * self,PyObject * args,PyObject * kwargs)1065 int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
1066 if (args != NULL && PyTuple_Size(args) != 0) {
1067 PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
1068 return -1;
1069 }
1070
1071 if (kwargs == NULL) {
1072 return 0;
1073 }
1074
1075 Py_ssize_t pos = 0;
1076 PyObject* name;
1077 PyObject* value;
1078 while (PyDict_Next(kwargs, &pos, &name, &value)) {
1079 if (!(PyString_Check(name) || PyUnicode_Check(name))) {
1080 PyErr_SetString(PyExc_ValueError, "Field name must be a string");
1081 return -1;
1082 }
1083 ScopedPyObjectPtr property(
1084 PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name));
1085 if (property == NULL ||
1086 !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) {
1087 PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
1088 self->message->GetDescriptor()->name().c_str(),
1089 PyString_AsString(name));
1090 return -1;
1091 }
1092 const FieldDescriptor* descriptor =
1093 reinterpret_cast<PyMessageFieldProperty*>(property.get())
1094 ->field_descriptor;
1095 if (value == Py_None) {
1096 // field=None is the same as no field at all.
1097 continue;
1098 }
1099 if (descriptor->is_map()) {
1100 ScopedPyObjectPtr map(GetFieldValue(self, descriptor));
1101 const FieldDescriptor* value_descriptor =
1102 descriptor->message_type()->FindFieldByName("value");
1103 if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1104 ScopedPyObjectPtr iter(PyObject_GetIter(value));
1105 if (iter == NULL) {
1106 PyErr_Format(PyExc_TypeError, "Argument %s is not iterable", PyString_AsString(name));
1107 return -1;
1108 }
1109 ScopedPyObjectPtr next;
1110 while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1111 ScopedPyObjectPtr source_value(PyObject_GetItem(value, next.get()));
1112 ScopedPyObjectPtr dest_value(PyObject_GetItem(map.get(), next.get()));
1113 if (source_value.get() == NULL || dest_value.get() == NULL) {
1114 return -1;
1115 }
1116 ScopedPyObjectPtr ok(PyObject_CallMethod(
1117 dest_value.get(), "MergeFrom", "O", source_value.get()));
1118 if (ok.get() == NULL) {
1119 return -1;
1120 }
1121 }
1122 } else {
1123 ScopedPyObjectPtr function_return;
1124 function_return.reset(
1125 PyObject_CallMethod(map.get(), "update", "O", value));
1126 if (function_return.get() == NULL) {
1127 return -1;
1128 }
1129 }
1130 } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1131 ScopedPyObjectPtr container(GetFieldValue(self, descriptor));
1132 if (container == NULL) {
1133 return -1;
1134 }
1135 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1136 RepeatedCompositeContainer* rc_container =
1137 reinterpret_cast<RepeatedCompositeContainer*>(container.get());
1138 ScopedPyObjectPtr iter(PyObject_GetIter(value));
1139 if (iter == NULL) {
1140 PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1141 return -1;
1142 }
1143 ScopedPyObjectPtr next;
1144 while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1145 PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL);
1146 ScopedPyObjectPtr new_msg(
1147 repeated_composite_container::Add(rc_container, NULL, kwargs));
1148 if (new_msg == NULL) {
1149 return -1;
1150 }
1151 if (kwargs == NULL) {
1152 // next was not a dict, it's a message we need to merge
1153 ScopedPyObjectPtr merged(MergeFrom(
1154 reinterpret_cast<CMessage*>(new_msg.get()), next.get()));
1155 if (merged.get() == NULL) {
1156 return -1;
1157 }
1158 }
1159 }
1160 if (PyErr_Occurred()) {
1161 // Check to see how PyIter_Next() exited.
1162 return -1;
1163 }
1164 } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1165 RepeatedScalarContainer* rs_container =
1166 reinterpret_cast<RepeatedScalarContainer*>(container.get());
1167 ScopedPyObjectPtr iter(PyObject_GetIter(value));
1168 if (iter == NULL) {
1169 PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1170 return -1;
1171 }
1172 ScopedPyObjectPtr next;
1173 while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1174 ScopedPyObjectPtr enum_value(
1175 GetIntegerEnumValue(*descriptor, next.get()));
1176 if (enum_value == NULL) {
1177 return -1;
1178 }
1179 ScopedPyObjectPtr new_msg(repeated_scalar_container::Append(
1180 rs_container, enum_value.get()));
1181 if (new_msg == NULL) {
1182 return -1;
1183 }
1184 }
1185 if (PyErr_Occurred()) {
1186 // Check to see how PyIter_Next() exited.
1187 return -1;
1188 }
1189 } else {
1190 if (ScopedPyObjectPtr(repeated_scalar_container::Extend(
1191 reinterpret_cast<RepeatedScalarContainer*>(container.get()),
1192 value)) ==
1193 NULL) {
1194 return -1;
1195 }
1196 }
1197 } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1198 ScopedPyObjectPtr message(GetFieldValue(self, descriptor));
1199 if (message == NULL) {
1200 return -1;
1201 }
1202 CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
1203 if (PyDict_Check(value)) {
1204 // Make the message exist even if the dict is empty.
1205 AssureWritable(cmessage);
1206 if (InitAttributes(cmessage, NULL, value) < 0) {
1207 return -1;
1208 }
1209 } else {
1210 ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
1211 if (merged == NULL) {
1212 return -1;
1213 }
1214 }
1215 } else {
1216 ScopedPyObjectPtr new_val;
1217 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1218 new_val.reset(GetIntegerEnumValue(*descriptor, value));
1219 if (new_val == NULL) {
1220 return -1;
1221 }
1222 value = new_val.get();
1223 }
1224 if (SetFieldValue(self, descriptor, value) < 0) {
1225 return -1;
1226 }
1227 }
1228 }
1229 return 0;
1230 }
1231
1232 // Allocates an incomplete Python Message: the caller must fill self->message
1233 // and eventually self->parent.
NewEmptyMessage(CMessageClass * type)1234 CMessage* NewEmptyMessage(CMessageClass* type) {
1235 CMessage* self = reinterpret_cast<CMessage*>(
1236 PyType_GenericAlloc(&type->super.ht_type, 0));
1237 if (self == NULL) {
1238 return NULL;
1239 }
1240
1241 self->message = NULL;
1242 self->parent = NULL;
1243 self->parent_field_descriptor = NULL;
1244 self->read_only = false;
1245
1246 self->composite_fields = NULL;
1247 self->child_submessages = NULL;
1248
1249 self->unknown_field_set = NULL;
1250
1251 return self;
1252 }
1253
1254 // The __new__ method of Message classes.
1255 // Creates a new C++ message and takes ownership.
New(PyTypeObject * cls,PyObject * unused_args,PyObject * unused_kwargs)1256 static PyObject* New(PyTypeObject* cls,
1257 PyObject* unused_args, PyObject* unused_kwargs) {
1258 CMessageClass* type = CheckMessageClass(cls);
1259 if (type == NULL) {
1260 return NULL;
1261 }
1262 // Retrieve the message descriptor and the default instance (=prototype).
1263 const Descriptor* message_descriptor = type->message_descriptor;
1264 if (message_descriptor == NULL) {
1265 return NULL;
1266 }
1267 const Message* prototype =
1268 type->py_message_factory->message_factory->GetPrototype(
1269 message_descriptor);
1270 if (prototype == NULL) {
1271 PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
1272 return NULL;
1273 }
1274
1275 CMessage* self = NewEmptyMessage(type);
1276 if (self == NULL) {
1277 return NULL;
1278 }
1279 self->message = prototype->New();
1280 self->parent = nullptr; // This message owns its data.
1281 return reinterpret_cast<PyObject*>(self);
1282 }
1283
1284 // The __init__ method of Message classes.
1285 // It initializes fields from keywords passed to the constructor.
Init(CMessage * self,PyObject * args,PyObject * kwargs)1286 static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
1287 return InitAttributes(self, args, kwargs);
1288 }
1289
1290 // ---------------------------------------------------------------------
1291 // Deallocating a CMessage
1292
Dealloc(CMessage * self)1293 static void Dealloc(CMessage* self) {
1294 if (self->weakreflist) {
1295 PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
1296 }
1297 // At this point all dependent objects have been removed.
1298 GOOGLE_DCHECK(!self->child_submessages || self->child_submessages->empty());
1299 GOOGLE_DCHECK(!self->composite_fields || self->composite_fields->empty());
1300 delete self->child_submessages;
1301 delete self->composite_fields;
1302 if (self->unknown_field_set) {
1303 unknown_fields::Clear(
1304 reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1305 }
1306
1307 CMessage* parent = self->parent;
1308 if (!parent) {
1309 // No parent, we own the message.
1310 delete self->message;
1311 } else if (parent->AsPyObject() == Py_None) {
1312 // Message owned externally: Nothing to dealloc
1313 Py_CLEAR(self->parent);
1314 } else {
1315 // Clear this message from its parent's map.
1316 if (self->parent_field_descriptor->is_repeated()) {
1317 if (parent->child_submessages)
1318 parent->child_submessages->erase(self->message);
1319 } else {
1320 if (parent->composite_fields)
1321 parent->composite_fields->erase(self->parent_field_descriptor);
1322 }
1323 Py_CLEAR(self->parent);
1324 }
1325 Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
1326 }
1327
1328 // ---------------------------------------------------------------------
1329
1330
IsInitialized(CMessage * self,PyObject * args)1331 PyObject* IsInitialized(CMessage* self, PyObject* args) {
1332 PyObject* errors = NULL;
1333 if (PyArg_ParseTuple(args, "|O", &errors) < 0) {
1334 return NULL;
1335 }
1336 if (self->message->IsInitialized()) {
1337 Py_RETURN_TRUE;
1338 }
1339 if (errors != NULL) {
1340 ScopedPyObjectPtr initialization_errors(
1341 FindInitializationErrors(self));
1342 if (initialization_errors == NULL) {
1343 return NULL;
1344 }
1345 ScopedPyObjectPtr extend_name(PyString_FromString("extend"));
1346 if (extend_name == NULL) {
1347 return NULL;
1348 }
1349 ScopedPyObjectPtr result(PyObject_CallMethodObjArgs(
1350 errors,
1351 extend_name.get(),
1352 initialization_errors.get(),
1353 NULL));
1354 if (result == NULL) {
1355 return NULL;
1356 }
1357 }
1358 Py_RETURN_FALSE;
1359 }
1360
HasFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1361 PyObject* HasFieldByDescriptor(
1362 CMessage* self, const FieldDescriptor* field_descriptor) {
1363 Message* message = self->message;
1364 if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
1365 return NULL;
1366 }
1367 if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1368 PyErr_SetString(PyExc_KeyError,
1369 "Field is repeated. A singular method is required.");
1370 return NULL;
1371 }
1372 bool has_field =
1373 message->GetReflection()->HasField(*message, field_descriptor);
1374 return PyBool_FromLong(has_field ? 1 : 0);
1375 }
1376
FindFieldWithOneofs(const Message * message,const string & field_name,bool * in_oneof)1377 const FieldDescriptor* FindFieldWithOneofs(
1378 const Message* message, const string& field_name, bool* in_oneof) {
1379 *in_oneof = false;
1380 const Descriptor* descriptor = message->GetDescriptor();
1381 const FieldDescriptor* field_descriptor =
1382 descriptor->FindFieldByName(field_name);
1383 if (field_descriptor != NULL) {
1384 return field_descriptor;
1385 }
1386 const OneofDescriptor* oneof_desc =
1387 descriptor->FindOneofByName(field_name);
1388 if (oneof_desc != NULL) {
1389 *in_oneof = true;
1390 return message->GetReflection()->GetOneofFieldDescriptor(*message,
1391 oneof_desc);
1392 }
1393 return NULL;
1394 }
1395
CheckHasPresence(const FieldDescriptor * field_descriptor,bool in_oneof)1396 bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) {
1397 auto message_name = field_descriptor->containing_type()->name();
1398 if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1399 PyErr_Format(PyExc_ValueError,
1400 "Protocol message %s has no singular \"%s\" field.",
1401 message_name.c_str(), field_descriptor->name().c_str());
1402 return false;
1403 }
1404
1405 if (field_descriptor->file()->syntax() == FileDescriptor::SYNTAX_PROTO3) {
1406 // HasField() for a oneof *itself* isn't supported.
1407 if (in_oneof) {
1408 PyErr_Format(PyExc_ValueError,
1409 "Can't test oneof field \"%s.%s\" for presence in proto3, "
1410 "use WhichOneof instead.", message_name.c_str(),
1411 field_descriptor->containing_oneof()->name().c_str());
1412 return false;
1413 }
1414
1415 // ...but HasField() for fields *in* a oneof is supported.
1416 if (field_descriptor->containing_oneof() != NULL) {
1417 return true;
1418 }
1419
1420 if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1421 PyErr_Format(
1422 PyExc_ValueError,
1423 "Can't test non-submessage field \"%s.%s\" for presence in proto3.",
1424 message_name.c_str(), field_descriptor->name().c_str());
1425 return false;
1426 }
1427 }
1428
1429 return true;
1430 }
1431
HasField(CMessage * self,PyObject * arg)1432 PyObject* HasField(CMessage* self, PyObject* arg) {
1433 char* field_name;
1434 Py_ssize_t size;
1435 #if PY_MAJOR_VERSION < 3
1436 if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) {
1437 return NULL;
1438 }
1439 #else
1440 field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
1441 if (!field_name) {
1442 return NULL;
1443 }
1444 #endif
1445
1446 Message* message = self->message;
1447 bool is_in_oneof;
1448 const FieldDescriptor* field_descriptor =
1449 FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof);
1450 if (field_descriptor == NULL) {
1451 if (!is_in_oneof) {
1452 PyErr_Format(PyExc_ValueError, "Protocol message %s has no field %s.",
1453 message->GetDescriptor()->name().c_str(), field_name);
1454 return NULL;
1455 } else {
1456 Py_RETURN_FALSE;
1457 }
1458 }
1459
1460 if (!CheckHasPresence(field_descriptor, is_in_oneof)) {
1461 return NULL;
1462 }
1463
1464 if (message->GetReflection()->HasField(*message, field_descriptor)) {
1465 Py_RETURN_TRUE;
1466 }
1467
1468 Py_RETURN_FALSE;
1469 }
1470
ClearExtension(CMessage * self,PyObject * extension)1471 PyObject* ClearExtension(CMessage* self, PyObject* extension) {
1472 const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1473 if (descriptor == NULL) {
1474 return NULL;
1475 }
1476 if (InternalReleaseFieldByDescriptor(self, descriptor) < 0) {
1477 return NULL;
1478 }
1479 return ClearFieldByDescriptor(self, descriptor);
1480 }
1481
HasExtension(CMessage * self,PyObject * extension)1482 PyObject* HasExtension(CMessage* self, PyObject* extension) {
1483 const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1484 if (descriptor == NULL) {
1485 return NULL;
1486 }
1487 return HasFieldByDescriptor(self, descriptor);
1488 }
1489
1490 // ---------------------------------------------------------------------
1491 // Releasing messages
1492 //
1493 // The Python API's ClearField() and Clear() methods behave
1494 // differently than their C++ counterparts. While the C++ versions
1495 // clears the children, the Python versions detaches the children,
1496 // without touching their content. This impedance mismatch causes
1497 // some complexity in the implementation, which is captured in this
1498 // section.
1499 //
1500 // When one or multiple fields are cleared we need to:
1501 //
1502 // * Gather all child objects that need to be detached from the message.
1503 // In composite_fields and child_submessages.
1504 //
1505 // * Create a new Python message of the same kind. Use SwapFields() to move
1506 // data from the original message.
1507 //
1508 // * Change the parent of all child objects: update their strong reference
1509 // to their parent, and move their presence in composite_fields and
1510 // child_submessages.
1511
1512 // ---------------------------------------------------------------------
1513 // Release a composite child of a CMessage
1514
InternalReparentFields(CMessage * self,const std::vector<CMessage * > & messages_to_release,const std::vector<ContainerBase * > & containers_to_release)1515 static int InternalReparentFields(
1516 CMessage* self, const std::vector<CMessage*>& messages_to_release,
1517 const std::vector<ContainerBase*>& containers_to_release) {
1518 if (messages_to_release.empty() && containers_to_release.empty()) {
1519 return 0;
1520 }
1521
1522 // Move all the passed sub_messages to another message.
1523 CMessage* new_message = cmessage::NewEmptyMessage(self->GetMessageClass());
1524 if (new_message == nullptr) {
1525 return -1;
1526 }
1527 new_message->message = self->message->New();
1528 ScopedPyObjectPtr holder(reinterpret_cast<PyObject*>(new_message));
1529 new_message->child_submessages = new CMessage::SubMessagesMap();
1530 new_message->composite_fields = new CMessage::CompositeFieldsMap();
1531 std::set<const FieldDescriptor*> fields_to_swap;
1532
1533 // In case this the removed fields are the last reference to a message, keep
1534 // a reference.
1535 Py_INCREF(self);
1536
1537 for (const auto& to_release : messages_to_release) {
1538 fields_to_swap.insert(to_release->parent_field_descriptor);
1539 // Reparent
1540 Py_INCREF(new_message);
1541 Py_DECREF(to_release->parent);
1542 to_release->parent = new_message;
1543 self->child_submessages->erase(to_release->message);
1544 new_message->child_submessages->emplace(to_release->message, to_release);
1545 }
1546
1547 for (const auto& to_release : containers_to_release) {
1548 fields_to_swap.insert(to_release->parent_field_descriptor);
1549 Py_INCREF(new_message);
1550 Py_DECREF(to_release->parent);
1551 to_release->parent = new_message;
1552 self->composite_fields->erase(to_release->parent_field_descriptor);
1553 new_message->composite_fields->emplace(to_release->parent_field_descriptor,
1554 to_release);
1555 }
1556
1557 self->message->GetReflection()->SwapFields(
1558 self->message, new_message->message,
1559 std::vector<const FieldDescriptor*>(fields_to_swap.begin(),
1560 fields_to_swap.end()));
1561
1562 // This might delete the Python message completely if all children were moved.
1563 Py_DECREF(self);
1564
1565 return 0;
1566 }
1567
InternalReleaseFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1568 int InternalReleaseFieldByDescriptor(
1569 CMessage* self,
1570 const FieldDescriptor* field_descriptor) {
1571 if (!field_descriptor->is_repeated() &&
1572 field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1573 // Single scalars are not in any cache.
1574 return 0;
1575 }
1576 std::vector<CMessage*> messages_to_release;
1577 std::vector<ContainerBase*> containers_to_release;
1578 if (self->child_submessages && field_descriptor->is_repeated() &&
1579 field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1580 for (const auto& child_item : *self->child_submessages) {
1581 if (child_item.second->parent_field_descriptor == field_descriptor) {
1582 messages_to_release.push_back(child_item.second);
1583 }
1584 }
1585 }
1586 if (self->composite_fields) {
1587 CMessage::CompositeFieldsMap::iterator it =
1588 self->composite_fields->find(field_descriptor);
1589 if (it != self->composite_fields->end()) {
1590 containers_to_release.push_back(it->second);
1591 }
1592 }
1593
1594 return InternalReparentFields(self, messages_to_release,
1595 containers_to_release);
1596 }
1597
ClearFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1598 PyObject* ClearFieldByDescriptor(
1599 CMessage* self,
1600 const FieldDescriptor* field_descriptor) {
1601 if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
1602 return NULL;
1603 }
1604 AssureWritable(self);
1605 Message* message = self->message;
1606 message->GetReflection()->ClearField(message, field_descriptor);
1607 Py_RETURN_NONE;
1608 }
1609
ClearField(CMessage * self,PyObject * arg)1610 PyObject* ClearField(CMessage* self, PyObject* arg) {
1611 if (!(PyString_Check(arg) || PyUnicode_Check(arg))) {
1612 PyErr_SetString(PyExc_TypeError, "field name must be a string");
1613 return NULL;
1614 }
1615 #if PY_MAJOR_VERSION < 3
1616 char* field_name;
1617 Py_ssize_t size;
1618 if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) {
1619 return NULL;
1620 }
1621 #else
1622 Py_ssize_t size;
1623 const char* field_name = PyUnicode_AsUTF8AndSize(arg, &size);
1624 #endif
1625 AssureWritable(self);
1626 Message* message = self->message;
1627 ScopedPyObjectPtr arg_in_oneof;
1628 bool is_in_oneof;
1629 const FieldDescriptor* field_descriptor =
1630 FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof);
1631 if (field_descriptor == NULL) {
1632 if (!is_in_oneof) {
1633 PyErr_Format(PyExc_ValueError,
1634 "Protocol message has no \"%s\" field.", field_name);
1635 return NULL;
1636 } else {
1637 Py_RETURN_NONE;
1638 }
1639 } else if (is_in_oneof) {
1640 const string& name = field_descriptor->name();
1641 arg_in_oneof.reset(PyString_FromStringAndSize(name.c_str(), name.size()));
1642 arg = arg_in_oneof.get();
1643 }
1644
1645 if (InternalReleaseFieldByDescriptor(self, field_descriptor) < 0) {
1646 return NULL;
1647 }
1648 return ClearFieldByDescriptor(self, field_descriptor);
1649 }
1650
Clear(CMessage * self)1651 PyObject* Clear(CMessage* self) {
1652 AssureWritable(self);
1653 // Detach all current fields of this message
1654 std::vector<CMessage*> messages_to_release;
1655 std::vector<ContainerBase*> containers_to_release;
1656 if (self->child_submessages) {
1657 for (const auto& item : *self->child_submessages) {
1658 messages_to_release.push_back(item.second);
1659 }
1660 }
1661 if (self->composite_fields) {
1662 for (const auto& item : *self->composite_fields) {
1663 containers_to_release.push_back(item.second);
1664 }
1665 }
1666 if (InternalReparentFields(self, messages_to_release, containers_to_release) <
1667 0) {
1668 return NULL;
1669 }
1670 if (self->unknown_field_set) {
1671 unknown_fields::Clear(
1672 reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1673 self->unknown_field_set = nullptr;
1674 }
1675 self->message->Clear();
1676 Py_RETURN_NONE;
1677 }
1678
1679 // ---------------------------------------------------------------------
1680
GetMessageName(CMessage * self)1681 static string GetMessageName(CMessage* self) {
1682 if (self->parent_field_descriptor != NULL) {
1683 return self->parent_field_descriptor->full_name();
1684 } else {
1685 return self->message->GetDescriptor()->full_name();
1686 }
1687 }
1688
InternalSerializeToString(CMessage * self,PyObject * args,PyObject * kwargs,bool require_initialized)1689 static PyObject* InternalSerializeToString(
1690 CMessage* self, PyObject* args, PyObject* kwargs,
1691 bool require_initialized) {
1692 // Parse the "deterministic" kwarg; defaults to False.
1693 static char* kwlist[] = { "deterministic", 0 };
1694 PyObject* deterministic_obj = Py_None;
1695 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
1696 &deterministic_obj)) {
1697 return NULL;
1698 }
1699 // Preemptively convert to a bool first, so we don't need to back out of
1700 // allocating memory if this raises an exception.
1701 // NOTE: This is unused later if deterministic == Py_None, but that's fine.
1702 int deterministic = PyObject_IsTrue(deterministic_obj);
1703 if (deterministic < 0) {
1704 return NULL;
1705 }
1706
1707 if (require_initialized && !self->message->IsInitialized()) {
1708 ScopedPyObjectPtr errors(FindInitializationErrors(self));
1709 if (errors == NULL) {
1710 return NULL;
1711 }
1712 ScopedPyObjectPtr comma(PyString_FromString(","));
1713 if (comma == NULL) {
1714 return NULL;
1715 }
1716 ScopedPyObjectPtr joined(
1717 PyObject_CallMethod(comma.get(), "join", "O", errors.get()));
1718 if (joined == NULL) {
1719 return NULL;
1720 }
1721
1722 // TODO(haberman): this is a (hopefully temporary) hack. The unit testing
1723 // infrastructure reloads all pure-Python modules for every test, but not
1724 // C++ modules (because that's generally impossible:
1725 // http://bugs.python.org/issue1144263). But if we cache EncodeError, we'll
1726 // return the EncodeError from a previous load of the module, which won't
1727 // match a user's attempt to catch EncodeError. So we have to look it up
1728 // again every time.
1729 ScopedPyObjectPtr message_module(PyImport_ImportModule(
1730 "google.protobuf.message"));
1731 if (message_module.get() == NULL) {
1732 return NULL;
1733 }
1734
1735 ScopedPyObjectPtr encode_error(
1736 PyObject_GetAttrString(message_module.get(), "EncodeError"));
1737 if (encode_error.get() == NULL) {
1738 return NULL;
1739 }
1740 PyErr_Format(encode_error.get(),
1741 "Message %s is missing required fields: %s",
1742 GetMessageName(self).c_str(), PyString_AsString(joined.get()));
1743 return NULL;
1744 }
1745
1746 // Ok, arguments parsed and errors checked, now encode to a string
1747 const size_t size = self->message->ByteSizeLong();
1748 if (size == 0) {
1749 return PyBytes_FromString("");
1750 }
1751 PyObject* result = PyBytes_FromStringAndSize(NULL, size);
1752 if (result == NULL) {
1753 return NULL;
1754 }
1755 io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
1756 io::CodedOutputStream coded_out(&out);
1757 if (deterministic_obj != Py_None) {
1758 coded_out.SetSerializationDeterministic(deterministic);
1759 }
1760 self->message->SerializeWithCachedSizes(&coded_out);
1761 GOOGLE_CHECK(!coded_out.HadError());
1762 return result;
1763 }
1764
SerializeToString(CMessage * self,PyObject * args,PyObject * kwargs)1765 static PyObject* SerializeToString(
1766 CMessage* self, PyObject* args, PyObject* kwargs) {
1767 return InternalSerializeToString(self, args, kwargs,
1768 /*require_initialized=*/true);
1769 }
1770
SerializePartialToString(CMessage * self,PyObject * args,PyObject * kwargs)1771 static PyObject* SerializePartialToString(
1772 CMessage* self, PyObject* args, PyObject* kwargs) {
1773 return InternalSerializeToString(self, args, kwargs,
1774 /*require_initialized=*/false);
1775 }
1776
1777 // Formats proto fields for ascii dumps using python formatting functions where
1778 // appropriate.
1779 class PythonFieldValuePrinter : public TextFormat::FieldValuePrinter {
1780 public:
1781 // Python has some differences from C++ when printing floating point numbers.
1782 //
1783 // 1) Trailing .0 is always printed.
1784 // 2) (Python2) Output is rounded to 12 digits.
1785 // 3) (Python3) The full precision of the double is preserved (and Python uses
1786 // David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some
1787 // differences, but they rarely happen)
1788 //
1789 // We override floating point printing with the C-API function for printing
1790 // Python floats to ensure consistency.
PrintFloat(float value) const1791 string PrintFloat(float value) const { return PrintDouble(value); }
PrintDouble(double value) const1792 string PrintDouble(double value) const {
1793 // This implementation is not highly optimized (it allocates two temporary
1794 // Python objects) but it is simple and portable. If this is shown to be a
1795 // performance bottleneck, we can optimize it, but the results will likely
1796 // be more complicated to accommodate the differing behavior of double
1797 // formatting between Python 2 and Python 3.
1798 //
1799 // (Though a valid question is: do we really want to make out output
1800 // dependent on the Python version?)
1801 ScopedPyObjectPtr py_value(PyFloat_FromDouble(value));
1802 if (!py_value.get()) {
1803 return string();
1804 }
1805
1806 ScopedPyObjectPtr py_str(PyObject_Str(py_value.get()));
1807 if (!py_str.get()) {
1808 return string();
1809 }
1810
1811 return string(PyString_AsString(py_str.get()));
1812 }
1813 };
1814
ToStr(CMessage * self)1815 static PyObject* ToStr(CMessage* self) {
1816 TextFormat::Printer printer;
1817 // Passes ownership
1818 printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter());
1819 printer.SetHideUnknownFields(true);
1820 string output;
1821 if (!printer.PrintToString(*self->message, &output)) {
1822 PyErr_SetString(PyExc_ValueError, "Unable to convert message to str");
1823 return NULL;
1824 }
1825 return PyString_FromString(output.c_str());
1826 }
1827
MergeFrom(CMessage * self,PyObject * arg)1828 PyObject* MergeFrom(CMessage* self, PyObject* arg) {
1829 CMessage* other_message;
1830 if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1831 PyErr_Format(PyExc_TypeError,
1832 "Parameter to MergeFrom() must be instance of same class: "
1833 "expected %s got %s.",
1834 self->message->GetDescriptor()->full_name().c_str(),
1835 Py_TYPE(arg)->tp_name);
1836 return NULL;
1837 }
1838
1839 other_message = reinterpret_cast<CMessage*>(arg);
1840 if (other_message->message->GetDescriptor() !=
1841 self->message->GetDescriptor()) {
1842 PyErr_Format(PyExc_TypeError,
1843 "Parameter to MergeFrom() must be instance of same class: "
1844 "expected %s got %s.",
1845 self->message->GetDescriptor()->full_name().c_str(),
1846 other_message->message->GetDescriptor()->full_name().c_str());
1847 return NULL;
1848 }
1849 AssureWritable(self);
1850
1851 self->message->MergeFrom(*other_message->message);
1852 // Child message might be lazily created before MergeFrom. Make sure they
1853 // are mutable at this point if child messages are really created.
1854 if (FixupMessageAfterMerge(self) < 0) {
1855 return NULL;
1856 }
1857
1858 Py_RETURN_NONE;
1859 }
1860
CopyFrom(CMessage * self,PyObject * arg)1861 static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
1862 CMessage* other_message;
1863 if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1864 PyErr_Format(PyExc_TypeError,
1865 "Parameter to CopyFrom() must be instance of same class: "
1866 "expected %s got %s.",
1867 self->message->GetDescriptor()->full_name().c_str(),
1868 Py_TYPE(arg)->tp_name);
1869 return NULL;
1870 }
1871
1872 other_message = reinterpret_cast<CMessage*>(arg);
1873
1874 if (self == other_message) {
1875 Py_RETURN_NONE;
1876 }
1877
1878 if (other_message->message->GetDescriptor() !=
1879 self->message->GetDescriptor()) {
1880 PyErr_Format(PyExc_TypeError,
1881 "Parameter to CopyFrom() must be instance of same class: "
1882 "expected %s got %s.",
1883 self->message->GetDescriptor()->full_name().c_str(),
1884 other_message->message->GetDescriptor()->full_name().c_str());
1885 return NULL;
1886 }
1887
1888 AssureWritable(self);
1889
1890 // CopyFrom on the message will not clean up self->composite_fields,
1891 // which can leave us in an inconsistent state, so clear it out here.
1892 (void)ScopedPyObjectPtr(Clear(self));
1893
1894 self->message->CopyFrom(*other_message->message);
1895
1896 Py_RETURN_NONE;
1897 }
1898
1899 // Protobuf has a 64MB limit built in, this variable will override this. Please
1900 // do not enable this unless you fully understand the implications: protobufs
1901 // must all be kept in memory at the same time, so if they grow too big you may
1902 // get OOM errors. The protobuf APIs do not provide any tools for processing
1903 // protobufs in chunks. If you have protos this big you should break them up if
1904 // it is at all convenient to do so.
1905 #ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
1906 static bool allow_oversize_protos = true;
1907 #else
1908 static bool allow_oversize_protos = false;
1909 #endif
1910
1911 // Provide a method in the module to set allow_oversize_protos to a boolean
1912 // value. This method returns the newly value of allow_oversize_protos.
SetAllowOversizeProtos(PyObject * m,PyObject * arg)1913 PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
1914 if (!arg || !PyBool_Check(arg)) {
1915 PyErr_SetString(PyExc_TypeError,
1916 "Argument to SetAllowOversizeProtos must be boolean");
1917 return NULL;
1918 }
1919 allow_oversize_protos = PyObject_IsTrue(arg);
1920 if (allow_oversize_protos) {
1921 Py_RETURN_TRUE;
1922 } else {
1923 Py_RETURN_FALSE;
1924 }
1925 }
1926
MergeFromString(CMessage * self,PyObject * arg)1927 static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
1928 const void* data;
1929 Py_ssize_t data_length;
1930 if (PyObject_AsReadBuffer(arg, &data, &data_length) < 0) {
1931 return NULL;
1932 }
1933
1934 AssureWritable(self);
1935
1936 io::CodedInputStream input(
1937 reinterpret_cast<const uint8*>(data), data_length);
1938 if (allow_oversize_protos) {
1939 input.SetTotalBytesLimit(INT_MAX, INT_MAX);
1940 input.SetRecursionLimit(INT_MAX);
1941 }
1942 PyMessageFactory* factory = GetFactoryForMessage(self);
1943 input.SetExtensionRegistry(factory->pool->pool, factory->message_factory);
1944 bool success = self->message->MergePartialFromCodedStream(&input);
1945 // Child message might be lazily created before MergeFrom. Make sure they
1946 // are mutable at this point if child messages are really created.
1947 if (FixupMessageAfterMerge(self) < 0) {
1948 return NULL;
1949 }
1950
1951 if (success) {
1952 if (!input.ConsumedEntireMessage()) {
1953 // TODO(jieluo): Raise error and return NULL instead.
1954 // b/27494216
1955 PyErr_Warn(NULL, "Unexpected end-group tag: Not all data was converted");
1956 }
1957 return PyInt_FromLong(input.CurrentPosition());
1958 } else {
1959 PyErr_Format(DecodeError_class, "Error parsing message");
1960 return NULL;
1961 }
1962 }
1963
ParseFromString(CMessage * self,PyObject * arg)1964 static PyObject* ParseFromString(CMessage* self, PyObject* arg) {
1965 if (ScopedPyObjectPtr(Clear(self)) == NULL) {
1966 return NULL;
1967 }
1968 return MergeFromString(self, arg);
1969 }
1970
ByteSize(CMessage * self,PyObject * args)1971 static PyObject* ByteSize(CMessage* self, PyObject* args) {
1972 return PyLong_FromLong(self->message->ByteSize());
1973 }
1974
RegisterExtension(PyObject * cls,PyObject * extension_handle)1975 PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
1976 const FieldDescriptor* descriptor =
1977 GetExtensionDescriptor(extension_handle);
1978 if (descriptor == NULL) {
1979 return NULL;
1980 }
1981 if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
1982 PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
1983 cls->ob_type->tp_name);
1984 return NULL;
1985 }
1986 CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
1987 if (message_class == NULL) {
1988 return NULL;
1989 }
1990 // If the extension was already registered, check that it is the same.
1991 const FieldDescriptor* existing_extension =
1992 message_class->py_message_factory->pool->pool->FindExtensionByNumber(
1993 descriptor->containing_type(), descriptor->number());
1994 if (existing_extension != NULL && existing_extension != descriptor) {
1995 PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
1996 return NULL;
1997 }
1998 Py_RETURN_NONE;
1999 }
2000
SetInParent(CMessage * self,PyObject * args)2001 static PyObject* SetInParent(CMessage* self, PyObject* args) {
2002 AssureWritable(self);
2003 Py_RETURN_NONE;
2004 }
2005
WhichOneof(CMessage * self,PyObject * arg)2006 static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
2007 Py_ssize_t name_size;
2008 char *name_data;
2009 if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0)
2010 return NULL;
2011 string oneof_name = string(name_data, name_size);
2012 const OneofDescriptor* oneof_desc =
2013 self->message->GetDescriptor()->FindOneofByName(oneof_name);
2014 if (oneof_desc == NULL) {
2015 PyErr_Format(PyExc_ValueError,
2016 "Protocol message has no oneof \"%s\" field.",
2017 oneof_name.c_str());
2018 return NULL;
2019 }
2020 const FieldDescriptor* field_in_oneof =
2021 self->message->GetReflection()->GetOneofFieldDescriptor(
2022 *self->message, oneof_desc);
2023 if (field_in_oneof == NULL) {
2024 Py_RETURN_NONE;
2025 } else {
2026 const string& name = field_in_oneof->name();
2027 return PyString_FromStringAndSize(name.c_str(), name.size());
2028 }
2029 }
2030
2031 static PyObject* GetExtensionDict(CMessage* self, void *closure);
2032
ListFields(CMessage * self)2033 static PyObject* ListFields(CMessage* self) {
2034 std::vector<const FieldDescriptor*> fields;
2035 self->message->GetReflection()->ListFields(*self->message, &fields);
2036
2037 // Normally, the list will be exactly the size of the fields.
2038 ScopedPyObjectPtr all_fields(PyList_New(fields.size()));
2039 if (all_fields == NULL) {
2040 return NULL;
2041 }
2042
2043 // When there are unknown extensions, the py list will *not* contain
2044 // the field information. Thus the actual size of the py list will be
2045 // smaller than the size of fields. Set the actual size at the end.
2046 Py_ssize_t actual_size = 0;
2047 for (size_t i = 0; i < fields.size(); ++i) {
2048 ScopedPyObjectPtr t(PyTuple_New(2));
2049 if (t == NULL) {
2050 return NULL;
2051 }
2052
2053 if (fields[i]->is_extension()) {
2054 ScopedPyObjectPtr extension_field(
2055 PyFieldDescriptor_FromDescriptor(fields[i]));
2056 if (extension_field == NULL) {
2057 return NULL;
2058 }
2059 // With C++ descriptors, the field can always be retrieved, but for
2060 // unknown extensions which have not been imported in Python code, there
2061 // is no message class and we cannot retrieve the value.
2062 // TODO(amauryfa): consider building the class on the fly!
2063 if (fields[i]->message_type() != NULL &&
2064 message_factory::GetMessageClass(
2065 GetFactoryForMessage(self),
2066 fields[i]->message_type()) == NULL) {
2067 PyErr_Clear();
2068 continue;
2069 }
2070 ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL));
2071 if (extensions == NULL) {
2072 return NULL;
2073 }
2074 // 'extension' reference later stolen by PyTuple_SET_ITEM.
2075 PyObject* extension = PyObject_GetItem(
2076 extensions.get(), extension_field.get());
2077 if (extension == NULL) {
2078 return NULL;
2079 }
2080 PyTuple_SET_ITEM(t.get(), 0, extension_field.release());
2081 // Steals reference to 'extension'
2082 PyTuple_SET_ITEM(t.get(), 1, extension);
2083 } else {
2084 // Normal field
2085 ScopedPyObjectPtr field_descriptor(
2086 PyFieldDescriptor_FromDescriptor(fields[i]));
2087 if (field_descriptor == NULL) {
2088 return NULL;
2089 }
2090
2091 PyObject* field_value = GetFieldValue(self, fields[i]);
2092 if (field_value == NULL) {
2093 PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str());
2094 return NULL;
2095 }
2096 PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release());
2097 PyTuple_SET_ITEM(t.get(), 1, field_value);
2098 }
2099 PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
2100 ++actual_size;
2101 }
2102 if (static_cast<size_t>(actual_size) != fields.size() &&
2103 (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) <
2104 0)) {
2105 return NULL;
2106 }
2107 return all_fields.release();
2108 }
2109
DiscardUnknownFields(CMessage * self)2110 static PyObject* DiscardUnknownFields(CMessage* self) {
2111 AssureWritable(self);
2112 self->message->DiscardUnknownFields();
2113 Py_RETURN_NONE;
2114 }
2115
FindInitializationErrors(CMessage * self)2116 PyObject* FindInitializationErrors(CMessage* self) {
2117 Message* message = self->message;
2118 std::vector<string> errors;
2119 message->FindInitializationErrors(&errors);
2120
2121 PyObject* error_list = PyList_New(errors.size());
2122 if (error_list == NULL) {
2123 return NULL;
2124 }
2125 for (size_t i = 0; i < errors.size(); ++i) {
2126 const string& error = errors[i];
2127 PyObject* error_string = PyString_FromStringAndSize(
2128 error.c_str(), error.length());
2129 if (error_string == NULL) {
2130 Py_DECREF(error_list);
2131 return NULL;
2132 }
2133 PyList_SET_ITEM(error_list, i, error_string);
2134 }
2135 return error_list;
2136 }
2137
RichCompare(CMessage * self,PyObject * other,int opid)2138 static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
2139 // Only equality comparisons are implemented.
2140 if (opid != Py_EQ && opid != Py_NE) {
2141 Py_INCREF(Py_NotImplemented);
2142 return Py_NotImplemented;
2143 }
2144 bool equals = true;
2145 // If other is not a message, it cannot be equal.
2146 if (!PyObject_TypeCheck(other, CMessage_Type)) {
2147 equals = false;
2148 }
2149 const google::protobuf::Message* other_message =
2150 reinterpret_cast<CMessage*>(other)->message;
2151 // If messages don't have the same descriptors, they are not equal.
2152 if (equals &&
2153 self->message->GetDescriptor() != other_message->GetDescriptor()) {
2154 equals = false;
2155 }
2156 // Check the message contents.
2157 if (equals && !google::protobuf::util::MessageDifferencer::Equals(
2158 *self->message,
2159 *reinterpret_cast<CMessage*>(other)->message)) {
2160 equals = false;
2161 }
2162
2163 if (equals ^ (opid == Py_EQ)) {
2164 Py_RETURN_FALSE;
2165 } else {
2166 Py_RETURN_TRUE;
2167 }
2168 }
2169
InternalGetScalar(const Message * message,const FieldDescriptor * field_descriptor)2170 PyObject* InternalGetScalar(const Message* message,
2171 const FieldDescriptor* field_descriptor) {
2172 const Reflection* reflection = message->GetReflection();
2173
2174 if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2175 return NULL;
2176 }
2177
2178 PyObject* result = NULL;
2179 switch (field_descriptor->cpp_type()) {
2180 case FieldDescriptor::CPPTYPE_INT32: {
2181 int32 value = reflection->GetInt32(*message, field_descriptor);
2182 result = PyInt_FromLong(value);
2183 break;
2184 }
2185 case FieldDescriptor::CPPTYPE_INT64: {
2186 int64 value = reflection->GetInt64(*message, field_descriptor);
2187 result = PyLong_FromLongLong(value);
2188 break;
2189 }
2190 case FieldDescriptor::CPPTYPE_UINT32: {
2191 uint32 value = reflection->GetUInt32(*message, field_descriptor);
2192 result = PyInt_FromSize_t(value);
2193 break;
2194 }
2195 case FieldDescriptor::CPPTYPE_UINT64: {
2196 uint64 value = reflection->GetUInt64(*message, field_descriptor);
2197 result = PyLong_FromUnsignedLongLong(value);
2198 break;
2199 }
2200 case FieldDescriptor::CPPTYPE_FLOAT: {
2201 float value = reflection->GetFloat(*message, field_descriptor);
2202 result = PyFloat_FromDouble(value);
2203 break;
2204 }
2205 case FieldDescriptor::CPPTYPE_DOUBLE: {
2206 double value = reflection->GetDouble(*message, field_descriptor);
2207 result = PyFloat_FromDouble(value);
2208 break;
2209 }
2210 case FieldDescriptor::CPPTYPE_BOOL: {
2211 bool value = reflection->GetBool(*message, field_descriptor);
2212 result = PyBool_FromLong(value);
2213 break;
2214 }
2215 case FieldDescriptor::CPPTYPE_STRING: {
2216 string scratch;
2217 const string& value =
2218 reflection->GetStringReference(*message, field_descriptor, &scratch);
2219 result = ToStringObject(field_descriptor, value);
2220 break;
2221 }
2222 case FieldDescriptor::CPPTYPE_ENUM: {
2223 const EnumValueDescriptor* enum_value =
2224 message->GetReflection()->GetEnum(*message, field_descriptor);
2225 result = PyInt_FromLong(enum_value->number());
2226 break;
2227 }
2228 default:
2229 PyErr_Format(
2230 PyExc_SystemError, "Getting a value from a field of unknown type %d",
2231 field_descriptor->cpp_type());
2232 }
2233
2234 return result;
2235 }
2236
InternalGetSubMessage(CMessage * self,const FieldDescriptor * field_descriptor)2237 CMessage* InternalGetSubMessage(
2238 CMessage* self, const FieldDescriptor* field_descriptor) {
2239 const Reflection* reflection = self->message->GetReflection();
2240 PyMessageFactory* factory = GetFactoryForMessage(self);
2241 const Message& sub_message = reflection->GetMessage(
2242 *self->message, field_descriptor, factory->message_factory);
2243
2244 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2245 factory, field_descriptor->message_type());
2246 ScopedPyObjectPtr message_class_owner(
2247 reinterpret_cast<PyObject*>(message_class));
2248 if (message_class == NULL) {
2249 return NULL;
2250 }
2251
2252 CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
2253 if (cmsg == NULL) {
2254 return NULL;
2255 }
2256
2257 Py_INCREF(self);
2258 cmsg->parent = self;
2259 cmsg->parent_field_descriptor = field_descriptor;
2260 cmsg->read_only = !reflection->HasField(*self->message, field_descriptor);
2261 cmsg->message = const_cast<Message*>(&sub_message);
2262 return cmsg;
2263 }
2264
InternalSetNonOneofScalar(Message * message,const FieldDescriptor * field_descriptor,PyObject * arg)2265 int InternalSetNonOneofScalar(
2266 Message* message,
2267 const FieldDescriptor* field_descriptor,
2268 PyObject* arg) {
2269 const Reflection* reflection = message->GetReflection();
2270
2271 if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2272 return -1;
2273 }
2274
2275 switch (field_descriptor->cpp_type()) {
2276 case FieldDescriptor::CPPTYPE_INT32: {
2277 GOOGLE_CHECK_GET_INT32(arg, value, -1);
2278 reflection->SetInt32(message, field_descriptor, value);
2279 break;
2280 }
2281 case FieldDescriptor::CPPTYPE_INT64: {
2282 GOOGLE_CHECK_GET_INT64(arg, value, -1);
2283 reflection->SetInt64(message, field_descriptor, value);
2284 break;
2285 }
2286 case FieldDescriptor::CPPTYPE_UINT32: {
2287 GOOGLE_CHECK_GET_UINT32(arg, value, -1);
2288 reflection->SetUInt32(message, field_descriptor, value);
2289 break;
2290 }
2291 case FieldDescriptor::CPPTYPE_UINT64: {
2292 GOOGLE_CHECK_GET_UINT64(arg, value, -1);
2293 reflection->SetUInt64(message, field_descriptor, value);
2294 break;
2295 }
2296 case FieldDescriptor::CPPTYPE_FLOAT: {
2297 GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
2298 reflection->SetFloat(message, field_descriptor, value);
2299 break;
2300 }
2301 case FieldDescriptor::CPPTYPE_DOUBLE: {
2302 GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
2303 reflection->SetDouble(message, field_descriptor, value);
2304 break;
2305 }
2306 case FieldDescriptor::CPPTYPE_BOOL: {
2307 GOOGLE_CHECK_GET_BOOL(arg, value, -1);
2308 reflection->SetBool(message, field_descriptor, value);
2309 break;
2310 }
2311 case FieldDescriptor::CPPTYPE_STRING: {
2312 if (!CheckAndSetString(
2313 arg, message, field_descriptor, reflection, false, -1)) {
2314 return -1;
2315 }
2316 break;
2317 }
2318 case FieldDescriptor::CPPTYPE_ENUM: {
2319 GOOGLE_CHECK_GET_INT32(arg, value, -1);
2320 if (reflection->SupportsUnknownEnumValues()) {
2321 reflection->SetEnumValue(message, field_descriptor, value);
2322 } else {
2323 const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
2324 const EnumValueDescriptor* enum_value =
2325 enum_descriptor->FindValueByNumber(value);
2326 if (enum_value != NULL) {
2327 reflection->SetEnum(message, field_descriptor, enum_value);
2328 } else {
2329 PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
2330 return -1;
2331 }
2332 }
2333 break;
2334 }
2335 default:
2336 PyErr_Format(
2337 PyExc_SystemError, "Setting value to a field of unknown type %d",
2338 field_descriptor->cpp_type());
2339 return -1;
2340 }
2341
2342 return 0;
2343 }
2344
InternalSetScalar(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * arg)2345 int InternalSetScalar(
2346 CMessage* self,
2347 const FieldDescriptor* field_descriptor,
2348 PyObject* arg) {
2349 if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
2350 return -1;
2351 }
2352
2353 if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
2354 return -1;
2355 }
2356
2357 return InternalSetNonOneofScalar(self->message, field_descriptor, arg);
2358 }
2359
FromString(PyTypeObject * cls,PyObject * serialized)2360 PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
2361 PyObject* py_cmsg = PyObject_CallObject(
2362 reinterpret_cast<PyObject*>(cls), NULL);
2363 if (py_cmsg == NULL) {
2364 return NULL;
2365 }
2366 CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
2367
2368 ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized));
2369 if (py_length == NULL) {
2370 Py_DECREF(py_cmsg);
2371 return NULL;
2372 }
2373
2374 return py_cmsg;
2375 }
2376
DeepCopy(CMessage * self,PyObject * arg)2377 PyObject* DeepCopy(CMessage* self, PyObject* arg) {
2378 PyObject* clone = PyObject_CallObject(
2379 reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL);
2380 if (clone == NULL) {
2381 return NULL;
2382 }
2383 if (!PyObject_TypeCheck(clone, CMessage_Type)) {
2384 Py_DECREF(clone);
2385 return NULL;
2386 }
2387 if (ScopedPyObjectPtr(MergeFrom(
2388 reinterpret_cast<CMessage*>(clone),
2389 reinterpret_cast<PyObject*>(self))) == NULL) {
2390 Py_DECREF(clone);
2391 return NULL;
2392 }
2393 return clone;
2394 }
2395
ToUnicode(CMessage * self)2396 PyObject* ToUnicode(CMessage* self) {
2397 // Lazy import to prevent circular dependencies
2398 ScopedPyObjectPtr text_format(
2399 PyImport_ImportModule("google.protobuf.text_format"));
2400 if (text_format == NULL) {
2401 return NULL;
2402 }
2403 ScopedPyObjectPtr method_name(PyString_FromString("MessageToString"));
2404 if (method_name == NULL) {
2405 return NULL;
2406 }
2407 Py_INCREF(Py_True);
2408 ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(
2409 text_format.get(), method_name.get(), self, Py_True, NULL));
2410 Py_DECREF(Py_True);
2411 if (encoded == NULL) {
2412 return NULL;
2413 }
2414 #if PY_MAJOR_VERSION < 3
2415 PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL);
2416 #else
2417 PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL);
2418 #endif
2419 if (decoded == NULL) {
2420 return NULL;
2421 }
2422 return decoded;
2423 }
2424
Reduce(CMessage * self)2425 PyObject* Reduce(CMessage* self) {
2426 ScopedPyObjectPtr constructor(reinterpret_cast<PyObject*>(Py_TYPE(self)));
2427 constructor.inc();
2428 ScopedPyObjectPtr args(PyTuple_New(0));
2429 if (args == NULL) {
2430 return NULL;
2431 }
2432 ScopedPyObjectPtr state(PyDict_New());
2433 if (state == NULL) {
2434 return NULL;
2435 }
2436 string contents;
2437 self->message->SerializePartialToString(&contents);
2438 ScopedPyObjectPtr serialized(
2439 PyBytes_FromStringAndSize(contents.c_str(), contents.size()));
2440 if (serialized == NULL) {
2441 return NULL;
2442 }
2443 if (PyDict_SetItemString(state.get(), "serialized", serialized.get()) < 0) {
2444 return NULL;
2445 }
2446 return Py_BuildValue("OOO", constructor.get(), args.get(), state.get());
2447 }
2448
SetState(CMessage * self,PyObject * state)2449 PyObject* SetState(CMessage* self, PyObject* state) {
2450 if (!PyDict_Check(state)) {
2451 PyErr_SetString(PyExc_TypeError, "state not a dict");
2452 return NULL;
2453 }
2454 PyObject* serialized = PyDict_GetItemString(state, "serialized");
2455 if (serialized == NULL) {
2456 return NULL;
2457 }
2458 #if PY_MAJOR_VERSION >= 3
2459 // On Python 3, using encoding='latin1' is required for unpickling
2460 // protos pickled by Python 2.
2461 if (!PyBytes_Check(serialized)) {
2462 serialized = PyUnicode_AsEncodedString(serialized, "latin1", NULL);
2463 }
2464 #endif
2465 if (ScopedPyObjectPtr(ParseFromString(self, serialized)) == NULL) {
2466 return NULL;
2467 }
2468 Py_RETURN_NONE;
2469 }
2470
2471 // CMessage static methods:
_CheckCalledFromGeneratedFile(PyObject * unused,PyObject * unused_arg)2472 PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
2473 PyObject* unused_arg) {
2474 if (!_CalledFromGeneratedFile(1)) {
2475 PyErr_SetString(PyExc_TypeError,
2476 "Descriptors should not be created directly, "
2477 "but only retrieved from their parent.");
2478 return NULL;
2479 }
2480 Py_RETURN_NONE;
2481 }
2482
GetExtensionDict(CMessage * self,void * closure)2483 static PyObject* GetExtensionDict(CMessage* self, void *closure) {
2484 // If there are extension_ranges, the message is "extendable". Allocate a
2485 // dictionary to store the extension fields.
2486 const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
2487 if (!descriptor->extension_range_count()) {
2488 PyErr_SetNone(PyExc_AttributeError);
2489 return NULL;
2490 }
2491 if (!self->composite_fields) {
2492 self->composite_fields = new CMessage::CompositeFieldsMap();
2493 }
2494 if (!self->composite_fields) {
2495 return NULL;
2496 }
2497 ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
2498 return reinterpret_cast<PyObject*>(extension_dict);
2499 }
2500
UnknownFieldSet(CMessage * self)2501 static PyObject* UnknownFieldSet(CMessage* self) {
2502 if (self->unknown_field_set == NULL) {
2503 self->unknown_field_set = unknown_fields::NewPyUnknownFields(self);
2504 } else {
2505 Py_INCREF(self->unknown_field_set);
2506 }
2507 return self->unknown_field_set;
2508 }
2509
GetExtensionsByName(CMessage * self,void * closure)2510 static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
2511 return message_meta::GetExtensionsByName(
2512 reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2513 }
2514
GetExtensionsByNumber(CMessage * self,void * closure)2515 static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
2516 return message_meta::GetExtensionsByNumber(
2517 reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2518 }
2519
2520 static PyGetSetDef Getters[] = {
2521 {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
2522 {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
2523 {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
2524 {NULL}
2525 };
2526
2527
2528 static PyMethodDef Methods[] = {
2529 { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
2530 "Makes a deep copy of the class." },
2531 { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
2532 "Outputs picklable representation of the message." },
2533 { "__setstate__", (PyCFunction)SetState, METH_O,
2534 "Inputs picklable representation of the message." },
2535 { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
2536 "Outputs a unicode representation of the message." },
2537 { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
2538 "Returns the size of the message in bytes." },
2539 { "Clear", (PyCFunction)Clear, METH_NOARGS,
2540 "Clears the message." },
2541 { "ClearExtension", (PyCFunction)ClearExtension, METH_O,
2542 "Clears a message field." },
2543 { "ClearField", (PyCFunction)ClearField, METH_O,
2544 "Clears a message field." },
2545 { "CopyFrom", (PyCFunction)CopyFrom, METH_O,
2546 "Copies a protocol message into the current message." },
2547 { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
2548 "Discards the unknown fields." },
2549 { "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
2550 METH_NOARGS,
2551 "Finds unset required fields." },
2552 { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS,
2553 "Creates new method instance from given serialized data." },
2554 { "HasExtension", (PyCFunction)HasExtension, METH_O,
2555 "Checks if a message field is set." },
2556 { "HasField", (PyCFunction)HasField, METH_O,
2557 "Checks if a message field is set." },
2558 { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS,
2559 "Checks if all required fields of a protocol message are set." },
2560 { "ListFields", (PyCFunction)ListFields, METH_NOARGS,
2561 "Lists all set fields of a message." },
2562 { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
2563 "Merges a protocol message into the current message." },
2564 { "MergeFromString", (PyCFunction)MergeFromString, METH_O,
2565 "Merges a serialized message into the current message." },
2566 { "ParseFromString", (PyCFunction)ParseFromString, METH_O,
2567 "Parses a serialized message into the current message." },
2568 { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
2569 "Registers an extension with the current message." },
2570 { "SerializePartialToString", (PyCFunction)SerializePartialToString,
2571 METH_VARARGS | METH_KEYWORDS,
2572 "Serializes the message to a string, even if it isn't initialized." },
2573 { "SerializeToString", (PyCFunction)SerializeToString,
2574 METH_VARARGS | METH_KEYWORDS,
2575 "Serializes the message to a string, only for initialized messages." },
2576 { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
2577 "Sets the has bit of the given field in its parent message." },
2578 { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS,
2579 "Parse unknown field set"},
2580 { "WhichOneof", (PyCFunction)WhichOneof, METH_O,
2581 "Returns the name of the field set inside a oneof, "
2582 "or None if no field is set." },
2583
2584 // Static Methods.
2585 { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile,
2586 METH_NOARGS | METH_STATIC,
2587 "Raises TypeError if the caller is not in a _pb2.py file."},
2588 { NULL, NULL}
2589 };
2590
SetCompositeField(CMessage * self,const FieldDescriptor * field,ContainerBase * value)2591 bool SetCompositeField(CMessage* self, const FieldDescriptor* field,
2592 ContainerBase* value) {
2593 if (self->composite_fields == NULL) {
2594 self->composite_fields = new CMessage::CompositeFieldsMap();
2595 }
2596 (*self->composite_fields)[field] = value;
2597 return true;
2598 }
2599
SetSubmessage(CMessage * self,CMessage * submessage)2600 bool SetSubmessage(CMessage* self, CMessage* submessage) {
2601 if (self->child_submessages == NULL) {
2602 self->child_submessages = new CMessage::SubMessagesMap();
2603 }
2604 (*self->child_submessages)[submessage->message] = submessage;
2605 return true;
2606 }
2607
GetAttr(PyObject * pself,PyObject * name)2608 PyObject* GetAttr(PyObject* pself, PyObject* name) {
2609 CMessage* self = reinterpret_cast<CMessage*>(pself);
2610 PyObject* result = PyObject_GenericGetAttr(
2611 reinterpret_cast<PyObject*>(self), name);
2612 if (result != NULL) {
2613 return result;
2614 }
2615 if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
2616 return NULL;
2617 }
2618
2619 PyErr_Clear();
2620 return message_meta::GetClassAttribute(
2621 CheckMessageClass(Py_TYPE(self)), name);
2622 }
2623
GetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor)2624 PyObject* GetFieldValue(CMessage* self,
2625 const FieldDescriptor* field_descriptor) {
2626 if (self->composite_fields) {
2627 CMessage::CompositeFieldsMap::iterator it =
2628 self->composite_fields->find(field_descriptor);
2629 if (it != self->composite_fields->end()) {
2630 ContainerBase* value = it->second;
2631 Py_INCREF(value);
2632 return value->AsPyObject();
2633 }
2634 }
2635
2636 if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2637 PyErr_Format(PyExc_TypeError,
2638 "descriptor to field '%s' doesn't apply to '%s' object",
2639 field_descriptor->full_name().c_str(),
2640 Py_TYPE(self)->tp_name);
2641 return NULL;
2642 }
2643
2644 if (!field_descriptor->is_repeated() &&
2645 field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
2646 return InternalGetScalar(self->message, field_descriptor);
2647 }
2648
2649 ContainerBase* py_container = nullptr;
2650 if (field_descriptor->is_map()) {
2651 const Descriptor* entry_type = field_descriptor->message_type();
2652 const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
2653 if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2654 CMessageClass* value_class = message_factory::GetMessageClass(
2655 GetFactoryForMessage(self), value_type->message_type());
2656 if (value_class == NULL) {
2657 return NULL;
2658 }
2659 py_container =
2660 NewMessageMapContainer(self, field_descriptor, value_class);
2661 } else {
2662 py_container = NewScalarMapContainer(self, field_descriptor);
2663 }
2664 } else if (field_descriptor->is_repeated()) {
2665 if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2666 CMessageClass* message_class = message_factory::GetMessageClass(
2667 GetFactoryForMessage(self), field_descriptor->message_type());
2668 if (message_class == NULL) {
2669 return NULL;
2670 }
2671 py_container = repeated_composite_container::NewContainer(
2672 self, field_descriptor, message_class);
2673 } else {
2674 py_container =
2675 repeated_scalar_container::NewContainer(self, field_descriptor);
2676 }
2677 } else if (field_descriptor->cpp_type() ==
2678 FieldDescriptor::CPPTYPE_MESSAGE) {
2679 py_container = InternalGetSubMessage(self, field_descriptor);
2680 } else {
2681 PyErr_SetString(PyExc_SystemError, "Should never happen");
2682 }
2683
2684 if (py_container == NULL) {
2685 return NULL;
2686 }
2687 if (!SetCompositeField(self, field_descriptor, py_container)) {
2688 Py_DECREF(py_container);
2689 return NULL;
2690 }
2691 return py_container->AsPyObject();
2692 }
2693
SetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * value)2694 int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
2695 PyObject* value) {
2696 if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2697 PyErr_Format(PyExc_TypeError,
2698 "descriptor to field '%s' doesn't apply to '%s' object",
2699 field_descriptor->full_name().c_str(),
2700 Py_TYPE(self)->tp_name);
2701 return -1;
2702 } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
2703 PyErr_Format(PyExc_AttributeError,
2704 "Assignment not allowed to repeated "
2705 "field \"%s\" in protocol message object.",
2706 field_descriptor->name().c_str());
2707 return -1;
2708 } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2709 PyErr_Format(PyExc_AttributeError,
2710 "Assignment not allowed to "
2711 "field \"%s\" in protocol message object.",
2712 field_descriptor->name().c_str());
2713 return -1;
2714 } else {
2715 AssureWritable(self);
2716 return InternalSetScalar(self, field_descriptor, value);
2717 }
2718 }
2719
2720 } // namespace cmessage
2721
2722 // All containers which are not messages:
2723 // - Make a new parent message
2724 // - Copy the field
2725 // - return the field.
DeepCopy()2726 PyObject* ContainerBase::DeepCopy() {
2727 CMessage* new_parent =
2728 cmessage::NewEmptyMessage(this->parent->GetMessageClass());
2729 new_parent->message = this->parent->message->New();
2730
2731 // Copy the map field into the new message.
2732 this->parent->message->GetReflection()->SwapFields(
2733 this->parent->message, new_parent->message,
2734 {this->parent_field_descriptor});
2735 this->parent->message->MergeFrom(*new_parent->message);
2736
2737 PyObject* result =
2738 cmessage::GetFieldValue(new_parent, this->parent_field_descriptor);
2739 Py_DECREF(new_parent);
2740 return result;
2741 }
2742
RemoveFromParentCache()2743 void ContainerBase::RemoveFromParentCache() {
2744 CMessage* parent = this->parent;
2745 if (parent) {
2746 if (parent->composite_fields)
2747 parent->composite_fields->erase(this->parent_field_descriptor);
2748 Py_CLEAR(parent);
2749 }
2750 }
2751
BuildSubMessageFromPointer(const FieldDescriptor * field_descriptor,Message * sub_message,CMessageClass * message_class)2752 CMessage* CMessage::BuildSubMessageFromPointer(
2753 const FieldDescriptor* field_descriptor, Message* sub_message,
2754 CMessageClass* message_class) {
2755 if (!this->child_submessages) {
2756 this->child_submessages = new CMessage::SubMessagesMap();
2757 }
2758 CMessage* cmsg = FindPtrOrNull(
2759 *this->child_submessages, sub_message);
2760 if (cmsg) {
2761 Py_INCREF(cmsg);
2762 } else {
2763 cmsg = cmessage::NewEmptyMessage(message_class);
2764
2765 if (cmsg == NULL) {
2766 return NULL;
2767 }
2768 cmsg->message = sub_message;
2769 Py_INCREF(this);
2770 cmsg->parent = this;
2771 cmsg->parent_field_descriptor = field_descriptor;
2772 cmessage::SetSubmessage(this, cmsg);
2773 }
2774 return cmsg;
2775 }
2776
MaybeReleaseSubMessage(Message * sub_message)2777 CMessage* CMessage::MaybeReleaseSubMessage(Message* sub_message) {
2778 if (!this->child_submessages) {
2779 return nullptr;
2780 }
2781 CMessage* released = FindPtrOrNull(
2782 *this->child_submessages, sub_message);
2783 if (!released) {
2784 return nullptr;
2785 }
2786 // The target message will now own its content.
2787 Py_CLEAR(released->parent);
2788 released->parent_field_descriptor = nullptr;
2789 released->read_only = false;
2790 // Delete it from the cache.
2791 this->child_submessages->erase(sub_message);
2792 return released;
2793 }
2794
2795 static CMessageClass _CMessage_Type = { { {
2796 PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0)
2797 FULL_MODULE_NAME ".CMessage", // tp_name
2798 sizeof(CMessage), // tp_basicsize
2799 0, // tp_itemsize
2800 (destructor)cmessage::Dealloc, // tp_dealloc
2801 0, // tp_print
2802 0, // tp_getattr
2803 0, // tp_setattr
2804 0, // tp_compare
2805 (reprfunc)cmessage::ToStr, // tp_repr
2806 0, // tp_as_number
2807 0, // tp_as_sequence
2808 0, // tp_as_mapping
2809 PyObject_HashNotImplemented, // tp_hash
2810 0, // tp_call
2811 (reprfunc)cmessage::ToStr, // tp_str
2812 cmessage::GetAttr, // tp_getattro
2813 0, // tp_setattro
2814 0, // tp_as_buffer
2815 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE
2816 | Py_TPFLAGS_HAVE_VERSION_TAG, // tp_flags
2817 "A ProtocolMessage", // tp_doc
2818 0, // tp_traverse
2819 0, // tp_clear
2820 (richcmpfunc)cmessage::RichCompare, // tp_richcompare
2821 offsetof(CMessage, weakreflist), // tp_weaklistoffset
2822 0, // tp_iter
2823 0, // tp_iternext
2824 cmessage::Methods, // tp_methods
2825 0, // tp_members
2826 cmessage::Getters, // tp_getset
2827 0, // tp_base
2828 0, // tp_dict
2829 0, // tp_descr_get
2830 0, // tp_descr_set
2831 0, // tp_dictoffset
2832 (initproc)cmessage::Init, // tp_init
2833 0, // tp_alloc
2834 cmessage::New, // tp_new
2835 } } };
2836 PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type;
2837
2838 // --- Exposing the C proto living inside Python proto to C code:
2839
2840 const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
2841 Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
2842
GetCProtoInsidePyProtoImpl(PyObject * msg)2843 static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
2844 const Message* message = PyMessage_GetMessagePointer(msg);
2845 if (message == NULL) {
2846 PyErr_Clear();
2847 return NULL;
2848 }
2849 return message;
2850 }
2851
MutableCProtoInsidePyProtoImpl(PyObject * msg)2852 static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
2853 Message* message = PyMessage_GetMutableMessagePointer(msg);
2854 if (message == NULL) {
2855 PyErr_Clear();
2856 return NULL;
2857 }
2858 return message;
2859 }
2860
PyMessage_GetMessagePointer(PyObject * msg)2861 const Message* PyMessage_GetMessagePointer(PyObject* msg) {
2862 if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2863 PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2864 return NULL;
2865 }
2866 CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2867 return cmsg->message;
2868 }
2869
PyMessage_GetMutableMessagePointer(PyObject * msg)2870 Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
2871 if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2872 PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2873 return NULL;
2874 }
2875 CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2876
2877
2878 if ((cmsg->composite_fields && !cmsg->composite_fields->empty()) ||
2879 (cmsg->child_submessages && !cmsg->child_submessages->empty())) {
2880 // There is currently no way of accurately syncing arbitrary changes to
2881 // the underlying C++ message back to the CMessage (e.g. removed repeated
2882 // composite containers). We only allow direct mutation of the underlying
2883 // C++ message if there is no child data in the CMessage.
2884 PyErr_SetString(PyExc_ValueError,
2885 "Cannot reliably get a mutable pointer "
2886 "to a message with extra references");
2887 return NULL;
2888 }
2889 cmessage::AssureWritable(cmsg);
2890 return cmsg->message;
2891 }
2892
PyMessage_NewMessageOwnedExternally(Message * message,PyObject * message_factory)2893 PyObject* PyMessage_NewMessageOwnedExternally(Message* message,
2894 PyObject* message_factory) {
2895 if (message_factory) {
2896 PyErr_SetString(PyExc_NotImplementedError,
2897 "Default message_factory=NULL is the only supported value");
2898 return NULL;
2899 }
2900 if (message->GetReflection()->GetMessageFactory() !=
2901 MessageFactory::generated_factory()) {
2902 PyErr_SetString(PyExc_TypeError,
2903 "Message pointer was not created from the default factory");
2904 return NULL;
2905 }
2906
2907 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2908 GetDefaultDescriptorPool()->py_message_factory, message->GetDescriptor());
2909
2910 CMessage* self = cmessage::NewEmptyMessage(message_class);
2911 if (self == NULL) {
2912 return NULL;
2913 }
2914 Py_DECREF(message_class);
2915 self->message = message;
2916 Py_INCREF(Py_None);
2917 self->parent = reinterpret_cast<CMessage*>(Py_None);
2918 return self->AsPyObject();
2919 }
2920
InitGlobals()2921 void InitGlobals() {
2922 // TODO(gps): Check all return values in this function for NULL and propagate
2923 // the error (MemoryError) on up to result in an import failure. These should
2924 // also be freed and reset to NULL during finalization.
2925 kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
2926
2927 PyObject *dummy_obj = PySet_New(NULL);
2928 kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
2929 Py_DECREF(dummy_obj);
2930 }
2931
InitProto2MessageModule(PyObject * m)2932 bool InitProto2MessageModule(PyObject *m) {
2933 // Initialize types and globals in descriptor.cc
2934 if (!InitDescriptor()) {
2935 return false;
2936 }
2937
2938 // Initialize types and globals in descriptor_pool.cc
2939 if (!InitDescriptorPool()) {
2940 return false;
2941 }
2942
2943 // Initialize types and globals in message_factory.cc
2944 if (!InitMessageFactory()) {
2945 return false;
2946 }
2947
2948 // Initialize constants defined in this file.
2949 InitGlobals();
2950
2951 CMessageClass_Type->tp_base = &PyType_Type;
2952 if (PyType_Ready(CMessageClass_Type) < 0) {
2953 return false;
2954 }
2955 PyModule_AddObject(m, "MessageMeta",
2956 reinterpret_cast<PyObject*>(CMessageClass_Type));
2957
2958 if (PyType_Ready(CMessage_Type) < 0) {
2959 return false;
2960 }
2961 if (PyType_Ready(CFieldProperty_Type) < 0) {
2962 return false;
2963 }
2964
2965 // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
2966 // it here as well to document that subclasses need to set it.
2967 PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None);
2968 // Invalidate any cached data for the CMessage type.
2969 // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG,
2970 // after we have modified CMessage_Type.tp_dict.
2971 PyType_Modified(CMessage_Type);
2972
2973 PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type));
2974
2975 // Initialize Repeated container types.
2976 {
2977 if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) {
2978 return false;
2979 }
2980
2981 PyModule_AddObject(m, "RepeatedScalarContainer",
2982 reinterpret_cast<PyObject*>(
2983 &RepeatedScalarContainer_Type));
2984
2985 if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) {
2986 return false;
2987 }
2988
2989 PyModule_AddObject(
2990 m, "RepeatedCompositeContainer",
2991 reinterpret_cast<PyObject*>(
2992 &RepeatedCompositeContainer_Type));
2993
2994 // Register them as collections.Sequence
2995 ScopedPyObjectPtr collections(PyImport_ImportModule("collections"));
2996 if (collections == NULL) {
2997 return false;
2998 }
2999 ScopedPyObjectPtr mutable_sequence(
3000 PyObject_GetAttrString(collections.get(), "MutableSequence"));
3001 if (mutable_sequence == NULL) {
3002 return false;
3003 }
3004 if (ScopedPyObjectPtr(
3005 PyObject_CallMethod(mutable_sequence.get(), "register", "O",
3006 &RepeatedScalarContainer_Type)) == NULL) {
3007 return false;
3008 }
3009 if (ScopedPyObjectPtr(
3010 PyObject_CallMethod(mutable_sequence.get(), "register", "O",
3011 &RepeatedCompositeContainer_Type)) == NULL) {
3012 return false;
3013 }
3014 }
3015
3016 if (PyType_Ready(&PyUnknownFields_Type) < 0) {
3017 return false;
3018 }
3019
3020 PyModule_AddObject(m, "UnknownFieldSet",
3021 reinterpret_cast<PyObject*>(
3022 &PyUnknownFields_Type));
3023
3024 if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) {
3025 return false;
3026 }
3027
3028 PyModule_AddObject(m, "UnknownField",
3029 reinterpret_cast<PyObject*>(
3030 &PyUnknownFieldRef_Type));
3031
3032 // Initialize Map container types.
3033 if (!InitMapContainers()) {
3034 return false;
3035 }
3036 PyModule_AddObject(m, "ScalarMapContainer",
3037 reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
3038 PyModule_AddObject(m, "MessageMapContainer",
3039 reinterpret_cast<PyObject*>(MessageMapContainer_Type));
3040 PyModule_AddObject(m, "MapIterator",
3041 reinterpret_cast<PyObject*>(&MapIterator_Type));
3042
3043 if (PyType_Ready(&ExtensionDict_Type) < 0) {
3044 return false;
3045 }
3046 PyModule_AddObject(
3047 m, "ExtensionDict",
3048 reinterpret_cast<PyObject*>(&ExtensionDict_Type));
3049 if (PyType_Ready(&ExtensionIterator_Type) < 0) {
3050 return false;
3051 }
3052 PyModule_AddObject(m, "ExtensionIterator",
3053 reinterpret_cast<PyObject*>(&ExtensionIterator_Type));
3054
3055 // Expose the DescriptorPool used to hold all descriptors added from generated
3056 // pb2.py files.
3057 // PyModule_AddObject steals a reference.
3058 Py_INCREF(GetDefaultDescriptorPool());
3059 PyModule_AddObject(m, "default_pool",
3060 reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
3061
3062 PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>(
3063 &PyDescriptorPool_Type));
3064
3065 // This implementation provides full Descriptor types, we advertise it so that
3066 // descriptor.py can use them in replacement of the Python classes.
3067 PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1);
3068
3069 PyModule_AddObject(m, "Descriptor", reinterpret_cast<PyObject*>(
3070 &PyMessageDescriptor_Type));
3071 PyModule_AddObject(m, "FieldDescriptor", reinterpret_cast<PyObject*>(
3072 &PyFieldDescriptor_Type));
3073 PyModule_AddObject(m, "EnumDescriptor", reinterpret_cast<PyObject*>(
3074 &PyEnumDescriptor_Type));
3075 PyModule_AddObject(m, "EnumValueDescriptor", reinterpret_cast<PyObject*>(
3076 &PyEnumValueDescriptor_Type));
3077 PyModule_AddObject(m, "FileDescriptor", reinterpret_cast<PyObject*>(
3078 &PyFileDescriptor_Type));
3079 PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
3080 &PyOneofDescriptor_Type));
3081 PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>(
3082 &PyServiceDescriptor_Type));
3083 PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>(
3084 &PyMethodDescriptor_Type));
3085
3086 PyObject* enum_type_wrapper = PyImport_ImportModule(
3087 "google.protobuf.internal.enum_type_wrapper");
3088 if (enum_type_wrapper == NULL) {
3089 return false;
3090 }
3091 EnumTypeWrapper_class =
3092 PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
3093 Py_DECREF(enum_type_wrapper);
3094
3095 PyObject* message_module = PyImport_ImportModule(
3096 "google.protobuf.message");
3097 if (message_module == NULL) {
3098 return false;
3099 }
3100 EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError");
3101 DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError");
3102 PythonMessage_class = PyObject_GetAttrString(message_module, "Message");
3103 Py_DECREF(message_module);
3104
3105 PyObject* pickle_module = PyImport_ImportModule("pickle");
3106 if (pickle_module == NULL) {
3107 return false;
3108 }
3109 PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError");
3110 Py_DECREF(pickle_module);
3111
3112 // Override {Get,Mutable}CProtoInsidePyProto.
3113 GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl;
3114 MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl;
3115
3116 return true;
3117 }
3118
3119 } // namespace python
3120 } // namespace protobuf
3121 } // namespace google
3122