• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
18 // Functor definition for BatchNormOp, must be compilable by nvcc.
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/tensor_types.h"
21 
22 namespace tensorflow {
23 namespace functor {
24 
25 // Functor used by BatchNormOp to do the computations.
26 template <typename Device, typename T>
27 struct BatchNorm {
operatorBatchNorm28   void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
29                   typename TTypes<T>::ConstVec mean,
30                   typename TTypes<T>::ConstVec var,
31                   typename TTypes<T>::ConstVec beta,
32                   typename TTypes<T>::ConstVec gamma, T variance_epsilon,
33                   bool scale_after_normalization,
34                   typename TTypes<T, 4>::Tensor output) {
35     const int depth = mean.dimension(0);
36     const int rest_size = input.size() / depth;
37 
38     Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth);
39 #if !defined(EIGEN_HAS_INDEX_LIST)
40     Eigen::DSizes<int, 2> rest_by_one(rest_size, 1);
41     Eigen::DSizes<int, 2> one_by_depth(1, depth);
42     Eigen::DSizes<int, 2> depth_by_one(depth, 1);
43 #else
44     Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one;
45     rest_by_one.set(0, rest_size);
46     Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth;
47     one_by_depth.set(1, depth);
48     Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one;
49     depth_by_one.set(0, depth);
50 #endif
51     if (scale_after_normalization) {
52       output.reshape(rest_by_depth).device(d) =
53           (input.reshape(rest_by_depth) -
54            mean.reshape(one_by_depth).broadcast(rest_by_one)) *
55               ((var + var.constant(variance_epsilon)).rsqrt() * gamma)
56                   .eval()
57                   .reshape(one_by_depth)
58                   .broadcast(rest_by_one) +
59           beta.reshape(one_by_depth).broadcast(rest_by_one);
60     } else {
61       output.reshape(rest_by_depth).device(d) =
62           (input.reshape(rest_by_depth) -
63            mean.reshape(one_by_depth).broadcast(rest_by_one)) *
64               ((var + var.constant(variance_epsilon)).rsqrt())
65                   .eval()
66                   .reshape(one_by_depth)
67                   .broadcast(rest_by_one) +
68           beta.reshape(one_by_depth).broadcast(rest_by_one);
69     }
70   }
71 };
72 
73 template <typename Device, typename T>
74 struct BatchNormGrad {
operatorBatchNormGrad75   void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
76                   typename TTypes<T>::ConstVec mean,
77                   typename TTypes<T>::ConstVec var,
78                   typename TTypes<T>::ConstVec gamma,
79                   typename TTypes<T, 4>::ConstTensor out_backprop,
80                   T variance_epsilon, bool scale_after_normalization,
81                   typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm,
82                   typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db,
83                   typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1,
84                   typename TTypes<T>::Vec scratch2) {
85     const int depth = mean.dimension(0);
86     const int rest_size = input.size() / depth;
87 
88     typedef typename TTypes<T>::ConstVec::Index Index;
89 
90     Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth);
91 #if !defined(EIGEN_HAS_INDEX_LIST)
92     Eigen::DSizes<Index, 2> rest_by_one(rest_size, 1);
93     Eigen::DSizes<Index, 2> one_by_depth(1, depth);
94     Eigen::array<Index, 1> reduction_axis;
95     reduction_axis[0] = 0;  // Reduces on first dimension.
96 #else
97     Eigen::IndexList<Index, Eigen::type2index<1> > rest_by_one;
98     rest_by_one.set(0, rest_size);
99     Eigen::IndexList<Eigen::type2index<1>, Index> one_by_depth;
100     one_by_depth.set(1, depth);
101     Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
102 #endif
103 
104     // db = out_backprop
105     //
106     // dg = out_backprop * ((x - m) * rsqrt(v + epsilon))
107     //
108     // dv = sum_over_rest(out_backprop * gamma * (x - m)) *
109     //      (-1/2) * (v + epsilon) ^ (-3/2)
110     //
111     // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon))
112     //
113     // dx = out_backprop * (gamma * rsqrt(v + epsilon))
114     db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis);
115 
116     // scratch1 = rsqrt(v + epsilon)
117     scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt();
118 
119     // scratch2 = sum_over_rest(out_backprop * (x - m))
120     scratch2.device(d) = (out_backprop.reshape(rest_by_depth) *
121                           (input.reshape(rest_by_depth) -
122                            mean.reshape(one_by_depth).broadcast(rest_by_one)))
123                              .sum(reduction_axis);
124 
125     if (scale_after_normalization) {
126       dx.reshape(rest_by_depth).device(d) =
127           out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma)
128                                                      .eval()
129                                                      .reshape(one_by_depth)
130                                                      .broadcast(rest_by_one));
131       dm.device(d) = -db * (scratch1 * gamma).eval();
132       dg.device(d) = scratch2 * scratch1;
133     } else {
134       dx.reshape(rest_by_depth).device(d) =
135           out_backprop.reshape(rest_by_depth) *
136           scratch1.reshape(one_by_depth).broadcast(rest_by_one);
137       dm.device(d) = -db * scratch1;
138       dg.device(d) = dg.constant(static_cast<T>(0.0));  // Gamma is not learned.
139     }
140 
141     // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2)
142     scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) /
143                          (var + var.constant(variance_epsilon));
144 
145     if (scale_after_normalization) {
146       dv.device(d) = scratch2 * (scratch1 * gamma).eval();
147     } else {
148       dv.device(d) = scratch2 * scratch1;
149     }
150   }
151 };
152 
153 }  // namespace functor
154 }  // namespace tensorflow
155 
156 #endif  // TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
157