1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 #include <unordered_map>
32
33 #include <Python.h>
34
35 #include <google/protobuf/dynamic_message.h>
36 #include <google/protobuf/pyext/descriptor.h>
37 #include <google/protobuf/pyext/message.h>
38 #include <google/protobuf/pyext/message_factory.h>
39 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
40
41 #if PY_MAJOR_VERSION >= 3
42 #if PY_VERSION_HEX < 0x03030000
43 #error "Python 3.0 - 3.2 are not supported."
44 #endif
45 #define PyString_AsStringAndSize(ob, charpp, sizep) \
46 (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
47 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
48 ? -1 \
49 : 0) \
50 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
51 #endif
52
53 namespace google {
54 namespace protobuf {
55 namespace python {
56
57 namespace message_factory {
58
NewMessageFactory(PyTypeObject * type,PyDescriptorPool * pool)59 PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
60 PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
61 PyType_GenericAlloc(type, 0));
62 if (factory == NULL) {
63 return NULL;
64 }
65
66 DynamicMessageFactory* message_factory = new DynamicMessageFactory();
67 // This option might be the default some day.
68 message_factory->SetDelegateToGeneratedFactory(true);
69 factory->message_factory = message_factory;
70
71 factory->pool = pool;
72 Py_INCREF(pool);
73
74 factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
75
76 return factory;
77 }
78
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)79 PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
80 static char* kwlist[] = {"pool", 0};
81 PyObject* pool = NULL;
82 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) {
83 return NULL;
84 }
85 ScopedPyObjectPtr owned_pool;
86 if (pool == NULL || pool == Py_None) {
87 owned_pool.reset(PyObject_CallFunction(
88 reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL));
89 if (owned_pool == NULL) {
90 return NULL;
91 }
92 pool = owned_pool.get();
93 } else {
94 if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
95 PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
96 pool->ob_type->tp_name);
97 return NULL;
98 }
99 }
100
101 return reinterpret_cast<PyObject*>(
102 NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
103 }
104
Dealloc(PyObject * pself)105 static void Dealloc(PyObject* pself) {
106 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
107
108 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
109 for (iterator it = self->classes_by_descriptor->begin();
110 it != self->classes_by_descriptor->end(); ++it) {
111 Py_CLEAR(it->second);
112 }
113 delete self->classes_by_descriptor;
114 delete self->message_factory;
115 Py_CLEAR(self->pool);
116 Py_TYPE(self)->tp_free(pself);
117 }
118
GcTraverse(PyObject * pself,visitproc visit,void * arg)119 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
120 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
121 Py_VISIT(self->pool);
122 for (const auto& desc_and_class : *self->classes_by_descriptor) {
123 Py_VISIT(desc_and_class.second);
124 }
125 return 0;
126 }
127
GcClear(PyObject * pself)128 static int GcClear(PyObject* pself) {
129 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
130 // Here it's important to not clear self->pool, so that the C++ DescriptorPool
131 // is still alive when self->message_factory is destructed.
132 for (auto& desc_and_class : *self->classes_by_descriptor) {
133 Py_CLEAR(desc_and_class.second);
134 }
135
136 return 0;
137 }
138
139 // Add a message class to our database.
RegisterMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor,CMessageClass * message_class)140 int RegisterMessageClass(PyMessageFactory* self,
141 const Descriptor* message_descriptor,
142 CMessageClass* message_class) {
143 Py_INCREF(message_class);
144 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
145 std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
146 std::make_pair(message_descriptor, message_class));
147 if (!ret.second) {
148 // Update case: DECREF the previous value.
149 Py_DECREF(ret.first->second);
150 ret.first->second = message_class;
151 }
152 return 0;
153 }
154
GetOrCreateMessageClass(PyMessageFactory * self,const Descriptor * descriptor)155 CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
156 const Descriptor* descriptor) {
157 // This is the same implementation as MessageFactory.GetPrototype().
158
159 // Do not create a MessageClass that already exists.
160 std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
161 self->classes_by_descriptor->find(descriptor);
162 if (it != self->classes_by_descriptor->end()) {
163 Py_INCREF(it->second);
164 return it->second;
165 }
166 ScopedPyObjectPtr py_descriptor(
167 PyMessageDescriptor_FromDescriptor(descriptor));
168 if (py_descriptor == NULL) {
169 return NULL;
170 }
171 // Create a new message class.
172 ScopedPyObjectPtr args(Py_BuildValue(
173 "s(){sOsOsO}", descriptor->name().c_str(),
174 "DESCRIPTOR", py_descriptor.get(),
175 "__module__", Py_None,
176 "message_factory", self));
177 if (args == NULL) {
178 return NULL;
179 }
180 ScopedPyObjectPtr message_class(PyObject_CallObject(
181 reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
182 if (message_class == NULL) {
183 return NULL;
184 }
185 // Create messages class for the messages used by the fields, and registers
186 // all extensions for these messages during the recursion.
187 for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
188 const Descriptor* sub_descriptor =
189 descriptor->field(field_idx)->message_type();
190 // It is NULL if the field type is not a message.
191 if (sub_descriptor != NULL) {
192 CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
193 if (result == NULL) {
194 return NULL;
195 }
196 Py_DECREF(result);
197 }
198 }
199
200 // Register extensions defined in this message.
201 for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
202 const FieldDescriptor* extension = descriptor->extension(ext_idx);
203 ScopedPyObjectPtr py_extended_class(
204 GetOrCreateMessageClass(self, extension->containing_type())
205 ->AsPyObject());
206 if (py_extended_class == NULL) {
207 return NULL;
208 }
209 ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
210 if (py_extension == NULL) {
211 return NULL;
212 }
213 ScopedPyObjectPtr result(cmessage::RegisterExtension(
214 py_extended_class.get(), py_extension.get()));
215 if (result == NULL) {
216 return NULL;
217 }
218 }
219 return reinterpret_cast<CMessageClass*>(message_class.release());
220 }
221
222 // Retrieve the message class added to our database.
GetMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor)223 CMessageClass* GetMessageClass(PyMessageFactory* self,
224 const Descriptor* message_descriptor) {
225 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
226 iterator ret = self->classes_by_descriptor->find(message_descriptor);
227 if (ret == self->classes_by_descriptor->end()) {
228 PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
229 message_descriptor->full_name().c_str());
230 return NULL;
231 } else {
232 return ret->second;
233 }
234 }
235
236 static PyMethodDef Methods[] = {
237 {NULL}};
238
GetPool(PyMessageFactory * self,void * closure)239 static PyObject* GetPool(PyMessageFactory* self, void* closure) {
240 Py_INCREF(self->pool);
241 return reinterpret_cast<PyObject*>(self->pool);
242 }
243
244 static PyGetSetDef Getters[] = {
245 {"pool", (getter)GetPool, NULL, "DescriptorPool"},
246 {NULL}
247 };
248
249 } // namespace message_factory
250
251 PyTypeObject PyMessageFactory_Type = {
252 PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
253 ".MessageFactory", // tp_name
254 sizeof(PyMessageFactory), // tp_basicsize
255 0, // tp_itemsize
256 message_factory::Dealloc, // tp_dealloc
257 0, // tp_print
258 0, // tp_getattr
259 0, // tp_setattr
260 0, // tp_compare
261 0, // tp_repr
262 0, // tp_as_number
263 0, // tp_as_sequence
264 0, // tp_as_mapping
265 0, // tp_hash
266 0, // tp_call
267 0, // tp_str
268 0, // tp_getattro
269 0, // tp_setattro
270 0, // tp_as_buffer
271 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags
272 "A static Message Factory", // tp_doc
273 message_factory::GcTraverse, // tp_traverse
274 message_factory::GcClear, // tp_clear
275 0, // tp_richcompare
276 0, // tp_weaklistoffset
277 0, // tp_iter
278 0, // tp_iternext
279 message_factory::Methods, // tp_methods
280 0, // tp_members
281 message_factory::Getters, // tp_getset
282 0, // tp_base
283 0, // tp_dict
284 0, // tp_descr_get
285 0, // tp_descr_set
286 0, // tp_dictoffset
287 0, // tp_init
288 0, // tp_alloc
289 message_factory::New, // tp_new
290 PyObject_GC_Del, // tp_free
291 };
292
InitMessageFactory()293 bool InitMessageFactory() {
294 if (PyType_Ready(&PyMessageFactory_Type) < 0) {
295 return false;
296 }
297
298 return true;
299 }
300
301 } // namespace python
302 } // namespace protobuf
303 } // namespace google
304