• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 #define EIGEN_USE_GPU
18 #if GOOGLE_CUDA
19 #include "third_party/gpus/cuda/include/cuda.h"
20 #endif
21 #if TENSORFLOW_USE_ROCM
22 #include "rocm/include/hip/hip_fp16.h"
23 typedef __half2 half2;
24 #endif
25 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27 
28 namespace tensorflow {
29 typedef Eigen::GpuDevice GPUDevice;
30 
31 namespace functor {
32 
33 // TODO(ezhulenev): Use CUB reductions on GPU.
34 template <typename T, typename U>
35 struct FusedBatchNormFreezeGrad<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormFreezeGrad36   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
37                   const Tensor& x_input, const Tensor& scale_input,
38                   const Tensor& pop_mean_input,
39                   const Tensor& pop_variance_input, U epsilon,
40                   Tensor* x_backprop_output, Tensor* scale_backprop_output,
41                   Tensor* offset_backprop_output) {
42     typename TTypes<T, 4>::ConstTensor y_backprop(
43         y_backprop_input.tensor<T, 4>());
44     typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
45     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
46     typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
47     typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
48     typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
49     typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
50     typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
51 
52     const int depth = pop_mean.dimension(0);
53     const int rest_size = input.size() / depth;
54 
55     // Allocate two temporary workspaces of [depth] shape.
56     Tensor scratch1_vec, scratch2_vec;
57     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
58                                                    {depth}, &scratch1_vec));
59     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
60                                                    {depth}, &scratch2_vec));
61 
62     typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
63     typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
64 
65     const GPUDevice& d = context->eigen_device<GPUDevice>();
66 
67     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
68 #if !defined(EIGEN_HAS_INDEX_LIST)
69     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
70     Eigen::array<int, 1> reduction_axis{0};
71     Eigen::array<int, 2> rest_by_one({rest_size, 1});
72 #else
73     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
74     one_by_depth.set(1, depth);
75     Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
76     Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
77     rest_by_one.set(0, rest_size);
78 #endif
79 
80     // offset_backprop  = sum(y_backprop)
81     // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
82     // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
83 
84     auto y_backprop_rest_by_depth =
85         y_backprop.reshape(rest_by_depth).template cast<U>();
86     auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
87 
88     offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
89 
90     // scratch1 = rsqrt(pop_var + epsilon)
91     scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
92 
93     // scratch2 = sum(y_backprop * (x - mean))
94     scratch2.device(d) =
95         (y_backprop_rest_by_depth *
96          (input_rest_by_depth -
97           pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
98             .sum(reduction_axis);
99 
100     x_backprop.reshape(rest_by_depth).device(d) =
101         (y_backprop_rest_by_depth *
102          ((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one)))
103             .template cast<T>();
104     scale_backprop.device(d) = scratch2 * scratch1;
105   }
106 };
107 
108 template struct FusedBatchNormFreezeGrad<GPUDevice, float, float>;
109 template struct FusedBatchNormFreezeGrad<GPUDevice, Eigen::half, float>;
110 
111 template <class T>
operator ()(const Eigen::GpuDevice & d,typename TTypes<T>::Flat out)112 void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d,
113                                   typename TTypes<T>::Flat out) {
114   To32Bit(out).device(d) =
115       To32Bit(out).constant(Eigen::NumTraits<T>::quiet_NaN());
116 }
117 
118 template class SetNanFunctor<float>;
119 
120 // -------------------------------------------------------------------------- //
121 // FusedBatchNormInferenceFunctor implementation.                             //
122 // -------------------------------------------------------------------------- //
123 
124 // Generic kernel, that does all computations by converting input to U data
125 // type. We use it when CUDA architecture doesn't have fast arithmetic fot the
126 // T data type (e.g. no fp16 in old GPU generations).
127 template <typename T, typename U, TensorFormat tensor_format,
128           bool add_side_input, FusedBatchNormActivationMode activation_mode,
129           bool is_generic_kernel>
130 struct FusedBatchNormInferenceKernel {
131   static_assert(tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW,
132                 "Unsupported data format");
133 
runtensorflow::functor::FusedBatchNormInferenceKernel134   __device__ static void run(int32 count, int32 channels_size,
135                              int32 inner_dim_size, const T* __restrict__ in,
136                              const U* __restrict__ scale,
137                              const U* __restrict__ offset,
138                              const U* __restrict__ mean,
139                              const U* __restrict__ var,
140                              const T* __restrict__ side_input, float epsilon,
141                              T* __restrict__ out) {
142     int32 index = blockIdx.x * blockDim.x + threadIdx.x;
143     const int32 total_device_threads = gridDim.x * blockDim.x;
144 
145     while (index < count) {
146       const int channel = (tensor_format == FORMAT_NHWC)
147                               ? index % channels_size
148                               : (index / inner_dim_size) % channels_size;
149 
150       U in_v = U(in[index]);
151       U scale_v = scale[channel];
152       U offset_v = offset[channel];
153       U mean_v = mean[channel];
154       U var_v = var[channel];
155 
156       U scaling_factor_v = rsqrt(var_v + epsilon) * scale_v;
157       static_assert(std::is_same<U, float>::value, "U data type must be float");
158       U shifted_v = fmaf(in_v - mean_v, scaling_factor_v, offset_v);
159 
160       if (add_side_input) {
161         shifted_v += U(side_input[index]);
162       }
163 
164       if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
165         out[index] = T(shifted_v);
166       } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
167         out[index] = T(shifted_v < U(0) ? U(0) : shifted_v);
168       }
169 
170       index += total_device_threads;
171     }
172   }
173 };
174 
175 // Specialization for T=Eigen::half and U=float.
176 template <TensorFormat tensor_format, bool add_side_input,
177           FusedBatchNormActivationMode activation_mode>
178 struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
179                                      add_side_input, activation_mode,
180                                      /*is_generic_kernel=*/false> {
181 #if TENSORFLOW_USE_ROCM
182   using IT = __half;
183 #else
184   using IT = Eigen::half;
185 #endif
186   using T = Eigen::half;
187   using U = float;
188 
189   // If CUDA architecture doesn't support fast fp16 computation, we will
190   // fallback on generic kernel defined above.
191   using GenericKernel =
192       FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
193                                     activation_mode,
194                                     /*is_generic_kernel=*/true>;
195 
runtensorflow::functor::FusedBatchNormInferenceKernel196   __device__ static void run(int32 count, int32 channels_size,
197                              int32 inner_dim_size, const T* __restrict__ _in,
198                              const U* __restrict__ scale,
199                              const U* __restrict__ offset,
200                              const U* __restrict__ mean,
201                              const U* __restrict__ var,
202                              const T* __restrict__ _side_input, float epsilon,
203                              T* __restrict__ _out) {
204     // Old GPUs do not have (or have very slow) fp16 arithmetic.
205 #if (__CUDA_ARCH__ >= 610) || TENSORFLOW_USE_ROCM
206     const IT* in = reinterpret_cast<const IT*>(_in);
207     const IT* side_input = reinterpret_cast<const IT*>(_side_input);
208     IT* out = reinterpret_cast<IT*>(_out);
209 
210     int32 index = blockIdx.x * blockDim.x + threadIdx.x;
211     const int32 total_device_threads = gridDim.x * blockDim.x;
212 
213     int32 half2_count = count >> 1;
214 
215     half epsilon_h = __float2half(epsilon);
216     half2 epsilon_h2 = __float2half2_rn(epsilon);
217 
218     const int32 max_channel_size = channels_size - 1;
219 
220     while (index < half2_count) {
221       int32 channel[2];
222       if (tensor_format == FORMAT_NHWC) {
223         channel[0] = (2 * index) % channels_size;
224         channel[1] = channel[0] == max_channel_size ? 0 : channel[0] + 1;
225       } else {
226         channel[0] = ((2 * index) / inner_dim_size) % channels_size;
227         channel[1] = ((2 * index + 1) / inner_dim_size) % channels_size;
228       }
229 
230       half2 in_v = reinterpret_cast<const half2*>(in)[index];
231       half2 scale_v = __floats2half2_rn(scale[channel[0]], scale[channel[1]]);
232       half2 offset_v =
233           __floats2half2_rn(offset[channel[0]], offset[channel[1]]);
234       half2 mean_v = __floats2half2_rn(mean[channel[0]], mean[channel[1]]);
235       half2 var_v = __floats2half2_rn(var[channel[0]], var[channel[1]]);
236 
237       half2 scaling_factor_v =
238           __hmul2(h2rsqrt(__hadd2(var_v, epsilon_h2)), scale_v);
239       half2 shifted_v =
240           __hfma2(__hsub2(in_v, mean_v), scaling_factor_v, offset_v);
241 
242       if (add_side_input) {
243         shifted_v = __hadd2(shifted_v,
244                             reinterpret_cast<const half2*>(side_input)[index]);
245       }
246 
247       if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
248         reinterpret_cast<half2*>(out)[index] = shifted_v;
249 
250       } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
251         const half2 kZeroH = __float2half2_rn(0.f);
252         const half2 mask_h = __hgt2(shifted_v, kZeroH);
253         reinterpret_cast<half2*>(out)[index] = __hmul2(mask_h, shifted_v);
254       }
255 
256       index += total_device_threads;
257     }
258 
259     if ((count & 0x1) == 1 && index == half2_count) {
260       index = count - 1;
261 
262       const int32 channel = (tensor_format == FORMAT_NHWC)
263                                 ? index % channels_size
264                                 : (index / inner_dim_size) % channels_size;
265 
266       half in_v = in[index];
267       half scale_v = __float2half(scale[channel]);
268       half offset_v = __float2half(offset[channel]);
269       half mean_v = __float2half(mean[channel]);
270       half var_v = __float2half(var[channel]);
271 
272       half scaling_factor_v = __hmul(hrsqrt(__hadd(var_v, epsilon_h)), scale_v);
273       half shifted_v = __hfma(__hsub(in_v, mean_v), scaling_factor_v, offset_v);
274 
275       if (add_side_input) {
276         shifted_v = __hadd(shifted_v, side_input[index]);
277       }
278 
279       if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
280         out[index] = shifted_v;
281 
282       } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
283         const half kZeroH = __float2half(0.f);
284         const half mask_h = __hgt(shifted_v, kZeroH);
285         out[index] = __hmul(mask_h, shifted_v);
286       }
287     }
288 
289 #else
290     GenericKernel::run(count, channels_size, inner_dim_size, _in, scale, offset,
291                        mean, var, _side_input, epsilon, _out);
292 #endif  // __CUDA_ARCH__ >= 610
293   }
294 };
295 
296 template <typename T, typename U, TensorFormat tensor_format,
297           bool add_side_input, FusedBatchNormActivationMode activation_mode>
FusedBatchNormInferenceMetaKernel(int32 count,int32 channels_size,int32 inner_dim_size,const T * in,const U * scale,const U * offset,const U * mean,const U * var,const T * side_input,float epsilon,T * out)298 __global__ void FusedBatchNormInferenceMetaKernel(
299     int32 count, int32 channels_size, int32 inner_dim_size, const T* in,
300     const U* scale, const U* offset, const U* mean, const U* var,
301     const T* side_input, float epsilon, T* out) {
302   // We prefer to run non-generic specialization, for the given types T and U.
303   FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
304                                 activation_mode,
305 #if TENSORFLOW_USE_ROCM
306                                 false
307 #else
308                                 // TODO(b/135435976): Temporary disable
309                                 // non-generic kernel implementation.
310                                 /*is_generic_kernel=*/true
311 #endif
312                                 >::run(count, channels_size, inner_dim_size, in,
313                                        scale, offset, mean, var, side_input,
314                                        epsilon, out);
315 }
316 
317 template <typename T, typename U>
318 struct FusedBatchNormInferenceFunctor<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormInferenceFunctor319   void operator()(OpKernelContext* context, TensorFormat tensor_format,
320                   typename TTypes<T, 4>::ConstTensor in,
321                   typename TTypes<U>::ConstVec scale,
322                   typename TTypes<U>::ConstVec offset,
323                   typename TTypes<U>::ConstVec estimated_mean,
324                   typename TTypes<U>::ConstVec estimated_variance,
325                   typename TTypes<T, 4>::ConstTensor side_input, U epsilon,
326                   FusedBatchNormActivationMode activation_mode,
327                   typename TTypes<T, 4>::Tensor out) {
328     const auto& d = context->eigen_device<GPUDevice>();
329 
330     const int32 count = out.size();
331     if (count == 0) return;
332 
333     bool launched = false;
334 #if TENSORFLOW_USE_ROCM
335     constexpr int32 kThreadInBlock = 1024;
336 #else
337     constexpr int32 kThreadInBlock = 512;
338 #endif
339 
340 #define LAUNCH(DATA_FORMAT, ADD_SIDE_INPUT, ACTIVATION, CHANNEL_SIZE,          \
341                INNER_DIM_SIZE)                                                 \
342   launched = true;                                                             \
343                                                                                \
344   GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(                   \
345       std::is_same<T, Eigen::half>::value ? Eigen::divup(count, 2) : count, d, \
346       FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT,     \
347                                         ACTIVATION>,                           \
348       0, kThreadInBlock);                                                      \
349                                                                                \
350   TF_CHECK_OK(GpuLaunchKernel(                                                 \
351       FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT,     \
352                                         ACTIVATION>,                           \
353       config.block_count, config.thread_per_block, 0, d.stream(), count,       \
354       CHANNEL_SIZE, INNER_DIM_SIZE, in.data(), scale.data(), offset.data(),    \
355       estimated_mean.data(), estimated_variance.data(), side_input.data(),     \
356       epsilon, out.data()));
357 
358     const bool no_side_input = side_input.dimensions().TotalSize() == 0;
359     const bool add_side_input = side_input.dimensions().TotalSize() != 0;
360 
361     using Activation = FusedBatchNormActivationMode;
362     const bool no_activation = activation_mode == Activation::kIdentity;
363     const bool relu_activation = activation_mode == Activation::kRelu;
364 
365     if (tensor_format == FORMAT_NHWC) {
366       const int c = in.dimensions()[3];
367 
368       if (no_activation && no_side_input) {
369         LAUNCH(FORMAT_NHWC, false, Activation::kIdentity, c, 1);
370       } else if (relu_activation && no_side_input) {
371         LAUNCH(FORMAT_NHWC, false, Activation::kRelu, c, 1);
372       } else if (no_activation && add_side_input) {
373         LAUNCH(FORMAT_NHWC, true, Activation::kIdentity, c, 1);
374       } else if (relu_activation && add_side_input) {
375         LAUNCH(FORMAT_NHWC, true, Activation::kRelu, c, 1);
376       }
377 
378     } else if (tensor_format == FORMAT_NCHW) {
379       const int c = in.dimensions()[1];
380       const int inner = in.dimensions()[2] * in.dimensions()[3];
381 
382       if (no_activation && no_side_input) {
383         LAUNCH(FORMAT_NCHW, false, Activation::kIdentity, c, inner);
384       } else if (relu_activation && no_side_input) {
385         LAUNCH(FORMAT_NCHW, false, Activation::kRelu, c, inner);
386       } else if (no_activation && add_side_input) {
387         LAUNCH(FORMAT_NCHW, true, Activation::kIdentity, c, inner);
388       } else if (relu_activation && add_side_input) {
389         LAUNCH(FORMAT_NCHW, true, Activation::kRelu, c, inner);
390       }
391     }
392 #undef LAUNCH
393 
394     OP_REQUIRES(context, launched,
395                 errors::InvalidArgument("Unsupported launch configuration"));
396   }
397 };
398 
399 template struct FusedBatchNormInferenceFunctor<GPUDevice, float, float>;
400 template struct FusedBatchNormInferenceFunctor<GPUDevice, Eigen::half, float>;
401 
402 }  // namespace functor
403 }  // namespace tensorflow
404 
405 #else
406 
407 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
408 
409 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
410