1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 18 19 #include <atomic> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 23 #include "tensorflow/core/common_runtime/device_set.h" 24 #include "tensorflow/core/common_runtime/graph_execution_state.h" 25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 26 #include "tensorflow/core/distributed_runtime/call_options.h" 27 #include "tensorflow/core/distributed_runtime/master_env.h" 28 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 29 #include "tensorflow/core/distributed_runtime/worker_cache.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/protobuf/master.pb.h" 33 #include "tensorflow/core/public/session_options.h" 34 35 namespace tensorflow { 36 37 class Device; 38 struct MasterEnv; 39 40 // A session encapsulates a graph computation (resource allocation, 41 // placement, execution, etc.). 42 class MasterSession : public core::RefCounted { 43 public: 44 // This session encapsulates the graph computation for a graph. 45 // 46 // The session places nodes on devices in "remote_devs" and executes 47 // operations on these devices. 48 // 49 // The caller takes ownership of all remote devices. 50 MasterSession( 51 const SessionOptions& options, const MasterEnv* env, 52 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, 53 std::unique_ptr<WorkerCacheInterface> worker_cache, 54 std::unique_ptr<DeviceSet> device_set, 55 std::vector<string> filtered_worker_list, 56 StatsPublisherFactory stats_publisher_factory); 57 58 // Initialize the MasterSession for "def". Must be called before Extend(), 59 // Run(), or Close(). 60 // 61 // After this method returns, `def` will no longer be valid. 62 Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options); 63 64 // Returns the session handle. handle()65 const string& handle() const { return handle_; } 66 67 // Returns the last access time (the number of micro-seconds since 68 // some fixed point in time) of this session. last_access_time_usec()69 uint64 last_access_time_usec() const { return last_access_time_usec_.load(); } 70 71 // Attempt to extend the graph according to the given "req". 72 // (See master.proto for details of valid extensions.) 73 // 74 // PRECONDITION: The current version of this session's graph 75 // is "req->current_graph_version". 76 // 77 // POSTCONDITION: The current version of this session's graph 78 // is "resp->new_graph_version". 79 // 80 // Extend() may block the caller thread for a long time. 81 Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp); 82 83 // Setup a partial run call. 84 Status PartialRunSetup(const PartialRunSetupRequest* req, 85 PartialRunSetupResponse* resp); 86 87 // Run one step. 88 Status Run(CallOptions* opts, const RunStepRequestWrapper& req, 89 MutableRunStepResponseWrapper* resp); 90 91 Status ListDevices(ListDevicesResponse* resp) const; 92 93 Status MakeCallable(const MakeCallableRequest& req, 94 MakeCallableResponse* resp); 95 96 Status RunCallable(CallOptions* opts, const RunCallableRequest& req, 97 RunCallableResponse* resp); 98 99 Status ReleaseCallable(const ReleaseCallableRequest& req, 100 ReleaseCallableResponse* resp); 101 102 // Close this session and delete "*this". Returns OK if all known 103 // states are cleanup successfully. 104 // 105 // Close() may block the caller thread for a long time. 106 Status Close(); 107 108 // Close this session and release a reference on "*this". 109 // 110 // Note that, unlike Close(), this method does not block on the 111 // completion of all work. 112 void GarbageCollect(); 113 114 private: 115 SessionOptions session_opts_; 116 117 // Not owned. 118 const MasterEnv* env_; 119 120 // The opaque session handle. 121 const string handle_; 122 123 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_; 124 125 // The optional session-specific worker cluster. 126 // TODO(saeta): Convert to std::optional when available. 127 const std::unique_ptr<WorkerCacheInterface> worker_cache_; 128 // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. 129 WorkerCacheInterface* get_worker_cache() const; 130 131 // The device set used by this session. 132 std::unique_ptr<DeviceSet> devices_; 133 134 // The (partial device) names of remote worker tasks that this 135 // session will contact. 136 const std::vector<string> filtered_worker_list_; 137 138 StatsPublisherFactory stats_publisher_factory_; 139 140 std::atomic_ulong last_access_time_usec_; 141 142 std::atomic<int64> partial_run_handle_counter_ = {0}; 143 144 uint64 NewStepId(int64 graph_key); 145 146 mutex mu_; 147 std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_); 148 int64 graph_version_; 149 150 // We keep a map from a signature of a run request to the 151 // ReffedClientGraph the can execute it. We keep up to one old copy 152 // of each ReffedClientGraph around because if it gets deallocated 153 // before a new substitute has been created, Variables can go out of 154 // scope and lose their state. 155 class ReffedClientGraph; 156 typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap; 157 RCGMap run_graphs_ GUARDED_BY(mu_); 158 RCGMap partial_run_graphs_ GUARDED_BY(mu_); 159 int64 next_callable_handle_ GUARDED_BY(mu_) = 0; 160 RCGMap callables_ GUARDED_BY(mu_); 161 162 struct PerStepState { 163 bool collect_costs = false; 164 bool collect_timeline = false; 165 bool collect_rpcs = false; 166 bool collect_partition_graphs = false; 167 bool report_tensor_allocations_upon_oom = false; 168 Microseconds start_micros = Microseconds(0); 169 Microseconds end_micros = Microseconds(0); 170 std::vector<StepStats> step_stats; // per partition 171 StepStats rpc_stats; // for RPC layer 172 CostGraphDef cost_graph; 173 }; 174 175 struct RunState { 176 std::unordered_map<string, bool> pending_inputs; // true if fed 177 std::unordered_map<string, bool> pending_outputs; // true if fetched 178 ReffedClientGraph* rcg = nullptr; 179 uint64 step_id; 180 int64 collective_graph_key; 181 int64 count = 0; 182 PerStepState pss; 183 std::unique_ptr<ProfileHandler> ph; 184 bool step_started = false; 185 186 RunState(const std::vector<string>& input_names, 187 const std::vector<string>& output_names, ReffedClientGraph* rcg, 188 const uint64 step_id, const int64 count); 189 190 bool PendingDone() const; 191 192 ~RunState(); 193 }; 194 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ 195 GUARDED_BY(mu_); 196 197 // Active RunStep calls. 198 condition_variable num_running_is_zero_; 199 int32 num_running_ GUARDED_BY(mu_) = 0; 200 201 bool closed_ GUARDED_BY(mu_) = false; 202 bool garbage_collected_ GUARDED_BY(mu_) = false; 203 204 std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_); 205 206 // We need to ensure that certain nodes added (e.g., send and recv 207 // nodes) are unique across all sub-graphs within this session. 208 int64 next_node_id_ GUARDED_BY(mu_) = 0; 209 210 // Used to cancel running steps on Close(). 211 CancellationManager cancellation_manager_; 212 213 // Private dtor. The client must call Close(). 214 virtual ~MasterSession(); 215 216 // Creates sessions on all workers. 217 // 218 // If this session is operating using the new ClusterSpec propagation behavior 219 // call this method in order to propagate the cluster membership to all 220 // workers. 221 Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); 222 223 bool should_delete_worker_sessions_ = false; 224 Status DeleteWorkerSessions(); 225 226 Status StartStep(const BuildGraphOptions& opts, bool is_partial, 227 ReffedClientGraph** out_rcg, int64* out_count); 228 void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, 229 RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_); 230 void FillPerStepState(MasterSession::ReffedClientGraph* rcg, 231 const RunOptions& run_options, uint64 step_id, 232 int64 count, PerStepState* out_pss, 233 std::unique_ptr<ProfileHandler>* out_ph); 234 Status DoRunWithLocalExecution(CallOptions* opts, 235 const RunStepRequestWrapper& req, 236 MutableRunStepResponseWrapper* resp); 237 Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, 238 MutableRunStepResponseWrapper* resp); 239 Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, 240 const RunCallableRequest& req, 241 RunCallableResponse* resp); 242 Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id, 243 const RunOptions& run_options, PerStepState* pss, 244 const std::unique_ptr<ProfileHandler>& ph, 245 const Status& run_status, 246 RunMetadata* out_run_metadata); 247 248 void MarkRunCompletion(); 249 void UpdateLastAccessTime(); 250 251 Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); 252 253 Status CreateDebuggerState( 254 const DebugOptions& debug_options, const RunStepRequestWrapper& req, 255 int64 rcg_execution_count, 256 std::unique_ptr<DebuggerStateInterface>* debugger_state); 257 258 TF_DISALLOW_COPY_AND_ASSIGN(MasterSession); 259 }; 260 261 } // end namespace tensorflow 262 263 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ 264