• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33 
34 #include <google/protobuf/pyext/extension_dict.h>
35 #include <memory>
36 
37 #include <google/protobuf/stubs/logging.h>
38 #include <google/protobuf/stubs/common.h>
39 #include <google/protobuf/descriptor.h>
40 #include <google/protobuf/dynamic_message.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/descriptor.pb.h>
43 #include <google/protobuf/pyext/descriptor.h>
44 #include <google/protobuf/pyext/message.h>
45 #include <google/protobuf/pyext/message_factory.h>
46 #include <google/protobuf/pyext/repeated_composite_container.h>
47 #include <google/protobuf/pyext/repeated_scalar_container.h>
48 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
49 
50 #if PY_MAJOR_VERSION >= 3
51   #if PY_VERSION_HEX < 0x03030000
52     #error "Python 3.0 - 3.2 are not supported."
53   #endif
54 #define PyString_AsStringAndSize(ob, charpp, sizep)                           \
55   (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>(                     \
56                                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
57                               ? -1                                            \
58                               : 0)                                            \
59                        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
60 #endif
61 
62 namespace google {
63 namespace protobuf {
64 namespace python {
65 
66 namespace extension_dict {
67 
len(ExtensionDict * self)68 static Py_ssize_t len(ExtensionDict* self) {
69   Py_ssize_t size = 0;
70   std::vector<const FieldDescriptor*> fields;
71   self->parent->message->GetReflection()->ListFields(*self->parent->message,
72                                                      &fields);
73 
74   for (size_t i = 0; i < fields.size(); ++i) {
75     if (fields[i]->is_extension()) {
76       // With C++ descriptors, the field can always be retrieved, but for
77       // unknown extensions which have not been imported in Python code, there
78       // is no message class and we cannot retrieve the value.
79       // ListFields() has the same behavior.
80       if (fields[i]->message_type() != nullptr &&
81           message_factory::GetMessageClass(
82               cmessage::GetFactoryForMessage(self->parent),
83               fields[i]->message_type()) == nullptr) {
84         PyErr_Clear();
85         continue;
86       }
87       ++size;
88     }
89   }
90   return size;
91 }
92 
93 struct ExtensionIterator {
94   PyObject_HEAD;
95   Py_ssize_t index;
96   std::vector<const FieldDescriptor*> fields;
97 
98   // Owned reference, to keep the FieldDescriptors alive.
99   ExtensionDict* extension_dict;
100 };
101 
GetIter(PyObject * _self)102 PyObject* GetIter(PyObject* _self) {
103   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
104 
105   ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
106   if (obj == nullptr) {
107     return PyErr_Format(PyExc_MemoryError,
108                         "Could not allocate extension iterator");
109   }
110 
111   ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
112 
113   // Call "placement new" to initialize. So the constructor of
114   // std::vector<...> fields will be called.
115   new (iter) ExtensionIterator;
116 
117   self->parent->message->GetReflection()->ListFields(*self->parent->message,
118                                                      &iter->fields);
119   iter->index = 0;
120   Py_INCREF(self);
121   iter->extension_dict = self;
122 
123   return obj.release();
124 }
125 
DeallocExtensionIterator(PyObject * _self)126 static void DeallocExtensionIterator(PyObject* _self) {
127   ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
128   self->fields.clear();
129   Py_XDECREF(self->extension_dict);
130   self->~ExtensionIterator();
131   Py_TYPE(_self)->tp_free(_self);
132 }
133 
subscript(ExtensionDict * self,PyObject * key)134 PyObject* subscript(ExtensionDict* self, PyObject* key) {
135   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
136   if (descriptor == NULL) {
137     return NULL;
138   }
139   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
140     return NULL;
141   }
142 
143   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
144       descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
145     return cmessage::InternalGetScalar(self->parent->message, descriptor);
146   }
147 
148   CMessage::CompositeFieldsMap::iterator iterator =
149       self->parent->composite_fields->find(descriptor);
150   if (iterator != self->parent->composite_fields->end()) {
151     Py_INCREF(iterator->second);
152     return iterator->second->AsPyObject();
153   }
154 
155   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
156       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
157     // TODO(plabatut): consider building the class on the fly!
158     ContainerBase* sub_message = cmessage::InternalGetSubMessage(
159         self->parent, descriptor);
160     if (sub_message == NULL) {
161       return NULL;
162     }
163     (*self->parent->composite_fields)[descriptor] = sub_message;
164     return sub_message->AsPyObject();
165   }
166 
167   if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
168     if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
169       // On the fly message class creation is needed to support the following
170       // situation:
171       // 1- add FileDescriptor to the pool that contains extensions of a message
172       //    defined by another proto file. Do not create any message classes.
173       // 2- instantiate an extended message, and access the extension using
174       //    the field descriptor.
175       // 3- the extension submessage fails to be returned, because no class has
176       //    been created.
177       // It happens when deserializing text proto format, or when enumerating
178       // fields of a deserialized message.
179       CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
180           cmessage::GetFactoryForMessage(self->parent),
181           descriptor->message_type());
182       ScopedPyObjectPtr message_class_handler(
183         reinterpret_cast<PyObject*>(message_class));
184       if (message_class == NULL) {
185         return NULL;
186       }
187       ContainerBase* py_container = repeated_composite_container::NewContainer(
188           self->parent, descriptor, message_class);
189       if (py_container == NULL) {
190         return NULL;
191       }
192       (*self->parent->composite_fields)[descriptor] = py_container;
193       return py_container->AsPyObject();
194     } else {
195       ContainerBase* py_container = repeated_scalar_container::NewContainer(
196           self->parent, descriptor);
197       if (py_container == NULL) {
198         return NULL;
199       }
200       (*self->parent->composite_fields)[descriptor] = py_container;
201       return py_container->AsPyObject();
202     }
203   }
204   PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
205   return NULL;
206 }
207 
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)208 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
209   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
210   if (descriptor == NULL) {
211     return -1;
212   }
213   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
214     return -1;
215   }
216 
217   if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
218       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
219     PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
220                     "type");
221     return -1;
222   }
223   cmessage::AssureWritable(self->parent);
224   if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
225     return -1;
226   }
227   return 0;
228 }
229 
_FindExtensionByName(ExtensionDict * self,PyObject * arg)230 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
231   char* name;
232   Py_ssize_t name_size;
233   if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
234     return NULL;
235   }
236 
237   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
238   const FieldDescriptor* message_extension =
239       pool->pool->FindExtensionByName(string(name, name_size));
240   if (message_extension == NULL) {
241     // Is is the name of a message set extension?
242     const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName(
243         string(name, name_size));
244     if (message_descriptor && message_descriptor->extension_count() > 0) {
245       const FieldDescriptor* extension = message_descriptor->extension(0);
246       if (extension->is_extension() &&
247           extension->containing_type()->options().message_set_wire_format() &&
248           extension->type() == FieldDescriptor::TYPE_MESSAGE &&
249           extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
250         message_extension = extension;
251       }
252     }
253   }
254   if (message_extension == NULL) {
255     Py_RETURN_NONE;
256   }
257 
258   return PyFieldDescriptor_FromDescriptor(message_extension);
259 }
260 
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)261 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
262   int64 number = PyLong_AsLong(arg);
263   if (number == -1 && PyErr_Occurred()) {
264     return NULL;
265   }
266 
267   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
268   const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
269       self->parent->message->GetDescriptor(), number);
270   if (message_extension == NULL) {
271     Py_RETURN_NONE;
272   }
273 
274   return PyFieldDescriptor_FromDescriptor(message_extension);
275 }
276 
Contains(PyObject * _self,PyObject * key)277 static int Contains(PyObject* _self, PyObject* key) {
278   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
279   const FieldDescriptor* field_descriptor =
280       cmessage::GetExtensionDescriptor(key);
281   if (field_descriptor == nullptr) {
282     return -1;
283   }
284 
285   if (!field_descriptor->is_extension()) {
286     PyErr_Format(PyExc_KeyError, "%s is not an extension",
287                  field_descriptor->full_name().c_str());
288     return -1;
289   }
290 
291   const Message* message = self->parent->message;
292   const Reflection* reflection = message->GetReflection();
293   if (field_descriptor->is_repeated()) {
294     if (reflection->FieldSize(*message, field_descriptor) > 0) {
295       return 1;
296     }
297   } else {
298     if (reflection->HasField(*message, field_descriptor)) {
299       return 1;
300     }
301   }
302 
303   return 0;
304 }
305 
NewExtensionDict(CMessage * parent)306 ExtensionDict* NewExtensionDict(CMessage *parent) {
307   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
308       PyType_GenericAlloc(&ExtensionDict_Type, 0));
309   if (self == NULL) {
310     return NULL;
311   }
312 
313   Py_INCREF(parent);
314   self->parent = parent;
315   return self;
316 }
317 
dealloc(PyObject * pself)318 void dealloc(PyObject* pself) {
319   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
320   Py_CLEAR(self->parent);
321   Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
322 }
323 
RichCompare(ExtensionDict * self,PyObject * other,int opid)324 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
325   // Only equality comparisons are implemented.
326   if (opid != Py_EQ && opid != Py_NE) {
327     Py_INCREF(Py_NotImplemented);
328     return Py_NotImplemented;
329   }
330   bool equals = false;
331   if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
332     equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
333   }
334   if (equals ^ (opid == Py_EQ)) {
335     Py_RETURN_FALSE;
336   } else {
337     Py_RETURN_TRUE;
338   }
339 }
340 static PySequenceMethods SeqMethods = {
341     (lenfunc)len,          // sq_length
342     0,                     // sq_concat
343     0,                     // sq_repeat
344     0,                     // sq_item
345     0,                     // sq_slice
346     0,                     // sq_ass_item
347     0,                     // sq_ass_slice
348     (objobjproc)Contains,  // sq_contains
349 };
350 
351 static PyMappingMethods MpMethods = {
352   (lenfunc)len,                /* mp_length */
353   (binaryfunc)subscript,       /* mp_subscript */
354   (objobjargproc)ass_subscript,/* mp_ass_subscript */
355 };
356 
357 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
358 static PyMethodDef Methods[] = {
359     EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
360     EDMETHOD(_FindExtensionByNumber, METH_O,
361              "Finds an extension by field number."),
362     {NULL, NULL},
363 };
364 
365 }  // namespace extension_dict
366 
367 PyTypeObject ExtensionDict_Type = {
368     PyVarObject_HEAD_INIT(&PyType_Type, 0)     //
369     FULL_MODULE_NAME ".ExtensionDict",         // tp_name
370     sizeof(ExtensionDict),                     // tp_basicsize
371     0,                                         //  tp_itemsize
372     (destructor)extension_dict::dealloc,       //  tp_dealloc
373     0,                                         //  tp_print
374     0,                                         //  tp_getattr
375     0,                                         //  tp_setattr
376     0,                                         //  tp_compare
377     0,                                         //  tp_repr
378     0,                                         //  tp_as_number
379     &extension_dict::SeqMethods,               //  tp_as_sequence
380     &extension_dict::MpMethods,                //  tp_as_mapping
381     PyObject_HashNotImplemented,               //  tp_hash
382     0,                                         //  tp_call
383     0,                                         //  tp_str
384     0,                                         //  tp_getattro
385     0,                                         //  tp_setattro
386     0,                                         //  tp_as_buffer
387     Py_TPFLAGS_DEFAULT,                        //  tp_flags
388     "An extension dict",                       //  tp_doc
389     0,                                         //  tp_traverse
390     0,                                         //  tp_clear
391     (richcmpfunc)extension_dict::RichCompare,  //  tp_richcompare
392     0,                                         //  tp_weaklistoffset
393     extension_dict::GetIter,                   //  tp_iter
394     0,                                         //  tp_iternext
395     extension_dict::Methods,                   //  tp_methods
396     0,                                         //  tp_members
397     0,                                         //  tp_getset
398     0,                                         //  tp_base
399     0,                                         //  tp_dict
400     0,                                         //  tp_descr_get
401     0,                                         //  tp_descr_set
402     0,                                         //  tp_dictoffset
403     0,                                         //  tp_init
404 };
405 
IterNext(PyObject * _self)406 PyObject* IterNext(PyObject* _self) {
407   extension_dict::ExtensionIterator* self =
408       reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
409   Py_ssize_t total_size = self->fields.size();
410   Py_ssize_t index = self->index;
411   while (self->index < total_size) {
412     index = self->index;
413     ++self->index;
414     if (self->fields[index]->is_extension()) {
415       // With C++ descriptors, the field can always be retrieved, but for
416       // unknown extensions which have not been imported in Python code, there
417       // is no message class and we cannot retrieve the value.
418       // ListFields() has the same behavior.
419       if (self->fields[index]->message_type() != nullptr &&
420           message_factory::GetMessageClass(
421               cmessage::GetFactoryForMessage(self->extension_dict->parent),
422               self->fields[index]->message_type()) == nullptr) {
423         PyErr_Clear();
424         continue;
425       }
426 
427       return PyFieldDescriptor_FromDescriptor(self->fields[index]);
428     }
429   }
430 
431   return nullptr;
432 }
433 
434 PyTypeObject ExtensionIterator_Type = {
435     PyVarObject_HEAD_INIT(&PyType_Type, 0)      //
436     FULL_MODULE_NAME ".ExtensionIterator",      //  tp_name
437     sizeof(extension_dict::ExtensionIterator),  //  tp_basicsize
438     0,                                          //  tp_itemsize
439     extension_dict::DeallocExtensionIterator,   //  tp_dealloc
440     0,                                          //  tp_print
441     0,                                          //  tp_getattr
442     0,                                          //  tp_setattr
443     0,                                          //  tp_compare
444     0,                                          //  tp_repr
445     0,                                          //  tp_as_number
446     0,                                          //  tp_as_sequence
447     0,                                          //  tp_as_mapping
448     0,                                          //  tp_hash
449     0,                                          //  tp_call
450     0,                                          //  tp_str
451     0,                                          //  tp_getattro
452     0,                                          //  tp_setattro
453     0,                                          //  tp_as_buffer
454     Py_TPFLAGS_DEFAULT,                         //  tp_flags
455     "A scalar map iterator",                    //  tp_doc
456     0,                                          //  tp_traverse
457     0,                                          //  tp_clear
458     0,                                          //  tp_richcompare
459     0,                                          //  tp_weaklistoffset
460     PyObject_SelfIter,                          //  tp_iter
461     IterNext,                                   //  tp_iternext
462     0,                                          //  tp_methods
463     0,                                          //  tp_members
464     0,                                          //  tp_getset
465     0,                                          //  tp_base
466     0,                                          //  tp_dict
467     0,                                          //  tp_descr_get
468     0,                                          //  tp_descr_set
469     0,                                          //  tp_dictoffset
470     0,                                          //  tp_init
471 };
472 }  // namespace python
473 }  // namespace protobuf
474 }  // namespace google
475