1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 // Author: anuraag@google.com (Anuraag Agrawal)
9 // Author: tibell@google.com (Johan Tibell)
10
11 #include "google/protobuf/pyext/extension_dict.h"
12
13 #include <cstdint>
14 #include <memory>
15 #include <vector>
16
17 #include "google/protobuf/descriptor.pb.h"
18 #include "google/protobuf/descriptor.h"
19 #include "google/protobuf/dynamic_message.h"
20 #include "google/protobuf/message.h"
21 #include "google/protobuf/pyext/descriptor.h"
22 #include "google/protobuf/pyext/message.h"
23 #include "google/protobuf/pyext/message_factory.h"
24 #include "google/protobuf/pyext/repeated_composite_container.h"
25 #include "google/protobuf/pyext/repeated_scalar_container.h"
26 #include "google/protobuf/pyext/scoped_pyobject_ptr.h"
27 #include "absl/strings/string_view.h"
28
29 #define PyString_AsStringAndSize(ob, charpp, sizep) \
30 (PyUnicode_Check(ob) \
31 ? ((*(charpp) = const_cast<char*>( \
32 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
33 ? -1 \
34 : 0) \
35 : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
36
37 namespace google {
38 namespace protobuf {
39 namespace python {
40
41 namespace extension_dict {
42
len(ExtensionDict * self)43 static Py_ssize_t len(ExtensionDict* self) {
44 Py_ssize_t size = 0;
45 std::vector<const FieldDescriptor*> fields;
46 self->parent->message->GetReflection()->ListFields(*self->parent->message,
47 &fields);
48
49 for (size_t i = 0; i < fields.size(); ++i) {
50 if (fields[i]->is_extension()) {
51 // With C++ descriptors, the field can always be retrieved, but for
52 // unknown extensions which have not been imported in Python code, there
53 // is no message class and we cannot retrieve the value.
54 // ListFields() has the same behavior.
55 if (fields[i]->message_type() != nullptr &&
56 message_factory::GetMessageClass(
57 cmessage::GetFactoryForMessage(self->parent),
58 fields[i]->message_type()) == nullptr) {
59 PyErr_Clear();
60 continue;
61 }
62 ++size;
63 }
64 }
65 return size;
66 }
67
68 struct ExtensionIterator {
69 PyObject_HEAD;
70 Py_ssize_t index;
71 std::vector<const FieldDescriptor*> fields;
72
73 // Owned reference, to keep the FieldDescriptors alive.
74 ExtensionDict* extension_dict;
75 };
76
GetIter(PyObject * _self)77 PyObject* GetIter(PyObject* _self) {
78 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
79
80 ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
81 if (obj == nullptr) {
82 return PyErr_Format(PyExc_MemoryError,
83 "Could not allocate extension iterator");
84 }
85
86 ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
87
88 // Call "placement new" to initialize. So the constructor of
89 // std::vector<...> fields will be called.
90 new (iter) ExtensionIterator;
91
92 self->parent->message->GetReflection()->ListFields(*self->parent->message,
93 &iter->fields);
94 iter->index = 0;
95 Py_INCREF(self);
96 iter->extension_dict = self;
97
98 return obj.release();
99 }
100
DeallocExtensionIterator(PyObject * _self)101 static void DeallocExtensionIterator(PyObject* _self) {
102 ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
103 self->fields.clear();
104 Py_XDECREF(self->extension_dict);
105 freefunc tp_free = Py_TYPE(_self)->tp_free;
106 self->~ExtensionIterator();
107 (*tp_free)(_self);
108 }
109
subscript(ExtensionDict * self,PyObject * key)110 PyObject* subscript(ExtensionDict* self, PyObject* key) {
111 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
112 if (descriptor == nullptr) {
113 return nullptr;
114 }
115 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
116 return nullptr;
117 }
118
119 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
120 descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
121 return cmessage::InternalGetScalar(self->parent->message, descriptor);
122 }
123
124 CMessage::CompositeFieldsMap::iterator iterator =
125 self->parent->composite_fields->find(descriptor);
126 if (iterator != self->parent->composite_fields->end()) {
127 Py_INCREF(iterator->second);
128 return iterator->second->AsPyObject();
129 }
130
131 if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
132 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
133 // TODO: consider building the class on the fly!
134 ContainerBase* sub_message = cmessage::InternalGetSubMessage(
135 self->parent, descriptor);
136 if (sub_message == nullptr) {
137 return nullptr;
138 }
139 (*self->parent->composite_fields)[descriptor] = sub_message;
140 return sub_message->AsPyObject();
141 }
142
143 if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
144 if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
145 // On the fly message class creation is needed to support the following
146 // situation:
147 // 1- add FileDescriptor to the pool that contains extensions of a message
148 // defined by another proto file. Do not create any message classes.
149 // 2- instantiate an extended message, and access the extension using
150 // the field descriptor.
151 // 3- the extension submessage fails to be returned, because no class has
152 // been created.
153 // It happens when deserializing text proto format, or when enumerating
154 // fields of a deserialized message.
155 CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
156 cmessage::GetFactoryForMessage(self->parent),
157 descriptor->message_type());
158 ScopedPyObjectPtr message_class_handler(
159 reinterpret_cast<PyObject*>(message_class));
160 if (message_class == nullptr) {
161 return nullptr;
162 }
163 ContainerBase* py_container = repeated_composite_container::NewContainer(
164 self->parent, descriptor, message_class);
165 if (py_container == nullptr) {
166 return nullptr;
167 }
168 (*self->parent->composite_fields)[descriptor] = py_container;
169 return py_container->AsPyObject();
170 } else {
171 ContainerBase* py_container = repeated_scalar_container::NewContainer(
172 self->parent, descriptor);
173 if (py_container == nullptr) {
174 return nullptr;
175 }
176 (*self->parent->composite_fields)[descriptor] = py_container;
177 return py_container->AsPyObject();
178 }
179 }
180 PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
181 return nullptr;
182 }
183
ass_subscript(ExtensionDict * self,PyObject * key,PyObject * value)184 int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
185 const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
186 if (descriptor == nullptr) {
187 return -1;
188 }
189 if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
190 return -1;
191 }
192
193 if (value == nullptr) {
194 return cmessage::ClearFieldByDescriptor(self->parent, descriptor);
195 }
196
197 if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
198 descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
199 PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
200 "type");
201 return -1;
202 }
203 cmessage::AssureWritable(self->parent);
204 if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
205 return -1;
206 }
207 return 0;
208 }
209
_FindExtensionByName(ExtensionDict * self,PyObject * arg)210 PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
211 char* name;
212 Py_ssize_t name_size;
213 if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
214 return nullptr;
215 }
216
217 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
218 const FieldDescriptor* message_extension =
219 pool->pool->FindExtensionByName(absl::string_view(name, name_size));
220 if (message_extension == nullptr) {
221 // Is is the name of a message set extension?
222 const Descriptor* message_descriptor =
223 pool->pool->FindMessageTypeByName(absl::string_view(name, name_size));
224 if (message_descriptor && message_descriptor->extension_count() > 0) {
225 const FieldDescriptor* extension = message_descriptor->extension(0);
226 if (extension->is_extension() &&
227 extension->containing_type()->options().message_set_wire_format() &&
228 extension->type() == FieldDescriptor::TYPE_MESSAGE &&
229 extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
230 message_extension = extension;
231 }
232 }
233 }
234 if (message_extension == nullptr) {
235 Py_RETURN_NONE;
236 }
237
238 return PyFieldDescriptor_FromDescriptor(message_extension);
239 }
240
_FindExtensionByNumber(ExtensionDict * self,PyObject * arg)241 PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
242 int64_t number = PyLong_AsLong(arg);
243 if (number == -1 && PyErr_Occurred()) {
244 return nullptr;
245 }
246
247 PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
248 const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
249 self->parent->message->GetDescriptor(), number);
250 if (message_extension == nullptr) {
251 Py_RETURN_NONE;
252 }
253
254 return PyFieldDescriptor_FromDescriptor(message_extension);
255 }
256
Contains(PyObject * _self,PyObject * key)257 static int Contains(PyObject* _self, PyObject* key) {
258 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
259 const FieldDescriptor* field_descriptor =
260 cmessage::GetExtensionDescriptor(key);
261 if (field_descriptor == nullptr) {
262 return -1;
263 }
264
265 if (!field_descriptor->is_extension()) {
266 PyErr_Format(PyExc_KeyError, "%s is not an extension",
267 std::string(field_descriptor->full_name()).c_str());
268 return -1;
269 }
270
271 const Message* message = self->parent->message;
272 const Reflection* reflection = message->GetReflection();
273 if (field_descriptor->is_repeated()) {
274 if (reflection->FieldSize(*message, field_descriptor) > 0) {
275 return 1;
276 }
277 } else {
278 if (reflection->HasField(*message, field_descriptor)) {
279 return 1;
280 }
281 }
282
283 return 0;
284 }
285
NewExtensionDict(CMessage * parent)286 ExtensionDict* NewExtensionDict(CMessage *parent) {
287 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
288 PyType_GenericAlloc(&ExtensionDict_Type, 0));
289 if (self == nullptr) {
290 return nullptr;
291 }
292
293 Py_INCREF(parent);
294 self->parent = parent;
295 return self;
296 }
297
dealloc(PyObject * pself)298 void dealloc(PyObject* pself) {
299 ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
300 Py_CLEAR(self->parent);
301 Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
302 }
303
RichCompare(ExtensionDict * self,PyObject * other,int opid)304 static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
305 // Only equality comparisons are implemented.
306 if (opid != Py_EQ && opid != Py_NE) {
307 Py_INCREF(Py_NotImplemented);
308 return Py_NotImplemented;
309 }
310 bool equals = false;
311 if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
312 equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;
313 }
314 if (equals ^ (opid == Py_EQ)) {
315 Py_RETURN_FALSE;
316 } else {
317 Py_RETURN_TRUE;
318 }
319 }
320 static PySequenceMethods SeqMethods = {
321 (lenfunc)len, // sq_length
322 nullptr, // sq_concat
323 nullptr, // sq_repeat
324 nullptr, // sq_item
325 nullptr, // sq_slice
326 nullptr, // sq_ass_item
327 nullptr, // sq_ass_slice
328 (objobjproc)Contains, // sq_contains
329 };
330
331 static PyMappingMethods MpMethods = {
332 (lenfunc)len, /* mp_length */
333 (binaryfunc)subscript, /* mp_subscript */
334 (objobjargproc)ass_subscript,/* mp_ass_subscript */
335 };
336
337 #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
338 static PyMethodDef Methods[] = {
339 EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
340 EDMETHOD(_FindExtensionByNumber, METH_O,
341 "Finds an extension by field number."),
342 {nullptr, nullptr},
343 };
344
345 } // namespace extension_dict
346
347 PyTypeObject ExtensionDict_Type = {
348 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
349 FULL_MODULE_NAME ".ExtensionDict", // tp_name
350 sizeof(ExtensionDict), // tp_basicsize
351 0, // tp_itemsize
352 (destructor)extension_dict::dealloc, // tp_dealloc
353 #if PY_VERSION_HEX < 0x03080000
354 nullptr, // tp_print
355 #else
356 0, // tp_vectorcall_offset
357 #endif
358 nullptr, // tp_getattr
359 nullptr, // tp_setattr
360 nullptr, // tp_compare
361 nullptr, // tp_repr
362 nullptr, // tp_as_number
363 &extension_dict::SeqMethods, // tp_as_sequence
364 &extension_dict::MpMethods, // tp_as_mapping
365 PyObject_HashNotImplemented, // tp_hash
366 nullptr, // tp_call
367 nullptr, // tp_str
368 nullptr, // tp_getattro
369 nullptr, // tp_setattro
370 nullptr, // tp_as_buffer
371 Py_TPFLAGS_DEFAULT, // tp_flags
372 "An extension dict", // tp_doc
373 nullptr, // tp_traverse
374 nullptr, // tp_clear
375 (richcmpfunc)extension_dict::RichCompare, // tp_richcompare
376 0, // tp_weaklistoffset
377 extension_dict::GetIter, // tp_iter
378 nullptr, // tp_iternext
379 extension_dict::Methods, // tp_methods
380 nullptr, // tp_members
381 nullptr, // tp_getset
382 nullptr, // tp_base
383 nullptr, // tp_dict
384 nullptr, // tp_descr_get
385 nullptr, // tp_descr_set
386 0, // tp_dictoffset
387 nullptr, // tp_init
388 };
389
IterNext(PyObject * _self)390 PyObject* IterNext(PyObject* _self) {
391 extension_dict::ExtensionIterator* self =
392 reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
393 Py_ssize_t total_size = self->fields.size();
394 Py_ssize_t index = self->index;
395 while (self->index < total_size) {
396 index = self->index;
397 ++self->index;
398 if (self->fields[index]->is_extension()) {
399 // With C++ descriptors, the field can always be retrieved, but for
400 // unknown extensions which have not been imported in Python code, there
401 // is no message class and we cannot retrieve the value.
402 // ListFields() has the same behavior.
403 if (self->fields[index]->message_type() != nullptr &&
404 message_factory::GetMessageClass(
405 cmessage::GetFactoryForMessage(self->extension_dict->parent),
406 self->fields[index]->message_type()) == nullptr) {
407 PyErr_Clear();
408 continue;
409 }
410
411 return PyFieldDescriptor_FromDescriptor(self->fields[index]);
412 }
413 }
414
415 return nullptr;
416 }
417
418 PyTypeObject ExtensionIterator_Type = {
419 PyVarObject_HEAD_INIT(&PyType_Type, 0) //
420 FULL_MODULE_NAME ".ExtensionIterator", // tp_name
421 sizeof(extension_dict::ExtensionIterator), // tp_basicsize
422 0, // tp_itemsize
423 extension_dict::DeallocExtensionIterator, // tp_dealloc
424 #if PY_VERSION_HEX < 0x03080000
425 nullptr, // tp_print
426 #else
427 0, // tp_vectorcall_offset
428 #endif
429 nullptr, // tp_getattr
430 nullptr, // tp_setattr
431 nullptr, // tp_compare
432 nullptr, // tp_repr
433 nullptr, // tp_as_number
434 nullptr, // tp_as_sequence
435 nullptr, // tp_as_mapping
436 nullptr, // tp_hash
437 nullptr, // tp_call
438 nullptr, // tp_str
439 nullptr, // tp_getattro
440 nullptr, // tp_setattro
441 nullptr, // tp_as_buffer
442 Py_TPFLAGS_DEFAULT, // tp_flags
443 "A scalar map iterator", // tp_doc
444 nullptr, // tp_traverse
445 nullptr, // tp_clear
446 nullptr, // tp_richcompare
447 0, // tp_weaklistoffset
448 PyObject_SelfIter, // tp_iter
449 IterNext, // tp_iternext
450 nullptr, // tp_methods
451 nullptr, // tp_members
452 nullptr, // tp_getset
453 nullptr, // tp_base
454 nullptr, // tp_dict
455 nullptr, // tp_descr_get
456 nullptr, // tp_descr_set
457 0, // tp_dictoffset
458 nullptr, // tp_init
459 };
460 } // namespace python
461 } // namespace protobuf
462 } // namespace google
463