• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/local_executor_params.h"
21 #include "tensorflow/core/framework/rendezvous.h"
22 #include "tensorflow/core/framework/session_state.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/notification.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/threadpool_interface.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 class StepStatsCollector;
35 
36 // Executor runs a graph computation.
37 // Example:
38 //   Graph* graph = ...;
39 //      ... construct graph ...
40 //   Executor* executor;
41 //   TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
42 //   Rendezvous* rendezvous = NewNaiveRendezvous();
43 //   TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
44 //   TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
45 //   TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
46 //   ... ...
47 //
48 // Multiple threads can call Executor::Run concurrently.
49 class Executor {
50  public:
~Executor()51   virtual ~Executor() {}
52 
53   // RunAsync() executes the graph computation. "done" is run when the
54   // graph computation completes. If any error happens during the
55   // computation, "done" is run and the error is passed to "done".
56   //
57   // RunAsync() is given a few arguments in Args. The caller must
58   // ensure objects passed in Args (rendezvous, stats_collector, etc.)
59   // are alive at least until done is invoked. All pointers to the
60   // argument objects can be nullptr.
61   //
62   // "step_id" is a process-wide unique identifier for the step being
63   // run. Executors on different devices may receive the same step_id
64   // in the case that a step runs Ops on more than one device. The
65   // step_id is used for tracking resource usage of a given step.
66   //
67   // RunAsync() uses the given "rendezvous", if not null, as the
68   // mechanism to communicate inputs and outputs of the underlying
69   // graph computation.
70   //
71   // RunAsync() calls "stats_collector", if not null, to keep track of
72   // stats. This allows us to collect statistics and traces on demand.
73   //
74   // RunAsync() is provided a "call_frame", if the executor is used
75   // for executing a function, is used to pass arguments and return
76   // values between the caller and the callee.
77   //
78   // RunAsync() uses "cancellation_manager", if not nullptr, to
79   // register callbacks that should be called if the graph computation
80   // is canceled. Note that the callbacks merely unblock any
81   // long-running computation, and a canceled step will terminate by
82   // returning/calling the DoneCallback as usual.
83   //
84   // RunAsync() dispatches closures to "runner". Typically, "runner"
85   // is backed up by a bounded threadpool.
86   //
87   // "start_time_usecs" is a timestamp for the start of RunAsync()
88   // execution. Used for system-wide latency metrics.
89   struct Args {
90     int64 step_id = 0;
91     RendezvousInterface* rendezvous = nullptr;
92     StepStatsCollectorInterface* stats_collector = nullptr;
93     CallFrameInterface* call_frame = nullptr;
94     CancellationManager* cancellation_manager = nullptr;
95     SessionState* session_state = nullptr;
96     // Unique session identifier. Can be empty.
97     string session_handle;
98     TensorStore* tensor_store = nullptr;
99     ScopedStepContainer* step_container = nullptr;
100     CollectiveExecutor* collective_executor = nullptr;
101     thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
102     CoordinationServiceAgent* coordination_service_agent = nullptr;
103     int64 start_time_usecs = 0;
104 
105     // If true, calls Sync() on the device.
106     bool sync_on_finish = false;
107 
108     typedef std::function<void()> Closure;
109     typedef std::function<void(Closure)> Runner;
110     Runner runner = nullptr;
111 
112     // If true, all kernels will be treated as "inexpensive", and hence executed
113     // on the scheduling thread.
114     bool run_all_kernels_inline = false;
115   };
116   typedef std::function<void(const Status&)> DoneCallback;
117   virtual void RunAsync(const Args& args, DoneCallback done) = 0;
118 
119   // Synchronous wrapper for RunAsync().
Run(const Args & args)120   virtual Status Run(const Args& args) {
121     Status ret;
122     Notification n;
123     RunAsync(args, [&ret, &n](const Status& s) {
124       ret = s;
125       n.Notify();
126     });
127     n.WaitForNotification();
128     return ret;
129   }
130 };
131 
132 // Creates an Executor that computes the given "graph".
133 //
134 // If successful, returns the constructed executor in "*executor". Otherwise,
135 // returns an error status.
136 //
137 // "params" provides a set of context for the executor. We expect that
138 // different context would provide different implementations.
139 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
140                                       const Graph& graph, Executor** executor);
141 
142 // A class to help run multiple executors in parallel and wait until
143 // all of them are complete.
144 //
145 // ExecutorBarrier deletes itself after the function returned by Get()
146 // is called.
147 class ExecutorBarrier {
148  public:
149   typedef std::function<void(const Status&)> StatusCallback;
150 
151   // Create an ExecutorBarrier for 'num' different executors.
152   //
153   // 'r' is the shared Rendezvous object that is used to communicate
154   // state.  If any of the executors experiences an error, the
155   // rendezvous object will be aborted exactly once.
156   //
157   // 'done' is called after the last executor completes, and
158   // ExecutorBarrier is deleted.
ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)159   ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
160       : rendez_(r), done_cb_(done), pending_(num) {}
161 
~ExecutorBarrier()162   ~ExecutorBarrier() {}
163 
164   // Returns a closure that Executors must call when they are done
165   // computing, passing the status of their execution as an argument.
Get()166   StatusCallback Get() {
167     return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
168   }
169 
170  private:
171   Rendezvous* rendez_ = nullptr;
172   StatusCallback done_cb_ = nullptr;
173 
174   mutable mutex mu_;
175   int pending_ TF_GUARDED_BY(mu_) = 0;
176   StatusGroup status_group_ TF_GUARDED_BY(mu_);
177 
WhenDone(const Status & s)178   void WhenDone(const Status& s) {
179     Rendezvous* error_rendez = nullptr;
180     StatusCallback done = nullptr;
181     Status status;
182 
183     {
184       mutex_lock l(mu_);
185 
186       // If we are the first error encountered, trigger an abort of the
187       // Rendezvous object by this thread only.
188       if (status_group_.ok() && !s.ok()) {
189         error_rendez = rendez_;
190         error_rendez->Ref();
191       }
192 
193       if (!s.ok() && !StatusGroup::IsDerived(s) &&
194           !status_group_.HasLogMessages()) {
195         status_group_.AttachLogMessages();
196       }
197 
198       status_group_.Update(s);
199 
200       // If this is the last call to WhenDone, call the final callback
201       // below.
202       if (--pending_ == 0) {
203         CHECK(done_cb_ != nullptr);
204         std::swap(done, done_cb_);
205         status = status_group_.as_summary_status();
206       }
207     }
208 
209     if (error_rendez != nullptr) {
210       error_rendez->StartAbort(
211           errors::Aborted("Stopping remaining executors."));
212       error_rendez->Unref();
213     }
214 
215     if (done != nullptr) {
216       delete this;
217       if (!status.ok()) {
218         VLOG(1) << "ExecutorBarrier finished with bad status: " << status;
219       }
220       done(status);
221     }
222   }
223 
224   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
225 };
226 
227 // A few helpers to facilitate create/delete kernels.
228 
229 // Creates a kernel based on "props" on device "device". The kernel can
230 // access the functions in the "flib". The caller takes ownership of
231 // returned "*kernel".
232 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
233                              const std::shared_ptr<const NodeProperties>& props,
234                              int graph_def_version, OpKernel** kernel);
235 
236 // Deletes "kernel" returned by CreateKernel.
237 void DeleteNonCachedKernel(OpKernel* kernel);
238 
239 }  // end namespace tensorflow
240 
241 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
242