1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/common_runtime/local_executor_params.h" 21 #include "tensorflow/core/framework/rendezvous.h" 22 #include "tensorflow/core/framework/session_state.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/graph/graph.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/notification.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/core/threadpool_interface.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 class StepStatsCollector; 35 36 // If this is called, we will sample execution cost for "inexpensive" kernels 37 // and switch them to "expensive" when the estimated cost exceeds expensive-ness 38 // threshold. 39 // This is a temporary flag for validating the performance impact of 40 // this feature. For simplicity, a global flag is used and once the flag 41 // is turned on, it cannot be turned off. We will remove this flag once this 42 // feature is validated. 43 void EnableAlwaysTrackKernelExecutionCost(); 44 45 // Executor runs a graph computation. 46 // Example: 47 // Graph* graph = ...; 48 // ... construct graph ... 49 // Executor* executor; 50 // TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); 51 // Rendezvous* rendezvous = NewNaiveRendezvous(); 52 // TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); 53 // TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); 54 // TF_CHECK_OK(rendezvous->Recv("output", &output_tensor)); 55 // ... ... 56 // 57 // Multiple threads can call Executor::Run concurrently. 58 class Executor { 59 public: ~Executor()60 virtual ~Executor() {} 61 62 // RunAsync() executes the graph computation. "done" is run when the 63 // graph computation completes. If any error happens during the 64 // computation, "done" is run and the error is passed to "done". 65 // 66 // RunAsync() is given a few arguments in Args. The caller must 67 // ensure objects passed in Args (rendezvous, stats_collector, etc.) 68 // are alive at least until done is invoked. All pointers to the 69 // argument objects can be nullptr. 70 // 71 // "step_id" is a process-wide unique identifier for the step being 72 // run. Executors on different devices may receive the same step_id 73 // in the case that a step runs Ops on more than one device. The 74 // step_id is used for tracking resource usage of a given step. 75 // 76 // RunAsync() uses the given "rendezvous", if not null, as the 77 // mechanism to communicate inputs and outputs of the underlying 78 // graph computation. 79 // 80 // RunAsync() calls "stats_collector", if not null, to keep track of 81 // stats. This allows us to collect statistics and traces on demand. 82 // 83 // RunAsync() is provided a "call_frame", if the executor is used 84 // for executing a function, is used to pass arguments and return 85 // values between the caller and the callee. 86 // 87 // RunAsync() uses "cancellation_manager", if not nullptr, to 88 // register callbacks that should be called if the graph computation 89 // is canceled. Note that the callbacks merely unblock any 90 // long-running computation, and a canceled step will terminate by 91 // returning/calling the DoneCallback as usual. 92 // 93 // RunAsync() dispatches closures to "runner". Typically, "runner" 94 // is backed up by a bounded threadpool. 95 struct Args { 96 int64 step_id = 0; 97 RendezvousInterface* rendezvous = nullptr; 98 StepStatsCollectorInterface* stats_collector = nullptr; 99 CallFrameInterface* call_frame = nullptr; 100 CancellationManager* cancellation_manager = nullptr; 101 SessionState* session_state = nullptr; 102 // Unique session identifier. Can be empty. 103 string session_handle; 104 TensorStore* tensor_store = nullptr; 105 ScopedStepContainer* step_container = nullptr; 106 CollectiveExecutor* collective_executor = nullptr; 107 thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr; 108 109 // If true, calls Sync() on the device. 110 bool sync_on_finish = false; 111 112 typedef std::function<void()> Closure; 113 typedef std::function<void(Closure)> Runner; 114 Runner runner = nullptr; 115 116 // If true, all kernels will be treated as "inexpensive", and hence executed 117 // on the scheduling thread. 118 bool run_all_kernels_inline = false; 119 }; 120 typedef std::function<void(const Status&)> DoneCallback; 121 virtual void RunAsync(const Args& args, DoneCallback done) = 0; 122 123 // Synchronous wrapper for RunAsync(). Run(const Args & args)124 virtual Status Run(const Args& args) { 125 Status ret; 126 Notification n; 127 RunAsync(args, [&ret, &n](const Status& s) { 128 ret = s; 129 n.Notify(); 130 }); 131 n.WaitForNotification(); 132 return ret; 133 } 134 }; 135 136 // Creates an Executor that computes the given "graph". 137 // 138 // If successful, returns the constructed executor in "*executor". Otherwise, 139 // returns an error status. 140 // 141 // "params" provides a set of context for the executor. We expect that 142 // different context would provide different implementations. 143 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, 144 const Graph& graph, Executor** executor); 145 146 // A class to help run multiple executors in parallel and wait until 147 // all of them are complete. 148 // 149 // ExecutorBarrier deletes itself after the function returned by Get() 150 // is called. 151 class ExecutorBarrier { 152 public: 153 typedef std::function<void(const Status&)> StatusCallback; 154 155 // Create an ExecutorBarrier for 'num' different executors. 156 // 157 // 'r' is the shared Rendezvous object that is used to communicate 158 // state. If any of the executors experiences an error, the 159 // rendezvous object will be aborted exactly once. 160 // 161 // 'done' is called after the last executor completes, and 162 // ExecutorBarrier is deleted. ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)163 ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done) 164 : rendez_(r), done_cb_(done), pending_(num) {} 165 ~ExecutorBarrier()166 ~ExecutorBarrier() {} 167 168 // Returns a closure that Executors must call when they are done 169 // computing, passing the status of their execution as an argument. Get()170 StatusCallback Get() { 171 return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); 172 } 173 174 private: 175 Rendezvous* rendez_ = nullptr; 176 StatusCallback done_cb_ = nullptr; 177 178 mutable mutex mu_; 179 int pending_ TF_GUARDED_BY(mu_) = 0; 180 StatusGroup status_group_ TF_GUARDED_BY(mu_); 181 WhenDone(const Status & s)182 void WhenDone(const Status& s) { 183 Rendezvous* error_rendez = nullptr; 184 StatusCallback done = nullptr; 185 Status status; 186 187 { 188 mutex_lock l(mu_); 189 190 // If we are the first error encountered, trigger an abort of the 191 // Rendezvous object by this thread only. 192 if (status_group_.ok() && !s.ok()) { 193 error_rendez = rendez_; 194 error_rendez->Ref(); 195 } 196 197 if (!s.ok() && !StatusGroup::IsDerived(s) && 198 !status_group_.HasLogMessages()) { 199 status_group_.AttachLogMessages(); 200 } 201 202 status_group_.Update(s); 203 204 // If this is the last call to WhenDone, call the final callback 205 // below. 206 if (--pending_ == 0) { 207 CHECK(done_cb_ != nullptr); 208 std::swap(done, done_cb_); 209 status = status_group_.as_summary_status(); 210 } 211 } 212 213 if (error_rendez != nullptr) { 214 error_rendez->StartAbort( 215 errors::Aborted("Stopping remaining executors.")); 216 error_rendez->Unref(); 217 } 218 219 if (done != nullptr) { 220 delete this; 221 if (!status.ok()) { 222 VLOG(1) << "ExecutorBarrier finished with bad status: " << status; 223 } 224 done(status); 225 } 226 } 227 228 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier); 229 }; 230 231 // A few helpers to facilitate create/delete kernels. 232 233 // Creates a kernel based on "props" on device "device". The kernel can 234 // access the functions in the "flib". The caller takes ownership of 235 // returned "*kernel". 236 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, 237 const std::shared_ptr<const NodeProperties>& props, 238 int graph_def_version, OpKernel** kernel); 239 240 // Deletes "kernel" returned by CreateKernel. 241 void DeleteNonCachedKernel(OpKernel* kernel); 242 243 } // end namespace tensorflow 244 245 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 246