1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33
34 #include <google/protobuf/pyext/extension_dict.h>
35 #include <memory>
36
37 #include <google/protobuf/stubs/logging.h>
38 #include <google/protobuf/stubs/common.h>
39 #include <google/protobuf/descriptor.h>
40 #include <google/protobuf/dynamic_message.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/descriptor.pb.h>
43 #include <google/protobuf/pyext/descriptor.h>
44 #include <google/protobuf/pyext/message.h>
45 #include <google/protobuf/pyext/message_factory.h>
46 #include <google/protobuf/pyext/repeated_composite_container.h>
47 #include <google/protobuf/pyext/repeated_scalar_container.h>
48 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
49
50 #if PY_MAJOR_VERSION >= 3
51 #if PY_VERSION_HEX < 0x03030000
52 #error "Python 3.0 - 3.2 are not supported."
53 #endif
54 #define PyString_AsStringAndSize(ob, charpp, sizep) \
55 (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
56 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
57 ? -1 \
58 : 0) \
59 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
60 #endif
61
62 namespace google {
63 namespace protobuf {
64 namespace python {
65
66 namespace extension_dict {
67
len(ExtensionDict * self)68 static Py_ssize_t len(ExtensionDict* self) {
69 Py_ssize_t size = 0;
70 std::vector<const FieldDescriptor*> fields;
71 self->parent->message->GetReflection()->ListFields(*self->parent->message,
72 &fields);
73
74 for (size_t i = 0; i < fields.size(); ++i) {
75 if (fields[i]->is_extension()) {
76 // With C++ descriptors, the field can always be retrieved, but for
77 // unknown extensions which have not been imported in Python code, there
78 // is no message class and we cannot retrieve the value.
79 // ListFields() has the same behavior.
80 if (fields[i]->message_type() != nullptr &&
81 message_factory::GetMessageClass(
82 cmessage::GetFactoryForMessage(self->parent),
83 fields[i]->message_type()) == nullptr) {
84 PyErr_Clear();
85 continue;
86 }
87 ++size;
88 }
89 }
90 return size;
91 }
92
93 struct ExtensionIterator {
94 PyObject_HEAD;
95 Py_ssize_t index;
96 std::vector<const FieldDescriptor*> fields;
97
98 // Owned reference, to keep the FieldDescriptors alive.
99 ExtensionDict* extension_dict;
100 };
101
GetIter(PyObject * _self)102 PyObject* GetIter(PyObject* _self) {
103 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
104
105 ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
106 if (obj == nullptr) {
107 return PyErr_Format(PyExc_MemoryError,
108 "Could not allocate extension iterator");
109 }
110
111 ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
112
113 // Call "placement new" to initialize. So the constructor of
114 // std::vector<...> fields will be called.
115 new (iter) ExtensionIterator;
116
117 self->parent->message->GetReflection()->ListFields(*self->parent->message,
118 &iter->fields);
119 iter->index = 0;
120 Py_INCREF(self);
121 iter->extension_dict = self;
122
123 return obj.release();
124 }
125
DeallocExtensionIterator(PyObject * _self)126 static void DeallocExtensionIterator(PyObject* _self) {
127 ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
128 self->fields.clear();
129 Py_XDECREF(self->extension_dict);
130 self->~ExtensionIterator();
131 Py_TYPE(_self)->tp_free(_self);
132 }
133
subscript(ExtensionDict * self,PyObject * key)134 PyObject* subscript(ExtensionDict* self, PyObject* key) {
135 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
136 if (descriptor == NULL) {
137 return NULL;
138 }
139 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
140 return NULL;
141 }
142
143 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
144 descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
145 return cmessage::InternalGetScalar(self->parent->message, descriptor);
146 }
147
148 CMessage::CompositeFieldsMap::iterator iterator =
149 self->parent->composite_fields->find(descriptor);
150 if (iterator != self->parent->composite_fields->end()) {
151 Py_INCREF(iterator->second);
152 return iterator->second->AsPyObject();
153 }
154
155 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
156 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
157 // TODO(plabatut): consider building the class on the fly!
158 ContainerBase* sub_message = cmessage::InternalGetSubMessage(
159 self->parent, descriptor);
160 if (sub_message == NULL) {
161 return NULL;
162 }
163 (*self->parent->composite_fields)[descriptor] = sub_message;
164 return sub_message->AsPyObject();
165 }
166
167 if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
168 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
169 // On the fly message class creation is needed to support the following
170 // situation:
171 // 1- add FileDescriptor to the pool that contains extensions of a message
172 // defined by another proto file. Do not create any message classes.
173 // 2- instantiate an extended message, and access the extension using
174 // the field descriptor.
175 // 3- the extension submessage fails to be returned, because no class has
176 // been created.
177 // It happens when deserializing text proto format, or when enumerating
178 // fields of a deserialized message.
179 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
180 cmessage::GetFactoryForMessage(self->parent),
181 descriptor->message_type());
182 ScopedPyObjectPtr message_class_handler(
183 reinterpret_cast<PyObject*>(message_class));
184 if (message_class == NULL) {
185 return NULL;
186 }
187 ContainerBase* py_container = repeated_composite_container::NewContainer(
188 self->parent, descriptor, message_class);
189 if (py_container == NULL) {
190 return NULL;
191 }
192 (*self->parent->composite_fields)[descriptor] = py_container;
193 return py_container->AsPyObject();
194 } else {
195 ContainerBase* py_container = repeated_scalar_container::NewContainer(
196 self->parent, descriptor);
197 if (py_container == NULL) {
198 return NULL;
199 }
200 (*self->parent->composite_fields)[descriptor] = py_container;
201 return py_container->AsPyObject();
202 }
203 }
204 PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
205 return NULL;
206 }
207
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)208 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
209 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
210 if (descriptor == NULL) {
211 return -1;
212 }
213 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
214 return -1;
215 }
216
217 if (value == nullptr) {
218 return cmessage::ClearFieldByDescriptor(self->parent, descriptor);
219 }
220
221 if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
222 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
223 PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
224 "type");
225 return -1;
226 }
227 cmessage::AssureWritable(self->parent);
228 if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
229 return -1;
230 }
231 return 0;
232 }
233
_FindExtensionByName(ExtensionDict * self,PyObject * arg)234 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
235 char* name;
236 Py_ssize_t name_size;
237 if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
238 return NULL;
239 }
240
241 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
242 const FieldDescriptor* message_extension =
243 pool->pool->FindExtensionByName(StringParam(name, name_size));
244 if (message_extension == NULL) {
245 // Is is the name of a message set extension?
246 const Descriptor* message_descriptor =
247 pool->pool->FindMessageTypeByName(StringParam(name, name_size));
248 if (message_descriptor && message_descriptor->extension_count() > 0) {
249 const FieldDescriptor* extension = message_descriptor->extension(0);
250 if (extension->is_extension() &&
251 extension->containing_type()->options().message_set_wire_format() &&
252 extension->type() == FieldDescriptor::TYPE_MESSAGE &&
253 extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
254 message_extension = extension;
255 }
256 }
257 }
258 if (message_extension == NULL) {
259 Py_RETURN_NONE;
260 }
261
262 return PyFieldDescriptor_FromDescriptor(message_extension);
263 }
264
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)265 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
266 int64 number = PyLong_AsLong(arg);
267 if (number == -1 && PyErr_Occurred()) {
268 return NULL;
269 }
270
271 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
272 const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
273 self->parent->message->GetDescriptor(), number);
274 if (message_extension == NULL) {
275 Py_RETURN_NONE;
276 }
277
278 return PyFieldDescriptor_FromDescriptor(message_extension);
279 }
280
Contains(PyObject * _self,PyObject * key)281 static int Contains(PyObject* _self, PyObject* key) {
282 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
283 const FieldDescriptor* field_descriptor =
284 cmessage::GetExtensionDescriptor(key);
285 if (field_descriptor == nullptr) {
286 return -1;
287 }
288
289 if (!field_descriptor->is_extension()) {
290 PyErr_Format(PyExc_KeyError, "%s is not an extension",
291 field_descriptor->full_name().c_str());
292 return -1;
293 }
294
295 const Message* message = self->parent->message;
296 const Reflection* reflection = message->GetReflection();
297 if (field_descriptor->is_repeated()) {
298 if (reflection->FieldSize(*message, field_descriptor) > 0) {
299 return 1;
300 }
301 } else {
302 if (reflection->HasField(*message, field_descriptor)) {
303 return 1;
304 }
305 }
306
307 return 0;
308 }
309
NewExtensionDict(CMessage * parent)310 ExtensionDict* NewExtensionDict(CMessage *parent) {
311 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
312 PyType_GenericAlloc(&ExtensionDict_Type, 0));
313 if (self == NULL) {
314 return NULL;
315 }
316
317 Py_INCREF(parent);
318 self->parent = parent;
319 return self;
320 }
321
dealloc(PyObject * pself)322 void dealloc(PyObject* pself) {
323 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
324 Py_CLEAR(self->parent);
325 Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
326 }
327
RichCompare(ExtensionDict * self,PyObject * other,int opid)328 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
329 // Only equality comparisons are implemented.
330 if (opid != Py_EQ && opid != Py_NE) {
331 Py_INCREF(Py_NotImplemented);
332 return Py_NotImplemented;
333 }
334 bool equals = false;
335 if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
336 equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
337 }
338 if (equals ^ (opid == Py_EQ)) {
339 Py_RETURN_FALSE;
340 } else {
341 Py_RETURN_TRUE;
342 }
343 }
344 static PySequenceMethods SeqMethods = {
345 (lenfunc)len, // sq_length
346 0, // sq_concat
347 0, // sq_repeat
348 0, // sq_item
349 0, // sq_slice
350 0, // sq_ass_item
351 0, // sq_ass_slice
352 (objobjproc)Contains, // sq_contains
353 };
354
355 static PyMappingMethods MpMethods = {
356 (lenfunc)len, /* mp_length */
357 (binaryfunc)subscript, /* mp_subscript */
358 (objobjargproc)ass_subscript,/* mp_ass_subscript */
359 };
360
361 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
362 static PyMethodDef Methods[] = {
363 EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
364 EDMETHOD(_FindExtensionByNumber, METH_O,
365 "Finds an extension by field number."),
366 {NULL, NULL},
367 };
368
369 } // namespace extension_dict
370
371 PyTypeObject ExtensionDict_Type = {
372 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
373 FULL_MODULE_NAME ".ExtensionDict", // tp_name
374 sizeof(ExtensionDict), // tp_basicsize
375 0, // tp_itemsize
376 (destructor)extension_dict::dealloc, // tp_dealloc
377 0, // tp_print
378 0, // tp_getattr
379 0, // tp_setattr
380 0, // tp_compare
381 0, // tp_repr
382 0, // tp_as_number
383 &extension_dict::SeqMethods, // tp_as_sequence
384 &extension_dict::MpMethods, // tp_as_mapping
385 PyObject_HashNotImplemented, // tp_hash
386 0, // tp_call
387 0, // tp_str
388 0, // tp_getattro
389 0, // tp_setattro
390 0, // tp_as_buffer
391 Py_TPFLAGS_DEFAULT, // tp_flags
392 "An extension dict", // tp_doc
393 0, // tp_traverse
394 0, // tp_clear
395 (richcmpfunc)extension_dict::RichCompare, // tp_richcompare
396 0, // tp_weaklistoffset
397 extension_dict::GetIter, // tp_iter
398 0, // tp_iternext
399 extension_dict::Methods, // tp_methods
400 0, // tp_members
401 0, // tp_getset
402 0, // tp_base
403 0, // tp_dict
404 0, // tp_descr_get
405 0, // tp_descr_set
406 0, // tp_dictoffset
407 0, // tp_init
408 };
409
IterNext(PyObject * _self)410 PyObject* IterNext(PyObject* _self) {
411 extension_dict::ExtensionIterator* self =
412 reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
413 Py_ssize_t total_size = self->fields.size();
414 Py_ssize_t index = self->index;
415 while (self->index < total_size) {
416 index = self->index;
417 ++self->index;
418 if (self->fields[index]->is_extension()) {
419 // With C++ descriptors, the field can always be retrieved, but for
420 // unknown extensions which have not been imported in Python code, there
421 // is no message class and we cannot retrieve the value.
422 // ListFields() has the same behavior.
423 if (self->fields[index]->message_type() != nullptr &&
424 message_factory::GetMessageClass(
425 cmessage::GetFactoryForMessage(self->extension_dict->parent),
426 self->fields[index]->message_type()) == nullptr) {
427 PyErr_Clear();
428 continue;
429 }
430
431 return PyFieldDescriptor_FromDescriptor(self->fields[index]);
432 }
433 }
434
435 return nullptr;
436 }
437
438 PyTypeObject ExtensionIterator_Type = {
439 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
440 FULL_MODULE_NAME ".ExtensionIterator", // tp_name
441 sizeof(extension_dict::ExtensionIterator), // tp_basicsize
442 0, // tp_itemsize
443 extension_dict::DeallocExtensionIterator, // tp_dealloc
444 0, // tp_print
445 0, // tp_getattr
446 0, // tp_setattr
447 0, // tp_compare
448 0, // tp_repr
449 0, // tp_as_number
450 0, // tp_as_sequence
451 0, // tp_as_mapping
452 0, // tp_hash
453 0, // tp_call
454 0, // tp_str
455 0, // tp_getattro
456 0, // tp_setattro
457 0, // tp_as_buffer
458 Py_TPFLAGS_DEFAULT, // tp_flags
459 "A scalar map iterator", // tp_doc
460 0, // tp_traverse
461 0, // tp_clear
462 0, // tp_richcompare
463 0, // tp_weaklistoffset
464 PyObject_SelfIter, // tp_iter
465 IterNext, // tp_iternext
466 0, // tp_methods
467 0, // tp_members
468 0, // tp_getset
469 0, // tp_base
470 0, // tp_dict
471 0, // tp_descr_get
472 0, // tp_descr_set
473 0, // tp_dictoffset
474 0, // tp_init
475 };
476 } // namespace python
477 } // namespace protobuf
478 } // namespace google
479