1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/kernels/dense_update_functor.h"
23 #include "tensorflow/core/kernels/variable_ops.h"
24
25 namespace tensorflow {
26
27 // Must be called before performing a sparse operation on a variable. Ensures
28 // that no concurrent dense operations can happen while holding the variable's
29 // lock.
30 template <typename Device, typename T>
EnsureSparseVariableAccess(OpKernelContext * ctx,Var * var)31 Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
32 if (var->copy_on_read_mode.load()) {
33 return Status::OK();
34 }
35 mutex_lock ml(*var->mu());
36 // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
37 // also happen if there are no concurrent reads of the variable and
38 // copy-on-read mode is false.
39 if (var->tensor()->RefCountIsOne()) {
40 var->copy_on_read_mode.store(true);
41 return Status::OK();
42 }
43 PersistentTensor unused;
44 Tensor* tmp;
45 if (std::is_same<T, Variant>::value) {
46 AllocatorAttributes attr;
47 attr.set_on_host(true);
48 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
49 var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
50
51 const auto elements_in = var->tensor()->flat<Variant>();
52 auto elements_out = tmp->flat<Variant>();
53 for (int64 i = 0; i < elements_in.size(); ++i) {
54 elements_out(i) = elements_in(i);
55 }
56 } else {
57 AllocatorAttributes attr;
58 attr.set_gpu_compatible(true);
59 attr.set_nic_compatible(true);
60 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
61 var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
62 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
63 copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
64 const_cast<const Tensor*>(var->tensor())->flat<T>());
65 }
66 *var->tensor() = *tmp;
67 var->copy_on_read_mode.store(true);
68 return Status::OK();
69 }
70
71 // Utility structure that releases a sequence of borrowed mutexes when it is
72 // deleted.
73 struct VariableInputLockHolder {
74 public:
VariableInputLockHolderVariableInputLockHolder75 VariableInputLockHolder(
76 std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
77 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
78 : vars_(std::move(vars)),
79 locks_(std::move(locks)),
80 shared_locks_(std::move(shared_locks)) {}
81
VariableInputLockHolderVariableInputLockHolder82 VariableInputLockHolder(VariableInputLockHolder&& other)
83 : vars_(std::move(other.vars_)),
84 locks_(std::move(other.locks_)),
85 shared_locks_(std::move(other.shared_locks_)) {}
86
~VariableInputLockHolderVariableInputLockHolder87 ~VariableInputLockHolder() {
88 // Release the locks before unreffing the Vars, because each lock
89 // is potentially borrowed from a Var in vars_.
90 locks_.reset();
91 for (Var* var : vars_) {
92 var->Unref();
93 }
94 }
95
96 private:
97 std::vector<Var*> vars_;
98 // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
99 // because a `std::vector<mutex_lock>` is not movable on all platforms.
100 std::unique_ptr<std::vector<mutex_lock>> locks_;
101 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
102 };
103
104 // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
105 //
106 // If `input` corresponds to a `DT_RESOURCE`-type variable input,
107 // `*maybe_resource` will be updated to contain the underlying resource, and the
108 // caller will be responsible for calling `Unref()` on that resource.
109 template <typename Device, typename T>
GetTrainingVariableMutex(OpKernelContext * ctx,int input,bool sparse,Var ** maybe_resource)110 mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
111 Var** maybe_resource) {
112 *maybe_resource = nullptr;
113 if (ctx->input_dtype(input) == DT_RESOURCE) {
114 if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
115 if (sparse) {
116 EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
117 .IgnoreError();
118 }
119 return (*maybe_resource)->mu();
120 } else {
121 ctx->CtxFailureWithWarning(
122 errors::Internal("Invalid variable reference."));
123 return nullptr;
124 }
125 }
126 return ctx->input_ref_mutex(input);
127 }
128
129 // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
130 // in address order to mitigate deadlock. Returns a structure that, when
131 // deleted, will release the acquired mutexes. Safe to pass duplicates - will
132 // only lock each distinct mutex once. If sparse is true will ensure the
133 // variable gets switched to copy-on-read mode before trying to acquire the
134 // locks. If do_lock is false, returns immediately for reference variables. For
135 // resource variables in copy-on-read-mode it will grab a shared lock if do_lock
136 // is false, exclusive lock otherwise. Note that this silently doesn't lock
137 // mutexes for invalid variable references; in all usages this is followed by
138 // GetInputTensor which will signal a failure.
139 template <typename Device, typename T>
MaybeLockVariableInputMutexesInOrder(OpKernelContext * ctx,bool do_lock,bool sparse,const std::vector<int> & input_ids)140 VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
141 OpKernelContext* ctx, bool do_lock, bool sparse,
142 const std::vector<int>& input_ids) {
143 bool any_resource = false;
144 for (auto i : input_ids) {
145 if (ctx->input_dtype(i) == DT_RESOURCE) {
146 any_resource = true;
147 break;
148 }
149 }
150 if (!do_lock && !any_resource) {
151 return VariableInputLockHolder({}, {}, {});
152 }
153 std::vector<Var*> vars;
154 std::vector<mutex*> mutexes;
155 std::vector<int> acquire_order;
156 for (auto input : input_ids) {
157 Var* var;
158 mutex* mutex =
159 GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
160 if (var) vars.push_back(var);
161 // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
162 if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
163 acquire_order.push_back(mutexes.size());
164 mutexes.push_back(mutex);
165 }
166 }
167 std::sort(acquire_order.begin(), acquire_order.end(),
168 [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
169
170 std::unique_ptr<std::vector<mutex_lock>> locks =
171 absl::make_unique<std::vector<mutex_lock>>();
172 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
173 absl::make_unique<std::vector<tf_shared_lock>>();
174 locks->reserve(acquire_order.size());
175
176 for (auto input : acquire_order) {
177 Var* var;
178 mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
179 core::ScopedUnref scoped_unref(var);
180 if (mu != nullptr) {
181 if (!sparse || do_lock) {
182 locks->emplace_back(*mu);
183 } else {
184 shared_locks->emplace_back(*mu);
185 }
186 }
187 }
188 return VariableInputLockHolder(std::move(vars), std::move(locks),
189 std::move(shared_locks));
190 }
191
192 void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
193 int output);
194
195 // This is for use with ResourceVariables to ensure *tensor has a
196 // reference count of 1 before you update it.
197 // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
198 template <typename Device, typename T>
PrepareToUpdateVariable(OpKernelContext * ctx,Tensor * tensor,bool copy_on_read_mode)199 Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
200 bool copy_on_read_mode) {
201 if (copy_on_read_mode || !tensor->RefCountIsOne()) {
202 // Tensor's buffer is in use by some read, so we need to copy before
203 // updating.
204 PersistentTensor unused;
205 Tensor* tmp;
206 if (std::is_same<T, Variant>::value) {
207 AllocatorAttributes attr;
208 attr.set_on_host(true);
209 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
210 tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
211
212 const auto elements_in = tensor->flat<Variant>();
213 auto elements_out = tmp->flat<Variant>();
214 for (int64 i = 0; i < elements_in.size(); ++i) {
215 elements_out(i) = elements_in(i);
216 }
217 } else {
218 AllocatorAttributes attr;
219 attr.set_gpu_compatible(true);
220 attr.set_nic_compatible(true);
221 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
222 tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
223 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
224 copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
225 const_cast<const Tensor*>(tensor)->flat<T>());
226 }
227 *tensor = *tmp;
228 }
229 return Status::OK();
230 }
231
232 // This gives you `*out`, a tensor you can update, corresponding to a variable
233 // passed as input index `input`. This handles the differences between
234 // reference and resource variables. For reference variables we can just grab
235 // the tensor, grabbing the lock if lock_held is False.
236 //
237 // For resource variables we, if sparse is true, ensure it's in copy-on-read
238 // mode, and then, regardless of the value of sparse, ensure its refcount is 1
239 // (by potentially copying its contents). In this case lock_held is ignored.
240 template <typename Device, typename T>
GetInputTensorFromVariable(OpKernelContext * ctx,int input,bool lock_held,bool sparse,Tensor * out)241 Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
242 bool lock_held, bool sparse, Tensor* out) {
243 if (ctx->input_dtype(input) == DT_RESOURCE) {
244 Var* var;
245 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
246 core::ScopedUnref unref_var(var);
247 if (sparse) {
248 TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var));
249 *out = *var->tensor();
250 return Status::OK();
251 }
252 TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
253 ctx, var->tensor(), var->copy_on_read_mode.load()));
254 *out = *var->tensor();
255 return Status::OK();
256 }
257 *out = ctx->mutable_input(input, lock_held);
258 return Status::OK();
259 }
260
261 } // end namespace tensorflow
262
263 #endif // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
264