• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
18 
19 #include "absl/time/time.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/local_executor_params.h"
23 #include "tensorflow/core/framework/rendezvous.h"
24 #include "tensorflow/core/framework/session_state.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/notification.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/threadpool_interface.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/util/managed_stack_trace.h"
34 
35 namespace tensorflow {
36 
37 class StepStatsCollector;
38 
39 // Executor runs a graph computation.
40 // Example:
41 //   Graph* graph = ...;
42 //      ... construct graph ...
43 //   Executor* executor;
44 //   TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
45 //   Rendezvous* rendezvous = NewNaiveRendezvous();
46 //   TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
47 //   TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
48 //   TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
49 //   ... ...
50 //
51 // Multiple threads can call Executor::Run concurrently.
52 class Executor {
53  public:
~Executor()54   virtual ~Executor() {}
55 
56   // RunAsync() executes the graph computation. "done" is run when the
57   // graph computation completes. If any error happens during the
58   // computation, "done" is run and the error is passed to "done".
59   //
60   // RunAsync() is given a few arguments in Args. The caller must
61   // ensure objects passed in Args (rendezvous, stats_collector, etc.)
62   // are alive at least until done is invoked. All pointers to the
63   // argument objects can be nullptr.
64   //
65   // "step_id" is a process-wide unique identifier for the step being
66   // run. Executors on different devices may receive the same step_id
67   // in the case that a step runs Ops on more than one device. The
68   // step_id is used for tracking resource usage of a given step.
69   //
70   // RunAsync() uses the given "rendezvous", if not null, as the
71   // mechanism to communicate inputs and outputs of the underlying
72   // graph computation.
73   //
74   // RunAsync() calls "stats_collector", if not null, to keep track of
75   // stats. This allows us to collect statistics and traces on demand.
76   //
77   // RunAsync() is provided a "call_frame", if the executor is used
78   // for executing a function, is used to pass arguments and return
79   // values between the caller and the callee.
80   //
81   // RunAsync() uses "cancellation_manager", if not nullptr, to
82   // register callbacks that should be called if the graph computation
83   // is canceled. Note that the callbacks merely unblock any
84   // long-running computation, and a canceled step will terminate by
85   // returning/calling the DoneCallback as usual.
86   //
87   // RunAsync() dispatches closures to "runner". Typically, "runner"
88   // is backed up by a bounded threadpool.
89   //
90   // "start_time_usecs" is a timestamp for the start of RunAsync()
91   // execution. Used for system-wide latency metrics.
92   struct Args {
93     int64_t step_id = 0;
94     RendezvousInterface* rendezvous = nullptr;
95     StepStatsCollectorInterface* stats_collector = nullptr;
96     CallFrameInterface* call_frame = nullptr;
97     CancellationManager* cancellation_manager = nullptr;
98     SessionState* session_state = nullptr;
99     // Unique session identifier. Can be empty.
100     string session_handle;
101     TensorStore* tensor_store = nullptr;
102     ScopedStepContainer* step_container = nullptr;
103     CollectiveExecutor* collective_executor = nullptr;
104     thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
105     CoordinationServiceAgent* coordination_service_agent = nullptr;
106     int64_t start_time_usecs = 0;
107     // The deadline for the kernel to complete by. Empty if unspecified.
108     absl::optional<absl::Time> deadline;
109     absl::optional<ManagedStackTrace> stack_trace = absl::nullopt;
110 
111     // If true, calls Sync() on the device.
112     bool sync_on_finish = false;
113 
114     typedef std::function<void()> Closure;
115     typedef std::function<void(Closure)> Runner;
116     Runner runner = nullptr;
117 
118     // If true, all kernels will be treated as "inexpensive", and hence executed
119     // on the scheduling thread.
120     bool run_all_kernels_inline = false;
121   };
122   typedef std::function<void(const Status&)> DoneCallback;
123   virtual void RunAsync(const Args& args, DoneCallback done) = 0;
124 
125   // Synchronous wrapper for RunAsync().
Run(const Args & args)126   virtual Status Run(const Args& args) {
127     Status ret;
128     Notification n;
129     RunAsync(args, [&ret, &n](const Status& s) {
130       ret = s;
131       n.Notify();
132     });
133     n.WaitForNotification();
134     return ret;
135   }
136 };
137 
138 // Creates an Executor that computes the given "graph".
139 //
140 // If successful, returns the constructed executor in "*executor". Otherwise,
141 // returns an error status.
142 //
143 // "params" provides a set of context for the executor. We expect that
144 // different context would provide different implementations.
145 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
146                                       const Graph& graph, Executor** executor);
147 
148 // A class to help run multiple executors in parallel and wait until
149 // all of them are complete.
150 //
151 // ExecutorBarrier deletes itself after the function returned by Get()
152 // is called.
153 class ExecutorBarrier {
154  public:
155   typedef std::function<void(const Status&)> StatusCallback;
156 
157   // Create an ExecutorBarrier for 'num' different executors.
158   //
159   // 'r' is the shared Rendezvous object that is used to communicate
160   // state.  If any of the executors experiences an error, the
161   // rendezvous object will be aborted exactly once.
162   //
163   // 'done' is called after the last executor completes, and
164   // ExecutorBarrier is deleted.
ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)165   ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
166       : rendez_(r), done_cb_(done), pending_(num) {}
167 
~ExecutorBarrier()168   ~ExecutorBarrier() {}
169 
170   // Returns a closure that Executors must call when they are done
171   // computing, passing the status of their execution as an argument.
Get()172   StatusCallback Get() {
173     return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
174   }
175 
176  private:
177   Rendezvous* rendez_ = nullptr;
178   StatusCallback done_cb_ = nullptr;
179 
180   mutable mutex mu_;
181   int pending_ TF_GUARDED_BY(mu_) = 0;
182   StatusGroup status_group_ TF_GUARDED_BY(mu_);
183 
WhenDone(const Status & s)184   void WhenDone(const Status& s) {
185     Rendezvous* error_rendez = nullptr;
186     StatusCallback done = nullptr;
187     Status status;
188 
189     {
190       mutex_lock l(mu_);
191 
192       // If we are the first error encountered, trigger an abort of the
193       // Rendezvous object by this thread only.
194       if (status_group_.ok() && !s.ok()) {
195         error_rendez = rendez_;
196         error_rendez->Ref();
197       }
198 
199       if (!s.ok() && !StatusGroup::IsDerived(s) &&
200           !status_group_.HasLogMessages()) {
201         status_group_.AttachLogMessages();
202       }
203 
204       status_group_.Update(s);
205 
206       // If this is the last call to WhenDone, call the final callback
207       // below.
208       if (--pending_ == 0) {
209         CHECK(done_cb_ != nullptr);
210         std::swap(done, done_cb_);
211         status = status_group_.as_summary_status();
212       }
213     }
214 
215     if (error_rendez != nullptr) {
216       error_rendez->StartAbort(
217           errors::Aborted("Stopping remaining executors."));
218       error_rendez->Unref();
219     }
220 
221     if (done != nullptr) {
222       delete this;
223       if (!status.ok()) {
224         VLOG(1) << "ExecutorBarrier finished with bad status: " << status;
225       }
226       done(status);
227     }
228   }
229 
230   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
231 };
232 
233 // A few helpers to facilitate create/delete kernels.
234 
235 // Creates a kernel based on "props" on device "device". The kernel can
236 // access the functions in the "flib". The caller takes ownership of
237 // returned "*kernel".
238 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
239                              const std::shared_ptr<const NodeProperties>& props,
240                              int graph_def_version, OpKernel** kernel);
241 
242 // Deletes "kernel" returned by CreateKernel.
243 void DeleteNonCachedKernel(OpKernel* kernel);
244 
245 }  // end namespace tensorflow
246 
247 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
248