• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/framework/rendezvous.h"
21 #include "tensorflow/core/framework/session_state.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 
30 namespace tensorflow {
31 
32 class StepStatsCollector;
33 
34 // Executor runs a graph computation.
35 // Example:
36 //   Graph* graph = ...;
37 //      ... construct graph ...
38 //   Executor* executor;
39 //   TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
40 //   Rendezvous* rendezvous = NewNaiveRendezvous();
41 //   TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
42 //   TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
43 //   TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
44 //   ... ...
45 //
46 // Multiple threads can call Executor::Run concurrently.
47 class Executor {
48  public:
~Executor()49   virtual ~Executor() {}
50 
51   // RunAsync() executes the graph computation. "done" is run when the
52   // graph computation completes. If any error happens during the
53   // computation, "done" is run and the error is passed to "done".
54   //
55   // RunAsync() is given a few arguments in Args. The caller must
56   // ensure objects passed in Args (rendezvous, stats_collector, etc.)
57   // are alive at least until done is invoked. All pointers to the
58   // argument objects can be nullptr.
59   //
60   // "step_id" is a process-wide unique identifier for the step being
61   // run. Executors on different devices may receive the same step_id
62   // in the case that a step runs Ops on more than one device. The
63   // step_id is used for tracking resource usage of a given step.
64   //
65   // RunAsync() uses the given "rendezvous", if not null, as the
66   // mechanism to communicate inputs and outputs of the underlying
67   // graph computation.
68   //
69   // RunAsync() calls "stats_collector", if not null, to keep track of
70   // stats. This allows us to collect statistics and traces on demand.
71   //
72   // RunAsync() is provided a "call_frame", if the executor is used
73   // for executing a function, is used to pass arguments and return
74   // values between the caller and the callee.
75   //
76   // RunAsync() uses "cancellation_manager", if not nullptr, to
77   // register callbacks that should be called if the graph computation
78   // is canceled. Note that the callbacks merely unblock any
79   // long-running computation, and a canceled step will terminate by
80   // returning/calling the DoneCallback as usual.
81   //
82   // RunAsync() dispatches closures to "runner". Typically, "runner"
83   // is backed up by a bounded threadpool.
84   struct Args {
85     int64 step_id = 0;
86     Rendezvous* rendezvous = nullptr;
87     StepStatsCollectorInterface* stats_collector = nullptr;
88     CallFrameInterface* call_frame = nullptr;
89     CancellationManager* cancellation_manager = nullptr;
90     SessionState* session_state = nullptr;
91     // Unique session identifier. Can be empty.
92     string session_handle;
93     TensorStore* tensor_store = nullptr;
94     ScopedStepContainer* step_container = nullptr;
95     CollectiveExecutor* collective_executor = nullptr;
96 
97     // If true, calls Sync() on the device.
98     bool sync_on_finish = false;
99 
100     typedef std::function<void()> Closure;
101     typedef std::function<void(Closure)> Runner;
102     Runner runner = nullptr;
103   };
104   typedef std::function<void(const Status&)> DoneCallback;
105   virtual void RunAsync(const Args& args, DoneCallback done) = 0;
106 
107   // Synchronous wrapper for RunAsync().
Run(const Args & args)108   Status Run(const Args& args) {
109     Status ret;
110     Notification n;
111     RunAsync(args, [&ret, &n](const Status& s) {
112       ret = s;
113       n.Notify();
114     });
115     n.WaitForNotification();
116     return ret;
117   }
118 };
119 
120 // Creates an Executor that computes the given "graph".
121 //
122 // If successful, returns the constructed executor in "*executor". Otherwise,
123 // returns an error status.
124 //
125 // "params" provides a set of context for the executor. We expect that
126 // different context would provide different implementations.
127 struct LocalExecutorParams {
128   Device* device;
129 
130   // The library runtime support.
131   FunctionLibraryRuntime* function_library = nullptr;
132 
133   // create_kernel returns an instance of op kernel based on NodeDef.
134   // delete_kernel is called for every kernel used by the executor
135   // when the executor is deleted.
136   std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
137   std::function<void(OpKernel*)> delete_kernel;
138 };
139 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
140                                       std::unique_ptr<const Graph> graph,
141                                       Executor** executor);
142 
143 // A class to help run multiple executors in parallel and wait until
144 // all of them are complete.
145 //
146 // ExecutorBarrier deletes itself after the function returned by Get()
147 // is called.
148 class ExecutorBarrier {
149  public:
150   typedef std::function<void(const Status&)> StatusCallback;
151 
152   // Create an ExecutorBarrier for 'num' different executors.
153   //
154   // 'r' is the shared Rendezvous object that is used to communicate
155   // state.  If any of the executors experiences an error, the
156   // rendezvous object will be aborted exactly once.
157   //
158   // 'done' is called after the last executor completes, and
159   // ExecutorBarrier is deleted.
ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)160   ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
161       : rendez_(r), done_cb_(done), pending_(num) {}
162 
~ExecutorBarrier()163   ~ExecutorBarrier() {}
164 
165   // Returns a closure that Executors must call when they are done
166   // computing, passing the status of their execution as an argument.
Get()167   StatusCallback Get() {
168     return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
169   }
170 
171  private:
172   Rendezvous* rendez_ = nullptr;
173   StatusCallback done_cb_ = nullptr;
174 
175   mutable mutex mu_;
176   int pending_ GUARDED_BY(mu_) = 0;
177   StatusGroup status_group_ GUARDED_BY(mu_);
178 
WhenDone(const Status & s)179   void WhenDone(const Status& s) {
180     Rendezvous* error_rendez = nullptr;
181     StatusCallback done = nullptr;
182     Status status;
183 
184     {
185       mutex_lock l(mu_);
186 
187       // If we are the first error encountered, trigger an abort of the
188       // Rendezvous object by this thread only.
189       if (status_group_.ok() && !s.ok()) {
190         error_rendez = rendez_;
191         error_rendez->Ref();
192       }
193 
194       status_group_.Update(s);
195 
196       // If this is the last call to WhenDone, call the final callback
197       // below.
198       if (--pending_ == 0) {
199         CHECK(done_cb_ != nullptr);
200         std::swap(done, done_cb_);
201         status = status_group_.as_status();
202       }
203     }
204 
205     if (error_rendez != nullptr) {
206       error_rendez->StartAbort(
207           errors::Aborted("Stopping remaining executors."));
208       error_rendez->Unref();
209     }
210 
211     if (done != nullptr) {
212       delete this;
213       done(status);
214     }
215   }
216 
217   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
218 };
219 
220 // A few helpers to facilitate create/delete kernels.
221 
222 // Creates a kernel based on "ndef" on device "device". The kernel can
223 // access the functions in the "flib". The caller takes ownership of
224 // returned "*kernel".
225 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
226                              const NodeDef& ndef, int graph_def_version,
227                              OpKernel** kernel);
228 
229 // Deletes "kernel" returned by CreateKernel.
230 void DeleteNonCachedKernel(OpKernel* kernel);
231 
232 }  // end namespace tensorflow
233 
234 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
235