1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 // Author: anuraag@google.com (Anuraag Agrawal)
32 // Author: tibell@google.com (Johan Tibell)
33
34 #include <google/protobuf/pyext/extension_dict.h>
35
36 #include <cstdint>
37 #include <memory>
38
39 #include <google/protobuf/stubs/logging.h>
40 #include <google/protobuf/stubs/common.h>
41 #include <google/protobuf/descriptor.pb.h>
42 #include <google/protobuf/descriptor.h>
43 #include <google/protobuf/dynamic_message.h>
44 #include <google/protobuf/message.h>
45 #include <google/protobuf/pyext/descriptor.h>
46 #include <google/protobuf/pyext/message.h>
47 #include <google/protobuf/pyext/message_factory.h>
48 #include <google/protobuf/pyext/repeated_composite_container.h>
49 #include <google/protobuf/pyext/repeated_scalar_container.h>
50 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
51
52 #define PyString_AsStringAndSize(ob, charpp, sizep) \
53 (PyUnicode_Check(ob) \
54 ? ((*(charpp) = const_cast<char*>( \
55 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
56 ? -1 \
57 : 0) \
58 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
59
60 namespace google {
61 namespace protobuf {
62 namespace python {
63
64 namespace extension_dict {
65
len(ExtensionDict * self)66 static Py_ssize_t len(ExtensionDict* self) {
67 Py_ssize_t size = 0;
68 std::vector<const FieldDescriptor*> fields;
69 self->parent->message->GetReflection()->ListFields(*self->parent->message,
70 &fields);
71
72 for (size_t i = 0; i < fields.size(); ++i) {
73 if (fields[i]->is_extension()) {
74 // With C++ descriptors, the field can always be retrieved, but for
75 // unknown extensions which have not been imported in Python code, there
76 // is no message class and we cannot retrieve the value.
77 // ListFields() has the same behavior.
78 if (fields[i]->message_type() != nullptr &&
79 message_factory::GetMessageClass(
80 cmessage::GetFactoryForMessage(self->parent),
81 fields[i]->message_type()) == nullptr) {
82 PyErr_Clear();
83 continue;
84 }
85 ++size;
86 }
87 }
88 return size;
89 }
90
91 struct ExtensionIterator {
92 PyObject_HEAD;
93 Py_ssize_t index;
94 std::vector<const FieldDescriptor*> fields;
95
96 // Owned reference, to keep the FieldDescriptors alive.
97 ExtensionDict* extension_dict;
98 };
99
GetIter(PyObject * _self)100 PyObject* GetIter(PyObject* _self) {
101 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
102
103 ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
104 if (obj == nullptr) {
105 return PyErr_Format(PyExc_MemoryError,
106 "Could not allocate extension iterator");
107 }
108
109 ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
110
111 // Call "placement new" to initialize. So the constructor of
112 // std::vector<...> fields will be called.
113 new (iter) ExtensionIterator;
114
115 self->parent->message->GetReflection()->ListFields(*self->parent->message,
116 &iter->fields);
117 iter->index = 0;
118 Py_INCREF(self);
119 iter->extension_dict = self;
120
121 return obj.release();
122 }
123
DeallocExtensionIterator(PyObject * _self)124 static void DeallocExtensionIterator(PyObject* _self) {
125 ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
126 self->fields.clear();
127 Py_XDECREF(self->extension_dict);
128 self->~ExtensionIterator();
129 Py_TYPE(_self)->tp_free(_self);
130 }
131
subscript(ExtensionDict * self,PyObject * key)132 PyObject* subscript(ExtensionDict* self, PyObject* key) {
133 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
134 if (descriptor == nullptr) {
135 return nullptr;
136 }
137 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
138 return nullptr;
139 }
140
141 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
142 descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
143 return cmessage::InternalGetScalar(self->parent->message, descriptor);
144 }
145
146 CMessage::CompositeFieldsMap::iterator iterator =
147 self->parent->composite_fields->find(descriptor);
148 if (iterator != self->parent->composite_fields->end()) {
149 Py_INCREF(iterator->second);
150 return iterator->second->AsPyObject();
151 }
152
153 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
154 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
155 // TODO(plabatut): consider building the class on the fly!
156 ContainerBase* sub_message = cmessage::InternalGetSubMessage(
157 self->parent, descriptor);
158 if (sub_message == nullptr) {
159 return nullptr;
160 }
161 (*self->parent->composite_fields)[descriptor] = sub_message;
162 return sub_message->AsPyObject();
163 }
164
165 if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
166 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
167 // On the fly message class creation is needed to support the following
168 // situation:
169 // 1- add FileDescriptor to the pool that contains extensions of a message
170 // defined by another proto file. Do not create any message classes.
171 // 2- instantiate an extended message, and access the extension using
172 // the field descriptor.
173 // 3- the extension submessage fails to be returned, because no class has
174 // been created.
175 // It happens when deserializing text proto format, or when enumerating
176 // fields of a deserialized message.
177 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
178 cmessage::GetFactoryForMessage(self->parent),
179 descriptor->message_type());
180 ScopedPyObjectPtr message_class_handler(
181 reinterpret_cast<PyObject*>(message_class));
182 if (message_class == nullptr) {
183 return nullptr;
184 }
185 ContainerBase* py_container = repeated_composite_container::NewContainer(
186 self->parent, descriptor, message_class);
187 if (py_container == nullptr) {
188 return nullptr;
189 }
190 (*self->parent->composite_fields)[descriptor] = py_container;
191 return py_container->AsPyObject();
192 } else {
193 ContainerBase* py_container = repeated_scalar_container::NewContainer(
194 self->parent, descriptor);
195 if (py_container == nullptr) {
196 return nullptr;
197 }
198 (*self->parent->composite_fields)[descriptor] = py_container;
199 return py_container->AsPyObject();
200 }
201 }
202 PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
203 return nullptr;
204 }
205
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)206 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
207 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
208 if (descriptor == nullptr) {
209 return -1;
210 }
211 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
212 return -1;
213 }
214
215 if (value == nullptr) {
216 return cmessage::ClearFieldByDescriptor(self->parent, descriptor);
217 }
218
219 if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
220 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
221 PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
222 "type");
223 return -1;
224 }
225 cmessage::AssureWritable(self->parent);
226 if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
227 return -1;
228 }
229 return 0;
230 }
231
_FindExtensionByName(ExtensionDict * self,PyObject * arg)232 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
233 char* name;
234 Py_ssize_t name_size;
235 if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
236 return nullptr;
237 }
238
239 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
240 const FieldDescriptor* message_extension =
241 pool->pool->FindExtensionByName(StringParam(name, name_size));
242 if (message_extension == nullptr) {
243 // Is is the name of a message set extension?
244 const Descriptor* message_descriptor =
245 pool->pool->FindMessageTypeByName(StringParam(name, name_size));
246 if (message_descriptor && message_descriptor->extension_count() > 0) {
247 const FieldDescriptor* extension = message_descriptor->extension(0);
248 if (extension->is_extension() &&
249 extension->containing_type()->options().message_set_wire_format() &&
250 extension->type() == FieldDescriptor::TYPE_MESSAGE &&
251 extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
252 message_extension = extension;
253 }
254 }
255 }
256 if (message_extension == nullptr) {
257 Py_RETURN_NONE;
258 }
259
260 return PyFieldDescriptor_FromDescriptor(message_extension);
261 }
262
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)263 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
264 int64_t number = PyLong_AsLong(arg);
265 if (number == -1 && PyErr_Occurred()) {
266 return nullptr;
267 }
268
269 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
270 const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
271 self->parent->message->GetDescriptor(), number);
272 if (message_extension == nullptr) {
273 Py_RETURN_NONE;
274 }
275
276 return PyFieldDescriptor_FromDescriptor(message_extension);
277 }
278
Contains(PyObject * _self,PyObject * key)279 static int Contains(PyObject* _self, PyObject* key) {
280 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
281 const FieldDescriptor* field_descriptor =
282 cmessage::GetExtensionDescriptor(key);
283 if (field_descriptor == nullptr) {
284 return -1;
285 }
286
287 if (!field_descriptor->is_extension()) {
288 PyErr_Format(PyExc_KeyError, "%s is not an extension",
289 field_descriptor->full_name().c_str());
290 return -1;
291 }
292
293 const Message* message = self->parent->message;
294 const Reflection* reflection = message->GetReflection();
295 if (field_descriptor->is_repeated()) {
296 if (reflection->FieldSize(*message, field_descriptor) > 0) {
297 return 1;
298 }
299 } else {
300 if (reflection->HasField(*message, field_descriptor)) {
301 return 1;
302 }
303 }
304
305 return 0;
306 }
307
NewExtensionDict(CMessage * parent)308 ExtensionDict* NewExtensionDict(CMessage *parent) {
309 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
310 PyType_GenericAlloc(&ExtensionDict_Type, 0));
311 if (self == nullptr) {
312 return nullptr;
313 }
314
315 Py_INCREF(parent);
316 self->parent = parent;
317 return self;
318 }
319
dealloc(PyObject * pself)320 void dealloc(PyObject* pself) {
321 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
322 Py_CLEAR(self->parent);
323 Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
324 }
325
RichCompare(ExtensionDict * self,PyObject * other,int opid)326 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
327 // Only equality comparisons are implemented.
328 if (opid != Py_EQ && opid != Py_NE) {
329 Py_INCREF(Py_NotImplemented);
330 return Py_NotImplemented;
331 }
332 bool equals = false;
333 if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
334 equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
335 }
336 if (equals ^ (opid == Py_EQ)) {
337 Py_RETURN_FALSE;
338 } else {
339 Py_RETURN_TRUE;
340 }
341 }
342 static PySequenceMethods SeqMethods = {
343 (lenfunc)len, // sq_length
344 nullptr, // sq_concat
345 nullptr, // sq_repeat
346 nullptr, // sq_item
347 nullptr, // sq_slice
348 nullptr, // sq_ass_item
349 nullptr, // sq_ass_slice
350 (objobjproc)Contains, // sq_contains
351 };
352
353 static PyMappingMethods MpMethods = {
354 (lenfunc)len, /* mp_length */
355 (binaryfunc)subscript, /* mp_subscript */
356 (objobjargproc)ass_subscript,/* mp_ass_subscript */
357 };
358
359 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
360 static PyMethodDef Methods[] = {
361 EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
362 EDMETHOD(_FindExtensionByNumber, METH_O,
363 "Finds an extension by field number."),
364 {nullptr, nullptr},
365 };
366
367 } // namespace extension_dict
368
369 PyTypeObject ExtensionDict_Type = {
370 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
371 FULL_MODULE_NAME ".ExtensionDict", // tp_name
372 sizeof(ExtensionDict), // tp_basicsize
373 0, // tp_itemsize
374 (destructor)extension_dict::dealloc, // tp_dealloc
375 #if PY_VERSION_HEX < 0x03080000
376 nullptr, // tp_print
377 #else
378 0, // tp_vectorcall_offset
379 #endif
380 nullptr, // tp_getattr
381 nullptr, // tp_setattr
382 nullptr, // tp_compare
383 nullptr, // tp_repr
384 nullptr, // tp_as_number
385 &extension_dict::SeqMethods, // tp_as_sequence
386 &extension_dict::MpMethods, // tp_as_mapping
387 PyObject_HashNotImplemented, // tp_hash
388 nullptr, // tp_call
389 nullptr, // tp_str
390 nullptr, // tp_getattro
391 nullptr, // tp_setattro
392 nullptr, // tp_as_buffer
393 Py_TPFLAGS_DEFAULT, // tp_flags
394 "An extension dict", // tp_doc
395 nullptr, // tp_traverse
396 nullptr, // tp_clear
397 (richcmpfunc)extension_dict::RichCompare, // tp_richcompare
398 0, // tp_weaklistoffset
399 extension_dict::GetIter, // tp_iter
400 nullptr, // tp_iternext
401 extension_dict::Methods, // tp_methods
402 nullptr, // tp_members
403 nullptr, // tp_getset
404 nullptr, // tp_base
405 nullptr, // tp_dict
406 nullptr, // tp_descr_get
407 nullptr, // tp_descr_set
408 0, // tp_dictoffset
409 nullptr, // tp_init
410 };
411
IterNext(PyObject * _self)412 PyObject* IterNext(PyObject* _self) {
413 extension_dict::ExtensionIterator* self =
414 reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
415 Py_ssize_t total_size = self->fields.size();
416 Py_ssize_t index = self->index;
417 while (self->index < total_size) {
418 index = self->index;
419 ++self->index;
420 if (self->fields[index]->is_extension()) {
421 // With C++ descriptors, the field can always be retrieved, but for
422 // unknown extensions which have not been imported in Python code, there
423 // is no message class and we cannot retrieve the value.
424 // ListFields() has the same behavior.
425 if (self->fields[index]->message_type() != nullptr &&
426 message_factory::GetMessageClass(
427 cmessage::GetFactoryForMessage(self->extension_dict->parent),
428 self->fields[index]->message_type()) == nullptr) {
429 PyErr_Clear();
430 continue;
431 }
432
433 return PyFieldDescriptor_FromDescriptor(self->fields[index]);
434 }
435 }
436
437 return nullptr;
438 }
439
440 PyTypeObject ExtensionIterator_Type = {
441 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
442 FULL_MODULE_NAME ".ExtensionIterator", // tp_name
443 sizeof(extension_dict::ExtensionIterator), // tp_basicsize
444 0, // tp_itemsize
445 extension_dict::DeallocExtensionIterator, // tp_dealloc
446 #if PY_VERSION_HEX < 0x03080000
447 nullptr, // tp_print
448 #else
449 0, // tp_vectorcall_offset
450 #endif
451 nullptr, // tp_getattr
452 nullptr, // tp_setattr
453 nullptr, // tp_compare
454 nullptr, // tp_repr
455 nullptr, // tp_as_number
456 nullptr, // tp_as_sequence
457 nullptr, // tp_as_mapping
458 nullptr, // tp_hash
459 nullptr, // tp_call
460 nullptr, // tp_str
461 nullptr, // tp_getattro
462 nullptr, // tp_setattro
463 nullptr, // tp_as_buffer
464 Py_TPFLAGS_DEFAULT, // tp_flags
465 "A scalar map iterator", // tp_doc
466 nullptr, // tp_traverse
467 nullptr, // tp_clear
468 nullptr, // tp_richcompare
469 0, // tp_weaklistoffset
470 PyObject_SelfIter, // tp_iter
471 IterNext, // tp_iternext
472 nullptr, // tp_methods
473 nullptr, // tp_members
474 nullptr, // tp_getset
475 nullptr, // tp_base
476 nullptr, // tp_dict
477 nullptr, // tp_descr_get
478 nullptr, // tp_descr_set
479 0, // tp_dictoffset
480 nullptr, // tp_init
481 };
482 } // namespace python
483 } // namespace protobuf
484 } // namespace google
485