1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 #include <unordered_map>
9 #include <utility>
10
11 #define PY_SSIZE_T_CLEAN
12 #include <Python.h>
13
14 #include "google/protobuf/dynamic_message.h"
15 #include "google/protobuf/pyext/descriptor.h"
16 #include "google/protobuf/pyext/message.h"
17 #include "google/protobuf/pyext/message_factory.h"
18 #include "google/protobuf/pyext/scoped_pyobject_ptr.h"
19
20 #define PyString_AsStringAndSize(ob, charpp, sizep) \
21 (PyUnicode_Check(ob) \
22 ? ((*(charpp) = const_cast<char*>( \
23 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
24 ? -1 \
25 : 0) \
26 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
27
28 namespace google {
29 namespace protobuf {
30 namespace python {
31
32 namespace message_factory {
33
NewMessageFactory(PyTypeObject * type,PyDescriptorPool * pool)34 PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
35 PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
36 PyType_GenericAlloc(type, 0));
37 if (factory == nullptr) {
38 return nullptr;
39 }
40
41 DynamicMessageFactory* message_factory = new DynamicMessageFactory();
42 // This option might be the default some day.
43 message_factory->SetDelegateToGeneratedFactory(true);
44 factory->message_factory = message_factory;
45
46 factory->pool = pool;
47 Py_INCREF(pool);
48
49 factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
50
51 return factory;
52 }
53
New(PyTypeObject * type,PyObject * args,PyObject * kwargs)54 PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
55 static const char* kwlist[] = {"pool", nullptr};
56 PyObject* pool = nullptr;
57 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O",
58 const_cast<char**>(kwlist), &pool)) {
59 return nullptr;
60 }
61 ScopedPyObjectPtr owned_pool;
62 if (pool == nullptr || pool == Py_None) {
63 owned_pool.reset(PyObject_CallFunction(
64 reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), nullptr));
65 if (owned_pool == nullptr) {
66 return nullptr;
67 }
68 pool = owned_pool.get();
69 } else {
70 if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
71 PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
72 pool->ob_type->tp_name);
73 return nullptr;
74 }
75 }
76
77 return reinterpret_cast<PyObject*>(
78 NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
79 }
80
Dealloc(PyObject * pself)81 static void Dealloc(PyObject* pself) {
82 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
83
84 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
85 for (iterator it = self->classes_by_descriptor->begin();
86 it != self->classes_by_descriptor->end(); ++it) {
87 Py_CLEAR(it->second);
88 }
89 delete self->classes_by_descriptor;
90 delete self->message_factory;
91 Py_CLEAR(self->pool);
92 Py_TYPE(self)->tp_free(pself);
93 }
94
GcTraverse(PyObject * pself,visitproc visit,void * arg)95 static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
96 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
97 Py_VISIT(self->pool);
98 for (const auto& desc_and_class : *self->classes_by_descriptor) {
99 Py_VISIT(desc_and_class.second);
100 }
101 return 0;
102 }
103
GcClear(PyObject * pself)104 static int GcClear(PyObject* pself) {
105 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
106 // Here it's important to not clear self->pool, so that the C++ DescriptorPool
107 // is still alive when self->message_factory is destructed.
108 for (auto& desc_and_class : *self->classes_by_descriptor) {
109 Py_CLEAR(desc_and_class.second);
110 }
111
112 return 0;
113 }
114
115 // Add a message class to our database.
RegisterMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor,CMessageClass * message_class)116 int RegisterMessageClass(PyMessageFactory* self,
117 const Descriptor* message_descriptor,
118 CMessageClass* message_class) {
119 Py_INCREF(message_class);
120 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
121 std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
122 std::make_pair(message_descriptor, message_class));
123 if (!ret.second) {
124 // Update case: DECREF the previous value.
125 Py_DECREF(ret.first->second);
126 ret.first->second = message_class;
127 }
128 return 0;
129 }
130
GetOrCreateMessageClass(PyMessageFactory * self,const Descriptor * descriptor)131 CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
132 const Descriptor* descriptor) {
133 // This is the same implementation as MessageFactory.GetPrototype().
134
135 // Do not create a MessageClass that already exists.
136 std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
137 self->classes_by_descriptor->find(descriptor);
138 if (it != self->classes_by_descriptor->end()) {
139 Py_INCREF(it->second);
140 return it->second;
141 }
142 ScopedPyObjectPtr py_descriptor(
143 PyMessageDescriptor_FromDescriptor(descriptor));
144 if (py_descriptor == nullptr) {
145 return nullptr;
146 }
147 // Create a new message class.
148 ScopedPyObjectPtr args(Py_BuildValue(
149 "s(){sOsOsO}", std::string(descriptor->name()).c_str(), "DESCRIPTOR",
150 py_descriptor.get(), "__module__", Py_None, "message_factory", self));
151 if (args == nullptr) {
152 return nullptr;
153 }
154 ScopedPyObjectPtr message_class(PyObject_CallObject(
155 reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
156 if (message_class == nullptr) {
157 return nullptr;
158 }
159 // Create messages class for the messages used by the fields, and registers
160 // all extensions for these messages during the recursion.
161 for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
162 const Descriptor* sub_descriptor =
163 descriptor->field(field_idx)->message_type();
164 // It is null if the field type is not a message.
165 if (sub_descriptor != nullptr) {
166 CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
167 if (result == nullptr) {
168 return nullptr;
169 }
170 Py_DECREF(result);
171 }
172 }
173
174 // Register extensions defined in this message.
175 for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
176 const FieldDescriptor* extension = descriptor->extension(ext_idx);
177 ScopedPyObjectPtr py_extended_class(
178 GetOrCreateMessageClass(self, extension->containing_type())
179 ->AsPyObject());
180 if (py_extended_class == nullptr) {
181 return nullptr;
182 }
183 ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
184 if (py_extension == nullptr) {
185 return nullptr;
186 }
187 }
188 return reinterpret_cast<CMessageClass*>(message_class.release());
189 }
190
191 // Retrieve the message class added to our database.
GetMessageClass(PyMessageFactory * self,const Descriptor * message_descriptor)192 CMessageClass* GetMessageClass(PyMessageFactory* self,
193 const Descriptor* message_descriptor) {
194 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
195 iterator ret = self->classes_by_descriptor->find(message_descriptor);
196 if (ret == self->classes_by_descriptor->end()) {
197 PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
198 std::string(message_descriptor->full_name()).c_str());
199 return nullptr;
200 } else {
201 return ret->second;
202 }
203 }
204
205 static PyMethodDef Methods[] = {
206 {nullptr},
207 };
208
GetPool(PyMessageFactory * self,void * closure)209 static PyObject* GetPool(PyMessageFactory* self, void* closure) {
210 Py_INCREF(self->pool);
211 return reinterpret_cast<PyObject*>(self->pool);
212 }
213
214 static PyGetSetDef Getters[] = {
215 {"pool", (getter)GetPool, nullptr, "DescriptorPool"},
216 {nullptr},
217 };
218
219 } // namespace message_factory
220
221 PyTypeObject PyMessageFactory_Type = {
222 PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
223 ".MessageFactory", // tp_name
224 sizeof(PyMessageFactory), // tp_basicsize
225 0, // tp_itemsize
226 message_factory::Dealloc, // tp_dealloc
227 #if PY_VERSION_HEX < 0x03080000
228 nullptr, // tp_print
229 #else
230 0, // tp_vectorcall_offset
231 #endif
232 nullptr, // tp_getattr
233 nullptr, // tp_setattr
234 nullptr, // tp_compare
235 nullptr, // tp_repr
236 nullptr, // tp_as_number
237 nullptr, // tp_as_sequence
238 nullptr, // tp_as_mapping
239 nullptr, // tp_hash
240 nullptr, // tp_call
241 nullptr, // tp_str
242 nullptr, // tp_getattro
243 nullptr, // tp_setattro
244 nullptr, // tp_as_buffer
245 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags
246 "A static Message Factory", // tp_doc
247 message_factory::GcTraverse, // tp_traverse
248 message_factory::GcClear, // tp_clear
249 nullptr, // tp_richcompare
250 0, // tp_weaklistoffset
251 nullptr, // tp_iter
252 nullptr, // tp_iternext
253 message_factory::Methods, // tp_methods
254 nullptr, // tp_members
255 message_factory::Getters, // tp_getset
256 nullptr, // tp_base
257 nullptr, // tp_dict
258 nullptr, // tp_descr_get
259 nullptr, // tp_descr_set
260 0, // tp_dictoffset
261 nullptr, // tp_init
262 nullptr, // tp_alloc
263 message_factory::New, // tp_new
264 PyObject_GC_Del, // tp_free
265 };
266
InitMessageFactory()267 bool InitMessageFactory() {
268 if (PyType_Ready(&PyMessageFactory_Type) < 0) {
269 return false;
270 }
271
272 return true;
273 }
274
275 } // namespace python
276 } // namespace protobuf
277 } // namespace google
278