• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "Python.h"
2 
3 #include "pycore_context.h"
4 #include "pycore_gc.h"            // _PyObject_GC_MAY_BE_TRACKED()
5 #include "pycore_hamt.h"
6 #include "pycore_object.h"
7 #include "pycore_pyerrors.h"
8 #include "pycore_pystate.h"       // _PyThreadState_GET()
9 #include "structmember.h"         // PyMemberDef
10 
11 
12 #define CONTEXT_FREELIST_MAXLEN 255
13 
14 
15 #include "clinic/context.c.h"
16 /*[clinic input]
17 module _contextvars
18 [clinic start generated code]*/
19 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/
20 
21 
22 #define ENSURE_Context(o, err_ret)                                  \
23     if (!PyContext_CheckExact(o)) {                                 \
24         PyErr_SetString(PyExc_TypeError,                            \
25                         "an instance of Context was expected");     \
26         return err_ret;                                             \
27     }
28 
29 #define ENSURE_ContextVar(o, err_ret)                               \
30     if (!PyContextVar_CheckExact(o)) {                              \
31         PyErr_SetString(PyExc_TypeError,                            \
32                        "an instance of ContextVar was expected");   \
33         return err_ret;                                             \
34     }
35 
36 #define ENSURE_ContextToken(o, err_ret)                             \
37     if (!PyContextToken_CheckExact(o)) {                            \
38         PyErr_SetString(PyExc_TypeError,                            \
39                         "an instance of Token was expected");       \
40         return err_ret;                                             \
41     }
42 
43 
44 /////////////////////////// Context API
45 
46 
47 static PyContext *
48 context_new_empty(void);
49 
50 static PyContext *
51 context_new_from_vars(PyHamtObject *vars);
52 
53 static inline PyContext *
54 context_get(void);
55 
56 static PyContextToken *
57 token_new(PyContext *ctx, PyContextVar *var, PyObject *val);
58 
59 static PyContextVar *
60 contextvar_new(PyObject *name, PyObject *def);
61 
62 static int
63 contextvar_set(PyContextVar *var, PyObject *val);
64 
65 static int
66 contextvar_del(PyContextVar *var);
67 
68 
69 static struct _Py_context_state *
get_context_state(void)70 get_context_state(void)
71 {
72     PyInterpreterState *interp = _PyInterpreterState_GET();
73     return &interp->context;
74 }
75 
76 
77 PyObject *
_PyContext_NewHamtForTests(void)78 _PyContext_NewHamtForTests(void)
79 {
80     return (PyObject *)_PyHamt_New();
81 }
82 
83 
84 PyObject *
PyContext_New(void)85 PyContext_New(void)
86 {
87     return (PyObject *)context_new_empty();
88 }
89 
90 
91 PyObject *
PyContext_Copy(PyObject * octx)92 PyContext_Copy(PyObject * octx)
93 {
94     ENSURE_Context(octx, NULL)
95     PyContext *ctx = (PyContext *)octx;
96     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
97 }
98 
99 
100 PyObject *
PyContext_CopyCurrent(void)101 PyContext_CopyCurrent(void)
102 {
103     PyContext *ctx = context_get();
104     if (ctx == NULL) {
105         return NULL;
106     }
107 
108     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
109 }
110 
111 
112 static int
_PyContext_Enter(PyThreadState * ts,PyObject * octx)113 _PyContext_Enter(PyThreadState *ts, PyObject *octx)
114 {
115     ENSURE_Context(octx, -1)
116     PyContext *ctx = (PyContext *)octx;
117 
118     if (ctx->ctx_entered) {
119         _PyErr_Format(ts, PyExc_RuntimeError,
120                       "cannot enter context: %R is already entered", ctx);
121         return -1;
122     }
123 
124     ctx->ctx_prev = (PyContext *)ts->context;  /* borrow */
125     ctx->ctx_entered = 1;
126 
127     Py_INCREF(ctx);
128     ts->context = (PyObject *)ctx;
129     ts->context_ver++;
130 
131     return 0;
132 }
133 
134 
135 int
PyContext_Enter(PyObject * octx)136 PyContext_Enter(PyObject *octx)
137 {
138     PyThreadState *ts = _PyThreadState_GET();
139     assert(ts != NULL);
140     return _PyContext_Enter(ts, octx);
141 }
142 
143 
144 static int
_PyContext_Exit(PyThreadState * ts,PyObject * octx)145 _PyContext_Exit(PyThreadState *ts, PyObject *octx)
146 {
147     ENSURE_Context(octx, -1)
148     PyContext *ctx = (PyContext *)octx;
149 
150     if (!ctx->ctx_entered) {
151         PyErr_Format(PyExc_RuntimeError,
152                      "cannot exit context: %R has not been entered", ctx);
153         return -1;
154     }
155 
156     if (ts->context != (PyObject *)ctx) {
157         /* Can only happen if someone misuses the C API */
158         PyErr_SetString(PyExc_RuntimeError,
159                         "cannot exit context: thread state references "
160                         "a different context object");
161         return -1;
162     }
163 
164     Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev);
165     ts->context_ver++;
166 
167     ctx->ctx_prev = NULL;
168     ctx->ctx_entered = 0;
169 
170     return 0;
171 }
172 
173 int
PyContext_Exit(PyObject * octx)174 PyContext_Exit(PyObject *octx)
175 {
176     PyThreadState *ts = _PyThreadState_GET();
177     assert(ts != NULL);
178     return _PyContext_Exit(ts, octx);
179 }
180 
181 
182 PyObject *
PyContextVar_New(const char * name,PyObject * def)183 PyContextVar_New(const char *name, PyObject *def)
184 {
185     PyObject *pyname = PyUnicode_FromString(name);
186     if (pyname == NULL) {
187         return NULL;
188     }
189     PyContextVar *var = contextvar_new(pyname, def);
190     Py_DECREF(pyname);
191     return (PyObject *)var;
192 }
193 
194 
195 int
PyContextVar_Get(PyObject * ovar,PyObject * def,PyObject ** val)196 PyContextVar_Get(PyObject *ovar, PyObject *def, PyObject **val)
197 {
198     ENSURE_ContextVar(ovar, -1)
199     PyContextVar *var = (PyContextVar *)ovar;
200 
201     PyThreadState *ts = _PyThreadState_GET();
202     assert(ts != NULL);
203     if (ts->context == NULL) {
204         goto not_found;
205     }
206 
207     if (var->var_cached != NULL &&
208             var->var_cached_tsid == ts->id &&
209             var->var_cached_tsver == ts->context_ver)
210     {
211         *val = var->var_cached;
212         goto found;
213     }
214 
215     assert(PyContext_CheckExact(ts->context));
216     PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars;
217 
218     PyObject *found = NULL;
219     int res = _PyHamt_Find(vars, (PyObject*)var, &found);
220     if (res < 0) {
221         goto error;
222     }
223     if (res == 1) {
224         assert(found != NULL);
225         var->var_cached = found;  /* borrow */
226         var->var_cached_tsid = ts->id;
227         var->var_cached_tsver = ts->context_ver;
228 
229         *val = found;
230         goto found;
231     }
232 
233 not_found:
234     if (def == NULL) {
235         if (var->var_default != NULL) {
236             *val = var->var_default;
237             goto found;
238         }
239 
240         *val = NULL;
241         goto found;
242     }
243     else {
244         *val = def;
245         goto found;
246    }
247 
248 found:
249     Py_XINCREF(*val);
250     return 0;
251 
252 error:
253     *val = NULL;
254     return -1;
255 }
256 
257 
258 PyObject *
PyContextVar_Set(PyObject * ovar,PyObject * val)259 PyContextVar_Set(PyObject *ovar, PyObject *val)
260 {
261     ENSURE_ContextVar(ovar, NULL)
262     PyContextVar *var = (PyContextVar *)ovar;
263 
264     if (!PyContextVar_CheckExact(var)) {
265         PyErr_SetString(
266             PyExc_TypeError, "an instance of ContextVar was expected");
267         return NULL;
268     }
269 
270     PyContext *ctx = context_get();
271     if (ctx == NULL) {
272         return NULL;
273     }
274 
275     PyObject *old_val = NULL;
276     int found = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &old_val);
277     if (found < 0) {
278         return NULL;
279     }
280 
281     Py_XINCREF(old_val);
282     PyContextToken *tok = token_new(ctx, var, old_val);
283     Py_XDECREF(old_val);
284 
285     if (contextvar_set(var, val)) {
286         Py_DECREF(tok);
287         return NULL;
288     }
289 
290     return (PyObject *)tok;
291 }
292 
293 
294 int
PyContextVar_Reset(PyObject * ovar,PyObject * otok)295 PyContextVar_Reset(PyObject *ovar, PyObject *otok)
296 {
297     ENSURE_ContextVar(ovar, -1)
298     ENSURE_ContextToken(otok, -1)
299     PyContextVar *var = (PyContextVar *)ovar;
300     PyContextToken *tok = (PyContextToken *)otok;
301 
302     if (tok->tok_used) {
303         PyErr_Format(PyExc_RuntimeError,
304                      "%R has already been used once", tok);
305         return -1;
306     }
307 
308     if (var != tok->tok_var) {
309         PyErr_Format(PyExc_ValueError,
310                      "%R was created by a different ContextVar", tok);
311         return -1;
312     }
313 
314     PyContext *ctx = context_get();
315     if (ctx != tok->tok_ctx) {
316         PyErr_Format(PyExc_ValueError,
317                      "%R was created in a different Context", tok);
318         return -1;
319     }
320 
321     tok->tok_used = 1;
322 
323     if (tok->tok_oldval == NULL) {
324         return contextvar_del(var);
325     }
326     else {
327         return contextvar_set(var, tok->tok_oldval);
328     }
329 }
330 
331 
332 /////////////////////////// PyContext
333 
334 /*[clinic input]
335 class _contextvars.Context "PyContext *" "&PyContext_Type"
336 [clinic start generated code]*/
337 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=bdf87f8e0cb580e8]*/
338 
339 
340 static inline PyContext *
_context_alloc(void)341 _context_alloc(void)
342 {
343     struct _Py_context_state *state = get_context_state();
344     PyContext *ctx;
345 #ifdef Py_DEBUG
346     // _context_alloc() must not be called after _PyContext_Fini()
347     assert(state->numfree != -1);
348 #endif
349     if (state->numfree) {
350         state->numfree--;
351         ctx = state->freelist;
352         state->freelist = (PyContext *)ctx->ctx_weakreflist;
353         ctx->ctx_weakreflist = NULL;
354         _Py_NewReference((PyObject *)ctx);
355     }
356     else {
357         ctx = PyObject_GC_New(PyContext, &PyContext_Type);
358         if (ctx == NULL) {
359             return NULL;
360         }
361     }
362 
363     ctx->ctx_vars = NULL;
364     ctx->ctx_prev = NULL;
365     ctx->ctx_entered = 0;
366     ctx->ctx_weakreflist = NULL;
367 
368     return ctx;
369 }
370 
371 
372 static PyContext *
context_new_empty(void)373 context_new_empty(void)
374 {
375     PyContext *ctx = _context_alloc();
376     if (ctx == NULL) {
377         return NULL;
378     }
379 
380     ctx->ctx_vars = _PyHamt_New();
381     if (ctx->ctx_vars == NULL) {
382         Py_DECREF(ctx);
383         return NULL;
384     }
385 
386     _PyObject_GC_TRACK(ctx);
387     return ctx;
388 }
389 
390 
391 static PyContext *
context_new_from_vars(PyHamtObject * vars)392 context_new_from_vars(PyHamtObject *vars)
393 {
394     PyContext *ctx = _context_alloc();
395     if (ctx == NULL) {
396         return NULL;
397     }
398 
399     Py_INCREF(vars);
400     ctx->ctx_vars = vars;
401 
402     _PyObject_GC_TRACK(ctx);
403     return ctx;
404 }
405 
406 
407 static inline PyContext *
context_get(void)408 context_get(void)
409 {
410     PyThreadState *ts = _PyThreadState_GET();
411     assert(ts != NULL);
412     PyContext *current_ctx = (PyContext *)ts->context;
413     if (current_ctx == NULL) {
414         current_ctx = context_new_empty();
415         if (current_ctx == NULL) {
416             return NULL;
417         }
418         ts->context = (PyObject *)current_ctx;
419     }
420     return current_ctx;
421 }
422 
423 static int
context_check_key_type(PyObject * key)424 context_check_key_type(PyObject *key)
425 {
426     if (!PyContextVar_CheckExact(key)) {
427         // abort();
428         PyErr_Format(PyExc_TypeError,
429                      "a ContextVar key was expected, got %R", key);
430         return -1;
431     }
432     return 0;
433 }
434 
435 static PyObject *
context_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)436 context_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
437 {
438     if (PyTuple_Size(args) || (kwds != NULL && PyDict_Size(kwds))) {
439         PyErr_SetString(
440             PyExc_TypeError, "Context() does not accept any arguments");
441         return NULL;
442     }
443     return PyContext_New();
444 }
445 
446 static int
context_tp_clear(PyContext * self)447 context_tp_clear(PyContext *self)
448 {
449     Py_CLEAR(self->ctx_prev);
450     Py_CLEAR(self->ctx_vars);
451     return 0;
452 }
453 
454 static int
context_tp_traverse(PyContext * self,visitproc visit,void * arg)455 context_tp_traverse(PyContext *self, visitproc visit, void *arg)
456 {
457     Py_VISIT(self->ctx_prev);
458     Py_VISIT(self->ctx_vars);
459     return 0;
460 }
461 
462 static void
context_tp_dealloc(PyContext * self)463 context_tp_dealloc(PyContext *self)
464 {
465     _PyObject_GC_UNTRACK(self);
466 
467     if (self->ctx_weakreflist != NULL) {
468         PyObject_ClearWeakRefs((PyObject*)self);
469     }
470     (void)context_tp_clear(self);
471 
472     struct _Py_context_state *state = get_context_state();
473 #ifdef Py_DEBUG
474     // _context_alloc() must not be called after _PyContext_Fini()
475     assert(state->numfree != -1);
476 #endif
477     if (state->numfree < CONTEXT_FREELIST_MAXLEN) {
478         state->numfree++;
479         self->ctx_weakreflist = (PyObject *)state->freelist;
480         state->freelist = self;
481     }
482     else {
483         Py_TYPE(self)->tp_free(self);
484     }
485 }
486 
487 static PyObject *
context_tp_iter(PyContext * self)488 context_tp_iter(PyContext *self)
489 {
490     return _PyHamt_NewIterKeys(self->ctx_vars);
491 }
492 
493 static PyObject *
context_tp_richcompare(PyObject * v,PyObject * w,int op)494 context_tp_richcompare(PyObject *v, PyObject *w, int op)
495 {
496     if (!PyContext_CheckExact(v) || !PyContext_CheckExact(w) ||
497             (op != Py_EQ && op != Py_NE))
498     {
499         Py_RETURN_NOTIMPLEMENTED;
500     }
501 
502     int res = _PyHamt_Eq(
503         ((PyContext *)v)->ctx_vars, ((PyContext *)w)->ctx_vars);
504     if (res < 0) {
505         return NULL;
506     }
507 
508     if (op == Py_NE) {
509         res = !res;
510     }
511 
512     if (res) {
513         Py_RETURN_TRUE;
514     }
515     else {
516         Py_RETURN_FALSE;
517     }
518 }
519 
520 static Py_ssize_t
context_tp_len(PyContext * self)521 context_tp_len(PyContext *self)
522 {
523     return _PyHamt_Len(self->ctx_vars);
524 }
525 
526 static PyObject *
context_tp_subscript(PyContext * self,PyObject * key)527 context_tp_subscript(PyContext *self, PyObject *key)
528 {
529     if (context_check_key_type(key)) {
530         return NULL;
531     }
532     PyObject *val = NULL;
533     int found = _PyHamt_Find(self->ctx_vars, key, &val);
534     if (found < 0) {
535         return NULL;
536     }
537     if (found == 0) {
538         PyErr_SetObject(PyExc_KeyError, key);
539         return NULL;
540     }
541     Py_INCREF(val);
542     return val;
543 }
544 
545 static int
context_tp_contains(PyContext * self,PyObject * key)546 context_tp_contains(PyContext *self, PyObject *key)
547 {
548     if (context_check_key_type(key)) {
549         return -1;
550     }
551     PyObject *val = NULL;
552     return _PyHamt_Find(self->ctx_vars, key, &val);
553 }
554 
555 
556 /*[clinic input]
557 _contextvars.Context.get
558     key: object
559     default: object = None
560     /
561 
562 Return the value for `key` if `key` has the value in the context object.
563 
564 If `key` does not exist, return `default`. If `default` is not given,
565 return None.
566 [clinic start generated code]*/
567 
568 static PyObject *
_contextvars_Context_get_impl(PyContext * self,PyObject * key,PyObject * default_value)569 _contextvars_Context_get_impl(PyContext *self, PyObject *key,
570                               PyObject *default_value)
571 /*[clinic end generated code: output=0c54aa7664268189 input=c8eeb81505023995]*/
572 {
573     if (context_check_key_type(key)) {
574         return NULL;
575     }
576 
577     PyObject *val = NULL;
578     int found = _PyHamt_Find(self->ctx_vars, key, &val);
579     if (found < 0) {
580         return NULL;
581     }
582     if (found == 0) {
583         Py_INCREF(default_value);
584         return default_value;
585     }
586     Py_INCREF(val);
587     return val;
588 }
589 
590 
591 /*[clinic input]
592 _contextvars.Context.items
593 
594 Return all variables and their values in the context object.
595 
596 The result is returned as a list of 2-tuples (variable, value).
597 [clinic start generated code]*/
598 
599 static PyObject *
_contextvars_Context_items_impl(PyContext * self)600 _contextvars_Context_items_impl(PyContext *self)
601 /*[clinic end generated code: output=fa1655c8a08502af input=00db64ae379f9f42]*/
602 {
603     return _PyHamt_NewIterItems(self->ctx_vars);
604 }
605 
606 
607 /*[clinic input]
608 _contextvars.Context.keys
609 
610 Return a list of all variables in the context object.
611 [clinic start generated code]*/
612 
613 static PyObject *
_contextvars_Context_keys_impl(PyContext * self)614 _contextvars_Context_keys_impl(PyContext *self)
615 /*[clinic end generated code: output=177227c6b63ec0e2 input=114b53aebca3449c]*/
616 {
617     return _PyHamt_NewIterKeys(self->ctx_vars);
618 }
619 
620 
621 /*[clinic input]
622 _contextvars.Context.values
623 
624 Return a list of all variables' values in the context object.
625 [clinic start generated code]*/
626 
627 static PyObject *
_contextvars_Context_values_impl(PyContext * self)628 _contextvars_Context_values_impl(PyContext *self)
629 /*[clinic end generated code: output=d286dabfc8db6dde input=ce8075d04a6ea526]*/
630 {
631     return _PyHamt_NewIterValues(self->ctx_vars);
632 }
633 
634 
635 /*[clinic input]
636 _contextvars.Context.copy
637 
638 Return a shallow copy of the context object.
639 [clinic start generated code]*/
640 
641 static PyObject *
_contextvars_Context_copy_impl(PyContext * self)642 _contextvars_Context_copy_impl(PyContext *self)
643 /*[clinic end generated code: output=30ba8896c4707a15 input=ebafdbdd9c72d592]*/
644 {
645     return (PyObject *)context_new_from_vars(self->ctx_vars);
646 }
647 
648 
649 static PyObject *
context_run(PyContext * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)650 context_run(PyContext *self, PyObject *const *args,
651             Py_ssize_t nargs, PyObject *kwnames)
652 {
653     PyThreadState *ts = _PyThreadState_GET();
654 
655     if (nargs < 1) {
656         _PyErr_SetString(ts, PyExc_TypeError,
657                          "run() missing 1 required positional argument");
658         return NULL;
659     }
660 
661     if (_PyContext_Enter(ts, (PyObject *)self)) {
662         return NULL;
663     }
664 
665     PyObject *call_result = _PyObject_VectorcallTstate(
666         ts, args[0], args + 1, nargs - 1, kwnames);
667 
668     if (_PyContext_Exit(ts, (PyObject *)self)) {
669         return NULL;
670     }
671 
672     return call_result;
673 }
674 
675 
676 static PyMethodDef PyContext_methods[] = {
677     _CONTEXTVARS_CONTEXT_GET_METHODDEF
678     _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF
679     _CONTEXTVARS_CONTEXT_KEYS_METHODDEF
680     _CONTEXTVARS_CONTEXT_VALUES_METHODDEF
681     _CONTEXTVARS_CONTEXT_COPY_METHODDEF
682     {"run", (PyCFunction)(void(*)(void))context_run, METH_FASTCALL | METH_KEYWORDS, NULL},
683     {NULL, NULL}
684 };
685 
686 static PySequenceMethods PyContext_as_sequence = {
687     0,                                   /* sq_length */
688     0,                                   /* sq_concat */
689     0,                                   /* sq_repeat */
690     0,                                   /* sq_item */
691     0,                                   /* sq_slice */
692     0,                                   /* sq_ass_item */
693     0,                                   /* sq_ass_slice */
694     (objobjproc)context_tp_contains,     /* sq_contains */
695     0,                                   /* sq_inplace_concat */
696     0,                                   /* sq_inplace_repeat */
697 };
698 
699 static PyMappingMethods PyContext_as_mapping = {
700     (lenfunc)context_tp_len,             /* mp_length */
701     (binaryfunc)context_tp_subscript,    /* mp_subscript */
702 };
703 
704 PyTypeObject PyContext_Type = {
705     PyVarObject_HEAD_INIT(&PyType_Type, 0)
706     "_contextvars.Context",
707     sizeof(PyContext),
708     .tp_methods = PyContext_methods,
709     .tp_as_mapping = &PyContext_as_mapping,
710     .tp_as_sequence = &PyContext_as_sequence,
711     .tp_iter = (getiterfunc)context_tp_iter,
712     .tp_dealloc = (destructor)context_tp_dealloc,
713     .tp_getattro = PyObject_GenericGetAttr,
714     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
715     .tp_richcompare = context_tp_richcompare,
716     .tp_traverse = (traverseproc)context_tp_traverse,
717     .tp_clear = (inquiry)context_tp_clear,
718     .tp_new = context_tp_new,
719     .tp_weaklistoffset = offsetof(PyContext, ctx_weakreflist),
720     .tp_hash = PyObject_HashNotImplemented,
721 };
722 
723 
724 /////////////////////////// ContextVar
725 
726 
727 static int
contextvar_set(PyContextVar * var,PyObject * val)728 contextvar_set(PyContextVar *var, PyObject *val)
729 {
730     var->var_cached = NULL;
731     PyThreadState *ts = PyThreadState_Get();
732 
733     PyContext *ctx = context_get();
734     if (ctx == NULL) {
735         return -1;
736     }
737 
738     PyHamtObject *new_vars = _PyHamt_Assoc(
739         ctx->ctx_vars, (PyObject *)var, val);
740     if (new_vars == NULL) {
741         return -1;
742     }
743 
744     Py_SETREF(ctx->ctx_vars, new_vars);
745 
746     var->var_cached = val;  /* borrow */
747     var->var_cached_tsid = ts->id;
748     var->var_cached_tsver = ts->context_ver;
749     return 0;
750 }
751 
752 static int
contextvar_del(PyContextVar * var)753 contextvar_del(PyContextVar *var)
754 {
755     var->var_cached = NULL;
756 
757     PyContext *ctx = context_get();
758     if (ctx == NULL) {
759         return -1;
760     }
761 
762     PyHamtObject *vars = ctx->ctx_vars;
763     PyHamtObject *new_vars = _PyHamt_Without(vars, (PyObject *)var);
764     if (new_vars == NULL) {
765         return -1;
766     }
767 
768     if (vars == new_vars) {
769         Py_DECREF(new_vars);
770         PyErr_SetObject(PyExc_LookupError, (PyObject *)var);
771         return -1;
772     }
773 
774     Py_SETREF(ctx->ctx_vars, new_vars);
775     return 0;
776 }
777 
778 static Py_hash_t
contextvar_generate_hash(void * addr,PyObject * name)779 contextvar_generate_hash(void *addr, PyObject *name)
780 {
781     /* Take hash of `name` and XOR it with the object's addr.
782 
783        The structure of the tree is encoded in objects' hashes, which
784        means that sufficiently similar hashes would result in tall trees
785        with many Collision nodes.  Which would, in turn, result in slower
786        get and set operations.
787 
788        The XORing helps to ensure that:
789 
790        (1) sequentially allocated ContextVar objects have
791            different hashes;
792 
793        (2) context variables with equal names have
794            different hashes.
795     */
796 
797     Py_hash_t name_hash = PyObject_Hash(name);
798     if (name_hash == -1) {
799         return -1;
800     }
801 
802     Py_hash_t res = _Py_HashPointer(addr) ^ name_hash;
803     return res == -1 ? -2 : res;
804 }
805 
806 static PyContextVar *
contextvar_new(PyObject * name,PyObject * def)807 contextvar_new(PyObject *name, PyObject *def)
808 {
809     if (!PyUnicode_Check(name)) {
810         PyErr_SetString(PyExc_TypeError,
811                         "context variable name must be a str");
812         return NULL;
813     }
814 
815     PyContextVar *var = PyObject_GC_New(PyContextVar, &PyContextVar_Type);
816     if (var == NULL) {
817         return NULL;
818     }
819 
820     var->var_hash = contextvar_generate_hash(var, name);
821     if (var->var_hash == -1) {
822         Py_DECREF(var);
823         return NULL;
824     }
825 
826     Py_INCREF(name);
827     var->var_name = name;
828 
829     Py_XINCREF(def);
830     var->var_default = def;
831 
832     var->var_cached = NULL;
833     var->var_cached_tsid = 0;
834     var->var_cached_tsver = 0;
835 
836     if (_PyObject_GC_MAY_BE_TRACKED(name) ||
837             (def != NULL && _PyObject_GC_MAY_BE_TRACKED(def)))
838     {
839         PyObject_GC_Track(var);
840     }
841     return var;
842 }
843 
844 
845 /*[clinic input]
846 class _contextvars.ContextVar "PyContextVar *" "&PyContextVar_Type"
847 [clinic start generated code]*/
848 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=445da935fa8883c3]*/
849 
850 
851 static PyObject *
contextvar_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)852 contextvar_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
853 {
854     static char *kwlist[] = {"", "default", NULL};
855     PyObject *name;
856     PyObject *def = NULL;
857 
858     if (!PyArg_ParseTupleAndKeywords(
859             args, kwds, "O|$O:ContextVar", kwlist, &name, &def))
860     {
861         return NULL;
862     }
863 
864     return (PyObject *)contextvar_new(name, def);
865 }
866 
867 static int
contextvar_tp_clear(PyContextVar * self)868 contextvar_tp_clear(PyContextVar *self)
869 {
870     Py_CLEAR(self->var_name);
871     Py_CLEAR(self->var_default);
872     self->var_cached = NULL;
873     self->var_cached_tsid = 0;
874     self->var_cached_tsver = 0;
875     return 0;
876 }
877 
878 static int
contextvar_tp_traverse(PyContextVar * self,visitproc visit,void * arg)879 contextvar_tp_traverse(PyContextVar *self, visitproc visit, void *arg)
880 {
881     Py_VISIT(self->var_name);
882     Py_VISIT(self->var_default);
883     return 0;
884 }
885 
886 static void
contextvar_tp_dealloc(PyContextVar * self)887 contextvar_tp_dealloc(PyContextVar *self)
888 {
889     PyObject_GC_UnTrack(self);
890     (void)contextvar_tp_clear(self);
891     Py_TYPE(self)->tp_free(self);
892 }
893 
894 static Py_hash_t
contextvar_tp_hash(PyContextVar * self)895 contextvar_tp_hash(PyContextVar *self)
896 {
897     return self->var_hash;
898 }
899 
900 static PyObject *
contextvar_tp_repr(PyContextVar * self)901 contextvar_tp_repr(PyContextVar *self)
902 {
903     _PyUnicodeWriter writer;
904 
905     _PyUnicodeWriter_Init(&writer);
906 
907     if (_PyUnicodeWriter_WriteASCIIString(
908             &writer, "<ContextVar name=", 17) < 0)
909     {
910         goto error;
911     }
912 
913     PyObject *name = PyObject_Repr(self->var_name);
914     if (name == NULL) {
915         goto error;
916     }
917     if (_PyUnicodeWriter_WriteStr(&writer, name) < 0) {
918         Py_DECREF(name);
919         goto error;
920     }
921     Py_DECREF(name);
922 
923     if (self->var_default != NULL) {
924         if (_PyUnicodeWriter_WriteASCIIString(&writer, " default=", 9) < 0) {
925             goto error;
926         }
927 
928         PyObject *def = PyObject_Repr(self->var_default);
929         if (def == NULL) {
930             goto error;
931         }
932         if (_PyUnicodeWriter_WriteStr(&writer, def) < 0) {
933             Py_DECREF(def);
934             goto error;
935         }
936         Py_DECREF(def);
937     }
938 
939     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
940     if (addr == NULL) {
941         goto error;
942     }
943     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
944         Py_DECREF(addr);
945         goto error;
946     }
947     Py_DECREF(addr);
948 
949     return _PyUnicodeWriter_Finish(&writer);
950 
951 error:
952     _PyUnicodeWriter_Dealloc(&writer);
953     return NULL;
954 }
955 
956 
957 /*[clinic input]
958 _contextvars.ContextVar.get
959     default: object = NULL
960     /
961 
962 Return a value for the context variable for the current context.
963 
964 If there is no value for the variable in the current context, the method will:
965  * return the value of the default argument of the method, if provided; or
966  * return the default value for the context variable, if it was created
967    with one; or
968  * raise a LookupError.
969 [clinic start generated code]*/
970 
971 static PyObject *
_contextvars_ContextVar_get_impl(PyContextVar * self,PyObject * default_value)972 _contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value)
973 /*[clinic end generated code: output=0746bd0aa2ced7bf input=30aa2ab9e433e401]*/
974 {
975     if (!PyContextVar_CheckExact(self)) {
976         PyErr_SetString(
977             PyExc_TypeError, "an instance of ContextVar was expected");
978         return NULL;
979     }
980 
981     PyObject *val;
982     if (PyContextVar_Get((PyObject *)self, default_value, &val) < 0) {
983         return NULL;
984     }
985 
986     if (val == NULL) {
987         PyErr_SetObject(PyExc_LookupError, (PyObject *)self);
988         return NULL;
989     }
990 
991     return val;
992 }
993 
994 /*[clinic input]
995 _contextvars.ContextVar.set
996     value: object
997     /
998 
999 Call to set a new value for the context variable in the current context.
1000 
1001 The required value argument is the new value for the context variable.
1002 
1003 Returns a Token object that can be used to restore the variable to its previous
1004 value via the `ContextVar.reset()` method.
1005 [clinic start generated code]*/
1006 
1007 static PyObject *
_contextvars_ContextVar_set(PyContextVar * self,PyObject * value)1008 _contextvars_ContextVar_set(PyContextVar *self, PyObject *value)
1009 /*[clinic end generated code: output=446ed5e820d6d60b input=c0a6887154227453]*/
1010 {
1011     return PyContextVar_Set((PyObject *)self, value);
1012 }
1013 
1014 /*[clinic input]
1015 _contextvars.ContextVar.reset
1016     token: object
1017     /
1018 
1019 Reset the context variable.
1020 
1021 The variable is reset to the value it had before the `ContextVar.set()` that
1022 created the token was used.
1023 [clinic start generated code]*/
1024 
1025 static PyObject *
_contextvars_ContextVar_reset(PyContextVar * self,PyObject * token)1026 _contextvars_ContextVar_reset(PyContextVar *self, PyObject *token)
1027 /*[clinic end generated code: output=d4ee34d0742d62ee input=ebe2881e5af4ffda]*/
1028 {
1029     if (!PyContextToken_CheckExact(token)) {
1030         PyErr_Format(PyExc_TypeError,
1031                      "expected an instance of Token, got %R", token);
1032         return NULL;
1033     }
1034 
1035     if (PyContextVar_Reset((PyObject *)self, token)) {
1036         return NULL;
1037     }
1038 
1039     Py_RETURN_NONE;
1040 }
1041 
1042 
1043 static PyMemberDef PyContextVar_members[] = {
1044     {"name", T_OBJECT, offsetof(PyContextVar, var_name), READONLY},
1045     {NULL}
1046 };
1047 
1048 static PyMethodDef PyContextVar_methods[] = {
1049     _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF
1050     _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF
1051     _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF
1052     {"__class_getitem__", (PyCFunction)Py_GenericAlias,
1053     METH_O|METH_CLASS,       PyDoc_STR("See PEP 585")},
1054     {NULL, NULL}
1055 };
1056 
1057 PyTypeObject PyContextVar_Type = {
1058     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1059     "_contextvars.ContextVar",
1060     sizeof(PyContextVar),
1061     .tp_methods = PyContextVar_methods,
1062     .tp_members = PyContextVar_members,
1063     .tp_dealloc = (destructor)contextvar_tp_dealloc,
1064     .tp_getattro = PyObject_GenericGetAttr,
1065     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1066     .tp_traverse = (traverseproc)contextvar_tp_traverse,
1067     .tp_clear = (inquiry)contextvar_tp_clear,
1068     .tp_new = contextvar_tp_new,
1069     .tp_free = PyObject_GC_Del,
1070     .tp_hash = (hashfunc)contextvar_tp_hash,
1071     .tp_repr = (reprfunc)contextvar_tp_repr,
1072 };
1073 
1074 
1075 /////////////////////////// Token
1076 
1077 static PyObject * get_token_missing(void);
1078 
1079 
1080 /*[clinic input]
1081 class _contextvars.Token "PyContextToken *" "&PyContextToken_Type"
1082 [clinic start generated code]*/
1083 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=338a5e2db13d3f5b]*/
1084 
1085 
1086 static PyObject *
token_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)1087 token_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
1088 {
1089     PyErr_SetString(PyExc_RuntimeError,
1090                     "Tokens can only be created by ContextVars");
1091     return NULL;
1092 }
1093 
1094 static int
token_tp_clear(PyContextToken * self)1095 token_tp_clear(PyContextToken *self)
1096 {
1097     Py_CLEAR(self->tok_ctx);
1098     Py_CLEAR(self->tok_var);
1099     Py_CLEAR(self->tok_oldval);
1100     return 0;
1101 }
1102 
1103 static int
token_tp_traverse(PyContextToken * self,visitproc visit,void * arg)1104 token_tp_traverse(PyContextToken *self, visitproc visit, void *arg)
1105 {
1106     Py_VISIT(self->tok_ctx);
1107     Py_VISIT(self->tok_var);
1108     Py_VISIT(self->tok_oldval);
1109     return 0;
1110 }
1111 
1112 static void
token_tp_dealloc(PyContextToken * self)1113 token_tp_dealloc(PyContextToken *self)
1114 {
1115     PyObject_GC_UnTrack(self);
1116     (void)token_tp_clear(self);
1117     Py_TYPE(self)->tp_free(self);
1118 }
1119 
1120 static PyObject *
token_tp_repr(PyContextToken * self)1121 token_tp_repr(PyContextToken *self)
1122 {
1123     _PyUnicodeWriter writer;
1124 
1125     _PyUnicodeWriter_Init(&writer);
1126 
1127     if (_PyUnicodeWriter_WriteASCIIString(&writer, "<Token", 6) < 0) {
1128         goto error;
1129     }
1130 
1131     if (self->tok_used) {
1132         if (_PyUnicodeWriter_WriteASCIIString(&writer, " used", 5) < 0) {
1133             goto error;
1134         }
1135     }
1136 
1137     if (_PyUnicodeWriter_WriteASCIIString(&writer, " var=", 5) < 0) {
1138         goto error;
1139     }
1140 
1141     PyObject *var = PyObject_Repr((PyObject *)self->tok_var);
1142     if (var == NULL) {
1143         goto error;
1144     }
1145     if (_PyUnicodeWriter_WriteStr(&writer, var) < 0) {
1146         Py_DECREF(var);
1147         goto error;
1148     }
1149     Py_DECREF(var);
1150 
1151     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
1152     if (addr == NULL) {
1153         goto error;
1154     }
1155     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
1156         Py_DECREF(addr);
1157         goto error;
1158     }
1159     Py_DECREF(addr);
1160 
1161     return _PyUnicodeWriter_Finish(&writer);
1162 
1163 error:
1164     _PyUnicodeWriter_Dealloc(&writer);
1165     return NULL;
1166 }
1167 
1168 static PyObject *
token_get_var(PyContextToken * self,void * Py_UNUSED (ignored))1169 token_get_var(PyContextToken *self, void *Py_UNUSED(ignored))
1170 {
1171     Py_INCREF(self->tok_var);
1172     return (PyObject *)self->tok_var;
1173 }
1174 
1175 static PyObject *
token_get_old_value(PyContextToken * self,void * Py_UNUSED (ignored))1176 token_get_old_value(PyContextToken *self, void *Py_UNUSED(ignored))
1177 {
1178     if (self->tok_oldval == NULL) {
1179         return get_token_missing();
1180     }
1181 
1182     Py_INCREF(self->tok_oldval);
1183     return self->tok_oldval;
1184 }
1185 
1186 static PyGetSetDef PyContextTokenType_getsetlist[] = {
1187     {"var", (getter)token_get_var, NULL, NULL},
1188     {"old_value", (getter)token_get_old_value, NULL, NULL},
1189     {NULL}
1190 };
1191 
1192 static PyMethodDef PyContextTokenType_methods[] = {
1193     {"__class_getitem__",    (PyCFunction)Py_GenericAlias,
1194     METH_O|METH_CLASS,       PyDoc_STR("See PEP 585")},
1195     {NULL}
1196 };
1197 
1198 PyTypeObject PyContextToken_Type = {
1199     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1200     "_contextvars.Token",
1201     sizeof(PyContextToken),
1202     .tp_methods = PyContextTokenType_methods,
1203     .tp_getset = PyContextTokenType_getsetlist,
1204     .tp_dealloc = (destructor)token_tp_dealloc,
1205     .tp_getattro = PyObject_GenericGetAttr,
1206     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1207     .tp_traverse = (traverseproc)token_tp_traverse,
1208     .tp_clear = (inquiry)token_tp_clear,
1209     .tp_new = token_tp_new,
1210     .tp_free = PyObject_GC_Del,
1211     .tp_hash = PyObject_HashNotImplemented,
1212     .tp_repr = (reprfunc)token_tp_repr,
1213 };
1214 
1215 static PyContextToken *
token_new(PyContext * ctx,PyContextVar * var,PyObject * val)1216 token_new(PyContext *ctx, PyContextVar *var, PyObject *val)
1217 {
1218     PyContextToken *tok = PyObject_GC_New(PyContextToken, &PyContextToken_Type);
1219     if (tok == NULL) {
1220         return NULL;
1221     }
1222 
1223     Py_INCREF(ctx);
1224     tok->tok_ctx = ctx;
1225 
1226     Py_INCREF(var);
1227     tok->tok_var = var;
1228 
1229     Py_XINCREF(val);
1230     tok->tok_oldval = val;
1231 
1232     tok->tok_used = 0;
1233 
1234     PyObject_GC_Track(tok);
1235     return tok;
1236 }
1237 
1238 
1239 /////////////////////////// Token.MISSING
1240 
1241 
1242 static PyObject *_token_missing;
1243 
1244 
1245 typedef struct {
1246     PyObject_HEAD
1247 } PyContextTokenMissing;
1248 
1249 
1250 static PyObject *
context_token_missing_tp_repr(PyObject * self)1251 context_token_missing_tp_repr(PyObject *self)
1252 {
1253     return PyUnicode_FromString("<Token.MISSING>");
1254 }
1255 
1256 
1257 PyTypeObject PyContextTokenMissing_Type = {
1258     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1259     "Token.MISSING",
1260     sizeof(PyContextTokenMissing),
1261     .tp_getattro = PyObject_GenericGetAttr,
1262     .tp_flags = Py_TPFLAGS_DEFAULT,
1263     .tp_repr = context_token_missing_tp_repr,
1264 };
1265 
1266 
1267 static PyObject *
get_token_missing(void)1268 get_token_missing(void)
1269 {
1270     if (_token_missing != NULL) {
1271         Py_INCREF(_token_missing);
1272         return _token_missing;
1273     }
1274 
1275     _token_missing = (PyObject *)PyObject_New(
1276         PyContextTokenMissing, &PyContextTokenMissing_Type);
1277     if (_token_missing == NULL) {
1278         return NULL;
1279     }
1280 
1281     Py_INCREF(_token_missing);
1282     return _token_missing;
1283 }
1284 
1285 
1286 ///////////////////////////
1287 
1288 
1289 void
_PyContext_ClearFreeList(PyInterpreterState * interp)1290 _PyContext_ClearFreeList(PyInterpreterState *interp)
1291 {
1292     struct _Py_context_state *state = &interp->context;
1293     for (; state->numfree; state->numfree--) {
1294         PyContext *ctx = state->freelist;
1295         state->freelist = (PyContext *)ctx->ctx_weakreflist;
1296         ctx->ctx_weakreflist = NULL;
1297         PyObject_GC_Del(ctx);
1298     }
1299 }
1300 
1301 
1302 void
_PyContext_Fini(PyInterpreterState * interp)1303 _PyContext_Fini(PyInterpreterState *interp)
1304 {
1305     if (_Py_IsMainInterpreter(interp)) {
1306         Py_CLEAR(_token_missing);
1307     }
1308     _PyContext_ClearFreeList(interp);
1309 #ifdef Py_DEBUG
1310     struct _Py_context_state *state = &interp->context;
1311     state->numfree = -1;
1312 #endif
1313     _PyHamt_Fini();
1314 }
1315 
1316 
1317 int
_PyContext_Init(void)1318 _PyContext_Init(void)
1319 {
1320     if (!_PyHamt_Init()) {
1321         return 0;
1322     }
1323 
1324     if ((PyType_Ready(&PyContext_Type) < 0) ||
1325         (PyType_Ready(&PyContextVar_Type) < 0) ||
1326         (PyType_Ready(&PyContextToken_Type) < 0) ||
1327         (PyType_Ready(&PyContextTokenMissing_Type) < 0))
1328     {
1329         return 0;
1330     }
1331 
1332     PyObject *missing = get_token_missing();
1333     if (PyDict_SetItemString(
1334         PyContextToken_Type.tp_dict, "MISSING", missing))
1335     {
1336         Py_DECREF(missing);
1337         return 0;
1338     }
1339     Py_DECREF(missing);
1340 
1341     return 1;
1342 }
1343