• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: haberman@google.com (Josh Haberman)
32 
33 #include <google/protobuf/pyext/map_container.h>
34 
35 #include <memory>
36 
37 #include <google/protobuf/stubs/logging.h>
38 #include <google/protobuf/stubs/common.h>
39 #include <google/protobuf/map_field.h>
40 #include <google/protobuf/map.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/pyext/message_factory.h>
43 #include <google/protobuf/pyext/message.h>
44 #include <google/protobuf/pyext/repeated_composite_container.h>
45 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
46 #include <google/protobuf/stubs/map_util.h>
47 
48 #if PY_MAJOR_VERSION >= 3
49   #define PyInt_FromLong PyLong_FromLong
50   #define PyInt_FromSize_t PyLong_FromSize_t
51 #endif
52 
53 namespace google {
54 namespace protobuf {
55 namespace python {
56 
57 // Functions that need access to map reflection functionality.
58 // They need to be contained in this class because it is friended.
59 class MapReflectionFriend {
60  public:
61   // Methods that are in common between the map types.
62   static PyObject* Contains(PyObject* _self, PyObject* key);
63   static Py_ssize_t Length(PyObject* _self);
64   static PyObject* GetIterator(PyObject *_self);
65   static PyObject* IterNext(PyObject* _self);
66   static PyObject* MergeFrom(PyObject* _self, PyObject* arg);
67 
68   // Methods that differ between the map types.
69   static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
70   static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
71   static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
72   static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
73   static PyObject* ScalarMapToStr(PyObject* _self);
74   static PyObject* MessageMapToStr(PyObject* _self);
75 };
76 
77 struct MapIterator {
78   PyObject_HEAD;
79 
80   std::unique_ptr<::google::protobuf::MapIterator> iter;
81 
82   // A pointer back to the container, so we can notice changes to the version.
83   // We own a ref on this.
84   MapContainer* container;
85 
86   // We need to keep a ref on the parent Message too, because
87   // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
88   // the ref on container (above) would guarantee outlive semantics.  However in
89   // the case of ClearField(), the MapContainer points to a different message,
90   // a copy of the original.  But our iterator still points to the original,
91   // which could now get deleted before us.
92   //
93   // To prevent this, we ensure that the Message will always stay alive as long
94   // as this iterator does.  This is solely for the benefit of the MapIterator
95   // destructor -- we should never actually access the iterator in this state
96   // except to delete it.
97   CMessage* parent;
98   // The version of the map when we took the iterator to it.
99   //
100   // We store this so that if the map is modified during iteration we can throw
101   // an error.
102   uint64 version;
103 };
104 
GetMutableMessage()105 Message* MapContainer::GetMutableMessage() {
106   cmessage::AssureWritable(parent);
107   return parent->message;
108 }
109 
110 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,string * stl_string)111 static bool PyStringToSTL(PyObject* py_string, string* stl_string) {
112   char *value;
113   Py_ssize_t value_len;
114 
115   if (!py_string) {
116     return false;
117   }
118   if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
119     Py_DECREF(py_string);
120     return false;
121   } else {
122     stl_string->assign(value, value_len);
123     Py_DECREF(py_string);
124     return true;
125   }
126 }
127 
PythonToMapKey(PyObject * obj,const FieldDescriptor * field_descriptor,MapKey * key)128 static bool PythonToMapKey(PyObject* obj,
129                            const FieldDescriptor* field_descriptor,
130                            MapKey* key) {
131   switch (field_descriptor->cpp_type()) {
132     case FieldDescriptor::CPPTYPE_INT32: {
133       GOOGLE_CHECK_GET_INT32(obj, value, false);
134       key->SetInt32Value(value);
135       break;
136     }
137     case FieldDescriptor::CPPTYPE_INT64: {
138       GOOGLE_CHECK_GET_INT64(obj, value, false);
139       key->SetInt64Value(value);
140       break;
141     }
142     case FieldDescriptor::CPPTYPE_UINT32: {
143       GOOGLE_CHECK_GET_UINT32(obj, value, false);
144       key->SetUInt32Value(value);
145       break;
146     }
147     case FieldDescriptor::CPPTYPE_UINT64: {
148       GOOGLE_CHECK_GET_UINT64(obj, value, false);
149       key->SetUInt64Value(value);
150       break;
151     }
152     case FieldDescriptor::CPPTYPE_BOOL: {
153       GOOGLE_CHECK_GET_BOOL(obj, value, false);
154       key->SetBoolValue(value);
155       break;
156     }
157     case FieldDescriptor::CPPTYPE_STRING: {
158       string str;
159       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
160         return false;
161       }
162       key->SetStringValue(str);
163       break;
164     }
165     default:
166       PyErr_Format(
167           PyExc_SystemError, "Type %d cannot be a map key",
168           field_descriptor->cpp_type());
169       return false;
170   }
171   return true;
172 }
173 
MapKeyToPython(const FieldDescriptor * field_descriptor,const MapKey & key)174 static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
175                                 const MapKey& key) {
176   switch (field_descriptor->cpp_type()) {
177     case FieldDescriptor::CPPTYPE_INT32:
178       return PyInt_FromLong(key.GetInt32Value());
179     case FieldDescriptor::CPPTYPE_INT64:
180       return PyLong_FromLongLong(key.GetInt64Value());
181     case FieldDescriptor::CPPTYPE_UINT32:
182       return PyInt_FromSize_t(key.GetUInt32Value());
183     case FieldDescriptor::CPPTYPE_UINT64:
184       return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
185     case FieldDescriptor::CPPTYPE_BOOL:
186       return PyBool_FromLong(key.GetBoolValue());
187     case FieldDescriptor::CPPTYPE_STRING:
188       return ToStringObject(field_descriptor, key.GetStringValue());
189     default:
190       PyErr_Format(
191           PyExc_SystemError, "Couldn't convert type %d to value",
192           field_descriptor->cpp_type());
193       return NULL;
194   }
195 }
196 
197 // This is only used for ScalarMap, so we don't need to handle the
198 // CPPTYPE_MESSAGE case.
MapValueRefToPython(const FieldDescriptor * field_descriptor,const MapValueRef & value)199 PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
200                               const MapValueRef& value) {
201   switch (field_descriptor->cpp_type()) {
202     case FieldDescriptor::CPPTYPE_INT32:
203       return PyInt_FromLong(value.GetInt32Value());
204     case FieldDescriptor::CPPTYPE_INT64:
205       return PyLong_FromLongLong(value.GetInt64Value());
206     case FieldDescriptor::CPPTYPE_UINT32:
207       return PyInt_FromSize_t(value.GetUInt32Value());
208     case FieldDescriptor::CPPTYPE_UINT64:
209       return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
210     case FieldDescriptor::CPPTYPE_FLOAT:
211       return PyFloat_FromDouble(value.GetFloatValue());
212     case FieldDescriptor::CPPTYPE_DOUBLE:
213       return PyFloat_FromDouble(value.GetDoubleValue());
214     case FieldDescriptor::CPPTYPE_BOOL:
215       return PyBool_FromLong(value.GetBoolValue());
216     case FieldDescriptor::CPPTYPE_STRING:
217       return ToStringObject(field_descriptor, value.GetStringValue());
218     case FieldDescriptor::CPPTYPE_ENUM:
219       return PyInt_FromLong(value.GetEnumValue());
220     default:
221       PyErr_Format(
222           PyExc_SystemError, "Couldn't convert type %d to value",
223           field_descriptor->cpp_type());
224       return NULL;
225   }
226 }
227 
228 // This is only used for ScalarMap, so we don't need to handle the
229 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(PyObject * obj,const FieldDescriptor * field_descriptor,bool allow_unknown_enum_values,MapValueRef * value_ref)230 static bool PythonToMapValueRef(PyObject* obj,
231                                 const FieldDescriptor* field_descriptor,
232                                 bool allow_unknown_enum_values,
233                                 MapValueRef* value_ref) {
234   switch (field_descriptor->cpp_type()) {
235     case FieldDescriptor::CPPTYPE_INT32: {
236       GOOGLE_CHECK_GET_INT32(obj, value, false);
237       value_ref->SetInt32Value(value);
238       return true;
239     }
240     case FieldDescriptor::CPPTYPE_INT64: {
241       GOOGLE_CHECK_GET_INT64(obj, value, false);
242       value_ref->SetInt64Value(value);
243       return true;
244     }
245     case FieldDescriptor::CPPTYPE_UINT32: {
246       GOOGLE_CHECK_GET_UINT32(obj, value, false);
247       value_ref->SetUInt32Value(value);
248       return true;
249     }
250     case FieldDescriptor::CPPTYPE_UINT64: {
251       GOOGLE_CHECK_GET_UINT64(obj, value, false);
252       value_ref->SetUInt64Value(value);
253       return true;
254     }
255     case FieldDescriptor::CPPTYPE_FLOAT: {
256       GOOGLE_CHECK_GET_FLOAT(obj, value, false);
257       value_ref->SetFloatValue(value);
258       return true;
259     }
260     case FieldDescriptor::CPPTYPE_DOUBLE: {
261       GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
262       value_ref->SetDoubleValue(value);
263       return true;
264     }
265     case FieldDescriptor::CPPTYPE_BOOL: {
266       GOOGLE_CHECK_GET_BOOL(obj, value, false);
267       value_ref->SetBoolValue(value);
268       return true;;
269     }
270     case FieldDescriptor::CPPTYPE_STRING: {
271       string str;
272       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
273         return false;
274       }
275       value_ref->SetStringValue(str);
276       return true;
277     }
278     case FieldDescriptor::CPPTYPE_ENUM: {
279       GOOGLE_CHECK_GET_INT32(obj, value, false);
280       if (allow_unknown_enum_values) {
281         value_ref->SetEnumValue(value);
282         return true;
283       } else {
284         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
285         const EnumValueDescriptor* enum_value =
286             enum_descriptor->FindValueByNumber(value);
287         if (enum_value != NULL) {
288           value_ref->SetEnumValue(value);
289           return true;
290         } else {
291           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
292           return false;
293         }
294       }
295       break;
296     }
297     default:
298       PyErr_Format(
299           PyExc_SystemError, "Setting value to a field of unknown type %d",
300           field_descriptor->cpp_type());
301       return false;
302   }
303 }
304 
305 // Map methods common to ScalarMap and MessageMap //////////////////////////////
306 
GetMap(PyObject * obj)307 static MapContainer* GetMap(PyObject* obj) {
308   return reinterpret_cast<MapContainer*>(obj);
309 }
310 
Length(PyObject * _self)311 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
312   MapContainer* self = GetMap(_self);
313   const google::protobuf::Message* message = self->parent->message;
314   return message->GetReflection()->MapSize(*message,
315                                            self->parent_field_descriptor);
316 }
317 
Clear(PyObject * _self)318 PyObject* Clear(PyObject* _self) {
319   MapContainer* self = GetMap(_self);
320   Message* message = self->GetMutableMessage();
321   const Reflection* reflection = message->GetReflection();
322 
323   reflection->ClearField(message, self->parent_field_descriptor);
324 
325   Py_RETURN_NONE;
326 }
327 
GetEntryClass(PyObject * _self)328 PyObject* GetEntryClass(PyObject* _self) {
329   MapContainer* self = GetMap(_self);
330   CMessageClass* message_class = message_factory::GetMessageClass(
331       cmessage::GetFactoryForMessage(self->parent),
332       self->parent_field_descriptor->message_type());
333   Py_XINCREF(message_class);
334   return reinterpret_cast<PyObject*>(message_class);
335 }
336 
MergeFrom(PyObject * _self,PyObject * arg)337 PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
338   MapContainer* self = GetMap(_self);
339   MapContainer* other_map = GetMap(arg);
340   Message* message = self->GetMutableMessage();
341   const Message* other_message = other_map->parent->message;
342   const Reflection* reflection = message->GetReflection();
343   const Reflection* other_reflection = other_message->GetReflection();
344   internal::MapFieldBase* field = reflection->MutableMapData(
345       message, self->parent_field_descriptor);
346   const internal::MapFieldBase* other_field =
347       other_reflection->GetMapData(*other_message,
348                                    self->parent_field_descriptor);
349   field->MergeFrom(*other_field);
350   self->version++;
351   Py_RETURN_NONE;
352 }
353 
Contains(PyObject * _self,PyObject * key)354 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
355   MapContainer* self = GetMap(_self);
356 
357   const Message* message = self->parent->message;
358   const Reflection* reflection = message->GetReflection();
359   MapKey map_key;
360 
361   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
362     return NULL;
363   }
364 
365   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
366                                  map_key)) {
367     Py_RETURN_TRUE;
368   } else {
369     Py_RETURN_FALSE;
370   }
371 }
372 
373 // ScalarMap ///////////////////////////////////////////////////////////////////
374 
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)375 MapContainer* NewScalarMapContainer(
376     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
377   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
378     return NULL;
379   }
380 
381   PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
382   if (obj == NULL) {
383     PyErr_Format(PyExc_RuntimeError,
384                  "Could not allocate new container.");
385     return NULL;
386   }
387 
388   MapContainer* self = GetMap(obj);
389 
390   Py_INCREF(parent);
391   self->parent = parent;
392   self->parent_field_descriptor = parent_field_descriptor;
393   self->version = 0;
394 
395   self->key_field_descriptor =
396       parent_field_descriptor->message_type()->FindFieldByName("key");
397   self->value_field_descriptor =
398       parent_field_descriptor->message_type()->FindFieldByName("value");
399 
400   if (self->key_field_descriptor == NULL ||
401       self->value_field_descriptor == NULL) {
402     PyErr_Format(PyExc_KeyError,
403                  "Map entry descriptor did not have key/value fields");
404     return NULL;
405   }
406 
407   return self;
408 }
409 
ScalarMapGetItem(PyObject * _self,PyObject * key)410 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
411                                                 PyObject* key) {
412   MapContainer* self = GetMap(_self);
413 
414   Message* message = self->GetMutableMessage();
415   const Reflection* reflection = message->GetReflection();
416   MapKey map_key;
417   MapValueRef value;
418 
419   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
420     return NULL;
421   }
422 
423   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
424                                          map_key, &value)) {
425     self->version++;
426   }
427 
428   return MapValueRefToPython(self->value_field_descriptor, value);
429 }
430 
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)431 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
432                                           PyObject* v) {
433   MapContainer* self = GetMap(_self);
434 
435   Message* message = self->GetMutableMessage();
436   const Reflection* reflection = message->GetReflection();
437   MapKey map_key;
438   MapValueRef value;
439 
440   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
441     return -1;
442   }
443 
444   self->version++;
445 
446   if (v) {
447     // Set item to v.
448     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
449                                        map_key, &value);
450 
451     return PythonToMapValueRef(v, self->value_field_descriptor,
452                                reflection->SupportsUnknownEnumValues(), &value)
453                ? 0
454                : -1;
455   } else {
456     // Delete key from map.
457     if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
458                                    map_key)) {
459       return 0;
460     } else {
461       PyErr_Format(PyExc_KeyError, "Key not present in map");
462       return -1;
463     }
464   }
465 }
466 
ScalarMapGet(PyObject * self,PyObject * args,PyObject * kwargs)467 static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
468                               PyObject* kwargs) {
469   static char* kwlist[] = {"key", "default", nullptr};
470   PyObject* key;
471   PyObject* default_value = NULL;
472   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, &key,
473                                    &default_value)) {
474     return NULL;
475   }
476 
477   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
478   if (is_present.get() == NULL) {
479     return NULL;
480   }
481 
482   if (PyObject_IsTrue(is_present.get())) {
483     return MapReflectionFriend::ScalarMapGetItem(self, key);
484   } else {
485     if (default_value != NULL) {
486       Py_INCREF(default_value);
487       return default_value;
488     } else {
489       Py_RETURN_NONE;
490     }
491   }
492 }
493 
ScalarMapToStr(PyObject * _self)494 PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
495   ScopedPyObjectPtr dict(PyDict_New());
496   if (dict == NULL) {
497     return NULL;
498   }
499   ScopedPyObjectPtr key;
500   ScopedPyObjectPtr value;
501 
502   MapContainer* self = GetMap(_self);
503   Message* message = self->GetMutableMessage();
504   const Reflection* reflection = message->GetReflection();
505   for (google::protobuf::MapIterator it = reflection->MapBegin(
506            message, self->parent_field_descriptor);
507        it != reflection->MapEnd(message, self->parent_field_descriptor);
508        ++it) {
509     key.reset(MapKeyToPython(self->key_field_descriptor,
510                              it.GetKey()));
511     if (key == NULL) {
512       return NULL;
513     }
514     value.reset(MapValueRefToPython(self->value_field_descriptor,
515                                     it.GetValueRef()));
516     if (value == NULL) {
517       return NULL;
518     }
519     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
520       return NULL;
521     }
522   }
523   return PyObject_Repr(dict.get());
524 }
525 
ScalarMapDealloc(PyObject * _self)526 static void ScalarMapDealloc(PyObject* _self) {
527   MapContainer* self = GetMap(_self);
528   self->RemoveFromParentCache();
529   PyTypeObject *type = Py_TYPE(_self);
530   type->tp_free(_self);
531   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
532     // With Python3, the Map class is not static, and must be managed.
533     Py_DECREF(type);
534   }
535 }
536 
537 static PyMethodDef ScalarMapMethods[] = {
538     {"__contains__", MapReflectionFriend::Contains, METH_O,
539      "Tests whether a key is a member of the map."},
540     {"clear", (PyCFunction)Clear, METH_NOARGS,
541      "Removes all elements from the map."},
542     {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
543      "Gets the value for the given key if present, or otherwise a default"},
544     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
545      "Return the class used to build Entries of (key, value) pairs."},
546     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
547      "Merges a map into the current map."},
548     /*
549     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
550       "Makes a deep copy of the class." },
551     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
552       "Outputs picklable representation of the repeated field." },
553     */
554     {NULL, NULL},
555 };
556 
557 PyTypeObject *ScalarMapContainer_Type;
558 #if PY_MAJOR_VERSION >= 3
559   static PyType_Slot ScalarMapContainer_Type_slots[] = {
560       {Py_tp_dealloc, (void *)ScalarMapDealloc},
561       {Py_mp_length, (void *)MapReflectionFriend::Length},
562       {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
563       {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
564       {Py_tp_methods, (void *)ScalarMapMethods},
565       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
566       {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr},
567       {0, 0},
568   };
569 
570   PyType_Spec ScalarMapContainer_Type_spec = {
571       FULL_MODULE_NAME ".ScalarMapContainer",
572       sizeof(MapContainer),
573       0,
574       Py_TPFLAGS_DEFAULT,
575       ScalarMapContainer_Type_slots
576   };
577 #else
578   static PyMappingMethods ScalarMapMappingMethods = {
579     MapReflectionFriend::Length,             // mp_length
580     MapReflectionFriend::ScalarMapGetItem,   // mp_subscript
581     MapReflectionFriend::ScalarMapSetItem,   // mp_ass_subscript
582   };
583 
584   PyTypeObject _ScalarMapContainer_Type = {
585     PyVarObject_HEAD_INIT(&PyType_Type, 0)
586     FULL_MODULE_NAME ".ScalarMapContainer",  //  tp_name
587     sizeof(MapContainer),                //  tp_basicsize
588     0,                                   //  tp_itemsize
589     ScalarMapDealloc,                    //  tp_dealloc
590     0,                                   //  tp_print
591     0,                                   //  tp_getattr
592     0,                                   //  tp_setattr
593     0,                                   //  tp_compare
594     MapReflectionFriend::ScalarMapToStr,  //  tp_repr
595     0,                                   //  tp_as_number
596     0,                                   //  tp_as_sequence
597     &ScalarMapMappingMethods,            //  tp_as_mapping
598     0,                                   //  tp_hash
599     0,                                   //  tp_call
600     0,                                   //  tp_str
601     0,                                   //  tp_getattro
602     0,                                   //  tp_setattro
603     0,                                   //  tp_as_buffer
604     Py_TPFLAGS_DEFAULT,                  //  tp_flags
605     "A scalar map container",            //  tp_doc
606     0,                                   //  tp_traverse
607     0,                                   //  tp_clear
608     0,                                   //  tp_richcompare
609     0,                                   //  tp_weaklistoffset
610     MapReflectionFriend::GetIterator,    //  tp_iter
611     0,                                   //  tp_iternext
612     ScalarMapMethods,                    //  tp_methods
613     0,                                   //  tp_members
614     0,                                   //  tp_getset
615     0,                                   //  tp_base
616     0,                                   //  tp_dict
617     0,                                   //  tp_descr_get
618     0,                                   //  tp_descr_set
619     0,                                   //  tp_dictoffset
620     0,                                   //  tp_init
621   };
622 #endif
623 
624 
625 // MessageMap //////////////////////////////////////////////////////////////////
626 
GetMessageMap(PyObject * obj)627 static MessageMapContainer* GetMessageMap(PyObject* obj) {
628   return reinterpret_cast<MessageMapContainer*>(obj);
629 }
630 
GetCMessage(MessageMapContainer * self,Message * message)631 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
632   // Get or create the CMessage object corresponding to this message.
633   return self->parent
634       ->BuildSubMessageFromPointer(self->parent_field_descriptor, message,
635                                    self->message_class)
636       ->AsPyObject();
637 }
638 
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)639 MessageMapContainer* NewMessageMapContainer(
640     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
641     CMessageClass* message_class) {
642   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
643     return NULL;
644   }
645 
646   PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
647   if (obj == NULL) {
648     PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container.");
649     return NULL;
650   }
651 
652   MessageMapContainer* self = GetMessageMap(obj);
653 
654   Py_INCREF(parent);
655   self->parent = parent;
656   self->parent_field_descriptor = parent_field_descriptor;
657   self->version = 0;
658 
659   self->key_field_descriptor =
660       parent_field_descriptor->message_type()->FindFieldByName("key");
661   self->value_field_descriptor =
662       parent_field_descriptor->message_type()->FindFieldByName("value");
663 
664   Py_INCREF(message_class);
665   self->message_class = message_class;
666 
667   if (self->key_field_descriptor == NULL ||
668       self->value_field_descriptor == NULL) {
669     Py_DECREF(self);
670     PyErr_SetString(PyExc_KeyError,
671                     "Map entry descriptor did not have key/value fields");
672     return NULL;
673   }
674 
675   return self;
676 }
677 
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)678 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
679                                            PyObject* v) {
680   if (v) {
681     PyErr_Format(PyExc_ValueError,
682                  "Direct assignment of submessage not allowed");
683     return -1;
684   }
685 
686   // Now we know that this is a delete, not a set.
687 
688   MessageMapContainer* self = GetMessageMap(_self);
689   Message* message = self->GetMutableMessage();
690   const Reflection* reflection = message->GetReflection();
691   MapKey map_key;
692   MapValueRef value;
693 
694   self->version++;
695 
696   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
697     return -1;
698   }
699 
700   // Delete key from map.
701   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
702                                  map_key)) {
703     // Delete key from CMessage dict.
704     MapValueRef value;
705     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
706                                        map_key, &value);
707     Message* sub_message = value.MutableMessageValue();
708     // If there is a living weak reference to an item, we "Release" it,
709     // otherwise we just discard the C++ value.
710     if (CMessage* released =
711             self->parent->MaybeReleaseSubMessage(sub_message)) {
712       Message* msg = released->message;
713       released->message = msg->New();
714       msg->GetReflection()->Swap(msg, released->message);
715     }
716 
717     // Delete key from map.
718     reflection->DeleteMapValue(message, self->parent_field_descriptor,
719                                map_key);
720     return 0;
721   } else {
722     PyErr_Format(PyExc_KeyError, "Key not present in map");
723     return -1;
724   }
725 }
726 
MessageMapGetItem(PyObject * _self,PyObject * key)727 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
728                                                  PyObject* key) {
729   MessageMapContainer* self = GetMessageMap(_self);
730 
731   Message* message = self->GetMutableMessage();
732   const Reflection* reflection = message->GetReflection();
733   MapKey map_key;
734   MapValueRef value;
735 
736   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
737     return NULL;
738   }
739 
740   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
741                                          map_key, &value)) {
742     self->version++;
743   }
744 
745   return GetCMessage(self, value.MutableMessageValue());
746 }
747 
MessageMapToStr(PyObject * _self)748 PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
749   ScopedPyObjectPtr dict(PyDict_New());
750   if (dict == NULL) {
751     return NULL;
752   }
753   ScopedPyObjectPtr key;
754   ScopedPyObjectPtr value;
755 
756   MessageMapContainer* self = GetMessageMap(_self);
757   Message* message = self->GetMutableMessage();
758   const Reflection* reflection = message->GetReflection();
759   for (google::protobuf::MapIterator it = reflection->MapBegin(
760            message, self->parent_field_descriptor);
761        it != reflection->MapEnd(message, self->parent_field_descriptor);
762        ++it) {
763     key.reset(MapKeyToPython(self->key_field_descriptor,
764                              it.GetKey()));
765     if (key == NULL) {
766       return NULL;
767     }
768     value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue()));
769     if (value == NULL) {
770       return NULL;
771     }
772     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
773       return NULL;
774     }
775   }
776   return PyObject_Repr(dict.get());
777 }
778 
MessageMapGet(PyObject * self,PyObject * args,PyObject * kwargs)779 PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
780   static char* kwlist[] = {"key", "default", nullptr};
781   PyObject* key;
782   PyObject* default_value = NULL;
783   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, &key,
784                                    &default_value)) {
785     return NULL;
786   }
787 
788   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
789   if (is_present.get() == NULL) {
790     return NULL;
791   }
792 
793   if (PyObject_IsTrue(is_present.get())) {
794     return MapReflectionFriend::MessageMapGetItem(self, key);
795   } else {
796     if (default_value != NULL) {
797       Py_INCREF(default_value);
798       return default_value;
799     } else {
800       Py_RETURN_NONE;
801     }
802   }
803 }
804 
MessageMapDealloc(PyObject * _self)805 static void MessageMapDealloc(PyObject* _self) {
806   MessageMapContainer* self = GetMessageMap(_self);
807   self->RemoveFromParentCache();
808   Py_DECREF(self->message_class);
809   PyTypeObject *type = Py_TYPE(_self);
810   type->tp_free(_self);
811   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
812     // With Python3, the Map class is not static, and must be managed.
813     Py_DECREF(type);
814   }
815 }
816 
817 static PyMethodDef MessageMapMethods[] = {
818     {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
819      "Tests whether the map contains this element."},
820     {"clear", (PyCFunction)Clear, METH_NOARGS,
821      "Removes all elements from the map."},
822     {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
823      "Gets the value for the given key if present, or otherwise a default"},
824     {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
825      "Alias for getitem, useful to make explicit that the map is mutated."},
826     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
827      "Return the class used to build Entries of (key, value) pairs."},
828     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
829      "Merges a map into the current map."},
830     /*
831     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
832       "Makes a deep copy of the class." },
833     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
834       "Outputs picklable representation of the repeated field." },
835     */
836     {NULL, NULL},
837 };
838 
839 PyTypeObject *MessageMapContainer_Type;
840 #if PY_MAJOR_VERSION >= 3
841   static PyType_Slot MessageMapContainer_Type_slots[] = {
842       {Py_tp_dealloc, (void *)MessageMapDealloc},
843       {Py_mp_length, (void *)MapReflectionFriend::Length},
844       {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
845       {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
846       {Py_tp_methods, (void *)MessageMapMethods},
847       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
848       {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr},
849       {0, 0}
850   };
851 
852   PyType_Spec MessageMapContainer_Type_spec = {
853       FULL_MODULE_NAME ".MessageMapContainer",
854       sizeof(MessageMapContainer),
855       0,
856       Py_TPFLAGS_DEFAULT,
857       MessageMapContainer_Type_slots
858   };
859 #else
860   static PyMappingMethods MessageMapMappingMethods = {
861     MapReflectionFriend::Length,              // mp_length
862     MapReflectionFriend::MessageMapGetItem,   // mp_subscript
863     MapReflectionFriend::MessageMapSetItem,   // mp_ass_subscript
864   };
865 
866   PyTypeObject _MessageMapContainer_Type = {
867     PyVarObject_HEAD_INIT(&PyType_Type, 0)
868     FULL_MODULE_NAME ".MessageMapContainer",  //  tp_name
869     sizeof(MessageMapContainer),         //  tp_basicsize
870     0,                                   //  tp_itemsize
871     MessageMapDealloc,                   //  tp_dealloc
872     0,                                   //  tp_print
873     0,                                   //  tp_getattr
874     0,                                   //  tp_setattr
875     0,                                   //  tp_compare
876     MapReflectionFriend::MessageMapToStr,  //  tp_repr
877     0,                                   //  tp_as_number
878     0,                                   //  tp_as_sequence
879     &MessageMapMappingMethods,           //  tp_as_mapping
880     0,                                   //  tp_hash
881     0,                                   //  tp_call
882     0,                                   //  tp_str
883     0,                                   //  tp_getattro
884     0,                                   //  tp_setattro
885     0,                                   //  tp_as_buffer
886     Py_TPFLAGS_DEFAULT,                  //  tp_flags
887     "A map container for message",       //  tp_doc
888     0,                                   //  tp_traverse
889     0,                                   //  tp_clear
890     0,                                   //  tp_richcompare
891     0,                                   //  tp_weaklistoffset
892     MapReflectionFriend::GetIterator,    //  tp_iter
893     0,                                   //  tp_iternext
894     MessageMapMethods,                   //  tp_methods
895     0,                                   //  tp_members
896     0,                                   //  tp_getset
897     0,                                   //  tp_base
898     0,                                   //  tp_dict
899     0,                                   //  tp_descr_get
900     0,                                   //  tp_descr_set
901     0,                                   //  tp_dictoffset
902     0,                                   //  tp_init
903   };
904 #endif
905 
906 // MapIterator /////////////////////////////////////////////////////////////////
907 
GetIter(PyObject * obj)908 static MapIterator* GetIter(PyObject* obj) {
909   return reinterpret_cast<MapIterator*>(obj);
910 }
911 
GetIterator(PyObject * _self)912 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
913   MapContainer* self = GetMap(_self);
914 
915   ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
916   if (obj == NULL) {
917     return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
918   }
919 
920   MapIterator* iter = GetIter(obj.get());
921 
922   Py_INCREF(self);
923   iter->container = self;
924   iter->version = self->version;
925   Py_INCREF(self->parent);
926   iter->parent = self->parent;
927 
928   if (MapReflectionFriend::Length(_self) > 0) {
929     Message* message = self->GetMutableMessage();
930     const Reflection* reflection = message->GetReflection();
931 
932     iter->iter.reset(new ::google::protobuf::MapIterator(
933         reflection->MapBegin(message, self->parent_field_descriptor)));
934   }
935 
936   return obj.release();
937 }
938 
IterNext(PyObject * _self)939 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
940   MapIterator* self = GetIter(_self);
941 
942   // This won't catch mutations to the map performed by MergeFrom(); no easy way
943   // to address that.
944   if (self->version != self->container->version) {
945     return PyErr_Format(PyExc_RuntimeError,
946                         "Map modified during iteration.");
947   }
948   if (self->parent != self->container->parent) {
949     return PyErr_Format(PyExc_RuntimeError,
950                         "Map cleared during iteration.");
951   }
952 
953   if (self->iter.get() == NULL) {
954     return NULL;
955   }
956 
957   Message* message = self->container->GetMutableMessage();
958   const Reflection* reflection = message->GetReflection();
959 
960   if (*self->iter ==
961       reflection->MapEnd(message, self->container->parent_field_descriptor)) {
962     return NULL;
963   }
964 
965   PyObject* ret = MapKeyToPython(self->container->key_field_descriptor,
966                                  self->iter->GetKey());
967 
968   ++(*self->iter);
969 
970   return ret;
971 }
972 
DeallocMapIterator(PyObject * _self)973 static void DeallocMapIterator(PyObject* _self) {
974   MapIterator* self = GetIter(_self);
975   self->iter.reset();
976   Py_CLEAR(self->container);
977   Py_CLEAR(self->parent);
978   Py_TYPE(_self)->tp_free(_self);
979 }
980 
981 PyTypeObject MapIterator_Type = {
982   PyVarObject_HEAD_INIT(&PyType_Type, 0)
983   FULL_MODULE_NAME ".MapIterator",     //  tp_name
984   sizeof(MapIterator),                 //  tp_basicsize
985   0,                                   //  tp_itemsize
986   DeallocMapIterator,                  //  tp_dealloc
987   0,                                   //  tp_print
988   0,                                   //  tp_getattr
989   0,                                   //  tp_setattr
990   0,                                   //  tp_compare
991   0,                                   //  tp_repr
992   0,                                   //  tp_as_number
993   0,                                   //  tp_as_sequence
994   0,                                   //  tp_as_mapping
995   0,                                   //  tp_hash
996   0,                                   //  tp_call
997   0,                                   //  tp_str
998   0,                                   //  tp_getattro
999   0,                                   //  tp_setattro
1000   0,                                   //  tp_as_buffer
1001   Py_TPFLAGS_DEFAULT,                  //  tp_flags
1002   "A scalar map iterator",             //  tp_doc
1003   0,                                   //  tp_traverse
1004   0,                                   //  tp_clear
1005   0,                                   //  tp_richcompare
1006   0,                                   //  tp_weaklistoffset
1007   PyObject_SelfIter,                   //  tp_iter
1008   MapReflectionFriend::IterNext,       //  tp_iternext
1009   0,                                   //  tp_methods
1010   0,                                   //  tp_members
1011   0,                                   //  tp_getset
1012   0,                                   //  tp_base
1013   0,                                   //  tp_dict
1014   0,                                   //  tp_descr_get
1015   0,                                   //  tp_descr_set
1016   0,                                   //  tp_dictoffset
1017   0,                                   //  tp_init
1018 };
1019 
InitMapContainers()1020 bool InitMapContainers() {
1021   // ScalarMapContainer_Type derives from our MutableMapping type.
1022   ScopedPyObjectPtr containers(PyImport_ImportModule(
1023       "google.protobuf.internal.containers"));
1024   if (containers == NULL) {
1025     return false;
1026   }
1027 
1028   ScopedPyObjectPtr mutable_mapping(
1029       PyObject_GetAttrString(containers.get(), "MutableMapping"));
1030   if (mutable_mapping == NULL) {
1031     return false;
1032   }
1033 
1034   Py_INCREF(mutable_mapping.get());
1035 #if PY_MAJOR_VERSION >= 3
1036   ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
1037   if (bases == NULL) {
1038     return false;
1039   }
1040 
1041   ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1042       PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
1043 #else
1044   _ScalarMapContainer_Type.tp_base =
1045       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1046 
1047   if (PyType_Ready(&_ScalarMapContainer_Type) < 0) {
1048     return false;
1049   }
1050 
1051   ScalarMapContainer_Type = &_ScalarMapContainer_Type;
1052 #endif
1053 
1054   if (PyType_Ready(&MapIterator_Type) < 0) {
1055     return false;
1056   }
1057 
1058 #if PY_MAJOR_VERSION >= 3
1059   MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1060       PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
1061 #else
1062   Py_INCREF(mutable_mapping.get());
1063   _MessageMapContainer_Type.tp_base =
1064       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1065 
1066   if (PyType_Ready(&_MessageMapContainer_Type) < 0) {
1067     return false;
1068   }
1069 
1070   MessageMapContainer_Type = &_MessageMapContainer_Type;
1071 #endif
1072   return true;
1073 }
1074 
1075 }  // namespace python
1076 }  // namespace protobuf
1077 }  // namespace google
1078