• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 #include "tensorflow/core/kernels/variable_ops.h"
18 
19 #include "tensorflow/core/framework/control_flow.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/strcat.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 
30 // Makes a unique name for a temporary variable inside a while loop body,
31 // because loop can be executed in multiple iterations in parallel.
TemporaryVariableName(const string & var_name,const FrameAndIter & control_frame)32 string TemporaryVariableName(const string& var_name,
33                              const FrameAndIter& control_frame) {
34   if (control_frame.frame_id != kIllegalFrameId &&
35       control_frame.iter_id != kIllegalIterId) {
36     return strings::StrCat(var_name, "/frame:", control_frame.frame_id,
37                            "/iter:", control_frame.iter_id);
38   }
39   return var_name;
40 }
41 
42 }  // namespace
43 
44 // Resource stored by variables in the resource manager
45 // (legacy, ref-style version).
46 class LegacyVar : public ResourceBase {
47  public:
LegacyVar(DataType dtype)48   explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
49   // Not copyable or movable.
50   LegacyVar(const LegacyVar&) = delete;
51   LegacyVar& operator=(const LegacyVar&) = delete;
52 
mu()53   mutex* mu() { return &mu_; }
tensor()54   Tensor* tensor() { return &tensor_; }
55 
DebugString() const56   string DebugString() const override {
57     return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
58                            tensor_.shape().DebugString());
59   }
60 
61  private:
62   mutex mu_;
63   Tensor tensor_;
64 
~LegacyVar()65   ~LegacyVar() override {}
66 };
67 
VariableOp(OpKernelConstruction * context)68 VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
69   OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
70   dtype_ = RemoveRefType(context->output_type(0));
71   OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(),
72                                       true /* use name() */));
73 }
74 
Compute(OpKernelContext * ctx)75 void VariableOp::Compute(OpKernelContext* ctx) {
76   auto creator = [this](LegacyVar** var) {
77     *var = new LegacyVar(dtype_);
78     (*var)->tensor()->set_shape(shape_);
79     return Status::OK();
80   };
81   LegacyVar* var;
82   OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
83                           cinfo_.container(), cinfo_.name(), &var, creator));
84   // Output a reference to our tensor, so it may be updated.
85   //
86   // As long as the resource manager hasn't been cleared the ref we return
87   // here is valid because it owns a ref on var.
88   ctx->set_output_ref(0, var->mu(), var->tensor());
89   if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
90     ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
91   }
92   var->Unref();
93 }
94 
95 class TemporaryVariableOp : public OpKernel {
96  public:
TemporaryVariableOp(OpKernelConstruction * context)97   explicit TemporaryVariableOp(OpKernelConstruction* context)
98       : OpKernel(context) {
99     OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
100     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
101     OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
102     // Variable name defaults to op name if not specified explicitly.
103     if (var_name_.empty()) var_name_ = name();
104   }
105 
Compute(OpKernelContext * context)106   void Compute(OpKernelContext* context) override {
107     Status s;
108     ResourceMgr* rm = context->resource_manager();
109     OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
110     auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
111     auto* tmp_var = new TmpVar;
112     OP_REQUIRES(context, tmp_var,
113                 errors::ResourceExhausted("Could not allocate TmpVar."));
114     tmp_var->name = unique_name;
115     s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
116     if (!s.ok()) tmp_var->Unref();
117     OP_REQUIRES_OK(context, s);
118     OP_REQUIRES_OK(context,
119                    context->step_container()->Create(rm, unique_name, tmp_var));
120     context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
121     if (context->track_allocations()) {
122       context->record_persistent_memory_allocation(
123           tmp_var->val.AllocatedBytes());
124     }
125   }
126 
127  private:
128   // Refcounted temporary variable resource.
129   friend class DestroyTemporaryVariableOp;
130   struct TmpVar : public ResourceBase {
131     mutex mu;
132     Tensor val;
133     string name;
DebugStringtensorflow::TemporaryVariableOp::TmpVar134     string DebugString() const override { return name; }
~TmpVartensorflow::TemporaryVariableOp::TmpVar135     ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
136   };
137 
138   TensorShape shape_;
139   DataType dtype_;
140   string var_name_;
141 };
142 
143 class DestroyTemporaryVariableOp : public OpKernel {
144  public:
DestroyTemporaryVariableOp(OpKernelConstruction * context)145   explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
146       : OpKernel(context) {
147     OP_REQUIRES(context, IsRefType(context->input_type(0)),
148                 errors::InvalidArgument("lhs input needs to be a ref type"));
149     OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
150     OP_REQUIRES(context, !var_name_.empty(),
151                 errors::InvalidArgument("Missing var_name attribute"));
152   }
153 
Compute(OpKernelContext * context)154   void Compute(OpKernelContext* context) override {
155     // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
156     // their execution before this DestroyTemporaryVariable op executes.
157     // This is typically achieved using control dependencies.
158     CHECK(IsRefType(context->input_dtype(0)));
159     Tensor tmpvar = context->mutable_input(0, false);
160     context->set_output(0, tmpvar);
161     ResourceMgr* rm = context->resource_manager();
162     OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
163     auto unique_name = TemporaryVariableName(var_name_, context->frame_iter());
164     OP_REQUIRES_OK(
165         context, context->step_container()->Delete<TemporaryVariableOp::TmpVar>(
166                      rm, unique_name));
167     if (context->track_allocations()) {
168       context->record_persistent_memory_allocation(
169           -static_cast<int64>(tmpvar.AllocatedBytes()));
170     }
171   }
172 
173  private:
174   string var_name_;
175 };
176 
177 class IsVariableInitializedOp : public OpKernel {
178  public:
IsVariableInitializedOp(OpKernelConstruction * context)179   explicit IsVariableInitializedOp(OpKernelConstruction* context)
180       : OpKernel(context) {}
181 
Compute(OpKernelContext * context)182   void Compute(OpKernelContext* context) override {
183     // Get a mutable input tensor of the Ref input.
184     const Tensor& input_tensor = context->mutable_input(0, false);
185     Tensor* output = nullptr;
186     OP_REQUIRES_OK(context,
187                    context->allocate_output(0, TensorShape({}), &output));
188     auto output_tensor = output->tensor<bool, 0>();
189     bool result = input_tensor.IsInitialized();
190     output_tensor() = result;
191   }
192 };
193 
194 REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
195 REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
196 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
197                         TemporaryVariableOp);
198 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
199                         DestroyTemporaryVariableOp);
200 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
201                         IsVariableInitializedOp);
202 
203 
204 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
205 // Only register 'Variable' on GPU for the subset of types also supported by
206 // 'Assign' (see dense_update_ops.cc.)
207 #define REGISTER_GPU_KERNELS(type)                                         \
208   REGISTER_KERNEL_BUILDER(                                                 \
209       Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"),   \
210       VariableOp);                                                         \
211   REGISTER_KERNEL_BUILDER(                                                 \
212       Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
213       VariableOp);                                                         \
214   REGISTER_KERNEL_BUILDER(Name("TemporaryVariable")                        \
215                               .Device(DEVICE_GPU)                          \
216                               .TypeConstraint<type>("dtype"),              \
217                           TemporaryVariableOp);                            \
218   REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable")                 \
219                               .Device(DEVICE_GPU)                          \
220                               .TypeConstraint<type>("T"),                  \
221                           DestroyTemporaryVariableOp);                     \
222   REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized")                    \
223                               .Device(DEVICE_GPU)                          \
224                               .TypeConstraint<type>("dtype")               \
225                               .HostMemory("is_initialized"),               \
226                           IsVariableInitializedOp);
227 
228 TF_CALL_int64(REGISTER_GPU_KERNELS);
229 TF_CALL_uint32(REGISTER_GPU_KERNELS);
230 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
231 #undef REGISTER_GPU_KERNELS
232 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
233 
234 }  // namespace tensorflow
235