• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <deque>
19 #include <utility>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/shared_ptr_variant.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/framework/variant_encode_decode.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 
35 namespace {
36 
37 class Mutex : public ResourceBase {
38  public:
Mutex(OpKernelContext * c,const string & name)39   explicit Mutex(OpKernelContext* c, const string& name)
40       : locked_(false),
41         thread_pool_(new thread::ThreadPool(
42             c->env(), ThreadOptions(),
43             strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)),
44             1 /* num_threads */, false /* low_latency_hint */)),
45         name_(name) {
46     VLOG(2) << "Creating mutex with name " << name << ": " << this;
47   }
48 
DebugString() const49   string DebugString() const override {
50     return strings::StrCat("Mutex ", name_);
51   }
52 
53   class LockReleaser {
54    public:
LockReleaser(Mutex * mutex)55     explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {}
56 
57     LockReleaser(const LockReleaser&) = delete;
58     LockReleaser& operator=(const LockReleaser&) = delete;
59 
~LockReleaser()60     virtual ~LockReleaser() {
61       VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_;
62       if (mutex_) {
63         mutex_lock lock(mutex_->mu_);
64         mutex_->locked_ = false;
65         mutex_->cv_.notify_all();
66         VLOG(3) << "Destroying LockReleaser " << this
67                 << ": sent notifications.";
68       }
69     }
70 
71    private:
72     Mutex* mutex_;
73   };
74 
75   typedef SharedPtrVariant<LockReleaser> SharedLockReleaser;
76 
AcquireAsync(OpKernelContext * c,std::function<void (const Status & s,SharedLockReleaser lock)> fn)77   void AcquireAsync(
78       OpKernelContext* c,
79       std::function<void(const Status& s, SharedLockReleaser lock)> fn) {
80     CancellationManager* cm = c->cancellation_manager();
81     CancellationToken token{};
82     bool* cancelled = nullptr;
83     if (cm) {
84       cancelled = new bool(false);  // TF_GUARDED_BY(mu_);
85       token = cm->get_cancellation_token();
86       const bool already_cancelled =
87           !cm->RegisterCallback(token, [this, cancelled]() {
88             mutex_lock lock(mu_);
89             *cancelled = true;
90             cv_.notify_all();
91           });
92       if (already_cancelled) {
93         delete cancelled;
94         fn(errors::Cancelled("Lock acquisition cancelled."),
95            SharedLockReleaser{nullptr});
96         return;
97       }
98     }
99     thread_pool_->Schedule(std::bind(
100         [this, cm, cancelled,
101          token](std::function<void(const Status& s, SharedLockReleaser&& lock)>
102                     fn_) {
103           bool local_locked;
104           {
105             mutex_lock lock(mu_);
106             while (locked_ && !(cancelled && *cancelled)) {
107               cv_.wait(lock);
108             }
109             local_locked = locked_ = !(cancelled && *cancelled);
110           }
111           if (cm) {
112             cm->DeregisterCallback(token);
113             delete cancelled;
114           }
115           if (local_locked) {  // Not cancelled.
116             fn_(OkStatus(),
117                 SharedLockReleaser{std::make_shared<LockReleaser>(this)});
118           } else {
119             fn_(errors::Cancelled("Lock acquisition cancelled."),
120                 SharedLockReleaser{nullptr});
121           }
122         },
123         std::move(fn)));
124   }
125 
126  private:
127   mutex mu_;
128   condition_variable cv_ TF_GUARDED_BY(mu_);
129   bool locked_ TF_GUARDED_BY(mu_);
130   std::unique_ptr<thread::ThreadPool> thread_pool_;
131   string name_;
132 };
133 
134 }  // namespace
135 
136 class MutexLockOp : public AsyncOpKernel {
137  public:
MutexLockOp(OpKernelConstruction * c)138   explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {}
139 
140  public:
ComputeAsync(OpKernelContext * c,DoneCallback done)141   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
142     Mutex* mutex = nullptr;
143     OP_REQUIRES_OK_ASYNC(
144         c,
145         LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex,
146                                       [c](Mutex** ptr) {
147                                         *ptr = new Mutex(
148                                             c, HandleFromInput(c, 0).name());
149                                         return OkStatus();
150                                       }),
151         done);
152 
153     Tensor* variant;
154     OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant),
155                          done);
156 
157     mutex->AcquireAsync(
158         c, std::bind(
159                [c, variant, mutex](DoneCallback done_,
160                                    // End of bound arguments.
161                                    const Status& s,
162                                    Mutex::SharedLockReleaser&& lock) {
163                  VLOG(2) << "Finished locking mutex " << mutex
164                          << " with lock: " << lock.shared_ptr.get()
165                          << " status: " << s.ToString();
166                  if (s.ok()) {
167                    variant->scalar<Variant>()() = std::move(lock);
168                  } else {
169                    c->SetStatus(s);
170                  }
171                  mutex->Unref();
172                  done_();
173                },
174                std::move(done), std::placeholders::_1, std::placeholders::_2));
175   }
176 };
177 
178 class ConsumeMutexLockOp : public OpKernel {
179  public:
ConsumeMutexLockOp(OpKernelConstruction * context)180   explicit ConsumeMutexLockOp(OpKernelConstruction* context)
181       : OpKernel(context) {}
182 
Compute(OpKernelContext * c)183   void Compute(OpKernelContext* c) override {
184     VLOG(2) << "Executing ConsumeMutexLockOp";
185     const Tensor& lock_t = c->input(0);
186     OP_REQUIRES(
187         c, lock_t.dims() == 0,
188         errors::InvalidArgument("Expected input to be a scalar, saw shape: ",
189                                 lock_t.shape().DebugString()));
190     OP_REQUIRES(
191         c, lock_t.dtype() == DT_VARIANT,
192         errors::InvalidArgument("Expected input to be a variant, saw type: ",
193                                 DataTypeString(lock_t.dtype())));
194     const auto* lock =
195         lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>();
196     OP_REQUIRES(c, lock,
197                 errors::InvalidArgument(
198                     "Expected input to contain a SharedLockReleaser "
199                     "object, but saw variant: '",
200                     lock_t.scalar<Variant>()().DebugString(), "'"));
201     const int use_count = lock->shared_ptr.use_count();
202     OP_REQUIRES(
203         c, use_count == 1,
204         errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
205                                 use_count));
206   }
207 
IsExpensive()208   bool IsExpensive() override { return false; }
209 };
210 
211 REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp);
212 
213 REGISTER_KERNEL_BUILDER(Name("MutexLock")
214                             .Device(DEVICE_DEFAULT)
215                             .HostMemory("mutex_lock")
216                             .HostMemory("mutex"),
217                         MutexLockOp);
218 
219 REGISTER_KERNEL_BUILDER(
220     Name("MutexV2").Device(DEVICE_CPU).HostMemory("resource"),
221     ResourceHandleOp<Mutex>);
222 
223 REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_DEFAULT),
224                         ResourceHandleOp<Mutex>);
225 
226 REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU),
227                         ConsumeMutexLockOp);
228 
229 REGISTER_KERNEL_BUILDER(
230     Name("ConsumeMutexLock").Device(DEVICE_DEFAULT).HostMemory("mutex_lock"),
231     ConsumeMutexLockOp);
232 
233 }  // namespace tensorflow
234