1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17 #include "tensorflow/core/kernels/variable_ops.h"
18
19 #include "tensorflow/core/framework/control_flow.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/strcat.h"
24 #include "tensorflow/core/platform/types.h"
25
26 namespace tensorflow {
27
28 namespace {
29
30 // Makes a unique name for a temporary variable inside a while loop body,
31 // because loop can be executed in multiple iterations in parallel.
TemporaryVariableName(const string & var_name,const FrameAndIter & control_frame)32 string TemporaryVariableName(const string& var_name,
33 const FrameAndIter& control_frame) {
34 if (control_frame.frame_id != kIllegalFrameId &&
35 control_frame.iter_id != kIllegalIterId) {
36 return strings::StrCat(var_name, "/frame:", control_frame.frame_id,
37 "/iter:", control_frame.iter_id);
38 }
39 return var_name;
40 }
41
42 } // namespace
43
44 // Resource stored by variables in the resource manager
45 // (legacy, ref-style version).
46 class LegacyVar : public ResourceBase {
47 public:
LegacyVar(DataType dtype)48 explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
49 // Not copyable or movable.
50 LegacyVar(const LegacyVar&) = delete;
51 LegacyVar& operator=(const LegacyVar&) = delete;
52
mu()53 mutex* mu() { return &mu_; }
tensor()54 Tensor* tensor() { return &tensor_; }
55
DebugString() const56 string DebugString() const override {
57 return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
58 tensor_.shape().DebugString());
59 }
60
61 private:
62 mutex mu_;
63 Tensor tensor_;
64
~LegacyVar()65 ~LegacyVar() override {}
66 };
67
VariableOp(OpKernelConstruction * context)68 VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
69 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
70 dtype_ = RemoveRefType(context->output_type(0));
71 OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(),
72 true /* use name() */));
73 }
74
Compute(OpKernelContext * ctx)75 void VariableOp::Compute(OpKernelContext* ctx) {
76 auto creator = [this](LegacyVar** var) {
77 *var = new LegacyVar(dtype_);
78 (*var)->tensor()->set_shape(shape_);
79 return OkStatus();
80 };
81 LegacyVar* var;
82 OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
83 cinfo_.container(), cinfo_.name(), &var, creator));
84 // Output a reference to our tensor, so it may be updated.
85 //
86 // As long as the resource manager hasn't been cleared the ref we return
87 // here is valid because it owns a ref on var.
88 ctx->set_output_ref(0, var->mu(), var->tensor());
89 if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
90 ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
91 }
92 var->Unref();
93 }
94
95 class TemporaryVariableOp : public OpKernel {
96 public:
TemporaryVariableOp(OpKernelConstruction * context)97 explicit TemporaryVariableOp(OpKernelConstruction* context)
98 : OpKernel(context) {
99 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
100 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
101 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
102 // Variable name defaults to op name if not specified explicitly.
103 if (var_name_.empty()) var_name_ = name();
104 }
105
Compute(OpKernelContext * context)106 void Compute(OpKernelContext* context) override {
107 Status s;
108 ResourceMgr* rm = context->resource_manager();
109 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
110 auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
111 auto* tmp_var = new TmpVar;
112 OP_REQUIRES(context, tmp_var,
113 errors::ResourceExhausted("Could not allocate TmpVar."));
114 tmp_var->name = unique_name;
115 s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
116 if (!s.ok()) tmp_var->Unref();
117 OP_REQUIRES_OK(context, s);
118 OP_REQUIRES_OK(context,
119 context->step_container()->Create(rm, unique_name, tmp_var));
120 context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
121 if (context->track_allocations()) {
122 context->record_persistent_memory_allocation(
123 tmp_var->val.AllocatedBytes());
124 }
125 }
126
127 private:
128 // Refcounted temporary variable resource.
129 friend class DestroyTemporaryVariableOp;
130 struct TmpVar : public ResourceBase {
131 mutex mu;
132 Tensor val;
133 string name;
DebugStringtensorflow::TemporaryVariableOp::TmpVar134 string DebugString() const override { return name; }
~TmpVartensorflow::TemporaryVariableOp::TmpVar135 ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
136 };
137
138 TensorShape shape_;
139 DataType dtype_;
140 string var_name_;
141 };
142
143 class DestroyTemporaryVariableOp : public OpKernel {
144 public:
DestroyTemporaryVariableOp(OpKernelConstruction * context)145 explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
146 : OpKernel(context) {
147 OP_REQUIRES(context, IsRefType(context->input_type(0)),
148 errors::InvalidArgument("lhs input needs to be a ref type"));
149 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
150 OP_REQUIRES(context, !var_name_.empty(),
151 errors::InvalidArgument("Missing var_name attribute"));
152 }
153
Compute(OpKernelContext * context)154 void Compute(OpKernelContext* context) override {
155 // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
156 // their execution before this DestroyTemporaryVariable op executes.
157 // This is typically achieved using control dependencies.
158 CHECK(IsRefType(context->input_dtype(0)));
159 Tensor tmpvar = context->mutable_input(0, false);
160 context->set_output(0, tmpvar);
161 ResourceMgr* rm = context->resource_manager();
162 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
163 auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
164 OP_REQUIRES_OK(
165 context, context->step_container()->Delete<TemporaryVariableOp::TmpVar>(
166 rm, unique_name));
167 if (context->track_allocations()) {
168 context->record_persistent_memory_allocation(
169 -static_cast<int64_t>(tmpvar.AllocatedBytes()));
170 }
171 }
172
173 private:
174 string var_name_;
175 };
176
177 class IsVariableInitializedOp : public OpKernel {
178 public:
IsVariableInitializedOp(OpKernelConstruction * context)179 explicit IsVariableInitializedOp(OpKernelConstruction* context)
180 : OpKernel(context) {}
181
Compute(OpKernelContext * context)182 void Compute(OpKernelContext* context) override {
183 // Get a mutable input tensor of the Ref input.
184 const Tensor& input_tensor = context->mutable_input(0, false);
185 Tensor* output = nullptr;
186 OP_REQUIRES_OK(context,
187 context->allocate_output(0, TensorShape({}), &output));
188 auto output_tensor = output->tensor<bool, 0>();
189 bool result = input_tensor.IsInitialized();
190 output_tensor() = result;
191 }
192 };
193
194 REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
195 REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
196 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
197 TemporaryVariableOp);
198 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
199 DestroyTemporaryVariableOp);
200 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
201 IsVariableInitializedOp);
202
203 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
204 // Only register 'Variable' on GPU for the subset of types also supported by
205 // 'Assign' (see dense_update_ops.cc.)
206 #define REGISTER_GPU_KERNELS(type) \
207 REGISTER_KERNEL_BUILDER( \
208 Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
209 VariableOp); \
210 REGISTER_KERNEL_BUILDER( \
211 Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
212 VariableOp); \
213 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
214 .Device(DEVICE_GPU) \
215 .TypeConstraint<type>("dtype"), \
216 TemporaryVariableOp); \
217 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
218 .Device(DEVICE_GPU) \
219 .TypeConstraint<type>("T"), \
220 DestroyTemporaryVariableOp); \
221 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
222 .Device(DEVICE_GPU) \
223 .TypeConstraint<type>("dtype") \
224 .HostMemory("is_initialized"), \
225 IsVariableInitializedOp);
226
227 TF_CALL_int64(REGISTER_GPU_KERNELS);
228 TF_CALL_uint32(REGISTER_GPU_KERNELS);
229 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
230 #undef REGISTER_GPU_KERNELS
231 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
232
233 #define REGISTER_DEFAULT_KERNELS(type) \
234 REGISTER_KERNEL_BUILDER( \
235 Name("Variable").Device(DEVICE_DEFAULT).TypeConstraint<type>("dtype"), \
236 VariableOp); \
237 REGISTER_KERNEL_BUILDER( \
238 Name("VariableV2").Device(DEVICE_DEFAULT).TypeConstraint<type>("dtype"), \
239 VariableOp); \
240 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
241 .Device(DEVICE_DEFAULT) \
242 .TypeConstraint<type>("dtype"), \
243 TemporaryVariableOp); \
244 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
245 .Device(DEVICE_DEFAULT) \
246 .TypeConstraint<type>("T"), \
247 DestroyTemporaryVariableOp); \
248 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
249 .Device(DEVICE_DEFAULT) \
250 .TypeConstraint<type>("dtype") \
251 .HostMemory("is_initialized"), \
252 IsVariableInitializedOp);
253
254 TF_CALL_int64(REGISTER_DEFAULT_KERNELS);
255 TF_CALL_uint32(REGISTER_DEFAULT_KERNELS);
256 TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
257 #undef REGISTER_DEFAULT_KERNELS
258
259 } // namespace tensorflow
260