• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
29 #include "tensorflow/core/framework/rendezvous.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/refcount.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/public/version.h"
38 
39 namespace tensorflow {
40 
41 class AsyncEagerNode;
42 
43 // A unit of execution for the EagerExecutor class below. Example subclasses
44 // encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
45 // device to another.
46 class EagerNode {
47  public:
EagerNode()48   EagerNode() {}
49 
~EagerNode()50   virtual ~EagerNode() {}
51 
52   // Prepares the node when adding it into EagerExecutor. If any errors happens,
53   // EagerExecutor will abort the node immediately.
Prepare()54   virtual Status Prepare() { return Status::OK(); }
55 
56   // Runs the computation corresponding to this node and blocks till the
57   // execution is done.
58   virtual Status Run() = 0;
59 
60   // Called when this node will not be run due to some error contained in
61   // `status`. `status` must not be OK.
62   // For example, if the node would have computed some tensors in the Run(),
63   // it should poison the corresponding tensor handles in this method.
64   virtual void Abort(Status status) = 0;
65 
66   // Returns nullptr iff this Eager node is synchronous.
AsAsync()67   virtual AsyncEagerNode* AsAsync() { return nullptr; }
68 
69   virtual string DebugString() const = 0;
70 };
71 
72 class AsyncEagerNode : public EagerNode {
73  public:
74   using EagerNode::EagerNode;  // Lift EagerNode constructors.
75 
76   // This node will be cleaned up once the done callback is called.
77   virtual void RunAsync(StatusCallback done) = 0;
78 
AsAsync()79   AsyncEagerNode* AsAsync() final { return this; }
80 
Run()81   Status Run() final {
82     return errors::Unimplemented("Don't call AsyncEagerNode::Run().");
83   }
84 };
85 
86 // A class for handling async execution (see TFE_ContextSetAsync).
87 // Note that this class is thread-safe.
88 // TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
89 // device of the input handle. Fix that.
90 // TODO(agarwal): Implement support for control dependencies.
91 // TODO(agarwal): Support out-of-order execution and dispatching multiple
92 // EagerNode in parallel.
93 // TODO(agarwal): Implement optimizations over EagerNode traces.
94 class EagerExecutor {
95  public:
96   explicit EagerExecutor(bool async);
97 
98   ~EagerExecutor();
99 
100   // Puts this in a shutdown state. In this state, AddOrExecute() will return an
101   // error and not add new EagerNodes. After putting this in the shutdown state,
102   // blocks until all pendings nodes have finished running.
103   // Returns the status of executing pending nodes.
104   // If async was not enabled, aborts and destroys all pending nodes.
105   Status ShutDown();
106 
107   bool Async() const;
108 
109   // Inline execute node if executor is in sync mode.
110   Status SyncExecute(EagerNode* node);
111 
112   // - Async Mode: schedules `node` for execution.
113   // - Sync Mode: inline execute the 'node' directly.
114   // If an error occurs (e.g. EagerExecutor has already been shut down), the
115   // `node` is not added to this executor and its Abort() method is called.
116   Status AddOrExecute(std::unique_ptr<EagerNode> node);
117 
118   // Blocks till all currently pending ops are done.
119   // In particular, if EnableAsync() has not beed called, it will not return
120   // until that happens (and pendings, at the time of call, nodes finish
121   // running). If this executor has already been shut down, its final status is
122   // returned.
123   Status WaitForAllPendingNodes();
124 
125   // Clears all currently set errors which re-enables async execution.
126   void ClearError();
127 
128   // Returns Status based on any errors that occurred during async execution.
status()129   Status status() const {
130     if (ok()) return Status::OK();
131 
132     tf_shared_lock l(node_queue_mutex_);
133     return status_;
134   }
135 
ok()136   bool ok() const NO_THREAD_SAFETY_ANALYSIS { return ok_; }
137 
138  private:
139   // Possible states for this executor.
140   // Executor starts in kActive state. When Shutdown() is called, Executor
141   // is put in the kShuttingDown state. In this state, the executor thread
142   // continues to run, but no new nodes are accepted. Finally, when all nodes
143   // are drained, the executor is put in the kShutDown state, which causes the
144   // thread to exit.
145   // If this executor is destroyed without calling shutdown first, it
146   // transitions to kShutDown state immediately which causes the thread to exit
147   // without running pending nodes.
148   enum class ExecutorState {
149     kActive,
150     kShuttingDown,
151     kShutDown,
152   };
153 
154   enum class NodeState {
155     kPENDING,
156     kSCHEDULED,
157     kDONE,
158   };
159 
160   struct NodeItem : core::RefCounted {
161     // Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it
162     // means item1.node is added before item2.node.
163     uint64 id;
164     std::unique_ptr<EagerNode> node;
165     NodeState state;
166   };
167 
168   const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
169 
170   void NodeDone(const core::RefCountPtr<NodeItem>& item, const Status& status,
171                 bool from_queue);
172   void NotifyWaiters(uint64 id) EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
173 
174   // Starts execution of pending EagerNodes. This function loops till executor
175   // state_ is set to kShutDown. If any errors are encontered, these are set
176   // inside `status_`. The loop blocks anytime there are no pending nodes, or if
177   // `status_` is not ok.
178   void Run();
179 
180   Status RunItem(core::RefCountPtr<NodeItem> item, bool from_queue);
181   Status MoveToUnfinished(core::RefCountPtr<NodeItem> item, bool from_queue);
182 
183   // The impl of WaitForAllPendingNodes
184   // `lock` is the lock that holds node_queue_mutex_.
185   Status WaitForAllPendingNodesLocked(mutex_lock* lock)
186       EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
187 
188   Status WaitImpl(bool wait_all, uint64 node_id);
189 
190   std::atomic<uint64> next_node_id_;
191 
192   mutable mutex node_queue_mutex_;
193 
194   // Used to signal that some EagerNodes are pending execution.
195   condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
196 
197   // Queue of pending NodeItems. Ordered by NodeItem::id.
198   std::queue<core::RefCountPtr<NodeItem>> node_queue_
199       GUARDED_BY(node_queue_mutex_);
200 
201   // Ordered by NodeItem::id.
202   std::map<uint64, core::RefCountPtr<NodeItem>, std::less<uint64>>
203       unfinished_nodes_ GUARDED_BY(node_queue_mutex_);
204 
205   // `status_` is set based on any errors raised during execution of a
206   // EagerNode.  It remains set until ClearError is called.
207   Status status_ GUARDED_BY(node_queue_mutex_);
208   std::atomic<bool> ok_ GUARDED_BY(node_queue_mutex_);
209 
210   // Map from id of a EagerNode to condition_variables (not owned by the map).
211   // These condition_variables are notified and removed when that EagerNode is
212   // done executing, or if an error is found in execution of any EagerNode.
213   // The map is ordered by id.
214   std::multimap<uint64, condition_variable*, std::less<uint64>>
215       node_done_notifications_ GUARDED_BY(node_queue_mutex_);
216 
217   // thread_exited_notification_ is notified by the `thread_` right before it
218   // exits.
219   Notification thread_exited_notification_;
220 
221   // When state_ is set to kShutDown, it indicates that `thread_` should stop as
222   // soon as it is done executing the current EagerNode.
223   ExecutorState state_ GUARDED_BY(node_queue_mutex_) = ExecutorState::kActive;
224 
225   // Thread object that calls the `Run` method in async mode.This thread runs
226   // until state_ is set to kShuttingDown. It is `nullptr` in sync mode.
227   const std::unique_ptr<Thread> thread_;
228 };
229 
Async()230 inline bool EagerExecutor::Async() const { return thread_ != nullptr; }
231 
232 }  // namespace tensorflow
233 
234 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
235