1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ 18 19 #include <atomic> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <vector> 25 26 #include "tensorflow/core/common_runtime/costmodel_manager.h" 27 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/common_runtime/device_set.h" 30 #include "tensorflow/core/common_runtime/executor.h" 31 #include "tensorflow/core/common_runtime/graph_execution_state.h" 32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 34 #include "tensorflow/core/common_runtime/session_factory.h" 35 #include "tensorflow/core/framework/cancellation.h" 36 #include "tensorflow/core/framework/collective.h" 37 #include "tensorflow/core/framework/graph.pb.h" 38 #include "tensorflow/core/framework/session_state.h" 39 #include "tensorflow/core/framework/tensor.h" 40 #include "tensorflow/core/lib/core/errors.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/platform/mutex.h" 44 #include "tensorflow/core/platform/thread_annotations.h" 45 #include "tensorflow/core/platform/types.h" 46 #include "tensorflow/core/public/session.h" 47 48 namespace tensorflow { 49 50 class CostModel; 51 class DebugGateway; 52 class Device; 53 class DirectSessionFactory; 54 55 class DirectSession : public Session { 56 public: 57 typedef std::function<void(Session*)> CloseCallback; 58 59 // Takes ownership of 'device_mgr'. 60 // 'factory' is used to unregister the DirectSession with 'factory' when its 61 // closed. This ensures that Reset requests from the 'factory' don't get sent 62 // to sessions that are already closed. 63 DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, 64 DirectSessionFactory* factory); 65 ~DirectSession() override; 66 67 typedef std::vector<std::pair<string, Tensor>> NamedTensorList; 68 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap; 69 70 ::tensorflow::Status Create(const GraphDef& graph) override; 71 ::tensorflow::Status Create(GraphDef&& graph) override; 72 ::tensorflow::Status Extend(const GraphDef& graph) override; 73 ::tensorflow::Status Extend(GraphDef&& graph) override; 74 ::tensorflow::Status Run(const NamedTensorList& inputs, 75 const std::vector<string>& output_names, 76 const std::vector<string>& target_nodes, 77 std::vector<Tensor>* outputs) override; 78 79 // NOTE: Experimental and subject to change. 80 ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options, 81 const NamedTensorList& inputs, 82 const std::vector<string>& output_names, 83 const std::vector<string>& target_nodes, 84 std::vector<Tensor>* outputs, 85 RunMetadata* run_metadata) override; 86 87 // NOTE: Experimental and subject to change. 88 ::tensorflow::Status Run( 89 const ::tensorflow::RunOptions& run_options, 90 const NamedTensorList& inputs, const std::vector<string>& output_names, 91 const std::vector<string>& target_nodes, std::vector<Tensor>* outputs, 92 RunMetadata* run_metadata, 93 const thread::ThreadPoolOptions& threadpool_options) override; 94 95 // NOTE: PRunSetup and PRun are added to support partial execution. This 96 // feature is experimental and subject to change. 97 ::tensorflow::Status PRunSetup(const std::vector<string>& input_names, 98 const std::vector<string>& output_names, 99 const std::vector<string>& target_nodes, 100 string* handle) override; 101 ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs, 102 const std::vector<string>& output_names, 103 std::vector<Tensor>* outputs) override; 104 105 // Reset clears 'containers' from the device_mgr of the DirectSession. 106 // If 'containers' is empty, then Reset clears the default container. 107 ::tensorflow::Status Reset(const std::vector<string>& containers); 108 109 ::tensorflow::Status ListDevices( 110 std::vector<DeviceAttributes>* response) override; 111 ::tensorflow::Status Close() override; LocalDeviceManager(const DeviceMgr ** output)112 ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override { 113 *output = device_mgr_.get(); 114 return ::tensorflow::Status::OK(); 115 } 116 ExportCostModels(CostModelManager::CostModelMap * cost_models)117 void ExportCostModels(CostModelManager::CostModelMap* cost_models) { 118 cost_model_manager_.ExportCostModels(cost_models); 119 } 120 121 ::tensorflow::Status MakeCallable(const CallableOptions& callable_options, 122 CallableHandle* out_handle) override; 123 124 ::tensorflow::Status RunCallable(CallableHandle handle, 125 const std::vector<Tensor>& feed_tensors, 126 std::vector<Tensor>* fetch_tensors, 127 RunMetadata* run_metadata) override; 128 129 ::tensorflow::Status RunCallable( 130 CallableHandle handle, const std::vector<Tensor>& feed_tensors, 131 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata, 132 const thread::ThreadPoolOptions& threadpool_options) override; 133 134 ::tensorflow::Status ReleaseCallable(CallableHandle handle) override; 135 136 ::tensorflow::Status Finalize() override; 137 options()138 const SessionOptions& options() const { return options_; } 139 140 private: 141 // For access to collective_graph_key_. 142 friend class DirectSessionCollectiveTest; 143 144 // We create one executor and its dependent library runtime for 145 // every partition. 146 struct PerPartitionExecutorsAndLib { 147 std::unique_ptr<Graph> graph = nullptr; 148 Device* device = nullptr; // not owned. 149 FunctionLibraryRuntime* flib = nullptr; // not owned. 150 std::unique_ptr<Executor> executor; 151 }; 152 153 // An ExecutorsAndKeys is created for a given set of feeds/fetches. 154 // 'step_count' is the number of times this graph is executed. 155 // 'graph' is the entire graph being executed. 'name_to_node' 156 // maps node name to node. We keep 'graph' and 'name_to_node' only in 157 // the case of partial runs. Each item in 'items' is the executor for 158 // a partition of the graph bundled with its dependent library runtime. 159 // 'input_keys' are the rendezvous keys for the feeds and 'output_keys' 160 // are rendezvous keys for the fetches. 161 struct ExecutorsAndKeys { ExecutorsAndKeysExecutorsAndKeys162 ExecutorsAndKeys() : step_count(0) {} 163 164 std::atomic_int_fast64_t step_count; 165 std::unique_ptr<Graph> graph; 166 NameNodeMap name_to_node; 167 std::vector<PerPartitionExecutorsAndLib> items; 168 std::unordered_map<string, size_t> input_name_to_index; 169 std::unordered_map<string, string> input_name_to_rendezvous_key; 170 std::unordered_map<string, size_t> output_name_to_index; 171 std::unordered_map<string, string> output_name_to_rendezvous_key; 172 173 DataTypeVector input_types; 174 DataTypeVector output_types; 175 176 CallableOptions callable_options; 177 178 int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; 179 }; 180 181 // A FunctionInfo object is created for every unique set of feeds/fetches. 182 // This info could be folded into the ExecutorsAndKeys object but we would 183 // like to maintain a deletion order in which the OpKernels (owned by the 184 // executor) should be destroyed first, followed by the resources in the 185 // device and then followed by the function stuff. 186 // TODO(rohanj): Consolidate function library definitions so that we can 187 // instantiate only one ProcFLR and lib_def and make this just a member 188 // variable and not a vector. 189 // 'flib_def' is the function library used. 190 // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per 191 // device. 192 struct FunctionInfo { 193 std::unique_ptr<FunctionLibraryDefinition> flib_def; 194 std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; 195 }; 196 197 // For each live Run() call, the session maintains a RunState. 198 // 'status' is the current status of the execution. 199 struct RunState { 200 mutex mu; 201 Status status TF_GUARDED_BY(mu); 202 std::unique_ptr<CollectiveExecutor::Handle> collective_executor; 203 std::unique_ptr<StepStatsCollector> collector; 204 TensorStore tensor_store; 205 ScopedStepContainer step_container; 206 207 RunState(int64 step_id, const std::vector<Device*>* devices); 208 }; 209 210 // For each live partial execution, the session maintains a PartialRunState. 211 // 'executor_done' is "notified" when all executors are done. 'pending_inputs' 212 // are the set of pending feeds and 'pending_outputs' are the set of pending 213 // fetches. 214 struct PartialRunState : public RunState { 215 Notification executors_done; 216 std::unordered_map<string, bool> pending_inputs; // true if fed 217 std::unordered_map<string, bool> pending_outputs; // true if fetched 218 core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr; 219 220 PartialRunState(const std::vector<string>& pending_input_names, 221 const std::vector<string>& pending_output_names, 222 int64 step_id, const std::vector<Device*>* devices); 223 224 // Returns true if all pending inputs and outputs have been completed. 225 bool PendingDone() const; 226 227 ~PartialRunState(); 228 }; 229 230 struct RunStateArgs { RunStateArgsRunStateArgs231 explicit RunStateArgs(const DebugOptions& options) 232 : debug_options(options) {} 233 234 bool is_partial_run = false; 235 string handle; 236 std::unique_ptr<Graph> graph; 237 const DebugOptions& debug_options; 238 int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; 239 }; 240 241 // Retrieves an already existing set of executors to run 'inputs' and 242 // 'outputs', or creates and caches them for future use. 243 ::tensorflow::Status GetOrCreateExecutors( 244 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, 245 gtl::ArraySlice<string> target_nodes, 246 ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args); 247 248 // Creates a set of executors to run the subgraph defined by 249 // `callable_options`. 250 ::tensorflow::Status CreateExecutors( 251 const CallableOptions& callable_options, 252 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, 253 std::unique_ptr<FunctionInfo>* out_func_info, 254 RunStateArgs* run_state_args); 255 256 // Creates several graphs given the existing graph_def_ and the 257 // input feeds and fetches, given 'devices'. The graphs share a common 258 // function library 'flib_def'. 259 ::tensorflow::Status CreateGraphs( 260 const BuildGraphOptions& options, 261 std::unordered_map<string, std::unique_ptr<Graph>>* outputs, 262 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 263 RunStateArgs* run_state_args, DataTypeVector* input_types, 264 DataTypeVector* output_types, int64* collective_graph_key); 265 266 ::tensorflow::Status RunInternal( 267 int64 step_id, const RunOptions& run_options, 268 CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys, 269 RunMetadata* run_metadata, 270 const thread::ThreadPoolOptions& threadpool_options); 271 272 // Returns whether inter-op execution uses a global pool or the input 273 // `run_options` requests being run on inter_op_thread_pool = 0 in case 274 // multiple pools are configured. 275 bool ShouldUseRunHandlerPool(const RunOptions& run_options) const; 276 277 ::tensorflow::Status ExtendLocked(GraphDef graph) 278 TF_EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); 279 280 ::tensorflow::Status ResourceHandleToInputTensor( 281 const Tensor& resource_tensor, Tensor* retrieved_tensor); 282 283 // Feeds more inputs to the executors, triggering further execution. 284 ::tensorflow::Status SendPRunInputs( 285 const std::vector<std::pair<string, Tensor>>& inputs, 286 const ExecutorsAndKeys* executors_and_keys, 287 IntraProcessRendezvous* rendez); 288 289 // Fetches more outputs from the executors. It waits until the output 290 // tensors are computed. 291 ::tensorflow::Status RecvPRunOutputs( 292 const std::vector<string>& output_names, 293 const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state, 294 std::vector<Tensor>* outputs); 295 296 // Check if the specified fetches can be computed from the feeds 297 // that we have already provided. 298 ::tensorflow::Status CheckFetch( 299 const std::vector<std::pair<string, Tensor>>& feeds, 300 const std::vector<string>& fetches, 301 const ExecutorsAndKeys* executors_and_keys, 302 const PartialRunState* run_state); 303 304 // Use the appropriate WaitForNotification function based on whether 305 // operation_timeout_in_ms is greater than 0. 306 // 307 // If the timeout expires, the `cm->StartCancel()` will be called. 308 ::tensorflow::Status WaitForNotification(Notification* n, 309 int64 timeout_in_ms); 310 void WaitForNotification(Notification* n, RunState* run_state, 311 CancellationManager* cm, int64 timeout_in_ms); 312 CheckNotClosed()313 ::tensorflow::Status CheckNotClosed() { 314 mutex_lock l(closed_lock_); 315 if (closed_) return errors::Cancelled("Session has been closed."); 316 return ::tensorflow::Status::OK(); 317 } 318 CheckGraphCreated(const char * method)319 ::tensorflow::Status CheckGraphCreated(const char* method) { 320 mutex_lock l(graph_state_lock_); 321 if (!graph_created_) { 322 return errors::InvalidArgument( 323 "Session was not created with a graph before ", method, "!"); 324 } 325 return ::tensorflow::Status::OK(); 326 } 327 328 ::tensorflow::Status CreateDebuggerState( 329 const CallableOptions& options, int64 global_step, 330 int64 session_run_index, int64 executor_step_index, 331 std::unique_ptr<DebuggerStateInterface>* debugger_state); 332 333 ::tensorflow::Status DecorateAndPublishGraphForDebug( 334 const DebugOptions& debug_options, Graph* graph, Device* device); 335 336 const SessionOptions options_; 337 338 // Device structures. 339 const std::unique_ptr<const DeviceMgr> device_mgr_; 340 std::vector<Device*> devices_; // not owned 341 DeviceSet device_set_; 342 343 // Unique session identifier. 344 string session_handle_; 345 mutex graph_state_lock_; 346 bool graph_created_ TF_GUARDED_BY(graph_state_lock_) = false; 347 bool finalized_ TF_GUARDED_BY(graph_state_lock_) = false; 348 349 // The thread-pools to use for running ops, with a bool indicating if the pool 350 // is owned. 351 std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_; 352 353 Status init_error_; // Set to an error if construction failed. 354 355 // If true, blocks until device has finished all queued operations in a step. 356 bool sync_on_finish_ = true; 357 358 std::vector<std::unique_ptr<FunctionInfo>> functions_ 359 TF_GUARDED_BY(executor_lock_); 360 361 mutex executor_lock_; // protects executors_ 362 // Holds mappings from signature to the executors that process 363 // it. The reason for a level of indirection around mapped_type is 364 // to guarantee address stability. 365 // The map value is a shared_ptr since multiple map keys can point to the 366 // same ExecutorsAndKey object. 367 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_ 368 TF_GUARDED_BY(executor_lock_); 369 370 class RunCallableCallFrame; 371 struct Callable { 372 std::shared_ptr<ExecutorsAndKeys> executors_and_keys; 373 std::shared_ptr<FunctionInfo> function_info; 374 ~Callable(); 375 }; 376 mutex callables_lock_; 377 int64 next_callable_handle_ TF_GUARDED_BY(callables_lock_) = 0; 378 std::unordered_map<int64, Callable> callables_ TF_GUARDED_BY(callables_lock_); 379 380 // Holds mappings from handle to partial run state. 381 std::unordered_map<string, std::unique_ptr<PartialRunState>> partial_runs_ 382 TF_GUARDED_BY(executor_lock_); 383 384 // This holds all the tensors that are currently alive in the session. 385 SessionState session_state_; 386 387 DirectSessionFactory* const factory_; // not owned 388 CancellationManager* cancellation_manager_; 389 std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_; 390 391 // Map of placed stateful nodes, i.e. nodes for which is_stateful() 392 // is true, such as "params" and "queue" nodes. Once placed these 393 // nodes can not be moved to a different device. Maps node names to 394 // device names. 395 std::unordered_map<string, string> stateful_placements_ 396 TF_GUARDED_BY(graph_state_lock_); 397 398 // Execution_state; used when placing the entire graph. 399 std::unique_ptr<GraphExecutionState> execution_state_ 400 TF_GUARDED_BY(graph_state_lock_); 401 402 // The function library, before any rewrites or optimizations have been 403 // performed. In particular, CreateGraphs() may need to modify the function 404 // library; it copies and modifies the function library. 405 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 406 407 // true if the Session has been Closed. 408 mutex closed_lock_; 409 bool closed_ TF_GUARDED_BY(closed_lock_) = false; 410 411 // For generating unique names for this session instance. 412 std::atomic<int64> edge_name_counter_ = {0}; 413 std::atomic<int64> handle_name_counter_ = {0}; 414 415 // For generating step ids that are unique among all sessions. 416 static std::atomic_int_fast64_t step_id_counter_; 417 418 // Global timeout for all blocking operations in this session. 419 const int64 operation_timeout_in_ms_ = 0; 420 421 // Manages all the cost models for the graphs executed in this session. 422 CostModelManager cost_model_manager_; 423 424 // For testing collective graph key generation. 425 mutex collective_graph_key_lock_; 426 int64 collective_graph_key_ TF_GUARDED_BY(collective_graph_key_lock_) = -1; 427 428 // Run in caller's thread if RunOptions.inter_op_thread_pool is negative or 429 // all of following conditions are met: 430 // 1. This session doesn't own any thread pool. 431 // 2. RunOptions.inter_op_thread_pool is unspecified or 0. 432 // 3. This session has a single executor. 433 // 4. config.inter_op_parallelism_threads is specified to negative explicitly 434 // or through environment variable TF_NUM_INTEROP_THREADS. 435 // 5. RunOptions.experimental.use_run_handler_pool is unspecified or false. 436 // Otherwise run in global thread pool, session owned thread pool or handler 437 // pool according to other specifications of RunOptions and ConfigProto. 438 bool run_in_caller_thread_ = false; 439 440 TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); 441 442 // EXPERIMENTAL: debugger (tfdbg) related 443 friend class DebugGateway; 444 }; 445 446 } // end namespace tensorflow 447 448 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ 449