1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ 18 19 #include <atomic> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <vector> 25 26 #include "tensorflow/core/common_runtime/costmodel_manager.h" 27 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/common_runtime/device_set.h" 30 #include "tensorflow/core/common_runtime/executor.h" 31 #include "tensorflow/core/common_runtime/graph_execution_state.h" 32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 34 #include "tensorflow/core/common_runtime/session_factory.h" 35 #include "tensorflow/core/framework/cancellation.h" 36 #include "tensorflow/core/framework/collective.h" 37 #include "tensorflow/core/framework/graph.pb.h" 38 #include "tensorflow/core/framework/session_state.h" 39 #include "tensorflow/core/framework/tensor.h" 40 #include "tensorflow/core/lib/core/errors.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/platform/mutex.h" 44 #include "tensorflow/core/platform/thread_annotations.h" 45 #include "tensorflow/core/platform/types.h" 46 #include "tensorflow/core/public/session.h" 47 48 namespace tensorflow { 49 50 class CostModel; 51 class DebugGateway; 52 class Device; 53 class DirectSessionFactory; 54 55 class DirectSession : public Session { 56 public: 57 typedef std::function<void(Session*)> CloseCallback; 58 59 // Takes ownership of 'device_mgr'. 60 // 'factory' is used to unregister the DirectSession with 'factory' when its 61 // closed. This ensures that Reset requests from the 'factory' don't get sent 62 // to sessions that are already closed. 63 DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, 64 DirectSessionFactory* factory); 65 ~DirectSession() override; 66 67 typedef std::vector<std::pair<string, Tensor>> NamedTensorList; 68 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap; 69 70 ::tensorflow::Status Create(const GraphDef& graph) override; 71 ::tensorflow::Status Extend(const GraphDef& graph) override; 72 ::tensorflow::Status Run(const NamedTensorList& inputs, 73 const std::vector<string>& output_names, 74 const std::vector<string>& target_nodes, 75 std::vector<Tensor>* outputs) override; 76 77 // NOTE: Experimental and subject to change. 78 ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options, 79 const NamedTensorList& inputs, 80 const std::vector<string>& output_names, 81 const std::vector<string>& target_nodes, 82 std::vector<Tensor>* outputs, 83 RunMetadata* run_metadata) override; 84 85 // NOTE: PRunSetup and PRun are added to support partial execution. This 86 // feature is experimental and subject to change. 87 ::tensorflow::Status PRunSetup(const std::vector<string>& input_names, 88 const std::vector<string>& output_names, 89 const std::vector<string>& target_nodes, 90 string* handle) override; 91 ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs, 92 const std::vector<string>& output_names, 93 std::vector<Tensor>* outputs) override; 94 95 // Reset clears 'containers' from the device_mgr of the DirectSession. 96 // If 'containers' is empty, then Reset clears the default container. 97 ::tensorflow::Status Reset(const std::vector<string>& containers); 98 99 ::tensorflow::Status ListDevices( 100 std::vector<DeviceAttributes>* response) override; 101 ::tensorflow::Status Close() override; LocalDeviceManager(const DeviceMgr ** output)102 ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override { 103 *output = device_mgr_.get(); 104 return ::tensorflow::Status::OK(); 105 } 106 ExportCostModels(CostModelManager::CostModelMap * cost_models)107 void ExportCostModels(CostModelManager::CostModelMap* cost_models) { 108 cost_model_manager_.ExportCostModels(cost_models); 109 } 110 111 ::tensorflow::Status MakeCallable(const CallableOptions& callable_options, 112 CallableHandle* out_handle) override; 113 ::tensorflow::Status RunCallable(CallableHandle handle, 114 const std::vector<Tensor>& feed_tensors, 115 std::vector<Tensor>* fetch_tensors, 116 RunMetadata* run_metadata) override; 117 ::tensorflow::Status ReleaseCallable(CallableHandle handle) override; 118 119 private: 120 // For access to collective_graph_key_. 121 friend class DirectSessionCollectiveTest; 122 123 // We create one executor and its dependent library runtime for 124 // every partition. 125 struct PerPartitionExecutorsAndLib { 126 Graph* graph = nullptr; // not owned. 127 Device* device = nullptr; // not owned. 128 FunctionLibraryRuntime* flib = nullptr; // not owned. 129 std::unique_ptr<Executor> executor; 130 }; 131 132 // An ExecutorsAndKeys is created for a given set of feeds/fetches. 133 // 'step_count' is the number of times this graph is executed. 134 // 'graph' is the entire graph being executed. 'name_to_node' 135 // maps node name to node. We keep 'graph' and 'name_to_node' only in 136 // the case of partial runs. Each item in 'items' is the executor for 137 // a partition of the graph bundled with its dependent library runtime. 138 // 'input_keys' are the rendezvous keys for the feeds and 'output_keys' 139 // are rendezvous keys for the fetches. 140 struct ExecutorsAndKeys { ExecutorsAndKeysExecutorsAndKeys141 ExecutorsAndKeys() : step_count(0) {} 142 143 std::atomic_int_fast64_t step_count; 144 std::unique_ptr<Graph> graph; 145 NameNodeMap name_to_node; 146 std::vector<PerPartitionExecutorsAndLib> items; 147 std::unordered_map<string, size_t> input_name_to_index; 148 std::unordered_map<string, string> input_name_to_rendezvous_key; 149 std::unordered_map<string, size_t> output_name_to_index; 150 std::unordered_map<string, string> output_name_to_rendezvous_key; 151 152 DataTypeVector input_types; 153 DataTypeVector output_types; 154 155 CallableOptions callable_options; 156 157 int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; 158 }; 159 160 // A FunctionInfo object is created for every unique set of feeds/fetches. 161 // This info could be folded into the ExecutorsAndKeys object but we would 162 // like to maintain a deletion order in which the OpKernels (owned by the 163 // executor) should be destroyed first, followed by the resources in the 164 // device and then followed by the function stuff. 165 // TODO(rohanj): Consolidate function library definitions so that we can 166 // instantiate only one ProcFLR and lib_def and make this just a member 167 // variable and not a vector. 168 // 'flib_def' is the function library used. 169 // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per 170 // device. 171 struct FunctionInfo { 172 std::unique_ptr<FunctionLibraryDefinition> flib_def; 173 std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; 174 }; 175 176 // For each live partial execution, the session maintains a RunState. 177 // 'status' is the current status of this partial execution. 'executor_done' 178 // is "notified" when all executors are done. 'pending_inputs' are the set 179 // of pending feeds and 'pending_outputs' are the set of pending fetches. 180 struct RunState { 181 mutex mu_; 182 Status status GUARDED_BY(mu_); 183 IntraProcessRendezvous* rendez = nullptr; 184 std::unique_ptr<CollectiveExecutor::Handle> collective_executor; 185 std::unique_ptr<StepStatsCollector> collector; 186 Notification executors_done; 187 std::unordered_map<string, bool> pending_inputs; // true if fed 188 std::unordered_map<string, bool> pending_outputs; // true if fetched 189 TensorStore tensor_store; 190 ScopedStepContainer step_container; 191 192 RunState(int64 step_id, const std::vector<Device*>* devices); 193 194 RunState(const std::vector<string>& pending_input_names, 195 const std::vector<string>& pending_output_names, int64 step_id, 196 const std::vector<Device*>* devices); 197 198 // Returns true if all pending inputs and outputs have been completed. 199 bool PendingDone() const; 200 201 ~RunState(); 202 }; 203 204 struct RunStateArgs { RunStateArgsRunStateArgs205 RunStateArgs(const DebugOptions& options) : debug_options(options) {} 206 207 bool is_partial_run = false; 208 string handle; 209 std::unique_ptr<Graph> graph; 210 const DebugOptions& debug_options; 211 int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; 212 }; 213 214 // Initializes the base execution state given the 'graph', 215 // if not already initialized. 216 Status MaybeInitializeExecutionState(const GraphDef& graph, 217 bool* out_already_initialized) 218 EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); 219 220 // Retrieves an already existing set of executors to run 'inputs' and 221 // 'outputs', or creates and caches them for future use. 222 ::tensorflow::Status GetOrCreateExecutors( 223 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, 224 gtl::ArraySlice<string> target_nodes, 225 ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args); 226 227 // Creates a set of executors to run the subgraph defined by 228 // `callable_options`. 229 ::tensorflow::Status CreateExecutors( 230 const CallableOptions& callable_options, 231 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, 232 std::unique_ptr<FunctionInfo>* out_func_info, 233 RunStateArgs* run_state_args); 234 235 // Creates several graphs given the existing graph_def_ and the 236 // input feeds and fetches, given 'devices'. The graphs share a common 237 // function library 'flib_def'. 238 ::tensorflow::Status CreateGraphs( 239 const BuildGraphOptions& options, 240 std::unordered_map<string, std::unique_ptr<Graph>>* outputs, 241 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 242 RunStateArgs* run_state_args, DataTypeVector* input_types, 243 DataTypeVector* output_types, int64* collective_graph_key); 244 245 ::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options, 246 CallFrameInterface* call_frame, 247 ExecutorsAndKeys* executors_and_keys, 248 RunMetadata* run_metadata); 249 250 // Returns whether inter-op execution uses a global pool or the input 251 // `run_options` requests being run on inter_op_thread_pool = 0 in case 252 // multiple pools are configured. 253 bool ShouldUseRunHandlerPool(const RunOptions& run_options) const; 254 255 ::tensorflow::Status ExtendLocked(const GraphDef& graph) 256 EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); 257 258 ::tensorflow::Status ResourceHandleToInputTensor( 259 const Tensor& resource_tensor, Tensor* retrieved_tensor); 260 261 // Feeds more inputs to the executors, triggering further execution. 262 ::tensorflow::Status SendPRunInputs( 263 const std::vector<std::pair<string, Tensor>>& inputs, 264 const ExecutorsAndKeys* executors_and_keys, 265 IntraProcessRendezvous* rendez); 266 267 // Fetches more outputs from the executors. It waits until the output 268 // tensors are computed. 269 ::tensorflow::Status RecvPRunOutputs( 270 const std::vector<string>& output_names, 271 const ExecutorsAndKeys* executors_and_keys, RunState* run_state, 272 std::vector<Tensor>* outputs); 273 274 // Check if the specified fetches can be computed from the feeds 275 // that we have already provided. 276 ::tensorflow::Status CheckFetch( 277 const std::vector<std::pair<string, Tensor>>& feeds, 278 const std::vector<string>& fetches, 279 const ExecutorsAndKeys* executors_and_keys, const RunState* run_state); 280 281 // Use the appropriate WaitForNotification function based on whether 282 // operation_timeout_in_ms is greater than 0. 283 // 284 // If the timeout expires, the `cm->StartCancel()` will be called. 285 ::tensorflow::Status WaitForNotification(Notification* n, 286 int64 timeout_in_ms); 287 void WaitForNotification(RunState* run_state, CancellationManager* cm, 288 int64 timeout_in_ms); 289 CheckNotClosed()290 ::tensorflow::Status CheckNotClosed() { 291 mutex_lock l(closed_lock_); 292 if (closed_) return errors::Cancelled("Session has been closed."); 293 return ::tensorflow::Status::OK(); 294 } 295 CheckGraphCreated(const char * method)296 ::tensorflow::Status CheckGraphCreated(const char* method) { 297 mutex_lock l(graph_state_lock_); 298 if (!graph_created_) { 299 return errors::InvalidArgument( 300 "Session was not created with a graph before ", method, "!"); 301 } 302 return ::tensorflow::Status::OK(); 303 } 304 305 ::tensorflow::Status CreateDebuggerState( 306 const CallableOptions& options, int64 global_step, 307 int64 session_run_index, int64 executor_step_index, 308 std::unique_ptr<DebuggerStateInterface>* debugger_state); 309 310 ::tensorflow::Status DecorateAndPublishGraphForDebug( 311 const DebugOptions& debug_options, Graph* graph, Device* device); 312 313 const SessionOptions options_; 314 315 // Device structures. 316 const std::unique_ptr<const DeviceMgr> device_mgr_; 317 std::vector<Device*> devices_; // not owned 318 DeviceSet device_set_; 319 320 // Unique session identifier. 321 string session_handle_; 322 mutex graph_state_lock_; 323 bool graph_created_ GUARDED_BY(graph_state_lock_) = false; 324 325 // The thread-pools to use for running ops, with a bool indicating if the pool 326 // is owned. 327 std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_; 328 329 Status init_error_; // Set to an error if construction failed. 330 331 // If true, blocks until device has finished all queued operations in a step. 332 bool sync_on_finish_ = true; 333 // Schedules 'c' for execution on pool. 334 void SchedClosure(thread::ThreadPool* pool, std::function<void()> c); 335 336 std::vector<std::unique_ptr<FunctionInfo>> functions_ 337 GUARDED_BY(executor_lock_); 338 339 mutex executor_lock_; // protects executors_ 340 // Holds mappings from signature to the executors that process 341 // it. The reason for a level of indirection around mapped_type is 342 // to guarantee address stability. 343 // The map value is a shared_ptr since multiple map keys can point to the 344 // same ExecutorsAndKey object. 345 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_ 346 GUARDED_BY(executor_lock_); 347 348 class RunCallableCallFrame; 349 struct Callable { 350 std::shared_ptr<ExecutorsAndKeys> executors_and_keys; 351 std::shared_ptr<FunctionInfo> function_info; 352 ~Callable(); 353 }; 354 mutex callables_lock_; 355 int64 next_callable_handle_ GUARDED_BY(callables_lock_) = 0; 356 std::unordered_map<int64, Callable> callables_ GUARDED_BY(callables_lock_); 357 358 // Holds mappings from handle to partial run state. 359 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ 360 GUARDED_BY(executor_lock_); 361 362 // This holds all the tensors that are currently alive in the session. 363 SessionState session_state_; 364 365 DirectSessionFactory* const factory_; // not owned 366 CancellationManager* cancellation_manager_; 367 std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_; 368 369 // Map of placed stateful nodes, i.e. nodes for which is_stateful() 370 // is true, such as "params" and "queue" nodes. Once placed these 371 // nodes can not be moved to a different device. Maps node names to 372 // device names. 373 std::unordered_map<string, string> stateful_placements_ 374 GUARDED_BY(graph_state_lock_); 375 376 // Execution_state; used when placing the entire graph. 377 std::unique_ptr<GraphExecutionState> execution_state_ 378 GUARDED_BY(graph_state_lock_); 379 380 // The function library, before any rewrites or optimizations have been 381 // performed. In particular, CreateGraphs() may need to modify the function 382 // library; it copies and modifies the function library. 383 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 384 385 // true if the Session has been Closed. 386 mutex closed_lock_; 387 bool closed_ GUARDED_BY(closed_lock_) = false; 388 389 // For generating unique names for this session instance. 390 std::atomic<int64> edge_name_counter_ = {0}; 391 std::atomic<int64> handle_name_counter_ = {0}; 392 393 // For generating step ids that are unique across this sessions. 394 static std::atomic_int_fast64_t step_id_counter_; 395 396 // Global timeout for all blocking operations in this session. 397 const int64 operation_timeout_in_ms_ = 0; 398 399 // Manages all the cost models for the graphs executed in this session. 400 CostModelManager cost_model_manager_; 401 402 // For testing collective graph key generation. 403 mutex collective_graph_key_lock_; 404 int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1; 405 406 TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); 407 408 // EXPERIMENTAL: debugger (tfdbg) related 409 friend class DebugGateway; 410 }; 411 412 } // end namespace tensorflow 413 414 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ 415