• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
18 
19 #include <atomic>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/graph_execution_state.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/distributed_runtime/call_options.h"
27 #include "tensorflow/core/distributed_runtime/master_env.h"
28 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/protobuf/master.pb.h"
33 #include "tensorflow/core/public/session_options.h"
34 
35 namespace tensorflow {
36 
37 class Device;
38 struct MasterEnv;
39 
40 // A session encapsulates a graph computation (resource allocation,
41 // placement, execution, etc.).
42 class MasterSession : public core::RefCounted {
43  public:
44   // This session encapsulates the graph computation for a graph.
45   //
46   // The session places nodes on devices in "remote_devs" and executes
47   // operations on these devices.
48   //
49   // The caller takes ownership of all remote devices.
50   MasterSession(
51       const SessionOptions& options, const MasterEnv* env,
52       std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
53       std::unique_ptr<WorkerCacheInterface> worker_cache,
54       std::unique_ptr<DeviceSet> device_set,
55       std::vector<string> filtered_worker_list,
56       StatsPublisherFactory stats_publisher_factory);
57 
58   // Initialize the MasterSession for "def".  Must be called before Extend(),
59   // Run(), or Close().
60   Status Create(GraphDef&& def, const WorkerCacheFactoryOptions& options);
61 
62   // Returns the session handle.
handle()63   const string& handle() const { return handle_; }
64 
65   // Returns the last access time (the number of micro-seconds since
66   // some fixed point in time) of this session.
last_access_time_usec()67   uint64 last_access_time_usec() const { return last_access_time_usec_.load(); }
68 
69   // Attempt to extend the graph according to the given "req".
70   // (See master.proto for details of valid extensions.)
71   //
72   // PRECONDITION: The current version of this session's graph
73   //   is "req->current_graph_version".
74   //
75   // POSTCONDITION: The current version of this session's graph
76   //   is "resp->new_graph_version".
77   //
78   // Extend() may block the caller thread for a long time.
79   Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
80 
81   // Setup a partial run call.
82   Status PartialRunSetup(const PartialRunSetupRequest* req,
83                          PartialRunSetupResponse* resp);
84 
85   // Run one step.
86   Status Run(CallOptions* opts, const RunStepRequestWrapper& req,
87              MutableRunStepResponseWrapper* resp);
88 
89   Status ListDevices(ListDevicesResponse* resp) const;
90 
91   Status MakeCallable(const MakeCallableRequest& req,
92                       MakeCallableResponse* resp);
93 
94   Status RunCallable(CallOptions* opts, const RunCallableRequest& req,
95                      RunCallableResponse* resp);
96 
97   Status ReleaseCallable(const ReleaseCallableRequest& req,
98                          ReleaseCallableResponse* resp);
99 
100   // Close this session and delete "*this". Returns OK if all known
101   // states are cleanup successfully.
102   //
103   // Close() may block the caller thread for a long time.
104   Status Close();
105 
106   // Close this session and release a reference on "*this".
107   //
108   // Note that, unlike Close(), this method does not block on the
109   // completion of all work.
110   void GarbageCollect();
111 
112  private:
113   SessionOptions session_opts_;
114 
115   // Not owned.
116   const MasterEnv* env_;
117 
118   // The opaque session handle.
119   const string handle_;
120 
121   std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
122 
123   // The optional session-specific worker cluster.
124   // TODO(saeta): Convert to std::optional when available.
125   const std::unique_ptr<WorkerCacheInterface> worker_cache_;
126   // Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
127   WorkerCacheInterface* get_worker_cache() const;
128 
129   // The device set used by this session.
130   std::unique_ptr<DeviceSet> devices_;
131 
132   // The (partial device) names of remote worker tasks that this
133   // session will contact.
134   const std::vector<string> filtered_worker_list_;
135 
136   StatsPublisherFactory stats_publisher_factory_;
137 
138   std::atomic_ulong last_access_time_usec_;
139 
140   std::atomic<int64> partial_run_handle_counter_ = {0};
141 
142   uint64 NewStepId(int64 graph_key);
143 
144   mutex mu_;
145   std::unique_ptr<GraphExecutionState> execution_state_ TF_GUARDED_BY(mu_);
146   int64 graph_version_;
147 
148   // We keep a map from a signature of a run request to the
149   // ReffedClientGraph the can execute it.  We keep up to one old copy
150   // of each ReffedClientGraph around because if it gets deallocated
151   // before a new substitute has been created, Variables can go out of
152   // scope and lose their state.
153   class ReffedClientGraph;
154   typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
155   RCGMap run_graphs_ TF_GUARDED_BY(mu_);
156   RCGMap partial_run_graphs_ TF_GUARDED_BY(mu_);
157   int64 next_callable_handle_ TF_GUARDED_BY(mu_) = 0;
158   RCGMap callables_ TF_GUARDED_BY(mu_);
159 
160   struct PerStepState {
161     bool collect_costs = false;
162     bool collect_timeline = false;
163     bool collect_rpcs = false;
164     bool collect_partition_graphs = false;
165     bool report_tensor_allocations_upon_oom = false;
166     Microseconds start_micros = Microseconds(0);
167     Microseconds end_micros = Microseconds(0);
168     std::vector<StepStats> step_stats;  // per partition
169     StepStats rpc_stats;                // for RPC layer
170     CostGraphDef cost_graph;
171   };
172 
173   struct RunState {
174     std::unordered_map<string, bool> pending_inputs;   // true if fed
175     std::unordered_map<string, bool> pending_outputs;  // true if fetched
176     ReffedClientGraph* rcg = nullptr;
177     uint64 step_id;
178     int64 collective_graph_key;
179     int64 count = 0;
180     PerStepState pss;
181     std::unique_ptr<ProfileHandler> ph;
182     bool step_started = false;
183 
184     RunState(const std::vector<string>& input_names,
185              const std::vector<string>& output_names, ReffedClientGraph* rcg,
186              const uint64 step_id, const int64 count);
187 
188     bool PendingDone() const;
189 
190     ~RunState();
191   };
192   std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
193       TF_GUARDED_BY(mu_);
194 
195   // Active RunStep calls.
196   condition_variable num_running_is_zero_;
197   int32 num_running_ TF_GUARDED_BY(mu_) = 0;
198 
199   bool closed_ TF_GUARDED_BY(mu_) = false;
200   bool garbage_collected_ TF_GUARDED_BY(mu_) = false;
201 
202   std::unordered_map<uint64, int64> subgraph_execution_counts_
203       TF_GUARDED_BY(mu_);
204 
205   // We need to ensure that certain nodes added (e.g., send and recv
206   // nodes) are unique across all sub-graphs within this session.
207   int64 next_node_id_ TF_GUARDED_BY(mu_) = 0;
208 
209   // Used to cancel running steps on Close().
210   CancellationManager cancellation_manager_;
211 
212   // Private dtor. The client must call Close().
213   virtual ~MasterSession();
214 
215   // Creates sessions on all workers.
216   //
217   // If this session is operating using the new ClusterSpec propagation behavior
218   // call this method in order to propagate the cluster membership to all
219   // workers.
220   Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
221 
222   bool should_delete_worker_sessions_ = false;
223   Status DeleteWorkerSessions();
224 
225   Status StartStep(const BuildGraphOptions& opts, bool is_partial,
226                    ReffedClientGraph** out_rcg, int64* out_count);
227   void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
228                       RCGMap* rcg_map) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
229   void FillPerStepState(MasterSession::ReffedClientGraph* rcg,
230                         const RunOptions& run_options, uint64 step_id,
231                         int64 count, PerStepState* out_pss,
232                         std::unique_ptr<ProfileHandler>* out_ph);
233   Status DoRunWithLocalExecution(CallOptions* opts,
234                                  const RunStepRequestWrapper& req,
235                                  MutableRunStepResponseWrapper* resp);
236   Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
237                       MutableRunStepResponseWrapper* resp);
238   Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
239                        const RunCallableRequest& req,
240                        RunCallableResponse* resp);
241   Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id,
242                         const RunOptions& run_options, PerStepState* pss,
243                         const std::unique_ptr<ProfileHandler>& ph,
244                         const Status& run_status,
245                         RunMetadata* out_run_metadata);
246 
247   void MarkRunCompletion();
248   void UpdateLastAccessTime();
249 
250   Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
251 
252   Status CreateDebuggerState(
253       const DebugOptions& debug_options, const RunStepRequestWrapper& req,
254       int64 rcg_execution_count,
255       std::unique_ptr<DebuggerStateInterface>* debugger_state);
256 
257   TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
258 };
259 
260 }  // end namespace tensorflow
261 
262 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
263