• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: haberman@google.com (Josh Haberman)
32 
33 #include <google/protobuf/pyext/map_container.h>
34 
35 #include <memory>
36 
37 #include <google/protobuf/stubs/logging.h>
38 #include <google/protobuf/stubs/common.h>
39 #include <google/protobuf/map_field.h>
40 #include <google/protobuf/map.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/pyext/message_factory.h>
43 #include <google/protobuf/pyext/message.h>
44 #include <google/protobuf/pyext/repeated_composite_container.h>
45 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
46 #include <google/protobuf/stubs/map_util.h>
47 
48 #if PY_MAJOR_VERSION >= 3
49   #define PyInt_FromLong PyLong_FromLong
50   #define PyInt_FromSize_t PyLong_FromSize_t
51 #endif
52 
53 namespace google {
54 namespace protobuf {
55 namespace python {
56 
57 // Functions that need access to map reflection functionality.
58 // They need to be contained in this class because it is friended.
59 class MapReflectionFriend {
60  public:
61   // Methods that are in common between the map types.
62   static PyObject* Contains(PyObject* _self, PyObject* key);
63   static Py_ssize_t Length(PyObject* _self);
64   static PyObject* GetIterator(PyObject *_self);
65   static PyObject* IterNext(PyObject* _self);
66   static PyObject* MergeFrom(PyObject* _self, PyObject* arg);
67 
68   // Methods that differ between the map types.
69   static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
70   static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
71   static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
72   static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
73   static PyObject* ScalarMapToStr(PyObject* _self);
74   static PyObject* MessageMapToStr(PyObject* _self);
75 };
76 
77 struct MapIterator {
78   PyObject_HEAD;
79 
80   std::unique_ptr<::google::protobuf::MapIterator> iter;
81 
82   // A pointer back to the container, so we can notice changes to the version.
83   // We own a ref on this.
84   MapContainer* container;
85 
86   // We need to keep a ref on the parent Message too, because
87   // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
88   // the ref on container (above) would guarantee outlive semantics.  However in
89   // the case of ClearField(), the MapContainer points to a different message,
90   // a copy of the original.  But our iterator still points to the original,
91   // which could now get deleted before us.
92   //
93   // To prevent this, we ensure that the Message will always stay alive as long
94   // as this iterator does.  This is solely for the benefit of the MapIterator
95   // destructor -- we should never actually access the iterator in this state
96   // except to delete it.
97   CMessage* parent;
98   // The version of the map when we took the iterator to it.
99   //
100   // We store this so that if the map is modified during iteration we can throw
101   // an error.
102   uint64 version;
103 };
104 
GetMutableMessage()105 Message* MapContainer::GetMutableMessage() {
106   cmessage::AssureWritable(parent);
107   return parent->message;
108 }
109 
110 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,std::string * stl_string)111 static bool PyStringToSTL(PyObject* py_string, std::string* stl_string) {
112   char *value;
113   Py_ssize_t value_len;
114 
115   if (!py_string) {
116     return false;
117   }
118   if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
119     Py_DECREF(py_string);
120     return false;
121   } else {
122     stl_string->assign(value, value_len);
123     Py_DECREF(py_string);
124     return true;
125   }
126 }
127 
PythonToMapKey(MapContainer * self,PyObject * obj,MapKey * key)128 static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key) {
129   const FieldDescriptor* field_descriptor =
130       self->parent_field_descriptor->message_type()->map_key();
131   switch (field_descriptor->cpp_type()) {
132     case FieldDescriptor::CPPTYPE_INT32: {
133       GOOGLE_CHECK_GET_INT32(obj, value, false);
134       key->SetInt32Value(value);
135       break;
136     }
137     case FieldDescriptor::CPPTYPE_INT64: {
138       GOOGLE_CHECK_GET_INT64(obj, value, false);
139       key->SetInt64Value(value);
140       break;
141     }
142     case FieldDescriptor::CPPTYPE_UINT32: {
143       GOOGLE_CHECK_GET_UINT32(obj, value, false);
144       key->SetUInt32Value(value);
145       break;
146     }
147     case FieldDescriptor::CPPTYPE_UINT64: {
148       GOOGLE_CHECK_GET_UINT64(obj, value, false);
149       key->SetUInt64Value(value);
150       break;
151     }
152     case FieldDescriptor::CPPTYPE_BOOL: {
153       GOOGLE_CHECK_GET_BOOL(obj, value, false);
154       key->SetBoolValue(value);
155       break;
156     }
157     case FieldDescriptor::CPPTYPE_STRING: {
158       std::string str;
159       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
160         return false;
161       }
162       key->SetStringValue(str);
163       break;
164     }
165     default:
166       PyErr_Format(
167           PyExc_SystemError, "Type %d cannot be a map key",
168           field_descriptor->cpp_type());
169       return false;
170   }
171   return true;
172 }
173 
MapKeyToPython(MapContainer * self,const MapKey & key)174 static PyObject* MapKeyToPython(MapContainer* self, const MapKey& key) {
175   const FieldDescriptor* field_descriptor =
176       self->parent_field_descriptor->message_type()->map_key();
177   switch (field_descriptor->cpp_type()) {
178     case FieldDescriptor::CPPTYPE_INT32:
179       return PyInt_FromLong(key.GetInt32Value());
180     case FieldDescriptor::CPPTYPE_INT64:
181       return PyLong_FromLongLong(key.GetInt64Value());
182     case FieldDescriptor::CPPTYPE_UINT32:
183       return PyInt_FromSize_t(key.GetUInt32Value());
184     case FieldDescriptor::CPPTYPE_UINT64:
185       return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
186     case FieldDescriptor::CPPTYPE_BOOL:
187       return PyBool_FromLong(key.GetBoolValue());
188     case FieldDescriptor::CPPTYPE_STRING:
189       return ToStringObject(field_descriptor, key.GetStringValue());
190     default:
191       PyErr_Format(
192           PyExc_SystemError, "Couldn't convert type %d to value",
193           field_descriptor->cpp_type());
194       return NULL;
195   }
196 }
197 
198 // This is only used for ScalarMap, so we don't need to handle the
199 // CPPTYPE_MESSAGE case.
MapValueRefToPython(MapContainer * self,const MapValueRef & value)200 PyObject* MapValueRefToPython(MapContainer* self, const MapValueRef& value) {
201   const FieldDescriptor* field_descriptor =
202       self->parent_field_descriptor->message_type()->map_value();
203   switch (field_descriptor->cpp_type()) {
204     case FieldDescriptor::CPPTYPE_INT32:
205       return PyInt_FromLong(value.GetInt32Value());
206     case FieldDescriptor::CPPTYPE_INT64:
207       return PyLong_FromLongLong(value.GetInt64Value());
208     case FieldDescriptor::CPPTYPE_UINT32:
209       return PyInt_FromSize_t(value.GetUInt32Value());
210     case FieldDescriptor::CPPTYPE_UINT64:
211       return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
212     case FieldDescriptor::CPPTYPE_FLOAT:
213       return PyFloat_FromDouble(value.GetFloatValue());
214     case FieldDescriptor::CPPTYPE_DOUBLE:
215       return PyFloat_FromDouble(value.GetDoubleValue());
216     case FieldDescriptor::CPPTYPE_BOOL:
217       return PyBool_FromLong(value.GetBoolValue());
218     case FieldDescriptor::CPPTYPE_STRING:
219       return ToStringObject(field_descriptor, value.GetStringValue());
220     case FieldDescriptor::CPPTYPE_ENUM:
221       return PyInt_FromLong(value.GetEnumValue());
222     default:
223       PyErr_Format(
224           PyExc_SystemError, "Couldn't convert type %d to value",
225           field_descriptor->cpp_type());
226       return NULL;
227   }
228 }
229 
230 // This is only used for ScalarMap, so we don't need to handle the
231 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(MapContainer * self,PyObject * obj,bool allow_unknown_enum_values,MapValueRef * value_ref)232 static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
233                                 bool allow_unknown_enum_values,
234                                 MapValueRef* value_ref) {
235   const FieldDescriptor* field_descriptor =
236       self->parent_field_descriptor->message_type()->map_value();
237   switch (field_descriptor->cpp_type()) {
238     case FieldDescriptor::CPPTYPE_INT32: {
239       GOOGLE_CHECK_GET_INT32(obj, value, false);
240       value_ref->SetInt32Value(value);
241       return true;
242     }
243     case FieldDescriptor::CPPTYPE_INT64: {
244       GOOGLE_CHECK_GET_INT64(obj, value, false);
245       value_ref->SetInt64Value(value);
246       return true;
247     }
248     case FieldDescriptor::CPPTYPE_UINT32: {
249       GOOGLE_CHECK_GET_UINT32(obj, value, false);
250       value_ref->SetUInt32Value(value);
251       return true;
252     }
253     case FieldDescriptor::CPPTYPE_UINT64: {
254       GOOGLE_CHECK_GET_UINT64(obj, value, false);
255       value_ref->SetUInt64Value(value);
256       return true;
257     }
258     case FieldDescriptor::CPPTYPE_FLOAT: {
259       GOOGLE_CHECK_GET_FLOAT(obj, value, false);
260       value_ref->SetFloatValue(value);
261       return true;
262     }
263     case FieldDescriptor::CPPTYPE_DOUBLE: {
264       GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
265       value_ref->SetDoubleValue(value);
266       return true;
267     }
268     case FieldDescriptor::CPPTYPE_BOOL: {
269       GOOGLE_CHECK_GET_BOOL(obj, value, false);
270       value_ref->SetBoolValue(value);
271       return true;;
272     }
273     case FieldDescriptor::CPPTYPE_STRING: {
274       std::string str;
275       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
276         return false;
277       }
278       value_ref->SetStringValue(str);
279       return true;
280     }
281     case FieldDescriptor::CPPTYPE_ENUM: {
282       GOOGLE_CHECK_GET_INT32(obj, value, false);
283       if (allow_unknown_enum_values) {
284         value_ref->SetEnumValue(value);
285         return true;
286       } else {
287         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
288         const EnumValueDescriptor* enum_value =
289             enum_descriptor->FindValueByNumber(value);
290         if (enum_value != NULL) {
291           value_ref->SetEnumValue(value);
292           return true;
293         } else {
294           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
295           return false;
296         }
297       }
298       break;
299     }
300     default:
301       PyErr_Format(
302           PyExc_SystemError, "Setting value to a field of unknown type %d",
303           field_descriptor->cpp_type());
304       return false;
305   }
306 }
307 
308 // Map methods common to ScalarMap and MessageMap //////////////////////////////
309 
GetMap(PyObject * obj)310 static MapContainer* GetMap(PyObject* obj) {
311   return reinterpret_cast<MapContainer*>(obj);
312 }
313 
Length(PyObject * _self)314 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
315   MapContainer* self = GetMap(_self);
316   const google::protobuf::Message* message = self->parent->message;
317   return message->GetReflection()->MapSize(*message,
318                                            self->parent_field_descriptor);
319 }
320 
Clear(PyObject * _self)321 PyObject* Clear(PyObject* _self) {
322   MapContainer* self = GetMap(_self);
323   Message* message = self->GetMutableMessage();
324   const Reflection* reflection = message->GetReflection();
325 
326   reflection->ClearField(message, self->parent_field_descriptor);
327 
328   Py_RETURN_NONE;
329 }
330 
GetEntryClass(PyObject * _self)331 PyObject* GetEntryClass(PyObject* _self) {
332   MapContainer* self = GetMap(_self);
333   CMessageClass* message_class = message_factory::GetMessageClass(
334       cmessage::GetFactoryForMessage(self->parent),
335       self->parent_field_descriptor->message_type());
336   Py_XINCREF(message_class);
337   return reinterpret_cast<PyObject*>(message_class);
338 }
339 
MergeFrom(PyObject * _self,PyObject * arg)340 PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
341   MapContainer* self = GetMap(_self);
342   MapContainer* other_map = GetMap(arg);
343   Message* message = self->GetMutableMessage();
344   const Message* other_message = other_map->parent->message;
345   const Reflection* reflection = message->GetReflection();
346   const Reflection* other_reflection = other_message->GetReflection();
347   internal::MapFieldBase* field = reflection->MutableMapData(
348       message, self->parent_field_descriptor);
349   const internal::MapFieldBase* other_field = other_reflection->GetMapData(
350       *other_message, other_map->parent_field_descriptor);
351   field->MergeFrom(*other_field);
352   self->version++;
353   Py_RETURN_NONE;
354 }
355 
Contains(PyObject * _self,PyObject * key)356 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
357   MapContainer* self = GetMap(_self);
358 
359   const Message* message = self->parent->message;
360   const Reflection* reflection = message->GetReflection();
361   MapKey map_key;
362 
363   if (!PythonToMapKey(self, key, &map_key)) {
364     return NULL;
365   }
366 
367   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
368                                  map_key)) {
369     Py_RETURN_TRUE;
370   } else {
371     Py_RETURN_FALSE;
372   }
373 }
374 
375 // ScalarMap ///////////////////////////////////////////////////////////////////
376 
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)377 MapContainer* NewScalarMapContainer(
378     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
379   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
380     return NULL;
381   }
382 
383   PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
384   if (obj == NULL) {
385     PyErr_Format(PyExc_RuntimeError,
386                  "Could not allocate new container.");
387     return NULL;
388   }
389 
390   MapContainer* self = GetMap(obj);
391 
392   Py_INCREF(parent);
393   self->parent = parent;
394   self->parent_field_descriptor = parent_field_descriptor;
395   self->version = 0;
396 
397   return self;
398 }
399 
ScalarMapGetItem(PyObject * _self,PyObject * key)400 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
401                                                 PyObject* key) {
402   MapContainer* self = GetMap(_self);
403 
404   Message* message = self->GetMutableMessage();
405   const Reflection* reflection = message->GetReflection();
406   MapKey map_key;
407   MapValueRef value;
408 
409   if (!PythonToMapKey(self, key, &map_key)) {
410     return NULL;
411   }
412 
413   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
414                                          map_key, &value)) {
415     self->version++;
416   }
417 
418   return MapValueRefToPython(self, value);
419 }
420 
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)421 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
422                                           PyObject* v) {
423   MapContainer* self = GetMap(_self);
424 
425   Message* message = self->GetMutableMessage();
426   const Reflection* reflection = message->GetReflection();
427   MapKey map_key;
428   MapValueRef value;
429 
430   if (!PythonToMapKey(self, key, &map_key)) {
431     return -1;
432   }
433 
434   self->version++;
435 
436   if (v) {
437     // Set item to v.
438     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
439                                        map_key, &value);
440 
441     if (!PythonToMapValueRef(self, v, reflection->SupportsUnknownEnumValues(),
442                              &value)) {
443       return -1;
444     }
445     return 0;
446   } else {
447     // Delete key from map.
448     if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
449                                    map_key)) {
450       return 0;
451     } else {
452       PyErr_Format(PyExc_KeyError, "Key not present in map");
453       return -1;
454     }
455   }
456 }
457 
ScalarMapGet(PyObject * self,PyObject * args,PyObject * kwargs)458 static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
459                               PyObject* kwargs) {
460   static char* kwlist[] = {"key", "default", nullptr};
461   PyObject* key;
462   PyObject* default_value = NULL;
463   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, &key,
464                                    &default_value)) {
465     return NULL;
466   }
467 
468   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
469   if (is_present.get() == NULL) {
470     return NULL;
471   }
472 
473   if (PyObject_IsTrue(is_present.get())) {
474     return MapReflectionFriend::ScalarMapGetItem(self, key);
475   } else {
476     if (default_value != NULL) {
477       Py_INCREF(default_value);
478       return default_value;
479     } else {
480       Py_RETURN_NONE;
481     }
482   }
483 }
484 
ScalarMapToStr(PyObject * _self)485 PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
486   ScopedPyObjectPtr dict(PyDict_New());
487   if (dict == NULL) {
488     return NULL;
489   }
490   ScopedPyObjectPtr key;
491   ScopedPyObjectPtr value;
492 
493   MapContainer* self = GetMap(_self);
494   Message* message = self->GetMutableMessage();
495   const Reflection* reflection = message->GetReflection();
496   for (google::protobuf::MapIterator it = reflection->MapBegin(
497            message, self->parent_field_descriptor);
498        it != reflection->MapEnd(message, self->parent_field_descriptor);
499        ++it) {
500     key.reset(MapKeyToPython(self, it.GetKey()));
501     if (key == NULL) {
502       return NULL;
503     }
504     value.reset(MapValueRefToPython(self, it.GetValueRef()));
505     if (value == NULL) {
506       return NULL;
507     }
508     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
509       return NULL;
510     }
511   }
512   return PyObject_Repr(dict.get());
513 }
514 
ScalarMapDealloc(PyObject * _self)515 static void ScalarMapDealloc(PyObject* _self) {
516   MapContainer* self = GetMap(_self);
517   self->RemoveFromParentCache();
518   PyTypeObject *type = Py_TYPE(_self);
519   type->tp_free(_self);
520   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
521     // With Python3, the Map class is not static, and must be managed.
522     Py_DECREF(type);
523   }
524 }
525 
526 static PyMethodDef ScalarMapMethods[] = {
527     {"__contains__", MapReflectionFriend::Contains, METH_O,
528      "Tests whether a key is a member of the map."},
529     {"clear", (PyCFunction)Clear, METH_NOARGS,
530      "Removes all elements from the map."},
531     {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
532      "Gets the value for the given key if present, or otherwise a default"},
533     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
534      "Return the class used to build Entries of (key, value) pairs."},
535     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
536      "Merges a map into the current map."},
537     /*
538     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
539       "Makes a deep copy of the class." },
540     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
541       "Outputs picklable representation of the repeated field." },
542     */
543     {NULL, NULL},
544 };
545 
546 PyTypeObject *ScalarMapContainer_Type;
547 #if PY_MAJOR_VERSION >= 3
548   static PyType_Slot ScalarMapContainer_Type_slots[] = {
549       {Py_tp_dealloc, (void *)ScalarMapDealloc},
550       {Py_mp_length, (void *)MapReflectionFriend::Length},
551       {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
552       {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
553       {Py_tp_methods, (void *)ScalarMapMethods},
554       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
555       {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr},
556       {0, 0},
557   };
558 
559   PyType_Spec ScalarMapContainer_Type_spec = {
560       FULL_MODULE_NAME ".ScalarMapContainer",
561       sizeof(MapContainer),
562       0,
563       Py_TPFLAGS_DEFAULT,
564       ScalarMapContainer_Type_slots
565   };
566 #else
567   static PyMappingMethods ScalarMapMappingMethods = {
568     MapReflectionFriend::Length,             // mp_length
569     MapReflectionFriend::ScalarMapGetItem,   // mp_subscript
570     MapReflectionFriend::ScalarMapSetItem,   // mp_ass_subscript
571   };
572 
573   PyTypeObject _ScalarMapContainer_Type = {
574     PyVarObject_HEAD_INIT(&PyType_Type, 0)
575     FULL_MODULE_NAME ".ScalarMapContainer",  //  tp_name
576     sizeof(MapContainer),                //  tp_basicsize
577     0,                                   //  tp_itemsize
578     ScalarMapDealloc,                    //  tp_dealloc
579     0,                                   //  tp_print
580     0,                                   //  tp_getattr
581     0,                                   //  tp_setattr
582     0,                                   //  tp_compare
583     MapReflectionFriend::ScalarMapToStr,  //  tp_repr
584     0,                                   //  tp_as_number
585     0,                                   //  tp_as_sequence
586     &ScalarMapMappingMethods,            //  tp_as_mapping
587     0,                                   //  tp_hash
588     0,                                   //  tp_call
589     0,                                   //  tp_str
590     0,                                   //  tp_getattro
591     0,                                   //  tp_setattro
592     0,                                   //  tp_as_buffer
593     Py_TPFLAGS_DEFAULT,                  //  tp_flags
594     "A scalar map container",            //  tp_doc
595     0,                                   //  tp_traverse
596     0,                                   //  tp_clear
597     0,                                   //  tp_richcompare
598     0,                                   //  tp_weaklistoffset
599     MapReflectionFriend::GetIterator,    //  tp_iter
600     0,                                   //  tp_iternext
601     ScalarMapMethods,                    //  tp_methods
602     0,                                   //  tp_members
603     0,                                   //  tp_getset
604     0,                                   //  tp_base
605     0,                                   //  tp_dict
606     0,                                   //  tp_descr_get
607     0,                                   //  tp_descr_set
608     0,                                   //  tp_dictoffset
609     0,                                   //  tp_init
610   };
611 #endif
612 
613 
614 // MessageMap //////////////////////////////////////////////////////////////////
615 
GetMessageMap(PyObject * obj)616 static MessageMapContainer* GetMessageMap(PyObject* obj) {
617   return reinterpret_cast<MessageMapContainer*>(obj);
618 }
619 
GetCMessage(MessageMapContainer * self,Message * message)620 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
621   // Get or create the CMessage object corresponding to this message.
622   return self->parent
623       ->BuildSubMessageFromPointer(self->parent_field_descriptor, message,
624                                    self->message_class)
625       ->AsPyObject();
626 }
627 
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)628 MessageMapContainer* NewMessageMapContainer(
629     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
630     CMessageClass* message_class) {
631   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
632     return NULL;
633   }
634 
635   PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
636   if (obj == NULL) {
637     PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container.");
638     return NULL;
639   }
640 
641   MessageMapContainer* self = GetMessageMap(obj);
642 
643   Py_INCREF(parent);
644   self->parent = parent;
645   self->parent_field_descriptor = parent_field_descriptor;
646   self->version = 0;
647 
648   Py_INCREF(message_class);
649   self->message_class = message_class;
650 
651   return self;
652 }
653 
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)654 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
655                                            PyObject* v) {
656   if (v) {
657     PyErr_Format(PyExc_ValueError,
658                  "Direct assignment of submessage not allowed");
659     return -1;
660   }
661 
662   // Now we know that this is a delete, not a set.
663 
664   MessageMapContainer* self = GetMessageMap(_self);
665   Message* message = self->GetMutableMessage();
666   const Reflection* reflection = message->GetReflection();
667   MapKey map_key;
668   MapValueRef value;
669 
670   self->version++;
671 
672   if (!PythonToMapKey(self, key, &map_key)) {
673     return -1;
674   }
675 
676   // Delete key from map.
677   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
678                                  map_key)) {
679     // Delete key from CMessage dict.
680     MapValueRef value;
681     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
682                                        map_key, &value);
683     Message* sub_message = value.MutableMessageValue();
684     // If there is a living weak reference to an item, we "Release" it,
685     // otherwise we just discard the C++ value.
686     if (CMessage* released =
687             self->parent->MaybeReleaseSubMessage(sub_message)) {
688       Message* msg = released->message;
689       released->message = msg->New();
690       msg->GetReflection()->Swap(msg, released->message);
691     }
692 
693     // Delete key from map.
694     reflection->DeleteMapValue(message, self->parent_field_descriptor,
695                                map_key);
696     return 0;
697   } else {
698     PyErr_Format(PyExc_KeyError, "Key not present in map");
699     return -1;
700   }
701 }
702 
MessageMapGetItem(PyObject * _self,PyObject * key)703 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
704                                                  PyObject* key) {
705   MessageMapContainer* self = GetMessageMap(_self);
706 
707   Message* message = self->GetMutableMessage();
708   const Reflection* reflection = message->GetReflection();
709   MapKey map_key;
710   MapValueRef value;
711 
712   if (!PythonToMapKey(self, key, &map_key)) {
713     return NULL;
714   }
715 
716   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
717                                          map_key, &value)) {
718     self->version++;
719   }
720 
721   return GetCMessage(self, value.MutableMessageValue());
722 }
723 
MessageMapToStr(PyObject * _self)724 PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
725   ScopedPyObjectPtr dict(PyDict_New());
726   if (dict == NULL) {
727     return NULL;
728   }
729   ScopedPyObjectPtr key;
730   ScopedPyObjectPtr value;
731 
732   MessageMapContainer* self = GetMessageMap(_self);
733   Message* message = self->GetMutableMessage();
734   const Reflection* reflection = message->GetReflection();
735   for (google::protobuf::MapIterator it = reflection->MapBegin(
736            message, self->parent_field_descriptor);
737        it != reflection->MapEnd(message, self->parent_field_descriptor);
738        ++it) {
739     key.reset(MapKeyToPython(self, it.GetKey()));
740     if (key == NULL) {
741       return NULL;
742     }
743     value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue()));
744     if (value == NULL) {
745       return NULL;
746     }
747     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
748       return NULL;
749     }
750   }
751   return PyObject_Repr(dict.get());
752 }
753 
MessageMapGet(PyObject * self,PyObject * args,PyObject * kwargs)754 PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
755   static char* kwlist[] = {"key", "default", nullptr};
756   PyObject* key;
757   PyObject* default_value = NULL;
758   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, &key,
759                                    &default_value)) {
760     return NULL;
761   }
762 
763   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
764   if (is_present.get() == NULL) {
765     return NULL;
766   }
767 
768   if (PyObject_IsTrue(is_present.get())) {
769     return MapReflectionFriend::MessageMapGetItem(self, key);
770   } else {
771     if (default_value != NULL) {
772       Py_INCREF(default_value);
773       return default_value;
774     } else {
775       Py_RETURN_NONE;
776     }
777   }
778 }
779 
MessageMapDealloc(PyObject * _self)780 static void MessageMapDealloc(PyObject* _self) {
781   MessageMapContainer* self = GetMessageMap(_self);
782   self->RemoveFromParentCache();
783   Py_DECREF(self->message_class);
784   PyTypeObject *type = Py_TYPE(_self);
785   type->tp_free(_self);
786   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
787     // With Python3, the Map class is not static, and must be managed.
788     Py_DECREF(type);
789   }
790 }
791 
792 static PyMethodDef MessageMapMethods[] = {
793     {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
794      "Tests whether the map contains this element."},
795     {"clear", (PyCFunction)Clear, METH_NOARGS,
796      "Removes all elements from the map."},
797     {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
798      "Gets the value for the given key if present, or otherwise a default"},
799     {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
800      "Alias for getitem, useful to make explicit that the map is mutated."},
801     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
802      "Return the class used to build Entries of (key, value) pairs."},
803     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
804      "Merges a map into the current map."},
805     /*
806     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
807       "Makes a deep copy of the class." },
808     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
809       "Outputs picklable representation of the repeated field." },
810     */
811     {NULL, NULL},
812 };
813 
814 PyTypeObject *MessageMapContainer_Type;
815 #if PY_MAJOR_VERSION >= 3
816   static PyType_Slot MessageMapContainer_Type_slots[] = {
817       {Py_tp_dealloc, (void *)MessageMapDealloc},
818       {Py_mp_length, (void *)MapReflectionFriend::Length},
819       {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
820       {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
821       {Py_tp_methods, (void *)MessageMapMethods},
822       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
823       {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr},
824       {0, 0}
825   };
826 
827   PyType_Spec MessageMapContainer_Type_spec = {
828       FULL_MODULE_NAME ".MessageMapContainer",
829       sizeof(MessageMapContainer),
830       0,
831       Py_TPFLAGS_DEFAULT,
832       MessageMapContainer_Type_slots
833   };
834 #else
835   static PyMappingMethods MessageMapMappingMethods = {
836     MapReflectionFriend::Length,              // mp_length
837     MapReflectionFriend::MessageMapGetItem,   // mp_subscript
838     MapReflectionFriend::MessageMapSetItem,   // mp_ass_subscript
839   };
840 
841   PyTypeObject _MessageMapContainer_Type = {
842     PyVarObject_HEAD_INIT(&PyType_Type, 0)
843     FULL_MODULE_NAME ".MessageMapContainer",  //  tp_name
844     sizeof(MessageMapContainer),         //  tp_basicsize
845     0,                                   //  tp_itemsize
846     MessageMapDealloc,                   //  tp_dealloc
847     0,                                   //  tp_print
848     0,                                   //  tp_getattr
849     0,                                   //  tp_setattr
850     0,                                   //  tp_compare
851     MapReflectionFriend::MessageMapToStr,  //  tp_repr
852     0,                                   //  tp_as_number
853     0,                                   //  tp_as_sequence
854     &MessageMapMappingMethods,           //  tp_as_mapping
855     0,                                   //  tp_hash
856     0,                                   //  tp_call
857     0,                                   //  tp_str
858     0,                                   //  tp_getattro
859     0,                                   //  tp_setattro
860     0,                                   //  tp_as_buffer
861     Py_TPFLAGS_DEFAULT,                  //  tp_flags
862     "A map container for message",       //  tp_doc
863     0,                                   //  tp_traverse
864     0,                                   //  tp_clear
865     0,                                   //  tp_richcompare
866     0,                                   //  tp_weaklistoffset
867     MapReflectionFriend::GetIterator,    //  tp_iter
868     0,                                   //  tp_iternext
869     MessageMapMethods,                   //  tp_methods
870     0,                                   //  tp_members
871     0,                                   //  tp_getset
872     0,                                   //  tp_base
873     0,                                   //  tp_dict
874     0,                                   //  tp_descr_get
875     0,                                   //  tp_descr_set
876     0,                                   //  tp_dictoffset
877     0,                                   //  tp_init
878   };
879 #endif
880 
881 // MapIterator /////////////////////////////////////////////////////////////////
882 
GetIter(PyObject * obj)883 static MapIterator* GetIter(PyObject* obj) {
884   return reinterpret_cast<MapIterator*>(obj);
885 }
886 
GetIterator(PyObject * _self)887 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
888   MapContainer* self = GetMap(_self);
889 
890   ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
891   if (obj == NULL) {
892     return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
893   }
894 
895   MapIterator* iter = GetIter(obj.get());
896 
897   Py_INCREF(self);
898   iter->container = self;
899   iter->version = self->version;
900   Py_INCREF(self->parent);
901   iter->parent = self->parent;
902 
903   if (MapReflectionFriend::Length(_self) > 0) {
904     Message* message = self->GetMutableMessage();
905     const Reflection* reflection = message->GetReflection();
906 
907     iter->iter.reset(new ::google::protobuf::MapIterator(
908         reflection->MapBegin(message, self->parent_field_descriptor)));
909   }
910 
911   return obj.release();
912 }
913 
IterNext(PyObject * _self)914 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
915   MapIterator* self = GetIter(_self);
916 
917   // This won't catch mutations to the map performed by MergeFrom(); no easy way
918   // to address that.
919   if (self->version != self->container->version) {
920     return PyErr_Format(PyExc_RuntimeError,
921                         "Map modified during iteration.");
922   }
923   if (self->parent != self->container->parent) {
924     return PyErr_Format(PyExc_RuntimeError,
925                         "Map cleared during iteration.");
926   }
927 
928   if (self->iter.get() == NULL) {
929     return NULL;
930   }
931 
932   Message* message = self->container->GetMutableMessage();
933   const Reflection* reflection = message->GetReflection();
934 
935   if (*self->iter ==
936       reflection->MapEnd(message, self->container->parent_field_descriptor)) {
937     return NULL;
938   }
939 
940   PyObject* ret = MapKeyToPython(self->container, self->iter->GetKey());
941 
942   ++(*self->iter);
943 
944   return ret;
945 }
946 
DeallocMapIterator(PyObject * _self)947 static void DeallocMapIterator(PyObject* _self) {
948   MapIterator* self = GetIter(_self);
949   self->iter.reset();
950   Py_CLEAR(self->container);
951   Py_CLEAR(self->parent);
952   Py_TYPE(_self)->tp_free(_self);
953 }
954 
955 PyTypeObject MapIterator_Type = {
956   PyVarObject_HEAD_INIT(&PyType_Type, 0)
957   FULL_MODULE_NAME ".MapIterator",     //  tp_name
958   sizeof(MapIterator),                 //  tp_basicsize
959   0,                                   //  tp_itemsize
960   DeallocMapIterator,                  //  tp_dealloc
961   0,                                   //  tp_print
962   0,                                   //  tp_getattr
963   0,                                   //  tp_setattr
964   0,                                   //  tp_compare
965   0,                                   //  tp_repr
966   0,                                   //  tp_as_number
967   0,                                   //  tp_as_sequence
968   0,                                   //  tp_as_mapping
969   0,                                   //  tp_hash
970   0,                                   //  tp_call
971   0,                                   //  tp_str
972   0,                                   //  tp_getattro
973   0,                                   //  tp_setattro
974   0,                                   //  tp_as_buffer
975   Py_TPFLAGS_DEFAULT,                  //  tp_flags
976   "A scalar map iterator",             //  tp_doc
977   0,                                   //  tp_traverse
978   0,                                   //  tp_clear
979   0,                                   //  tp_richcompare
980   0,                                   //  tp_weaklistoffset
981   PyObject_SelfIter,                   //  tp_iter
982   MapReflectionFriend::IterNext,       //  tp_iternext
983   0,                                   //  tp_methods
984   0,                                   //  tp_members
985   0,                                   //  tp_getset
986   0,                                   //  tp_base
987   0,                                   //  tp_dict
988   0,                                   //  tp_descr_get
989   0,                                   //  tp_descr_set
990   0,                                   //  tp_dictoffset
991   0,                                   //  tp_init
992 };
993 
InitMapContainers()994 bool InitMapContainers() {
995   // ScalarMapContainer_Type derives from our MutableMapping type.
996   ScopedPyObjectPtr containers(PyImport_ImportModule(
997       "google.protobuf.internal.containers"));
998   if (containers == NULL) {
999     return false;
1000   }
1001 
1002   ScopedPyObjectPtr mutable_mapping(
1003       PyObject_GetAttrString(containers.get(), "MutableMapping"));
1004   if (mutable_mapping == NULL) {
1005     return false;
1006   }
1007 
1008   Py_INCREF(mutable_mapping.get());
1009 #if PY_MAJOR_VERSION >= 3
1010   ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
1011   if (bases == NULL) {
1012     return false;
1013   }
1014 
1015   ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1016       PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
1017 #else
1018   _ScalarMapContainer_Type.tp_base =
1019       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1020 
1021   if (PyType_Ready(&_ScalarMapContainer_Type) < 0) {
1022     return false;
1023   }
1024 
1025   ScalarMapContainer_Type = &_ScalarMapContainer_Type;
1026 #endif
1027 
1028   if (PyType_Ready(&MapIterator_Type) < 0) {
1029     return false;
1030   }
1031 
1032 #if PY_MAJOR_VERSION >= 3
1033   MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1034       PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
1035 #else
1036   Py_INCREF(mutable_mapping.get());
1037   _MessageMapContainer_Type.tp_base =
1038       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1039 
1040   if (PyType_Ready(&_MessageMapContainer_Type) < 0) {
1041     return false;
1042   }
1043 
1044   MessageMapContainer_Type = &_MessageMapContainer_Type;
1045 #endif
1046   return true;
1047 }
1048 
1049 }  // namespace python
1050 }  // namespace protobuf
1051 }  // namespace google
1052