• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/batch_norm_op.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 
27 namespace tensorflow {
28 
29 typedef Eigen::ThreadPoolDevice CPUDevice;
30 typedef Eigen::GpuDevice GPUDevice;
31 #ifdef TENSORFLOW_USE_SYCL
32 typedef Eigen::SyclDevice SYCLDevice;
33 #endif  // TENSORFLOW_USE_SYCL
34 
35 template <typename Device, typename T>
36 class BatchNormOp : public OpKernel {
37  public:
BatchNormOp(OpKernelConstruction * context)38   explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
39     float variance_epsilon;
40     OP_REQUIRES_OK(context,
41                    context->GetAttr("variance_epsilon", &variance_epsilon));
42     variance_epsilon_ = T(variance_epsilon);
43     OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
44                                              &scale_after_normalization_));
45   }
46 
Compute(OpKernelContext * context)47   void Compute(OpKernelContext* context) override {
48     const Tensor& input = context->input(0);
49     const Tensor& mean = context->input(1);
50     const Tensor& var = context->input(2);
51     const Tensor& beta = context->input(3);
52     const Tensor& gamma = context->input(4);
53 
54     OP_REQUIRES(context, input.dims() == 4,
55                 errors::InvalidArgument("input must be 4-dimensional",
56                                         input.shape().DebugString()));
57     OP_REQUIRES(context, mean.dims() == 1,
58                 errors::InvalidArgument("mean must be 1-dimensional",
59                                         mean.shape().DebugString()));
60     OP_REQUIRES(context, var.dims() == 1,
61                 errors::InvalidArgument("var must be 1-dimensional",
62                                         var.shape().DebugString()));
63     OP_REQUIRES(context, beta.dims() == 1,
64                 errors::InvalidArgument("beta must be 1-dimensional",
65                                         beta.shape().DebugString()));
66     OP_REQUIRES(context, gamma.dims() == 1,
67                 errors::InvalidArgument("gamma must be 1-dimensional",
68                                         gamma.shape().DebugString()));
69 
70     Tensor* output = nullptr;
71     OP_REQUIRES_OK(context,
72                    context->allocate_output(0, input.shape(), &output));
73 
74     functor::BatchNorm<Device, T>()(
75         context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
76         var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_,
77         scale_after_normalization_, output->tensor<T, 4>());
78   }
79 
80  private:
81   T variance_epsilon_;
82   bool scale_after_normalization_;
83 };
84 
85 template <typename Device, typename T>
86 class BatchNormGradOp : public OpKernel {
87  public:
BatchNormGradOp(OpKernelConstruction * context)88   explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
89     float variance_epsilon;
90     OP_REQUIRES_OK(context,
91                    context->GetAttr("variance_epsilon", &variance_epsilon));
92     variance_epsilon_ = T(variance_epsilon);
93     OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
94                                              &scale_after_normalization_));
95   }
96 
Compute(OpKernelContext * context)97   void Compute(OpKernelContext* context) override {
98     const Tensor& input = context->input(0);
99     const Tensor& mean = context->input(1);
100     const Tensor& var = context->input(2);
101     const Tensor& gamma = context->input(3);
102     const Tensor& out_backprop = context->input(4);
103 
104     OP_REQUIRES(context, input.dims() == 4,
105                 errors::InvalidArgument("input must be 4-dimensional",
106                                         input.shape().DebugString()));
107     OP_REQUIRES(context, mean.dims() == 1,
108                 errors::InvalidArgument("mean must be 1-dimensional",
109                                         mean.shape().DebugString()));
110     OP_REQUIRES(context, var.dims() == 1,
111                 errors::InvalidArgument("var must be 1-dimensional",
112                                         var.shape().DebugString()));
113     OP_REQUIRES(context, gamma.dims() == 1,
114                 errors::InvalidArgument("gamma must be 1-dimensional",
115                                         gamma.shape().DebugString()));
116     OP_REQUIRES(context, out_backprop.dims() == 4,
117                 errors::InvalidArgument("out_backprop must be 4-dimensional",
118                                         out_backprop.shape().DebugString()));
119 
120     Tensor* dx = nullptr;
121     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
122                                 {0, 4}, 0, input.shape(), &dx));
123     Tensor* dm = nullptr;
124     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
125                                 {1}, 1, mean.shape(), &dm));
126     Tensor* dv = nullptr;
127     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
128                                 {2}, 2, var.shape(), &dv));
129     Tensor* db = nullptr;
130     if (scale_after_normalization_) {
131       OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
132     } else {
133       OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
134                                   {3}, 3, mean.shape(), &db));
135     }
136     Tensor* dg = nullptr;
137     OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
138 
139     // Scratch buffer of [depth] dimension, aka the 4th dimension of input,
140     // which is dim_size(3), for calculating various combinations of
141     // (var + epsilon).
142     Tensor scratch1;
143     OP_REQUIRES_OK(context, context->allocate_temp(
144                                 DataTypeToEnum<T>::value,
145                                 TensorShape({input.dim_size(3)}), &scratch1));
146 
147     // Scratch buffer of [depth] dimension for saving intermediate calculation
148     // values.
149     Tensor scratch2;
150     OP_REQUIRES_OK(context, context->allocate_temp(
151                                 DataTypeToEnum<T>::value,
152                                 TensorShape({input.dim_size(3)}), &scratch2));
153 
154     functor::BatchNormGrad<Device, T>()(
155         context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
156         var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(),
157         variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(),
158         dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(),
159         scratch1.vec<T>(), scratch2.vec<T>());
160   }
161 
162  private:
163   T variance_epsilon_;
164   bool scale_after_normalization_;
165 };
166 
167 #define REGISTER_KERNEL(T)                                         \
168   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
169                               .Device(DEVICE_CPU)                  \
170                               .TypeConstraint<T>("T"),             \
171                           BatchNormOp<CPUDevice, T>);
172 
173 TF_CALL_half(REGISTER_KERNEL);
174 TF_CALL_float(REGISTER_KERNEL);
175 TF_CALL_double(REGISTER_KERNEL);
176 #undef REGISTER_KERNEL
177 
178 #if GOOGLE_CUDA
179 // Forward declarations of the functor specializations for GPU.
180 namespace functor {
181 #define DECLARE_GPU_SPEC(T)                                                  \
182   template <>                                                                \
183   void BatchNorm<GPUDevice, T>::operator()(                                  \
184       const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,          \
185       typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var,   \
186       typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \
187       T variance_epsilon, bool scale_after_normalization,                    \
188       typename TTypes<T, 4>::Tensor output);                                 \
189   extern template struct BatchNorm<GPUDevice, T>;
190 
191 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
192 
193 TF_CALL_half(DECLARE_GPU_SPECS);
194 TF_CALL_float(DECLARE_GPU_SPECS);
195 #undef DECLARE_GPU_SPEC
196 }  // namespace functor
197 
198 // Registration of the GPU implementations.
199 #define REGISTER_GPU_KERNEL(T)                                     \
200   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
201                               .Device(DEVICE_GPU)                  \
202                               .TypeConstraint<T>("T"),             \
203                           BatchNormOp<GPUDevice, T>);
204 
205 TF_CALL_half(REGISTER_GPU_KERNEL);
206 TF_CALL_float(REGISTER_GPU_KERNEL);
207 #undef REGISTER_GPU_KERNEL
208 
209 #endif  // GOOGLE_CUDA
210 
211 #if TENSORFLOW_USE_SYCL
212 #define REGISTER_KERNEL(T)                                         \
213   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
214                               .Device(DEVICE_SYCL)                 \
215                               .TypeConstraint<T>("T"),             \
216                           BatchNormOp<SYCLDevice, T>);
217 
218 TF_CALL_float(REGISTER_KERNEL);
219 TF_CALL_double(REGISTER_KERNEL);
220 #undef REGISTER_KERNEL
221 #endif  // TENSORFLOW_USE_SYCL
222 
223 #define REGISTER_KERNEL(T)                                             \
224   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
225                               .Device(DEVICE_CPU)                      \
226                               .TypeConstraint<T>("T"),                 \
227                           BatchNormGradOp<CPUDevice, T>);
228 
229 TF_CALL_half(REGISTER_KERNEL);
230 TF_CALL_float(REGISTER_KERNEL);
231 TF_CALL_double(REGISTER_KERNEL);
232 #undef REGISTER_KERNEL
233 
234 #if GOOGLE_CUDA
235 // Forward declarations of the functor specializations for GPU.
236 namespace functor {
237 #define DECLARE_GPU_SPEC(T)                                                \
238   template <>                                                              \
239   void BatchNormGrad<GPUDevice, T>::operator()(                            \
240       const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,        \
241       typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \
242       typename TTypes<T>::ConstVec gamma,                                  \
243       typename TTypes<T, 4>::ConstTensor out_backprop, T variance_epsilon, \
244       bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx,    \
245       typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv,              \
246       typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg,              \
247       typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \
248   extern template struct BatchNormGrad<GPUDevice, T>;
249 
250 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
251 
252 TF_CALL_half(DECLARE_GPU_SPECS);
253 TF_CALL_float(DECLARE_GPU_SPECS);
254 #undef DECLARE_GPU_SPEC
255 }  // namespace functor
256 
257 // Registration of the GPU implementations.
258 #define REGISTER_GPU_KERNEL(T)                                         \
259   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
260                               .Device(DEVICE_GPU)                      \
261                               .TypeConstraint<T>("T"),                 \
262                           BatchNormGradOp<GPUDevice, T>);
263 
264 TF_CALL_half(REGISTER_GPU_KERNEL);
265 TF_CALL_float(REGISTER_GPU_KERNEL);
266 #undef REGISTER_GPU_KERNEL
267 
268 #endif  // GOOGLE_CUDA
269 
270 #if TENSORFLOW_USE_SYCL
271 #define REGISTER_KERNEL(T)                                             \
272   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
273                               .Device(DEVICE_SYCL)                     \
274                               .TypeConstraint<T>("T"),                 \
275                           BatchNormGradOp<SYCLDevice, T>);
276 
277 TF_CALL_float(REGISTER_KERNEL);
278 TF_CALL_double(REGISTER_KERNEL);
279 #undef REGISTER_KERNEL
280 
281 #endif  // TENSORFLOW_USE_SYCL
282 
283 }  // namespace tensorflow
284