• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/functorch/DynamicLayer.h>
2 #include <torch/library.h>
3 #include <ATen/ATen.h>
4 #include <ATen/WrapDimUtils.h>
5 #include <ATen/functorch/TensorWrapper.h>
6 #include <ATen/functorch/BatchedTensorImpl.h>
7 #include <ATen/ATen.h>
8 #include <ATen/Dispatch.h>
9 #include <c10/util/irange.h>
10 #include <ATen/NamedTensorUtils.h>
11 #include <ATen/native/LinearAlgebraUtils.h>
12 #include <ATen/native/xnnpack/Engine.h>
13 
14 namespace at::functorch {
15 
16 // NOTE: [functorch's PyTorch Operator Hacks]
17 //
18 // This file contains hacks for composite PyTorch operators that are problematic.
19 // For example, the composite op might have in-place operations,
20 // or call data_ptr. We have some idea of how to fix these things in the long term
21 // e.g., upstream the changes to PyTorch.
22 //
23 // TODO: all of these should be fixed in a more blessed way. In particular,
24 // it is bad if any of these go out-of-sync with the implementations in
25 // pytorch/pytorch.
26 
27 // TODO: upstream into core
28 
29 namespace {
index_select_backward_hack(const Tensor & grad,IntArrayRef self_sizes,int64_t dim,const Tensor & index)30 Tensor index_select_backward_hack(const Tensor& grad, IntArrayRef self_sizes, int64_t dim, const Tensor& index) {
31   return at::zeros(self_sizes, grad.options()).index_add(dim, index, grad);
32 }
33 
34 // TODO: linear is pretty important for performance, but I'm not sure how to work
35 // around the in-place.
linear_hack(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)36 Tensor linear_hack(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt) {
37   // See [Note: hacky wrapper removal for optional tensor]
38   auto bias = bias_opt.has_value()
39     ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
40     : c10::MaybeOwned<Tensor>::owned(std::in_place);
41 
42   if (input.is_mkldnn()) {
43     return at::mkldnn_linear(input, weight, *bias);
44   }
45 #if defined(C10_MOBILE)
46   if (at::native::xnnpack::use_linear(input, weight, *bias)) {
47     return at::native::xnnpack::linear(input, weight, *bias);
48   }
49 #endif
50   if (input.dim() == 2 && bias->defined()) {
51     // Fused op is marginally faster.
52     return at::addmm(*bias, input, weight.t());
53   }
54   if (input.dim() == 3 && bias->defined() && input.is_contiguous()) {
55     // Also hit the fused path for contiguous 3D input.
56     const auto input_sizes = input.sizes();
57     const auto result = at::addmm(*bias, input.view({input_sizes[0] * input_sizes[1], input_sizes[2]}), weight.t());
58     return result.view({input_sizes[0], input_sizes[1], result.size(1)});
59   }
60   auto output = at::matmul(input, weight.t());
61   if (bias->defined()) {
62     const auto& stack = getDynamicLayerStack();
63     bool any_vmap_layers = std::any_of(
64         stack.begin(), stack.end(),
65         [](const DynamicLayer& dl){ return dl.key() == TransformType::Vmap; });
66     if (any_vmap_layers) {
67       return output.add(*bias);
68     }
69     return output.add_(*bias);
70   }
71   return output;
72 }
73 
apply_loss_reduction(const at::Tensor & unreduced,int64_t reduction)74 static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
75   if (reduction == at::Reduction::Mean) {
76     return unreduced.mean();
77   } else if (reduction == at::Reduction::Sum) {
78     return unreduced.sum();
79   }
80   return unreduced;
81 }
82 
binary_cross_entropy_with_logits_hack(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & pos_weight_opt,int64_t reduction)83 Tensor binary_cross_entropy_with_logits_hack(
84     const Tensor& input,
85     const Tensor& target,
86     const std::optional<Tensor>& weight_opt,
87     const std::optional<Tensor>& pos_weight_opt,
88     int64_t reduction) {
89   // See [Note: hacky wrapper removal for optional tensor]
90   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
91   const Tensor& weight = *weight_maybe_owned;
92   const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();});
93 
94   Tensor loss;
95   auto max_val = (-input).clamp_min(0);
96   if (pos_weight.defined()) {
97     // pos_weight need to be broadcasted, thus mul(target) is not inplace.
98     auto log_weight = (pos_weight - 1).mul(target).add_(1);
99     loss = (1 - target).mul(input).add(log_weight.mul(((-max_val).exp_().add((-input - max_val).exp_())).log_().add_(max_val)));
100   } else {
101     loss = (1 - target).mul(input).add_(max_val).add_((-max_val).exp_().add((-input -max_val).exp_()).log_());
102   }
103 
104   if (weight.defined()) {
105     loss = loss * weight;
106   }
107 
108   return apply_loss_reduction(loss, reduction);
109 }
110 
trace_backward_decomp(const Tensor & grad,IntArrayRef sizes)111 Tensor trace_backward_decomp(const Tensor& grad, IntArrayRef sizes) {
112   if (sizes.size() != 2) {
113     throw std::runtime_error("expected matrix input");
114   }
115   auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
116   auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
117   // Workaround using index_put instead of yet unsupported index_fill_
118   grad_input = grad_input.index_put({indices}, grad);
119   return grad_input.view(sizes);
120 }
121 }
122 
123 // dropout hack
124 // TODO: make the following changes in pytorch/pytorch
125 namespace dropout_hack {
126 
127 namespace {
128 
129 template<bool inplace>
130 using Ctype = std::conditional_t<inplace, Tensor&, Tensor>;
131 
make_feature_noise(const Tensor & input)132 static Tensor make_feature_noise(const Tensor& input) {
133   auto input_sizes = input.sizes();
134   TORCH_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
135   std::vector<int64_t> sizes;
136   sizes.reserve(input.dim());
137   sizes.push_back(input_sizes[0]);
138   sizes.push_back(input_sizes[1]);
139   for (C10_UNUSED const auto i : c10::irange(2, input.dim())) {
140     sizes.push_back(1);
141   }
142   // NB: THIS WAS CHANGED FROM THE ORIGINAL
143   return at::empty(sizes, input.options());
144 }
145 
is_fused_kernel_acceptable(const Tensor & input,double p)146 static bool is_fused_kernel_acceptable(const Tensor& input, double p) {
147   return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0;
148 }
149 
150 // NB: sure, we could have used different overloads here, but I would feel insecure
151 // knowing that this dispatch depends only on the constness of the references
152 template<bool inplace>
multiply(Tensor & input,const Tensor & noise)153 Tensor& multiply(Tensor& input, const Tensor& noise) {
154   static_assert(inplace, "Wrong multiply overload triggered in Dropout.cpp");
155   return input.mul_(noise);
156 }
157 
158 template<bool inplace>
multiply(const Tensor & input,const Tensor & noise)159 Tensor multiply(const Tensor& input, const Tensor& noise) {
160   static_assert(!inplace, "Wrong multiply overload triggered in Dropout.cpp");
161   return input.mul(noise);
162 }
163 
164 template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
_dropout_impl(T & input,double p,bool train)165 Ctype<inplace> _dropout_impl(T& input, double p, bool train) {
166   TORCH_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
167   if (p == 0 || !train || input.numel() == 0) {
168     return input;
169   }
170 
171   if (p == 1) {
172     return multiply<inplace>(input, at::zeros({}, input.options()));
173   }
174 
175   at::Tensor b; // used for alpha_dropout only
176 
177   // NB: THIS WAS CHANGED FROM THE ORIGINAL
178   Tensor noise;
179   if (feature_dropout) {
180     auto empty = make_feature_noise(input);
181     noise = at::bernoulli(empty, 1 - p);
182   } else {
183     // NB: it is important that this is at::empty and not at::empty_like
184     auto empty = at::empty({}, input.options()).expand(input.sizes());
185     noise = at::bernoulli(empty, 1 - p);
186   }
187 
188   if (alpha_dropout) {
189     constexpr double alpha = 1.7580993408473766;
190     double a = 1. / std::sqrt((alpha * alpha * p + 1) * (1 - p));
191     b = noise.add(-1).mul_(alpha * a).add_(alpha * a * p);
192     noise.mul_(a);
193   } else {
194     noise.div_(1 - p);
195   }
196 
197   if (!alpha_dropout) {
198     return multiply<inplace>(input, noise);
199   } else {
200     return multiply<inplace>(input, noise).add_(b);
201   }
202 }
203 
204 #define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA)                      \
205 template <bool inplace, typename... Args>                                           \
206 Ctype<inplace> ALIAS_NAME(Args&&... args) {                                         \
207   return _dropout_impl<IS_FEATURE, IS_ALPHA, inplace>(std::forward<Args>(args)...); \
208 }
209 
ALIAS_SPECIALIZATION(_dropout,false,false)210 ALIAS_SPECIALIZATION(_dropout,               false, false)
211 ALIAS_SPECIALIZATION(_feature_dropout,       true,  false)
212 ALIAS_SPECIALIZATION(_alpha_dropout,         false, true )
213 ALIAS_SPECIALIZATION(_feature_alpha_dropout, true,  true )
214 
215 static Tensor dropout(const Tensor& input, double p, bool train) {
216   auto result = [&]() {
217     NoNamesGuard guard;
218     if (train && is_fused_kernel_acceptable(input, p)) {
219       return std::get<0>(at::native_dropout(input, p, train));
220     }
221     return _dropout<false>(input, p, train);
222   }();
223   namedinference::propagate_names(result, input);
224   return result;
225 }
226 
dropout_(Tensor & input,double p,bool train)227 Tensor& dropout_(Tensor& input, double p, bool train) {
228   return _dropout<true>(input, p, train);
229 }
230 
feature_dropout(const Tensor & input,double p,bool train)231 Tensor feature_dropout(const Tensor& input, double p, bool train) {
232   return _feature_dropout<false>(input, p, train);
233 }
234 
feature_dropout_(Tensor & input,double p,bool train)235 Tensor& feature_dropout_(Tensor& input, double p, bool train) {
236   return _feature_dropout<true>(input, p, train);
237 }
238 
alpha_dropout(const Tensor & input,double p,bool train)239 Tensor alpha_dropout(const Tensor& input, double p, bool train) {
240   return _alpha_dropout<false>(input, p, train);
241 }
242 
alpha_dropout_(Tensor & input,double p,bool train)243 Tensor& alpha_dropout_(Tensor& input, double p, bool train) {
244   return _alpha_dropout<true>(input, p, train);
245 }
246 
feature_alpha_dropout(const Tensor & input,double p,bool train)247 Tensor feature_alpha_dropout(const Tensor& input, double p, bool train) {
248   return _feature_alpha_dropout<false>(input, p, train);
249 }
250 
feature_alpha_dropout_(Tensor & input,double p,bool train)251 Tensor& feature_alpha_dropout_(Tensor& input, double p, bool train) {
252   return _feature_alpha_dropout<true>(input, p, train);
253 }
254 
255 }
256 } // dropout_hack
257 
TORCH_LIBRARY_IMPL(aten,FuncTorchDynamicLayerFrontMode,m)258 TORCH_LIBRARY_IMPL(aten, FuncTorchDynamicLayerFrontMode, m) {
259   m.impl("index_select_backward", index_select_backward_hack);
260   m.impl("linear", linear_hack);
261   m.impl("binary_cross_entropy_with_logits", binary_cross_entropy_with_logits_hack);
262   m.impl("trace_backward", trace_backward_decomp);
263 
264   m.impl("dropout", dropout_hack::dropout);
265   m.impl("feature_dropout", dropout_hack::feature_dropout);
266   m.impl("alpha_dropout", dropout_hack::alpha_dropout);
267   m.impl("feature_alpha_dropout", dropout_hack::feature_alpha_dropout);
268 
269   m.impl("dropout_", dropout_hack::dropout_);
270   m.impl("feature_dropout_", dropout_hack::feature_dropout_);
271   m.impl("alpha_dropout_", dropout_hack::alpha_dropout_);
272   m.impl("feature_alpha_dropout_", dropout_hack::feature_alpha_dropout_);
273 }
274 
275 } // namespace at::functorch
276