• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
18 
19 #include <atomic>
20 #include <functional>
21 
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/gtl/flatmap.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 // A token that can be used to register and deregister a
33 // CancelCallback with a CancellationManager.
34 //
35 // CancellationToken values must be created by a call to
36 // CancellationManager::get_cancellation_token.
37 typedef int64 CancellationToken;
38 
39 // A callback that is invoked when a step is canceled.
40 //
41 // NOTE(mrry): See caveats about CancelCallback implementations in the
42 // comment for CancellationManager::RegisterCallback.
43 typedef std::function<void()> CancelCallback;
44 
45 class CancellationManager {
46  public:
47   // A value that won't be returned by get_cancellation_token().
48   static const CancellationToken kInvalidToken;
49 
50   CancellationManager();
51   ~CancellationManager();
52 
53   // Run all callbacks associated with this manager.
54   void StartCancel();
55 
56   // Returns true iff StartCancel() has been called.
IsCancelled()57   bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
58 
59   // Resets the cancellation manager to its original pre-cancelled state.
60   void Reset();
61 
62   // Returns a token that must be used in calls to RegisterCallback
63   // and DeregisterCallback.
64   CancellationToken get_cancellation_token();
65 
66   // Attempts to register the given callback to be invoked when this
67   // manager is cancelled. Returns true if the callback was
68   // registered; returns false if this manager was already cancelled,
69   // and the callback was not registered.
70   //
71   // If this method returns false, it is the caller's responsibility
72   // to perform any cancellation cleanup.
73   //
74   // This method is tricky to use correctly. The following usage pattern
75   // is recommended:
76   //
77   // class ObjectWithCancellableOperation {
78   //   mutex mu_;
79   //   void CancellableOperation(CancellationManager* cm,
80   //                             std::function<void(Status)> callback) {
81   //     bool already_cancelled;
82   //     CancellationToken token = cm->get_cancellation_token();
83   //     {
84   //       mutex_lock(mu_);
85   //       already_cancelled = !cm->RegisterCallback(
86   //           [this, token]() { Cancel(token); });
87   //       if (!already_cancelled) {
88   //         // Issue asynchronous operation. Associate the pending operation
89   //         // with `token` in some object state, or provide another way for
90   //         // the Cancel method to look up the operation for cancellation.
91   //         // Ensure that `cm->DeregisterCallback(token)` is called without
92   //         // holding `mu_`, before `callback` is invoked.
93   //         // ...
94   //       }
95   //     }
96   //     if (already_cancelled) {
97   //       callback(errors::Cancelled("Operation was cancelled"));
98   //     }
99   //   }
100   //
101   //   void Cancel(CancellationToken token) {
102   //     mutex_lock(mu_);
103   //     // Take action to cancel the operation with the given cancellation
104   //     // token.
105   //   }
106   //
107   // NOTE(mrry): The caller should take care that (i) the calling code
108   // is robust to `callback` being invoked asynchronously (e.g. from
109   // another thread), (ii) `callback` is deregistered by a call to
110   // this->DeregisterCallback(token) when the operation completes
111   // successfully, and (iii) `callback` does not invoke any method
112   // on this cancellation manager. Furthermore, it is important that
113   // the eventual caller of the complementary DeregisterCallback does not
114   // hold any mutexes that are required by `callback`.
115   bool RegisterCallback(CancellationToken token, CancelCallback callback);
116 
117   // Deregister the callback that, when registered, was associated
118   // with the given cancellation token. Returns true iff the callback
119   // was deregistered and will not be invoked; otherwise returns false
120   // after the callback has been invoked, blocking if necessary.
121   //
122   // NOTE(mrry): This method may block if cancellation is in progress.
123   // The caller of this method must not hold any mutexes that are required
124   // to invoke any cancellation callback that has been registered with this
125   // cancellation manager.
126   bool DeregisterCallback(CancellationToken token);
127 
128   // Deregister the callback that, when registered, was associated
129   // with the given cancellation token. Returns true iff the callback
130   // was deregistered and will not be invoked; otherwise returns false
131   // immediately, with no guarantee that the callback has completed.
132   //
133   // This method is guaranteed to return true if StartCancel has not been
134   // called.
135   bool TryDeregisterCallback(CancellationToken token);
136 
137  private:
138   bool is_cancelling_;
139   std::atomic_bool is_cancelled_;
140 
141   mutex mu_;
142   Notification cancelled_notification_;
143   CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
144   gtl::FlatMap<CancellationToken, CancelCallback> callbacks_ GUARDED_BY(mu_);
145 };
146 
147 }  // namespace tensorflow
148 
149 #endif  // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
150