1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/cancellation.h"
17
18 #include <vector>
19 #include "tensorflow/core/lib/core/notification.h"
20 #include "tensorflow/core/lib/core/threadpool.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace tensorflow {
24
TEST(Cancellation,SimpleNoCancel)25 TEST(Cancellation, SimpleNoCancel) {
26 bool is_cancelled = false;
27 CancellationManager* manager = new CancellationManager();
28 auto token = manager->get_cancellation_token();
29 bool registered = manager->RegisterCallback(
30 token, [&is_cancelled]() { is_cancelled = true; });
31 EXPECT_TRUE(registered);
32 bool deregistered = manager->DeregisterCallback(token);
33 EXPECT_TRUE(deregistered);
34 delete manager;
35 EXPECT_FALSE(is_cancelled);
36 }
37
TEST(Cancellation,SimpleCancel)38 TEST(Cancellation, SimpleCancel) {
39 bool is_cancelled = false;
40 CancellationManager* manager = new CancellationManager();
41 auto token = manager->get_cancellation_token();
42 bool registered = manager->RegisterCallback(
43 token, [&is_cancelled]() { is_cancelled = true; });
44 EXPECT_TRUE(registered);
45 manager->StartCancel();
46 EXPECT_TRUE(is_cancelled);
47 delete manager;
48 }
49
TEST(Cancellation,CancelBeforeRegister)50 TEST(Cancellation, CancelBeforeRegister) {
51 CancellationManager* manager = new CancellationManager();
52 auto token = manager->get_cancellation_token();
53 manager->StartCancel();
54 bool registered = manager->RegisterCallback(token, nullptr);
55 EXPECT_FALSE(registered);
56 delete manager;
57 }
58
TEST(Cancellation,DeregisterAfterCancel)59 TEST(Cancellation, DeregisterAfterCancel) {
60 bool is_cancelled = false;
61 CancellationManager* manager = new CancellationManager();
62 auto token = manager->get_cancellation_token();
63 bool registered = manager->RegisterCallback(
64 token, [&is_cancelled]() { is_cancelled = true; });
65 EXPECT_TRUE(registered);
66 manager->StartCancel();
67 EXPECT_TRUE(is_cancelled);
68 bool deregistered = manager->DeregisterCallback(token);
69 EXPECT_FALSE(deregistered);
70 delete manager;
71 }
72
TEST(Cancellation,CancelMultiple)73 TEST(Cancellation, CancelMultiple) {
74 bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
75 CancellationManager* manager = new CancellationManager();
76 auto token_1 = manager->get_cancellation_token();
77 bool registered_1 = manager->RegisterCallback(
78 token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
79 EXPECT_TRUE(registered_1);
80 auto token_2 = manager->get_cancellation_token();
81 bool registered_2 = manager->RegisterCallback(
82 token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
83 EXPECT_TRUE(registered_2);
84 EXPECT_FALSE(is_cancelled_1);
85 EXPECT_FALSE(is_cancelled_2);
86 manager->StartCancel();
87 EXPECT_TRUE(is_cancelled_1);
88 EXPECT_TRUE(is_cancelled_2);
89 EXPECT_FALSE(is_cancelled_3);
90 auto token_3 = manager->get_cancellation_token();
91 bool registered_3 = manager->RegisterCallback(
92 token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
93 EXPECT_FALSE(registered_3);
94 EXPECT_FALSE(is_cancelled_3);
95 delete manager;
96 }
97
TEST(Cancellation,IsCancelled)98 TEST(Cancellation, IsCancelled) {
99 CancellationManager* cm = new CancellationManager();
100 thread::ThreadPool w(Env::Default(), "test", 4);
101 std::vector<Notification> done(8);
102 for (size_t i = 0; i < done.size(); ++i) {
103 Notification* n = &done[i];
104 w.Schedule([n, cm]() {
105 while (!cm->IsCancelled()) {
106 }
107 n->Notify();
108 });
109 }
110 Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
111 cm->StartCancel();
112 for (size_t i = 0; i < done.size(); ++i) {
113 done[i].WaitForNotification();
114 }
115 delete cm;
116 }
117
TEST(Cancellation,TryDeregisterWithoutCancel)118 TEST(Cancellation, TryDeregisterWithoutCancel) {
119 bool is_cancelled = false;
120 CancellationManager* manager = new CancellationManager();
121 auto token = manager->get_cancellation_token();
122 bool registered = manager->RegisterCallback(
123 token, [&is_cancelled]() { is_cancelled = true; });
124 EXPECT_TRUE(registered);
125 bool deregistered = manager->TryDeregisterCallback(token);
126 EXPECT_TRUE(deregistered);
127 delete manager;
128 EXPECT_FALSE(is_cancelled);
129 }
130
TEST(Cancellation,TryDeregisterAfterCancel)131 TEST(Cancellation, TryDeregisterAfterCancel) {
132 bool is_cancelled = false;
133 CancellationManager* manager = new CancellationManager();
134 auto token = manager->get_cancellation_token();
135 bool registered = manager->RegisterCallback(
136 token, [&is_cancelled]() { is_cancelled = true; });
137 EXPECT_TRUE(registered);
138 manager->StartCancel();
139 EXPECT_TRUE(is_cancelled);
140 bool deregistered = manager->TryDeregisterCallback(token);
141 EXPECT_FALSE(deregistered);
142 delete manager;
143 }
144
TEST(Cancellation,TryDeregisterDuringCancel)145 TEST(Cancellation, TryDeregisterDuringCancel) {
146 Notification cancel_started, finish_callback, cancel_complete;
147 CancellationManager* manager = new CancellationManager();
148 auto token = manager->get_cancellation_token();
149 bool registered = manager->RegisterCallback(token, [&]() {
150 cancel_started.Notify();
151 finish_callback.WaitForNotification();
152 });
153 EXPECT_TRUE(registered);
154
155 thread::ThreadPool w(Env::Default(), "test", 1);
156 w.Schedule([&]() {
157 manager->StartCancel();
158 cancel_complete.Notify();
159 });
160 cancel_started.WaitForNotification();
161
162 bool deregistered = manager->TryDeregisterCallback(token);
163 EXPECT_FALSE(deregistered);
164
165 finish_callback.Notify();
166 cancel_complete.WaitForNotification();
167 delete manager;
168 }
169
170 } // namespace tensorflow
171