1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ 17 #define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/framework/rendezvous.h" 21 #include "tensorflow/core/framework/session_state.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/lib/core/notification.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/macros.h" 28 29 namespace tensorflow { 30 31 class StepStatsCollector; 32 33 // Executor runs a graph computation. 34 // Example: 35 // Graph* graph = ...; 36 // ... construct graph ... 37 // Executor* executor; 38 // TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); 39 // Rendezvous* rendezvous = NewNaiveRendezvous(); 40 // TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); 41 // TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); 42 // TF_CHECK_OK(rendezvous->Recv("output", &output_tensor)); 43 // ... ... 44 // 45 // Multiple threads can call Executor::Run concurrently. 46 class Executor { 47 public: ~Executor()48 virtual ~Executor() {} 49 50 // RunAsync() executes the graph computation. "done" is run when the 51 // graph computation completes. If any error happens during the 52 // computation, "done" is run and the error is passed to "done". 53 // 54 // RunAsync() is given a few arguments in Args. The caller must 55 // ensure objects passed in Args (rendezvous, stats_collector, etc.) 56 // are alive at least until done is invoked. All pointers to the 57 // argument objects can be nullptr. 58 // 59 // "step_id" is a process-wide unique identifier for the step being 60 // run. Executors on different devices may receive the same step_id 61 // in the case that a step runs Ops on more than one device. The 62 // step_id is used for tracking resource usage of a given step. 63 // 64 // RunAsync() uses the given "rendezvous", if not null, as the 65 // mechanism to communicate inputs and outputs of the underlying 66 // graph computation. 67 // 68 // RunAsync() calls "stats_collector", if not null, to keep track of 69 // stats. This allows us to collect statistics and traces on demand. 70 // 71 // RunAsync() is provided a "call_frame", if the executor is used 72 // for executing a function, is used to pass arguments and return 73 // values between the caller and the callee. 74 // 75 // RunAsync() uses "cancellation_manager", if not nullptr, to 76 // register callbacks that should be called if the graph computation 77 // is canceled. Note that the callbacks merely unblock any 78 // long-running computation, and a canceled step will terminate by 79 // returning/calling the DoneCallback as usual. 80 // 81 // RunAsync() dispatches closures to "runner". Typically, "runner" 82 // is backed up by a bounded threadpool. 83 struct Args { 84 int64 step_id = 0; 85 Rendezvous* rendezvous = nullptr; 86 StepStatsCollector* stats_collector = nullptr; 87 CallFrameInterface* call_frame = nullptr; 88 CancellationManager* cancellation_manager = nullptr; 89 SessionState* session_state = nullptr; 90 TensorStore* tensor_store = nullptr; 91 ScopedStepContainer* step_container = nullptr; 92 93 // If true, calls Sync() on the device. 94 bool sync_on_finish = false; 95 96 typedef std::function<void()> Closure; 97 typedef std::function<void(Closure)> Runner; 98 Runner runner = nullptr; 99 100 // A callback that is invoked each time a node has finished executing. 101 typedef std::function<Status(const string& node_name, const int output_slot, 102 const Tensor* tensor, const bool is_ref, 103 OpKernelContext* ctx)> 104 NodeOutputsCallback; 105 NodeOutputsCallback node_outputs_cb = nullptr; 106 }; 107 typedef std::function<void(const Status&)> DoneCallback; 108 virtual void RunAsync(const Args& args, DoneCallback done) = 0; 109 110 // Synchronous wrapper for RunAsync(). Run(const Args & args)111 Status Run(const Args& args) { 112 Status ret; 113 Notification n; 114 RunAsync(args, [&ret, &n](const Status& s) { 115 ret = s; 116 n.Notify(); 117 }); 118 n.WaitForNotification(); 119 return ret; 120 } 121 }; 122 123 // Creates an Executor that computes the given "graph". 124 // 125 // If successful, returns the constructed executor in "*executor". Otherwise, 126 // returns an error status. 127 // 128 // "params" provides a set of context for the executor. We expect that 129 // different context would provide different implementations. 130 struct LocalExecutorParams { 131 Device* device; 132 133 // The library runtime support. 134 FunctionLibraryRuntime* function_library = nullptr; 135 136 // create_kernel returns an instance of op kernel based on NodeDef. 137 // delete_kernel is called for every kernel used by the executor 138 // when the executor is deleted. 139 std::function<Status(const NodeDef&, OpKernel**)> create_kernel; 140 std::function<void(OpKernel*)> delete_kernel; 141 142 Executor::Args::NodeOutputsCallback node_outputs_cb; 143 }; 144 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, 145 std::unique_ptr<const Graph> graph, 146 Executor** executor); 147 148 // A class to help run multiple executors in parallel and wait until 149 // all of them are complete. 150 // 151 // ExecutorBarrier deletes itself after the function returned by Get() 152 // is called. 153 class ExecutorBarrier { 154 public: 155 typedef std::function<void(const Status&)> StatusCallback; 156 157 // Create an ExecutorBarrier for 'num' different executors. 158 // 159 // 'r' is the shared Rendezvous object that is used to communicate 160 // state. If any of the executors experiences an error, the 161 // rendezvous object will be aborted exactly once. 162 // 163 // 'done' is called after the last executor completes, and 164 // ExecutorBarrier is deleted. ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)165 ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done) 166 : rendez_(r), done_cb_(done), pending_(num) {} 167 ~ExecutorBarrier()168 ~ExecutorBarrier() {} 169 170 // Returns a closure that Executors must call when they are done 171 // computing, passing the status of their execution as an argument. Get()172 StatusCallback Get() { 173 return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); 174 } 175 176 private: 177 Rendezvous* rendez_ = nullptr; 178 StatusCallback done_cb_ = nullptr; 179 180 mutable mutex mu_; 181 int pending_ GUARDED_BY(mu_) = 0; 182 Status status_ GUARDED_BY(mu_); 183 WhenDone(const Status & s)184 void WhenDone(const Status& s) { 185 bool error = false; 186 Rendezvous* error_rendez = nullptr; 187 StatusCallback done = nullptr; 188 Status status; 189 { 190 mutex_lock l(mu_); 191 // If we are the first error encountered, mark the status 192 // appropriately and later trigger an abort of the Rendezvous 193 // object by this thread only. 194 if (status_.ok() && !s.ok()) { 195 error = true; 196 error_rendez = rendez_; 197 error_rendez->Ref(); 198 status_ = s; 199 } 200 201 // If this is the last call to WhenDone, call the final callback 202 // below. 203 if (--pending_ == 0) { 204 CHECK(done_cb_ != nullptr); 205 std::swap(done, done_cb_); 206 } 207 208 if (!status_.ok()) { 209 status = status_; 210 } 211 } 212 213 if (error) { 214 error_rendez->StartAbort(status); 215 error_rendez->Unref(); 216 } 217 if (done != nullptr) { 218 delete this; 219 done(status); 220 } 221 } 222 223 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier); 224 }; 225 226 // A few helpers to facilitate create/delete kernels. 227 228 // Creates a kernel based on "ndef" on device "device". The kernel can 229 // access the functions in the "flib". The caller takes ownership of 230 // returned "*kernel". 231 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, 232 const NodeDef& ndef, int graph_def_version, 233 OpKernel** kernel); 234 235 // Deletes "kernel" returned by CreateKernel. 236 void DeleteNonCachedKernel(OpKernel* kernel); 237 238 } // end namespace tensorflow 239 240 #endif // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ 241