• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 //   - acquire the variable's mutex (in "shared" mode);
20 //   - make a (shallow) copy of the Tensor object, which increments
21 //     the reference count on the variable's TensorBuffer;
22 //   - release the variable's mutex;
23 //   - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 //   - acquire the variable's mutex (in "exclusive" mode);
26 //   - check the reference count of variable's TensorBuffer and
27 //     if it is >1, make a deep copy of the variable's Tensor;
28 //   - mutate the variable's Tensor;
29 //   - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 //   "shared" mode) for the duration of the whole read. This means
41 //   that as long as you only do sparse read operations no write will
42 //   see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 //   that they want to perform the write without locks held
45 //   (use_locking=false), we never copy even if the variable's
46 //   reference count is >1.
47 
48 #define EIGEN_USE_THREADS
49 
50 #if GOOGLE_CUDA
51 #define EIGEN_USE_GPU
52 #endif
53 
54 #include <memory>
55 #include <vector>
56 
57 #include "absl/strings/str_join.h"
58 #include "tensorflow/core/common_runtime/device.h"
59 #include "tensorflow/core/framework/bounds_check.h"
60 #include "tensorflow/core/framework/op_kernel.h"
61 #include "tensorflow/core/framework/register_types.h"
62 #include "tensorflow/core/framework/resource_mgr.h"
63 #include "tensorflow/core/framework/tensor_types.h"
64 #include "tensorflow/core/framework/variant_op_registry.h"
65 #include "tensorflow/core/kernels/dense_update_functor.h"
66 #include "tensorflow/core/kernels/gather_functor.h"
67 #include "tensorflow/core/kernels/resource_variable_ops.h"
68 #include "tensorflow/core/kernels/scatter_functor.h"
69 #include "tensorflow/core/kernels/training_op_helpers.h"
70 #include "tensorflow/core/kernels/variable_ops.h"
71 #include "tensorflow/core/lib/core/errors.h"
72 #include "tensorflow/core/lib/core/refcount.h"
73 #include "tensorflow/core/platform/mem.h"
74 #include "tensorflow/core/platform/mutex.h"
75 #include "tensorflow/core/platform/types.h"
76 #include "tensorflow/core/util/util.h"
77 
78 namespace tensorflow {
79 
80 REGISTER_RESOURCE_HANDLE_KERNEL(Var);
81 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
82                         ResourceHandlesOp<Var>);
83 
ReadVariableOp(OpKernelConstruction * c)84 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
85   OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
86 }
87 
88 namespace {
89 
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)90 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
91   Tensor* output;
92   Notification n;
93   Status status;
94   AllocatorAttributes attr;
95   if (t->dtype() == DT_VARIANT) {
96     attr.set_on_host(true);
97   }
98   TF_RETURN_IF_ERROR(
99       ctx->allocate_output(output_idx, t->shape(), &output, attr));
100   if (t->dtype() == DT_VARIANT) {
101     output->flat<Variant>() = t->flat<Variant>();
102   } else if (ctx->op_device_context() != nullptr) {
103     // TODO(apassos): remove the down_cast by just returning Device* from
104     // OpKernelContext
105     Device* device = static_cast<Device*>(ctx->device());
106     ctx->op_device_context()->CopyTensorInSameDevice(
107         t, device, output, [&n, &status](const Status& s) {
108           status = s;
109           n.Notify();
110         });
111     n.WaitForNotification();
112     return status;
113   } else {
114     switch (t->dtype()) {
115 #define HANDLER(type)                       \
116   case DataTypeToEnum<type>::value:         \
117     output->flat<type>() = t->flat<type>(); \
118     break;
119       TF_CALL_ALL_TYPES(HANDLER);
120 #undef HANDLER
121       default:
122         return errors::Internal("Unsupported dtype", t->dtype());
123     }
124   }
125   return Status::OK();
126 }
127 
128 }  // namespace
129 
Compute(OpKernelContext * ctx)130 void ReadVariableOp::Compute(OpKernelContext* ctx) {
131   Var* variable = nullptr;
132   const ResourceHandle& handle = HandleFromInput(ctx, 0);
133   const auto status = LookupResource(ctx, handle, &variable);
134   OP_REQUIRES(ctx, status.ok(),
135               errors::FailedPrecondition(
136                   "Error while reading resource variable ", handle.name(),
137                   " from Container: ", handle.container(),
138                   ". This could mean that the variable was uninitialized. ",
139                   status.ToString()));
140 
141   core::ScopedUnref s(variable);
142   // We're acquiring a reference to the underlying buffer while
143   // holding a shared lock to guarantee ordering of reads and
144   // writes.
145   tf_shared_lock ml(*variable->mu());
146   const Tensor* t = variable->tensor();
147   OP_REQUIRES(ctx, dtype_ == t->dtype(),
148               errors::InvalidArgument(
149                   "Trying to read variable with wrong dtype. Expected ",
150                   DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
151   if (variable->copy_on_read_mode.load()) {
152     OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
153   } else {
154     ctx->set_output(0, *t);
155   }
156 }
157 
ReadVariablesOp(OpKernelConstruction * c)158 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
159   int n;
160   OP_REQUIRES_OK(c, c->GetAttr("N", &n));
161   OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
162   OP_REQUIRES(c, n == dtypes_.size(),
163               errors::InvalidArgument(
164                   "Mismatched number of arguments to ReadVariablesOp (", n,
165                   " vs. ", dtypes_.size(), ")"));
166 }
167 
Compute(OpKernelContext * ctx)168 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
169   std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
170       dtypes_.size());
171   std::vector<const ResourceHandle*> handles(dtypes_.size());
172   for (size_t i = 0; i < dtypes_.size(); ++i) {
173     handles[i] = &HandleFromInput(ctx, i);
174   }
175 
176   OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
177 
178   std::vector<string> uninitialized_vars;
179   for (int64 i = 0; i < variables.size(); i++) {
180     if (variables[i] == nullptr) {
181       uninitialized_vars.push_back(handles[i]->name());
182     }
183   }
184 
185   OP_REQUIRES(
186       ctx, uninitialized_vars.empty(),
187       errors::InvalidArgument("In ReadVariableOp the following variables were "
188                               "found uninitialized: ",
189                               absl::StrJoin(uninitialized_vars, ", ")));
190 
191   for (size_t i = 0; i < dtypes_.size(); ++i) {
192     // We're acquiring a reference to the underlying buffer while
193     // holding a shared lock to guarantee ordering of reads and
194     // writes.
195     tf_shared_lock ml(*variables[i]->mu());
196     OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
197                 errors::InvalidArgument(
198                     "Trying to read variable ", handles[i]->name(),
199                     " from Container: ", handles[i]->container(),
200                     " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
201                     " got ", DataTypeString(variables[i]->tensor()->dtype())));
202     if (variables[i]->copy_on_read_mode.load()) {
203       OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
204     } else {
205       const Tensor& t = *variables[i]->tensor();
206       ctx->set_output(i, t);
207     }
208   }
209 }
210 
211 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
212                         ReadVariableOp);
213 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
214                         ReadVariablesOp);
215 
216 #if GOOGLE_CUDA
217 REGISTER_KERNEL_BUILDER(
218     Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
219     ReadVariableOp);
220 REGISTER_KERNEL_BUILDER(
221     Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
222     ReadVariablesOp);
223 
224 #define REGISTER_GPU_KERNELS(type)                             \
225   namespace functor {                                          \
226   template <>                                                  \
227   void DenseUpdate<GPUDevice, type, ASSIGN>::operator()(       \
228       const GPUDevice& d, typename TTypes<type>::Flat lhs,     \
229       typename TTypes<type>::ConstFlat rhs);                   \
230   extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
231   }                                                            \
232   REGISTER_KERNEL_BUILDER(Name("VarHandleOp")                  \
233                               .Device(DEVICE_GPU)              \
234                               .HostMemory("resource")          \
235                               .TypeConstraint<type>("dtype"),  \
236                           ResourceHandleOp<Var>)
237 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
238 TF_CALL_int64(REGISTER_GPU_KERNELS);
239 TF_CALL_variant(REGISTER_GPU_KERNELS);
240 #undef REGISTER_GPU_KERNELS
241 
242 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
243                             .Device(DEVICE_GPU)
244                             .HostMemory("resources")
245                             .TypeConstraint("dtypes",
246                                             {DT_INT64, DT_COMPLEX64,
247                                              DT_COMPLEX128, DT_HALF, DT_FLOAT,
248                                              DT_DOUBLE, DT_BOOL, DT_VARIANT}),
249                         ResourceHandlesOp<Var>);
250 
251 #endif  // GOOGLE_CUDA
252 
253 template <typename T>
254 class VariableShapeOp : public OpKernel {
255  public:
VariableShapeOp(OpKernelConstruction * c)256   explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
257 
Compute(OpKernelContext * ctx)258   void Compute(OpKernelContext* ctx) override {
259     Var* variable = nullptr;
260     OP_REQUIRES_OK(ctx,
261                    LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
262     core::ScopedUnref s(variable);
263     variable->mu()->lock_shared();
264     TensorShape shape = variable->tensor()->shape();
265     variable->mu()->unlock_shared();
266     Tensor* output;
267     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
268     for (int i = 0; i < shape.dims(); ++i) {
269       output->flat<T>()(i) = shape.dim_size(i);
270     }
271   }
272 };
273 
274 REGISTER_KERNEL_BUILDER(
275     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
276     VariableShapeOp<int32>);
277 REGISTER_KERNEL_BUILDER(
278     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
279     VariableShapeOp<int64>);
280 
281 #if GOOGLE_CUDA
282 
283 REGISTER_KERNEL_BUILDER(Name("VariableShape")
284                             .Device(DEVICE_GPU)
285                             .TypeConstraint<int32>("out_type")
286                             .HostMemory("output")
287                             .HostMemory("input"),
288                         VariableShapeOp<int32>);
289 REGISTER_KERNEL_BUILDER(Name("VariableShape")
290                             .Device(DEVICE_GPU)
291                             .TypeConstraint<int64>("out_type")
292                             .HostMemory("output")
293                             .HostMemory("input"),
294                         VariableShapeOp<int64>);
295 
296 #endif  // GOOGLE_CUDA
297 
DestroyResourceOp(OpKernelConstruction * ctx)298 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
299     : OpKernel(ctx) {
300   OP_REQUIRES_OK(ctx,
301                  ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
302 }
303 
Compute(OpKernelContext * ctx)304 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
305   const ResourceHandle& p = HandleFromInput(ctx, 0);
306   Status status = DeleteResource(ctx, p);
307   if (ignore_lookup_error_ && errors::IsNotFound(status)) {
308     return;
309   }
310   OP_REQUIRES_OK(ctx, status);
311 }
312 
313 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
314                         DestroyResourceOp);
315 REGISTER_KERNEL_BUILDER(
316     Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
317     DestroyResourceOp);
318 
319 template <typename Device, typename T>
320 class AssignVariableOp : public OpKernel {
321  public:
AssignVariableOp(OpKernelConstruction * c)322   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
323     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
324     if (!c->GetAttr("_grappler_relax_allocator_constraints",
325                     &relax_constraints_)
326              .ok()) {
327       relax_constraints_ = false;
328     }
329   }
330 
Compute(OpKernelContext * context)331   void Compute(OpKernelContext* context) override {
332     OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
333                 errors::InvalidArgument(
334                     "Variable and value dtypes don't match; respectively, ",
335                     DataTypeString(dtype_), " and ",
336                     DataTypeString(context->input(1).dtype())));
337     Var* variable = nullptr;
338     const Tensor& value = context->input(1);
339     // Note: every resource-variable-manipulating op assumes copy-on-write
340     // semantics, and creates a copy of the variable's Tensor if its refcount is
341     // bigger than 1 when we try to modify it. This means we never need to copy
342     // the original tensor for AssignVariableOp; even if there are other live
343     // users of it we know none can modify it so this is always safe (even in
344     // esoteric cases where the same tensor is used to initialize multiple
345     // variables or the tensor is a constant this is safe, as future writes will
346     // trigger copies).
347     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
348                                 context, HandleFromInput(context, 0), &variable,
349                                 [this, &value](Var** ptr) {
350                                   *ptr = new Var(dtype_);
351                                   *(*ptr)->tensor() = value;
352                                   (*ptr)->is_initialized = true;
353                                   return Status::OK();
354                                 }));
355     core::ScopedUnref s(variable);
356     mutex_lock ml(*variable->mu());
357     OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
358                 errors::InvalidArgument(
359                     "Trying to assign variable with wrong dtype. Expected ",
360                     DataTypeString(variable->tensor()->dtype()), " got ",
361                     DataTypeString(dtype_)));
362     if (variable->copy_on_read_mode.load()) {
363       PersistentTensor unused;
364       Tensor* tmp;
365       AllocatorAttributes attr;
366       attr.set_gpu_compatible(true);
367       attr.set_nic_compatible(true);
368       OP_REQUIRES_OK(context,
369                      context->allocate_persistent(value.dtype(), value.shape(),
370                                                   &unused, &tmp, attr));
371       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
372       copy_functor(context->eigen_device<Device>(), tmp->flat<T>(),
373                    value.flat<T>());
374       *variable->tensor() = *tmp;
375     } else {
376       *variable->tensor() = value;
377     }
378     variable->is_initialized = true;
379   }
380 
381  private:
382   DataType dtype_;
383   bool relax_constraints_;
384 };
385 
386 template <typename Device>
387 class AssignVariableOp<Device, Variant> : public OpKernel {
388  public:
AssignVariableOp(OpKernelConstruction * c)389   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
390     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
391     OP_REQUIRES(c, dtype_ == DT_VARIANT,
392                 errors::Internal("Variant kernel called with dtype: ",
393                                  DataTypeString(dtype_)));
394   }
395 
Compute(OpKernelContext * context)396   void Compute(OpKernelContext* context) override {
397     const Tensor& value = context->input(1);
398     Var* variable = nullptr;
399     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
400                                 context, HandleFromInput(context, 0), &variable,
401                                 [](Var** ptr) {
402                                   // Created on host.
403                                   *ptr = new Var(DT_VARIANT);
404                                   return Status::OK();
405                                 }));
406     core::ScopedUnref s(variable);
407 
408     // For purposes of forwarding DT_VARIANT, we want the least
409     // restrictive attr; we already know the input is on host.
410     AllocatorAttributes attr;
411 
412     // Copying is unnecessary if we are the last user of the value
413     // tensor, we can just adopt the input tensor's buffer instead.
414     // Note that Variant objects themselves always reside on host.
415     //
416     // We nevertheless want to signal to the runtime that the tensor
417     // should reside in memory of the associated device, as Variant
418     // tensors may be marked as sitting on either CPU or GPU.  This
419     // helps to elide one or more copies.
420     std::unique_ptr<Tensor> input_alias = context->forward_input(
421         1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
422         value.shape(),
423         DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
424         attr);
425 
426     mutex_lock ml(*variable->mu());
427     OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
428                 errors::InvalidArgument(
429                     "Trying to assign variable with wrong dtype. Expected ",
430                     DataTypeString(variable->tensor()->dtype()), " got ",
431                     DataTypeString(DT_VARIANT)));
432     variable->is_initialized = true;
433     *variable->tensor() = Tensor(DT_VARIANT, value.shape());
434 
435     if (input_alias) {
436       *variable->tensor() = *input_alias;
437       return;
438     }
439 
440     // Need to copy, but maybe we can re-use variable's buffer?
441     if (!variable->tensor()->RefCountIsOne() ||
442         !variable->tensor()->shape().IsSameSize(value.shape())) {
443       PersistentTensor unused;
444       Tensor* tmp;
445       // Allocation of DT_VARIANT is always on host.
446       attr.set_on_host(true);
447       OP_REQUIRES_OK(context,
448                      context->allocate_persistent(DT_VARIANT, value.shape(),
449                                                   &unused, &tmp, attr));
450       *variable->tensor() = *tmp;
451     }
452 
453     const auto elements_in = value.flat<Variant>();
454     auto elements_out = variable->tensor()->flat<Variant>();
455     for (int64 i = 0; i < elements_in.size(); ++i) {
456       elements_out(i) = elements_in(i);
457     }
458   }
459 
460  private:
461   DataType dtype_;
462 };
463 
464 #define REGISTER_KERNELS(type)                                \
465   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")            \
466                               .Device(DEVICE_CPU)             \
467                               .TypeConstraint<type>("dtype"), \
468                           AssignVariableOp<Eigen::ThreadPoolDevice, type>);
469 
470 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
471 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
472 #undef REGISTER_KERNELS
473 
474 #if GOOGLE_CUDA
475 #define REGISTER_GPU_KERNELS(type)                           \
476   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")           \
477                               .Device(DEVICE_GPU)            \
478                               .TypeConstraint<type>("dtype") \
479                               .HostMemory("resource"),       \
480                           AssignVariableOp<GPUDevice, type>);
481 
482 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
483 TF_CALL_int64(REGISTER_GPU_KERNELS);
484 TF_CALL_variant(REGISTER_GPU_KERNELS);
485 #undef REGISTER_GPU_KERNELS
486 #endif  // GOOGLE_CUDA
487 
488 template <typename Device, typename T, DenseUpdateType Op>
489 class AssignUpdateVariableOp : public OpKernel {
490  public:
AssignUpdateVariableOp(OpKernelConstruction * c)491   explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
492 
Compute(OpKernelContext * context)493   void Compute(OpKernelContext* context) override {
494     Var* variable = nullptr;
495     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
496                                            &variable));
497     core::ScopedUnref s(variable);
498 
499     const Tensor& value = context->input(1);
500     // TODO(apassos): We could possibly avoid the copy done by
501     // PrepareToUpdateVariable() for commutative operations like Op ==
502     // ADD if value's refcount was 1.
503     mutex_lock ml(*variable->mu());
504     Tensor* var_tensor = variable->tensor();
505     OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
506                 errors::InvalidArgument("Cannot update variable with shape ",
507                                         var_tensor->shape().DebugString(),
508                                         " using a Tensor with shape ",
509                                         value.shape().DebugString(),
510                                         ", shapes must be equal."));
511     OP_REQUIRES_OK(
512         context, PrepareToUpdateVariable<Device, T>(
513                      context, var_tensor, variable->copy_on_read_mode.load()));
514     functor::DenseUpdate<Device, T, Op> update_functor;
515     update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
516                    value.flat<T>());
517   }
518 };
519 
520 #define REGISTER_KERNELS(type)                                     \
521   REGISTER_KERNEL_BUILDER(                                         \
522       Name("AssignAddVariableOp")                                  \
523           .Device(DEVICE_CPU)                                      \
524           .TypeConstraint<type>("dtype"),                          \
525       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
526   REGISTER_KERNEL_BUILDER(                                         \
527       Name("AssignSubVariableOp")                                  \
528           .Device(DEVICE_CPU)                                      \
529           .TypeConstraint<type>("dtype"),                          \
530       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
531 
532 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
533 #undef REGISTER_KERNELS
534 
535 #if GOOGLE_CUDA
536 #define REGISTER_GPU_KERNELS(type)                                       \
537   REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp")                    \
538                               .Device(DEVICE_GPU)                        \
539                               .HostMemory("resource")                    \
540                               .TypeConstraint<type>("dtype"),            \
541                           AssignUpdateVariableOp<GPUDevice, type, ADD>); \
542   REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp")                    \
543                               .Device(DEVICE_GPU)                        \
544                               .HostMemory("resource")                    \
545                               .TypeConstraint<type>("dtype"),            \
546                           AssignUpdateVariableOp<GPUDevice, type, SUB>);
547 
548 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
549 TF_CALL_int64(REGISTER_GPU_KERNELS);
550 #undef REGISTER_GPU_KERNELS
551 #endif  // GOOGLE_CUDA
552 
553 class VarIsInitializedOp : public OpKernel {
554  public:
VarIsInitializedOp(OpKernelConstruction * c)555   explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
556 
Compute(OpKernelContext * context)557   void Compute(OpKernelContext* context) override {
558     Tensor* output = nullptr;
559     OP_REQUIRES_OK(context,
560                    context->allocate_output(0, TensorShape({}), &output));
561     auto output_tensor = output->tensor<bool, 0>();
562     Var* variable = nullptr;
563     Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
564     if (!s.ok()) {
565       output_tensor() = false;
566       return;
567     }
568     core::ScopedUnref su(variable);
569     mutex_lock ml(*variable->mu());
570     output_tensor() = variable->is_initialized;
571   }
572 };
573 
574 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
575                         VarIsInitializedOp);
576 
577 #if GOOGLE_CUDA
578 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
579                             .Device(DEVICE_GPU)
580                             .HostMemory("resource")
581                             .HostMemory("is_initialized"),
582                         IsResourceInitialized<Var>);
583 #endif  // GOOGLE_CUDA
584 
585 template <typename Device, typename T, typename Index>
586 class ResourceGatherOp : public OpKernel {
587  private:
588   int32 batch_dims_ = 0;
589 
590   // Add the batch offset derrived from params to each batch of indices.
591   // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
592   // If indexing into a params dimension of size 4, then the indices will become
593   // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(Tensor * indices,const Tensor & params)594   void AddBatchOffsets(Tensor* indices, const Tensor& params) {
595     int64 batch_size = 1;  // The size of all batch dimensions.
596     for (int idx = 0; idx < batch_dims_; ++idx) {
597       batch_size *= params.dim_size(idx);
598     }
599 
600     auto indices_flat = indices->flat<Index>();
601     int64 const index_inner_size = indices->NumElements() / batch_size;
602     int64 const batch_offset = params.dim_size(batch_dims_);
603     for (int64 batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
604          ++batch_idx) {
605       for (int64 idx = 0; idx < index_inner_size; ++idx) {
606         indices_flat(dest_idx++) += batch_offset * batch_idx;
607       }
608     }
609   }
610 
611  public:
ResourceGatherOp(OpKernelConstruction * c)612   explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
613     OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
614   }
615 
Compute(OpKernelContext * c)616   void Compute(OpKernelContext* c) override {
617     Var* v = nullptr;
618     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
619     core::ScopedUnref su(v);
620     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
621     // NOTE: We hold the lock for the whole gather operation instead
622     // of increasing the reference count of v->tensor() to avoid a
623     // situation where a write to the same variable will see a
624     // reference count greater than one and make a copy of the
625     // (potentially very large) tensor buffer.
626     tf_shared_lock ml(*v->mu());
627     const Tensor& params = *v->tensor();
628     const Tensor& indices = c->input(1);
629     OP_REQUIRES(
630         c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
631         errors::InvalidArgument("params must be at least 1 dimensional"));
632 
633     // Check that we have enough index space
634     const int64 N = indices.NumElements();
635     OP_REQUIRES(
636         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
637         errors::InvalidArgument("params.shape[0] too large for ",
638                                 DataTypeString(DataTypeToEnum<Index>::v()),
639                                 " indexing: ", params.dim_size(0), " > ",
640                                 std::numeric_limits<Index>::max()));
641 
642     // The result shape is params.shape[:batch_dims] +
643     // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
644     TensorShape result_shape;
645     for (int i = 0; i < batch_dims_; ++i) {
646       result_shape.AddDim(params.dim_size(i));
647     }
648     for (int i = batch_dims_; i < indices.dims(); ++i) {
649       result_shape.AddDim(indices.dim_size(i));
650     }
651     for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
652       result_shape.AddDim(params.dim_size(i));
653     }
654 
655     Tensor* out = nullptr;
656     Tensor tmp;
657     if (params.dtype() == DT_VARIANT) {
658       tmp = Tensor(DT_VARIANT, result_shape);
659       c->set_output(0, tmp);
660       out = &tmp;
661     } else {
662       OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
663     }
664 
665     if (N > 0) {
666       Tensor tmp_indices;
667 
668       // Points to the original or updated (if batch_dims is set) indices.
669       const Tensor* op_indices = &indices;
670       if (batch_dims_ > 0) {
671         OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
672                                            &tmp_indices));
673         functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
674         copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
675                      indices.flat<Index>());
676 
677         AddBatchOffsets(&tmp_indices, params);
678         op_indices = &tmp_indices;
679       }
680 
681       int64 gather_dim_size = 1;
682       for (int idx = 0; idx <= batch_dims_; ++idx) {
683         gather_dim_size *= params.dim_size(idx);
684       }
685       int64 inner_size = 1;
686       for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
687         inner_size *= params.dim_size(i);
688       }
689       auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
690       const auto indices_flat = op_indices->flat<Index>();
691       auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
692 
693       functor::GatherFunctor<Device, T, Index> functor;
694       int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
695 
696       OP_REQUIRES(
697           c, bad_i < 0,
698           errors::InvalidArgument(
699               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
700               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
701     }
702   }
703 };
704 
705 #define REGISTER_GATHER_FULL(dev, type, index_type)                    \
706   REGISTER_KERNEL_BUILDER(Name("ResourceGather")                       \
707                               .Device(DEVICE_##dev)                    \
708                               .HostMemory("resource")                  \
709                               .TypeConstraint<type>("dtype")           \
710                               .TypeConstraint<index_type>("Tindices"), \
711                           ResourceGatherOp<dev##Device, type, index_type>)
712 
713 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
714   REGISTER_GATHER_FULL(dev, type, int32);      \
715   REGISTER_GATHER_FULL(dev, type, int64)
716 
717 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
718 
719 // Registration of the CPU implementations.
720 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
721 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
722 
723 // Registers GPU kernels.
724 #if GOOGLE_CUDA
725 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
726 
727 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
728 
729 // Variant objects themselves sit on CPU, even if they contain data
730 // pointing to a device.
731 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
732                             .Device(DEVICE_GPU)
733                             .HostMemory("resource")
734                             .HostMemory("indices")
735                             .TypeConstraint<Variant>("dtype")
736                             .TypeConstraint<int32>("Tindices"),
737                         ResourceGatherOp<GPUDevice, Variant, int32>)
738 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
739                             .Device(DEVICE_GPU)
740                             .HostMemory("resource")
741                             .HostMemory("indices")
742                             .TypeConstraint<Variant>("dtype")
743                             .TypeConstraint<int64>("Tindices"),
744                         ResourceGatherOp<GPUDevice, Variant, int64>)
745 
746 #endif  // GOOGLE_CUDA
747 
748 #undef REGISTER_GATHER_CPU
749 #undef REGISTER_GATHER_GPU
750 #undef REGISTER_GATHER_ALL_INDICES
751 #undef REGISTER_GATHER_FULL
752 
753 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
754 class ResourceScatterUpdateOp : public OpKernel {
755  public:
ResourceScatterUpdateOp(OpKernelConstruction * c)756   explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
757 
Compute(OpKernelContext * c)758   void Compute(OpKernelContext* c) override {
759     Var* v = nullptr;
760     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
761     core::ScopedUnref unref_v(v);
762     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
763     tf_shared_lock ml(*v->mu());
764     Tensor* params = v->tensor();
765     const Tensor& indices = c->input(1);
766     const Tensor& updates = c->input(2);
767 
768     // Check that we have enough index space
769     const int64 N_big = indices.NumElements();
770     OP_REQUIRES(
771         c, N_big <= std::numeric_limits<Index>::max(),
772         errors::InvalidArgument("indices has too many elements for ",
773                                 DataTypeString(DataTypeToEnum<Index>::v()),
774                                 " indexing: ", N_big, " > ",
775                                 std::numeric_limits<Index>::max()));
776     const Index N = static_cast<Index>(N_big);
777     OP_REQUIRES(
778         c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
779         errors::InvalidArgument("params.shape[0] too large for ",
780                                 DataTypeString(DataTypeToEnum<Index>::v()),
781                                 " indexing: ", params->dim_size(0), " > ",
782                                 std::numeric_limits<Index>::max()));
783 
784     if (N > 0) {
785       auto indices_flat = indices.flat<Index>();
786       auto params_flat = params->flat_outer_dims<T>();
787       if (TensorShapeUtils::IsScalar(updates.shape())) {
788         const auto update = updates.scalar<T>();
789 
790         functor::ScatterScalarFunctor<Device, T, Index, op> functor;
791         const Index bad_i = functor(c, c->template eigen_device<Device>(),
792                                     params_flat, update, indices_flat);
793         OP_REQUIRES(c, bad_i < 0,
794                     errors::InvalidArgument(
795                         "indices", SliceDebugString(indices.shape(), bad_i),
796                         " = ", indices_flat(bad_i), " is not in [0, ",
797                         params->dim_size(0), ")"));
798       } else {
799         int64 num_updates = updates.NumElements();
800         OP_REQUIRES(c, num_updates % N == 0,
801                     errors::InvalidArgument(
802                         "shape of indices (", indices.shape().DebugString(),
803                         ") is not compatible with the shape of updates (",
804                         updates.shape().DebugString(), ")"));
805         auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
806 
807         functor::ScatterFunctor<Device, T, Index, op> functor;
808         const Index bad_i = functor(c, c->template eigen_device<Device>(),
809                                     params_flat, updates_flat, indices_flat);
810         OP_REQUIRES(c, bad_i < 0,
811                     errors::InvalidArgument(
812                         "indices", SliceDebugString(indices.shape(), bad_i),
813                         " = ", indices_flat(bad_i), " is not in [0, ",
814                         params->dim_size(0), ")"));
815       }
816     }
817   }
818 };
819 
820 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
821   REGISTER_KERNEL_BUILDER(                                             \
822       Name(name)                                                       \
823           .Device(DEVICE_##dev)                                        \
824           .HostMemory("resource")                                      \
825           .TypeConstraint<type>("dtype")                               \
826           .TypeConstraint<index_type>("Tindices"),                     \
827       ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
828 
829 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
830   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
831   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
832 
833 #define REGISTER_SCATTER_ARITHMETIC(type, dev)                \
834   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd",    \
835                           scatter_op::UpdateOp::ADD);         \
836   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub",    \
837                           scatter_op::UpdateOp::SUB);         \
838   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul",    \
839                           scatter_op::UpdateOp::MUL);         \
840   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv",    \
841                           scatter_op::UpdateOp::DIV);         \
842   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
843                           scatter_op::UpdateOp::ASSIGN);
844 #define REGISTER_SCATTER_MINMAX(type, dev)                 \
845   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
846                           scatter_op::UpdateOp::MIN);      \
847   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
848                           scatter_op::UpdateOp::MAX);
849 
850 // Registers CPU kernels.
851 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
852   REGISTER_SCATTER_ARITHMETIC(type, CPU);
853 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
854 
855 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
856 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
857 
858 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
859                         scatter_op::UpdateOp::ASSIGN);
860 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
861                         scatter_op::UpdateOp::ASSIGN);
862 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
863                         scatter_op::UpdateOp::ASSIGN);
864 
865 // Registers GPU kernels.
866 #if GOOGLE_CUDA
867 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
868   REGISTER_SCATTER_ARITHMETIC(type, GPU);
869 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
870 
871 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
872 
873 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
874 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
875 
876 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
877                             .Device(DEVICE_GPU)
878                             .HostMemory("resource")
879                             .HostMemory("indices")
880                             .TypeConstraint<Variant>("dtype")
881                             .TypeConstraint<int32>("Tindices"),
882                         ResourceScatterUpdateOp<GPUDevice, Variant, int32,
883                                                 scatter_op::UpdateOp::ASSIGN>)
884 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
885                             .Device(DEVICE_GPU)
886                             .HostMemory("resource")
887                             .TypeConstraint<bool>("dtype")
888                             .TypeConstraint<int32>("Tindices"),
889                         ResourceScatterUpdateOp<GPUDevice, bool, int32,
890                                                 scatter_op::UpdateOp::ASSIGN>)
891 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
892                             .Device(DEVICE_GPU)
893                             .HostMemory("resource")
894                             .HostMemory("indices")
895                             .TypeConstraint<Variant>("dtype")
896                             .TypeConstraint<int64>("Tindices"),
897                         ResourceScatterUpdateOp<GPUDevice, Variant, int64,
898                                                 scatter_op::UpdateOp::ASSIGN>)
899 
900 #endif  // GOOGLE_CUDA
901 
902 #undef REGISTER_SCATTER_ARITHMETIC
903 #undef REGISTER_SCATTER_ARITHMETIC_CPU
904 #undef REGISTER_SCATTER_MINMAX
905 #undef REGISTER_SCATTER_MINMAX_CPU
906 #undef REGISTER_SCATTER_KERNEL
907 #undef REGISTER_SCATTER_KERNEL_INDEX
908 
909 }  // namespace tensorflow
910