• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
17 #define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/framework/rendezvous.h"
21 #include "tensorflow/core/framework/session_state.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/core/notification.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace tensorflow {
30 
31 class StepStatsCollector;
32 
33 // Executor runs a graph computation.
34 // Example:
35 //   Graph* graph = ...;
36 //      ... construct graph ...
37 //   Executor* executor;
38 //   TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
39 //   Rendezvous* rendezvous = NewNaiveRendezvous();
40 //   TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
41 //   TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
42 //   TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
43 //   ... ...
44 //
45 // Multiple threads can call Executor::Run concurrently.
46 class Executor {
47  public:
~Executor()48   virtual ~Executor() {}
49 
50   // RunAsync() executes the graph computation. "done" is run when the
51   // graph computation completes. If any error happens during the
52   // computation, "done" is run and the error is passed to "done".
53   //
54   // RunAsync() is given a few arguments in Args. The caller must
55   // ensure objects passed in Args (rendezvous, stats_collector, etc.)
56   // are alive at least until done is invoked. All pointers to the
57   // argument objects can be nullptr.
58   //
59   // "step_id" is a process-wide unique identifier for the step being
60   // run. Executors on different devices may receive the same step_id
61   // in the case that a step runs Ops on more than one device. The
62   // step_id is used for tracking resource usage of a given step.
63   //
64   // RunAsync() uses the given "rendezvous", if not null, as the
65   // mechanism to communicate inputs and outputs of the underlying
66   // graph computation.
67   //
68   // RunAsync() calls "stats_collector", if not null, to keep track of
69   // stats. This allows us to collect statistics and traces on demand.
70   //
71   // RunAsync() is provided a "call_frame", if the executor is used
72   // for executing a function, is used to pass arguments and return
73   // values between the caller and the callee.
74   //
75   // RunAsync() uses "cancellation_manager", if not nullptr, to
76   // register callbacks that should be called if the graph computation
77   // is canceled. Note that the callbacks merely unblock any
78   // long-running computation, and a canceled step will terminate by
79   // returning/calling the DoneCallback as usual.
80   //
81   // RunAsync() dispatches closures to "runner". Typically, "runner"
82   // is backed up by a bounded threadpool.
83   struct Args {
84     int64 step_id = 0;
85     Rendezvous* rendezvous = nullptr;
86     StepStatsCollector* stats_collector = nullptr;
87     CallFrameInterface* call_frame = nullptr;
88     CancellationManager* cancellation_manager = nullptr;
89     SessionState* session_state = nullptr;
90     TensorStore* tensor_store = nullptr;
91     ScopedStepContainer* step_container = nullptr;
92 
93     // If true, calls Sync() on the device.
94     bool sync_on_finish = false;
95 
96     typedef std::function<void()> Closure;
97     typedef std::function<void(Closure)> Runner;
98     Runner runner = nullptr;
99 
100     // A callback that is invoked each time a node has finished executing.
101     typedef std::function<Status(const string& node_name, const int output_slot,
102                                  const Tensor* tensor, const bool is_ref,
103                                  OpKernelContext* ctx)>
104         NodeOutputsCallback;
105     NodeOutputsCallback node_outputs_cb = nullptr;
106   };
107   typedef std::function<void(const Status&)> DoneCallback;
108   virtual void RunAsync(const Args& args, DoneCallback done) = 0;
109 
110   // Synchronous wrapper for RunAsync().
Run(const Args & args)111   Status Run(const Args& args) {
112     Status ret;
113     Notification n;
114     RunAsync(args, [&ret, &n](const Status& s) {
115       ret = s;
116       n.Notify();
117     });
118     n.WaitForNotification();
119     return ret;
120   }
121 };
122 
123 // Creates an Executor that computes the given "graph".
124 //
125 // If successful, returns the constructed executor in "*executor". Otherwise,
126 // returns an error status.
127 //
128 // "params" provides a set of context for the executor. We expect that
129 // different context would provide different implementations.
130 struct LocalExecutorParams {
131   Device* device;
132 
133   // The library runtime support.
134   FunctionLibraryRuntime* function_library = nullptr;
135 
136   // create_kernel returns an instance of op kernel based on NodeDef.
137   // delete_kernel is called for every kernel used by the executor
138   // when the executor is deleted.
139   std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
140   std::function<void(OpKernel*)> delete_kernel;
141 
142   Executor::Args::NodeOutputsCallback node_outputs_cb;
143 };
144 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
145                                       std::unique_ptr<const Graph> graph,
146                                       Executor** executor);
147 
148 // A class to help run multiple executors in parallel and wait until
149 // all of them are complete.
150 //
151 // ExecutorBarrier deletes itself after the function returned by Get()
152 // is called.
153 class ExecutorBarrier {
154  public:
155   typedef std::function<void(const Status&)> StatusCallback;
156 
157   // Create an ExecutorBarrier for 'num' different executors.
158   //
159   // 'r' is the shared Rendezvous object that is used to communicate
160   // state.  If any of the executors experiences an error, the
161   // rendezvous object will be aborted exactly once.
162   //
163   // 'done' is called after the last executor completes, and
164   // ExecutorBarrier is deleted.
ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)165   ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
166       : rendez_(r), done_cb_(done), pending_(num) {}
167 
~ExecutorBarrier()168   ~ExecutorBarrier() {}
169 
170   // Returns a closure that Executors must call when they are done
171   // computing, passing the status of their execution as an argument.
Get()172   StatusCallback Get() {
173     return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
174   }
175 
176  private:
177   Rendezvous* rendez_ = nullptr;
178   StatusCallback done_cb_ = nullptr;
179 
180   mutable mutex mu_;
181   int pending_ GUARDED_BY(mu_) = 0;
182   Status status_ GUARDED_BY(mu_);
183 
WhenDone(const Status & s)184   void WhenDone(const Status& s) {
185     bool error = false;
186     Rendezvous* error_rendez = nullptr;
187     StatusCallback done = nullptr;
188     Status status;
189     {
190       mutex_lock l(mu_);
191       // If we are the first error encountered, mark the status
192       // appropriately and later trigger an abort of the Rendezvous
193       // object by this thread only.
194       if (status_.ok() && !s.ok()) {
195         error = true;
196         error_rendez = rendez_;
197         error_rendez->Ref();
198         status_ = s;
199       }
200 
201       // If this is the last call to WhenDone, call the final callback
202       // below.
203       if (--pending_ == 0) {
204         CHECK(done_cb_ != nullptr);
205         std::swap(done, done_cb_);
206       }
207 
208       if (!status_.ok()) {
209         status = status_;
210       }
211     }
212 
213     if (error) {
214       error_rendez->StartAbort(status);
215       error_rendez->Unref();
216     }
217     if (done != nullptr) {
218       delete this;
219       done(status);
220     }
221   }
222 
223   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
224 };
225 
226 // A few helpers to facilitate create/delete kernels.
227 
228 // Creates a kernel based on "ndef" on device "device". The kernel can
229 // access the functions in the "flib". The caller takes ownership of
230 // returned "*kernel".
231 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
232                              const NodeDef& ndef, int graph_def_version,
233                              OpKernel** kernel);
234 
235 // Deletes "kernel" returned by CreateKernel.
236 void DeleteNonCachedKernel(OpKernel* kernel);
237 
238 }  // end namespace tensorflow
239 
240 #endif  // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
241