• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/kernels/dense_update_functor.h"
23 #include "tensorflow/core/kernels/variable_ops.h"
24 
25 namespace tensorflow {
26 
27 // Must be called before performing a sparse operation on a variable. Ensures
28 // that no concurrent dense operations can happen while holding the variable's
29 // lock.
30 template <typename Device, typename T>
EnsureSparseVariableAccess(OpKernelContext * ctx,Var * var)31 Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
32   if (var->copy_on_read_mode.load()) {
33     return Status::OK();
34   }
35   mutex_lock ml(*var->mu());
36   // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
37   // also happen if there are no concurrent reads of the variable and
38   // copy-on-read mode is false.
39   if (var->tensor()->RefCountIsOne()) {
40     var->copy_on_read_mode.store(true);
41     return Status::OK();
42   }
43   PersistentTensor unused;
44   Tensor* tmp;
45   if (std::is_same<T, Variant>::value) {
46     AllocatorAttributes attr;
47     attr.set_on_host(true);
48     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
49         var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
50 
51     const auto elements_in = var->tensor()->flat<Variant>();
52     auto elements_out = tmp->flat<Variant>();
53     for (int64 i = 0; i < elements_in.size(); ++i) {
54       elements_out(i) = elements_in(i);
55     }
56   } else {
57     AllocatorAttributes attr;
58     attr.set_gpu_compatible(true);
59     attr.set_nic_compatible(true);
60     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
61         var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
62     functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
63     copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
64                  const_cast<const Tensor*>(var->tensor())->flat<T>());
65   }
66   *var->tensor() = *tmp;
67   var->copy_on_read_mode.store(true);
68   return Status::OK();
69 }
70 
71 // Utility structure that releases a sequence of borrowed mutexes when it is
72 // deleted.
73 struct VariableInputLockHolder {
74  public:
VariableInputLockHolderVariableInputLockHolder75   VariableInputLockHolder(
76       std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
77       std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
78       : vars_(std::move(vars)),
79         locks_(std::move(locks)),
80         shared_locks_(std::move(shared_locks)) {}
81 
VariableInputLockHolderVariableInputLockHolder82   VariableInputLockHolder(VariableInputLockHolder&& other)
83       : vars_(std::move(other.vars_)),
84         locks_(std::move(other.locks_)),
85         shared_locks_(std::move(other.shared_locks_)) {}
86 
~VariableInputLockHolderVariableInputLockHolder87   ~VariableInputLockHolder() {
88     // Release the locks before unreffing the Vars, because each lock
89     // is potentially borrowed from a Var in vars_.
90     locks_.reset();
91     for (Var* var : vars_) {
92       var->Unref();
93     }
94   }
95 
96  private:
97   std::vector<Var*> vars_;
98   // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
99   // because a `std::vector<mutex_lock>` is not movable on all platforms.
100   std::unique_ptr<std::vector<mutex_lock>> locks_;
101   std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
102 };
103 
104 // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
105 //
106 // If `input` corresponds to a `DT_RESOURCE`-type variable input,
107 // `*maybe_resource` will be updated to contain the underlying resource, and the
108 // caller will be responsible for calling `Unref()` on that resource.
109 template <typename Device, typename T>
GetTrainingVariableMutex(OpKernelContext * ctx,int input,bool sparse,Var ** maybe_resource)110 mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
111                                 Var** maybe_resource) {
112   *maybe_resource = nullptr;
113   if (ctx->input_dtype(input) == DT_RESOURCE) {
114     if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
115       if (sparse) {
116         EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
117             .IgnoreError();
118       }
119       return (*maybe_resource)->mu();
120     } else {
121       ctx->CtxFailureWithWarning(
122           errors::Internal("Invalid variable reference."));
123       return nullptr;
124     }
125   }
126   return ctx->input_ref_mutex(input);
127 }
128 
129 // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
130 // in address order to mitigate deadlock.  Returns a structure that, when
131 // deleted, will release the acquired mutexes. Safe to pass duplicates - will
132 // only lock each distinct mutex once. If sparse is true will ensure the
133 // variable gets switched to copy-on-read mode before trying to acquire the
134 // locks. If do_lock is false, returns immediately for reference variables. For
135 // resource variables in copy-on-read-mode it will grab a shared lock if do_lock
136 // is false, exclusive lock otherwise.  Note that this silently doesn't lock
137 // mutexes for invalid variable references; in all usages this is followed by
138 // GetInputTensor which will signal a failure.
139 template <typename Device, typename T>
MaybeLockVariableInputMutexesInOrder(OpKernelContext * ctx,bool do_lock,bool sparse,const std::vector<int> & input_ids)140 VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
141     OpKernelContext* ctx, bool do_lock, bool sparse,
142     const std::vector<int>& input_ids) {
143   bool any_resource = false;
144   for (auto i : input_ids) {
145     if (ctx->input_dtype(i) == DT_RESOURCE) {
146       any_resource = true;
147       break;
148     }
149   }
150   if (!do_lock && !any_resource) {
151     return VariableInputLockHolder({}, {}, {});
152   }
153   std::vector<Var*> vars;
154   std::vector<mutex*> mutexes;
155   std::vector<int> acquire_order;
156   for (auto input : input_ids) {
157     Var* var;
158     mutex* mutex =
159         GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
160     if (var) vars.push_back(var);
161     // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
162     if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
163       acquire_order.push_back(mutexes.size());
164       mutexes.push_back(mutex);
165     }
166   }
167   std::sort(acquire_order.begin(), acquire_order.end(),
168             [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
169 
170   std::unique_ptr<std::vector<mutex_lock>> locks =
171       absl::make_unique<std::vector<mutex_lock>>();
172   std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
173       absl::make_unique<std::vector<tf_shared_lock>>();
174   locks->reserve(acquire_order.size());
175 
176   for (auto input : acquire_order) {
177     Var* var;
178     mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
179     core::ScopedUnref scoped_unref(var);
180     if (mu != nullptr) {
181       if (!sparse || do_lock) {
182         locks->emplace_back(*mu);
183       } else {
184         shared_locks->emplace_back(*mu);
185       }
186     }
187   }
188   return VariableInputLockHolder(std::move(vars), std::move(locks),
189                                  std::move(shared_locks));
190 }
191 
192 void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
193                                      int output);
194 
195 // This is for use with ResourceVariables to ensure *tensor has a
196 // reference count of 1 before you update it.
197 // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
198 template <typename Device, typename T>
PrepareToUpdateVariable(OpKernelContext * ctx,Tensor * tensor,bool copy_on_read_mode)199 Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
200                                bool copy_on_read_mode) {
201   if (copy_on_read_mode || !tensor->RefCountIsOne()) {
202     // Tensor's buffer is in use by some read, so we need to copy before
203     // updating.
204     PersistentTensor unused;
205     Tensor* tmp;
206     if (std::is_same<T, Variant>::value) {
207       AllocatorAttributes attr;
208       attr.set_on_host(true);
209       TF_RETURN_IF_ERROR(ctx->allocate_persistent(
210           tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
211 
212       const auto elements_in = tensor->flat<Variant>();
213       auto elements_out = tmp->flat<Variant>();
214       for (int64 i = 0; i < elements_in.size(); ++i) {
215         elements_out(i) = elements_in(i);
216       }
217     } else {
218       AllocatorAttributes attr;
219       attr.set_gpu_compatible(true);
220       attr.set_nic_compatible(true);
221       TF_RETURN_IF_ERROR(ctx->allocate_persistent(
222           tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
223       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
224       copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
225                    const_cast<const Tensor*>(tensor)->flat<T>());
226     }
227     *tensor = *tmp;
228   }
229   return Status::OK();
230 }
231 
232 // This gives you `*out`, a tensor you can update, corresponding to a variable
233 // passed as input index `input`.  This handles the differences between
234 // reference and resource variables. For reference variables we can just grab
235 // the tensor, grabbing the lock if lock_held is False.
236 //
237 // For resource variables we, if sparse is true, ensure it's in copy-on-read
238 // mode, and then, regardless of the value of sparse, ensure its refcount is 1
239 // (by potentially copying its contents). In this case lock_held is ignored.
240 template <typename Device, typename T>
GetInputTensorFromVariable(OpKernelContext * ctx,int input,bool lock_held,bool sparse,Tensor * out)241 Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
242                                   bool lock_held, bool sparse, Tensor* out) {
243   if (ctx->input_dtype(input) == DT_RESOURCE) {
244     Var* var;
245     TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
246     core::ScopedUnref unref_var(var);
247     if (sparse) {
248       TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var));
249       *out = *var->tensor();
250       return Status::OK();
251     }
252     TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
253         ctx, var->tensor(), var->copy_on_read_mode.load()));
254     *out = *var->tensor();
255     return Status::OK();
256   }
257   *out = ctx->mutable_input(input, lock_held);
258   return Status::OK();
259 }
260 
261 }  // end namespace tensorflow
262 
263 #endif  // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
264