• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33 
34 #include <google/protobuf/pyext/extension_dict.h>
35 
36 #include <cstdint>
37 #include <memory>
38 
39 #include <google/protobuf/stubs/logging.h>
40 #include <google/protobuf/stubs/common.h>
41 #include <google/protobuf/descriptor.pb.h>
42 #include <google/protobuf/descriptor.h>
43 #include <google/protobuf/dynamic_message.h>
44 #include <google/protobuf/message.h>
45 #include <google/protobuf/pyext/descriptor.h>
46 #include <google/protobuf/pyext/message.h>
47 #include <google/protobuf/pyext/message_factory.h>
48 #include <google/protobuf/pyext/repeated_composite_container.h>
49 #include <google/protobuf/pyext/repeated_scalar_container.h>
50 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
51 
52 #define PyString_AsStringAndSize(ob, charpp, sizep)              \
53   (PyUnicode_Check(ob)                                           \
54        ? ((*(charpp) = const_cast<char*>(                        \
55                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
56               ? -1                                               \
57               : 0)                                               \
58        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
59 
60 namespace google {
61 namespace protobuf {
62 namespace python {
63 
64 namespace extension_dict {
65 
len(ExtensionDict * self)66 static Py_ssize_t len(ExtensionDict* self) {
67   Py_ssize_t size = 0;
68   std::vector<const FieldDescriptor*> fields;
69   self->parent->message->GetReflection()->ListFields(*self->parent->message,
70                                                      &fields);
71 
72   for (size_t i = 0; i < fields.size(); ++i) {
73     if (fields[i]->is_extension()) {
74       // With C++ descriptors, the field can always be retrieved, but for
75       // unknown extensions which have not been imported in Python code, there
76       // is no message class and we cannot retrieve the value.
77       // ListFields() has the same behavior.
78       if (fields[i]->message_type() != nullptr &&
79           message_factory::GetMessageClass(
80               cmessage::GetFactoryForMessage(self->parent),
81               fields[i]->message_type()) == nullptr) {
82         PyErr_Clear();
83         continue;
84       }
85       ++size;
86     }
87   }
88   return size;
89 }
90 
91 struct ExtensionIterator {
92   PyObject_HEAD;
93   Py_ssize_t index;
94   std::vector<const FieldDescriptor*> fields;
95 
96   // Owned reference, to keep the FieldDescriptors alive.
97   ExtensionDict* extension_dict;
98 };
99 
GetIter(PyObject * _self)100 PyObject* GetIter(PyObject* _self) {
101   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
102 
103   ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
104   if (obj == nullptr) {
105     return PyErr_Format(PyExc_MemoryError,
106                         "Could not allocate extension iterator");
107   }
108 
109   ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
110 
111   // Call "placement new" to initialize. So the constructor of
112   // std::vector<...> fields will be called.
113   new (iter) ExtensionIterator;
114 
115   self->parent->message->GetReflection()->ListFields(*self->parent->message,
116                                                      &iter->fields);
117   iter->index = 0;
118   Py_INCREF(self);
119   iter->extension_dict = self;
120 
121   return obj.release();
122 }
123 
DeallocExtensionIterator(PyObject * _self)124 static void DeallocExtensionIterator(PyObject* _self) {
125   ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
126   self->fields.clear();
127   Py_XDECREF(self->extension_dict);
128   self->~ExtensionIterator();
129   Py_TYPE(_self)->tp_free(_self);
130 }
131 
subscript(ExtensionDict * self,PyObject * key)132 PyObject* subscript(ExtensionDict* self, PyObject* key) {
133   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
134   if (descriptor == nullptr) {
135     return nullptr;
136   }
137   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
138     return nullptr;
139   }
140 
141   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
142       descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
143     return cmessage::InternalGetScalar(self->parent->message, descriptor);
144   }
145 
146   CMessage::CompositeFieldsMap::iterator iterator =
147       self->parent->composite_fields->find(descriptor);
148   if (iterator != self->parent->composite_fields->end()) {
149     Py_INCREF(iterator->second);
150     return iterator->second->AsPyObject();
151   }
152 
153   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
154       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
155     // TODO(plabatut): consider building the class on the fly!
156     ContainerBase* sub_message = cmessage::InternalGetSubMessage(
157         self->parent, descriptor);
158     if (sub_message == nullptr) {
159       return nullptr;
160     }
161     (*self->parent->composite_fields)[descriptor] = sub_message;
162     return sub_message->AsPyObject();
163   }
164 
165   if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
166     if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
167       // On the fly message class creation is needed to support the following
168       // situation:
169       // 1- add FileDescriptor to the pool that contains extensions of a message
170       //    defined by another proto file. Do not create any message classes.
171       // 2- instantiate an extended message, and access the extension using
172       //    the field descriptor.
173       // 3- the extension submessage fails to be returned, because no class has
174       //    been created.
175       // It happens when deserializing text proto format, or when enumerating
176       // fields of a deserialized message.
177       CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
178           cmessage::GetFactoryForMessage(self->parent),
179           descriptor->message_type());
180       ScopedPyObjectPtr message_class_handler(
181         reinterpret_cast<PyObject*>(message_class));
182       if (message_class == nullptr) {
183         return nullptr;
184       }
185       ContainerBase* py_container = repeated_composite_container::NewContainer(
186           self->parent, descriptor, message_class);
187       if (py_container == nullptr) {
188         return nullptr;
189       }
190       (*self->parent->composite_fields)[descriptor] = py_container;
191       return py_container->AsPyObject();
192     } else {
193       ContainerBase* py_container = repeated_scalar_container::NewContainer(
194           self->parent, descriptor);
195       if (py_container == nullptr) {
196         return nullptr;
197       }
198       (*self->parent->composite_fields)[descriptor] = py_container;
199       return py_container->AsPyObject();
200     }
201   }
202   PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
203   return nullptr;
204 }
205 
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)206 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
207   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
208   if (descriptor == nullptr) {
209     return -1;
210   }
211   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
212     return -1;
213   }
214 
215   if (value == nullptr) {
216     return cmessage::ClearFieldByDescriptor(self->parent, descriptor);
217   }
218 
219   if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
220       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
221     PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
222                     "type");
223     return -1;
224   }
225   cmessage::AssureWritable(self->parent);
226   if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
227     return -1;
228   }
229   return 0;
230 }
231 
_FindExtensionByName(ExtensionDict * self,PyObject * arg)232 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
233   char* name;
234   Py_ssize_t name_size;
235   if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
236     return nullptr;
237   }
238 
239   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
240   const FieldDescriptor* message_extension =
241       pool->pool->FindExtensionByName(StringParam(name, name_size));
242   if (message_extension == nullptr) {
243     // Is is the name of a message set extension?
244     const Descriptor* message_descriptor =
245         pool->pool->FindMessageTypeByName(StringParam(name, name_size));
246     if (message_descriptor && message_descriptor->extension_count() > 0) {
247       const FieldDescriptor* extension = message_descriptor->extension(0);
248       if (extension->is_extension() &&
249           extension->containing_type()->options().message_set_wire_format() &&
250           extension->type() == FieldDescriptor::TYPE_MESSAGE &&
251           extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
252         message_extension = extension;
253       }
254     }
255   }
256   if (message_extension == nullptr) {
257     Py_RETURN_NONE;
258   }
259 
260   return PyFieldDescriptor_FromDescriptor(message_extension);
261 }
262 
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)263 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
264   int64_t number = PyLong_AsLong(arg);
265   if (number == -1 && PyErr_Occurred()) {
266     return nullptr;
267   }
268 
269   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
270   const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
271       self->parent->message->GetDescriptor(), number);
272   if (message_extension == nullptr) {
273     Py_RETURN_NONE;
274   }
275 
276   return PyFieldDescriptor_FromDescriptor(message_extension);
277 }
278 
Contains(PyObject * _self,PyObject * key)279 static int Contains(PyObject* _self, PyObject* key) {
280   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
281   const FieldDescriptor* field_descriptor =
282       cmessage::GetExtensionDescriptor(key);
283   if (field_descriptor == nullptr) {
284     return -1;
285   }
286 
287   if (!field_descriptor->is_extension()) {
288     PyErr_Format(PyExc_KeyError, "%s is not an extension",
289                  field_descriptor->full_name().c_str());
290     return -1;
291   }
292 
293   const Message* message = self->parent->message;
294   const Reflection* reflection = message->GetReflection();
295   if (field_descriptor->is_repeated()) {
296     if (reflection->FieldSize(*message, field_descriptor) > 0) {
297       return 1;
298     }
299   } else {
300     if (reflection->HasField(*message, field_descriptor)) {
301       return 1;
302     }
303   }
304 
305   return 0;
306 }
307 
NewExtensionDict(CMessage * parent)308 ExtensionDict* NewExtensionDict(CMessage *parent) {
309   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
310       PyType_GenericAlloc(&ExtensionDict_Type, 0));
311   if (self == nullptr) {
312     return nullptr;
313   }
314 
315   Py_INCREF(parent);
316   self->parent = parent;
317   return self;
318 }
319 
dealloc(PyObject * pself)320 void dealloc(PyObject* pself) {
321   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
322   Py_CLEAR(self->parent);
323   Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
324 }
325 
RichCompare(ExtensionDict * self,PyObject * other,int opid)326 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
327   // Only equality comparisons are implemented.
328   if (opid != Py_EQ && opid != Py_NE) {
329     Py_INCREF(Py_NotImplemented);
330     return Py_NotImplemented;
331   }
332   bool equals = false;
333   if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
334     equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
335   }
336   if (equals ^ (opid == Py_EQ)) {
337     Py_RETURN_FALSE;
338   } else {
339     Py_RETURN_TRUE;
340   }
341 }
342 static PySequenceMethods SeqMethods = {
343     (lenfunc)len,          // sq_length
344     nullptr,               // sq_concat
345     nullptr,               // sq_repeat
346     nullptr,               // sq_item
347     nullptr,               // sq_slice
348     nullptr,               // sq_ass_item
349     nullptr,               // sq_ass_slice
350     (objobjproc)Contains,  // sq_contains
351 };
352 
353 static PyMappingMethods MpMethods = {
354   (lenfunc)len,                /* mp_length */
355   (binaryfunc)subscript,       /* mp_subscript */
356   (objobjargproc)ass_subscript,/* mp_ass_subscript */
357 };
358 
359 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
360 static PyMethodDef Methods[] = {
361     EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
362     EDMETHOD(_FindExtensionByNumber, METH_O,
363              "Finds an extension by field number."),
364     {nullptr, nullptr},
365 };
366 
367 }  // namespace extension_dict
368 
369 PyTypeObject ExtensionDict_Type = {
370     PyVarObject_HEAD_INIT(&PyType_Type, 0)  //
371     FULL_MODULE_NAME ".ExtensionDict",      // tp_name
372     sizeof(ExtensionDict),                  // tp_basicsize
373     0,                                      //  tp_itemsize
374     (destructor)extension_dict::dealloc,    //  tp_dealloc
375 #if PY_VERSION_HEX < 0x03080000
376     nullptr,  // tp_print
377 #else
378     0,  // tp_vectorcall_offset
379 #endif
380     nullptr,                                   //  tp_getattr
381     nullptr,                                   //  tp_setattr
382     nullptr,                                   //  tp_compare
383     nullptr,                                   //  tp_repr
384     nullptr,                                   //  tp_as_number
385     &extension_dict::SeqMethods,               //  tp_as_sequence
386     &extension_dict::MpMethods,                //  tp_as_mapping
387     PyObject_HashNotImplemented,               //  tp_hash
388     nullptr,                                   //  tp_call
389     nullptr,                                   //  tp_str
390     nullptr,                                   //  tp_getattro
391     nullptr,                                   //  tp_setattro
392     nullptr,                                   //  tp_as_buffer
393     Py_TPFLAGS_DEFAULT,                        //  tp_flags
394     "An extension dict",                       //  tp_doc
395     nullptr,                                   //  tp_traverse
396     nullptr,                                   //  tp_clear
397     (richcmpfunc)extension_dict::RichCompare,  //  tp_richcompare
398     0,                                         //  tp_weaklistoffset
399     extension_dict::GetIter,                   //  tp_iter
400     nullptr,                                   //  tp_iternext
401     extension_dict::Methods,                   //  tp_methods
402     nullptr,                                   //  tp_members
403     nullptr,                                   //  tp_getset
404     nullptr,                                   //  tp_base
405     nullptr,                                   //  tp_dict
406     nullptr,                                   //  tp_descr_get
407     nullptr,                                   //  tp_descr_set
408     0,                                         //  tp_dictoffset
409     nullptr,                                   //  tp_init
410 };
411 
IterNext(PyObject * _self)412 PyObject* IterNext(PyObject* _self) {
413   extension_dict::ExtensionIterator* self =
414       reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
415   Py_ssize_t total_size = self->fields.size();
416   Py_ssize_t index = self->index;
417   while (self->index < total_size) {
418     index = self->index;
419     ++self->index;
420     if (self->fields[index]->is_extension()) {
421       // With C++ descriptors, the field can always be retrieved, but for
422       // unknown extensions which have not been imported in Python code, there
423       // is no message class and we cannot retrieve the value.
424       // ListFields() has the same behavior.
425       if (self->fields[index]->message_type() != nullptr &&
426           message_factory::GetMessageClass(
427               cmessage::GetFactoryForMessage(self->extension_dict->parent),
428               self->fields[index]->message_type()) == nullptr) {
429         PyErr_Clear();
430         continue;
431       }
432 
433       return PyFieldDescriptor_FromDescriptor(self->fields[index]);
434     }
435   }
436 
437   return nullptr;
438 }
439 
440 PyTypeObject ExtensionIterator_Type = {
441     PyVarObject_HEAD_INIT(&PyType_Type, 0)      //
442     FULL_MODULE_NAME ".ExtensionIterator",      //  tp_name
443     sizeof(extension_dict::ExtensionIterator),  //  tp_basicsize
444     0,                                          //  tp_itemsize
445     extension_dict::DeallocExtensionIterator,   //  tp_dealloc
446 #if PY_VERSION_HEX < 0x03080000
447     nullptr,  // tp_print
448 #else
449     0,  // tp_vectorcall_offset
450 #endif
451     nullptr,                  //  tp_getattr
452     nullptr,                  //  tp_setattr
453     nullptr,                  //  tp_compare
454     nullptr,                  //  tp_repr
455     nullptr,                  //  tp_as_number
456     nullptr,                  //  tp_as_sequence
457     nullptr,                  //  tp_as_mapping
458     nullptr,                  //  tp_hash
459     nullptr,                  //  tp_call
460     nullptr,                  //  tp_str
461     nullptr,                  //  tp_getattro
462     nullptr,                  //  tp_setattro
463     nullptr,                  //  tp_as_buffer
464     Py_TPFLAGS_DEFAULT,       //  tp_flags
465     "A scalar map iterator",  //  tp_doc
466     nullptr,                  //  tp_traverse
467     nullptr,                  //  tp_clear
468     nullptr,                  //  tp_richcompare
469     0,                        //  tp_weaklistoffset
470     PyObject_SelfIter,        //  tp_iter
471     IterNext,                 //  tp_iternext
472     nullptr,                  //  tp_methods
473     nullptr,                  //  tp_members
474     nullptr,                  //  tp_getset
475     nullptr,                  //  tp_base
476     nullptr,                  //  tp_dict
477     nullptr,                  //  tp_descr_get
478     nullptr,                  //  tp_descr_set
479     0,                        //  tp_dictoffset
480     nullptr,                  //  tp_init
481 };
482 }  // namespace python
483 }  // namespace protobuf
484 }  // namespace google
485