• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 #include "tensorflow/core/kernels/variable_ops.h"
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 
26 // Resource stored by variables in the resource manager
27 // (legacy, ref-style version).
28 class LegacyVar : public ResourceBase {
29  public:
LegacyVar(DataType dtype)30   explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
31   // Not copyable or movable.
32   LegacyVar(const LegacyVar&) = delete;
33   LegacyVar& operator=(const LegacyVar&) = delete;
34 
mu()35   mutex* mu() { return &mu_; }
tensor()36   Tensor* tensor() { return &tensor_; }
37 
DebugString() const38   string DebugString() const override {
39     return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
40                            tensor_.shape().DebugString());
41   }
42 
43  private:
44   mutex mu_;
45   Tensor tensor_;
46 
~LegacyVar()47   ~LegacyVar() override {}
48 };
49 
VariableOp(OpKernelConstruction * context)50 VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
51   OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
52   dtype_ = RemoveRefType(context->output_type(0));
53 }
54 
Compute(OpKernelContext * ctx)55 void VariableOp::Compute(OpKernelContext* ctx) {
56   mutex_lock l(init_mu_);
57   if (!initialized_) {
58     OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
59                                     true /* use name() */));
60     initialized_ = true;
61   }
62   auto creator = [this](LegacyVar** var) {
63     *var = new LegacyVar(dtype_);
64     (*var)->tensor()->set_shape(shape_);
65     return Status::OK();
66   };
67   LegacyVar* var;
68   OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
69                           cinfo_.container(), cinfo_.name(), &var, creator));
70   // Output a reference to our tensor, so it may be updated.
71   //
72   // As long as the resource manager hasn't been cleared the ref we return
73   // here is valid because it owns a ref on var.
74   ctx->set_output_ref(0, var->mu(), var->tensor());
75   if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
76     ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
77   }
78   var->Unref();
79 }
80 
81 class TemporaryVariableOp : public OpKernel {
82  public:
TemporaryVariableOp(OpKernelConstruction * context)83   explicit TemporaryVariableOp(OpKernelConstruction* context)
84       : OpKernel(context) {
85     OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
86     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
87     OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
88     // Variable name defaults to op name if not specified explicitly.
89     if (var_name_.empty()) var_name_ = name();
90   }
91 
Compute(OpKernelContext * context)92   void Compute(OpKernelContext* context) override {
93     Status s;
94     ResourceMgr* rm = context->resource_manager();
95     OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
96     auto* tmp_var = new TmpVar;
97     OP_REQUIRES(context, tmp_var,
98                 errors::ResourceExhausted("Could not allocate TmpVar."));
99     tmp_var->name = var_name_;
100     s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
101     if (!s.ok()) tmp_var->Unref();
102     OP_REQUIRES_OK(context, s);
103     OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
104                                        var_name_, tmp_var));
105     context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
106     if (context->track_allocations()) {
107       context->record_persistent_memory_allocation(
108           tmp_var->val.AllocatedBytes());
109     }
110   }
111 
112  private:
113   // Refcounted temporary variable resource.
114   friend class DestroyTemporaryVariableOp;
115   struct TmpVar : public ResourceBase {
116     mutex mu;
117     Tensor val;
118     string name;
DebugStringtensorflow::TemporaryVariableOp::TmpVar119     string DebugString() const override { return name; }
~TmpVartensorflow::TemporaryVariableOp::TmpVar120     ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
121   };
122 
123   TensorShape shape_;
124   DataType dtype_;
125   string var_name_;
126 };
127 
128 class DestroyTemporaryVariableOp : public OpKernel {
129  public:
DestroyTemporaryVariableOp(OpKernelConstruction * context)130   explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
131       : OpKernel(context) {
132     OP_REQUIRES(context, IsRefType(context->input_type(0)),
133                 errors::InvalidArgument("lhs input needs to be a ref type"));
134     OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
135     OP_REQUIRES(context, !var_name_.empty(),
136                 errors::InvalidArgument("Missing var_name attribute"));
137   }
138 
Compute(OpKernelContext * context)139   void Compute(OpKernelContext* context) override {
140     // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
141     // their execution before this DestroyTemporaryVariable op executes.
142     // This is typically achieved using control dependencies.
143     CHECK(IsRefType(context->input_dtype(0)));
144     Tensor tmpvar = context->mutable_input(0, false);
145     context->set_output(0, tmpvar);
146     ResourceMgr* rm = context->resource_manager();
147     OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
148     OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
149                                 context->step_container()->name(), var_name_));
150     if (context->track_allocations()) {
151       context->record_persistent_memory_allocation(
152           -static_cast<int64>(tmpvar.AllocatedBytes()));
153     }
154   }
155 
156  private:
157   string var_name_;
158 };
159 
160 class IsVariableInitializedOp : public OpKernel {
161  public:
IsVariableInitializedOp(OpKernelConstruction * context)162   explicit IsVariableInitializedOp(OpKernelConstruction* context)
163       : OpKernel(context) {}
164 
Compute(OpKernelContext * context)165   void Compute(OpKernelContext* context) override {
166     // Get a mutable input tensor of the Ref input.
167     const Tensor& input_tensor = context->mutable_input(0, false);
168     Tensor* output = nullptr;
169     OP_REQUIRES_OK(context,
170                    context->allocate_output(0, TensorShape({}), &output));
171     auto output_tensor = output->tensor<bool, 0>();
172     bool result = input_tensor.IsInitialized();
173     output_tensor() = result;
174   }
175 };
176 
177 REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
178 REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
179 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
180                         TemporaryVariableOp);
181 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
182                         DestroyTemporaryVariableOp);
183 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
184                         IsVariableInitializedOp);
185 
186 #ifdef TENSORFLOW_USE_SYCL
187 #define REGISTER_SYCL_KERNEL(type)                                          \
188   REGISTER_KERNEL_BUILDER(                                                  \
189       Name("Variable").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"),   \
190       VariableOp);                                                          \
191   REGISTER_KERNEL_BUILDER(                                                  \
192       Name("VariableV2").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
193       VariableOp);                                                          \
194   REGISTER_KERNEL_BUILDER(Name("TemporaryVariable")                         \
195                               .Device(DEVICE_SYCL)                          \
196                               .TypeConstraint<type>("dtype"),               \
197                           TemporaryVariableOp);                             \
198   REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable")                  \
199                               .Device(DEVICE_SYCL)                          \
200                               .TypeConstraint<type>("T"),                   \
201                           DestroyTemporaryVariableOp);                      \
202   REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized")                     \
203                               .Device(DEVICE_SYCL)                          \
204                               .TypeConstraint<type>("dtype")                \
205                               .HostMemory("is_initialized"),                \
206                           IsVariableInitializedOp);
207 
208 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
209 #undef REGISTER_SYCL_KERNEL
210 #endif  // TENSORFLOW_USE_SYCL
211 
212 #if GOOGLE_CUDA
213 // Only register 'Variable' on GPU for the subset of types also supported by
214 // 'Assign' (see dense_update_ops.cc.)
215 #define REGISTER_GPU_KERNELS(type)                                         \
216   REGISTER_KERNEL_BUILDER(                                                 \
217       Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"),   \
218       VariableOp);                                                         \
219   REGISTER_KERNEL_BUILDER(                                                 \
220       Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
221       VariableOp);                                                         \
222   REGISTER_KERNEL_BUILDER(Name("TemporaryVariable")                        \
223                               .Device(DEVICE_GPU)                          \
224                               .TypeConstraint<type>("dtype"),              \
225                           TemporaryVariableOp);                            \
226   REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable")                 \
227                               .Device(DEVICE_GPU)                          \
228                               .TypeConstraint<type>("T"),                  \
229                           DestroyTemporaryVariableOp);                     \
230   REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized")                    \
231                               .Device(DEVICE_GPU)                          \
232                               .TypeConstraint<type>("dtype")               \
233                               .HostMemory("is_initialized"),               \
234                           IsVariableInitializedOp);
235 
236 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
237 TF_CALL_int64(REGISTER_GPU_KERNELS);
238 #undef REGISTER_GPU_KERNELS
239 #endif  // GOOGLE_CUDA
240 
241 }  // namespace tensorflow
242