1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <deque> 19 #include <utility> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/framework/shared_ptr_variant.h" 24 #include "tensorflow/core/framework/variant.h" 25 #include "tensorflow/core/framework/variant_encode_decode.h" 26 #include "tensorflow/core/kernels/ops_util.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/lib/core/threadpool.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/platform/mutex.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace tensorflow { 34 35 namespace { 36 37 class Mutex : public ResourceBase { 38 public: Mutex(OpKernelContext * c,const string & name)39 explicit Mutex(OpKernelContext* c, const string& name) 40 : locked_(false), 41 thread_pool_(new thread::ThreadPool( 42 c->env(), ThreadOptions(), 43 strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)), 44 1 /* num_threads */, false /* low_latency_hint */)), 45 name_(name) { 46 VLOG(2) << "Creating mutex with name " << name << ": " << this; 47 } 48 DebugString() const49 string DebugString() const override { 50 return strings::StrCat("Mutex ", name_); 51 } 52 53 class LockReleaser { 54 public: LockReleaser(Mutex * mutex)55 explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {} 56 57 LockReleaser(const LockReleaser&) = delete; 58 LockReleaser& operator=(const LockReleaser&) = delete; 59 ~LockReleaser()60 virtual ~LockReleaser() { 61 VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_; 62 if (mutex_) { 63 mutex_lock lock(mutex_->mu_); 64 mutex_->locked_ = false; 65 mutex_->cv_.notify_all(); 66 VLOG(3) << "Destroying LockReleaser " << this 67 << ": sent notifications."; 68 } 69 } 70 71 private: 72 Mutex* mutex_; 73 }; 74 75 typedef SharedPtrVariant<LockReleaser> SharedLockReleaser; 76 AcquireAsync(OpKernelContext * c,std::function<void (const Status & s,SharedLockReleaser lock)> fn)77 void AcquireAsync( 78 OpKernelContext* c, 79 std::function<void(const Status& s, SharedLockReleaser lock)> fn) { 80 CancellationManager* cm = c->cancellation_manager(); 81 CancellationToken token{}; 82 bool* cancelled = nullptr; 83 if (cm) { 84 cancelled = new bool(false); // TF_GUARDED_BY(mu_); 85 token = cm->get_cancellation_token(); 86 const bool already_cancelled = 87 !cm->RegisterCallback(token, [this, cancelled]() { 88 mutex_lock lock(mu_); 89 *cancelled = true; 90 cv_.notify_all(); 91 }); 92 if (already_cancelled) { 93 delete cancelled; 94 fn(errors::Cancelled("Lock acquisition cancelled."), 95 SharedLockReleaser{nullptr}); 96 return; 97 } 98 } 99 thread_pool_->Schedule(std::bind( 100 [this, cm, cancelled, 101 token](std::function<void(const Status& s, SharedLockReleaser&& lock)> 102 fn_) { 103 bool local_locked; 104 { 105 mutex_lock lock(mu_); 106 while (locked_ && !(cancelled && *cancelled)) { 107 cv_.wait(lock); 108 } 109 local_locked = locked_ = !(cancelled && *cancelled); 110 } 111 if (cm) { 112 cm->DeregisterCallback(token); 113 delete cancelled; 114 } 115 if (local_locked) { // Not cancelled. 116 fn_(OkStatus(), 117 SharedLockReleaser{std::make_shared<LockReleaser>(this)}); 118 } else { 119 fn_(errors::Cancelled("Lock acquisition cancelled."), 120 SharedLockReleaser{nullptr}); 121 } 122 }, 123 std::move(fn))); 124 } 125 126 private: 127 mutex mu_; 128 condition_variable cv_ TF_GUARDED_BY(mu_); 129 bool locked_ TF_GUARDED_BY(mu_); 130 std::unique_ptr<thread::ThreadPool> thread_pool_; 131 string name_; 132 }; 133 134 } // namespace 135 136 class MutexLockOp : public AsyncOpKernel { 137 public: MutexLockOp(OpKernelConstruction * c)138 explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {} 139 140 public: ComputeAsync(OpKernelContext * c,DoneCallback done)141 void ComputeAsync(OpKernelContext* c, DoneCallback done) override { 142 Mutex* mutex = nullptr; 143 OP_REQUIRES_OK_ASYNC( 144 c, 145 LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex, 146 [c](Mutex** ptr) { 147 *ptr = new Mutex( 148 c, HandleFromInput(c, 0).name()); 149 return OkStatus(); 150 }), 151 done); 152 153 Tensor* variant; 154 OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant), 155 done); 156 157 mutex->AcquireAsync( 158 c, std::bind( 159 [c, variant, mutex](DoneCallback done_, 160 // End of bound arguments. 161 const Status& s, 162 Mutex::SharedLockReleaser&& lock) { 163 VLOG(2) << "Finished locking mutex " << mutex 164 << " with lock: " << lock.shared_ptr.get() 165 << " status: " << s.ToString(); 166 if (s.ok()) { 167 variant->scalar<Variant>()() = std::move(lock); 168 } else { 169 c->SetStatus(s); 170 } 171 mutex->Unref(); 172 done_(); 173 }, 174 std::move(done), std::placeholders::_1, std::placeholders::_2)); 175 } 176 }; 177 178 class ConsumeMutexLockOp : public OpKernel { 179 public: ConsumeMutexLockOp(OpKernelConstruction * context)180 explicit ConsumeMutexLockOp(OpKernelConstruction* context) 181 : OpKernel(context) {} 182 Compute(OpKernelContext * c)183 void Compute(OpKernelContext* c) override { 184 VLOG(2) << "Executing ConsumeMutexLockOp"; 185 const Tensor& lock_t = c->input(0); 186 OP_REQUIRES( 187 c, lock_t.dims() == 0, 188 errors::InvalidArgument("Expected input to be a scalar, saw shape: ", 189 lock_t.shape().DebugString())); 190 OP_REQUIRES( 191 c, lock_t.dtype() == DT_VARIANT, 192 errors::InvalidArgument("Expected input to be a variant, saw type: ", 193 DataTypeString(lock_t.dtype()))); 194 const auto* lock = 195 lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>(); 196 OP_REQUIRES(c, lock, 197 errors::InvalidArgument( 198 "Expected input to contain a SharedLockReleaser " 199 "object, but saw variant: '", 200 lock_t.scalar<Variant>()().DebugString(), "'")); 201 const int use_count = lock->shared_ptr.use_count(); 202 OP_REQUIRES( 203 c, use_count == 1, 204 errors::InvalidArgument("Expected use count of lock to be 1, but saw: ", 205 use_count)); 206 } 207 IsExpensive()208 bool IsExpensive() override { return false; } 209 }; 210 211 REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp); 212 213 REGISTER_KERNEL_BUILDER(Name("MutexLock") 214 .Device(DEVICE_DEFAULT) 215 .HostMemory("mutex_lock") 216 .HostMemory("mutex"), 217 MutexLockOp); 218 219 REGISTER_KERNEL_BUILDER( 220 Name("MutexV2").Device(DEVICE_CPU).HostMemory("resource"), 221 ResourceHandleOp<Mutex>); 222 223 REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_DEFAULT), 224 ResourceHandleOp<Mutex>); 225 226 REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU), 227 ConsumeMutexLockOp); 228 229 REGISTER_KERNEL_BUILDER( 230 Name("ConsumeMutexLock").Device(DEVICE_DEFAULT).HostMemory("mutex_lock"), 231 ConsumeMutexLockOp); 232 233 } // namespace tensorflow 234