1 #include <ATen/functorch/DynamicLayer.h>
2 #include <torch/library.h>
3 #include <ATen/ATen.h>
4 #include <ATen/WrapDimUtils.h>
5 #include <ATen/functorch/TensorWrapper.h>
6 #include <ATen/functorch/BatchedTensorImpl.h>
7 #include <ATen/ATen.h>
8 #include <ATen/Dispatch.h>
9 #include <c10/util/irange.h>
10 #include <ATen/NamedTensorUtils.h>
11 #include <ATen/native/LinearAlgebraUtils.h>
12 #include <ATen/native/xnnpack/Engine.h>
13
14 namespace at::functorch {
15
16 // NOTE: [functorch's PyTorch Operator Hacks]
17 //
18 // This file contains hacks for composite PyTorch operators that are problematic.
19 // For example, the composite op might have in-place operations,
20 // or call data_ptr. We have some idea of how to fix these things in the long term
21 // e.g., upstream the changes to PyTorch.
22 //
23 // TODO: all of these should be fixed in a more blessed way. In particular,
24 // it is bad if any of these go out-of-sync with the implementations in
25 // pytorch/pytorch.
26
27 // TODO: upstream into core
28
29 namespace {
index_select_backward_hack(const Tensor & grad,IntArrayRef self_sizes,int64_t dim,const Tensor & index)30 Tensor index_select_backward_hack(const Tensor& grad, IntArrayRef self_sizes, int64_t dim, const Tensor& index) {
31 return at::zeros(self_sizes, grad.options()).index_add(dim, index, grad);
32 }
33
34 // TODO: linear is pretty important for performance, but I'm not sure how to work
35 // around the in-place.
linear_hack(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)36 Tensor linear_hack(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt) {
37 // See [Note: hacky wrapper removal for optional tensor]
38 auto bias = bias_opt.has_value()
39 ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
40 : c10::MaybeOwned<Tensor>::owned(std::in_place);
41
42 if (input.is_mkldnn()) {
43 return at::mkldnn_linear(input, weight, *bias);
44 }
45 #if defined(C10_MOBILE)
46 if (at::native::xnnpack::use_linear(input, weight, *bias)) {
47 return at::native::xnnpack::linear(input, weight, *bias);
48 }
49 #endif
50 if (input.dim() == 2 && bias->defined()) {
51 // Fused op is marginally faster.
52 return at::addmm(*bias, input, weight.t());
53 }
54 if (input.dim() == 3 && bias->defined() && input.is_contiguous()) {
55 // Also hit the fused path for contiguous 3D input.
56 const auto input_sizes = input.sizes();
57 const auto result = at::addmm(*bias, input.view({input_sizes[0] * input_sizes[1], input_sizes[2]}), weight.t());
58 return result.view({input_sizes[0], input_sizes[1], result.size(1)});
59 }
60 auto output = at::matmul(input, weight.t());
61 if (bias->defined()) {
62 const auto& stack = getDynamicLayerStack();
63 bool any_vmap_layers = std::any_of(
64 stack.begin(), stack.end(),
65 [](const DynamicLayer& dl){ return dl.key() == TransformType::Vmap; });
66 if (any_vmap_layers) {
67 return output.add(*bias);
68 }
69 return output.add_(*bias);
70 }
71 return output;
72 }
73
apply_loss_reduction(const at::Tensor & unreduced,int64_t reduction)74 static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
75 if (reduction == at::Reduction::Mean) {
76 return unreduced.mean();
77 } else if (reduction == at::Reduction::Sum) {
78 return unreduced.sum();
79 }
80 return unreduced;
81 }
82
binary_cross_entropy_with_logits_hack(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & pos_weight_opt,int64_t reduction)83 Tensor binary_cross_entropy_with_logits_hack(
84 const Tensor& input,
85 const Tensor& target,
86 const std::optional<Tensor>& weight_opt,
87 const std::optional<Tensor>& pos_weight_opt,
88 int64_t reduction) {
89 // See [Note: hacky wrapper removal for optional tensor]
90 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
91 const Tensor& weight = *weight_maybe_owned;
92 const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();});
93
94 Tensor loss;
95 auto max_val = (-input).clamp_min(0);
96 if (pos_weight.defined()) {
97 // pos_weight need to be broadcasted, thus mul(target) is not inplace.
98 auto log_weight = (pos_weight - 1).mul(target).add_(1);
99 loss = (1 - target).mul(input).add(log_weight.mul(((-max_val).exp_().add((-input - max_val).exp_())).log_().add_(max_val)));
100 } else {
101 loss = (1 - target).mul(input).add_(max_val).add_((-max_val).exp_().add((-input -max_val).exp_()).log_());
102 }
103
104 if (weight.defined()) {
105 loss = loss * weight;
106 }
107
108 return apply_loss_reduction(loss, reduction);
109 }
110
trace_backward_decomp(const Tensor & grad,IntArrayRef sizes)111 Tensor trace_backward_decomp(const Tensor& grad, IntArrayRef sizes) {
112 if (sizes.size() != 2) {
113 throw std::runtime_error("expected matrix input");
114 }
115 auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
116 auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
117 // Workaround using index_put instead of yet unsupported index_fill_
118 grad_input = grad_input.index_put({indices}, grad);
119 return grad_input.view(sizes);
120 }
121 }
122
123 // dropout hack
124 // TODO: make the following changes in pytorch/pytorch
125 namespace dropout_hack {
126
127 namespace {
128
129 template<bool inplace>
130 using Ctype = std::conditional_t<inplace, Tensor&, Tensor>;
131
make_feature_noise(const Tensor & input)132 static Tensor make_feature_noise(const Tensor& input) {
133 auto input_sizes = input.sizes();
134 TORCH_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
135 std::vector<int64_t> sizes;
136 sizes.reserve(input.dim());
137 sizes.push_back(input_sizes[0]);
138 sizes.push_back(input_sizes[1]);
139 for (C10_UNUSED const auto i : c10::irange(2, input.dim())) {
140 sizes.push_back(1);
141 }
142 // NB: THIS WAS CHANGED FROM THE ORIGINAL
143 return at::empty(sizes, input.options());
144 }
145
is_fused_kernel_acceptable(const Tensor & input,double p)146 static bool is_fused_kernel_acceptable(const Tensor& input, double p) {
147 return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0;
148 }
149
150 // NB: sure, we could have used different overloads here, but I would feel insecure
151 // knowing that this dispatch depends only on the constness of the references
152 template<bool inplace>
multiply(Tensor & input,const Tensor & noise)153 Tensor& multiply(Tensor& input, const Tensor& noise) {
154 static_assert(inplace, "Wrong multiply overload triggered in Dropout.cpp");
155 return input.mul_(noise);
156 }
157
158 template<bool inplace>
multiply(const Tensor & input,const Tensor & noise)159 Tensor multiply(const Tensor& input, const Tensor& noise) {
160 static_assert(!inplace, "Wrong multiply overload triggered in Dropout.cpp");
161 return input.mul(noise);
162 }
163
164 template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
_dropout_impl(T & input,double p,bool train)165 Ctype<inplace> _dropout_impl(T& input, double p, bool train) {
166 TORCH_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
167 if (p == 0 || !train || input.numel() == 0) {
168 return input;
169 }
170
171 if (p == 1) {
172 return multiply<inplace>(input, at::zeros({}, input.options()));
173 }
174
175 at::Tensor b; // used for alpha_dropout only
176
177 // NB: THIS WAS CHANGED FROM THE ORIGINAL
178 Tensor noise;
179 if (feature_dropout) {
180 auto empty = make_feature_noise(input);
181 noise = at::bernoulli(empty, 1 - p);
182 } else {
183 // NB: it is important that this is at::empty and not at::empty_like
184 auto empty = at::empty({}, input.options()).expand(input.sizes());
185 noise = at::bernoulli(empty, 1 - p);
186 }
187
188 if (alpha_dropout) {
189 constexpr double alpha = 1.7580993408473766;
190 double a = 1. / std::sqrt((alpha * alpha * p + 1) * (1 - p));
191 b = noise.add(-1).mul_(alpha * a).add_(alpha * a * p);
192 noise.mul_(a);
193 } else {
194 noise.div_(1 - p);
195 }
196
197 if (!alpha_dropout) {
198 return multiply<inplace>(input, noise);
199 } else {
200 return multiply<inplace>(input, noise).add_(b);
201 }
202 }
203
204 #define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA) \
205 template <bool inplace, typename... Args> \
206 Ctype<inplace> ALIAS_NAME(Args&&... args) { \
207 return _dropout_impl<IS_FEATURE, IS_ALPHA, inplace>(std::forward<Args>(args)...); \
208 }
209
ALIAS_SPECIALIZATION(_dropout,false,false)210 ALIAS_SPECIALIZATION(_dropout, false, false)
211 ALIAS_SPECIALIZATION(_feature_dropout, true, false)
212 ALIAS_SPECIALIZATION(_alpha_dropout, false, true )
213 ALIAS_SPECIALIZATION(_feature_alpha_dropout, true, true )
214
215 static Tensor dropout(const Tensor& input, double p, bool train) {
216 auto result = [&]() {
217 NoNamesGuard guard;
218 if (train && is_fused_kernel_acceptable(input, p)) {
219 return std::get<0>(at::native_dropout(input, p, train));
220 }
221 return _dropout<false>(input, p, train);
222 }();
223 namedinference::propagate_names(result, input);
224 return result;
225 }
226
dropout_(Tensor & input,double p,bool train)227 Tensor& dropout_(Tensor& input, double p, bool train) {
228 return _dropout<true>(input, p, train);
229 }
230
feature_dropout(const Tensor & input,double p,bool train)231 Tensor feature_dropout(const Tensor& input, double p, bool train) {
232 return _feature_dropout<false>(input, p, train);
233 }
234
feature_dropout_(Tensor & input,double p,bool train)235 Tensor& feature_dropout_(Tensor& input, double p, bool train) {
236 return _feature_dropout<true>(input, p, train);
237 }
238
alpha_dropout(const Tensor & input,double p,bool train)239 Tensor alpha_dropout(const Tensor& input, double p, bool train) {
240 return _alpha_dropout<false>(input, p, train);
241 }
242
alpha_dropout_(Tensor & input,double p,bool train)243 Tensor& alpha_dropout_(Tensor& input, double p, bool train) {
244 return _alpha_dropout<true>(input, p, train);
245 }
246
feature_alpha_dropout(const Tensor & input,double p,bool train)247 Tensor feature_alpha_dropout(const Tensor& input, double p, bool train) {
248 return _feature_alpha_dropout<false>(input, p, train);
249 }
250
feature_alpha_dropout_(Tensor & input,double p,bool train)251 Tensor& feature_alpha_dropout_(Tensor& input, double p, bool train) {
252 return _feature_alpha_dropout<true>(input, p, train);
253 }
254
255 }
256 } // dropout_hack
257
TORCH_LIBRARY_IMPL(aten,FuncTorchDynamicLayerFrontMode,m)258 TORCH_LIBRARY_IMPL(aten, FuncTorchDynamicLayerFrontMode, m) {
259 m.impl("index_select_backward", index_select_backward_hack);
260 m.impl("linear", linear_hack);
261 m.impl("binary_cross_entropy_with_logits", binary_cross_entropy_with_logits_hack);
262 m.impl("trace_backward", trace_backward_decomp);
263
264 m.impl("dropout", dropout_hack::dropout);
265 m.impl("feature_dropout", dropout_hack::feature_dropout);
266 m.impl("alpha_dropout", dropout_hack::alpha_dropout);
267 m.impl("feature_alpha_dropout", dropout_hack::feature_alpha_dropout);
268
269 m.impl("dropout_", dropout_hack::dropout_);
270 m.impl("feature_dropout_", dropout_hack::feature_dropout_);
271 m.impl("alpha_dropout_", dropout_hack::alpha_dropout_);
272 m.impl("feature_alpha_dropout_", dropout_hack::feature_alpha_dropout_);
273 }
274
275 } // namespace at::functorch
276