• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/state_ops.cc.
17 #define EIGEN_USE_THREADS
18 
19 #if GOOGLE_CUDA
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA
22 
23 #include "tensorflow/core/kernels/scatter_nd_op.h"
24 
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/kernels/dense_update_functor.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/inplace_ops_functor.h"
33 #include "tensorflow/core/kernels/training_op_helpers.h"
34 #include "tensorflow/core/kernels/variable_ops.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/util.h"
39 
40 #ifdef TENSORFLOW_USE_SYCL
41 #include "tensorflow/core/common_runtime/sycl/sycl_util.h"
42 #endif  // TENSORFLOW_USE_SYCL
43 
44 namespace tensorflow {
45 
46 typedef Eigen::ThreadPoolDevice CPUDevice;
47 typedef Eigen::GpuDevice GPUDevice;
48 #ifdef TENSORFLOW_USE_SYCL
49 typedef Eigen::SyclDevice SYCLDevice;
50 #endif  // TENSORFLOW_USE_SYCL
51 
52 // Returns true if the three tensors have valid number of elements
53 // If shape_input has 0 elements, then we need to have indices and updates with
54 // exactly 0 elements too, otherwise we should error. If indices has 0 elements
55 // then updates should also have 0 elements, otherwise we should error.
ValidEmptyOutputShape(int64 num_inputs,int64 num_indices,int64 num_updates)56 bool ValidEmptyOutputShape(int64 num_inputs, int64 num_indices,
57                            int64 num_updates) {
58   if (num_indices == 0 && num_updates == 0) {
59     return true;  // regardless of num_inputs ?= 0, covers both cases
60   }
61   // now we want all 3 tensors to have values
62   return (num_inputs != 0 && num_indices != 0 && num_updates != 0);
63 }
64 
65 template <typename Device, typename T, typename Index>
66 class ScatterNdOp : public OpKernel {
67  public:
ScatterNdOp(OpKernelConstruction * c)68   explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
69     const DataType dt = DataTypeToEnum<T>::v();
70     const DataType index_t = DataTypeToEnum<Index>::v();
71     OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
72   }
73 
Compute(OpKernelContext * c)74   void Compute(OpKernelContext* c) override {
75     const Tensor& indices = c->input(0);
76     const Tensor& updates = c->input(1);
77     const Tensor& shape_input = c->input(2);
78 
79     OP_REQUIRES(c, indices.shape().dims() >= 1,
80                 errors::InvalidArgument(
81                     "Indices shape must have rank at least one. Found:",
82                     indices.shape().DebugString()));
83     OP_REQUIRES(c, updates.shape().dims() >= 1,
84                 errors::InvalidArgument(
85                     "Updates shape must have rank at least one. Found:",
86                     updates.shape().DebugString()));
87 
88     auto vec = shape_input.flat<Index>();
89     TensorShape shape;
90     OP_REQUIRES_OK(c,
91                    TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
92 
93     OP_REQUIRES(c,
94                 ValidEmptyOutputShape(shape_input.NumElements(),
95                                       indices.shape().num_elements(),
96                                       updates.shape().num_elements()),
97                 errors::InvalidArgument(
98                     "Indices and updates specified for empty output shape"));
99 
100     const int64 outer_dims = indices.shape().dims() - 1;
101 
102     for (int i = 0; i < outer_dims; ++i) {
103       OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
104                   errors::InvalidArgument(
105                       "Outer dimensions of indices and update must match. "
106                       "Indices shape: ",
107                       indices.shape().DebugString(),
108                       ", updates shape:", updates.shape().DebugString()));
109     }
110 
111     const int64 ix = indices.shape().dim_size(outer_dims);
112     OP_REQUIRES(
113         c, updates.shape().dims() - outer_dims == shape.dims() - ix,
114         errors::InvalidArgument("Inner dimensions of output shape must match "
115                                 "inner dimensions of updates shape. Output: ",
116                                 shape.DebugString(),
117                                 " updates: ", updates.shape().DebugString()));
118     for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
119       OP_REQUIRES(
120           c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
121           errors::InvalidArgument(
122               "The inner ", shape.dims() - ix,
123               " dimensions of output.shape=", shape.DebugString(),
124               " must match the inner ", updates.shape().dims() - outer_dims,
125               " dimensions of updates.shape=", updates.shape().DebugString()));
126     }
127     OP_REQUIRES(c, shape_input.dims() == 1,
128                 errors::InvalidArgument("Shape must be a vector"));
129 
130     Tensor out;
131     OP_REQUIRES_OK(
132         c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
133                c, indices, updates, shape, &out, true /*allocate*/));
134     c->set_output(0, out);
135   }
136 };
137 
138 template <typename Device, typename T, typename Index,
139           scatter_nd_op::UpdateOp op>
140 class TensorScatterOp : public OpKernel {
141  public:
TensorScatterOp(OpKernelConstruction * c)142   explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) {
143     const DataType dt = DataTypeToEnum<T>::v();
144     const DataType index_t = DataTypeToEnum<Index>::v();
145     OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
146   }
147 
Compute(OpKernelContext * c)148   void Compute(OpKernelContext* c) override {
149     const Tensor& input = c->input(0);
150     const Tensor& indices = c->input(1);
151     const Tensor& updates = c->input(2);
152 
153     OP_REQUIRES(c, indices.shape().dims() >= 1,
154                 errors::InvalidArgument(
155                     "Indices shape must have rank at least one. Found:",
156                     indices.shape().DebugString()));
157     OP_REQUIRES(c, updates.shape().dims() >= 1,
158                 errors::InvalidArgument(
159                     "Updates shape must have rank at least one. Found:",
160                     updates.shape().DebugString()));
161 
162     TensorShape shape = input.shape();
163 
164     OP_REQUIRES(c,
165                 ValidEmptyOutputShape(shape.num_elements(),
166                                       indices.shape().num_elements(),
167                                       updates.shape().num_elements()),
168                 errors::InvalidArgument(
169                     "Indices and updates specified for empty output shape"));
170 
171     const int64 outer_dims = indices.shape().dims() - 1;
172 
173     for (int i = 0; i < outer_dims; ++i) {
174       OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
175                   errors::InvalidArgument(
176                       "Outer dimensions of indices and update must match. "
177                       "Indices shape: ",
178                       indices.shape().DebugString(),
179                       ", updates shape:", updates.shape().DebugString()));
180     }
181 
182     const int64 ix = indices.shape().dim_size(outer_dims);
183     OP_REQUIRES(
184         c, updates.shape().dims() - outer_dims == shape.dims() - ix,
185         errors::InvalidArgument("Inner dimensions of output shape must match "
186                                 "inner dimensions of updates shape. Output: ",
187                                 shape.DebugString(),
188                                 " updates: ", updates.shape().DebugString()));
189     for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
190       OP_REQUIRES(
191           c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
192           errors::InvalidArgument(
193               "The inner ", shape.dims() - ix,
194               " dimensions of output.shape=", shape.DebugString(),
195               " must match the inner ", updates.shape().dims() - outer_dims,
196               " dimensions of updates.shape=", updates.shape().DebugString()));
197     }
198 
199     std::unique_ptr<Tensor> forwarded_input = c->forward_input(
200         0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes());
201 
202     if (forwarded_input == nullptr) {
203       // We were not able to forward the input, so we deep copy the tensor and
204       // set the output.
205       Tensor* out;
206       OP_REQUIRES_OK(c, c->allocate_output(0, input.shape(), &out));
207 
208       OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(),
209                                                     input, out));
210       OP_REQUIRES_OK(c,
211                      functor::DoScatterNd<Device, T, Index, op>(
212                          c, indices, updates, shape, out, false /*allocate*/));
213     } else {
214       // Output forwarded, so simply perform the scatter.
215       OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
216                             c, indices, updates, shape, forwarded_input.get(),
217                             false /*allocate*/));
218 
219       c->set_output(0, *forwarded_input);
220     }
221   }
222 };
223 
224 template <typename Device, typename T, typename Index,
225           scatter_nd_op::UpdateOp op>
226 class ScatterNdUpdateOp : public OpKernel {
227  public:
ScatterNdUpdateOp(OpKernelConstruction * c)228   explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
229     const DataType dt = DataTypeToEnum<T>::v();
230     const DataType dt_ref = DataTypeToEnum<T>::ref();
231     const DataType index_t = DataTypeToEnum<Index>::v();
232     dtype_ = c->input_type(0);
233     if (c->input_type(0) == DT_RESOURCE) {
234       // TODO(apassos): what to validate here?
235     } else if (IsRefType(c->input_type(0))) {
236       OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
237       OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
238     } else {
239       OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
240       use_exclusive_lock_ = false;
241     }
242   }
243 
Compute(OpKernelContext * c)244   void Compute(OpKernelContext* c) override {
245     if (dtype_ == DT_RESOURCE) {
246       Var* v;
247       OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
248       core::ScopedUnref scoped_unref(v);
249       OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
250       mutex_lock m(*v->mu());
251       DoCompute(c);
252     } else if (use_exclusive_lock_) {
253       // If we're here, it means the input type is a ref.
254       DCHECK(IsRefType(c->input_dtype(0)));
255       // Hold mutex while we apply updates
256       mutex_lock l(*c->input_ref_mutex(0));
257       DoCompute(c);
258     } else {
259       DoCompute(c);
260     }
261   }
262 
263  private:
264   DataType dtype_;
265   bool use_exclusive_lock_;
266 
DoCompute(OpKernelContext * c)267   void DoCompute(OpKernelContext* c) {
268     const Tensor& indices = c->input(1);
269     const Tensor& updates = c->input(2);
270     Tensor params;
271     TensorShape params_shape;
272 
273     if (dtype_ == DT_RESOURCE) {
274       Var* v;
275       OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
276       Tensor* t = v->tensor();
277       params = *t;
278       params_shape = params.shape();
279     } else if (IsRefType(c->input_dtype(0))) {
280       params = c->mutable_input(0, use_exclusive_lock_);
281       params_shape = params.shape();
282       c->forward_ref_input_to_ref_output(0, 0);
283       OP_REQUIRES(c, params.IsInitialized(),
284                   errors::FailedPrecondition("Null ref for params"));
285     } else {
286       Tensor* params_ptr;
287       params_shape = c->input(0).shape();
288       if (!c->forward_input_to_output_with_shape(0, 0, params_shape,
289                                                  &params_ptr)) {
290         // We weren't able to forward the input to output, so just
291         // allocate a new output tensor and copy the values over.
292         OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, &params_ptr));
293         params = *params_ptr;
294         functor::DenseUpdate<Device, T, ASSIGN> copy;
295         const Tensor& input_copy = c->input(0);
296         copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>());
297       } else {
298         params = *params_ptr;
299       }
300     }
301 
302     OP_REQUIRES_OK(
303         c, functor::DoScatterNd<Device, T, Index, op>(
304                c, indices, updates, params_shape, &params, false /*allocate*/));
305   }
306 };
307 
308 #define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
309   REGISTER_KERNEL_BUILDER(Name(name)                                  \
310                               .Device(DEVICE_##dev)                   \
311                               .TypeConstraint<type>("T")              \
312                               .TypeConstraint<index_type>("Tindices") \
313                               .HostMemory("shape"),                   \
314                           ScatterNdOp<dev##Device, type, index_type>)
315 
316 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
317                                                 op)                          \
318   REGISTER_KERNEL_BUILDER(                                                   \
319       Name(name)                                                             \
320           .Device(DEVICE_##dev)                                              \
321           .TypeConstraint<type>("T")                                         \
322           .TypeConstraint<index_type>("Tindices"),                           \
323       ScatterNdUpdateOp<dev##Device, type, index_type, op>)
324 
325 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
326                                                          dev, name, op)    \
327   REGISTER_KERNEL_BUILDER(                                                 \
328       Name(name)                                                           \
329           .Device(DEVICE_##dev)                                            \
330           .TypeConstraint<type>("T")                                       \
331           .TypeConstraint<index_type>("Tindices")                          \
332           .HostMemory("ref"),                                              \
333       ScatterNdUpdateOp<dev##Device, type, index_type, op>)
334 
335 #define REGISTER_SCATTER_ND_KERNEL(type, dev, name)         \
336   REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
337   REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
338 
339 #define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op)         \
340   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
341   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
342 
343 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op)    \
344   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
345                                                    op);                    \
346   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
347 
348 #define REGISTER_SCATTER_ND_ADD_SUB(type, dev)                            \
349   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd",            \
350                                     scatter_nd_op::UpdateOp::ADD);        \
351   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
352                                     scatter_nd_op::UpdateOp::ADD);        \
353   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub",            \
354                                     scatter_nd_op::UpdateOp::SUB);        \
355   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                             \
356       type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD);   \
357   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                             \
358       type, dev, "ResourceScatterNdSub", scatter_nd_op::UpdateOp::SUB);
359 
360 #define REGISTER_SCATTER_ND(type, dev) \
361   REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
362 
363 #define REGISTER_SCATTER_ND_UPDATE(type, dev)                         \
364   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate",     \
365                                     scatter_nd_op::UpdateOp::ASSIGN); \
366   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                         \
367       type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
368 
369 // Registers CPU kernels.
370 #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
371   REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
372 
373 #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
374   REGISTER_SCATTER_ND_UPDATE(type, CPU);
375 
376 #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
377 #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
378 
379 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
380 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
381 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
382 TF_CALL_string(REGISTER_SCATTER_ND_CPU);
383 TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
384 TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
385 TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
386 
387 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \
388                                                           dev)              \
389   REGISTER_KERNEL_BUILDER(Name("TensorScatterUpdate")                       \
390                               .Device(DEVICE_##dev)                         \
391                               .TypeConstraint<type>("T")                    \
392                               .TypeConstraint<index_type>("Tindices"),      \
393                           TensorScatterOp<dev##Device, type, index_type,    \
394                                           scatter_nd_op::UpdateOp::ASSIGN>)
395 
396 #define REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, index_type, dev) \
397   REGISTER_KERNEL_BUILDER(Name("TensorScatterAdd")                            \
398                               .Device(DEVICE_##dev)                           \
399                               .TypeConstraint<type>("T")                      \
400                               .TypeConstraint<index_type>("Tindices"),        \
401                           TensorScatterOp<dev##Device, type, index_type,      \
402                                           scatter_nd_op::UpdateOp::ADD>)
403 
404 #define REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, index_type, dev) \
405   REGISTER_KERNEL_BUILDER(Name("TensorScatterSub")                            \
406                               .Device(DEVICE_##dev)                           \
407                               .TypeConstraint<type>("T")                      \
408                               .TypeConstraint<index_type>("Tindices"),        \
409                           TensorScatterOp<dev##Device, type, index_type,      \
410                                           scatter_nd_op::UpdateOp::SUB>)
411 
412 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type)                    \
413   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \
414   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64, CPU);
415 
416 #define REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type)                    \
417   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, CPU); \
418   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64, CPU);
419 
420 #define REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type)                    \
421   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \
422   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, CPU);
423 
424 #define REGISTER_SCATTER_ND_TENSOR_CPU(type)   \
425   REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \
426   REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type);    \
427   REGISTER_SCATTER_ND_TENSOR_SUB_CPU(type);
428 
429 // Register TensorScatterUpdate/Add/Sub for all number types.
430 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU);
431 // Register only TensorScatterUpdate for string/bool types as well.
432 TF_CALL_string(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
433 TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
434 
435 #undef REGISTER_SCATTER_ND_TENSOR_CPU
436 
437 // Registers GPU kernels.
438 #if GOOGLE_CUDA
439 
440 #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
441   REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
442 
443 #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
444   REGISTER_SCATTER_ND_UPDATE(type, GPU);
445 
446 #define REGISTER_SCATTER_ND_ALL_GPU(type) \
447   REGISTER_SCATTER_ND_ADD_SUB_GPU(type);  \
448   REGISTER_SCATTER_ND_UPDATE_GPU(type);   \
449   REGISTER_SCATTER_ND_GPU(type);
450 
451 TF_CALL_int32(REGISTER_SCATTER_ND_ALL_GPU);
452 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
453 TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
454 TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
455 
456 #undef REGISTER_SCATTER_ND_ALL_GPU
457 
458 #ifdef TENSORFLOW_USE_SYCL
459 #define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \
460   REGISTER_SCATTER_ND_ADD_SUB(type, SYCL);
461 
462 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
463   REGISTER_SCATTER_ND_UPDATE(type, SYCL);
464 
465 TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
466 TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL);
467 TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL);
468 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
469 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
470 #undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
471 #undef REGISTER_SCATTER_ND_UPDATE_SYCL
472 #endif  // TENSORFLOW_USE_SYCL
473 
474 #define REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type)                    \
475   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, GPU); \
476   REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64, GPU);
477 
478 #define REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type)                    \
479   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int32, GPU); \
480   REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE(type, int64, GPU);
481 
482 #define REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type)                    \
483   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \
484   REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, GPU);
485 
486 #define REGISTER_SCATTER_ND_TENSOR_GPU(type)   \
487   REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type);    \
488   REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \
489   REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type);
490 
491 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
492 
493 #undef REGISTER_SCATTER_ND_ADD
494 #undef REGISTER_SCATTER_ND_ADD_SUB
495 #undef REGISTER_SCATTER_ND_ADD_SUB_CPU
496 #undef REGISTER_SCATTER_ND_ADD_SUB_GPU
497 #undef REGISTER_SCATTER_ND_UPDATE
498 #undef REGISTER_SCATTER_ND_UPDATE_CPU
499 #undef REGISTER_SCATTER_ND_UPDATE_GPU
500 #undef REGISTER_SCATTER_ND_KERNEL
501 #undef REGISTER_SCATTER_ND_KERNEL_INDEX
502 #undef REGISTER_SCATTER_ND_TENSOR_TYPE_INDEX_TYPE
503 #undef REGISTER_SCATTER_ND_TENSOR_CPU
504 #undef REGISTER_SCATTER_ND_TENSOR_GPU
505 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE
506 #undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE
507 #undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE
508 #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU
509 #undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU
510 #undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
511 #undef REGISTER_SCATTER_ND_TENSOR_GPU
512 
513 #endif  // GOOGLE_CUDA
514 
515 namespace functor {
516 // Check whether updates.shape = indices.shape[:batch_dim] +
517 // params_shape[slice_dim:]
ValidateUpdateShape(const TensorShape & params_shape,const Tensor & indices,const Tensor & updates)518 Status ValidateUpdateShape(const TensorShape& params_shape,
519                            const Tensor& indices, const Tensor& updates) {
520   const int64 slice_dim =
521       (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1;
522   const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1;
523 
524   auto shape_err = [&]() {
525     return errors::InvalidArgument(
526         "Must have updates.shape = indices.shape[:batch_dim] + ",
527         "params_shape[slice_dim:], got updates.shape: ",
528         updates.shape().DebugString(),
529         ", indices.shape: ", indices.shape().DebugString(),
530         ", params_shape: ", params_shape.DebugString(),
531         ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim);
532   };
533 
534   if (updates.dims() < batch_dim) return shape_err();
535   if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
536     return shape_err();
537   }
538   if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
539     return shape_err();
540   }
541   for (int d = 0; d < batch_dim; ++d) {
542     if (updates.dim_size(d) != indices.dim_size(d)) return shape_err();
543   }
544   for (int d = 0; d < updates.dims() - batch_dim; ++d) {
545     if (updates.dim_size(d + batch_dim) !=
546         params_shape.dim_size(d + slice_dim)) {
547       return shape_err();
548     }
549   }
550   return Status::OK();
551 }
552 
553 template <typename Index>
PrepareAndValidateInputs(const TensorShape & params_shape,const Tensor & indices,const Tensor & updates,int64 * slice_dim,Index * num_updates,Index * slice_size)554 Status PrepareAndValidateInputs(const TensorShape& params_shape,
555                                 const Tensor& indices, const Tensor& updates,
556                                 int64* slice_dim, Index* num_updates,
557                                 Index* slice_size) {
558   const TensorShape& indices_shape(indices.shape());
559   const TensorShape& updates_shape(updates.shape());
560 
561   if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) {
562     return errors::InvalidArgument("Output must be at least 1-D, ",
563                                    "got shape: ", params_shape.DebugString());
564   }
565 
566   if (!ValidEmptyOutputShape(params_shape.num_elements(),
567                              indices_shape.num_elements(),
568                              updates_shape.num_elements())) {
569     return errors::InvalidArgument(
570         "Indices and updates specified for empty output.  indices shape: ",
571         indices.shape().DebugString());
572   }
573 
574   if (updates.dim_size(0) != indices.dim_size(0)) {
575     return errors::InvalidArgument(
576         "The outermost dimension of updates and indices ",
577         "must match. Got indices.shape ", indices_shape.DebugString(),
578         ", updates.shape ", updates_shape.DebugString());
579   }
580   TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates));
581 
582   // Check that we have enough index space
583   const int64 N_big = indices.NumElements();
584   if (N_big > std::numeric_limits<Index>::max()) {
585     return errors::InvalidArgument("indices has too many elements for ",
586                                    DataTypeString(DataTypeToEnum<Index>::v()),
587                                    " indexing: ", N_big, " > ",
588                                    std::numeric_limits<Index>::max());
589   }
590   if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) {
591     return errors::InvalidArgument("params_shape[0] too large for ",
592                                    DataTypeString(DataTypeToEnum<Index>::v()),
593                                    " indexing: ", params_shape.dim_size(0),
594                                    " > ", std::numeric_limits<Index>::max());
595   }
596 
597   // Calculate the number of dimensions in indices
598   *slice_dim = (indices_shape.dims() > 1)
599                    ? indices_shape.dim_size(indices_shape.dims() - 1)
600                    : 1;
601 
602   // Calculate the number of elements that make up each slice of our updated
603   // tensor. This allows us to work with flattened tensors and copy over whole
604   // slices at a time.
605   Index total_nd = params_shape.dims();
606 
607   int64 slice_size_big = 1;
608   for (int64 i = *slice_dim; i < total_nd; ++i) {
609     slice_size_big *= params_shape.dim_size(i);
610   }
611 
612   if (slice_size_big > std::numeric_limits<Index>::max()) {
613     return errors::InvalidArgument(
614         "slice size is too large for indexing: ", slice_size_big, " > ",
615         std::numeric_limits<Index>::max());
616   }
617 
618   *slice_size = static_cast<Index>(slice_size_big);
619 
620   const int64 safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim;
621   *num_updates = indices_shape.num_elements() / safe_slice_dim;
622 
623   return Status::OK();
624 }
625 
626 template <typename Device, typename Index>
627 class IndexFlattener {
628  public:
operator ()(OpKernelContext *,const Tensor & indices)629   inline typename TTypes<Index, 2>::ConstTensor operator()(
630       OpKernelContext*, const Tensor& indices) {
631     return indices.flat_inner_dims<Index>();
632   }
633 };
634 
635 #ifdef TENSORFLOW_USE_SYCL
636 template <typename Index>
637 class IndexFlattener<SYCLDevice, Index> {
638  public:
IndexFlattener()639   IndexFlattener() { indices_host_ = nullptr; }
~IndexFlattener()640   ~IndexFlattener() { delete[] indices_host_; }
641 
operator ()(OpKernelContext * c,const Tensor & indices)642   inline typename TTypes<Index, 2>::ConstTensor operator()(
643       OpKernelContext* c, const Tensor& indices) {
644     size_t num_indices = indices.NumElements();
645     indices_host_ = new Index[num_indices];
646     auto device = c->eigen_sycl_device();
647     auto size = sizeof(Index) * num_indices;
648     auto src_ptr = GetBase(&indices);
649     device.memcpyDeviceToHost(indices_host_, static_cast<const Index*>(src_ptr),
650                               size);
651     return typename TTypes<Index, 2>::ConstTensor(
652         indices_host_, indices.shape().AsEigenDSizes<2>());
653   }
654 
655  private:
656   Index* indices_host_;
657 };
658 #endif
659 
660 template <typename Device, typename T, typename Index,
661           scatter_nd_op::UpdateOp Op>
DoScatterNd(OpKernelContext * c,const Tensor & indices,const Tensor & updates,const TensorShape & shape,Tensor * out,bool allocate)662 Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
663                    const Tensor& updates, const TensorShape& shape, Tensor* out,
664                    bool allocate) {
665   int64 slice_dim;
666   Index num_updates;
667   Index slice_size;
668   TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
669       shape, indices, updates, &slice_dim, &num_updates, &slice_size));
670 
671   IndexFlattener<Device, Index> index_flattener;
672   auto indices_flat = index_flattener(c, indices);
673   auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
674 
675   if (allocate) {
676     TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out));
677   } else {
678     CHECK_NOTNULL(out);
679   }
680 
681   if (shape.num_elements() == 0) {
682     return Status::OK();
683   }
684 
685   if (allocate) {
686     // Brand new tensor, zero it out.
687     functor::SetZeroFunctor<Device, T> fill;
688     fill(c->eigen_device<Device>(), out->flat<T>());
689   }
690   auto output_matrix =
691       out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size});
692 
693   Index bad_i = -1;
694 
695   if (shape.num_elements() > 0) {
696     switch (slice_dim) {
697 #define PARAMS_CASE(IXDIM)                                                  \
698   case IXDIM: {                                                             \
699     typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix;    \
700     for (int i = 0; i < IXDIM; ++i) {                                       \
701       output_shape_prefix[i] = shape.dim_size(i);                           \
702     }                                                                       \
703     functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor;         \
704     bad_i =                                                                 \
705         functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
706                 output_matrix, indices_flat, updates_flat, output_matrix);  \
707   } break
708       // TODO(simister): Re-enable this once binary size is under control.
709       //      PARAMS_CASE(0);
710       PARAMS_CASE(1);
711       PARAMS_CASE(2);
712       PARAMS_CASE(3);
713       PARAMS_CASE(4);
714       PARAMS_CASE(5);
715       PARAMS_CASE(6);
716       PARAMS_CASE(7);
717 #undef PARAMS_CASE
718       default:
719         return errors::InvalidArgument(
720             "Only indices.shape[-1] values between 1 and 5 "
721             "are currently supported.  Requested rank: ",
722             slice_dim);
723     }
724   }
725   if (bad_i >= 0) {
726     auto slice_shape = indices.shape();
727     slice_shape.RemoveLastDims(1);
728     return errors::InvalidArgument(
729         "indices", SliceDebugString(slice_shape, bad_i), " = [",
730         str_util::Join(
731             gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
732         "] does not index into shape ", shape.DebugString());
733   }
734   return Status::OK();
735 }
736 }  // namespace functor
737 
738 #ifdef GOOGLE_CUDA
739 // Forward declarations of the functor specializations for GPU.
740 namespace functor {
741 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM)           \
742   template <>                                                           \
743   Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()(   \
744       const GPUDevice& d, const Index slice_size,                       \
745       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
746       typename TTypes<T, 2>::Tensor Tparams,                            \
747       typename TTypes<Index, 2>::ConstTensor Tindices,                  \
748       typename TTypes<T, 2>::ConstTensor Tupdates,                      \
749       typename TTypes<T, 2>::Tensor Toutput);                           \
750   extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
751 
752 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op)     \
753   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
754   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
755   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
756   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
757   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \
758   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \
759   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7);
760 
761 #define DECLARE_GPU_SPECS_INDEX(T, Index)                                \
762   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
763   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD);    \
764   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
765 
766 #define DECLARE_GPU_SPECS(T)         \
767   DECLARE_GPU_SPECS_INDEX(T, int32); \
768   DECLARE_GPU_SPECS_INDEX(T, int64)
769 
770 TF_CALL_int32(DECLARE_GPU_SPECS);
771 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
772 TF_CALL_complex64(DECLARE_GPU_SPECS);
773 TF_CALL_complex128(DECLARE_GPU_SPECS);
774 
775 #undef DECLARE_GPU_SPECS
776 #undef DECLARE_GPU_SPECS_INDEX
777 #undef DECLARE_GPU_SPECS_INDEX_OP
778 
779 }  // namespace functor
780 
781 #endif  // GOOGLE_CUDA
782 
783 }  // namespace tensorflow
784