• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
17 #define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "tensorflow/cc/training/coordinator.h"
25 #include "tensorflow/core/lib/core/blocking_counter.h"
26 #include "tensorflow/core/lib/core/error_codes.pb.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/protobuf/config.pb.h"
31 #include "tensorflow/core/protobuf/queue_runner.pb.h"
32 #include "tensorflow/core/public/session.h"
33 
34 namespace tensorflow {
35 
36 /// QueueRunner class imitates the behavior of the python version of QueueRunner
37 /// which creates a thread for each enqueue op, runs close op on completion.
38 class QueueRunner : public RunnerInterface {
39  public:
40   /// Creates a new QueueRunner from proto.
41   // TODO(yuefengz): we may want to initialize from queues and ops in the
42   // future.
43   static Status New(const QueueRunnerDef& queue_runner_def,
44                     std::unique_ptr<QueueRunner>* result);
45 
46   /// Creates a new QueueRunner with a coordinator, see coordinator.h for usage.
47   static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
48                     std::unique_ptr<QueueRunner>* result);
49 
50   /// Adds a callback that the queue runner will call when it detects an error.
51   void AddErrorCallback(const std::function<void(Status)>& cb);
52 
53   /// Delete the previously registered callbacks.
54   void ClearErrorCallbacks();
55 
56   /// The destructor would join all the threads.
57   ~QueueRunner();
58 
59   /// Starts the queue runner with the given session.
60   Status Start(Session* sess);
61 
62   /// Starts the queue runner with the given session and sets the run arguments
63   /// for sess->Run. It also collects and stores the cost model.
64   Status StartAndCollectCostGraph(Session* sess,
65                                   const RunOptions& run_options = RunOptions());
66 
67   /// Starts the queue runner with the given session, and wait for up to the
68   /// specified time (in milliseconds) for the queues to start to fill up.
69   Status Start(Session* sess, int wait_for_ms);
70   Status StartAndCollectCostGraph(Session* session, int wait_for_ms,
71                                   const RunOptions& run_options = RunOptions());
72 
73   /// Requests to stop and runs the cancel op. It would be called in a separate
74   /// thread when coordinator is set. If there is no coordinator it should be
75   /// called before calling Join.
76   void Stop(Session* sess);
77 
78   /// Joins all the threads. Returns okay if all threads run successfully;
79   /// otherwise returns the first captured failure status.
80   Status Join() final;
81 
82   /// Returns the latest status.
83   Status GetStatus();
84 
85   // Returns the stored cost model.
86   Status ExportCostGraph(CostGraphDef* cost_graph) const override;
87 
88  private:
QueueRunner()89   QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {}
90 
91   // Initializes the instance with the QueueRunnerDef proto.
92   Status Init(const QueueRunnerDef& queue_runner_def);
93 
94   // The Run function for each thread.
95   void Run(Session* sess, const string& enqueue_op);
96 
97   // Updates the internal status; it only keeps OK or the first unexpected error
98   // status.
99   void UpdateStatus(const Status& status);
100 
IsQueueClosed(Status status)101   bool IsQueueClosed(Status status) const {
102     return queue_closed_exception_types_.count(
103                static_cast<int>(status.code())) > 0;
104   }
105 
IsRunning()106   bool IsRunning() const override { return !stopped_; }
107 
108   void SetRunArgumentsAndCostGraph(const RunOptions& run_options);
109 
110   Status RealRun(Session* sess, const string& op, bool update_costs);
111 
112   string queue_name_;
113   std::vector<string> enqueue_op_names_;
114   string close_op_name_;
115   string cancel_op_name_;
116   // code::Code casted to int to avoid a hash function.
117   std::unordered_set<int> queue_closed_exception_types_;
118 
119   std::unique_ptr<thread::ThreadPool> thread_pool_;
120   mutex mu_;
121   int runs_ = 0;
122   Status status_ GUARDED_BY(mu_);
123   Status enqueue_status_ GUARDED_BY(mu_);
124   std::unique_ptr<BlockingCounter> counter_;
125 
126   Coordinator* coord_;
127 
128   std::atomic<bool> stopped_;
129 
130   mutex cb_mu_;
131   std::vector<std::function<void(Status)>> callbacks_;
132 
133   mutable std::unique_ptr<mutex> cg_mu_;
134   std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
135   RunOptions run_options_;
136 };
137 
138 }  // namespace tensorflow
139 
140 #endif  // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
141