• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/training/coordinator.h"
17 
18 #include "tensorflow/cc/training/queue_runner.h"
19 #include "tensorflow/core/lib/core/blocking_counter.h"
20 #include "tensorflow/core/lib/core/error_codes.pb.h"
21 #include "tensorflow/core/lib/core/notification.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/public/session.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 using error::Code;
31 
WaitForStopThread(Coordinator * coord,Notification * about_to_wait,Notification * done)32 void WaitForStopThread(Coordinator* coord, Notification* about_to_wait,
33                        Notification* done) {
34   about_to_wait->Notify();
35   coord->WaitForStop();
36   done->Notify();
37 }
38 
TEST(CoordinatorTest,TestStopAndWaitOnStop)39 TEST(CoordinatorTest, TestStopAndWaitOnStop) {
40   Coordinator coord;
41   EXPECT_EQ(coord.ShouldStop(), false);
42 
43   Notification about_to_wait;
44   Notification done;
45   Env::Default()->SchedClosure(
46       std::bind(&WaitForStopThread, &coord, &about_to_wait, &done));
47   about_to_wait.WaitForNotification();
48   Env::Default()->SleepForMicroseconds(1000 * 1000);
49   EXPECT_FALSE(done.HasBeenNotified());
50 
51   TF_EXPECT_OK(coord.RequestStop());
52   done.WaitForNotification();
53   EXPECT_TRUE(coord.ShouldStop());
54 }
55 
56 class MockQueueRunner : public RunnerInterface {
57  public:
MockQueueRunner(Coordinator * coord)58   explicit MockQueueRunner(Coordinator* coord) {
59     coord_ = coord;
60     join_counter_ = nullptr;
61     thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10));
62     stopped_ = false;
63   }
64 
MockQueueRunner(Coordinator * coord,int * join_counter)65   MockQueueRunner(Coordinator* coord, int* join_counter)
66       : MockQueueRunner(coord) {
67     join_counter_ = join_counter;
68   }
69 
StartCounting(std::atomic<int> * counter,int until,Notification * start=nullptr)70   void StartCounting(std::atomic<int>* counter, int until,
71                      Notification* start = nullptr) {
72     thread_pool_->Schedule(
73         std::bind(&MockQueueRunner::CountThread, this, counter, until, start));
74   }
75 
StartSettingStatus(const Status & status,BlockingCounter * counter,Notification * start)76   void StartSettingStatus(const Status& status, BlockingCounter* counter,
77                           Notification* start) {
78     thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this,
79                                      status, counter, start));
80   }
81 
Join()82   Status Join() override {
83     if (join_counter_ != nullptr) {
84       (*join_counter_)++;
85     }
86     thread_pool_.reset();
87     return status_;
88   }
89 
GetStatus()90   Status GetStatus() { return status_; }
91 
SetStatus(const Status & status)92   void SetStatus(const Status& status) { status_ = status; }
93 
IsRunning() const94   bool IsRunning() const override { return !stopped_; };
95 
Stop()96   void Stop() { stopped_ = true; }
97 
98  private:
CountThread(std::atomic<int> * counter,int until,Notification * start)99   void CountThread(std::atomic<int>* counter, int until, Notification* start) {
100     if (start != nullptr) start->WaitForNotification();
101     while (!coord_->ShouldStop() && counter->load() < until) {
102       (*counter)++;
103       Env::Default()->SleepForMicroseconds(10 * 1000);
104     }
105     coord_->RequestStop().IgnoreError();
106   }
SetStatusThread(const Status & status,BlockingCounter * counter,Notification * start)107   void SetStatusThread(const Status& status, BlockingCounter* counter,
108                        Notification* start) {
109     start->WaitForNotification();
110     SetStatus(status);
111     counter->DecrementCount();
112   }
113   std::unique_ptr<thread::ThreadPool> thread_pool_;
114   Status status_;
115   Coordinator* coord_;
116   int* join_counter_;
117   bool stopped_;
118 };
119 
TEST(CoordinatorTest,TestRealStop)120 TEST(CoordinatorTest, TestRealStop) {
121   std::atomic<int> counter(0);
122   Coordinator coord;
123 
124   std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
125   qr1->StartCounting(&counter, 100);
126   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
127 
128   std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
129   qr2->StartCounting(&counter, 100);
130   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
131 
132   // Wait until the counting has started
133   while (counter.load() == 0)
134     ;
135   TF_EXPECT_OK(coord.RequestStop());
136 
137   int temp_counter = counter.load();
138   Env::Default()->SleepForMicroseconds(1000 * 1000);
139   EXPECT_EQ(temp_counter, counter.load());
140   TF_EXPECT_OK(coord.Join());
141 }
142 
TEST(CoordinatorTest,TestRequestStop)143 TEST(CoordinatorTest, TestRequestStop) {
144   Coordinator coord;
145   std::atomic<int> counter(0);
146   Notification start;
147   std::unique_ptr<MockQueueRunner> qr;
148   for (int i = 0; i < 10; i++) {
149     qr.reset(new MockQueueRunner(&coord));
150     qr->StartCounting(&counter, 10, &start);
151     TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
152   }
153   start.Notify();
154 
155   coord.WaitForStop();
156   EXPECT_EQ(coord.ShouldStop(), true);
157   EXPECT_EQ(counter.load(), 10);
158   TF_EXPECT_OK(coord.Join());
159 }
160 
TEST(CoordinatorTest,TestJoin)161 TEST(CoordinatorTest, TestJoin) {
162   Coordinator coord;
163   int join_counter = 0;
164   std::unique_ptr<MockQueueRunner> qr1(
165       new MockQueueRunner(&coord, &join_counter));
166   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
167   std::unique_ptr<MockQueueRunner> qr2(
168       new MockQueueRunner(&coord, &join_counter));
169   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
170 
171   TF_EXPECT_OK(coord.RequestStop());
172   TF_EXPECT_OK(coord.Join());
173   EXPECT_EQ(join_counter, 2);
174 }
175 
TEST(CoordinatorTest,StatusReporting)176 TEST(CoordinatorTest, StatusReporting) {
177   Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
178   Notification start;
179   BlockingCounter counter(3);
180 
181   std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
182   qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start);
183   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));
184 
185   std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
186   qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start);
187   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));
188 
189   std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
190   qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start);
191   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3)));
192 
193   start.Notify();
194   counter.Wait();
195   TF_EXPECT_OK(coord.RequestStop());
196   EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
197 }
198 
TEST(CoordinatorTest,JoinWithoutStop)199 TEST(CoordinatorTest, JoinWithoutStop) {
200   Coordinator coord;
201   std::unique_ptr<MockQueueRunner> qr(new MockQueueRunner(&coord));
202   TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
203 
204   EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION);
205 }
206 
TEST(CoordinatorTest,AllRunnersStopped)207 TEST(CoordinatorTest, AllRunnersStopped) {
208   Coordinator coord;
209   MockQueueRunner* qr = new MockQueueRunner(&coord);
210   TF_ASSERT_OK(coord.RegisterRunner(std::unique_ptr<RunnerInterface>(qr)));
211 
212   EXPECT_FALSE(coord.AllRunnersStopped());
213   qr->Stop();
214   EXPECT_TRUE(coord.AllRunnersStopped());
215 }
216 
217 }  // namespace
218 }  // namespace tensorflow
219