1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/batch_norm_op.h" 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/numeric_op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 27 namespace tensorflow { 28 29 typedef Eigen::ThreadPoolDevice CPUDevice; 30 typedef Eigen::GpuDevice GPUDevice; 31 #ifdef TENSORFLOW_USE_SYCL 32 typedef Eigen::SyclDevice SYCLDevice; 33 #endif // TENSORFLOW_USE_SYCL 34 35 template <typename Device, typename T> 36 class BatchNormOp : public OpKernel { 37 public: BatchNormOp(OpKernelConstruction * context)38 explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) { 39 float variance_epsilon; 40 OP_REQUIRES_OK(context, 41 context->GetAttr("variance_epsilon", &variance_epsilon)); 42 variance_epsilon_ = T(variance_epsilon); 43 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", 44 &scale_after_normalization_)); 45 } 46 Compute(OpKernelContext * context)47 void Compute(OpKernelContext* context) override { 48 const Tensor& input = context->input(0); 49 const Tensor& mean = context->input(1); 50 const Tensor& var = context->input(2); 51 const Tensor& beta = context->input(3); 52 const Tensor& gamma = context->input(4); 53 54 OP_REQUIRES(context, input.dims() == 4, 55 errors::InvalidArgument("input must be 4-dimensional", 56 input.shape().DebugString())); 57 OP_REQUIRES(context, mean.dims() == 1, 58 errors::InvalidArgument("mean must be 1-dimensional", 59 mean.shape().DebugString())); 60 OP_REQUIRES(context, var.dims() == 1, 61 errors::InvalidArgument("var must be 1-dimensional", 62 var.shape().DebugString())); 63 OP_REQUIRES(context, beta.dims() == 1, 64 errors::InvalidArgument("beta must be 1-dimensional", 65 beta.shape().DebugString())); 66 OP_REQUIRES(context, gamma.dims() == 1, 67 errors::InvalidArgument("gamma must be 1-dimensional", 68 gamma.shape().DebugString())); 69 70 Tensor* output = nullptr; 71 OP_REQUIRES_OK(context, 72 context->allocate_output(0, input.shape(), &output)); 73 74 functor::BatchNorm<Device, T>()( 75 context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(), 76 var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_, 77 scale_after_normalization_, output->tensor<T, 4>()); 78 } 79 80 private: 81 T variance_epsilon_; 82 bool scale_after_normalization_; 83 }; 84 85 template <typename Device, typename T> 86 class BatchNormGradOp : public OpKernel { 87 public: BatchNormGradOp(OpKernelConstruction * context)88 explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) { 89 float variance_epsilon; 90 OP_REQUIRES_OK(context, 91 context->GetAttr("variance_epsilon", &variance_epsilon)); 92 variance_epsilon_ = T(variance_epsilon); 93 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", 94 &scale_after_normalization_)); 95 } 96 Compute(OpKernelContext * context)97 void Compute(OpKernelContext* context) override { 98 const Tensor& input = context->input(0); 99 const Tensor& mean = context->input(1); 100 const Tensor& var = context->input(2); 101 const Tensor& gamma = context->input(3); 102 const Tensor& out_backprop = context->input(4); 103 104 OP_REQUIRES(context, input.dims() == 4, 105 errors::InvalidArgument("input must be 4-dimensional", 106 input.shape().DebugString())); 107 OP_REQUIRES(context, mean.dims() == 1, 108 errors::InvalidArgument("mean must be 1-dimensional", 109 mean.shape().DebugString())); 110 OP_REQUIRES(context, var.dims() == 1, 111 errors::InvalidArgument("var must be 1-dimensional", 112 var.shape().DebugString())); 113 OP_REQUIRES(context, gamma.dims() == 1, 114 errors::InvalidArgument("gamma must be 1-dimensional", 115 gamma.shape().DebugString())); 116 OP_REQUIRES(context, out_backprop.dims() == 4, 117 errors::InvalidArgument("out_backprop must be 4-dimensional", 118 out_backprop.shape().DebugString())); 119 120 Tensor* dx = nullptr; 121 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 122 {0, 4}, 0, input.shape(), &dx)); 123 Tensor* dm = nullptr; 124 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 125 {1}, 1, mean.shape(), &dm)); 126 Tensor* dv = nullptr; 127 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 128 {2}, 2, var.shape(), &dv)); 129 Tensor* db = nullptr; 130 if (scale_after_normalization_) { 131 OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db)); 132 } else { 133 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 134 {3}, 3, mean.shape(), &db)); 135 } 136 Tensor* dg = nullptr; 137 OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); 138 139 // Scratch buffer of [depth] dimension, aka the 4th dimension of input, 140 // which is dim_size(3), for calculating various combinations of 141 // (var + epsilon). 142 Tensor scratch1; 143 OP_REQUIRES_OK(context, context->allocate_temp( 144 DataTypeToEnum<T>::value, 145 TensorShape({input.dim_size(3)}), &scratch1)); 146 147 // Scratch buffer of [depth] dimension for saving intermediate calculation 148 // values. 149 Tensor scratch2; 150 OP_REQUIRES_OK(context, context->allocate_temp( 151 DataTypeToEnum<T>::value, 152 TensorShape({input.dim_size(3)}), &scratch2)); 153 154 functor::BatchNormGrad<Device, T>()( 155 context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(), 156 var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(), 157 variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(), 158 dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(), 159 scratch1.vec<T>(), scratch2.vec<T>()); 160 } 161 162 private: 163 T variance_epsilon_; 164 bool scale_after_normalization_; 165 }; 166 167 #define REGISTER_KERNEL(T) \ 168 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ 169 .Device(DEVICE_CPU) \ 170 .TypeConstraint<T>("T"), \ 171 BatchNormOp<CPUDevice, T>); 172 173 TF_CALL_half(REGISTER_KERNEL); 174 TF_CALL_float(REGISTER_KERNEL); 175 TF_CALL_double(REGISTER_KERNEL); 176 #undef REGISTER_KERNEL 177 178 #if GOOGLE_CUDA 179 // Forward declarations of the functor specializations for GPU. 180 namespace functor { 181 #define DECLARE_GPU_SPEC(T) \ 182 template <> \ 183 void BatchNorm<GPUDevice, T>::operator()( \ 184 const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \ 185 typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \ 186 typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \ 187 T variance_epsilon, bool scale_after_normalization, \ 188 typename TTypes<T, 4>::Tensor output); \ 189 extern template struct BatchNorm<GPUDevice, T>; 190 191 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 192 193 TF_CALL_half(DECLARE_GPU_SPECS); 194 TF_CALL_float(DECLARE_GPU_SPECS); 195 #undef DECLARE_GPU_SPEC 196 } // namespace functor 197 198 // Registration of the GPU implementations. 199 #define REGISTER_GPU_KERNEL(T) \ 200 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ 201 .Device(DEVICE_GPU) \ 202 .TypeConstraint<T>("T"), \ 203 BatchNormOp<GPUDevice, T>); 204 205 TF_CALL_half(REGISTER_GPU_KERNEL); 206 TF_CALL_float(REGISTER_GPU_KERNEL); 207 #undef REGISTER_GPU_KERNEL 208 209 #endif // GOOGLE_CUDA 210 211 #if TENSORFLOW_USE_SYCL 212 #define REGISTER_KERNEL(T) \ 213 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ 214 .Device(DEVICE_SYCL) \ 215 .TypeConstraint<T>("T"), \ 216 BatchNormOp<SYCLDevice, T>); 217 218 TF_CALL_float(REGISTER_KERNEL); 219 TF_CALL_double(REGISTER_KERNEL); 220 #undef REGISTER_KERNEL 221 #endif // TENSORFLOW_USE_SYCL 222 223 #define REGISTER_KERNEL(T) \ 224 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ 225 .Device(DEVICE_CPU) \ 226 .TypeConstraint<T>("T"), \ 227 BatchNormGradOp<CPUDevice, T>); 228 229 TF_CALL_half(REGISTER_KERNEL); 230 TF_CALL_float(REGISTER_KERNEL); 231 TF_CALL_double(REGISTER_KERNEL); 232 #undef REGISTER_KERNEL 233 234 #if GOOGLE_CUDA 235 // Forward declarations of the functor specializations for GPU. 236 namespace functor { 237 #define DECLARE_GPU_SPEC(T) \ 238 template <> \ 239 void BatchNormGrad<GPUDevice, T>::operator()( \ 240 const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \ 241 typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \ 242 typename TTypes<T>::ConstVec gamma, \ 243 typename TTypes<T, 4>::ConstTensor out_backprop, T variance_epsilon, \ 244 bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx, \ 245 typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv, \ 246 typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg, \ 247 typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \ 248 extern template struct BatchNormGrad<GPUDevice, T>; 249 250 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 251 252 TF_CALL_half(DECLARE_GPU_SPECS); 253 TF_CALL_float(DECLARE_GPU_SPECS); 254 #undef DECLARE_GPU_SPEC 255 } // namespace functor 256 257 // Registration of the GPU implementations. 258 #define REGISTER_GPU_KERNEL(T) \ 259 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ 260 .Device(DEVICE_GPU) \ 261 .TypeConstraint<T>("T"), \ 262 BatchNormGradOp<GPUDevice, T>); 263 264 TF_CALL_half(REGISTER_GPU_KERNEL); 265 TF_CALL_float(REGISTER_GPU_KERNEL); 266 #undef REGISTER_GPU_KERNEL 267 268 #endif // GOOGLE_CUDA 269 270 #if TENSORFLOW_USE_SYCL 271 #define REGISTER_KERNEL(T) \ 272 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ 273 .Device(DEVICE_SYCL) \ 274 .TypeConstraint<T>("T"), \ 275 BatchNormGradOp<SYCLDevice, T>); 276 277 TF_CALL_float(REGISTER_KERNEL); 278 TF_CALL_double(REGISTER_KERNEL); 279 #undef REGISTER_KERNEL 280 281 #endif // TENSORFLOW_USE_SYCL 282 283 } // namespace tensorflow 284