• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 vcyou may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
17 #define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
18 
19 #include <array>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/shape_refiner.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
28 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/protobuf.h"
31 #include "tensorflow/core/util/padding.h"
32 
33 namespace tensorflow {
34 
35 class GraphTransferInfo;
36 class GraphTransferNodeInfo;
37 class GraphTransferNodeInputInfo;
38 
39 // GraphTransferer transfers graph definitions into SoC memory.
40 // This functionality is effective if SoC is capable to run
41 // the graph on that chip.
42 // TODO(satok): support transferring subgraphs to be able to split graphs
43 // to avoid unsupported ops in SoC.
44 class GraphTransferer {
45  public:
46   // TODO(satok): Remove. Use proto definition instead.
47   static constexpr int MAX_SUPPORTED_RANK = 4;
48   // TODO(satok): Remove. Use proto definition instead.
49   static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
50   using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
51 
52   GraphTransferer();
53 
54   ~GraphTransferer();
55 
56   // Load graph structure into GraphTransferer
57   // TODO(satok): Pass a pair of TensorShape and DataType instead of
58   // Tensor as input_node_info_list.
59   Status LoadGraphFromProto(
60       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
61       const GraphDef& graph_def,
62       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
63       const std::vector<string>& output_node_names,
64       const bool shape_inference_for_unknown_shape);
65 
66   // Load graph structure into GraphTransferer from protobuf file
67   // TODO(satok): Pass a pair of TensorShape and DataType instead of
68   // Tensor as input_node_info_list.
69   Status LoadGraphFromProtoFile(
70       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
71       const string& graph_def_path,
72       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
73       const std::vector<string>& output_node_names, const bool is_text_proto,
74       const bool shape_inference_for_unknown_shape,
75       const bool dry_run_for_unknown_shape);
76 
77   // Sort params so that all input nodes appear before consumer nodes.
78   // CAVEAT: This may be slow if the number of nodes are too large
79   void SortParams(const std::vector<string>& output_node_names);
80 
81   void EnableStrictCheckMode(bool enable);
82 
83   // Import parameters for transfer
84   void SetSerializedGraphTransferInfo(const string& serialized_proto);
85 
86   // Return parameters for graph transfer
87   const GraphTransferInfo& GetGraphTransferInfo() const;
88 
89   // Return mutable GraphTransferInfo for graph transfer
90   GraphTransferInfo& GetMutableGraphTransferInfo();
91 
92   // Dump verification string of parameters to verify with offline tools
93   void DumpVerificationStringOfNodeTransferParams() const;
94 
95   static std::array<int64, SHAPE_ARRAY_SIZE> ToTensorShapeArray(
96       const TensorShape& shape);
97 
98  private:
99   class TransferParamsComparator {
100    public:
101     TransferParamsComparator(
102         const std::unordered_map<int, std::unordered_set<int>>& dep_map);
103     bool operator()(const GraphTransferNodeInfo& obj0,
104                     const GraphTransferNodeInfo& obj1);
105     const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
106   };
107 
108   void CacheNode(const Node& node);
109 
110   bool AreAllInputsCached(const Node& node) const;
111 
112   // Transform a remote fused graph to add an aggregated input node which takes
113   // all inputs of the remote graph.
114   Status TransformGraphToAddAggregatedInputNode(
115       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
116       Graph* graph, ShapeRefiner* shape_refiner);
117 
118   Status RegisterNode(
119       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
120       const ShapeRefiner& shape_refiner, const Node& node,
121       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
122       const std::vector<string>& output_node_names);
123 
124   void RegisterConstantNode(const ShapeRefiner& shape_refiner,
125                             const Node& node);
126 
127   int RegisterConstantShape(const std::vector<int>& shape);
128 
129   int RegisterConstTensor(const Tensor& tensor, const string& suffix);
130 
131   int RegisterConstScalar(const DataType dt, const int val, const int dst_id,
132                           const int dst_input_count);
133 
134   bool HasPaddingAndStrides(const Node& node);
135 
136   bool NeedsToAddRank(const Node& node);
137 
138   bool IsPadNode(const Node& node);
139 
140   // Return true if the node is a reshape op which just flattens input
141   // TODO(satok): Remove this method once generic reshape op is implemented in
142   // SOC
143   bool IsNodeFlattenReshape(const Node& node,
144                             const ShapeRefiner& shape_refiner);
145 
146   void RegisterNodeWithPaddingAndStrides(
147       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
148       const ShapeRefiner& shape_refiner, const Node& node);
149 
150   void RegisterNodeWithRank(
151       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
152       const ShapeRefiner& shape_refiner, const Node& node);
153 
154   void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
155                        const ShapeRefiner& shape_refiner, const Node& node);
156 
157   void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
158                          const ShapeRefiner& shape_refiner, const Node& node);
159 
160   void RegisterFlattenNode(
161       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
162       const ShapeRefiner& shape_refiner, const Node& node);
163 
164   void RegisterGenericNode(
165       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
166       const ShapeRefiner& shape_refiner, const Node& node);
167 
168   Status RegisterNodeIfAllInputsAreCached(
169       const IRemoteFusedGraphOpsDefinitions& ops_definitions,
170       const ShapeRefiner& shape_refiner, const Node& node,
171       const bool only_register_const_node,
172       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
173       const std::vector<string>& output_node_names);
174 
175   void AppendNodeParams(const string& name, const int id, const string& type,
176                         const int type_id, const int padding,
177                         const int inputs_size,
178                         const std::vector<int>& extra_inputs,
179                         const int outputs_size);
180 
181   void AddNodeInputByInputIndex(const Node& node, const int idx,
182                                 GraphTransferNodeInputInfo* node_input_info);
183 
184   void AppendNodeInputParams(const int id, const Node& node,
185                              const std::vector<int>& extra_inputs);
186 
187   void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id,
188                               const Node& node);
189 
190   static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
191       const shape_inference::ShapeHandle& shape_handle,
192       shape_inference::InferenceContext* context);
193 
194   void AppendNodeParamsWithIoParams(
195       const ShapeRefiner& shape_refiner, const Node& node, const string& name,
196       const int id, const string& type, const int type_id, const int padding,
197       const int inputs_size, const std::vector<int>& extra_inputs,
198       const int outputs_size, const bool append_input_params,
199       const bool append_output_params);
200 
201   static string ToPaddingDebugString(int padding);
202 
203   // Create dependency map
204   static void FillDependencyRec(
205       int node_id, std::unordered_map<int, std::unordered_set<int>>& dep_map,
206       std::unordered_set<int>& completed);
207 
208   // Build tensor from proto
209   static Status MakeTensorFromProto(const TensorProto& tensor_proto,
210                                     Tensor* tensor);
211 
212   void ClearCache();
213 
214   // Dump pretty print of parameters
215   void DumpNodeTransferParams() const;
216 
217   GraphTransferInfo* graph_transfer_info_;
218 
219   std::vector<const Node*> node_name_cache_list_{};
220   std::unordered_map<string, int> node_name_to_id_cache_map_{};
221 
222   // strict check mode is true by default.  Disable this if the ops' shape
223   // inferences are not implemented correctly.
224   bool strict_check_mode_{true};
225 
226   TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferer);
227 };
228 
229 }  // namespace tensorflow
230 
231 #endif  // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
232