1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33
34 #include <google/protobuf/pyext/extension_dict.h>
35 #include <memory>
36
37 #include <google/protobuf/stubs/logging.h>
38 #include <google/protobuf/stubs/common.h>
39 #include <google/protobuf/descriptor.h>
40 #include <google/protobuf/dynamic_message.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/descriptor.pb.h>
43 #include <google/protobuf/pyext/descriptor.h>
44 #include <google/protobuf/pyext/message.h>
45 #include <google/protobuf/pyext/message_factory.h>
46 #include <google/protobuf/pyext/repeated_composite_container.h>
47 #include <google/protobuf/pyext/repeated_scalar_container.h>
48 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
49
50 #if PY_MAJOR_VERSION >= 3
51 #if PY_VERSION_HEX < 0x03030000
52 #error "Python 3.0 - 3.2 are not supported."
53 #endif
54 #define PyString_AsStringAndSize(ob, charpp, sizep) \
55 (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
56 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
57 ? -1 \
58 : 0) \
59 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
60 #endif
61
62 namespace google {
63 namespace protobuf {
64 namespace python {
65
66 namespace extension_dict {
67
len(ExtensionDict * self)68 static Py_ssize_t len(ExtensionDict* self) {
69 Py_ssize_t size = 0;
70 std::vector<const FieldDescriptor*> fields;
71 self->parent->message->GetReflection()->ListFields(*self->parent->message,
72 &fields);
73
74 for (size_t i = 0; i < fields.size(); ++i) {
75 if (fields[i]->is_extension()) {
76 // With C++ descriptors, the field can always be retrieved, but for
77 // unknown extensions which have not been imported in Python code, there
78 // is no message class and we cannot retrieve the value.
79 // ListFields() has the same behavior.
80 if (fields[i]->message_type() != nullptr &&
81 message_factory::GetMessageClass(
82 cmessage::GetFactoryForMessage(self->parent),
83 fields[i]->message_type()) == nullptr) {
84 PyErr_Clear();
85 continue;
86 }
87 ++size;
88 }
89 }
90 return size;
91 }
92
93 struct ExtensionIterator {
94 PyObject_HEAD;
95 Py_ssize_t index;
96 std::vector<const FieldDescriptor*> fields;
97
98 // Owned reference, to keep the FieldDescriptors alive.
99 ExtensionDict* extension_dict;
100 };
101
GetIter(PyObject * _self)102 PyObject* GetIter(PyObject* _self) {
103 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
104
105 ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
106 if (obj == nullptr) {
107 return PyErr_Format(PyExc_MemoryError,
108 "Could not allocate extension iterator");
109 }
110
111 ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
112
113 // Call "placement new" to initialize. So the constructor of
114 // std::vector<...> fields will be called.
115 new (iter) ExtensionIterator;
116
117 self->parent->message->GetReflection()->ListFields(*self->parent->message,
118 &iter->fields);
119 iter->index = 0;
120 Py_INCREF(self);
121 iter->extension_dict = self;
122
123 return obj.release();
124 }
125
DeallocExtensionIterator(PyObject * _self)126 static void DeallocExtensionIterator(PyObject* _self) {
127 ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
128 self->fields.clear();
129 Py_XDECREF(self->extension_dict);
130 self->~ExtensionIterator();
131 Py_TYPE(_self)->tp_free(_self);
132 }
133
subscript(ExtensionDict * self,PyObject * key)134 PyObject* subscript(ExtensionDict* self, PyObject* key) {
135 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
136 if (descriptor == NULL) {
137 return NULL;
138 }
139 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
140 return NULL;
141 }
142
143 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
144 descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
145 return cmessage::InternalGetScalar(self->parent->message, descriptor);
146 }
147
148 CMessage::CompositeFieldsMap::iterator iterator =
149 self->parent->composite_fields->find(descriptor);
150 if (iterator != self->parent->composite_fields->end()) {
151 Py_INCREF(iterator->second);
152 return iterator->second->AsPyObject();
153 }
154
155 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
156 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
157 // TODO(plabatut): consider building the class on the fly!
158 ContainerBase* sub_message = cmessage::InternalGetSubMessage(
159 self->parent, descriptor);
160 if (sub_message == NULL) {
161 return NULL;
162 }
163 (*self->parent->composite_fields)[descriptor] = sub_message;
164 return sub_message->AsPyObject();
165 }
166
167 if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
168 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
169 // On the fly message class creation is needed to support the following
170 // situation:
171 // 1- add FileDescriptor to the pool that contains extensions of a message
172 // defined by another proto file. Do not create any message classes.
173 // 2- instantiate an extended message, and access the extension using
174 // the field descriptor.
175 // 3- the extension submessage fails to be returned, because no class has
176 // been created.
177 // It happens when deserializing text proto format, or when enumerating
178 // fields of a deserialized message.
179 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
180 cmessage::GetFactoryForMessage(self->parent),
181 descriptor->message_type());
182 ScopedPyObjectPtr message_class_handler(
183 reinterpret_cast<PyObject*>(message_class));
184 if (message_class == NULL) {
185 return NULL;
186 }
187 ContainerBase* py_container = repeated_composite_container::NewContainer(
188 self->parent, descriptor, message_class);
189 if (py_container == NULL) {
190 return NULL;
191 }
192 (*self->parent->composite_fields)[descriptor] = py_container;
193 return py_container->AsPyObject();
194 } else {
195 ContainerBase* py_container = repeated_scalar_container::NewContainer(
196 self->parent, descriptor);
197 if (py_container == NULL) {
198 return NULL;
199 }
200 (*self->parent->composite_fields)[descriptor] = py_container;
201 return py_container->AsPyObject();
202 }
203 }
204 PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
205 return NULL;
206 }
207
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)208 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
209 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
210 if (descriptor == NULL) {
211 return -1;
212 }
213 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
214 return -1;
215 }
216
217 if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
218 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
219 PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
220 "type");
221 return -1;
222 }
223 cmessage::AssureWritable(self->parent);
224 if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
225 return -1;
226 }
227 return 0;
228 }
229
_FindExtensionByName(ExtensionDict * self,PyObject * arg)230 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
231 char* name;
232 Py_ssize_t name_size;
233 if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
234 return NULL;
235 }
236
237 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
238 const FieldDescriptor* message_extension =
239 pool->pool->FindExtensionByName(string(name, name_size));
240 if (message_extension == NULL) {
241 // Is is the name of a message set extension?
242 const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName(
243 string(name, name_size));
244 if (message_descriptor && message_descriptor->extension_count() > 0) {
245 const FieldDescriptor* extension = message_descriptor->extension(0);
246 if (extension->is_extension() &&
247 extension->containing_type()->options().message_set_wire_format() &&
248 extension->type() == FieldDescriptor::TYPE_MESSAGE &&
249 extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
250 message_extension = extension;
251 }
252 }
253 }
254 if (message_extension == NULL) {
255 Py_RETURN_NONE;
256 }
257
258 return PyFieldDescriptor_FromDescriptor(message_extension);
259 }
260
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)261 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
262 int64 number = PyLong_AsLong(arg);
263 if (number == -1 && PyErr_Occurred()) {
264 return NULL;
265 }
266
267 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
268 const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
269 self->parent->message->GetDescriptor(), number);
270 if (message_extension == NULL) {
271 Py_RETURN_NONE;
272 }
273
274 return PyFieldDescriptor_FromDescriptor(message_extension);
275 }
276
Contains(PyObject * _self,PyObject * key)277 static int Contains(PyObject* _self, PyObject* key) {
278 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
279 const FieldDescriptor* field_descriptor =
280 cmessage::GetExtensionDescriptor(key);
281 if (field_descriptor == nullptr) {
282 return -1;
283 }
284
285 if (!field_descriptor->is_extension()) {
286 PyErr_Format(PyExc_KeyError, "%s is not an extension",
287 field_descriptor->full_name().c_str());
288 return -1;
289 }
290
291 const Message* message = self->parent->message;
292 const Reflection* reflection = message->GetReflection();
293 if (field_descriptor->is_repeated()) {
294 if (reflection->FieldSize(*message, field_descriptor) > 0) {
295 return 1;
296 }
297 } else {
298 if (reflection->HasField(*message, field_descriptor)) {
299 return 1;
300 }
301 }
302
303 return 0;
304 }
305
NewExtensionDict(CMessage * parent)306 ExtensionDict* NewExtensionDict(CMessage *parent) {
307 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
308 PyType_GenericAlloc(&ExtensionDict_Type, 0));
309 if (self == NULL) {
310 return NULL;
311 }
312
313 Py_INCREF(parent);
314 self->parent = parent;
315 return self;
316 }
317
dealloc(PyObject * pself)318 void dealloc(PyObject* pself) {
319 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
320 Py_CLEAR(self->parent);
321 Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
322 }
323
RichCompare(ExtensionDict * self,PyObject * other,int opid)324 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
325 // Only equality comparisons are implemented.
326 if (opid != Py_EQ && opid != Py_NE) {
327 Py_INCREF(Py_NotImplemented);
328 return Py_NotImplemented;
329 }
330 bool equals = false;
331 if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
332 equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
333 }
334 if (equals ^ (opid == Py_EQ)) {
335 Py_RETURN_FALSE;
336 } else {
337 Py_RETURN_TRUE;
338 }
339 }
340 static PySequenceMethods SeqMethods = {
341 (lenfunc)len, // sq_length
342 0, // sq_concat
343 0, // sq_repeat
344 0, // sq_item
345 0, // sq_slice
346 0, // sq_ass_item
347 0, // sq_ass_slice
348 (objobjproc)Contains, // sq_contains
349 };
350
351 static PyMappingMethods MpMethods = {
352 (lenfunc)len, /* mp_length */
353 (binaryfunc)subscript, /* mp_subscript */
354 (objobjargproc)ass_subscript,/* mp_ass_subscript */
355 };
356
357 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
358 static PyMethodDef Methods[] = {
359 EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
360 EDMETHOD(_FindExtensionByNumber, METH_O,
361 "Finds an extension by field number."),
362 {NULL, NULL},
363 };
364
365 } // namespace extension_dict
366
367 PyTypeObject ExtensionDict_Type = {
368 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
369 FULL_MODULE_NAME ".ExtensionDict", // tp_name
370 sizeof(ExtensionDict), // tp_basicsize
371 0, // tp_itemsize
372 (destructor)extension_dict::dealloc, // tp_dealloc
373 0, // tp_print
374 0, // tp_getattr
375 0, // tp_setattr
376 0, // tp_compare
377 0, // tp_repr
378 0, // tp_as_number
379 &extension_dict::SeqMethods, // tp_as_sequence
380 &extension_dict::MpMethods, // tp_as_mapping
381 PyObject_HashNotImplemented, // tp_hash
382 0, // tp_call
383 0, // tp_str
384 0, // tp_getattro
385 0, // tp_setattro
386 0, // tp_as_buffer
387 Py_TPFLAGS_DEFAULT, // tp_flags
388 "An extension dict", // tp_doc
389 0, // tp_traverse
390 0, // tp_clear
391 (richcmpfunc)extension_dict::RichCompare, // tp_richcompare
392 0, // tp_weaklistoffset
393 extension_dict::GetIter, // tp_iter
394 0, // tp_iternext
395 extension_dict::Methods, // tp_methods
396 0, // tp_members
397 0, // tp_getset
398 0, // tp_base
399 0, // tp_dict
400 0, // tp_descr_get
401 0, // tp_descr_set
402 0, // tp_dictoffset
403 0, // tp_init
404 };
405
IterNext(PyObject * _self)406 PyObject* IterNext(PyObject* _self) {
407 extension_dict::ExtensionIterator* self =
408 reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
409 Py_ssize_t total_size = self->fields.size();
410 Py_ssize_t index = self->index;
411 while (self->index < total_size) {
412 index = self->index;
413 ++self->index;
414 if (self->fields[index]->is_extension()) {
415 // With C++ descriptors, the field can always be retrieved, but for
416 // unknown extensions which have not been imported in Python code, there
417 // is no message class and we cannot retrieve the value.
418 // ListFields() has the same behavior.
419 if (self->fields[index]->message_type() != nullptr &&
420 message_factory::GetMessageClass(
421 cmessage::GetFactoryForMessage(self->extension_dict->parent),
422 self->fields[index]->message_type()) == nullptr) {
423 PyErr_Clear();
424 continue;
425 }
426
427 return PyFieldDescriptor_FromDescriptor(self->fields[index]);
428 }
429 }
430
431 return nullptr;
432 }
433
434 PyTypeObject ExtensionIterator_Type = {
435 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
436 FULL_MODULE_NAME ".ExtensionIterator", // tp_name
437 sizeof(extension_dict::ExtensionIterator), // tp_basicsize
438 0, // tp_itemsize
439 extension_dict::DeallocExtensionIterator, // tp_dealloc
440 0, // tp_print
441 0, // tp_getattr
442 0, // tp_setattr
443 0, // tp_compare
444 0, // tp_repr
445 0, // tp_as_number
446 0, // tp_as_sequence
447 0, // tp_as_mapping
448 0, // tp_hash
449 0, // tp_call
450 0, // tp_str
451 0, // tp_getattro
452 0, // tp_setattro
453 0, // tp_as_buffer
454 Py_TPFLAGS_DEFAULT, // tp_flags
455 "A scalar map iterator", // tp_doc
456 0, // tp_traverse
457 0, // tp_clear
458 0, // tp_richcompare
459 0, // tp_weaklistoffset
460 PyObject_SelfIter, // tp_iter
461 IterNext, // tp_iternext
462 0, // tp_methods
463 0, // tp_members
464 0, // tp_getset
465 0, // tp_base
466 0, // tp_dict
467 0, // tp_descr_get
468 0, // tp_descr_set
469 0, // tp_dictoffset
470 0, // tp_init
471 };
472 } // namespace python
473 } // namespace protobuf
474 } // namespace google
475