• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
18 
19 #include <atomic>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/graph_execution_state.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/distributed_runtime/call_options.h"
27 #include "tensorflow/core/distributed_runtime/master_env.h"
28 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/protobuf/master.pb.h"
33 #include "tensorflow/core/public/session_options.h"
34 
35 namespace tensorflow {
36 
37 class Device;
38 struct MasterEnv;
39 
40 // A session encapsulates a graph computation (resource allocation,
41 // placement, execution, etc.).
42 class MasterSession : public core::RefCounted {
43  public:
44   // This session encapsulates the graph computation for a graph.
45   //
46   // The session places nodes on devices in "remote_devs" and executes
47   // operations on these devices.
48   //
49   // The caller takes ownership of all remote devices.
50   MasterSession(
51       const SessionOptions& options, const MasterEnv* env,
52       std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
53       std::unique_ptr<WorkerCacheInterface> worker_cache,
54       std::unique_ptr<DeviceSet> device_set,
55       std::vector<string> filtered_worker_list,
56       StatsPublisherFactory stats_publisher_factory);
57 
58   // Initialize the MasterSession for "def".  Must be called before Extend(),
59   // Run(), or Close().
60   //
61   // After this method returns, `def` will no longer be valid.
62   Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
63 
64   // Returns the session handle.
handle()65   const string& handle() const { return handle_; }
66 
67   // Returns the last access time (the number of micro-seconds since
68   // some fixed point in time) of this session.
last_access_time_usec()69   uint64 last_access_time_usec() const { return last_access_time_usec_.load(); }
70 
71   // Attempt to extend the graph according to the given "req".
72   // (See master.proto for details of valid extensions.)
73   //
74   // PRECONDITION: The current version of this session's graph
75   //   is "req->current_graph_version".
76   //
77   // POSTCONDITION: The current version of this session's graph
78   //   is "resp->new_graph_version".
79   //
80   // Extend() may block the caller thread for a long time.
81   Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
82 
83   // Setup a partial run call.
84   Status PartialRunSetup(const PartialRunSetupRequest* req,
85                          PartialRunSetupResponse* resp);
86 
87   // Run one step.
88   Status Run(CallOptions* opts, const RunStepRequestWrapper& req,
89              MutableRunStepResponseWrapper* resp);
90 
91   Status ListDevices(ListDevicesResponse* resp) const;
92 
93   Status MakeCallable(const MakeCallableRequest& req,
94                       MakeCallableResponse* resp);
95 
96   Status RunCallable(CallOptions* opts, const RunCallableRequest& req,
97                      RunCallableResponse* resp);
98 
99   Status ReleaseCallable(const ReleaseCallableRequest& req,
100                          ReleaseCallableResponse* resp);
101 
102   // Close this session and delete "*this". Returns OK if all known
103   // states are cleanup successfully.
104   //
105   // Close() may block the caller thread for a long time.
106   Status Close();
107 
108   // Close this session and release a reference on "*this".
109   //
110   // Note that, unlike Close(), this method does not block on the
111   // completion of all work.
112   void GarbageCollect();
113 
114  private:
115   SessionOptions session_opts_;
116 
117   // Not owned.
118   const MasterEnv* env_;
119 
120   // The opaque session handle.
121   const string handle_;
122 
123   std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
124 
125   // The optional session-specific worker cluster.
126   // TODO(saeta): Convert to std::optional when available.
127   const std::unique_ptr<WorkerCacheInterface> worker_cache_;
128   // Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
129   WorkerCacheInterface* get_worker_cache() const;
130 
131   // The device set used by this session.
132   std::unique_ptr<DeviceSet> devices_;
133 
134   // The (partial device) names of remote worker tasks that this
135   // session will contact.
136   const std::vector<string> filtered_worker_list_;
137 
138   StatsPublisherFactory stats_publisher_factory_;
139 
140   std::atomic_ulong last_access_time_usec_;
141 
142   std::atomic<int64> partial_run_handle_counter_ = {0};
143 
144   uint64 NewStepId(int64 graph_key);
145 
146   mutex mu_;
147   std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_);
148   int64 graph_version_;
149 
150   // We keep a map from a signature of a run request to the
151   // ReffedClientGraph the can execute it.  We keep up to one old copy
152   // of each ReffedClientGraph around because if it gets deallocated
153   // before a new substitute has been created, Variables can go out of
154   // scope and lose their state.
155   class ReffedClientGraph;
156   typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
157   RCGMap run_graphs_ GUARDED_BY(mu_);
158   RCGMap partial_run_graphs_ GUARDED_BY(mu_);
159   int64 next_callable_handle_ GUARDED_BY(mu_) = 0;
160   RCGMap callables_ GUARDED_BY(mu_);
161 
162   struct PerStepState {
163     bool collect_costs = false;
164     bool collect_timeline = false;
165     bool collect_rpcs = false;
166     bool collect_partition_graphs = false;
167     bool report_tensor_allocations_upon_oom = false;
168     Microseconds start_micros = Microseconds(0);
169     Microseconds end_micros = Microseconds(0);
170     std::vector<StepStats> step_stats;  // per partition
171     StepStats rpc_stats;                // for RPC layer
172     CostGraphDef cost_graph;
173   };
174 
175   struct RunState {
176     std::unordered_map<string, bool> pending_inputs;   // true if fed
177     std::unordered_map<string, bool> pending_outputs;  // true if fetched
178     ReffedClientGraph* rcg = nullptr;
179     uint64 step_id;
180     int64 collective_graph_key;
181     int64 count = 0;
182     PerStepState pss;
183     std::unique_ptr<ProfileHandler> ph;
184     bool step_started = false;
185 
186     RunState(const std::vector<string>& input_names,
187              const std::vector<string>& output_names, ReffedClientGraph* rcg,
188              const uint64 step_id, const int64 count);
189 
190     bool PendingDone() const;
191 
192     ~RunState();
193   };
194   std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
195       GUARDED_BY(mu_);
196 
197   // Active RunStep calls.
198   condition_variable num_running_is_zero_;
199   int32 num_running_ GUARDED_BY(mu_) = 0;
200 
201   bool closed_ GUARDED_BY(mu_) = false;
202   bool garbage_collected_ GUARDED_BY(mu_) = false;
203 
204   std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_);
205 
206   // We need to ensure that certain nodes added (e.g., send and recv
207   // nodes) are unique across all sub-graphs within this session.
208   int64 next_node_id_ GUARDED_BY(mu_) = 0;
209 
210   // Used to cancel running steps on Close().
211   CancellationManager cancellation_manager_;
212 
213   // Private dtor. The client must call Close().
214   virtual ~MasterSession();
215 
216   // Creates sessions on all workers.
217   //
218   // If this session is operating using the new ClusterSpec propagation behavior
219   // call this method in order to propagate the cluster membership to all
220   // workers.
221   Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
222 
223   bool should_delete_worker_sessions_ = false;
224   Status DeleteWorkerSessions();
225 
226   Status StartStep(const BuildGraphOptions& opts, bool is_partial,
227                    ReffedClientGraph** out_rcg, int64* out_count);
228   void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
229                       RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
230   void FillPerStepState(MasterSession::ReffedClientGraph* rcg,
231                         const RunOptions& run_options, uint64 step_id,
232                         int64 count, PerStepState* out_pss,
233                         std::unique_ptr<ProfileHandler>* out_ph);
234   Status DoRunWithLocalExecution(CallOptions* opts,
235                                  const RunStepRequestWrapper& req,
236                                  MutableRunStepResponseWrapper* resp);
237   Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
238                       MutableRunStepResponseWrapper* resp);
239   Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
240                        const RunCallableRequest& req,
241                        RunCallableResponse* resp);
242   Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id,
243                         const RunOptions& run_options, PerStepState* pss,
244                         const std::unique_ptr<ProfileHandler>& ph,
245                         const Status& run_status,
246                         RunMetadata* out_run_metadata);
247 
248   void MarkRunCompletion();
249   void UpdateLastAccessTime();
250 
251   Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
252 
253   Status CreateDebuggerState(
254       const DebugOptions& debug_options, const RunStepRequestWrapper& req,
255       int64 rcg_execution_count,
256       std::unique_ptr<DebuggerStateInterface>* debugger_state);
257 
258   TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
259 };
260 
261 }  // end namespace tensorflow
262 
263 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
264