• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: anuraag@google.com (Anuraag Agrawal)
9 // Author: tibell@google.com (Johan Tibell)
10 
11 #include "google/protobuf/pyext/extension_dict.h"
12 
13 #include <cstdint>
14 #include <memory>
15 #include <vector>
16 
17 #include "google/protobuf/descriptor.pb.h"
18 #include "google/protobuf/descriptor.h"
19 #include "google/protobuf/dynamic_message.h"
20 #include "google/protobuf/message.h"
21 #include "google/protobuf/pyext/descriptor.h"
22 #include "google/protobuf/pyext/message.h"
23 #include "google/protobuf/pyext/message_factory.h"
24 #include "google/protobuf/pyext/repeated_composite_container.h"
25 #include "google/protobuf/pyext/repeated_scalar_container.h"
26 #include "google/protobuf/pyext/scoped_pyobject_ptr.h"
27 #include "absl/strings/string_view.h"
28 
29 #define PyString_AsStringAndSize(ob, charpp, sizep)              \
30   (PyUnicode_Check(ob)                                           \
31        ? ((*(charpp) = const_cast<char*>(                        \
32                PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
33               ? -1                                               \
34               : 0)                                               \
35        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
36 
37 namespace google {
38 namespace protobuf {
39 namespace python {
40 
41 namespace extension_dict {
42 
len(ExtensionDict * self)43 static Py_ssize_t len(ExtensionDict* self) {
44   Py_ssize_t size = 0;
45   std::vector<const FieldDescriptor*> fields;
46   self->parent->message->GetReflection()->ListFields(*self->parent->message,
47                                                      &fields);
48 
49   for (size_t i = 0; i < fields.size(); ++i) {
50     if (fields[i]->is_extension()) {
51       // With C++ descriptors, the field can always be retrieved, but for
52       // unknown extensions which have not been imported in Python code, there
53       // is no message class and we cannot retrieve the value.
54       // ListFields() has the same behavior.
55       if (fields[i]->message_type() != nullptr &&
56           message_factory::GetMessageClass(
57               cmessage::GetFactoryForMessage(self->parent),
58               fields[i]->message_type()) == nullptr) {
59         PyErr_Clear();
60         continue;
61       }
62       ++size;
63     }
64   }
65   return size;
66 }
67 
68 struct ExtensionIterator {
69   PyObject_HEAD;
70   Py_ssize_t index;
71   std::vector<const FieldDescriptor*> fields;
72 
73   // Owned reference, to keep the FieldDescriptors alive.
74   ExtensionDict* extension_dict;
75 };
76 
GetIter(PyObject * _self)77 PyObject* GetIter(PyObject* _self) {
78   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
79 
80   ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
81   if (obj == nullptr) {
82     return PyErr_Format(PyExc_MemoryError,
83                         "Could not allocate extension iterator");
84   }
85 
86   ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
87 
88   // Call "placement new" to initialize. So the constructor of
89   // std::vector<...> fields will be called.
90   new (iter) ExtensionIterator;
91 
92   self->parent->message->GetReflection()->ListFields(*self->parent->message,
93                                                      &iter->fields);
94   iter->index = 0;
95   Py_INCREF(self);
96   iter->extension_dict = self;
97 
98   return obj.release();
99 }
100 
DeallocExtensionIterator(PyObject * _self)101 static void DeallocExtensionIterator(PyObject* _self) {
102   ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
103   self->fields.clear();
104   Py_XDECREF(self->extension_dict);
105   freefunc tp_free = Py_TYPE(_self)->tp_free;
106   self->~ExtensionIterator();
107   (*tp_free)(_self);
108 }
109 
subscript(ExtensionDict * self,PyObject * key)110 PyObject* subscript(ExtensionDict* self, PyObject* key) {
111   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
112   if (descriptor == nullptr) {
113     return nullptr;
114   }
115   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
116     return nullptr;
117   }
118 
119   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
120       descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
121     return cmessage::InternalGetScalar(self->parent->message, descriptor);
122   }
123 
124   CMessage::CompositeFieldsMap::iterator iterator =
125       self->parent->composite_fields->find(descriptor);
126   if (iterator != self->parent->composite_fields->end()) {
127     Py_INCREF(iterator->second);
128     return iterator->second->AsPyObject();
129   }
130 
131   if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
132       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
133     // TODO: consider building the class on the fly!
134     ContainerBase* sub_message = cmessage::InternalGetSubMessage(
135         self->parent, descriptor);
136     if (sub_message == nullptr) {
137       return nullptr;
138     }
139     (*self->parent->composite_fields)[descriptor] = sub_message;
140     return sub_message->AsPyObject();
141   }
142 
143   if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
144     if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
145       // On the fly message class creation is needed to support the following
146       // situation:
147       // 1- add FileDescriptor to the pool that contains extensions of a message
148       //    defined by another proto file. Do not create any message classes.
149       // 2- instantiate an extended message, and access the extension using
150       //    the field descriptor.
151       // 3- the extension submessage fails to be returned, because no class has
152       //    been created.
153       // It happens when deserializing text proto format, or when enumerating
154       // fields of a deserialized message.
155       CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
156           cmessage::GetFactoryForMessage(self->parent),
157           descriptor->message_type());
158       ScopedPyObjectPtr message_class_handler(
159         reinterpret_cast<PyObject*>(message_class));
160       if (message_class == nullptr) {
161         return nullptr;
162       }
163       ContainerBase* py_container = repeated_composite_container::NewContainer(
164           self->parent, descriptor, message_class);
165       if (py_container == nullptr) {
166         return nullptr;
167       }
168       (*self->parent->composite_fields)[descriptor] = py_container;
169       return py_container->AsPyObject();
170     } else {
171       ContainerBase* py_container = repeated_scalar_container::NewContainer(
172           self->parent, descriptor);
173       if (py_container == nullptr) {
174         return nullptr;
175       }
176       (*self->parent->composite_fields)[descriptor] = py_container;
177       return py_container->AsPyObject();
178     }
179   }
180   PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
181   return nullptr;
182 }
183 
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)184 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
185   const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
186   if (descriptor == nullptr) {
187     return -1;
188   }
189   if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
190     return -1;
191   }
192 
193   if (value == nullptr) {
194     return cmessage::ClearFieldByDescriptor(self->parent, descriptor);
195   }
196 
197   if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
198       descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
199     PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
200                     "type");
201     return -1;
202   }
203   cmessage::AssureWritable(self->parent);
204   if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
205     return -1;
206   }
207   return 0;
208 }
209 
_FindExtensionByName(ExtensionDict * self,PyObject * arg)210 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
211   char* name;
212   Py_ssize_t name_size;
213   if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
214     return nullptr;
215   }
216 
217   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
218   const FieldDescriptor* message_extension =
219       pool->pool->FindExtensionByName(absl::string_view(name, name_size));
220   if (message_extension == nullptr) {
221     // Is is the name of a message set extension?
222     const Descriptor* message_descriptor =
223         pool->pool->FindMessageTypeByName(absl::string_view(name, name_size));
224     if (message_descriptor && message_descriptor->extension_count() > 0) {
225       const FieldDescriptor* extension = message_descriptor->extension(0);
226       if (extension->is_extension() &&
227           extension->containing_type()->options().message_set_wire_format() &&
228           extension->type() == FieldDescriptor::TYPE_MESSAGE &&
229           extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
230         message_extension = extension;
231       }
232     }
233   }
234   if (message_extension == nullptr) {
235     Py_RETURN_NONE;
236   }
237 
238   return PyFieldDescriptor_FromDescriptor(message_extension);
239 }
240 
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)241 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
242   int64_t number = PyLong_AsLong(arg);
243   if (number == -1 && PyErr_Occurred()) {
244     return nullptr;
245   }
246 
247   PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
248   const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
249       self->parent->message->GetDescriptor(), number);
250   if (message_extension == nullptr) {
251     Py_RETURN_NONE;
252   }
253 
254   return PyFieldDescriptor_FromDescriptor(message_extension);
255 }
256 
Contains(PyObject * _self,PyObject * key)257 static int Contains(PyObject* _self, PyObject* key) {
258   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
259   const FieldDescriptor* field_descriptor =
260       cmessage::GetExtensionDescriptor(key);
261   if (field_descriptor == nullptr) {
262     return -1;
263   }
264 
265   if (!field_descriptor->is_extension()) {
266     PyErr_Format(PyExc_KeyError, "%s is not an extension",
267                  std::string(field_descriptor->full_name()).c_str());
268     return -1;
269   }
270 
271   const Message* message = self->parent->message;
272   const Reflection* reflection = message->GetReflection();
273   if (field_descriptor->is_repeated()) {
274     if (reflection->FieldSize(*message, field_descriptor) > 0) {
275       return 1;
276     }
277   } else {
278     if (reflection->HasField(*message, field_descriptor)) {
279       return 1;
280     }
281   }
282 
283   return 0;
284 }
285 
NewExtensionDict(CMessage * parent)286 ExtensionDict* NewExtensionDict(CMessage *parent) {
287   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
288       PyType_GenericAlloc(&ExtensionDict_Type, 0));
289   if (self == nullptr) {
290     return nullptr;
291   }
292 
293   Py_INCREF(parent);
294   self->parent = parent;
295   return self;
296 }
297 
dealloc(PyObject * pself)298 void dealloc(PyObject* pself) {
299   ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
300   Py_CLEAR(self->parent);
301   Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
302 }
303 
RichCompare(ExtensionDict * self,PyObject * other,int opid)304 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
305   // Only equality comparisons are implemented.
306   if (opid != Py_EQ && opid != Py_NE) {
307     Py_INCREF(Py_NotImplemented);
308     return Py_NotImplemented;
309   }
310   bool equals = false;
311   if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
312     equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;
313   }
314   if (equals ^ (opid == Py_EQ)) {
315     Py_RETURN_FALSE;
316   } else {
317     Py_RETURN_TRUE;
318   }
319 }
320 static PySequenceMethods SeqMethods = {
321     (lenfunc)len,          // sq_length
322     nullptr,               // sq_concat
323     nullptr,               // sq_repeat
324     nullptr,               // sq_item
325     nullptr,               // sq_slice
326     nullptr,               // sq_ass_item
327     nullptr,               // sq_ass_slice
328     (objobjproc)Contains,  // sq_contains
329 };
330 
331 static PyMappingMethods MpMethods = {
332   (lenfunc)len,                /* mp_length */
333   (binaryfunc)subscript,       /* mp_subscript */
334   (objobjargproc)ass_subscript,/* mp_ass_subscript */
335 };
336 
337 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
338 static PyMethodDef Methods[] = {
339     EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
340     EDMETHOD(_FindExtensionByNumber, METH_O,
341              "Finds an extension by field number."),
342     {nullptr, nullptr},
343 };
344 
345 }  // namespace extension_dict
346 
347 PyTypeObject ExtensionDict_Type = {
348     PyVarObject_HEAD_INIT(&PyType_Type, 0)  //
349     FULL_MODULE_NAME ".ExtensionDict",      // tp_name
350     sizeof(ExtensionDict),                  // tp_basicsize
351     0,                                      //  tp_itemsize
352     (destructor)extension_dict::dealloc,    //  tp_dealloc
353 #if PY_VERSION_HEX < 0x03080000
354     nullptr,  // tp_print
355 #else
356     0,  // tp_vectorcall_offset
357 #endif
358     nullptr,                                   //  tp_getattr
359     nullptr,                                   //  tp_setattr
360     nullptr,                                   //  tp_compare
361     nullptr,                                   //  tp_repr
362     nullptr,                                   //  tp_as_number
363     &extension_dict::SeqMethods,               //  tp_as_sequence
364     &extension_dict::MpMethods,                //  tp_as_mapping
365     PyObject_HashNotImplemented,               //  tp_hash
366     nullptr,                                   //  tp_call
367     nullptr,                                   //  tp_str
368     nullptr,                                   //  tp_getattro
369     nullptr,                                   //  tp_setattro
370     nullptr,                                   //  tp_as_buffer
371     Py_TPFLAGS_DEFAULT,                        //  tp_flags
372     "An extension dict",                       //  tp_doc
373     nullptr,                                   //  tp_traverse
374     nullptr,                                   //  tp_clear
375     (richcmpfunc)extension_dict::RichCompare,  //  tp_richcompare
376     0,                                         //  tp_weaklistoffset
377     extension_dict::GetIter,                   //  tp_iter
378     nullptr,                                   //  tp_iternext
379     extension_dict::Methods,                   //  tp_methods
380     nullptr,                                   //  tp_members
381     nullptr,                                   //  tp_getset
382     nullptr,                                   //  tp_base
383     nullptr,                                   //  tp_dict
384     nullptr,                                   //  tp_descr_get
385     nullptr,                                   //  tp_descr_set
386     0,                                         //  tp_dictoffset
387     nullptr,                                   //  tp_init
388 };
389 
IterNext(PyObject * _self)390 PyObject* IterNext(PyObject* _self) {
391   extension_dict::ExtensionIterator* self =
392       reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
393   Py_ssize_t total_size = self->fields.size();
394   Py_ssize_t index = self->index;
395   while (self->index < total_size) {
396     index = self->index;
397     ++self->index;
398     if (self->fields[index]->is_extension()) {
399       // With C++ descriptors, the field can always be retrieved, but for
400       // unknown extensions which have not been imported in Python code, there
401       // is no message class and we cannot retrieve the value.
402       // ListFields() has the same behavior.
403       if (self->fields[index]->message_type() != nullptr &&
404           message_factory::GetMessageClass(
405               cmessage::GetFactoryForMessage(self->extension_dict->parent),
406               self->fields[index]->message_type()) == nullptr) {
407         PyErr_Clear();
408         continue;
409       }
410 
411       return PyFieldDescriptor_FromDescriptor(self->fields[index]);
412     }
413   }
414 
415   return nullptr;
416 }
417 
418 PyTypeObject ExtensionIterator_Type = {
419     PyVarObject_HEAD_INIT(&PyType_Type, 0)      //
420     FULL_MODULE_NAME ".ExtensionIterator",      //  tp_name
421     sizeof(extension_dict::ExtensionIterator),  //  tp_basicsize
422     0,                                          //  tp_itemsize
423     extension_dict::DeallocExtensionIterator,   //  tp_dealloc
424 #if PY_VERSION_HEX < 0x03080000
425     nullptr,  // tp_print
426 #else
427     0,  // tp_vectorcall_offset
428 #endif
429     nullptr,                  //  tp_getattr
430     nullptr,                  //  tp_setattr
431     nullptr,                  //  tp_compare
432     nullptr,                  //  tp_repr
433     nullptr,                  //  tp_as_number
434     nullptr,                  //  tp_as_sequence
435     nullptr,                  //  tp_as_mapping
436     nullptr,                  //  tp_hash
437     nullptr,                  //  tp_call
438     nullptr,                  //  tp_str
439     nullptr,                  //  tp_getattro
440     nullptr,                  //  tp_setattro
441     nullptr,                  //  tp_as_buffer
442     Py_TPFLAGS_DEFAULT,       //  tp_flags
443     "A scalar map iterator",  //  tp_doc
444     nullptr,                  //  tp_traverse
445     nullptr,                  //  tp_clear
446     nullptr,                  //  tp_richcompare
447     0,                        //  tp_weaklistoffset
448     PyObject_SelfIter,        //  tp_iter
449     IterNext,                 //  tp_iternext
450     nullptr,                  //  tp_methods
451     nullptr,                  //  tp_members
452     nullptr,                  //  tp_getset
453     nullptr,                  //  tp_base
454     nullptr,                  //  tp_dict
455     nullptr,                  //  tp_descr_get
456     nullptr,                  //  tp_descr_set
457     0,                        //  tp_dictoffset
458     nullptr,                  //  tp_init
459 };
460 }  // namespace python
461 }  // namespace protobuf
462 }  // namespace google
463