• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <c10/core/impl/TorchDispatchModeTLS.h>
2 #include <c10/core/impl/PythonDispatcherTLS.h>
3 #include <ATen/core/PythonFallbackKernel.h>
4 #include <c10/core/SafePyObject.h>
5 
6 namespace {
7 
8 // This TLS is used to track the state of the dispatcher to be able to restore
9 // it when calling back into python.
10 // It has the following invariant:
11 //  - It must be empty while python code is executed.
12 //  - It should only be set once even for multiple dispatcher calls that do not come
13 //    back to python.
14 // To achieve this, we ensure that the tls is empty by default and emptied again both when
15 // we call into user torch_dispatch or returning back to python after this call.
16 
17 thread_local std::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
18 
safe_get_tls_on_entry()19 c10::impl::LocalDispatchKeySet safe_get_tls_on_entry() {
20   TORCH_CHECK(tls_on_entry.has_value(), "Accessing torch dispatch state outside of '__torch_dispatch__' "
21               "is not allowed.");
22   return tls_on_entry.value();
23 }
24 
25 // All the keys below the Python key
26 constexpr c10::DispatchKeySet after_Python_keyset = c10::DispatchKeySet(c10::DispatchKeySet::FULL) ^
27   (c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python) |
28    c10::DispatchKeySet(c10::DispatchKey::Python));
29 
30 
31 // This guard assumes that tls_on_entry has a value.
32 struct StashTLSOnEntryGuard {
33 public:
34   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
StashTLSOnEntryGuard__anonf92e42200111::StashTLSOnEntryGuard35   StashTLSOnEntryGuard(): saved_(tls_on_entry.value()) {
36     tls_on_entry = std::nullopt;
37   }
38 
~StashTLSOnEntryGuard__anonf92e42200111::StashTLSOnEntryGuard39   ~StashTLSOnEntryGuard() {
40     TORCH_INTERNAL_ASSERT(!tls_on_entry.has_value());
41     tls_on_entry = saved_;
42   }
43 
44 private:
45   c10::impl::LocalDispatchKeySet saved_;
46 };
47 
pythonFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)48 void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
49   TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
50   // c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
51   // StashTLSOnEntryGuard stash_guard;
52   c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
53 
54 
55   // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
56   const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
57   if (mode_stack_len > 0) {
58     const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
59     cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
60     return;
61   }
62 
63   // Otherwise, find a PyInterpreter on a Tensor
64   const auto& schema = op.schema();
65   const auto num_arguments = schema.arguments().size();
66   // It is safe to dispatch on the very first Tensor with a pyobj_interpreter
67   // without checking the interpreters of any of the arguments, because when
68   // we actually run dispatch(), we will take out PyObjects in the context
69   // of that interpreter, and this will ensure that everyone is on the same
70   // interpreter.
71   for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
72     if (ivalue.isTensor()) {
73       auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
74       if (interpreter) {
75         (*interpreter)->dispatch(op, stack);
76         return;
77       }
78     } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
79       // NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
80       // is not a thing)
81       for (const auto& nv : ivalue.toListRef()) {
82         if (nv.isNone()) {
83           continue;
84         }
85         auto* interpreter = nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
86         if (interpreter) {
87           (*interpreter)->dispatch(op, stack);
88           return;
89         }
90       }
91     }
92   }
93   TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
94 }
95 
pythonDispatcherFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)96 void pythonDispatcherFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
97   auto* state = c10::impl::PythonDispatcherTLS::get_state();
98   TORCH_INTERNAL_ASSERT(state, "Hit PythonDispatcher dispatch key but PythonDispatcherTLS was not set");
99   (*state)->python_dispatcher(op, dispatch_keys.remove(c10::DispatchKey::PythonDispatcher), stack);
100 }
101 
pythonTLSSnapshotFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)102 void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
103   // It is ok for the tls to be already set here.
104   // It means that there are multiple calls into the dispatcher not originating from python code.
105   // The guard below will properly ignore such calls.
106   at::impl::MaybeSetTLSOnEntryGuard guard;
107 
108   op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
109 }
110 
111 // The PreDispatch key gets a no-op fallback that just redispatches.
112 // The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing)
113 // Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches?
114 // Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it.
115 // Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode.
116 // Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience.
preDispatchFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)117 void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
118   op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack);
119 }
120 
121 } // anonymous namespace
122 
123 
124 namespace at::impl {
125 
RestorePythonTLSSnapshot()126 RestorePythonTLSSnapshot::RestorePythonTLSSnapshot() : saved_(safe_get_tls_on_entry()), guard_(safe_get_tls_on_entry()) {
127   tls_on_entry = std::nullopt;
128 }
129 
~RestorePythonTLSSnapshot()130 RestorePythonTLSSnapshot::~RestorePythonTLSSnapshot() {
131   TORCH_INTERNAL_ASSERT(!tls_on_entry.has_value());
132   tls_on_entry = saved_;
133 }
134 
MaybeSetTLSOnEntryGuard()135 MaybeSetTLSOnEntryGuard::MaybeSetTLSOnEntryGuard() {
136   if (tls_on_entry.has_value()) {
137     value_set_ = false;
138   } else {
139     value_set_ = true;
140     tls_on_entry = c10::impl::tls_local_dispatch_key_set();
141   }
142 }
~MaybeSetTLSOnEntryGuard()143 MaybeSetTLSOnEntryGuard::~MaybeSetTLSOnEntryGuard() {
144   if (value_set_) {
145     TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
146     tls_on_entry = std::nullopt;
147   }
148 }
149 
150 
151 } // namespace at::impl
152 
TORCH_LIBRARY_IMPL(_,Python,m)153 TORCH_LIBRARY_IMPL(_, Python, m) {
154   m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
155 }
156 
TORCH_LIBRARY_IMPL(_,PythonDispatcher,m)157 TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
158   m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonDispatcherFallback>());
159 }
160 
TORCH_LIBRARY_IMPL(_,PythonTLSSnapshot,m)161 TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
162   m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
163 }
164 
TORCH_LIBRARY_IMPL(_,PreDispatch,m)165 TORCH_LIBRARY_IMPL(_, PreDispatch, m) {
166   m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>());
167 }
168