• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // types.UnionType -- used to represent e.g. Union[int, str], int | str
2 #include "Python.h"
3 #include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
4 #include "pycore_typevarobject.h"  // _PyTypeAlias_Type
5 #include "pycore_unionobject.h"
6 
7 
8 
9 static PyObject *make_union(PyObject *);
10 
11 
12 typedef struct {
13     PyObject_HEAD
14     PyObject *args;
15     PyObject *parameters;
16 } unionobject;
17 
18 static void
unionobject_dealloc(PyObject * self)19 unionobject_dealloc(PyObject *self)
20 {
21     unionobject *alias = (unionobject *)self;
22 
23     _PyObject_GC_UNTRACK(self);
24 
25     Py_XDECREF(alias->args);
26     Py_XDECREF(alias->parameters);
27     Py_TYPE(self)->tp_free(self);
28 }
29 
30 static int
union_traverse(PyObject * self,visitproc visit,void * arg)31 union_traverse(PyObject *self, visitproc visit, void *arg)
32 {
33     unionobject *alias = (unionobject *)self;
34     Py_VISIT(alias->args);
35     Py_VISIT(alias->parameters);
36     return 0;
37 }
38 
39 static Py_hash_t
union_hash(PyObject * self)40 union_hash(PyObject *self)
41 {
42     unionobject *alias = (unionobject *)self;
43     PyObject *args = PyFrozenSet_New(alias->args);
44     if (args == NULL) {
45         return (Py_hash_t)-1;
46     }
47     Py_hash_t hash = PyObject_Hash(args);
48     Py_DECREF(args);
49     return hash;
50 }
51 
52 static PyObject *
union_richcompare(PyObject * a,PyObject * b,int op)53 union_richcompare(PyObject *a, PyObject *b, int op)
54 {
55     if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
56         Py_RETURN_NOTIMPLEMENTED;
57     }
58 
59     PyObject *a_set = PySet_New(((unionobject*)a)->args);
60     if (a_set == NULL) {
61         return NULL;
62     }
63     PyObject *b_set = PySet_New(((unionobject*)b)->args);
64     if (b_set == NULL) {
65         Py_DECREF(a_set);
66         return NULL;
67     }
68     PyObject *result = PyObject_RichCompare(a_set, b_set, op);
69     Py_DECREF(b_set);
70     Py_DECREF(a_set);
71     return result;
72 }
73 
74 static int
is_same(PyObject * left,PyObject * right)75 is_same(PyObject *left, PyObject *right)
76 {
77     int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
78     return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
79 }
80 
81 static int
contains(PyObject ** items,Py_ssize_t size,PyObject * obj)82 contains(PyObject **items, Py_ssize_t size, PyObject *obj)
83 {
84     for (int i = 0; i < size; i++) {
85         int is_duplicate = is_same(items[i], obj);
86         if (is_duplicate) {  // -1 or 1
87             return is_duplicate;
88         }
89     }
90     return 0;
91 }
92 
93 static PyObject *
merge(PyObject ** items1,Py_ssize_t size1,PyObject ** items2,Py_ssize_t size2)94 merge(PyObject **items1, Py_ssize_t size1,
95       PyObject **items2, Py_ssize_t size2)
96 {
97     PyObject *tuple = NULL;
98     Py_ssize_t pos = 0;
99 
100     for (int i = 0; i < size2; i++) {
101         PyObject *arg = items2[i];
102         int is_duplicate = contains(items1, size1, arg);
103         if (is_duplicate < 0) {
104             Py_XDECREF(tuple);
105             return NULL;
106         }
107         if (is_duplicate) {
108             continue;
109         }
110 
111         if (tuple == NULL) {
112             tuple = PyTuple_New(size1 + size2 - i);
113             if (tuple == NULL) {
114                 return NULL;
115             }
116             for (; pos < size1; pos++) {
117                 PyObject *a = items1[pos];
118                 PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
119             }
120         }
121         PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
122         pos++;
123     }
124 
125     if (tuple) {
126         (void) _PyTuple_Resize(&tuple, pos);
127     }
128     return tuple;
129 }
130 
131 static PyObject **
get_types(PyObject ** obj,Py_ssize_t * size)132 get_types(PyObject **obj, Py_ssize_t *size)
133 {
134     if (*obj == Py_None) {
135         *obj = (PyObject *)&_PyNone_Type;
136     }
137     if (_PyUnion_Check(*obj)) {
138         PyObject *args = ((unionobject *) *obj)->args;
139         *size = PyTuple_GET_SIZE(args);
140         return &PyTuple_GET_ITEM(args, 0);
141     }
142     else {
143         *size = 1;
144         return obj;
145     }
146 }
147 
148 static int
is_unionable(PyObject * obj)149 is_unionable(PyObject *obj)
150 {
151     if (obj == Py_None ||
152         PyType_Check(obj) ||
153         _PyGenericAlias_Check(obj) ||
154         _PyUnion_Check(obj) ||
155         Py_IS_TYPE(obj, &_PyTypeAlias_Type)) {
156         return 1;
157     }
158     return 0;
159 }
160 
161 PyObject *
_Py_union_type_or(PyObject * self,PyObject * other)162 _Py_union_type_or(PyObject* self, PyObject* other)
163 {
164     if (!is_unionable(self) || !is_unionable(other)) {
165         Py_RETURN_NOTIMPLEMENTED;
166     }
167 
168     Py_ssize_t size1, size2;
169     PyObject **items1 = get_types(&self, &size1);
170     PyObject **items2 = get_types(&other, &size2);
171     PyObject *tuple = merge(items1, size1, items2, size2);
172     if (tuple == NULL) {
173         if (PyErr_Occurred()) {
174             return NULL;
175         }
176         return Py_NewRef(self);
177     }
178 
179     PyObject *new_union = make_union(tuple);
180     Py_DECREF(tuple);
181     return new_union;
182 }
183 
184 static int
union_repr_item(_PyUnicodeWriter * writer,PyObject * p)185 union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
186 {
187     PyObject *qualname = NULL;
188     PyObject *module = NULL;
189     PyObject *r = NULL;
190     int rc;
191 
192     if (p == (PyObject *)&_PyNone_Type) {
193         return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
194     }
195 
196     if ((rc = PyObject_HasAttrWithError(p, &_Py_ID(__origin__))) > 0 &&
197         (rc = PyObject_HasAttrWithError(p, &_Py_ID(__args__))) > 0)
198     {
199         // It looks like a GenericAlias
200         goto use_repr;
201     }
202     if (rc < 0) {
203         goto exit;
204     }
205 
206     if (PyObject_GetOptionalAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
207         goto exit;
208     }
209     if (qualname == NULL) {
210         goto use_repr;
211     }
212     if (PyObject_GetOptionalAttr(p, &_Py_ID(__module__), &module) < 0) {
213         goto exit;
214     }
215     if (module == NULL || module == Py_None) {
216         goto use_repr;
217     }
218 
219     // Looks like a class
220     if (PyUnicode_Check(module) &&
221         _PyUnicode_EqualToASCIIString(module, "builtins"))
222     {
223         // builtins don't need a module name
224         r = PyObject_Str(qualname);
225         goto exit;
226     }
227     else {
228         r = PyUnicode_FromFormat("%S.%S", module, qualname);
229         goto exit;
230     }
231 
232 use_repr:
233     r = PyObject_Repr(p);
234 exit:
235     Py_XDECREF(qualname);
236     Py_XDECREF(module);
237     if (r == NULL) {
238         return -1;
239     }
240     rc = _PyUnicodeWriter_WriteStr(writer, r);
241     Py_DECREF(r);
242     return rc;
243 }
244 
245 static PyObject *
union_repr(PyObject * self)246 union_repr(PyObject *self)
247 {
248     unionobject *alias = (unionobject *)self;
249     Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
250 
251     _PyUnicodeWriter writer;
252     _PyUnicodeWriter_Init(&writer);
253      for (Py_ssize_t i = 0; i < len; i++) {
254         if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
255             goto error;
256         }
257         PyObject *p = PyTuple_GET_ITEM(alias->args, i);
258         if (union_repr_item(&writer, p) < 0) {
259             goto error;
260         }
261     }
262     return _PyUnicodeWriter_Finish(&writer);
263 error:
264     _PyUnicodeWriter_Dealloc(&writer);
265     return NULL;
266 }
267 
268 static PyMemberDef union_members[] = {
269         {"__args__", _Py_T_OBJECT, offsetof(unionobject, args), Py_READONLY},
270         {0}
271 };
272 
273 static PyObject *
union_getitem(PyObject * self,PyObject * item)274 union_getitem(PyObject *self, PyObject *item)
275 {
276     unionobject *alias = (unionobject *)self;
277     // Populate __parameters__ if needed.
278     if (alias->parameters == NULL) {
279         alias->parameters = _Py_make_parameters(alias->args);
280         if (alias->parameters == NULL) {
281             return NULL;
282         }
283     }
284 
285     PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
286     if (newargs == NULL) {
287         return NULL;
288     }
289 
290     PyObject *res;
291     Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
292     if (nargs == 0) {
293         res = make_union(newargs);
294     }
295     else {
296         res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
297         for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
298             PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
299             Py_SETREF(res, PyNumber_Or(res, arg));
300             if (res == NULL) {
301                 break;
302             }
303         }
304     }
305     Py_DECREF(newargs);
306     return res;
307 }
308 
309 static PyMappingMethods union_as_mapping = {
310     .mp_subscript = union_getitem,
311 };
312 
313 static PyObject *
union_parameters(PyObject * self,void * Py_UNUSED (unused))314 union_parameters(PyObject *self, void *Py_UNUSED(unused))
315 {
316     unionobject *alias = (unionobject *)self;
317     if (alias->parameters == NULL) {
318         alias->parameters = _Py_make_parameters(alias->args);
319         if (alias->parameters == NULL) {
320             return NULL;
321         }
322     }
323     return Py_NewRef(alias->parameters);
324 }
325 
326 static PyGetSetDef union_properties[] = {
327     {"__parameters__", union_parameters, (setter)NULL,
328      PyDoc_STR("Type variables in the types.UnionType."), NULL},
329     {0}
330 };
331 
332 static PyNumberMethods union_as_number = {
333         .nb_or = _Py_union_type_or, // Add __or__ function
334 };
335 
336 static const char* const cls_attrs[] = {
337         "__module__",  // Required for compatibility with typing module
338         NULL,
339 };
340 
341 static PyObject *
union_getattro(PyObject * self,PyObject * name)342 union_getattro(PyObject *self, PyObject *name)
343 {
344     unionobject *alias = (unionobject *)self;
345     if (PyUnicode_Check(name)) {
346         for (const char * const *p = cls_attrs; ; p++) {
347             if (*p == NULL) {
348                 break;
349             }
350             if (_PyUnicode_EqualToASCIIString(name, *p)) {
351                 return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
352             }
353         }
354     }
355     return PyObject_GenericGetAttr(self, name);
356 }
357 
358 PyObject *
_Py_union_args(PyObject * self)359 _Py_union_args(PyObject *self)
360 {
361     assert(_PyUnion_Check(self));
362     return ((unionobject *) self)->args;
363 }
364 
365 PyTypeObject _PyUnion_Type = {
366     PyVarObject_HEAD_INIT(&PyType_Type, 0)
367     .tp_name = "types.UnionType",
368     .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
369               "\n"
370               "E.g. for int | str"),
371     .tp_basicsize = sizeof(unionobject),
372     .tp_dealloc = unionobject_dealloc,
373     .tp_alloc = PyType_GenericAlloc,
374     .tp_free = PyObject_GC_Del,
375     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
376     .tp_traverse = union_traverse,
377     .tp_hash = union_hash,
378     .tp_getattro = union_getattro,
379     .tp_members = union_members,
380     .tp_richcompare = union_richcompare,
381     .tp_as_mapping = &union_as_mapping,
382     .tp_as_number = &union_as_number,
383     .tp_repr = union_repr,
384     .tp_getset = union_properties,
385 };
386 
387 static PyObject *
make_union(PyObject * args)388 make_union(PyObject *args)
389 {
390     assert(PyTuple_CheckExact(args));
391 
392     unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
393     if (result == NULL) {
394         return NULL;
395     }
396 
397     result->parameters = NULL;
398     result->args = Py_NewRef(args);
399     _PyObject_GC_TRACK(result);
400     return (PyObject*)result;
401 }
402