1 #include <ATen/functorch/ADInterpreters.h>
2 #include <ATen/functorch/DynamicLayer.h>
3 #include <ATen/functorch/TensorWrapper.h>
4 #include <bitset>
5
6 namespace at::functorch {
7
8 constexpr size_t default_bitset_size = 64;
9
checkForInvalidMutationOnCaptures(const c10::OperatorHandle & op,const torch::jit::Stack * stack,int64_t cur_level)10 static void checkForInvalidMutationOnCaptures(
11 const c10::OperatorHandle& op,
12 const torch::jit::Stack* stack,
13 int64_t cur_level) {
14 if (!isInplaceOp(op.schema())) {
15 return;
16 }
17 auto args = torch::jit::last(stack, op.schema().arguments().size());
18 auto mutated_arg = unwrapIfDead(args[0].toTensor());
19 auto* wrapper = maybeGetTensorWrapper(mutated_arg);
20 if (wrapper && wrapper->level().has_value() && wrapper->level().value() == cur_level && !(wrapper->is_immutable())) {
21 return;
22 }
23 TORCH_CHECK(false,
24 "During a grad (vjp, jvp, grad, etc) transform, the function provided ",
25 "attempted to call in-place operation (", op.schema().operator_name(), ") ",
26 "that would mutate a captured Tensor. This is not supported; please rewrite ",
27 "the function being transformed to explicitly accept the mutated Tensor(s) ",
28 "as inputs.");
29 }
30
materializeGradWrappers(const Tensor & tensor,int64_t current_level)31 static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) {
32 if (!tensor.defined()) {
33 return tensor;
34 }
35 // TensorWrapper creation may call dispatcher ops (e.g. aten.sym_storage_offset).
36 // We need to ensure that they pass through the functorch stack properly.
37 // In order to do that, we want to call those dispatcher ops at the next layer,
38 // hence we disable DynamicLayerFrontMode so the call to the op automatically
39 // goes to DynamicLayerBackMode which will then send it to the next layer.
40 c10::impl::ExcludeDispatchKeyGuard guard(c10::DispatchKey::FuncTorchDynamicLayerFrontMode);
41 auto* wrapper = maybeGetTensorWrapper(tensor);
42 if (!wrapper) {
43 return makeTensorWrapper(tensor, current_level, /*is_immutable=*/true);
44 }
45 TORCH_INTERNAL_ASSERT(wrapper->level().value() <= current_level, "escaped?");
46 if (wrapper->level().value() == current_level) {
47 TORCH_INTERNAL_ASSERT(tensor.defined());
48 return tensor;
49 }
50 return makeTensorWrapper(tensor, current_level, /*is_immutable=*/true);
51 }
52
lift(const Tensor & tensor) const53 Tensor GradInterpreterPtr::lift(const Tensor& tensor) const {
54 return materializeGradWrappers(tensor, level());
55 }
56
lift(const Tensor & tensor) const57 Tensor JvpInterpreterPtr::lift(const Tensor& tensor) const {
58 return materializeGradWrappers(tensor, level());
59 }
60
autogradBasedTransformProcess(const c10::OperatorHandle & op,torch::jit::Stack * stack,int64_t current_level,TransformType transform_type)61 static void autogradBasedTransformProcess(
62 const c10::OperatorHandle& op,
63 torch::jit::Stack* stack,
64 int64_t current_level,
65 TransformType transform_type) {
66 // if is a grad transform, and the operation is in-place, and the mutated
67 // argument is not currently wrapped in a TensorWrapper, then we need to
68 // error out otherwise the result is silently incorrect
69 checkForInvalidMutationOnCaptures(op, stack, current_level);
70
71 // materialize live GradWrappers
72 auto maybeTransformGradWrappers = [&](const Tensor& tensor) {
73 return materializeGradWrappers(tensor, current_level);
74 };
75 auto num_args = op.schema().arguments().size();
76 foreachTensorInplace(*stack, static_cast<int64_t>(stack->size() - num_args), static_cast<int64_t>(stack->size()), maybeTransformGradWrappers);
77
78 setup_dispatch_key_tls(transform_type, {});
79 op.callBoxed(stack);
80 }
81
autogradBasedTransformSendToNext(const c10::OperatorHandle & op,torch::jit::Stack * stack,const Interpreter & interpreter,TransformType transform_type,std::optional<bool> prev_grad_mode,std::optional<bool> prev_fwd_grad_mode,bool grad_special_case)82 static void autogradBasedTransformSendToNext(
83 const c10::OperatorHandle& op,
84 torch::jit::Stack* stack,
85 const Interpreter& interpreter,
86 TransformType transform_type,
87 std::optional<bool> prev_grad_mode,
88 std::optional<bool> prev_fwd_grad_mode,
89 bool grad_special_case) {
90 auto current_level = interpreter.level();
91 if (transform_type == TransformType::Grad) {
92 TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
93 }
94 if (transform_type == TransformType::Jvp) {
95 TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
96 }
97 auto unwrap = [&](const Tensor& tensor) {
98 if (!tensor.defined()) {
99 return tensor;
100 }
101 auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
102 if (!maybe_tensor_wrapper) {
103 return tensor;
104 }
105 auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
106 TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= current_level);
107 if (tensor_wrapper_level == current_level) {
108 return maybe_tensor_wrapper->value();
109 }
110 return tensor;
111 };
112 auto wrap = [&](const Tensor& tensor, bool is_immutable) {
113 if (!tensor.defined()) {
114 return tensor;
115 }
116 // if (c10::show_dispatch_trace_enabled()) {
117 // std::cout << "wrap " << current_level << std::endl;
118 // }
119 return makeTensorWrapper(tensor, interpreter, is_immutable);
120 };
121
122 // TODO: we only need to do the following (marked with !) on in-place functions
123 // that modify sizes or strides. There aren't many of them.
124 // If autograd dispatch key:
125 // 1. (!) Put a copy of all of the args onto the stack
126 // 2. Unwrap all the args in the copy set
127 // 3. Call the operator
128 // 4. Wrap the output
129 // 5. (!) refreshMetadata for all the args in the original set
130 // 6. (!) Pop those args off.
131
132 // Step 1 & 2
133 auto args_size = op.schema().arguments().size();
134 const auto ret_size = op.schema().returns().size();
135 // Step 1
136 auto front = static_cast<int64_t>(stack->size()) - args_size;
137 for (const auto arg_idx : c10::irange(0, args_size)) {
138 stack->push_back((*stack)[front + arg_idx]);
139 }
140
141 std::bitset<default_bitset_size> outputs_aliasing_immutable; // set = 1 for all bits
142 if(!grad_special_case) {
143 for (auto idx = stack->size() - args_size; idx < stack->size(); idx++) {
144 const auto ivalue = (*stack)[idx];
145 if (!ivalue.isTensor()) {
146 continue; // only input that can be aliased is a tensor, not a tensor list (expect in ops without returns)
147 }
148 const auto& tensor = ivalue.toTensor();
149 auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
150 if (!maybe_tensor_wrapper || maybe_tensor_wrapper->is_immutable()) {
151 // if the input is immutable, we find if it aliases anything, noting that
152 // args are in reverse order on stack, so the last arg is at the top of the stack
153 const auto relative_pos = idx - (stack->size() - args_size);
154 const auto aliased_out = findAliasedOutput(op.schema(), static_cast<int64_t>(relative_pos));
155 if (aliased_out.has_value()) {
156 outputs_aliasing_immutable.flip(*aliased_out); // each output aliases at most one input, so we can only hit this once
157 }
158 }
159 }
160 }
161
162 // Step 2
163 foreachTensorInplace(*stack, static_cast<int64_t>(stack->size() - args_size), static_cast<int64_t>(stack->size()), unwrap);
164
165 // See NOTE [grad and vjp interaction with no_grad]
166 std::optional<c10::AutoGradMode> grad_guard;
167 if (transform_type == TransformType::Grad && prev_grad_mode.has_value() && *prev_grad_mode == false) {
168 grad_guard.emplace(*prev_grad_mode);
169 }
170 std::optional<c10::AutoFwGradMode> fw_grad_guard;
171 if (transform_type == TransformType::Jvp &&
172 prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
173 fw_grad_guard.emplace(*prev_fwd_grad_mode);
174 }
175
176 // Re-dispatch
177 if (getDynamicLayerStack().empty()) {
178 sanityCheckStack(op, stack);
179 }
180
181 // Step 4, 5, 6
182
183 op.callBoxed(stack);
184
185 // Step 4
186 foreachTensorInplaceWithFlag(*stack, static_cast<int64_t>(stack->size() - ret_size), static_cast<int64_t>(stack->size()), outputs_aliasing_immutable, wrap);
187
188 // Step 5
189 auto args_front = stack->size() - args_size - ret_size;
190 for (const auto arg_idx : c10::irange(0, args_size)) {
191 auto& ivalue = (*stack)[args_front + arg_idx];
192 if (!ivalue.isTensor()) {
193 continue;
194 }
195 auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
196 if (!maybe_tensor_wrapper) {
197 continue;
198 }
199 maybe_tensor_wrapper->refreshMetadata();
200 }
201
202 // Step 6
203 stack->erase(stack->end() - std::ptrdiff_t(args_size + ret_size), stack->end() - std::ptrdiff_t(ret_size));
204 }
205
processImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack)206 void GradInterpreterPtr::processImpl(
207 const c10::OperatorHandle& op,
208 torch::jit::Stack* stack) {
209 autogradBasedTransformProcess(op, stack, level(), TransformType::Grad);
210 }
211
sendToNextInterpreterImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)212 void GradInterpreterPtr::sendToNextInterpreterImpl(
213 const c10::OperatorHandle& op,
214 torch::jit::Stack* stack,
215 bool grad_special_case) {
216 autogradBasedTransformSendToNext(
217 op, stack, *base_,
218 TransformType::Grad,
219 prevGradMode(),
220 std::nullopt,
221 grad_special_case);
222 }
223
processImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack)224 void JvpInterpreterPtr::processImpl(
225 const c10::OperatorHandle& op,
226 torch::jit::Stack* stack) {
227 autogradBasedTransformProcess(op, stack, level(), TransformType::Jvp);
228 }
229
sendToNextInterpreterImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)230 void JvpInterpreterPtr::sendToNextInterpreterImpl(
231 const c10::OperatorHandle& op,
232 torch::jit::Stack* stack,
233 bool grad_special_case) {
234 autogradBasedTransformSendToNext(
235 op, stack, *base_,
236 TransformType::Jvp,
237 std::nullopt,
238 prevFwdGradMode(),
239 grad_special_case);
240 }
241
242 } // namespace at::functorch
243