• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/cancellation.h"
17 
18 #include <vector>
19 #include "tensorflow/core/lib/core/notification.h"
20 #include "tensorflow/core/lib/core/threadpool.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 
TEST(Cancellation,SimpleNoCancel)25 TEST(Cancellation, SimpleNoCancel) {
26   bool is_cancelled = false;
27   CancellationManager* manager = new CancellationManager();
28   auto token = manager->get_cancellation_token();
29   bool registered = manager->RegisterCallback(
30       token, [&is_cancelled]() { is_cancelled = true; });
31   EXPECT_TRUE(registered);
32   bool deregistered = manager->DeregisterCallback(token);
33   EXPECT_TRUE(deregistered);
34   delete manager;
35   EXPECT_FALSE(is_cancelled);
36 }
37 
TEST(Cancellation,SimpleCancel)38 TEST(Cancellation, SimpleCancel) {
39   bool is_cancelled = false;
40   CancellationManager* manager = new CancellationManager();
41   auto token = manager->get_cancellation_token();
42   bool registered = manager->RegisterCallback(
43       token, [&is_cancelled]() { is_cancelled = true; });
44   EXPECT_TRUE(registered);
45   manager->StartCancel();
46   EXPECT_TRUE(is_cancelled);
47   delete manager;
48 }
49 
TEST(Cancellation,CancelBeforeRegister)50 TEST(Cancellation, CancelBeforeRegister) {
51   CancellationManager* manager = new CancellationManager();
52   auto token = manager->get_cancellation_token();
53   manager->StartCancel();
54   bool registered = manager->RegisterCallback(token, nullptr);
55   EXPECT_FALSE(registered);
56   delete manager;
57 }
58 
TEST(Cancellation,DeregisterAfterCancel)59 TEST(Cancellation, DeregisterAfterCancel) {
60   bool is_cancelled = false;
61   CancellationManager* manager = new CancellationManager();
62   auto token = manager->get_cancellation_token();
63   bool registered = manager->RegisterCallback(
64       token, [&is_cancelled]() { is_cancelled = true; });
65   EXPECT_TRUE(registered);
66   manager->StartCancel();
67   EXPECT_TRUE(is_cancelled);
68   bool deregistered = manager->DeregisterCallback(token);
69   EXPECT_FALSE(deregistered);
70   delete manager;
71 }
72 
TEST(Cancellation,CancelMultiple)73 TEST(Cancellation, CancelMultiple) {
74   bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
75   CancellationManager* manager = new CancellationManager();
76   auto token_1 = manager->get_cancellation_token();
77   bool registered_1 = manager->RegisterCallback(
78       token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
79   EXPECT_TRUE(registered_1);
80   auto token_2 = manager->get_cancellation_token();
81   bool registered_2 = manager->RegisterCallback(
82       token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
83   EXPECT_TRUE(registered_2);
84   EXPECT_FALSE(is_cancelled_1);
85   EXPECT_FALSE(is_cancelled_2);
86   manager->StartCancel();
87   EXPECT_TRUE(is_cancelled_1);
88   EXPECT_TRUE(is_cancelled_2);
89   EXPECT_FALSE(is_cancelled_3);
90   auto token_3 = manager->get_cancellation_token();
91   bool registered_3 = manager->RegisterCallback(
92       token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
93   EXPECT_FALSE(registered_3);
94   EXPECT_FALSE(is_cancelled_3);
95   delete manager;
96 }
97 
TEST(Cancellation,IsCancelled)98 TEST(Cancellation, IsCancelled) {
99   CancellationManager* cm = new CancellationManager();
100   thread::ThreadPool w(Env::Default(), "test", 4);
101   std::vector<Notification> done(8);
102   for (size_t i = 0; i < done.size(); ++i) {
103     Notification* n = &done[i];
104     w.Schedule([n, cm]() {
105       while (!cm->IsCancelled()) {
106       }
107       n->Notify();
108     });
109   }
110   Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
111   cm->StartCancel();
112   for (size_t i = 0; i < done.size(); ++i) {
113     done[i].WaitForNotification();
114   }
115   delete cm;
116 }
117 
TEST(Cancellation,TryDeregisterWithoutCancel)118 TEST(Cancellation, TryDeregisterWithoutCancel) {
119   bool is_cancelled = false;
120   CancellationManager* manager = new CancellationManager();
121   auto token = manager->get_cancellation_token();
122   bool registered = manager->RegisterCallback(
123       token, [&is_cancelled]() { is_cancelled = true; });
124   EXPECT_TRUE(registered);
125   bool deregistered = manager->TryDeregisterCallback(token);
126   EXPECT_TRUE(deregistered);
127   delete manager;
128   EXPECT_FALSE(is_cancelled);
129 }
130 
TEST(Cancellation,TryDeregisterAfterCancel)131 TEST(Cancellation, TryDeregisterAfterCancel) {
132   bool is_cancelled = false;
133   CancellationManager* manager = new CancellationManager();
134   auto token = manager->get_cancellation_token();
135   bool registered = manager->RegisterCallback(
136       token, [&is_cancelled]() { is_cancelled = true; });
137   EXPECT_TRUE(registered);
138   manager->StartCancel();
139   EXPECT_TRUE(is_cancelled);
140   bool deregistered = manager->TryDeregisterCallback(token);
141   EXPECT_FALSE(deregistered);
142   delete manager;
143 }
144 
TEST(Cancellation,TryDeregisterDuringCancel)145 TEST(Cancellation, TryDeregisterDuringCancel) {
146   Notification cancel_started, finish_callback, cancel_complete;
147   CancellationManager* manager = new CancellationManager();
148   auto token = manager->get_cancellation_token();
149   bool registered = manager->RegisterCallback(token, [&]() {
150     cancel_started.Notify();
151     finish_callback.WaitForNotification();
152   });
153   EXPECT_TRUE(registered);
154 
155   thread::ThreadPool w(Env::Default(), "test", 1);
156   w.Schedule([&]() {
157     manager->StartCancel();
158     cancel_complete.Notify();
159   });
160   cancel_started.WaitForNotification();
161 
162   bool deregistered = manager->TryDeregisterCallback(token);
163   EXPECT_FALSE(deregistered);
164 
165   finish_callback.Notify();
166   cancel_complete.WaitForNotification();
167   delete manager;
168 }
169 
170 }  // namespace tensorflow
171