• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/kernels/dense_update_functor.h"
23 #include "tensorflow/core/kernels/variable_ops.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 
26 namespace tensorflow {
27 
28 // Must be called before performing a sparse operation on a variable. Ensures
29 // that no concurrent dense operations can happen while holding the variable's
30 // lock.
31 template <typename Device, typename T>
EnsureSparseVariableAccess(OpKernelContext * ctx,Var * var)32 Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
33   if (var->copy_on_read_mode.load()) {
34     return Status::OK();
35   }
36   mutex_lock ml(*var->mu());
37   // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
38   // also happen if there are no concurrent reads of the variable and
39   // copy-on-read mode is false.
40   if (var->tensor()->RefCountIsOne()) {
41     var->copy_on_read_mode.store(true);
42     return Status::OK();
43   }
44   Tensor tmp;
45   if (std::is_same<T, Variant>::value) {
46     AllocatorAttributes attr;
47     attr.set_on_host(true);
48     TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
49                                           var->tensor()->shape(), &tmp, attr));
50 
51     const auto elements_in = var->tensor()->flat<Variant>();
52     auto elements_out = tmp.flat<Variant>();
53     for (int64_t i = 0; i < elements_in.size(); ++i) {
54       elements_out(i) = elements_in(i);
55     }
56   } else {
57     AllocatorAttributes attr;
58     attr.set_gpu_compatible(true);
59     attr.set_nic_compatible(true);
60     TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
61                                           var->tensor()->shape(), &tmp, attr));
62     functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
63     copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
64                  const_cast<const Tensor*>(var->tensor())->flat<T>());
65   }
66   *var->tensor() = tmp;
67   var->copy_on_read_mode.store(true);
68   return Status::OK();
69 }
70 
71 // Utility structure that releases a sequence of borrowed mutexes when it is
72 // deleted.
73 struct VariableInputLockHolder {
74  public:
VariableInputLockHolderVariableInputLockHolder75   VariableInputLockHolder(
76       std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
77       std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
78       : vars_(std::move(vars)),
79         locks_(std::move(locks)),
80         shared_locks_(std::move(shared_locks)) {}
81 
VariableInputLockHolderVariableInputLockHolder82   VariableInputLockHolder(VariableInputLockHolder&& other)
83       : vars_(std::move(other.vars_)),
84         locks_(std::move(other.locks_)),
85         shared_locks_(std::move(other.shared_locks_)) {}
86 
~VariableInputLockHolderVariableInputLockHolder87   ~VariableInputLockHolder() {
88     // Release the locks before unreffing the Vars, because each lock
89     // is potentially borrowed from a Var in vars_.
90     locks_.reset();
91     for (Var* var : vars_) {
92       var->Unref();
93     }
94   }
95 
96  private:
97   std::vector<Var*> vars_;
98   // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
99   // because a `std::vector<mutex_lock>` is not movable on all platforms.
100   std::unique_ptr<std::vector<mutex_lock>> locks_;
101   std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
102 };
103 
104 // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
105 //
106 // If `input` corresponds to a `DT_RESOURCE`-type variable input,
107 // `*maybe_resource` will be updated to contain the underlying resource, and the
108 // caller will be responsible for calling `Unref()` on that resource.
109 template <typename Device, typename T>
GetTrainingVariableMutex(OpKernelContext * ctx,int input,bool sparse,Var ** maybe_resource)110 mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
111                                 Var** maybe_resource) {
112   *maybe_resource = nullptr;
113   if (ctx->input_dtype(input) == DT_RESOURCE) {
114     if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
115       if (sparse) {
116         EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
117             .IgnoreError();
118       }
119       return (*maybe_resource)->mu();
120     } else {
121       ctx->CtxFailureWithWarning(
122           errors::Internal("Invalid variable reference."));
123       return nullptr;
124     }
125   }
126   return ctx->input_ref_mutex(input);
127 }
128 
129 // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
130 // in address order to mitigate deadlock.  Returns a structure that, when
131 // deleted, will release the acquired mutexes. Safe to pass duplicates - will
132 // only lock each distinct mutex once. If sparse is true will ensure the
133 // variable gets switched to copy-on-read mode before trying to acquire the
134 // locks. If do_lock is false, returns immediately for reference variables. For
135 // resource variables in copy-on-read-mode it will grab a shared lock if do_lock
136 // is false, exclusive lock otherwise.  Note that this silently doesn't lock
137 // mutexes for invalid variable references; in all usages this is followed by
138 // GetInputTensor which will signal a failure.
139 template <typename Device, typename T>
MaybeLockVariableInputMutexesInOrder(OpKernelContext * ctx,bool do_lock,bool sparse,const std::vector<int> & input_ids)140 VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
141     OpKernelContext* ctx, bool do_lock, bool sparse,
142     const std::vector<int>& input_ids) {
143   bool any_resource = false;
144   for (auto i : input_ids) {
145     if (ctx->input_dtype(i) == DT_RESOURCE) {
146       any_resource = true;
147       break;
148     }
149   }
150   if (!do_lock && !any_resource) {
151     return VariableInputLockHolder({}, {}, {});
152   }
153   std::vector<Var*> vars;
154   std::vector<mutex*> mutexes;
155   std::vector<int> acquire_order;
156   for (auto input : input_ids) {
157     Var* var;
158     mutex* mutex =
159         GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
160     if (var) vars.push_back(var);
161     // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
162     if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
163       acquire_order.push_back(mutexes.size());
164       mutexes.push_back(mutex);
165     }
166   }
167   std::sort(acquire_order.begin(), acquire_order.end(),
168             [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
169 
170   auto locks = absl::make_unique<std::vector<mutex_lock>>();
171   auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
172   locks->reserve(acquire_order.size());
173 
174   for (auto input : acquire_order) {
175     Var* var;
176     mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
177     core::ScopedUnref scoped_unref(var);
178     if (mu != nullptr) {
179       if (!sparse || do_lock) {
180         locks->emplace_back(*mu);
181       } else {
182         shared_locks->emplace_back(*mu);
183       }
184     }
185   }
186   return VariableInputLockHolder(std::move(vars), std::move(locks),
187                                  std::move(shared_locks));
188 }
189 
190 void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
191                                      int output);
192 
193 // This is for use with ResourceVariables to ensure *tensor has a
194 // reference count of 1 before you update it.
195 // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
196 template <typename Device, typename T>
PrepareToUpdateVariable(OpKernelContext * ctx,Tensor * tensor,bool copy_on_read_mode)197 Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
198                                bool copy_on_read_mode) {
199   if (copy_on_read_mode || !tensor->RefCountIsOne()) {
200     // Tensor's buffer is in use by some read, so we need to copy before
201     // updating.
202     Tensor tmp;
203     if (std::is_same<T, Variant>::value) {
204       AllocatorAttributes attr;
205       attr.set_on_host(true);
206       TF_RETURN_IF_ERROR(
207           ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
208 
209       const auto elements_in = tensor->flat<Variant>();
210       auto elements_out = tmp.flat<Variant>();
211       for (int64_t i = 0; i < elements_in.size(); ++i) {
212         elements_out(i) = elements_in(i);
213       }
214     } else {
215       AllocatorAttributes attr;
216       attr.set_gpu_compatible(true);
217       attr.set_nic_compatible(true);
218       TF_RETURN_IF_ERROR(
219           ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
220       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
221       copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
222                    const_cast<const Tensor*>(tensor)->flat<T>());
223     }
224     *tensor = tmp;
225   }
226   return Status::OK();
227 }
228 
229 // This gives you `*out`, a tensor you can update, corresponding to a variable
230 // passed as input index `input`.  This handles the differences between
231 // reference and resource variables. For reference variables we can just grab
232 // the tensor, grabbing the lock if lock_held is False.
233 //
234 // For resource variables we, if sparse is true, ensure it's in copy-on-read
235 // mode, and then, regardless of the value of sparse, ensure its refcount is 1
236 // (by potentially copying its contents). In this case lock_held is ignored.
237 template <typename Device, typename T>
GetInputTensorFromVariable(OpKernelContext * ctx,int input,bool lock_held,bool sparse,Tensor * out)238 Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
239                                   bool lock_held, bool sparse, Tensor* out) {
240   if (ctx->input_dtype(input) == DT_RESOURCE) {
241     core::RefCountPtr<Var> var;
242     TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
243     if (sparse) {
244       TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
245       *out = *var->tensor();
246       return Status::OK();
247     }
248     TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
249         ctx, var->tensor(), var->copy_on_read_mode.load()));
250     *out = *var->tensor();
251     return Status::OK();
252   }
253   *out = ctx->mutable_input(input, lock_held);
254   return Status::OK();
255 }
256 
257 }  // end namespace tensorflow
258 
259 #endif  // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
260