1 #include <c10/core/impl/TorchDispatchModeTLS.h>
2 #include <c10/core/impl/PythonDispatcherTLS.h>
3 #include <ATen/core/PythonFallbackKernel.h>
4 #include <c10/core/SafePyObject.h>
5
6 namespace {
7
8 // This TLS is used to track the state of the dispatcher to be able to restore
9 // it when calling back into python.
10 // It has the following invariant:
11 // - It must be empty while python code is executed.
12 // - It should only be set once even for multiple dispatcher calls that do not come
13 // back to python.
14 // To achieve this, we ensure that the tls is empty by default and emptied again both when
15 // we call into user torch_dispatch or returning back to python after this call.
16
17 thread_local std::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
18
safe_get_tls_on_entry()19 c10::impl::LocalDispatchKeySet safe_get_tls_on_entry() {
20 TORCH_CHECK(tls_on_entry.has_value(), "Accessing torch dispatch state outside of '__torch_dispatch__' "
21 "is not allowed.");
22 return tls_on_entry.value();
23 }
24
25 // All the keys below the Python key
26 constexpr c10::DispatchKeySet after_Python_keyset = c10::DispatchKeySet(c10::DispatchKeySet::FULL) ^
27 (c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python) |
28 c10::DispatchKeySet(c10::DispatchKey::Python));
29
30
31 // This guard assumes that tls_on_entry has a value.
32 struct StashTLSOnEntryGuard {
33 public:
34 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
StashTLSOnEntryGuard__anonf92e42200111::StashTLSOnEntryGuard35 StashTLSOnEntryGuard(): saved_(tls_on_entry.value()) {
36 tls_on_entry = std::nullopt;
37 }
38
~StashTLSOnEntryGuard__anonf92e42200111::StashTLSOnEntryGuard39 ~StashTLSOnEntryGuard() {
40 TORCH_INTERNAL_ASSERT(!tls_on_entry.has_value());
41 tls_on_entry = saved_;
42 }
43
44 private:
45 c10::impl::LocalDispatchKeySet saved_;
46 };
47
pythonFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)48 void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
49 TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
50 // c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
51 // StashTLSOnEntryGuard stash_guard;
52 c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
53
54
55 // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
56 const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
57 if (mode_stack_len > 0) {
58 const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
59 cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
60 return;
61 }
62
63 // Otherwise, find a PyInterpreter on a Tensor
64 const auto& schema = op.schema();
65 const auto num_arguments = schema.arguments().size();
66 // It is safe to dispatch on the very first Tensor with a pyobj_interpreter
67 // without checking the interpreters of any of the arguments, because when
68 // we actually run dispatch(), we will take out PyObjects in the context
69 // of that interpreter, and this will ensure that everyone is on the same
70 // interpreter.
71 for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
72 if (ivalue.isTensor()) {
73 auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
74 if (interpreter) {
75 (*interpreter)->dispatch(op, stack);
76 return;
77 }
78 } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
79 // NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
80 // is not a thing)
81 for (const auto& nv : ivalue.toListRef()) {
82 if (nv.isNone()) {
83 continue;
84 }
85 auto* interpreter = nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
86 if (interpreter) {
87 (*interpreter)->dispatch(op, stack);
88 return;
89 }
90 }
91 }
92 }
93 TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
94 }
95
pythonDispatcherFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)96 void pythonDispatcherFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
97 auto* state = c10::impl::PythonDispatcherTLS::get_state();
98 TORCH_INTERNAL_ASSERT(state, "Hit PythonDispatcher dispatch key but PythonDispatcherTLS was not set");
99 (*state)->python_dispatcher(op, dispatch_keys.remove(c10::DispatchKey::PythonDispatcher), stack);
100 }
101
pythonTLSSnapshotFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)102 void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
103 // It is ok for the tls to be already set here.
104 // It means that there are multiple calls into the dispatcher not originating from python code.
105 // The guard below will properly ignore such calls.
106 at::impl::MaybeSetTLSOnEntryGuard guard;
107
108 op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
109 }
110
111 // The PreDispatch key gets a no-op fallback that just redispatches.
112 // The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing)
113 // Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches?
114 // Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it.
115 // Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode.
116 // Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience.
preDispatchFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)117 void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
118 op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack);
119 }
120
121 } // anonymous namespace
122
123
124 namespace at::impl {
125
RestorePythonTLSSnapshot()126 RestorePythonTLSSnapshot::RestorePythonTLSSnapshot() : saved_(safe_get_tls_on_entry()), guard_(safe_get_tls_on_entry()) {
127 tls_on_entry = std::nullopt;
128 }
129
~RestorePythonTLSSnapshot()130 RestorePythonTLSSnapshot::~RestorePythonTLSSnapshot() {
131 TORCH_INTERNAL_ASSERT(!tls_on_entry.has_value());
132 tls_on_entry = saved_;
133 }
134
MaybeSetTLSOnEntryGuard()135 MaybeSetTLSOnEntryGuard::MaybeSetTLSOnEntryGuard() {
136 if (tls_on_entry.has_value()) {
137 value_set_ = false;
138 } else {
139 value_set_ = true;
140 tls_on_entry = c10::impl::tls_local_dispatch_key_set();
141 }
142 }
~MaybeSetTLSOnEntryGuard()143 MaybeSetTLSOnEntryGuard::~MaybeSetTLSOnEntryGuard() {
144 if (value_set_) {
145 TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
146 tls_on_entry = std::nullopt;
147 }
148 }
149
150
151 } // namespace at::impl
152
TORCH_LIBRARY_IMPL(_,Python,m)153 TORCH_LIBRARY_IMPL(_, Python, m) {
154 m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
155 }
156
TORCH_LIBRARY_IMPL(_,PythonDispatcher,m)157 TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
158 m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonDispatcherFallback>());
159 }
160
TORCH_LIBRARY_IMPL(_,PythonTLSSnapshot,m)161 TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
162 m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
163 }
164
TORCH_LIBRARY_IMPL(_,PreDispatch,m)165 TORCH_LIBRARY_IMPL(_, PreDispatch, m) {
166 m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>());
167 }
168