1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 // Author: haberman@google.com (Josh Haberman)
9
10 #include "google/protobuf/pyext/map_container.h"
11
12 #include <cstdint>
13 #include <memory>
14 #include <string>
15
16 #include "google/protobuf/map.h"
17 #include "google/protobuf/map_field.h"
18 #include "google/protobuf/message.h"
19 #include "google/protobuf/pyext/message.h"
20 #include "google/protobuf/pyext/message_factory.h"
21 #include "google/protobuf/pyext/repeated_composite_container.h"
22 #include "google/protobuf/pyext/scoped_pyobject_ptr.h"
23
24 namespace google {
25 namespace protobuf {
26 namespace python {
27
28 // Functions that need access to map reflection functionality.
29 // They need to be contained in this class because it is friended.
30 class MapReflectionFriend {
31 public:
32 // Methods that are in common between the map types.
33 static PyObject* Contains(PyObject* _self, PyObject* key);
34 static Py_ssize_t Length(PyObject* _self);
35 static PyObject* GetIterator(PyObject *_self);
36 static PyObject* IterNext(PyObject* _self);
37 static PyObject* MergeFrom(PyObject* _self, PyObject* arg);
38
39 // Methods that differ between the map types.
40 static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
41 static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
42 static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
43 static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
44 static PyObject* ScalarMapToStr(PyObject* _self);
45 static PyObject* MessageMapToStr(PyObject* _self);
46 };
47
48 struct MapIterator {
49 PyObject_HEAD;
50
51 std::unique_ptr<::google::protobuf::MapIterator> iter;
52
53 // A pointer back to the container, so we can notice changes to the version.
54 // We own a ref on this.
55 MapContainer* container;
56
57 // We need to keep a ref on the parent Message too, because
58 // MapIterator::~MapIterator() accesses it. Normally this would be ok because
59 // the ref on container (above) would guarantee outlive semantics. However in
60 // the case of ClearField(), the MapContainer points to a different message,
61 // a copy of the original. But our iterator still points to the original,
62 // which could now get deleted before us.
63 //
64 // To prevent this, we ensure that the Message will always stay alive as long
65 // as this iterator does. This is solely for the benefit of the MapIterator
66 // destructor -- we should never actually access the iterator in this state
67 // except to delete it.
68 CMessage* parent;
69 // The version of the map when we took the iterator to it.
70 //
71 // We store this so that if the map is modified during iteration we can throw
72 // an error.
73 uint64_t version;
74 };
75
GetMutableMessage()76 Message* MapContainer::GetMutableMessage() {
77 cmessage::AssureWritable(parent);
78 return parent->message;
79 }
80
81 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,std::string * stl_string)82 static bool PyStringToSTL(PyObject* py_string, std::string* stl_string) {
83 char *value;
84 Py_ssize_t value_len;
85
86 if (!py_string) {
87 return false;
88 }
89 if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
90 Py_DECREF(py_string);
91 return false;
92 } else {
93 stl_string->assign(value, value_len);
94 Py_DECREF(py_string);
95 return true;
96 }
97 }
98
PythonToMapKey(MapContainer * self,PyObject * obj,MapKey * key,std::string * key_string)99 static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key,
100 std::string* key_string) {
101 const FieldDescriptor* field_descriptor =
102 self->parent_field_descriptor->message_type()->map_key();
103 switch (field_descriptor->cpp_type()) {
104 case FieldDescriptor::CPPTYPE_INT32: {
105 PROTOBUF_CHECK_GET_INT32(obj, value, false);
106 key->SetInt32Value(value);
107 break;
108 }
109 case FieldDescriptor::CPPTYPE_INT64: {
110 PROTOBUF_CHECK_GET_INT64(obj, value, false);
111 key->SetInt64Value(value);
112 break;
113 }
114 case FieldDescriptor::CPPTYPE_UINT32: {
115 PROTOBUF_CHECK_GET_UINT32(obj, value, false);
116 key->SetUInt32Value(value);
117 break;
118 }
119 case FieldDescriptor::CPPTYPE_UINT64: {
120 PROTOBUF_CHECK_GET_UINT64(obj, value, false);
121 key->SetUInt64Value(value);
122 break;
123 }
124 case FieldDescriptor::CPPTYPE_BOOL: {
125 PROTOBUF_CHECK_GET_BOOL(obj, value, false);
126 key->SetBoolValue(value);
127 break;
128 }
129 case FieldDescriptor::CPPTYPE_STRING: {
130 if (!PyStringToSTL(CheckString(obj, field_descriptor), key_string)) {
131 return false;
132 }
133 key->SetStringValue(*key_string);
134 break;
135 }
136 default:
137 PyErr_Format(
138 PyExc_SystemError, "Type %d cannot be a map key",
139 field_descriptor->cpp_type());
140 return false;
141 }
142 return true;
143 }
144
MapKeyToPython(MapContainer * self,const MapKey & key)145 static PyObject* MapKeyToPython(MapContainer* self, const MapKey& key) {
146 const FieldDescriptor* field_descriptor =
147 self->parent_field_descriptor->message_type()->map_key();
148 switch (field_descriptor->cpp_type()) {
149 case FieldDescriptor::CPPTYPE_INT32:
150 return PyLong_FromLong(key.GetInt32Value());
151 case FieldDescriptor::CPPTYPE_INT64:
152 return PyLong_FromLongLong(key.GetInt64Value());
153 case FieldDescriptor::CPPTYPE_UINT32:
154 return PyLong_FromSize_t(key.GetUInt32Value());
155 case FieldDescriptor::CPPTYPE_UINT64:
156 return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
157 case FieldDescriptor::CPPTYPE_BOOL:
158 return PyBool_FromLong(key.GetBoolValue());
159 case FieldDescriptor::CPPTYPE_STRING:
160 return ToStringObject(field_descriptor, key.GetStringValue());
161 default:
162 PyErr_Format(
163 PyExc_SystemError, "Couldn't convert type %d to value",
164 field_descriptor->cpp_type());
165 return nullptr;
166 }
167 }
168
169 // This is only used for ScalarMap, so we don't need to handle the
170 // CPPTYPE_MESSAGE case.
MapValueRefToPython(MapContainer * self,const MapValueRef & value)171 PyObject* MapValueRefToPython(MapContainer* self, const MapValueRef& value) {
172 const FieldDescriptor* field_descriptor =
173 self->parent_field_descriptor->message_type()->map_value();
174 switch (field_descriptor->cpp_type()) {
175 case FieldDescriptor::CPPTYPE_INT32:
176 return PyLong_FromLong(value.GetInt32Value());
177 case FieldDescriptor::CPPTYPE_INT64:
178 return PyLong_FromLongLong(value.GetInt64Value());
179 case FieldDescriptor::CPPTYPE_UINT32:
180 return PyLong_FromSize_t(value.GetUInt32Value());
181 case FieldDescriptor::CPPTYPE_UINT64:
182 return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
183 case FieldDescriptor::CPPTYPE_FLOAT:
184 return PyFloat_FromDouble(value.GetFloatValue());
185 case FieldDescriptor::CPPTYPE_DOUBLE:
186 return PyFloat_FromDouble(value.GetDoubleValue());
187 case FieldDescriptor::CPPTYPE_BOOL:
188 return PyBool_FromLong(value.GetBoolValue());
189 case FieldDescriptor::CPPTYPE_STRING:
190 return ToStringObject(field_descriptor, value.GetStringValue());
191 case FieldDescriptor::CPPTYPE_ENUM:
192 return PyLong_FromLong(value.GetEnumValue());
193 default:
194 PyErr_Format(
195 PyExc_SystemError, "Couldn't convert type %d to value",
196 field_descriptor->cpp_type());
197 return nullptr;
198 }
199 }
200
201 // This is only used for ScalarMap, so we don't need to handle the
202 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(MapContainer * self,PyObject * obj,bool allow_unknown_enum_values,MapValueRef * value_ref)203 static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
204 bool allow_unknown_enum_values,
205 MapValueRef* value_ref) {
206 const FieldDescriptor* field_descriptor =
207 self->parent_field_descriptor->message_type()->map_value();
208 switch (field_descriptor->cpp_type()) {
209 case FieldDescriptor::CPPTYPE_INT32: {
210 PROTOBUF_CHECK_GET_INT32(obj, value, false);
211 value_ref->SetInt32Value(value);
212 return true;
213 }
214 case FieldDescriptor::CPPTYPE_INT64: {
215 PROTOBUF_CHECK_GET_INT64(obj, value, false);
216 value_ref->SetInt64Value(value);
217 return true;
218 }
219 case FieldDescriptor::CPPTYPE_UINT32: {
220 PROTOBUF_CHECK_GET_UINT32(obj, value, false);
221 value_ref->SetUInt32Value(value);
222 return true;
223 }
224 case FieldDescriptor::CPPTYPE_UINT64: {
225 PROTOBUF_CHECK_GET_UINT64(obj, value, false);
226 value_ref->SetUInt64Value(value);
227 return true;
228 }
229 case FieldDescriptor::CPPTYPE_FLOAT: {
230 PROTOBUF_CHECK_GET_FLOAT(obj, value, false);
231 value_ref->SetFloatValue(value);
232 return true;
233 }
234 case FieldDescriptor::CPPTYPE_DOUBLE: {
235 PROTOBUF_CHECK_GET_DOUBLE(obj, value, false);
236 value_ref->SetDoubleValue(value);
237 return true;
238 }
239 case FieldDescriptor::CPPTYPE_BOOL: {
240 PROTOBUF_CHECK_GET_BOOL(obj, value, false);
241 value_ref->SetBoolValue(value);
242 return true;
243 }
244 case FieldDescriptor::CPPTYPE_STRING: {
245 std::string str;
246 if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
247 return false;
248 }
249 value_ref->SetStringValue(str);
250 return true;
251 }
252 case FieldDescriptor::CPPTYPE_ENUM: {
253 PROTOBUF_CHECK_GET_INT32(obj, value, false);
254 if (allow_unknown_enum_values) {
255 value_ref->SetEnumValue(value);
256 return true;
257 } else {
258 const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
259 const EnumValueDescriptor* enum_value =
260 enum_descriptor->FindValueByNumber(value);
261 if (enum_value != nullptr) {
262 value_ref->SetEnumValue(value);
263 return true;
264 } else {
265 PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
266 return false;
267 }
268 }
269 break;
270 }
271 default:
272 PyErr_Format(
273 PyExc_SystemError, "Setting value to a field of unknown type %d",
274 field_descriptor->cpp_type());
275 return false;
276 }
277 }
278
279 // Map methods common to ScalarMap and MessageMap //////////////////////////////
280
GetMap(PyObject * obj)281 static MapContainer* GetMap(PyObject* obj) {
282 return reinterpret_cast<MapContainer*>(obj);
283 }
284
Length(PyObject * _self)285 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
286 MapContainer* self = GetMap(_self);
287 const google::protobuf::Message* message = self->parent->message;
288 return message->GetReflection()->MapSize(*message,
289 self->parent_field_descriptor);
290 }
291
Clear(PyObject * _self)292 PyObject* Clear(PyObject* _self) {
293 MapContainer* self = GetMap(_self);
294 Message* message = self->GetMutableMessage();
295 const Reflection* reflection = message->GetReflection();
296
297 reflection->ClearField(message, self->parent_field_descriptor);
298
299 Py_RETURN_NONE;
300 }
301
GetEntryClass(PyObject * _self)302 PyObject* GetEntryClass(PyObject* _self) {
303 MapContainer* self = GetMap(_self);
304 CMessageClass* message_class = message_factory::GetMessageClass(
305 cmessage::GetFactoryForMessage(self->parent),
306 self->parent_field_descriptor->message_type());
307 Py_XINCREF(message_class);
308 return reinterpret_cast<PyObject*>(message_class);
309 }
310
MergeFrom(PyObject * _self,PyObject * arg)311 PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
312 MapContainer* self = GetMap(_self);
313 if (!PyObject_TypeCheck(arg, ScalarMapContainer_Type) &&
314 !PyObject_TypeCheck(arg, MessageMapContainer_Type)) {
315 PyErr_SetString(PyExc_AttributeError, "Not a map field");
316 return nullptr;
317 }
318 MapContainer* other_map = GetMap(arg);
319 Message* message = self->GetMutableMessage();
320 const Message* other_message = other_map->parent->message;
321 const Reflection* reflection = message->GetReflection();
322 const Reflection* other_reflection = other_message->GetReflection();
323 internal::MapFieldBase* field = reflection->MutableMapData(
324 message, self->parent_field_descriptor);
325 const internal::MapFieldBase* other_field = other_reflection->GetMapData(
326 *other_message, other_map->parent_field_descriptor);
327 field->MergeFrom(*other_field);
328 self->version++;
329 Py_RETURN_NONE;
330 }
331
Contains(PyObject * _self,PyObject * key)332 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
333 MapContainer* self = GetMap(_self);
334
335 const Message* message = self->parent->message;
336 const Reflection* reflection = message->GetReflection();
337 std::string map_key_string;
338 MapKey map_key;
339
340 if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
341 return nullptr;
342 }
343
344 if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
345 map_key)) {
346 Py_RETURN_TRUE;
347 } else {
348 Py_RETURN_FALSE;
349 }
350 }
351
352 // ScalarMap ///////////////////////////////////////////////////////////////////
353
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)354 MapContainer* NewScalarMapContainer(
355 CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
356 if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
357 return nullptr;
358 }
359
360 PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
361 if (obj == nullptr) {
362 PyErr_Format(PyExc_RuntimeError,
363 "Could not allocate new container.");
364 return nullptr;
365 }
366
367 MapContainer* self = GetMap(obj);
368
369 Py_INCREF(parent);
370 self->parent = parent;
371 self->parent_field_descriptor = parent_field_descriptor;
372 self->version = 0;
373
374 return self;
375 }
376
ScalarMapGetItem(PyObject * _self,PyObject * key)377 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
378 PyObject* key) {
379 MapContainer* self = GetMap(_self);
380
381 Message* message = self->GetMutableMessage();
382 const Reflection* reflection = message->GetReflection();
383 std::string map_key_string;
384 MapKey map_key;
385 MapValueRef value;
386
387 if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
388 return nullptr;
389 }
390
391 if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
392 map_key, &value)) {
393 self->version++;
394 }
395
396 return MapValueRefToPython(self, value);
397 }
398
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)399 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
400 PyObject* v) {
401 MapContainer* self = GetMap(_self);
402
403 Message* message = self->GetMutableMessage();
404 const Reflection* reflection = message->GetReflection();
405 std::string map_key_string;
406 MapKey map_key;
407 MapValueRef value;
408
409 if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
410 return -1;
411 }
412
413 if (v) {
414 // Set item to v.
415 if (reflection->InsertOrLookupMapValue(
416 message, self->parent_field_descriptor, map_key, &value)) {
417 self->version++;
418 }
419
420 if (!PythonToMapValueRef(self, v,
421 !self->parent_field_descriptor->message_type()
422 ->map_value()
423 ->legacy_enum_field_treated_as_closed(),
424 &value)) {
425 return -1;
426 }
427 return 0;
428 } else {
429 // Delete key from map.
430 if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
431 map_key)) {
432 self->version++;
433 return 0;
434 } else {
435 PyErr_Format(PyExc_KeyError, "Key not present in map");
436 return -1;
437 }
438 }
439 }
440
ScalarMapGet(PyObject * self,PyObject * args,PyObject * kwargs)441 static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
442 PyObject* kwargs) {
443 static const char* kwlist[] = {"key", "default", nullptr};
444 PyObject* key;
445 PyObject* default_value = nullptr;
446 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
447 const_cast<char**>(kwlist), &key,
448 &default_value)) {
449 return nullptr;
450 }
451
452 ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
453 if (is_present.get() == nullptr) {
454 return nullptr;
455 }
456
457 if (PyObject_IsTrue(is_present.get())) {
458 return MapReflectionFriend::ScalarMapGetItem(self, key);
459 } else {
460 if (default_value != nullptr) {
461 Py_INCREF(default_value);
462 return default_value;
463 } else {
464 Py_RETURN_NONE;
465 }
466 }
467 }
468
ScalarMapToStr(PyObject * _self)469 PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
470 ScopedPyObjectPtr dict(PyDict_New());
471 if (dict == nullptr) {
472 return nullptr;
473 }
474 ScopedPyObjectPtr key;
475 ScopedPyObjectPtr value;
476
477 MapContainer* self = GetMap(_self);
478 Message* message = self->GetMutableMessage();
479 const Reflection* reflection = message->GetReflection();
480 for (google::protobuf::MapIterator it = reflection->MapBegin(
481 message, self->parent_field_descriptor);
482 it != reflection->MapEnd(message, self->parent_field_descriptor);
483 ++it) {
484 key.reset(MapKeyToPython(self, it.GetKey()));
485 if (key == nullptr) {
486 return nullptr;
487 }
488 value.reset(MapValueRefToPython(self, it.GetValueRef()));
489 if (value == nullptr) {
490 return nullptr;
491 }
492 if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
493 return nullptr;
494 }
495 }
496 return PyObject_Repr(dict.get());
497 }
498
ScalarMapDealloc(PyObject * _self)499 static void ScalarMapDealloc(PyObject* _self) {
500 MapContainer* self = GetMap(_self);
501 self->RemoveFromParentCache();
502 PyTypeObject *type = Py_TYPE(_self);
503 type->tp_free(_self);
504 if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
505 // With Python3, the Map class is not static, and must be managed.
506 Py_DECREF(type);
507 }
508 }
509
510 static PyMethodDef ScalarMapMethods[] = {
511 {"__contains__", MapReflectionFriend::Contains, METH_O,
512 "Tests whether a key is a member of the map."},
513 {"clear", (PyCFunction)Clear, METH_NOARGS,
514 "Removes all elements from the map."},
515 {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
516 "Gets the value for the given key if present, or otherwise a default"},
517 {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
518 "Return the class used to build Entries of (key, value) pairs."},
519 {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
520 "Merges a map into the current map."},
521 /*
522 { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
523 "Makes a deep copy of the class." },
524 { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
525 "Outputs picklable representation of the repeated field." },
526 */
527 {nullptr, nullptr},
528 };
529
530 PyTypeObject* ScalarMapContainer_Type;
531 static PyType_Slot ScalarMapContainer_Type_slots[] = {
532 {Py_tp_dealloc, (void*)ScalarMapDealloc},
533 {Py_mp_length, (void*)MapReflectionFriend::Length},
534 {Py_mp_subscript, (void*)MapReflectionFriend::ScalarMapGetItem},
535 {Py_mp_ass_subscript, (void*)MapReflectionFriend::ScalarMapSetItem},
536 {Py_tp_methods, (void*)ScalarMapMethods},
537 {Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
538 {Py_tp_repr, (void*)MapReflectionFriend::ScalarMapToStr},
539 {0, nullptr},
540 };
541
542 PyType_Spec ScalarMapContainer_Type_spec = {
543 FULL_MODULE_NAME ".ScalarMapContainer", sizeof(MapContainer), 0,
544 Py_TPFLAGS_DEFAULT, ScalarMapContainer_Type_slots};
545
546 // MessageMap //////////////////////////////////////////////////////////////////
547
GetMessageMap(PyObject * obj)548 static MessageMapContainer* GetMessageMap(PyObject* obj) {
549 return reinterpret_cast<MessageMapContainer*>(obj);
550 }
551
GetCMessage(MessageMapContainer * self,Message * message)552 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
553 // Get or create the CMessage object corresponding to this message.
554 return self->parent
555 ->BuildSubMessageFromPointer(self->parent_field_descriptor, message,
556 self->message_class)
557 ->AsPyObject();
558 }
559
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)560 MessageMapContainer* NewMessageMapContainer(
561 CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
562 CMessageClass* message_class) {
563 if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
564 return nullptr;
565 }
566
567 PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
568 if (obj == nullptr) {
569 PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container.");
570 return nullptr;
571 }
572
573 MessageMapContainer* self = GetMessageMap(obj);
574
575 Py_INCREF(parent);
576 self->parent = parent;
577 self->parent_field_descriptor = parent_field_descriptor;
578 self->version = 0;
579
580 Py_INCREF(message_class);
581 self->message_class = message_class;
582
583 return self;
584 }
585
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)586 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
587 PyObject* v) {
588 if (v) {
589 PyErr_Format(PyExc_ValueError,
590 "Direct assignment of submessage not allowed");
591 return -1;
592 }
593
594 // Now we know that this is a delete, not a set.
595
596 MessageMapContainer* self = GetMessageMap(_self);
597 Message* message = self->GetMutableMessage();
598 const Reflection* reflection = message->GetReflection();
599 std::string map_key_string;
600 MapKey map_key;
601 MapValueRef value;
602
603 self->version++;
604
605 if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
606 return -1;
607 }
608
609 // Delete key from map.
610 if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
611 map_key)) {
612 // Delete key from CMessage dict.
613 MapValueRef value;
614 reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
615 map_key, &value);
616 Message* sub_message = value.MutableMessageValue();
617 // If there is a living weak reference to an item, we "Release" it,
618 // otherwise we just discard the C++ value.
619 if (CMessage* released =
620 self->parent->MaybeReleaseSubMessage(sub_message)) {
621 Message* msg = released->message;
622 released->message = msg->New();
623 msg->GetReflection()->Swap(msg, released->message);
624 }
625
626 // Delete key from map.
627 reflection->DeleteMapValue(message, self->parent_field_descriptor,
628 map_key);
629 return 0;
630 } else {
631 PyErr_Format(PyExc_KeyError, "Key not present in map");
632 return -1;
633 }
634 }
635
MessageMapGetItem(PyObject * _self,PyObject * key)636 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
637 PyObject* key) {
638 MessageMapContainer* self = GetMessageMap(_self);
639
640 Message* message = self->GetMutableMessage();
641 const Reflection* reflection = message->GetReflection();
642 std::string map_key_string;
643 MapKey map_key;
644 MapValueRef value;
645
646 if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
647 return nullptr;
648 }
649
650 if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
651 map_key, &value)) {
652 self->version++;
653 }
654
655 return GetCMessage(self, value.MutableMessageValue());
656 }
657
MessageMapToStr(PyObject * _self)658 PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
659 ScopedPyObjectPtr dict(PyDict_New());
660 if (dict == nullptr) {
661 return nullptr;
662 }
663 ScopedPyObjectPtr key;
664 ScopedPyObjectPtr value;
665
666 MessageMapContainer* self = GetMessageMap(_self);
667 Message* message = self->GetMutableMessage();
668 const Reflection* reflection = message->GetReflection();
669 for (google::protobuf::MapIterator it = reflection->MapBegin(
670 message, self->parent_field_descriptor);
671 it != reflection->MapEnd(message, self->parent_field_descriptor);
672 ++it) {
673 key.reset(MapKeyToPython(self, it.GetKey()));
674 if (key == nullptr) {
675 return nullptr;
676 }
677 value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue()));
678 if (value == nullptr) {
679 return nullptr;
680 }
681 if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
682 return nullptr;
683 }
684 }
685 return PyObject_Repr(dict.get());
686 }
687
MessageMapGet(PyObject * self,PyObject * args,PyObject * kwargs)688 PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
689 static const char* kwlist[] = {"key", "default", nullptr};
690 PyObject* key;
691 PyObject* default_value = nullptr;
692 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
693 const_cast<char**>(kwlist), &key,
694 &default_value)) {
695 return nullptr;
696 }
697
698 ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
699 if (is_present.get() == nullptr) {
700 return nullptr;
701 }
702
703 if (PyObject_IsTrue(is_present.get())) {
704 return MapReflectionFriend::MessageMapGetItem(self, key);
705 } else {
706 if (default_value != nullptr) {
707 Py_INCREF(default_value);
708 return default_value;
709 } else {
710 Py_RETURN_NONE;
711 }
712 }
713 }
714
MessageMapDealloc(PyObject * _self)715 static void MessageMapDealloc(PyObject* _self) {
716 MessageMapContainer* self = GetMessageMap(_self);
717 self->RemoveFromParentCache();
718 Py_DECREF(self->message_class);
719 PyTypeObject *type = Py_TYPE(_self);
720 type->tp_free(_self);
721 if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
722 // With Python3, the Map class is not static, and must be managed.
723 Py_DECREF(type);
724 }
725 }
726
727 static PyMethodDef MessageMapMethods[] = {
728 {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
729 "Tests whether the map contains this element."},
730 {"clear", (PyCFunction)Clear, METH_NOARGS,
731 "Removes all elements from the map."},
732 {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
733 "Gets the value for the given key if present, or otherwise a default"},
734 {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
735 "Alias for getitem, useful to make explicit that the map is mutated."},
736 {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
737 "Return the class used to build Entries of (key, value) pairs."},
738 {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
739 "Merges a map into the current map."},
740 /*
741 { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
742 "Makes a deep copy of the class." },
743 { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
744 "Outputs picklable representation of the repeated field." },
745 */
746 {nullptr, nullptr},
747 };
748
749 PyTypeObject* MessageMapContainer_Type;
750 static PyType_Slot MessageMapContainer_Type_slots[] = {
751 {Py_tp_dealloc, (void*)MessageMapDealloc},
752 {Py_mp_length, (void*)MapReflectionFriend::Length},
753 {Py_mp_subscript, (void*)MapReflectionFriend::MessageMapGetItem},
754 {Py_mp_ass_subscript, (void*)MapReflectionFriend::MessageMapSetItem},
755 {Py_tp_methods, (void*)MessageMapMethods},
756 {Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
757 {Py_tp_repr, (void*)MapReflectionFriend::MessageMapToStr},
758 {0, nullptr}};
759
760 PyType_Spec MessageMapContainer_Type_spec = {
761 FULL_MODULE_NAME ".MessageMapContainer", sizeof(MessageMapContainer), 0,
762 Py_TPFLAGS_DEFAULT, MessageMapContainer_Type_slots};
763
764 // MapIterator /////////////////////////////////////////////////////////////////
765
GetIter(PyObject * obj)766 static MapIterator* GetIter(PyObject* obj) {
767 return reinterpret_cast<MapIterator*>(obj);
768 }
769
GetIterator(PyObject * _self)770 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
771 MapContainer* self = GetMap(_self);
772
773 ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
774 if (obj == nullptr) {
775 return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
776 }
777
778 MapIterator* iter = GetIter(obj.get());
779
780 Py_INCREF(self);
781 iter->container = self;
782 iter->version = self->version;
783 Py_INCREF(self->parent);
784 iter->parent = self->parent;
785
786 if (MapReflectionFriend::Length(_self) > 0) {
787 Message* message = self->GetMutableMessage();
788 const Reflection* reflection = message->GetReflection();
789
790 iter->iter.reset(new ::google::protobuf::MapIterator(
791 reflection->MapBegin(message, self->parent_field_descriptor)));
792 }
793
794 return obj.release();
795 }
796
IterNext(PyObject * _self)797 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
798 MapIterator* self = GetIter(_self);
799
800 // This won't catch mutations to the map performed by MergeFrom(); no easy way
801 // to address that.
802 if (self->version != self->container->version) {
803 return PyErr_Format(PyExc_RuntimeError,
804 "Map modified during iteration.");
805 }
806 if (self->parent != self->container->parent) {
807 return PyErr_Format(PyExc_RuntimeError,
808 "Map cleared during iteration.");
809 }
810
811 if (self->iter.get() == nullptr) {
812 return nullptr;
813 }
814
815 Message* message = self->container->GetMutableMessage();
816 const Reflection* reflection = message->GetReflection();
817
818 if (*self->iter ==
819 reflection->MapEnd(message, self->container->parent_field_descriptor)) {
820 return nullptr;
821 }
822
823 PyObject* ret = MapKeyToPython(self->container, self->iter->GetKey());
824
825 ++(*self->iter);
826
827 return ret;
828 }
829
DeallocMapIterator(PyObject * _self)830 static void DeallocMapIterator(PyObject* _self) {
831 MapIterator* self = GetIter(_self);
832 self->iter.reset();
833 Py_CLEAR(self->container);
834 Py_CLEAR(self->parent);
835 Py_TYPE(_self)->tp_free(_self);
836 }
837
838 PyTypeObject MapIterator_Type = {
839 PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
840 ".MapIterator", // tp_name
841 sizeof(MapIterator), // tp_basicsize
842 0, // tp_itemsize
843 DeallocMapIterator, // tp_dealloc
844 #if PY_VERSION_HEX < 0x03080000
845 nullptr, // tp_print
846 #else
847 0, // tp_vectorcall_offset
848 #endif
849 nullptr, // tp_getattr
850 nullptr, // tp_setattr
851 nullptr, // tp_compare
852 nullptr, // tp_repr
853 nullptr, // tp_as_number
854 nullptr, // tp_as_sequence
855 nullptr, // tp_as_mapping
856 nullptr, // tp_hash
857 nullptr, // tp_call
858 nullptr, // tp_str
859 nullptr, // tp_getattro
860 nullptr, // tp_setattro
861 nullptr, // tp_as_buffer
862 Py_TPFLAGS_DEFAULT, // tp_flags
863 "A scalar map iterator", // tp_doc
864 nullptr, // tp_traverse
865 nullptr, // tp_clear
866 nullptr, // tp_richcompare
867 0, // tp_weaklistoffset
868 PyObject_SelfIter, // tp_iter
869 MapReflectionFriend::IterNext, // tp_iternext
870 nullptr, // tp_methods
871 nullptr, // tp_members
872 nullptr, // tp_getset
873 nullptr, // tp_base
874 nullptr, // tp_dict
875 nullptr, // tp_descr_get
876 nullptr, // tp_descr_set
877 0, // tp_dictoffset
878 nullptr, // tp_init
879 };
880
InitMapContainers()881 bool InitMapContainers() {
882 // ScalarMapContainer_Type derives from our MutableMapping type.
883 ScopedPyObjectPtr abc(PyImport_ImportModule("collections.abc"));
884 if (abc == nullptr) {
885 return false;
886 }
887
888 ScopedPyObjectPtr mutable_mapping(
889 PyObject_GetAttrString(abc.get(), "MutableMapping"));
890 if (mutable_mapping == nullptr) {
891 return false;
892 }
893
894 Py_INCREF(mutable_mapping.get());
895 ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
896 if (bases == nullptr) {
897 return false;
898 }
899
900 ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
901 PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
902
903 if (PyType_Ready(&MapIterator_Type) < 0) {
904 return false;
905 }
906
907 MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
908 PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
909 return true;
910 }
911
912 } // namespace python
913 } // namespace protobuf
914 } // namespace google
915