1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ 18 19 // Functor definitions for Reduction ops, must be compilable by nvcc. 20 21 #include <iostream> 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 26 namespace tensorflow { 27 namespace functor { 28 29 // Dummy class used for template specialization for mean reduction, which is 30 // accomplished by SumReducer and on-the-fly division by the reduction factor. 31 template <typename Scalar> 32 struct MeanReducer { initializeMeanReducer33 Scalar initialize() const { return Scalar(0); } 34 }; 35 36 // Dummy class used for template specialization for l2-norm reduction. 37 template <typename Scalar> 38 struct EuclideanNormReducer { initializeEuclideanNormReducer39 Scalar initialize() const { return Scalar(0); } 40 }; 41 42 template <typename Device, typename OUT_T, typename IN_T, 43 typename ReductionAxes, typename Reducer> 44 struct ReduceEigenImpl { operatorReduceEigenImpl45 void operator()(const Device& d, OUT_T out, IN_T in, 46 const ReductionAxes& reduction_axes, const Reducer& reducer) { 47 out.device(d) = in.reduce(reduction_axes, reducer); 48 } 49 }; 50 51 template <typename Device, typename OUT_T, typename IN_T, 52 typename ReductionAxes, typename Scalar> 53 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, 54 functor::MeanReducer<Scalar>> { 55 void operator()(const Device& d, OUT_T out, IN_T in, 56 const ReductionAxes& reduction_axes, 57 const functor::MeanReducer<Scalar>& reducer) { 58 static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, ""); 59 Eigen::internal::SumReducer<Scalar> sum_reducer; 60 out.device(d) = in.reduce(reduction_axes, sum_reducer) / 61 static_cast<Scalar>(in.size() / out.size()); 62 } 63 }; 64 65 // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional 66 // controlled by an attribute. 67 template <typename Device, typename OUT_T, typename IN_T, 68 typename ReductionAxes, typename Scalar> 69 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, 70 functor::EuclideanNormReducer<Scalar>> { 71 void operator()(const Device& d, OUT_T out, IN_T in, 72 const ReductionAxes& reduction_axes, 73 const functor::EuclideanNormReducer<Scalar>& reducer) { 74 static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, ""); 75 Eigen::internal::SumReducer<Scalar> sum_reducer; 76 out.device(d) = 77 (in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt(); 78 } 79 }; 80 81 template <typename Device, typename OUT_T, typename IN_T, 82 typename ReductionAxes> 83 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, 84 functor::EuclideanNormReducer<bfloat16>> { 85 void operator()(const Device& d, OUT_T out, IN_T in, 86 const ReductionAxes& reduction_axes, 87 const functor::EuclideanNormReducer<bfloat16>& reducer) { 88 static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, ""); 89 Eigen::internal::SumReducer<float> sum_reducer; 90 auto in_as_float = in.template cast<float>(); 91 out.device(d) = (in_as_float * in_as_float.conjugate()) 92 .reduce(reduction_axes, sum_reducer) 93 .sqrt() 94 .template cast<bfloat16>(); 95 } 96 }; 97 98 // For most reducers, the identity is Reducer::initialize() 99 template <typename Reducer> 100 struct Identity { 101 static auto identity(const Reducer& reducer) 102 -> decltype(reducer.initialize()) { 103 return reducer.initialize(); 104 } 105 }; 106 107 // MeanReducer is a special case, since it doesn't technically have an identity. 108 // Thus, ideally we'd return nan. However, mean is instantiated for integer 109 // types as well, so we do the nan override only for floating point types. 110 #define FIX_MEAN_IDENTITY(T) \ 111 template <> \ 112 struct Identity<functor::MeanReducer<T>> { \ 113 static T identity(const functor::MeanReducer<T>&) { \ 114 return Eigen::NumTraits<T>::quiet_NaN(); \ 115 } \ 116 }; 117 FIX_MEAN_IDENTITY(Eigen::half) 118 FIX_MEAN_IDENTITY(float) 119 FIX_MEAN_IDENTITY(double) 120 FIX_MEAN_IDENTITY(complex64) 121 FIX_MEAN_IDENTITY(complex128) 122 #undef FIX_MEAN_IDENTITY 123 124 template <typename Device, typename OUT_T, typename Reducer> 125 void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) { 126 out.device(d) = out.constant(Identity<Reducer>::identity(reducer)); 127 } 128 129 template <typename Device, typename Reducer> 130 struct ReduceFunctor { 131 template <typename OUT_T, typename IN_T, typename ReductionAxes> 132 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 133 const ReductionAxes& reduction_axes, 134 const Reducer& reducer); 135 136 template <typename OUT_T> 137 static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer); 138 }; 139 140 } // namespace functor 141 } // namespace tensorflow 142 143 #endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ 144