1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 18 19 #include <atomic> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 23 #include "tensorflow/core/common_runtime/device_set.h" 24 #include "tensorflow/core/common_runtime/graph_execution_state.h" 25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 26 #include "tensorflow/core/distributed_runtime/call_options.h" 27 #include "tensorflow/core/distributed_runtime/master_env.h" 28 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 29 #include "tensorflow/core/distributed_runtime/worker_cache.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/protobuf/master.pb.h" 33 #include "tensorflow/core/public/session_options.h" 34 35 namespace tensorflow { 36 37 class Device; 38 struct MasterEnv; 39 40 // A session encapsulates a graph computation (resource allocation, 41 // placement, execution, etc.). 42 class MasterSession : public core::RefCounted { 43 public: 44 // This session encapsulates the graph computation for a graph. 45 // 46 // The session places nodes on devices in "remote_devs" and executes 47 // operations on these devices. 48 // 49 // The caller takes ownership of all remote devices. 50 MasterSession( 51 const SessionOptions& options, const MasterEnv* env, 52 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, 53 std::unique_ptr<WorkerCacheInterface> worker_cache, 54 std::unique_ptr<DeviceSet> device_set, 55 std::vector<string> filtered_worker_list, 56 StatsPublisherFactory stats_publisher_factory); 57 58 // Initialize the MasterSession for "def". Must be called before Extend(), 59 // Run(), or Close(). 60 Status Create(GraphDef&& def, const WorkerCacheFactoryOptions& options); 61 62 // Returns the session handle. handle()63 const string& handle() const { return handle_; } 64 65 // Returns the last access time (the number of micro-seconds since 66 // some fixed point in time) of this session. last_access_time_usec()67 uint64 last_access_time_usec() const { return last_access_time_usec_.load(); } 68 69 // Attempt to extend the graph according to the given "req". 70 // (See master.proto for details of valid extensions.) 71 // 72 // PRECONDITION: The current version of this session's graph 73 // is "req->current_graph_version". 74 // 75 // POSTCONDITION: The current version of this session's graph 76 // is "resp->new_graph_version". 77 // 78 // Extend() may block the caller thread for a long time. 79 Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp); 80 81 // Setup a partial run call. 82 Status PartialRunSetup(const PartialRunSetupRequest* req, 83 PartialRunSetupResponse* resp); 84 85 // Run one step. 86 Status Run(CallOptions* opts, const RunStepRequestWrapper& req, 87 MutableRunStepResponseWrapper* resp); 88 89 Status ListDevices(ListDevicesResponse* resp) const; 90 91 Status MakeCallable(const MakeCallableRequest& req, 92 MakeCallableResponse* resp); 93 94 Status RunCallable(CallOptions* opts, const RunCallableRequest& req, 95 RunCallableResponse* resp); 96 97 Status ReleaseCallable(const ReleaseCallableRequest& req, 98 ReleaseCallableResponse* resp); 99 100 // Close this session and delete "*this". Returns OK if all known 101 // states are cleanup successfully. 102 // 103 // Close() may block the caller thread for a long time. 104 Status Close(); 105 106 // Close this session and release a reference on "*this". 107 // 108 // Note that, unlike Close(), this method does not block on the 109 // completion of all work. 110 void GarbageCollect(); 111 112 private: 113 SessionOptions session_opts_; 114 115 // Not owned. 116 const MasterEnv* env_; 117 118 // The opaque session handle. 119 const string handle_; 120 121 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_; 122 123 // The optional session-specific worker cluster. 124 // TODO(saeta): Convert to std::optional when available. 125 const std::unique_ptr<WorkerCacheInterface> worker_cache_; 126 // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. 127 WorkerCacheInterface* get_worker_cache() const; 128 129 // The device set used by this session. 130 std::unique_ptr<DeviceSet> devices_; 131 132 // The (partial device) names of remote worker tasks that this 133 // session will contact. 134 const std::vector<string> filtered_worker_list_; 135 136 StatsPublisherFactory stats_publisher_factory_; 137 138 std::atomic_ulong last_access_time_usec_; 139 140 std::atomic<int64> partial_run_handle_counter_ = {0}; 141 142 uint64 NewStepId(int64 graph_key); 143 144 mutex mu_; 145 std::unique_ptr<GraphExecutionState> execution_state_ TF_GUARDED_BY(mu_); 146 int64 graph_version_; 147 148 // We keep a map from a signature of a run request to the 149 // ReffedClientGraph the can execute it. We keep up to one old copy 150 // of each ReffedClientGraph around because if it gets deallocated 151 // before a new substitute has been created, Variables can go out of 152 // scope and lose their state. 153 class ReffedClientGraph; 154 typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap; 155 RCGMap run_graphs_ TF_GUARDED_BY(mu_); 156 RCGMap partial_run_graphs_ TF_GUARDED_BY(mu_); 157 int64 next_callable_handle_ TF_GUARDED_BY(mu_) = 0; 158 RCGMap callables_ TF_GUARDED_BY(mu_); 159 160 struct PerStepState { 161 bool collect_costs = false; 162 bool collect_timeline = false; 163 bool collect_rpcs = false; 164 bool collect_partition_graphs = false; 165 bool report_tensor_allocations_upon_oom = false; 166 Microseconds start_micros = Microseconds(0); 167 Microseconds end_micros = Microseconds(0); 168 std::vector<StepStats> step_stats; // per partition 169 StepStats rpc_stats; // for RPC layer 170 CostGraphDef cost_graph; 171 }; 172 173 struct RunState { 174 std::unordered_map<string, bool> pending_inputs; // true if fed 175 std::unordered_map<string, bool> pending_outputs; // true if fetched 176 ReffedClientGraph* rcg = nullptr; 177 uint64 step_id; 178 int64 collective_graph_key; 179 int64 count = 0; 180 PerStepState pss; 181 std::unique_ptr<ProfileHandler> ph; 182 bool step_started = false; 183 184 RunState(const std::vector<string>& input_names, 185 const std::vector<string>& output_names, ReffedClientGraph* rcg, 186 const uint64 step_id, const int64 count); 187 188 bool PendingDone() const; 189 190 ~RunState(); 191 }; 192 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ 193 TF_GUARDED_BY(mu_); 194 195 // Active RunStep calls. 196 condition_variable num_running_is_zero_; 197 int32 num_running_ TF_GUARDED_BY(mu_) = 0; 198 199 bool closed_ TF_GUARDED_BY(mu_) = false; 200 bool garbage_collected_ TF_GUARDED_BY(mu_) = false; 201 202 std::unordered_map<uint64, int64> subgraph_execution_counts_ 203 TF_GUARDED_BY(mu_); 204 205 // We need to ensure that certain nodes added (e.g., send and recv 206 // nodes) are unique across all sub-graphs within this session. 207 int64 next_node_id_ TF_GUARDED_BY(mu_) = 0; 208 209 // Used to cancel running steps on Close(). 210 CancellationManager cancellation_manager_; 211 212 // Private dtor. The client must call Close(). 213 virtual ~MasterSession(); 214 215 // Creates sessions on all workers. 216 // 217 // If this session is operating using the new ClusterSpec propagation behavior 218 // call this method in order to propagate the cluster membership to all 219 // workers. 220 Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); 221 222 bool should_delete_worker_sessions_ = false; 223 Status DeleteWorkerSessions(); 224 225 Status StartStep(const BuildGraphOptions& opts, bool is_partial, 226 ReffedClientGraph** out_rcg, int64* out_count); 227 void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, 228 RCGMap* rcg_map) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 229 void FillPerStepState(MasterSession::ReffedClientGraph* rcg, 230 const RunOptions& run_options, uint64 step_id, 231 int64 count, PerStepState* out_pss, 232 std::unique_ptr<ProfileHandler>* out_ph); 233 Status DoRunWithLocalExecution(CallOptions* opts, 234 const RunStepRequestWrapper& req, 235 MutableRunStepResponseWrapper* resp); 236 Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, 237 MutableRunStepResponseWrapper* resp); 238 Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, 239 const RunCallableRequest& req, 240 RunCallableResponse* resp); 241 Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id, 242 const RunOptions& run_options, PerStepState* pss, 243 const std::unique_ptr<ProfileHandler>& ph, 244 const Status& run_status, 245 RunMetadata* out_run_metadata); 246 247 void MarkRunCompletion(); 248 void UpdateLastAccessTime(); 249 250 Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); 251 252 Status CreateDebuggerState( 253 const DebugOptions& debug_options, const RunStepRequestWrapper& req, 254 int64 rcg_execution_count, 255 std::unique_ptr<DebuggerStateInterface>* debugger_state); 256 257 TF_DISALLOW_COPY_AND_ASSIGN(MasterSession); 258 }; 259 260 } // end namespace tensorflow 261 262 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 263