1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17 #include "tensorflow/core/kernels/variable_ops.h"
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/types.h"
23
24 namespace tensorflow {
25
26 // Resource stored by variables in the resource manager
27 // (legacy, ref-style version).
28 class LegacyVar : public ResourceBase {
29 public:
LegacyVar(DataType dtype)30 explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
31 // Not copyable or movable.
32 LegacyVar(const LegacyVar&) = delete;
33 LegacyVar& operator=(const LegacyVar&) = delete;
34
mu()35 mutex* mu() { return &mu_; }
tensor()36 Tensor* tensor() { return &tensor_; }
37
DebugString() const38 string DebugString() const override {
39 return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
40 tensor_.shape().DebugString());
41 }
42
43 private:
44 mutex mu_;
45 Tensor tensor_;
46
~LegacyVar()47 ~LegacyVar() override {}
48 };
49
VariableOp(OpKernelConstruction * context)50 VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
51 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
52 dtype_ = RemoveRefType(context->output_type(0));
53 }
54
Compute(OpKernelContext * ctx)55 void VariableOp::Compute(OpKernelContext* ctx) {
56 mutex_lock l(init_mu_);
57 if (!initialized_) {
58 OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
59 true /* use name() */));
60 initialized_ = true;
61 }
62 auto creator = [this](LegacyVar** var) {
63 *var = new LegacyVar(dtype_);
64 (*var)->tensor()->set_shape(shape_);
65 return Status::OK();
66 };
67 LegacyVar* var;
68 OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
69 cinfo_.container(), cinfo_.name(), &var, creator));
70 // Output a reference to our tensor, so it may be updated.
71 //
72 // As long as the resource manager hasn't been cleared the ref we return
73 // here is valid because it owns a ref on var.
74 ctx->set_output_ref(0, var->mu(), var->tensor());
75 if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
76 ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
77 }
78 var->Unref();
79 }
80
81 class TemporaryVariableOp : public OpKernel {
82 public:
TemporaryVariableOp(OpKernelConstruction * context)83 explicit TemporaryVariableOp(OpKernelConstruction* context)
84 : OpKernel(context) {
85 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
86 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
87 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
88 // Variable name defaults to op name if not specified explicitly.
89 if (var_name_.empty()) var_name_ = name();
90 }
91
Compute(OpKernelContext * context)92 void Compute(OpKernelContext* context) override {
93 Status s;
94 ResourceMgr* rm = context->resource_manager();
95 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
96 auto* tmp_var = new TmpVar;
97 OP_REQUIRES(context, tmp_var,
98 errors::ResourceExhausted("Could not allocate TmpVar."));
99 tmp_var->name = var_name_;
100 s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
101 if (!s.ok()) tmp_var->Unref();
102 OP_REQUIRES_OK(context, s);
103 OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
104 var_name_, tmp_var));
105 context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
106 if (context->track_allocations()) {
107 context->record_persistent_memory_allocation(
108 tmp_var->val.AllocatedBytes());
109 }
110 }
111
112 private:
113 // Refcounted temporary variable resource.
114 friend class DestroyTemporaryVariableOp;
115 struct TmpVar : public ResourceBase {
116 mutex mu;
117 Tensor val;
118 string name;
DebugStringtensorflow::TemporaryVariableOp::TmpVar119 string DebugString() const override { return name; }
~TmpVartensorflow::TemporaryVariableOp::TmpVar120 ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
121 };
122
123 TensorShape shape_;
124 DataType dtype_;
125 string var_name_;
126 };
127
128 class DestroyTemporaryVariableOp : public OpKernel {
129 public:
DestroyTemporaryVariableOp(OpKernelConstruction * context)130 explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
131 : OpKernel(context) {
132 OP_REQUIRES(context, IsRefType(context->input_type(0)),
133 errors::InvalidArgument("lhs input needs to be a ref type"));
134 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
135 OP_REQUIRES(context, !var_name_.empty(),
136 errors::InvalidArgument("Missing var_name attribute"));
137 }
138
Compute(OpKernelContext * context)139 void Compute(OpKernelContext* context) override {
140 // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
141 // their execution before this DestroyTemporaryVariable op executes.
142 // This is typically achieved using control dependencies.
143 CHECK(IsRefType(context->input_dtype(0)));
144 Tensor tmpvar = context->mutable_input(0, false);
145 context->set_output(0, tmpvar);
146 ResourceMgr* rm = context->resource_manager();
147 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
148 OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
149 context->step_container()->name(), var_name_));
150 if (context->track_allocations()) {
151 context->record_persistent_memory_allocation(
152 -static_cast<int64>(tmpvar.AllocatedBytes()));
153 }
154 }
155
156 private:
157 string var_name_;
158 };
159
160 class IsVariableInitializedOp : public OpKernel {
161 public:
IsVariableInitializedOp(OpKernelConstruction * context)162 explicit IsVariableInitializedOp(OpKernelConstruction* context)
163 : OpKernel(context) {}
164
Compute(OpKernelContext * context)165 void Compute(OpKernelContext* context) override {
166 // Get a mutable input tensor of the Ref input.
167 const Tensor& input_tensor = context->mutable_input(0, false);
168 Tensor* output = nullptr;
169 OP_REQUIRES_OK(context,
170 context->allocate_output(0, TensorShape({}), &output));
171 auto output_tensor = output->tensor<bool, 0>();
172 bool result = input_tensor.IsInitialized();
173 output_tensor() = result;
174 }
175 };
176
177 REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
178 REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
179 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
180 TemporaryVariableOp);
181 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
182 DestroyTemporaryVariableOp);
183 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
184 IsVariableInitializedOp);
185
186 #ifdef TENSORFLOW_USE_SYCL
187 #define REGISTER_SYCL_KERNEL(type) \
188 REGISTER_KERNEL_BUILDER( \
189 Name("Variable").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
190 VariableOp); \
191 REGISTER_KERNEL_BUILDER( \
192 Name("VariableV2").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
193 VariableOp); \
194 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
195 .Device(DEVICE_SYCL) \
196 .TypeConstraint<type>("dtype"), \
197 TemporaryVariableOp); \
198 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
199 .Device(DEVICE_SYCL) \
200 .TypeConstraint<type>("T"), \
201 DestroyTemporaryVariableOp); \
202 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
203 .Device(DEVICE_SYCL) \
204 .TypeConstraint<type>("dtype") \
205 .HostMemory("is_initialized"), \
206 IsVariableInitializedOp);
207
208 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
209 #undef REGISTER_SYCL_KERNEL
210 #endif // TENSORFLOW_USE_SYCL
211
212 #if GOOGLE_CUDA
213 // Only register 'Variable' on GPU for the subset of types also supported by
214 // 'Assign' (see dense_update_ops.cc.)
215 #define REGISTER_GPU_KERNELS(type) \
216 REGISTER_KERNEL_BUILDER( \
217 Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
218 VariableOp); \
219 REGISTER_KERNEL_BUILDER( \
220 Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
221 VariableOp); \
222 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
223 .Device(DEVICE_GPU) \
224 .TypeConstraint<type>("dtype"), \
225 TemporaryVariableOp); \
226 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
227 .Device(DEVICE_GPU) \
228 .TypeConstraint<type>("T"), \
229 DestroyTemporaryVariableOp); \
230 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
231 .Device(DEVICE_GPU) \
232 .TypeConstraint<type>("dtype") \
233 .HostMemory("is_initialized"), \
234 IsVariableInitializedOp);
235
236 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
237 TF_CALL_int64(REGISTER_GPU_KERNELS);
238 #undef REGISTER_GPU_KERNELS
239 #endif // GOOGLE_CUDA
240
241 } // namespace tensorflow
242