1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ 18 19 // Specialization of GatherNdSlice to CPU 20 21 #define EIGEN_USE_THREADS 22 23 #include <atomic> 24 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/kernels/gather_nd_op.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/mem.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/util.h" 34 35 namespace tensorflow { 36 37 typedef Eigen::ThreadPoolDevice CPUDevice; 38 39 namespace generator { 40 41 template <typename T, typename Index, int IXDIM> 42 class GatherNdSliceGenerator { 43 public: GatherNdSliceGenerator(const Index slice_size,typename TTypes<Index>::ConstMatrix Tindices,typename TTypes<T,IXDIM+1>::ConstTensor Tparams,typename TTypes<T>::Matrix Tout,std::atomic<Index> * error_loc)44 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator( 45 const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices, 46 typename TTypes<T, IXDIM + 1>::ConstTensor Tparams, 47 typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc) 48 : slice_size_(slice_size), 49 Tindices_(Tindices), 50 Tparams_(Tparams), 51 Tout_(Tout), 52 error_loc_(error_loc) {} 53 GenerateIndices(const Index loc,Eigen::array<Eigen::DenseIndex,IXDIM+1> * ix)54 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices( 55 const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const { 56 (*ix)[IXDIM] = 0; 57 bool out_of_bounds = false; 58 for (int i = 0; i < IXDIM; ++i) { 59 const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); 60 (*ix)[i] = ix_i; 61 out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); 62 } 63 return out_of_bounds; 64 } 65 66 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 operator()67 operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const { 68 const Index loc = loc_array[0]; 69 Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix; 70 Eigen::array<Eigen::DenseIndex, 2> ix_out; 71 ix_out[0] = loc; 72 ix_out[1] = 0; 73 const bool out_of_bounds = GenerateIndices(loc, &ix); 74 if (TF_PREDICT_FALSE(out_of_bounds)) { 75 error_loc_->store(loc); 76 std::fill_n(&Tout_(ix_out), slice_size_, T()); 77 } else { 78 std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out)); 79 } 80 81 return static_cast<int32>(0); // Return something... 82 } 83 84 private: 85 const Index slice_size_; 86 const typename TTypes<Index>::ConstMatrix Tindices_; 87 const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_; 88 mutable typename TTypes<T>::Matrix Tout_; 89 std::atomic<Index>* error_loc_; 90 }; 91 92 } // namespace generator 93 94 namespace functor { 95 96 template <typename T, typename Index, int IXDIM> 97 struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { 98 Index operator()(const CPUDevice& d, const Index slice_size, 99 typename TTypes<int32>::Scalar Tscratch, 100 typename TTypes<T, IXDIM + 1>::ConstTensor Tparams, 101 typename TTypes<Index>::ConstMatrix Tindices, 102 typename TTypes<T>::Matrix Tout) { 103 std::atomic<Index> error_loc(-1); 104 105 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 106 #if !defined(EIGEN_HAS_INDEX_LIST) 107 Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }}; 108 Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }}; 109 #else 110 Eigen::IndexList<Eigen::type2index<1> > reshape_dims; 111 Eigen::IndexList<Eigen::DenseIndex> broadcast_dims; 112 broadcast_dims.set(0, batch_size); 113 #endif 114 generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator( 115 slice_size, Tindices, Tparams, Tout, &error_loc); 116 117 #if defined(INTEL_MKL) && defined(ENABLE_MKL) 118 // Eigen implementation below is not highly performant. gather_nd_generator 119 // does not seem to be called in parallel, leading to very poor performance. 120 // Additionally, since it uses scalar (Tscratch) to invoke 'generate', it 121 // needs to go through redundant operations like 'reshape', 'broadcast' and 122 // 'sum'. OpenMP loop below essentially does same thing as Eigen code, but 123 // is considerably more efficient. 124 #pragma omp parallel for 125 for (Eigen::DenseIndex i = 0; i < batch_size; i++) { 126 const Eigen::array<Eigen::DenseIndex, 1> loc{i}; 127 gather_nd_generator(loc); 128 } 129 #else // INTEL_MKL && ENABLE_MKL 130 Tscratch.device(d) = Tscratch.reshape(reshape_dims) 131 .broadcast(broadcast_dims) 132 .generate(gather_nd_generator) 133 .sum(); 134 #endif // INTEL_MKL && ENABLE_MKL 135 136 // error_loc() returns -1 if there's no out-of-bounds index, 137 // otherwise it returns the location of an OOB index in Tindices. 138 return error_loc.load(); 139 } 140 }; 141 142 #define REGISTER_GATHER_ND_FULL(T, Index) \ 143 template Index GatherNdSlice<CPUDevice, T, Index, CPU_PROVIDED_IXDIM>:: \ 144 operator()(const CPUDevice& d, const Index slice_size, \ 145 typename TTypes<int32>::Scalar Tscratch, \ 146 typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::ConstTensor Tparams, \ 147 typename TTypes<Index>::ConstMatrix Tindices, \ 148 typename TTypes<T>::Matrix Tout); 149 150 #define REGISTER_GATHER_ND_CPU(type) \ 151 REGISTER_GATHER_ND_FULL(type, int32); \ 152 REGISTER_GATHER_ND_FULL(type, int64) 153 154 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); 155 156 } // namespace functor 157 158 } // namespace tensorflow 159 160 #endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ 161