1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/training/coordinator.h" 17 18 namespace tensorflow { 19 Coordinator()20Coordinator::Coordinator() : Coordinator(std::vector<error::Code>()) {} 21 Coordinator(const std::vector<error::Code> & clean_stop_errors)22Coordinator::Coordinator(const std::vector<error::Code>& clean_stop_errors) 23 : should_stop_(false) { 24 if (clean_stop_errors.empty()) { 25 clean_stop_errors_.insert(error::OUT_OF_RANGE); 26 } else { 27 for (const auto& code : clean_stop_errors) { 28 clean_stop_errors_.insert(static_cast<int>(code)); 29 } 30 } 31 } 32 ~Coordinator()33Coordinator::~Coordinator() { 34 RequestStop().IgnoreError(); 35 Join().IgnoreError(); 36 } 37 RegisterRunner(std::unique_ptr<RunnerInterface> runner)38Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) { 39 { 40 mutex_lock l(mu_); 41 if (should_stop_) { 42 return Status(error::FAILED_PRECONDITION, 43 "The coordinator has been stopped."); 44 } 45 } 46 mutex_lock l(runners_lock_); 47 runners_.push_back(std::move(runner)); 48 return Status::OK(); 49 } 50 AllRunnersStopped()51bool Coordinator::AllRunnersStopped() { 52 mutex_lock l(runners_lock_); 53 for (const auto& runner : runners_) { 54 if (runner->IsRunning()) { 55 return false; 56 } 57 } 58 return true; 59 } 60 RequestStop()61Status Coordinator::RequestStop() { 62 mutex_lock l(mu_); 63 if (should_stop_) { 64 return Status(error::FAILED_PRECONDITION, 65 "The Coordinator is not running."); 66 } 67 should_stop_ = true; 68 wait_for_stop_.notify_all(); 69 return Status::OK(); 70 } 71 ShouldStop()72bool Coordinator::ShouldStop() { 73 mutex_lock l(mu_); 74 return should_stop_; 75 } 76 Join()77Status Coordinator::Join() { 78 // TODO(yuefengz): deal with stragglers. 79 { 80 mutex_lock l(mu_); 81 if (!should_stop_) { 82 return Status(error::FAILED_PRECONDITION, 83 "Joining coordinator without requesting to stop."); 84 } 85 } 86 87 { 88 mutex_lock l(runners_lock_); 89 for (const auto& t : runners_) { 90 ReportStatus(t->Join()); 91 } 92 runners_.clear(); 93 } 94 return GetStatus(); 95 } 96 ReportStatus(const Status & status)97void Coordinator::ReportStatus(const Status& status) { 98 mutex_lock l(status_lock_); 99 if (status.ok() || !status_.ok() || 100 clean_stop_errors_.count(static_cast<int>(status.code())) > 0) { 101 return; 102 } 103 status_ = status; 104 } 105 GetStatus()106Status Coordinator::GetStatus() { 107 mutex_lock l(status_lock_); 108 return status_; 109 } 110 WaitForStop()111void Coordinator::WaitForStop() { 112 mutex_lock l(mu_); 113 while (!should_stop_) { 114 wait_for_stop_.wait(l); 115 } 116 } 117 ExportCostGraph(CostGraphDef * cost_graph) const118Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const { 119 mutex_lock l(runners_lock_); 120 for (auto& t : runners_) { 121 Status s = t->ExportCostGraph(cost_graph); 122 if (!s.ok()) { 123 return s; 124 } 125 } 126 return Status::OK(); 127 } 128 129 } // namespace tensorflow 130