• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/kernels/gather_nd_op.h"
20 #include "tensorflow/core/framework/bounds_check.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mem.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/util.h"
29 
30 namespace tensorflow {
31 
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 typedef Eigen::GpuDevice GPUDevice;
34 
35 template <typename Device, typename T, typename Index>
36 class GatherNdOp : public OpKernel {
37  public:
GatherNdOp(OpKernelConstruction * c)38   explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) {
39     const DataType dt = DataTypeToEnum<T>::v();
40     const DataType index_t = DataTypeToEnum<Index>::v();
41     OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
42   }
43 
Compute(OpKernelContext * c)44   void Compute(OpKernelContext* c) override {
45     const Tensor& params = c->input(0);
46     const Tensor& indices = c->input(1);
47 
48     Tensor out;
49     OP_REQUIRES_OK(
50         c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
51     c->set_output(0, out);
52   }
53 };
54 
55 #define REGISTER_GATHER_ND_FULL(dev, type, index_type)                 \
56   REGISTER_KERNEL_BUILDER(Name("GatherNd")                             \
57                               .Device(DEVICE_##dev)                    \
58                               .TypeConstraint<type>("Tparams")         \
59                               .TypeConstraint<index_type>("Tindices"), \
60                           GatherNdOp<dev##Device, type, index_type>)
61 
62 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
63   REGISTER_GATHER_ND_FULL(dev, type, int32);      \
64   REGISTER_GATHER_ND_FULL(dev, type, int64)
65 
66 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
67 
68 // TODO(ebrevdo): This is a pure data-movement kernel. It shouldn't be
69 // instantiated for all different types. Instead, all the types should
70 // be coalesced. So we should only have int8, int16, int32, int64 support.
71 // And float is redirected to int32, double is redirected to int64,
72 // and complex<float> is redirected to int32 with twice the number of
73 // entries, similarly for complex<double>.
74 //
75 // Same for the GPU kernel.
76 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
77 
78 #undef REGISTER_GATHER_ND_CPU
79 
80 namespace functor {
81 template <typename Device, typename T, typename Index>
DoGatherNd(OpKernelContext * c,const Tensor & params,const Tensor & indices,Tensor * out)82 Status DoGatherNd(OpKernelContext* c, const Tensor& params,
83                   const Tensor& indices, Tensor* out) {
84   if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
85     return errors::InvalidArgument("params must be at least a vector");
86   }
87   if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) {
88     return errors::InvalidArgument("indices must be at least a vector");
89   }
90   if (indices.dim_size(indices.dims() - 1) > params.dims()) {
91     return errors::InvalidArgument(
92         "index innermost dimension length must be <= params rank; saw: ",
93         indices.dim_size(indices.dims() - 1), " vs. ", params.dims());
94   }
95 
96   const TensorShape& indices_shape(indices.shape());
97   const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
98 
99   // Check that we have enough index space
100   int64 N_big = 1;
101   for (int i = 0; i < indices_shape.dims() - 1; ++i) {
102     N_big *= indices_shape.dim_size(i);
103   }
104   if (N_big > std::numeric_limits<int>::max()) {
105     return errors::InvalidArgument(
106         "indices has too many elements for int indexing: ", N_big, " > ",
107         std::numeric_limits<int>::max());
108   }
109   if (params.NumElements() > std::numeric_limits<Index>::max()) {
110     return errors::InvalidArgument("params.NumElements() too large for ",
111                                    DataTypeString(DataTypeToEnum<Index>::v()),
112                                    " indexing: ", params.NumElements(), " > ",
113                                    std::numeric_limits<Index>::max());
114   }
115 
116   // The result shape is
117   //   indices.shape[:-1] + params.shape[indices.shape[-1]:]
118   Index N_result = 1;
119   for (int i = 0; i < indices_shape.dims() - 1; ++i) {
120     N_result *= indices_shape.dim_size(i);
121   }
122 
123   const TensorShape& params_shape(params.shape());
124   Index total_nd = params_shape.dims();
125 
126   TensorShape result_shape(indices_shape);
127   result_shape.RemoveLastDims(1);
128 
129   int64 slice_size_big = 1;
130   for (Index i = indices_nd; i < total_nd; ++i) {
131     slice_size_big *= params_shape.dim_size(i);
132     result_shape.AddDim(params_shape.dim_size(i));
133   }
134 
135   if (slice_size_big > std::numeric_limits<Index>::max()) {
136     return errors::InvalidArgument(
137         "slice size is too large for indexing: ", slice_size_big, " > ",
138         std::numeric_limits<Index>::max());
139   }
140 
141   const Index slice_size = static_cast<Index>(slice_size_big);
142 
143   TF_RETURN_IF_ERROR(
144       c->allocate_temp(DataTypeToEnum<T>::value, result_shape, out));
145 
146   if (N_result > 0) {
147     if (params_shape.num_elements() == 0) {
148       return errors::InvalidArgument(
149           "Requested more than 0 entries, but "
150           "params is empty.  Params shape: ",
151           params_shape.DebugString());
152     }
153 
154     auto indices_mat = indices.flat_inner_dims<Index>();
155 
156     Index bad_i = -1;
157 
158     // Request to copy slices / subtensors
159     // Make out a matrix with the slices the col size.
160     auto out_mat = out->shaped<T, 2>({N_result, slice_size});
161     Tensor scratch;
162     TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch));
163     auto scratch_scalar = scratch.scalar<int32>();
164 
165     switch (indices_nd) {
166 #define PARAMS_CASE(IXDIM)                                              \
167   case IXDIM: {                                                         \
168     functor::GatherNdSlice<Device, T, Index, IXDIM> func;               \
169     auto params_flat = params.flat_outer_dims<T, IXDIM + 1>();          \
170     bad_i = func(c->eigen_device<Device>(), slice_size, scratch_scalar, \
171                  params_flat, indices_mat, out_mat);                    \
172   } break
173       PARAMS_CASE(0);
174       PARAMS_CASE(1);
175       PARAMS_CASE(2);
176       PARAMS_CASE(3);
177       PARAMS_CASE(4);
178       PARAMS_CASE(5);
179       PARAMS_CASE(6);
180       PARAMS_CASE(7);
181 #undef PARAMS_CASE
182       default:
183         return errors::InvalidArgument(
184             "Only indices.shape[-1] values between 1 and 7 "
185             "are currently supported.  Requested rank: ",
186             indices_nd);
187     }
188 
189     // bad_i will only return >= 0 on CPUs right now.
190     if (bad_i >= 0) {
191       auto shape = indices.shape();
192       shape.RemoveLastDims(1);
193       return errors::InvalidArgument(
194           "indices", SliceDebugString(shape, bad_i), " = [",
195           str_util::Join(
196               gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "),
197           "] does not index into param shape ", params.shape().DebugString());
198     }
199   }
200   return Status::OK();
201 }
202 
203 }  // namespace functor
204 
205 #if GOOGLE_CUDA
206 // Forward declarations of the functor specializations for GPU.
207 namespace functor {
208 #define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM)          \
209   template <>                                                 \
210   Index GatherNdSlice<GPUDevice, T, Index, NDIM>::operator()( \
211       const GPUDevice& d, const Index slice_size,             \
212       typename TTypes<int32>::Scalar Tscratch,                \
213       typename TTypes<T, NDIM + 1>::ConstTensor Tparams,      \
214       typename TTypes<Index>::ConstMatrix Tindices,           \
215       typename TTypes<T>::Matrix Tout);                       \
216   extern template struct GatherNdSlice<GPUDevice, T, Index, NDIM>;
217 
218 #define DECLARE_GPU_SPECS_INDEX(T, Index)    \
219   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 0); \
220   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
221   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
222   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
223   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
224   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5); \
225   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 6); \
226   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 7);
227 
228 #define DECLARE_GPU_SPECS(T)         \
229   DECLARE_GPU_SPECS_INDEX(T, int32); \
230   DECLARE_GPU_SPECS_INDEX(T, int64)
231 
232 TF_CALL_int32(DECLARE_GPU_SPECS);
233 TF_CALL_int64(DECLARE_GPU_SPECS);
234 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
235 TF_CALL_complex64(DECLARE_GPU_SPECS);
236 TF_CALL_complex128(DECLARE_GPU_SPECS);
237 
238 #undef DECLARE_GPU_SPECS
239 #undef DECLARE_GPU_SPECS_INDEX
240 }  // namespace functor
241 
242 // Registration of the GPU implementations.
243 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
244 
245 TF_CALL_int32(REGISTER_GATHER_ND_GPU);
246 TF_CALL_int64(REGISTER_GATHER_ND_GPU);
247 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
248 TF_CALL_complex64(REGISTER_GATHER_ND_GPU);
249 TF_CALL_complex128(REGISTER_GATHER_ND_GPU);
250 
251 #undef REGISTER_GATHER_ND_GPU
252 
253 #endif  // GOOGLE_CUDA
254 
255 #undef REGISTER_GATHER_ND_ALL_INDICES
256 #undef REGISTER_GATHER_ND_FULL
257 
258 }  // namespace tensorflow
259