1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33
34 #include <google/protobuf/pyext/message.h>
35
36 #include <structmember.h> // A Python header file.
37
38 #include <map>
39 #include <memory>
40 #include <string>
41 #include <vector>
42
43 #include <google/protobuf/stubs/strutil.h>
44
45 #ifndef PyVarObject_HEAD_INIT
46 #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
47 #endif
48 #ifndef Py_TYPE
49 #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
50 #endif
51 #include <google/protobuf/stubs/common.h>
52 #include <google/protobuf/stubs/logging.h>
53 #include <google/protobuf/io/coded_stream.h>
54 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
55 #include <google/protobuf/descriptor.pb.h>
56 #include <google/protobuf/descriptor.h>
57 #include <google/protobuf/message.h>
58 #include <google/protobuf/text_format.h>
59 #include <google/protobuf/unknown_field_set.h>
60 #include <google/protobuf/pyext/descriptor.h>
61 #include <google/protobuf/pyext/descriptor_pool.h>
62 #include <google/protobuf/pyext/extension_dict.h>
63 #include <google/protobuf/pyext/field.h>
64 #include <google/protobuf/pyext/map_container.h>
65 #include <google/protobuf/pyext/message_factory.h>
66 #include <google/protobuf/pyext/repeated_composite_container.h>
67 #include <google/protobuf/pyext/repeated_scalar_container.h>
68 #include <google/protobuf/pyext/safe_numerics.h>
69 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
70 #include <google/protobuf/pyext/unknown_fields.h>
71 #include <google/protobuf/util/message_differencer.h>
72 #include <google/protobuf/io/strtod.h>
73 #include <google/protobuf/stubs/map_util.h>
74
75 // clang-format off
76 #include <google/protobuf/port_def.inc>
77 // clang-format on
78
79 #if PY_MAJOR_VERSION >= 3
80 #define PyInt_AsLong PyLong_AsLong
81 #define PyInt_FromLong PyLong_FromLong
82 #define PyInt_FromSize_t PyLong_FromSize_t
83 #define PyString_Check PyUnicode_Check
84 #define PyString_FromString PyUnicode_FromString
85 #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
86 #define PyString_FromFormat PyUnicode_FromFormat
87 #if PY_VERSION_HEX < 0x03030000
88 #error "Python 3.0 - 3.2 are not supported."
89 #else
90 #define PyString_AsString(ob) \
91 (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob))
92 #define PyString_AsStringAndSize(ob, charpp, sizep) \
93 (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
94 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
95 ? -1 \
96 : 0) \
97 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
98 #endif
99 #endif
100
101 namespace google {
102 namespace protobuf {
103 namespace python {
104
105 static PyObject* kDESCRIPTOR;
106 PyObject* EnumTypeWrapper_class;
107 static PyObject* PythonMessage_class;
108 static PyObject* kEmptyWeakref;
109 static PyObject* WKT_classes = NULL;
110
111 namespace message_meta {
112
113 static int InsertEmptyWeakref(PyTypeObject* base);
114
115 namespace {
116 // Copied over from internal 'google/protobuf/stubs/strutil.h'.
LowerString(std::string * s)117 inline void LowerString(std::string* s) {
118 std::string::iterator end = s->end();
119 for (std::string::iterator i = s->begin(); i != end; ++i) {
120 // tolower() changes based on locale. We don't want this!
121 if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A';
122 }
123 }
124 }
125
126 // Finalize the creation of the Message class.
AddDescriptors(PyObject * cls,const Descriptor * descriptor)127 static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
128 // For each field set: cls.<field>_FIELD_NUMBER = <number>
129 for (int i = 0; i < descriptor->field_count(); ++i) {
130 const FieldDescriptor* field_descriptor = descriptor->field(i);
131 ScopedPyObjectPtr property(NewFieldProperty(field_descriptor));
132 if (property == NULL) {
133 return -1;
134 }
135 if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(),
136 property.get()) < 0) {
137 return -1;
138 }
139 }
140
141 // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
142 for (int i = 0; i < descriptor->enum_type_count(); ++i) {
143 const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
144 ScopedPyObjectPtr enum_type(
145 PyEnumDescriptor_FromDescriptor(enum_descriptor));
146 if (enum_type == NULL) {
147 return -1;
148 }
149 // Add wrapped enum type to message class.
150 ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
151 EnumTypeWrapper_class, enum_type.get(), NULL));
152 if (wrapped == NULL) {
153 return -1;
154 }
155 if (PyObject_SetAttrString(
156 cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) {
157 return -1;
158 }
159
160 // For each enum value add cls.<name> = <number>
161 for (int j = 0; j < enum_descriptor->value_count(); ++j) {
162 const EnumValueDescriptor* enum_value_descriptor =
163 enum_descriptor->value(j);
164 ScopedPyObjectPtr value_number(PyInt_FromLong(
165 enum_value_descriptor->number()));
166 if (value_number == NULL) {
167 return -1;
168 }
169 if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(),
170 value_number.get()) == -1) {
171 return -1;
172 }
173 }
174 }
175
176 // For each extension set cls.<extension name> = <extension descriptor>.
177 //
178 // Extension descriptors come from
179 // <message descriptor>.extensions_by_name[name]
180 // which was defined previously.
181 for (int i = 0; i < descriptor->extension_count(); ++i) {
182 const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
183 ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
184 if (extension_field == NULL) {
185 return -1;
186 }
187
188 // Add the extension field to the message class.
189 if (PyObject_SetAttrString(
190 cls, field->name().c_str(), extension_field.get()) == -1) {
191 return -1;
192 }
193 }
194
195 return 0;
196 }
197
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)198 static PyObject* New(PyTypeObject* type,
199 PyObject* args, PyObject* kwargs) {
200 static char *kwlist[] = {"name", "bases", "dict", 0};
201 PyObject *bases, *dict;
202 const char* name;
203
204 // Check arguments: (name, bases, dict)
205 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist,
206 &name,
207 &PyTuple_Type, &bases,
208 &PyDict_Type, &dict)) {
209 return NULL;
210 }
211
212 // Check bases: only (), or (message.Message,) are allowed
213 if (!(PyTuple_GET_SIZE(bases) == 0 ||
214 (PyTuple_GET_SIZE(bases) == 1 &&
215 PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) {
216 PyErr_SetString(PyExc_TypeError,
217 "A Message class can only inherit from Message");
218 return NULL;
219 }
220
221 // Check dict['DESCRIPTOR']
222 PyObject* descriptor_or_name = PyDict_GetItem(dict, kDESCRIPTOR);
223 if (descriptor_or_name == nullptr) {
224 PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR");
225 return NULL;
226 }
227
228 Py_ssize_t name_size;
229 char* full_name;
230 const Descriptor* message_descriptor;
231 PyObject* py_descriptor;
232
233 if (PyObject_TypeCheck(descriptor_or_name, &PyMessageDescriptor_Type)) {
234 py_descriptor = descriptor_or_name;
235 message_descriptor = PyMessageDescriptor_AsDescriptor(py_descriptor);
236 if (message_descriptor == nullptr) {
237 return nullptr;
238 }
239 } else {
240 if (PyString_AsStringAndSize(descriptor_or_name, &full_name, &name_size) <
241 0) {
242 return nullptr;
243 }
244 message_descriptor =
245 GetDefaultDescriptorPool()->pool->FindMessageTypeByName(
246 StringParam(full_name, name_size));
247 if (message_descriptor == nullptr) {
248 PyErr_Format(PyExc_KeyError,
249 "Can not find message descriptor %s "
250 "from pool",
251 full_name);
252 return nullptr;
253 }
254 py_descriptor = PyMessageDescriptor_FromDescriptor(message_descriptor);
255 // reset the dict['DESCRIPTOR'] to py_descriptor.
256 PyDict_SetItem(dict, kDESCRIPTOR, py_descriptor);
257 }
258
259 // Messages have no __dict__
260 ScopedPyObjectPtr slots(PyTuple_New(0));
261 if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
262 return NULL;
263 }
264
265 // Build the arguments to the base metaclass.
266 // We change the __bases__ classes.
267 ScopedPyObjectPtr new_args;
268
269 if (WKT_classes == NULL) {
270 ScopedPyObjectPtr well_known_types(PyImport_ImportModule(
271 "google.protobuf.internal.well_known_types"));
272 GOOGLE_DCHECK(well_known_types != NULL);
273
274 WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
275 GOOGLE_DCHECK(WKT_classes != NULL);
276 }
277
278 PyObject* well_known_class = PyDict_GetItemString(
279 WKT_classes, message_descriptor->full_name().c_str());
280 if (well_known_class == NULL) {
281 new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type,
282 PythonMessage_class, dict));
283 } else {
284 new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type,
285 PythonMessage_class, well_known_class, dict));
286 }
287
288 if (new_args == NULL) {
289 return NULL;
290 }
291 // Call the base metaclass.
292 ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL));
293 if (result == NULL) {
294 return NULL;
295 }
296 CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
297
298 // Insert the empty weakref into the base classes.
299 if (InsertEmptyWeakref(
300 reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 ||
301 InsertEmptyWeakref(CMessage_Type) < 0) {
302 return NULL;
303 }
304
305 // Cache the descriptor, both as Python object and as C++ pointer.
306 const Descriptor* descriptor =
307 PyMessageDescriptor_AsDescriptor(py_descriptor);
308 if (descriptor == NULL) {
309 return NULL;
310 }
311 Py_INCREF(py_descriptor);
312 newtype->py_message_descriptor = py_descriptor;
313 newtype->message_descriptor = descriptor;
314 // TODO(amauryfa): Don't always use the canonical pool of the descriptor,
315 // use the MessageFactory optionally passed in the class dict.
316 PyDescriptorPool* py_descriptor_pool =
317 GetDescriptorPool_FromPool(descriptor->file()->pool());
318 if (py_descriptor_pool == NULL) {
319 return NULL;
320 }
321 newtype->py_message_factory = py_descriptor_pool->py_message_factory;
322 Py_INCREF(newtype->py_message_factory);
323
324 // Register the message in the MessageFactory.
325 // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the
326 // MessageFactory is fully implemented in C++.
327 if (message_factory::RegisterMessageClass(newtype->py_message_factory,
328 descriptor, newtype) < 0) {
329 return NULL;
330 }
331
332 // Continue with type initialization: add other descriptors, enum values...
333 if (AddDescriptors(result.get(), descriptor) < 0) {
334 return NULL;
335 }
336 return result.release();
337 }
338
Dealloc(PyObject * pself)339 static void Dealloc(PyObject* pself) {
340 CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
341 Py_XDECREF(self->py_message_descriptor);
342 Py_XDECREF(self->py_message_factory);
343 return PyType_Type.tp_dealloc(pself);
344 }
345
GcTraverse(PyObject * pself,visitproc visit,void * arg)346 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
347 CMessageClass* self = reinterpret_cast<CMessageClass*>(pself);
348 Py_VISIT(self->py_message_descriptor);
349 Py_VISIT(self->py_message_factory);
350 return PyType_Type.tp_traverse(pself, visit, arg);
351 }
352
GcClear(PyObject * pself)353 static int GcClear(PyObject* pself) {
354 // It's important to keep the descriptor and factory alive, until the
355 // C++ message is fully destructed.
356 return PyType_Type.tp_clear(pself);
357 }
358
359 // This function inserts and empty weakref at the end of the list of
360 // subclasses for the main protocol buffer Message class.
361 //
362 // This eliminates a O(n^2) behaviour in the internal add_subclass
363 // routine.
InsertEmptyWeakref(PyTypeObject * base_type)364 static int InsertEmptyWeakref(PyTypeObject *base_type) {
365 #if PY_MAJOR_VERSION >= 3
366 // Python 3.4 has already included the fix for the issue that this
367 // hack addresses. For further background and the fix please see
368 // https://bugs.python.org/issue17936.
369 return 0;
370 #else
371 #ifdef Py_DEBUG
372 // The code below causes all new subclasses to append an entry, which is never
373 // cleared. This is a small memory leak, which we disable in Py_DEBUG mode
374 // to have stable refcounting checks.
375 #else
376 PyObject *subclasses = base_type->tp_subclasses;
377 if (subclasses && PyList_CheckExact(subclasses)) {
378 return PyList_Append(subclasses, kEmptyWeakref);
379 }
380 #endif // !Py_DEBUG
381 return 0;
382 #endif // PY_MAJOR_VERSION >= 3
383 }
384
385 // The _extensions_by_name dictionary is built on every access.
386 // TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
GetExtensionsByName(CMessageClass * self,void * closure)387 static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
388 if (self->message_descriptor == NULL) {
389 // This is the base Message object, simply raise AttributeError.
390 PyErr_SetString(PyExc_AttributeError,
391 "Base Message class has no DESCRIPTOR");
392 return NULL;
393 }
394
395 const PyDescriptorPool* pool = self->py_message_factory->pool;
396
397 std::vector<const FieldDescriptor*> extensions;
398 pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
399
400 ScopedPyObjectPtr result(PyDict_New());
401 for (int i = 0; i < extensions.size(); i++) {
402 ScopedPyObjectPtr extension(
403 PyFieldDescriptor_FromDescriptor(extensions[i]));
404 if (extension == NULL) {
405 return NULL;
406 }
407 if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
408 extension.get()) < 0) {
409 return NULL;
410 }
411 }
412 return result.release();
413 }
414
415 // The _extensions_by_number dictionary is built on every access.
416 // TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
GetExtensionsByNumber(CMessageClass * self,void * closure)417 static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
418 if (self->message_descriptor == NULL) {
419 // This is the base Message object, simply raise AttributeError.
420 PyErr_SetString(PyExc_AttributeError,
421 "Base Message class has no DESCRIPTOR");
422 return NULL;
423 }
424
425 const PyDescriptorPool* pool = self->py_message_factory->pool;
426
427 std::vector<const FieldDescriptor*> extensions;
428 pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
429
430 ScopedPyObjectPtr result(PyDict_New());
431 for (int i = 0; i < extensions.size(); i++) {
432 ScopedPyObjectPtr extension(
433 PyFieldDescriptor_FromDescriptor(extensions[i]));
434 if (extension == NULL) {
435 return NULL;
436 }
437 ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
438 if (number == NULL) {
439 return NULL;
440 }
441 if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
442 return NULL;
443 }
444 }
445 return result.release();
446 }
447
448 static PyGetSetDef Getters[] = {
449 {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
450 {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
451 {NULL}
452 };
453
454 // Compute some class attributes on the fly:
455 // - All the _FIELD_NUMBER attributes, for all fields and nested extensions.
456 // Returns a new reference, or NULL with an exception set.
GetClassAttribute(CMessageClass * self,PyObject * name)457 static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) {
458 char* attr;
459 Py_ssize_t attr_size;
460 static const char kSuffix[] = "_FIELD_NUMBER";
461 if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 &&
462 HasSuffixString(StringPiece(attr, attr_size), kSuffix)) {
463 std::string field_name(attr, attr_size - sizeof(kSuffix) + 1);
464 LowerString(&field_name);
465
466 // Try to find a field with the given name, without the suffix.
467 const FieldDescriptor* field =
468 self->message_descriptor->FindFieldByLowercaseName(field_name);
469 if (!field) {
470 // Search nested extensions as well.
471 field =
472 self->message_descriptor->FindExtensionByLowercaseName(field_name);
473 }
474 if (field) {
475 return PyInt_FromLong(field->number());
476 }
477 }
478 PyErr_SetObject(PyExc_AttributeError, name);
479 return NULL;
480 }
481
GetAttr(CMessageClass * self,PyObject * name)482 static PyObject* GetAttr(CMessageClass* self, PyObject* name) {
483 PyObject* result = CMessageClass_Type->tp_base->tp_getattro(
484 reinterpret_cast<PyObject*>(self), name);
485 if (result != NULL) {
486 return result;
487 }
488 if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
489 return NULL;
490 }
491
492 PyErr_Clear();
493 return GetClassAttribute(self, name);
494 }
495
496 } // namespace message_meta
497
498 static PyTypeObject _CMessageClass_Type = {
499 PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
500 ".MessageMeta", // tp_name
501 sizeof(CMessageClass), // tp_basicsize
502 0, // tp_itemsize
503 message_meta::Dealloc, // tp_dealloc
504 0, // tp_print
505 0, // tp_getattr
506 0, // tp_setattr
507 0, // tp_compare
508 0, // tp_repr
509 0, // tp_as_number
510 0, // tp_as_sequence
511 0, // tp_as_mapping
512 0, // tp_hash
513 0, // tp_call
514 0, // tp_str
515 (getattrofunc)message_meta::GetAttr, // tp_getattro
516 0, // tp_setattro
517 0, // tp_as_buffer
518 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags
519 "The metaclass of ProtocolMessages", // tp_doc
520 message_meta::GcTraverse, // tp_traverse
521 message_meta::GcClear, // tp_clear
522 0, // tp_richcompare
523 0, // tp_weaklistoffset
524 0, // tp_iter
525 0, // tp_iternext
526 0, // tp_methods
527 0, // tp_members
528 message_meta::Getters, // tp_getset
529 0, // tp_base
530 0, // tp_dict
531 0, // tp_descr_get
532 0, // tp_descr_set
533 0, // tp_dictoffset
534 0, // tp_init
535 0, // tp_alloc
536 message_meta::New, // tp_new
537 };
538 PyTypeObject* CMessageClass_Type = &_CMessageClass_Type;
539
CheckMessageClass(PyTypeObject * cls)540 static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
541 if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
542 PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
543 return NULL;
544 }
545 return reinterpret_cast<CMessageClass*>(cls);
546 }
547
GetMessageDescriptor(PyTypeObject * cls)548 static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
549 CMessageClass* type = CheckMessageClass(cls);
550 if (type == NULL) {
551 return NULL;
552 }
553 return type->message_descriptor;
554 }
555
556 // Forward declarations
557 namespace cmessage {
558 int InternalReleaseFieldByDescriptor(
559 CMessage* self,
560 const FieldDescriptor* field_descriptor);
561 } // namespace cmessage
562
563 // ---------------------------------------------------------------------
564
565 PyObject* EncodeError_class;
566 PyObject* DecodeError_class;
567 PyObject* PickleError_class;
568
569 // Format an error message for unexpected types.
570 // Always return with an exception set.
FormatTypeError(PyObject * arg,char * expected_types)571 void FormatTypeError(PyObject* arg, char* expected_types) {
572 // This function is often called with an exception set.
573 // Clear it to call PyObject_Repr() in good conditions.
574 PyErr_Clear();
575 PyObject* repr = PyObject_Repr(arg);
576 if (repr) {
577 PyErr_Format(PyExc_TypeError,
578 "%.100s has type %.100s, but expected one of: %s",
579 PyString_AsString(repr),
580 Py_TYPE(arg)->tp_name,
581 expected_types);
582 Py_DECREF(repr);
583 }
584 }
585
OutOfRangeError(PyObject * arg)586 void OutOfRangeError(PyObject* arg) {
587 PyObject *s = PyObject_Str(arg);
588 if (s) {
589 PyErr_Format(PyExc_ValueError,
590 "Value out of range: %s",
591 PyString_AsString(s));
592 Py_DECREF(s);
593 }
594 }
595
596 template<class RangeType, class ValueType>
VerifyIntegerCastAndRange(PyObject * arg,ValueType value)597 bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
598 if (PROTOBUF_PREDICT_FALSE(value == -1 && PyErr_Occurred())) {
599 if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
600 // Replace it with the same ValueError as pure python protos instead of
601 // the default one.
602 PyErr_Clear();
603 OutOfRangeError(arg);
604 } // Otherwise propagate existing error.
605 return false;
606 }
607 if (PROTOBUF_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) {
608 OutOfRangeError(arg);
609 return false;
610 }
611 return true;
612 }
613
614 template<class T>
CheckAndGetInteger(PyObject * arg,T * value)615 bool CheckAndGetInteger(PyObject* arg, T* value) {
616 // The fast path.
617 #if PY_MAJOR_VERSION < 3
618 // For the typical case, offer a fast path.
619 if (PROTOBUF_PREDICT_TRUE(PyInt_Check(arg))) {
620 long int_result = PyInt_AsLong(arg);
621 if (PROTOBUF_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) {
622 *value = static_cast<T>(int_result);
623 return true;
624 } else {
625 OutOfRangeError(arg);
626 return false;
627 }
628 }
629 #endif
630 // This effectively defines an integer as "an object that can be cast as
631 // an integer and can be used as an ordinal number".
632 // This definition includes everything that implements numbers.Integral
633 // and shouldn't cast the net too wide.
634 if (PROTOBUF_PREDICT_FALSE(!PyIndex_Check(arg))) {
635 FormatTypeError(arg, "int, long");
636 return false;
637 }
638
639 // Now we have an integral number so we can safely use PyLong_ functions.
640 // We need to treat the signed and unsigned cases differently in case arg is
641 // holding a value above the maximum for signed longs.
642 if (std::numeric_limits<T>::min() == 0) {
643 // Unsigned case.
644 unsigned PY_LONG_LONG ulong_result;
645 if (PyLong_Check(arg)) {
646 ulong_result = PyLong_AsUnsignedLongLong(arg);
647 } else {
648 // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
649 // picky about the exact type.
650 PyObject* casted = PyNumber_Long(arg);
651 if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
652 // Propagate existing error.
653 return false;
654 }
655 ulong_result = PyLong_AsUnsignedLongLong(casted);
656 Py_DECREF(casted);
657 }
658 if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
659 ulong_result)) {
660 *value = static_cast<T>(ulong_result);
661 } else {
662 return false;
663 }
664 } else {
665 // Signed case.
666 PY_LONG_LONG long_result;
667 PyNumberMethods *nb;
668 if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
669 // PyLong_AsLongLong requires it to be a long or to have an __int__()
670 // method.
671 long_result = PyLong_AsLongLong(arg);
672 } else {
673 // Valid subclasses of numbers.Integral should have a __long__() method
674 // so fall back to that.
675 PyObject* casted = PyNumber_Long(arg);
676 if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) {
677 // Propagate existing error.
678 return false;
679 }
680 long_result = PyLong_AsLongLong(casted);
681 Py_DECREF(casted);
682 }
683 if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
684 *value = static_cast<T>(long_result);
685 } else {
686 return false;
687 }
688 }
689
690 return true;
691 }
692
693 // These are referenced by repeated_scalar_container, and must
694 // be explicitly instantiated.
695 template bool CheckAndGetInteger<int32>(PyObject*, int32*);
696 template bool CheckAndGetInteger<int64>(PyObject*, int64*);
697 template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
698 template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
699
CheckAndGetDouble(PyObject * arg,double * value)700 bool CheckAndGetDouble(PyObject* arg, double* value) {
701 *value = PyFloat_AsDouble(arg);
702 if (PROTOBUF_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) {
703 FormatTypeError(arg, "int, long, float");
704 return false;
705 }
706 return true;
707 }
708
CheckAndGetFloat(PyObject * arg,float * value)709 bool CheckAndGetFloat(PyObject* arg, float* value) {
710 double double_value;
711 if (!CheckAndGetDouble(arg, &double_value)) {
712 return false;
713 }
714 *value = io::SafeDoubleToFloat(double_value);
715 return true;
716 }
717
CheckAndGetBool(PyObject * arg,bool * value)718 bool CheckAndGetBool(PyObject* arg, bool* value) {
719 long long_value = PyInt_AsLong(arg);
720 if (long_value == -1 && PyErr_Occurred()) {
721 FormatTypeError(arg, "int, long, bool");
722 return false;
723 }
724 *value = static_cast<bool>(long_value);
725
726 return true;
727 }
728
729 // Checks whether the given object (which must be "bytes" or "unicode") contains
730 // valid UTF-8.
IsValidUTF8(PyObject * obj)731 bool IsValidUTF8(PyObject* obj) {
732 if (PyBytes_Check(obj)) {
733 PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL);
734
735 // Clear the error indicator; we report our own error when desired.
736 PyErr_Clear();
737
738 if (unicode) {
739 Py_DECREF(unicode);
740 return true;
741 } else {
742 return false;
743 }
744 } else {
745 // Unicode object, known to be valid UTF-8.
746 return true;
747 }
748 }
749
AllowInvalidUTF8(const FieldDescriptor * field)750 bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; }
751
CheckString(PyObject * arg,const FieldDescriptor * descriptor)752 PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
753 GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING ||
754 descriptor->type() == FieldDescriptor::TYPE_BYTES);
755 if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
756 if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
757 FormatTypeError(arg, "bytes, unicode");
758 return NULL;
759 }
760
761 if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) {
762 PyObject* repr = PyObject_Repr(arg);
763 PyErr_Format(PyExc_ValueError,
764 "%s has type str, but isn't valid UTF-8 "
765 "encoding. Non-UTF-8 strings must be converted to "
766 "unicode objects before being added.",
767 PyString_AsString(repr));
768 Py_DECREF(repr);
769 return NULL;
770 }
771 } else if (!PyBytes_Check(arg)) {
772 FormatTypeError(arg, "bytes");
773 return NULL;
774 }
775
776 PyObject* encoded_string = NULL;
777 if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
778 if (PyBytes_Check(arg)) {
779 // The bytes were already validated as correctly encoded UTF-8 above.
780 encoded_string = arg; // Already encoded.
781 Py_INCREF(encoded_string);
782 } else {
783 encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
784 }
785 } else {
786 // In this case field type is "bytes".
787 encoded_string = arg;
788 Py_INCREF(encoded_string);
789 }
790
791 return encoded_string;
792 }
793
CheckAndSetString(PyObject * arg,Message * message,const FieldDescriptor * descriptor,const Reflection * reflection,bool append,int index)794 bool CheckAndSetString(
795 PyObject* arg, Message* message,
796 const FieldDescriptor* descriptor,
797 const Reflection* reflection,
798 bool append,
799 int index) {
800 ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor));
801
802 if (encoded_string.get() == NULL) {
803 return false;
804 }
805
806 char* value;
807 Py_ssize_t value_len;
808 if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) {
809 return false;
810 }
811
812 string value_string(value, value_len);
813 if (append) {
814 reflection->AddString(message, descriptor, std::move(value_string));
815 } else if (index < 0) {
816 reflection->SetString(message, descriptor, std::move(value_string));
817 } else {
818 reflection->SetRepeatedString(message, descriptor, index,
819 std::move(value_string));
820 }
821 return true;
822 }
823
ToStringObject(const FieldDescriptor * descriptor,const std::string & value)824 PyObject* ToStringObject(const FieldDescriptor* descriptor,
825 const std::string& value) {
826 if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
827 return PyBytes_FromStringAndSize(value.c_str(), value.length());
828 }
829
830 PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL);
831 // If the string can't be decoded in UTF-8, just return a string object that
832 // contains the raw bytes. This can't happen if the value was assigned using
833 // the members of the Python message object, but can happen if the values were
834 // parsed from the wire (binary).
835 if (result == NULL) {
836 PyErr_Clear();
837 result = PyBytes_FromStringAndSize(value.c_str(), value.length());
838 }
839 return result;
840 }
841
CheckFieldBelongsToMessage(const FieldDescriptor * field_descriptor,const Message * message)842 bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
843 const Message* message) {
844 if (message->GetDescriptor() == field_descriptor->containing_type()) {
845 return true;
846 }
847 PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'",
848 field_descriptor->full_name().c_str(),
849 message->GetDescriptor()->full_name().c_str());
850 return false;
851 }
852
853 namespace cmessage {
854
GetFactoryForMessage(CMessage * message)855 PyMessageFactory* GetFactoryForMessage(CMessage* message) {
856 GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type));
857 return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
858 }
859
MaybeReleaseOverlappingOneofField(CMessage * cmessage,const FieldDescriptor * field)860 static int MaybeReleaseOverlappingOneofField(
861 CMessage* cmessage,
862 const FieldDescriptor* field) {
863 #ifdef GOOGLE_PROTOBUF_HAS_ONEOF
864 Message* message = cmessage->message;
865 const Reflection* reflection = message->GetReflection();
866 if (!field->containing_oneof() ||
867 !reflection->HasOneof(*message, field->containing_oneof()) ||
868 reflection->HasField(*message, field)) {
869 // No other field in this oneof, no need to release.
870 return 0;
871 }
872
873 const OneofDescriptor* oneof = field->containing_oneof();
874 const FieldDescriptor* existing_field =
875 reflection->GetOneofFieldDescriptor(*message, oneof);
876 if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
877 // Non-message fields don't need to be released.
878 return 0;
879 }
880 if (InternalReleaseFieldByDescriptor(cmessage, existing_field) < 0) {
881 return -1;
882 }
883 #endif
884 return 0;
885 }
886
887 // After a Merge, visit every sub-message that was read-only, and
888 // eventually update their pointer if the Merge operation modified them.
FixupMessageAfterMerge(CMessage * self)889 int FixupMessageAfterMerge(CMessage* self) {
890 if (!self->composite_fields) {
891 return 0;
892 }
893 for (const auto& item : *self->composite_fields) {
894 const FieldDescriptor* descriptor = item.first;
895 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
896 !descriptor->is_repeated()) {
897 CMessage* cmsg = reinterpret_cast<CMessage*>(item.second);
898 if (cmsg->read_only == false) {
899 return 0;
900 }
901 Message* message = self->message;
902 const Reflection* reflection = message->GetReflection();
903 if (reflection->HasField(*message, descriptor)) {
904 // Message used to be read_only, but is no longer. Get the new pointer
905 // and record it.
906 Message* mutable_message =
907 reflection->MutableMessage(message, descriptor, nullptr);
908 cmsg->message = mutable_message;
909 cmsg->read_only = false;
910 if (FixupMessageAfterMerge(cmsg) < 0) {
911 return -1;
912 }
913 }
914 }
915 }
916
917 return 0;
918 }
919
920 // ---------------------------------------------------------------------
921 // Making a message writable
922
AssureWritable(CMessage * self)923 int AssureWritable(CMessage* self) {
924 if (self == NULL || !self->read_only) {
925 return 0;
926 }
927
928 // Toplevel messages are always mutable.
929 GOOGLE_DCHECK(self->parent);
930
931 if (AssureWritable(self->parent) == -1)
932 return -1;
933
934 // If this message is part of a oneof, there might be a field to release in
935 // the parent.
936 if (MaybeReleaseOverlappingOneofField(self->parent,
937 self->parent_field_descriptor) < 0) {
938 return -1;
939 }
940
941 // Make self->message writable.
942 Message* parent_message = self->parent->message;
943 const Reflection* reflection = parent_message->GetReflection();
944 Message* mutable_message = reflection->MutableMessage(
945 parent_message, self->parent_field_descriptor,
946 GetFactoryForMessage(self->parent)->message_factory);
947 if (mutable_message == NULL) {
948 return -1;
949 }
950 self->message = mutable_message;
951 self->read_only = false;
952
953 return 0;
954 }
955
956 // --- Globals:
957
958 // Retrieve a C++ FieldDescriptor for an extension handle.
GetExtensionDescriptor(PyObject * extension)959 const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
960 ScopedPyObjectPtr cdescriptor;
961 if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) {
962 // Most callers consider extensions as a plain dictionary. We should
963 // allow input which is not a field descriptor, and simply pretend it does
964 // not exist.
965 PyErr_SetObject(PyExc_KeyError, extension);
966 return NULL;
967 }
968 return PyFieldDescriptor_AsDescriptor(extension);
969 }
970
971 // If value is a string, convert it into an enum value based on the labels in
972 // descriptor, otherwise simply return value. Always returns a new reference.
GetIntegerEnumValue(const FieldDescriptor & descriptor,PyObject * value)973 static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
974 PyObject* value) {
975 if (PyString_Check(value) || PyUnicode_Check(value)) {
976 const EnumDescriptor* enum_descriptor = descriptor.enum_type();
977 if (enum_descriptor == NULL) {
978 PyErr_SetString(PyExc_TypeError, "not an enum field");
979 return NULL;
980 }
981 char* enum_label;
982 Py_ssize_t size;
983 if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) {
984 return NULL;
985 }
986 const EnumValueDescriptor* enum_value_descriptor =
987 enum_descriptor->FindValueByName(StringParam(enum_label, size));
988 if (enum_value_descriptor == NULL) {
989 PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label);
990 return NULL;
991 }
992 return PyInt_FromLong(enum_value_descriptor->number());
993 }
994 Py_INCREF(value);
995 return value;
996 }
997
998 // Delete a slice from a repeated field.
999 // The only way to remove items in C++ protos is to delete the last one,
1000 // so we swap items to move the deleted ones at the end, and then strip the
1001 // sequence.
DeleteRepeatedField(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * slice)1002 int DeleteRepeatedField(
1003 CMessage* self,
1004 const FieldDescriptor* field_descriptor,
1005 PyObject* slice) {
1006 Py_ssize_t length, from, to, step, slice_length;
1007 Message* message = self->message;
1008 const Reflection* reflection = message->GetReflection();
1009 int min, max;
1010 length = reflection->FieldSize(*message, field_descriptor);
1011
1012 if (PySlice_Check(slice)) {
1013 from = to = step = slice_length = 0;
1014 #if PY_MAJOR_VERSION < 3
1015 PySlice_GetIndicesEx(
1016 reinterpret_cast<PySliceObject*>(slice),
1017 length, &from, &to, &step, &slice_length);
1018 #else
1019 PySlice_GetIndicesEx(
1020 slice,
1021 length, &from, &to, &step, &slice_length);
1022 #endif
1023 if (from < to) {
1024 min = from;
1025 max = to - 1;
1026 } else {
1027 min = to + 1;
1028 max = from;
1029 }
1030 } else {
1031 from = to = PyLong_AsLong(slice);
1032 if (from == -1 && PyErr_Occurred()) {
1033 PyErr_SetString(PyExc_TypeError, "list indices must be integers");
1034 return -1;
1035 }
1036
1037 if (from < 0) {
1038 from = to = length + from;
1039 }
1040 step = 1;
1041 min = max = from;
1042
1043 // Range check.
1044 if (from < 0 || from >= length) {
1045 PyErr_Format(PyExc_IndexError, "list assignment index out of range");
1046 return -1;
1047 }
1048 }
1049
1050 Py_ssize_t i = from;
1051 std::vector<bool> to_delete(length, false);
1052 while (i >= min && i <= max) {
1053 to_delete[i] = true;
1054 i += step;
1055 }
1056
1057 // Swap elements so that items to delete are at the end.
1058 to = 0;
1059 for (i = 0; i < length; ++i) {
1060 if (!to_delete[i]) {
1061 if (i != to) {
1062 reflection->SwapElements(message, field_descriptor, i, to);
1063 }
1064 ++to;
1065 }
1066 }
1067
1068 // Remove items, starting from the end.
1069 for (; length > to; length--) {
1070 if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1071 reflection->RemoveLast(message, field_descriptor);
1072 continue;
1073 }
1074 // It seems that RemoveLast() is less efficient for sub-messages, and
1075 // the memory is not completely released. Prefer ReleaseLast().
1076 Message* sub_message = reflection->ReleaseLast(message, field_descriptor);
1077 // If there is a live weak reference to an item being removed, we "Release"
1078 // it, and it takes ownership of the message.
1079 if (CMessage* released = self->MaybeReleaseSubMessage(sub_message)) {
1080 released->message = sub_message;
1081 } else {
1082 // sub_message was not transferred, delete it.
1083 delete sub_message;
1084 }
1085 }
1086
1087 return 0;
1088 }
1089
1090 // Initializes fields of a message. Used in constructors.
InitAttributes(CMessage * self,PyObject * args,PyObject * kwargs)1091 int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
1092 if (args != NULL && PyTuple_Size(args) != 0) {
1093 PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
1094 return -1;
1095 }
1096
1097 if (kwargs == NULL) {
1098 return 0;
1099 }
1100
1101 Py_ssize_t pos = 0;
1102 PyObject* name;
1103 PyObject* value;
1104 while (PyDict_Next(kwargs, &pos, &name, &value)) {
1105 if (!(PyString_Check(name) || PyUnicode_Check(name))) {
1106 PyErr_SetString(PyExc_ValueError, "Field name must be a string");
1107 return -1;
1108 }
1109 ScopedPyObjectPtr property(
1110 PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name));
1111 if (property == NULL ||
1112 !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) {
1113 PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
1114 self->message->GetDescriptor()->name().c_str(),
1115 PyString_AsString(name));
1116 return -1;
1117 }
1118 const FieldDescriptor* descriptor =
1119 reinterpret_cast<PyMessageFieldProperty*>(property.get())
1120 ->field_descriptor;
1121 if (value == Py_None) {
1122 // field=None is the same as no field at all.
1123 continue;
1124 }
1125 if (descriptor->is_map()) {
1126 ScopedPyObjectPtr map(GetFieldValue(self, descriptor));
1127 const FieldDescriptor* value_descriptor =
1128 descriptor->message_type()->FindFieldByName("value");
1129 if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1130 ScopedPyObjectPtr iter(PyObject_GetIter(value));
1131 if (iter == NULL) {
1132 PyErr_Format(PyExc_TypeError, "Argument %s is not iterable", PyString_AsString(name));
1133 return -1;
1134 }
1135 ScopedPyObjectPtr next;
1136 while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1137 ScopedPyObjectPtr source_value(PyObject_GetItem(value, next.get()));
1138 ScopedPyObjectPtr dest_value(PyObject_GetItem(map.get(), next.get()));
1139 if (source_value.get() == NULL || dest_value.get() == NULL) {
1140 return -1;
1141 }
1142 ScopedPyObjectPtr ok(PyObject_CallMethod(
1143 dest_value.get(), "MergeFrom", "O", source_value.get()));
1144 if (ok.get() == NULL) {
1145 return -1;
1146 }
1147 }
1148 } else {
1149 ScopedPyObjectPtr function_return;
1150 function_return.reset(
1151 PyObject_CallMethod(map.get(), "update", "O", value));
1152 if (function_return.get() == NULL) {
1153 return -1;
1154 }
1155 }
1156 } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1157 ScopedPyObjectPtr container(GetFieldValue(self, descriptor));
1158 if (container == NULL) {
1159 return -1;
1160 }
1161 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1162 RepeatedCompositeContainer* rc_container =
1163 reinterpret_cast<RepeatedCompositeContainer*>(container.get());
1164 ScopedPyObjectPtr iter(PyObject_GetIter(value));
1165 if (iter == NULL) {
1166 PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1167 return -1;
1168 }
1169 ScopedPyObjectPtr next;
1170 while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1171 PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL);
1172 ScopedPyObjectPtr new_msg(
1173 repeated_composite_container::Add(rc_container, NULL, kwargs));
1174 if (new_msg == NULL) {
1175 return -1;
1176 }
1177 if (kwargs == NULL) {
1178 // next was not a dict, it's a message we need to merge
1179 ScopedPyObjectPtr merged(MergeFrom(
1180 reinterpret_cast<CMessage*>(new_msg.get()), next.get()));
1181 if (merged.get() == NULL) {
1182 return -1;
1183 }
1184 }
1185 }
1186 if (PyErr_Occurred()) {
1187 // Check to see how PyIter_Next() exited.
1188 return -1;
1189 }
1190 } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1191 RepeatedScalarContainer* rs_container =
1192 reinterpret_cast<RepeatedScalarContainer*>(container.get());
1193 ScopedPyObjectPtr iter(PyObject_GetIter(value));
1194 if (iter == NULL) {
1195 PyErr_SetString(PyExc_TypeError, "Value must be iterable");
1196 return -1;
1197 }
1198 ScopedPyObjectPtr next;
1199 while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
1200 ScopedPyObjectPtr enum_value(
1201 GetIntegerEnumValue(*descriptor, next.get()));
1202 if (enum_value == NULL) {
1203 return -1;
1204 }
1205 ScopedPyObjectPtr new_msg(repeated_scalar_container::Append(
1206 rs_container, enum_value.get()));
1207 if (new_msg == NULL) {
1208 return -1;
1209 }
1210 }
1211 if (PyErr_Occurred()) {
1212 // Check to see how PyIter_Next() exited.
1213 return -1;
1214 }
1215 } else {
1216 if (ScopedPyObjectPtr(repeated_scalar_container::Extend(
1217 reinterpret_cast<RepeatedScalarContainer*>(container.get()),
1218 value)) ==
1219 NULL) {
1220 return -1;
1221 }
1222 }
1223 } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1224 ScopedPyObjectPtr message(GetFieldValue(self, descriptor));
1225 if (message == NULL) {
1226 return -1;
1227 }
1228 CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
1229 if (PyDict_Check(value)) {
1230 // Make the message exist even if the dict is empty.
1231 AssureWritable(cmessage);
1232 if (InitAttributes(cmessage, NULL, value) < 0) {
1233 return -1;
1234 }
1235 } else {
1236 ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
1237 if (merged == NULL) {
1238 return -1;
1239 }
1240 }
1241 } else {
1242 ScopedPyObjectPtr new_val;
1243 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
1244 new_val.reset(GetIntegerEnumValue(*descriptor, value));
1245 if (new_val == NULL) {
1246 return -1;
1247 }
1248 value = new_val.get();
1249 }
1250 if (SetFieldValue(self, descriptor, value) < 0) {
1251 return -1;
1252 }
1253 }
1254 }
1255 return 0;
1256 }
1257
1258 // Allocates an incomplete Python Message: the caller must fill self->message
1259 // and eventually self->parent.
NewEmptyMessage(CMessageClass * type)1260 CMessage* NewEmptyMessage(CMessageClass* type) {
1261 CMessage* self = reinterpret_cast<CMessage*>(
1262 PyType_GenericAlloc(&type->super.ht_type, 0));
1263 if (self == NULL) {
1264 return NULL;
1265 }
1266
1267 self->message = NULL;
1268 self->parent = NULL;
1269 self->parent_field_descriptor = NULL;
1270 self->read_only = false;
1271
1272 self->composite_fields = NULL;
1273 self->child_submessages = NULL;
1274
1275 self->unknown_field_set = NULL;
1276
1277 return self;
1278 }
1279
1280 // The __new__ method of Message classes.
1281 // Creates a new C++ message and takes ownership.
New(PyTypeObject * cls,PyObject * unused_args,PyObject * unused_kwargs)1282 static PyObject* New(PyTypeObject* cls,
1283 PyObject* unused_args, PyObject* unused_kwargs) {
1284 CMessageClass* type = CheckMessageClass(cls);
1285 if (type == NULL) {
1286 return NULL;
1287 }
1288 // Retrieve the message descriptor and the default instance (=prototype).
1289 const Descriptor* message_descriptor = type->message_descriptor;
1290 if (message_descriptor == NULL) {
1291 return NULL;
1292 }
1293 const Message* prototype =
1294 type->py_message_factory->message_factory->GetPrototype(
1295 message_descriptor);
1296 if (prototype == NULL) {
1297 PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
1298 return NULL;
1299 }
1300
1301 CMessage* self = NewEmptyMessage(type);
1302 if (self == NULL) {
1303 return NULL;
1304 }
1305 self->message = prototype->New();
1306 self->parent = nullptr; // This message owns its data.
1307 return reinterpret_cast<PyObject*>(self);
1308 }
1309
1310 // The __init__ method of Message classes.
1311 // It initializes fields from keywords passed to the constructor.
Init(CMessage * self,PyObject * args,PyObject * kwargs)1312 static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
1313 return InitAttributes(self, args, kwargs);
1314 }
1315
1316 // ---------------------------------------------------------------------
1317 // Deallocating a CMessage
1318
Dealloc(CMessage * self)1319 static void Dealloc(CMessage* self) {
1320 if (self->weakreflist) {
1321 PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
1322 }
1323 // At this point all dependent objects have been removed.
1324 GOOGLE_DCHECK(!self->child_submessages || self->child_submessages->empty());
1325 GOOGLE_DCHECK(!self->composite_fields || self->composite_fields->empty());
1326 delete self->child_submessages;
1327 delete self->composite_fields;
1328 if (self->unknown_field_set) {
1329 unknown_fields::Clear(
1330 reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1331 }
1332
1333 CMessage* parent = self->parent;
1334 if (!parent) {
1335 // No parent, we own the message.
1336 delete self->message;
1337 } else if (parent->AsPyObject() == Py_None) {
1338 // Message owned externally: Nothing to dealloc
1339 Py_CLEAR(self->parent);
1340 } else {
1341 // Clear this message from its parent's map.
1342 if (self->parent_field_descriptor->is_repeated()) {
1343 if (parent->child_submessages)
1344 parent->child_submessages->erase(self->message);
1345 } else {
1346 if (parent->composite_fields)
1347 parent->composite_fields->erase(self->parent_field_descriptor);
1348 }
1349 Py_CLEAR(self->parent);
1350 }
1351 Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
1352 }
1353
1354 // ---------------------------------------------------------------------
1355
1356
IsInitialized(CMessage * self,PyObject * args)1357 PyObject* IsInitialized(CMessage* self, PyObject* args) {
1358 PyObject* errors = NULL;
1359 if (!PyArg_ParseTuple(args, "|O", &errors)) {
1360 return NULL;
1361 }
1362 if (self->message->IsInitialized()) {
1363 Py_RETURN_TRUE;
1364 }
1365 if (errors != NULL) {
1366 ScopedPyObjectPtr initialization_errors(
1367 FindInitializationErrors(self));
1368 if (initialization_errors == NULL) {
1369 return NULL;
1370 }
1371 ScopedPyObjectPtr extend_name(PyString_FromString("extend"));
1372 if (extend_name == NULL) {
1373 return NULL;
1374 }
1375 ScopedPyObjectPtr result(PyObject_CallMethodObjArgs(
1376 errors,
1377 extend_name.get(),
1378 initialization_errors.get(),
1379 NULL));
1380 if (result == NULL) {
1381 return NULL;
1382 }
1383 }
1384 Py_RETURN_FALSE;
1385 }
1386
HasFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1387 int HasFieldByDescriptor(CMessage* self,
1388 const FieldDescriptor* field_descriptor) {
1389 Message* message = self->message;
1390 if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
1391 return -1;
1392 }
1393 if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1394 PyErr_SetString(PyExc_KeyError,
1395 "Field is repeated. A singular method is required.");
1396 return -1;
1397 }
1398 return message->GetReflection()->HasField(*message, field_descriptor);
1399 }
1400
FindFieldWithOneofs(const Message * message,ConstStringParam field_name,bool * in_oneof)1401 const FieldDescriptor* FindFieldWithOneofs(const Message* message,
1402 ConstStringParam field_name,
1403 bool* in_oneof) {
1404 *in_oneof = false;
1405 const Descriptor* descriptor = message->GetDescriptor();
1406 const FieldDescriptor* field_descriptor =
1407 descriptor->FindFieldByName(field_name);
1408 if (field_descriptor != NULL) {
1409 return field_descriptor;
1410 }
1411 const OneofDescriptor* oneof_desc =
1412 descriptor->FindOneofByName(field_name);
1413 if (oneof_desc != NULL) {
1414 *in_oneof = true;
1415 return message->GetReflection()->GetOneofFieldDescriptor(*message,
1416 oneof_desc);
1417 }
1418 return NULL;
1419 }
1420
CheckHasPresence(const FieldDescriptor * field_descriptor,bool in_oneof)1421 bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) {
1422 auto message_name = field_descriptor->containing_type()->name();
1423 if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
1424 PyErr_Format(PyExc_ValueError,
1425 "Protocol message %s has no singular \"%s\" field.",
1426 message_name.c_str(), field_descriptor->name().c_str());
1427 return false;
1428 }
1429
1430 if (!field_descriptor->has_presence()) {
1431 PyErr_Format(PyExc_ValueError,
1432 "Can't test non-optional, non-submessage field \"%s.%s\" for "
1433 "presence in proto3.",
1434 message_name.c_str(), field_descriptor->name().c_str());
1435 return false;
1436 }
1437
1438 return true;
1439 }
1440
HasField(CMessage * self,PyObject * arg)1441 PyObject* HasField(CMessage* self, PyObject* arg) {
1442 char* field_name;
1443 Py_ssize_t size;
1444 #if PY_MAJOR_VERSION < 3
1445 if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) {
1446 return NULL;
1447 }
1448 #else
1449 field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
1450 if (!field_name) {
1451 return NULL;
1452 }
1453 #endif
1454
1455 Message* message = self->message;
1456 bool is_in_oneof;
1457 const FieldDescriptor* field_descriptor =
1458 FindFieldWithOneofs(message, StringParam(field_name, size), &is_in_oneof);
1459 if (field_descriptor == NULL) {
1460 if (!is_in_oneof) {
1461 PyErr_Format(PyExc_ValueError, "Protocol message %s has no field %s.",
1462 message->GetDescriptor()->name().c_str(), field_name);
1463 return NULL;
1464 } else {
1465 Py_RETURN_FALSE;
1466 }
1467 }
1468
1469 if (!CheckHasPresence(field_descriptor, is_in_oneof)) {
1470 return NULL;
1471 }
1472
1473 if (message->GetReflection()->HasField(*message, field_descriptor)) {
1474 Py_RETURN_TRUE;
1475 }
1476
1477 Py_RETURN_FALSE;
1478 }
1479
ClearExtension(CMessage * self,PyObject * extension)1480 PyObject* ClearExtension(CMessage* self, PyObject* extension) {
1481 const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1482 if (descriptor == NULL) {
1483 return NULL;
1484 }
1485 if (ClearFieldByDescriptor(self, descriptor) < 0) {
1486 return nullptr;
1487 }
1488 Py_RETURN_NONE;
1489 }
1490
HasExtension(CMessage * self,PyObject * extension)1491 PyObject* HasExtension(CMessage* self, PyObject* extension) {
1492 const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
1493 if (descriptor == NULL) {
1494 return NULL;
1495 }
1496 int has_field = HasFieldByDescriptor(self, descriptor);
1497 if (has_field < 0) {
1498 return nullptr;
1499 } else {
1500 return PyBool_FromLong(has_field);
1501 }
1502 }
1503
1504 // ---------------------------------------------------------------------
1505 // Releasing messages
1506 //
1507 // The Python API's ClearField() and Clear() methods behave
1508 // differently than their C++ counterparts. While the C++ versions
1509 // clears the children, the Python versions detaches the children,
1510 // without touching their content. This impedance mismatch causes
1511 // some complexity in the implementation, which is captured in this
1512 // section.
1513 //
1514 // When one or multiple fields are cleared we need to:
1515 //
1516 // * Gather all child objects that need to be detached from the message.
1517 // In composite_fields and child_submessages.
1518 //
1519 // * Create a new Python message of the same kind. Use SwapFields() to move
1520 // data from the original message.
1521 //
1522 // * Change the parent of all child objects: update their strong reference
1523 // to their parent, and move their presence in composite_fields and
1524 // child_submessages.
1525
1526 // ---------------------------------------------------------------------
1527 // Release a composite child of a CMessage
1528
InternalReparentFields(CMessage * self,const std::vector<CMessage * > & messages_to_release,const std::vector<ContainerBase * > & containers_to_release)1529 static int InternalReparentFields(
1530 CMessage* self, const std::vector<CMessage*>& messages_to_release,
1531 const std::vector<ContainerBase*>& containers_to_release) {
1532 if (messages_to_release.empty() && containers_to_release.empty()) {
1533 return 0;
1534 }
1535
1536 // Move all the passed sub_messages to another message.
1537 CMessage* new_message = cmessage::NewEmptyMessage(self->GetMessageClass());
1538 if (new_message == nullptr) {
1539 return -1;
1540 }
1541 new_message->message = self->message->New();
1542 ScopedPyObjectPtr holder(reinterpret_cast<PyObject*>(new_message));
1543 new_message->child_submessages = new CMessage::SubMessagesMap();
1544 new_message->composite_fields = new CMessage::CompositeFieldsMap();
1545 std::set<const FieldDescriptor*> fields_to_swap;
1546
1547 // In case this the removed fields are the last reference to a message, keep
1548 // a reference.
1549 Py_INCREF(self);
1550
1551 for (const auto& to_release : messages_to_release) {
1552 fields_to_swap.insert(to_release->parent_field_descriptor);
1553 // Reparent
1554 Py_INCREF(new_message);
1555 Py_DECREF(to_release->parent);
1556 to_release->parent = new_message;
1557 self->child_submessages->erase(to_release->message);
1558 new_message->child_submessages->emplace(to_release->message, to_release);
1559 }
1560
1561 for (const auto& to_release : containers_to_release) {
1562 fields_to_swap.insert(to_release->parent_field_descriptor);
1563 Py_INCREF(new_message);
1564 Py_DECREF(to_release->parent);
1565 to_release->parent = new_message;
1566 self->composite_fields->erase(to_release->parent_field_descriptor);
1567 new_message->composite_fields->emplace(to_release->parent_field_descriptor,
1568 to_release);
1569 }
1570
1571 self->message->GetReflection()->SwapFields(
1572 self->message, new_message->message,
1573 std::vector<const FieldDescriptor*>(fields_to_swap.begin(),
1574 fields_to_swap.end()));
1575
1576 // This might delete the Python message completely if all children were moved.
1577 Py_DECREF(self);
1578
1579 return 0;
1580 }
1581
InternalReleaseFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1582 int InternalReleaseFieldByDescriptor(
1583 CMessage* self,
1584 const FieldDescriptor* field_descriptor) {
1585 if (!field_descriptor->is_repeated() &&
1586 field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
1587 // Single scalars are not in any cache.
1588 return 0;
1589 }
1590 std::vector<CMessage*> messages_to_release;
1591 std::vector<ContainerBase*> containers_to_release;
1592 if (self->child_submessages && field_descriptor->is_repeated() &&
1593 field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
1594 for (const auto& child_item : *self->child_submessages) {
1595 if (child_item.second->parent_field_descriptor == field_descriptor) {
1596 messages_to_release.push_back(child_item.second);
1597 }
1598 }
1599 }
1600 if (self->composite_fields) {
1601 CMessage::CompositeFieldsMap::iterator it =
1602 self->composite_fields->find(field_descriptor);
1603 if (it != self->composite_fields->end()) {
1604 containers_to_release.push_back(it->second);
1605 }
1606 }
1607
1608 return InternalReparentFields(self, messages_to_release,
1609 containers_to_release);
1610 }
1611
ClearFieldByDescriptor(CMessage * self,const FieldDescriptor * field_descriptor)1612 int ClearFieldByDescriptor(CMessage* self,
1613 const FieldDescriptor* field_descriptor) {
1614 if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
1615 return -1;
1616 }
1617 if (InternalReleaseFieldByDescriptor(self, field_descriptor) < 0) {
1618 return -1;
1619 }
1620 AssureWritable(self);
1621 Message* message = self->message;
1622 message->GetReflection()->ClearField(message, field_descriptor);
1623 return 0;
1624 }
1625
ClearField(CMessage * self,PyObject * arg)1626 PyObject* ClearField(CMessage* self, PyObject* arg) {
1627 char* field_name;
1628 Py_ssize_t field_size;
1629 if (PyString_AsStringAndSize(arg, &field_name, &field_size) < 0) {
1630 return NULL;
1631 }
1632 AssureWritable(self);
1633 bool is_in_oneof;
1634 const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
1635 self->message, StringParam(field_name, field_size), &is_in_oneof);
1636 if (field_descriptor == NULL) {
1637 if (is_in_oneof) {
1638 // We gave the name of a oneof, and none of its fields are set.
1639 Py_RETURN_NONE;
1640 } else {
1641 PyErr_Format(PyExc_ValueError,
1642 "Protocol message has no \"%s\" field.", field_name);
1643 return NULL;
1644 }
1645 }
1646
1647 if (ClearFieldByDescriptor(self, field_descriptor) < 0) {
1648 return nullptr;
1649 }
1650 Py_RETURN_NONE;
1651 }
1652
Clear(CMessage * self)1653 PyObject* Clear(CMessage* self) {
1654 AssureWritable(self);
1655 // Detach all current fields of this message
1656 std::vector<CMessage*> messages_to_release;
1657 std::vector<ContainerBase*> containers_to_release;
1658 if (self->child_submessages) {
1659 for (const auto& item : *self->child_submessages) {
1660 messages_to_release.push_back(item.second);
1661 }
1662 }
1663 if (self->composite_fields) {
1664 for (const auto& item : *self->composite_fields) {
1665 containers_to_release.push_back(item.second);
1666 }
1667 }
1668 if (InternalReparentFields(self, messages_to_release, containers_to_release) <
1669 0) {
1670 return NULL;
1671 }
1672 if (self->unknown_field_set) {
1673 unknown_fields::Clear(
1674 reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
1675 self->unknown_field_set = nullptr;
1676 }
1677 self->message->Clear();
1678 Py_RETURN_NONE;
1679 }
1680
1681 // ---------------------------------------------------------------------
1682
GetMessageName(CMessage * self)1683 static std::string GetMessageName(CMessage* self) {
1684 if (self->parent_field_descriptor != NULL) {
1685 return self->parent_field_descriptor->full_name();
1686 } else {
1687 return self->message->GetDescriptor()->full_name();
1688 }
1689 }
1690
InternalSerializeToString(CMessage * self,PyObject * args,PyObject * kwargs,bool require_initialized)1691 static PyObject* InternalSerializeToString(
1692 CMessage* self, PyObject* args, PyObject* kwargs,
1693 bool require_initialized) {
1694 // Parse the "deterministic" kwarg; defaults to False.
1695 static char* kwlist[] = { "deterministic", 0 };
1696 PyObject* deterministic_obj = Py_None;
1697 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
1698 &deterministic_obj)) {
1699 return NULL;
1700 }
1701 // Preemptively convert to a bool first, so we don't need to back out of
1702 // allocating memory if this raises an exception.
1703 // NOTE: This is unused later if deterministic == Py_None, but that's fine.
1704 int deterministic = PyObject_IsTrue(deterministic_obj);
1705 if (deterministic < 0) {
1706 return NULL;
1707 }
1708
1709 if (require_initialized && !self->message->IsInitialized()) {
1710 ScopedPyObjectPtr errors(FindInitializationErrors(self));
1711 if (errors == NULL) {
1712 return NULL;
1713 }
1714 ScopedPyObjectPtr comma(PyString_FromString(","));
1715 if (comma == NULL) {
1716 return NULL;
1717 }
1718 ScopedPyObjectPtr joined(
1719 PyObject_CallMethod(comma.get(), "join", "O", errors.get()));
1720 if (joined == NULL) {
1721 return NULL;
1722 }
1723
1724 // TODO(haberman): this is a (hopefully temporary) hack. The unit testing
1725 // infrastructure reloads all pure-Python modules for every test, but not
1726 // C++ modules (because that's generally impossible:
1727 // http://bugs.python.org/issue1144263). But if we cache EncodeError, we'll
1728 // return the EncodeError from a previous load of the module, which won't
1729 // match a user's attempt to catch EncodeError. So we have to look it up
1730 // again every time.
1731 ScopedPyObjectPtr message_module(PyImport_ImportModule(
1732 "google.protobuf.message"));
1733 if (message_module.get() == NULL) {
1734 return NULL;
1735 }
1736
1737 ScopedPyObjectPtr encode_error(
1738 PyObject_GetAttrString(message_module.get(), "EncodeError"));
1739 if (encode_error.get() == NULL) {
1740 return NULL;
1741 }
1742 PyErr_Format(encode_error.get(),
1743 "Message %s is missing required fields: %s",
1744 GetMessageName(self).c_str(), PyString_AsString(joined.get()));
1745 return NULL;
1746 }
1747
1748 // Ok, arguments parsed and errors checked, now encode to a string
1749 const size_t size = self->message->ByteSizeLong();
1750 if (size == 0) {
1751 return PyBytes_FromString("");
1752 }
1753
1754 if (size > INT_MAX) {
1755 PyErr_Format(PyExc_ValueError,
1756 "Message %s exceeds maximum protobuf "
1757 "size of 2GB: %zu",
1758 GetMessageName(self).c_str(), size);
1759 return nullptr;
1760 }
1761
1762 PyObject* result = PyBytes_FromStringAndSize(NULL, size);
1763 if (result == NULL) {
1764 return NULL;
1765 }
1766 io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
1767 io::CodedOutputStream coded_out(&out);
1768 if (deterministic_obj != Py_None) {
1769 coded_out.SetSerializationDeterministic(deterministic);
1770 }
1771 self->message->SerializeWithCachedSizes(&coded_out);
1772 GOOGLE_CHECK(!coded_out.HadError());
1773 return result;
1774 }
1775
SerializeToString(CMessage * self,PyObject * args,PyObject * kwargs)1776 static PyObject* SerializeToString(
1777 CMessage* self, PyObject* args, PyObject* kwargs) {
1778 return InternalSerializeToString(self, args, kwargs,
1779 /*require_initialized=*/true);
1780 }
1781
SerializePartialToString(CMessage * self,PyObject * args,PyObject * kwargs)1782 static PyObject* SerializePartialToString(
1783 CMessage* self, PyObject* args, PyObject* kwargs) {
1784 return InternalSerializeToString(self, args, kwargs,
1785 /*require_initialized=*/false);
1786 }
1787
1788 // Formats proto fields for ascii dumps using python formatting functions where
1789 // appropriate.
1790 class PythonFieldValuePrinter : public TextFormat::FastFieldValuePrinter {
1791 public:
1792 // Python has some differences from C++ when printing floating point numbers.
1793 //
1794 // 1) Trailing .0 is always printed.
1795 // 2) (Python2) Output is rounded to 12 digits.
1796 // 3) (Python3) The full precision of the double is preserved (and Python uses
1797 // David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some
1798 // differences, but they rarely happen)
1799 //
1800 // We override floating point printing with the C-API function for printing
1801 // Python floats to ensure consistency.
PrintFloat(float val,TextFormat::BaseTextGenerator * generator) const1802 void PrintFloat(float val,
1803 TextFormat::BaseTextGenerator* generator) const override {
1804 PrintDouble(val, generator);
1805 }
PrintDouble(double val,TextFormat::BaseTextGenerator * generator) const1806 void PrintDouble(double val,
1807 TextFormat::BaseTextGenerator* generator) const override {
1808 // This implementation is not highly optimized (it allocates two temporary
1809 // Python objects) but it is simple and portable. If this is shown to be a
1810 // performance bottleneck, we can optimize it, but the results will likely
1811 // be more complicated to accommodate the differing behavior of double
1812 // formatting between Python 2 and Python 3.
1813 //
1814 // (Though a valid question is: do we really want to make out output
1815 // dependent on the Python version?)
1816 ScopedPyObjectPtr py_value(PyFloat_FromDouble(val));
1817 if (!py_value.get()) {
1818 return;
1819 }
1820
1821 ScopedPyObjectPtr py_str(PyObject_Str(py_value.get()));
1822 if (!py_str.get()) {
1823 return;
1824 }
1825
1826 generator->PrintString(PyString_AsString(py_str.get()));
1827 }
1828 };
1829
ToStr(CMessage * self)1830 static PyObject* ToStr(CMessage* self) {
1831 TextFormat::Printer printer;
1832 // Passes ownership
1833 printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter());
1834 printer.SetHideUnknownFields(true);
1835 std::string output;
1836 if (!printer.PrintToString(*self->message, &output)) {
1837 PyErr_SetString(PyExc_ValueError, "Unable to convert message to str");
1838 return NULL;
1839 }
1840 return PyString_FromString(output.c_str());
1841 }
1842
MergeFrom(CMessage * self,PyObject * arg)1843 PyObject* MergeFrom(CMessage* self, PyObject* arg) {
1844 CMessage* other_message;
1845 if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1846 PyErr_Format(PyExc_TypeError,
1847 "Parameter to MergeFrom() must be instance of same class: "
1848 "expected %s got %s.",
1849 self->message->GetDescriptor()->full_name().c_str(),
1850 Py_TYPE(arg)->tp_name);
1851 return NULL;
1852 }
1853
1854 other_message = reinterpret_cast<CMessage*>(arg);
1855 if (other_message->message->GetDescriptor() !=
1856 self->message->GetDescriptor()) {
1857 PyErr_Format(PyExc_TypeError,
1858 "Parameter to MergeFrom() must be instance of same class: "
1859 "expected %s got %s.",
1860 self->message->GetDescriptor()->full_name().c_str(),
1861 other_message->message->GetDescriptor()->full_name().c_str());
1862 return NULL;
1863 }
1864 AssureWritable(self);
1865
1866 self->message->MergeFrom(*other_message->message);
1867 // Child message might be lazily created before MergeFrom. Make sure they
1868 // are mutable at this point if child messages are really created.
1869 if (FixupMessageAfterMerge(self) < 0) {
1870 return NULL;
1871 }
1872
1873 Py_RETURN_NONE;
1874 }
1875
CopyFrom(CMessage * self,PyObject * arg)1876 static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
1877 CMessage* other_message;
1878 if (!PyObject_TypeCheck(arg, CMessage_Type)) {
1879 PyErr_Format(PyExc_TypeError,
1880 "Parameter to CopyFrom() must be instance of same class: "
1881 "expected %s got %s.",
1882 self->message->GetDescriptor()->full_name().c_str(),
1883 Py_TYPE(arg)->tp_name);
1884 return NULL;
1885 }
1886
1887 other_message = reinterpret_cast<CMessage*>(arg);
1888
1889 if (self == other_message) {
1890 Py_RETURN_NONE;
1891 }
1892
1893 if (other_message->message->GetDescriptor() !=
1894 self->message->GetDescriptor()) {
1895 PyErr_Format(PyExc_TypeError,
1896 "Parameter to CopyFrom() must be instance of same class: "
1897 "expected %s got %s.",
1898 self->message->GetDescriptor()->full_name().c_str(),
1899 other_message->message->GetDescriptor()->full_name().c_str());
1900 return NULL;
1901 }
1902
1903 AssureWritable(self);
1904
1905 // CopyFrom on the message will not clean up self->composite_fields,
1906 // which can leave us in an inconsistent state, so clear it out here.
1907 (void)ScopedPyObjectPtr(Clear(self));
1908
1909 self->message->CopyFrom(*other_message->message);
1910
1911 Py_RETURN_NONE;
1912 }
1913
1914 // Protobuf has a 64MB limit built in, this variable will override this. Please
1915 // do not enable this unless you fully understand the implications: protobufs
1916 // must all be kept in memory at the same time, so if they grow too big you may
1917 // get OOM errors. The protobuf APIs do not provide any tools for processing
1918 // protobufs in chunks. If you have protos this big you should break them up if
1919 // it is at all convenient to do so.
1920 #ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
1921 static bool allow_oversize_protos = true;
1922 #else
1923 static bool allow_oversize_protos = false;
1924 #endif
1925
1926 // Provide a method in the module to set allow_oversize_protos to a boolean
1927 // value. This method returns the newly value of allow_oversize_protos.
SetAllowOversizeProtos(PyObject * m,PyObject * arg)1928 PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
1929 if (!arg || !PyBool_Check(arg)) {
1930 PyErr_SetString(PyExc_TypeError,
1931 "Argument to SetAllowOversizeProtos must be boolean");
1932 return NULL;
1933 }
1934 allow_oversize_protos = PyObject_IsTrue(arg);
1935 if (allow_oversize_protos) {
1936 Py_RETURN_TRUE;
1937 } else {
1938 Py_RETURN_FALSE;
1939 }
1940 }
1941
MergeFromString(CMessage * self,PyObject * arg)1942 static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
1943 const void* data;
1944 Py_ssize_t data_length;
1945 if (PyObject_AsReadBuffer(arg, &data, &data_length) < 0) {
1946 return NULL;
1947 }
1948
1949 AssureWritable(self);
1950
1951 PyMessageFactory* factory = GetFactoryForMessage(self);
1952 int depth = allow_oversize_protos
1953 ? INT_MAX
1954 : io::CodedInputStream::GetDefaultRecursionLimit();
1955 const char* ptr;
1956 internal::ParseContext ctx(
1957 depth, false, &ptr,
1958 StringPiece(static_cast<const char*>(data), data_length));
1959 ctx.data().pool = factory->pool->pool;
1960 ctx.data().factory = factory->message_factory;
1961
1962 ptr = self->message->_InternalParse(ptr, &ctx);
1963
1964 // Child message might be lazily created before MergeFrom. Make sure they
1965 // are mutable at this point if child messages are really created.
1966 if (FixupMessageAfterMerge(self) < 0) {
1967 return NULL;
1968 }
1969
1970 // Python makes distinction in error message, between a general parse failure
1971 // and in-correct ending on a terminating tag. Hence we need to be a bit more
1972 // explicit in our correctness checks.
1973 if (ptr == nullptr || ctx.BytesUntilLimit(ptr) < 0) {
1974 // Parse error or the parser overshoot the limit.
1975 PyErr_Format(DecodeError_class, "Error parsing message");
1976 return NULL;
1977 }
1978 // ctx has an explicit limit set (length of string_view), so we have to
1979 // check we ended at that limit.
1980 if (!ctx.EndedAtLimit()) {
1981 // TODO(jieluo): Raise error and return NULL instead.
1982 // b/27494216
1983 PyErr_Warn(nullptr, "Unexpected end-group tag: Not all data was converted");
1984 return PyInt_FromLong(data_length - ctx.BytesUntilLimit(ptr));
1985 }
1986 return PyInt_FromLong(data_length);
1987 }
1988
ParseFromString(CMessage * self,PyObject * arg)1989 static PyObject* ParseFromString(CMessage* self, PyObject* arg) {
1990 if (ScopedPyObjectPtr(Clear(self)) == NULL) {
1991 return NULL;
1992 }
1993 return MergeFromString(self, arg);
1994 }
1995
ByteSize(CMessage * self,PyObject * args)1996 static PyObject* ByteSize(CMessage* self, PyObject* args) {
1997 return PyLong_FromLong(self->message->ByteSizeLong());
1998 }
1999
RegisterExtension(PyObject * cls,PyObject * extension_handle)2000 PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
2001 const FieldDescriptor* descriptor =
2002 GetExtensionDescriptor(extension_handle);
2003 if (descriptor == NULL) {
2004 return NULL;
2005 }
2006 if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
2007 PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
2008 cls->ob_type->tp_name);
2009 return NULL;
2010 }
2011 CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
2012 if (message_class == NULL) {
2013 return NULL;
2014 }
2015 // If the extension was already registered, check that it is the same.
2016 const FieldDescriptor* existing_extension =
2017 message_class->py_message_factory->pool->pool->FindExtensionByNumber(
2018 descriptor->containing_type(), descriptor->number());
2019 if (existing_extension != NULL && existing_extension != descriptor) {
2020 PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
2021 return NULL;
2022 }
2023 Py_RETURN_NONE;
2024 }
2025
SetInParent(CMessage * self,PyObject * args)2026 static PyObject* SetInParent(CMessage* self, PyObject* args) {
2027 AssureWritable(self);
2028 Py_RETURN_NONE;
2029 }
2030
WhichOneof(CMessage * self,PyObject * arg)2031 static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
2032 Py_ssize_t name_size;
2033 char *name_data;
2034 if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0)
2035 return NULL;
2036 const OneofDescriptor* oneof_desc =
2037 self->message->GetDescriptor()->FindOneofByName(
2038 StringParam(name_data, name_size));
2039 if (oneof_desc == NULL) {
2040 PyErr_Format(PyExc_ValueError,
2041 "Protocol message has no oneof \"%s\" field.", name_data);
2042 return NULL;
2043 }
2044 const FieldDescriptor* field_in_oneof =
2045 self->message->GetReflection()->GetOneofFieldDescriptor(
2046 *self->message, oneof_desc);
2047 if (field_in_oneof == NULL) {
2048 Py_RETURN_NONE;
2049 } else {
2050 const std::string& name = field_in_oneof->name();
2051 return PyString_FromStringAndSize(name.c_str(), name.size());
2052 }
2053 }
2054
2055 static PyObject* GetExtensionDict(CMessage* self, void *closure);
2056
ListFields(CMessage * self)2057 static PyObject* ListFields(CMessage* self) {
2058 std::vector<const FieldDescriptor*> fields;
2059 self->message->GetReflection()->ListFields(*self->message, &fields);
2060
2061 // Normally, the list will be exactly the size of the fields.
2062 ScopedPyObjectPtr all_fields(PyList_New(fields.size()));
2063 if (all_fields == NULL) {
2064 return NULL;
2065 }
2066
2067 // When there are unknown extensions, the py list will *not* contain
2068 // the field information. Thus the actual size of the py list will be
2069 // smaller than the size of fields. Set the actual size at the end.
2070 Py_ssize_t actual_size = 0;
2071 for (size_t i = 0; i < fields.size(); ++i) {
2072 ScopedPyObjectPtr t(PyTuple_New(2));
2073 if (t == NULL) {
2074 return NULL;
2075 }
2076
2077 if (fields[i]->is_extension()) {
2078 ScopedPyObjectPtr extension_field(
2079 PyFieldDescriptor_FromDescriptor(fields[i]));
2080 if (extension_field == NULL) {
2081 return NULL;
2082 }
2083 // With C++ descriptors, the field can always be retrieved, but for
2084 // unknown extensions which have not been imported in Python code, there
2085 // is no message class and we cannot retrieve the value.
2086 // TODO(amauryfa): consider building the class on the fly!
2087 if (fields[i]->message_type() != NULL &&
2088 message_factory::GetMessageClass(
2089 GetFactoryForMessage(self),
2090 fields[i]->message_type()) == NULL) {
2091 PyErr_Clear();
2092 continue;
2093 }
2094 ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL));
2095 if (extensions == NULL) {
2096 return NULL;
2097 }
2098 // 'extension' reference later stolen by PyTuple_SET_ITEM.
2099 PyObject* extension = PyObject_GetItem(
2100 extensions.get(), extension_field.get());
2101 if (extension == NULL) {
2102 return NULL;
2103 }
2104 PyTuple_SET_ITEM(t.get(), 0, extension_field.release());
2105 // Steals reference to 'extension'
2106 PyTuple_SET_ITEM(t.get(), 1, extension);
2107 } else {
2108 // Normal field
2109 ScopedPyObjectPtr field_descriptor(
2110 PyFieldDescriptor_FromDescriptor(fields[i]));
2111 if (field_descriptor == NULL) {
2112 return NULL;
2113 }
2114
2115 PyObject* field_value = GetFieldValue(self, fields[i]);
2116 if (field_value == NULL) {
2117 PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str());
2118 return NULL;
2119 }
2120 PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release());
2121 PyTuple_SET_ITEM(t.get(), 1, field_value);
2122 }
2123 PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
2124 ++actual_size;
2125 }
2126 if (static_cast<size_t>(actual_size) != fields.size() &&
2127 (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) <
2128 0)) {
2129 return NULL;
2130 }
2131 return all_fields.release();
2132 }
2133
DiscardUnknownFields(CMessage * self)2134 static PyObject* DiscardUnknownFields(CMessage* self) {
2135 AssureWritable(self);
2136 self->message->DiscardUnknownFields();
2137 Py_RETURN_NONE;
2138 }
2139
FindInitializationErrors(CMessage * self)2140 PyObject* FindInitializationErrors(CMessage* self) {
2141 Message* message = self->message;
2142 std::vector<std::string> errors;
2143 message->FindInitializationErrors(&errors);
2144
2145 PyObject* error_list = PyList_New(errors.size());
2146 if (error_list == NULL) {
2147 return NULL;
2148 }
2149 for (size_t i = 0; i < errors.size(); ++i) {
2150 const std::string& error = errors[i];
2151 PyObject* error_string = PyString_FromStringAndSize(
2152 error.c_str(), error.length());
2153 if (error_string == NULL) {
2154 Py_DECREF(error_list);
2155 return NULL;
2156 }
2157 PyList_SET_ITEM(error_list, i, error_string);
2158 }
2159 return error_list;
2160 }
2161
RichCompare(CMessage * self,PyObject * other,int opid)2162 static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
2163 // Only equality comparisons are implemented.
2164 if (opid != Py_EQ && opid != Py_NE) {
2165 Py_INCREF(Py_NotImplemented);
2166 return Py_NotImplemented;
2167 }
2168 bool equals = true;
2169 // If other is not a message, it cannot be equal.
2170 if (!PyObject_TypeCheck(other, CMessage_Type)) {
2171 equals = false;
2172 } else {
2173 // Otherwise, we have a CMessage whose message we can inspect.
2174 const google::protobuf::Message* other_message =
2175 reinterpret_cast<CMessage*>(other)->message;
2176 // If messages don't have the same descriptors, they are not equal.
2177 if (equals &&
2178 self->message->GetDescriptor() != other_message->GetDescriptor()) {
2179 equals = false;
2180 }
2181 // Check the message contents.
2182 if (equals &&
2183 !google::protobuf::util::MessageDifferencer::Equals(
2184 *self->message, *reinterpret_cast<CMessage*>(other)->message)) {
2185 equals = false;
2186 }
2187 }
2188
2189 if (equals ^ (opid == Py_EQ)) {
2190 Py_RETURN_FALSE;
2191 } else {
2192 Py_RETURN_TRUE;
2193 }
2194 }
2195
InternalGetScalar(const Message * message,const FieldDescriptor * field_descriptor)2196 PyObject* InternalGetScalar(const Message* message,
2197 const FieldDescriptor* field_descriptor) {
2198 const Reflection* reflection = message->GetReflection();
2199
2200 if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2201 return NULL;
2202 }
2203
2204 PyObject* result = NULL;
2205 switch (field_descriptor->cpp_type()) {
2206 case FieldDescriptor::CPPTYPE_INT32: {
2207 int32 value = reflection->GetInt32(*message, field_descriptor);
2208 result = PyInt_FromLong(value);
2209 break;
2210 }
2211 case FieldDescriptor::CPPTYPE_INT64: {
2212 int64 value = reflection->GetInt64(*message, field_descriptor);
2213 result = PyLong_FromLongLong(value);
2214 break;
2215 }
2216 case FieldDescriptor::CPPTYPE_UINT32: {
2217 uint32 value = reflection->GetUInt32(*message, field_descriptor);
2218 result = PyInt_FromSize_t(value);
2219 break;
2220 }
2221 case FieldDescriptor::CPPTYPE_UINT64: {
2222 uint64 value = reflection->GetUInt64(*message, field_descriptor);
2223 result = PyLong_FromUnsignedLongLong(value);
2224 break;
2225 }
2226 case FieldDescriptor::CPPTYPE_FLOAT: {
2227 float value = reflection->GetFloat(*message, field_descriptor);
2228 result = PyFloat_FromDouble(value);
2229 break;
2230 }
2231 case FieldDescriptor::CPPTYPE_DOUBLE: {
2232 double value = reflection->GetDouble(*message, field_descriptor);
2233 result = PyFloat_FromDouble(value);
2234 break;
2235 }
2236 case FieldDescriptor::CPPTYPE_BOOL: {
2237 bool value = reflection->GetBool(*message, field_descriptor);
2238 result = PyBool_FromLong(value);
2239 break;
2240 }
2241 case FieldDescriptor::CPPTYPE_STRING: {
2242 std::string scratch;
2243 const std::string& value =
2244 reflection->GetStringReference(*message, field_descriptor, &scratch);
2245 result = ToStringObject(field_descriptor, value);
2246 break;
2247 }
2248 case FieldDescriptor::CPPTYPE_ENUM: {
2249 const EnumValueDescriptor* enum_value =
2250 message->GetReflection()->GetEnum(*message, field_descriptor);
2251 result = PyInt_FromLong(enum_value->number());
2252 break;
2253 }
2254 default:
2255 PyErr_Format(
2256 PyExc_SystemError, "Getting a value from a field of unknown type %d",
2257 field_descriptor->cpp_type());
2258 }
2259
2260 return result;
2261 }
2262
InternalGetSubMessage(CMessage * self,const FieldDescriptor * field_descriptor)2263 CMessage* InternalGetSubMessage(
2264 CMessage* self, const FieldDescriptor* field_descriptor) {
2265 const Reflection* reflection = self->message->GetReflection();
2266 PyMessageFactory* factory = GetFactoryForMessage(self);
2267 const Message& sub_message = reflection->GetMessage(
2268 *self->message, field_descriptor, factory->message_factory);
2269
2270 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2271 factory, field_descriptor->message_type());
2272 ScopedPyObjectPtr message_class_owner(
2273 reinterpret_cast<PyObject*>(message_class));
2274 if (message_class == NULL) {
2275 return NULL;
2276 }
2277
2278 CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
2279 if (cmsg == NULL) {
2280 return NULL;
2281 }
2282
2283 Py_INCREF(self);
2284 cmsg->parent = self;
2285 cmsg->parent_field_descriptor = field_descriptor;
2286 cmsg->read_only = !reflection->HasField(*self->message, field_descriptor);
2287 cmsg->message = const_cast<Message*>(&sub_message);
2288 return cmsg;
2289 }
2290
InternalSetNonOneofScalar(Message * message,const FieldDescriptor * field_descriptor,PyObject * arg)2291 int InternalSetNonOneofScalar(
2292 Message* message,
2293 const FieldDescriptor* field_descriptor,
2294 PyObject* arg) {
2295 const Reflection* reflection = message->GetReflection();
2296
2297 if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
2298 return -1;
2299 }
2300
2301 switch (field_descriptor->cpp_type()) {
2302 case FieldDescriptor::CPPTYPE_INT32: {
2303 GOOGLE_CHECK_GET_INT32(arg, value, -1);
2304 reflection->SetInt32(message, field_descriptor, value);
2305 break;
2306 }
2307 case FieldDescriptor::CPPTYPE_INT64: {
2308 GOOGLE_CHECK_GET_INT64(arg, value, -1);
2309 reflection->SetInt64(message, field_descriptor, value);
2310 break;
2311 }
2312 case FieldDescriptor::CPPTYPE_UINT32: {
2313 GOOGLE_CHECK_GET_UINT32(arg, value, -1);
2314 reflection->SetUInt32(message, field_descriptor, value);
2315 break;
2316 }
2317 case FieldDescriptor::CPPTYPE_UINT64: {
2318 GOOGLE_CHECK_GET_UINT64(arg, value, -1);
2319 reflection->SetUInt64(message, field_descriptor, value);
2320 break;
2321 }
2322 case FieldDescriptor::CPPTYPE_FLOAT: {
2323 GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
2324 reflection->SetFloat(message, field_descriptor, value);
2325 break;
2326 }
2327 case FieldDescriptor::CPPTYPE_DOUBLE: {
2328 GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
2329 reflection->SetDouble(message, field_descriptor, value);
2330 break;
2331 }
2332 case FieldDescriptor::CPPTYPE_BOOL: {
2333 GOOGLE_CHECK_GET_BOOL(arg, value, -1);
2334 reflection->SetBool(message, field_descriptor, value);
2335 break;
2336 }
2337 case FieldDescriptor::CPPTYPE_STRING: {
2338 if (!CheckAndSetString(
2339 arg, message, field_descriptor, reflection, false, -1)) {
2340 return -1;
2341 }
2342 break;
2343 }
2344 case FieldDescriptor::CPPTYPE_ENUM: {
2345 GOOGLE_CHECK_GET_INT32(arg, value, -1);
2346 if (reflection->SupportsUnknownEnumValues()) {
2347 reflection->SetEnumValue(message, field_descriptor, value);
2348 } else {
2349 const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
2350 const EnumValueDescriptor* enum_value =
2351 enum_descriptor->FindValueByNumber(value);
2352 if (enum_value != NULL) {
2353 reflection->SetEnum(message, field_descriptor, enum_value);
2354 } else {
2355 PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
2356 return -1;
2357 }
2358 }
2359 break;
2360 }
2361 default:
2362 PyErr_Format(
2363 PyExc_SystemError, "Setting value to a field of unknown type %d",
2364 field_descriptor->cpp_type());
2365 return -1;
2366 }
2367
2368 return 0;
2369 }
2370
InternalSetScalar(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * arg)2371 int InternalSetScalar(
2372 CMessage* self,
2373 const FieldDescriptor* field_descriptor,
2374 PyObject* arg) {
2375 if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
2376 return -1;
2377 }
2378
2379 if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
2380 return -1;
2381 }
2382
2383 return InternalSetNonOneofScalar(self->message, field_descriptor, arg);
2384 }
2385
FromString(PyTypeObject * cls,PyObject * serialized)2386 PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
2387 PyObject* py_cmsg = PyObject_CallObject(
2388 reinterpret_cast<PyObject*>(cls), NULL);
2389 if (py_cmsg == NULL) {
2390 return NULL;
2391 }
2392 CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
2393
2394 ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized));
2395 if (py_length == NULL) {
2396 Py_DECREF(py_cmsg);
2397 return NULL;
2398 }
2399
2400 return py_cmsg;
2401 }
2402
DeepCopy(CMessage * self,PyObject * arg)2403 PyObject* DeepCopy(CMessage* self, PyObject* arg) {
2404 PyObject* clone = PyObject_CallObject(
2405 reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL);
2406 if (clone == NULL) {
2407 return NULL;
2408 }
2409 if (!PyObject_TypeCheck(clone, CMessage_Type)) {
2410 Py_DECREF(clone);
2411 return NULL;
2412 }
2413 if (ScopedPyObjectPtr(MergeFrom(
2414 reinterpret_cast<CMessage*>(clone),
2415 reinterpret_cast<PyObject*>(self))) == NULL) {
2416 Py_DECREF(clone);
2417 return NULL;
2418 }
2419 return clone;
2420 }
2421
ToUnicode(CMessage * self)2422 PyObject* ToUnicode(CMessage* self) {
2423 // Lazy import to prevent circular dependencies
2424 ScopedPyObjectPtr text_format(
2425 PyImport_ImportModule("google.protobuf.text_format"));
2426 if (text_format == NULL) {
2427 return NULL;
2428 }
2429 ScopedPyObjectPtr method_name(PyString_FromString("MessageToString"));
2430 if (method_name == NULL) {
2431 return NULL;
2432 }
2433 Py_INCREF(Py_True);
2434 ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(
2435 text_format.get(), method_name.get(), self, Py_True, NULL));
2436 Py_DECREF(Py_True);
2437 if (encoded == NULL) {
2438 return NULL;
2439 }
2440 #if PY_MAJOR_VERSION < 3
2441 PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL);
2442 #else
2443 PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL);
2444 #endif
2445 if (decoded == NULL) {
2446 return NULL;
2447 }
2448 return decoded;
2449 }
2450
2451 // CMessage static methods:
_CheckCalledFromGeneratedFile(PyObject * unused,PyObject * unused_arg)2452 PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
2453 PyObject* unused_arg) {
2454 if (!_CalledFromGeneratedFile(1)) {
2455 PyErr_SetString(PyExc_TypeError,
2456 "Descriptors should not be created directly, "
2457 "but only retrieved from their parent.");
2458 return NULL;
2459 }
2460 Py_RETURN_NONE;
2461 }
2462
GetExtensionDict(CMessage * self,void * closure)2463 static PyObject* GetExtensionDict(CMessage* self, void *closure) {
2464 // If there are extension_ranges, the message is "extendable". Allocate a
2465 // dictionary to store the extension fields.
2466 const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
2467 if (!descriptor->extension_range_count()) {
2468 PyErr_SetNone(PyExc_AttributeError);
2469 return NULL;
2470 }
2471 if (!self->composite_fields) {
2472 self->composite_fields = new CMessage::CompositeFieldsMap();
2473 }
2474 if (!self->composite_fields) {
2475 return NULL;
2476 }
2477 ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
2478 return reinterpret_cast<PyObject*>(extension_dict);
2479 }
2480
UnknownFieldSet(CMessage * self)2481 static PyObject* UnknownFieldSet(CMessage* self) {
2482 if (self->unknown_field_set == NULL) {
2483 self->unknown_field_set = unknown_fields::NewPyUnknownFields(self);
2484 } else {
2485 Py_INCREF(self->unknown_field_set);
2486 }
2487 return self->unknown_field_set;
2488 }
2489
GetExtensionsByName(CMessage * self,void * closure)2490 static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
2491 return message_meta::GetExtensionsByName(
2492 reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2493 }
2494
GetExtensionsByNumber(CMessage * self,void * closure)2495 static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
2496 return message_meta::GetExtensionsByNumber(
2497 reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
2498 }
2499
2500 static PyGetSetDef Getters[] = {
2501 {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
2502 {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
2503 {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
2504 {NULL}
2505 };
2506
2507
2508 static PyMethodDef Methods[] = {
2509 { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
2510 "Makes a deep copy of the class." },
2511 { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
2512 "Outputs a unicode representation of the message." },
2513 { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
2514 "Returns the size of the message in bytes." },
2515 { "Clear", (PyCFunction)Clear, METH_NOARGS,
2516 "Clears the message." },
2517 { "ClearExtension", (PyCFunction)ClearExtension, METH_O,
2518 "Clears a message field." },
2519 { "ClearField", (PyCFunction)ClearField, METH_O,
2520 "Clears a message field." },
2521 { "CopyFrom", (PyCFunction)CopyFrom, METH_O,
2522 "Copies a protocol message into the current message." },
2523 { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
2524 "Discards the unknown fields." },
2525 { "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
2526 METH_NOARGS,
2527 "Finds unset required fields." },
2528 { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS,
2529 "Creates new method instance from given serialized data." },
2530 { "HasExtension", (PyCFunction)HasExtension, METH_O,
2531 "Checks if a message field is set." },
2532 { "HasField", (PyCFunction)HasField, METH_O,
2533 "Checks if a message field is set." },
2534 { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS,
2535 "Checks if all required fields of a protocol message are set." },
2536 { "ListFields", (PyCFunction)ListFields, METH_NOARGS,
2537 "Lists all set fields of a message." },
2538 { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
2539 "Merges a protocol message into the current message." },
2540 { "MergeFromString", (PyCFunction)MergeFromString, METH_O,
2541 "Merges a serialized message into the current message." },
2542 { "ParseFromString", (PyCFunction)ParseFromString, METH_O,
2543 "Parses a serialized message into the current message." },
2544 { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
2545 "Registers an extension with the current message." },
2546 { "SerializePartialToString", (PyCFunction)SerializePartialToString,
2547 METH_VARARGS | METH_KEYWORDS,
2548 "Serializes the message to a string, even if it isn't initialized." },
2549 { "SerializeToString", (PyCFunction)SerializeToString,
2550 METH_VARARGS | METH_KEYWORDS,
2551 "Serializes the message to a string, only for initialized messages." },
2552 { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
2553 "Sets the has bit of the given field in its parent message." },
2554 { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS,
2555 "Parse unknown field set"},
2556 { "WhichOneof", (PyCFunction)WhichOneof, METH_O,
2557 "Returns the name of the field set inside a oneof, "
2558 "or None if no field is set." },
2559
2560 // Static Methods.
2561 { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile,
2562 METH_NOARGS | METH_STATIC,
2563 "Raises TypeError if the caller is not in a _pb2.py file."},
2564 { NULL, NULL}
2565 };
2566
SetCompositeField(CMessage * self,const FieldDescriptor * field,ContainerBase * value)2567 bool SetCompositeField(CMessage* self, const FieldDescriptor* field,
2568 ContainerBase* value) {
2569 if (self->composite_fields == NULL) {
2570 self->composite_fields = new CMessage::CompositeFieldsMap();
2571 }
2572 (*self->composite_fields)[field] = value;
2573 return true;
2574 }
2575
SetSubmessage(CMessage * self,CMessage * submessage)2576 bool SetSubmessage(CMessage* self, CMessage* submessage) {
2577 if (self->child_submessages == NULL) {
2578 self->child_submessages = new CMessage::SubMessagesMap();
2579 }
2580 (*self->child_submessages)[submessage->message] = submessage;
2581 return true;
2582 }
2583
GetAttr(PyObject * pself,PyObject * name)2584 PyObject* GetAttr(PyObject* pself, PyObject* name) {
2585 CMessage* self = reinterpret_cast<CMessage*>(pself);
2586 PyObject* result = PyObject_GenericGetAttr(
2587 reinterpret_cast<PyObject*>(self), name);
2588 if (result != NULL) {
2589 return result;
2590 }
2591 if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
2592 return NULL;
2593 }
2594
2595 PyErr_Clear();
2596 return message_meta::GetClassAttribute(
2597 CheckMessageClass(Py_TYPE(self)), name);
2598 }
2599
GetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor)2600 PyObject* GetFieldValue(CMessage* self,
2601 const FieldDescriptor* field_descriptor) {
2602 if (self->composite_fields) {
2603 CMessage::CompositeFieldsMap::iterator it =
2604 self->composite_fields->find(field_descriptor);
2605 if (it != self->composite_fields->end()) {
2606 ContainerBase* value = it->second;
2607 Py_INCREF(value);
2608 return value->AsPyObject();
2609 }
2610 }
2611
2612 if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2613 PyErr_Format(PyExc_TypeError,
2614 "descriptor to field '%s' doesn't apply to '%s' object",
2615 field_descriptor->full_name().c_str(),
2616 Py_TYPE(self)->tp_name);
2617 return NULL;
2618 }
2619
2620 if (!field_descriptor->is_repeated() &&
2621 field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
2622 return InternalGetScalar(self->message, field_descriptor);
2623 }
2624
2625 ContainerBase* py_container = nullptr;
2626 if (field_descriptor->is_map()) {
2627 const Descriptor* entry_type = field_descriptor->message_type();
2628 const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
2629 if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2630 CMessageClass* value_class = message_factory::GetMessageClass(
2631 GetFactoryForMessage(self), value_type->message_type());
2632 if (value_class == NULL) {
2633 return NULL;
2634 }
2635 py_container =
2636 NewMessageMapContainer(self, field_descriptor, value_class);
2637 } else {
2638 py_container = NewScalarMapContainer(self, field_descriptor);
2639 }
2640 } else if (field_descriptor->is_repeated()) {
2641 if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2642 CMessageClass* message_class = message_factory::GetMessageClass(
2643 GetFactoryForMessage(self), field_descriptor->message_type());
2644 if (message_class == NULL) {
2645 return NULL;
2646 }
2647 py_container = repeated_composite_container::NewContainer(
2648 self, field_descriptor, message_class);
2649 } else {
2650 py_container =
2651 repeated_scalar_container::NewContainer(self, field_descriptor);
2652 }
2653 } else if (field_descriptor->cpp_type() ==
2654 FieldDescriptor::CPPTYPE_MESSAGE) {
2655 py_container = InternalGetSubMessage(self, field_descriptor);
2656 } else {
2657 PyErr_SetString(PyExc_SystemError, "Should never happen");
2658 }
2659
2660 if (py_container == NULL) {
2661 return NULL;
2662 }
2663 if (!SetCompositeField(self, field_descriptor, py_container)) {
2664 Py_DECREF(py_container);
2665 return NULL;
2666 }
2667 return py_container->AsPyObject();
2668 }
2669
SetFieldValue(CMessage * self,const FieldDescriptor * field_descriptor,PyObject * value)2670 int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
2671 PyObject* value) {
2672 if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
2673 PyErr_Format(PyExc_TypeError,
2674 "descriptor to field '%s' doesn't apply to '%s' object",
2675 field_descriptor->full_name().c_str(),
2676 Py_TYPE(self)->tp_name);
2677 return -1;
2678 } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
2679 PyErr_Format(PyExc_AttributeError,
2680 "Assignment not allowed to repeated "
2681 "field \"%s\" in protocol message object.",
2682 field_descriptor->name().c_str());
2683 return -1;
2684 } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
2685 PyErr_Format(PyExc_AttributeError,
2686 "Assignment not allowed to "
2687 "field \"%s\" in protocol message object.",
2688 field_descriptor->name().c_str());
2689 return -1;
2690 } else {
2691 AssureWritable(self);
2692 return InternalSetScalar(self, field_descriptor, value);
2693 }
2694 }
2695
2696 } // namespace cmessage
2697
2698 // All containers which are not messages:
2699 // - Make a new parent message
2700 // - Copy the field
2701 // - return the field.
DeepCopy()2702 PyObject* ContainerBase::DeepCopy() {
2703 CMessage* new_parent =
2704 cmessage::NewEmptyMessage(this->parent->GetMessageClass());
2705 new_parent->message = this->parent->message->New();
2706
2707 // Copy the map field into the new message.
2708 this->parent->message->GetReflection()->SwapFields(
2709 this->parent->message, new_parent->message,
2710 {this->parent_field_descriptor});
2711 this->parent->message->MergeFrom(*new_parent->message);
2712
2713 PyObject* result =
2714 cmessage::GetFieldValue(new_parent, this->parent_field_descriptor);
2715 Py_DECREF(new_parent);
2716 return result;
2717 }
2718
RemoveFromParentCache()2719 void ContainerBase::RemoveFromParentCache() {
2720 CMessage* parent = this->parent;
2721 if (parent) {
2722 if (parent->composite_fields)
2723 parent->composite_fields->erase(this->parent_field_descriptor);
2724 Py_CLEAR(parent);
2725 }
2726 }
2727
BuildSubMessageFromPointer(const FieldDescriptor * field_descriptor,Message * sub_message,CMessageClass * message_class)2728 CMessage* CMessage::BuildSubMessageFromPointer(
2729 const FieldDescriptor* field_descriptor, Message* sub_message,
2730 CMessageClass* message_class) {
2731 if (!this->child_submessages) {
2732 this->child_submessages = new CMessage::SubMessagesMap();
2733 }
2734 CMessage* cmsg = FindPtrOrNull(
2735 *this->child_submessages, sub_message);
2736 if (cmsg) {
2737 Py_INCREF(cmsg);
2738 } else {
2739 cmsg = cmessage::NewEmptyMessage(message_class);
2740
2741 if (cmsg == NULL) {
2742 return NULL;
2743 }
2744 cmsg->message = sub_message;
2745 Py_INCREF(this);
2746 cmsg->parent = this;
2747 cmsg->parent_field_descriptor = field_descriptor;
2748 cmessage::SetSubmessage(this, cmsg);
2749 }
2750 return cmsg;
2751 }
2752
MaybeReleaseSubMessage(Message * sub_message)2753 CMessage* CMessage::MaybeReleaseSubMessage(Message* sub_message) {
2754 if (!this->child_submessages) {
2755 return nullptr;
2756 }
2757 CMessage* released = FindPtrOrNull(
2758 *this->child_submessages, sub_message);
2759 if (!released) {
2760 return nullptr;
2761 }
2762 // The target message will now own its content.
2763 Py_CLEAR(released->parent);
2764 released->parent_field_descriptor = nullptr;
2765 released->read_only = false;
2766 // Delete it from the cache.
2767 this->child_submessages->erase(sub_message);
2768 return released;
2769 }
2770
2771 static CMessageClass _CMessage_Type = { { {
2772 PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0)
2773 FULL_MODULE_NAME ".CMessage", // tp_name
2774 sizeof(CMessage), // tp_basicsize
2775 0, // tp_itemsize
2776 (destructor)cmessage::Dealloc, // tp_dealloc
2777 0, // tp_print
2778 0, // tp_getattr
2779 0, // tp_setattr
2780 0, // tp_compare
2781 (reprfunc)cmessage::ToStr, // tp_repr
2782 0, // tp_as_number
2783 0, // tp_as_sequence
2784 0, // tp_as_mapping
2785 PyObject_HashNotImplemented, // tp_hash
2786 0, // tp_call
2787 (reprfunc)cmessage::ToStr, // tp_str
2788 cmessage::GetAttr, // tp_getattro
2789 0, // tp_setattro
2790 0, // tp_as_buffer
2791 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE
2792 | Py_TPFLAGS_HAVE_VERSION_TAG, // tp_flags
2793 "A ProtocolMessage", // tp_doc
2794 0, // tp_traverse
2795 0, // tp_clear
2796 (richcmpfunc)cmessage::RichCompare, // tp_richcompare
2797 offsetof(CMessage, weakreflist), // tp_weaklistoffset
2798 0, // tp_iter
2799 0, // tp_iternext
2800 cmessage::Methods, // tp_methods
2801 0, // tp_members
2802 cmessage::Getters, // tp_getset
2803 0, // tp_base
2804 0, // tp_dict
2805 0, // tp_descr_get
2806 0, // tp_descr_set
2807 0, // tp_dictoffset
2808 (initproc)cmessage::Init, // tp_init
2809 0, // tp_alloc
2810 cmessage::New, // tp_new
2811 } } };
2812 PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type;
2813
2814 // --- Exposing the C proto living inside Python proto to C code:
2815
2816 const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
2817 Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
2818
GetCProtoInsidePyProtoImpl(PyObject * msg)2819 static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
2820 const Message* message = PyMessage_GetMessagePointer(msg);
2821 if (message == NULL) {
2822 PyErr_Clear();
2823 return NULL;
2824 }
2825 return message;
2826 }
2827
MutableCProtoInsidePyProtoImpl(PyObject * msg)2828 static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
2829 Message* message = PyMessage_GetMutableMessagePointer(msg);
2830 if (message == NULL) {
2831 PyErr_Clear();
2832 return NULL;
2833 }
2834 return message;
2835 }
2836
PyMessage_GetMessagePointer(PyObject * msg)2837 const Message* PyMessage_GetMessagePointer(PyObject* msg) {
2838 if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2839 PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2840 return NULL;
2841 }
2842 CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2843 return cmsg->message;
2844 }
2845
PyMessage_GetMutableMessagePointer(PyObject * msg)2846 Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
2847 if (!PyObject_TypeCheck(msg, CMessage_Type)) {
2848 PyErr_SetString(PyExc_TypeError, "Not a Message instance");
2849 return NULL;
2850 }
2851 CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
2852
2853
2854 if ((cmsg->composite_fields && !cmsg->composite_fields->empty()) ||
2855 (cmsg->child_submessages && !cmsg->child_submessages->empty())) {
2856 // There is currently no way of accurately syncing arbitrary changes to
2857 // the underlying C++ message back to the CMessage (e.g. removed repeated
2858 // composite containers). We only allow direct mutation of the underlying
2859 // C++ message if there is no child data in the CMessage.
2860 PyErr_SetString(PyExc_ValueError,
2861 "Cannot reliably get a mutable pointer "
2862 "to a message with extra references");
2863 return NULL;
2864 }
2865 cmessage::AssureWritable(cmsg);
2866 return cmsg->message;
2867 }
2868
PyMessage_NewMessageOwnedExternally(Message * message,PyObject * message_factory)2869 PyObject* PyMessage_NewMessageOwnedExternally(Message* message,
2870 PyObject* message_factory) {
2871 if (message_factory) {
2872 PyErr_SetString(PyExc_NotImplementedError,
2873 "Default message_factory=NULL is the only supported value");
2874 return NULL;
2875 }
2876 if (message->GetReflection()->GetMessageFactory() !=
2877 MessageFactory::generated_factory()) {
2878 PyErr_SetString(PyExc_TypeError,
2879 "Message pointer was not created from the default factory");
2880 return NULL;
2881 }
2882
2883 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
2884 GetDefaultDescriptorPool()->py_message_factory, message->GetDescriptor());
2885
2886 CMessage* self = cmessage::NewEmptyMessage(message_class);
2887 if (self == NULL) {
2888 return NULL;
2889 }
2890 Py_DECREF(message_class);
2891 self->message = message;
2892 Py_INCREF(Py_None);
2893 self->parent = reinterpret_cast<CMessage*>(Py_None);
2894 return self->AsPyObject();
2895 }
2896
InitGlobals()2897 void InitGlobals() {
2898 // TODO(gps): Check all return values in this function for NULL and propagate
2899 // the error (MemoryError) on up to result in an import failure. These should
2900 // also be freed and reset to NULL during finalization.
2901 kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
2902
2903 PyObject *dummy_obj = PySet_New(NULL);
2904 kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
2905 Py_DECREF(dummy_obj);
2906 }
2907
InitProto2MessageModule(PyObject * m)2908 bool InitProto2MessageModule(PyObject *m) {
2909 // Initialize types and globals in descriptor.cc
2910 if (!InitDescriptor()) {
2911 return false;
2912 }
2913
2914 // Initialize types and globals in descriptor_pool.cc
2915 if (!InitDescriptorPool()) {
2916 return false;
2917 }
2918
2919 // Initialize types and globals in message_factory.cc
2920 if (!InitMessageFactory()) {
2921 return false;
2922 }
2923
2924 // Initialize constants defined in this file.
2925 InitGlobals();
2926
2927 CMessageClass_Type->tp_base = &PyType_Type;
2928 if (PyType_Ready(CMessageClass_Type) < 0) {
2929 return false;
2930 }
2931 PyModule_AddObject(m, "MessageMeta",
2932 reinterpret_cast<PyObject*>(CMessageClass_Type));
2933
2934 if (PyType_Ready(CMessage_Type) < 0) {
2935 return false;
2936 }
2937 if (PyType_Ready(CFieldProperty_Type) < 0) {
2938 return false;
2939 }
2940
2941 // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
2942 // it here as well to document that subclasses need to set it.
2943 PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None);
2944 // Invalidate any cached data for the CMessage type.
2945 // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG,
2946 // after we have modified CMessage_Type.tp_dict.
2947 PyType_Modified(CMessage_Type);
2948
2949 PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type));
2950
2951 // Initialize Repeated container types.
2952 {
2953 if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) {
2954 return false;
2955 }
2956
2957 PyModule_AddObject(m, "RepeatedScalarContainer",
2958 reinterpret_cast<PyObject*>(
2959 &RepeatedScalarContainer_Type));
2960
2961 if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) {
2962 return false;
2963 }
2964
2965 PyModule_AddObject(
2966 m, "RepeatedCompositeContainer",
2967 reinterpret_cast<PyObject*>(
2968 &RepeatedCompositeContainer_Type));
2969
2970 // Register them as MutableSequence.
2971 #if PY_MAJOR_VERSION >= 3
2972 ScopedPyObjectPtr collections(PyImport_ImportModule("collections.abc"));
2973 #else
2974 ScopedPyObjectPtr collections(PyImport_ImportModule("collections"));
2975 #endif
2976 if (collections == NULL) {
2977 return false;
2978 }
2979 ScopedPyObjectPtr mutable_sequence(
2980 PyObject_GetAttrString(collections.get(), "MutableSequence"));
2981 if (mutable_sequence == NULL) {
2982 return false;
2983 }
2984 if (ScopedPyObjectPtr(
2985 PyObject_CallMethod(mutable_sequence.get(), "register", "O",
2986 &RepeatedScalarContainer_Type)) == NULL) {
2987 return false;
2988 }
2989 if (ScopedPyObjectPtr(
2990 PyObject_CallMethod(mutable_sequence.get(), "register", "O",
2991 &RepeatedCompositeContainer_Type)) == NULL) {
2992 return false;
2993 }
2994 }
2995
2996 if (PyType_Ready(&PyUnknownFields_Type) < 0) {
2997 return false;
2998 }
2999
3000 PyModule_AddObject(m, "UnknownFieldSet",
3001 reinterpret_cast<PyObject*>(
3002 &PyUnknownFields_Type));
3003
3004 if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) {
3005 return false;
3006 }
3007
3008 PyModule_AddObject(m, "UnknownField",
3009 reinterpret_cast<PyObject*>(
3010 &PyUnknownFieldRef_Type));
3011
3012 // Initialize Map container types.
3013 if (!InitMapContainers()) {
3014 return false;
3015 }
3016 PyModule_AddObject(m, "ScalarMapContainer",
3017 reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
3018 PyModule_AddObject(m, "MessageMapContainer",
3019 reinterpret_cast<PyObject*>(MessageMapContainer_Type));
3020 PyModule_AddObject(m, "MapIterator",
3021 reinterpret_cast<PyObject*>(&MapIterator_Type));
3022
3023 if (PyType_Ready(&ExtensionDict_Type) < 0) {
3024 return false;
3025 }
3026 PyModule_AddObject(
3027 m, "ExtensionDict",
3028 reinterpret_cast<PyObject*>(&ExtensionDict_Type));
3029 if (PyType_Ready(&ExtensionIterator_Type) < 0) {
3030 return false;
3031 }
3032 PyModule_AddObject(m, "ExtensionIterator",
3033 reinterpret_cast<PyObject*>(&ExtensionIterator_Type));
3034
3035 // Expose the DescriptorPool used to hold all descriptors added from generated
3036 // pb2.py files.
3037 // PyModule_AddObject steals a reference.
3038 Py_INCREF(GetDefaultDescriptorPool());
3039 PyModule_AddObject(m, "default_pool",
3040 reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
3041
3042 PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>(
3043 &PyDescriptorPool_Type));
3044
3045 PyModule_AddObject(m, "Descriptor", reinterpret_cast<PyObject*>(
3046 &PyMessageDescriptor_Type));
3047 PyModule_AddObject(m, "FieldDescriptor", reinterpret_cast<PyObject*>(
3048 &PyFieldDescriptor_Type));
3049 PyModule_AddObject(m, "EnumDescriptor", reinterpret_cast<PyObject*>(
3050 &PyEnumDescriptor_Type));
3051 PyModule_AddObject(m, "EnumValueDescriptor", reinterpret_cast<PyObject*>(
3052 &PyEnumValueDescriptor_Type));
3053 PyModule_AddObject(m, "FileDescriptor", reinterpret_cast<PyObject*>(
3054 &PyFileDescriptor_Type));
3055 PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
3056 &PyOneofDescriptor_Type));
3057 PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>(
3058 &PyServiceDescriptor_Type));
3059 PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>(
3060 &PyMethodDescriptor_Type));
3061
3062 PyObject* enum_type_wrapper = PyImport_ImportModule(
3063 "google.protobuf.internal.enum_type_wrapper");
3064 if (enum_type_wrapper == NULL) {
3065 return false;
3066 }
3067 EnumTypeWrapper_class =
3068 PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
3069 Py_DECREF(enum_type_wrapper);
3070
3071 PyObject* message_module = PyImport_ImportModule(
3072 "google.protobuf.message");
3073 if (message_module == NULL) {
3074 return false;
3075 }
3076 EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError");
3077 DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError");
3078 PythonMessage_class = PyObject_GetAttrString(message_module, "Message");
3079 Py_DECREF(message_module);
3080
3081 PyObject* pickle_module = PyImport_ImportModule("pickle");
3082 if (pickle_module == NULL) {
3083 return false;
3084 }
3085 PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError");
3086 Py_DECREF(pickle_module);
3087
3088 // Override {Get,Mutable}CProtoInsidePyProto.
3089 GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl;
3090 MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl;
3091
3092 return true;
3093 }
3094
3095 } // namespace python
3096 } // namespace protobuf
3097 } // namespace google
3098