• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/functorch/ADInterpreters.h>
2 #include <ATen/functorch/DynamicLayer.h>
3 #include <ATen/functorch/TensorWrapper.h>
4 #include <bitset>
5 
6 namespace at::functorch {
7 
8 constexpr size_t default_bitset_size = 64;
9 
checkForInvalidMutationOnCaptures(const c10::OperatorHandle & op,const torch::jit::Stack * stack,int64_t cur_level)10 static void checkForInvalidMutationOnCaptures(
11     const c10::OperatorHandle& op,
12     const torch::jit::Stack* stack,
13     int64_t cur_level) {
14   if (!isInplaceOp(op.schema())) {
15     return;
16   }
17   auto args = torch::jit::last(stack, op.schema().arguments().size());
18   auto mutated_arg = unwrapIfDead(args[0].toTensor());
19   auto* wrapper = maybeGetTensorWrapper(mutated_arg);
20   if (wrapper && wrapper->level().has_value() && wrapper->level().value() == cur_level && !(wrapper->is_immutable())) {
21     return;
22   }
23   TORCH_CHECK(false,
24       "During a grad (vjp, jvp, grad, etc) transform, the function provided ",
25       "attempted to call in-place operation (", op.schema().operator_name(), ") ",
26       "that would mutate a captured Tensor. This is not supported; please rewrite ",
27       "the function being transformed to explicitly accept the mutated Tensor(s) ",
28       "as inputs.");
29 }
30 
materializeGradWrappers(const Tensor & tensor,int64_t current_level)31 static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) {
32   if (!tensor.defined()) {
33     return tensor;
34   }
35   // TensorWrapper creation may call dispatcher ops (e.g. aten.sym_storage_offset).
36   // We need to ensure that they pass through the functorch stack properly.
37   // In order to do that, we want to call those dispatcher ops at the next layer,
38   // hence we disable DynamicLayerFrontMode so the call to the op automatically
39   // goes to DynamicLayerBackMode which will then send it to the next layer.
40   c10::impl::ExcludeDispatchKeyGuard guard(c10::DispatchKey::FuncTorchDynamicLayerFrontMode);
41   auto* wrapper = maybeGetTensorWrapper(tensor);
42   if (!wrapper) {
43     return makeTensorWrapper(tensor, current_level, /*is_immutable=*/true);
44   }
45   TORCH_INTERNAL_ASSERT(wrapper->level().value() <= current_level, "escaped?");
46   if (wrapper->level().value() == current_level) {
47     TORCH_INTERNAL_ASSERT(tensor.defined());
48     return tensor;
49   }
50   return makeTensorWrapper(tensor, current_level, /*is_immutable=*/true);
51 }
52 
lift(const Tensor & tensor) const53 Tensor GradInterpreterPtr::lift(const Tensor& tensor) const {
54   return materializeGradWrappers(tensor, level());
55 }
56 
lift(const Tensor & tensor) const57 Tensor JvpInterpreterPtr::lift(const Tensor& tensor) const {
58   return materializeGradWrappers(tensor, level());
59 }
60 
autogradBasedTransformProcess(const c10::OperatorHandle & op,torch::jit::Stack * stack,int64_t current_level,TransformType transform_type)61 static void autogradBasedTransformProcess(
62     const c10::OperatorHandle& op,
63     torch::jit::Stack* stack,
64     int64_t current_level,
65     TransformType transform_type) {
66   // if is a grad transform, and the operation is in-place, and the mutated
67   // argument is not currently wrapped in a TensorWrapper, then we need to
68   // error out otherwise the result is silently incorrect
69   checkForInvalidMutationOnCaptures(op, stack, current_level);
70 
71   // materialize live GradWrappers
72   auto maybeTransformGradWrappers = [&](const Tensor& tensor) {
73     return materializeGradWrappers(tensor, current_level);
74   };
75   auto num_args = op.schema().arguments().size();
76   foreachTensorInplace(*stack, static_cast<int64_t>(stack->size() - num_args), static_cast<int64_t>(stack->size()), maybeTransformGradWrappers);
77 
78   setup_dispatch_key_tls(transform_type, {});
79   op.callBoxed(stack);
80 }
81 
autogradBasedTransformSendToNext(const c10::OperatorHandle & op,torch::jit::Stack * stack,const Interpreter & interpreter,TransformType transform_type,std::optional<bool> prev_grad_mode,std::optional<bool> prev_fwd_grad_mode,bool grad_special_case)82 static void autogradBasedTransformSendToNext(
83     const c10::OperatorHandle& op,
84     torch::jit::Stack* stack,
85     const Interpreter& interpreter,
86     TransformType transform_type,
87     std::optional<bool> prev_grad_mode,
88     std::optional<bool> prev_fwd_grad_mode,
89     bool grad_special_case) {
90   auto current_level = interpreter.level();
91   if (transform_type == TransformType::Grad) {
92     TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
93   }
94   if (transform_type == TransformType::Jvp) {
95     TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
96   }
97   auto unwrap = [&](const Tensor& tensor) {
98     if (!tensor.defined()) {
99       return tensor;
100     }
101     auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
102     if (!maybe_tensor_wrapper) {
103       return tensor;
104     }
105     auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
106     TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= current_level);
107     if (tensor_wrapper_level == current_level) {
108       return maybe_tensor_wrapper->value();
109     }
110     return tensor;
111   };
112   auto wrap = [&](const Tensor& tensor, bool is_immutable) {
113     if (!tensor.defined()) {
114       return tensor;
115     }
116     // if (c10::show_dispatch_trace_enabled()) {
117     //   std::cout << "wrap " << current_level << std::endl;
118     // }
119     return makeTensorWrapper(tensor, interpreter, is_immutable);
120   };
121 
122   // TODO: we only need to do the following (marked with !) on in-place functions
123   // that modify sizes or strides. There aren't many of them.
124   // If autograd dispatch key:
125   // 1. (!) Put a copy of all of the args onto the stack
126   // 2. Unwrap all the args in the copy set
127   // 3. Call the operator
128   // 4. Wrap the output
129   // 5. (!) refreshMetadata for all the args in the original set
130   // 6. (!) Pop those args off.
131 
132   // Step 1 & 2
133   auto args_size = op.schema().arguments().size();
134   const auto ret_size = op.schema().returns().size();
135   // Step 1
136   auto front = static_cast<int64_t>(stack->size()) - args_size;
137   for (const auto arg_idx : c10::irange(0, args_size)) {
138     stack->push_back((*stack)[front + arg_idx]);
139   }
140 
141   std::bitset<default_bitset_size> outputs_aliasing_immutable; // set = 1 for all bits
142   if(!grad_special_case) {
143     for (auto idx = stack->size() - args_size; idx < stack->size(); idx++) {
144       const auto ivalue = (*stack)[idx];
145       if (!ivalue.isTensor()) {
146         continue; // only input that can be aliased is a tensor, not a tensor list (expect in ops without returns)
147       }
148       const auto& tensor = ivalue.toTensor();
149       auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
150       if (!maybe_tensor_wrapper || maybe_tensor_wrapper->is_immutable()) {
151         // if the input is immutable, we find if it aliases anything, noting that
152         // args are in reverse order on stack, so the last arg is at the top of the stack
153         const auto relative_pos = idx - (stack->size() - args_size);
154         const auto aliased_out = findAliasedOutput(op.schema(), static_cast<int64_t>(relative_pos));
155         if (aliased_out.has_value()) {
156           outputs_aliasing_immutable.flip(*aliased_out); // each output aliases at most one input, so we can only hit this once
157         }
158       }
159     }
160   }
161 
162   // Step 2
163   foreachTensorInplace(*stack, static_cast<int64_t>(stack->size() - args_size), static_cast<int64_t>(stack->size()), unwrap);
164 
165   // See NOTE [grad and vjp interaction with no_grad]
166   std::optional<c10::AutoGradMode> grad_guard;
167   if (transform_type == TransformType::Grad && prev_grad_mode.has_value() && *prev_grad_mode == false) {
168     grad_guard.emplace(*prev_grad_mode);
169   }
170   std::optional<c10::AutoFwGradMode> fw_grad_guard;
171   if (transform_type == TransformType::Jvp &&
172       prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
173     fw_grad_guard.emplace(*prev_fwd_grad_mode);
174   }
175 
176   // Re-dispatch
177   if (getDynamicLayerStack().empty()) {
178     sanityCheckStack(op, stack);
179   }
180 
181   // Step 4, 5, 6
182 
183   op.callBoxed(stack);
184 
185   // Step 4
186   foreachTensorInplaceWithFlag(*stack, static_cast<int64_t>(stack->size() - ret_size), static_cast<int64_t>(stack->size()), outputs_aliasing_immutable, wrap);
187 
188   // Step 5
189   auto args_front = stack->size() - args_size - ret_size;
190   for (const auto arg_idx : c10::irange(0, args_size)) {
191     auto& ivalue = (*stack)[args_front + arg_idx];
192     if (!ivalue.isTensor()) {
193       continue;
194     }
195     auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
196     if (!maybe_tensor_wrapper) {
197       continue;
198     }
199     maybe_tensor_wrapper->refreshMetadata();
200   }
201 
202   // Step 6
203   stack->erase(stack->end() - std::ptrdiff_t(args_size + ret_size), stack->end() - std::ptrdiff_t(ret_size));
204 }
205 
processImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack)206 void GradInterpreterPtr::processImpl(
207     const c10::OperatorHandle& op,
208     torch::jit::Stack* stack) {
209   autogradBasedTransformProcess(op, stack, level(), TransformType::Grad);
210 }
211 
sendToNextInterpreterImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)212 void GradInterpreterPtr::sendToNextInterpreterImpl(
213     const c10::OperatorHandle& op,
214     torch::jit::Stack* stack,
215     bool grad_special_case) {
216   autogradBasedTransformSendToNext(
217       op, stack, *base_,
218       TransformType::Grad,
219       prevGradMode(),
220       std::nullopt,
221       grad_special_case);
222 }
223 
processImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack)224 void JvpInterpreterPtr::processImpl(
225     const c10::OperatorHandle& op,
226     torch::jit::Stack* stack) {
227   autogradBasedTransformProcess(op, stack, level(), TransformType::Jvp);
228 }
229 
sendToNextInterpreterImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)230 void JvpInterpreterPtr::sendToNextInterpreterImpl(
231     const c10::OperatorHandle& op,
232     torch::jit::Stack* stack,
233     bool grad_special_case) {
234   autogradBasedTransformSendToNext(
235       op, stack, *base_,
236       TransformType::Jvp,
237       std::nullopt,
238       prevFwdGradMode(),
239       grad_special_case);
240 }
241 
242 } // namespace at::functorch
243