1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/common_runtime/build_graph_options.h" 25 #include "tensorflow/core/common_runtime/device.h" 26 #include "tensorflow/core/common_runtime/device_set.h" 27 #include "tensorflow/core/framework/graph.pb.h" 28 #include "tensorflow/core/graph/costmodel.h" 29 #include "tensorflow/core/graph/graph.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace tensorflow { 35 struct SessionOptions; 36 37 namespace subgraph { 38 struct RewriteGraphMetadata; 39 } 40 41 struct GraphExecutionStateOptions { 42 const DeviceSet* device_set = nullptr; 43 const SessionOptions* session_options = nullptr; 44 // Unique session identifier. Can be empty. 45 string session_handle; 46 // A map from node name to device name, representing the unchangeable 47 // placement of stateful nodes. 48 std::unordered_map<string, string> stateful_placements; 49 }; 50 51 // A ClientGraph is simply a sub-graph of the full graph as induced by 52 // BuildGraphOptions. 53 struct ClientGraph { ClientGraphClientGraph54 explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, 55 DataTypeVector feed_types, DataTypeVector fetch_types, 56 int64 collective_graph_key) 57 : flib_def(std::move(flib)), 58 graph(flib_def.get()), 59 feed_types(std::move(feed_types)), 60 fetch_types(std::move(fetch_types)), 61 collective_graph_key(collective_graph_key) {} 62 // Each client-graph gets its own function library since optimization passes 63 // post rewrite for execution might want to introduce new functions. 64 std::unique_ptr<FunctionLibraryDefinition> flib_def; 65 Graph graph; 66 DataTypeVector feed_types; 67 DataTypeVector fetch_types; 68 int64 collective_graph_key; 69 }; 70 71 // GraphExecutionState is responsible for generating an 72 // executable ClientGraph from the original GraphDef that specifies 73 // the complete graph and from BuildGraphOptions which specifies 74 // input/output nodes. 75 // 76 // An executable Graph differs from a GraphDef by being Placed, 77 // meaning that each Node is assigned to a single Device in the 78 // available set. 79 // 80 // When GraphExecutionState is first constructed it instantiates 81 // a full Graph from the provided GraphDef, and places it, using only 82 // the static device assignments from the GraphDef. Nodes without are 83 // currently placed in a very naive way. Since stateful Nodes cannot 84 // be moved after initial placement, it is important that stateful 85 // Nodes get sensible initial device assignments in the graph 86 // definition. 87 // 88 // Subsequently, GraphExecutionState generates a SimpleClientGraph on 89 // demand, which is a sub-graph of the latest placement of the full 90 // Graph. MasterSession uses such a ClientGraph to execute one or 91 // more similar client requests. 92 // 93 // GraphExecutionState is thread-safe. 94 95 class GraphExecutionState { 96 public: 97 virtual ~GraphExecutionState(); 98 99 // Creates a new `GraphExecutionState` for the given 100 // `graph_def`, which represents the entire graph for a session. 101 // 102 // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def` 103 // in an undefined state. If it is necessary to use `*graph_def` 104 // after this call, make an explicit copy of the graph before 105 // calling this method. 106 static Status MakeForBaseGraph( 107 GraphDef* graph_def, const GraphExecutionStateOptions& options, 108 std::unique_ptr<GraphExecutionState>* out_state); 109 110 // Creates a new `GraphExecutionState` and `SimpleClientGraph` 111 // for the subgraph of `original_graph_def` defined by 112 // `subgraph_options`. 113 static Status MakeForPrunedGraph( 114 const FunctionDefLibrary& func_def_lib, 115 const GraphExecutionStateOptions& options, 116 const GraphDef& original_graph_def, 117 const BuildGraphOptions& subgraph_options, 118 std::unique_ptr<GraphExecutionState>* out_state, 119 std::unique_ptr<ClientGraph>* out_client_graph); 120 121 // Creates a new GraphExecutionState representing the 122 // concatenation of this graph, and the graph defined by 123 // "extension_def". The same name may not be used to define a node 124 // in both this graph and "extension_def". 125 // 126 // If successful, returns OK and the caller takes ownership of "*out". 127 // Otherwise returns an error and does not modify "*out". 128 // 129 // After calling `old_state->Extend()`, `old_state` may no longer be 130 // used. 131 // 132 // NOTE(mrry): This method respects the placement of stateful nodes in 133 // in *this, but currently does not transfer any other placement 134 // or cost model information to the new graph. 135 Status Extend(const GraphDef& extension_def, 136 std::unique_ptr<GraphExecutionState>* out) const; 137 138 // Builds a ClientGraph (a sub-graph of the full graph as induced by 139 // the Node set specified in "options"). If successful, returns OK 140 // and the caller takes the ownership of "*out". Otherwise, returns 141 // an error. 142 Status BuildGraph(const BuildGraphOptions& options, 143 std::unique_ptr<ClientGraph>* out); 144 145 // The graph returned by BuildGraph may contain only the pruned 146 // graph, whereas some clients may want access to the full graph. full_graph()147 const Graph* full_graph() { return graph_; } 148 149 // Returns the node with the given name, or null if it does not exist. get_node_by_name(const string & name)150 const Node* get_node_by_name(const string& name) const { 151 NodeNameToCostIdMap::const_iterator iter = 152 node_name_to_cost_id_map_.find(name); 153 if (iter != node_name_to_cost_id_map_.end()) { 154 return graph_->FindNodeId(iter->second); 155 } else { 156 return nullptr; 157 } 158 } 159 160 // Returns a reference to the current graph_def. Use must 161 // not extend beyond lifetime of GrahExecutionState object. original_graph_def()162 const GraphDef& original_graph_def() { return original_graph_def_; } 163 164 // Returns the map of stateful placements as a map of 165 // node name to placement string. GetStatefulPlacements()166 std::unordered_map<string, string> GetStatefulPlacements() const { 167 return stateful_placements_; 168 } 169 170 private: 171 GraphExecutionState(GraphDef* graph_def, 172 const GraphExecutionStateOptions& options); 173 174 Status InitBaseGraph(const BuildGraphOptions& options); 175 176 // Map of placed stateful nodes, i.e. nodes for which is_stateful() 177 // is true, such as "params" and "queue" nodes. Once placed these 178 // nodes can not be moved to a different device. Maps node names to 179 // device names. 180 std::unordered_map<string, string> stateful_placements_; // Immutable after 181 // ctor. 182 void SaveStatefulNodes(Graph* graph); 183 void RestoreStatefulNodes(Graph* graph); 184 185 // Extract the subset of the graph that needs to be run, adding feed/fetch 186 // ops as needed. 187 Status PruneGraph(const BuildGraphOptions& options, Graph* graph, 188 subgraph::RewriteGraphMetadata* out_rewrite_metadata); 189 190 Status OptimizeGraph( 191 const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph, 192 std::unique_ptr<FunctionLibraryDefinition>* optimized_flib); 193 194 GraphDef original_graph_def_; // Immutable after ctor. 195 const DeviceSet* device_set_; // Not owned 196 const SessionOptions* session_options_; // Not owned 197 // Unique session identifier. Can be empty. 198 string session_handle_; 199 200 // Map from name to Node for the full graph in placed_. 201 NodeNameToCostIdMap node_name_to_cost_id_map_; 202 203 // 'flib_def_' is initialized from the initial graph def's library, 204 // and may be updated by a graph optimization pass. 205 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 206 207 // `rewrite_metadata_` is only set for GraphExecutionState 208 // objects created by `MakeForPrunedGraph()`. 209 std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_; 210 211 // The dataflow graph owned by this object. 212 Graph* graph_; 213 214 TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState); 215 }; 216 217 } // namespace tensorflow 218 219 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 220