• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
17 #define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
18 
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "tensorflow/core/common_runtime/graph_constructor.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
25 #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace tensorflow {
30 
31 enum RemoteFusedGraphNodeType {
32   UNUSED = 0,
33   GRAPH_INPUT = 1,
34   GRAPH_OUTPUT = 2,
35   FUSED_NODE = 3,
36   BORDER_INPUT = 4,
37   BORDER_OUTPUT = 5,
38 };
39 
40 class RemoteFusedGraphExecuteInfo;
41 
42 // RemoteFusedGraphExecuteUtils provides APIs to register and get builder
43 // functions for IRemoteFusedGraphExecutor.
44 class RemoteFusedGraphExecuteUtils {
45  public:
46   // TODO(satok): Use "_output_data_types" to share a spec with other ops
47   static constexpr const char* const ATTR_OUTPUT_DATA_TYPES =
48       "_default_remote_graph_output_data_types";
49   // TODO(satok): Use "_output_shapes" to share a spec with other ops
50   static constexpr const char* const ATTR_OUTPUT_SHAPES =
51       "_default_remote_output_shapes";
52   static constexpr const char* const
53       ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO =
54           "serialized_remote_fused_graph_execute_info";
55   static constexpr const char* const ATTR_NODE_TYPE =
56       "_remote_fused_graph_node_type";
57 
58   // Argument key strings to fuse a subgraph into RemoteFusedGraphExecuteOp.
59   static constexpr const char* const
60       TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
61           "remote_fused_graph_executor_name";
62   static constexpr const char* const
63       TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME =
64           "remote_fused_graph_node_name";
65   static constexpr const char* const TRANSFORM_ARG_FUSED_NODES = "fused_nodes";
66   static constexpr const char* const TRANSFORM_ARG_BORDER_INPUTS =
67       "border_inputs";
68   static constexpr const char* const TRANSFORM_ARG_BORDER_OUTPUTS =
69       "border_outputs";
70   static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES =
71       "fused_op_types";
72   static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR =
73       "fuse_by_executor";
74   static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
75   static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
76       "input_shapes";
77 
78   using ExecutorBuildFunc = std::function<Status(
79       std::unique_ptr<IRemoteFusedGraphExecutor>* executor)>;
80   // Registrar class for IRemoteFusedGraphExecutor.
81   class ExecutorBuildRegistrar {
82    public:
83     ExecutorBuildRegistrar(const string& name, ExecutorBuildFunc func);
84 
85    private:
86     TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBuildRegistrar);
87   };
88   using ExecutorBuildRegistry = std::map<string, ExecutorBuildFunc>;
89 
90   using TensorShapeType = std::pair<DataType, TensorShape>;
91   using TensorShapeMap = std::unordered_multimap<string,         // node name
92                                                  std::pair<int,  // port
93                                                            TensorShapeType>>;
94   using ClusterInfo = std::tuple<std::unordered_set<string>,  // node names
95                                  std::vector<string>,         // border inputs
96                                  std::vector<string>>;        // border outputs
97 
98   // Return registered ExecutorBuildFunc for given name.
99   static const ExecutorBuildFunc* GetExecutorBuildFunc(const string& name);
100 
101   // To determine shapes of output tensors of all nodes, dryrun the graph.
102   // This function supplies memory allocation information when loading
103   // the graph. This function is used to verify shape inference and actual
104   // output shape.
105   static Status DryRunInference(
106       const GraphDef& graph_def,
107       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
108       const std::vector<string>& output_node_names,
109       const bool initialize_by_zero,
110       std::vector<tensorflow::Tensor>* output_tensors);
111 
112   // Dry run inference to obtain shapes for all nodes.
113   // CAVEAT: Do not add or modify output_tensors in output_tensor_info
114   // otherwise, address map may be broken by re-allocation inside
115   // std::vector.
116   static Status DryRunInferenceForAllNode(
117       const GraphDef& graph_def,
118       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
119       const bool initialize_by_zero, TensorShapeMap* tensor_shape_map);
120 
121   static bool IsInputNode(
122       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
123       const string& node_name);
124 
125   static void ConvertToTensorShapeMap(
126       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
127       const std::vector<string>& output_node_names,
128       const std::vector<tensorflow::Tensor>& output_tensors,
129       TensorShapeMap* tensor_shape_map);
130 
131   static Status MakeTensorFromProto(const TensorProto& tensor_proto,
132                                     Tensor* tensor);
133 
134   static bool AddOutputTensorShapeType(const std::vector<DataType>& data_types,
135                                        const std::vector<TensorShape>& shapes,
136                                        NodeDef* node_def);
137 
138   static Status AddOutputTensorShapeTypeByTensorShapeMap(
139       const TensorShapeMap& tensor_shape_map, NodeDef* node_def);
140 
141   static Status GetOutputTensorShapeType(AttrSlice attrs,
142                                          std::vector<DataType>* data_types,
143                                          std::vector<TensorShape>* shapes);
144 
145   static bool GetOutputTensorShapeType(const GraphDef& graph_def,
146                                        const string& name_and_port,
147                                        DataType* data_type, TensorShape* shape);
148 
149   static Status PropagateShapeInference(
150       const GraphDef& graph_def,
151       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
152       Graph* graph, ShapeRefiner* shape_refiner);
153 
154   static Status BuildTensorShapeMapFromGraph(const Graph& graph,
155                                              const ShapeRefiner& shape_refiner,
156                                              TensorShapeMap* tensor_shape_map);
157 
158   static const TensorShapeType* GetTensorShapeType(
159       const TensorShapeMap& tensor_shape_map, const string& node_name);
160 
161   static const TensorShapeType* GetTensorShapeType(
162       const TensorShapeMap& tensor_shape_map, const string& node_name,
163       const int port);
164 
165   static void BuildRemoteGraphInputsAndOutputsFromProto(
166       const RemoteFusedGraphExecuteInfo& proto,
167       std::vector<std::pair<string, Tensor>>* inputs,
168       std::vector<string>* outputs);
169 
170   static Status BuildAndAddTensorShapes(
171       const std::vector<std::pair<string, Tensor>>& input_tensors,
172       const bool dry_run_inference, GraphDef* graph_def);
173 
174   // Build remote fused graph execute info.
175   static Status BuildRemoteFusedGraphExecuteInfo(
176       const string& executor_name, const GraphDef& subgraph_def,
177       const std::vector<string>& inputs, const std::vector<string>& outputs,
178       const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
179       DataTypeVector* input_types, DataTypeVector* output_types);
180 
181   // Build remote fused graph execute op node by fusing specified subgraph
182   // as remote fused graph execute info.
183   static Status BuildRemoteFusedGraphExecuteOpNode(
184       const string& node_name, const string& executor_name,
185       const GraphDef& subgraph_def, const std::vector<string>& inputs,
186       const std::vector<string>& outputs, const bool require_shape_type,
187       Graph* graph, Node** created_node);
188 
189   // Build Identity node to forward remote graph node output.
190   static Status BuildIdentityOpNode(const string& node_name,
191                                     const string& input_node_name,
192                                     const int input_node_port,
193                                     const DataType dt, Graph* graph,
194                                     Node** created_node);
195 
196   // Create clusters of given nodes.
197   static Status ClusterizeNodes(const std::unordered_set<string>& node_names,
198                                 const GraphDef& graph_def,
199                                 std::vector<ClusterInfo>* cluster_infos);
200 
201   // Build GraphDef of a given cluster.
202   static Status BuildClusterSubgraphDef(const ClusterInfo& cluster,
203                                         const GraphDef& graph_def,
204                                         GraphDef* subgraph_def);
205 
206   // Build a cluster by given border.
207   // CAVEAT: The border must be consistent for one cluster.
208   static Status BuildClusterByBorder(const std::vector<string>& border_inputs,
209                                      const std::vector<string>& border_outputs,
210                                      const GraphDef& graph_def,
211                                      ClusterInfo* cluster);
212 
213   // Fuse one cluster into a newly created RemoteFusedGraphExecuteOp node.
214   // The subgraph is stored as a graph in RemoteFusedGraphExecuteInfo.
215   // CAVEAT1: This transform strips unvisited nodes with given outputs.
216   // CAVEAT2: If you want to use a graph output as a border output,
217   // that graph output node is replaced by an identity node.  Therefore,
218   // the number of output of the node must be 1.
219   static Status FuseCluster(const GraphDef& input_graph_def,
220                             const std::vector<string>& inputs,
221                             const std::vector<string>& outputs,
222                             const string& remote_fused_graph_node_name,
223                             const ClusterInfo& cluster,
224                             const string& remote_graph_executor_name,
225                             const bool require_shape_type,
226                             GraphDef* output_graph_def);
227 
228   // Fuse subgraph of specified nodes.
229   static Status FuseRemoteGraphByNodeNames(
230       const GraphDef& input_graph_def, const std::vector<string>& inputs,
231       const std::vector<string>& outputs,
232       const string& remote_fused_graph_node_name_prefix,
233       const std::unordered_set<string>& subgraph_nodes,
234       const string& remote_fused_graph_executor_name,
235       const bool require_shape_type, GraphDef* output_graph_def);
236 
237   // Fuse subgraph of specified border.
238   static Status FuseRemoteGraphByBorder(
239       const GraphDef& input_graph_def, const std::vector<string>& inputs,
240       const std::vector<string>& outputs,
241       const string& remote_fused_graph_node_name,
242       const std::vector<string>& border_inputs,
243       const std::vector<string>& border_outputs,
244       const string& remote_graph_executor_name, const bool require_shape_type,
245       GraphDef* output_graph_def);
246 
247   // Fuse subgraph of specified op types.
248   static Status FuseRemoteGraphByOpTypes(
249       const GraphDef& input_graph_def, const std::vector<string>& inputs,
250       const std::vector<string>& outputs,
251       const string& remote_fused_graph_node_name_prefix,
252       const std::unordered_set<string>& fused_op_types,
253       const string& remote_fused_graph_executor_name,
254       const bool require_shape_type, GraphDef* output_graph_def);
255 
256   // Place arguments to fuse remote graph.
257   static Status PlaceRemoteGraphArguments(
258       const std::vector<string>& inputs, const std::vector<string>& outputs,
259       const std::unordered_set<string>& fused_node_names,
260       const std::vector<string>& border_inputs,
261       const std::vector<string>& border_outputs,
262       const std::unordered_set<string>& fused_op_types,
263       const string& remote_fused_graph_node_name,
264       const string& remote_graph_executor_name, GraphDef* graph_def);
265 
266   // Fuse remote graph by placed arguments.
267   static Status FuseRemoteGraphByPlacedArguments(
268       const GraphDef& input_graph_def,
269       const std::vector<std::pair<string, Tensor>>& input_tensors,
270       GraphDef* output_graph_def);
271 
272   static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def,
273                                           const std::vector<string>& inputs,
274                                           const std::vector<string>& outputs,
275                                           const string& executor_name,
276                                           GraphDef* output_graph_def);
277 
278   static bool IsFuseReady(
279       const GraphDef& input_graph_def,
280       const std::vector<std::pair<string, Tensor>>& input_tensors);
281 
282   // Copy a byte array to a tensor data.  Though tensor data must be
283   // updated with typed information in general, we can't guarantee that
284   // returned values from a remote processor has typed information because
285   // a logic running in the remote processor possibly be in a separate binary
286   // which may not link tensorflow libraries.  To deal with this situation,
287   // remote fused graph needs to overwrite the tensor data by a byte array.
288   static Status CopyByteArrayToTensor(const void* src_ptr, const int src_size,
289                                       Tensor* tensor);
290 
291   static std::unordered_set<string> BuildNodeMapFromOpTypes(
292       const GraphDef& graph_def, const std::unordered_set<string>& op_types);
293 
294   static std::unordered_set<string> BuildNodeMapFromOpsDefinitions(
295       const GraphDef& graph_def,
296       const IRemoteFusedGraphOpsDefinitions& ops_definitions);
297 
298  private:
299   static void EmplaceTensorShapeType(const string& name, const Tensor& tensor,
300                                      TensorShapeMap* tensor_shape_map);
301 
302   static Status ReplaceInputNodeByPlaceHolder(const string& input,
303                                               const DataType type,
304                                               const TensorShape& shape,
305                                               GraphDef* graph_def);
306 
307   static ExecutorBuildRegistry* GetExecutorBuildRegistry();
308 
309   static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,
310                                   const int port, const int index,
311                                   const string& executor_name,
312                                   const string& node_name);
313 
314   static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,
315                                   const int port, const int index);
316 
317   static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type);
318 
319   TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteUtils);
320 };
321 }  // namespace tensorflow
322 
323 #endif  // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
324