• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "Python.h"
2 
3 #include "pycore_context.h"
4 #include "pycore_hamt.h"
5 #include "pycore_object.h"
6 #include "pycore_pystate.h"
7 #include "structmember.h"
8 
9 
10 #define CONTEXT_FREELIST_MAXLEN 255
11 static PyContext *ctx_freelist = NULL;
12 static int ctx_freelist_len = 0;
13 
14 
15 #include "clinic/context.c.h"
16 /*[clinic input]
17 module _contextvars
18 [clinic start generated code]*/
19 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/
20 
21 
22 #define ENSURE_Context(o, err_ret)                                  \
23     if (!PyContext_CheckExact(o)) {                                 \
24         PyErr_SetString(PyExc_TypeError,                            \
25                         "an instance of Context was expected");     \
26         return err_ret;                                             \
27     }
28 
29 #define ENSURE_ContextVar(o, err_ret)                               \
30     if (!PyContextVar_CheckExact(o)) {                              \
31         PyErr_SetString(PyExc_TypeError,                            \
32                        "an instance of ContextVar was expected");   \
33         return err_ret;                                             \
34     }
35 
36 #define ENSURE_ContextToken(o, err_ret)                             \
37     if (!PyContextToken_CheckExact(o)) {                            \
38         PyErr_SetString(PyExc_TypeError,                            \
39                         "an instance of Token was expected");       \
40         return err_ret;                                             \
41     }
42 
43 
44 /////////////////////////// Context API
45 
46 
47 static PyContext *
48 context_new_empty(void);
49 
50 static PyContext *
51 context_new_from_vars(PyHamtObject *vars);
52 
53 static inline PyContext *
54 context_get(void);
55 
56 static PyContextToken *
57 token_new(PyContext *ctx, PyContextVar *var, PyObject *val);
58 
59 static PyContextVar *
60 contextvar_new(PyObject *name, PyObject *def);
61 
62 static int
63 contextvar_set(PyContextVar *var, PyObject *val);
64 
65 static int
66 contextvar_del(PyContextVar *var);
67 
68 
69 PyObject *
_PyContext_NewHamtForTests(void)70 _PyContext_NewHamtForTests(void)
71 {
72     return (PyObject *)_PyHamt_New();
73 }
74 
75 
76 PyObject *
PyContext_New(void)77 PyContext_New(void)
78 {
79     return (PyObject *)context_new_empty();
80 }
81 
82 
83 PyObject *
PyContext_Copy(PyObject * octx)84 PyContext_Copy(PyObject * octx)
85 {
86     ENSURE_Context(octx, NULL)
87     PyContext *ctx = (PyContext *)octx;
88     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
89 }
90 
91 
92 PyObject *
PyContext_CopyCurrent(void)93 PyContext_CopyCurrent(void)
94 {
95     PyContext *ctx = context_get();
96     if (ctx == NULL) {
97         return NULL;
98     }
99 
100     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
101 }
102 
103 
104 int
PyContext_Enter(PyObject * octx)105 PyContext_Enter(PyObject *octx)
106 {
107     ENSURE_Context(octx, -1)
108     PyContext *ctx = (PyContext *)octx;
109 
110     if (ctx->ctx_entered) {
111         PyErr_Format(PyExc_RuntimeError,
112                      "cannot enter context: %R is already entered", ctx);
113         return -1;
114     }
115 
116     PyThreadState *ts = _PyThreadState_GET();
117     assert(ts != NULL);
118 
119     ctx->ctx_prev = (PyContext *)ts->context;  /* borrow */
120     ctx->ctx_entered = 1;
121 
122     Py_INCREF(ctx);
123     ts->context = (PyObject *)ctx;
124     ts->context_ver++;
125 
126     return 0;
127 }
128 
129 
130 int
PyContext_Exit(PyObject * octx)131 PyContext_Exit(PyObject *octx)
132 {
133     ENSURE_Context(octx, -1)
134     PyContext *ctx = (PyContext *)octx;
135 
136     if (!ctx->ctx_entered) {
137         PyErr_Format(PyExc_RuntimeError,
138                      "cannot exit context: %R has not been entered", ctx);
139         return -1;
140     }
141 
142     PyThreadState *ts = _PyThreadState_GET();
143     assert(ts != NULL);
144 
145     if (ts->context != (PyObject *)ctx) {
146         /* Can only happen if someone misuses the C API */
147         PyErr_SetString(PyExc_RuntimeError,
148                         "cannot exit context: thread state references "
149                         "a different context object");
150         return -1;
151     }
152 
153     Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev);
154     ts->context_ver++;
155 
156     ctx->ctx_prev = NULL;
157     ctx->ctx_entered = 0;
158 
159     return 0;
160 }
161 
162 
163 PyObject *
PyContextVar_New(const char * name,PyObject * def)164 PyContextVar_New(const char *name, PyObject *def)
165 {
166     PyObject *pyname = PyUnicode_FromString(name);
167     if (pyname == NULL) {
168         return NULL;
169     }
170     PyContextVar *var = contextvar_new(pyname, def);
171     Py_DECREF(pyname);
172     return (PyObject *)var;
173 }
174 
175 
176 int
PyContextVar_Get(PyObject * ovar,PyObject * def,PyObject ** val)177 PyContextVar_Get(PyObject *ovar, PyObject *def, PyObject **val)
178 {
179     ENSURE_ContextVar(ovar, -1)
180     PyContextVar *var = (PyContextVar *)ovar;
181 
182     PyThreadState *ts = _PyThreadState_GET();
183     assert(ts != NULL);
184     if (ts->context == NULL) {
185         goto not_found;
186     }
187 
188     if (var->var_cached != NULL &&
189             var->var_cached_tsid == ts->id &&
190             var->var_cached_tsver == ts->context_ver)
191     {
192         *val = var->var_cached;
193         goto found;
194     }
195 
196     assert(PyContext_CheckExact(ts->context));
197     PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars;
198 
199     PyObject *found = NULL;
200     int res = _PyHamt_Find(vars, (PyObject*)var, &found);
201     if (res < 0) {
202         goto error;
203     }
204     if (res == 1) {
205         assert(found != NULL);
206         var->var_cached = found;  /* borrow */
207         var->var_cached_tsid = ts->id;
208         var->var_cached_tsver = ts->context_ver;
209 
210         *val = found;
211         goto found;
212     }
213 
214 not_found:
215     if (def == NULL) {
216         if (var->var_default != NULL) {
217             *val = var->var_default;
218             goto found;
219         }
220 
221         *val = NULL;
222         goto found;
223     }
224     else {
225         *val = def;
226         goto found;
227    }
228 
229 found:
230     Py_XINCREF(*val);
231     return 0;
232 
233 error:
234     *val = NULL;
235     return -1;
236 }
237 
238 
239 PyObject *
PyContextVar_Set(PyObject * ovar,PyObject * val)240 PyContextVar_Set(PyObject *ovar, PyObject *val)
241 {
242     ENSURE_ContextVar(ovar, NULL)
243     PyContextVar *var = (PyContextVar *)ovar;
244 
245     if (!PyContextVar_CheckExact(var)) {
246         PyErr_SetString(
247             PyExc_TypeError, "an instance of ContextVar was expected");
248         return NULL;
249     }
250 
251     PyContext *ctx = context_get();
252     if (ctx == NULL) {
253         return NULL;
254     }
255 
256     PyObject *old_val = NULL;
257     int found = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &old_val);
258     if (found < 0) {
259         return NULL;
260     }
261 
262     Py_XINCREF(old_val);
263     PyContextToken *tok = token_new(ctx, var, old_val);
264     Py_XDECREF(old_val);
265 
266     if (contextvar_set(var, val)) {
267         Py_DECREF(tok);
268         return NULL;
269     }
270 
271     return (PyObject *)tok;
272 }
273 
274 
275 int
PyContextVar_Reset(PyObject * ovar,PyObject * otok)276 PyContextVar_Reset(PyObject *ovar, PyObject *otok)
277 {
278     ENSURE_ContextVar(ovar, -1)
279     ENSURE_ContextToken(otok, -1)
280     PyContextVar *var = (PyContextVar *)ovar;
281     PyContextToken *tok = (PyContextToken *)otok;
282 
283     if (tok->tok_used) {
284         PyErr_Format(PyExc_RuntimeError,
285                      "%R has already been used once", tok);
286         return -1;
287     }
288 
289     if (var != tok->tok_var) {
290         PyErr_Format(PyExc_ValueError,
291                      "%R was created by a different ContextVar", tok);
292         return -1;
293     }
294 
295     PyContext *ctx = context_get();
296     if (ctx != tok->tok_ctx) {
297         PyErr_Format(PyExc_ValueError,
298                      "%R was created in a different Context", tok);
299         return -1;
300     }
301 
302     tok->tok_used = 1;
303 
304     if (tok->tok_oldval == NULL) {
305         return contextvar_del(var);
306     }
307     else {
308         return contextvar_set(var, tok->tok_oldval);
309     }
310 }
311 
312 
313 /////////////////////////// PyContext
314 
315 /*[clinic input]
316 class _contextvars.Context "PyContext *" "&PyContext_Type"
317 [clinic start generated code]*/
318 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=bdf87f8e0cb580e8]*/
319 
320 
321 static inline PyContext *
_context_alloc(void)322 _context_alloc(void)
323 {
324     PyContext *ctx;
325     if (ctx_freelist_len) {
326         ctx_freelist_len--;
327         ctx = ctx_freelist;
328         ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
329         ctx->ctx_weakreflist = NULL;
330         _Py_NewReference((PyObject *)ctx);
331     }
332     else {
333         ctx = PyObject_GC_New(PyContext, &PyContext_Type);
334         if (ctx == NULL) {
335             return NULL;
336         }
337     }
338 
339     ctx->ctx_vars = NULL;
340     ctx->ctx_prev = NULL;
341     ctx->ctx_entered = 0;
342     ctx->ctx_weakreflist = NULL;
343 
344     return ctx;
345 }
346 
347 
348 static PyContext *
context_new_empty(void)349 context_new_empty(void)
350 {
351     PyContext *ctx = _context_alloc();
352     if (ctx == NULL) {
353         return NULL;
354     }
355 
356     ctx->ctx_vars = _PyHamt_New();
357     if (ctx->ctx_vars == NULL) {
358         Py_DECREF(ctx);
359         return NULL;
360     }
361 
362     _PyObject_GC_TRACK(ctx);
363     return ctx;
364 }
365 
366 
367 static PyContext *
context_new_from_vars(PyHamtObject * vars)368 context_new_from_vars(PyHamtObject *vars)
369 {
370     PyContext *ctx = _context_alloc();
371     if (ctx == NULL) {
372         return NULL;
373     }
374 
375     Py_INCREF(vars);
376     ctx->ctx_vars = vars;
377 
378     _PyObject_GC_TRACK(ctx);
379     return ctx;
380 }
381 
382 
383 static inline PyContext *
context_get(void)384 context_get(void)
385 {
386     PyThreadState *ts = _PyThreadState_GET();
387     assert(ts != NULL);
388     PyContext *current_ctx = (PyContext *)ts->context;
389     if (current_ctx == NULL) {
390         current_ctx = context_new_empty();
391         if (current_ctx == NULL) {
392             return NULL;
393         }
394         ts->context = (PyObject *)current_ctx;
395     }
396     return current_ctx;
397 }
398 
399 static int
context_check_key_type(PyObject * key)400 context_check_key_type(PyObject *key)
401 {
402     if (!PyContextVar_CheckExact(key)) {
403         // abort();
404         PyErr_Format(PyExc_TypeError,
405                      "a ContextVar key was expected, got %R", key);
406         return -1;
407     }
408     return 0;
409 }
410 
411 static PyObject *
context_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)412 context_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
413 {
414     if (PyTuple_Size(args) || (kwds != NULL && PyDict_Size(kwds))) {
415         PyErr_SetString(
416             PyExc_TypeError, "Context() does not accept any arguments");
417         return NULL;
418     }
419     return PyContext_New();
420 }
421 
422 static int
context_tp_clear(PyContext * self)423 context_tp_clear(PyContext *self)
424 {
425     Py_CLEAR(self->ctx_prev);
426     Py_CLEAR(self->ctx_vars);
427     return 0;
428 }
429 
430 static int
context_tp_traverse(PyContext * self,visitproc visit,void * arg)431 context_tp_traverse(PyContext *self, visitproc visit, void *arg)
432 {
433     Py_VISIT(self->ctx_prev);
434     Py_VISIT(self->ctx_vars);
435     return 0;
436 }
437 
438 static void
context_tp_dealloc(PyContext * self)439 context_tp_dealloc(PyContext *self)
440 {
441     _PyObject_GC_UNTRACK(self);
442 
443     if (self->ctx_weakreflist != NULL) {
444         PyObject_ClearWeakRefs((PyObject*)self);
445     }
446     (void)context_tp_clear(self);
447 
448     if (ctx_freelist_len < CONTEXT_FREELIST_MAXLEN) {
449         ctx_freelist_len++;
450         self->ctx_weakreflist = (PyObject *)ctx_freelist;
451         ctx_freelist = self;
452     }
453     else {
454         Py_TYPE(self)->tp_free(self);
455     }
456 }
457 
458 static PyObject *
context_tp_iter(PyContext * self)459 context_tp_iter(PyContext *self)
460 {
461     return _PyHamt_NewIterKeys(self->ctx_vars);
462 }
463 
464 static PyObject *
context_tp_richcompare(PyObject * v,PyObject * w,int op)465 context_tp_richcompare(PyObject *v, PyObject *w, int op)
466 {
467     if (!PyContext_CheckExact(v) || !PyContext_CheckExact(w) ||
468             (op != Py_EQ && op != Py_NE))
469     {
470         Py_RETURN_NOTIMPLEMENTED;
471     }
472 
473     int res = _PyHamt_Eq(
474         ((PyContext *)v)->ctx_vars, ((PyContext *)w)->ctx_vars);
475     if (res < 0) {
476         return NULL;
477     }
478 
479     if (op == Py_NE) {
480         res = !res;
481     }
482 
483     if (res) {
484         Py_RETURN_TRUE;
485     }
486     else {
487         Py_RETURN_FALSE;
488     }
489 }
490 
491 static Py_ssize_t
context_tp_len(PyContext * self)492 context_tp_len(PyContext *self)
493 {
494     return _PyHamt_Len(self->ctx_vars);
495 }
496 
497 static PyObject *
context_tp_subscript(PyContext * self,PyObject * key)498 context_tp_subscript(PyContext *self, PyObject *key)
499 {
500     if (context_check_key_type(key)) {
501         return NULL;
502     }
503     PyObject *val = NULL;
504     int found = _PyHamt_Find(self->ctx_vars, key, &val);
505     if (found < 0) {
506         return NULL;
507     }
508     if (found == 0) {
509         PyErr_SetObject(PyExc_KeyError, key);
510         return NULL;
511     }
512     Py_INCREF(val);
513     return val;
514 }
515 
516 static int
context_tp_contains(PyContext * self,PyObject * key)517 context_tp_contains(PyContext *self, PyObject *key)
518 {
519     if (context_check_key_type(key)) {
520         return -1;
521     }
522     PyObject *val = NULL;
523     return _PyHamt_Find(self->ctx_vars, key, &val);
524 }
525 
526 
527 /*[clinic input]
528 _contextvars.Context.get
529     key: object
530     default: object = None
531     /
532 
533 Return the value for `key` if `key` has the value in the context object.
534 
535 If `key` does not exist, return `default`. If `default` is not given,
536 return None.
537 [clinic start generated code]*/
538 
539 static PyObject *
_contextvars_Context_get_impl(PyContext * self,PyObject * key,PyObject * default_value)540 _contextvars_Context_get_impl(PyContext *self, PyObject *key,
541                               PyObject *default_value)
542 /*[clinic end generated code: output=0c54aa7664268189 input=c8eeb81505023995]*/
543 {
544     if (context_check_key_type(key)) {
545         return NULL;
546     }
547 
548     PyObject *val = NULL;
549     int found = _PyHamt_Find(self->ctx_vars, key, &val);
550     if (found < 0) {
551         return NULL;
552     }
553     if (found == 0) {
554         Py_INCREF(default_value);
555         return default_value;
556     }
557     Py_INCREF(val);
558     return val;
559 }
560 
561 
562 /*[clinic input]
563 _contextvars.Context.items
564 
565 Return all variables and their values in the context object.
566 
567 The result is returned as a list of 2-tuples (variable, value).
568 [clinic start generated code]*/
569 
570 static PyObject *
_contextvars_Context_items_impl(PyContext * self)571 _contextvars_Context_items_impl(PyContext *self)
572 /*[clinic end generated code: output=fa1655c8a08502af input=00db64ae379f9f42]*/
573 {
574     return _PyHamt_NewIterItems(self->ctx_vars);
575 }
576 
577 
578 /*[clinic input]
579 _contextvars.Context.keys
580 
581 Return a list of all variables in the context object.
582 [clinic start generated code]*/
583 
584 static PyObject *
_contextvars_Context_keys_impl(PyContext * self)585 _contextvars_Context_keys_impl(PyContext *self)
586 /*[clinic end generated code: output=177227c6b63ec0e2 input=114b53aebca3449c]*/
587 {
588     return _PyHamt_NewIterKeys(self->ctx_vars);
589 }
590 
591 
592 /*[clinic input]
593 _contextvars.Context.values
594 
595 Return a list of all variables' values in the context object.
596 [clinic start generated code]*/
597 
598 static PyObject *
_contextvars_Context_values_impl(PyContext * self)599 _contextvars_Context_values_impl(PyContext *self)
600 /*[clinic end generated code: output=d286dabfc8db6dde input=ce8075d04a6ea526]*/
601 {
602     return _PyHamt_NewIterValues(self->ctx_vars);
603 }
604 
605 
606 /*[clinic input]
607 _contextvars.Context.copy
608 
609 Return a shallow copy of the context object.
610 [clinic start generated code]*/
611 
612 static PyObject *
_contextvars_Context_copy_impl(PyContext * self)613 _contextvars_Context_copy_impl(PyContext *self)
614 /*[clinic end generated code: output=30ba8896c4707a15 input=ebafdbdd9c72d592]*/
615 {
616     return (PyObject *)context_new_from_vars(self->ctx_vars);
617 }
618 
619 
620 static PyObject *
context_run(PyContext * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)621 context_run(PyContext *self, PyObject *const *args,
622             Py_ssize_t nargs, PyObject *kwnames)
623 {
624     if (nargs < 1) {
625         PyErr_SetString(PyExc_TypeError,
626                         "run() missing 1 required positional argument");
627         return NULL;
628     }
629 
630     if (PyContext_Enter((PyObject *)self)) {
631         return NULL;
632     }
633 
634     PyObject *call_result = _PyObject_Vectorcall(
635         args[0], args + 1, nargs - 1, kwnames);
636 
637     if (PyContext_Exit((PyObject *)self)) {
638         return NULL;
639     }
640 
641     return call_result;
642 }
643 
644 
645 static PyMethodDef PyContext_methods[] = {
646     _CONTEXTVARS_CONTEXT_GET_METHODDEF
647     _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF
648     _CONTEXTVARS_CONTEXT_KEYS_METHODDEF
649     _CONTEXTVARS_CONTEXT_VALUES_METHODDEF
650     _CONTEXTVARS_CONTEXT_COPY_METHODDEF
651     {"run", (PyCFunction)(void(*)(void))context_run, METH_FASTCALL | METH_KEYWORDS, NULL},
652     {NULL, NULL}
653 };
654 
655 static PySequenceMethods PyContext_as_sequence = {
656     0,                                   /* sq_length */
657     0,                                   /* sq_concat */
658     0,                                   /* sq_repeat */
659     0,                                   /* sq_item */
660     0,                                   /* sq_slice */
661     0,                                   /* sq_ass_item */
662     0,                                   /* sq_ass_slice */
663     (objobjproc)context_tp_contains,     /* sq_contains */
664     0,                                   /* sq_inplace_concat */
665     0,                                   /* sq_inplace_repeat */
666 };
667 
668 static PyMappingMethods PyContext_as_mapping = {
669     (lenfunc)context_tp_len,             /* mp_length */
670     (binaryfunc)context_tp_subscript,    /* mp_subscript */
671 };
672 
673 PyTypeObject PyContext_Type = {
674     PyVarObject_HEAD_INIT(&PyType_Type, 0)
675     "Context",
676     sizeof(PyContext),
677     .tp_methods = PyContext_methods,
678     .tp_as_mapping = &PyContext_as_mapping,
679     .tp_as_sequence = &PyContext_as_sequence,
680     .tp_iter = (getiterfunc)context_tp_iter,
681     .tp_dealloc = (destructor)context_tp_dealloc,
682     .tp_getattro = PyObject_GenericGetAttr,
683     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
684     .tp_richcompare = context_tp_richcompare,
685     .tp_traverse = (traverseproc)context_tp_traverse,
686     .tp_clear = (inquiry)context_tp_clear,
687     .tp_new = context_tp_new,
688     .tp_weaklistoffset = offsetof(PyContext, ctx_weakreflist),
689     .tp_hash = PyObject_HashNotImplemented,
690 };
691 
692 
693 /////////////////////////// ContextVar
694 
695 
696 static int
contextvar_set(PyContextVar * var,PyObject * val)697 contextvar_set(PyContextVar *var, PyObject *val)
698 {
699     var->var_cached = NULL;
700     PyThreadState *ts = PyThreadState_Get();
701 
702     PyContext *ctx = context_get();
703     if (ctx == NULL) {
704         return -1;
705     }
706 
707     PyHamtObject *new_vars = _PyHamt_Assoc(
708         ctx->ctx_vars, (PyObject *)var, val);
709     if (new_vars == NULL) {
710         return -1;
711     }
712 
713     Py_SETREF(ctx->ctx_vars, new_vars);
714 
715     var->var_cached = val;  /* borrow */
716     var->var_cached_tsid = ts->id;
717     var->var_cached_tsver = ts->context_ver;
718     return 0;
719 }
720 
721 static int
contextvar_del(PyContextVar * var)722 contextvar_del(PyContextVar *var)
723 {
724     var->var_cached = NULL;
725 
726     PyContext *ctx = context_get();
727     if (ctx == NULL) {
728         return -1;
729     }
730 
731     PyHamtObject *vars = ctx->ctx_vars;
732     PyHamtObject *new_vars = _PyHamt_Without(vars, (PyObject *)var);
733     if (new_vars == NULL) {
734         return -1;
735     }
736 
737     if (vars == new_vars) {
738         Py_DECREF(new_vars);
739         PyErr_SetObject(PyExc_LookupError, (PyObject *)var);
740         return -1;
741     }
742 
743     Py_SETREF(ctx->ctx_vars, new_vars);
744     return 0;
745 }
746 
747 static Py_hash_t
contextvar_generate_hash(void * addr,PyObject * name)748 contextvar_generate_hash(void *addr, PyObject *name)
749 {
750     /* Take hash of `name` and XOR it with the object's addr.
751 
752        The structure of the tree is encoded in objects' hashes, which
753        means that sufficiently similar hashes would result in tall trees
754        with many Collision nodes.  Which would, in turn, result in slower
755        get and set operations.
756 
757        The XORing helps to ensure that:
758 
759        (1) sequentially allocated ContextVar objects have
760            different hashes;
761 
762        (2) context variables with equal names have
763            different hashes.
764     */
765 
766     Py_hash_t name_hash = PyObject_Hash(name);
767     if (name_hash == -1) {
768         return -1;
769     }
770 
771     Py_hash_t res = _Py_HashPointer(addr) ^ name_hash;
772     return res == -1 ? -2 : res;
773 }
774 
775 static PyContextVar *
contextvar_new(PyObject * name,PyObject * def)776 contextvar_new(PyObject *name, PyObject *def)
777 {
778     if (!PyUnicode_Check(name)) {
779         PyErr_SetString(PyExc_TypeError,
780                         "context variable name must be a str");
781         return NULL;
782     }
783 
784     PyContextVar *var = PyObject_GC_New(PyContextVar, &PyContextVar_Type);
785     if (var == NULL) {
786         return NULL;
787     }
788 
789     var->var_hash = contextvar_generate_hash(var, name);
790     if (var->var_hash == -1) {
791         Py_DECREF(var);
792         return NULL;
793     }
794 
795     Py_INCREF(name);
796     var->var_name = name;
797 
798     Py_XINCREF(def);
799     var->var_default = def;
800 
801     var->var_cached = NULL;
802     var->var_cached_tsid = 0;
803     var->var_cached_tsver = 0;
804 
805     if (_PyObject_GC_MAY_BE_TRACKED(name) ||
806             (def != NULL && _PyObject_GC_MAY_BE_TRACKED(def)))
807     {
808         PyObject_GC_Track(var);
809     }
810     return var;
811 }
812 
813 
814 /*[clinic input]
815 class _contextvars.ContextVar "PyContextVar *" "&PyContextVar_Type"
816 [clinic start generated code]*/
817 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=445da935fa8883c3]*/
818 
819 
820 static PyObject *
contextvar_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)821 contextvar_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
822 {
823     static char *kwlist[] = {"", "default", NULL};
824     PyObject *name;
825     PyObject *def = NULL;
826 
827     if (!PyArg_ParseTupleAndKeywords(
828             args, kwds, "O|$O:ContextVar", kwlist, &name, &def))
829     {
830         return NULL;
831     }
832 
833     return (PyObject *)contextvar_new(name, def);
834 }
835 
836 static int
contextvar_tp_clear(PyContextVar * self)837 contextvar_tp_clear(PyContextVar *self)
838 {
839     Py_CLEAR(self->var_name);
840     Py_CLEAR(self->var_default);
841     self->var_cached = NULL;
842     self->var_cached_tsid = 0;
843     self->var_cached_tsver = 0;
844     return 0;
845 }
846 
847 static int
contextvar_tp_traverse(PyContextVar * self,visitproc visit,void * arg)848 contextvar_tp_traverse(PyContextVar *self, visitproc visit, void *arg)
849 {
850     Py_VISIT(self->var_name);
851     Py_VISIT(self->var_default);
852     return 0;
853 }
854 
855 static void
contextvar_tp_dealloc(PyContextVar * self)856 contextvar_tp_dealloc(PyContextVar *self)
857 {
858     PyObject_GC_UnTrack(self);
859     (void)contextvar_tp_clear(self);
860     Py_TYPE(self)->tp_free(self);
861 }
862 
863 static Py_hash_t
contextvar_tp_hash(PyContextVar * self)864 contextvar_tp_hash(PyContextVar *self)
865 {
866     return self->var_hash;
867 }
868 
869 static PyObject *
contextvar_tp_repr(PyContextVar * self)870 contextvar_tp_repr(PyContextVar *self)
871 {
872     _PyUnicodeWriter writer;
873 
874     _PyUnicodeWriter_Init(&writer);
875 
876     if (_PyUnicodeWriter_WriteASCIIString(
877             &writer, "<ContextVar name=", 17) < 0)
878     {
879         goto error;
880     }
881 
882     PyObject *name = PyObject_Repr(self->var_name);
883     if (name == NULL) {
884         goto error;
885     }
886     if (_PyUnicodeWriter_WriteStr(&writer, name) < 0) {
887         Py_DECREF(name);
888         goto error;
889     }
890     Py_DECREF(name);
891 
892     if (self->var_default != NULL) {
893         if (_PyUnicodeWriter_WriteASCIIString(&writer, " default=", 9) < 0) {
894             goto error;
895         }
896 
897         PyObject *def = PyObject_Repr(self->var_default);
898         if (def == NULL) {
899             goto error;
900         }
901         if (_PyUnicodeWriter_WriteStr(&writer, def) < 0) {
902             Py_DECREF(def);
903             goto error;
904         }
905         Py_DECREF(def);
906     }
907 
908     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
909     if (addr == NULL) {
910         goto error;
911     }
912     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
913         Py_DECREF(addr);
914         goto error;
915     }
916     Py_DECREF(addr);
917 
918     return _PyUnicodeWriter_Finish(&writer);
919 
920 error:
921     _PyUnicodeWriter_Dealloc(&writer);
922     return NULL;
923 }
924 
925 
926 /*[clinic input]
927 _contextvars.ContextVar.get
928     default: object = NULL
929     /
930 
931 Return a value for the context variable for the current context.
932 
933 If there is no value for the variable in the current context, the method will:
934  * return the value of the default argument of the method, if provided; or
935  * return the default value for the context variable, if it was created
936    with one; or
937  * raise a LookupError.
938 [clinic start generated code]*/
939 
940 static PyObject *
_contextvars_ContextVar_get_impl(PyContextVar * self,PyObject * default_value)941 _contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value)
942 /*[clinic end generated code: output=0746bd0aa2ced7bf input=30aa2ab9e433e401]*/
943 {
944     if (!PyContextVar_CheckExact(self)) {
945         PyErr_SetString(
946             PyExc_TypeError, "an instance of ContextVar was expected");
947         return NULL;
948     }
949 
950     PyObject *val;
951     if (PyContextVar_Get((PyObject *)self, default_value, &val) < 0) {
952         return NULL;
953     }
954 
955     if (val == NULL) {
956         PyErr_SetObject(PyExc_LookupError, (PyObject *)self);
957         return NULL;
958     }
959 
960     return val;
961 }
962 
963 /*[clinic input]
964 _contextvars.ContextVar.set
965     value: object
966     /
967 
968 Call to set a new value for the context variable in the current context.
969 
970 The required value argument is the new value for the context variable.
971 
972 Returns a Token object that can be used to restore the variable to its previous
973 value via the `ContextVar.reset()` method.
974 [clinic start generated code]*/
975 
976 static PyObject *
_contextvars_ContextVar_set(PyContextVar * self,PyObject * value)977 _contextvars_ContextVar_set(PyContextVar *self, PyObject *value)
978 /*[clinic end generated code: output=446ed5e820d6d60b input=c0a6887154227453]*/
979 {
980     return PyContextVar_Set((PyObject *)self, value);
981 }
982 
983 /*[clinic input]
984 _contextvars.ContextVar.reset
985     token: object
986     /
987 
988 Reset the context variable.
989 
990 The variable is reset to the value it had before the `ContextVar.set()` that
991 created the token was used.
992 [clinic start generated code]*/
993 
994 static PyObject *
_contextvars_ContextVar_reset(PyContextVar * self,PyObject * token)995 _contextvars_ContextVar_reset(PyContextVar *self, PyObject *token)
996 /*[clinic end generated code: output=d4ee34d0742d62ee input=ebe2881e5af4ffda]*/
997 {
998     if (!PyContextToken_CheckExact(token)) {
999         PyErr_Format(PyExc_TypeError,
1000                      "expected an instance of Token, got %R", token);
1001         return NULL;
1002     }
1003 
1004     if (PyContextVar_Reset((PyObject *)self, token)) {
1005         return NULL;
1006     }
1007 
1008     Py_RETURN_NONE;
1009 }
1010 
1011 
1012 static PyObject *
contextvar_cls_getitem(PyObject * self,PyObject * arg)1013 contextvar_cls_getitem(PyObject *self, PyObject *arg)
1014 {
1015     Py_INCREF(self);
1016     return self;
1017 }
1018 
1019 static PyMemberDef PyContextVar_members[] = {
1020     {"name", T_OBJECT, offsetof(PyContextVar, var_name), READONLY},
1021     {NULL}
1022 };
1023 
1024 static PyMethodDef PyContextVar_methods[] = {
1025     _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF
1026     _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF
1027     _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF
1028     {"__class_getitem__", contextvar_cls_getitem,
1029         METH_O | METH_CLASS, NULL},
1030     {NULL, NULL}
1031 };
1032 
1033 PyTypeObject PyContextVar_Type = {
1034     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1035     "ContextVar",
1036     sizeof(PyContextVar),
1037     .tp_methods = PyContextVar_methods,
1038     .tp_members = PyContextVar_members,
1039     .tp_dealloc = (destructor)contextvar_tp_dealloc,
1040     .tp_getattro = PyObject_GenericGetAttr,
1041     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1042     .tp_traverse = (traverseproc)contextvar_tp_traverse,
1043     .tp_clear = (inquiry)contextvar_tp_clear,
1044     .tp_new = contextvar_tp_new,
1045     .tp_free = PyObject_GC_Del,
1046     .tp_hash = (hashfunc)contextvar_tp_hash,
1047     .tp_repr = (reprfunc)contextvar_tp_repr,
1048 };
1049 
1050 
1051 /////////////////////////// Token
1052 
1053 static PyObject * get_token_missing(void);
1054 
1055 
1056 /*[clinic input]
1057 class _contextvars.Token "PyContextToken *" "&PyContextToken_Type"
1058 [clinic start generated code]*/
1059 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=338a5e2db13d3f5b]*/
1060 
1061 
1062 static PyObject *
token_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)1063 token_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
1064 {
1065     PyErr_SetString(PyExc_RuntimeError,
1066                     "Tokens can only be created by ContextVars");
1067     return NULL;
1068 }
1069 
1070 static int
token_tp_clear(PyContextToken * self)1071 token_tp_clear(PyContextToken *self)
1072 {
1073     Py_CLEAR(self->tok_ctx);
1074     Py_CLEAR(self->tok_var);
1075     Py_CLEAR(self->tok_oldval);
1076     return 0;
1077 }
1078 
1079 static int
token_tp_traverse(PyContextToken * self,visitproc visit,void * arg)1080 token_tp_traverse(PyContextToken *self, visitproc visit, void *arg)
1081 {
1082     Py_VISIT(self->tok_ctx);
1083     Py_VISIT(self->tok_var);
1084     Py_VISIT(self->tok_oldval);
1085     return 0;
1086 }
1087 
1088 static void
token_tp_dealloc(PyContextToken * self)1089 token_tp_dealloc(PyContextToken *self)
1090 {
1091     PyObject_GC_UnTrack(self);
1092     (void)token_tp_clear(self);
1093     Py_TYPE(self)->tp_free(self);
1094 }
1095 
1096 static PyObject *
token_tp_repr(PyContextToken * self)1097 token_tp_repr(PyContextToken *self)
1098 {
1099     _PyUnicodeWriter writer;
1100 
1101     _PyUnicodeWriter_Init(&writer);
1102 
1103     if (_PyUnicodeWriter_WriteASCIIString(&writer, "<Token", 6) < 0) {
1104         goto error;
1105     }
1106 
1107     if (self->tok_used) {
1108         if (_PyUnicodeWriter_WriteASCIIString(&writer, " used", 5) < 0) {
1109             goto error;
1110         }
1111     }
1112 
1113     if (_PyUnicodeWriter_WriteASCIIString(&writer, " var=", 5) < 0) {
1114         goto error;
1115     }
1116 
1117     PyObject *var = PyObject_Repr((PyObject *)self->tok_var);
1118     if (var == NULL) {
1119         goto error;
1120     }
1121     if (_PyUnicodeWriter_WriteStr(&writer, var) < 0) {
1122         Py_DECREF(var);
1123         goto error;
1124     }
1125     Py_DECREF(var);
1126 
1127     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
1128     if (addr == NULL) {
1129         goto error;
1130     }
1131     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
1132         Py_DECREF(addr);
1133         goto error;
1134     }
1135     Py_DECREF(addr);
1136 
1137     return _PyUnicodeWriter_Finish(&writer);
1138 
1139 error:
1140     _PyUnicodeWriter_Dealloc(&writer);
1141     return NULL;
1142 }
1143 
1144 static PyObject *
token_get_var(PyContextToken * self,void * Py_UNUSED (ignored))1145 token_get_var(PyContextToken *self, void *Py_UNUSED(ignored))
1146 {
1147     Py_INCREF(self->tok_var);
1148     return (PyObject *)self->tok_var;
1149 }
1150 
1151 static PyObject *
token_get_old_value(PyContextToken * self,void * Py_UNUSED (ignored))1152 token_get_old_value(PyContextToken *self, void *Py_UNUSED(ignored))
1153 {
1154     if (self->tok_oldval == NULL) {
1155         return get_token_missing();
1156     }
1157 
1158     Py_INCREF(self->tok_oldval);
1159     return self->tok_oldval;
1160 }
1161 
1162 static PyGetSetDef PyContextTokenType_getsetlist[] = {
1163     {"var", (getter)token_get_var, NULL, NULL},
1164     {"old_value", (getter)token_get_old_value, NULL, NULL},
1165     {NULL}
1166 };
1167 
1168 PyTypeObject PyContextToken_Type = {
1169     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1170     "Token",
1171     sizeof(PyContextToken),
1172     .tp_getset = PyContextTokenType_getsetlist,
1173     .tp_dealloc = (destructor)token_tp_dealloc,
1174     .tp_getattro = PyObject_GenericGetAttr,
1175     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1176     .tp_traverse = (traverseproc)token_tp_traverse,
1177     .tp_clear = (inquiry)token_tp_clear,
1178     .tp_new = token_tp_new,
1179     .tp_free = PyObject_GC_Del,
1180     .tp_hash = PyObject_HashNotImplemented,
1181     .tp_repr = (reprfunc)token_tp_repr,
1182 };
1183 
1184 static PyContextToken *
token_new(PyContext * ctx,PyContextVar * var,PyObject * val)1185 token_new(PyContext *ctx, PyContextVar *var, PyObject *val)
1186 {
1187     PyContextToken *tok = PyObject_GC_New(PyContextToken, &PyContextToken_Type);
1188     if (tok == NULL) {
1189         return NULL;
1190     }
1191 
1192     Py_INCREF(ctx);
1193     tok->tok_ctx = ctx;
1194 
1195     Py_INCREF(var);
1196     tok->tok_var = var;
1197 
1198     Py_XINCREF(val);
1199     tok->tok_oldval = val;
1200 
1201     tok->tok_used = 0;
1202 
1203     PyObject_GC_Track(tok);
1204     return tok;
1205 }
1206 
1207 
1208 /////////////////////////// Token.MISSING
1209 
1210 
1211 static PyObject *_token_missing;
1212 
1213 
1214 typedef struct {
1215     PyObject_HEAD
1216 } PyContextTokenMissing;
1217 
1218 
1219 static PyObject *
context_token_missing_tp_repr(PyObject * self)1220 context_token_missing_tp_repr(PyObject *self)
1221 {
1222     return PyUnicode_FromString("<Token.MISSING>");
1223 }
1224 
1225 
1226 PyTypeObject PyContextTokenMissing_Type = {
1227     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1228     "Token.MISSING",
1229     sizeof(PyContextTokenMissing),
1230     .tp_getattro = PyObject_GenericGetAttr,
1231     .tp_flags = Py_TPFLAGS_DEFAULT,
1232     .tp_repr = context_token_missing_tp_repr,
1233 };
1234 
1235 
1236 static PyObject *
get_token_missing(void)1237 get_token_missing(void)
1238 {
1239     if (_token_missing != NULL) {
1240         Py_INCREF(_token_missing);
1241         return _token_missing;
1242     }
1243 
1244     _token_missing = (PyObject *)PyObject_New(
1245         PyContextTokenMissing, &PyContextTokenMissing_Type);
1246     if (_token_missing == NULL) {
1247         return NULL;
1248     }
1249 
1250     Py_INCREF(_token_missing);
1251     return _token_missing;
1252 }
1253 
1254 
1255 ///////////////////////////
1256 
1257 
1258 int
PyContext_ClearFreeList(void)1259 PyContext_ClearFreeList(void)
1260 {
1261     int size = ctx_freelist_len;
1262     while (ctx_freelist_len) {
1263         PyContext *ctx = ctx_freelist;
1264         ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
1265         ctx->ctx_weakreflist = NULL;
1266         PyObject_GC_Del(ctx);
1267         ctx_freelist_len--;
1268     }
1269     return size;
1270 }
1271 
1272 
1273 void
_PyContext_Fini(void)1274 _PyContext_Fini(void)
1275 {
1276     Py_CLEAR(_token_missing);
1277     (void)PyContext_ClearFreeList();
1278     (void)_PyHamt_Fini();
1279 }
1280 
1281 
1282 int
_PyContext_Init(void)1283 _PyContext_Init(void)
1284 {
1285     if (!_PyHamt_Init()) {
1286         return 0;
1287     }
1288 
1289     if ((PyType_Ready(&PyContext_Type) < 0) ||
1290         (PyType_Ready(&PyContextVar_Type) < 0) ||
1291         (PyType_Ready(&PyContextToken_Type) < 0) ||
1292         (PyType_Ready(&PyContextTokenMissing_Type) < 0))
1293     {
1294         return 0;
1295     }
1296 
1297     PyObject *missing = get_token_missing();
1298     if (PyDict_SetItemString(
1299         PyContextToken_Type.tp_dict, "MISSING", missing))
1300     {
1301         Py_DECREF(missing);
1302         return 0;
1303     }
1304     Py_DECREF(missing);
1305 
1306     return 1;
1307 }
1308