• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 #include "python/extension_dict.h"
9 
10 #include "python/message.h"
11 #include "python/protobuf.h"
12 #include "upb/reflection/def.h"
13 
14 // -----------------------------------------------------------------------------
15 // ExtensionDict
16 // -----------------------------------------------------------------------------
17 
18 typedef struct {
19   PyObject_HEAD;
20   PyObject* msg;  // Owning ref to our parent pessage.
21 } PyUpb_ExtensionDict;
22 
PyUpb_ExtensionDict_New(PyObject * msg)23 PyObject* PyUpb_ExtensionDict_New(PyObject* msg) {
24   PyUpb_ModuleState* state = PyUpb_ModuleState_Get();
25   PyUpb_ExtensionDict* ext_dict =
26       (void*)PyType_GenericAlloc(state->extension_dict_type, 0);
27   ext_dict->msg = msg;
28   Py_INCREF(ext_dict->msg);
29   return &ext_dict->ob_base;
30 }
31 
PyUpb_ExtensionDict_FindExtensionByName(PyObject * _self,PyObject * key)32 static PyObject* PyUpb_ExtensionDict_FindExtensionByName(PyObject* _self,
33                                                          PyObject* key) {
34   PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
35   const char* name = PyUpb_GetStrData(key);
36   if (!name) {
37     PyErr_Format(PyExc_TypeError, "_FindExtensionByName expect a str");
38     return NULL;
39   }
40   const upb_MessageDef* m = PyUpb_Message_GetMsgdef(self->msg);
41   const upb_FileDef* file = upb_MessageDef_File(m);
42   const upb_DefPool* symtab = upb_FileDef_Pool(file);
43   const upb_FieldDef* ext = upb_DefPool_FindExtensionByName(symtab, name);
44   if (ext) {
45     return PyUpb_FieldDescriptor_Get(ext);
46   } else {
47     Py_RETURN_NONE;
48   }
49 }
50 
PyUpb_ExtensionDict_FindExtensionByNumber(PyObject * _self,PyObject * arg)51 static PyObject* PyUpb_ExtensionDict_FindExtensionByNumber(PyObject* _self,
52                                                            PyObject* arg) {
53   PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
54   const upb_MessageDef* m = PyUpb_Message_GetMsgdef(self->msg);
55   const upb_MiniTable* l = upb_MessageDef_MiniTable(m);
56   const upb_FileDef* file = upb_MessageDef_File(m);
57   const upb_DefPool* symtab = upb_FileDef_Pool(file);
58   const upb_ExtensionRegistry* reg = upb_DefPool_ExtensionRegistry(symtab);
59   int64_t number = PyLong_AsLong(arg);
60   if (number == -1 && PyErr_Occurred()) return NULL;
61   const upb_MiniTableExtension* ext =
62       (upb_MiniTableExtension*)upb_ExtensionRegistry_Lookup(reg, l, number);
63   if (ext) {
64     const upb_FieldDef* f = upb_DefPool_FindExtensionByMiniTable(symtab, ext);
65     return PyUpb_FieldDescriptor_Get(f);
66   } else {
67     Py_RETURN_NONE;
68   }
69 }
70 
PyUpb_ExtensionDict_Dealloc(PyUpb_ExtensionDict * self)71 static void PyUpb_ExtensionDict_Dealloc(PyUpb_ExtensionDict* self) {
72   PyUpb_Message_ClearExtensionDict(self->msg);
73   Py_DECREF(self->msg);
74   PyUpb_Dealloc(self);
75 }
76 
PyUpb_ExtensionDict_RichCompare(PyObject * _self,PyObject * _other,int opid)77 static PyObject* PyUpb_ExtensionDict_RichCompare(PyObject* _self,
78                                                  PyObject* _other, int opid) {
79   // Only equality comparisons are implemented.
80   if (opid != Py_EQ && opid != Py_NE) {
81     Py_INCREF(Py_NotImplemented);
82     return Py_NotImplemented;
83   }
84   PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
85   bool equals = false;
86   if (PyObject_TypeCheck(_other, Py_TYPE(_self))) {
87     PyUpb_ExtensionDict* other = (PyUpb_ExtensionDict*)_other;
88     equals = self->msg == other->msg;
89   }
90   bool ret = opid == Py_EQ ? equals : !equals;
91   return PyBool_FromLong(ret);
92 }
93 
PyUpb_ExtensionDict_Contains(PyObject * _self,PyObject * key)94 static int PyUpb_ExtensionDict_Contains(PyObject* _self, PyObject* key) {
95   PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
96   const upb_FieldDef* f = PyUpb_Message_GetExtensionDef(self->msg, key);
97   if (!f) return -1;
98   upb_Message* msg = PyUpb_Message_GetIfReified(self->msg);
99   if (!msg) return 0;
100   if (upb_FieldDef_IsRepeated(f)) {
101     upb_MessageValue val = upb_Message_GetFieldByDef(msg, f);
102     return upb_Array_Size(val.array_val) > 0;
103   } else {
104     return upb_Message_HasFieldByDef(msg, f);
105   }
106 }
107 
PyUpb_ExtensionDict_Length(PyObject * _self)108 static Py_ssize_t PyUpb_ExtensionDict_Length(PyObject* _self) {
109   PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
110   upb_Message* msg = PyUpb_Message_GetIfReified(self->msg);
111   return msg ? upb_Message_ExtensionCount(msg) : 0;
112 }
113 
PyUpb_ExtensionDict_Subscript(PyObject * _self,PyObject * key)114 static PyObject* PyUpb_ExtensionDict_Subscript(PyObject* _self, PyObject* key) {
115   PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
116   const upb_FieldDef* f = PyUpb_Message_GetExtensionDef(self->msg, key);
117   if (!f) return NULL;
118   return PyUpb_Message_GetFieldValue(self->msg, f);
119 }
120 
PyUpb_ExtensionDict_AssignSubscript(PyObject * _self,PyObject * key,PyObject * val)121 static int PyUpb_ExtensionDict_AssignSubscript(PyObject* _self, PyObject* key,
122                                                PyObject* val) {
123   PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self;
124   const upb_FieldDef* f = PyUpb_Message_GetExtensionDef(self->msg, key);
125   if (!f) return -1;
126   if (val) {
127     return PyUpb_Message_SetFieldValue(self->msg, f, val, PyExc_TypeError);
128   } else {
129     PyUpb_Message_DoClearField(self->msg, f);
130     return 0;
131   }
132 }
133 
134 static PyObject* PyUpb_ExtensionIterator_New(PyObject* _ext_dict);
135 
136 static PyMethodDef PyUpb_ExtensionDict_Methods[] = {
137     {"_FindExtensionByName", PyUpb_ExtensionDict_FindExtensionByName, METH_O,
138      "Finds an extension by name."},
139     {"_FindExtensionByNumber", PyUpb_ExtensionDict_FindExtensionByNumber,
140      METH_O, "Finds an extension by number."},
141     {NULL, NULL},
142 };
143 
144 static PyType_Slot PyUpb_ExtensionDict_Slots[] = {
145     {Py_tp_dealloc, PyUpb_ExtensionDict_Dealloc},
146     {Py_tp_methods, PyUpb_ExtensionDict_Methods},
147     //{Py_tp_getset, PyUpb_ExtensionDict_Getters},
148     //{Py_tp_hash, PyObject_HashNotImplemented},
149     {Py_tp_richcompare, PyUpb_ExtensionDict_RichCompare},
150     {Py_tp_iter, PyUpb_ExtensionIterator_New},
151     {Py_sq_contains, PyUpb_ExtensionDict_Contains},
152     {Py_sq_length, PyUpb_ExtensionDict_Length},
153     {Py_mp_length, PyUpb_ExtensionDict_Length},
154     {Py_mp_subscript, PyUpb_ExtensionDict_Subscript},
155     {Py_mp_ass_subscript, PyUpb_ExtensionDict_AssignSubscript},
156     {0, NULL}};
157 
158 static PyType_Spec PyUpb_ExtensionDict_Spec = {
159     PYUPB_MODULE_NAME ".ExtensionDict",  // tp_name
160     sizeof(PyUpb_ExtensionDict),         // tp_basicsize
161     0,                                   // tp_itemsize
162     Py_TPFLAGS_DEFAULT,                  // tp_flags
163     PyUpb_ExtensionDict_Slots,
164 };
165 
166 // -----------------------------------------------------------------------------
167 // ExtensionIterator
168 // -----------------------------------------------------------------------------
169 
170 typedef struct {
171   PyObject_HEAD;
172   PyObject* msg;
173   size_t iter;
174 } PyUpb_ExtensionIterator;
175 
PyUpb_ExtensionIterator_New(PyObject * _ext_dict)176 static PyObject* PyUpb_ExtensionIterator_New(PyObject* _ext_dict) {
177   PyUpb_ExtensionDict* ext_dict = (PyUpb_ExtensionDict*)_ext_dict;
178   PyUpb_ModuleState* state = PyUpb_ModuleState_Get();
179   PyUpb_ExtensionIterator* iter =
180       (void*)PyType_GenericAlloc(state->extension_iterator_type, 0);
181   if (!iter) return NULL;
182   iter->msg = ext_dict->msg;
183   iter->iter = kUpb_Message_Begin;
184   Py_INCREF(iter->msg);
185   return &iter->ob_base;
186 }
187 
PyUpb_ExtensionIterator_Dealloc(void * _self)188 static void PyUpb_ExtensionIterator_Dealloc(void* _self) {
189   PyUpb_ExtensionIterator* self = (PyUpb_ExtensionIterator*)_self;
190   Py_DECREF(self->msg);
191   PyUpb_Dealloc(_self);
192 }
193 
PyUpb_ExtensionIterator_IterNext(PyObject * _self)194 PyObject* PyUpb_ExtensionIterator_IterNext(PyObject* _self) {
195   PyUpb_ExtensionIterator* self = (PyUpb_ExtensionIterator*)_self;
196   upb_Message* msg = PyUpb_Message_GetIfReified(self->msg);
197   if (!msg) return NULL;
198   const upb_MessageDef* m = PyUpb_Message_GetMsgdef(self->msg);
199   const upb_DefPool* symtab = upb_FileDef_Pool(upb_MessageDef_File(m));
200   while (true) {
201     const upb_FieldDef* f;
202     upb_MessageValue val;
203     if (!upb_Message_Next(msg, m, symtab, &f, &val, &self->iter)) return NULL;
204     if (upb_FieldDef_IsExtension(f)) return PyUpb_FieldDescriptor_Get(f);
205   }
206 }
207 
208 static PyType_Slot PyUpb_ExtensionIterator_Slots[] = {
209     {Py_tp_dealloc, PyUpb_ExtensionIterator_Dealloc},
210     {Py_tp_iter, PyObject_SelfIter},
211     {Py_tp_iternext, PyUpb_ExtensionIterator_IterNext},
212     {0, NULL}};
213 
214 static PyType_Spec PyUpb_ExtensionIterator_Spec = {
215     PYUPB_MODULE_NAME ".ExtensionIterator",  // tp_name
216     sizeof(PyUpb_ExtensionIterator),         // tp_basicsize
217     0,                                       // tp_itemsize
218     Py_TPFLAGS_DEFAULT,                      // tp_flags
219     PyUpb_ExtensionIterator_Slots,
220 };
221 
222 // -----------------------------------------------------------------------------
223 // Top Level
224 // -----------------------------------------------------------------------------
225 
PyUpb_InitExtensionDict(PyObject * m)226 bool PyUpb_InitExtensionDict(PyObject* m) {
227   PyUpb_ModuleState* s = PyUpb_ModuleState_GetFromModule(m);
228 
229   s->extension_dict_type = PyUpb_AddClass(m, &PyUpb_ExtensionDict_Spec);
230   s->extension_iterator_type = PyUpb_AddClass(m, &PyUpb_ExtensionIterator_Spec);
231 
232   return s->extension_dict_type && s->extension_iterator_type;
233 }
234