• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: haberman@google.com (Josh Haberman)
9 
10 #include "google/protobuf/pyext/map_container.h"
11 
12 #include <cstdint>
13 #include <memory>
14 #include <string>
15 
16 #include "google/protobuf/map.h"
17 #include "google/protobuf/map_field.h"
18 #include "google/protobuf/message.h"
19 #include "google/protobuf/pyext/message.h"
20 #include "google/protobuf/pyext/message_factory.h"
21 #include "google/protobuf/pyext/repeated_composite_container.h"
22 #include "google/protobuf/pyext/scoped_pyobject_ptr.h"
23 
24 namespace google {
25 namespace protobuf {
26 namespace python {
27 
28 // Functions that need access to map reflection functionality.
29 // They need to be contained in this class because it is friended.
30 class MapReflectionFriend {
31  public:
32   // Methods that are in common between the map types.
33   static PyObject* Contains(PyObject* _self, PyObject* key);
34   static Py_ssize_t Length(PyObject* _self);
35   static PyObject* GetIterator(PyObject *_self);
36   static PyObject* IterNext(PyObject* _self);
37   static PyObject* MergeFrom(PyObject* _self, PyObject* arg);
38 
39   // Methods that differ between the map types.
40   static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
41   static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
42   static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
43   static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
44   static PyObject* ScalarMapToStr(PyObject* _self);
45   static PyObject* MessageMapToStr(PyObject* _self);
46 };
47 
48 struct MapIterator {
49   PyObject_HEAD;
50 
51   std::unique_ptr<::google::protobuf::MapIterator> iter;
52 
53   // A pointer back to the container, so we can notice changes to the version.
54   // We own a ref on this.
55   MapContainer* container;
56 
57   // We need to keep a ref on the parent Message too, because
58   // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
59   // the ref on container (above) would guarantee outlive semantics.  However in
60   // the case of ClearField(), the MapContainer points to a different message,
61   // a copy of the original.  But our iterator still points to the original,
62   // which could now get deleted before us.
63   //
64   // To prevent this, we ensure that the Message will always stay alive as long
65   // as this iterator does.  This is solely for the benefit of the MapIterator
66   // destructor -- we should never actually access the iterator in this state
67   // except to delete it.
68   CMessage* parent;
69   // The version of the map when we took the iterator to it.
70   //
71   // We store this so that if the map is modified during iteration we can throw
72   // an error.
73   uint64_t version;
74 };
75 
GetMutableMessage()76 Message* MapContainer::GetMutableMessage() {
77   cmessage::AssureWritable(parent);
78   return parent->message;
79 }
80 
81 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,std::string * stl_string)82 static bool PyStringToSTL(PyObject* py_string, std::string* stl_string) {
83   char *value;
84   Py_ssize_t value_len;
85 
86   if (!py_string) {
87     return false;
88   }
89   if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
90     Py_DECREF(py_string);
91     return false;
92   } else {
93     stl_string->assign(value, value_len);
94     Py_DECREF(py_string);
95     return true;
96   }
97 }
98 
PythonToMapKey(MapContainer * self,PyObject * obj,MapKey * key,std::string * key_string)99 static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key,
100                            std::string* key_string) {
101   const FieldDescriptor* field_descriptor =
102       self->parent_field_descriptor->message_type()->map_key();
103   switch (field_descriptor->cpp_type()) {
104     case FieldDescriptor::CPPTYPE_INT32: {
105       PROTOBUF_CHECK_GET_INT32(obj, value, false);
106       key->SetInt32Value(value);
107       break;
108     }
109     case FieldDescriptor::CPPTYPE_INT64: {
110       PROTOBUF_CHECK_GET_INT64(obj, value, false);
111       key->SetInt64Value(value);
112       break;
113     }
114     case FieldDescriptor::CPPTYPE_UINT32: {
115       PROTOBUF_CHECK_GET_UINT32(obj, value, false);
116       key->SetUInt32Value(value);
117       break;
118     }
119     case FieldDescriptor::CPPTYPE_UINT64: {
120       PROTOBUF_CHECK_GET_UINT64(obj, value, false);
121       key->SetUInt64Value(value);
122       break;
123     }
124     case FieldDescriptor::CPPTYPE_BOOL: {
125       PROTOBUF_CHECK_GET_BOOL(obj, value, false);
126       key->SetBoolValue(value);
127       break;
128     }
129     case FieldDescriptor::CPPTYPE_STRING: {
130       if (!PyStringToSTL(CheckString(obj, field_descriptor), key_string)) {
131         return false;
132       }
133       key->SetStringValue(*key_string);
134       break;
135     }
136     default:
137       PyErr_Format(
138           PyExc_SystemError, "Type %d cannot be a map key",
139           field_descriptor->cpp_type());
140       return false;
141   }
142   return true;
143 }
144 
MapKeyToPython(MapContainer * self,const MapKey & key)145 static PyObject* MapKeyToPython(MapContainer* self, const MapKey& key) {
146   const FieldDescriptor* field_descriptor =
147       self->parent_field_descriptor->message_type()->map_key();
148   switch (field_descriptor->cpp_type()) {
149     case FieldDescriptor::CPPTYPE_INT32:
150       return PyLong_FromLong(key.GetInt32Value());
151     case FieldDescriptor::CPPTYPE_INT64:
152       return PyLong_FromLongLong(key.GetInt64Value());
153     case FieldDescriptor::CPPTYPE_UINT32:
154       return PyLong_FromSize_t(key.GetUInt32Value());
155     case FieldDescriptor::CPPTYPE_UINT64:
156       return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
157     case FieldDescriptor::CPPTYPE_BOOL:
158       return PyBool_FromLong(key.GetBoolValue());
159     case FieldDescriptor::CPPTYPE_STRING:
160       return ToStringObject(field_descriptor, key.GetStringValue());
161     default:
162       PyErr_Format(
163           PyExc_SystemError, "Couldn't convert type %d to value",
164           field_descriptor->cpp_type());
165       return nullptr;
166   }
167 }
168 
169 // This is only used for ScalarMap, so we don't need to handle the
170 // CPPTYPE_MESSAGE case.
MapValueRefToPython(MapContainer * self,const MapValueRef & value)171 PyObject* MapValueRefToPython(MapContainer* self, const MapValueRef& value) {
172   const FieldDescriptor* field_descriptor =
173       self->parent_field_descriptor->message_type()->map_value();
174   switch (field_descriptor->cpp_type()) {
175     case FieldDescriptor::CPPTYPE_INT32:
176       return PyLong_FromLong(value.GetInt32Value());
177     case FieldDescriptor::CPPTYPE_INT64:
178       return PyLong_FromLongLong(value.GetInt64Value());
179     case FieldDescriptor::CPPTYPE_UINT32:
180       return PyLong_FromSize_t(value.GetUInt32Value());
181     case FieldDescriptor::CPPTYPE_UINT64:
182       return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
183     case FieldDescriptor::CPPTYPE_FLOAT:
184       return PyFloat_FromDouble(value.GetFloatValue());
185     case FieldDescriptor::CPPTYPE_DOUBLE:
186       return PyFloat_FromDouble(value.GetDoubleValue());
187     case FieldDescriptor::CPPTYPE_BOOL:
188       return PyBool_FromLong(value.GetBoolValue());
189     case FieldDescriptor::CPPTYPE_STRING:
190       return ToStringObject(field_descriptor, value.GetStringValue());
191     case FieldDescriptor::CPPTYPE_ENUM:
192       return PyLong_FromLong(value.GetEnumValue());
193     default:
194       PyErr_Format(
195           PyExc_SystemError, "Couldn't convert type %d to value",
196           field_descriptor->cpp_type());
197       return nullptr;
198   }
199 }
200 
201 // This is only used for ScalarMap, so we don't need to handle the
202 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(MapContainer * self,PyObject * obj,bool allow_unknown_enum_values,MapValueRef * value_ref)203 static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
204                                 bool allow_unknown_enum_values,
205                                 MapValueRef* value_ref) {
206   const FieldDescriptor* field_descriptor =
207       self->parent_field_descriptor->message_type()->map_value();
208   switch (field_descriptor->cpp_type()) {
209     case FieldDescriptor::CPPTYPE_INT32: {
210       PROTOBUF_CHECK_GET_INT32(obj, value, false);
211       value_ref->SetInt32Value(value);
212       return true;
213     }
214     case FieldDescriptor::CPPTYPE_INT64: {
215       PROTOBUF_CHECK_GET_INT64(obj, value, false);
216       value_ref->SetInt64Value(value);
217       return true;
218     }
219     case FieldDescriptor::CPPTYPE_UINT32: {
220       PROTOBUF_CHECK_GET_UINT32(obj, value, false);
221       value_ref->SetUInt32Value(value);
222       return true;
223     }
224     case FieldDescriptor::CPPTYPE_UINT64: {
225       PROTOBUF_CHECK_GET_UINT64(obj, value, false);
226       value_ref->SetUInt64Value(value);
227       return true;
228     }
229     case FieldDescriptor::CPPTYPE_FLOAT: {
230       PROTOBUF_CHECK_GET_FLOAT(obj, value, false);
231       value_ref->SetFloatValue(value);
232       return true;
233     }
234     case FieldDescriptor::CPPTYPE_DOUBLE: {
235       PROTOBUF_CHECK_GET_DOUBLE(obj, value, false);
236       value_ref->SetDoubleValue(value);
237       return true;
238     }
239     case FieldDescriptor::CPPTYPE_BOOL: {
240       PROTOBUF_CHECK_GET_BOOL(obj, value, false);
241       value_ref->SetBoolValue(value);
242       return true;
243     }
244     case FieldDescriptor::CPPTYPE_STRING: {
245       std::string str;
246       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
247         return false;
248       }
249       value_ref->SetStringValue(str);
250       return true;
251     }
252     case FieldDescriptor::CPPTYPE_ENUM: {
253       PROTOBUF_CHECK_GET_INT32(obj, value, false);
254       if (allow_unknown_enum_values) {
255         value_ref->SetEnumValue(value);
256         return true;
257       } else {
258         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
259         const EnumValueDescriptor* enum_value =
260             enum_descriptor->FindValueByNumber(value);
261         if (enum_value != nullptr) {
262           value_ref->SetEnumValue(value);
263           return true;
264         } else {
265           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
266           return false;
267         }
268       }
269       break;
270     }
271     default:
272       PyErr_Format(
273           PyExc_SystemError, "Setting value to a field of unknown type %d",
274           field_descriptor->cpp_type());
275       return false;
276   }
277 }
278 
279 // Map methods common to ScalarMap and MessageMap //////////////////////////////
280 
GetMap(PyObject * obj)281 static MapContainer* GetMap(PyObject* obj) {
282   return reinterpret_cast<MapContainer*>(obj);
283 }
284 
Length(PyObject * _self)285 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
286   MapContainer* self = GetMap(_self);
287   const google::protobuf::Message* message = self->parent->message;
288   return message->GetReflection()->MapSize(*message,
289                                            self->parent_field_descriptor);
290 }
291 
Clear(PyObject * _self)292 PyObject* Clear(PyObject* _self) {
293   MapContainer* self = GetMap(_self);
294   Message* message = self->GetMutableMessage();
295   const Reflection* reflection = message->GetReflection();
296 
297   reflection->ClearField(message, self->parent_field_descriptor);
298 
299   Py_RETURN_NONE;
300 }
301 
GetEntryClass(PyObject * _self)302 PyObject* GetEntryClass(PyObject* _self) {
303   MapContainer* self = GetMap(_self);
304   CMessageClass* message_class = message_factory::GetMessageClass(
305       cmessage::GetFactoryForMessage(self->parent),
306       self->parent_field_descriptor->message_type());
307   Py_XINCREF(message_class);
308   return reinterpret_cast<PyObject*>(message_class);
309 }
310 
MergeFrom(PyObject * _self,PyObject * arg)311 PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
312   MapContainer* self = GetMap(_self);
313   if (!PyObject_TypeCheck(arg, ScalarMapContainer_Type) &&
314       !PyObject_TypeCheck(arg, MessageMapContainer_Type)) {
315     PyErr_SetString(PyExc_AttributeError, "Not a map field");
316     return nullptr;
317   }
318   MapContainer* other_map = GetMap(arg);
319   Message* message = self->GetMutableMessage();
320   const Message* other_message = other_map->parent->message;
321   const Reflection* reflection = message->GetReflection();
322   const Reflection* other_reflection = other_message->GetReflection();
323   internal::MapFieldBase* field = reflection->MutableMapData(
324       message, self->parent_field_descriptor);
325   const internal::MapFieldBase* other_field = other_reflection->GetMapData(
326       *other_message, other_map->parent_field_descriptor);
327   field->MergeFrom(*other_field);
328   self->version++;
329   Py_RETURN_NONE;
330 }
331 
Contains(PyObject * _self,PyObject * key)332 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
333   MapContainer* self = GetMap(_self);
334 
335   const Message* message = self->parent->message;
336   const Reflection* reflection = message->GetReflection();
337   std::string map_key_string;
338   MapKey map_key;
339 
340   if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
341     return nullptr;
342   }
343 
344   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
345                                  map_key)) {
346     Py_RETURN_TRUE;
347   } else {
348     Py_RETURN_FALSE;
349   }
350 }
351 
352 // ScalarMap ///////////////////////////////////////////////////////////////////
353 
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)354 MapContainer* NewScalarMapContainer(
355     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
356   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
357     return nullptr;
358   }
359 
360   PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
361   if (obj == nullptr) {
362     PyErr_Format(PyExc_RuntimeError,
363                  "Could not allocate new container.");
364     return nullptr;
365   }
366 
367   MapContainer* self = GetMap(obj);
368 
369   Py_INCREF(parent);
370   self->parent = parent;
371   self->parent_field_descriptor = parent_field_descriptor;
372   self->version = 0;
373 
374   return self;
375 }
376 
ScalarMapGetItem(PyObject * _self,PyObject * key)377 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
378                                                 PyObject* key) {
379   MapContainer* self = GetMap(_self);
380 
381   Message* message = self->GetMutableMessage();
382   const Reflection* reflection = message->GetReflection();
383   std::string map_key_string;
384   MapKey map_key;
385   MapValueRef value;
386 
387   if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
388     return nullptr;
389   }
390 
391   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
392                                          map_key, &value)) {
393     self->version++;
394   }
395 
396   return MapValueRefToPython(self, value);
397 }
398 
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)399 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
400                                           PyObject* v) {
401   MapContainer* self = GetMap(_self);
402 
403   Message* message = self->GetMutableMessage();
404   const Reflection* reflection = message->GetReflection();
405   std::string map_key_string;
406   MapKey map_key;
407   MapValueRef value;
408 
409   if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
410     return -1;
411   }
412 
413   if (v) {
414     // Set item to v.
415     if (reflection->InsertOrLookupMapValue(
416             message, self->parent_field_descriptor, map_key, &value)) {
417       self->version++;
418     }
419 
420     if (!PythonToMapValueRef(self, v,
421                              !self->parent_field_descriptor->message_type()
422                                   ->map_value()
423                                   ->legacy_enum_field_treated_as_closed(),
424                              &value)) {
425       return -1;
426     }
427     return 0;
428   } else {
429     // Delete key from map.
430     if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
431                                    map_key)) {
432       self->version++;
433       return 0;
434     } else {
435       PyErr_Format(PyExc_KeyError, "Key not present in map");
436       return -1;
437     }
438   }
439 }
440 
ScalarMapGet(PyObject * self,PyObject * args,PyObject * kwargs)441 static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
442                               PyObject* kwargs) {
443   static const char* kwlist[] = {"key", "default", nullptr};
444   PyObject* key;
445   PyObject* default_value = nullptr;
446   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
447                                    const_cast<char**>(kwlist), &key,
448                                    &default_value)) {
449     return nullptr;
450   }
451 
452   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
453   if (is_present.get() == nullptr) {
454     return nullptr;
455   }
456 
457   if (PyObject_IsTrue(is_present.get())) {
458     return MapReflectionFriend::ScalarMapGetItem(self, key);
459   } else {
460     if (default_value != nullptr) {
461       Py_INCREF(default_value);
462       return default_value;
463     } else {
464       Py_RETURN_NONE;
465     }
466   }
467 }
468 
ScalarMapToStr(PyObject * _self)469 PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
470   ScopedPyObjectPtr dict(PyDict_New());
471   if (dict == nullptr) {
472     return nullptr;
473   }
474   ScopedPyObjectPtr key;
475   ScopedPyObjectPtr value;
476 
477   MapContainer* self = GetMap(_self);
478   Message* message = self->GetMutableMessage();
479   const Reflection* reflection = message->GetReflection();
480   for (google::protobuf::MapIterator it = reflection->MapBegin(
481            message, self->parent_field_descriptor);
482        it != reflection->MapEnd(message, self->parent_field_descriptor);
483        ++it) {
484     key.reset(MapKeyToPython(self, it.GetKey()));
485     if (key == nullptr) {
486       return nullptr;
487     }
488     value.reset(MapValueRefToPython(self, it.GetValueRef()));
489     if (value == nullptr) {
490       return nullptr;
491     }
492     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
493       return nullptr;
494     }
495   }
496   return PyObject_Repr(dict.get());
497 }
498 
ScalarMapDealloc(PyObject * _self)499 static void ScalarMapDealloc(PyObject* _self) {
500   MapContainer* self = GetMap(_self);
501   self->RemoveFromParentCache();
502   PyTypeObject *type = Py_TYPE(_self);
503   type->tp_free(_self);
504   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
505     // With Python3, the Map class is not static, and must be managed.
506     Py_DECREF(type);
507   }
508 }
509 
510 static PyMethodDef ScalarMapMethods[] = {
511     {"__contains__", MapReflectionFriend::Contains, METH_O,
512      "Tests whether a key is a member of the map."},
513     {"clear", (PyCFunction)Clear, METH_NOARGS,
514      "Removes all elements from the map."},
515     {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
516      "Gets the value for the given key if present, or otherwise a default"},
517     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
518      "Return the class used to build Entries of (key, value) pairs."},
519     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
520      "Merges a map into the current map."},
521     /*
522     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
523       "Makes a deep copy of the class." },
524     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
525       "Outputs picklable representation of the repeated field." },
526     */
527     {nullptr, nullptr},
528 };
529 
530 PyTypeObject* ScalarMapContainer_Type;
531 static PyType_Slot ScalarMapContainer_Type_slots[] = {
532     {Py_tp_dealloc, (void*)ScalarMapDealloc},
533     {Py_mp_length, (void*)MapReflectionFriend::Length},
534     {Py_mp_subscript, (void*)MapReflectionFriend::ScalarMapGetItem},
535     {Py_mp_ass_subscript, (void*)MapReflectionFriend::ScalarMapSetItem},
536     {Py_tp_methods, (void*)ScalarMapMethods},
537     {Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
538     {Py_tp_repr, (void*)MapReflectionFriend::ScalarMapToStr},
539     {0, nullptr},
540 };
541 
542 PyType_Spec ScalarMapContainer_Type_spec = {
543     FULL_MODULE_NAME ".ScalarMapContainer", sizeof(MapContainer), 0,
544     Py_TPFLAGS_DEFAULT, ScalarMapContainer_Type_slots};
545 
546 // MessageMap //////////////////////////////////////////////////////////////////
547 
GetMessageMap(PyObject * obj)548 static MessageMapContainer* GetMessageMap(PyObject* obj) {
549   return reinterpret_cast<MessageMapContainer*>(obj);
550 }
551 
GetCMessage(MessageMapContainer * self,Message * message)552 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
553   // Get or create the CMessage object corresponding to this message.
554   return self->parent
555       ->BuildSubMessageFromPointer(self->parent_field_descriptor, message,
556                                    self->message_class)
557       ->AsPyObject();
558 }
559 
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)560 MessageMapContainer* NewMessageMapContainer(
561     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
562     CMessageClass* message_class) {
563   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
564     return nullptr;
565   }
566 
567   PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
568   if (obj == nullptr) {
569     PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container.");
570     return nullptr;
571   }
572 
573   MessageMapContainer* self = GetMessageMap(obj);
574 
575   Py_INCREF(parent);
576   self->parent = parent;
577   self->parent_field_descriptor = parent_field_descriptor;
578   self->version = 0;
579 
580   Py_INCREF(message_class);
581   self->message_class = message_class;
582 
583   return self;
584 }
585 
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)586 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
587                                            PyObject* v) {
588   if (v) {
589     PyErr_Format(PyExc_ValueError,
590                  "Direct assignment of submessage not allowed");
591     return -1;
592   }
593 
594   // Now we know that this is a delete, not a set.
595 
596   MessageMapContainer* self = GetMessageMap(_self);
597   Message* message = self->GetMutableMessage();
598   const Reflection* reflection = message->GetReflection();
599   std::string map_key_string;
600   MapKey map_key;
601   MapValueRef value;
602 
603   self->version++;
604 
605   if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
606     return -1;
607   }
608 
609   // Delete key from map.
610   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
611                                  map_key)) {
612     // Delete key from CMessage dict.
613     MapValueRef value;
614     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
615                                        map_key, &value);
616     Message* sub_message = value.MutableMessageValue();
617     // If there is a living weak reference to an item, we "Release" it,
618     // otherwise we just discard the C++ value.
619     if (CMessage* released =
620             self->parent->MaybeReleaseSubMessage(sub_message)) {
621       Message* msg = released->message;
622       released->message = msg->New();
623       msg->GetReflection()->Swap(msg, released->message);
624     }
625 
626     // Delete key from map.
627     reflection->DeleteMapValue(message, self->parent_field_descriptor,
628                                map_key);
629     return 0;
630   } else {
631     PyErr_Format(PyExc_KeyError, "Key not present in map");
632     return -1;
633   }
634 }
635 
MessageMapGetItem(PyObject * _self,PyObject * key)636 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
637                                                  PyObject* key) {
638   MessageMapContainer* self = GetMessageMap(_self);
639 
640   Message* message = self->GetMutableMessage();
641   const Reflection* reflection = message->GetReflection();
642   std::string map_key_string;
643   MapKey map_key;
644   MapValueRef value;
645 
646   if (!PythonToMapKey(self, key, &map_key, &map_key_string)) {
647     return nullptr;
648   }
649 
650   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
651                                          map_key, &value)) {
652     self->version++;
653   }
654 
655   return GetCMessage(self, value.MutableMessageValue());
656 }
657 
MessageMapToStr(PyObject * _self)658 PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
659   ScopedPyObjectPtr dict(PyDict_New());
660   if (dict == nullptr) {
661     return nullptr;
662   }
663   ScopedPyObjectPtr key;
664   ScopedPyObjectPtr value;
665 
666   MessageMapContainer* self = GetMessageMap(_self);
667   Message* message = self->GetMutableMessage();
668   const Reflection* reflection = message->GetReflection();
669   for (google::protobuf::MapIterator it = reflection->MapBegin(
670            message, self->parent_field_descriptor);
671        it != reflection->MapEnd(message, self->parent_field_descriptor);
672        ++it) {
673     key.reset(MapKeyToPython(self, it.GetKey()));
674     if (key == nullptr) {
675       return nullptr;
676     }
677     value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue()));
678     if (value == nullptr) {
679       return nullptr;
680     }
681     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
682       return nullptr;
683     }
684   }
685   return PyObject_Repr(dict.get());
686 }
687 
MessageMapGet(PyObject * self,PyObject * args,PyObject * kwargs)688 PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
689   static const char* kwlist[] = {"key", "default", nullptr};
690   PyObject* key;
691   PyObject* default_value = nullptr;
692   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
693                                    const_cast<char**>(kwlist), &key,
694                                    &default_value)) {
695     return nullptr;
696   }
697 
698   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
699   if (is_present.get() == nullptr) {
700     return nullptr;
701   }
702 
703   if (PyObject_IsTrue(is_present.get())) {
704     return MapReflectionFriend::MessageMapGetItem(self, key);
705   } else {
706     if (default_value != nullptr) {
707       Py_INCREF(default_value);
708       return default_value;
709     } else {
710       Py_RETURN_NONE;
711     }
712   }
713 }
714 
MessageMapDealloc(PyObject * _self)715 static void MessageMapDealloc(PyObject* _self) {
716   MessageMapContainer* self = GetMessageMap(_self);
717   self->RemoveFromParentCache();
718   Py_DECREF(self->message_class);
719   PyTypeObject *type = Py_TYPE(_self);
720   type->tp_free(_self);
721   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
722     // With Python3, the Map class is not static, and must be managed.
723     Py_DECREF(type);
724   }
725 }
726 
727 static PyMethodDef MessageMapMethods[] = {
728     {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
729      "Tests whether the map contains this element."},
730     {"clear", (PyCFunction)Clear, METH_NOARGS,
731      "Removes all elements from the map."},
732     {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
733      "Gets the value for the given key if present, or otherwise a default"},
734     {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
735      "Alias for getitem, useful to make explicit that the map is mutated."},
736     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
737      "Return the class used to build Entries of (key, value) pairs."},
738     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
739      "Merges a map into the current map."},
740     /*
741     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
742       "Makes a deep copy of the class." },
743     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
744       "Outputs picklable representation of the repeated field." },
745     */
746     {nullptr, nullptr},
747 };
748 
749 PyTypeObject* MessageMapContainer_Type;
750 static PyType_Slot MessageMapContainer_Type_slots[] = {
751     {Py_tp_dealloc, (void*)MessageMapDealloc},
752     {Py_mp_length, (void*)MapReflectionFriend::Length},
753     {Py_mp_subscript, (void*)MapReflectionFriend::MessageMapGetItem},
754     {Py_mp_ass_subscript, (void*)MapReflectionFriend::MessageMapSetItem},
755     {Py_tp_methods, (void*)MessageMapMethods},
756     {Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
757     {Py_tp_repr, (void*)MapReflectionFriend::MessageMapToStr},
758     {0, nullptr}};
759 
760 PyType_Spec MessageMapContainer_Type_spec = {
761     FULL_MODULE_NAME ".MessageMapContainer", sizeof(MessageMapContainer), 0,
762     Py_TPFLAGS_DEFAULT, MessageMapContainer_Type_slots};
763 
764 // MapIterator /////////////////////////////////////////////////////////////////
765 
GetIter(PyObject * obj)766 static MapIterator* GetIter(PyObject* obj) {
767   return reinterpret_cast<MapIterator*>(obj);
768 }
769 
GetIterator(PyObject * _self)770 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
771   MapContainer* self = GetMap(_self);
772 
773   ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
774   if (obj == nullptr) {
775     return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
776   }
777 
778   MapIterator* iter = GetIter(obj.get());
779 
780   Py_INCREF(self);
781   iter->container = self;
782   iter->version = self->version;
783   Py_INCREF(self->parent);
784   iter->parent = self->parent;
785 
786   if (MapReflectionFriend::Length(_self) > 0) {
787     Message* message = self->GetMutableMessage();
788     const Reflection* reflection = message->GetReflection();
789 
790     iter->iter.reset(new ::google::protobuf::MapIterator(
791         reflection->MapBegin(message, self->parent_field_descriptor)));
792   }
793 
794   return obj.release();
795 }
796 
IterNext(PyObject * _self)797 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
798   MapIterator* self = GetIter(_self);
799 
800   // This won't catch mutations to the map performed by MergeFrom(); no easy way
801   // to address that.
802   if (self->version != self->container->version) {
803     return PyErr_Format(PyExc_RuntimeError,
804                         "Map modified during iteration.");
805   }
806   if (self->parent != self->container->parent) {
807     return PyErr_Format(PyExc_RuntimeError,
808                         "Map cleared during iteration.");
809   }
810 
811   if (self->iter.get() == nullptr) {
812     return nullptr;
813   }
814 
815   Message* message = self->container->GetMutableMessage();
816   const Reflection* reflection = message->GetReflection();
817 
818   if (*self->iter ==
819       reflection->MapEnd(message, self->container->parent_field_descriptor)) {
820     return nullptr;
821   }
822 
823   PyObject* ret = MapKeyToPython(self->container, self->iter->GetKey());
824 
825   ++(*self->iter);
826 
827   return ret;
828 }
829 
DeallocMapIterator(PyObject * _self)830 static void DeallocMapIterator(PyObject* _self) {
831   MapIterator* self = GetIter(_self);
832   self->iter.reset();
833   Py_CLEAR(self->container);
834   Py_CLEAR(self->parent);
835   Py_TYPE(_self)->tp_free(_self);
836 }
837 
838 PyTypeObject MapIterator_Type = {
839     PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
840     ".MapIterator",       //  tp_name
841     sizeof(MapIterator),  //  tp_basicsize
842     0,                    //  tp_itemsize
843     DeallocMapIterator,   //  tp_dealloc
844 #if PY_VERSION_HEX < 0x03080000
845     nullptr,  // tp_print
846 #else
847     0,  // tp_vectorcall_offset
848 #endif
849     nullptr,                        //  tp_getattr
850     nullptr,                        //  tp_setattr
851     nullptr,                        //  tp_compare
852     nullptr,                        //  tp_repr
853     nullptr,                        //  tp_as_number
854     nullptr,                        //  tp_as_sequence
855     nullptr,                        //  tp_as_mapping
856     nullptr,                        //  tp_hash
857     nullptr,                        //  tp_call
858     nullptr,                        //  tp_str
859     nullptr,                        //  tp_getattro
860     nullptr,                        //  tp_setattro
861     nullptr,                        //  tp_as_buffer
862     Py_TPFLAGS_DEFAULT,             //  tp_flags
863     "A scalar map iterator",        //  tp_doc
864     nullptr,                        //  tp_traverse
865     nullptr,                        //  tp_clear
866     nullptr,                        //  tp_richcompare
867     0,                              //  tp_weaklistoffset
868     PyObject_SelfIter,              //  tp_iter
869     MapReflectionFriend::IterNext,  //  tp_iternext
870     nullptr,                        //  tp_methods
871     nullptr,                        //  tp_members
872     nullptr,                        //  tp_getset
873     nullptr,                        //  tp_base
874     nullptr,                        //  tp_dict
875     nullptr,                        //  tp_descr_get
876     nullptr,                        //  tp_descr_set
877     0,                              //  tp_dictoffset
878     nullptr,                        //  tp_init
879 };
880 
InitMapContainers()881 bool InitMapContainers() {
882   // ScalarMapContainer_Type derives from our MutableMapping type.
883   ScopedPyObjectPtr abc(PyImport_ImportModule("collections.abc"));
884   if (abc == nullptr) {
885     return false;
886   }
887 
888   ScopedPyObjectPtr mutable_mapping(
889       PyObject_GetAttrString(abc.get(), "MutableMapping"));
890   if (mutable_mapping == nullptr) {
891     return false;
892   }
893 
894   Py_INCREF(mutable_mapping.get());
895   ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
896   if (bases == nullptr) {
897     return false;
898   }
899 
900   ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
901       PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
902 
903   if (PyType_Ready(&MapIterator_Type) < 0) {
904     return false;
905   }
906 
907   MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
908       PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
909   return true;
910 }
911 
912 }  // namespace python
913 }  // namespace protobuf
914 }  // namespace google
915