1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ 17 #define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_set> 22 #include <vector> 23 24 #include "tensorflow/cc/training/coordinator.h" 25 #include "tensorflow/core/lib/core/blocking_counter.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/lib/core/threadpool.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/protobuf/config.pb.h" 30 #include "tensorflow/core/protobuf/error_codes.pb.h" 31 #include "tensorflow/core/protobuf/queue_runner.pb.h" 32 #include "tensorflow/core/public/session.h" 33 34 namespace tensorflow { 35 36 /// QueueRunner class imitates the behavior of the python version of QueueRunner 37 /// which creates a thread for each enqueue op, runs close op on completion. 38 class QueueRunner : public RunnerInterface { 39 public: 40 /// Creates a new QueueRunner from proto. 41 // TODO(yuefengz): we may want to initialize from queues and ops in the 42 // future. 43 static Status New(const QueueRunnerDef& queue_runner_def, 44 std::unique_ptr<QueueRunner>* result); 45 46 /// Creates a new QueueRunner with a coordinator, see coordinator.h for usage. 47 static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord, 48 std::unique_ptr<QueueRunner>* result); 49 50 /// Adds a callback that the queue runner will call when it detects an error. 51 void AddErrorCallback(const std::function<void(Status)>& cb); 52 53 /// Delete the previously registered callbacks. 54 void ClearErrorCallbacks(); 55 56 /// The destructor would join all the threads. 57 ~QueueRunner(); 58 59 /// Starts the queue runner with the given session. 60 Status Start(Session* sess); 61 62 /// Starts the queue runner with the given session and sets the run arguments 63 /// for sess->Run. It also collects and stores the cost model. 64 Status StartAndCollectCostGraph(Session* sess, 65 const RunOptions& run_options = RunOptions()); 66 67 /// Starts the queue runner with the given session, and wait for up to the 68 /// specified time (in milliseconds) for the queues to start to fill up. 69 Status Start(Session* sess, int wait_for_ms); 70 Status StartAndCollectCostGraph(Session* session, int wait_for_ms, 71 const RunOptions& run_options = RunOptions()); 72 73 /// Requests to stop and runs the cancel op. It would be called in a separate 74 /// thread when coordinator is set. If there is no coordinator it should be 75 /// called before calling Join. 76 void Stop(Session* sess); 77 78 /// Joins all the threads. Returns okay if all threads run successfully; 79 /// otherwise returns the first captured failure status. 80 Status Join() final; 81 82 /// Returns the latest status. 83 Status GetStatus(); 84 85 // Returns the stored cost model. 86 Status ExportCostGraph(CostGraphDef* cost_graph) const override; 87 88 private: QueueRunner()89 QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {} 90 91 // Initializes the instance with the QueueRunnerDef proto. 92 Status Init(const QueueRunnerDef& queue_runner_def); 93 94 // The Run function for each thread. 95 void Run(Session* sess, const string& enqueue_op); 96 97 // Updates the internal status; it only keeps OK or the first unexpected error 98 // status. 99 void UpdateStatus(const Status& status); 100 IsQueueClosed(Status status)101 bool IsQueueClosed(Status status) const { 102 return queue_closed_exception_types_.count( 103 static_cast<int>(status.code())) > 0; 104 } 105 IsRunning()106 bool IsRunning() const override { return !stopped_; } 107 108 void SetRunArgumentsAndCostGraph(const RunOptions& run_options); 109 110 Status RealRun(Session* sess, const string& op, bool update_costs); 111 112 string queue_name_; 113 std::vector<string> enqueue_op_names_; 114 string close_op_name_; 115 string cancel_op_name_; 116 // code::Code casted to int to avoid a hash function. 117 std::unordered_set<int> queue_closed_exception_types_; 118 119 std::unique_ptr<thread::ThreadPool> thread_pool_; 120 mutex mu_; 121 int runs_ = 0; 122 Status status_ TF_GUARDED_BY(mu_); 123 Status enqueue_status_ TF_GUARDED_BY(mu_); 124 std::unique_ptr<BlockingCounter> counter_; 125 126 Coordinator* coord_; 127 128 std::atomic<bool> stopped_; 129 130 mutex cb_mu_; 131 std::vector<std::function<void(Status)>> callbacks_; 132 133 mutable std::unique_ptr<mutex> cg_mu_; 134 std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_); 135 RunOptions run_options_; 136 }; 137 138 } // namespace tensorflow 139 140 #endif // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ 141