• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/state_ops.cc.
17 #define EIGEN_USE_THREADS
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #include "tensorflow/core/platform/stream_executor.h"
22 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/kernels/dense_update_functor.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/inplace_ops_functor.h"
33 #include "tensorflow/core/kernels/scatter_nd_op.h"
34 #include "tensorflow/core/kernels/scatter_nd_util.h"
35 #include "tensorflow/core/kernels/training_op_helpers.h"
36 #include "tensorflow/core/kernels/variable_ops.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/util/determinism.h"
41 #include "tensorflow/core/util/util.h"
42 
43 namespace tensorflow {
44 
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46 typedef Eigen::GpuDevice GPUDevice;
47 
48 // Returns true if the three tensors have valid number of elements
49 // If shape_input has 0 elements, then we need to have indices and updates with
50 // exactly 0 elements too, otherwise we should error. If indices has 0 elements
51 // then updates should also have 0 elements, otherwise we should error.
ValidEmptyOutputShape(int64_t num_inputs,int64_t num_indices,int64_t num_updates)52 bool ValidEmptyOutputShape(int64_t num_inputs, int64_t num_indices,
53                            int64_t num_updates) {
54   if (num_indices == 0 && num_updates == 0) {
55     return true;  // regardless of num_inputs ?= 0, covers both cases
56   }
57   // now we want all 3 tensors to have values
58   return (num_inputs != 0 && num_indices != 0 && num_updates != 0);
59 }
60 
61 template <typename Device, typename T, typename Index>
62 class ScatterNdOp : public OpKernel {
63  public:
ScatterNdOp(OpKernelConstruction * c)64   explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
65     const DataType dt = DataTypeToEnum<T>::v();
66     const DataType index_t = DataTypeToEnum<Index>::v();
67     OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
68   }
69 
Compute(OpKernelContext * c)70   void Compute(OpKernelContext* c) override {
71     const Tensor& indices = c->input(0);
72     const Tensor& updates = c->input(1);
73     const Tensor& shape_input = c->input(2);
74 
75     OP_REQUIRES(c, indices.shape().dims() >= 1,
76                 errors::InvalidArgument(
77                     "Indices shape must have rank at least one. Found:",
78                     indices.shape().DebugString()));
79     OP_REQUIRES(c, updates.shape().dims() >= 1,
80                 errors::InvalidArgument(
81                     "Updates shape must have rank at least one. Found:",
82                     updates.shape().DebugString()));
83 
84     auto vec = shape_input.flat<Index>();
85     TensorShape shape;
86     OP_REQUIRES_OK(c,
87                    TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
88 
89     OP_REQUIRES(c,
90                 ValidEmptyOutputShape(shape_input.NumElements(),
91                                       indices.shape().num_elements(),
92                                       updates.shape().num_elements()),
93                 errors::InvalidArgument(
94                     "Indices and updates specified for empty output shape"));
95 
96     const int64_t outer_dims = indices.shape().dims() - 1;
97 
98     for (int i = 0; i < outer_dims; ++i) {
99       OP_REQUIRES(
100           c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
101           errors::InvalidArgument(
102               "Dimensions [0,", outer_dims,
103               ") of indices[shape=", indices.shape().DebugString(),
104               "] must match dimensions [0,", outer_dims,
105               ") of updates[shape=", updates.shape().DebugString(), "]"));
106     }
107 
108     const int64_t ix = indices.shape().dim_size(outer_dims);
109     OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix,
110                 errors::InvalidArgument(
111                     "Dimensions [", ix, ",", shape.dims(), ") of input[shape=",
112                     shape.DebugString(), "] must match dimensions [",
113                     outer_dims, ",", updates.shape().dims(),
114                     ") of updates[shape=", updates.shape().DebugString(), "]"));
115 
116     for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
117       OP_REQUIRES(
118           c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
119           errors::InvalidArgument("Dimensions [", ix, ",", shape.dims(),
120                                   ") of input[shape=", shape.DebugString(),
121                                   "] must match dimensions [", outer_dims, ",",
122                                   updates.shape().dims(), ") of updates[shape=",
123                                   updates.shape().DebugString(), "]"));
124     }
125     OP_REQUIRES(c, shape_input.dims() == 1,
126                 errors::InvalidArgument("Shape must be a vector"));
127 
128     Tensor out;
129     OP_REQUIRES_OK(
130         c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
131                c, indices, updates, shape, &out, true /*allocate*/));
132     c->set_output(0, out);
133   }
134 };
135 
136 template <typename Device, typename T, typename Index,
137           scatter_nd_op::UpdateOp op>
138 class TensorScatterOp : public OpKernel {
139  public:
TensorScatterOp(OpKernelConstruction * c)140   explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) {
141     const DataType dt = DataTypeToEnum<T>::v();
142     const DataType index_t = DataTypeToEnum<Index>::v();
143     OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
144   }
145 
Compute(OpKernelContext * c)146   void Compute(OpKernelContext* c) override {
147     const Tensor& input = c->input(0);
148     const Tensor& indices = c->input(1);
149     const Tensor& updates = c->input(2);
150 
151     OP_REQUIRES(c, indices.shape().dims() >= 1,
152                 errors::InvalidArgument(
153                     "Indices shape must have rank at least one. Found:",
154                     indices.shape().DebugString()));
155     OP_REQUIRES(c, updates.shape().dims() >= 1,
156                 errors::InvalidArgument(
157                     "Updates shape must have rank at least one. Found:",
158                     updates.shape().DebugString()));
159 
160     TensorShape shape = input.shape();
161 
162     OP_REQUIRES(c,
163                 ValidEmptyOutputShape(shape.num_elements(),
164                                       indices.shape().num_elements(),
165                                       updates.shape().num_elements()),
166                 errors::InvalidArgument(
167                     "Indices and updates specified for empty output shape"));
168 
169     const int64_t outer_dims = indices.shape().dims() - 1;
170 
171     for (int i = 0; i < outer_dims; ++i) {
172       OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
173                   errors::InvalidArgument(
174                       "Outer dimensions of indices and update must match. "
175                       "Indices shape: ",
176                       indices.shape().DebugString(),
177                       ", updates shape:", updates.shape().DebugString()));
178     }
179 
180     const int64_t ix = indices.shape().dim_size(outer_dims);
181     OP_REQUIRES(
182         c, updates.shape().dims() - outer_dims == shape.dims() - ix,
183         errors::InvalidArgument("Inner dimensions of output shape must match "
184                                 "inner dimensions of updates shape. Output: ",
185                                 shape.DebugString(),
186                                 " updates: ", updates.shape().DebugString()));
187     for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
188       OP_REQUIRES(
189           c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
190           errors::InvalidArgument(
191               "The inner ", shape.dims() - ix,
192               " dimensions of output.shape=", shape.DebugString(),
193               " must match the inner ", updates.shape().dims() - outer_dims,
194               " dimensions of updates.shape=", updates.shape().DebugString()));
195     }
196 
197     AllocatorAttributes alloc_attr;
198     MemoryType memory_type = DEVICE_MEMORY;
199     if (std::is_same<Device, CPUDevice>::value) {
200       alloc_attr.set_on_host(true);
201       memory_type = HOST_MEMORY;
202     } else {
203       memory_type = DEVICE_MEMORY;
204     }
205     std::unique_ptr<Tensor> forwarded_input =
206         c->forward_input(0, 0, input.dtype(), shape, memory_type, alloc_attr);
207 
208     if (forwarded_input == nullptr) {
209       // We were not able to forward the input, so we deep copy the tensor and
210       // set the output.
211       Tensor* out;
212       OP_REQUIRES_OK(c, c->allocate_output(0, input.shape(), &out));
213 
214       OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(),
215                                                     input, out));
216       OP_REQUIRES_OK(c,
217                      functor::DoScatterNd<Device, T, Index, op>(
218                          c, indices, updates, shape, out, false /*allocate*/));
219     } else {
220       // Output forwarded, so simply perform the scatter.
221       OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
222                             c, indices, updates, shape, forwarded_input.get(),
223                             false /*allocate*/));
224 
225       c->set_output(0, *forwarded_input);
226     }
227   }
228 };
229 
230 template <typename Device, typename T, typename Index,
231           scatter_nd_op::UpdateOp op>
232 class ScatterNdUpdateOp : public OpKernel {
233  public:
ScatterNdUpdateOp(OpKernelConstruction * c)234   explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
235     const DataType dt = DataTypeToEnum<T>::v();
236     const DataType dt_ref = DataTypeToEnum<T>::ref();
237     const DataType index_t = DataTypeToEnum<Index>::v();
238     dtype_ = c->input_type(0);
239     // If we are updating a resource, we always use the exclusive lock.
240     // For ref types, we lock based on the use_locking parameter
241     // Otherwise, we don't mutate the input tensor (we copy-on-write if needed).
242     if (c->input_type(0) == DT_RESOURCE) {
243       // TODO(apassos): what to validate here?
244     } else if (IsRefType(c->input_type(0))) {
245       OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
246       OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
247     } else {
248       OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
249       use_exclusive_lock_ = false;
250     }
251   }
252 
Compute(OpKernelContext * c)253   void Compute(OpKernelContext* c) override {
254     if (dtype_ == DT_RESOURCE) {
255       core::RefCountPtr<Var> v;
256       OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
257       OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
258       mutex_lock m(*v->mu());
259       DoCompute(c);
260     } else if (use_exclusive_lock_) {
261       // If we're here, it means the input type is a ref.
262       DCHECK(IsRefType(c->input_dtype(0)));
263       // Hold mutex while we apply updates
264       mutex_lock l(*c->input_ref_mutex(0));
265       DoCompute(c);
266     } else {
267       DoCompute(c);
268     }
269   }
270 
271  private:
272   DataType dtype_;
273   bool use_exclusive_lock_;
274 
DoCompute(OpKernelContext * c)275   void DoCompute(OpKernelContext* c) {
276     const Tensor& indices = c->input(1);
277     const Tensor& updates = c->input(2);
278     Tensor params;
279     TensorShape params_shape;
280 
281     if (dtype_ == DT_RESOURCE) {
282       core::RefCountPtr<Var> v;
283       OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
284       Tensor* t = v->tensor();
285       params = *t;
286       params_shape = params.shape();
287     } else if (IsRefType(c->input_dtype(0))) {
288       params = c->mutable_input(0, use_exclusive_lock_);
289       params_shape = params.shape();
290       c->forward_ref_input_to_ref_output(0, 0);
291       OP_REQUIRES(c, params.IsInitialized(),
292                   errors::FailedPrecondition("Null ref for params"));
293     } else {
294       Tensor* params_ptr;
295       params_shape = c->input(0).shape();
296       if (!c->forward_input_to_output_with_shape(0, 0, params_shape,
297                                                  &params_ptr)) {
298         // We weren't able to forward the input to output, so just
299         // allocate a new output tensor and copy the values over.
300         OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, &params_ptr));
301         params = *params_ptr;
302         functor::DenseUpdate<Device, T, ASSIGN> copy;
303         const Tensor& input_copy = c->input(0);
304         copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>());
305       } else {
306         params = *params_ptr;
307       }
308     }
309 
310     OP_REQUIRES_OK(
311         c, functor::DoScatterNd<Device, T, Index, op>(
312                c, indices, updates, params_shape, &params, false /*allocate*/));
313   }
314 };
315 
316 #define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
317   REGISTER_KERNEL_BUILDER(Name(name)                                  \
318                               .Device(DEVICE_##dev)                   \
319                               .TypeConstraint<type>("T")              \
320                               .TypeConstraint<index_type>("Tindices") \
321                               .HostMemory("shape"),                   \
322                           ScatterNdOp<dev##Device, type, index_type>)
323 
324 #define REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(index_type, name)  \
325   REGISTER_KERNEL_BUILDER(Name(name)                                  \
326                               .Device(DEVICE_DEFAULT)                 \
327                               .TypeConstraint<int32>("T")             \
328                               .TypeConstraint<index_type>("Tindices") \
329                               .HostMemory("indices")                  \
330                               .HostMemory("updates")                  \
331                               .HostMemory("shape")                    \
332                               .HostMemory("output"),                  \
333                           ScatterNdOp<CPUDevice, int32, index_type>)
334 
335 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
336                                                 op)                          \
337   REGISTER_KERNEL_BUILDER(                                                   \
338       Name(name)                                                             \
339           .Device(DEVICE_##dev)                                              \
340           .TypeConstraint<type>("T")                                         \
341           .TypeConstraint<index_type>("Tindices"),                           \
342       ScatterNdUpdateOp<dev##Device, type, index_type, op>)
343 
344 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, name, \
345                                                           op)               \
346   REGISTER_KERNEL_BUILDER(Name(name)                                        \
347                               .Device(DEVICE_DEFAULT)                       \
348                               .TypeConstraint<int32>("T")                   \
349                               .TypeConstraint<index_type>("Tindices")       \
350                               .HostMemory("ref")                            \
351                               .HostMemory("indices")                        \
352                               .HostMemory("updates")                        \
353                               .HostMemory("output_ref"),                    \
354                           ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
355 
356 #define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU( \
357     index_type, name, op)                                               \
358   REGISTER_KERNEL_BUILDER(Name(name)                                    \
359                               .Device(DEVICE_DEFAULT)                   \
360                               .TypeConstraint<int32>("T")               \
361                               .TypeConstraint<index_type>("Tindices")   \
362                               .HostMemory("input")                      \
363                               .HostMemory("indices")                    \
364                               .HostMemory("updates")                    \
365                               .HostMemory("output"),                    \
366                           ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
367 
368 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
369                                                          dev, name, op)    \
370   REGISTER_KERNEL_BUILDER(                                                 \
371       Name(name)                                                           \
372           .Device(DEVICE_##dev)                                            \
373           .TypeConstraint<type>("T")                                       \
374           .TypeConstraint<index_type>("Tindices")                          \
375           .HostMemory("ref"),                                              \
376       ScatterNdUpdateOp<dev##Device, type, index_type, op>)
377 
378 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(index_type, \
379                                                                    name, op)   \
380   REGISTER_KERNEL_BUILDER(Name(name)                                           \
381                               .Device(DEVICE_DEFAULT)                          \
382                               .TypeConstraint<int32>("T")                      \
383                               .TypeConstraint<index_type>("Tindices")          \
384                               .HostMemory("ref")                               \
385                               .HostMemory("indices")                           \
386                               .HostMemory("updates"),                          \
387                           ScatterNdUpdateOp<CPUDevice, int32, index_type, op>)
388 
389 #define REGISTER_SCATTER_ND_KERNEL(type, dev, name)         \
390   REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
391   REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64_t, dev, name)
392 
393 #define REGISTER_SCATTER_ND_KERNEL_INT32_GPU(name)         \
394   REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int32, name); \
395   REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU(int64_t, name)
396 
397 #define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op)         \
398   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
399   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op)
400 
401 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op)         \
402   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \
403   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op)
404 
405 #define REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU(name, op)    \
406   REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, \
407                                                                  op);         \
408   REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t,     \
409                                                                  name, op)
410 
411 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op)    \
412   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
413                                                    op);                    \
414   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64_t, dev, name, op)
415 
416 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(name, op)         \
417   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int32, name, op); \
418   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU(int64_t, name, op)
419 
420 #define REGISTER_SCATTER_ND_ADD_SUB(type, dev)                            \
421   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd",            \
422                                     scatter_nd_op::UpdateOp::ADD);        \
423   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
424                                     scatter_nd_op::UpdateOp::ADD);        \
425   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub",            \
426                                     scatter_nd_op::UpdateOp::SUB);        \
427   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                             \
428       type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD);   \
429   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                             \
430       type, dev, "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
431 
432 #define REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU()                              \
433   REGISTER_SCATTER_ND_NON_ALIASING_UPDATE_KERNEL_INT32_GPU(                  \
434       "ScatterNdNonAliasingAdd", scatter_nd_op::UpdateOp::ADD);              \
435   REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdAdd",                \
436                                               scatter_nd_op::UpdateOp::ADD); \
437   REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdSub",                \
438                                               scatter_nd_op::UpdateOp::SUB); \
439   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(                      \
440       "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD);                 \
441   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(                      \
442       "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
443 
444 #define REGISTER_SCATTER_ND(type, dev) \
445   REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
446 
447 #define REGISTER_SCATTER_ND_INT32_GPU() \
448   REGISTER_SCATTER_ND_KERNEL_INT32_GPU("ScatterNd");
449 
450 #define REGISTER_SCATTER_ND_UPDATE(type, dev)                         \
451   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate",     \
452                                     scatter_nd_op::UpdateOp::ASSIGN); \
453   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                         \
454       type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
455 
456 #define REGISTER_SCATTER_ND_UPDATE_INT32_GPU()             \
457   REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(             \
458       "ScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); \
459   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(    \
460       "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
461 
462 #define REGISTER_SCATTER_ND_MIN_MAX(type, dev)                          \
463   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMax",          \
464                                     scatter_nd_op::UpdateOp::MAX);      \
465   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMin",          \
466                                     scatter_nd_op::UpdateOp::MIN);      \
467   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                           \
468       type, dev, "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \
469   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                           \
470       type, dev, "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
471 
472 #define REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU()                              \
473   REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMax",                \
474                                               scatter_nd_op::UpdateOp::MAX); \
475   REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU("ScatterNdMin",                \
476                                               scatter_nd_op::UpdateOp::MIN); \
477   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(                      \
478       "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN);                 \
479   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU(                      \
480       "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
481 
482 // Registers CPU kernels.
483 #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
484   REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
485 
486 #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
487   REGISTER_SCATTER_ND_UPDATE(type, CPU);
488 
489 #define REGISTER_SCATTER_ND_MIN_MAX_CPU(type) \
490   REGISTER_SCATTER_ND_MIN_MAX(type, CPU);
491 
492 #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
493 #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
494 
495 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
496 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
497 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
498 TF_CALL_tstring(REGISTER_SCATTER_ND_CPU);
499 TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU);
500 TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
501 TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
502 TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
503 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_CPU);
504 
505 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \
506                                                           dev)              \
507   REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate")                       \
508                               .Device(DEVICE_##dev)                         \
509                               .TypeConstraint<type>("T")                    \
510                               .TypeConstraint<index_type>("Tindices"),      \
511                           TensorScatterOp<dev##Device, type, index_type,    \
512                                           scatter_nd_op::UpdateOp::ASSIGN>)
513 
514 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(index_type) \
515   REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate")                      \
516                               .Device(DEVICE_DEFAULT)                      \
517                               .TypeConstraint<int32>("T")                  \
518                               .TypeConstraint<index_type>("Tindices")      \
519                               .HostMemory("tensor")                        \
520                               .HostMemory("indices")                       \
521                               .HostMemory("updates")                       \
522                               .HostMemory("output"),                       \
523                           TensorScatterOp<CPUDevice, int32, index_type,    \
524                                           scatter_nd_op::UpdateOp::ASSIGN>)
525 
526 #define REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, index_type, dev) \
527   REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd")                            \
528                               .Device(DEVICE_##dev)                           \
529                               .TypeConstraint<type>("T")                      \
530                               .TypeConstraint<index_type>("Tindices"),        \
531                           TensorScatterOp<dev##Device, type, index_type,      \
532                                           scatter_nd_op::UpdateOp::ADD>)
533 
534 #define REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(index_type) \
535   REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd")                      \
536                               .Device(DEVICE_DEFAULT)                   \
537                               .TypeConstraint<int32>("T")               \
538                               .TypeConstraint<index_type>("Tindices")   \
539                               .HostMemory("tensor")                     \
540                               .HostMemory("indices")                    \
541                               .HostMemory("updates")                    \
542                               .HostMemory("output"),                    \
543                           TensorScatterOp<CPUDevice, int32, index_type, \
544                                           scatter_nd_op::UpdateOp::ADD>)
545 
546 #define REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, index_type, dev) \
547   REGISTER_KERNEL_BUILDER(Name("TensorScatterSub")                            \
548                               .Device(DEVICE_##dev)                           \
549                               .TypeConstraint<type>("T")                      \
550                               .TypeConstraint<index_type>("Tindices"),        \
551                           TensorScatterOp<dev##Device, type, index_type,      \
552                                           scatter_nd_op::UpdateOp::SUB>)
553 
554 #define REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(index_type) \
555   REGISTER_KERNEL_BUILDER(Name("TensorScatterSub")                      \
556                               .Device(DEVICE_DEFAULT)                   \
557                               .TypeConstraint<int32>("T")               \
558                               .TypeConstraint<index_type>("Tindices")   \
559                               .HostMemory("tensor")                     \
560                               .HostMemory("indices")                    \
561                               .HostMemory("updates")                    \
562                               .HostMemory("output"),                    \
563                           TensorScatterOp<CPUDevice, int32, index_type, \
564                                           scatter_nd_op::UpdateOp::SUB>)
565 
566 #define REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, index_type, dev) \
567   REGISTER_KERNEL_BUILDER(Name("TensorScatterMin")                            \
568                               .Device(DEVICE_##dev)                           \
569                               .TypeConstraint<type>("T")                      \
570                               .TypeConstraint<index_type>("Tindices"),        \
571                           TensorScatterOp<dev##Device, type, index_type,      \
572                                           scatter_nd_op::UpdateOp::MIN>)
573 
574 #define REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(index_type) \
575   REGISTER_KERNEL_BUILDER(Name("TensorScatterMin")                      \
576                               .Device(DEVICE_DEFAULT)                   \
577                               .TypeConstraint<int32>("T")               \
578                               .TypeConstraint<index_type>("Tindices")   \
579                               .HostMemory("tensor")                     \
580                               .HostMemory("indices")                    \
581                               .HostMemory("updates")                    \
582                               .HostMemory("output"),                    \
583                           TensorScatterOp<CPUDevice, int32, index_type, \
584                                           scatter_nd_op::UpdateOp::MIN>)
585 
586 #define REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, index_type, dev) \
587   REGISTER_KERNEL_BUILDER(Name("TensorScatterMax")                            \
588                               .Device(DEVICE_##dev)                           \
589                               .TypeConstraint<type>("T")                      \
590                               .TypeConstraint<index_type>("Tindices"),        \
591                           TensorScatterOp<dev##Device, type, index_type,      \
592                                           scatter_nd_op::UpdateOp::MAX>)
593 
594 #define REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(index_type) \
595   REGISTER_KERNEL_BUILDER(Name("TensorScatterMax")                      \
596                               .Device(DEVICE_DEFAULT)                   \
597                               .TypeConstraint<int32>("T")               \
598                               .TypeConstraint<index_type>("Tindices")   \
599                               .HostMemory("tensor")                     \
600                               .HostMemory("indices")                    \
601                               .HostMemory("updates")                    \
602                               .HostMemory("output"),                    \
603                           TensorScatterOp<CPUDevice, int32, index_type, \
604                                           scatter_nd_op::UpdateOp::MAX>)
605 
606 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type)                    \
607   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \
608   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, CPU);
609 
610 #define REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type)                    \
611   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, CPU); \
612   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, CPU);
613 
614 #define REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type)                    \
615   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \
616   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, CPU);
617 
618 #define REGISTER_SCATTER_ND_TENSOR_MIN_CPU(type)                    \
619   REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, CPU); \
620   REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, CPU);
621 
622 #define REGISTER_SCATTER_ND_TENSOR_MAX_CPU(type)                    \
623   REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, CPU); \
624   REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, CPU);
625 
626 #define REGISTER_SCATTER_ND_TENSOR_CPU(type)   \
627   REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \
628   REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type);    \
629   REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type);
630 
631 // Register TensorScatterUpdate/Add/Sub for all number types.
632 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU);
633 // Register min/max operations only for Real number types
634 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MIN_CPU);
635 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MAX_CPU);
636 // Register only TensorScatterUpdate for string/bool types as well.
637 TF_CALL_tstring(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
638 TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
639 
640 #undef REGISTER_SCATTER_ND_TENSOR_CPU
641 
642 // Registers GPU kernels.
643 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
644 
645 #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
646   REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
647 
648 #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
649   REGISTER_SCATTER_ND_UPDATE(type, GPU);
650 
651 #define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \
652   REGISTER_SCATTER_ND_MIN_MAX(type, GPU);
653 
654 #define REGISTER_SCATTER_ND_ALL_GPU(type) \
655   REGISTER_SCATTER_ND_ADD_SUB_GPU(type);  \
656   REGISTER_SCATTER_ND_UPDATE_GPU(type);   \
657   REGISTER_SCATTER_ND_GPU(type);
658 
659 #define REGISTER_SCATTER_ND_ALL_INT32_GPU() \
660   REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU();  \
661   REGISTER_SCATTER_ND_UPDATE_INT32_GPU();   \
662   REGISTER_SCATTER_ND_INT32_GPU();
663 
664 REGISTER_SCATTER_ND_ALL_INT32_GPU();
665 REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU();
666 
667 TF_CALL_int64(REGISTER_SCATTER_ND_ALL_GPU);
668 TF_CALL_int64(REGISTER_SCATTER_ND_MIN_MAX_GPU);
669 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
670 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_GPU);
671 TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
672 
673 #undef REGISTER_SCATTER_ND_ALL_GPU
674 
675 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type)                    \
676   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, GPU); \
677   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64_t, GPU);
678 
679 #define REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type)                    \
680   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, GPU); \
681   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64_t, GPU);
682 
683 #define REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type)                    \
684   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \
685   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64_t, GPU);
686 
687 #define REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type)                    \
688   REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, GPU); \
689   REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64_t, GPU);
690 
691 #define REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type)                    \
692   REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, GPU); \
693   REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64_t, GPU);
694 
695 #define REGISTER_SCATTER_ND_TENSOR_GPU(type)   \
696   REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type);    \
697   REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \
698   REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type);
699 
700 #define REGISTER_SCATTER_ND_TENSOR_INT32_GPU()                   \
701   REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int32);    \
702   REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE(int64_t);  \
703   REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int32);    \
704   REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE(int64_t);  \
705   REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int32); \
706   REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE(int64_t);
707 
708 #define REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX(type) \
709   REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type);          \
710   REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type);
711 
712 #define REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU()          \
713   REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int32);   \
714   REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE(int64_t); \
715   REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int32);   \
716   REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE(int64_t);
717 
718 REGISTER_SCATTER_ND_TENSOR_INT32_GPU();
719 REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU();
720 
721 TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU);
722 TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
723 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
724 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
725 TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_TENSOR_GPU);
726 
727 #undef REGISTER_SCATTER_ND_ADD
728 #undef REGISTER_SCATTER_ND_ADD_SUB
729 #undef REGISTER_SCATTER_ND_ADD_SUB_CPU
730 #undef REGISTER_SCATTER_ND_ADD_SUB_GPU
731 #undef REGISTER_SCATTER_ND_MIN_MAX
732 #undef REGISTER_SCATTER_ND_MIN_MAX_CPU
733 #undef REGISTER_SCATTER_ND_MIN_MAX_GPU
734 #undef REGISTER_SCATTER_ND_UPDATE
735 #undef REGISTER_SCATTER_ND_UPDATE_CPU
736 #undef REGISTER_SCATTER_ND_UPDATE_GPU
737 #undef REGISTER_SCATTER_ND_KERNEL
738 #undef REGISTER_SCATTER_ND_KERNEL_INDEX
739 #undef REGISTER_SCATTER_ND_TENSOR_TYPE_INDEX_TYPE
740 #undef REGISTER_SCATTER_ND_TENSOR_CPU
741 #undef REGISTER_SCATTER_ND_TENSOR_GPU
742 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE
743 #undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE
744 #undef REGISTER_SCATTER_ND_TENSOR_ADD_INT32_GPU_INDEX_TYPE
745 #undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE
746 #undef REGISTER_SCATTER_ND_TENSOR_SUB_INT32_GPU_INDEX_TYPE
747 #undef REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE
748 #undef REGISTER_SCATTER_ND_TENSOR_MIN_INT32_GPU_INDEX_TYPE
749 #undef REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE
750 #undef REGISTER_SCATTER_ND_TENSOR_MAX_INT32_GPU_INDEX_TYPE
751 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU
752 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_INT32_GPU_INDEX_TYPE
753 #undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU
754 #undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
755 #undef REGISTER_SCATTER_ND_TENSOR_MIN_GPU
756 #undef REGISTER_SCATTER_ND_TENSOR_MAX_GPU
757 #undef REGISTER_SCATTER_ND_TENSOR_GPU
758 #undef REGISTER_SCATTER_ND_TENSOR_INT32_GPU
759 #undef REGISTER_SCATTER_ND_TENSOR_MIN_MAX_INT32_GPU
760 #undef REGISTER_SCATTER_ND_ADD_SUB_INT32_GPU
761 #undef REGISTER_SCATTER_ND_ALL_INT32_GPU
762 #undef REGISTER_SCATTER_ND_MIN_MAX_INT32_GPU
763 #undef REGISTER_SCATTER_ND_INT32_GPU
764 #undef REGISTER_SCATTER_ND_UPDATE_INT32_GPU
765 #undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INT32_GPU
766 #undef REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU
767 #undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INT32_GPU
768 #undef REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX_INT32_GPU
769 #undef REGISTER_SCATTER_ND_KERNEL_INT32_GPU
770 #undef REGISTER_SCATTER_ND_KERNEL_INDEX_INT32_GPU
771 
772 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
773 
774 namespace functor {
775 
776 template <typename Index>
PrepareAndValidateInputs(const TensorShape & params_shape,const Tensor & indices,const Tensor & updates,int64_t * slice_dim,Index * num_updates,Index * slice_size)777 Status PrepareAndValidateInputs(const TensorShape& params_shape,
778                                 const Tensor& indices, const Tensor& updates,
779                                 int64_t* slice_dim, Index* num_updates,
780                                 Index* slice_size) {
781   const TensorShape& indices_shape(indices.shape());
782   const TensorShape& updates_shape(updates.shape());
783 
784   if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) {
785     return errors::InvalidArgument("Output must be at least 1-D, ",
786                                    "got shape: ", params_shape.DebugString());
787   }
788 
789   if (!ValidEmptyOutputShape(params_shape.num_elements(),
790                              indices_shape.num_elements(),
791                              updates_shape.num_elements())) {
792     return errors::InvalidArgument(
793         "Indices and updates specified for empty output.  indices shape: ",
794         indices.shape().DebugString());
795   }
796 
797   if (updates.dim_size(0) != indices.dim_size(0)) {
798     return errors::InvalidArgument(
799         "Dimensions [0,1) of indices[shape=", indices_shape.DebugString(),
800         "] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[",
801         "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0));
802   }
803   TF_RETURN_IF_ERROR(ValidateScatterNdUpdateShape(params_shape, indices.shape(),
804                                                   updates.shape()));
805 
806   // Check that we have enough index space
807   const int64_t N_big = indices.NumElements();
808   if (N_big > std::numeric_limits<Index>::max()) {
809     return errors::InvalidArgument("indices has too many elements for ",
810                                    DataTypeString(DataTypeToEnum<Index>::v()),
811                                    " indexing: ", N_big, " > ",
812                                    std::numeric_limits<Index>::max());
813   }
814   if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) {
815     return errors::InvalidArgument("params_shape[0] too large for ",
816                                    DataTypeString(DataTypeToEnum<Index>::v()),
817                                    " indexing: ", params_shape.dim_size(0),
818                                    " > ", std::numeric_limits<Index>::max());
819   }
820 
821   // Calculate the number of dimensions in indices
822   *slice_dim = (indices_shape.dims() > 1)
823                    ? indices_shape.dim_size(indices_shape.dims() - 1)
824                    : 1;
825 
826   // Calculate the number of elements that make up each slice of our updated
827   // tensor. This allows us to work with flattened tensors and copy over whole
828   // slices at a time.
829   Index total_nd = params_shape.dims();
830 
831   int64_t slice_size_big = 1;
832   for (int64_t i = *slice_dim; i < total_nd; ++i) {
833     slice_size_big *= params_shape.dim_size(i);
834   }
835 
836   if (slice_size_big > std::numeric_limits<Index>::max()) {
837     return errors::InvalidArgument(
838         "slice size is too large for indexing: ", slice_size_big, " > ",
839         std::numeric_limits<Index>::max());
840   }
841 
842   *slice_size = static_cast<Index>(slice_size_big);
843 
844   const int64_t safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim;
845   *num_updates = indices_shape.num_elements() / safe_slice_dim;
846 
847   return OkStatus();
848 }
849 
850 template <typename Device, typename Index>
851 class IndexFlattener {
852  public:
operator ()(OpKernelContext *,const Tensor & indices)853   inline typename TTypes<Index, 2>::ConstTensor operator()(
854       OpKernelContext*, const Tensor& indices) {
855     return indices.flat_inner_dims<Index>();
856   }
857 };
858 
859 namespace {
860 
861 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
862 
863 // Copies inputs to the CPU, runs DoScatterNd on the CPU, then copies output
864 // back to GPU. This is useful because the CPU implementation is deterministic
865 // and the GPU implementation is not. Tensor inputs to this function must be on
866 // the GPU.
867 template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
DoScatterNdOnCpu(OpKernelContext * c,const Tensor & indices,const Tensor & updates,const TensorShape & shape,Tensor * out,bool allocate)868 Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
869                         const Tensor& updates, const TensorShape& shape,
870                         Tensor* out, bool allocate) {
871   AllocatorAttributes alloc_attr;
872   alloc_attr.set_on_host(true);
873   alloc_attr.set_gpu_compatible(true);
874   auto stream = c->op_device_context()->stream();
875 
876   // Copy 'indices' to host.
877   Tensor host_indices;
878   TF_RETURN_IF_ERROR(c->allocate_temp(indices.dtype(), indices.shape(),
879                                       &host_indices, alloc_attr));
880   se::DeviceMemoryBase indices_ptr(
881       const_cast<Tensor&>(indices).flat<Index>().data(),
882       indices.flat<Index>().size() * sizeof(Index));
883   stream->ThenMemcpy(host_indices.flat<Index>().data(), indices_ptr,
884                      indices.NumElements() * sizeof(Index));
885   if (!stream) {
886     return errors::Internal("Failed to copy indices to host");
887   }
888 
889   // Copy 'updates' to host.
890   Tensor host_updates;
891   TF_RETURN_IF_ERROR(c->allocate_temp(updates.dtype(), updates.shape(),
892                                       &host_updates, alloc_attr));
893   se::DeviceMemoryBase updates_ptr(
894       const_cast<Tensor&>(updates).flat<T>().data(),
895       updates.flat<T>().size() * sizeof(T));
896   stream->ThenMemcpy(host_updates.flat<T>().data(), updates_ptr,
897                      updates.NumElements() * sizeof(T));
898   if (!stream) {
899     return errors::Internal("Failed to copy updates to host");
900   }
901 
902   // Create 'out' on host, copying from device if 'allocate' is false.
903   Tensor host_out;
904   TF_RETURN_IF_ERROR(
905       c->allocate_temp(updates.dtype(), shape, &host_out, alloc_attr));
906   if (allocate) {
907     TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out));
908     functor::SetZeroFunctor<CPUDevice, T> fill;
909     fill(c->eigen_device<CPUDevice>(), host_out.flat<T>());
910   } else {
911     CHECK_NOTNULL(out);  // Crash OK
912     se::DeviceMemoryBase out_ptr(out->flat<T>().data(),
913                                  out->flat<T>().size() * sizeof(T));
914     stream->ThenMemcpy(host_out.flat<T>().data(), out_ptr,
915                        host_out.NumElements() * sizeof(T));
916     if (!stream) {
917       return errors::Internal("Failed to copy output to host");
918     }
919   }
920 
921   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
922   TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
923       c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));
924 
925   // Copy 'host_out' to device.
926   se::DeviceMemoryBase out_ptr(out->flat<T>().data(),
927                                out->flat<T>().size() * sizeof(T));
928   stream->ThenMemcpy(&out_ptr, host_out.flat<T>().data(),
929                      host_out.NumElements() * sizeof(T));
930   if (!stream) {
931     return errors::Internal("Failed to copy output to device");
932   }
933   // Block host, since 'host_out' cannot be destructed until the copy is done.
934   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
935   return OkStatus();
936 }
937 
938 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
939 
940 }  // namespace
941 
942 template <typename Device, typename T, typename Index,
943           scatter_nd_op::UpdateOp Op>
DoScatterNd(OpKernelContext * c,const Tensor & indices,const Tensor & updates,const TensorShape & shape,Tensor * out,bool allocate)944 Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
945                    const Tensor& updates, const TensorShape& shape, Tensor* out,
946                    bool allocate) {
947 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
948   if (std::is_same<Device, GPUDevice>::value &&
949       tensorflow::OpDeterminismRequired()) {
950     return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
951                                           allocate);
952   }
953 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
954   int64_t slice_dim;
955   Index num_updates;
956   Index slice_size;
957   TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
958       shape, indices, updates, &slice_dim, &num_updates, &slice_size));
959 
960   IndexFlattener<Device, Index> index_flattener;
961   auto indices_flat = index_flattener(c, indices);
962   auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
963 
964   if (allocate) {
965     AllocatorAttributes alloc_attr;
966     if (std::is_same<Device, CPUDevice>::value) {
967       alloc_attr.set_on_host(true);
968     }
969     TF_RETURN_IF_ERROR(
970         c->allocate_temp(DataTypeToEnum<T>::value, shape, out, alloc_attr));
971   } else {
972     CHECK_NOTNULL(out);
973   }
974 
975   if (shape.num_elements() == 0) {
976     return OkStatus();
977   }
978 
979   if (allocate) {
980     // Brand new tensor, zero it out.
981     functor::SetZeroFunctor<Device, T> fill;
982     fill(c->eigen_device<Device>(), out->flat<T>());
983   }
984   auto output_matrix =
985       out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size});
986 
987   Index bad_i = -1;
988 
989   if (shape.num_elements() > 0) {
990     switch (slice_dim) {
991 #define PARAMS_CASE(IXDIM)                                                  \
992   case IXDIM: {                                                             \
993     typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix;    \
994     for (int i = 0; i < IXDIM; ++i) {                                       \
995       output_shape_prefix[i] = shape.dim_size(i);                           \
996     }                                                                       \
997     functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor;         \
998     bad_i =                                                                 \
999         functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
1000                 output_matrix, indices_flat, updates_flat, output_matrix);  \
1001   } break
1002       // TODO(simister): Re-enable this once binary size is under control.
1003       //      PARAMS_CASE(0);
1004       PARAMS_CASE(1);
1005       PARAMS_CASE(2);
1006       PARAMS_CASE(3);
1007       PARAMS_CASE(4);
1008       PARAMS_CASE(5);
1009       PARAMS_CASE(6);
1010       PARAMS_CASE(7);
1011 #undef PARAMS_CASE
1012       default:
1013         return errors::InvalidArgument(
1014             "Only indices.shape[-1] values between 1 and 5 "
1015             "are currently supported.  Requested rank: ",
1016             slice_dim);
1017     }
1018   }
1019   if (bad_i >= 0) {
1020     auto slice_shape = indices.shape();
1021     slice_shape.RemoveLastDims(1);
1022     return errors::InvalidArgument(
1023         "indices", SliceDebugString(slice_shape, bad_i), " = [",
1024         absl::StrJoin(
1025             gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
1026         "] does not index into shape ", shape.DebugString());
1027   }
1028   return OkStatus();
1029 }
1030 }  // namespace functor
1031 
1032 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1033 // Forward declarations of the functor specializations for GPU.
1034 namespace functor {
1035 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM)           \
1036   template <>                                                           \
1037   Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()(   \
1038       const GPUDevice& d, const Index slice_size,                       \
1039       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
1040       typename TTypes<T, 2>::Tensor Tparams,                            \
1041       typename TTypes<Index, 2>::ConstTensor Tindices,                  \
1042       typename TTypes<T, 2>::ConstTensor Tupdates,                      \
1043       typename TTypes<T, 2>::Tensor Toutput);                           \
1044   extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
1045 
1046 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op)     \
1047   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
1048   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
1049   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
1050   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
1051   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \
1052   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \
1053   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7);
1054 
1055 #define DECLARE_GPU_SPECS_INDEX(T, Index)                                \
1056   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
1057   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD);    \
1058   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
1059 
1060 #define DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, Index)                     \
1061   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN); \
1062   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX)
1063 
1064 #define DECLARE_GPU_SPECS(T)         \
1065   DECLARE_GPU_SPECS_INDEX(T, int32); \
1066   DECLARE_GPU_SPECS_INDEX(T, int64_t)
1067 
1068 #define DECLARE_GPU_SPECS_MIN_MAX(T)         \
1069   DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int32); \
1070   DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int64_t)
1071 
1072 TF_CALL_int32(DECLARE_GPU_SPECS);
1073 TF_CALL_int32(DECLARE_GPU_SPECS_MIN_MAX);
1074 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
1075 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MIN_MAX);
1076 TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
1077 
1078 #undef DECLARE_GPU_SPECS_MIN_MAX
1079 #undef DECLARE_GPU_SPECS
1080 #undef DECLARE_GPU_SPECS_INDEX_MIN_MAX
1081 #undef DECLARE_GPU_SPECS_INDEX
1082 #undef DECLARE_GPU_SPECS_INDEX_OP
1083 
1084 }  // namespace functor
1085 
1086 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1087 
1088 }  // namespace tensorflow
1089