• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/empty.h>
10 #include <ATen/ops/miopen_batch_norm_native.h>
11 #include <ATen/ops/miopen_batch_norm_backward_native.h>
12 #endif
13 
14 // TODO: Remove the condition on AT_ROCM_ENABLED entirely,
15 // don't build this file as part of CPU build.
16 #include <ATen/cuda/CUDAConfig.h>
17 
18 #if !AT_ROCM_ENABLED()
19 
20 namespace at { namespace native {
21 
22 // See Note [ATen preprocessor philosophy]
23 
miopen_batch_norm(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double exponential_average_factor,double epsilon)24 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
25     const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
26     bool training, double exponential_average_factor, double epsilon) {
27   AT_ERROR("miopen_batch_norm: ATen not compiled with MIOpen support");
28 }
29 
miopen_batch_norm_backward(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,double epsilon)30 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
31     const Tensor& input, const Tensor& grad_output, const Tensor& weight, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
32     double epsilon) {
33   AT_ERROR("miopen_batch_norm_backward: ATen not compiled with MIOpen support");
34 }
35 
36 }}  // namespace at::native
37 
38 #else // AT_ROCM_ENABLED
39 
40 #include <ATen/miopen/Descriptors.h>
41 #include <ATen/miopen/Types.h>
42 #include <ATen/miopen/Utils.h>
43 
44 #include <ATen/TensorUtils.h>
45 
46 namespace at { namespace native {
47 
48 namespace {
49 
expandScale(const Tensor & t,int64_t dim)50 Tensor expandScale(const Tensor& t, int64_t dim) {
51   std::vector<int64_t> size{ 1, t.numel() };
52   while (static_cast<int64_t>(size.size()) < dim) {
53     size.emplace_back(1);
54   }
55   return t.view(size);
56 }
57 
58 }  // namespace
59 
miopen_batch_norm(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t_opt,const std::optional<Tensor> & running_mean_t_opt,const std::optional<Tensor> & running_var_t_opt,bool training,double exponential_average_factor,double epsilon)60 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
61     const Tensor& input_t, const Tensor& weight_t, const std::optional<Tensor>& bias_t_opt, const std::optional<Tensor>& running_mean_t_opt, const std::optional<Tensor>& running_var_t_opt,
62     bool training, double exponential_average_factor, double epsilon)
63 {
64   // See [Note: hacky wrapper removal for optional tensor]
65   c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
66   const Tensor& bias_t = *bias_t_maybe_owned;
67   const Tensor& running_mean_t = c10::value_or_else(running_mean_t_opt, [] {return Tensor();});
68   const Tensor& running_var_t = c10::value_or_else(running_var_t_opt, [] {return Tensor();});
69 
70   TensorArg input{ input_t, "input", 1 },
71             weight{ weight_t, "weight", 2 },
72             bias{ bias_t, "bias", 3 },
73             running_mean{ running_mean_t, "running_mean", 4 },
74             running_var{ running_var_t, "running_var", 5 };
75   CheckedFrom c = "miopen_batch_norm";
76 
77   checkAllDefined(c, {input, weight, bias});
78   if (!training) {
79     checkAllDefined(c, {running_mean, running_var});
80   }
81   checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
82   if (input->scalar_type() != ScalarType::Half) {
83     checkAllSameType(c, {input, weight});
84   }
85   checkAllSameType(c, {weight, bias, running_mean, running_var});
86   checkAllContiguous(c, {weight, bias, running_mean, running_var});
87   TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
88   checkDimRange(c, input, 2, 6 /* exclusive */);
89   auto num_features = input->size(1);
90   for (auto t : {weight, bias, running_mean, running_var}) {
91     if (t->defined()) {
92       checkNumel(c, t, num_features);
93     }
94   }
95 
96   miopenBatchNormMode_t mode;
97   if (input->dim() == 2) {
98     mode = miopenBNPerActivation;
99   } else {
100     mode = miopenBNSpatial;
101   }
102 
103   auto output_t = at::empty(input->sizes(), input->options());
104   TensorArg output{ output_t, "output", 0 };
105 
106   auto handle = getMiopenHandle();
107   auto dataType = getMiopenDataType(*input);
108   TensorDescriptor idesc{ *input, 4 };  // input descriptor
109   TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 };  // descriptor for weight, bias, running_mean, etc.
110 
111   Constant one(dataType, 1);
112   Constant zero(dataType, 0);
113   Tensor save_mean, save_var;
114 
115   if (training) {
116     int64_t num_features = input_t.size(1);
117     save_mean = at::empty({ num_features }, weight_t.options());
118     save_var = at::empty({ num_features }, weight_t.options());
119     MIOPEN_CHECK(miopenBatchNormalizationForwardTraining(
120       handle, mode, &one, &zero,
121       idesc.desc(), input->const_data_ptr(),
122       idesc.desc(), output->data_ptr(),
123       wdesc.desc(),
124       // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs,
125       // not outputs. However, unfortunately the function signature only takes
126       // non-const pointers, presumably by accident
127       const_cast<void*>(weight->const_data_ptr()),
128       const_cast<void*>(bias->const_data_ptr()),
129       exponential_average_factor,
130       at::maybe_data_ptr(running_mean),
131       at::maybe_data_ptr(running_var),
132       epsilon,
133       save_mean.mutable_data_ptr(),
134       save_var.mutable_data_ptr()));
135   } else {
136     save_mean = at::empty({0}, weight_t.options());
137     save_var = at::empty({0}, weight_t.options());
138     MIOPEN_CHECK(miopenBatchNormalizationForwardInference(
139       handle, mode, &one, &zero,
140       idesc.desc(), input->const_data_ptr(),
141       idesc.desc(), output->data_ptr(),
142       wdesc.desc(),
143       // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs,
144       // not outputs. However, unfortunately the function signature only takes
145       // non-const pointers, presumably by accident
146       const_cast<void*>(weight->const_data_ptr()),
147       const_cast<void*>(bias->const_data_ptr()),
148       running_mean->data_ptr(),
149       running_var->data_ptr(),
150       epsilon));
151   }
152 
153   // save_mean and save_var can be undefined
154   // If this causes problems, we can initialize them to empty tensors
155   // of the correct type
156   return std::tuple<Tensor, Tensor, Tensor>{output_t, save_mean, save_var};
157 }
158 
miopen_batch_norm_backward(const Tensor & input_t,const Tensor & grad_output_t,const Tensor & weight_t,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_t_opt,const std::optional<Tensor> & save_var_t_opt,double epsilon)159 std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
160     const Tensor& input_t,
161     const Tensor& grad_output_t,
162     const Tensor& weight_t,
163     // Unused: but we require them to be passed so that double backwards
164     // has access
165     const std::optional<Tensor>& running_mean_opt,
166     const std::optional<Tensor>& running_var_opt,
167     const std::optional<Tensor>& save_mean_t_opt,
168     const std::optional<Tensor>& save_var_t_opt,
169     double epsilon) {
170   // See [Note: hacky wrapper removal for optional tensor]
171   const Tensor& running_mean =
172       c10::value_or_else(running_mean_opt, [] { return Tensor(); });
173   const Tensor& running_var =
174       c10::value_or_else(running_var_opt, [] { return Tensor(); });
175   const Tensor& save_mean_t =
176       c10::value_or_else(save_mean_t_opt, [] { return Tensor(); });
177   const Tensor& save_var_t =
178       c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
179 
180   TensorArg input{ input_t, "input", 1 },
181             grad_output{ grad_output_t, "grad_output", 2 },
182             weight{ weight_t, "weight", 3 },
183             save_mean{ save_mean_t, "save_mean", 4 },
184             save_var{ save_var_t, "save_var", 5 };
185   CheckedFrom c = "miopen_batch_norm_backward";
186 
187   checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
188   checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
189   if (input->scalar_type() == ScalarType::Half) {
190     checkScalarType(c, weight, ScalarType::Float);
191   } else {
192     checkAllSameType(c, {input, weight});
193   }
194   checkAllSameType(c, {input, grad_output});
195   checkAllSameType(c, {weight, save_mean, save_var});
196   checkAllContiguous(c, {input, grad_output, save_mean, save_var});
197   checkDimRange(c, input, 2, 6 /* exclusive */);
198   checkSameSize(c, input, grad_output);
199   auto num_features = input->size(1);
200   for (auto t : {weight, save_mean, save_var}) {
201     checkNumel(c, t, num_features);
202   }
203 
204   miopenBatchNormMode_t mode;
205   if (input->dim() == 2) {
206     mode = miopenBNPerActivation;
207   } else {
208     mode = miopenBNSpatial;
209   }
210 
211   auto grad_input_t  = at::empty(input->sizes(), input->options());
212   auto grad_weight_t = at::empty(weight->sizes(), weight->options());
213   auto grad_bias_t   = at::empty(weight->sizes(), weight->options());
214 
215   auto handle = getMiopenHandle();
216   auto dataType = getMiopenDataType(*input);
217 
218   TensorDescriptor idesc{ *input, 4 };  // input, output, grad_output descriptor
219   TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 };  // descriptor for weight, bias, save_mean, etc.
220 
221   Constant one(dataType, 1);
222   Constant zero(dataType, 0);
223 
224   MIOPEN_CHECK(miopenBatchNormalizationBackward(
225     handle, mode, &one, &zero, &one, &zero,
226     idesc.desc(), input->const_data_ptr(),
227     idesc.desc(), grad_output->const_data_ptr(),
228     idesc.desc(), grad_input_t.data_ptr(),
229     wdesc.desc(), weight->const_data_ptr(),
230     grad_weight_t.data_ptr(),
231     grad_bias_t.data_ptr(),
232     epsilon,
233     save_mean->const_data_ptr(),
234     save_var->const_data_ptr()));
235 
236   return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
237 }
238 
239 }}  // namespace native
240 
241 #endif
242