• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 //   - acquire the variable's mutex (in "shared" mode);
20 //   - make a (shallow) copy of the Tensor object, which increments
21 //     the reference count on the variable's TensorBuffer;
22 //   - release the variable's mutex;
23 //   - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 //   - acquire the variable's mutex (in "exclusive" mode);
26 //   - check the reference count of variable's TensorBuffer and
27 //     if it is >1, make a deep copy of the variable's Tensor;
28 //   - mutate the variable's Tensor;
29 //   - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 //   "shared" mode) for the duration of the whole read. This means
41 //   that as long as you only do sparse read operations no write will
42 //   see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 //   that they want to perform the write without locks held
45 //   (use_locking=false), we never copy even if the variable's
46 //   reference count is >1.
47 
48 #define EIGEN_USE_THREADS
49 
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #define EIGEN_USE_GPU
52 #endif
53 
54 #include "tensorflow/core/kernels/resource_variable_ops.h"
55 
56 #include <memory>
57 #include <vector>
58 
59 #include "absl/strings/str_join.h"
60 #include "tensorflow/core/common_runtime/device.h"
61 #include "tensorflow/core/framework/bounds_check.h"
62 #include "tensorflow/core/framework/op_kernel.h"
63 #include "tensorflow/core/framework/register_types.h"
64 #include "tensorflow/core/framework/resource_mgr.h"
65 #include "tensorflow/core/framework/tensor_shape.h"
66 #include "tensorflow/core/framework/tensor_types.h"
67 #include "tensorflow/core/framework/variant_op_registry.h"
68 #include "tensorflow/core/kernels/dense_update_functor.h"
69 #include "tensorflow/core/kernels/gather_functor.h"
70 #include "tensorflow/core/kernels/gather_nd_op.h"
71 #include "tensorflow/core/kernels/scatter_functor.h"
72 #include "tensorflow/core/kernels/training_op_helpers.h"
73 #include "tensorflow/core/kernels/variable_ops.h"
74 #include "tensorflow/core/lib/core/errors.h"
75 #include "tensorflow/core/lib/core/refcount.h"
76 #include "tensorflow/core/platform/casts.h"
77 #include "tensorflow/core/platform/mem.h"
78 #include "tensorflow/core/platform/mutex.h"
79 #include "tensorflow/core/platform/types.h"
80 #include "tensorflow/core/util/util.h"
81 
82 namespace tensorflow {
83 
84 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
85                         ResourceHandlesOp<Var>);
86 
ReadVariableOp(OpKernelConstruction * c)87 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
88   OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
89 }
90 
91 namespace {
92 
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)93 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
94   Tensor* output;
95   Notification n;
96   Status status;
97   AllocatorAttributes attr;
98   if (t->dtype() == DT_VARIANT) {
99     attr.set_on_host(true);
100   }
101   TF_RETURN_IF_ERROR(
102       ctx->allocate_output(output_idx, t->shape(), &output, attr));
103   if (t->dtype() == DT_VARIANT) {
104     output->flat<Variant>() = t->flat<Variant>();
105   } else if (ctx->op_device_context() != nullptr) {
106     // TODO(apassos): remove the down_cast by just returning Device* from
107     // OpKernelContext
108     Device* device = down_cast<Device*>(ctx->device());
109     ctx->op_device_context()->CopyTensorInSameDevice(
110         t, device, output, [&n, &status](const Status& s) {
111           status = s;
112           n.Notify();
113         });
114     n.WaitForNotification();
115     return status;
116   } else {
117     switch (t->dtype()) {
118 #define HANDLER(type)                       \
119   case DataTypeToEnum<type>::value:         \
120     output->flat<type>() = t->flat<type>(); \
121     break;
122       TF_CALL_ALL_TYPES(HANDLER);
123 #undef HANDLER
124       default:
125         return errors::Internal("Unsupported dtype", t->dtype());
126     }
127   }
128   return Status::OK();
129 }
130 
131 }  // namespace
132 
Compute(OpKernelContext * ctx)133 void ReadVariableOp::Compute(OpKernelContext* ctx) {
134   core::RefCountPtr<Var> variable;
135   const ResourceHandle& handle = HandleFromInput(ctx, 0);
136   const auto status = LookupResource(ctx, handle, &variable);
137   OP_REQUIRES(ctx, status.ok(),
138               errors::FailedPrecondition(
139                   "Could not find variable ", handle.name(), ". ",
140                   "This could mean that the variable has been deleted. ",
141                   "In TF1, it can also mean the variable is uninitialized. ",
142                   "Debug info: container=", handle.container(),
143                   ", status=", status.ToString()));
144 
145   tf_shared_lock ml(*variable->mu());
146   // We're acquiring a reference to the underlying buffer while
147   // holding a shared lock to guarantee ordering of reads and
148   // writes when in copy-on-write mode.
149   const Tensor* t = variable->tensor();
150   if (!variable->copy_on_read_mode.load()) {
151     OP_REQUIRES(
152         ctx, dtype_ == t->dtype(),
153         errors::InvalidArgument(
154             "Trying to read variable with wrong dtype. Expected ",
155             DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
156     ctx->set_output(0, *t);
157   } else {
158     OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
159   }
160 }
161 
ReadVariablesOp(OpKernelConstruction * c)162 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
163   int n;
164   OP_REQUIRES_OK(c, c->GetAttr("N", &n));
165   OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
166   OP_REQUIRES(c, n == dtypes_.size(),
167               errors::InvalidArgument(
168                   "Mismatched number of arguments to ReadVariablesOp (", n,
169                   " vs. ", dtypes_.size(), ")"));
170 }
171 
Compute(OpKernelContext * ctx)172 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
173   std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
174   std::vector<const ResourceHandle*> handles(dtypes_.size());
175   for (size_t i = 0; i < dtypes_.size(); ++i) {
176     handles[i] = &HandleFromInput(ctx, i);
177   }
178 
179   OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
180 
181   std::vector<string> uninitialized_vars;
182   for (int64_t i = 0; i < variables.size(); i++) {
183     if (variables[i] == nullptr) {
184       uninitialized_vars.push_back(handles[i]->name());
185     }
186   }
187 
188   OP_REQUIRES(ctx, uninitialized_vars.empty(),
189               errors::FailedPrecondition(
190                   "In ReadVariablesOp the following variables were "
191                   "found uninitialized: ",
192                   absl::StrJoin(uninitialized_vars, ", ")));
193 
194   for (size_t i = 0; i < dtypes_.size(); ++i) {
195     // We're acquiring a reference to the underlying buffer while
196     // holding a shared lock to guarantee ordering of reads and
197     // writes.
198     tf_shared_lock ml(*variables[i]->mu());
199     OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
200                 errors::InvalidArgument(
201                     "Trying to read variable ", handles[i]->name(),
202                     " from Container: ", handles[i]->container(),
203                     " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
204                     " got ", DataTypeString(variables[i]->tensor()->dtype())));
205     if (variables[i]->copy_on_read_mode.load()) {
206       OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
207     } else {
208       const Tensor& t = *variables[i]->tensor();
209       ctx->set_output(i, t);
210     }
211   }
212 }
213 
214 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
215                         ReadVariableOp);
216 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
217                         ReadVariablesOp);
218 
219 REGISTER_KERNEL_BUILDER(
220     Name("ReadVariableOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
221     ReadVariableOp);
222 REGISTER_KERNEL_BUILDER(
223     Name("_ReadVariablesOp").Device(DEVICE_DEFAULT).HostMemory("resources"),
224     ReadVariablesOp);
225 
VarHandleOp(OpKernelConstruction * context)226 VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
227   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
228   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
229 
230   OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
231   OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
232 
233   is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
234 
235   if (!is_anonymous_) {
236     AllocatorAttributes attr;
237     attr.set_on_host(true);
238     OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
239                                                    &resource_, attr));
240     resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
241         context, container_, name_,
242         std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
243   }
244 }
245 
Compute(OpKernelContext * ctx)246 void VarHandleOp::Compute(OpKernelContext* ctx) {
247   if (is_anonymous_) {
248     AllocatorAttributes attr;
249     attr.set_on_host(true);
250     Tensor handle;
251     OP_REQUIRES_OK(
252         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
253     handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
254         ctx, container_, name_,
255         std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
256         ctx->stack_trace());
257     ctx->set_output(0, handle);
258   } else {
259     ctx->set_output(0, resource_);
260   }
261 }
262 
263 REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
264 
265 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266 REGISTER_KERNEL_BUILDER(
267     Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
268     ReadVariableOp);
269 REGISTER_KERNEL_BUILDER(
270     Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
271     ReadVariablesOp);
272 
273 #define REGISTER_GPU_KERNELS(type)                             \
274   namespace functor {                                          \
275   template <>                                                  \
276   void DenseUpdate<GPUDevice, type, ASSIGN>::operator()(       \
277       const GPUDevice& d, typename TTypes<type>::Flat lhs,     \
278       typename TTypes<type>::ConstFlat rhs);                   \
279   extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
280   }                                                            \
281   REGISTER_KERNEL_BUILDER(Name("VarHandleOp")                  \
282                               .Device(DEVICE_GPU)              \
283                               .HostMemory("resource")          \
284                               .TypeConstraint<type>("dtype"),  \
285                           VarHandleOp)
286 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
287 TF_CALL_int64(REGISTER_GPU_KERNELS);
288 TF_CALL_variant(REGISTER_GPU_KERNELS);
289 TF_CALL_uint32(REGISTER_GPU_KERNELS);
290 #undef REGISTER_GPU_KERNELS
291 
292 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
293                             .Device(DEVICE_GPU)
294                             .HostMemory("resources")
295                             .TypeConstraint("dtypes",
296                                             {DT_INT64, DT_COMPLEX64,
297                                              DT_COMPLEX128, DT_HALF, DT_FLOAT,
298                                              DT_DOUBLE, DT_BOOL, DT_VARIANT}),
299                         ResourceHandlesOp<Var>);
300 
301 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
302 
303 #define REGISTER_DEFAULT_KERNELS(type)                        \
304   REGISTER_KERNEL_BUILDER(Name("VarHandleOp")                 \
305                               .Device(DEVICE_DEFAULT)         \
306                               .HostMemory("resource")         \
307                               .TypeConstraint<type>("dtype"), \
308                           VarHandleOp)
309 TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
310 TF_CALL_int64(REGISTER_DEFAULT_KERNELS);
311 TF_CALL_variant(REGISTER_DEFAULT_KERNELS);
312 TF_CALL_uint32(REGISTER_DEFAULT_KERNELS);
313 #undef REGISTER_DEFAULT_KERNELS
314 
315 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
316                             .Device(DEVICE_DEFAULT)
317                             .HostMemory("resources")
318                             .TypeConstraint("dtypes",
319                                             {DT_INT64, DT_COMPLEX64,
320                                              DT_COMPLEX128, DT_HALF, DT_FLOAT,
321                                              DT_DOUBLE, DT_BOOL, DT_VARIANT}),
322                         ResourceHandlesOp<Var>);
323 
324 REGISTER_KERNEL_BUILDER(
325     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
326     VariableShapeOp<int32>);
327 REGISTER_KERNEL_BUILDER(
328     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
329     VariableShapeOp<int64>);
330 
331 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
332 
333 REGISTER_KERNEL_BUILDER(Name("VariableShape")
334                             .Device(DEVICE_GPU)
335                             .TypeConstraint<int32>("out_type")
336                             .HostMemory("output")
337                             .HostMemory("input"),
338                         VariableShapeOp<int32>);
339 REGISTER_KERNEL_BUILDER(Name("VariableShape")
340                             .Device(DEVICE_GPU)
341                             .TypeConstraint<int64>("out_type")
342                             .HostMemory("output")
343                             .HostMemory("input"),
344                         VariableShapeOp<int64>);
345 
346 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
347 
DestroyResourceOp(OpKernelConstruction * ctx)348 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
349     : OpKernel(ctx) {
350   OP_REQUIRES_OK(ctx,
351                  ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
352 }
353 
Compute(OpKernelContext * ctx)354 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
355   const ResourceHandle& p = HandleFromInput(ctx, 0);
356   Status status = DeleteResource(ctx, p);
357   if (ignore_lookup_error_ && errors::IsNotFound(status)) {
358     return;
359   }
360   OP_REQUIRES_OK(ctx, status);
361 }
362 
363 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
364                         DestroyResourceOp);
365 REGISTER_KERNEL_BUILDER(
366     Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
367     DestroyResourceOp);
368 
369 template <typename Device, typename T>
370 class AssignVariableOp : public OpKernel {
371  public:
AssignVariableOp(OpKernelConstruction * c)372   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
373     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
374     if (!c->GetAttr("_grappler_relax_allocator_constraints",
375                     &relax_constraints_)
376              .ok()) {
377       relax_constraints_ = false;
378     }
379   }
380 
Compute(OpKernelContext * context)381   void Compute(OpKernelContext* context) override {
382     OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
383                 errors::InvalidArgument(
384                     "Variable and value dtypes don't match; respectively, ",
385                     DataTypeString(dtype_), " and ",
386                     DataTypeString(context->input(1).dtype())));
387     core::RefCountPtr<Var> variable;
388     const Tensor& value = context->input(1);
389     // Note: every resource-variable-manipulating op assumes copy-on-write
390     // semantics, and creates a copy of the variable's Tensor if its refcount is
391     // bigger than 1 when we try to modify it. This means we never need to copy
392     // the original tensor for AssignVariableOp; even if there are other live
393     // users of it we know none can modify it so this is always safe (even in
394     // esoteric cases where the same tensor is used to initialize multiple
395     // variables or the tensor is a constant this is safe, as future writes will
396     // trigger copies).
397     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
398                                 context, HandleFromInput(context, 0), &variable,
399                                 [this, &value](Var** ptr) {
400                                   *ptr = new Var(dtype_);
401                                   *(*ptr)->tensor() = value;
402                                   (*ptr)->is_initialized = true;
403                                   return Status::OK();
404                                 }));
405     mutex_lock ml(*variable->mu());
406     // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
407     // check below is to allow an XLA specific situation wherein update can
408     // happen first by the AssignVariableOp,
409     // in which case the variable is still uninitialized.
410     // When using TF-XLA, this scenario is possible when the execution uses the
411     // 'fallback' path (which essentially invokes Tensorflow ops via
412     // partitioned_call).
413     OP_REQUIRES(context,
414                 (variable->tensor()->dtype() == DT_INVALID &&
415                  !variable->is_initialized) ||
416                     variable->tensor()->dtype() == dtype_,
417                 errors::InvalidArgument(
418                     "Trying to assign variable with wrong dtype. Expected ",
419                     DataTypeString(variable->tensor()->dtype()), " got ",
420                     DataTypeString(dtype_)));
421     if (variable->copy_on_read_mode.load()) {
422       AllocatorAttributes attr;
423       attr.set_gpu_compatible(true);
424       attr.set_nic_compatible(true);
425       OP_REQUIRES_OK(context,
426                      context->allocate_temp(value.dtype(), value.shape(),
427                                             variable->tensor(), attr));
428       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
429       copy_functor(context->eigen_device<Device>(),
430                    variable->tensor()->flat<T>(), value.flat<T>());
431     } else {
432       *variable->tensor() = value;
433     }
434     variable->is_initialized = true;
435   }
436 
437  private:
438   DataType dtype_;
439   bool relax_constraints_;
440 };
441 
442 template <typename Device>
443 class AssignVariableOp<Device, Variant> : public OpKernel {
444  public:
AssignVariableOp(OpKernelConstruction * c)445   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
446     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
447     OP_REQUIRES(c, dtype_ == DT_VARIANT,
448                 errors::Internal("Variant kernel called with dtype: ",
449                                  DataTypeString(dtype_)));
450   }
451 
Compute(OpKernelContext * context)452   void Compute(OpKernelContext* context) override {
453     const Tensor& value = context->input(1);
454     core::RefCountPtr<Var> variable;
455     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
456                                 context, HandleFromInput(context, 0), &variable,
457                                 [](Var** ptr) {
458                                   // Created on host.
459                                   *ptr = new Var(DT_VARIANT);
460                                   return Status::OK();
461                                 }));
462 
463     // For purposes of forwarding DT_VARIANT, we want the least
464     // restrictive attr; we already know the input is on host.
465     AllocatorAttributes attr;
466 
467     // Copying is unnecessary if we are the last user of the value
468     // tensor, we can just adopt the input tensor's buffer instead.
469     // Note that Variant objects themselves always reside on host.
470     //
471     // We nevertheless want to signal to the runtime that the tensor
472     // should reside in memory of the associated device, as Variant
473     // tensors may be marked as sitting on either CPU or GPU.  This
474     // helps to elide one or more copies.
475     std::unique_ptr<Tensor> input_alias = context->forward_input(
476         1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
477         value.shape(),
478         DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
479         attr);
480 
481     mutex_lock ml(*variable->mu());
482     OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
483                 errors::InvalidArgument(
484                     "Trying to assign variable with wrong dtype. Expected ",
485                     DataTypeString(variable->tensor()->dtype()), " got ",
486                     DataTypeString(DT_VARIANT)));
487     variable->is_initialized = true;
488     *variable->tensor() = Tensor(DT_VARIANT, value.shape());
489 
490     if (input_alias) {
491       *variable->tensor() = *input_alias;
492       return;
493     }
494 
495     // Need to copy, but maybe we can re-use variable's buffer?
496     if (!variable->tensor()->RefCountIsOne() ||
497         !variable->tensor()->shape().IsSameSize(value.shape())) {
498       // Allocation of DT_VARIANT is always on host.
499       attr.set_on_host(true);
500       OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, value.shape(),
501                                                      variable->tensor(), attr));
502     }
503 
504     const auto elements_in = value.flat<Variant>();
505     auto elements_out = variable->tensor()->flat<Variant>();
506     for (int64_t i = 0; i < elements_in.size(); ++i) {
507       elements_out(i) = elements_in(i);
508     }
509   }
510 
511  private:
512   DataType dtype_;
513 };
514 
515 #define REGISTER_KERNELS(type)                                \
516   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")            \
517                               .Device(DEVICE_CPU)             \
518                               .TypeConstraint<type>("dtype"), \
519                           AssignVariableOp<Eigen::ThreadPoolDevice, type>);
520 
521 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
522 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
523 #undef REGISTER_KERNELS
524 
525 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
526 #define REGISTER_GPU_KERNELS(type)                           \
527   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")           \
528                               .Device(DEVICE_GPU)            \
529                               .TypeConstraint<type>("dtype") \
530                               .HostMemory("resource"),       \
531                           AssignVariableOp<GPUDevice, type>);
532 
533 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
534 TF_CALL_int64(REGISTER_GPU_KERNELS);
535 TF_CALL_variant(REGISTER_GPU_KERNELS);
536 TF_CALL_uint32(REGISTER_GPU_KERNELS);
537 #undef REGISTER_GPU_KERNELS
538 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
539 
540 template <typename Device, typename T, DenseUpdateType Op>
541 class AssignUpdateVariableOp : public OpKernel {
542  public:
AssignUpdateVariableOp(OpKernelConstruction * c)543   explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
544 
Compute(OpKernelContext * context)545   void Compute(OpKernelContext* context) override {
546     core::RefCountPtr<Var> variable;
547     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
548                                            &variable));
549 
550     const Tensor& value = context->input(1);
551     // TODO(apassos): We could possibly avoid the copy done by
552     // PrepareToUpdateVariable() for commutative operations like Op ==
553     // ADD if value's refcount was 1.
554     mutex_lock ml(*variable->mu());
555     Tensor* var_tensor = variable->tensor();
556     OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
557                 errors::InvalidArgument("Cannot update variable with shape ",
558                                         var_tensor->shape().DebugString(),
559                                         " using a Tensor with shape ",
560                                         value.shape().DebugString(),
561                                         ", shapes must be equal."));
562     OP_REQUIRES_OK(
563         context, PrepareToUpdateVariable<Device, T>(
564                      context, var_tensor, variable->copy_on_read_mode.load()));
565     functor::DenseUpdate<Device, T, Op> update_functor;
566     update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
567                    value.flat<T>());
568   }
569 };
570 
571 #define REGISTER_KERNELS(type)                                     \
572   REGISTER_KERNEL_BUILDER(                                         \
573       Name("AssignAddVariableOp")                                  \
574           .Device(DEVICE_CPU)                                      \
575           .TypeConstraint<type>("dtype"),                          \
576       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
577   REGISTER_KERNEL_BUILDER(                                         \
578       Name("AssignSubVariableOp")                                  \
579           .Device(DEVICE_CPU)                                      \
580           .TypeConstraint<type>("dtype"),                          \
581       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
582 
583 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
584 #undef REGISTER_KERNELS
585 
586 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
587 #define REGISTER_GPU_KERNELS(type)                                       \
588   REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp")                    \
589                               .Device(DEVICE_GPU)                        \
590                               .HostMemory("resource")                    \
591                               .TypeConstraint<type>("dtype"),            \
592                           AssignUpdateVariableOp<GPUDevice, type, ADD>); \
593   REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp")                    \
594                               .Device(DEVICE_GPU)                        \
595                               .HostMemory("resource")                    \
596                               .TypeConstraint<type>("dtype"),            \
597                           AssignUpdateVariableOp<GPUDevice, type, SUB>);
598 
599 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
600 TF_CALL_int64(REGISTER_GPU_KERNELS);
601 #undef REGISTER_GPU_KERNELS
602 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
603 
604 class VarIsInitializedOp : public OpKernel {
605  public:
VarIsInitializedOp(OpKernelConstruction * c)606   explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
607 
Compute(OpKernelContext * context)608   void Compute(OpKernelContext* context) override {
609     Tensor* output = nullptr;
610     OP_REQUIRES_OK(context,
611                    context->allocate_output(0, TensorShape({}), &output));
612     auto output_tensor = output->tensor<bool, 0>();
613     core::RefCountPtr<Var> variable;
614     Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
615     if (!s.ok()) {
616       output_tensor() = false;
617       return;
618     }
619     mutex_lock ml(*variable->mu());
620     output_tensor() = variable->is_initialized;
621   }
622 };
623 
624 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
625                         VarIsInitializedOp);
626 
627 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
628 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
629                             .Device(DEVICE_GPU)
630                             .HostMemory("resource")
631                             .HostMemory("is_initialized"),
632                         IsResourceInitialized<Var>);
633 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
634 
635 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
636                             .Device(DEVICE_DEFAULT)
637                             .HostMemory("resource")
638                             .HostMemory("is_initialized"),
639                         IsResourceInitialized<Var>);
640 
641 template <typename Device, typename T, typename Index>
642 class ResourceGatherOp : public OpKernel {
643  public:
ResourceGatherOp(OpKernelConstruction * c)644   explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
645     OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
646   }
647 
Compute(OpKernelContext * c)648   void Compute(OpKernelContext* c) override {
649     core::RefCountPtr<Var> v;
650     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
651     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
652     // NOTE: We hold the lock for the whole gather operation instead
653     // of increasing the reference count of v->tensor() to avoid a
654     // situation where a write to the same variable will see a
655     // reference count greater than one and make a copy of the
656     // (potentially very large) tensor buffer.
657     tf_shared_lock ml(*v->mu());
658     const Tensor& params = *v->tensor();
659     const Tensor& indices = c->input(1);
660     OP_REQUIRES(
661         c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
662         errors::InvalidArgument("params must be at least 1 dimensional"));
663     OP_REQUIRES(
664         c, params.shape().dims() >= batch_dims_,
665         errors::InvalidArgument("params must have at least ", batch_dims_,
666                                 " (batch_dims) dimensions but it has shape ",
667                                 params.shape().DebugString()));
668 
669     // Check that we have enough index space
670     const int64_t N = indices.NumElements();
671     OP_REQUIRES(
672         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
673         errors::InvalidArgument("params.shape[0] too large for ",
674                                 DataTypeString(DataTypeToEnum<Index>::v()),
675                                 " indexing: ", params.dim_size(0), " > ",
676                                 std::numeric_limits<Index>::max()));
677 
678     // The result shape is params.shape[:batch_dims] +
679     // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
680     TensorShape result_shape;
681     for (int i = 0; i < batch_dims_; ++i) {
682       result_shape.AddDim(params.dim_size(i));
683     }
684     for (int i = batch_dims_; i < indices.dims(); ++i) {
685       result_shape.AddDim(indices.dim_size(i));
686     }
687     for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
688       result_shape.AddDim(params.dim_size(i));
689     }
690 
691     Tensor* out = nullptr;
692     Tensor tmp;
693     if (params.dtype() == DT_VARIANT) {
694       tmp = Tensor(DT_VARIANT, result_shape);
695       c->set_output(0, tmp);
696       out = &tmp;
697     } else {
698       OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
699     }
700 
701     if (N > 0) {
702       Tensor tmp_indices;
703 
704       // Points to the original or updated (if batch_dims is set) indices.
705       const Tensor* op_indices = &indices;
706       if (batch_dims_ > 0) {
707         OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
708                                            &tmp_indices));
709         functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
710         copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
711                      indices.flat<Index>());
712 
713         AddBatchOffsets(c, &tmp_indices, params);
714         if (!c->status().ok()) return;
715         op_indices = &tmp_indices;
716       }
717 
718       int64_t gather_dim_size = 1;
719       for (int idx = 0; idx <= batch_dims_; ++idx) {
720         gather_dim_size *= params.dim_size(idx);
721       }
722       int64_t inner_size = 1;
723       for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
724         inner_size *= params.dim_size(i);
725       }
726       auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
727       const auto indices_flat = op_indices->flat<Index>();
728       auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
729 
730       functor::GatherFunctor<Device, T, Index> functor;
731       int64_t bad_i = functor(c, params_flat, indices_flat, out_flat);
732 
733       OP_REQUIRES(
734           c, bad_i < 0,
735           errors::InvalidArgument(
736               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
737               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
738     }
739   }
740 
741  private:
742   // Add the batch offset derived from params to each batch of indices.
743   // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
744   // If indexing into a params dimension of size 4, then the indices will become
745   // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(OpKernelContext * ctx,Tensor * indices,const Tensor & params)746   void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices,
747                        const Tensor& params) {
748     int64_t batch_size = 1;  // The size of all batch dimensions.
749     for (int idx = 0; idx < batch_dims_; ++idx) {
750       batch_size *= params.dim_size(idx);
751     }
752     OP_REQUIRES(
753         ctx, batch_size != 0,
754         errors::InvalidArgument(
755             "Inner size of indices would result in batch_size of 0 and a ",
756             "division by 0 in the implementation. This is illegal"));
757 
758     auto indices_flat = indices->flat<Index>();
759     int64_t const index_inner_size = indices->NumElements() / batch_size;
760     int64_t const batch_offset = params.dim_size(batch_dims_);
761     for (int64_t batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
762          ++batch_idx) {
763       for (int64_t idx = 0; idx < index_inner_size; ++idx) {
764         indices_flat(dest_idx++) += batch_offset * batch_idx;
765       }
766     }
767   }
768 
769   int32 batch_dims_ = 0;
770 };
771 
772 #define REGISTER_GATHER_FULL(dev, type, index_type)                    \
773   REGISTER_KERNEL_BUILDER(Name("ResourceGather")                       \
774                               .Device(DEVICE_##dev)                    \
775                               .HostMemory("resource")                  \
776                               .TypeConstraint<type>("dtype")           \
777                               .TypeConstraint<index_type>("Tindices"), \
778                           ResourceGatherOp<dev##Device, type, index_type>)
779 
780 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
781   REGISTER_GATHER_FULL(dev, type, int32);      \
782   REGISTER_GATHER_FULL(dev, type, int64)
783 
784 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
785 
786 // Registration of the CPU implementations.
787 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
788 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
789 
790 // Registers GPU kernels.
791 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
792 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
793 
794 TF_CALL_int64(REGISTER_GATHER_GPU);
795 TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
796 
797 // Variant objects themselves sit on CPU, even if they contain data
798 // pointing to a device.
799 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
800                             .Device(DEVICE_GPU)
801                             .HostMemory("resource")
802                             .HostMemory("indices")
803                             .TypeConstraint<Variant>("dtype")
804                             .TypeConstraint<int32>("Tindices"),
805                         ResourceGatherOp<GPUDevice, Variant, int32>)
806 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
807                             .Device(DEVICE_GPU)
808                             .HostMemory("resource")
809                             .HostMemory("indices")
810                             .TypeConstraint<Variant>("dtype")
811                             .TypeConstraint<int64>("Tindices"),
812                         ResourceGatherOp<GPUDevice, Variant, int64>)
813 
814 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
815 
816 #undef REGISTER_GATHER_CPU
817 #undef REGISTER_GATHER_GPU
818 #undef REGISTER_GATHER_ALL_INDICES
819 #undef REGISTER_GATHER_FULL
820 
821 template <typename Device, typename T, typename Index>
822 class ResourceGatherNdOp : public OpKernel {
823  public:
ResourceGatherNdOp(OpKernelConstruction * c)824   explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
825 
Compute(OpKernelContext * c)826   void Compute(OpKernelContext* c) override {
827     core::RefCountPtr<Var> v;
828     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
829     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
830     // NOTE: We hold the lock for the whole gather operation instead
831     // of increasing the reference count of v->tensor() to avoid a
832     // situation where a write to the same variable will see a
833     // reference count greater than one and make a copy of the
834     // (potentially very large) tensor buffer.
835     tf_shared_lock ml(*v->mu());
836     const Tensor& params = *v->tensor();
837     const Tensor& indices = c->input(1);
838 
839     Tensor out;
840     OP_REQUIRES_OK(
841         c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
842     c->set_output(0, out);
843   }
844 };
845 
846 #define REGISTER_GATHER_ND_FULL(dev, type, index_type)                 \
847   REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd")                     \
848                               .Device(DEVICE_##dev)                    \
849                               .HostMemory("resource")                  \
850                               .TypeConstraint<type>("dtype")           \
851                               .TypeConstraint<index_type>("Tindices"), \
852                           ResourceGatherNdOp<dev##Device, type, index_type>)
853 
854 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
855   REGISTER_GATHER_ND_FULL(dev, type, int32);      \
856   REGISTER_GATHER_ND_FULL(dev, type, int64)
857 
858 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
859 
860 // Registration of the CPU implementations.
861 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
862 
863 // Registers GPU kernels.
864 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
865 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
866 
867 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
868 
869 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
870 
871 #undef REGISTER_GATHER_ND_CPU
872 #undef REGISTER_GATHER_ND_GPU
873 #undef REGISTER_GATHER_ND_ALL_INDICES
874 #undef REGISTER_GATHER_ND_FULL
875 
876 namespace {
877 
878 template <typename Device>
isCPUDevice()879 bool isCPUDevice() {
880   return false;
881 }
882 
883 template <>
isCPUDevice()884 bool isCPUDevice<CPUDevice>() {
885   return true;
886 }
887 
888 template <typename T>
ValidateInput(const Tensor & updates)889 bool ValidateInput(const Tensor& updates) {
890   const auto updates_flat = updates.flat<T>();
891   const T zero(0);
892   for (int i = 0; i < updates.NumElements(); i++) {
893     if (updates_flat(i) == zero) return false;
894   }
895   return true;
896 }
897 
898 template <>
ValidateInput(const Tensor & updates)899 bool ValidateInput<Variant>(const Tensor& updates) {
900   return true;
901 }
902 
903 }  // namespace
904 
905 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
906 class ResourceScatterUpdateOp : public OpKernel {
907  public:
ResourceScatterUpdateOp(OpKernelConstruction * c)908   explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
909     // We use the same kernel for many operations.
910     // Each operation has a different set of attributes defined in its nodes.
911     Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
912     if (!s.ok()) {
913       use_exclusive_lock_ = false;
914     }
915   }
916 
Compute(OpKernelContext * c)917   void Compute(OpKernelContext* c) override {
918     core::RefCountPtr<Var> v;
919     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
920     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
921     const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
922                                   c->input_dtype(0) == DT_STRING ||
923                                   c->input_dtype(0) == DT_VARIANT;
924     if (is_non_pod_dtype || use_exclusive_lock_) {
925       mutex_lock ml(*v->mu());
926       DoCompute(c);
927     } else {
928       // For POD dtypes, we can safely run the update without the mutex.
929       tf_shared_lock ml(*v->mu());
930       DoCompute(c);
931     }
932   }
933 
934  private:
935   bool use_exclusive_lock_;
936 
DoCompute(OpKernelContext * c)937   void DoCompute(OpKernelContext* c) {
938     core::RefCountPtr<Var> v;
939     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
940     Tensor* params = v->tensor();
941     const Tensor& indices = c->input(1);
942     const Tensor& updates = c->input(2);
943 
944     // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
945     OP_REQUIRES(c,
946                 updates.dims() == 0 ||
947                     updates.dims() == indices.dims() + params->dims() - 1,
948                 errors::InvalidArgument(
949                     "Must have updates.shape = indices.shape + "
950                     "params.shape[1:] or updates.shape = [], got ",
951                     "updates.shape ", updates.shape().DebugString(),
952                     ", indices.shape ", indices.shape().DebugString(),
953                     ", params.shape ", params->shape().DebugString()));
954 
955     // Check that we have enough index space
956     const int64_t N_big = indices.NumElements();
957     OP_REQUIRES(
958         c, N_big <= std::numeric_limits<Index>::max(),
959         errors::InvalidArgument("indices has too many elements for ",
960                                 DataTypeString(DataTypeToEnum<Index>::v()),
961                                 " indexing: ", N_big, " > ",
962                                 std::numeric_limits<Index>::max()));
963     const Index N = static_cast<Index>(N_big);
964     OP_REQUIRES(
965         c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
966         errors::InvalidArgument("params.shape[0] too large for ",
967                                 DataTypeString(DataTypeToEnum<Index>::v()),
968                                 " indexing: ", params->dim_size(0), " > ",
969                                 std::numeric_limits<Index>::max()));
970 
971     // Prevent division by 0
972     if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
973       OP_REQUIRES(c, ValidateInput<T>(updates),
974                   errors::InvalidArgument("updates must not contain 0"));
975     }
976 
977     if (N > 0) {
978       auto indices_flat = indices.flat<Index>();
979       auto params_flat = params->flat_outer_dims<T>();
980       if (TensorShapeUtils::IsScalar(updates.shape())) {
981         const auto update = updates.scalar<T>();
982 
983         functor::ScatterScalarFunctor<Device, T, Index, op> functor;
984         const Index bad_i = functor(c, c->template eigen_device<Device>(),
985                                     params_flat, update, indices_flat);
986         OP_REQUIRES(c, bad_i < 0,
987                     errors::InvalidArgument(
988                         "indices", SliceDebugString(indices.shape(), bad_i),
989                         " = ", indices_flat(bad_i), " is not in [0, ",
990                         params->dim_size(0), ")"));
991       } else {
992         int64_t num_updates = updates.NumElements();
993         OP_REQUIRES(
994             c, TensorShapeUtils::StartsWith(updates.shape(), indices.shape()),
995             errors::InvalidArgument(
996                 "The shape of indices (", indices.shape().DebugString(),
997                 ") must be a prefix of the shape of updates (",
998                 updates.shape().DebugString(), ")"));
999         auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
1000 
1001         functor::ScatterFunctor<Device, T, Index, op> functor;
1002         const Index bad_i = functor(c, c->template eigen_device<Device>(),
1003                                     params_flat, updates_flat, indices_flat);
1004         OP_REQUIRES(c, bad_i < 0,
1005                     errors::InvalidArgument(
1006                         "indices", SliceDebugString(indices.shape(), bad_i),
1007                         " = ", indices_flat(bad_i), " is not in [0, ",
1008                         params->dim_size(0), ")"));
1009       }
1010     }
1011   }
1012 };
1013 
1014 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
1015   REGISTER_KERNEL_BUILDER(                                             \
1016       Name(name)                                                       \
1017           .Device(DEVICE_##dev)                                        \
1018           .HostMemory("resource")                                      \
1019           .TypeConstraint<type>("dtype")                               \
1020           .TypeConstraint<index_type>("Tindices"),                     \
1021       ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
1022 
1023 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
1024   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
1025   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
1026 
1027 #define REGISTER_SCATTER_ARITHMETIC(type, dev)                \
1028   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd",    \
1029                           scatter_op::UpdateOp::ADD);         \
1030   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub",    \
1031                           scatter_op::UpdateOp::SUB);         \
1032   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul",    \
1033                           scatter_op::UpdateOp::MUL);         \
1034   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv",    \
1035                           scatter_op::UpdateOp::DIV);         \
1036   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
1037                           scatter_op::UpdateOp::ASSIGN);
1038 #define REGISTER_SCATTER_MINMAX(type, dev)                 \
1039   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
1040                           scatter_op::UpdateOp::MIN);      \
1041   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
1042                           scatter_op::UpdateOp::MAX);
1043 
1044 // Registers CPU kernels.
1045 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
1046   REGISTER_SCATTER_ARITHMETIC(type, CPU);
1047 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
1048 
1049 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
1050 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
1051 
1052 REGISTER_SCATTER_KERNEL(tstring, CPU, "ResourceScatterUpdate",
1053                         scatter_op::UpdateOp::ASSIGN);
1054 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
1055                         scatter_op::UpdateOp::ASSIGN);
1056 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
1057                         scatter_op::UpdateOp::ASSIGN);
1058 
1059 // Registers GPU kernels.
1060 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1061 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
1062   REGISTER_SCATTER_ARITHMETIC(type, GPU);
1063 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
1064 
1065 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
1066 
1067 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
1068 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
1069 
1070 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1071                             .Device(DEVICE_GPU)
1072                             .HostMemory("resource")
1073                             .HostMemory("indices")
1074                             .TypeConstraint<Variant>("dtype")
1075                             .TypeConstraint<int32>("Tindices"),
1076                         ResourceScatterUpdateOp<GPUDevice, Variant, int32,
1077                                                 scatter_op::UpdateOp::ASSIGN>)
1078 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1079                             .Device(DEVICE_GPU)
1080                             .HostMemory("resource")
1081                             .TypeConstraint<bool>("dtype")
1082                             .TypeConstraint<int32>("Tindices"),
1083                         ResourceScatterUpdateOp<GPUDevice, bool, int32,
1084                                                 scatter_op::UpdateOp::ASSIGN>)
1085 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1086                             .Device(DEVICE_GPU)
1087                             .HostMemory("resource")
1088                             .HostMemory("indices")
1089                             .TypeConstraint<Variant>("dtype")
1090                             .TypeConstraint<int64>("Tindices"),
1091                         ResourceScatterUpdateOp<GPUDevice, Variant, int64,
1092                                                 scatter_op::UpdateOp::ASSIGN>)
1093 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1094                             .Device(DEVICE_GPU)
1095                             .HostMemory("resource")
1096                             .TypeConstraint<int64>("dtype")
1097                             .TypeConstraint<int64>("Tindices"),
1098                         ResourceScatterUpdateOp<GPUDevice, int64, int64,
1099                                                 scatter_op::UpdateOp::ASSIGN>)
1100 
1101 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1102 
1103 #undef REGISTER_SCATTER_ARITHMETIC
1104 #undef REGISTER_SCATTER_ARITHMETIC_CPU
1105 #undef REGISTER_SCATTER_MINMAX
1106 #undef REGISTER_SCATTER_MINMAX_CPU
1107 #undef REGISTER_SCATTER_KERNEL
1108 #undef REGISTER_SCATTER_KERNEL_INDEX
1109 
1110 }  // namespace tensorflow
1111