• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33 
34 #include <google/protobuf/pyext/extension_dict.h>
35 #include <memory>
36 
37 #include <google/protobuf/stubs/logging.h>
38 #include <google/protobuf/stubs/common.h>
39 #include <google/protobuf/descriptor.h>
40 #include <google/protobuf/dynamic_message.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/descriptor.pb.h>
43 #include <google/protobuf/pyext/descriptor.h>
44 #include <google/protobuf/pyext/message.h>
45 #include <google/protobuf/pyext/message_factory.h>
46 #include <google/protobuf/pyext/repeated_composite_container.h>
47 #include <google/protobuf/pyext/repeated_scalar_container.h>
48 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
49 
50 #if PY_MAJOR_VERSION >= 3
51   #if PY_VERSION_HEX < 0x03030000
52     #error "Python 3.0 - 3.2 are not supported."
53   #endif
54 #define PyString_AsStringAndSize(ob, charpp, sizep)                           \
55   (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>(                     \
56                                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
57                               ? -1                                            \
58                               : 0)                                            \
59                        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
60 #endif
61 
62 namespace google {
63 namespace protobuf {
64 namespace python {
65 
66 namespace extension_dict {
67 
len(ExtensionDict * self)68 static Py_ssize_t len(ExtensionDict* self) {
69   Py_ssize_t size = 0;
70   std::vector<const FieldDescriptor*> fields;
71   self->parent->message->GetReflection()->ListFields(*self->parent->message,
72                                                      &fields);
73 
74   for (size_t i = 0; i < fields.size(); ++i) {
75     if (fields[i]->is_extension()) {
76       // With C++ descriptors, the field can always be retrieved, but for
77       // unknown extensions which have not been imported in Python code, there
78       // is no message class and we cannot retrieve the value.
79       // ListFields() has the same behavior.
80       if (fields[i]->message_type() != nullptr &&
81           message_factory::GetMessageClass(
82               cmessage::GetFactoryForMessage(self->parent),
83               fields[i]->message_type()) == nullptr) {
84         PyErr_Clear();
85         continue;
86       }
87       ++size;
88     }
89   }
90   return size;
91 }
92 
93 struct ExtensionIterator {
94   PyObject_HEAD;
95   Py_ssize_t index;
96   std::vector<const FieldDescriptor*> fields;
97 
98   // Owned reference, to keep the FieldDescriptors alive.
99   ExtensionDict* extension_dict;
100 };
101 
GetIter(PyObject * _self)102 PyObject* GetIter(PyObject* _self) {
103   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
104 
105   ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
106   if (obj == nullptr) {
107     return PyErr_Format(PyExc_MemoryError,
108                         "Could not allocate extension iterator");
109   }
110 
111   ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
112 
113   // Call "placement new" to initialize. So the constructor of
114   // std::vector<...> fields will be called.
115   new (iter) ExtensionIterator;
116 
117   self->parent->message->GetReflection()->ListFields(*self->parent->message,
118                                                      &iter->fields);
119   iter->index = 0;
120   Py_INCREF(self);
121   iter->extension_dict = self;
122 
123   return obj.release();
124 }
125 
DeallocExtensionIterator(PyObject * _self)126 static void DeallocExtensionIterator(PyObject* _self) {
127   ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
128   self->fields.clear();
129   Py_XDECREF(self->extension_dict);
130   self->~ExtensionIterator();
131   Py_TYPE(_self)->tp_free(_self);
132 }
133 
subscript(ExtensionDict * self,PyObject * key)134 PyObject* subscript(ExtensionDict* self, PyObject* key) {
135   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
136   if (descriptor == NULL) {
137     return NULL;
138   }
139   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
140     return NULL;
141   }
142 
143   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
144       descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
145     return cmessage::InternalGetScalar(self->parent->message, descriptor);
146   }
147 
148   CMessage::CompositeFieldsMap::iterator iterator =
149       self->parent->composite_fields->find(descriptor);
150   if (iterator != self->parent->composite_fields->end()) {
151     Py_INCREF(iterator->second);
152     return iterator->second->AsPyObject();
153   }
154 
155   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
156       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
157     // TODO(plabatut): consider building the class on the fly!
158     ContainerBase* sub_message = cmessage::InternalGetSubMessage(
159         self->parent, descriptor);
160     if (sub_message == NULL) {
161       return NULL;
162     }
163     (*self->parent->composite_fields)[descriptor] = sub_message;
164     return sub_message->AsPyObject();
165   }
166 
167   if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
168     if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
169       // On the fly message class creation is needed to support the following
170       // situation:
171       // 1- add FileDescriptor to the pool that contains extensions of a message
172       //    defined by another proto file. Do not create any message classes.
173       // 2- instantiate an extended message, and access the extension using
174       //    the field descriptor.
175       // 3- the extension submessage fails to be returned, because no class has
176       //    been created.
177       // It happens when deserializing text proto format, or when enumerating
178       // fields of a deserialized message.
179       CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
180           cmessage::GetFactoryForMessage(self->parent),
181           descriptor->message_type());
182       ScopedPyObjectPtr message_class_handler(
183         reinterpret_cast<PyObject*>(message_class));
184       if (message_class == NULL) {
185         return NULL;
186       }
187       ContainerBase* py_container = repeated_composite_container::NewContainer(
188           self->parent, descriptor, message_class);
189       if (py_container == NULL) {
190         return NULL;
191       }
192       (*self->parent->composite_fields)[descriptor] = py_container;
193       return py_container->AsPyObject();
194     } else {
195       ContainerBase* py_container = repeated_scalar_container::NewContainer(
196           self->parent, descriptor);
197       if (py_container == NULL) {
198         return NULL;
199       }
200       (*self->parent->composite_fields)[descriptor] = py_container;
201       return py_container->AsPyObject();
202     }
203   }
204   PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
205   return NULL;
206 }
207 
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)208 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
209   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
210   if (descriptor == NULL) {
211     return -1;
212   }
213   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
214     return -1;
215   }
216 
217   if (value == nullptr) {
218     return cmessage::ClearFieldByDescriptor(self->parent, descriptor);
219   }
220 
221   if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
222       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
223     PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
224                     "type");
225     return -1;
226   }
227   cmessage::AssureWritable(self->parent);
228   if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
229     return -1;
230   }
231   return 0;
232 }
233 
_FindExtensionByName(ExtensionDict * self,PyObject * arg)234 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
235   char* name;
236   Py_ssize_t name_size;
237   if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
238     return NULL;
239   }
240 
241   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
242   const FieldDescriptor* message_extension =
243       pool->pool->FindExtensionByName(StringParam(name, name_size));
244   if (message_extension == NULL) {
245     // Is is the name of a message set extension?
246     const Descriptor* message_descriptor =
247         pool->pool->FindMessageTypeByName(StringParam(name, name_size));
248     if (message_descriptor && message_descriptor->extension_count() > 0) {
249       const FieldDescriptor* extension = message_descriptor->extension(0);
250       if (extension->is_extension() &&
251           extension->containing_type()->options().message_set_wire_format() &&
252           extension->type() == FieldDescriptor::TYPE_MESSAGE &&
253           extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
254         message_extension = extension;
255       }
256     }
257   }
258   if (message_extension == NULL) {
259     Py_RETURN_NONE;
260   }
261 
262   return PyFieldDescriptor_FromDescriptor(message_extension);
263 }
264 
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)265 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
266   int64 number = PyLong_AsLong(arg);
267   if (number == -1 && PyErr_Occurred()) {
268     return NULL;
269   }
270 
271   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
272   const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
273       self->parent->message->GetDescriptor(), number);
274   if (message_extension == NULL) {
275     Py_RETURN_NONE;
276   }
277 
278   return PyFieldDescriptor_FromDescriptor(message_extension);
279 }
280 
Contains(PyObject * _self,PyObject * key)281 static int Contains(PyObject* _self, PyObject* key) {
282   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
283   const FieldDescriptor* field_descriptor =
284       cmessage::GetExtensionDescriptor(key);
285   if (field_descriptor == nullptr) {
286     return -1;
287   }
288 
289   if (!field_descriptor->is_extension()) {
290     PyErr_Format(PyExc_KeyError, "%s is not an extension",
291                  field_descriptor->full_name().c_str());
292     return -1;
293   }
294 
295   const Message* message = self->parent->message;
296   const Reflection* reflection = message->GetReflection();
297   if (field_descriptor->is_repeated()) {
298     if (reflection->FieldSize(*message, field_descriptor) > 0) {
299       return 1;
300     }
301   } else {
302     if (reflection->HasField(*message, field_descriptor)) {
303       return 1;
304     }
305   }
306 
307   return 0;
308 }
309 
NewExtensionDict(CMessage * parent)310 ExtensionDict* NewExtensionDict(CMessage *parent) {
311   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
312       PyType_GenericAlloc(&ExtensionDict_Type, 0));
313   if (self == NULL) {
314     return NULL;
315   }
316 
317   Py_INCREF(parent);
318   self->parent = parent;
319   return self;
320 }
321 
dealloc(PyObject * pself)322 void dealloc(PyObject* pself) {
323   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
324   Py_CLEAR(self->parent);
325   Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
326 }
327 
RichCompare(ExtensionDict * self,PyObject * other,int opid)328 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
329   // Only equality comparisons are implemented.
330   if (opid != Py_EQ && opid != Py_NE) {
331     Py_INCREF(Py_NotImplemented);
332     return Py_NotImplemented;
333   }
334   bool equals = false;
335   if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
336     equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
337   }
338   if (equals ^ (opid == Py_EQ)) {
339     Py_RETURN_FALSE;
340   } else {
341     Py_RETURN_TRUE;
342   }
343 }
344 static PySequenceMethods SeqMethods = {
345     (lenfunc)len,          // sq_length
346     0,                     // sq_concat
347     0,                     // sq_repeat
348     0,                     // sq_item
349     0,                     // sq_slice
350     0,                     // sq_ass_item
351     0,                     // sq_ass_slice
352     (objobjproc)Contains,  // sq_contains
353 };
354 
355 static PyMappingMethods MpMethods = {
356   (lenfunc)len,                /* mp_length */
357   (binaryfunc)subscript,       /* mp_subscript */
358   (objobjargproc)ass_subscript,/* mp_ass_subscript */
359 };
360 
361 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
362 static PyMethodDef Methods[] = {
363     EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
364     EDMETHOD(_FindExtensionByNumber, METH_O,
365              "Finds an extension by field number."),
366     {NULL, NULL},
367 };
368 
369 }  // namespace extension_dict
370 
371 PyTypeObject ExtensionDict_Type = {
372     PyVarObject_HEAD_INIT(&PyType_Type, 0)     //
373     FULL_MODULE_NAME ".ExtensionDict",         // tp_name
374     sizeof(ExtensionDict),                     // tp_basicsize
375     0,                                         //  tp_itemsize
376     (destructor)extension_dict::dealloc,       //  tp_dealloc
377     0,                                         //  tp_print
378     0,                                         //  tp_getattr
379     0,                                         //  tp_setattr
380     0,                                         //  tp_compare
381     0,                                         //  tp_repr
382     0,                                         //  tp_as_number
383     &extension_dict::SeqMethods,               //  tp_as_sequence
384     &extension_dict::MpMethods,                //  tp_as_mapping
385     PyObject_HashNotImplemented,               //  tp_hash
386     0,                                         //  tp_call
387     0,                                         //  tp_str
388     0,                                         //  tp_getattro
389     0,                                         //  tp_setattro
390     0,                                         //  tp_as_buffer
391     Py_TPFLAGS_DEFAULT,                        //  tp_flags
392     "An extension dict",                       //  tp_doc
393     0,                                         //  tp_traverse
394     0,                                         //  tp_clear
395     (richcmpfunc)extension_dict::RichCompare,  //  tp_richcompare
396     0,                                         //  tp_weaklistoffset
397     extension_dict::GetIter,                   //  tp_iter
398     0,                                         //  tp_iternext
399     extension_dict::Methods,                   //  tp_methods
400     0,                                         //  tp_members
401     0,                                         //  tp_getset
402     0,                                         //  tp_base
403     0,                                         //  tp_dict
404     0,                                         //  tp_descr_get
405     0,                                         //  tp_descr_set
406     0,                                         //  tp_dictoffset
407     0,                                         //  tp_init
408 };
409 
IterNext(PyObject * _self)410 PyObject* IterNext(PyObject* _self) {
411   extension_dict::ExtensionIterator* self =
412       reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
413   Py_ssize_t total_size = self->fields.size();
414   Py_ssize_t index = self->index;
415   while (self->index < total_size) {
416     index = self->index;
417     ++self->index;
418     if (self->fields[index]->is_extension()) {
419       // With C++ descriptors, the field can always be retrieved, but for
420       // unknown extensions which have not been imported in Python code, there
421       // is no message class and we cannot retrieve the value.
422       // ListFields() has the same behavior.
423       if (self->fields[index]->message_type() != nullptr &&
424           message_factory::GetMessageClass(
425               cmessage::GetFactoryForMessage(self->extension_dict->parent),
426               self->fields[index]->message_type()) == nullptr) {
427         PyErr_Clear();
428         continue;
429       }
430 
431       return PyFieldDescriptor_FromDescriptor(self->fields[index]);
432     }
433   }
434 
435   return nullptr;
436 }
437 
438 PyTypeObject ExtensionIterator_Type = {
439     PyVarObject_HEAD_INIT(&PyType_Type, 0)      //
440     FULL_MODULE_NAME ".ExtensionIterator",      //  tp_name
441     sizeof(extension_dict::ExtensionIterator),  //  tp_basicsize
442     0,                                          //  tp_itemsize
443     extension_dict::DeallocExtensionIterator,   //  tp_dealloc
444     0,                                          //  tp_print
445     0,                                          //  tp_getattr
446     0,                                          //  tp_setattr
447     0,                                          //  tp_compare
448     0,                                          //  tp_repr
449     0,                                          //  tp_as_number
450     0,                                          //  tp_as_sequence
451     0,                                          //  tp_as_mapping
452     0,                                          //  tp_hash
453     0,                                          //  tp_call
454     0,                                          //  tp_str
455     0,                                          //  tp_getattro
456     0,                                          //  tp_setattro
457     0,                                          //  tp_as_buffer
458     Py_TPFLAGS_DEFAULT,                         //  tp_flags
459     "A scalar map iterator",                    //  tp_doc
460     0,                                          //  tp_traverse
461     0,                                          //  tp_clear
462     0,                                          //  tp_richcompare
463     0,                                          //  tp_weaklistoffset
464     PyObject_SelfIter,                          //  tp_iter
465     IterNext,                                   //  tp_iternext
466     0,                                          //  tp_methods
467     0,                                          //  tp_members
468     0,                                          //  tp_getset
469     0,                                          //  tp_base
470     0,                                          //  tp_dict
471     0,                                          //  tp_descr_get
472     0,                                          //  tp_descr_set
473     0,                                          //  tp_dictoffset
474     0,                                          //  tp_init
475 };
476 }  // namespace python
477 }  // namespace protobuf
478 }  // namespace google
479