1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17 #define EIGEN_USE_THREADS
18
19 #include "tensorflow/core/kernels/gather_nd_op.h"
20 #include "tensorflow/core/framework/bounds_check.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mem.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/util.h"
29
30 namespace tensorflow {
31
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 typedef Eigen::GpuDevice GPUDevice;
34
35 template <typename Device, typename T, typename Index>
36 class GatherNdOp : public OpKernel {
37 public:
GatherNdOp(OpKernelConstruction * c)38 explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) {
39 const DataType dt = DataTypeToEnum<T>::v();
40 const DataType index_t = DataTypeToEnum<Index>::v();
41 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
42 }
43
Compute(OpKernelContext * c)44 void Compute(OpKernelContext* c) override {
45 const Tensor& params = c->input(0);
46 const Tensor& indices = c->input(1);
47
48 Tensor out;
49 OP_REQUIRES_OK(
50 c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
51 c->set_output(0, out);
52 }
53 };
54
55 #define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
56 REGISTER_KERNEL_BUILDER(Name("GatherNd") \
57 .Device(DEVICE_##dev) \
58 .TypeConstraint<type>("Tparams") \
59 .TypeConstraint<index_type>("Tindices"), \
60 GatherNdOp<dev##Device, type, index_type>)
61
62 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
63 REGISTER_GATHER_ND_FULL(dev, type, int32); \
64 REGISTER_GATHER_ND_FULL(dev, type, int64)
65
66 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
67
68 // TODO(ebrevdo): This is a pure data-movement kernel. It shouldn't be
69 // instantiated for all different types. Instead, all the types should
70 // be coalesced. So we should only have int8, int16, int32, int64 support.
71 // And float is redirected to int32, double is redirected to int64,
72 // and complex<float> is redirected to int32 with twice the number of
73 // entries, similarly for complex<double>.
74 //
75 // Same for the GPU kernel.
76 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
77
78 #undef REGISTER_GATHER_ND_CPU
79
80 namespace functor {
81 template <typename Device, typename T, typename Index>
DoGatherNd(OpKernelContext * c,const Tensor & params,const Tensor & indices,Tensor * out)82 Status DoGatherNd(OpKernelContext* c, const Tensor& params,
83 const Tensor& indices, Tensor* out) {
84 if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
85 return errors::InvalidArgument("params must be at least a vector");
86 }
87 if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) {
88 return errors::InvalidArgument("indices must be at least a vector");
89 }
90 if (indices.dim_size(indices.dims() - 1) > params.dims()) {
91 return errors::InvalidArgument(
92 "index innermost dimension length must be <= params rank; saw: ",
93 indices.dim_size(indices.dims() - 1), " vs. ", params.dims());
94 }
95
96 const TensorShape& indices_shape(indices.shape());
97 const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
98
99 // Check that we have enough index space
100 int64 N_big = 1;
101 for (int i = 0; i < indices_shape.dims() - 1; ++i) {
102 N_big *= indices_shape.dim_size(i);
103 }
104 if (N_big > std::numeric_limits<int>::max()) {
105 return errors::InvalidArgument(
106 "indices has too many elements for int indexing: ", N_big, " > ",
107 std::numeric_limits<int>::max());
108 }
109 if (params.NumElements() > std::numeric_limits<Index>::max()) {
110 return errors::InvalidArgument("params.NumElements() too large for ",
111 DataTypeString(DataTypeToEnum<Index>::v()),
112 " indexing: ", params.NumElements(), " > ",
113 std::numeric_limits<Index>::max());
114 }
115
116 // The result shape is
117 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
118 Index N_result = 1;
119 for (int i = 0; i < indices_shape.dims() - 1; ++i) {
120 N_result *= indices_shape.dim_size(i);
121 }
122
123 const TensorShape& params_shape(params.shape());
124 Index total_nd = params_shape.dims();
125
126 TensorShape result_shape(indices_shape);
127 result_shape.RemoveLastDims(1);
128
129 int64 slice_size_big = 1;
130 for (Index i = indices_nd; i < total_nd; ++i) {
131 slice_size_big *= params_shape.dim_size(i);
132 result_shape.AddDim(params_shape.dim_size(i));
133 }
134
135 if (slice_size_big > std::numeric_limits<Index>::max()) {
136 return errors::InvalidArgument(
137 "slice size is too large for indexing: ", slice_size_big, " > ",
138 std::numeric_limits<Index>::max());
139 }
140
141 const Index slice_size = static_cast<Index>(slice_size_big);
142
143 TF_RETURN_IF_ERROR(
144 c->allocate_temp(DataTypeToEnum<T>::value, result_shape, out));
145
146 if (N_result > 0) {
147 if (params_shape.num_elements() == 0) {
148 return errors::InvalidArgument(
149 "Requested more than 0 entries, but "
150 "params is empty. Params shape: ",
151 params_shape.DebugString());
152 }
153
154 auto indices_mat = indices.flat_inner_dims<Index>();
155
156 Index bad_i = -1;
157
158 // Request to copy slices / subtensors
159 // Make out a matrix with the slices the col size.
160 auto out_mat = out->shaped<T, 2>({N_result, slice_size});
161 Tensor scratch;
162 TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch));
163 auto scratch_scalar = scratch.scalar<int32>();
164
165 switch (indices_nd) {
166 #define PARAMS_CASE(IXDIM) \
167 case IXDIM: { \
168 functor::GatherNdSlice<Device, T, Index, IXDIM> func; \
169 auto params_flat = params.flat_outer_dims<T, IXDIM + 1>(); \
170 bad_i = func(c->eigen_device<Device>(), slice_size, scratch_scalar, \
171 params_flat, indices_mat, out_mat); \
172 } break
173 PARAMS_CASE(0);
174 PARAMS_CASE(1);
175 PARAMS_CASE(2);
176 PARAMS_CASE(3);
177 PARAMS_CASE(4);
178 PARAMS_CASE(5);
179 PARAMS_CASE(6);
180 PARAMS_CASE(7);
181 #undef PARAMS_CASE
182 default:
183 return errors::InvalidArgument(
184 "Only indices.shape[-1] values between 1 and 7 "
185 "are currently supported. Requested rank: ",
186 indices_nd);
187 }
188
189 // bad_i will only return >= 0 on CPUs right now.
190 if (bad_i >= 0) {
191 auto shape = indices.shape();
192 shape.RemoveLastDims(1);
193 return errors::InvalidArgument(
194 "indices", SliceDebugString(shape, bad_i), " = [",
195 str_util::Join(
196 gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "),
197 "] does not index into param shape ", params.shape().DebugString());
198 }
199 }
200 return Status::OK();
201 }
202
203 } // namespace functor
204
205 #if GOOGLE_CUDA
206 // Forward declarations of the functor specializations for GPU.
207 namespace functor {
208 #define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
209 template <> \
210 Index GatherNdSlice<GPUDevice, T, Index, NDIM>::operator()( \
211 const GPUDevice& d, const Index slice_size, \
212 typename TTypes<int32>::Scalar Tscratch, \
213 typename TTypes<T, NDIM + 1>::ConstTensor Tparams, \
214 typename TTypes<Index>::ConstMatrix Tindices, \
215 typename TTypes<T>::Matrix Tout); \
216 extern template struct GatherNdSlice<GPUDevice, T, Index, NDIM>;
217
218 #define DECLARE_GPU_SPECS_INDEX(T, Index) \
219 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 0); \
220 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
221 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
222 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
223 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
224 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5); \
225 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 6); \
226 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 7);
227
228 #define DECLARE_GPU_SPECS(T) \
229 DECLARE_GPU_SPECS_INDEX(T, int32); \
230 DECLARE_GPU_SPECS_INDEX(T, int64)
231
232 TF_CALL_int32(DECLARE_GPU_SPECS);
233 TF_CALL_int64(DECLARE_GPU_SPECS);
234 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
235 TF_CALL_complex64(DECLARE_GPU_SPECS);
236 TF_CALL_complex128(DECLARE_GPU_SPECS);
237
238 #undef DECLARE_GPU_SPECS
239 #undef DECLARE_GPU_SPECS_INDEX
240 } // namespace functor
241
242 // Registration of the GPU implementations.
243 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
244
245 TF_CALL_int32(REGISTER_GATHER_ND_GPU);
246 TF_CALL_int64(REGISTER_GATHER_ND_GPU);
247 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
248 TF_CALL_complex64(REGISTER_GATHER_ND_GPU);
249 TF_CALL_complex128(REGISTER_GATHER_ND_GPU);
250
251 #undef REGISTER_GATHER_ND_GPU
252
253 #endif // GOOGLE_CUDA
254
255 #undef REGISTER_GATHER_ND_ALL_INDICES
256 #undef REGISTER_GATHER_ND_FULL
257
258 } // namespace tensorflow
259