• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
18 
19 #include <atomic>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "tensorflow/core/common_runtime/costmodel_manager.h"
27 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/executor.h"
31 #include "tensorflow/core/common_runtime/graph_execution_state.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/common_runtime/session_factory.h"
35 #include "tensorflow/core/framework/cancellation.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/graph.pb.h"
38 #include "tensorflow/core/framework/session_state.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45 #include "tensorflow/core/platform/types.h"
46 #include "tensorflow/core/public/session.h"
47 
48 namespace tensorflow {
49 
50 class CostModel;
51 class DebugGateway;
52 class Device;
53 class DirectSessionFactory;
54 
55 class DirectSession : public Session {
56  public:
57   typedef std::function<void(Session*)> CloseCallback;
58 
59   // Takes ownership of 'device_mgr'.
60   // 'factory' is used to unregister the DirectSession with 'factory' when its
61   // closed. This ensures that Reset requests from the 'factory' don't get sent
62   // to sessions that are already closed.
63   DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
64                 DirectSessionFactory* factory);
65   ~DirectSession() override;
66 
67   typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
68   typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
69 
70   ::tensorflow::Status Create(const GraphDef& graph) override;
71   ::tensorflow::Status Extend(const GraphDef& graph) override;
72   ::tensorflow::Status Run(const NamedTensorList& inputs,
73                            const std::vector<string>& output_names,
74                            const std::vector<string>& target_nodes,
75                            std::vector<Tensor>* outputs) override;
76 
77   // NOTE: Experimental and subject to change.
78   ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
79                            const NamedTensorList& inputs,
80                            const std::vector<string>& output_names,
81                            const std::vector<string>& target_nodes,
82                            std::vector<Tensor>* outputs,
83                            RunMetadata* run_metadata) override;
84 
85   // NOTE: PRunSetup and PRun are added to support partial execution. This
86   // feature is experimental and subject to change.
87   ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
88                                  const std::vector<string>& output_names,
89                                  const std::vector<string>& target_nodes,
90                                  string* handle) override;
91   ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
92                             const std::vector<string>& output_names,
93                             std::vector<Tensor>* outputs) override;
94 
95   // Reset clears 'containers' from the device_mgr of the DirectSession.
96   // If 'containers' is empty, then Reset clears the default container.
97   ::tensorflow::Status Reset(const std::vector<string>& containers);
98 
99   ::tensorflow::Status ListDevices(
100       std::vector<DeviceAttributes>* response) override;
101   ::tensorflow::Status Close() override;
LocalDeviceManager(const DeviceMgr ** output)102   ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
103     *output = device_mgr_.get();
104     return ::tensorflow::Status::OK();
105   }
106 
ExportCostModels(CostModelManager::CostModelMap * cost_models)107   void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
108     cost_model_manager_.ExportCostModels(cost_models);
109   }
110 
111   ::tensorflow::Status MakeCallable(const CallableOptions& callable_options,
112                                     CallableHandle* out_handle) override;
113   ::tensorflow::Status RunCallable(CallableHandle handle,
114                                    const std::vector<Tensor>& feed_tensors,
115                                    std::vector<Tensor>* fetch_tensors,
116                                    RunMetadata* run_metadata) override;
117   ::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
118 
119  private:
120   // For access to collective_graph_key_.
121   friend class DirectSessionCollectiveTest;
122 
123   // We create one executor and its dependent library runtime for
124   // every partition.
125   struct PerPartitionExecutorsAndLib {
126     Graph* graph = nullptr;                  // not owned.
127     Device* device = nullptr;                // not owned.
128     FunctionLibraryRuntime* flib = nullptr;  // not owned.
129     std::unique_ptr<Executor> executor;
130   };
131 
132   // An ExecutorsAndKeys is created for a given set of feeds/fetches.
133   // 'step_count' is the number of times this graph is executed.
134   // 'graph' is the entire graph being executed. 'name_to_node'
135   // maps node name to node. We keep 'graph' and 'name_to_node' only in
136   // the case of partial runs. Each item in 'items' is the executor for
137   // a partition of the graph bundled with its dependent library runtime.
138   // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
139   // are rendezvous keys for the fetches.
140   struct ExecutorsAndKeys {
ExecutorsAndKeysExecutorsAndKeys141     ExecutorsAndKeys() : step_count(0) {}
142 
143     std::atomic_int_fast64_t step_count;
144     std::unique_ptr<Graph> graph;
145     NameNodeMap name_to_node;
146     std::vector<PerPartitionExecutorsAndLib> items;
147     std::unordered_map<string, size_t> input_name_to_index;
148     std::unordered_map<string, string> input_name_to_rendezvous_key;
149     std::unordered_map<string, size_t> output_name_to_index;
150     std::unordered_map<string, string> output_name_to_rendezvous_key;
151 
152     DataTypeVector input_types;
153     DataTypeVector output_types;
154 
155     CallableOptions callable_options;
156 
157     int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
158   };
159 
160   // A FunctionInfo object is created for every unique set of feeds/fetches.
161   // This info could be folded into the ExecutorsAndKeys object but we would
162   // like to maintain a deletion order in which the OpKernels (owned by the
163   // executor) should be destroyed first, followed by the resources in the
164   // device and then followed by the function stuff.
165   // TODO(rohanj): Consolidate function library definitions so that we can
166   // instantiate only one ProcFLR and lib_def and make this just a member
167   // variable and not a vector.
168   // 'flib_def' is the function library used.
169   // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
170   // device.
171   struct FunctionInfo {
172     std::unique_ptr<FunctionLibraryDefinition> flib_def;
173     std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
174   };
175 
176   // For each live partial execution, the session maintains a RunState.
177   // 'status' is the current status of this partial execution. 'executor_done'
178   // is "notified" when all executors are done. 'pending_inputs' are the set
179   // of pending feeds and 'pending_outputs' are the set of pending fetches.
180   struct RunState {
181     mutex mu_;
182     Status status GUARDED_BY(mu_);
183     IntraProcessRendezvous* rendez = nullptr;
184     std::unique_ptr<CollectiveExecutor::Handle> collective_executor;
185     std::unique_ptr<StepStatsCollector> collector;
186     Notification executors_done;
187     std::unordered_map<string, bool> pending_inputs;   // true if fed
188     std::unordered_map<string, bool> pending_outputs;  // true if fetched
189     TensorStore tensor_store;
190     ScopedStepContainer step_container;
191 
192     RunState(int64 step_id, const std::vector<Device*>* devices);
193 
194     RunState(const std::vector<string>& pending_input_names,
195              const std::vector<string>& pending_output_names, int64 step_id,
196              const std::vector<Device*>* devices);
197 
198     // Returns true if all pending inputs and outputs have been completed.
199     bool PendingDone() const;
200 
201     ~RunState();
202   };
203 
204   struct RunStateArgs {
RunStateArgsRunStateArgs205     RunStateArgs(const DebugOptions& options) : debug_options(options) {}
206 
207     bool is_partial_run = false;
208     string handle;
209     std::unique_ptr<Graph> graph;
210     const DebugOptions& debug_options;
211     int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
212   };
213 
214   // Initializes the base execution state given the 'graph',
215   // if not already initialized.
216   Status MaybeInitializeExecutionState(const GraphDef& graph,
217                                        bool* out_already_initialized)
218       EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
219 
220   // Retrieves an already existing set of executors to run 'inputs' and
221   // 'outputs', or creates and caches them for future use.
222   ::tensorflow::Status GetOrCreateExecutors(
223       gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
224       gtl::ArraySlice<string> target_nodes,
225       ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
226 
227   // Creates a set of executors to run the subgraph defined by
228   // `callable_options`.
229   ::tensorflow::Status CreateExecutors(
230       const CallableOptions& callable_options,
231       std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
232       std::unique_ptr<FunctionInfo>* out_func_info,
233       RunStateArgs* run_state_args);
234 
235   // Creates several graphs given the existing graph_def_ and the
236   // input feeds and fetches, given 'devices'. The graphs share a common
237   // function library 'flib_def'.
238   ::tensorflow::Status CreateGraphs(
239       const BuildGraphOptions& options,
240       std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
241       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
242       RunStateArgs* run_state_args, DataTypeVector* input_types,
243       DataTypeVector* output_types, int64* collective_graph_key);
244 
245   ::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
246                                    CallFrameInterface* call_frame,
247                                    ExecutorsAndKeys* executors_and_keys,
248                                    RunMetadata* run_metadata);
249 
250   // Returns whether inter-op execution uses a global pool or the input
251   // `run_options` requests being run on inter_op_thread_pool = 0 in case
252   // multiple pools are configured.
253   bool ShouldUseRunHandlerPool(const RunOptions& run_options) const;
254 
255   ::tensorflow::Status ExtendLocked(const GraphDef& graph)
256       EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
257 
258   ::tensorflow::Status ResourceHandleToInputTensor(
259       const Tensor& resource_tensor, Tensor* retrieved_tensor);
260 
261   // Feeds more inputs to the executors, triggering further execution.
262   ::tensorflow::Status SendPRunInputs(
263       const std::vector<std::pair<string, Tensor>>& inputs,
264       const ExecutorsAndKeys* executors_and_keys,
265       IntraProcessRendezvous* rendez);
266 
267   // Fetches more outputs from the executors. It waits until the output
268   // tensors are computed.
269   ::tensorflow::Status RecvPRunOutputs(
270       const std::vector<string>& output_names,
271       const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
272       std::vector<Tensor>* outputs);
273 
274   // Check if the specified fetches can be computed from the feeds
275   // that we have already provided.
276   ::tensorflow::Status CheckFetch(
277       const std::vector<std::pair<string, Tensor>>& feeds,
278       const std::vector<string>& fetches,
279       const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
280 
281   // Use the appropriate WaitForNotification function based on whether
282   // operation_timeout_in_ms is greater than 0.
283   //
284   // If the timeout expires, the `cm->StartCancel()` will be called.
285   ::tensorflow::Status WaitForNotification(Notification* n,
286                                            int64 timeout_in_ms);
287   void WaitForNotification(RunState* run_state, CancellationManager* cm,
288                            int64 timeout_in_ms);
289 
CheckNotClosed()290   ::tensorflow::Status CheckNotClosed() {
291     mutex_lock l(closed_lock_);
292     if (closed_) return errors::Cancelled("Session has been closed.");
293     return ::tensorflow::Status::OK();
294   }
295 
CheckGraphCreated(const char * method)296   ::tensorflow::Status CheckGraphCreated(const char* method) {
297     mutex_lock l(graph_state_lock_);
298     if (!graph_created_) {
299       return errors::InvalidArgument(
300           "Session was not created with a graph before ", method, "!");
301     }
302     return ::tensorflow::Status::OK();
303   }
304 
305   ::tensorflow::Status CreateDebuggerState(
306       const CallableOptions& options, int64 global_step,
307       int64 session_run_index, int64 executor_step_index,
308       std::unique_ptr<DebuggerStateInterface>* debugger_state);
309 
310   ::tensorflow::Status DecorateAndPublishGraphForDebug(
311       const DebugOptions& debug_options, Graph* graph, Device* device);
312 
313   const SessionOptions options_;
314 
315   // Device structures.
316   const std::unique_ptr<const DeviceMgr> device_mgr_;
317   std::vector<Device*> devices_;  // not owned
318   DeviceSet device_set_;
319 
320   // Unique session identifier.
321   string session_handle_;
322   mutex graph_state_lock_;
323   bool graph_created_ GUARDED_BY(graph_state_lock_) = false;
324 
325   // The thread-pools to use for running ops, with a bool indicating if the pool
326   // is owned.
327   std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;
328 
329   Status init_error_;  // Set to an error if construction failed.
330 
331   // If true, blocks until device has finished all queued operations in a step.
332   bool sync_on_finish_ = true;
333   // Schedules 'c' for execution on pool.
334   void SchedClosure(thread::ThreadPool* pool, std::function<void()> c);
335 
336   std::vector<std::unique_ptr<FunctionInfo>> functions_
337       GUARDED_BY(executor_lock_);
338 
339   mutex executor_lock_;  // protects executors_
340   // Holds mappings from signature to the executors that process
341   // it. The reason for a level of indirection around mapped_type is
342   // to guarantee address stability.
343   // The map value is a shared_ptr since multiple map keys can point to the
344   // same ExecutorsAndKey object.
345   std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
346       GUARDED_BY(executor_lock_);
347 
348   class RunCallableCallFrame;
349   struct Callable {
350     std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
351     std::shared_ptr<FunctionInfo> function_info;
352     ~Callable();
353   };
354   mutex callables_lock_;
355   int64 next_callable_handle_ GUARDED_BY(callables_lock_) = 0;
356   std::unordered_map<int64, Callable> callables_ GUARDED_BY(callables_lock_);
357 
358   // Holds mappings from handle to partial run state.
359   std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
360       GUARDED_BY(executor_lock_);
361 
362   // This holds all the tensors that are currently alive in the session.
363   SessionState session_state_;
364 
365   DirectSessionFactory* const factory_;  // not owned
366   CancellationManager* cancellation_manager_;
367   std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
368 
369   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
370   // is true, such as "params" and "queue" nodes.  Once placed these
371   // nodes can not be moved to a different device.  Maps node names to
372   // device names.
373   std::unordered_map<string, string> stateful_placements_
374       GUARDED_BY(graph_state_lock_);
375 
376   // Execution_state; used when placing the entire graph.
377   std::unique_ptr<GraphExecutionState> execution_state_
378       GUARDED_BY(graph_state_lock_);
379 
380   // The function library, before any rewrites or optimizations have been
381   // performed. In particular, CreateGraphs() may need to modify the function
382   // library; it copies and modifies the function library.
383   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
384 
385   // true if the Session has been Closed.
386   mutex closed_lock_;
387   bool closed_ GUARDED_BY(closed_lock_) = false;
388 
389   // For generating unique names for this session instance.
390   std::atomic<int64> edge_name_counter_ = {0};
391   std::atomic<int64> handle_name_counter_ = {0};
392 
393   // For generating step ids that are unique across this sessions.
394   static std::atomic_int_fast64_t step_id_counter_;
395 
396   // Global timeout for all blocking operations in this session.
397   const int64 operation_timeout_in_ms_ = 0;
398 
399   // Manages all the cost models for the graphs executed in this session.
400   CostModelManager cost_model_manager_;
401 
402   // For testing collective graph key generation.
403   mutex collective_graph_key_lock_;
404   int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
405 
406   TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
407 
408   // EXPERIMENTAL: debugger (tfdbg) related
409   friend class DebugGateway;
410 };
411 
412 }  // end namespace tensorflow
413 
414 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
415