• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: haberman@google.com (Josh Haberman)
32 
33 #include <google/protobuf/pyext/map_container.h>
34 
35 #include <memory>
36 #ifndef _SHARED_PTR_H
37 #include <google/protobuf/stubs/shared_ptr.h>
38 #endif
39 
40 #include <google/protobuf/stubs/logging.h>
41 #include <google/protobuf/stubs/common.h>
42 #include <google/protobuf/map_field.h>
43 #include <google/protobuf/map.h>
44 #include <google/protobuf/message.h>
45 #include <google/protobuf/pyext/message.h>
46 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
47 
48 #if PY_MAJOR_VERSION >= 3
49   #define PyInt_FromLong PyLong_FromLong
50   #define PyInt_FromSize_t PyLong_FromSize_t
51 #endif
52 
53 namespace google {
54 namespace protobuf {
55 namespace python {
56 
57 // Functions that need access to map reflection functionality.
58 // They need to be contained in this class because it is friended.
59 class MapReflectionFriend {
60  public:
61   // Methods that are in common between the map types.
62   static PyObject* Contains(PyObject* _self, PyObject* key);
63   static Py_ssize_t Length(PyObject* _self);
64   static PyObject* GetIterator(PyObject *_self);
65   static PyObject* IterNext(PyObject* _self);
66 
67   // Methods that differ between the map types.
68   static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
69   static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
70   static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
71   static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
72 };
73 
74 struct MapIterator {
75   PyObject_HEAD;
76 
77   google::protobuf::scoped_ptr< ::google::protobuf::MapIterator> iter;
78 
79   // A pointer back to the container, so we can notice changes to the version.
80   // We own a ref on this.
81   MapContainer* container;
82 
83   // We need to keep a ref on the Message* too, because
84   // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
85   // the ref on container (above) would guarantee outlive semantics.  However in
86   // the case of ClearField(), InitializeAndCopyToParentContainer() resets the
87   // message pointer (and the owner) to a different message, a copy of the
88   // original.  But our iterator still points to the original, which could now
89   // get deleted before us.
90   //
91   // To prevent this, we ensure that the Message will always stay alive as long
92   // as this iterator does.  This is solely for the benefit of the MapIterator
93   // destructor -- we should never actually access the iterator in this state
94   // except to delete it.
95   shared_ptr<Message> owner;
96 
97   // The version of the map when we took the iterator to it.
98   //
99   // We store this so that if the map is modified during iteration we can throw
100   // an error.
101   uint64 version;
102 
103   // True if the container is empty.  We signal this separately to avoid calling
104   // any of the iteration methods, which are non-const.
105   bool empty;
106 };
107 
GetMutableMessage()108 Message* MapContainer::GetMutableMessage() {
109   cmessage::AssureWritable(parent);
110   return const_cast<Message*>(message);
111 }
112 
113 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,string * stl_string)114 static bool PyStringToSTL(PyObject* py_string, string* stl_string) {
115   char *value;
116   Py_ssize_t value_len;
117 
118   if (!py_string) {
119     return false;
120   }
121   if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
122     Py_DECREF(py_string);
123     return false;
124   } else {
125     stl_string->assign(value, value_len);
126     Py_DECREF(py_string);
127     return true;
128   }
129 }
130 
PythonToMapKey(PyObject * obj,const FieldDescriptor * field_descriptor,MapKey * key)131 static bool PythonToMapKey(PyObject* obj,
132                            const FieldDescriptor* field_descriptor,
133                            MapKey* key) {
134   switch (field_descriptor->cpp_type()) {
135     case FieldDescriptor::CPPTYPE_INT32: {
136       GOOGLE_CHECK_GET_INT32(obj, value, false);
137       key->SetInt32Value(value);
138       break;
139     }
140     case FieldDescriptor::CPPTYPE_INT64: {
141       GOOGLE_CHECK_GET_INT64(obj, value, false);
142       key->SetInt64Value(value);
143       break;
144     }
145     case FieldDescriptor::CPPTYPE_UINT32: {
146       GOOGLE_CHECK_GET_UINT32(obj, value, false);
147       key->SetUInt32Value(value);
148       break;
149     }
150     case FieldDescriptor::CPPTYPE_UINT64: {
151       GOOGLE_CHECK_GET_UINT64(obj, value, false);
152       key->SetUInt64Value(value);
153       break;
154     }
155     case FieldDescriptor::CPPTYPE_BOOL: {
156       GOOGLE_CHECK_GET_BOOL(obj, value, false);
157       key->SetBoolValue(value);
158       break;
159     }
160     case FieldDescriptor::CPPTYPE_STRING: {
161       string str;
162       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
163         return false;
164       }
165       key->SetStringValue(str);
166       break;
167     }
168     default:
169       PyErr_Format(
170           PyExc_SystemError, "Type %d cannot be a map key",
171           field_descriptor->cpp_type());
172       return false;
173   }
174   return true;
175 }
176 
MapKeyToPython(const FieldDescriptor * field_descriptor,const MapKey & key)177 static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
178                                 const MapKey& key) {
179   switch (field_descriptor->cpp_type()) {
180     case FieldDescriptor::CPPTYPE_INT32:
181       return PyInt_FromLong(key.GetInt32Value());
182     case FieldDescriptor::CPPTYPE_INT64:
183       return PyLong_FromLongLong(key.GetInt64Value());
184     case FieldDescriptor::CPPTYPE_UINT32:
185       return PyInt_FromSize_t(key.GetUInt32Value());
186     case FieldDescriptor::CPPTYPE_UINT64:
187       return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
188     case FieldDescriptor::CPPTYPE_BOOL:
189       return PyBool_FromLong(key.GetBoolValue());
190     case FieldDescriptor::CPPTYPE_STRING:
191       return ToStringObject(field_descriptor, key.GetStringValue());
192     default:
193       PyErr_Format(
194           PyExc_SystemError, "Couldn't convert type %d to value",
195           field_descriptor->cpp_type());
196       return NULL;
197   }
198 }
199 
200 // This is only used for ScalarMap, so we don't need to handle the
201 // CPPTYPE_MESSAGE case.
MapValueRefToPython(const FieldDescriptor * field_descriptor,MapValueRef * value)202 PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
203                               MapValueRef* value) {
204   switch (field_descriptor->cpp_type()) {
205     case FieldDescriptor::CPPTYPE_INT32:
206       return PyInt_FromLong(value->GetInt32Value());
207     case FieldDescriptor::CPPTYPE_INT64:
208       return PyLong_FromLongLong(value->GetInt64Value());
209     case FieldDescriptor::CPPTYPE_UINT32:
210       return PyInt_FromSize_t(value->GetUInt32Value());
211     case FieldDescriptor::CPPTYPE_UINT64:
212       return PyLong_FromUnsignedLongLong(value->GetUInt64Value());
213     case FieldDescriptor::CPPTYPE_FLOAT:
214       return PyFloat_FromDouble(value->GetFloatValue());
215     case FieldDescriptor::CPPTYPE_DOUBLE:
216       return PyFloat_FromDouble(value->GetDoubleValue());
217     case FieldDescriptor::CPPTYPE_BOOL:
218       return PyBool_FromLong(value->GetBoolValue());
219     case FieldDescriptor::CPPTYPE_STRING:
220       return ToStringObject(field_descriptor, value->GetStringValue());
221     case FieldDescriptor::CPPTYPE_ENUM:
222       return PyInt_FromLong(value->GetEnumValue());
223     default:
224       PyErr_Format(
225           PyExc_SystemError, "Couldn't convert type %d to value",
226           field_descriptor->cpp_type());
227       return NULL;
228   }
229 }
230 
231 // This is only used for ScalarMap, so we don't need to handle the
232 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(PyObject * obj,const FieldDescriptor * field_descriptor,bool allow_unknown_enum_values,MapValueRef * value_ref)233 static bool PythonToMapValueRef(PyObject* obj,
234                                 const FieldDescriptor* field_descriptor,
235                                 bool allow_unknown_enum_values,
236                                 MapValueRef* value_ref) {
237   switch (field_descriptor->cpp_type()) {
238     case FieldDescriptor::CPPTYPE_INT32: {
239       GOOGLE_CHECK_GET_INT32(obj, value, false);
240       value_ref->SetInt32Value(value);
241       return true;
242     }
243     case FieldDescriptor::CPPTYPE_INT64: {
244       GOOGLE_CHECK_GET_INT64(obj, value, false);
245       value_ref->SetInt64Value(value);
246       return true;
247     }
248     case FieldDescriptor::CPPTYPE_UINT32: {
249       GOOGLE_CHECK_GET_UINT32(obj, value, false);
250       value_ref->SetUInt32Value(value);
251       return true;
252     }
253     case FieldDescriptor::CPPTYPE_UINT64: {
254       GOOGLE_CHECK_GET_UINT64(obj, value, false);
255       value_ref->SetUInt64Value(value);
256       return true;
257     }
258     case FieldDescriptor::CPPTYPE_FLOAT: {
259       GOOGLE_CHECK_GET_FLOAT(obj, value, false);
260       value_ref->SetFloatValue(value);
261       return true;
262     }
263     case FieldDescriptor::CPPTYPE_DOUBLE: {
264       GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
265       value_ref->SetDoubleValue(value);
266       return true;
267     }
268     case FieldDescriptor::CPPTYPE_BOOL: {
269       GOOGLE_CHECK_GET_BOOL(obj, value, false);
270       value_ref->SetBoolValue(value);
271       return true;;
272     }
273     case FieldDescriptor::CPPTYPE_STRING: {
274       string str;
275       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
276         return false;
277       }
278       value_ref->SetStringValue(str);
279       return true;
280     }
281     case FieldDescriptor::CPPTYPE_ENUM: {
282       GOOGLE_CHECK_GET_INT32(obj, value, false);
283       if (allow_unknown_enum_values) {
284         value_ref->SetEnumValue(value);
285         return true;
286       } else {
287         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
288         const EnumValueDescriptor* enum_value =
289             enum_descriptor->FindValueByNumber(value);
290         if (enum_value != NULL) {
291           value_ref->SetEnumValue(value);
292           return true;
293         } else {
294           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
295           return false;
296         }
297       }
298       break;
299     }
300     default:
301       PyErr_Format(
302           PyExc_SystemError, "Setting value to a field of unknown type %d",
303           field_descriptor->cpp_type());
304       return false;
305   }
306 }
307 
308 // Map methods common to ScalarMap and MessageMap //////////////////////////////
309 
GetMap(PyObject * obj)310 static MapContainer* GetMap(PyObject* obj) {
311   return reinterpret_cast<MapContainer*>(obj);
312 }
313 
Length(PyObject * _self)314 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
315   MapContainer* self = GetMap(_self);
316   const google::protobuf::Message* message = self->message;
317   return message->GetReflection()->MapSize(*message,
318                                            self->parent_field_descriptor);
319 }
320 
Clear(PyObject * _self)321 PyObject* Clear(PyObject* _self) {
322   MapContainer* self = GetMap(_self);
323   Message* message = self->GetMutableMessage();
324   const Reflection* reflection = message->GetReflection();
325 
326   reflection->ClearField(message, self->parent_field_descriptor);
327 
328   Py_RETURN_NONE;
329 }
330 
Contains(PyObject * _self,PyObject * key)331 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
332   MapContainer* self = GetMap(_self);
333 
334   const Message* message = self->message;
335   const Reflection* reflection = message->GetReflection();
336   MapKey map_key;
337 
338   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
339     return NULL;
340   }
341 
342   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
343                                  map_key)) {
344     Py_RETURN_TRUE;
345   } else {
346     Py_RETURN_FALSE;
347   }
348 }
349 
350 // Initializes the underlying Message object of "to" so it becomes a new parent
351 // map container, and copies all the values from "from" to it. A child map
352 // container can be released by passing it as both from and to (e.g. making it
353 // the recipient of the new parent message and copying the values from itself).
354 // In fact, this is the only supported use at the moment.
InitializeAndCopyToParentContainer(MapContainer * from,MapContainer * to)355 static int InitializeAndCopyToParentContainer(MapContainer* from,
356                                               MapContainer* to) {
357   // For now we require from == to, re-evaluate if we want to support deep copy
358   // as in repeated_scalar_container.cc.
359   GOOGLE_DCHECK(from == to);
360   Message* new_message = from->message->New();
361 
362   if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) {
363     // A somewhat roundabout way of copying just one field from old_message to
364     // new_message.  This is the best we can do with what Reflection gives us.
365     Message* mutable_old = from->GetMutableMessage();
366     vector<const FieldDescriptor*> fields;
367     fields.push_back(from->parent_field_descriptor);
368 
369     // Move the map field into the new message.
370     mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields);
371 
372     // If/when we support from != to, this will be required also to copy the
373     // map field back into the existing message:
374     // mutable_old->MergeFrom(*new_message);
375   }
376 
377   // If from == to this could delete old_message.
378   to->owner.reset(new_message);
379 
380   to->parent = NULL;
381   to->parent_field_descriptor = from->parent_field_descriptor;
382   to->message = new_message;
383 
384   // Invalidate iterators, since they point to the old copy of the field.
385   to->version++;
386 
387   return 0;
388 }
389 
Release()390 int MapContainer::Release() {
391   return InitializeAndCopyToParentContainer(this, this);
392 }
393 
394 
395 // ScalarMap ///////////////////////////////////////////////////////////////////
396 
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)397 PyObject *NewScalarMapContainer(
398     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
399   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
400     return NULL;
401   }
402 
403 #if PY_MAJOR_VERSION >= 3
404   ScopedPyObjectPtr obj(PyType_GenericAlloc(
405         reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
406 #else
407   ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
408 #endif
409   if (obj.get() == NULL) {
410     return PyErr_Format(PyExc_RuntimeError,
411                         "Could not allocate new container.");
412   }
413 
414   MapContainer* self = GetMap(obj.get());
415 
416   self->message = parent->message;
417   self->parent = parent;
418   self->parent_field_descriptor = parent_field_descriptor;
419   self->owner = parent->owner;
420   self->version = 0;
421 
422   self->key_field_descriptor =
423       parent_field_descriptor->message_type()->FindFieldByName("key");
424   self->value_field_descriptor =
425       parent_field_descriptor->message_type()->FindFieldByName("value");
426 
427   if (self->key_field_descriptor == NULL ||
428       self->value_field_descriptor == NULL) {
429     return PyErr_Format(PyExc_KeyError,
430                         "Map entry descriptor did not have key/value fields");
431   }
432 
433   return obj.release();
434 }
435 
ScalarMapGetItem(PyObject * _self,PyObject * key)436 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
437                                                 PyObject* key) {
438   MapContainer* self = GetMap(_self);
439 
440   Message* message = self->GetMutableMessage();
441   const Reflection* reflection = message->GetReflection();
442   MapKey map_key;
443   MapValueRef value;
444 
445   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
446     return NULL;
447   }
448 
449   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
450                                          map_key, &value)) {
451     self->version++;
452   }
453 
454   return MapValueRefToPython(self->value_field_descriptor, &value);
455 }
456 
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)457 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
458                                           PyObject* v) {
459   MapContainer* self = GetMap(_self);
460 
461   Message* message = self->GetMutableMessage();
462   const Reflection* reflection = message->GetReflection();
463   MapKey map_key;
464   MapValueRef value;
465 
466   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
467     return -1;
468   }
469 
470   self->version++;
471 
472   if (v) {
473     // Set item to v.
474     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
475                                        map_key, &value);
476 
477     return PythonToMapValueRef(v, self->value_field_descriptor,
478                                reflection->SupportsUnknownEnumValues(), &value)
479                ? 0
480                : -1;
481   } else {
482     // Delete key from map.
483     if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
484                                    map_key)) {
485       return 0;
486     } else {
487       PyErr_Format(PyExc_KeyError, "Key not present in map");
488       return -1;
489     }
490   }
491 }
492 
ScalarMapGet(PyObject * self,PyObject * args)493 static PyObject* ScalarMapGet(PyObject* self, PyObject* args) {
494   PyObject* key;
495   PyObject* default_value = NULL;
496   if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
497     return NULL;
498   }
499 
500   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
501   if (is_present.get() == NULL) {
502     return NULL;
503   }
504 
505   if (PyObject_IsTrue(is_present.get())) {
506     return MapReflectionFriend::ScalarMapGetItem(self, key);
507   } else {
508     if (default_value != NULL) {
509       Py_INCREF(default_value);
510       return default_value;
511     } else {
512       Py_RETURN_NONE;
513     }
514   }
515 }
516 
ScalarMapDealloc(PyObject * _self)517 static void ScalarMapDealloc(PyObject* _self) {
518   MapContainer* self = GetMap(_self);
519   self->owner.reset();
520   Py_TYPE(_self)->tp_free(_self);
521 }
522 
523 static PyMethodDef ScalarMapMethods[] = {
524   { "__contains__", MapReflectionFriend::Contains, METH_O,
525     "Tests whether a key is a member of the map." },
526   { "clear", (PyCFunction)Clear, METH_NOARGS,
527     "Removes all elements from the map." },
528   { "get", ScalarMapGet, METH_VARARGS,
529     "Gets the value for the given key if present, or otherwise a default" },
530   /*
531   { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
532     "Makes a deep copy of the class." },
533   { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
534     "Outputs picklable representation of the repeated field." },
535   */
536   {NULL, NULL},
537 };
538 
539 #if PY_MAJOR_VERSION >= 3
540   static PyType_Slot ScalarMapContainer_Type_slots[] = {
541       {Py_tp_dealloc, (void *)ScalarMapDealloc},
542       {Py_mp_length, (void *)MapReflectionFriend::Length},
543       {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
544       {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
545       {Py_tp_methods, (void *)ScalarMapMethods},
546       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
547       {0, 0},
548   };
549 
550   PyType_Spec ScalarMapContainer_Type_spec = {
551       FULL_MODULE_NAME ".ScalarMapContainer",
552       sizeof(MapContainer),
553       0,
554       Py_TPFLAGS_DEFAULT,
555       ScalarMapContainer_Type_slots
556   };
557   PyObject *ScalarMapContainer_Type;
558 #else
559   static PyMappingMethods ScalarMapMappingMethods = {
560     MapReflectionFriend::Length,             // mp_length
561     MapReflectionFriend::ScalarMapGetItem,   // mp_subscript
562     MapReflectionFriend::ScalarMapSetItem,   // mp_ass_subscript
563   };
564 
565   PyTypeObject ScalarMapContainer_Type = {
566     PyVarObject_HEAD_INIT(&PyType_Type, 0)
567     FULL_MODULE_NAME ".ScalarMapContainer",  //  tp_name
568     sizeof(MapContainer),                //  tp_basicsize
569     0,                                   //  tp_itemsize
570     ScalarMapDealloc,                    //  tp_dealloc
571     0,                                   //  tp_print
572     0,                                   //  tp_getattr
573     0,                                   //  tp_setattr
574     0,                                   //  tp_compare
575     0,                                   //  tp_repr
576     0,                                   //  tp_as_number
577     0,                                   //  tp_as_sequence
578     &ScalarMapMappingMethods,            //  tp_as_mapping
579     0,                                   //  tp_hash
580     0,                                   //  tp_call
581     0,                                   //  tp_str
582     0,                                   //  tp_getattro
583     0,                                   //  tp_setattro
584     0,                                   //  tp_as_buffer
585     Py_TPFLAGS_DEFAULT,                  //  tp_flags
586     "A scalar map container",            //  tp_doc
587     0,                                   //  tp_traverse
588     0,                                   //  tp_clear
589     0,                                   //  tp_richcompare
590     0,                                   //  tp_weaklistoffset
591     MapReflectionFriend::GetIterator,    //  tp_iter
592     0,                                   //  tp_iternext
593     ScalarMapMethods,                    //  tp_methods
594     0,                                   //  tp_members
595     0,                                   //  tp_getset
596     0,                                   //  tp_base
597     0,                                   //  tp_dict
598     0,                                   //  tp_descr_get
599     0,                                   //  tp_descr_set
600     0,                                   //  tp_dictoffset
601     0,                                   //  tp_init
602   };
603 #endif
604 
605 
606 // MessageMap //////////////////////////////////////////////////////////////////
607 
GetMessageMap(PyObject * obj)608 static MessageMapContainer* GetMessageMap(PyObject* obj) {
609   return reinterpret_cast<MessageMapContainer*>(obj);
610 }
611 
GetCMessage(MessageMapContainer * self,Message * message)612 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
613   // Get or create the CMessage object corresponding to this message.
614   ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
615   PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
616 
617   if (ret == NULL) {
618     CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class);
619     ret = reinterpret_cast<PyObject*>(cmsg);
620 
621     if (cmsg == NULL) {
622       return NULL;
623     }
624     cmsg->owner = self->owner;
625     cmsg->message = message;
626     cmsg->parent = self->parent;
627 
628     if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
629       Py_DECREF(ret);
630       return NULL;
631     }
632   } else {
633     Py_INCREF(ret);
634   }
635 
636   return ret;
637 }
638 
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)639 PyObject* NewMessageMapContainer(
640     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
641     CMessageClass* message_class) {
642   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
643     return NULL;
644   }
645 
646 #if PY_MAJOR_VERSION >= 3
647   PyObject* obj = PyType_GenericAlloc(
648         reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
649 #else
650   PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
651 #endif
652   if (obj == NULL) {
653     return PyErr_Format(PyExc_RuntimeError,
654                         "Could not allocate new container.");
655   }
656 
657   MessageMapContainer* self = GetMessageMap(obj);
658 
659   self->message = parent->message;
660   self->parent = parent;
661   self->parent_field_descriptor = parent_field_descriptor;
662   self->owner = parent->owner;
663   self->version = 0;
664 
665   self->key_field_descriptor =
666       parent_field_descriptor->message_type()->FindFieldByName("key");
667   self->value_field_descriptor =
668       parent_field_descriptor->message_type()->FindFieldByName("value");
669 
670   self->message_dict = PyDict_New();
671   if (self->message_dict == NULL) {
672     return PyErr_Format(PyExc_RuntimeError,
673                         "Could not allocate message dict.");
674   }
675 
676   Py_INCREF(message_class);
677   self->message_class = message_class;
678 
679   if (self->key_field_descriptor == NULL ||
680       self->value_field_descriptor == NULL) {
681     Py_DECREF(obj);
682     return PyErr_Format(PyExc_KeyError,
683                         "Map entry descriptor did not have key/value fields");
684   }
685 
686   return obj;
687 }
688 
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)689 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
690                                            PyObject* v) {
691   if (v) {
692     PyErr_Format(PyExc_ValueError,
693                  "Direct assignment of submessage not allowed");
694     return -1;
695   }
696 
697   // Now we know that this is a delete, not a set.
698 
699   MessageMapContainer* self = GetMessageMap(_self);
700   Message* message = self->GetMutableMessage();
701   const Reflection* reflection = message->GetReflection();
702   MapKey map_key;
703   MapValueRef value;
704 
705   self->version++;
706 
707   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
708     return -1;
709   }
710 
711   // Delete key from map.
712   if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
713                                  map_key)) {
714     return 0;
715   } else {
716     PyErr_Format(PyExc_KeyError, "Key not present in map");
717     return -1;
718   }
719 }
720 
MessageMapGetItem(PyObject * _self,PyObject * key)721 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
722                                                  PyObject* key) {
723   MessageMapContainer* self = GetMessageMap(_self);
724 
725   Message* message = self->GetMutableMessage();
726   const Reflection* reflection = message->GetReflection();
727   MapKey map_key;
728   MapValueRef value;
729 
730   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
731     return NULL;
732   }
733 
734   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
735                                          map_key, &value)) {
736     self->version++;
737   }
738 
739   return GetCMessage(self, value.MutableMessageValue());
740 }
741 
MessageMapGet(PyObject * self,PyObject * args)742 PyObject* MessageMapGet(PyObject* self, PyObject* args) {
743   PyObject* key;
744   PyObject* default_value = NULL;
745   if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
746     return NULL;
747   }
748 
749   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
750   if (is_present.get() == NULL) {
751     return NULL;
752   }
753 
754   if (PyObject_IsTrue(is_present.get())) {
755     return MapReflectionFriend::MessageMapGetItem(self, key);
756   } else {
757     if (default_value != NULL) {
758       Py_INCREF(default_value);
759       return default_value;
760     } else {
761       Py_RETURN_NONE;
762     }
763   }
764 }
765 
MessageMapDealloc(PyObject * _self)766 static void MessageMapDealloc(PyObject* _self) {
767   MessageMapContainer* self = GetMessageMap(_self);
768   self->owner.reset();
769   Py_DECREF(self->message_dict);
770   Py_DECREF(self->message_class);
771   Py_TYPE(_self)->tp_free(_self);
772 }
773 
774 static PyMethodDef MessageMapMethods[] = {
775   { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
776     "Tests whether the map contains this element."},
777   { "clear", (PyCFunction)Clear, METH_NOARGS,
778     "Removes all elements from the map."},
779   { "get", MessageMapGet, METH_VARARGS,
780     "Gets the value for the given key if present, or otherwise a default" },
781   { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
782     "Alias for getitem, useful to make explicit that the map is mutated." },
783   /*
784   { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
785     "Makes a deep copy of the class." },
786   { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
787     "Outputs picklable representation of the repeated field." },
788   */
789   {NULL, NULL},
790 };
791 
792 #if PY_MAJOR_VERSION >= 3
793   static PyType_Slot MessageMapContainer_Type_slots[] = {
794       {Py_tp_dealloc, (void *)MessageMapDealloc},
795       {Py_mp_length, (void *)MapReflectionFriend::Length},
796       {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
797       {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
798       {Py_tp_methods, (void *)MessageMapMethods},
799       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
800       {0, 0}
801   };
802 
803   PyType_Spec MessageMapContainer_Type_spec = {
804       FULL_MODULE_NAME ".MessageMapContainer",
805       sizeof(MessageMapContainer),
806       0,
807       Py_TPFLAGS_DEFAULT,
808       MessageMapContainer_Type_slots
809   };
810 
811   PyObject *MessageMapContainer_Type;
812 #else
813   static PyMappingMethods MessageMapMappingMethods = {
814     MapReflectionFriend::Length,              // mp_length
815     MapReflectionFriend::MessageMapGetItem,   // mp_subscript
816     MapReflectionFriend::MessageMapSetItem,   // mp_ass_subscript
817   };
818 
819   PyTypeObject MessageMapContainer_Type = {
820     PyVarObject_HEAD_INIT(&PyType_Type, 0)
821     FULL_MODULE_NAME ".MessageMapContainer",  //  tp_name
822     sizeof(MessageMapContainer),         //  tp_basicsize
823     0,                                   //  tp_itemsize
824     MessageMapDealloc,                   //  tp_dealloc
825     0,                                   //  tp_print
826     0,                                   //  tp_getattr
827     0,                                   //  tp_setattr
828     0,                                   //  tp_compare
829     0,                                   //  tp_repr
830     0,                                   //  tp_as_number
831     0,                                   //  tp_as_sequence
832     &MessageMapMappingMethods,           //  tp_as_mapping
833     0,                                   //  tp_hash
834     0,                                   //  tp_call
835     0,                                   //  tp_str
836     0,                                   //  tp_getattro
837     0,                                   //  tp_setattro
838     0,                                   //  tp_as_buffer
839     Py_TPFLAGS_DEFAULT,                  //  tp_flags
840     "A map container for message",       //  tp_doc
841     0,                                   //  tp_traverse
842     0,                                   //  tp_clear
843     0,                                   //  tp_richcompare
844     0,                                   //  tp_weaklistoffset
845     MapReflectionFriend::GetIterator,    //  tp_iter
846     0,                                   //  tp_iternext
847     MessageMapMethods,                   //  tp_methods
848     0,                                   //  tp_members
849     0,                                   //  tp_getset
850     0,                                   //  tp_base
851     0,                                   //  tp_dict
852     0,                                   //  tp_descr_get
853     0,                                   //  tp_descr_set
854     0,                                   //  tp_dictoffset
855     0,                                   //  tp_init
856   };
857 #endif
858 
859 // MapIterator /////////////////////////////////////////////////////////////////
860 
GetIter(PyObject * obj)861 static MapIterator* GetIter(PyObject* obj) {
862   return reinterpret_cast<MapIterator*>(obj);
863 }
864 
GetIterator(PyObject * _self)865 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
866   MapContainer* self = GetMap(_self);
867 
868   ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
869   if (obj == NULL) {
870     return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
871   }
872 
873   MapIterator* iter = GetIter(obj.get());
874 
875   Py_INCREF(self);
876   iter->container = self;
877   iter->version = self->version;
878   iter->owner = self->owner;
879 
880   if (MapReflectionFriend::Length(_self) > 0) {
881     Message* message = self->GetMutableMessage();
882     const Reflection* reflection = message->GetReflection();
883 
884     iter->iter.reset(new ::google::protobuf::MapIterator(
885         reflection->MapBegin(message, self->parent_field_descriptor)));
886   }
887 
888   return obj.release();
889 }
890 
IterNext(PyObject * _self)891 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
892   MapIterator* self = GetIter(_self);
893 
894   // This won't catch mutations to the map performed by MergeFrom(); no easy way
895   // to address that.
896   if (self->version != self->container->version) {
897     return PyErr_Format(PyExc_RuntimeError,
898                         "Map modified during iteration.");
899   }
900 
901   if (self->iter.get() == NULL) {
902     return NULL;
903   }
904 
905   Message* message = self->container->GetMutableMessage();
906   const Reflection* reflection = message->GetReflection();
907 
908   if (*self->iter ==
909       reflection->MapEnd(message, self->container->parent_field_descriptor)) {
910     return NULL;
911   }
912 
913   PyObject* ret = MapKeyToPython(self->container->key_field_descriptor,
914                                  self->iter->GetKey());
915 
916   ++(*self->iter);
917 
918   return ret;
919 }
920 
DeallocMapIterator(PyObject * _self)921 static void DeallocMapIterator(PyObject* _self) {
922   MapIterator* self = GetIter(_self);
923   self->iter.reset();
924   self->owner.reset();
925   Py_XDECREF(self->container);
926   Py_TYPE(_self)->tp_free(_self);
927 }
928 
929 PyTypeObject MapIterator_Type = {
930   PyVarObject_HEAD_INIT(&PyType_Type, 0)
931   FULL_MODULE_NAME ".MapIterator",     //  tp_name
932   sizeof(MapIterator),                 //  tp_basicsize
933   0,                                   //  tp_itemsize
934   DeallocMapIterator,                  //  tp_dealloc
935   0,                                   //  tp_print
936   0,                                   //  tp_getattr
937   0,                                   //  tp_setattr
938   0,                                   //  tp_compare
939   0,                                   //  tp_repr
940   0,                                   //  tp_as_number
941   0,                                   //  tp_as_sequence
942   0,                                   //  tp_as_mapping
943   0,                                   //  tp_hash
944   0,                                   //  tp_call
945   0,                                   //  tp_str
946   0,                                   //  tp_getattro
947   0,                                   //  tp_setattro
948   0,                                   //  tp_as_buffer
949   Py_TPFLAGS_DEFAULT,                  //  tp_flags
950   "A scalar map iterator",             //  tp_doc
951   0,                                   //  tp_traverse
952   0,                                   //  tp_clear
953   0,                                   //  tp_richcompare
954   0,                                   //  tp_weaklistoffset
955   PyObject_SelfIter,                   //  tp_iter
956   MapReflectionFriend::IterNext,       //  tp_iternext
957   0,                                   //  tp_methods
958   0,                                   //  tp_members
959   0,                                   //  tp_getset
960   0,                                   //  tp_base
961   0,                                   //  tp_dict
962   0,                                   //  tp_descr_get
963   0,                                   //  tp_descr_set
964   0,                                   //  tp_dictoffset
965   0,                                   //  tp_init
966 };
967 
968 }  // namespace python
969 }  // namespace protobuf
970 }  // namespace google
971