1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/kernels/dense_update_functor.h"
23 #include "tensorflow/core/kernels/variable_ops.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25
26 namespace tensorflow {
27
28 // Must be called before performing a sparse operation on a variable. Ensures
29 // that no concurrent dense operations can happen while holding the variable's
30 // lock.
31 template <typename Device, typename T>
EnsureSparseVariableAccess(OpKernelContext * ctx,Var * var)32 Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
33 if (var->copy_on_read_mode.load()) {
34 return Status::OK();
35 }
36 mutex_lock ml(*var->mu());
37 // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
38 // also happen if there are no concurrent reads of the variable and
39 // copy-on-read mode is false.
40 if (var->tensor()->RefCountIsOne()) {
41 var->copy_on_read_mode.store(true);
42 return Status::OK();
43 }
44 Tensor tmp;
45 if (std::is_same<T, Variant>::value) {
46 AllocatorAttributes attr;
47 attr.set_on_host(true);
48 TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
49 var->tensor()->shape(), &tmp, attr));
50
51 const auto elements_in = var->tensor()->flat<Variant>();
52 auto elements_out = tmp.flat<Variant>();
53 for (int64_t i = 0; i < elements_in.size(); ++i) {
54 elements_out(i) = elements_in(i);
55 }
56 } else {
57 AllocatorAttributes attr;
58 attr.set_gpu_compatible(true);
59 attr.set_nic_compatible(true);
60 TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
61 var->tensor()->shape(), &tmp, attr));
62 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
63 copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
64 const_cast<const Tensor*>(var->tensor())->flat<T>());
65 }
66 *var->tensor() = tmp;
67 var->copy_on_read_mode.store(true);
68 return Status::OK();
69 }
70
71 // Utility structure that releases a sequence of borrowed mutexes when it is
72 // deleted.
73 struct VariableInputLockHolder {
74 public:
VariableInputLockHolderVariableInputLockHolder75 VariableInputLockHolder(
76 std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
77 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
78 : vars_(std::move(vars)),
79 locks_(std::move(locks)),
80 shared_locks_(std::move(shared_locks)) {}
81
VariableInputLockHolderVariableInputLockHolder82 VariableInputLockHolder(VariableInputLockHolder&& other)
83 : vars_(std::move(other.vars_)),
84 locks_(std::move(other.locks_)),
85 shared_locks_(std::move(other.shared_locks_)) {}
86
~VariableInputLockHolderVariableInputLockHolder87 ~VariableInputLockHolder() {
88 // Release the locks before unreffing the Vars, because each lock
89 // is potentially borrowed from a Var in vars_.
90 locks_.reset();
91 for (Var* var : vars_) {
92 var->Unref();
93 }
94 }
95
96 private:
97 std::vector<Var*> vars_;
98 // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
99 // because a `std::vector<mutex_lock>` is not movable on all platforms.
100 std::unique_ptr<std::vector<mutex_lock>> locks_;
101 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
102 };
103
104 // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
105 //
106 // If `input` corresponds to a `DT_RESOURCE`-type variable input,
107 // `*maybe_resource` will be updated to contain the underlying resource, and the
108 // caller will be responsible for calling `Unref()` on that resource.
109 template <typename Device, typename T>
GetTrainingVariableMutex(OpKernelContext * ctx,int input,bool sparse,Var ** maybe_resource)110 mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
111 Var** maybe_resource) {
112 *maybe_resource = nullptr;
113 if (ctx->input_dtype(input) == DT_RESOURCE) {
114 if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
115 if (sparse) {
116 EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
117 .IgnoreError();
118 }
119 return (*maybe_resource)->mu();
120 } else {
121 ctx->CtxFailureWithWarning(
122 errors::Internal("Invalid variable reference."));
123 return nullptr;
124 }
125 }
126 return ctx->input_ref_mutex(input);
127 }
128
129 // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
130 // in address order to mitigate deadlock. Returns a structure that, when
131 // deleted, will release the acquired mutexes. Safe to pass duplicates - will
132 // only lock each distinct mutex once. If sparse is true will ensure the
133 // variable gets switched to copy-on-read mode before trying to acquire the
134 // locks. If do_lock is false, returns immediately for reference variables. For
135 // resource variables in copy-on-read-mode it will grab a shared lock if do_lock
136 // is false, exclusive lock otherwise. Note that this silently doesn't lock
137 // mutexes for invalid variable references; in all usages this is followed by
138 // GetInputTensor which will signal a failure.
139 template <typename Device, typename T>
MaybeLockVariableInputMutexesInOrder(OpKernelContext * ctx,bool do_lock,bool sparse,const std::vector<int> & input_ids)140 VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
141 OpKernelContext* ctx, bool do_lock, bool sparse,
142 const std::vector<int>& input_ids) {
143 bool any_resource = false;
144 for (auto i : input_ids) {
145 if (ctx->input_dtype(i) == DT_RESOURCE) {
146 any_resource = true;
147 break;
148 }
149 }
150 if (!do_lock && !any_resource) {
151 return VariableInputLockHolder({}, {}, {});
152 }
153 std::vector<Var*> vars;
154 std::vector<mutex*> mutexes;
155 std::vector<int> acquire_order;
156 for (auto input : input_ids) {
157 Var* var;
158 mutex* mutex =
159 GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
160 if (var) vars.push_back(var);
161 // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
162 if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
163 acquire_order.push_back(mutexes.size());
164 mutexes.push_back(mutex);
165 }
166 }
167 std::sort(acquire_order.begin(), acquire_order.end(),
168 [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
169
170 auto locks = absl::make_unique<std::vector<mutex_lock>>();
171 auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
172 locks->reserve(acquire_order.size());
173
174 for (auto input : acquire_order) {
175 Var* var;
176 mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
177 core::ScopedUnref scoped_unref(var);
178 if (mu != nullptr) {
179 if (!sparse || do_lock) {
180 locks->emplace_back(*mu);
181 } else {
182 shared_locks->emplace_back(*mu);
183 }
184 }
185 }
186 return VariableInputLockHolder(std::move(vars), std::move(locks),
187 std::move(shared_locks));
188 }
189
190 void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
191 int output);
192
193 // This is for use with ResourceVariables to ensure *tensor has a
194 // reference count of 1 before you update it.
195 // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
196 template <typename Device, typename T>
PrepareToUpdateVariable(OpKernelContext * ctx,Tensor * tensor,bool copy_on_read_mode)197 Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
198 bool copy_on_read_mode) {
199 if (copy_on_read_mode || !tensor->RefCountIsOne()) {
200 // Tensor's buffer is in use by some read, so we need to copy before
201 // updating.
202 Tensor tmp;
203 if (std::is_same<T, Variant>::value) {
204 AllocatorAttributes attr;
205 attr.set_on_host(true);
206 TF_RETURN_IF_ERROR(
207 ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
208
209 const auto elements_in = tensor->flat<Variant>();
210 auto elements_out = tmp.flat<Variant>();
211 for (int64_t i = 0; i < elements_in.size(); ++i) {
212 elements_out(i) = elements_in(i);
213 }
214 } else {
215 AllocatorAttributes attr;
216 attr.set_gpu_compatible(true);
217 attr.set_nic_compatible(true);
218 TF_RETURN_IF_ERROR(
219 ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
220 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
221 copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
222 const_cast<const Tensor*>(tensor)->flat<T>());
223 }
224 *tensor = tmp;
225 }
226 return Status::OK();
227 }
228
229 // This gives you `*out`, a tensor you can update, corresponding to a variable
230 // passed as input index `input`. This handles the differences between
231 // reference and resource variables. For reference variables we can just grab
232 // the tensor, grabbing the lock if lock_held is False.
233 //
234 // For resource variables we, if sparse is true, ensure it's in copy-on-read
235 // mode, and then, regardless of the value of sparse, ensure its refcount is 1
236 // (by potentially copying its contents). In this case lock_held is ignored.
237 template <typename Device, typename T>
GetInputTensorFromVariable(OpKernelContext * ctx,int input,bool lock_held,bool sparse,Tensor * out)238 Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
239 bool lock_held, bool sparse, Tensor* out) {
240 if (ctx->input_dtype(input) == DT_RESOURCE) {
241 core::RefCountPtr<Var> var;
242 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
243 if (sparse) {
244 TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
245 *out = *var->tensor();
246 return Status::OK();
247 }
248 TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
249 ctx, var->tensor(), var->copy_on_read_mode.load()));
250 *out = *var->tensor();
251 return Status::OK();
252 }
253 *out = ctx->mutable_input(input, lock_held);
254 return Status::OK();
255 }
256
257 } // end namespace tensorflow
258
259 #endif // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
260