• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
18 
19 // Functor definitions for Reduction ops, must be compilable by nvcc.
20 
21 #include <iostream>
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor_types.h"
25 
26 namespace tensorflow {
27 namespace functor {
28 
29 // Dummy class used for template specialization for mean reduction, which is
30 // accomplished by SumReducer and on-the-fly division by the reduction factor.
31 template <typename Scalar>
32 struct MeanReducer {
initializeMeanReducer33   Scalar initialize() const { return Scalar(0); }
34 };
35 
36 // Dummy class used for template specialization for l2-norm reduction.
37 template <typename Scalar>
38 struct EuclideanNormReducer {
initializeEuclideanNormReducer39   Scalar initialize() const { return Scalar(0); }
40 };
41 
42 template <typename Device, typename OUT_T, typename IN_T,
43           typename ReductionAxes, typename Reducer>
44 struct ReduceEigenImpl {
operatorReduceEigenImpl45   void operator()(const Device& d, OUT_T out, IN_T in,
46                   const ReductionAxes& reduction_axes, const Reducer& reducer) {
47     out.device(d) = in.reduce(reduction_axes, reducer);
48   }
49 };
50 
51 template <typename Device, typename OUT_T, typename IN_T,
52           typename ReductionAxes, typename Scalar>
53 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
54                        functor::MeanReducer<Scalar>> {
55   void operator()(const Device& d, OUT_T out, IN_T in,
56                   const ReductionAxes& reduction_axes,
57                   const functor::MeanReducer<Scalar>& reducer) {
58     static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
59     Eigen::internal::SumReducer<Scalar> sum_reducer;
60     out.device(d) = in.reduce(reduction_axes, sum_reducer) /
61                     static_cast<Scalar>(in.size() / out.size());
62   }
63 };
64 
65 // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
66 // controlled by an attribute.
67 template <typename Device, typename OUT_T, typename IN_T,
68           typename ReductionAxes, typename Scalar>
69 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
70                        functor::EuclideanNormReducer<Scalar>> {
71   void operator()(const Device& d, OUT_T out, IN_T in,
72                   const ReductionAxes& reduction_axes,
73                   const functor::EuclideanNormReducer<Scalar>& reducer) {
74     static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
75     Eigen::internal::SumReducer<Scalar> sum_reducer;
76     out.device(d) =
77         (in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt();
78   }
79 };
80 
81 template <typename Device, typename OUT_T, typename IN_T,
82           typename ReductionAxes>
83 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
84                        functor::EuclideanNormReducer<bfloat16>> {
85   void operator()(const Device& d, OUT_T out, IN_T in,
86                   const ReductionAxes& reduction_axes,
87                   const functor::EuclideanNormReducer<bfloat16>& reducer) {
88     static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, "");
89     Eigen::internal::SumReducer<float> sum_reducer;
90     auto in_as_float = in.template cast<float>();
91     out.device(d) = (in_as_float * in_as_float.conjugate())
92                         .reduce(reduction_axes, sum_reducer)
93                         .sqrt()
94                         .template cast<bfloat16>();
95   }
96 };
97 
98 // For most reducers, the identity is Reducer::initialize()
99 template <typename Reducer>
100 struct Identity {
101   static auto identity(const Reducer& reducer)
102       -> decltype(reducer.initialize()) {
103     return reducer.initialize();
104   }
105 };
106 
107 // MeanReducer is a special case, since it doesn't technically have an identity.
108 // Thus, ideally we'd return nan.  However, mean is instantiated for integer
109 // types as well, so we do the nan override only for floating point types.
110 #define FIX_MEAN_IDENTITY(T)                            \
111   template <>                                           \
112   struct Identity<functor::MeanReducer<T>> {            \
113     static T identity(const functor::MeanReducer<T>&) { \
114       return Eigen::NumTraits<T>::quiet_NaN();          \
115     }                                                   \
116   };
117 FIX_MEAN_IDENTITY(Eigen::half)
118 FIX_MEAN_IDENTITY(float)
119 FIX_MEAN_IDENTITY(double)
120 FIX_MEAN_IDENTITY(complex64)
121 FIX_MEAN_IDENTITY(complex128)
122 #undef FIX_MEAN_IDENTITY
123 
124 template <typename Device, typename OUT_T, typename Reducer>
125 void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) {
126   out.device(d) = out.constant(Identity<Reducer>::identity(reducer));
127 }
128 
129 template <typename Device, typename Reducer>
130 struct ReduceFunctor {
131   template <typename OUT_T, typename IN_T, typename ReductionAxes>
132   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
133                      const ReductionAxes& reduction_axes,
134                      const Reducer& reducer);
135 
136   template <typename OUT_T>
137   static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer);
138 };
139 
140 }  // namespace functor
141 }  // namespace tensorflow
142 
143 #endif  // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
144