1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 // - acquire the variable's mutex (in "shared" mode);
20 // - make a (shallow) copy of the Tensor object, which increments
21 // the reference count on the variable's TensorBuffer;
22 // - release the variable's mutex;
23 // - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 // - acquire the variable's mutex (in "exclusive" mode);
26 // - check the reference count of variable's TensorBuffer and
27 // if it is >1, make a deep copy of the variable's Tensor;
28 // - mutate the variable's Tensor;
29 // - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 // "shared" mode) for the duration of the whole read. This means
41 // that as long as you only do sparse read operations no write will
42 // see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 // that they want to perform the write without locks held
45 // (use_locking=false), we never copy even if the variable's
46 // reference count is >1.
47
48 #define EIGEN_USE_THREADS
49
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #define EIGEN_USE_GPU
52 #endif
53
54 #include "tensorflow/core/kernels/resource_variable_ops.h"
55
56 #include <memory>
57 #include <vector>
58
59 #include "absl/strings/str_join.h"
60 #include "tensorflow/core/common_runtime/device.h"
61 #include "tensorflow/core/framework/bounds_check.h"
62 #include "tensorflow/core/framework/op_kernel.h"
63 #include "tensorflow/core/framework/register_types.h"
64 #include "tensorflow/core/framework/resource_mgr.h"
65 #include "tensorflow/core/framework/tensor_shape.h"
66 #include "tensorflow/core/framework/tensor_types.h"
67 #include "tensorflow/core/framework/variant_op_registry.h"
68 #include "tensorflow/core/kernels/dense_update_functor.h"
69 #include "tensorflow/core/kernels/gather_functor.h"
70 #include "tensorflow/core/kernels/gather_nd_op.h"
71 #include "tensorflow/core/kernels/scatter_functor.h"
72 #include "tensorflow/core/kernels/training_op_helpers.h"
73 #include "tensorflow/core/kernels/variable_ops.h"
74 #include "tensorflow/core/lib/core/errors.h"
75 #include "tensorflow/core/lib/core/refcount.h"
76 #include "tensorflow/core/platform/casts.h"
77 #include "tensorflow/core/platform/mem.h"
78 #include "tensorflow/core/platform/mutex.h"
79 #include "tensorflow/core/platform/types.h"
80 #include "tensorflow/core/util/util.h"
81
82 namespace tensorflow {
83
84 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
85 ResourceHandlesOp<Var>);
86
ReadVariableOp(OpKernelConstruction * c)87 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
88 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
89 }
90
91 namespace {
92
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)93 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
94 Tensor* output;
95 Notification n;
96 Status status;
97 AllocatorAttributes attr;
98 if (t->dtype() == DT_VARIANT) {
99 attr.set_on_host(true);
100 }
101 TF_RETURN_IF_ERROR(
102 ctx->allocate_output(output_idx, t->shape(), &output, attr));
103 if (t->dtype() == DT_VARIANT) {
104 output->flat<Variant>() = t->flat<Variant>();
105 } else if (ctx->op_device_context() != nullptr) {
106 // TODO(apassos): remove the down_cast by just returning Device* from
107 // OpKernelContext
108 Device* device = down_cast<Device*>(ctx->device());
109 ctx->op_device_context()->CopyTensorInSameDevice(
110 t, device, output, [&n, &status](const Status& s) {
111 status = s;
112 n.Notify();
113 });
114 n.WaitForNotification();
115 return status;
116 } else {
117 switch (t->dtype()) {
118 #define HANDLER(type) \
119 case DataTypeToEnum<type>::value: \
120 output->flat<type>() = t->flat<type>(); \
121 break;
122 TF_CALL_ALL_TYPES(HANDLER);
123 #undef HANDLER
124 default:
125 return errors::Internal("Unsupported dtype", t->dtype());
126 }
127 }
128 return Status::OK();
129 }
130
131 } // namespace
132
Compute(OpKernelContext * ctx)133 void ReadVariableOp::Compute(OpKernelContext* ctx) {
134 core::RefCountPtr<Var> variable;
135 const ResourceHandle& handle = HandleFromInput(ctx, 0);
136 const auto status = LookupResource(ctx, handle, &variable);
137 OP_REQUIRES(ctx, status.ok(),
138 errors::FailedPrecondition(
139 "Could not find variable ", handle.name(), ". ",
140 "This could mean that the variable has been deleted. ",
141 "In TF1, it can also mean the variable is uninitialized. ",
142 "Debug info: container=", handle.container(),
143 ", status=", status.ToString()));
144
145 tf_shared_lock ml(*variable->mu());
146 // We're acquiring a reference to the underlying buffer while
147 // holding a shared lock to guarantee ordering of reads and
148 // writes when in copy-on-write mode.
149 const Tensor* t = variable->tensor();
150 if (!variable->copy_on_read_mode.load()) {
151 OP_REQUIRES(
152 ctx, dtype_ == t->dtype(),
153 errors::InvalidArgument(
154 "Trying to read variable with wrong dtype. Expected ",
155 DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
156 ctx->set_output(0, *t);
157 } else {
158 OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
159 }
160 }
161
ReadVariablesOp(OpKernelConstruction * c)162 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
163 int n;
164 OP_REQUIRES_OK(c, c->GetAttr("N", &n));
165 OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
166 OP_REQUIRES(c, n == dtypes_.size(),
167 errors::InvalidArgument(
168 "Mismatched number of arguments to ReadVariablesOp (", n,
169 " vs. ", dtypes_.size(), ")"));
170 }
171
Compute(OpKernelContext * ctx)172 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
173 std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
174 std::vector<const ResourceHandle*> handles(dtypes_.size());
175 for (size_t i = 0; i < dtypes_.size(); ++i) {
176 handles[i] = &HandleFromInput(ctx, i);
177 }
178
179 OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
180
181 std::vector<string> uninitialized_vars;
182 for (int64_t i = 0; i < variables.size(); i++) {
183 if (variables[i] == nullptr) {
184 uninitialized_vars.push_back(handles[i]->name());
185 }
186 }
187
188 OP_REQUIRES(ctx, uninitialized_vars.empty(),
189 errors::FailedPrecondition(
190 "In ReadVariablesOp the following variables were "
191 "found uninitialized: ",
192 absl::StrJoin(uninitialized_vars, ", ")));
193
194 for (size_t i = 0; i < dtypes_.size(); ++i) {
195 // We're acquiring a reference to the underlying buffer while
196 // holding a shared lock to guarantee ordering of reads and
197 // writes.
198 tf_shared_lock ml(*variables[i]->mu());
199 OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
200 errors::InvalidArgument(
201 "Trying to read variable ", handles[i]->name(),
202 " from Container: ", handles[i]->container(),
203 " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
204 " got ", DataTypeString(variables[i]->tensor()->dtype())));
205 if (variables[i]->copy_on_read_mode.load()) {
206 OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
207 } else {
208 const Tensor& t = *variables[i]->tensor();
209 ctx->set_output(i, t);
210 }
211 }
212 }
213
214 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
215 ReadVariableOp);
216 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
217 ReadVariablesOp);
218
219 REGISTER_KERNEL_BUILDER(
220 Name("ReadVariableOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
221 ReadVariableOp);
222 REGISTER_KERNEL_BUILDER(
223 Name("_ReadVariablesOp").Device(DEVICE_DEFAULT).HostMemory("resources"),
224 ReadVariablesOp);
225
VarHandleOp(OpKernelConstruction * context)226 VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
227 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
228 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
229
230 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
231 OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
232
233 is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
234
235 if (!is_anonymous_) {
236 AllocatorAttributes attr;
237 attr.set_on_host(true);
238 OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
239 &resource_, attr));
240 resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
241 context, container_, name_,
242 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
243 }
244 }
245
Compute(OpKernelContext * ctx)246 void VarHandleOp::Compute(OpKernelContext* ctx) {
247 if (is_anonymous_) {
248 AllocatorAttributes attr;
249 attr.set_on_host(true);
250 Tensor handle;
251 OP_REQUIRES_OK(
252 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
253 handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
254 ctx, container_, name_,
255 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
256 ctx->stack_trace());
257 ctx->set_output(0, handle);
258 } else {
259 ctx->set_output(0, resource_);
260 }
261 }
262
263 REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
264
265 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266 REGISTER_KERNEL_BUILDER(
267 Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
268 ReadVariableOp);
269 REGISTER_KERNEL_BUILDER(
270 Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
271 ReadVariablesOp);
272
273 #define REGISTER_GPU_KERNELS(type) \
274 namespace functor { \
275 template <> \
276 void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
277 const GPUDevice& d, typename TTypes<type>::Flat lhs, \
278 typename TTypes<type>::ConstFlat rhs); \
279 extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
280 } \
281 REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \
282 .Device(DEVICE_GPU) \
283 .HostMemory("resource") \
284 .TypeConstraint<type>("dtype"), \
285 VarHandleOp)
286 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
287 TF_CALL_int64(REGISTER_GPU_KERNELS);
288 TF_CALL_variant(REGISTER_GPU_KERNELS);
289 TF_CALL_uint32(REGISTER_GPU_KERNELS);
290 #undef REGISTER_GPU_KERNELS
291
292 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
293 .Device(DEVICE_GPU)
294 .HostMemory("resources")
295 .TypeConstraint("dtypes",
296 {DT_INT64, DT_COMPLEX64,
297 DT_COMPLEX128, DT_HALF, DT_FLOAT,
298 DT_DOUBLE, DT_BOOL, DT_VARIANT}),
299 ResourceHandlesOp<Var>);
300
301 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
302
303 #define REGISTER_DEFAULT_KERNELS(type) \
304 REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \
305 .Device(DEVICE_DEFAULT) \
306 .HostMemory("resource") \
307 .TypeConstraint<type>("dtype"), \
308 VarHandleOp)
309 TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
310 TF_CALL_int64(REGISTER_DEFAULT_KERNELS);
311 TF_CALL_variant(REGISTER_DEFAULT_KERNELS);
312 TF_CALL_uint32(REGISTER_DEFAULT_KERNELS);
313 #undef REGISTER_DEFAULT_KERNELS
314
315 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
316 .Device(DEVICE_DEFAULT)
317 .HostMemory("resources")
318 .TypeConstraint("dtypes",
319 {DT_INT64, DT_COMPLEX64,
320 DT_COMPLEX128, DT_HALF, DT_FLOAT,
321 DT_DOUBLE, DT_BOOL, DT_VARIANT}),
322 ResourceHandlesOp<Var>);
323
324 REGISTER_KERNEL_BUILDER(
325 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
326 VariableShapeOp<int32>);
327 REGISTER_KERNEL_BUILDER(
328 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
329 VariableShapeOp<int64>);
330
331 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
332
333 REGISTER_KERNEL_BUILDER(Name("VariableShape")
334 .Device(DEVICE_GPU)
335 .TypeConstraint<int32>("out_type")
336 .HostMemory("output")
337 .HostMemory("input"),
338 VariableShapeOp<int32>);
339 REGISTER_KERNEL_BUILDER(Name("VariableShape")
340 .Device(DEVICE_GPU)
341 .TypeConstraint<int64>("out_type")
342 .HostMemory("output")
343 .HostMemory("input"),
344 VariableShapeOp<int64>);
345
346 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
347
DestroyResourceOp(OpKernelConstruction * ctx)348 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
349 : OpKernel(ctx) {
350 OP_REQUIRES_OK(ctx,
351 ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
352 }
353
Compute(OpKernelContext * ctx)354 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
355 const ResourceHandle& p = HandleFromInput(ctx, 0);
356 Status status = DeleteResource(ctx, p);
357 if (ignore_lookup_error_ && errors::IsNotFound(status)) {
358 return;
359 }
360 OP_REQUIRES_OK(ctx, status);
361 }
362
363 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
364 DestroyResourceOp);
365 REGISTER_KERNEL_BUILDER(
366 Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
367 DestroyResourceOp);
368
369 template <typename Device, typename T>
370 class AssignVariableOp : public OpKernel {
371 public:
AssignVariableOp(OpKernelConstruction * c)372 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
373 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
374 if (!c->GetAttr("_grappler_relax_allocator_constraints",
375 &relax_constraints_)
376 .ok()) {
377 relax_constraints_ = false;
378 }
379 }
380
Compute(OpKernelContext * context)381 void Compute(OpKernelContext* context) override {
382 OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
383 errors::InvalidArgument(
384 "Variable and value dtypes don't match; respectively, ",
385 DataTypeString(dtype_), " and ",
386 DataTypeString(context->input(1).dtype())));
387 core::RefCountPtr<Var> variable;
388 const Tensor& value = context->input(1);
389 // Note: every resource-variable-manipulating op assumes copy-on-write
390 // semantics, and creates a copy of the variable's Tensor if its refcount is
391 // bigger than 1 when we try to modify it. This means we never need to copy
392 // the original tensor for AssignVariableOp; even if there are other live
393 // users of it we know none can modify it so this is always safe (even in
394 // esoteric cases where the same tensor is used to initialize multiple
395 // variables or the tensor is a constant this is safe, as future writes will
396 // trigger copies).
397 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
398 context, HandleFromInput(context, 0), &variable,
399 [this, &value](Var** ptr) {
400 *ptr = new Var(dtype_);
401 *(*ptr)->tensor() = value;
402 (*ptr)->is_initialized = true;
403 return Status::OK();
404 }));
405 mutex_lock ml(*variable->mu());
406 // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
407 // check below is to allow an XLA specific situation wherein update can
408 // happen first by the AssignVariableOp,
409 // in which case the variable is still uninitialized.
410 // When using TF-XLA, this scenario is possible when the execution uses the
411 // 'fallback' path (which essentially invokes Tensorflow ops via
412 // partitioned_call).
413 OP_REQUIRES(context,
414 (variable->tensor()->dtype() == DT_INVALID &&
415 !variable->is_initialized) ||
416 variable->tensor()->dtype() == dtype_,
417 errors::InvalidArgument(
418 "Trying to assign variable with wrong dtype. Expected ",
419 DataTypeString(variable->tensor()->dtype()), " got ",
420 DataTypeString(dtype_)));
421 if (variable->copy_on_read_mode.load()) {
422 AllocatorAttributes attr;
423 attr.set_gpu_compatible(true);
424 attr.set_nic_compatible(true);
425 OP_REQUIRES_OK(context,
426 context->allocate_temp(value.dtype(), value.shape(),
427 variable->tensor(), attr));
428 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
429 copy_functor(context->eigen_device<Device>(),
430 variable->tensor()->flat<T>(), value.flat<T>());
431 } else {
432 *variable->tensor() = value;
433 }
434 variable->is_initialized = true;
435 }
436
437 private:
438 DataType dtype_;
439 bool relax_constraints_;
440 };
441
442 template <typename Device>
443 class AssignVariableOp<Device, Variant> : public OpKernel {
444 public:
AssignVariableOp(OpKernelConstruction * c)445 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
446 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
447 OP_REQUIRES(c, dtype_ == DT_VARIANT,
448 errors::Internal("Variant kernel called with dtype: ",
449 DataTypeString(dtype_)));
450 }
451
Compute(OpKernelContext * context)452 void Compute(OpKernelContext* context) override {
453 const Tensor& value = context->input(1);
454 core::RefCountPtr<Var> variable;
455 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
456 context, HandleFromInput(context, 0), &variable,
457 [](Var** ptr) {
458 // Created on host.
459 *ptr = new Var(DT_VARIANT);
460 return Status::OK();
461 }));
462
463 // For purposes of forwarding DT_VARIANT, we want the least
464 // restrictive attr; we already know the input is on host.
465 AllocatorAttributes attr;
466
467 // Copying is unnecessary if we are the last user of the value
468 // tensor, we can just adopt the input tensor's buffer instead.
469 // Note that Variant objects themselves always reside on host.
470 //
471 // We nevertheless want to signal to the runtime that the tensor
472 // should reside in memory of the associated device, as Variant
473 // tensors may be marked as sitting on either CPU or GPU. This
474 // helps to elide one or more copies.
475 std::unique_ptr<Tensor> input_alias = context->forward_input(
476 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
477 value.shape(),
478 DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
479 attr);
480
481 mutex_lock ml(*variable->mu());
482 OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
483 errors::InvalidArgument(
484 "Trying to assign variable with wrong dtype. Expected ",
485 DataTypeString(variable->tensor()->dtype()), " got ",
486 DataTypeString(DT_VARIANT)));
487 variable->is_initialized = true;
488 *variable->tensor() = Tensor(DT_VARIANT, value.shape());
489
490 if (input_alias) {
491 *variable->tensor() = *input_alias;
492 return;
493 }
494
495 // Need to copy, but maybe we can re-use variable's buffer?
496 if (!variable->tensor()->RefCountIsOne() ||
497 !variable->tensor()->shape().IsSameSize(value.shape())) {
498 // Allocation of DT_VARIANT is always on host.
499 attr.set_on_host(true);
500 OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, value.shape(),
501 variable->tensor(), attr));
502 }
503
504 const auto elements_in = value.flat<Variant>();
505 auto elements_out = variable->tensor()->flat<Variant>();
506 for (int64_t i = 0; i < elements_in.size(); ++i) {
507 elements_out(i) = elements_in(i);
508 }
509 }
510
511 private:
512 DataType dtype_;
513 };
514
515 #define REGISTER_KERNELS(type) \
516 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
517 .Device(DEVICE_CPU) \
518 .TypeConstraint<type>("dtype"), \
519 AssignVariableOp<Eigen::ThreadPoolDevice, type>);
520
521 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
522 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
523 #undef REGISTER_KERNELS
524
525 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
526 #define REGISTER_GPU_KERNELS(type) \
527 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
528 .Device(DEVICE_GPU) \
529 .TypeConstraint<type>("dtype") \
530 .HostMemory("resource"), \
531 AssignVariableOp<GPUDevice, type>);
532
533 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
534 TF_CALL_int64(REGISTER_GPU_KERNELS);
535 TF_CALL_variant(REGISTER_GPU_KERNELS);
536 TF_CALL_uint32(REGISTER_GPU_KERNELS);
537 #undef REGISTER_GPU_KERNELS
538 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
539
540 template <typename Device, typename T, DenseUpdateType Op>
541 class AssignUpdateVariableOp : public OpKernel {
542 public:
AssignUpdateVariableOp(OpKernelConstruction * c)543 explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
544
Compute(OpKernelContext * context)545 void Compute(OpKernelContext* context) override {
546 core::RefCountPtr<Var> variable;
547 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
548 &variable));
549
550 const Tensor& value = context->input(1);
551 // TODO(apassos): We could possibly avoid the copy done by
552 // PrepareToUpdateVariable() for commutative operations like Op ==
553 // ADD if value's refcount was 1.
554 mutex_lock ml(*variable->mu());
555 Tensor* var_tensor = variable->tensor();
556 OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
557 errors::InvalidArgument("Cannot update variable with shape ",
558 var_tensor->shape().DebugString(),
559 " using a Tensor with shape ",
560 value.shape().DebugString(),
561 ", shapes must be equal."));
562 OP_REQUIRES_OK(
563 context, PrepareToUpdateVariable<Device, T>(
564 context, var_tensor, variable->copy_on_read_mode.load()));
565 functor::DenseUpdate<Device, T, Op> update_functor;
566 update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
567 value.flat<T>());
568 }
569 };
570
571 #define REGISTER_KERNELS(type) \
572 REGISTER_KERNEL_BUILDER( \
573 Name("AssignAddVariableOp") \
574 .Device(DEVICE_CPU) \
575 .TypeConstraint<type>("dtype"), \
576 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
577 REGISTER_KERNEL_BUILDER( \
578 Name("AssignSubVariableOp") \
579 .Device(DEVICE_CPU) \
580 .TypeConstraint<type>("dtype"), \
581 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
582
583 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
584 #undef REGISTER_KERNELS
585
586 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
587 #define REGISTER_GPU_KERNELS(type) \
588 REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
589 .Device(DEVICE_GPU) \
590 .HostMemory("resource") \
591 .TypeConstraint<type>("dtype"), \
592 AssignUpdateVariableOp<GPUDevice, type, ADD>); \
593 REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \
594 .Device(DEVICE_GPU) \
595 .HostMemory("resource") \
596 .TypeConstraint<type>("dtype"), \
597 AssignUpdateVariableOp<GPUDevice, type, SUB>);
598
599 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
600 TF_CALL_int64(REGISTER_GPU_KERNELS);
601 #undef REGISTER_GPU_KERNELS
602 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
603
604 class VarIsInitializedOp : public OpKernel {
605 public:
VarIsInitializedOp(OpKernelConstruction * c)606 explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
607
Compute(OpKernelContext * context)608 void Compute(OpKernelContext* context) override {
609 Tensor* output = nullptr;
610 OP_REQUIRES_OK(context,
611 context->allocate_output(0, TensorShape({}), &output));
612 auto output_tensor = output->tensor<bool, 0>();
613 core::RefCountPtr<Var> variable;
614 Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
615 if (!s.ok()) {
616 output_tensor() = false;
617 return;
618 }
619 mutex_lock ml(*variable->mu());
620 output_tensor() = variable->is_initialized;
621 }
622 };
623
624 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
625 VarIsInitializedOp);
626
627 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
628 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
629 .Device(DEVICE_GPU)
630 .HostMemory("resource")
631 .HostMemory("is_initialized"),
632 IsResourceInitialized<Var>);
633 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
634
635 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
636 .Device(DEVICE_DEFAULT)
637 .HostMemory("resource")
638 .HostMemory("is_initialized"),
639 IsResourceInitialized<Var>);
640
641 template <typename Device, typename T, typename Index>
642 class ResourceGatherOp : public OpKernel {
643 public:
ResourceGatherOp(OpKernelConstruction * c)644 explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
645 OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
646 }
647
Compute(OpKernelContext * c)648 void Compute(OpKernelContext* c) override {
649 core::RefCountPtr<Var> v;
650 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
651 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
652 // NOTE: We hold the lock for the whole gather operation instead
653 // of increasing the reference count of v->tensor() to avoid a
654 // situation where a write to the same variable will see a
655 // reference count greater than one and make a copy of the
656 // (potentially very large) tensor buffer.
657 tf_shared_lock ml(*v->mu());
658 const Tensor& params = *v->tensor();
659 const Tensor& indices = c->input(1);
660 OP_REQUIRES(
661 c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
662 errors::InvalidArgument("params must be at least 1 dimensional"));
663 OP_REQUIRES(
664 c, params.shape().dims() >= batch_dims_,
665 errors::InvalidArgument("params must have at least ", batch_dims_,
666 " (batch_dims) dimensions but it has shape ",
667 params.shape().DebugString()));
668
669 // Check that we have enough index space
670 const int64_t N = indices.NumElements();
671 OP_REQUIRES(
672 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
673 errors::InvalidArgument("params.shape[0] too large for ",
674 DataTypeString(DataTypeToEnum<Index>::v()),
675 " indexing: ", params.dim_size(0), " > ",
676 std::numeric_limits<Index>::max()));
677
678 // The result shape is params.shape[:batch_dims] +
679 // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
680 TensorShape result_shape;
681 for (int i = 0; i < batch_dims_; ++i) {
682 result_shape.AddDim(params.dim_size(i));
683 }
684 for (int i = batch_dims_; i < indices.dims(); ++i) {
685 result_shape.AddDim(indices.dim_size(i));
686 }
687 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
688 result_shape.AddDim(params.dim_size(i));
689 }
690
691 Tensor* out = nullptr;
692 Tensor tmp;
693 if (params.dtype() == DT_VARIANT) {
694 tmp = Tensor(DT_VARIANT, result_shape);
695 c->set_output(0, tmp);
696 out = &tmp;
697 } else {
698 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
699 }
700
701 if (N > 0) {
702 Tensor tmp_indices;
703
704 // Points to the original or updated (if batch_dims is set) indices.
705 const Tensor* op_indices = &indices;
706 if (batch_dims_ > 0) {
707 OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
708 &tmp_indices));
709 functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
710 copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
711 indices.flat<Index>());
712
713 AddBatchOffsets(c, &tmp_indices, params);
714 if (!c->status().ok()) return;
715 op_indices = &tmp_indices;
716 }
717
718 int64_t gather_dim_size = 1;
719 for (int idx = 0; idx <= batch_dims_; ++idx) {
720 gather_dim_size *= params.dim_size(idx);
721 }
722 int64_t inner_size = 1;
723 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
724 inner_size *= params.dim_size(i);
725 }
726 auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
727 const auto indices_flat = op_indices->flat<Index>();
728 auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
729
730 functor::GatherFunctor<Device, T, Index> functor;
731 int64_t bad_i = functor(c, params_flat, indices_flat, out_flat);
732
733 OP_REQUIRES(
734 c, bad_i < 0,
735 errors::InvalidArgument(
736 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
737 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
738 }
739 }
740
741 private:
742 // Add the batch offset derived from params to each batch of indices.
743 // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
744 // If indexing into a params dimension of size 4, then the indices will become
745 // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(OpKernelContext * ctx,Tensor * indices,const Tensor & params)746 void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices,
747 const Tensor& params) {
748 int64_t batch_size = 1; // The size of all batch dimensions.
749 for (int idx = 0; idx < batch_dims_; ++idx) {
750 batch_size *= params.dim_size(idx);
751 }
752 OP_REQUIRES(
753 ctx, batch_size != 0,
754 errors::InvalidArgument(
755 "Inner size of indices would result in batch_size of 0 and a ",
756 "division by 0 in the implementation. This is illegal"));
757
758 auto indices_flat = indices->flat<Index>();
759 int64_t const index_inner_size = indices->NumElements() / batch_size;
760 int64_t const batch_offset = params.dim_size(batch_dims_);
761 for (int64_t batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
762 ++batch_idx) {
763 for (int64_t idx = 0; idx < index_inner_size; ++idx) {
764 indices_flat(dest_idx++) += batch_offset * batch_idx;
765 }
766 }
767 }
768
769 int32 batch_dims_ = 0;
770 };
771
772 #define REGISTER_GATHER_FULL(dev, type, index_type) \
773 REGISTER_KERNEL_BUILDER(Name("ResourceGather") \
774 .Device(DEVICE_##dev) \
775 .HostMemory("resource") \
776 .TypeConstraint<type>("dtype") \
777 .TypeConstraint<index_type>("Tindices"), \
778 ResourceGatherOp<dev##Device, type, index_type>)
779
780 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
781 REGISTER_GATHER_FULL(dev, type, int32); \
782 REGISTER_GATHER_FULL(dev, type, int64)
783
784 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
785
786 // Registration of the CPU implementations.
787 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
788 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
789
790 // Registers GPU kernels.
791 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
792 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
793
794 TF_CALL_int64(REGISTER_GATHER_GPU);
795 TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
796
797 // Variant objects themselves sit on CPU, even if they contain data
798 // pointing to a device.
799 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
800 .Device(DEVICE_GPU)
801 .HostMemory("resource")
802 .HostMemory("indices")
803 .TypeConstraint<Variant>("dtype")
804 .TypeConstraint<int32>("Tindices"),
805 ResourceGatherOp<GPUDevice, Variant, int32>)
806 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
807 .Device(DEVICE_GPU)
808 .HostMemory("resource")
809 .HostMemory("indices")
810 .TypeConstraint<Variant>("dtype")
811 .TypeConstraint<int64>("Tindices"),
812 ResourceGatherOp<GPUDevice, Variant, int64>)
813
814 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
815
816 #undef REGISTER_GATHER_CPU
817 #undef REGISTER_GATHER_GPU
818 #undef REGISTER_GATHER_ALL_INDICES
819 #undef REGISTER_GATHER_FULL
820
821 template <typename Device, typename T, typename Index>
822 class ResourceGatherNdOp : public OpKernel {
823 public:
ResourceGatherNdOp(OpKernelConstruction * c)824 explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
825
Compute(OpKernelContext * c)826 void Compute(OpKernelContext* c) override {
827 core::RefCountPtr<Var> v;
828 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
829 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
830 // NOTE: We hold the lock for the whole gather operation instead
831 // of increasing the reference count of v->tensor() to avoid a
832 // situation where a write to the same variable will see a
833 // reference count greater than one and make a copy of the
834 // (potentially very large) tensor buffer.
835 tf_shared_lock ml(*v->mu());
836 const Tensor& params = *v->tensor();
837 const Tensor& indices = c->input(1);
838
839 Tensor out;
840 OP_REQUIRES_OK(
841 c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
842 c->set_output(0, out);
843 }
844 };
845
846 #define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
847 REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd") \
848 .Device(DEVICE_##dev) \
849 .HostMemory("resource") \
850 .TypeConstraint<type>("dtype") \
851 .TypeConstraint<index_type>("Tindices"), \
852 ResourceGatherNdOp<dev##Device, type, index_type>)
853
854 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
855 REGISTER_GATHER_ND_FULL(dev, type, int32); \
856 REGISTER_GATHER_ND_FULL(dev, type, int64)
857
858 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
859
860 // Registration of the CPU implementations.
861 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
862
863 // Registers GPU kernels.
864 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
865 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
866
867 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
868
869 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
870
871 #undef REGISTER_GATHER_ND_CPU
872 #undef REGISTER_GATHER_ND_GPU
873 #undef REGISTER_GATHER_ND_ALL_INDICES
874 #undef REGISTER_GATHER_ND_FULL
875
876 namespace {
877
878 template <typename Device>
isCPUDevice()879 bool isCPUDevice() {
880 return false;
881 }
882
883 template <>
isCPUDevice()884 bool isCPUDevice<CPUDevice>() {
885 return true;
886 }
887
888 template <typename T>
ValidateInput(const Tensor & updates)889 bool ValidateInput(const Tensor& updates) {
890 const auto updates_flat = updates.flat<T>();
891 const T zero(0);
892 for (int i = 0; i < updates.NumElements(); i++) {
893 if (updates_flat(i) == zero) return false;
894 }
895 return true;
896 }
897
898 template <>
ValidateInput(const Tensor & updates)899 bool ValidateInput<Variant>(const Tensor& updates) {
900 return true;
901 }
902
903 } // namespace
904
905 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
906 class ResourceScatterUpdateOp : public OpKernel {
907 public:
ResourceScatterUpdateOp(OpKernelConstruction * c)908 explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
909 // We use the same kernel for many operations.
910 // Each operation has a different set of attributes defined in its nodes.
911 Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
912 if (!s.ok()) {
913 use_exclusive_lock_ = false;
914 }
915 }
916
Compute(OpKernelContext * c)917 void Compute(OpKernelContext* c) override {
918 core::RefCountPtr<Var> v;
919 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
920 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
921 const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
922 c->input_dtype(0) == DT_STRING ||
923 c->input_dtype(0) == DT_VARIANT;
924 if (is_non_pod_dtype || use_exclusive_lock_) {
925 mutex_lock ml(*v->mu());
926 DoCompute(c);
927 } else {
928 // For POD dtypes, we can safely run the update without the mutex.
929 tf_shared_lock ml(*v->mu());
930 DoCompute(c);
931 }
932 }
933
934 private:
935 bool use_exclusive_lock_;
936
DoCompute(OpKernelContext * c)937 void DoCompute(OpKernelContext* c) {
938 core::RefCountPtr<Var> v;
939 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
940 Tensor* params = v->tensor();
941 const Tensor& indices = c->input(1);
942 const Tensor& updates = c->input(2);
943
944 // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
945 OP_REQUIRES(c,
946 updates.dims() == 0 ||
947 updates.dims() == indices.dims() + params->dims() - 1,
948 errors::InvalidArgument(
949 "Must have updates.shape = indices.shape + "
950 "params.shape[1:] or updates.shape = [], got ",
951 "updates.shape ", updates.shape().DebugString(),
952 ", indices.shape ", indices.shape().DebugString(),
953 ", params.shape ", params->shape().DebugString()));
954
955 // Check that we have enough index space
956 const int64_t N_big = indices.NumElements();
957 OP_REQUIRES(
958 c, N_big <= std::numeric_limits<Index>::max(),
959 errors::InvalidArgument("indices has too many elements for ",
960 DataTypeString(DataTypeToEnum<Index>::v()),
961 " indexing: ", N_big, " > ",
962 std::numeric_limits<Index>::max()));
963 const Index N = static_cast<Index>(N_big);
964 OP_REQUIRES(
965 c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
966 errors::InvalidArgument("params.shape[0] too large for ",
967 DataTypeString(DataTypeToEnum<Index>::v()),
968 " indexing: ", params->dim_size(0), " > ",
969 std::numeric_limits<Index>::max()));
970
971 // Prevent division by 0
972 if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
973 OP_REQUIRES(c, ValidateInput<T>(updates),
974 errors::InvalidArgument("updates must not contain 0"));
975 }
976
977 if (N > 0) {
978 auto indices_flat = indices.flat<Index>();
979 auto params_flat = params->flat_outer_dims<T>();
980 if (TensorShapeUtils::IsScalar(updates.shape())) {
981 const auto update = updates.scalar<T>();
982
983 functor::ScatterScalarFunctor<Device, T, Index, op> functor;
984 const Index bad_i = functor(c, c->template eigen_device<Device>(),
985 params_flat, update, indices_flat);
986 OP_REQUIRES(c, bad_i < 0,
987 errors::InvalidArgument(
988 "indices", SliceDebugString(indices.shape(), bad_i),
989 " = ", indices_flat(bad_i), " is not in [0, ",
990 params->dim_size(0), ")"));
991 } else {
992 int64_t num_updates = updates.NumElements();
993 OP_REQUIRES(
994 c, TensorShapeUtils::StartsWith(updates.shape(), indices.shape()),
995 errors::InvalidArgument(
996 "The shape of indices (", indices.shape().DebugString(),
997 ") must be a prefix of the shape of updates (",
998 updates.shape().DebugString(), ")"));
999 auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
1000
1001 functor::ScatterFunctor<Device, T, Index, op> functor;
1002 const Index bad_i = functor(c, c->template eigen_device<Device>(),
1003 params_flat, updates_flat, indices_flat);
1004 OP_REQUIRES(c, bad_i < 0,
1005 errors::InvalidArgument(
1006 "indices", SliceDebugString(indices.shape(), bad_i),
1007 " = ", indices_flat(bad_i), " is not in [0, ",
1008 params->dim_size(0), ")"));
1009 }
1010 }
1011 }
1012 };
1013
1014 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
1015 REGISTER_KERNEL_BUILDER( \
1016 Name(name) \
1017 .Device(DEVICE_##dev) \
1018 .HostMemory("resource") \
1019 .TypeConstraint<type>("dtype") \
1020 .TypeConstraint<index_type>("Tindices"), \
1021 ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
1022
1023 #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
1024 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
1025 REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
1026
1027 #define REGISTER_SCATTER_ARITHMETIC(type, dev) \
1028 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
1029 scatter_op::UpdateOp::ADD); \
1030 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
1031 scatter_op::UpdateOp::SUB); \
1032 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
1033 scatter_op::UpdateOp::MUL); \
1034 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
1035 scatter_op::UpdateOp::DIV); \
1036 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
1037 scatter_op::UpdateOp::ASSIGN);
1038 #define REGISTER_SCATTER_MINMAX(type, dev) \
1039 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
1040 scatter_op::UpdateOp::MIN); \
1041 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
1042 scatter_op::UpdateOp::MAX);
1043
1044 // Registers CPU kernels.
1045 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
1046 REGISTER_SCATTER_ARITHMETIC(type, CPU);
1047 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
1048
1049 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
1050 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
1051
1052 REGISTER_SCATTER_KERNEL(tstring, CPU, "ResourceScatterUpdate",
1053 scatter_op::UpdateOp::ASSIGN);
1054 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
1055 scatter_op::UpdateOp::ASSIGN);
1056 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
1057 scatter_op::UpdateOp::ASSIGN);
1058
1059 // Registers GPU kernels.
1060 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1061 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
1062 REGISTER_SCATTER_ARITHMETIC(type, GPU);
1063 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
1064
1065 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
1066
1067 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
1068 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
1069
1070 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1071 .Device(DEVICE_GPU)
1072 .HostMemory("resource")
1073 .HostMemory("indices")
1074 .TypeConstraint<Variant>("dtype")
1075 .TypeConstraint<int32>("Tindices"),
1076 ResourceScatterUpdateOp<GPUDevice, Variant, int32,
1077 scatter_op::UpdateOp::ASSIGN>)
1078 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1079 .Device(DEVICE_GPU)
1080 .HostMemory("resource")
1081 .TypeConstraint<bool>("dtype")
1082 .TypeConstraint<int32>("Tindices"),
1083 ResourceScatterUpdateOp<GPUDevice, bool, int32,
1084 scatter_op::UpdateOp::ASSIGN>)
1085 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1086 .Device(DEVICE_GPU)
1087 .HostMemory("resource")
1088 .HostMemory("indices")
1089 .TypeConstraint<Variant>("dtype")
1090 .TypeConstraint<int64>("Tindices"),
1091 ResourceScatterUpdateOp<GPUDevice, Variant, int64,
1092 scatter_op::UpdateOp::ASSIGN>)
1093 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1094 .Device(DEVICE_GPU)
1095 .HostMemory("resource")
1096 .TypeConstraint<int64>("dtype")
1097 .TypeConstraint<int64>("Tindices"),
1098 ResourceScatterUpdateOp<GPUDevice, int64, int64,
1099 scatter_op::UpdateOp::ASSIGN>)
1100
1101 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1102
1103 #undef REGISTER_SCATTER_ARITHMETIC
1104 #undef REGISTER_SCATTER_ARITHMETIC_CPU
1105 #undef REGISTER_SCATTER_MINMAX
1106 #undef REGISTER_SCATTER_MINMAX_CPU
1107 #undef REGISTER_SCATTER_KERNEL
1108 #undef REGISTER_SCATTER_KERNEL_INDEX
1109
1110 } // namespace tensorflow
1111