• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
18 
19 #include <atomic>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "tensorflow/core/common_runtime/costmodel_manager.h"
27 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/executor.h"
31 #include "tensorflow/core/common_runtime/graph_execution_state.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/common_runtime/session_factory.h"
35 #include "tensorflow/core/framework/cancellation.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/graph.pb.h"
38 #include "tensorflow/core/framework/session_state.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45 #include "tensorflow/core/platform/types.h"
46 #include "tensorflow/core/public/session.h"
47 
48 namespace tensorflow {
49 
50 class CostModel;
51 class DebugGateway;
52 class Device;
53 class DirectSessionFactory;
54 
55 class DirectSession : public Session {
56  public:
57   typedef std::function<void(Session*)> CloseCallback;
58 
59   // Takes ownership of 'device_mgr'.
60   // 'factory' is used to unregister the DirectSession with 'factory' when its
61   // closed. This ensures that Reset requests from the 'factory' don't get sent
62   // to sessions that are already closed.
63   DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
64                 DirectSessionFactory* factory);
65   ~DirectSession() override;
66 
67   typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
68   typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
69 
70   ::tensorflow::Status Create(const GraphDef& graph) override;
71   ::tensorflow::Status Create(GraphDef&& graph) override;
72   ::tensorflow::Status Extend(const GraphDef& graph) override;
73   ::tensorflow::Status Extend(GraphDef&& graph) override;
74   ::tensorflow::Status Run(const NamedTensorList& inputs,
75                            const std::vector<string>& output_names,
76                            const std::vector<string>& target_nodes,
77                            std::vector<Tensor>* outputs) override;
78 
79   // NOTE: Experimental and subject to change.
80   ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
81                            const NamedTensorList& inputs,
82                            const std::vector<string>& output_names,
83                            const std::vector<string>& target_nodes,
84                            std::vector<Tensor>* outputs,
85                            RunMetadata* run_metadata) override;
86 
87   // NOTE: Experimental and subject to change.
88   ::tensorflow::Status Run(
89       const ::tensorflow::RunOptions& run_options,
90       const NamedTensorList& inputs, const std::vector<string>& output_names,
91       const std::vector<string>& target_nodes, std::vector<Tensor>* outputs,
92       RunMetadata* run_metadata,
93       const thread::ThreadPoolOptions& threadpool_options) override;
94 
95   // NOTE: PRunSetup and PRun are added to support partial execution. This
96   // feature is experimental and subject to change.
97   ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
98                                  const std::vector<string>& output_names,
99                                  const std::vector<string>& target_nodes,
100                                  string* handle) override;
101   ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
102                             const std::vector<string>& output_names,
103                             std::vector<Tensor>* outputs) override;
104 
105   // Reset clears 'containers' from the device_mgr of the DirectSession.
106   // If 'containers' is empty, then Reset clears the default container.
107   ::tensorflow::Status Reset(const std::vector<string>& containers);
108 
109   ::tensorflow::Status ListDevices(
110       std::vector<DeviceAttributes>* response) override;
111   ::tensorflow::Status Close() override;
LocalDeviceManager(const DeviceMgr ** output)112   ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
113     *output = device_mgr_.get();
114     return OkStatus();
115   }
116 
ExportCostModels(CostModelManager::CostModelMap * cost_models)117   void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
118     cost_model_manager_.ExportCostModels(cost_models);
119   }
120 
121   ::tensorflow::Status MakeCallable(const CallableOptions& callable_options,
122                                     CallableHandle* out_handle) override;
123 
124   ::tensorflow::Status RunCallable(CallableHandle handle,
125                                    const std::vector<Tensor>& feed_tensors,
126                                    std::vector<Tensor>* fetch_tensors,
127                                    RunMetadata* run_metadata) override;
128 
129   ::tensorflow::Status RunCallable(
130       CallableHandle handle, const std::vector<Tensor>& feed_tensors,
131       std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
132       const thread::ThreadPoolOptions& threadpool_options) override;
133 
134   ::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
135 
136   ::tensorflow::Status Finalize() override;
137 
options()138   const SessionOptions& options() const { return options_; }
139 
140  private:
141   // For access to collective_graph_key_.
142   friend class DirectSessionCollectiveTest;
143 
144   // We create one executor and its dependent library runtime for
145   // every partition.
146   struct PerPartitionExecutorsAndLib {
147     std::unique_ptr<Graph> graph = nullptr;
148     Device* device = nullptr;                // not owned.
149     FunctionLibraryRuntime* flib = nullptr;  // not owned.
150     std::unique_ptr<Executor> executor;
151   };
152 
153   // An ExecutorsAndKeys is created for a given set of feeds/fetches.
154   // 'step_count' is the number of times this graph is executed.
155   // 'graph' is the entire graph being executed. 'name_to_node'
156   // maps node name to node. We keep 'graph' and 'name_to_node' only in
157   // the case of partial runs. Each item in 'items' is the executor for
158   // a partition of the graph bundled with its dependent library runtime.
159   // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
160   // are rendezvous keys for the fetches.
161   struct ExecutorsAndKeys {
ExecutorsAndKeysExecutorsAndKeys162     ExecutorsAndKeys() : step_count(0) {}
163 
164     std::atomic_int_fast64_t step_count;
165     std::unique_ptr<Graph> graph;
166     NameNodeMap name_to_node;
167     std::vector<PerPartitionExecutorsAndLib> items;
168     std::unordered_map<string, size_t> input_name_to_index;
169     std::unordered_map<string, string> input_name_to_rendezvous_key;
170     std::unordered_map<string, size_t> output_name_to_index;
171     std::unordered_map<string, string> output_name_to_rendezvous_key;
172 
173     DataTypeVector input_types;
174     DataTypeVector output_types;
175 
176     CallableOptions callable_options;
177 
178     int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
179   };
180 
181   // A FunctionInfo object is created for every unique set of feeds/fetches.
182   // This info could be folded into the ExecutorsAndKeys object but we would
183   // like to maintain a deletion order in which the OpKernels (owned by the
184   // executor) should be destroyed first, followed by the resources in the
185   // device and then followed by the function stuff.
186   // TODO(rohanj): Consolidate function library definitions so that we can
187   // instantiate only one ProcFLR and lib_def and make this just a member
188   // variable and not a vector.
189   // 'flib_def' is the function library used.
190   // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
191   // device.
192   struct FunctionInfo {
193     std::unique_ptr<FunctionLibraryDefinition> flib_def;
194     std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
195   };
196 
197   // For each live Run() call, the session maintains a RunState.
198   // 'status' is the current status of the execution.
199   struct RunState {
200     mutex mu;
201     Status status TF_GUARDED_BY(mu);
202     std::unique_ptr<CollectiveExecutor::Handle> collective_executor;
203     std::unique_ptr<StepStatsCollector> collector;
204     TensorStore tensor_store;
205     ScopedStepContainer step_container;
206 
207     RunState(int64_t step_id, const std::vector<Device*>* devices);
208   };
209 
210   // For each live partial execution, the session maintains a PartialRunState.
211   // 'executor_done' is "notified" when all executors are done. 'pending_inputs'
212   // are the set of pending feeds and 'pending_outputs' are the set of pending
213   // fetches.
214   struct PartialRunState : public RunState {
215     Notification executors_done;
216     std::unordered_map<string, bool> pending_inputs;   // true if fed
217     std::unordered_map<string, bool> pending_outputs;  // true if fetched
218     core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr;
219 
220     PartialRunState(const std::vector<string>& pending_input_names,
221                     const std::vector<string>& pending_output_names,
222                     int64_t step_id, const std::vector<Device*>* devices);
223 
224     // Returns true if all pending inputs and outputs have been completed.
225     bool PendingDone() const;
226 
227     ~PartialRunState();
228   };
229 
230   struct RunStateArgs {
RunStateArgsRunStateArgs231     explicit RunStateArgs(const DebugOptions& options)
232         : debug_options(options) {}
233 
234     bool is_partial_run = false;
235     string handle;
236     std::unique_ptr<Graph> graph;
237     const DebugOptions& debug_options;
238     int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
239   };
240 
241   // Retrieves an already existing set of executors to run 'inputs' and
242   // 'outputs', or creates and caches them for future use.
243   ::tensorflow::Status GetOrCreateExecutors(
244       gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
245       gtl::ArraySlice<string> target_nodes,
246       ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
247 
248   // Creates a set of executors to run the subgraph defined by
249   // `callable_options`.
250   ::tensorflow::Status CreateExecutors(
251       const CallableOptions& callable_options,
252       std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
253       std::unique_ptr<FunctionInfo>* out_func_info,
254       RunStateArgs* run_state_args);
255 
256   // Creates several graphs given the existing graph_def_ and the
257   // input feeds and fetches, given 'devices'. The graphs share a common
258   // function library 'flib_def'.
259   ::tensorflow::Status CreateGraphs(
260       const BuildGraphOptions& options,
261       std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
262       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
263       RunStateArgs* run_state_args, DataTypeVector* input_types,
264       DataTypeVector* output_types, int64_t* collective_graph_key);
265 
266   ::tensorflow::Status RunInternal(
267       int64_t step_id, const RunOptions& run_options,
268       CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
269       RunMetadata* run_metadata,
270       const thread::ThreadPoolOptions& threadpool_options);
271 
272   // Returns whether inter-op execution uses a global pool or the input
273   // `run_options` requests being run on inter_op_thread_pool = 0 in case
274   // multiple pools are configured.
275   bool ShouldUseRunHandlerPool(const RunOptions& run_options) const;
276 
277   ::tensorflow::Status ExtendLocked(GraphDef&& graph)
278       TF_EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
279 
280   ::tensorflow::Status ResourceHandleToInputTensor(
281       const Tensor& resource_tensor, Tensor* retrieved_tensor);
282 
283   // Feeds more inputs to the executors, triggering further execution.
284   ::tensorflow::Status SendPRunInputs(
285       const std::vector<std::pair<string, Tensor>>& inputs,
286       const ExecutorsAndKeys* executors_and_keys,
287       IntraProcessRendezvous* rendez);
288 
289   // Fetches more outputs from the executors. It waits until the output
290   // tensors are computed.
291   ::tensorflow::Status RecvPRunOutputs(
292       const std::vector<string>& output_names,
293       const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
294       std::vector<Tensor>* outputs);
295 
296   // Check if the specified fetches can be computed from the feeds
297   // that we have already provided.
298   ::tensorflow::Status CheckFetch(
299       const std::vector<std::pair<string, Tensor>>& feeds,
300       const std::vector<string>& fetches,
301       const ExecutorsAndKeys* executors_and_keys,
302       const PartialRunState* run_state);
303 
304   // Use the appropriate WaitForNotification function based on whether
305   // operation_timeout_in_ms is greater than 0.
306   //
307   // If the timeout expires, the `cm->StartCancel()` will be called.
308   ::tensorflow::Status WaitForNotification(Notification* n,
309                                            int64_t timeout_in_ms);
310   void WaitForNotification(Notification* n, RunState* run_state,
311                            CancellationManager* cm, int64_t timeout_in_ms);
312 
CheckNotClosed()313   ::tensorflow::Status CheckNotClosed() {
314     mutex_lock l(closed_lock_);
315     if (closed_) return errors::Cancelled("Session has been closed.");
316     return OkStatus();
317   }
318 
CheckGraphCreated(const char * method)319   ::tensorflow::Status CheckGraphCreated(const char* method) {
320     mutex_lock l(graph_state_lock_);
321     if (!graph_created_) {
322       return errors::InvalidArgument(
323           "Session was not created with a graph before ", method, "!");
324     }
325     return OkStatus();
326   }
327 
328   ::tensorflow::Status CreateDebuggerState(
329       const CallableOptions& options, int64_t global_step,
330       int64_t session_run_index, int64_t executor_step_index,
331       std::unique_ptr<DebuggerStateInterface>* debugger_state);
332 
333   ::tensorflow::Status DecorateAndPublishGraphForDebug(
334       const DebugOptions& debug_options, Graph* graph, Device* device);
335 
336   const SessionOptions options_;
337 
338   // Device structures.
339   const std::unique_ptr<const DeviceMgr> device_mgr_;
340   std::vector<Device*> devices_;  // not owned
341   DeviceSet device_set_;
342 
343   // Unique session identifier.
344   string session_handle_;
345   mutex graph_state_lock_;
346   bool graph_created_ TF_GUARDED_BY(graph_state_lock_) = false;
347   bool finalized_ TF_GUARDED_BY(graph_state_lock_) = false;
348 
349   // The thread-pools to use for running ops, with a bool indicating if the pool
350   // is owned.
351   std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;
352 
353   Status init_error_;  // Set to an error if construction failed.
354 
355   // If true, blocks until device has finished all queued operations in a step.
356   bool sync_on_finish_ = true;
357 
358   std::vector<std::unique_ptr<FunctionInfo>> functions_
359       TF_GUARDED_BY(executor_lock_);
360 
361   mutex executor_lock_;  // protects executors_
362   // Holds mappings from signature to the executors that process
363   // it. The reason for a level of indirection around mapped_type is
364   // to guarantee address stability.
365   // The map value is a shared_ptr since multiple map keys can point to the
366   // same ExecutorsAndKey object.
367   std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
368       TF_GUARDED_BY(executor_lock_);
369 
370   class RunCallableCallFrame;
371   struct Callable {
372     std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
373     std::shared_ptr<FunctionInfo> function_info;
374     ~Callable();
375   };
376   mutex callables_lock_;
377   int64_t next_callable_handle_ TF_GUARDED_BY(callables_lock_) = 0;
378   std::unordered_map<int64_t, Callable> callables_
379       TF_GUARDED_BY(callables_lock_);
380 
381   // Holds mappings from handle to partial run state.
382   std::unordered_map<string, std::unique_ptr<PartialRunState>> partial_runs_
383       TF_GUARDED_BY(executor_lock_);
384 
385   // This holds all the tensors that are currently alive in the session.
386   SessionState session_state_;
387 
388   DirectSessionFactory* const factory_;  // not owned
389   CancellationManager* cancellation_manager_;
390   std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
391 
392   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
393   // is true, such as "params" and "queue" nodes.  Once placed these
394   // nodes can not be moved to a different device.  Maps node names to
395   // device names.
396   std::unordered_map<string, string> stateful_placements_
397       TF_GUARDED_BY(graph_state_lock_);
398 
399   // Execution_state; used when placing the entire graph.
400   std::unique_ptr<GraphExecutionState> execution_state_
401       TF_GUARDED_BY(graph_state_lock_);
402 
403   // The function library, before any rewrites or optimizations have been
404   // performed. In particular, CreateGraphs() may need to modify the function
405   // library; it copies and modifies the function library.
406   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
407 
408   // true if the Session has been Closed.
409   mutex closed_lock_;
410   bool closed_ TF_GUARDED_BY(closed_lock_) = false;
411 
412   // For generating unique names for this session instance.
413   std::atomic<int64_t> edge_name_counter_ = {0};
414   std::atomic<int64_t> handle_name_counter_ = {0};
415 
416   // For generating step ids that are unique among all sessions.
417   static std::atomic_int_fast64_t step_id_counter_;
418 
419   // Global timeout for all blocking operations in this session.
420   const int64_t operation_timeout_in_ms_ = 0;
421 
422   // Manages all the cost models for the graphs executed in this session.
423   CostModelManager cost_model_manager_;
424 
425   // For testing collective graph key generation.
426   mutex collective_graph_key_lock_;
427   int64_t collective_graph_key_ TF_GUARDED_BY(collective_graph_key_lock_) = -1;
428 
429   // Run in caller's thread if RunOptions.inter_op_thread_pool is negative or
430   // all of following conditions are met:
431   // 1. This session doesn't own any thread pool.
432   // 2. RunOptions.inter_op_thread_pool is unspecified or 0.
433   // 3. This session has a single executor.
434   // 4. config.inter_op_parallelism_threads is specified to negative explicitly
435   //    or through environment variable TF_NUM_INTEROP_THREADS.
436   // 5. RunOptions.experimental.use_run_handler_pool is unspecified or false.
437   // Otherwise run in global thread pool, session owned thread pool or handler
438   // pool according to other specifications of RunOptions and ConfigProto.
439   bool run_in_caller_thread_ = false;
440 
441   TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
442 
443   // EXPERIMENTAL: debugger (tfdbg) related
444   friend class DebugGateway;
445 };
446 
447 }  // end namespace tensorflow
448 
449 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
450