1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/training/coordinator.h"
17
18 #include "tensorflow/cc/training/queue_runner.h"
19 #include "tensorflow/core/lib/core/blocking_counter.h"
20 #include "tensorflow/core/lib/core/error_codes.pb.h"
21 #include "tensorflow/core/lib/core/notification.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/public/session.h"
26
27 namespace tensorflow {
28 namespace {
29
30 using error::Code;
31
WaitForStopThread(Coordinator * coord,Notification * about_to_wait,Notification * done)32 void WaitForStopThread(Coordinator* coord, Notification* about_to_wait,
33 Notification* done) {
34 about_to_wait->Notify();
35 coord->WaitForStop();
36 done->Notify();
37 }
38
TEST(CoordinatorTest,TestStopAndWaitOnStop)39 TEST(CoordinatorTest, TestStopAndWaitOnStop) {
40 Coordinator coord;
41 EXPECT_EQ(coord.ShouldStop(), false);
42
43 Notification about_to_wait;
44 Notification done;
45 Env::Default()->SchedClosure(
46 std::bind(&WaitForStopThread, &coord, &about_to_wait, &done));
47 about_to_wait.WaitForNotification();
48 Env::Default()->SleepForMicroseconds(1000 * 1000);
49 EXPECT_FALSE(done.HasBeenNotified());
50
51 TF_EXPECT_OK(coord.RequestStop());
52 done.WaitForNotification();
53 EXPECT_TRUE(coord.ShouldStop());
54 }
55
56 class MockQueueRunner : public RunnerInterface {
57 public:
MockQueueRunner(Coordinator * coord)58 explicit MockQueueRunner(Coordinator* coord) {
59 coord_ = coord;
60 join_counter_ = nullptr;
61 thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10));
62 stopped_ = false;
63 }
64
MockQueueRunner(Coordinator * coord,int * join_counter)65 MockQueueRunner(Coordinator* coord, int* join_counter)
66 : MockQueueRunner(coord) {
67 join_counter_ = join_counter;
68 }
69
StartCounting(std::atomic<int> * counter,int until,Notification * start=nullptr)70 void StartCounting(std::atomic<int>* counter, int until,
71 Notification* start = nullptr) {
72 thread_pool_->Schedule(
73 std::bind(&MockQueueRunner::CountThread, this, counter, until, start));
74 }
75
StartSettingStatus(const Status & status,BlockingCounter * counter,Notification * start)76 void StartSettingStatus(const Status& status, BlockingCounter* counter,
77 Notification* start) {
78 thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this,
79 status, counter, start));
80 }
81
Join()82 Status Join() override {
83 if (join_counter_ != nullptr) {
84 (*join_counter_)++;
85 }
86 thread_pool_.reset();
87 return status_;
88 }
89
GetStatus()90 Status GetStatus() { return status_; }
91
SetStatus(const Status & status)92 void SetStatus(const Status& status) { status_ = status; }
93
IsRunning() const94 bool IsRunning() const override { return !stopped_; };
95
Stop()96 void Stop() { stopped_ = true; }
97
98 private:
CountThread(std::atomic<int> * counter,int until,Notification * start)99 void CountThread(std::atomic<int>* counter, int until, Notification* start) {
100 if (start != nullptr) start->WaitForNotification();
101 while (!coord_->ShouldStop() && counter->load() < until) {
102 (*counter)++;
103 Env::Default()->SleepForMicroseconds(10 * 1000);
104 }
105 coord_->RequestStop().IgnoreError();
106 }
SetStatusThread(const Status & status,BlockingCounter * counter,Notification * start)107 void SetStatusThread(const Status& status, BlockingCounter* counter,
108 Notification* start) {
109 start->WaitForNotification();
110 SetStatus(status);
111 counter->DecrementCount();
112 }
113 std::unique_ptr<thread::ThreadPool> thread_pool_;
114 Status status_;
115 Coordinator* coord_;
116 int* join_counter_;
117 bool stopped_;
118 };
119
TEST(CoordinatorTest,TestRealStop)120 TEST(CoordinatorTest, TestRealStop) {
121 std::atomic<int> counter(0);
122 Coordinator coord;
123
124 std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
125 qr1->StartCounting(&counter, 100);
126 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
127
128 std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
129 qr2->StartCounting(&counter, 100);
130 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
131
132 // Wait until the counting has started
133 while (counter.load() == 0)
134 ;
135 TF_EXPECT_OK(coord.RequestStop());
136
137 int temp_counter = counter.load();
138 Env::Default()->SleepForMicroseconds(1000 * 1000);
139 EXPECT_EQ(temp_counter, counter.load());
140 TF_EXPECT_OK(coord.Join());
141 }
142
TEST(CoordinatorTest,TestRequestStop)143 TEST(CoordinatorTest, TestRequestStop) {
144 Coordinator coord;
145 std::atomic<int> counter(0);
146 Notification start;
147 std::unique_ptr<MockQueueRunner> qr;
148 for (int i = 0; i < 10; i++) {
149 qr.reset(new MockQueueRunner(&coord));
150 qr->StartCounting(&counter, 10, &start);
151 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
152 }
153 start.Notify();
154
155 coord.WaitForStop();
156 EXPECT_EQ(coord.ShouldStop(), true);
157 EXPECT_EQ(counter.load(), 10);
158 TF_EXPECT_OK(coord.Join());
159 }
160
TEST(CoordinatorTest,TestJoin)161 TEST(CoordinatorTest, TestJoin) {
162 Coordinator coord;
163 int join_counter = 0;
164 std::unique_ptr<MockQueueRunner> qr1(
165 new MockQueueRunner(&coord, &join_counter));
166 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
167 std::unique_ptr<MockQueueRunner> qr2(
168 new MockQueueRunner(&coord, &join_counter));
169 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
170
171 TF_EXPECT_OK(coord.RequestStop());
172 TF_EXPECT_OK(coord.Join());
173 EXPECT_EQ(join_counter, 2);
174 }
175
TEST(CoordinatorTest,StatusReporting)176 TEST(CoordinatorTest, StatusReporting) {
177 Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
178 Notification start;
179 BlockingCounter counter(3);
180
181 std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
182 qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start);
183 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
184
185 std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
186 qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start);
187 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
188
189 std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
190 qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start);
191 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3)));
192
193 start.Notify();
194 counter.Wait();
195 TF_EXPECT_OK(coord.RequestStop());
196 EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
197 }
198
TEST(CoordinatorTest,JoinWithoutStop)199 TEST(CoordinatorTest, JoinWithoutStop) {
200 Coordinator coord;
201 std::unique_ptr<MockQueueRunner> qr(new MockQueueRunner(&coord));
202 TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
203
204 EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION);
205 }
206
TEST(CoordinatorTest,AllRunnersStopped)207 TEST(CoordinatorTest, AllRunnersStopped) {
208 Coordinator coord;
209 MockQueueRunner* qr = new MockQueueRunner(&coord);
210 TF_ASSERT_OK(coord.RegisterRunner(std::unique_ptr<RunnerInterface>(qr)));
211
212 EXPECT_FALSE(coord.AllRunnersStopped());
213 qr->Stop();
214 EXPECT_TRUE(coord.AllRunnersStopped());
215 }
216
217 } // namespace
218 } // namespace tensorflow
219