• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 #include <unordered_map>
9 #include <utility>
10 
11 #define PY_SSIZE_T_CLEAN
12 #include <Python.h>
13 
14 #include "google/protobuf/dynamic_message.h"
15 #include "google/protobuf/pyext/descriptor.h"
16 #include "google/protobuf/pyext/message.h"
17 #include "google/protobuf/pyext/message_factory.h"
18 #include "google/protobuf/pyext/scoped_pyobject_ptr.h"
19 
20 #define PyString_AsStringAndSize(ob, charpp, sizep)              \
21   (PyUnicode_Check(ob)                                           \
22        ? ((*(charpp) = const_cast<char*>(                        \
23                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
24               ? -1                                               \
25               : 0)                                               \
26        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
27 
28 namespace google {
29 namespace protobuf {
30 namespace python {
31 
32 namespace message_factory {
33 
NewMessageFactory(PyTypeObject * type,PyDescriptorPool * pool)34 PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
35   PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
36       PyType_GenericAlloc(type, 0));
37   if (factory == nullptr) {
38     return nullptr;
39   }
40 
41   DynamicMessageFactory* message_factory = new DynamicMessageFactory();
42   // This option might be the default some day.
43   message_factory->SetDelegateToGeneratedFactory(true);
44   factory->message_factory = message_factory;
45 
46   factory->pool = pool;
47   Py_INCREF(pool);
48 
49   factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
50 
51   return factory;
52 }
53 
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)54 PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
55   static const char* kwlist[] = {"pool", nullptr};
56   PyObject* pool = nullptr;
57   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O",
58                                    const_cast<char**>(kwlist), &pool)) {
59     return nullptr;
60   }
61   ScopedPyObjectPtr owned_pool;
62   if (pool == nullptr || pool == Py_None) {
63     owned_pool.reset(PyObject_CallFunction(
64         reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), nullptr));
65     if (owned_pool == nullptr) {
66       return nullptr;
67     }
68     pool = owned_pool.get();
69   } else {
70     if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
71       PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
72                    pool->ob_type->tp_name);
73       return nullptr;
74     }
75   }
76 
77   return reinterpret_cast<PyObject*>(
78       NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
79 }
80 
Dealloc(PyObject * pself)81 static void Dealloc(PyObject* pself) {
82   PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
83 
84   typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
85   for (iterator it = self->classes_by_descriptor->begin();
86        it != self->classes_by_descriptor->end(); ++it) {
87     Py_CLEAR(it->second);
88   }
89   delete self->classes_by_descriptor;
90   delete self->message_factory;
91   Py_CLEAR(self->pool);
92   Py_TYPE(self)->tp_free(pself);
93 }
94 
GcTraverse(PyObject * pself,visitproc visit,void * arg)95 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
96   PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
97   Py_VISIT(self->pool);
98   for (const auto& desc_and_class : *self->classes_by_descriptor) {
99     Py_VISIT(desc_and_class.second);
100   }
101   return 0;
102 }
103 
GcClear(PyObject * pself)104 static int GcClear(PyObject* pself) {
105   PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
106   // Here it's important to not clear self->pool, so that the C++ DescriptorPool
107   // is still alive when self->message_factory is destructed.
108   for (auto& desc_and_class : *self->classes_by_descriptor) {
109     Py_CLEAR(desc_and_class.second);
110   }
111 
112   return 0;
113 }
114 
115 // Add a message class to our database.
RegisterMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor,CMessageClass * message_class)116 int RegisterMessageClass(PyMessageFactory* self,
117                          const Descriptor* message_descriptor,
118                          CMessageClass* message_class) {
119   Py_INCREF(message_class);
120   typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
121   std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
122       std::make_pair(message_descriptor, message_class));
123   if (!ret.second) {
124     // Update case: DECREF the previous value.
125     Py_DECREF(ret.first->second);
126     ret.first->second = message_class;
127   }
128   return 0;
129 }
130 
GetOrCreateMessageClass(PyMessageFactory * self,const Descriptor * descriptor)131 CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
132                                        const Descriptor* descriptor) {
133   // This is the same implementation as MessageFactory.GetPrototype().
134 
135   // Do not create a MessageClass that already exists.
136   std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
137       self->classes_by_descriptor->find(descriptor);
138   if (it != self->classes_by_descriptor->end()) {
139     Py_INCREF(it->second);
140     return it->second;
141   }
142   ScopedPyObjectPtr py_descriptor(
143       PyMessageDescriptor_FromDescriptor(descriptor));
144   if (py_descriptor == nullptr) {
145     return nullptr;
146   }
147   // Create a new message class.
148   ScopedPyObjectPtr args(Py_BuildValue(
149       "s(){sOsOsO}", std::string(descriptor->name()).c_str(), "DESCRIPTOR",
150       py_descriptor.get(), "__module__", Py_None, "message_factory", self));
151   if (args == nullptr) {
152     return nullptr;
153   }
154   ScopedPyObjectPtr message_class(PyObject_CallObject(
155       reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
156   if (message_class == nullptr) {
157     return nullptr;
158   }
159   // Create messages class for the messages used by the fields, and registers
160   // all extensions for these messages during the recursion.
161   for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
162     const Descriptor* sub_descriptor =
163         descriptor->field(field_idx)->message_type();
164     // It is null if the field type is not a message.
165     if (sub_descriptor != nullptr) {
166       CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
167       if (result == nullptr) {
168         return nullptr;
169       }
170       Py_DECREF(result);
171     }
172   }
173 
174   // Register extensions defined in this message.
175   for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
176     const FieldDescriptor* extension = descriptor->extension(ext_idx);
177     ScopedPyObjectPtr py_extended_class(
178         GetOrCreateMessageClass(self, extension->containing_type())
179             ->AsPyObject());
180     if (py_extended_class == nullptr) {
181       return nullptr;
182     }
183     ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
184     if (py_extension == nullptr) {
185       return nullptr;
186     }
187   }
188   return reinterpret_cast<CMessageClass*>(message_class.release());
189 }
190 
191 // Retrieve the message class added to our database.
GetMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor)192 CMessageClass* GetMessageClass(PyMessageFactory* self,
193                                const Descriptor* message_descriptor) {
194   typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
195   iterator ret = self->classes_by_descriptor->find(message_descriptor);
196   if (ret == self->classes_by_descriptor->end()) {
197     PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
198                  std::string(message_descriptor->full_name()).c_str());
199     return nullptr;
200   } else {
201     return ret->second;
202   }
203 }
204 
205 static PyMethodDef Methods[] = {
206     {nullptr},
207 };
208 
GetPool(PyMessageFactory * self,void * closure)209 static PyObject* GetPool(PyMessageFactory* self, void* closure) {
210   Py_INCREF(self->pool);
211   return reinterpret_cast<PyObject*>(self->pool);
212 }
213 
214 static PyGetSetDef Getters[] = {
215     {"pool", (getter)GetPool, nullptr, "DescriptorPool"},
216     {nullptr},
217 };
218 
219 }  // namespace message_factory
220 
221 PyTypeObject PyMessageFactory_Type = {
222     PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
223     ".MessageFactory",         // tp_name
224     sizeof(PyMessageFactory),  // tp_basicsize
225     0,                         // tp_itemsize
226     message_factory::Dealloc,  // tp_dealloc
227 #if PY_VERSION_HEX < 0x03080000
228     nullptr,  // tp_print
229 #else
230     0,  // tp_vectorcall_offset
231 #endif
232     nullptr,  // tp_getattr
233     nullptr,  // tp_setattr
234     nullptr,  // tp_compare
235     nullptr,  // tp_repr
236     nullptr,  // tp_as_number
237     nullptr,  // tp_as_sequence
238     nullptr,  // tp_as_mapping
239     nullptr,  // tp_hash
240     nullptr,  // tp_call
241     nullptr,  // tp_str
242     nullptr,  // tp_getattro
243     nullptr,  // tp_setattro
244     nullptr,  // tp_as_buffer
245     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,  // tp_flags
246     "A static Message Factory",                                     // tp_doc
247     message_factory::GcTraverse,  // tp_traverse
248     message_factory::GcClear,     // tp_clear
249     nullptr,                      // tp_richcompare
250     0,                            // tp_weaklistoffset
251     nullptr,                      // tp_iter
252     nullptr,                      // tp_iternext
253     message_factory::Methods,     // tp_methods
254     nullptr,                      // tp_members
255     message_factory::Getters,     // tp_getset
256     nullptr,                      // tp_base
257     nullptr,                      // tp_dict
258     nullptr,                      // tp_descr_get
259     nullptr,                      // tp_descr_set
260     0,                            // tp_dictoffset
261     nullptr,                      // tp_init
262     nullptr,                      // tp_alloc
263     message_factory::New,         // tp_new
264     PyObject_GC_Del,              // tp_free
265 };
266 
InitMessageFactory()267 bool InitMessageFactory() {
268   if (PyType_Ready(&PyMessageFactory_Type) < 0) {
269     return false;
270   }
271 
272   return true;
273 }
274 
275 }  // namespace python
276 }  // namespace protobuf
277 }  // namespace google
278