1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/common_runtime/build_graph_options.h" 25 #include "tensorflow/core/common_runtime/device.h" 26 #include "tensorflow/core/common_runtime/device_set.h" 27 #include "tensorflow/core/framework/function.h" 28 #include "tensorflow/core/framework/graph.pb.h" 29 #include "tensorflow/core/graph/costmodel.h" 30 #include "tensorflow/core/graph/graph.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace tensorflow { 36 struct SessionOptions; 37 38 namespace subgraph { 39 struct RewriteGraphMetadata; 40 } 41 42 struct GraphExecutionStateOptions { 43 const DeviceSet* device_set = nullptr; 44 const SessionOptions* session_options = nullptr; 45 // Unique session identifier. Can be empty. 46 string session_handle; 47 // A map from node name to device name, representing the unchangeable 48 // placement of stateful nodes. 49 std::unordered_map<string, string> stateful_placements; 50 }; 51 52 // A ClientGraph is simply a sub-graph of the full graph as induced by 53 // BuildGraphOptions. 54 struct ClientGraph { ClientGraphClientGraph55 explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, 56 DataTypeVector feed_types, DataTypeVector fetch_types, 57 int64 collective_graph_key) 58 : flib_def(std::move(flib)), 59 graph(flib_def.get()), 60 feed_types(std::move(feed_types)), 61 fetch_types(std::move(fetch_types)), 62 collective_graph_key(collective_graph_key) {} 63 // Each client-graph gets its own function library since optimization passes 64 // post rewrite for execution might want to introduce new functions. 65 std::unique_ptr<FunctionLibraryDefinition> flib_def; 66 Graph graph; 67 DataTypeVector feed_types; 68 DataTypeVector fetch_types; 69 int64 collective_graph_key; 70 }; 71 72 // GraphExecutionState is responsible for generating an 73 // executable ClientGraph from the original GraphDef that specifies 74 // the complete graph and from BuildGraphOptions which specifies 75 // input/output nodes. 76 // 77 // An executable Graph differs from a GraphDef by being Placed, 78 // meaning that each Node is assigned to a single Device in the 79 // available set. 80 // 81 // When GraphExecutionState is first constructed it instantiates 82 // a full Graph from the provided GraphDef, and places it, using only 83 // the static device assignments from the GraphDef. Nodes without are 84 // currently placed in a very naive way. Since stateful Nodes cannot 85 // be moved after initial placement, it is important that stateful 86 // Nodes get sensible initial device assignments in the graph 87 // definition. 88 // 89 // Subsequently, GraphExecutionState generates a SimpleClientGraph on 90 // demand, which is a sub-graph of the latest placement of the full 91 // Graph. MasterSession uses such a ClientGraph to execute one or 92 // more similar client requests. 93 // 94 // GraphExecutionState is thread-safe. 95 96 class GraphExecutionState { 97 public: 98 virtual ~GraphExecutionState(); 99 100 // Creates a new `GraphExecutionState` for the given 101 // `graph_def`, which represents the entire graph for a session. 102 static Status MakeForBaseGraph( 103 GraphDef&& graph_def, const GraphExecutionStateOptions& options, 104 std::unique_ptr<GraphExecutionState>* out_state); 105 106 // Creates a new `GraphExecutionState` and `SimpleClientGraph` 107 // for the subgraph of `original_graph_def` defined by 108 // `subgraph_options`. 109 static Status MakeForPrunedGraph( 110 const GraphExecutionState& base_execution_state, 111 const GraphExecutionStateOptions& options, 112 const BuildGraphOptions& subgraph_options, 113 std::unique_ptr<GraphExecutionState>* out_state, 114 std::unique_ptr<ClientGraph>* out_client_graph); 115 116 // Creates a new GraphExecutionState representing the 117 // concatenation of this graph, and the graph defined by 118 // "extension_def". The same name may not be used to define a node 119 // in both this graph and "extension_def". 120 // 121 // If successful, returns OK and the caller takes ownership of "*out". 122 // Otherwise returns an error and does not modify "*out". 123 // 124 // After calling `old_state->Extend()`, `old_state` may no longer be 125 // used. 126 // 127 // NOTE(mrry): This method respects the placement of stateful nodes in 128 // in *this, but currently does not transfer any other placement 129 // or cost model information to the new graph. 130 Status Extend(const GraphDef& extension_def, 131 std::unique_ptr<GraphExecutionState>* out) const; 132 133 // Builds a ClientGraph (a sub-graph of the full graph as induced by 134 // the Node set specified in "options"). If successful, returns OK 135 // and the caller takes the ownership of "*out". Otherwise, returns 136 // an error. 137 Status BuildGraph(const BuildGraphOptions& options, 138 std::unique_ptr<ClientGraph>* out); 139 140 // Optimize the graph with the node set specified in `options`. 141 Status OptimizeGraph( 142 const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph, 143 std::unique_ptr<FunctionLibraryDefinition>* optimized_flib); 144 145 // The graph returned by BuildGraph may contain only the pruned 146 // graph, whereas some clients may want access to the full graph. full_graph()147 const Graph* full_graph() { return graph_; } 148 149 // The original function library of this graph. flib_def()150 const FunctionLibraryDefinition& flib_def() const { return *flib_def_; } 151 152 // Returns the node with the given name, or null if it does not exist. get_node_by_name(const string & name)153 const Node* get_node_by_name(const string& name) const { 154 NodeNameToCostIdMap::const_iterator iter = 155 node_name_to_cost_id_map_.find(name); 156 if (iter != node_name_to_cost_id_map_.end()) { 157 return graph_->FindNodeId(iter->second); 158 } else { 159 return nullptr; 160 } 161 } 162 163 // Returns the map of stateful placements as a map of 164 // node name to placement string. GetStatefulPlacements()165 std::unordered_map<string, string> GetStatefulPlacements() const { 166 return stateful_placements_; 167 } 168 169 private: 170 GraphExecutionState(std::unique_ptr<GraphDef>&& graph_def, 171 std::unique_ptr<FunctionLibraryDefinition>&& flib_def, 172 const GraphExecutionStateOptions& options); 173 174 Status InitBaseGraph(std::unique_ptr<Graph>&& graph); 175 176 // Map of placed stateful nodes, i.e. nodes for which is_stateful() 177 // is true, such as "params" and "queue" nodes. Once placed these 178 // nodes can not be moved to a different device. Maps node names to 179 // device names. 180 std::unordered_map<string, string> stateful_placements_; // Immutable after 181 // ctor. 182 void SaveStatefulNodes(Graph* graph); 183 void RestoreStatefulNodes(Graph* graph); 184 185 // Extract the subset of the graph that needs to be run, adding feed/fetch 186 // ops as needed. 187 Status PruneGraph(const BuildGraphOptions& options, Graph* graph, 188 subgraph::RewriteGraphMetadata* out_rewrite_metadata); 189 190 // The GraphExecutionState must store a copy of the original GraphDef if 191 // either of the following conditions holds: 192 // 193 // * `session_options_.config.graph_options().place_pruned_graph()` is true. 194 // * `session_options_.config.experimental().optimize_for_static_graph()` is 195 // false. 196 const std::unique_ptr<GraphDef> original_graph_def_; 197 198 const DeviceSet* device_set_; // Not owned 199 const SessionOptions* session_options_; // Not owned 200 // Unique session identifier. Can be empty. 201 string session_handle_; 202 203 // Map from name to Node for the full graph in placed_. 204 NodeNameToCostIdMap node_name_to_cost_id_map_; 205 206 // 'flib_def_' is initialized from the initial graph def's library, 207 // and may be updated by a graph optimization pass. 208 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 209 210 // `rewrite_metadata_` is only set for GraphExecutionState 211 // objects created by `MakeForPrunedGraph()`. 212 std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_; 213 214 // The dataflow graph owned by this object. 215 Graph* graph_; 216 217 TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState); 218 }; 219 220 } // namespace tensorflow 221 222 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 223