• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/build_graph_options.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_set.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/graph/costmodel.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 struct SessionOptions;
37 
38 namespace subgraph {
39 struct RewriteGraphMetadata;
40 }
41 
42 struct GraphExecutionStateOptions {
43   const DeviceSet* device_set = nullptr;
44   const SessionOptions* session_options = nullptr;
45   // Unique session identifier. Can be empty.
46   string session_handle;
47   // A map from node name to device name, representing the unchangeable
48   // placement of stateful nodes.
49   std::unordered_map<string, string> stateful_placements;
50 };
51 
52 // A ClientGraph is simply a sub-graph of the full graph as induced by
53 // BuildGraphOptions.
54 struct ClientGraph {
ClientGraphClientGraph55   explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
56                        DataTypeVector feed_types, DataTypeVector fetch_types,
57                        int64 collective_graph_key)
58       : flib_def(std::move(flib)),
59         graph(flib_def.get()),
60         feed_types(std::move(feed_types)),
61         fetch_types(std::move(fetch_types)),
62         collective_graph_key(collective_graph_key) {}
63   // Each client-graph gets its own function library since optimization passes
64   // post rewrite for execution might want to introduce new functions.
65   std::unique_ptr<FunctionLibraryDefinition> flib_def;
66   Graph graph;
67   DataTypeVector feed_types;
68   DataTypeVector fetch_types;
69   int64 collective_graph_key;
70 };
71 
72 // GraphExecutionState is responsible for generating an
73 // executable ClientGraph from the original GraphDef that specifies
74 // the complete graph and from BuildGraphOptions which specifies
75 // input/output nodes.
76 //
77 // An executable Graph differs from a GraphDef by being Placed,
78 // meaning that each Node is assigned to a single Device in the
79 // available set.
80 //
81 // When GraphExecutionState is first constructed it instantiates
82 // a full Graph from the provided GraphDef, and places it, using only
83 // the static device assignments from the GraphDef.  Nodes without are
84 // currently placed in a very naive way.  Since stateful Nodes cannot
85 // be moved after initial placement, it is important that stateful
86 // Nodes get sensible initial device assignments in the graph
87 // definition.
88 //
89 // Subsequently, GraphExecutionState generates a SimpleClientGraph on
90 // demand, which is a sub-graph of the latest placement of the full
91 // Graph.  MasterSession uses such a ClientGraph to execute one or
92 // more similar client requests.
93 //
94 // GraphExecutionState is thread-safe.
95 
96 class GraphExecutionState {
97  public:
98   virtual ~GraphExecutionState();
99 
100   // Creates a new `GraphExecutionState` for the given
101   // `graph_def`, which represents the entire graph for a session.
102   static Status MakeForBaseGraph(
103       GraphDef&& graph_def, const GraphExecutionStateOptions& options,
104       std::unique_ptr<GraphExecutionState>* out_state);
105 
106   // Creates a new `GraphExecutionState` and `SimpleClientGraph`
107   // for the subgraph of `original_graph_def` defined by
108   // `subgraph_options`.
109   static Status MakeForPrunedGraph(
110       const GraphExecutionState& base_execution_state,
111       const GraphExecutionStateOptions& options,
112       const BuildGraphOptions& subgraph_options,
113       std::unique_ptr<GraphExecutionState>* out_state,
114       std::unique_ptr<ClientGraph>* out_client_graph);
115 
116   // Creates a new GraphExecutionState representing the
117   // concatenation of this graph, and the graph defined by
118   // "extension_def". The same name may not be used to define a node
119   // in both this graph and "extension_def".
120   //
121   // If successful, returns OK and the caller takes ownership of "*out".
122   // Otherwise returns an error and does not modify "*out".
123   //
124   // After calling `old_state->Extend()`, `old_state` may no longer be
125   // used.
126   //
127   // NOTE(mrry): This method respects the placement of stateful nodes in
128   // in *this, but currently does not transfer any other placement
129   // or cost model information to the new graph.
130   Status Extend(const GraphDef& extension_def,
131                 std::unique_ptr<GraphExecutionState>* out) const;
132 
133   // Builds a ClientGraph (a sub-graph of the full graph as induced by
134   // the Node set specified in "options").  If successful, returns OK
135   // and the caller takes the ownership of "*out". Otherwise, returns
136   // an error.
137   Status BuildGraph(const BuildGraphOptions& options,
138                     std::unique_ptr<ClientGraph>* out);
139 
140   // Optimize the graph with the node set specified in `options`.
141   Status OptimizeGraph(
142       const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph,
143       std::unique_ptr<FunctionLibraryDefinition>* optimized_flib);
144 
145   // The graph returned by BuildGraph may contain only the pruned
146   // graph, whereas some clients may want access to the full graph.
full_graph()147   const Graph* full_graph() { return graph_; }
148 
149   // The original function library of this graph.
flib_def()150   const FunctionLibraryDefinition& flib_def() const { return *flib_def_; }
151 
152   // Returns the node with the given name, or null if it does not exist.
get_node_by_name(const string & name)153   const Node* get_node_by_name(const string& name) const {
154     NodeNameToCostIdMap::const_iterator iter =
155         node_name_to_cost_id_map_.find(name);
156     if (iter != node_name_to_cost_id_map_.end()) {
157       return graph_->FindNodeId(iter->second);
158     } else {
159       return nullptr;
160     }
161   }
162 
163   // Returns the map of stateful placements as a map of
164   // node name to placement string.
GetStatefulPlacements()165   std::unordered_map<string, string> GetStatefulPlacements() const {
166     return stateful_placements_;
167   }
168 
169  private:
170   GraphExecutionState(std::unique_ptr<GraphDef>&& graph_def,
171                       std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
172                       const GraphExecutionStateOptions& options);
173 
174   Status InitBaseGraph(std::unique_ptr<Graph>&& graph);
175 
176   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
177   // is true, such as "params" and "queue" nodes.  Once placed these
178   // nodes can not be moved to a different device.  Maps node names to
179   // device names.
180   std::unordered_map<string, string> stateful_placements_;  // Immutable after
181                                                             // ctor.
182   void SaveStatefulNodes(Graph* graph);
183   void RestoreStatefulNodes(Graph* graph);
184 
185   // Extract the subset of the graph that needs to be run, adding feed/fetch
186   // ops as needed.
187   Status PruneGraph(const BuildGraphOptions& options, Graph* graph,
188                     subgraph::RewriteGraphMetadata* out_rewrite_metadata);
189 
190   // The GraphExecutionState must store a copy of the original GraphDef if
191   // either of the following conditions holds:
192   //
193   // * `session_options_.config.graph_options().place_pruned_graph()` is true.
194   // * `session_options_.config.experimental().optimize_for_static_graph()` is
195   //   false.
196   const std::unique_ptr<GraphDef> original_graph_def_;
197 
198   const DeviceSet* device_set_;            // Not owned
199   const SessionOptions* session_options_;  // Not owned
200   // Unique session identifier. Can be empty.
201   string session_handle_;
202 
203   // Map from name to Node for the full graph in placed_.
204   NodeNameToCostIdMap node_name_to_cost_id_map_;
205 
206   // 'flib_def_' is initialized from the initial graph def's library,
207   // and may be updated by a graph optimization pass.
208   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
209 
210   // `rewrite_metadata_` is only set for GraphExecutionState
211   // objects created by `MakeForPrunedGraph()`.
212   std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
213 
214   // The dataflow graph owned by this object.
215   Graph* graph_;
216 
217   TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState);
218 };
219 
220 }  // namespace tensorflow
221 
222 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
223