• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: haberman@google.com (Josh Haberman)
32 
33 #include <google/protobuf/pyext/map_container.h>
34 
35 #include <memory>
36 
37 #include <google/protobuf/stubs/logging.h>
38 #include <google/protobuf/stubs/common.h>
39 #include <google/protobuf/map_field.h>
40 #include <google/protobuf/map.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/pyext/message_factory.h>
43 #include <google/protobuf/pyext/message.h>
44 #include <google/protobuf/pyext/repeated_composite_container.h>
45 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
46 #include <google/protobuf/stubs/map_util.h>
47 
48 #if PY_MAJOR_VERSION >= 3
49   #define PyInt_FromLong PyLong_FromLong
50   #define PyInt_FromSize_t PyLong_FromSize_t
51 #endif
52 
53 namespace google {
54 namespace protobuf {
55 namespace python {
56 
57 // Functions that need access to map reflection functionality.
58 // They need to be contained in this class because it is friended.
59 class MapReflectionFriend {
60  public:
61   // Methods that are in common between the map types.
62   static PyObject* Contains(PyObject* _self, PyObject* key);
63   static Py_ssize_t Length(PyObject* _self);
64   static PyObject* GetIterator(PyObject *_self);
65   static PyObject* IterNext(PyObject* _self);
66   static PyObject* MergeFrom(PyObject* _self, PyObject* arg);
67 
68   // Methods that differ between the map types.
69   static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
70   static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
71   static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
72   static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
73   static PyObject* ScalarMapToStr(PyObject* _self);
74   static PyObject* MessageMapToStr(PyObject* _self);
75 };
76 
77 struct MapIterator {
78   PyObject_HEAD;
79 
80   std::unique_ptr<::google::protobuf::MapIterator> iter;
81 
82   // A pointer back to the container, so we can notice changes to the version.
83   // We own a ref on this.
84   MapContainer* container;
85 
86   // We need to keep a ref on the parent Message too, because
87   // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
88   // the ref on container (above) would guarantee outlive semantics.  However in
89   // the case of ClearField(), the MapContainer points to a different message,
90   // a copy of the original.  But our iterator still points to the original,
91   // which could now get deleted before us.
92   //
93   // To prevent this, we ensure that the Message will always stay alive as long
94   // as this iterator does.  This is solely for the benefit of the MapIterator
95   // destructor -- we should never actually access the iterator in this state
96   // except to delete it.
97   CMessage* parent;
98   // The version of the map when we took the iterator to it.
99   //
100   // We store this so that if the map is modified during iteration we can throw
101   // an error.
102   uint64 version;
103 };
104 
GetMutableMessage()105 Message* MapContainer::GetMutableMessage() {
106   cmessage::AssureWritable(parent);
107   return parent->message;
108 }
109 
110 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,std::string * stl_string)111 static bool PyStringToSTL(PyObject* py_string, std::string* stl_string) {
112   char *value;
113   Py_ssize_t value_len;
114 
115   if (!py_string) {
116     return false;
117   }
118   if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
119     Py_DECREF(py_string);
120     return false;
121   } else {
122     stl_string->assign(value, value_len);
123     Py_DECREF(py_string);
124     return true;
125   }
126 }
127 
PythonToMapKey(MapContainer * self,PyObject * obj,MapKey * key)128 static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key) {
129   const FieldDescriptor* field_descriptor =
130       self->parent_field_descriptor->message_type()->map_key();
131   switch (field_descriptor->cpp_type()) {
132     case FieldDescriptor::CPPTYPE_INT32: {
133       GOOGLE_CHECK_GET_INT32(obj, value, false);
134       key->SetInt32Value(value);
135       break;
136     }
137     case FieldDescriptor::CPPTYPE_INT64: {
138       GOOGLE_CHECK_GET_INT64(obj, value, false);
139       key->SetInt64Value(value);
140       break;
141     }
142     case FieldDescriptor::CPPTYPE_UINT32: {
143       GOOGLE_CHECK_GET_UINT32(obj, value, false);
144       key->SetUInt32Value(value);
145       break;
146     }
147     case FieldDescriptor::CPPTYPE_UINT64: {
148       GOOGLE_CHECK_GET_UINT64(obj, value, false);
149       key->SetUInt64Value(value);
150       break;
151     }
152     case FieldDescriptor::CPPTYPE_BOOL: {
153       GOOGLE_CHECK_GET_BOOL(obj, value, false);
154       key->SetBoolValue(value);
155       break;
156     }
157     case FieldDescriptor::CPPTYPE_STRING: {
158       std::string str;
159       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
160         return false;
161       }
162       key->SetStringValue(str);
163       break;
164     }
165     default:
166       PyErr_Format(
167           PyExc_SystemError, "Type %d cannot be a map key",
168           field_descriptor->cpp_type());
169       return false;
170   }
171   return true;
172 }
173 
MapKeyToPython(MapContainer * self,const MapKey & key)174 static PyObject* MapKeyToPython(MapContainer* self, const MapKey& key) {
175   const FieldDescriptor* field_descriptor =
176       self->parent_field_descriptor->message_type()->map_key();
177   switch (field_descriptor->cpp_type()) {
178     case FieldDescriptor::CPPTYPE_INT32:
179       return PyInt_FromLong(key.GetInt32Value());
180     case FieldDescriptor::CPPTYPE_INT64:
181       return PyLong_FromLongLong(key.GetInt64Value());
182     case FieldDescriptor::CPPTYPE_UINT32:
183       return PyInt_FromSize_t(key.GetUInt32Value());
184     case FieldDescriptor::CPPTYPE_UINT64:
185       return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
186     case FieldDescriptor::CPPTYPE_BOOL:
187       return PyBool_FromLong(key.GetBoolValue());
188     case FieldDescriptor::CPPTYPE_STRING:
189       return ToStringObject(field_descriptor, key.GetStringValue());
190     default:
191       PyErr_Format(
192           PyExc_SystemError, "Couldn't convert type %d to value",
193           field_descriptor->cpp_type());
194       return NULL;
195   }
196 }
197 
198 // This is only used for ScalarMap, so we don't need to handle the
199 // CPPTYPE_MESSAGE case.
MapValueRefToPython(MapContainer * self,const MapValueRef & value)200 PyObject* MapValueRefToPython(MapContainer* self, const MapValueRef& value) {
201   const FieldDescriptor* field_descriptor =
202       self->parent_field_descriptor->message_type()->map_value();
203   switch (field_descriptor->cpp_type()) {
204     case FieldDescriptor::CPPTYPE_INT32:
205       return PyInt_FromLong(value.GetInt32Value());
206     case FieldDescriptor::CPPTYPE_INT64:
207       return PyLong_FromLongLong(value.GetInt64Value());
208     case FieldDescriptor::CPPTYPE_UINT32:
209       return PyInt_FromSize_t(value.GetUInt32Value());
210     case FieldDescriptor::CPPTYPE_UINT64:
211       return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
212     case FieldDescriptor::CPPTYPE_FLOAT:
213       return PyFloat_FromDouble(value.GetFloatValue());
214     case FieldDescriptor::CPPTYPE_DOUBLE:
215       return PyFloat_FromDouble(value.GetDoubleValue());
216     case FieldDescriptor::CPPTYPE_BOOL:
217       return PyBool_FromLong(value.GetBoolValue());
218     case FieldDescriptor::CPPTYPE_STRING:
219       return ToStringObject(field_descriptor, value.GetStringValue());
220     case FieldDescriptor::CPPTYPE_ENUM:
221       return PyInt_FromLong(value.GetEnumValue());
222     default:
223       PyErr_Format(
224           PyExc_SystemError, "Couldn't convert type %d to value",
225           field_descriptor->cpp_type());
226       return NULL;
227   }
228 }
229 
230 // This is only used for ScalarMap, so we don't need to handle the
231 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(MapContainer * self,PyObject * obj,bool allow_unknown_enum_values,MapValueRef * value_ref)232 static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
233                                 bool allow_unknown_enum_values,
234                                 MapValueRef* value_ref) {
235   const FieldDescriptor* field_descriptor =
236       self->parent_field_descriptor->message_type()->map_value();
237   switch (field_descriptor->cpp_type()) {
238     case FieldDescriptor::CPPTYPE_INT32: {
239       GOOGLE_CHECK_GET_INT32(obj, value, false);
240       value_ref->SetInt32Value(value);
241       return true;
242     }
243     case FieldDescriptor::CPPTYPE_INT64: {
244       GOOGLE_CHECK_GET_INT64(obj, value, false);
245       value_ref->SetInt64Value(value);
246       return true;
247     }
248     case FieldDescriptor::CPPTYPE_UINT32: {
249       GOOGLE_CHECK_GET_UINT32(obj, value, false);
250       value_ref->SetUInt32Value(value);
251       return true;
252     }
253     case FieldDescriptor::CPPTYPE_UINT64: {
254       GOOGLE_CHECK_GET_UINT64(obj, value, false);
255       value_ref->SetUInt64Value(value);
256       return true;
257     }
258     case FieldDescriptor::CPPTYPE_FLOAT: {
259       GOOGLE_CHECK_GET_FLOAT(obj, value, false);
260       value_ref->SetFloatValue(value);
261       return true;
262     }
263     case FieldDescriptor::CPPTYPE_DOUBLE: {
264       GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
265       value_ref->SetDoubleValue(value);
266       return true;
267     }
268     case FieldDescriptor::CPPTYPE_BOOL: {
269       GOOGLE_CHECK_GET_BOOL(obj, value, false);
270       value_ref->SetBoolValue(value);
271       return true;;
272     }
273     case FieldDescriptor::CPPTYPE_STRING: {
274       std::string str;
275       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
276         return false;
277       }
278       value_ref->SetStringValue(str);
279       return true;
280     }
281     case FieldDescriptor::CPPTYPE_ENUM: {
282       GOOGLE_CHECK_GET_INT32(obj, value, false);
283       if (allow_unknown_enum_values) {
284         value_ref->SetEnumValue(value);
285         return true;
286       } else {
287         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
288         const EnumValueDescriptor* enum_value =
289             enum_descriptor->FindValueByNumber(value);
290         if (enum_value != NULL) {
291           value_ref->SetEnumValue(value);
292           return true;
293         } else {
294           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
295           return false;
296         }
297       }
298       break;
299     }
300     default:
301       PyErr_Format(
302           PyExc_SystemError, "Setting value to a field of unknown type %d",
303           field_descriptor->cpp_type());
304       return false;
305   }
306 }
307 
308 // Map methods common to ScalarMap and MessageMap //////////////////////////////
309 
GetMap(PyObject * obj)310 static MapContainer* GetMap(PyObject* obj) {
311   return reinterpret_cast<MapContainer*>(obj);
312 }
313 
Length(PyObject * _self)314 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
315   MapContainer* self = GetMap(_self);
316   const google::protobuf::Message* message = self->parent->message;
317   return message->GetReflection()->MapSize(*message,
318                                            self->parent_field_descriptor);
319 }
320 
Clear(PyObject * _self)321 PyObject* Clear(PyObject* _self) {
322   MapContainer* self = GetMap(_self);
323   Message* message = self->GetMutableMessage();
324   const Reflection* reflection = message->GetReflection();
325 
326   reflection->ClearField(message, self->parent_field_descriptor);
327 
328   Py_RETURN_NONE;
329 }
330 
GetEntryClass(PyObject * _self)331 PyObject* GetEntryClass(PyObject* _self) {
332   MapContainer* self = GetMap(_self);
333   CMessageClass* message_class = message_factory::GetMessageClass(
334       cmessage::GetFactoryForMessage(self->parent),
335       self->parent_field_descriptor->message_type());
336   Py_XINCREF(message_class);
337   return reinterpret_cast<PyObject*>(message_class);
338 }
339 
MergeFrom(PyObject * _self,PyObject * arg)340 PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
341   MapContainer* self = GetMap(_self);
342   if (!PyObject_TypeCheck(arg, ScalarMapContainer_Type) &&
343       !PyObject_TypeCheck(arg, MessageMapContainer_Type)) {
344     PyErr_SetString(PyExc_AttributeError, "Not a map field");
345     return nullptr;
346   }
347   MapContainer* other_map = GetMap(arg);
348   Message* message = self->GetMutableMessage();
349   const Message* other_message = other_map->parent->message;
350   const Reflection* reflection = message->GetReflection();
351   const Reflection* other_reflection = other_message->GetReflection();
352   internal::MapFieldBase* field = reflection->MutableMapData(
353       message, self->parent_field_descriptor);
354   const internal::MapFieldBase* other_field = other_reflection->GetMapData(
355       *other_message, other_map->parent_field_descriptor);
356   field->MergeFrom(*other_field);
357   self->version++;
358   Py_RETURN_NONE;
359 }
360 
Contains(PyObject * _self,PyObject * key)361 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
362   MapContainer* self = GetMap(_self);
363 
364   const Message* message = self->parent->message;
365   const Reflection* reflection = message->GetReflection();
366   MapKey map_key;
367 
368   if (!PythonToMapKey(self, key, &map_key)) {
369     return NULL;
370   }
371 
372   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
373                                  map_key)) {
374     Py_RETURN_TRUE;
375   } else {
376     Py_RETURN_FALSE;
377   }
378 }
379 
380 // ScalarMap ///////////////////////////////////////////////////////////////////
381 
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)382 MapContainer* NewScalarMapContainer(
383     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
384   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
385     return NULL;
386   }
387 
388   PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
389   if (obj == NULL) {
390     PyErr_Format(PyExc_RuntimeError,
391                  "Could not allocate new container.");
392     return NULL;
393   }
394 
395   MapContainer* self = GetMap(obj);
396 
397   Py_INCREF(parent);
398   self->parent = parent;
399   self->parent_field_descriptor = parent_field_descriptor;
400   self->version = 0;
401 
402   return self;
403 }
404 
ScalarMapGetItem(PyObject * _self,PyObject * key)405 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
406                                                 PyObject* key) {
407   MapContainer* self = GetMap(_self);
408 
409   Message* message = self->GetMutableMessage();
410   const Reflection* reflection = message->GetReflection();
411   MapKey map_key;
412   MapValueRef value;
413 
414   if (!PythonToMapKey(self, key, &map_key)) {
415     return NULL;
416   }
417 
418   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
419                                          map_key, &value)) {
420     self->version++;
421   }
422 
423   return MapValueRefToPython(self, value);
424 }
425 
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)426 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
427                                           PyObject* v) {
428   MapContainer* self = GetMap(_self);
429 
430   Message* message = self->GetMutableMessage();
431   const Reflection* reflection = message->GetReflection();
432   MapKey map_key;
433   MapValueRef value;
434 
435   if (!PythonToMapKey(self, key, &map_key)) {
436     return -1;
437   }
438 
439   self->version++;
440 
441   if (v) {
442     // Set item to v.
443     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
444                                        map_key, &value);
445 
446     if (!PythonToMapValueRef(self, v, reflection->SupportsUnknownEnumValues(),
447                              &value)) {
448       return -1;
449     }
450     return 0;
451   } else {
452     // Delete key from map.
453     if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
454                                    map_key)) {
455       return 0;
456     } else {
457       PyErr_Format(PyExc_KeyError, "Key not present in map");
458       return -1;
459     }
460   }
461 }
462 
ScalarMapGet(PyObject * self,PyObject * args,PyObject * kwargs)463 static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
464                               PyObject* kwargs) {
465   static char* kwlist[] = {"key", "default", nullptr};
466   PyObject* key;
467   PyObject* default_value = NULL;
468   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, &key,
469                                    &default_value)) {
470     return NULL;
471   }
472 
473   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
474   if (is_present.get() == NULL) {
475     return NULL;
476   }
477 
478   if (PyObject_IsTrue(is_present.get())) {
479     return MapReflectionFriend::ScalarMapGetItem(self, key);
480   } else {
481     if (default_value != NULL) {
482       Py_INCREF(default_value);
483       return default_value;
484     } else {
485       Py_RETURN_NONE;
486     }
487   }
488 }
489 
ScalarMapToStr(PyObject * _self)490 PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
491   ScopedPyObjectPtr dict(PyDict_New());
492   if (dict == NULL) {
493     return NULL;
494   }
495   ScopedPyObjectPtr key;
496   ScopedPyObjectPtr value;
497 
498   MapContainer* self = GetMap(_self);
499   Message* message = self->GetMutableMessage();
500   const Reflection* reflection = message->GetReflection();
501   for (google::protobuf::MapIterator it = reflection->MapBegin(
502            message, self->parent_field_descriptor);
503        it != reflection->MapEnd(message, self->parent_field_descriptor);
504        ++it) {
505     key.reset(MapKeyToPython(self, it.GetKey()));
506     if (key == NULL) {
507       return NULL;
508     }
509     value.reset(MapValueRefToPython(self, it.GetValueRef()));
510     if (value == NULL) {
511       return NULL;
512     }
513     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
514       return NULL;
515     }
516   }
517   return PyObject_Repr(dict.get());
518 }
519 
ScalarMapDealloc(PyObject * _self)520 static void ScalarMapDealloc(PyObject* _self) {
521   MapContainer* self = GetMap(_self);
522   self->RemoveFromParentCache();
523   PyTypeObject *type = Py_TYPE(_self);
524   type->tp_free(_self);
525   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
526     // With Python3, the Map class is not static, and must be managed.
527     Py_DECREF(type);
528   }
529 }
530 
531 static PyMethodDef ScalarMapMethods[] = {
532     {"__contains__", MapReflectionFriend::Contains, METH_O,
533      "Tests whether a key is a member of the map."},
534     {"clear", (PyCFunction)Clear, METH_NOARGS,
535      "Removes all elements from the map."},
536     {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
537      "Gets the value for the given key if present, or otherwise a default"},
538     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
539      "Return the class used to build Entries of (key, value) pairs."},
540     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
541      "Merges a map into the current map."},
542     /*
543     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
544       "Makes a deep copy of the class." },
545     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
546       "Outputs picklable representation of the repeated field." },
547     */
548     {NULL, NULL},
549 };
550 
551 PyTypeObject *ScalarMapContainer_Type;
552 #if PY_MAJOR_VERSION >= 3
553   static PyType_Slot ScalarMapContainer_Type_slots[] = {
554       {Py_tp_dealloc, (void *)ScalarMapDealloc},
555       {Py_mp_length, (void *)MapReflectionFriend::Length},
556       {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
557       {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
558       {Py_tp_methods, (void *)ScalarMapMethods},
559       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
560       {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr},
561       {0, 0},
562   };
563 
564   PyType_Spec ScalarMapContainer_Type_spec = {
565       FULL_MODULE_NAME ".ScalarMapContainer",
566       sizeof(MapContainer),
567       0,
568       Py_TPFLAGS_DEFAULT,
569       ScalarMapContainer_Type_slots
570   };
571 #else
572   static PyMappingMethods ScalarMapMappingMethods = {
573     MapReflectionFriend::Length,             // mp_length
574     MapReflectionFriend::ScalarMapGetItem,   // mp_subscript
575     MapReflectionFriend::ScalarMapSetItem,   // mp_ass_subscript
576   };
577 
578   PyTypeObject _ScalarMapContainer_Type = {
579     PyVarObject_HEAD_INIT(&PyType_Type, 0)
580     FULL_MODULE_NAME ".ScalarMapContainer",  //  tp_name
581     sizeof(MapContainer),                //  tp_basicsize
582     0,                                   //  tp_itemsize
583     ScalarMapDealloc,                    //  tp_dealloc
584     0,                                   //  tp_print
585     0,                                   //  tp_getattr
586     0,                                   //  tp_setattr
587     0,                                   //  tp_compare
588     MapReflectionFriend::ScalarMapToStr,  //  tp_repr
589     0,                                   //  tp_as_number
590     0,                                   //  tp_as_sequence
591     &ScalarMapMappingMethods,            //  tp_as_mapping
592     0,                                   //  tp_hash
593     0,                                   //  tp_call
594     0,                                   //  tp_str
595     0,                                   //  tp_getattro
596     0,                                   //  tp_setattro
597     0,                                   //  tp_as_buffer
598     Py_TPFLAGS_DEFAULT,                  //  tp_flags
599     "A scalar map container",            //  tp_doc
600     0,                                   //  tp_traverse
601     0,                                   //  tp_clear
602     0,                                   //  tp_richcompare
603     0,                                   //  tp_weaklistoffset
604     MapReflectionFriend::GetIterator,    //  tp_iter
605     0,                                   //  tp_iternext
606     ScalarMapMethods,                    //  tp_methods
607     0,                                   //  tp_members
608     0,                                   //  tp_getset
609     0,                                   //  tp_base
610     0,                                   //  tp_dict
611     0,                                   //  tp_descr_get
612     0,                                   //  tp_descr_set
613     0,                                   //  tp_dictoffset
614     0,                                   //  tp_init
615   };
616 #endif
617 
618 
619 // MessageMap //////////////////////////////////////////////////////////////////
620 
GetMessageMap(PyObject * obj)621 static MessageMapContainer* GetMessageMap(PyObject* obj) {
622   return reinterpret_cast<MessageMapContainer*>(obj);
623 }
624 
GetCMessage(MessageMapContainer * self,Message * message)625 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
626   // Get or create the CMessage object corresponding to this message.
627   return self->parent
628       ->BuildSubMessageFromPointer(self->parent_field_descriptor, message,
629                                    self->message_class)
630       ->AsPyObject();
631 }
632 
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)633 MessageMapContainer* NewMessageMapContainer(
634     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
635     CMessageClass* message_class) {
636   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
637     return NULL;
638   }
639 
640   PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
641   if (obj == NULL) {
642     PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container.");
643     return NULL;
644   }
645 
646   MessageMapContainer* self = GetMessageMap(obj);
647 
648   Py_INCREF(parent);
649   self->parent = parent;
650   self->parent_field_descriptor = parent_field_descriptor;
651   self->version = 0;
652 
653   Py_INCREF(message_class);
654   self->message_class = message_class;
655 
656   return self;
657 }
658 
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)659 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
660                                            PyObject* v) {
661   if (v) {
662     PyErr_Format(PyExc_ValueError,
663                  "Direct assignment of submessage not allowed");
664     return -1;
665   }
666 
667   // Now we know that this is a delete, not a set.
668 
669   MessageMapContainer* self = GetMessageMap(_self);
670   Message* message = self->GetMutableMessage();
671   const Reflection* reflection = message->GetReflection();
672   MapKey map_key;
673   MapValueRef value;
674 
675   self->version++;
676 
677   if (!PythonToMapKey(self, key, &map_key)) {
678     return -1;
679   }
680 
681   // Delete key from map.
682   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
683                                  map_key)) {
684     // Delete key from CMessage dict.
685     MapValueRef value;
686     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
687                                        map_key, &value);
688     Message* sub_message = value.MutableMessageValue();
689     // If there is a living weak reference to an item, we "Release" it,
690     // otherwise we just discard the C++ value.
691     if (CMessage* released =
692             self->parent->MaybeReleaseSubMessage(sub_message)) {
693       Message* msg = released->message;
694       released->message = msg->New();
695       msg->GetReflection()->Swap(msg, released->message);
696     }
697 
698     // Delete key from map.
699     reflection->DeleteMapValue(message, self->parent_field_descriptor,
700                                map_key);
701     return 0;
702   } else {
703     PyErr_Format(PyExc_KeyError, "Key not present in map");
704     return -1;
705   }
706 }
707 
MessageMapGetItem(PyObject * _self,PyObject * key)708 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
709                                                  PyObject* key) {
710   MessageMapContainer* self = GetMessageMap(_self);
711 
712   Message* message = self->GetMutableMessage();
713   const Reflection* reflection = message->GetReflection();
714   MapKey map_key;
715   MapValueRef value;
716 
717   if (!PythonToMapKey(self, key, &map_key)) {
718     return NULL;
719   }
720 
721   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
722                                          map_key, &value)) {
723     self->version++;
724   }
725 
726   return GetCMessage(self, value.MutableMessageValue());
727 }
728 
MessageMapToStr(PyObject * _self)729 PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
730   ScopedPyObjectPtr dict(PyDict_New());
731   if (dict == NULL) {
732     return NULL;
733   }
734   ScopedPyObjectPtr key;
735   ScopedPyObjectPtr value;
736 
737   MessageMapContainer* self = GetMessageMap(_self);
738   Message* message = self->GetMutableMessage();
739   const Reflection* reflection = message->GetReflection();
740   for (google::protobuf::MapIterator it = reflection->MapBegin(
741            message, self->parent_field_descriptor);
742        it != reflection->MapEnd(message, self->parent_field_descriptor);
743        ++it) {
744     key.reset(MapKeyToPython(self, it.GetKey()));
745     if (key == NULL) {
746       return NULL;
747     }
748     value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue()));
749     if (value == NULL) {
750       return NULL;
751     }
752     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
753       return NULL;
754     }
755   }
756   return PyObject_Repr(dict.get());
757 }
758 
MessageMapGet(PyObject * self,PyObject * args,PyObject * kwargs)759 PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
760   static char* kwlist[] = {"key", "default", nullptr};
761   PyObject* key;
762   PyObject* default_value = NULL;
763   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, &key,
764                                    &default_value)) {
765     return NULL;
766   }
767 
768   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
769   if (is_present.get() == NULL) {
770     return NULL;
771   }
772 
773   if (PyObject_IsTrue(is_present.get())) {
774     return MapReflectionFriend::MessageMapGetItem(self, key);
775   } else {
776     if (default_value != NULL) {
777       Py_INCREF(default_value);
778       return default_value;
779     } else {
780       Py_RETURN_NONE;
781     }
782   }
783 }
784 
MessageMapDealloc(PyObject * _self)785 static void MessageMapDealloc(PyObject* _self) {
786   MessageMapContainer* self = GetMessageMap(_self);
787   self->RemoveFromParentCache();
788   Py_DECREF(self->message_class);
789   PyTypeObject *type = Py_TYPE(_self);
790   type->tp_free(_self);
791   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
792     // With Python3, the Map class is not static, and must be managed.
793     Py_DECREF(type);
794   }
795 }
796 
797 static PyMethodDef MessageMapMethods[] = {
798     {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
799      "Tests whether the map contains this element."},
800     {"clear", (PyCFunction)Clear, METH_NOARGS,
801      "Removes all elements from the map."},
802     {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
803      "Gets the value for the given key if present, or otherwise a default"},
804     {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
805      "Alias for getitem, useful to make explicit that the map is mutated."},
806     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
807      "Return the class used to build Entries of (key, value) pairs."},
808     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
809      "Merges a map into the current map."},
810     /*
811     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
812       "Makes a deep copy of the class." },
813     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
814       "Outputs picklable representation of the repeated field." },
815     */
816     {NULL, NULL},
817 };
818 
819 PyTypeObject *MessageMapContainer_Type;
820 #if PY_MAJOR_VERSION >= 3
821   static PyType_Slot MessageMapContainer_Type_slots[] = {
822       {Py_tp_dealloc, (void *)MessageMapDealloc},
823       {Py_mp_length, (void *)MapReflectionFriend::Length},
824       {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
825       {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
826       {Py_tp_methods, (void *)MessageMapMethods},
827       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
828       {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr},
829       {0, 0}
830   };
831 
832   PyType_Spec MessageMapContainer_Type_spec = {
833       FULL_MODULE_NAME ".MessageMapContainer",
834       sizeof(MessageMapContainer),
835       0,
836       Py_TPFLAGS_DEFAULT,
837       MessageMapContainer_Type_slots
838   };
839 #else
840   static PyMappingMethods MessageMapMappingMethods = {
841     MapReflectionFriend::Length,              // mp_length
842     MapReflectionFriend::MessageMapGetItem,   // mp_subscript
843     MapReflectionFriend::MessageMapSetItem,   // mp_ass_subscript
844   };
845 
846   PyTypeObject _MessageMapContainer_Type = {
847     PyVarObject_HEAD_INIT(&PyType_Type, 0)
848     FULL_MODULE_NAME ".MessageMapContainer",  //  tp_name
849     sizeof(MessageMapContainer),         //  tp_basicsize
850     0,                                   //  tp_itemsize
851     MessageMapDealloc,                   //  tp_dealloc
852     0,                                   //  tp_print
853     0,                                   //  tp_getattr
854     0,                                   //  tp_setattr
855     0,                                   //  tp_compare
856     MapReflectionFriend::MessageMapToStr,  //  tp_repr
857     0,                                   //  tp_as_number
858     0,                                   //  tp_as_sequence
859     &MessageMapMappingMethods,           //  tp_as_mapping
860     0,                                   //  tp_hash
861     0,                                   //  tp_call
862     0,                                   //  tp_str
863     0,                                   //  tp_getattro
864     0,                                   //  tp_setattro
865     0,                                   //  tp_as_buffer
866     Py_TPFLAGS_DEFAULT,                  //  tp_flags
867     "A map container for message",       //  tp_doc
868     0,                                   //  tp_traverse
869     0,                                   //  tp_clear
870     0,                                   //  tp_richcompare
871     0,                                   //  tp_weaklistoffset
872     MapReflectionFriend::GetIterator,    //  tp_iter
873     0,                                   //  tp_iternext
874     MessageMapMethods,                   //  tp_methods
875     0,                                   //  tp_members
876     0,                                   //  tp_getset
877     0,                                   //  tp_base
878     0,                                   //  tp_dict
879     0,                                   //  tp_descr_get
880     0,                                   //  tp_descr_set
881     0,                                   //  tp_dictoffset
882     0,                                   //  tp_init
883   };
884 #endif
885 
886 // MapIterator /////////////////////////////////////////////////////////////////
887 
GetIter(PyObject * obj)888 static MapIterator* GetIter(PyObject* obj) {
889   return reinterpret_cast<MapIterator*>(obj);
890 }
891 
GetIterator(PyObject * _self)892 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
893   MapContainer* self = GetMap(_self);
894 
895   ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
896   if (obj == NULL) {
897     return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
898   }
899 
900   MapIterator* iter = GetIter(obj.get());
901 
902   Py_INCREF(self);
903   iter->container = self;
904   iter->version = self->version;
905   Py_INCREF(self->parent);
906   iter->parent = self->parent;
907 
908   if (MapReflectionFriend::Length(_self) > 0) {
909     Message* message = self->GetMutableMessage();
910     const Reflection* reflection = message->GetReflection();
911 
912     iter->iter.reset(new ::google::protobuf::MapIterator(
913         reflection->MapBegin(message, self->parent_field_descriptor)));
914   }
915 
916   return obj.release();
917 }
918 
IterNext(PyObject * _self)919 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
920   MapIterator* self = GetIter(_self);
921 
922   // This won't catch mutations to the map performed by MergeFrom(); no easy way
923   // to address that.
924   if (self->version != self->container->version) {
925     return PyErr_Format(PyExc_RuntimeError,
926                         "Map modified during iteration.");
927   }
928   if (self->parent != self->container->parent) {
929     return PyErr_Format(PyExc_RuntimeError,
930                         "Map cleared during iteration.");
931   }
932 
933   if (self->iter.get() == NULL) {
934     return NULL;
935   }
936 
937   Message* message = self->container->GetMutableMessage();
938   const Reflection* reflection = message->GetReflection();
939 
940   if (*self->iter ==
941       reflection->MapEnd(message, self->container->parent_field_descriptor)) {
942     return NULL;
943   }
944 
945   PyObject* ret = MapKeyToPython(self->container, self->iter->GetKey());
946 
947   ++(*self->iter);
948 
949   return ret;
950 }
951 
DeallocMapIterator(PyObject * _self)952 static void DeallocMapIterator(PyObject* _self) {
953   MapIterator* self = GetIter(_self);
954   self->iter.reset();
955   Py_CLEAR(self->container);
956   Py_CLEAR(self->parent);
957   Py_TYPE(_self)->tp_free(_self);
958 }
959 
960 PyTypeObject MapIterator_Type = {
961   PyVarObject_HEAD_INIT(&PyType_Type, 0)
962   FULL_MODULE_NAME ".MapIterator",     //  tp_name
963   sizeof(MapIterator),                 //  tp_basicsize
964   0,                                   //  tp_itemsize
965   DeallocMapIterator,                  //  tp_dealloc
966   0,                                   //  tp_print
967   0,                                   //  tp_getattr
968   0,                                   //  tp_setattr
969   0,                                   //  tp_compare
970   0,                                   //  tp_repr
971   0,                                   //  tp_as_number
972   0,                                   //  tp_as_sequence
973   0,                                   //  tp_as_mapping
974   0,                                   //  tp_hash
975   0,                                   //  tp_call
976   0,                                   //  tp_str
977   0,                                   //  tp_getattro
978   0,                                   //  tp_setattro
979   0,                                   //  tp_as_buffer
980   Py_TPFLAGS_DEFAULT,                  //  tp_flags
981   "A scalar map iterator",             //  tp_doc
982   0,                                   //  tp_traverse
983   0,                                   //  tp_clear
984   0,                                   //  tp_richcompare
985   0,                                   //  tp_weaklistoffset
986   PyObject_SelfIter,                   //  tp_iter
987   MapReflectionFriend::IterNext,       //  tp_iternext
988   0,                                   //  tp_methods
989   0,                                   //  tp_members
990   0,                                   //  tp_getset
991   0,                                   //  tp_base
992   0,                                   //  tp_dict
993   0,                                   //  tp_descr_get
994   0,                                   //  tp_descr_set
995   0,                                   //  tp_dictoffset
996   0,                                   //  tp_init
997 };
998 
InitMapContainers()999 bool InitMapContainers() {
1000   // ScalarMapContainer_Type derives from our MutableMapping type.
1001   ScopedPyObjectPtr containers(PyImport_ImportModule(
1002       "google.protobuf.internal.containers"));
1003   if (containers == NULL) {
1004     return false;
1005   }
1006 
1007   ScopedPyObjectPtr mutable_mapping(
1008       PyObject_GetAttrString(containers.get(), "MutableMapping"));
1009   if (mutable_mapping == NULL) {
1010     return false;
1011   }
1012 
1013   Py_INCREF(mutable_mapping.get());
1014 #if PY_MAJOR_VERSION >= 3
1015   ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
1016   if (bases == NULL) {
1017     return false;
1018   }
1019 
1020   ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1021       PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
1022 #else
1023   _ScalarMapContainer_Type.tp_base =
1024       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1025 
1026   if (PyType_Ready(&_ScalarMapContainer_Type) < 0) {
1027     return false;
1028   }
1029 
1030   ScalarMapContainer_Type = &_ScalarMapContainer_Type;
1031 #endif
1032 
1033   if (PyType_Ready(&MapIterator_Type) < 0) {
1034     return false;
1035   }
1036 
1037 #if PY_MAJOR_VERSION >= 3
1038   MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1039       PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
1040 #else
1041   Py_INCREF(mutable_mapping.get());
1042   _MessageMapContainer_Type.tp_base =
1043       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1044 
1045   if (PyType_Ready(&_MessageMapContainer_Type) < 0) {
1046     return false;
1047   }
1048 
1049   MessageMapContainer_Type = &_MessageMapContainer_Type;
1050 #endif
1051   return true;
1052 }
1053 
1054 }  // namespace python
1055 }  // namespace protobuf
1056 }  // namespace google
1057