• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
29 #include "tensorflow/core/framework/rendezvous.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/refcount.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/public/version.h"
38 
39 namespace tensorflow {
40 
41 class AsyncEagerNode;
42 class AsyncRemoteExecuteNode;
43 namespace eager {
44 class EagerClient;
45 }
46 
47 // A unit of execution for the EagerExecutor class below. Example subclasses
48 // encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
49 // device to another.
50 class EagerNode {
51  public:
EagerNode()52   EagerNode() {}
53 
~EagerNode()54   virtual ~EagerNode() {}
55 
56   // Prepares the node when adding it into EagerExecutor. If any errors happens,
57   // EagerExecutor will abort the node immediately.
Prepare()58   virtual Status Prepare() { return OkStatus(); }
59 
60   // Runs the computation corresponding to this node and blocks till the
61   // execution is done.
62   virtual Status Run() = 0;
63 
64   // Called when this node will not be run due to some error contained in
65   // `status`. `status` must not be OK.
66   // For example, if the node would have computed some tensors in the Run(),
67   // it should poison the corresponding tensor handles in this method.
68   virtual void Abort(Status status) = 0;
69 
70   // Returns nullptr iff this Eager node is synchronous.
AsAsync()71   virtual AsyncEagerNode* AsAsync() { return nullptr; }
AsAsyncRemoteExecuteNode()72   virtual AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() { return nullptr; }
73 
74   virtual string DebugString() const = 0;
75 
76   // Indicates whether a node failure should make the executor unusable.
Fatal()77   virtual bool Fatal() const { return true; }
78 };
79 
80 class AsyncEagerNode : public EagerNode {
81  public:
82   using EagerNode::EagerNode;  // Lift EagerNode constructors.
83 
84   // This node will be cleaned up once the done callback is called.
85   virtual void RunAsync(StatusCallback done) = 0;
86 
AsAsync()87   AsyncEagerNode* AsAsync() final { return this; }
88 
Run()89   Status Run() final {
90     return errors::Unimplemented("Don't call AsyncEagerNode::Run().");
91   }
92 };
93 
94 class AsyncRemoteExecuteNode : public AsyncEagerNode {
95  public:
AsAsyncRemoteExecuteNode()96   AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() final { return this; }
97 
98   virtual const eager::EagerClient* eager_client() const = 0;
99   virtual bool needs_remote_inputs() const = 0;
100   virtual bool allow_multiple_pending_requests() const = 0;
101   virtual Status SyncExecutors() = 0;
102 };
103 
104 // A class for handling async execution (see TFE_ContextSetAsync).
105 // Note that this class is thread-safe.
106 // TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
107 // device of the input handle. Fix that.
108 // TODO(agarwal): Implement support for control dependencies.
109 // TODO(agarwal): Support out-of-order execution and dispatching multiple
110 // EagerNode in parallel.
111 // TODO(agarwal): Implement optimizations over EagerNode traces.
112 class EagerExecutor {
113  public:
114   explicit EagerExecutor(bool async, bool enable_streaming_enqueue = true);
115 
116   ~EagerExecutor();
117 
118   // Puts this in a shutdown state. In this state, AddOrExecute() will return an
119   // error and not add new EagerNodes. After putting this in the shutdown state,
120   // blocks until all pendings nodes have finished running.
121   // Returns the status of executing pending nodes.
122   // If async was not enabled, aborts and destroys all pending nodes.
123   Status ShutDown();
124 
125   bool Async() const;
126 
127   bool StreamingEnqueue() const;
128 
129   // Inline execute node if executor is in sync mode.
130   Status SyncExecute(EagerNode* node);
131 
132   // - Async Mode: schedules `node` for execution.
133   // - Sync Mode: inline execute the 'node' directly.
134   // If an error occurs (e.g. EagerExecutor has already been shut down), the
135   // `node` is not added to this executor and its Abort() method is called.
136   Status AddOrExecute(std::unique_ptr<EagerNode> node);
137 
138   // Blocks till all currently pending ops are done.
139   // In particular, if EnableAsync() has not beed called, it will not return
140   // until that happens (and pendings, at the time of call, nodes finish
141   // running). If this executor has already been shut down, its final status is
142   // returned.
143   Status WaitForAllPendingNodes();
144 
145   // Clears all currently set errors which re-enables async execution.
146   void ClearError();
147 
148   // Returns Status based on any errors that occurred during async execution.
status()149   Status status() const {
150     if (ok()) return OkStatus();
151 
152     tf_shared_lock l(node_queue_mutex_);
153     return status_;
154   }
155 
ok()156   bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; }
157 
158   // On destruction, runs `callback`. Used by the EagerContext for clearing
159   // thread-local executors.
160   void AddCleanup(intptr_t key, std::function<void()> callback);
161   // If `key` (e.g. a context) is destroyed before the executor, the associated
162   // callbacks are no longer safe to run.
163   void RemoveCleanups(intptr_t key);
164 
165  private:
166   // Possible states for this executor.
167   // Executor starts in kActive state. When Shutdown() is called, Executor
168   // is put in the kShuttingDown state. In this state, the executor thread
169   // continues to run, but no new nodes are accepted. Finally, when all nodes
170   // are drained, the executor is put in the kShutDown state, which causes the
171   // thread to exit.
172   // If this executor is destroyed without calling shutdown first, it
173   // transitions to kShutDown state immediately which causes the thread to exit
174   // without running pending nodes.
175   enum class ExecutorState {
176     kActive,
177     kShuttingDown,
178     kShutDown,
179   };
180 
181   enum class NodeState {
182     kPENDING,
183     kSCHEDULED,
184     kDONE,
185   };
186 
187   struct NodeItem : core::RefCounted {
188     // Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it
189     // means item1.node is added before item2.node.
190     uint64 id;
191     std::unique_ptr<EagerNode> node;
192     NodeState state;
193   };
194 
195   const char* StateStringLocked()
196       TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
197 
198   void NodeDone(const core::RefCountPtr<NodeItem>& item, const Status& status,
199                 bool from_queue);
200   void NotifyWaiters(uint64 id) TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
201 
202   // Starts execution of pending EagerNodes. This function loops till executor
203   // state_ is set to kShutDown. If any errors are encountered, these are set
204   // inside `status_`. The loop blocks anytime there are no pending nodes, or if
205   // `status_` is not ok.
206   void Run();
207 
208   Status RunItem(core::RefCountPtr<NodeItem> item, bool from_queue);
209   Status MoveToUnfinished(core::RefCountPtr<NodeItem> item, bool from_queue);
210 
211   // The impl of WaitForAllPendingNodes
212   // `lock` is the lock that holds node_queue_mutex_.
213   Status WaitForAllPendingNodesLocked(mutex_lock* lock)
214       TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
215 
216   Status WaitImpl(bool wait_all, uint64 node_id);
217 
218   std::atomic<uint64> next_node_id_;
219 
220   mutable mutex node_queue_mutex_;
221 
222   // Used to signal that some EagerNodes are pending execution.
223   condition_variable nodes_pending_ TF_GUARDED_BY(node_queue_mutex_);
224 
225   // Queue of pending NodeItems. Ordered by NodeItem::id.
226   std::queue<core::RefCountPtr<NodeItem>> node_queue_
227       TF_GUARDED_BY(node_queue_mutex_);
228 
229   // Ordered by NodeItem::id.
230   std::map<uint64, core::RefCountPtr<NodeItem>, std::less<uint64>>
231       unfinished_nodes_ TF_GUARDED_BY(node_queue_mutex_);
232 
233   // `status_` is set based on any errors raised during execution of a
234   // EagerNode.  It remains set until ClearError is called.
235   Status status_ TF_GUARDED_BY(node_queue_mutex_);
236   std::atomic<bool> ok_ TF_GUARDED_BY(node_queue_mutex_);
237 
238   // Map from id of a EagerNode to condition_variables (not owned by the map).
239   // These condition_variables are notified and removed when that EagerNode is
240   // done executing, or if an error is found in execution of any EagerNode.
241   // The map is ordered by id.
242   std::multimap<uint64, condition_variable*, std::less<uint64>>
243       node_done_notifications_ TF_GUARDED_BY(node_queue_mutex_);
244 
245   // thread_exited_notification_ is notified by the `thread_` right before it
246   // exits.
247   Notification thread_exited_notification_;
248 
249   // When state_ is set to kShutDown, it indicates that `thread_` should stop as
250   // soon as it is done executing the current EagerNode.
251   ExecutorState state_ TF_GUARDED_BY(node_queue_mutex_) =
252       ExecutorState::kActive;
253 
254   // Thread object that calls the `Run` method in async mode.This thread runs
255   // until state_ is set to kShuttingDown. It is `nullptr` in sync mode.
256   const std::unique_ptr<Thread> thread_;
257 
258   // Last device where remote function with remote inputs was executed.
259   const eager::EagerClient* last_eager_client_;
260 
261   const bool enable_async_wait_for_remote_function_;
262 
263   // Enable sending remote executions through streaming enqueue.
264   const bool enable_streaming_enqueue_;
265 
266   // Callbacks to run on destruction.
267   std::unordered_map<intptr_t, std::vector<std::function<void()>>> cleanups_;
268 };
269 
Async()270 inline bool EagerExecutor::Async() const { return thread_ != nullptr; }
271 
StreamingEnqueue()272 inline bool EagerExecutor::StreamingEnqueue() const {
273   return enable_streaming_enqueue_;
274 }
275 
276 }  // namespace tensorflow
277 
278 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
279