1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/common_runtime/local_executor_params.h" 21 #include "tensorflow/core/framework/rendezvous.h" 22 #include "tensorflow/core/framework/session_state.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/graph/graph.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/notification.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/core/threadpool_interface.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 class StepStatsCollector; 35 36 // Executor runs a graph computation. 37 // Example: 38 // Graph* graph = ...; 39 // ... construct graph ... 40 // Executor* executor; 41 // TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); 42 // Rendezvous* rendezvous = NewNaiveRendezvous(); 43 // TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); 44 // TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); 45 // TF_CHECK_OK(rendezvous->Recv("output", &output_tensor)); 46 // ... ... 47 // 48 // Multiple threads can call Executor::Run concurrently. 49 class Executor { 50 public: ~Executor()51 virtual ~Executor() {} 52 53 // RunAsync() executes the graph computation. "done" is run when the 54 // graph computation completes. If any error happens during the 55 // computation, "done" is run and the error is passed to "done". 56 // 57 // RunAsync() is given a few arguments in Args. The caller must 58 // ensure objects passed in Args (rendezvous, stats_collector, etc.) 59 // are alive at least until done is invoked. All pointers to the 60 // argument objects can be nullptr. 61 // 62 // "step_id" is a process-wide unique identifier for the step being 63 // run. Executors on different devices may receive the same step_id 64 // in the case that a step runs Ops on more than one device. The 65 // step_id is used for tracking resource usage of a given step. 66 // 67 // RunAsync() uses the given "rendezvous", if not null, as the 68 // mechanism to communicate inputs and outputs of the underlying 69 // graph computation. 70 // 71 // RunAsync() calls "stats_collector", if not null, to keep track of 72 // stats. This allows us to collect statistics and traces on demand. 73 // 74 // RunAsync() is provided a "call_frame", if the executor is used 75 // for executing a function, is used to pass arguments and return 76 // values between the caller and the callee. 77 // 78 // RunAsync() uses "cancellation_manager", if not nullptr, to 79 // register callbacks that should be called if the graph computation 80 // is canceled. Note that the callbacks merely unblock any 81 // long-running computation, and a canceled step will terminate by 82 // returning/calling the DoneCallback as usual. 83 // 84 // RunAsync() dispatches closures to "runner". Typically, "runner" 85 // is backed up by a bounded threadpool. 86 // 87 // "start_time_usecs" is a timestamp for the start of RunAsync() 88 // execution. Used for system-wide latency metrics. 89 struct Args { 90 int64 step_id = 0; 91 RendezvousInterface* rendezvous = nullptr; 92 StepStatsCollectorInterface* stats_collector = nullptr; 93 CallFrameInterface* call_frame = nullptr; 94 CancellationManager* cancellation_manager = nullptr; 95 SessionState* session_state = nullptr; 96 // Unique session identifier. Can be empty. 97 string session_handle; 98 TensorStore* tensor_store = nullptr; 99 ScopedStepContainer* step_container = nullptr; 100 CollectiveExecutor* collective_executor = nullptr; 101 thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr; 102 CoordinationServiceAgent* coordination_service_agent = nullptr; 103 int64 start_time_usecs = 0; 104 105 // If true, calls Sync() on the device. 106 bool sync_on_finish = false; 107 108 typedef std::function<void()> Closure; 109 typedef std::function<void(Closure)> Runner; 110 Runner runner = nullptr; 111 112 // If true, all kernels will be treated as "inexpensive", and hence executed 113 // on the scheduling thread. 114 bool run_all_kernels_inline = false; 115 }; 116 typedef std::function<void(const Status&)> DoneCallback; 117 virtual void RunAsync(const Args& args, DoneCallback done) = 0; 118 119 // Synchronous wrapper for RunAsync(). Run(const Args & args)120 virtual Status Run(const Args& args) { 121 Status ret; 122 Notification n; 123 RunAsync(args, [&ret, &n](const Status& s) { 124 ret = s; 125 n.Notify(); 126 }); 127 n.WaitForNotification(); 128 return ret; 129 } 130 }; 131 132 // Creates an Executor that computes the given "graph". 133 // 134 // If successful, returns the constructed executor in "*executor". Otherwise, 135 // returns an error status. 136 // 137 // "params" provides a set of context for the executor. We expect that 138 // different context would provide different implementations. 139 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, 140 const Graph& graph, Executor** executor); 141 142 // A class to help run multiple executors in parallel and wait until 143 // all of them are complete. 144 // 145 // ExecutorBarrier deletes itself after the function returned by Get() 146 // is called. 147 class ExecutorBarrier { 148 public: 149 typedef std::function<void(const Status&)> StatusCallback; 150 151 // Create an ExecutorBarrier for 'num' different executors. 152 // 153 // 'r' is the shared Rendezvous object that is used to communicate 154 // state. If any of the executors experiences an error, the 155 // rendezvous object will be aborted exactly once. 156 // 157 // 'done' is called after the last executor completes, and 158 // ExecutorBarrier is deleted. ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)159 ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done) 160 : rendez_(r), done_cb_(done), pending_(num) {} 161 ~ExecutorBarrier()162 ~ExecutorBarrier() {} 163 164 // Returns a closure that Executors must call when they are done 165 // computing, passing the status of their execution as an argument. Get()166 StatusCallback Get() { 167 return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); 168 } 169 170 private: 171 Rendezvous* rendez_ = nullptr; 172 StatusCallback done_cb_ = nullptr; 173 174 mutable mutex mu_; 175 int pending_ TF_GUARDED_BY(mu_) = 0; 176 StatusGroup status_group_ TF_GUARDED_BY(mu_); 177 WhenDone(const Status & s)178 void WhenDone(const Status& s) { 179 Rendezvous* error_rendez = nullptr; 180 StatusCallback done = nullptr; 181 Status status; 182 183 { 184 mutex_lock l(mu_); 185 186 // If we are the first error encountered, trigger an abort of the 187 // Rendezvous object by this thread only. 188 if (status_group_.ok() && !s.ok()) { 189 error_rendez = rendez_; 190 error_rendez->Ref(); 191 } 192 193 if (!s.ok() && !StatusGroup::IsDerived(s) && 194 !status_group_.HasLogMessages()) { 195 status_group_.AttachLogMessages(); 196 } 197 198 status_group_.Update(s); 199 200 // If this is the last call to WhenDone, call the final callback 201 // below. 202 if (--pending_ == 0) { 203 CHECK(done_cb_ != nullptr); 204 std::swap(done, done_cb_); 205 status = status_group_.as_summary_status(); 206 } 207 } 208 209 if (error_rendez != nullptr) { 210 error_rendez->StartAbort( 211 errors::Aborted("Stopping remaining executors.")); 212 error_rendez->Unref(); 213 } 214 215 if (done != nullptr) { 216 delete this; 217 if (!status.ok()) { 218 VLOG(1) << "ExecutorBarrier finished with bad status: " << status; 219 } 220 done(status); 221 } 222 } 223 224 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier); 225 }; 226 227 // A few helpers to facilitate create/delete kernels. 228 229 // Creates a kernel based on "props" on device "device". The kernel can 230 // access the functions in the "flib". The caller takes ownership of 231 // returned "*kernel". 232 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, 233 const std::shared_ptr<const NodeProperties>& props, 234 int graph_def_version, OpKernel** kernel); 235 236 // Deletes "kernel" returned by CreateKernel. 237 void DeleteNonCachedKernel(OpKernel* kernel); 238 239 } // end namespace tensorflow 240 241 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 242