1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 18 19 #include "absl/time/time.h" 20 #include "absl/types/optional.h" 21 #include "tensorflow/core/common_runtime/device.h" 22 #include "tensorflow/core/common_runtime/local_executor_params.h" 23 #include "tensorflow/core/framework/rendezvous.h" 24 #include "tensorflow/core/framework/session_state.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/graph/graph.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/lib/core/notification.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/core/threadpool_interface.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/util/managed_stack_trace.h" 34 35 namespace tensorflow { 36 37 class StepStatsCollector; 38 39 // Executor runs a graph computation. 40 // Example: 41 // Graph* graph = ...; 42 // ... construct graph ... 43 // Executor* executor; 44 // TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); 45 // Rendezvous* rendezvous = NewNaiveRendezvous(); 46 // TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); 47 // TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); 48 // TF_CHECK_OK(rendezvous->Recv("output", &output_tensor)); 49 // ... ... 50 // 51 // Multiple threads can call Executor::Run concurrently. 52 class Executor { 53 public: ~Executor()54 virtual ~Executor() {} 55 56 // RunAsync() executes the graph computation. "done" is run when the 57 // graph computation completes. If any error happens during the 58 // computation, "done" is run and the error is passed to "done". 59 // 60 // RunAsync() is given a few arguments in Args. The caller must 61 // ensure objects passed in Args (rendezvous, stats_collector, etc.) 62 // are alive at least until done is invoked. All pointers to the 63 // argument objects can be nullptr. 64 // 65 // "step_id" is a process-wide unique identifier for the step being 66 // run. Executors on different devices may receive the same step_id 67 // in the case that a step runs Ops on more than one device. The 68 // step_id is used for tracking resource usage of a given step. 69 // 70 // RunAsync() uses the given "rendezvous", if not null, as the 71 // mechanism to communicate inputs and outputs of the underlying 72 // graph computation. 73 // 74 // RunAsync() calls "stats_collector", if not null, to keep track of 75 // stats. This allows us to collect statistics and traces on demand. 76 // 77 // RunAsync() is provided a "call_frame", if the executor is used 78 // for executing a function, is used to pass arguments and return 79 // values between the caller and the callee. 80 // 81 // RunAsync() uses "cancellation_manager", if not nullptr, to 82 // register callbacks that should be called if the graph computation 83 // is canceled. Note that the callbacks merely unblock any 84 // long-running computation, and a canceled step will terminate by 85 // returning/calling the DoneCallback as usual. 86 // 87 // RunAsync() dispatches closures to "runner". Typically, "runner" 88 // is backed up by a bounded threadpool. 89 // 90 // "start_time_usecs" is a timestamp for the start of RunAsync() 91 // execution. Used for system-wide latency metrics. 92 struct Args { 93 int64_t step_id = 0; 94 RendezvousInterface* rendezvous = nullptr; 95 StepStatsCollectorInterface* stats_collector = nullptr; 96 CallFrameInterface* call_frame = nullptr; 97 CancellationManager* cancellation_manager = nullptr; 98 SessionState* session_state = nullptr; 99 // Unique session identifier. Can be empty. 100 string session_handle; 101 TensorStore* tensor_store = nullptr; 102 ScopedStepContainer* step_container = nullptr; 103 CollectiveExecutor* collective_executor = nullptr; 104 thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr; 105 CoordinationServiceAgent* coordination_service_agent = nullptr; 106 int64_t start_time_usecs = 0; 107 // The deadline for the kernel to complete by. Empty if unspecified. 108 absl::optional<absl::Time> deadline; 109 absl::optional<ManagedStackTrace> stack_trace = absl::nullopt; 110 111 // If true, calls Sync() on the device. 112 bool sync_on_finish = false; 113 114 typedef std::function<void()> Closure; 115 typedef std::function<void(Closure)> Runner; 116 Runner runner = nullptr; 117 118 // If true, all kernels will be treated as "inexpensive", and hence executed 119 // on the scheduling thread. 120 bool run_all_kernels_inline = false; 121 }; 122 typedef std::function<void(const Status&)> DoneCallback; 123 virtual void RunAsync(const Args& args, DoneCallback done) = 0; 124 125 // Synchronous wrapper for RunAsync(). Run(const Args & args)126 virtual Status Run(const Args& args) { 127 Status ret; 128 Notification n; 129 RunAsync(args, [&ret, &n](const Status& s) { 130 ret = s; 131 n.Notify(); 132 }); 133 n.WaitForNotification(); 134 return ret; 135 } 136 }; 137 138 // Creates an Executor that computes the given "graph". 139 // 140 // If successful, returns the constructed executor in "*executor". Otherwise, 141 // returns an error status. 142 // 143 // "params" provides a set of context for the executor. We expect that 144 // different context would provide different implementations. 145 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, 146 const Graph& graph, Executor** executor); 147 148 // A class to help run multiple executors in parallel and wait until 149 // all of them are complete. 150 // 151 // ExecutorBarrier deletes itself after the function returned by Get() 152 // is called. 153 class ExecutorBarrier { 154 public: 155 typedef std::function<void(const Status&)> StatusCallback; 156 157 // Create an ExecutorBarrier for 'num' different executors. 158 // 159 // 'r' is the shared Rendezvous object that is used to communicate 160 // state. If any of the executors experiences an error, the 161 // rendezvous object will be aborted exactly once. 162 // 163 // 'done' is called after the last executor completes, and 164 // ExecutorBarrier is deleted. ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)165 ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done) 166 : rendez_(r), done_cb_(done), pending_(num) {} 167 ~ExecutorBarrier()168 ~ExecutorBarrier() {} 169 170 // Returns a closure that Executors must call when they are done 171 // computing, passing the status of their execution as an argument. Get()172 StatusCallback Get() { 173 return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); 174 } 175 176 private: 177 Rendezvous* rendez_ = nullptr; 178 StatusCallback done_cb_ = nullptr; 179 180 mutable mutex mu_; 181 int pending_ TF_GUARDED_BY(mu_) = 0; 182 StatusGroup status_group_ TF_GUARDED_BY(mu_); 183 WhenDone(const Status & s)184 void WhenDone(const Status& s) { 185 Rendezvous* error_rendez = nullptr; 186 StatusCallback done = nullptr; 187 Status status; 188 189 { 190 mutex_lock l(mu_); 191 192 // If we are the first error encountered, trigger an abort of the 193 // Rendezvous object by this thread only. 194 if (status_group_.ok() && !s.ok()) { 195 error_rendez = rendez_; 196 error_rendez->Ref(); 197 } 198 199 if (!s.ok() && !StatusGroup::IsDerived(s) && 200 !status_group_.HasLogMessages()) { 201 status_group_.AttachLogMessages(); 202 } 203 204 status_group_.Update(s); 205 206 // If this is the last call to WhenDone, call the final callback 207 // below. 208 if (--pending_ == 0) { 209 CHECK(done_cb_ != nullptr); 210 std::swap(done, done_cb_); 211 status = status_group_.as_summary_status(); 212 } 213 } 214 215 if (error_rendez != nullptr) { 216 error_rendez->StartAbort( 217 errors::Aborted("Stopping remaining executors.")); 218 error_rendez->Unref(); 219 } 220 221 if (done != nullptr) { 222 delete this; 223 if (!status.ok()) { 224 VLOG(1) << "ExecutorBarrier finished with bad status: " << status; 225 } 226 done(status); 227 } 228 } 229 230 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier); 231 }; 232 233 // A few helpers to facilitate create/delete kernels. 234 235 // Creates a kernel based on "props" on device "device". The kernel can 236 // access the functions in the "flib". The caller takes ownership of 237 // returned "*kernel". 238 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, 239 const std::shared_ptr<const NodeProperties>& props, 240 int graph_def_version, OpKernel** kernel); 241 242 // Deletes "kernel" returned by CreateKernel. 243 void DeleteNonCachedKernel(OpKernel* kernel); 244 245 } // end namespace tensorflow 246 247 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ 248