• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 #include <unordered_map>
32 
33 #include <Python.h>
34 
35 #include <google/protobuf/dynamic_message.h>
36 #include <google/protobuf/pyext/descriptor.h>
37 #include <google/protobuf/pyext/message.h>
38 #include <google/protobuf/pyext/message_factory.h>
39 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
40 
41 #if PY_MAJOR_VERSION >= 3
42   #if PY_VERSION_HEX < 0x03030000
43     #error "Python 3.0 - 3.2 are not supported."
44   #endif
45   #define PyString_AsStringAndSize(ob, charpp, sizep) \
46     (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>(                   \
47                                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
48                               ? -1                                            \
49                               : 0)                                            \
50                         : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
51 #endif
52 
53 namespace google {
54 namespace protobuf {
55 namespace python {
56 
57 namespace message_factory {
58 
NewMessageFactory(PyTypeObject * type,PyDescriptorPool * pool)59 PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
60   PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
61       PyType_GenericAlloc(type, 0));
62   if (factory == NULL) {
63     return NULL;
64   }
65 
66   DynamicMessageFactory* message_factory = new DynamicMessageFactory();
67   // This option might be the default some day.
68   message_factory->SetDelegateToGeneratedFactory(true);
69   factory->message_factory = message_factory;
70 
71   factory->pool = pool;
72   Py_INCREF(pool);
73 
74   factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
75 
76   return factory;
77 }
78 
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)79 PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
80   static char* kwlist[] = {"pool", 0};
81   PyObject* pool = NULL;
82   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) {
83     return NULL;
84   }
85   ScopedPyObjectPtr owned_pool;
86   if (pool == NULL || pool == Py_None) {
87     owned_pool.reset(PyObject_CallFunction(
88         reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL));
89     if (owned_pool == NULL) {
90       return NULL;
91     }
92     pool = owned_pool.get();
93   } else {
94     if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
95       PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
96                    pool->ob_type->tp_name);
97       return NULL;
98     }
99   }
100 
101   return reinterpret_cast<PyObject*>(
102       NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
103 }
104 
Dealloc(PyObject * pself)105 static void Dealloc(PyObject* pself) {
106   PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
107 
108   typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
109   for (iterator it = self->classes_by_descriptor->begin();
110        it != self->classes_by_descriptor->end(); ++it) {
111     Py_CLEAR(it->second);
112   }
113   delete self->classes_by_descriptor;
114   delete self->message_factory;
115   Py_CLEAR(self->pool);
116   Py_TYPE(self)->tp_free(pself);
117 }
118 
GcTraverse(PyObject * pself,visitproc visit,void * arg)119 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
120   PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
121   Py_VISIT(self->pool);
122   for (const auto& desc_and_class : *self->classes_by_descriptor) {
123     Py_VISIT(desc_and_class.second);
124   }
125   return 0;
126 }
127 
GcClear(PyObject * pself)128 static int GcClear(PyObject* pself) {
129   PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
130   // Here it's important to not clear self->pool, so that the C++ DescriptorPool
131   // is still alive when self->message_factory is destructed.
132   for (auto& desc_and_class : *self->classes_by_descriptor) {
133     Py_CLEAR(desc_and_class.second);
134   }
135 
136   return 0;
137 }
138 
139 // Add a message class to our database.
RegisterMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor,CMessageClass * message_class)140 int RegisterMessageClass(PyMessageFactory* self,
141                          const Descriptor* message_descriptor,
142                          CMessageClass* message_class) {
143   Py_INCREF(message_class);
144   typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
145   std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
146       std::make_pair(message_descriptor, message_class));
147   if (!ret.second) {
148     // Update case: DECREF the previous value.
149     Py_DECREF(ret.first->second);
150     ret.first->second = message_class;
151   }
152   return 0;
153 }
154 
GetOrCreateMessageClass(PyMessageFactory * self,const Descriptor * descriptor)155 CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
156                                        const Descriptor* descriptor) {
157   // This is the same implementation as MessageFactory.GetPrototype().
158 
159   // Do not create a MessageClass that already exists.
160   std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
161       self->classes_by_descriptor->find(descriptor);
162   if (it != self->classes_by_descriptor->end()) {
163     Py_INCREF(it->second);
164     return it->second;
165   }
166   ScopedPyObjectPtr py_descriptor(
167       PyMessageDescriptor_FromDescriptor(descriptor));
168   if (py_descriptor == NULL) {
169     return NULL;
170   }
171   // Create a new message class.
172   ScopedPyObjectPtr args(Py_BuildValue(
173       "s(){sOsOsO}", descriptor->name().c_str(),
174       "DESCRIPTOR", py_descriptor.get(),
175       "__module__", Py_None,
176       "message_factory", self));
177   if (args == NULL) {
178     return NULL;
179   }
180   ScopedPyObjectPtr message_class(PyObject_CallObject(
181       reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
182   if (message_class == NULL) {
183     return NULL;
184   }
185   // Create messages class for the messages used by the fields, and registers
186   // all extensions for these messages during the recursion.
187   for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
188     const Descriptor* sub_descriptor =
189         descriptor->field(field_idx)->message_type();
190     // It is NULL if the field type is not a message.
191     if (sub_descriptor != NULL) {
192       CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
193       if (result == NULL) {
194         return NULL;
195       }
196       Py_DECREF(result);
197     }
198   }
199 
200   // Register extensions defined in this message.
201   for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
202     const FieldDescriptor* extension = descriptor->extension(ext_idx);
203     ScopedPyObjectPtr py_extended_class(
204         GetOrCreateMessageClass(self, extension->containing_type())
205             ->AsPyObject());
206     if (py_extended_class == NULL) {
207       return NULL;
208     }
209     ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
210     if (py_extension == NULL) {
211       return NULL;
212     }
213     ScopedPyObjectPtr result(cmessage::RegisterExtension(
214         py_extended_class.get(), py_extension.get()));
215     if (result == NULL) {
216       return NULL;
217     }
218   }
219   return reinterpret_cast<CMessageClass*>(message_class.release());
220 }
221 
222 // Retrieve the message class added to our database.
GetMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor)223 CMessageClass* GetMessageClass(PyMessageFactory* self,
224                                const Descriptor* message_descriptor) {
225   typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
226   iterator ret = self->classes_by_descriptor->find(message_descriptor);
227   if (ret == self->classes_by_descriptor->end()) {
228     PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
229                  message_descriptor->full_name().c_str());
230     return NULL;
231   } else {
232     return ret->second;
233   }
234 }
235 
236 static PyMethodDef Methods[] = {
237     {NULL}};
238 
GetPool(PyMessageFactory * self,void * closure)239 static PyObject* GetPool(PyMessageFactory* self, void* closure) {
240   Py_INCREF(self->pool);
241   return reinterpret_cast<PyObject*>(self->pool);
242 }
243 
244 static PyGetSetDef Getters[] = {
245     {"pool", (getter)GetPool, NULL, "DescriptorPool"},
246     {NULL}
247 };
248 
249 }  // namespace message_factory
250 
251 PyTypeObject PyMessageFactory_Type = {
252     PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
253     ".MessageFactory",         // tp_name
254     sizeof(PyMessageFactory),  // tp_basicsize
255     0,                         // tp_itemsize
256     message_factory::Dealloc,  // tp_dealloc
257     0,                         // tp_print
258     0,                         // tp_getattr
259     0,                         // tp_setattr
260     0,                         // tp_compare
261     0,                         // tp_repr
262     0,                         // tp_as_number
263     0,                         // tp_as_sequence
264     0,                         // tp_as_mapping
265     0,                         // tp_hash
266     0,                         // tp_call
267     0,                         // tp_str
268     0,                         // tp_getattro
269     0,                         // tp_setattro
270     0,                         // tp_as_buffer
271     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,  // tp_flags
272     "A static Message Factory",                                     // tp_doc
273     message_factory::GcTraverse,  // tp_traverse
274     message_factory::GcClear,     // tp_clear
275     0,                            // tp_richcompare
276     0,                            // tp_weaklistoffset
277     0,                            // tp_iter
278     0,                            // tp_iternext
279     message_factory::Methods,     // tp_methods
280     0,                            // tp_members
281     message_factory::Getters,     // tp_getset
282     0,                            // tp_base
283     0,                            // tp_dict
284     0,                            // tp_descr_get
285     0,                            // tp_descr_set
286     0,                            // tp_dictoffset
287     0,                            // tp_init
288     0,                            // tp_alloc
289     message_factory::New,         // tp_new
290     PyObject_GC_Del,              // tp_free
291 };
292 
InitMessageFactory()293 bool InitMessageFactory() {
294   if (PyType_Ready(&PyMessageFactory_Type) < 0) {
295     return false;
296   }
297 
298   return true;
299 }
300 
301 }  // namespace python
302 }  // namespace protobuf
303 }  // namespace google
304