1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/state_ops.cc.
17 #define EIGEN_USE_THREADS
18
19 #if GOOGLE_CUDA
20 #define EIGEN_USE_GPU
21 #endif // GOOGLE_CUDA
22
23 #include "tensorflow/core/kernels/scatter_nd_op.h"
24
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/kernels/dense_update_functor.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/inplace_ops_functor.h"
33 #include "tensorflow/core/kernels/training_op_helpers.h"
34 #include "tensorflow/core/kernels/variable_ops.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/util.h"
39
40 #ifdef TENSORFLOW_USE_SYCL
41 #include "tensorflow/core/common_runtime/sycl/sycl_util.h"
42 #endif // TENSORFLOW_USE_SYCL
43
44 namespace tensorflow {
45
46 typedef Eigen::ThreadPoolDevice CPUDevice;
47 typedef Eigen::GpuDevice GPUDevice;
48 #ifdef TENSORFLOW_USE_SYCL
49 typedef Eigen::SyclDevice SYCLDevice;
50 #endif // TENSORFLOW_USE_SYCL
51
52 // Returns true if the three tensors have valid number of elements
53 // If shape_input has 0 elements, then we need to have indices and updates with
54 // exactly 0 elements too, otherwise we should error. If indices has 0 elements
55 // then updates should also have 0 elements, otherwise we should error.
ValidEmptyOutputShape(int64 num_inputs,int64 num_indices,int64 num_updates)56 bool ValidEmptyOutputShape(int64 num_inputs, int64 num_indices,
57 int64 num_updates) {
58 if (num_indices == 0 && num_updates == 0) {
59 return true; // regardless of num_inputs ?= 0, covers both cases
60 }
61 // now we want all 3 tensors to have values
62 return (num_inputs != 0 && num_indices != 0 && num_updates != 0);
63 }
64
65 template <typename Device, typename T, typename Index>
66 class ScatterNdOp : public OpKernel {
67 public:
ScatterNdOp(OpKernelConstruction * c)68 explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
69 const DataType dt = DataTypeToEnum<T>::v();
70 const DataType index_t = DataTypeToEnum<Index>::v();
71 OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
72 }
73
Compute(OpKernelContext * c)74 void Compute(OpKernelContext* c) override {
75 const Tensor& indices = c->input(0);
76 const Tensor& updates = c->input(1);
77 const Tensor& shape_input = c->input(2);
78
79 OP_REQUIRES(c, indices.shape().dims() >= 1,
80 errors::InvalidArgument(
81 "Indices shape must have rank at least one. Found:",
82 indices.shape().DebugString()));
83 OP_REQUIRES(c, updates.shape().dims() >= 1,
84 errors::InvalidArgument(
85 "Updates shape must have rank at least one. Found:",
86 updates.shape().DebugString()));
87
88 auto vec = shape_input.flat<Index>();
89 TensorShape shape;
90 OP_REQUIRES_OK(c,
91 TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
92
93 OP_REQUIRES(c,
94 ValidEmptyOutputShape(shape_input.NumElements(),
95 indices.shape().num_elements(),
96 updates.shape().num_elements()),
97 errors::InvalidArgument(
98 "Indices and updates specified for empty output shape"));
99
100 const int64 outer_dims = indices.shape().dims() - 1;
101
102 for (int i = 0; i < outer_dims; ++i) {
103 OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
104 errors::InvalidArgument(
105 "Outer dimensions of indices and update must match. "
106 "Indices shape: ",
107 indices.shape().DebugString(),
108 ", updates shape:", updates.shape().DebugString()));
109 }
110
111 const int64 ix = indices.shape().dim_size(outer_dims);
112 OP_REQUIRES(
113 c, updates.shape().dims() - outer_dims == shape.dims() - ix,
114 errors::InvalidArgument("Inner dimensions of output shape must match "
115 "inner dimensions of updates shape. Output: ",
116 shape.DebugString(),
117 " updates: ", updates.shape().DebugString()));
118 for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
119 OP_REQUIRES(
120 c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
121 errors::InvalidArgument(
122 "The inner ", shape.dims() - ix,
123 " dimensions of output.shape=", shape.DebugString(),
124 " must match the inner ", updates.shape().dims() - outer_dims,
125 " dimensions of updates.shape=", updates.shape().DebugString()));
126 }
127 OP_REQUIRES(c, shape_input.dims() == 1,
128 errors::InvalidArgument("Shape must be a vector"));
129
130 Tensor out;
131 OP_REQUIRES_OK(
132 c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
133 c, indices, updates, shape, &out, true /*allocate*/));
134 c->set_output(0, out);
135 }
136 };
137
138 template <typename Device, typename T, typename Index,
139 scatter_nd_op::UpdateOp op>
140 class TensorScatterOp : public OpKernel {
141 public:
TensorScatterOp(OpKernelConstruction * c)142 explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) {
143 const DataType dt = DataTypeToEnum<T>::v();
144 const DataType index_t = DataTypeToEnum<Index>::v();
145 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
146 }
147
Compute(OpKernelContext * c)148 void Compute(OpKernelContext* c) override {
149 const Tensor& input = c->input(0);
150 const Tensor& indices = c->input(1);
151 const Tensor& updates = c->input(2);
152
153 OP_REQUIRES(c, indices.shape().dims() >= 1,
154 errors::InvalidArgument(
155 "Indices shape must have rank at least one. Found:",
156 indices.shape().DebugString()));
157 OP_REQUIRES(c, updates.shape().dims() >= 1,
158 errors::InvalidArgument(
159 "Updates shape must have rank at least one. Found:",
160 updates.shape().DebugString()));
161
162 TensorShape shape = input.shape();
163
164 OP_REQUIRES(c,
165 ValidEmptyOutputShape(shape.num_elements(),
166 indices.shape().num_elements(),
167 updates.shape().num_elements()),
168 errors::InvalidArgument(
169 "Indices and updates specified for empty output shape"));
170
171 const int64 outer_dims = indices.shape().dims() - 1;
172
173 for (int i = 0; i < outer_dims; ++i) {
174 OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
175 errors::InvalidArgument(
176 "Outer dimensions of indices and update must match. "
177 "Indices shape: ",
178 indices.shape().DebugString(),
179 ", updates shape:", updates.shape().DebugString()));
180 }
181
182 const int64 ix = indices.shape().dim_size(outer_dims);
183 OP_REQUIRES(
184 c, updates.shape().dims() - outer_dims == shape.dims() - ix,
185 errors::InvalidArgument("Inner dimensions of output shape must match "
186 "inner dimensions of updates shape. Output: ",
187 shape.DebugString(),
188 " updates: ", updates.shape().DebugString()));
189 for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
190 OP_REQUIRES(
191 c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
192 errors::InvalidArgument(
193 "The inner ", shape.dims() - ix,
194 " dimensions of output.shape=", shape.DebugString(),
195 " must match the inner ", updates.shape().dims() - outer_dims,
196 " dimensions of updates.shape=", updates.shape().DebugString()));
197 }
198
199 std::unique_ptr<Tensor> forwarded_input = c->forward_input(
200 0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes());
201
202 if (forwarded_input == nullptr) {
203 // We were not able to forward the input, so we deep copy the tensor and
204 // set the output.
205 Tensor* out;
206 OP_REQUIRES_OK(c, c->allocate_output(0, input.shape(), &out));
207
208 OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(),
209 input, out));
210 OP_REQUIRES_OK(c,
211 functor::DoScatterNd<Device, T, Index, op>(
212 c, indices, updates, shape, out, false /*allocate*/));
213 } else {
214 // Output forwarded, so simply perform the scatter.
215 OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
216 c, indices, updates, shape, forwarded_input.get(),
217 false /*allocate*/));
218
219 c->set_output(0, *forwarded_input);
220 }
221 }
222 };
223
224 template <typename Device, typename T, typename Index,
225 scatter_nd_op::UpdateOp op>
226 class ScatterNdUpdateOp : public OpKernel {
227 public:
ScatterNdUpdateOp(OpKernelConstruction * c)228 explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
229 const DataType dt = DataTypeToEnum<T>::v();
230 const DataType dt_ref = DataTypeToEnum<T>::ref();
231 const DataType index_t = DataTypeToEnum<Index>::v();
232 dtype_ = c->input_type(0);
233 if (c->input_type(0) == DT_RESOURCE) {
234 // TODO(apassos): what to validate here?
235 } else if (IsRefType(c->input_type(0))) {
236 OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
237 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
238 } else {
239 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
240 use_exclusive_lock_ = false;
241 }
242 }
243
Compute(OpKernelContext * c)244 void Compute(OpKernelContext* c) override {
245 if (dtype_ == DT_RESOURCE) {
246 Var* v;
247 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
248 core::ScopedUnref scoped_unref(v);
249 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
250 mutex_lock m(*v->mu());
251 DoCompute(c);
252 } else if (use_exclusive_lock_) {
253 // If we're here, it means the input type is a ref.
254 DCHECK(IsRefType(c->input_dtype(0)));
255 // Hold mutex while we apply updates
256 mutex_lock l(*c->input_ref_mutex(0));
257 DoCompute(c);
258 } else {
259 DoCompute(c);
260 }
261 }
262
263 private:
264 DataType dtype_;
265 bool use_exclusive_lock_;
266
DoCompute(OpKernelContext * c)267 void DoCompute(OpKernelContext* c) {
268 const Tensor& indices = c->input(1);
269 const Tensor& updates = c->input(2);
270 Tensor params;
271 TensorShape params_shape;
272
273 if (dtype_ == DT_RESOURCE) {
274 Var* v;
275 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
276 Tensor* t = v->tensor();
277 params = *t;
278 params_shape = params.shape();
279 } else if (IsRefType(c->input_dtype(0))) {
280 params = c->mutable_input(0, use_exclusive_lock_);
281 params_shape = params.shape();
282 c->forward_ref_input_to_ref_output(0, 0);
283 OP_REQUIRES(c, params.IsInitialized(),
284 errors::FailedPrecondition("Null ref for params"));
285 } else {
286 Tensor* params_ptr;
287 params_shape = c->input(0).shape();
288 if (!c->forward_input_to_output_with_shape(0, 0, params_shape,
289 ¶ms_ptr)) {
290 // We weren't able to forward the input to output, so just
291 // allocate a new output tensor and copy the values over.
292 OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, ¶ms_ptr));
293 params = *params_ptr;
294 functor::DenseUpdate<Device, T, ASSIGN> copy;
295 const Tensor& input_copy = c->input(0);
296 copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>());
297 } else {
298 params = *params_ptr;
299 }
300 }
301
302 OP_REQUIRES_OK(
303 c, functor::DoScatterNd<Device, T, Index, op>(
304 c, indices, updates, params_shape, ¶ms, false /*allocate*/));
305 }
306 };
307
308 #define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
309 REGISTER_KERNEL_BUILDER(Name(name) \
310 .Device(DEVICE_##dev) \
311 .TypeConstraint<type>("T") \
312 .TypeConstraint<index_type>("Tindices") \
313 .HostMemory("shape"), \
314 ScatterNdOp<dev##Device, type, index_type>)
315
316 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
317 op) \
318 REGISTER_KERNEL_BUILDER( \
319 Name(name) \
320 .Device(DEVICE_##dev) \
321 .TypeConstraint<type>("T") \
322 .TypeConstraint<index_type>("Tindices"), \
323 ScatterNdUpdateOp<dev##Device, type, index_type, op>)
324
325 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
326 dev, name, op) \
327 REGISTER_KERNEL_BUILDER( \
328 Name(name) \
329 .Device(DEVICE_##dev) \
330 .TypeConstraint<type>("T") \
331 .TypeConstraint<index_type>("Tindices") \
332 .HostMemory("ref"), \
333 ScatterNdUpdateOp<dev##Device, type, index_type, op>)
334
335 #define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
336 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
337 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
338
339 #define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
340 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
341 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
342
343 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
344 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
345 op); \
346 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
347
348 #define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
349 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
350 scatter_nd_op::UpdateOp::ADD); \
351 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
352 scatter_nd_op::UpdateOp::ADD); \
353 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
354 scatter_nd_op::UpdateOp::SUB); \
355 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
356 type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD); \
357 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
358 type, dev, "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
359
360 #define REGISTER_SCATTER_ND(type, dev) \
361 REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
362
363 #define REGISTER_SCATTER_ND_UPDATE(type, dev) \
364 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
365 scatter_nd_op::UpdateOp::ASSIGN); \
366 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
367 type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
368
369 // Registers CPU kernels.
370 #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
371 REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
372
373 #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
374 REGISTER_SCATTER_ND_UPDATE(type, CPU);
375
376 #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
377 #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
378
379 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
380 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
381 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
382 TF_CALL_string(REGISTER_SCATTER_ND_CPU);
383 TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
384 TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
385 TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
386
387 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \
388 dev) \
389 REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate") \
390 .Device(DEVICE_##dev) \
391 .TypeConstraint<type>("T") \
392 .TypeConstraint<index_type>("Tindices"), \
393 TensorScatterOp<dev##Device, type, index_type, \
394 scatter_nd_op::UpdateOp::ASSIGN>)
395
396 #define REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, index_type, dev) \
397 REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd") \
398 .Device(DEVICE_##dev) \
399 .TypeConstraint<type>("T") \
400 .TypeConstraint<index_type>("Tindices"), \
401 TensorScatterOp<dev##Device, type, index_type, \
402 scatter_nd_op::UpdateOp::ADD>)
403
404 #define REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, index_type, dev) \
405 REGISTER_KERNEL_BUILDER(Name("TensorScatterSub") \
406 .Device(DEVICE_##dev) \
407 .TypeConstraint<type>("T") \
408 .TypeConstraint<index_type>("Tindices"), \
409 TensorScatterOp<dev##Device, type, index_type, \
410 scatter_nd_op::UpdateOp::SUB>)
411
412 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type) \
413 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \
414 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64, CPU);
415
416 #define REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type) \
417 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, CPU); \
418 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64, CPU);
419
420 #define REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type) \
421 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \
422 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, CPU);
423
424 #define REGISTER_SCATTER_ND_TENSOR_CPU(type) \
425 REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \
426 REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type); \
427 REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type);
428
429 // Register TensorScatterUpdate/Add/Sub for all number types.
430 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU);
431 // Register only TensorScatterUpdate for string/bool types as well.
432 TF_CALL_string(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
433 TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
434
435 #undef REGISTER_SCATTER_ND_TENSOR_CPU
436
437 // Registers GPU kernels.
438 #if GOOGLE_CUDA
439
440 #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
441 REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
442
443 #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
444 REGISTER_SCATTER_ND_UPDATE(type, GPU);
445
446 #define REGISTER_SCATTER_ND_ALL_GPU(type) \
447 REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \
448 REGISTER_SCATTER_ND_UPDATE_GPU(type); \
449 REGISTER_SCATTER_ND_GPU(type);
450
451 TF_CALL_int32(REGISTER_SCATTER_ND_ALL_GPU);
452 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
453 TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
454 TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
455
456 #undef REGISTER_SCATTER_ND_ALL_GPU
457
458 #ifdef TENSORFLOW_USE_SYCL
459 #define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \
460 REGISTER_SCATTER_ND_ADD_SUB(type, SYCL);
461
462 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
463 REGISTER_SCATTER_ND_UPDATE(type, SYCL);
464
465 TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
466 TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL);
467 TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL);
468 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
469 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
470 #undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
471 #undef REGISTER_SCATTER_ND_UPDATE_SYCL
472 #endif // TENSORFLOW_USE_SYCL
473
474 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type) \
475 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, GPU); \
476 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64, GPU);
477
478 #define REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type) \
479 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, GPU); \
480 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64, GPU);
481
482 #define REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type) \
483 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \
484 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, GPU);
485
486 #define REGISTER_SCATTER_ND_TENSOR_GPU(type) \
487 REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type); \
488 REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \
489 REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type);
490
491 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
492
493 #undef REGISTER_SCATTER_ND_ADD
494 #undef REGISTER_SCATTER_ND_ADD_SUB
495 #undef REGISTER_SCATTER_ND_ADD_SUB_CPU
496 #undef REGISTER_SCATTER_ND_ADD_SUB_GPU
497 #undef REGISTER_SCATTER_ND_UPDATE
498 #undef REGISTER_SCATTER_ND_UPDATE_CPU
499 #undef REGISTER_SCATTER_ND_UPDATE_GPU
500 #undef REGISTER_SCATTER_ND_KERNEL
501 #undef REGISTER_SCATTER_ND_KERNEL_INDEX
502 #undef REGISTER_SCATTER_ND_TENSOR_TYPE_INDEX_TYPE
503 #undef REGISTER_SCATTER_ND_TENSOR_CPU
504 #undef REGISTER_SCATTER_ND_TENSOR_GPU
505 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE
506 #undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE
507 #undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE
508 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU
509 #undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU
510 #undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
511 #undef REGISTER_SCATTER_ND_TENSOR_GPU
512
513 #endif // GOOGLE_CUDA
514
515 namespace functor {
516 // Check whether updates.shape = indices.shape[:batch_dim] +
517 // params_shape[slice_dim:]
ValidateUpdateShape(const TensorShape & params_shape,const Tensor & indices,const Tensor & updates)518 Status ValidateUpdateShape(const TensorShape& params_shape,
519 const Tensor& indices, const Tensor& updates) {
520 const int64 slice_dim =
521 (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1;
522 const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1;
523
524 auto shape_err = [&]() {
525 return errors::InvalidArgument(
526 "Must have updates.shape = indices.shape[:batch_dim] + ",
527 "params_shape[slice_dim:], got updates.shape: ",
528 updates.shape().DebugString(),
529 ", indices.shape: ", indices.shape().DebugString(),
530 ", params_shape: ", params_shape.DebugString(),
531 ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim);
532 };
533
534 if (updates.dims() < batch_dim) return shape_err();
535 if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
536 return shape_err();
537 }
538 if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
539 return shape_err();
540 }
541 for (int d = 0; d < batch_dim; ++d) {
542 if (updates.dim_size(d) != indices.dim_size(d)) return shape_err();
543 }
544 for (int d = 0; d < updates.dims() - batch_dim; ++d) {
545 if (updates.dim_size(d + batch_dim) !=
546 params_shape.dim_size(d + slice_dim)) {
547 return shape_err();
548 }
549 }
550 return Status::OK();
551 }
552
553 template <typename Index>
PrepareAndValidateInputs(const TensorShape & params_shape,const Tensor & indices,const Tensor & updates,int64 * slice_dim,Index * num_updates,Index * slice_size)554 Status PrepareAndValidateInputs(const TensorShape& params_shape,
555 const Tensor& indices, const Tensor& updates,
556 int64* slice_dim, Index* num_updates,
557 Index* slice_size) {
558 const TensorShape& indices_shape(indices.shape());
559 const TensorShape& updates_shape(updates.shape());
560
561 if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) {
562 return errors::InvalidArgument("Output must be at least 1-D, ",
563 "got shape: ", params_shape.DebugString());
564 }
565
566 if (!ValidEmptyOutputShape(params_shape.num_elements(),
567 indices_shape.num_elements(),
568 updates_shape.num_elements())) {
569 return errors::InvalidArgument(
570 "Indices and updates specified for empty output. indices shape: ",
571 indices.shape().DebugString());
572 }
573
574 if (updates.dim_size(0) != indices.dim_size(0)) {
575 return errors::InvalidArgument(
576 "The outermost dimension of updates and indices ",
577 "must match. Got indices.shape ", indices_shape.DebugString(),
578 ", updates.shape ", updates_shape.DebugString());
579 }
580 TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates));
581
582 // Check that we have enough index space
583 const int64 N_big = indices.NumElements();
584 if (N_big > std::numeric_limits<Index>::max()) {
585 return errors::InvalidArgument("indices has too many elements for ",
586 DataTypeString(DataTypeToEnum<Index>::v()),
587 " indexing: ", N_big, " > ",
588 std::numeric_limits<Index>::max());
589 }
590 if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) {
591 return errors::InvalidArgument("params_shape[0] too large for ",
592 DataTypeString(DataTypeToEnum<Index>::v()),
593 " indexing: ", params_shape.dim_size(0),
594 " > ", std::numeric_limits<Index>::max());
595 }
596
597 // Calculate the number of dimensions in indices
598 *slice_dim = (indices_shape.dims() > 1)
599 ? indices_shape.dim_size(indices_shape.dims() - 1)
600 : 1;
601
602 // Calculate the number of elements that make up each slice of our updated
603 // tensor. This allows us to work with flattened tensors and copy over whole
604 // slices at a time.
605 Index total_nd = params_shape.dims();
606
607 int64 slice_size_big = 1;
608 for (int64 i = *slice_dim; i < total_nd; ++i) {
609 slice_size_big *= params_shape.dim_size(i);
610 }
611
612 if (slice_size_big > std::numeric_limits<Index>::max()) {
613 return errors::InvalidArgument(
614 "slice size is too large for indexing: ", slice_size_big, " > ",
615 std::numeric_limits<Index>::max());
616 }
617
618 *slice_size = static_cast<Index>(slice_size_big);
619
620 const int64 safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim;
621 *num_updates = indices_shape.num_elements() / safe_slice_dim;
622
623 return Status::OK();
624 }
625
626 template <typename Device, typename Index>
627 class IndexFlattener {
628 public:
operator ()(OpKernelContext *,const Tensor & indices)629 inline typename TTypes<Index, 2>::ConstTensor operator()(
630 OpKernelContext*, const Tensor& indices) {
631 return indices.flat_inner_dims<Index>();
632 }
633 };
634
635 #ifdef TENSORFLOW_USE_SYCL
636 template <typename Index>
637 class IndexFlattener<SYCLDevice, Index> {
638 public:
IndexFlattener()639 IndexFlattener() { indices_host_ = nullptr; }
~IndexFlattener()640 ~IndexFlattener() { delete[] indices_host_; }
641
operator ()(OpKernelContext * c,const Tensor & indices)642 inline typename TTypes<Index, 2>::ConstTensor operator()(
643 OpKernelContext* c, const Tensor& indices) {
644 size_t num_indices = indices.NumElements();
645 indices_host_ = new Index[num_indices];
646 auto device = c->eigen_sycl_device();
647 auto size = sizeof(Index) * num_indices;
648 auto src_ptr = GetBase(&indices);
649 device.memcpyDeviceToHost(indices_host_, static_cast<const Index*>(src_ptr),
650 size);
651 return typename TTypes<Index, 2>::ConstTensor(
652 indices_host_, indices.shape().AsEigenDSizes<2>());
653 }
654
655 private:
656 Index* indices_host_;
657 };
658 #endif
659
660 template <typename Device, typename T, typename Index,
661 scatter_nd_op::UpdateOp Op>
DoScatterNd(OpKernelContext * c,const Tensor & indices,const Tensor & updates,const TensorShape & shape,Tensor * out,bool allocate)662 Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
663 const Tensor& updates, const TensorShape& shape, Tensor* out,
664 bool allocate) {
665 int64 slice_dim;
666 Index num_updates;
667 Index slice_size;
668 TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
669 shape, indices, updates, &slice_dim, &num_updates, &slice_size));
670
671 IndexFlattener<Device, Index> index_flattener;
672 auto indices_flat = index_flattener(c, indices);
673 auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
674
675 if (allocate) {
676 TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out));
677 } else {
678 CHECK_NOTNULL(out);
679 }
680
681 if (shape.num_elements() == 0) {
682 return Status::OK();
683 }
684
685 if (allocate) {
686 // Brand new tensor, zero it out.
687 functor::SetZeroFunctor<Device, T> fill;
688 fill(c->eigen_device<Device>(), out->flat<T>());
689 }
690 auto output_matrix =
691 out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size});
692
693 Index bad_i = -1;
694
695 if (shape.num_elements() > 0) {
696 switch (slice_dim) {
697 #define PARAMS_CASE(IXDIM) \
698 case IXDIM: { \
699 typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
700 for (int i = 0; i < IXDIM; ++i) { \
701 output_shape_prefix[i] = shape.dim_size(i); \
702 } \
703 functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \
704 bad_i = \
705 functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
706 output_matrix, indices_flat, updates_flat, output_matrix); \
707 } break
708 // TODO(simister): Re-enable this once binary size is under control.
709 // PARAMS_CASE(0);
710 PARAMS_CASE(1);
711 PARAMS_CASE(2);
712 PARAMS_CASE(3);
713 PARAMS_CASE(4);
714 PARAMS_CASE(5);
715 PARAMS_CASE(6);
716 PARAMS_CASE(7);
717 #undef PARAMS_CASE
718 default:
719 return errors::InvalidArgument(
720 "Only indices.shape[-1] values between 1 and 5 "
721 "are currently supported. Requested rank: ",
722 slice_dim);
723 }
724 }
725 if (bad_i >= 0) {
726 auto slice_shape = indices.shape();
727 slice_shape.RemoveLastDims(1);
728 return errors::InvalidArgument(
729 "indices", SliceDebugString(slice_shape, bad_i), " = [",
730 str_util::Join(
731 gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
732 "] does not index into shape ", shape.DebugString());
733 }
734 return Status::OK();
735 }
736 } // namespace functor
737
738 #ifdef GOOGLE_CUDA
739 // Forward declarations of the functor specializations for GPU.
740 namespace functor {
741 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
742 template <> \
743 Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \
744 const GPUDevice& d, const Index slice_size, \
745 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
746 typename TTypes<T, 2>::Tensor Tparams, \
747 typename TTypes<Index, 2>::ConstTensor Tindices, \
748 typename TTypes<T, 2>::ConstTensor Tupdates, \
749 typename TTypes<T, 2>::Tensor Toutput); \
750 extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
751
752 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
753 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
754 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
755 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
756 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
757 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \
758 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \
759 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7);
760
761 #define DECLARE_GPU_SPECS_INDEX(T, Index) \
762 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
763 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
764 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
765
766 #define DECLARE_GPU_SPECS(T) \
767 DECLARE_GPU_SPECS_INDEX(T, int32); \
768 DECLARE_GPU_SPECS_INDEX(T, int64)
769
770 TF_CALL_int32(DECLARE_GPU_SPECS);
771 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
772 TF_CALL_complex64(DECLARE_GPU_SPECS);
773 TF_CALL_complex128(DECLARE_GPU_SPECS);
774
775 #undef DECLARE_GPU_SPECS
776 #undef DECLARE_GPU_SPECS_INDEX
777 #undef DECLARE_GPU_SPECS_INDEX_OP
778
779 } // namespace functor
780
781 #endif // GOOGLE_CUDA
782
783 } // namespace tensorflow
784