1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17 #include "tensorflow/core/kernels/variable_ops.h"
18
19 #include "tensorflow/core/framework/control_flow.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/strcat.h"
24 #include "tensorflow/core/platform/types.h"
25
26 namespace tensorflow {
27
28 namespace {
29
30 // Makes a unique name for a temporary variable inside a while loop body,
31 // because loop can be executed in multiple iterations in parallel.
TemporaryVariableName(const string & var_name,const FrameAndIter & control_frame)32 string TemporaryVariableName(const string& var_name,
33 const FrameAndIter& control_frame) {
34 if (control_frame.frame_id != kIllegalFrameId &&
35 control_frame.iter_id != kIllegalIterId) {
36 return strings::StrCat(var_name, "/frame:", control_frame.frame_id,
37 "/iter:", control_frame.iter_id);
38 }
39 return var_name;
40 }
41
42 } // namespace
43
44 // Resource stored by variables in the resource manager
45 // (legacy, ref-style version).
46 class LegacyVar : public ResourceBase {
47 public:
LegacyVar(DataType dtype)48 explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
49 // Not copyable or movable.
50 LegacyVar(const LegacyVar&) = delete;
51 LegacyVar& operator=(const LegacyVar&) = delete;
52
mu()53 mutex* mu() { return &mu_; }
tensor()54 Tensor* tensor() { return &tensor_; }
55
DebugString() const56 string DebugString() const override {
57 return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
58 tensor_.shape().DebugString());
59 }
60
61 private:
62 mutex mu_;
63 Tensor tensor_;
64
~LegacyVar()65 ~LegacyVar() override {}
66 };
67
VariableOp(OpKernelConstruction * context)68 VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
69 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
70 dtype_ = RemoveRefType(context->output_type(0));
71 OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(),
72 true /* use name() */));
73 }
74
Compute(OpKernelContext * ctx)75 void VariableOp::Compute(OpKernelContext* ctx) {
76 auto creator = [this](LegacyVar** var) {
77 *var = new LegacyVar(dtype_);
78 (*var)->tensor()->set_shape(shape_);
79 return Status::OK();
80 };
81 LegacyVar* var;
82 OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
83 cinfo_.container(), cinfo_.name(), &var, creator));
84 // Output a reference to our tensor, so it may be updated.
85 //
86 // As long as the resource manager hasn't been cleared the ref we return
87 // here is valid because it owns a ref on var.
88 ctx->set_output_ref(0, var->mu(), var->tensor());
89 if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
90 ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
91 }
92 var->Unref();
93 }
94
95 class TemporaryVariableOp : public OpKernel {
96 public:
TemporaryVariableOp(OpKernelConstruction * context)97 explicit TemporaryVariableOp(OpKernelConstruction* context)
98 : OpKernel(context) {
99 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
100 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
101 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
102 // Variable name defaults to op name if not specified explicitly.
103 if (var_name_.empty()) var_name_ = name();
104 }
105
Compute(OpKernelContext * context)106 void Compute(OpKernelContext* context) override {
107 Status s;
108 ResourceMgr* rm = context->resource_manager();
109 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
110 auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
111 auto* tmp_var = new TmpVar;
112 OP_REQUIRES(context, tmp_var,
113 errors::ResourceExhausted("Could not allocate TmpVar."));
114 tmp_var->name = unique_name;
115 s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
116 if (!s.ok()) tmp_var->Unref();
117 OP_REQUIRES_OK(context, s);
118 OP_REQUIRES_OK(context,
119 context->step_container()->Create(rm, unique_name, tmp_var));
120 context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
121 if (context->track_allocations()) {
122 context->record_persistent_memory_allocation(
123 tmp_var->val.AllocatedBytes());
124 }
125 }
126
127 private:
128 // Refcounted temporary variable resource.
129 friend class DestroyTemporaryVariableOp;
130 struct TmpVar : public ResourceBase {
131 mutex mu;
132 Tensor val;
133 string name;
DebugStringtensorflow::TemporaryVariableOp::TmpVar134 string DebugString() const override { return name; }
~TmpVartensorflow::TemporaryVariableOp::TmpVar135 ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
136 };
137
138 TensorShape shape_;
139 DataType dtype_;
140 string var_name_;
141 };
142
143 class DestroyTemporaryVariableOp : public OpKernel {
144 public:
DestroyTemporaryVariableOp(OpKernelConstruction * context)145 explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
146 : OpKernel(context) {
147 OP_REQUIRES(context, IsRefType(context->input_type(0)),
148 errors::InvalidArgument("lhs input needs to be a ref type"));
149 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
150 OP_REQUIRES(context, !var_name_.empty(),
151 errors::InvalidArgument("Missing var_name attribute"));
152 }
153
Compute(OpKernelContext * context)154 void Compute(OpKernelContext* context) override {
155 // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
156 // their execution before this DestroyTemporaryVariable op executes.
157 // This is typically achieved using control dependencies.
158 CHECK(IsRefType(context->input_dtype(0)));
159 Tensor tmpvar = context->mutable_input(0, false);
160 context->set_output(0, tmpvar);
161 ResourceMgr* rm = context->resource_manager();
162 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
163 auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
164 OP_REQUIRES_OK(
165 context, context->step_container()->Delete<TemporaryVariableOp::TmpVar>(
166 rm, unique_name));
167 if (context->track_allocations()) {
168 context->record_persistent_memory_allocation(
169 -static_cast<int64>(tmpvar.AllocatedBytes()));
170 }
171 }
172
173 private:
174 string var_name_;
175 };
176
177 class IsVariableInitializedOp : public OpKernel {
178 public:
IsVariableInitializedOp(OpKernelConstruction * context)179 explicit IsVariableInitializedOp(OpKernelConstruction* context)
180 : OpKernel(context) {}
181
Compute(OpKernelContext * context)182 void Compute(OpKernelContext* context) override {
183 // Get a mutable input tensor of the Ref input.
184 const Tensor& input_tensor = context->mutable_input(0, false);
185 Tensor* output = nullptr;
186 OP_REQUIRES_OK(context,
187 context->allocate_output(0, TensorShape({}), &output));
188 auto output_tensor = output->tensor<bool, 0>();
189 bool result = input_tensor.IsInitialized();
190 output_tensor() = result;
191 }
192 };
193
194 REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
195 REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
196 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
197 TemporaryVariableOp);
198 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
199 DestroyTemporaryVariableOp);
200 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
201 IsVariableInitializedOp);
202
203
204 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
205 // Only register 'Variable' on GPU for the subset of types also supported by
206 // 'Assign' (see dense_update_ops.cc.)
207 #define REGISTER_GPU_KERNELS(type) \
208 REGISTER_KERNEL_BUILDER( \
209 Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
210 VariableOp); \
211 REGISTER_KERNEL_BUILDER( \
212 Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
213 VariableOp); \
214 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
215 .Device(DEVICE_GPU) \
216 .TypeConstraint<type>("dtype"), \
217 TemporaryVariableOp); \
218 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
219 .Device(DEVICE_GPU) \
220 .TypeConstraint<type>("T"), \
221 DestroyTemporaryVariableOp); \
222 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
223 .Device(DEVICE_GPU) \
224 .TypeConstraint<type>("dtype") \
225 .HostMemory("is_initialized"), \
226 IsVariableInitializedOp);
227
228 TF_CALL_int64(REGISTER_GPU_KERNELS);
229 TF_CALL_uint32(REGISTER_GPU_KERNELS);
230 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
231 #undef REGISTER_GPU_KERNELS
232 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
233
234 } // namespace tensorflow
235