1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/framework/rendezvous.h" 21 #include "tensorflow/core/framework/session_state.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/core/notification.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/platform/macros.h" 29 30 namespace tensorflow { 31 32 class StepStatsCollector; 33 34 // Executor runs a graph computation. 35 // Example: 36 // Graph* graph = ...; 37 // ... construct graph ... 38 // Executor* executor; 39 // TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); 40 // Rendezvous* rendezvous = NewNaiveRendezvous(); 41 // TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); 42 // TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); 43 // TF_CHECK_OK(rendezvous->Recv("output", &output_tensor)); 44 // ... ... 45 // 46 // Multiple threads can call Executor::Run concurrently. 47 class Executor { 48 public: ~Executor()49 virtual ~Executor() {} 50 51 // RunAsync() executes the graph computation. "done" is run when the 52 // graph computation completes. If any error happens during the 53 // computation, "done" is run and the error is passed to "done". 54 // 55 // RunAsync() is given a few arguments in Args. The caller must 56 // ensure objects passed in Args (rendezvous, stats_collector, etc.) 57 // are alive at least until done is invoked. All pointers to the 58 // argument objects can be nullptr. 59 // 60 // "step_id" is a process-wide unique identifier for the step being 61 // run. Executors on different devices may receive the same step_id 62 // in the case that a step runs Ops on more than one device. The 63 // step_id is used for tracking resource usage of a given step. 64 // 65 // RunAsync() uses the given "rendezvous", if not null, as the 66 // mechanism to communicate inputs and outputs of the underlying 67 // graph computation. 68 // 69 // RunAsync() calls "stats_collector", if not null, to keep track of 70 // stats. This allows us to collect statistics and traces on demand. 71 // 72 // RunAsync() is provided a "call_frame", if the executor is used 73 // for executing a function, is used to pass arguments and return 74 // values between the caller and the callee. 75 // 76 // RunAsync() uses "cancellation_manager", if not nullptr, to 77 // register callbacks that should be called if the graph computation 78 // is canceled. Note that the callbacks merely unblock any 79 // long-running computation, and a canceled step will terminate by 80 // returning/calling the DoneCallback as usual. 81 // 82 // RunAsync() dispatches closures to "runner". Typically, "runner" 83 // is backed up by a bounded threadpool. 84 struct Args { 85 int64 step_id = 0; 86 Rendezvous* rendezvous = nullptr; 87 StepStatsCollectorInterface* stats_collector = nullptr; 88 CallFrameInterface* call_frame = nullptr; 89 CancellationManager* cancellation_manager = nullptr; 90 SessionState* session_state = nullptr; 91 // Unique session identifier. Can be empty. 92 string session_handle; 93 TensorStore* tensor_store = nullptr; 94 ScopedStepContainer* step_container = nullptr; 95 CollectiveExecutor* collective_executor = nullptr; 96 97 // If true, calls Sync() on the device. 98 bool sync_on_finish = false; 99 100 typedef std::function<void()> Closure; 101 typedef std::function<void(Closure)> Runner; 102 Runner runner = nullptr; 103 }; 104 typedef std::function<void(const Status&)> DoneCallback; 105 virtual void RunAsync(const Args& args, DoneCallback done) = 0; 106 107 // Synchronous wrapper for RunAsync(). Run(const Args & args)108 Status Run(const Args& args) { 109 Status ret; 110 Notification n; 111 RunAsync(args, [&ret, &n](const Status& s) { 112 ret = s; 113 n.Notify(); 114 }); 115 n.WaitForNotification(); 116 return ret; 117 } 118 }; 119 120 // Creates an Executor that computes the given "graph". 121 // 122 // If successful, returns the constructed executor in "*executor". Otherwise, 123 // returns an error status. 124 // 125 // "params" provides a set of context for the executor. We expect that 126 // different context would provide different implementations. 127 struct LocalExecutorParams { 128 Device* device; 129 130 // The library runtime support. 131 FunctionLibraryRuntime* function_library = nullptr; 132 133 // create_kernel returns an instance of op kernel based on NodeDef. 134 // delete_kernel is called for every kernel used by the executor 135 // when the executor is deleted. 136 std::function<Status(const NodeDef&, OpKernel**)> create_kernel; 137 std::function<void(OpKernel*)> delete_kernel; 138 }; 139 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, 140 std::unique_ptr<const Graph> graph, 141 Executor** executor); 142 143 // A class to help run multiple executors in parallel and wait until 144 // all of them are complete. 145 // 146 // ExecutorBarrier deletes itself after the function returned by Get() 147 // is called. 148 class ExecutorBarrier { 149 public: 150 typedef std::function<void(const Status&)> StatusCallback; 151 152 // Create an ExecutorBarrier for 'num' different executors. 153 // 154 // 'r' is the shared Rendezvous object that is used to communicate 155 // state. If any of the executors experiences an error, the 156 // rendezvous object will be aborted exactly once. 157 // 158 // 'done' is called after the last executor completes, and 159 // ExecutorBarrier is deleted. ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)160 ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done) 161 : rendez_(r), done_cb_(done), pending_(num) {} 162 ~ExecutorBarrier()163 ~ExecutorBarrier() {} 164 165 // Returns a closure that Executors must call when they are done 166 // computing, passing the status of their execution as an argument. Get()167 StatusCallback Get() { 168 return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); 169 } 170 171 private: 172 Rendezvous* rendez_ = nullptr; 173 StatusCallback done_cb_ = nullptr; 174 175 mutable mutex mu_; 176 int pending_ GUARDED_BY(mu_) = 0; 177 StatusGroup status_group_ GUARDED_BY(mu_); 178 WhenDone(const Status & s)179 void WhenDone(const Status& s) { 180 Rendezvous* error_rendez = nullptr; 181 StatusCallback done = nullptr; 182 Status status; 183 184 { 185 mutex_lock l(mu_); 186 187 // If we are the first error encountered, trigger an abort of the 188 // Rendezvous object by this thread only. 189 if (status_group_.ok() && !s.ok()) { 190 error_rendez = rendez_; 191 error_rendez->Ref(); 192 } 193 194 status_group_.Update(s); 195 196 // If this is the last call to WhenDone, call the final callback 197 // below. 198 if (--pending_ == 0) { 199 CHECK(done_cb_ != nullptr); 200 std::swap(done, done_cb_); 201 status = status_group_.as_status(); 202 } 203 } 204 205 if (error_rendez != nullptr) { 206 error_rendez->StartAbort( 207 errors::Aborted("Stopping remaining executors.")); 208 error_rendez->Unref(); 209 } 210 211 if (done != nullptr) { 212 delete this; 213 done(status); 214 } 215 } 216 217 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier); 218 }; 219 220 // A few helpers to facilitate create/delete kernels. 221 222 // Creates a kernel based on "ndef" on device "device". The kernel can 223 // access the functions in the "flib". The caller takes ownership of 224 // returned "*kernel". 225 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, 226 const NodeDef& ndef, int graph_def_version, 227 OpKernel** kernel); 228 229 // Deletes "kernel" returned by CreateKernel. 230 void DeleteNonCachedKernel(OpKernel* kernel); 231 232 } // end namespace tensorflow 233 234 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 235