• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
18 
19 // Specialization of GatherNdSlice to CPU
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <atomic>
24 
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/kernels/gather_nd_op.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mem.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/util.h"
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 
39 namespace generator {
40 
41 template <typename T, typename Index, int IXDIM>
42 class GatherNdSliceGenerator {
43  public:
GatherNdSliceGenerator(const Index slice_size,typename TTypes<Index>::ConstMatrix Tindices,typename TTypes<T,IXDIM+1>::ConstTensor Tparams,typename TTypes<T>::Matrix Tout,std::atomic<Index> * error_loc)44   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator(
45       const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices,
46       typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
47       typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc)
48       : slice_size_(slice_size),
49         Tindices_(Tindices),
50         Tparams_(Tparams),
51         Tout_(Tout),
52         error_loc_(error_loc) {}
53 
GenerateIndices(const Index loc,Eigen::array<Eigen::DenseIndex,IXDIM+1> * ix)54   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices(
55       const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
56     (*ix)[IXDIM] = 0;
57     bool out_of_bounds = false;
58     for (int i = 0; i < IXDIM; ++i) {
59       const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
60       (*ix)[i] = ix_i;
61       out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
62     }
63     return out_of_bounds;
64   }
65 
66   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
operator()67   operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
68     const Index loc = loc_array[0];
69     Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix;
70     Eigen::array<Eigen::DenseIndex, 2> ix_out;
71     ix_out[0] = loc;
72     ix_out[1] = 0;
73     const bool out_of_bounds = GenerateIndices(loc, &ix);
74     if (TF_PREDICT_FALSE(out_of_bounds)) {
75       error_loc_->store(loc);
76       std::fill_n(&Tout_(ix_out), slice_size_, T());
77     } else {
78       std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out));
79     }
80 
81     return static_cast<int32>(0);  // Return something...
82   }
83 
84  private:
85   const Index slice_size_;
86   const typename TTypes<Index>::ConstMatrix Tindices_;
87   const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_;
88   mutable typename TTypes<T>::Matrix Tout_;
89   std::atomic<Index>* error_loc_;
90 };
91 
92 }  // namespace generator
93 
94 namespace functor {
95 
96 template <typename T, typename Index, int IXDIM>
97 struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
98   Index operator()(const CPUDevice& d, const Index slice_size,
99                    typename TTypes<int32>::Scalar Tscratch,
100                    typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
101                    typename TTypes<Index>::ConstMatrix Tindices,
102                    typename TTypes<T>::Matrix Tout) {
103     std::atomic<Index> error_loc(-1);
104 
105     const Eigen::DenseIndex batch_size = Tindices.dimension(0);
106 #if !defined(EIGEN_HAS_INDEX_LIST)
107     Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
108     Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
109 #else
110     Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
111     Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
112     broadcast_dims.set(0, batch_size);
113 #endif
114     generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
115         slice_size, Tindices, Tparams, Tout, &error_loc);
116 
117 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
118 // Eigen implementation below is not highly performant. gather_nd_generator
119 // does not seem to be called in parallel, leading to very poor performance.
120 // Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
121 // needs to go through redundant operations like 'reshape', 'broadcast' and
122 // 'sum'. OpenMP loop below essentially does same thing as Eigen code, but
123 // is considerably more efficient.
124 #pragma omp parallel for
125     for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
126       const Eigen::array<Eigen::DenseIndex, 1> loc{i};
127       gather_nd_generator(loc);
128     }
129 #else   // INTEL_MKL && ENABLE_MKL
130     Tscratch.device(d) = Tscratch.reshape(reshape_dims)
131                              .broadcast(broadcast_dims)
132                              .generate(gather_nd_generator)
133                              .sum();
134 #endif  // INTEL_MKL && ENABLE_MKL
135 
136     // error_loc() returns -1 if there's no out-of-bounds index,
137     // otherwise it returns the location of an OOB index in Tindices.
138     return error_loc.load();
139   }
140 };
141 
142 #define REGISTER_GATHER_ND_FULL(T, Index)                                     \
143   template Index GatherNdSlice<CPUDevice, T, Index, CPU_PROVIDED_IXDIM>::     \
144   operator()(const CPUDevice& d, const Index slice_size,                      \
145              typename TTypes<int32>::Scalar Tscratch,                         \
146              typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::ConstTensor Tparams, \
147              typename TTypes<Index>::ConstMatrix Tindices,                    \
148              typename TTypes<T>::Matrix Tout);
149 
150 #define REGISTER_GATHER_ND_CPU(type)    \
151   REGISTER_GATHER_ND_FULL(type, int32); \
152   REGISTER_GATHER_ND_FULL(type, int64)
153 
154 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
155 
156 }  // namespace functor
157 
158 }  // namespace tensorflow
159 
160 #endif  // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
161