1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 // - acquire the variable's mutex (in "shared" mode);
20 // - make a (shallow) copy of the Tensor object, which increments
21 // the reference count on the variable's TensorBuffer;
22 // - release the variable's mutex;
23 // - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 // - acquire the variable's mutex (in "exclusive" mode);
26 // - check the reference count of variable's TensorBuffer and
27 // if it is >1, make a deep copy of the variable's Tensor;
28 // - mutate the variable's Tensor;
29 // - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 // "shared" mode) for the duration of the whole read. This means
41 // that as long as you only do sparse read operations no write will
42 // see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 // that they want to perform the write without locks held
45 // (use_locking=false), we never copy even if the variable's
46 // reference count is >1.
47
48 #define EIGEN_USE_THREADS
49
50 #if GOOGLE_CUDA
51 #define EIGEN_USE_GPU
52 #endif
53
54 #include <memory>
55 #include <vector>
56
57 #include "absl/strings/str_join.h"
58 #include "tensorflow/core/common_runtime/device.h"
59 #include "tensorflow/core/framework/bounds_check.h"
60 #include "tensorflow/core/framework/op_kernel.h"
61 #include "tensorflow/core/framework/register_types.h"
62 #include "tensorflow/core/framework/resource_mgr.h"
63 #include "tensorflow/core/framework/tensor_types.h"
64 #include "tensorflow/core/framework/variant_op_registry.h"
65 #include "tensorflow/core/kernels/dense_update_functor.h"
66 #include "tensorflow/core/kernels/gather_functor.h"
67 #include "tensorflow/core/kernels/resource_variable_ops.h"
68 #include "tensorflow/core/kernels/scatter_functor.h"
69 #include "tensorflow/core/kernels/training_op_helpers.h"
70 #include "tensorflow/core/kernels/variable_ops.h"
71 #include "tensorflow/core/lib/core/errors.h"
72 #include "tensorflow/core/lib/core/refcount.h"
73 #include "tensorflow/core/platform/mem.h"
74 #include "tensorflow/core/platform/mutex.h"
75 #include "tensorflow/core/platform/types.h"
76 #include "tensorflow/core/util/util.h"
77
78 namespace tensorflow {
79
80 REGISTER_RESOURCE_HANDLE_KERNEL(Var);
81 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
82 ResourceHandlesOp<Var>);
83
ReadVariableOp(OpKernelConstruction * c)84 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
85 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
86 }
87
88 namespace {
89
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)90 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
91 Tensor* output;
92 Notification n;
93 Status status;
94 AllocatorAttributes attr;
95 if (t->dtype() == DT_VARIANT) {
96 attr.set_on_host(true);
97 }
98 TF_RETURN_IF_ERROR(
99 ctx->allocate_output(output_idx, t->shape(), &output, attr));
100 if (t->dtype() == DT_VARIANT) {
101 output->flat<Variant>() = t->flat<Variant>();
102 } else if (ctx->op_device_context() != nullptr) {
103 // TODO(apassos): remove the down_cast by just returning Device* from
104 // OpKernelContext
105 Device* device = static_cast<Device*>(ctx->device());
106 ctx->op_device_context()->CopyTensorInSameDevice(
107 t, device, output, [&n, &status](const Status& s) {
108 status = s;
109 n.Notify();
110 });
111 n.WaitForNotification();
112 return status;
113 } else {
114 switch (t->dtype()) {
115 #define HANDLER(type) \
116 case DataTypeToEnum<type>::value: \
117 output->flat<type>() = t->flat<type>(); \
118 break;
119 TF_CALL_ALL_TYPES(HANDLER);
120 #undef HANDLER
121 default:
122 return errors::Internal("Unsupported dtype", t->dtype());
123 }
124 }
125 return Status::OK();
126 }
127
128 } // namespace
129
Compute(OpKernelContext * ctx)130 void ReadVariableOp::Compute(OpKernelContext* ctx) {
131 Var* variable = nullptr;
132 const ResourceHandle& handle = HandleFromInput(ctx, 0);
133 const auto status = LookupResource(ctx, handle, &variable);
134 OP_REQUIRES(ctx, status.ok(),
135 errors::FailedPrecondition(
136 "Error while reading resource variable ", handle.name(),
137 " from Container: ", handle.container(),
138 ". This could mean that the variable was uninitialized. ",
139 status.ToString()));
140
141 core::ScopedUnref s(variable);
142 // We're acquiring a reference to the underlying buffer while
143 // holding a shared lock to guarantee ordering of reads and
144 // writes.
145 tf_shared_lock ml(*variable->mu());
146 const Tensor* t = variable->tensor();
147 OP_REQUIRES(ctx, dtype_ == t->dtype(),
148 errors::InvalidArgument(
149 "Trying to read variable with wrong dtype. Expected ",
150 DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
151 if (variable->copy_on_read_mode.load()) {
152 OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
153 } else {
154 ctx->set_output(0, *t);
155 }
156 }
157
ReadVariablesOp(OpKernelConstruction * c)158 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
159 int n;
160 OP_REQUIRES_OK(c, c->GetAttr("N", &n));
161 OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
162 OP_REQUIRES(c, n == dtypes_.size(),
163 errors::InvalidArgument(
164 "Mismatched number of arguments to ReadVariablesOp (", n,
165 " vs. ", dtypes_.size(), ")"));
166 }
167
Compute(OpKernelContext * ctx)168 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
169 std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
170 dtypes_.size());
171 std::vector<const ResourceHandle*> handles(dtypes_.size());
172 for (size_t i = 0; i < dtypes_.size(); ++i) {
173 handles[i] = &HandleFromInput(ctx, i);
174 }
175
176 OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
177
178 std::vector<string> uninitialized_vars;
179 for (int64 i = 0; i < variables.size(); i++) {
180 if (variables[i] == nullptr) {
181 uninitialized_vars.push_back(handles[i]->name());
182 }
183 }
184
185 OP_REQUIRES(
186 ctx, uninitialized_vars.empty(),
187 errors::InvalidArgument("In ReadVariableOp the following variables were "
188 "found uninitialized: ",
189 absl::StrJoin(uninitialized_vars, ", ")));
190
191 for (size_t i = 0; i < dtypes_.size(); ++i) {
192 // We're acquiring a reference to the underlying buffer while
193 // holding a shared lock to guarantee ordering of reads and
194 // writes.
195 tf_shared_lock ml(*variables[i]->mu());
196 OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
197 errors::InvalidArgument(
198 "Trying to read variable ", handles[i]->name(),
199 " from Container: ", handles[i]->container(),
200 " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
201 " got ", DataTypeString(variables[i]->tensor()->dtype())));
202 if (variables[i]->copy_on_read_mode.load()) {
203 OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
204 } else {
205 const Tensor& t = *variables[i]->tensor();
206 ctx->set_output(i, t);
207 }
208 }
209 }
210
211 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
212 ReadVariableOp);
213 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
214 ReadVariablesOp);
215
216 #if GOOGLE_CUDA
217 REGISTER_KERNEL_BUILDER(
218 Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
219 ReadVariableOp);
220 REGISTER_KERNEL_BUILDER(
221 Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
222 ReadVariablesOp);
223
224 #define REGISTER_GPU_KERNELS(type) \
225 namespace functor { \
226 template <> \
227 void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
228 const GPUDevice& d, typename TTypes<type>::Flat lhs, \
229 typename TTypes<type>::ConstFlat rhs); \
230 extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
231 } \
232 REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \
233 .Device(DEVICE_GPU) \
234 .HostMemory("resource") \
235 .TypeConstraint<type>("dtype"), \
236 ResourceHandleOp<Var>)
237 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
238 TF_CALL_int64(REGISTER_GPU_KERNELS);
239 TF_CALL_variant(REGISTER_GPU_KERNELS);
240 #undef REGISTER_GPU_KERNELS
241
242 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
243 .Device(DEVICE_GPU)
244 .HostMemory("resources")
245 .TypeConstraint("dtypes",
246 {DT_INT64, DT_COMPLEX64,
247 DT_COMPLEX128, DT_HALF, DT_FLOAT,
248 DT_DOUBLE, DT_BOOL, DT_VARIANT}),
249 ResourceHandlesOp<Var>);
250
251 #endif // GOOGLE_CUDA
252
253 template <typename T>
254 class VariableShapeOp : public OpKernel {
255 public:
VariableShapeOp(OpKernelConstruction * c)256 explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
257
Compute(OpKernelContext * ctx)258 void Compute(OpKernelContext* ctx) override {
259 Var* variable = nullptr;
260 OP_REQUIRES_OK(ctx,
261 LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
262 core::ScopedUnref s(variable);
263 variable->mu()->lock_shared();
264 TensorShape shape = variable->tensor()->shape();
265 variable->mu()->unlock_shared();
266 Tensor* output;
267 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
268 for (int i = 0; i < shape.dims(); ++i) {
269 output->flat<T>()(i) = shape.dim_size(i);
270 }
271 }
272 };
273
274 REGISTER_KERNEL_BUILDER(
275 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
276 VariableShapeOp<int32>);
277 REGISTER_KERNEL_BUILDER(
278 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
279 VariableShapeOp<int64>);
280
281 #if GOOGLE_CUDA
282
283 REGISTER_KERNEL_BUILDER(Name("VariableShape")
284 .Device(DEVICE_GPU)
285 .TypeConstraint<int32>("out_type")
286 .HostMemory("output")
287 .HostMemory("input"),
288 VariableShapeOp<int32>);
289 REGISTER_KERNEL_BUILDER(Name("VariableShape")
290 .Device(DEVICE_GPU)
291 .TypeConstraint<int64>("out_type")
292 .HostMemory("output")
293 .HostMemory("input"),
294 VariableShapeOp<int64>);
295
296 #endif // GOOGLE_CUDA
297
DestroyResourceOp(OpKernelConstruction * ctx)298 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
299 : OpKernel(ctx) {
300 OP_REQUIRES_OK(ctx,
301 ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
302 }
303
Compute(OpKernelContext * ctx)304 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
305 const ResourceHandle& p = HandleFromInput(ctx, 0);
306 Status status = DeleteResource(ctx, p);
307 if (ignore_lookup_error_ && errors::IsNotFound(status)) {
308 return;
309 }
310 OP_REQUIRES_OK(ctx, status);
311 }
312
313 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
314 DestroyResourceOp);
315 REGISTER_KERNEL_BUILDER(
316 Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
317 DestroyResourceOp);
318
319 template <typename Device, typename T>
320 class AssignVariableOp : public OpKernel {
321 public:
AssignVariableOp(OpKernelConstruction * c)322 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
323 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
324 if (!c->GetAttr("_grappler_relax_allocator_constraints",
325 &relax_constraints_)
326 .ok()) {
327 relax_constraints_ = false;
328 }
329 }
330
Compute(OpKernelContext * context)331 void Compute(OpKernelContext* context) override {
332 OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
333 errors::InvalidArgument(
334 "Variable and value dtypes don't match; respectively, ",
335 DataTypeString(dtype_), " and ",
336 DataTypeString(context->input(1).dtype())));
337 Var* variable = nullptr;
338 const Tensor& value = context->input(1);
339 // Note: every resource-variable-manipulating op assumes copy-on-write
340 // semantics, and creates a copy of the variable's Tensor if its refcount is
341 // bigger than 1 when we try to modify it. This means we never need to copy
342 // the original tensor for AssignVariableOp; even if there are other live
343 // users of it we know none can modify it so this is always safe (even in
344 // esoteric cases where the same tensor is used to initialize multiple
345 // variables or the tensor is a constant this is safe, as future writes will
346 // trigger copies).
347 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
348 context, HandleFromInput(context, 0), &variable,
349 [this, &value](Var** ptr) {
350 *ptr = new Var(dtype_);
351 *(*ptr)->tensor() = value;
352 (*ptr)->is_initialized = true;
353 return Status::OK();
354 }));
355 core::ScopedUnref s(variable);
356 mutex_lock ml(*variable->mu());
357 OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
358 errors::InvalidArgument(
359 "Trying to assign variable with wrong dtype. Expected ",
360 DataTypeString(variable->tensor()->dtype()), " got ",
361 DataTypeString(dtype_)));
362 if (variable->copy_on_read_mode.load()) {
363 PersistentTensor unused;
364 Tensor* tmp;
365 AllocatorAttributes attr;
366 attr.set_gpu_compatible(true);
367 attr.set_nic_compatible(true);
368 OP_REQUIRES_OK(context,
369 context->allocate_persistent(value.dtype(), value.shape(),
370 &unused, &tmp, attr));
371 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
372 copy_functor(context->eigen_device<Device>(), tmp->flat<T>(),
373 value.flat<T>());
374 *variable->tensor() = *tmp;
375 } else {
376 *variable->tensor() = value;
377 }
378 variable->is_initialized = true;
379 }
380
381 private:
382 DataType dtype_;
383 bool relax_constraints_;
384 };
385
386 template <typename Device>
387 class AssignVariableOp<Device, Variant> : public OpKernel {
388 public:
AssignVariableOp(OpKernelConstruction * c)389 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
390 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
391 OP_REQUIRES(c, dtype_ == DT_VARIANT,
392 errors::Internal("Variant kernel called with dtype: ",
393 DataTypeString(dtype_)));
394 }
395
Compute(OpKernelContext * context)396 void Compute(OpKernelContext* context) override {
397 const Tensor& value = context->input(1);
398 Var* variable = nullptr;
399 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
400 context, HandleFromInput(context, 0), &variable,
401 [](Var** ptr) {
402 // Created on host.
403 *ptr = new Var(DT_VARIANT);
404 return Status::OK();
405 }));
406 core::ScopedUnref s(variable);
407
408 // For purposes of forwarding DT_VARIANT, we want the least
409 // restrictive attr; we already know the input is on host.
410 AllocatorAttributes attr;
411
412 // Copying is unnecessary if we are the last user of the value
413 // tensor, we can just adopt the input tensor's buffer instead.
414 // Note that Variant objects themselves always reside on host.
415 //
416 // We nevertheless want to signal to the runtime that the tensor
417 // should reside in memory of the associated device, as Variant
418 // tensors may be marked as sitting on either CPU or GPU. This
419 // helps to elide one or more copies.
420 std::unique_ptr<Tensor> input_alias = context->forward_input(
421 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
422 value.shape(),
423 DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
424 attr);
425
426 mutex_lock ml(*variable->mu());
427 OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
428 errors::InvalidArgument(
429 "Trying to assign variable with wrong dtype. Expected ",
430 DataTypeString(variable->tensor()->dtype()), " got ",
431 DataTypeString(DT_VARIANT)));
432 variable->is_initialized = true;
433 *variable->tensor() = Tensor(DT_VARIANT, value.shape());
434
435 if (input_alias) {
436 *variable->tensor() = *input_alias;
437 return;
438 }
439
440 // Need to copy, but maybe we can re-use variable's buffer?
441 if (!variable->tensor()->RefCountIsOne() ||
442 !variable->tensor()->shape().IsSameSize(value.shape())) {
443 PersistentTensor unused;
444 Tensor* tmp;
445 // Allocation of DT_VARIANT is always on host.
446 attr.set_on_host(true);
447 OP_REQUIRES_OK(context,
448 context->allocate_persistent(DT_VARIANT, value.shape(),
449 &unused, &tmp, attr));
450 *variable->tensor() = *tmp;
451 }
452
453 const auto elements_in = value.flat<Variant>();
454 auto elements_out = variable->tensor()->flat<Variant>();
455 for (int64 i = 0; i < elements_in.size(); ++i) {
456 elements_out(i) = elements_in(i);
457 }
458 }
459
460 private:
461 DataType dtype_;
462 };
463
464 #define REGISTER_KERNELS(type) \
465 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
466 .Device(DEVICE_CPU) \
467 .TypeConstraint<type>("dtype"), \
468 AssignVariableOp<Eigen::ThreadPoolDevice, type>);
469
470 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
471 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
472 #undef REGISTER_KERNELS
473
474 #if GOOGLE_CUDA
475 #define REGISTER_GPU_KERNELS(type) \
476 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
477 .Device(DEVICE_GPU) \
478 .TypeConstraint<type>("dtype") \
479 .HostMemory("resource"), \
480 AssignVariableOp<GPUDevice, type>);
481
482 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
483 TF_CALL_int64(REGISTER_GPU_KERNELS);
484 TF_CALL_variant(REGISTER_GPU_KERNELS);
485 #undef REGISTER_GPU_KERNELS
486 #endif // GOOGLE_CUDA
487
488 template <typename Device, typename T, DenseUpdateType Op>
489 class AssignUpdateVariableOp : public OpKernel {
490 public:
AssignUpdateVariableOp(OpKernelConstruction * c)491 explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
492
Compute(OpKernelContext * context)493 void Compute(OpKernelContext* context) override {
494 Var* variable = nullptr;
495 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
496 &variable));
497 core::ScopedUnref s(variable);
498
499 const Tensor& value = context->input(1);
500 // TODO(apassos): We could possibly avoid the copy done by
501 // PrepareToUpdateVariable() for commutative operations like Op ==
502 // ADD if value's refcount was 1.
503 mutex_lock ml(*variable->mu());
504 Tensor* var_tensor = variable->tensor();
505 OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
506 errors::InvalidArgument("Cannot update variable with shape ",
507 var_tensor->shape().DebugString(),
508 " using a Tensor with shape ",
509 value.shape().DebugString(),
510 ", shapes must be equal."));
511 OP_REQUIRES_OK(
512 context, PrepareToUpdateVariable<Device, T>(
513 context, var_tensor, variable->copy_on_read_mode.load()));
514 functor::DenseUpdate<Device, T, Op> update_functor;
515 update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
516 value.flat<T>());
517 }
518 };
519
520 #define REGISTER_KERNELS(type) \
521 REGISTER_KERNEL_BUILDER( \
522 Name("AssignAddVariableOp") \
523 .Device(DEVICE_CPU) \
524 .TypeConstraint<type>("dtype"), \
525 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
526 REGISTER_KERNEL_BUILDER( \
527 Name("AssignSubVariableOp") \
528 .Device(DEVICE_CPU) \
529 .TypeConstraint<type>("dtype"), \
530 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
531
532 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
533 #undef REGISTER_KERNELS
534
535 #if GOOGLE_CUDA
536 #define REGISTER_GPU_KERNELS(type) \
537 REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
538 .Device(DEVICE_GPU) \
539 .HostMemory("resource") \
540 .TypeConstraint<type>("dtype"), \
541 AssignUpdateVariableOp<GPUDevice, type, ADD>); \
542 REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \
543 .Device(DEVICE_GPU) \
544 .HostMemory("resource") \
545 .TypeConstraint<type>("dtype"), \
546 AssignUpdateVariableOp<GPUDevice, type, SUB>);
547
548 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
549 TF_CALL_int64(REGISTER_GPU_KERNELS);
550 #undef REGISTER_GPU_KERNELS
551 #endif // GOOGLE_CUDA
552
553 class VarIsInitializedOp : public OpKernel {
554 public:
VarIsInitializedOp(OpKernelConstruction * c)555 explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
556
Compute(OpKernelContext * context)557 void Compute(OpKernelContext* context) override {
558 Tensor* output = nullptr;
559 OP_REQUIRES_OK(context,
560 context->allocate_output(0, TensorShape({}), &output));
561 auto output_tensor = output->tensor<bool, 0>();
562 Var* variable = nullptr;
563 Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
564 if (!s.ok()) {
565 output_tensor() = false;
566 return;
567 }
568 core::ScopedUnref su(variable);
569 mutex_lock ml(*variable->mu());
570 output_tensor() = variable->is_initialized;
571 }
572 };
573
574 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
575 VarIsInitializedOp);
576
577 #if GOOGLE_CUDA
578 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
579 .Device(DEVICE_GPU)
580 .HostMemory("resource")
581 .HostMemory("is_initialized"),
582 IsResourceInitialized<Var>);
583 #endif // GOOGLE_CUDA
584
585 template <typename Device, typename T, typename Index>
586 class ResourceGatherOp : public OpKernel {
587 private:
588 int32 batch_dims_ = 0;
589
590 // Add the batch offset derrived from params to each batch of indices.
591 // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
592 // If indexing into a params dimension of size 4, then the indices will become
593 // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(Tensor * indices,const Tensor & params)594 void AddBatchOffsets(Tensor* indices, const Tensor& params) {
595 int64 batch_size = 1; // The size of all batch dimensions.
596 for (int idx = 0; idx < batch_dims_; ++idx) {
597 batch_size *= params.dim_size(idx);
598 }
599
600 auto indices_flat = indices->flat<Index>();
601 int64 const index_inner_size = indices->NumElements() / batch_size;
602 int64 const batch_offset = params.dim_size(batch_dims_);
603 for (int64 batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
604 ++batch_idx) {
605 for (int64 idx = 0; idx < index_inner_size; ++idx) {
606 indices_flat(dest_idx++) += batch_offset * batch_idx;
607 }
608 }
609 }
610
611 public:
ResourceGatherOp(OpKernelConstruction * c)612 explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
613 OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
614 }
615
Compute(OpKernelContext * c)616 void Compute(OpKernelContext* c) override {
617 Var* v = nullptr;
618 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
619 core::ScopedUnref su(v);
620 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
621 // NOTE: We hold the lock for the whole gather operation instead
622 // of increasing the reference count of v->tensor() to avoid a
623 // situation where a write to the same variable will see a
624 // reference count greater than one and make a copy of the
625 // (potentially very large) tensor buffer.
626 tf_shared_lock ml(*v->mu());
627 const Tensor& params = *v->tensor();
628 const Tensor& indices = c->input(1);
629 OP_REQUIRES(
630 c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
631 errors::InvalidArgument("params must be at least 1 dimensional"));
632
633 // Check that we have enough index space
634 const int64 N = indices.NumElements();
635 OP_REQUIRES(
636 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
637 errors::InvalidArgument("params.shape[0] too large for ",
638 DataTypeString(DataTypeToEnum<Index>::v()),
639 " indexing: ", params.dim_size(0), " > ",
640 std::numeric_limits<Index>::max()));
641
642 // The result shape is params.shape[:batch_dims] +
643 // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
644 TensorShape result_shape;
645 for (int i = 0; i < batch_dims_; ++i) {
646 result_shape.AddDim(params.dim_size(i));
647 }
648 for (int i = batch_dims_; i < indices.dims(); ++i) {
649 result_shape.AddDim(indices.dim_size(i));
650 }
651 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
652 result_shape.AddDim(params.dim_size(i));
653 }
654
655 Tensor* out = nullptr;
656 Tensor tmp;
657 if (params.dtype() == DT_VARIANT) {
658 tmp = Tensor(DT_VARIANT, result_shape);
659 c->set_output(0, tmp);
660 out = &tmp;
661 } else {
662 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
663 }
664
665 if (N > 0) {
666 Tensor tmp_indices;
667
668 // Points to the original or updated (if batch_dims is set) indices.
669 const Tensor* op_indices = &indices;
670 if (batch_dims_ > 0) {
671 OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
672 &tmp_indices));
673 functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
674 copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
675 indices.flat<Index>());
676
677 AddBatchOffsets(&tmp_indices, params);
678 op_indices = &tmp_indices;
679 }
680
681 int64 gather_dim_size = 1;
682 for (int idx = 0; idx <= batch_dims_; ++idx) {
683 gather_dim_size *= params.dim_size(idx);
684 }
685 int64 inner_size = 1;
686 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
687 inner_size *= params.dim_size(i);
688 }
689 auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
690 const auto indices_flat = op_indices->flat<Index>();
691 auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
692
693 functor::GatherFunctor<Device, T, Index> functor;
694 int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
695
696 OP_REQUIRES(
697 c, bad_i < 0,
698 errors::InvalidArgument(
699 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
700 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
701 }
702 }
703 };
704
705 #define REGISTER_GATHER_FULL(dev, type, index_type) \
706 REGISTER_KERNEL_BUILDER(Name("ResourceGather") \
707 .Device(DEVICE_##dev) \
708 .HostMemory("resource") \
709 .TypeConstraint<type>("dtype") \
710 .TypeConstraint<index_type>("Tindices"), \
711 ResourceGatherOp<dev##Device, type, index_type>)
712
713 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
714 REGISTER_GATHER_FULL(dev, type, int32); \
715 REGISTER_GATHER_FULL(dev, type, int64)
716
717 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
718
719 // Registration of the CPU implementations.
720 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
721 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
722
723 // Registers GPU kernels.
724 #if GOOGLE_CUDA
725 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
726
727 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
728
729 // Variant objects themselves sit on CPU, even if they contain data
730 // pointing to a device.
731 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
732 .Device(DEVICE_GPU)
733 .HostMemory("resource")
734 .HostMemory("indices")
735 .TypeConstraint<Variant>("dtype")
736 .TypeConstraint<int32>("Tindices"),
737 ResourceGatherOp<GPUDevice, Variant, int32>)
738 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
739 .Device(DEVICE_GPU)
740 .HostMemory("resource")
741 .HostMemory("indices")
742 .TypeConstraint<Variant>("dtype")
743 .TypeConstraint<int64>("Tindices"),
744 ResourceGatherOp<GPUDevice, Variant, int64>)
745
746 #endif // GOOGLE_CUDA
747
748 #undef REGISTER_GATHER_CPU
749 #undef REGISTER_GATHER_GPU
750 #undef REGISTER_GATHER_ALL_INDICES
751 #undef REGISTER_GATHER_FULL
752
753 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
754 class ResourceScatterUpdateOp : public OpKernel {
755 public:
ResourceScatterUpdateOp(OpKernelConstruction * c)756 explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
757
Compute(OpKernelContext * c)758 void Compute(OpKernelContext* c) override {
759 Var* v = nullptr;
760 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
761 core::ScopedUnref unref_v(v);
762 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
763 tf_shared_lock ml(*v->mu());
764 Tensor* params = v->tensor();
765 const Tensor& indices = c->input(1);
766 const Tensor& updates = c->input(2);
767
768 // Check that we have enough index space
769 const int64 N_big = indices.NumElements();
770 OP_REQUIRES(
771 c, N_big <= std::numeric_limits<Index>::max(),
772 errors::InvalidArgument("indices has too many elements for ",
773 DataTypeString(DataTypeToEnum<Index>::v()),
774 " indexing: ", N_big, " > ",
775 std::numeric_limits<Index>::max()));
776 const Index N = static_cast<Index>(N_big);
777 OP_REQUIRES(
778 c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
779 errors::InvalidArgument("params.shape[0] too large for ",
780 DataTypeString(DataTypeToEnum<Index>::v()),
781 " indexing: ", params->dim_size(0), " > ",
782 std::numeric_limits<Index>::max()));
783
784 if (N > 0) {
785 auto indices_flat = indices.flat<Index>();
786 auto params_flat = params->flat_outer_dims<T>();
787 if (TensorShapeUtils::IsScalar(updates.shape())) {
788 const auto update = updates.scalar<T>();
789
790 functor::ScatterScalarFunctor<Device, T, Index, op> functor;
791 const Index bad_i = functor(c, c->template eigen_device<Device>(),
792 params_flat, update, indices_flat);
793 OP_REQUIRES(c, bad_i < 0,
794 errors::InvalidArgument(
795 "indices", SliceDebugString(indices.shape(), bad_i),
796 " = ", indices_flat(bad_i), " is not in [0, ",
797 params->dim_size(0), ")"));
798 } else {
799 int64 num_updates = updates.NumElements();
800 OP_REQUIRES(c, num_updates % N == 0,
801 errors::InvalidArgument(
802 "shape of indices (", indices.shape().DebugString(),
803 ") is not compatible with the shape of updates (",
804 updates.shape().DebugString(), ")"));
805 auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
806
807 functor::ScatterFunctor<Device, T, Index, op> functor;
808 const Index bad_i = functor(c, c->template eigen_device<Device>(),
809 params_flat, updates_flat, indices_flat);
810 OP_REQUIRES(c, bad_i < 0,
811 errors::InvalidArgument(
812 "indices", SliceDebugString(indices.shape(), bad_i),
813 " = ", indices_flat(bad_i), " is not in [0, ",
814 params->dim_size(0), ")"));
815 }
816 }
817 }
818 };
819
820 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
821 REGISTER_KERNEL_BUILDER( \
822 Name(name) \
823 .Device(DEVICE_##dev) \
824 .HostMemory("resource") \
825 .TypeConstraint<type>("dtype") \
826 .TypeConstraint<index_type>("Tindices"), \
827 ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
828
829 #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
830 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
831 REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
832
833 #define REGISTER_SCATTER_ARITHMETIC(type, dev) \
834 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
835 scatter_op::UpdateOp::ADD); \
836 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
837 scatter_op::UpdateOp::SUB); \
838 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
839 scatter_op::UpdateOp::MUL); \
840 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
841 scatter_op::UpdateOp::DIV); \
842 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
843 scatter_op::UpdateOp::ASSIGN);
844 #define REGISTER_SCATTER_MINMAX(type, dev) \
845 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
846 scatter_op::UpdateOp::MIN); \
847 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
848 scatter_op::UpdateOp::MAX);
849
850 // Registers CPU kernels.
851 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
852 REGISTER_SCATTER_ARITHMETIC(type, CPU);
853 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
854
855 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
856 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
857
858 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
859 scatter_op::UpdateOp::ASSIGN);
860 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
861 scatter_op::UpdateOp::ASSIGN);
862 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
863 scatter_op::UpdateOp::ASSIGN);
864
865 // Registers GPU kernels.
866 #if GOOGLE_CUDA
867 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
868 REGISTER_SCATTER_ARITHMETIC(type, GPU);
869 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
870
871 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
872
873 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
874 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
875
876 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
877 .Device(DEVICE_GPU)
878 .HostMemory("resource")
879 .HostMemory("indices")
880 .TypeConstraint<Variant>("dtype")
881 .TypeConstraint<int32>("Tindices"),
882 ResourceScatterUpdateOp<GPUDevice, Variant, int32,
883 scatter_op::UpdateOp::ASSIGN>)
884 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
885 .Device(DEVICE_GPU)
886 .HostMemory("resource")
887 .TypeConstraint<bool>("dtype")
888 .TypeConstraint<int32>("Tindices"),
889 ResourceScatterUpdateOp<GPUDevice, bool, int32,
890 scatter_op::UpdateOp::ASSIGN>)
891 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
892 .Device(DEVICE_GPU)
893 .HostMemory("resource")
894 .HostMemory("indices")
895 .TypeConstraint<Variant>("dtype")
896 .TypeConstraint<int64>("Tindices"),
897 ResourceScatterUpdateOp<GPUDevice, Variant, int64,
898 scatter_op::UpdateOp::ASSIGN>)
899
900 #endif // GOOGLE_CUDA
901
902 #undef REGISTER_SCATTER_ARITHMETIC
903 #undef REGISTER_SCATTER_ARITHMETIC_CPU
904 #undef REGISTER_SCATTER_MINMAX
905 #undef REGISTER_SCATTER_MINMAX_CPU
906 #undef REGISTER_SCATTER_KERNEL
907 #undef REGISTER_SCATTER_KERNEL_INDEX
908
909 } // namespace tensorflow
910