1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/cancellation.h"
17
18 #include <forward_list>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace tensorflow {
25
26 const CancellationToken CancellationManager::kInvalidToken = -1;
27
CancellationManager()28 CancellationManager::CancellationManager()
29 : is_cancelling_(false),
30 is_cancelled_(false),
31 next_cancellation_token_(0) {}
32
CancellationManager(CancellationManager * parent)33 CancellationManager::CancellationManager(CancellationManager* parent)
34 : is_cancelling_(false), next_cancellation_token_(0), parent_(parent) {
35 is_cancelled_ = parent->RegisterChild(this);
36 }
37
StartCancel()38 void CancellationManager::StartCancel() {
39 gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
40 std::forward_list<CancellationManager*> children_to_cancel;
41 Notification* cancelled_notification = nullptr;
42 {
43 mutex_lock l(mu_);
44 if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
45 return;
46 }
47 is_cancelling_ = true;
48 if (state_) {
49 std::swap(state_->callbacks, callbacks_to_run);
50
51 // Remove all children from the list of children.
52 CancellationManager* child = state_->first_child;
53 while (child != nullptr) {
54 children_to_cancel.push_front(child);
55 child->is_removed_from_parent_ = true;
56 child = child->next_sibling_;
57 }
58 state_->first_child = nullptr;
59
60 cancelled_notification = &state_->cancelled_notification;
61 }
62 }
63 // We call these callbacks without holding mu_, so that concurrent
64 // calls to DeregisterCallback, which can happen asynchronously, do
65 // not block. The callbacks remain valid because any concurrent call
66 // to DeregisterCallback will block until the
67 // cancelled_notification_ is notified.
68 for (auto key_and_value : callbacks_to_run) {
69 key_and_value.second();
70 }
71 for (CancellationManager* child : children_to_cancel) {
72 child->StartCancel();
73 }
74 {
75 mutex_lock l(mu_);
76 is_cancelling_ = false;
77 is_cancelled_.store(true, std::memory_order_release);
78 }
79 if (cancelled_notification) {
80 cancelled_notification->Notify();
81 }
82 }
83
RegisterCallback(CancellationToken token,CancelCallback callback)84 bool CancellationManager::RegisterCallback(CancellationToken token,
85 CancelCallback callback) {
86 DCHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
87 mutex_lock l(mu_);
88 bool should_register = !is_cancelled_ && !is_cancelling_;
89 if (should_register) {
90 if (!state_) {
91 state_ = absl::make_unique<State>();
92 }
93 std::swap(state_->callbacks[token], callback);
94 }
95 return should_register;
96 }
97
DeregisterCallback(CancellationToken token)98 bool CancellationManager::DeregisterCallback(CancellationToken token) {
99 mu_.lock();
100 if (is_cancelled_) {
101 mu_.unlock();
102 return false;
103 } else if (is_cancelling_) {
104 Notification* cancelled_notification =
105 state_ ? &state_->cancelled_notification : nullptr;
106 mu_.unlock();
107 // Wait for all of the cancellation callbacks to be called. This
108 // wait ensures that the caller of DeregisterCallback does not
109 // return immediately and free objects that may be used in the
110 // execution of any currently pending callbacks in StartCancel.
111 if (cancelled_notification) {
112 cancelled_notification->WaitForNotification();
113 }
114 return false;
115 } else {
116 if (state_) {
117 state_->callbacks.erase(token);
118 }
119 mu_.unlock();
120 return true;
121 }
122 }
123
RegisterChild(CancellationManager * child)124 bool CancellationManager::RegisterChild(CancellationManager* child) {
125 mutex_lock l(mu_);
126 if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
127 child->is_removed_from_parent_ = true;
128 return true;
129 }
130
131 if (!state_) {
132 state_ = absl::make_unique<State>();
133 }
134
135 // Push `child` onto the front of the list of children.
136 CancellationManager* current_head = state_->first_child;
137 state_->first_child = child;
138 child->prev_sibling_ = nullptr;
139 child->next_sibling_ = current_head;
140 if (current_head) {
141 current_head->prev_sibling_ = child;
142 }
143
144 return false;
145 }
146
DeregisterChild(CancellationManager * child)147 void CancellationManager::DeregisterChild(CancellationManager* child) {
148 DCHECK_EQ(child->parent_, this);
149 Notification* cancelled_notification = nullptr;
150 {
151 mutex_lock l(mu_);
152 if (!child->is_removed_from_parent_) {
153 // Remove the child from this manager's list of children.
154 DCHECK(state_);
155
156 if (child->prev_sibling_ == nullptr) {
157 // The child was at the head of the list.
158 DCHECK_EQ(state_->first_child, child);
159 state_->first_child = child->next_sibling_;
160 } else {
161 child->prev_sibling_->next_sibling_ = child->next_sibling_;
162 }
163
164 if (child->next_sibling_ != nullptr) {
165 child->next_sibling_->prev_sibling_ = child->prev_sibling_;
166 }
167
168 child->is_removed_from_parent_ = true;
169 }
170 if (is_cancelling_) {
171 cancelled_notification = &state_->cancelled_notification;
172 }
173 }
174
175 // Wait for an ongoing call to StartCancel() to finish. This wait ensures that
176 // the caller of DeregisterChild does not return immediately and free a child
177 // that may currently be being cancelled by StartCancel().
178 if (cancelled_notification) {
179 cancelled_notification->WaitForNotification();
180 }
181 }
182
TryDeregisterCallback(CancellationToken token)183 bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
184 mutex_lock lock(mu_);
185 if (is_cancelled_ || is_cancelling_) {
186 return false;
187 } else {
188 if (state_) {
189 state_->callbacks.erase(token);
190 }
191 return true;
192 }
193 }
194
~CancellationManager()195 CancellationManager::~CancellationManager() {
196 if (parent_) {
197 parent_->DeregisterChild(this);
198 }
199 if (state_) {
200 StartCancel();
201 }
202 }
203
IsCancelling()204 bool CancellationManager::IsCancelling() {
205 mutex_lock lock(mu_);
206 return is_cancelling_;
207 }
208
RegisterCancellationCallback(CancellationManager * cancellation_manager,std::function<void ()> callback,std::function<void ()> * deregister_fn)209 Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
210 std::function<void()> callback,
211 std::function<void()>* deregister_fn) {
212 if (cancellation_manager) {
213 CancellationToken token = cancellation_manager->get_cancellation_token();
214 if (!cancellation_manager->RegisterCallback(token, std::move(callback))) {
215 return errors::Cancelled("Operation was cancelled");
216 }
217 *deregister_fn = [cancellation_manager, token]() {
218 cancellation_manager->DeregisterCallback(token);
219 };
220 } else {
221 VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
222 "not be registered.";
223 *deregister_fn = []() {};
224 }
225 return Status::OK();
226 }
227
228 } // end namespace tensorflow
229