1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/state_ops.cc.
17 #define EIGEN_USE_THREADS
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #include "tensorflow/core/platform/stream_executor.h"
22 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/kernels/dense_update_functor.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/inplace_ops_functor.h"
33 #include "tensorflow/core/kernels/scatter_nd_op.h"
34 #include "tensorflow/core/kernels/scatter_nd_util.h"
35 #include "tensorflow/core/kernels/training_op_helpers.h"
36 #include "tensorflow/core/kernels/variable_ops.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/util/determinism.h"
41 #include "tensorflow/core/util/util.h"
42
43 namespace tensorflow {
44
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46 typedef Eigen::GpuDevice GPUDevice;
47
48 // Returns true if the three tensors have valid number of elements
49 // If shape_input has 0 elements, then we need to have indices and updates with
50 // exactly 0 elements too, otherwise we should error. If indices has 0 elements
51 // then updates should also have 0 elements, otherwise we should error.
ValidEmptyOutputShape(int64_t num_inputs,int64_t num_indices,int64_t num_updates)52 bool ValidEmptyOutputShape(int64_t num_inputs, int64_t num_indices,
53 int64_t num_updates) {
54 if (num_indices == 0 && num_updates == 0) {
55 return true; // regardless of num_inputs ?= 0, covers both cases
56 }
57 // now we want all 3 tensors to have values
58 return (num_inputs != 0 && num_indices != 0 && num_updates != 0);
59 }
60
61 template <typename Device, typename T, typename Index>
62 class ScatterNdOp : public OpKernel {
63 public:
ScatterNdOp(OpKernelConstruction * c)64 explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
65 const DataType dt = DataTypeToEnum<T>::v();
66 const DataType index_t = DataTypeToEnum<Index>::v();
67 OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
68 }
69
Compute(OpKernelContext * c)70 void Compute(OpKernelContext* c) override {
71 const Tensor& indices = c->input(0);
72 const Tensor& updates = c->input(1);
73 const Tensor& shape_input = c->input(2);
74
75 OP_REQUIRES(c, indices.shape().dims() >= 1,
76 errors::InvalidArgument(
77 "Indices shape must have rank at least one. Found:",
78 indices.shape().DebugString()));
79 OP_REQUIRES(c, updates.shape().dims() >= 1,
80 errors::InvalidArgument(
81 "Updates shape must have rank at least one. Found:",
82 updates.shape().DebugString()));
83
84 auto vec = shape_input.flat<Index>();
85 TensorShape shape;
86 OP_REQUIRES_OK(c,
87 TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
88
89 OP_REQUIRES(c,
90 ValidEmptyOutputShape(shape_input.NumElements(),
91 indices.shape().num_elements(),
92 updates.shape().num_elements()),
93 errors::InvalidArgument(
94 "Indices and updates specified for empty output shape"));
95
96 const int64_t outer_dims = indices.shape().dims() - 1;
97
98 for (int i = 0; i < outer_dims; ++i) {
99 OP_REQUIRES(
100 c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
101 errors::InvalidArgument(
102 "Dimensions [0,", outer_dims,
103 ") of indices[shape=", indices.shape().DebugString(),
104 "] must match dimensions [0,", outer_dims,
105 ") of updates[shape=", updates.shape().DebugString(), "]"));
106 }
107
108 const int64_t ix = indices.shape().dim_size(outer_dims);
109 OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix,
110 errors::InvalidArgument(
111 "Dimensions [", ix, ",", shape.dims(), ") of input[shape=",
112 shape.DebugString(), "] must match dimensions [",
113 outer_dims, ",", updates.shape().dims(),
114 ") of updates[shape=", updates.shape().DebugString(), "]"));
115
116 for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
117 OP_REQUIRES(
118 c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
119 errors::InvalidArgument("Dimensions [", ix, ",", shape.dims(),
120 ") of input[shape=", shape.DebugString(),
121 "] must match dimensions [", outer_dims, ",",
122 updates.shape().dims(), ") of updates[shape=",
123 updates.shape().DebugString(), "]"));
124 }
125 OP_REQUIRES(c, shape_input.dims() == 1,
126 errors::InvalidArgument("Shape must be a vector"));
127
128 Tensor out;
129 OP_REQUIRES_OK(
130 c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
131 c, indices, updates, shape, &out, true /*allocate*/));
132 c->set_output(0, out);
133 }
134 };
135
136 template <typename Device, typename T, typename Index,
137 scatter_nd_op::UpdateOp op>
138 class TensorScatterOp : public OpKernel {
139 public:
TensorScatterOp(OpKernelConstruction * c)140 explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) {
141 const DataType dt = DataTypeToEnum<T>::v();
142 const DataType index_t = DataTypeToEnum<Index>::v();
143 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
144 }
145
Compute(OpKernelContext * c)146 void Compute(OpKernelContext* c) override {
147 const Tensor& input = c->input(0);
148 const Tensor& indices = c->input(1);
149 const Tensor& updates = c->input(2);
150
151 OP_REQUIRES(c, indices.shape().dims() >= 1,
152 errors::InvalidArgument(
153 "Indices shape must have rank at least one. Found:",
154 indices.shape().DebugString()));
155 OP_REQUIRES(c, updates.shape().dims() >= 1,
156 errors::InvalidArgument(
157 "Updates shape must have rank at least one. Found:",
158 updates.shape().DebugString()));
159
160 TensorShape shape = input.shape();
161
162 OP_REQUIRES(c,
163 ValidEmptyOutputShape(shape.num_elements(),
164 indices.shape().num_elements(),
165 updates.shape().num_elements()),
166 errors::InvalidArgument(
167 "Indices and updates specified for empty output shape"));
168
169 const int64_t outer_dims = indices.shape().dims() - 1;
170
171 for (int i = 0; i < outer_dims; ++i) {
172 OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
173 errors::InvalidArgument(
174 "Outer dimensions of indices and update must match. "
175 "Indices shape: ",
176 indices.shape().DebugString(),
177 ", updates shape:", updates.shape().DebugString()));
178 }
179
180 const int64_t ix = indices.shape().dim_size(outer_dims);
181 OP_REQUIRES(
182 c, updates.shape().dims() - outer_dims == shape.dims() - ix,
183 errors::InvalidArgument("Inner dimensions of output shape must match "
184 "inner dimensions of updates shape. Output: ",
185 shape.DebugString(),
186 " updates: ", updates.shape().DebugString()));
187 for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
188 OP_REQUIRES(
189 c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
190 errors::InvalidArgument(
191 "The inner ", shape.dims() - ix,
192 " dimensions of output.shape=", shape.DebugString(),
193 " must match the inner ", updates.shape().dims() - outer_dims,
194 " dimensions of updates.shape=", updates.shape().DebugString()));
195 }
196
197 AllocatorAttributes alloc_attr;
198 MemoryType memory_type = DEVICE_MEMORY;
199 if (std::is_same<Device, CPUDevice>::value) {
200 alloc_attr.set_on_host(true);
201 memory_type = HOST_MEMORY;
202 } else {
203 memory_type = DEVICE_MEMORY;
204 }
205 std::unique_ptr<Tensor> forwarded_input =
206 c->forward_input(0, 0, input.dtype(), shape, memory_type, alloc_attr);
207
208 if (forwarded_input == nullptr) {
209 // We were not able to forward the input, so we deep copy the tensor and
210 // set the output.
211 Tensor* out;
212 OP_REQUIRES_OK(c, c->allocate_output(0, input.shape(), &out));
213
214 OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(),
215 input, out));
216 OP_REQUIRES_OK(c,
217 functor::DoScatterNd<Device, T, Index, op>(
218 c, indices, updates, shape, out, false /*allocate*/));
219 } else {
220 // Output forwarded, so simply perform the scatter.
221 OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
222 c, indices, updates, shape, forwarded_input.get(),
223 false /*allocate*/));
224
225 c->set_output(0, *forwarded_input);
226 }
227 }
228 };
229
230 template <typename Device, typename T, typename Index,
231 scatter_nd_op::UpdateOp op>
232 class ScatterNdUpdateOp : public OpKernel {
233 public:
ScatterNdUpdateOp(OpKernelConstruction * c)234 explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
235 const DataType dt = DataTypeToEnum<T>::v();
236 const DataType dt_ref = DataTypeToEnum<T>::ref();
237 const DataType index_t = DataTypeToEnum<Index>::v();
238 dtype_ = c->input_type(0);
239 // If we are updating a resource, we always use the exclusive lock.
240 // For ref types, we lock based on the use_locking parameter
241 // Otherwise, we don't mutate the input tensor (we copy-on-write if needed).
242 if (c->input_type(0) == DT_RESOURCE) {
243 // TODO(apassos): what to validate here?
244 } else if (IsRefType(c->input_type(0))) {
245 OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
246 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
247 } else {
248 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
249 use_exclusive_lock_ = false;
250 }
251 }
252
Compute(OpKernelContext * c)253 void Compute(OpKernelContext* c) override {
254 if (dtype_ == DT_RESOURCE) {
255 core::RefCountPtr<Var> v;
256 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
257 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
258 mutex_lock m(*v->mu());
259 DoCompute(c);
260 } else if (use_exclusive_lock_) {
261 // If we're here, it means the input type is a ref.
262 DCHECK(IsRefType(c->input_dtype(0)));
263 // Hold mutex while we apply updates
264 mutex_lock l(*c->input_ref_mutex(0));
265 DoCompute(c);
266 } else {
267 DoCompute(c);
268 }
269 }
270
271 private:
272 DataType dtype_;
273 bool use_exclusive_lock_;
274
DoCompute(OpKernelContext * c)275 void DoCompute(OpKernelContext* c) {
276 const Tensor& indices = c->input(1);
277 const Tensor& updates = c->input(2);
278 Tensor params;
279 TensorShape params_shape;
280
281 if (dtype_ == DT_RESOURCE) {
282 core::RefCountPtr<Var> v;
283 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
284 Tensor* t = v->tensor();
285 params = *t;
286 params_shape = params.shape();
287 } else if (IsRefType(c->input_dtype(0))) {
288 params = c->mutable_input(0, use_exclusive_lock_);
289 params_shape = params.shape();
290 c->forward_ref_input_to_ref_output(0, 0);
291 OP_REQUIRES(c, params.IsInitialized(),
292 errors::FailedPrecondition("Null ref for params"));
293 } else {
294 Tensor* params_ptr;
295 params_shape = c->input(0).shape();
296 if (!c->forward_input_to_output_with_shape(0, 0, params_shape,
297 ¶ms_ptr)) {
298 // We weren't able to forward the input to output, so just
299 // allocate a new output tensor and copy the values over.
300 OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, ¶ms_ptr));
301 params = *params_ptr;
302 functor::DenseUpdate<Device, T, ASSIGN> copy;
303 const Tensor& input_copy = c->input(0);
304 copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>());
305 } else {
306 params = *params_ptr;
307 }
308 }
309
310 OP_REQUIRES_OK(
311 c, functor::DoScatterNd<Device, T, Index, op>(
312 c, indices, updates, params_shape, ¶ms, false /*allocate*/));
313 }
314 };
315
316 #define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
317 REGISTER_KERNEL_BUILDER(Name(name) \
318 .Device(DEVICE_##dev) \
319 .TypeConstraint<type>("T") \
320 .TypeConstraint<index_type>("Tindices") \
321 .HostMemory("shape"), \
322 ScatterNdOp<dev##Device, type, index_type>)
323
324 #define REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(index_type, name) \
325 REGISTER_KERNEL_BUILDER(Name(name) \
326 .Device(DEVICE_DEFAULT) \
327 .TypeConstraint<int32>("T") \
328 .TypeConstraint<index_type>("Tindices") \
329 .HostMemory("indices") \
330 .HostMemory("updates") \
331 .HostMemory("shape") \
332 .HostMemory("output"), \
333 ScatterNdOp<CPUDevice, int32, index_type>)
334
335 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
336 op) \
337 REGISTER_KERNEL_BUILDER( \
338 Name(name) \
339 .Device(DEVICE_##dev) \
340 .TypeConstraint<type>("T") \
341 .TypeConstraint<index_type>("Tindices"), \
342 ScatterNdUpdateOp<dev##Device, type, index_type, op>)
343
344 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, name, \
345 op) \
346 REGISTER_KERNEL_BUILDER(Name(name) \
347 .Device(DEVICE_DEFAULT) \
348 .TypeConstraint<int32>("T") \
349 .TypeConstraint<index_type>("Tindices") \
350 .HostMemory("ref") \
351 .HostMemory("indices") \
352 .HostMemory("updates") \
353 .HostMemory("output_ref"), \
354 ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
355
356 #define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU( \
357 index_type, name, op) \
358 REGISTER_KERNEL_BUILDER(Name(name) \
359 .Device(DEVICE_DEFAULT) \
360 .TypeConstraint<int32>("T") \
361 .TypeConstraint<index_type>("Tindices") \
362 .HostMemory("input") \
363 .HostMemory("indices") \
364 .HostMemory("updates") \
365 .HostMemory("output"), \
366 ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
367
368 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
369 dev, name, op) \
370 REGISTER_KERNEL_BUILDER( \
371 Name(name) \
372 .Device(DEVICE_##dev) \
373 .TypeConstraint<type>("T") \
374 .TypeConstraint<index_type>("Tindices") \
375 .HostMemory("ref"), \
376 ScatterNdUpdateOp<dev##Device, type, index_type, op>)
377
378 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, \
379 name, op) \
380 REGISTER_KERNEL_BUILDER(Name(name) \
381 .Device(DEVICE_DEFAULT) \
382 .TypeConstraint<int32>("T") \
383 .TypeConstraint<index_type>("Tindices") \
384 .HostMemory("ref") \
385 .HostMemory("indices") \
386 .HostMemory("updates"), \
387 ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
388
389 #define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
390 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
391 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64_t, dev, name)
392
393 #define REGISTER_SCATTER_ND_KERNEL_INT32_GPU(name) \
394 REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int32, name); \
395 REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int64_t, name)
396
397 #define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
398 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
399 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op)
400
401 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op) \
402 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \
403 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op)
404
405 #define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU(name, op) \
406 REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, \
407 op); \
408 REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, \
409 name, op)
410
411 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
412 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
413 op); \
414 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op)
415
416 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op) \
417 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \
418 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op)
419
420 #define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
421 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
422 scatter_nd_op::UpdateOp::ADD); \
423 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
424 scatter_nd_op::UpdateOp::ADD); \
425 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
426 scatter_nd_op::UpdateOp::SUB); \
427 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
428 type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD); \
429 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
430 type, dev, "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
431
432 #define REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU() \
433 REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU( \
434 "ScatterNdNonAliasingAdd", scatter_nd_op::UpdateOp::ADD); \
435 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdAdd", \
436 scatter_nd_op::UpdateOp::ADD); \
437 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdSub", \
438 scatter_nd_op::UpdateOp::SUB); \
439 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
440 "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD); \
441 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
442 "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
443
444 #define REGISTER_SCATTER_ND(type, dev) \
445 REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
446
447 #define REGISTER_SCATTER_ND_INT32_GPU() \
448 REGISTER_SCATTER_ND_KERNEL_INT32_GPU("ScatterNd");
449
450 #define REGISTER_SCATTER_ND_UPDATE(type, dev) \
451 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
452 scatter_nd_op::UpdateOp::ASSIGN); \
453 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
454 type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
455
456 #define REGISTER_SCATTER_ND_UPDATE_INT32_GPU() \
457 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
458 "ScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); \
459 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
460 "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
461
462 #define REGISTER_SCATTER_ND_MIN_MAX(type, dev) \
463 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMax", \
464 scatter_nd_op::UpdateOp::MAX); \
465 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMin", \
466 scatter_nd_op::UpdateOp::MIN); \
467 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
468 type, dev, "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \
469 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
470 type, dev, "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
471
472 #define REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU() \
473 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMax", \
474 scatter_nd_op::UpdateOp::MAX); \
475 REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMin", \
476 scatter_nd_op::UpdateOp::MIN); \
477 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
478 "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \
479 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU( \
480 "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
481
482 // Registers CPU kernels.
483 #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
484 REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
485
486 #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
487 REGISTER_SCATTER_ND_UPDATE(type, CPU);
488
489 #define REGISTER_SCATTER_ND_MIN_MAX_CPU(type) \
490 REGISTER_SCATTER_ND_MIN_MAX(type, CPU);
491
492 #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
493 #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
494
495 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
496 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
497 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
498 TF_CALL_tstring(REGISTER_SCATTER_ND_CPU);
499 TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU);
500 TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
501 TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
502 TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
503 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_CPU);
504
505 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \
506 dev) \
507 REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate") \
508 .Device(DEVICE_##dev) \
509 .TypeConstraint<type>("T") \
510 .TypeConstraint<index_type>("Tindices"), \
511 TensorScatterOp<dev##Device, type, index_type, \
512 scatter_nd_op::UpdateOp::ASSIGN>)
513
514 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(index_type) \
515 REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate") \
516 .Device(DEVICE_DEFAULT) \
517 .TypeConstraint<int32>("T") \
518 .TypeConstraint<index_type>("Tindices") \
519 .HostMemory("tensor") \
520 .HostMemory("indices") \
521 .HostMemory("updates") \
522 .HostMemory("output"), \
523 TensorScatterOp<CPUDevice, int32, index_type, \
524 scatter_nd_op::UpdateOp::ASSIGN>)
525
526 #define REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, index_type, dev) \
527 REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd") \
528 .Device(DEVICE_##dev) \
529 .TypeConstraint<type>("T") \
530 .TypeConstraint<index_type>("Tindices"), \
531 TensorScatterOp<dev##Device, type, index_type, \
532 scatter_nd_op::UpdateOp::ADD>)
533
534 #define REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(index_type) \
535 REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd") \
536 .Device(DEVICE_DEFAULT) \
537 .TypeConstraint<int32>("T") \
538 .TypeConstraint<index_type>("Tindices") \
539 .HostMemory("tensor") \
540 .HostMemory("indices") \
541 .HostMemory("updates") \
542 .HostMemory("output"), \
543 TensorScatterOp<CPUDevice, int32, index_type, \
544 scatter_nd_op::UpdateOp::ADD>)
545
546 #define REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, index_type, dev) \
547 REGISTER_KERNEL_BUILDER(Name("TensorScatterSub") \
548 .Device(DEVICE_##dev) \
549 .TypeConstraint<type>("T") \
550 .TypeConstraint<index_type>("Tindices"), \
551 TensorScatterOp<dev##Device, type, index_type, \
552 scatter_nd_op::UpdateOp::SUB>)
553
554 #define REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(index_type) \
555 REGISTER_KERNEL_BUILDER(Name("TensorScatterSub") \
556 .Device(DEVICE_DEFAULT) \
557 .TypeConstraint<int32>("T") \
558 .TypeConstraint<index_type>("Tindices") \
559 .HostMemory("tensor") \
560 .HostMemory("indices") \
561 .HostMemory("updates") \
562 .HostMemory("output"), \
563 TensorScatterOp<CPUDevice, int32, index_type, \
564 scatter_nd_op::UpdateOp::SUB>)
565
566 #define REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, index_type, dev) \
567 REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \
568 .Device(DEVICE_##dev) \
569 .TypeConstraint<type>("T") \
570 .TypeConstraint<index_type>("Tindices"), \
571 TensorScatterOp<dev##Device, type, index_type, \
572 scatter_nd_op::UpdateOp::MIN>)
573
574 #define REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(index_type) \
575 REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \
576 .Device(DEVICE_DEFAULT) \
577 .TypeConstraint<int32>("T") \
578 .TypeConstraint<index_type>("Tindices") \
579 .HostMemory("tensor") \
580 .HostMemory("indices") \
581 .HostMemory("updates") \
582 .HostMemory("output"), \
583 TensorScatterOp<CPUDevice, int32, index_type, \
584 scatter_nd_op::UpdateOp::MIN>)
585
586 #define REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, index_type, dev) \
587 REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \
588 .Device(DEVICE_##dev) \
589 .TypeConstraint<type>("T") \
590 .TypeConstraint<index_type>("Tindices"), \
591 TensorScatterOp<dev##Device, type, index_type, \
592 scatter_nd_op::UpdateOp::MAX>)
593
594 #define REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(index_type) \
595 REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \
596 .Device(DEVICE_DEFAULT) \
597 .TypeConstraint<int32>("T") \
598 .TypeConstraint<index_type>("Tindices") \
599 .HostMemory("tensor") \
600 .HostMemory("indices") \
601 .HostMemory("updates") \
602 .HostMemory("output"), \
603 TensorScatterOp<CPUDevice, int32, index_type, \
604 scatter_nd_op::UpdateOp::MAX>)
605
606 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type) \
607 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \
608 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, CPU);
609
610 #define REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type) \
611 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, CPU); \
612 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, CPU);
613
614 #define REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type) \
615 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \
616 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, CPU);
617
618 #define REGISTER_SCATTER_ND_TENSOR_MIN_CPU(type) \
619 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, CPU); \
620 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, CPU);
621
622 #define REGISTER_SCATTER_ND_TENSOR_MAX_CPU(type) \
623 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, CPU); \
624 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, CPU);
625
626 #define REGISTER_SCATTER_ND_TENSOR_CPU(type) \
627 REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \
628 REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type); \
629 REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type);
630
631 // Register TensorScatterUpdate/Add/Sub for all number types.
632 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU);
633 // Register min/max operations only for Real number types
634 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MIN_CPU);
635 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MAX_CPU);
636 // Register only TensorScatterUpdate for string/bool types as well.
637 TF_CALL_tstring(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
638 TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
639
640 #undef REGISTER_SCATTER_ND_TENSOR_CPU
641
642 // Registers GPU kernels.
643 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
644
645 #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
646 REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
647
648 #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
649 REGISTER_SCATTER_ND_UPDATE(type, GPU);
650
651 #define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \
652 REGISTER_SCATTER_ND_MIN_MAX(type, GPU);
653
654 #define REGISTER_SCATTER_ND_ALL_GPU(type) \
655 REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \
656 REGISTER_SCATTER_ND_UPDATE_GPU(type); \
657 REGISTER_SCATTER_ND_GPU(type);
658
659 #define REGISTER_SCATTER_ND_ALL_INT32_GPU() \
660 REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU(); \
661 REGISTER_SCATTER_ND_UPDATE_INT32_GPU(); \
662 REGISTER_SCATTER_ND_INT32_GPU();
663
664 REGISTER_SCATTER_ND_ALL_INT32_GPU();
665 REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU();
666
667 TF_CALL_int64(REGISTER_SCATTER_ND_ALL_GPU);
668 TF_CALL_int64(REGISTER_SCATTER_ND_MIN_MAX_GPU);
669 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
670 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_GPU);
671 TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
672
673 #undef REGISTER_SCATTER_ND_ALL_GPU
674
675 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type) \
676 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, GPU); \
677 REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, GPU);
678
679 #define REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type) \
680 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, GPU); \
681 REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, GPU);
682
683 #define REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type) \
684 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \
685 REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, GPU);
686
687 #define REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type) \
688 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, GPU); \
689 REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, GPU);
690
691 #define REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type) \
692 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, GPU); \
693 REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, GPU);
694
695 #define REGISTER_SCATTER_ND_TENSOR_GPU(type) \
696 REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type); \
697 REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \
698 REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type);
699
700 #define REGISTER_SCATTER_ND_TENSOR_INT32_GPU() \
701 REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int32); \
702 REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int64_t); \
703 REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int32); \
704 REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int64_t); \
705 REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int32); \
706 REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int64_t);
707
708 #define REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX(type) \
709 REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type); \
710 REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type);
711
712 #define REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU() \
713 REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int32); \
714 REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int64_t); \
715 REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int32); \
716 REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int64_t);
717
718 REGISTER_SCATTER_ND_TENSOR_INT32_GPU();
719 REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU();
720
721 TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU);
722 TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
723 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
724 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
725 TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_TENSOR_GPU);
726
727 #undef REGISTER_SCATTER_ND_ADD
728 #undef REGISTER_SCATTER_ND_ADD_SUB
729 #undef REGISTER_SCATTER_ND_ADD_SUB_CPU
730 #undef REGISTER_SCATTER_ND_ADD_SUB_GPU
731 #undef REGISTER_SCATTER_ND_MIN_MAX
732 #undef REGISTER_SCATTER_ND_MIN_MAX_CPU
733 #undef REGISTER_SCATTER_ND_MIN_MAX_GPU
734 #undef REGISTER_SCATTER_ND_UPDATE
735 #undef REGISTER_SCATTER_ND_UPDATE_CPU
736 #undef REGISTER_SCATTER_ND_UPDATE_GPU
737 #undef REGISTER_SCATTER_ND_KERNEL
738 #undef REGISTER_SCATTER_ND_KERNEL_INDEX
739 #undef REGISTER_SCATTER_ND_TENSOR_TYPE_INDEX_TYPE
740 #undef REGISTER_SCATTER_ND_TENSOR_CPU
741 #undef REGISTER_SCATTER_ND_TENSOR_GPU
742 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE
743 #undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE
744 #undef REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE
745 #undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE
746 #undef REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE
747 #undef REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE
748 #undef REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE
749 #undef REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE
750 #undef REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE
751 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU
752 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE
753 #undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU
754 #undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
755 #undef REGISTER_SCATTER_ND_TENSOR_MIN_GPU
756 #undef REGISTER_SCATTER_ND_TENSOR_MAX_GPU
757 #undef REGISTER_SCATTER_ND_TENSOR_GPU
758 #undef REGISTER_SCATTER_ND_TENSOR_INT32_GPU
759 #undef REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU
760 #undef REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU
761 #undef REGISTER_SCATTER_ND_ALL_INT32_GPU
762 #undef REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU
763 #undef REGISTER_SCATTER_ND_INT32_GPU
764 #undef REGISTER_SCATTER_ND_UPDATE_INT32_GPU
765 #undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU
766 #undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU
767 #undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU
768 #undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU
769 #undef REGISTER_SCATTER_ND_KERNEL_INT32_GPU
770 #undef REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU
771
772 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
773
774 namespace functor {
775
776 template <typename Index>
PrepareAndValidateInputs(const TensorShape & params_shape,const Tensor & indices,const Tensor & updates,int64_t * slice_dim,Index * num_updates,Index * slice_size)777 Status PrepareAndValidateInputs(const TensorShape& params_shape,
778 const Tensor& indices, const Tensor& updates,
779 int64_t* slice_dim, Index* num_updates,
780 Index* slice_size) {
781 const TensorShape& indices_shape(indices.shape());
782 const TensorShape& updates_shape(updates.shape());
783
784 if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) {
785 return errors::InvalidArgument("Output must be at least 1-D, ",
786 "got shape: ", params_shape.DebugString());
787 }
788
789 if (!ValidEmptyOutputShape(params_shape.num_elements(),
790 indices_shape.num_elements(),
791 updates_shape.num_elements())) {
792 return errors::InvalidArgument(
793 "Indices and updates specified for empty output. indices shape: ",
794 indices.shape().DebugString());
795 }
796
797 if (updates.dim_size(0) != indices.dim_size(0)) {
798 return errors::InvalidArgument(
799 "Dimensions [0,1) of indices[shape=", indices_shape.DebugString(),
800 "] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[",
801 "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0));
802 }
803 TF_RETURN_IF_ERROR(ValidateScatterNdUpdateShape(params_shape, indices.shape(),
804 updates.shape()));
805
806 // Check that we have enough index space
807 const int64_t N_big = indices.NumElements();
808 if (N_big > std::numeric_limits<Index>::max()) {
809 return errors::InvalidArgument("indices has too many elements for ",
810 DataTypeString(DataTypeToEnum<Index>::v()),
811 " indexing: ", N_big, " > ",
812 std::numeric_limits<Index>::max());
813 }
814 if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) {
815 return errors::InvalidArgument("params_shape[0] too large for ",
816 DataTypeString(DataTypeToEnum<Index>::v()),
817 " indexing: ", params_shape.dim_size(0),
818 " > ", std::numeric_limits<Index>::max());
819 }
820
821 // Calculate the number of dimensions in indices
822 *slice_dim = (indices_shape.dims() > 1)
823 ? indices_shape.dim_size(indices_shape.dims() - 1)
824 : 1;
825
826 // Calculate the number of elements that make up each slice of our updated
827 // tensor. This allows us to work with flattened tensors and copy over whole
828 // slices at a time.
829 Index total_nd = params_shape.dims();
830
831 int64_t slice_size_big = 1;
832 for (int64_t i = *slice_dim; i < total_nd; ++i) {
833 slice_size_big *= params_shape.dim_size(i);
834 }
835
836 if (slice_size_big > std::numeric_limits<Index>::max()) {
837 return errors::InvalidArgument(
838 "slice size is too large for indexing: ", slice_size_big, " > ",
839 std::numeric_limits<Index>::max());
840 }
841
842 *slice_size = static_cast<Index>(slice_size_big);
843
844 const int64_t safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim;
845 *num_updates = indices_shape.num_elements() / safe_slice_dim;
846
847 return OkStatus();
848 }
849
850 template <typename Device, typename Index>
851 class IndexFlattener {
852 public:
operator ()(OpKernelContext *,const Tensor & indices)853 inline typename TTypes<Index, 2>::ConstTensor operator()(
854 OpKernelContext*, const Tensor& indices) {
855 return indices.flat_inner_dims<Index>();
856 }
857 };
858
859 namespace {
860
861 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
862
863 // Copies inputs to the CPU, runs DoScatterNd on the CPU, then copies output
864 // back to GPU. This is useful because the CPU implementation is deterministic
865 // and the GPU implementation is not. Tensor inputs to this function must be on
866 // the GPU.
867 template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
DoScatterNdOnCpu(OpKernelContext * c,const Tensor & indices,const Tensor & updates,const TensorShape & shape,Tensor * out,bool allocate)868 Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
869 const Tensor& updates, const TensorShape& shape,
870 Tensor* out, bool allocate) {
871 AllocatorAttributes alloc_attr;
872 alloc_attr.set_on_host(true);
873 alloc_attr.set_gpu_compatible(true);
874 auto stream = c->op_device_context()->stream();
875
876 // Copy 'indices' to host.
877 Tensor host_indices;
878 TF_RETURN_IF_ERROR(c->allocate_temp(indices.dtype(), indices.shape(),
879 &host_indices, alloc_attr));
880 se::DeviceMemoryBase indices_ptr(
881 const_cast<Tensor&>(indices).flat<Index>().data(),
882 indices.flat<Index>().size() * sizeof(Index));
883 stream->ThenMemcpy(host_indices.flat<Index>().data(), indices_ptr,
884 indices.NumElements() * sizeof(Index));
885 if (!stream) {
886 return errors::Internal("Failed to copy indices to host");
887 }
888
889 // Copy 'updates' to host.
890 Tensor host_updates;
891 TF_RETURN_IF_ERROR(c->allocate_temp(updates.dtype(), updates.shape(),
892 &host_updates, alloc_attr));
893 se::DeviceMemoryBase updates_ptr(
894 const_cast<Tensor&>(updates).flat<T>().data(),
895 updates.flat<T>().size() * sizeof(T));
896 stream->ThenMemcpy(host_updates.flat<T>().data(), updates_ptr,
897 updates.NumElements() * sizeof(T));
898 if (!stream) {
899 return errors::Internal("Failed to copy updates to host");
900 }
901
902 // Create 'out' on host, copying from device if 'allocate' is false.
903 Tensor host_out;
904 TF_RETURN_IF_ERROR(
905 c->allocate_temp(updates.dtype(), shape, &host_out, alloc_attr));
906 if (allocate) {
907 TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out));
908 functor::SetZeroFunctor<CPUDevice, T> fill;
909 fill(c->eigen_device<CPUDevice>(), host_out.flat<T>());
910 } else {
911 CHECK_NOTNULL(out); // Crash OK
912 se::DeviceMemoryBase out_ptr(out->flat<T>().data(),
913 out->flat<T>().size() * sizeof(T));
914 stream->ThenMemcpy(host_out.flat<T>().data(), out_ptr,
915 host_out.NumElements() * sizeof(T));
916 if (!stream) {
917 return errors::Internal("Failed to copy output to host");
918 }
919 }
920
921 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
922 TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
923 c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));
924
925 // Copy 'host_out' to device.
926 se::DeviceMemoryBase out_ptr(out->flat<T>().data(),
927 out->flat<T>().size() * sizeof(T));
928 stream->ThenMemcpy(&out_ptr, host_out.flat<T>().data(),
929 host_out.NumElements() * sizeof(T));
930 if (!stream) {
931 return errors::Internal("Failed to copy output to device");
932 }
933 // Block host, since 'host_out' cannot be destructed until the copy is done.
934 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
935 return OkStatus();
936 }
937
938 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
939
940 } // namespace
941
942 template <typename Device, typename T, typename Index,
943 scatter_nd_op::UpdateOp Op>
DoScatterNd(OpKernelContext * c,const Tensor & indices,const Tensor & updates,const TensorShape & shape,Tensor * out,bool allocate)944 Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
945 const Tensor& updates, const TensorShape& shape, Tensor* out,
946 bool allocate) {
947 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
948 if (std::is_same<Device, GPUDevice>::value &&
949 tensorflow::OpDeterminismRequired()) {
950 return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
951 allocate);
952 }
953 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
954 int64_t slice_dim;
955 Index num_updates;
956 Index slice_size;
957 TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
958 shape, indices, updates, &slice_dim, &num_updates, &slice_size));
959
960 IndexFlattener<Device, Index> index_flattener;
961 auto indices_flat = index_flattener(c, indices);
962 auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
963
964 if (allocate) {
965 AllocatorAttributes alloc_attr;
966 if (std::is_same<Device, CPUDevice>::value) {
967 alloc_attr.set_on_host(true);
968 }
969 TF_RETURN_IF_ERROR(
970 c->allocate_temp(DataTypeToEnum<T>::value, shape, out, alloc_attr));
971 } else {
972 CHECK_NOTNULL(out);
973 }
974
975 if (shape.num_elements() == 0) {
976 return OkStatus();
977 }
978
979 if (allocate) {
980 // Brand new tensor, zero it out.
981 functor::SetZeroFunctor<Device, T> fill;
982 fill(c->eigen_device<Device>(), out->flat<T>());
983 }
984 auto output_matrix =
985 out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size});
986
987 Index bad_i = -1;
988
989 if (shape.num_elements() > 0) {
990 switch (slice_dim) {
991 #define PARAMS_CASE(IXDIM) \
992 case IXDIM: { \
993 typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
994 for (int i = 0; i < IXDIM; ++i) { \
995 output_shape_prefix[i] = shape.dim_size(i); \
996 } \
997 functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \
998 bad_i = \
999 functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
1000 output_matrix, indices_flat, updates_flat, output_matrix); \
1001 } break
1002 // TODO(simister): Re-enable this once binary size is under control.
1003 // PARAMS_CASE(0);
1004 PARAMS_CASE(1);
1005 PARAMS_CASE(2);
1006 PARAMS_CASE(3);
1007 PARAMS_CASE(4);
1008 PARAMS_CASE(5);
1009 PARAMS_CASE(6);
1010 PARAMS_CASE(7);
1011 #undef PARAMS_CASE
1012 default:
1013 return errors::InvalidArgument(
1014 "Only indices.shape[-1] values between 1 and 5 "
1015 "are currently supported. Requested rank: ",
1016 slice_dim);
1017 }
1018 }
1019 if (bad_i >= 0) {
1020 auto slice_shape = indices.shape();
1021 slice_shape.RemoveLastDims(1);
1022 return errors::InvalidArgument(
1023 "indices", SliceDebugString(slice_shape, bad_i), " = [",
1024 absl::StrJoin(
1025 gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
1026 "] does not index into shape ", shape.DebugString());
1027 }
1028 return OkStatus();
1029 }
1030 } // namespace functor
1031
1032 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1033 // Forward declarations of the functor specializations for GPU.
1034 namespace functor {
1035 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
1036 template <> \
1037 Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \
1038 const GPUDevice& d, const Index slice_size, \
1039 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
1040 typename TTypes<T, 2>::Tensor Tparams, \
1041 typename TTypes<Index, 2>::ConstTensor Tindices, \
1042 typename TTypes<T, 2>::ConstTensor Tupdates, \
1043 typename TTypes<T, 2>::Tensor Toutput); \
1044 extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
1045
1046 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
1047 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
1048 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
1049 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
1050 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
1051 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \
1052 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \
1053 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7);
1054
1055 #define DECLARE_GPU_SPECS_INDEX(T, Index) \
1056 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
1057 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
1058 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
1059
1060 #define DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, Index) \
1061 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN); \
1062 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX)
1063
1064 #define DECLARE_GPU_SPECS(T) \
1065 DECLARE_GPU_SPECS_INDEX(T, int32); \
1066 DECLARE_GPU_SPECS_INDEX(T, int64_t)
1067
1068 #define DECLARE_GPU_SPECS_MIN_MAX(T) \
1069 DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int32); \
1070 DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int64_t)
1071
1072 TF_CALL_int32(DECLARE_GPU_SPECS);
1073 TF_CALL_int32(DECLARE_GPU_SPECS_MIN_MAX);
1074 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
1075 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MIN_MAX);
1076 TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
1077
1078 #undef DECLARE_GPU_SPECS_MIN_MAX
1079 #undef DECLARE_GPU_SPECS
1080 #undef DECLARE_GPU_SPECS_INDEX_MIN_MAX
1081 #undef DECLARE_GPU_SPECS_INDEX
1082 #undef DECLARE_GPU_SPECS_INDEX_OP
1083
1084 } // namespace functor
1085
1086 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1087
1088 } // namespace tensorflow
1089