• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
18 
19 #include <atomic>
20 #include <functional>
21 
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/gtl/flatmap.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 // A token that can be used to register and deregister a
33 // CancelCallback with a CancellationManager.
34 //
35 // CancellationToken values must be created by a call to
36 // CancellationManager::get_cancellation_token.
37 typedef int64 CancellationToken;
38 
39 // A callback that is invoked when a step is canceled.
40 //
41 // NOTE(mrry): See caveats about CancelCallback implementations in the
42 // comment for CancellationManager::RegisterCallback.
43 typedef std::function<void()> CancelCallback;
44 
45 // This class should never simultaneously be used as the cancellation manager
46 // for two separate sets of executions (i.e two separate steps, or two separate
47 // function executions).
48 class CancellationManager {
49  public:
50   // A value that won't be returned by get_cancellation_token().
51   static const CancellationToken kInvalidToken;
52 
53   CancellationManager();
54 
55   // Constructs a new CancellationManager that is a "child" of `*parent`.
56   //
57   // If `*parent` is cancelled, `*this` will be cancelled. `*parent` must
58   // outlive the created CancellationManager.
59   explicit CancellationManager(CancellationManager* parent);
60 
61   ~CancellationManager();
62 
63   // Run all callbacks associated with this manager.
64   void StartCancel();
65 
66   // Returns true iff StartCancel() has been called.
IsCancelled()67   bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
68 
69   // Returns a token that must be used in calls to RegisterCallback
70   // and DeregisterCallback.
get_cancellation_token()71   CancellationToken get_cancellation_token() {
72     return next_cancellation_token_.fetch_add(1);
73   }
74 
75   // Attempts to register the given callback to be invoked when this
76   // manager is cancelled. Returns true if the callback was
77   // registered; returns false if this manager was already cancelled,
78   // and the callback was not registered.
79   //
80   // If this method returns false, it is the caller's responsibility
81   // to perform any cancellation cleanup.
82   //
83   // This method is tricky to use correctly. The following usage pattern
84   // is recommended:
85   //
86   // class ObjectWithCancellableOperation {
87   //   mutex mu_;
88   //   void CancellableOperation(CancellationManager* cm,
89   //                             std::function<void(Status)> callback) {
90   //     bool already_cancelled;
91   //     CancellationToken token = cm->get_cancellation_token();
92   //     {
93   //       mutex_lock(mu_);
94   //       already_cancelled = !cm->RegisterCallback(
95   //           [this, token]() { Cancel(token); });
96   //       if (!already_cancelled) {
97   //         // Issue asynchronous operation. Associate the pending operation
98   //         // with `token` in some object state, or provide another way for
99   //         // the Cancel method to look up the operation for cancellation.
100   //         // Ensure that `cm->DeregisterCallback(token)` is called without
101   //         // holding `mu_`, before `callback` is invoked.
102   //         // ...
103   //       }
104   //     }
105   //     if (already_cancelled) {
106   //       callback(errors::Cancelled("Operation was cancelled"));
107   //     }
108   //   }
109   //
110   //   void Cancel(CancellationToken token) {
111   //     mutex_lock(mu_);
112   //     // Take action to cancel the operation with the given cancellation
113   //     // token.
114   //   }
115   //
116   // NOTE(mrry): The caller should take care that (i) the calling code
117   // is robust to `callback` being invoked asynchronously (e.g. from
118   // another thread), (ii) `callback` is deregistered by a call to
119   // this->DeregisterCallback(token) when the operation completes
120   // successfully, and (iii) `callback` does not invoke any method
121   // on this cancellation manager. Furthermore, it is important that
122   // the eventual caller of the complementary DeregisterCallback does not
123   // hold any mutexes that are required by `callback`.
124   bool RegisterCallback(CancellationToken token, CancelCallback callback);
125 
126   // Deregister the callback that, when registered, was associated
127   // with the given cancellation token. Returns true iff the callback
128   // was deregistered and will not be invoked; otherwise returns false
129   // after the callback has been invoked, blocking if necessary.
130   //
131   // NOTE(mrry): This method may block if cancellation is in progress.
132   // The caller of this method must not hold any mutexes that are required
133   // to invoke any cancellation callback that has been registered with this
134   // cancellation manager.
135   bool DeregisterCallback(CancellationToken token);
136 
137   // Deregister the callback that, when registered, was associated
138   // with the given cancellation token. Returns true iff the callback
139   // was deregistered and will not be invoked; otherwise returns false
140   // immediately, with no guarantee that the callback has completed.
141   //
142   // This method is guaranteed to return true if StartCancel has not been
143   // called.
144   bool TryDeregisterCallback(CancellationToken token);
145 
146   // Returns true iff cancellation is in progress.
147   bool IsCancelling();
148 
149  private:
150   struct State {
151     Notification cancelled_notification;
152     gtl::FlatMap<CancellationToken, CancelCallback> callbacks;
153 
154     // If this CancellationManager has any children, this member points to the
155     // head of a doubly-linked list of its children.
156     CancellationManager* first_child = nullptr;  // Not owned.
157   };
158 
159   bool RegisterChild(CancellationManager* child);
160   void DeregisterChild(CancellationManager* child);
161 
162   bool is_cancelling_;
163   std::atomic_bool is_cancelled_;
164   std::atomic<CancellationToken> next_cancellation_token_;
165 
166   CancellationManager* const parent_ = nullptr;  // Not owned.
167 
168   // If this CancellationManager is associated with a parent, this member will
169   // be set to `true` after this is removed from the parent's list of children.
170   bool is_removed_from_parent_ TF_GUARDED_BY(parent_->mu_) = false;
171 
172   // If this CancellationManager is associated with a parent, these members form
173   // a doubly-linked list of that parent's children.
174   //
175   // These fields are valid only when `this->is_removed_from_parent_` is false.
176   CancellationManager* prev_sibling_ TF_GUARDED_BY(parent_->mu_) =
177       nullptr;  // Not owned.
178   CancellationManager* next_sibling_ TF_GUARDED_BY(parent_->mu_) =
179       nullptr;  // Not owned.
180 
181   mutex mu_;
182   std::unique_ptr<State> state_ TF_GUARDED_BY(mu_);
183 };
184 
185 // Registers the given cancellation callback, returning a function that can be
186 // used to deregister the callback. If `cancellation_manager` is NULL, no
187 // registration occurs and `deregister_fn` will be a no-op.
188 Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
189                                     std::function<void()> callback,
190                                     std::function<void()>* deregister_fn);
191 
192 }  // namespace tensorflow
193 
194 #endif  // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
195