1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/empty.h>
10 #include <ATen/ops/miopen_batch_norm_native.h>
11 #include <ATen/ops/miopen_batch_norm_backward_native.h>
12 #endif
13
14 // TODO: Remove the condition on AT_ROCM_ENABLED entirely,
15 // don't build this file as part of CPU build.
16 #include <ATen/cuda/CUDAConfig.h>
17
18 #if !AT_ROCM_ENABLED()
19
20 namespace at { namespace native {
21
22 // See Note [ATen preprocessor philosophy]
23
miopen_batch_norm(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double exponential_average_factor,double epsilon)24 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
25 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
26 bool training, double exponential_average_factor, double epsilon) {
27 AT_ERROR("miopen_batch_norm: ATen not compiled with MIOpen support");
28 }
29
miopen_batch_norm_backward(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,double epsilon)30 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
31 const Tensor& input, const Tensor& grad_output, const Tensor& weight, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
32 double epsilon) {
33 AT_ERROR("miopen_batch_norm_backward: ATen not compiled with MIOpen support");
34 }
35
36 }} // namespace at::native
37
38 #else // AT_ROCM_ENABLED
39
40 #include <ATen/miopen/Descriptors.h>
41 #include <ATen/miopen/Types.h>
42 #include <ATen/miopen/Utils.h>
43
44 #include <ATen/TensorUtils.h>
45
46 namespace at { namespace native {
47
48 namespace {
49
expandScale(const Tensor & t,int64_t dim)50 Tensor expandScale(const Tensor& t, int64_t dim) {
51 std::vector<int64_t> size{ 1, t.numel() };
52 while (static_cast<int64_t>(size.size()) < dim) {
53 size.emplace_back(1);
54 }
55 return t.view(size);
56 }
57
58 } // namespace
59
miopen_batch_norm(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t_opt,const std::optional<Tensor> & running_mean_t_opt,const std::optional<Tensor> & running_var_t_opt,bool training,double exponential_average_factor,double epsilon)60 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
61 const Tensor& input_t, const Tensor& weight_t, const std::optional<Tensor>& bias_t_opt, const std::optional<Tensor>& running_mean_t_opt, const std::optional<Tensor>& running_var_t_opt,
62 bool training, double exponential_average_factor, double epsilon)
63 {
64 // See [Note: hacky wrapper removal for optional tensor]
65 c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
66 const Tensor& bias_t = *bias_t_maybe_owned;
67 const Tensor& running_mean_t = c10::value_or_else(running_mean_t_opt, [] {return Tensor();});
68 const Tensor& running_var_t = c10::value_or_else(running_var_t_opt, [] {return Tensor();});
69
70 TensorArg input{ input_t, "input", 1 },
71 weight{ weight_t, "weight", 2 },
72 bias{ bias_t, "bias", 3 },
73 running_mean{ running_mean_t, "running_mean", 4 },
74 running_var{ running_var_t, "running_var", 5 };
75 CheckedFrom c = "miopen_batch_norm";
76
77 checkAllDefined(c, {input, weight, bias});
78 if (!training) {
79 checkAllDefined(c, {running_mean, running_var});
80 }
81 checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
82 if (input->scalar_type() != ScalarType::Half) {
83 checkAllSameType(c, {input, weight});
84 }
85 checkAllSameType(c, {weight, bias, running_mean, running_var});
86 checkAllContiguous(c, {weight, bias, running_mean, running_var});
87 TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
88 checkDimRange(c, input, 2, 6 /* exclusive */);
89 auto num_features = input->size(1);
90 for (auto t : {weight, bias, running_mean, running_var}) {
91 if (t->defined()) {
92 checkNumel(c, t, num_features);
93 }
94 }
95
96 miopenBatchNormMode_t mode;
97 if (input->dim() == 2) {
98 mode = miopenBNPerActivation;
99 } else {
100 mode = miopenBNSpatial;
101 }
102
103 auto output_t = at::empty(input->sizes(), input->options());
104 TensorArg output{ output_t, "output", 0 };
105
106 auto handle = getMiopenHandle();
107 auto dataType = getMiopenDataType(*input);
108 TensorDescriptor idesc{ *input, 4 }; // input descriptor
109 TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, running_mean, etc.
110
111 Constant one(dataType, 1);
112 Constant zero(dataType, 0);
113 Tensor save_mean, save_var;
114
115 if (training) {
116 int64_t num_features = input_t.size(1);
117 save_mean = at::empty({ num_features }, weight_t.options());
118 save_var = at::empty({ num_features }, weight_t.options());
119 MIOPEN_CHECK(miopenBatchNormalizationForwardTraining(
120 handle, mode, &one, &zero,
121 idesc.desc(), input->const_data_ptr(),
122 idesc.desc(), output->data_ptr(),
123 wdesc.desc(),
124 // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs,
125 // not outputs. However, unfortunately the function signature only takes
126 // non-const pointers, presumably by accident
127 const_cast<void*>(weight->const_data_ptr()),
128 const_cast<void*>(bias->const_data_ptr()),
129 exponential_average_factor,
130 at::maybe_data_ptr(running_mean),
131 at::maybe_data_ptr(running_var),
132 epsilon,
133 save_mean.mutable_data_ptr(),
134 save_var.mutable_data_ptr()));
135 } else {
136 save_mean = at::empty({0}, weight_t.options());
137 save_var = at::empty({0}, weight_t.options());
138 MIOPEN_CHECK(miopenBatchNormalizationForwardInference(
139 handle, mode, &one, &zero,
140 idesc.desc(), input->const_data_ptr(),
141 idesc.desc(), output->data_ptr(),
142 wdesc.desc(),
143 // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs,
144 // not outputs. However, unfortunately the function signature only takes
145 // non-const pointers, presumably by accident
146 const_cast<void*>(weight->const_data_ptr()),
147 const_cast<void*>(bias->const_data_ptr()),
148 running_mean->data_ptr(),
149 running_var->data_ptr(),
150 epsilon));
151 }
152
153 // save_mean and save_var can be undefined
154 // If this causes problems, we can initialize them to empty tensors
155 // of the correct type
156 return std::tuple<Tensor, Tensor, Tensor>{output_t, save_mean, save_var};
157 }
158
miopen_batch_norm_backward(const Tensor & input_t,const Tensor & grad_output_t,const Tensor & weight_t,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_t_opt,const std::optional<Tensor> & save_var_t_opt,double epsilon)159 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
160 const Tensor& input_t,
161 const Tensor& grad_output_t,
162 const Tensor& weight_t,
163 // Unused: but we require them to be passed so that double backwards
164 // has access
165 const std::optional<Tensor>& running_mean_opt,
166 const std::optional<Tensor>& running_var_opt,
167 const std::optional<Tensor>& save_mean_t_opt,
168 const std::optional<Tensor>& save_var_t_opt,
169 double epsilon) {
170 // See [Note: hacky wrapper removal for optional tensor]
171 const Tensor& running_mean =
172 c10::value_or_else(running_mean_opt, [] { return Tensor(); });
173 const Tensor& running_var =
174 c10::value_or_else(running_var_opt, [] { return Tensor(); });
175 const Tensor& save_mean_t =
176 c10::value_or_else(save_mean_t_opt, [] { return Tensor(); });
177 const Tensor& save_var_t =
178 c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
179
180 TensorArg input{ input_t, "input", 1 },
181 grad_output{ grad_output_t, "grad_output", 2 },
182 weight{ weight_t, "weight", 3 },
183 save_mean{ save_mean_t, "save_mean", 4 },
184 save_var{ save_var_t, "save_var", 5 };
185 CheckedFrom c = "miopen_batch_norm_backward";
186
187 checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
188 checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
189 if (input->scalar_type() == ScalarType::Half) {
190 checkScalarType(c, weight, ScalarType::Float);
191 } else {
192 checkAllSameType(c, {input, weight});
193 }
194 checkAllSameType(c, {input, grad_output});
195 checkAllSameType(c, {weight, save_mean, save_var});
196 checkAllContiguous(c, {input, grad_output, save_mean, save_var});
197 checkDimRange(c, input, 2, 6 /* exclusive */);
198 checkSameSize(c, input, grad_output);
199 auto num_features = input->size(1);
200 for (auto t : {weight, save_mean, save_var}) {
201 checkNumel(c, t, num_features);
202 }
203
204 miopenBatchNormMode_t mode;
205 if (input->dim() == 2) {
206 mode = miopenBNPerActivation;
207 } else {
208 mode = miopenBNSpatial;
209 }
210
211 auto grad_input_t = at::empty(input->sizes(), input->options());
212 auto grad_weight_t = at::empty(weight->sizes(), weight->options());
213 auto grad_bias_t = at::empty(weight->sizes(), weight->options());
214
215 auto handle = getMiopenHandle();
216 auto dataType = getMiopenDataType(*input);
217
218 TensorDescriptor idesc{ *input, 4 }; // input, output, grad_output descriptor
219 TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc.
220
221 Constant one(dataType, 1);
222 Constant zero(dataType, 0);
223
224 MIOPEN_CHECK(miopenBatchNormalizationBackward(
225 handle, mode, &one, &zero, &one, &zero,
226 idesc.desc(), input->const_data_ptr(),
227 idesc.desc(), grad_output->const_data_ptr(),
228 idesc.desc(), grad_input_t.data_ptr(),
229 wdesc.desc(), weight->const_data_ptr(),
230 grad_weight_t.data_ptr(),
231 grad_bias_t.data_ptr(),
232 epsilon,
233 save_mean->const_data_ptr(),
234 save_var->const_data_ptr()));
235
236 return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
237 }
238
239 }} // namespace native
240
241 #endif
242