• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/build_graph_options.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_set.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/graph/costmodel.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 struct SessionOptions;
37 
38 namespace subgraph {
39 struct RewriteGraphMetadata;
40 }
41 
42 struct GraphExecutionStateOptions {
43   const DeviceSet* device_set = nullptr;
44   const SessionOptions* session_options = nullptr;
45   // Unique session identifier. Can be empty.
46   string session_handle;
47   // A map from node name to device name, representing the unchangeable
48   // placement of stateful nodes.
49   std::unordered_map<string, string> stateful_placements;
50 };
51 
52 // A ClientGraph is simply a sub-graph of the full graph as induced by
53 // BuildGraphOptions.
54 struct ClientGraph {
ClientGraphClientGraph55   explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
56                        DataTypeVector feed_types, DataTypeVector fetch_types,
57                        int64_t collective_graph_key)
58       : flib_def(std::move(flib)),
59         graph(flib_def.get()),
60         feed_types(std::move(feed_types)),
61         fetch_types(std::move(fetch_types)),
62         collective_graph_key(collective_graph_key) {}
63   // Each client-graph gets its own function library since optimization passes
64   // post rewrite for execution might want to introduce new functions.
65   std::unique_ptr<FunctionLibraryDefinition> flib_def;
66   Graph graph;
67   DataTypeVector feed_types;
68   DataTypeVector fetch_types;
69   int64 collective_graph_key;
70 };
71 
72 // GraphExecutionState is responsible for generating an
73 // executable ClientGraph from the original GraphDef that specifies
74 // the complete graph and from BuildGraphOptions which specifies
75 // input/output nodes.
76 //
77 // An executable Graph differs from a GraphDef by being Placed,
78 // meaning that each Node is assigned to a single Device in the
79 // available set.
80 //
81 // When GraphExecutionState is first constructed it instantiates
82 // a full Graph from the provided GraphDef, and places it, using only
83 // the static device assignments from the GraphDef.  Nodes without are
84 // currently placed in a very naive way.  Since stateful Nodes cannot
85 // be moved after initial placement, it is important that stateful
86 // Nodes get sensible initial device assignments in the graph
87 // definition.
88 //
89 // Subsequently, GraphExecutionState generates a SimpleClientGraph on
90 // demand, which is a sub-graph of the latest placement of the full
91 // Graph.  MasterSession uses such a ClientGraph to execute one or
92 // more similar client requests.
93 //
94 // GraphExecutionState is thread-safe.
95 
96 class GraphExecutionState {
97  public:
98   virtual ~GraphExecutionState();
99 
100   // Creates a new `GraphExecutionState` for the given
101   // `graph_def`, which represents the entire graph for a session.
102   static Status MakeForBaseGraph(
103       GraphDef&& graph_def, const GraphExecutionStateOptions& options,
104       std::unique_ptr<GraphExecutionState>* out_state);
105 
106   // Creates a new `GraphExecutionState` and `SimpleClientGraph`
107   // for the subgraph of `original_graph_def` defined by
108   // `subgraph_options`.
109   static Status MakeForPrunedGraph(
110       const GraphExecutionState& base_execution_state,
111       const GraphExecutionStateOptions& options,
112       const BuildGraphOptions& subgraph_options,
113       std::unique_ptr<GraphExecutionState>* out_state,
114       std::unique_ptr<ClientGraph>* out_client_graph);
115 
116   // Creates a new GraphExecutionState representing the
117   // concatenation of this graph, and the graph defined by
118   // "extension_def". The same name may not be used to define a node
119   // in both this graph and "extension_def".
120   //
121   // If successful, returns OK and the caller takes ownership of "*out".
122   // Otherwise returns an error and does not modify "*out".
123   //
124   // After calling `old_state->Extend()`, `old_state` may no longer be
125   // used.
126   //
127   // NOTE(mrry): This method respects the placement of stateful nodes in
128   // in *this, but currently does not transfer any other placement
129   // or cost model information to the new graph.
130   Status Extend(const GraphDef& extension_def,
131                 std::unique_ptr<GraphExecutionState>* out) const;
132 
133   // Builds a ClientGraph (a sub-graph of the full graph as induced by
134   // the Node set specified in "options").  If successful, returns OK
135   // and the caller takes the ownership of "*out". Otherwise, returns
136   // an error.
137   Status BuildGraph(const BuildGraphOptions& options,
138                     std::unique_ptr<ClientGraph>* out);
139 
140   // Optimize the graph with the node set specified in `options`.
141   Status OptimizeGraph(
142       const BuildGraphOptions& options, const Graph& graph,
143       const FunctionLibraryDefinition* flib_def,
144       std::unique_ptr<Graph>* optimized_graph,
145       std::unique_ptr<FunctionLibraryDefinition>* optimized_flib);
146 
147   // The graph returned by BuildGraph may contain only the pruned
148   // graph, whereas some clients may want access to the full graph.
full_graph()149   const Graph* full_graph() { return graph_; }
150 
151   // The original function library of this graph.
flib_def()152   const FunctionLibraryDefinition& flib_def() const { return *flib_def_; }
153 
154   // Returns the node with the given name, or null if it does not exist.
get_node_by_name(const string & name)155   const Node* get_node_by_name(const string& name) const {
156     NodeNameToCostIdMap::const_iterator iter =
157         node_name_to_cost_id_map_.find(name);
158     if (iter != node_name_to_cost_id_map_.end()) {
159       return graph_->FindNodeId(iter->second);
160     } else {
161       return nullptr;
162     }
163   }
164 
165   // Returns the map of stateful placements as a map of
166   // node name to placement string.
GetStatefulPlacements()167   std::unordered_map<string, string> GetStatefulPlacements() const {
168     return stateful_placements_;
169   }
170 
171  private:
172   GraphExecutionState(std::unique_ptr<GraphDef>&& graph_def,
173                       std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
174                       const GraphExecutionStateOptions& options);
175 
176   Status InitBaseGraph(std::unique_ptr<Graph>&& graph);
177 
178   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
179   // is true, such as "params" and "queue" nodes.  Once placed these
180   // nodes can not be moved to a different device.  Maps node names to
181   // device names.
182   std::unordered_map<string, string> stateful_placements_;  // Immutable after
183                                                             // ctor.
184   void SaveStatefulNodes(Graph* graph);
185   void RestoreStatefulNodes(Graph* graph);
186 
187   // Extract the subset of the graph that needs to be run, adding feed/fetch
188   // ops as needed.
189   Status PruneGraph(const BuildGraphOptions& options, Graph* graph,
190                     subgraph::RewriteGraphMetadata* out_rewrite_metadata);
191 
192   // The GraphExecutionState must store a copy of the original GraphDef if
193   // either of the following conditions holds:
194   //
195   // * `session_options_.config.graph_options().place_pruned_graph()` is true.
196   // * `session_options_.config.experimental().optimize_for_static_graph()` is
197   //   false.
198   const std::unique_ptr<GraphDef> original_graph_def_;
199 
200   const DeviceSet* device_set_;            // Not owned
201   const SessionOptions* session_options_;  // Not owned
202   // Unique session identifier. Can be empty.
203   string session_handle_;
204 
205   // Map from name to Node for the full graph in placed_.
206   NodeNameToCostIdMap node_name_to_cost_id_map_;
207 
208   // 'flib_def_' is initialized from the initial graph def's library,
209   // and may be updated by a graph optimization pass.
210   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
211 
212   // `rewrite_metadata_` is only set for GraphExecutionState
213   // objects created by `MakeForPrunedGraph()`.
214   std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
215 
216   // The dataflow graph owned by this object.
217   Graph* graph_;
218 
219   TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState);
220 };
221 
222 }  // namespace tensorflow
223 
224 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
225