1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ 17 #define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ 18 19 #include <unordered_map> 20 #include <unordered_set> 21 22 #include "tensorflow/core/graph/graph.h" 23 #include "tensorflow/core/graph/graph_constructor.h" 24 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" 25 #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/macros.h" 28 29 namespace tensorflow { 30 31 enum RemoteFusedGraphNodeType { 32 UNUSED = 0, 33 GRAPH_INPUT = 1, 34 GRAPH_OUTPUT = 2, 35 FUSED_NODE = 3, 36 BORDER_INPUT = 4, 37 BORDER_OUTPUT = 5, 38 }; 39 40 class RemoteFusedGraphExecuteInfo; 41 42 // RemoteFusedGraphExecuteUtils provides APIs to register and get builder 43 // functions for IRemoteFusedGraphExecutor. 44 class RemoteFusedGraphExecuteUtils { 45 public: 46 // TODO(satok): Use "_output_data_types" to share a spec with other ops 47 static constexpr const char* const ATTR_OUTPUT_DATA_TYPES = 48 "_default_remote_graph_output_data_types"; 49 // TODO(satok): Use "_output_shapes" to share a spec with other ops 50 static constexpr const char* const ATTR_OUTPUT_SHAPES = 51 "_default_remote_output_shapes"; 52 static constexpr const char* const 53 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO = 54 "serialized_remote_fused_graph_execute_info"; 55 static constexpr const char* const ATTR_NODE_TYPE = 56 "_remote_fused_graph_node_type"; 57 58 // Argument key strings to fuse a subgraph into RemoteFusedGraphExecuteOp. 59 static constexpr const char* const 60 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME = 61 "remote_fused_graph_executor_name"; 62 static constexpr const char* const 63 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME = 64 "remote_fused_graph_node_name"; 65 static constexpr const char* const TRANSFORM_ARG_FUSED_NODES = "fused_nodes"; 66 static constexpr const char* const TRANSFORM_ARG_BORDER_INPUTS = 67 "border_inputs"; 68 static constexpr const char* const TRANSFORM_ARG_BORDER_OUTPUTS = 69 "border_outputs"; 70 static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES = 71 "fused_op_types"; 72 static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR = 73 "fuse_by_executor"; 74 static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types"; 75 static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES = 76 "input_shapes"; 77 78 using ExecutorBuildFunc = std::function<Status( 79 std::unique_ptr<IRemoteFusedGraphExecutor>* executor)>; 80 // Registrar class for IRemoteFusedGraphExecutor. 81 class ExecutorBuildRegistrar { 82 public: 83 ExecutorBuildRegistrar(const string& name, ExecutorBuildFunc func); 84 85 private: 86 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBuildRegistrar); 87 }; 88 using ExecutorBuildRegistry = std::map<string, ExecutorBuildFunc>; 89 90 using TensorShapeType = std::pair<DataType, TensorShape>; 91 using TensorShapeMap = std::unordered_multimap<string, // node name 92 std::pair<int, // port 93 TensorShapeType>>; 94 using ClusterInfo = std::tuple<std::unordered_set<string>, // node names 95 std::vector<string>, // border inputs 96 std::vector<string>>; // border outputs 97 98 // Return registered ExecutorBuildFunc for given name. 99 static const ExecutorBuildFunc* GetExecutorBuildFunc(const string& name); 100 101 // To determine shapes of output tensors of all nodes, dryrun the graph. 102 // This function supplies memory allocation information when loading 103 // the graph. This function is used to verify shape inference and actual 104 // output shape. 105 static Status DryRunInference( 106 const GraphDef& graph_def, 107 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 108 const std::vector<string>& output_node_names, 109 const bool initialize_by_zero, 110 std::vector<tensorflow::Tensor>* output_tensors); 111 112 // Dry run inference to obtain shapes for all nodes. 113 // CAVEAT: Do not add or modify output_tensors in output_tensor_info 114 // otherwise, address map may be broken by re-allocation inside 115 // std::vector. 116 static Status DryRunInferenceForAllNode( 117 const GraphDef& graph_def, 118 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 119 const bool initialize_by_zero, TensorShapeMap* tensor_shape_map); 120 121 static bool IsInputNode( 122 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 123 const string& node_name); 124 125 static void ConvertToTensorShapeMap( 126 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 127 const std::vector<string>& output_node_names, 128 const std::vector<tensorflow::Tensor>& output_tensors, 129 TensorShapeMap* tensor_shape_map); 130 131 static Status MakeTensorFromProto(const TensorProto& tensor_proto, 132 Tensor* tensor); 133 134 static bool AddOutputTensorShapeType(const std::vector<DataType>& data_types, 135 const std::vector<TensorShape>& shapes, 136 NodeDef* node_def); 137 138 static Status AddOutputTensorShapeTypeByTensorShapeMap( 139 const TensorShapeMap& tensor_shape_map, NodeDef* node_def); 140 141 static Status GetOutputTensorShapeType(AttrSlice attrs, 142 std::vector<DataType>* data_types, 143 std::vector<TensorShape>* shapes); 144 145 static bool GetOutputTensorShapeType(const GraphDef& graph_def, 146 const string& name_and_port, 147 DataType* data_type, TensorShape* shape); 148 149 static Status PropagateShapeInference( 150 const GraphDef& graph_def, 151 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 152 Graph* graph, ShapeRefiner* shape_refiner); 153 154 static Status BuildTensorShapeMapFromGraph(const Graph& graph, 155 const ShapeRefiner& shape_refiner, 156 TensorShapeMap* tensor_shape_map); 157 158 static const TensorShapeType* GetTensorShapeType( 159 const TensorShapeMap& tensor_shape_map, const string& node_name); 160 161 static const TensorShapeType* GetTensorShapeType( 162 const TensorShapeMap& tensor_shape_map, const string& node_name, 163 const int port); 164 165 static void BuildRemoteGraphInputsAndOutputsFromProto( 166 const RemoteFusedGraphExecuteInfo& proto, 167 std::vector<std::pair<string, Tensor>>* inputs, 168 std::vector<string>* outputs); 169 170 static Status BuildAndAddTensorShapes( 171 const std::vector<std::pair<string, Tensor>>& input_tensors, 172 const bool dry_run_inference, GraphDef* graph_def); 173 174 // Build remote fused graph execute info. 175 static Status BuildRemoteFusedGraphExecuteInfo( 176 const string& executor_name, const GraphDef& subgraph_def, 177 const std::vector<string>& inputs, const std::vector<string>& outputs, 178 const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info, 179 DataTypeVector* input_types, DataTypeVector* output_types); 180 181 // Build remote fused graph execute op node by fusing specified subgraph 182 // as remote fused graph execute info. 183 static Status BuildRemoteFusedGraphExecuteOpNode( 184 const string& node_name, const string& executor_name, 185 const GraphDef& subgraph_def, const std::vector<string>& inputs, 186 const std::vector<string>& outputs, const bool require_shape_type, 187 Graph* graph, Node** created_node); 188 189 // Build Identity node to forward remote graph node output. 190 static Status BuildIdentityOpNode(const string& node_name, 191 const string& input_node_name, 192 const int input_node_port, 193 const DataType dt, Graph* graph, 194 Node** created_node); 195 196 // Create clusters of given nodes. 197 static Status ClusterizeNodes(const std::unordered_set<string>& node_names, 198 const GraphDef& graph_def, 199 std::vector<ClusterInfo>* cluster_infos); 200 201 // Build GraphDef of a given cluster. 202 static Status BuildClusterSubgraphDef(const ClusterInfo& cluster, 203 const GraphDef& graph_def, 204 GraphDef* subgraph_def); 205 206 // Build a cluster by given border. 207 // CAVEAT: The border must be consistent for one cluster. 208 static Status BuildClusterByBorder(const std::vector<string>& border_inputs, 209 const std::vector<string>& border_outputs, 210 const GraphDef& graph_def, 211 ClusterInfo* cluster); 212 213 // Fuse one cluster into a newly created RemoteFusedGraphExecuteOp node. 214 // The subgraph is stored as a graph in RemoteFusedGraphExecuteInfo. 215 // CAVEAT1: This transform strips unvisited nodes with given outputs. 216 // CAVEAT2: If you want to use a graph output as a border output, 217 // that graph output node is replaced by an identity node. Therefore, 218 // the number of output of the node must be 1. 219 static Status FuseCluster(const GraphDef& input_graph_def, 220 const std::vector<string>& inputs, 221 const std::vector<string>& outputs, 222 const string& remote_fused_graph_node_name, 223 const ClusterInfo& cluster, 224 const string& remote_graph_executor_name, 225 const bool require_shape_type, 226 GraphDef* output_graph_def); 227 228 // Fuse subgraph of specified nodes. 229 static Status FuseRemoteGraphByNodeNames( 230 const GraphDef& input_graph_def, const std::vector<string>& inputs, 231 const std::vector<string>& outputs, 232 const string& remote_fused_graph_node_name_prefix, 233 const std::unordered_set<string>& subgraph_nodes, 234 const string& remote_fused_graph_executor_name, 235 const bool require_shape_type, GraphDef* output_graph_def); 236 237 // Fuse subgraph of specified border. 238 static Status FuseRemoteGraphByBorder( 239 const GraphDef& input_graph_def, const std::vector<string>& inputs, 240 const std::vector<string>& outputs, 241 const string& remote_fused_graph_node_name, 242 const std::vector<string>& border_inputs, 243 const std::vector<string>& border_outputs, 244 const string& remote_graph_executor_name, const bool require_shape_type, 245 GraphDef* output_graph_def); 246 247 // Fuse subgraph of specified op types. 248 static Status FuseRemoteGraphByOpTypes( 249 const GraphDef& input_graph_def, const std::vector<string>& inputs, 250 const std::vector<string>& outputs, 251 const string& remote_fused_graph_node_name_prefix, 252 const std::unordered_set<string>& fused_op_types, 253 const string& remote_fused_graph_executor_name, 254 const bool require_shape_type, GraphDef* output_graph_def); 255 256 // Place arguments to fuse remote graph. 257 static Status PlaceRemoteGraphArguments( 258 const std::vector<string>& inputs, const std::vector<string>& outputs, 259 const std::unordered_set<string>& fused_node_names, 260 const std::vector<string>& border_inputs, 261 const std::vector<string>& border_outputs, 262 const std::unordered_set<string>& fused_op_types, 263 const string& remote_fused_graph_node_name, 264 const string& remote_graph_executor_name, GraphDef* graph_def); 265 266 // Fuse remote graph by placed arguments. 267 static Status FuseRemoteGraphByPlacedArguments( 268 const GraphDef& input_graph_def, 269 const std::vector<std::pair<string, Tensor>>& input_tensors, 270 GraphDef* output_graph_def); 271 272 static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def, 273 const std::vector<string>& inputs, 274 const std::vector<string>& outputs, 275 const string& executor_name, 276 GraphDef* output_graph_def); 277 278 static bool IsFuseReady( 279 const GraphDef& input_graph_def, 280 const std::vector<std::pair<string, Tensor>>& input_tensors); 281 282 // Copy a byte array to a tensor data. Though tensor data must be 283 // updated with typed information in general, we can't guarantee that 284 // returned values from a remote processor has typed information because 285 // a logic running in the remote processor possibly be in a separate binary 286 // which may not link tensorflow libraries. To deal with this situation, 287 // remote fused graph needs to overwrite the tensor data by a byte array. 288 static Status CopyByteArrayToTensor(const void* src_ptr, const int src_size, 289 Tensor* tensor); 290 291 static std::unordered_set<string> BuildNodeMapFromOpTypes( 292 const GraphDef& graph_def, const std::unordered_set<string>& op_types); 293 294 static std::unordered_set<string> BuildNodeMapFromOpsDefinitions( 295 const GraphDef& graph_def, 296 const IRemoteFusedGraphOpsDefinitions& ops_definitions); 297 298 private: 299 static void EmplaceTensorShapeType(const string& name, const Tensor& tensor, 300 TensorShapeMap* tensor_shape_map); 301 302 static Status ReplaceInputNodeByPlaceHolder(const string& input, 303 const DataType type, 304 const TensorShape& shape, 305 GraphDef* graph_def); 306 307 static ExecutorBuildRegistry* GetExecutorBuildRegistry(); 308 309 static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type, 310 const int port, const int index, 311 const string& executor_name, 312 const string& node_name); 313 314 static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type, 315 const int port, const int index); 316 317 static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type); 318 319 TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteUtils); 320 }; 321 } // namespace tensorflow 322 323 #endif // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ 324