• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
17 #include "base/core_ops.h"
18 #include "ir/graph_utils.h"
19 #include "debug/common.h"
20 #include "backend/kernel_compiler/common_utils.h"
21 #include "backend/session/anf_runtime_algorithm.h"
22 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
23 #include "backend/optimizer/graph_kernel/model/lite_graph.h"
24 #include "backend/optimizer/graph_kernel/model/op_register.h"
25 
26 namespace mindspore {
27 namespace opt {
28 namespace {
29 enum FormatType { kFormatUnknown, kFormatA, kFormatB };
30 enum TransOpType { kTransAB, kTransBA };
31 struct Edge {
32   size_t to;
33   size_t capacity;
34 };
35 
36 struct Vertex {
37   FormatType format{kFormatB};
38   size_t depth{0};
39   std::vector<size_t> out_edges;
40 };
41 
42 constexpr size_t INF = static_cast<size_t>(1) << 30;
43 
44 class MinCut {
45  private:
46   // Add the bidirectional edges for the vertex `from` and `to`.
47   // the two edge ids are adjacent in vector, x and x+1 (x are 0,2,4,...)
48   // we can use (i xor 1) to get the inverse edge for any edge i.
49   // e.g. edge_0 and edge_1 are a couple, 0^1=1, 1^1=0.
AddEdge(size_t from,size_t to,size_t capacity,size_t inv_capacity)50   void AddEdge(size_t from, size_t to, size_t capacity, size_t inv_capacity) {
51     (void)edges_.emplace_back(Edge{to, capacity});
52     (void)nodes_[from].out_edges.emplace_back(edges_.size() - 1);
53     // inverse edge
54     (void)edges_.emplace_back(Edge{from, inv_capacity});
55     (void)nodes_[to].out_edges.emplace_back(edges_.size() - 1);
56   }
57 
BfsSetDepth()58   bool BfsSetDepth() {
59     std::queue<size_t> bfs_queue;
60     for (auto &node : nodes_) {
61       node.depth = 0;
62     }
63     nodes_[source_id_].depth = 1;
64     bfs_queue.push(source_id_);
65     while (!bfs_queue.empty()) {
66       auto edge_from = bfs_queue.front();
67       bfs_queue.pop();
68       for (auto e_id : nodes_[edge_from].out_edges) {
69         auto edge_to = edges_[e_id].to;
70         if (edges_[e_id].capacity > 0 && nodes_[edge_to].depth == 0) {
71           nodes_[edge_to].depth = nodes_[edge_from].depth + 1;
72           bfs_queue.push(edge_to);
73         }
74       }
75     }
76     return nodes_[sink_id_].depth > 0;
77   }
78 
DfsMaxFlow(size_t node,size_t flow)79   size_t DfsMaxFlow(size_t node, size_t flow) {
80     if (node == sink_id_) return flow;
81     size_t max_flow = 0;
82     for (size_t e_id : nodes_[node].out_edges) {
83       if ((edges_[e_id].capacity > 0) && (nodes_[node].depth + 1 == nodes_[edges_[e_id].to].depth)) {
84         auto tmp_flow = DfsMaxFlow(edges_[e_id].to, std::min(flow, edges_[e_id].capacity));
85         if (tmp_flow > 0) {
86           max_flow += tmp_flow;
87           flow -= tmp_flow;
88           edges_[e_id].capacity -= tmp_flow;
89           edges_[e_id ^ 1].capacity += tmp_flow;
90         }
91       }
92     }
93     return max_flow;
94   }
95 
Dinic()96   void Dinic() {
97     while (BfsSetDepth()) {
98       (void)DfsMaxFlow(source_id_, INF);
99     }
100   }
101 
SetFormat(size_t node_id)102   void SetFormat(size_t node_id) {
103     nodes_[node_id].format = kFormatA;
104     for (size_t i : nodes_[node_id].out_edges) {
105       if (edges_[i].capacity > 0 && nodes_[edges_[i].to].format != kFormatA) {
106         SetFormat(edges_[i].to);
107       }
108     }
109   }
110 
BuildGraph(const std::vector<std::pair<size_t,FormatType>> & original_nodes)111   void BuildGraph(const std::vector<std::pair<size_t, FormatType>> &original_nodes) {
112     for (size_t i = 0; i < origin_nodes_num_; ++i) {
113       // link the source node to the nodes with FormatA,
114       // link the nodes with FormatB to the sink node.
115       if (original_nodes[i].second == kFormatA) {
116         AddEdge(source_id_, original_nodes[i].first, INF, 0);
117       } else if (original_nodes[i].second == kFormatB) {
118         AddEdge(original_nodes[i].first, sink_id_, INF, 0);
119       }
120       // each nodes was split into two part, input part and output part.
121       // the input part's id is the original node's id, the output part's id is input id + origin_nodes_num_.
122       AddEdge(original_nodes[i].first, original_nodes[i].first + origin_nodes_num_, 1, 1);
123     }
124     for (auto e : original_edges_) {
125       auto from = e.first, to = e.second;
126       AddEdge(from + origin_nodes_num_, to, 1, 1);
127     }
128   }
129 
130  public:
MinCut(const std::vector<std::pair<size_t,FormatType>> & original_nodes,const std::vector<std::pair<size_t,size_t>> & original_edges)131   MinCut(const std::vector<std::pair<size_t, FormatType>> &original_nodes,
132          const std::vector<std::pair<size_t, size_t>> &original_edges)
133       : origin_nodes_num_(original_nodes.size()),
134         sink_id_(2 * original_nodes.size() + 1),  // source_id_ is 0
135         nodes_(2 * original_nodes.size() + 2),    // double nodes, and source_node/sink_node
136         original_edges_(original_edges) {
137     BuildGraph(original_nodes);
138   }
139   ~MinCut() = default;
140 
Run()141   void Run() {
142     Dinic();
143     SetFormat(source_id_);
144   }
145 
GetOneNodeOps() const146   std::vector<std::pair<size_t, TransOpType>> GetOneNodeOps() const {
147     std::vector<std::pair<size_t, TransOpType>> one_node_ops;
148     for (size_t i = 1; i <= origin_nodes_num_; ++i) {
149       if (nodes_[i].format == kFormatA && nodes_[i + origin_nodes_num_].format != kFormatA) {
150         (void)one_node_ops.emplace_back(i, kTransAB);
151       } else if (nodes_[i].format != kFormatA && nodes_[i + origin_nodes_num_].format == kFormatA) {
152         (void)one_node_ops.emplace_back(i, kTransBA);
153       }
154     }
155     return one_node_ops;
156   }
157 
GetTwoNodeOps() const158   std::vector<std::pair<std::pair<size_t, size_t>, TransOpType>> GetTwoNodeOps() const {
159     std::vector<std::pair<std::pair<size_t, size_t>, TransOpType>> two_node_ops;
160     for (auto i : original_edges_) {
161       if (nodes_[i.first + origin_nodes_num_].format == kFormatA && nodes_[i.second].format != kFormatA) {
162         (void)two_node_ops.emplace_back(i, kTransAB);
163       } else if (nodes_[i.first + origin_nodes_num_].format != kFormatA && nodes_[i.second].format == kFormatA) {
164         (void)two_node_ops.emplace_back(i, kTransBA);
165       }
166     }
167     return two_node_ops;
168   }
169 
170  private:
171   size_t origin_nodes_num_;
172   size_t source_id_{0};
173   size_t sink_id_;
174   std::vector<Vertex> nodes_;
175   std::vector<Edge> edges_;
176   std::vector<std::pair<size_t, size_t>> original_edges_;
177 };
178 }  // namespace
179 
180 using graphkernel::LiteGraph;
181 using graphkernel::LiteGraphPtr;
182 using graphkernel::Node;
183 using graphkernel::NodePtr;
184 using graphkernel::NodePtrList;
185 using graphkernel::NType;
186 using graphkernel::PrimOp;
187 using graphkernel::PrimOpPtr;
188 
189 class TransformOp {
190  public:
TransformOp(const NodePtr & node)191   explicit TransformOp(const NodePtr &node)
192       : op_(node->As<PrimOp>()->op()), format_a_(node->input(0)->format), format_b_(node->format) {}
193   ~TransformOp() = default;
IsTransformOp(const NodePtr & node)194   bool IsTransformOp(const NodePtr &node) {
195     if (node->NodeType() != NType::Primitive || node->As<PrimOp>()->op() != op_) {
196       return false;
197     }
198     if (node->input(0)->format == format_a_ && node->format == format_b_) {
199       return true;
200     } else if (node->input(0)->format == format_b_ && node->format == format_a_) {
201       return true;
202     }
203     return false;
204   }
205 
GetFormatType(const std::string & fmt)206   FormatType GetFormatType(const std::string &fmt) {
207     return fmt == format_a_ ? FormatType::kFormatA : FormatType::kFormatB;
208   }
209 
GenTransformOp(TransOpType trans_type)210   NodePtr GenTransformOp(TransOpType trans_type) {
211     // Only support Transpose now
212     static std::map<std::pair<std::string, std::string>, std::vector<int64_t>> perm_map = {
213       {{kOpFormat_DEFAULT, kOpFormat_NHWC}, {0, 2, 3, 1}},
214       {{kOpFormat_NCHW, kOpFormat_NHWC}, {0, 2, 3, 1}},
215       {{kOpFormat_NHWC, kOpFormat_NCHW}, {0, 3, 1, 2}},
216       {{kOpFormat_NHWC, kOpFormat_DEFAULT}, {0, 3, 1, 2}},
217     };
218     std::vector<int64_t> perm;
219     if (trans_type == TransOpType::kTransAB) {
220       perm = perm_map[{format_a_, format_b_}];
221     } else {
222       perm = perm_map[{format_b_, format_a_}];
223     }
224     if (perm.empty()) {
225       MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_;
226     }
227     auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans");
228     op->SetAttr("perm", MakeValue(perm));
229     return op;
230   }
231 
232  private:
233   std::string op_;
234   std::string format_a_;
235   std::string format_b_;
236 };
237 
IsFlexibleOp(const NodePtr & node)238 bool IsFlexibleOp(const NodePtr &node) {
239   static const std::set<std::string> format_flexible_ops = {
240     "Abs",  "Add",     "Sub",     "Mul",   "Round",   "Cast",         "Neg",  "Exp",       "Log",
241     "Pow",  "Minimum", "Maximum", "Rsqrt", "Sqrt",    "Reciprocal",   "Tanh", "Sin",       "Cos",
242     "Asin", "ACos",    "RealDiv", "Equal", "Greater", "GreaterEqual", "Less", "LessEqual", "Sign"};
243   if (node->NodeType() != NType::Primitive) {
244     return false;
245   }
246   if (format_flexible_ops.count(node->As<PrimOp>()->op()) == 0) {
247     return false;
248   }
249   // check the input and output formats are all the same, except ConstValue.
250   for (auto &inp : node->inputs()) {
251     if (inp->NodeType() != NType::Value && inp->format != node->format) {
252       return false;
253     }
254   }
255   return true;
256 }
257 
258 class Mutator {
259  public:
Mutator(const NodePtr & node)260   explicit Mutator(const NodePtr &node) : op_checker_(node), basenode_(node), ori_node_(1) {}
261   ~Mutator() = default;
Run()262   bool Run() {
263     VisitNode(basenode_);
264     if (flexible_ops_.empty()) return false;
265     // remove transform ops in litegraph
266     RemoveTransOp();
267     GenFormatGraph();
268     RebuildLiteGraph();
269     return true;
270   }
271 
272  private:
273   // visit nodes bidirectionally
VisitNode(const NodePtr & node)274   void VisitNode(const NodePtr &node) {
275     if (visited_.count(node) > 0) return;
276     (void)visited_.insert(node);
277     if (op_checker_.IsTransformOp(node)) {
278       (void)trans_ops_.insert(node);
279     } else if (!IsFlexibleOp(node)) {
280       if (node->NodeType() != NType::Output) {
281         fmt_type[{node, -1}] = op_checker_.GetFormatType(node->format);
282       }
283       if (node->NodeType() != NType::Parameter) {
284         for (size_t i = 0; i < node->inputs().size(); i++) {
285           if (node->input(i)->NodeType() == NType::Value) {
286             continue;
287           }
288           fmt_type[{node, i}] = op_checker_.GetFormatType(node->input(i)->format);
289         }
290       }
291       return;
292     } else {
293       (void)flexible_ops_.insert(node);
294       fmt_type[{node, -1}] = FormatType::kFormatUnknown;
295     }
296 
297     for (auto &input : node->inputs()) {
298       if (input->NodeType() != NType::Value) {
299         VisitNode(input);
300       }
301     }
302     for (auto &user : node->users()) {
303       VisitNode(user.first->shared_from_this());
304     }
305   }
306 
RemoveTransOp()307   void RemoveTransOp() {
308     for (auto &node : trans_ops_) {
309       (void)visited_.erase(node);
310       node->ReplaceWith(node->input(0));
311       // clear inputs, so that the node will not be the basenode again.
312       node->SetInputs({});
313     }
314     trans_ops_.clear();
315   }
316 
GenFormatGraph()317   void GenFormatGraph() {
318     for (auto &node : visited_) {
319       if (node->NodeType() == NType::Parameter) continue;
320       bool is_flexible = (flexible_ops_.find(node) != flexible_ops_.end());
321       size_t cur_id = 0;
322       if (is_flexible) {
323         cur_id = GetId({node, -1});
324       }
325       for (size_t i = 0; i < node->inputs().size(); i++) {
326         if (visited_.count(node->input(i)) == 0) continue;
327         if (!is_flexible) {
328           cur_id = GetId({node, SizeToInt(i)});
329         }
330         auto input_id = GetId({node->input(i), -1});
331         (void)graph_edges_.emplace_back(input_id, cur_id);
332       }
333     }
334   }
335 
RebuildLiteGraph()336   void RebuildLiteGraph() {
337     MinCut min_cut(graph_vertex_, graph_edges_);
338     min_cut.Run();
339     for (auto [node_id, trans_type] : min_cut.GetOneNodeOps()) {
340       if (ori_node_[node_id].second != -1) {
341         MS_LOG(EXCEPTION) << "OneNodeOp should be the output edge. node_id:" << node_id
342                           << " index:" << ori_node_[node_id].second;
343       }
344       auto trans_op = op_checker_.GenTransformOp(trans_type);
345       ori_node_[node_id].first->ReplaceWith(trans_op);
346       trans_op->SetInputs({ori_node_[node_id].first});
347     }
348 
349     std::map<size_t, NodePtr> trans_op_cache;
350     for (auto [edge, trans_type] : min_cut.GetTwoNodeOps()) {
351       auto node_id_from = edge.first;
352       auto node_id_to = edge.second;
353       if (ori_node_[node_id_from].second != -1) {
354         MS_LOG(EXCEPTION) << "node_from should be the output edge. node_id:" << node_id_from
355                           << " index:" << ori_node_[node_id_from].second;
356       }
357       auto node_from = ori_node_[node_id_from].first;
358       auto node_to = ori_node_[node_id_to].first;
359       auto &trans_op = trans_op_cache[node_id_from];
360       if (trans_op == nullptr) {
361         trans_op = op_checker_.GenTransformOp(trans_type);
362         trans_op->SetInputs({node_from});
363       }
364       if (ori_node_[node_id_to].second >= 0) {
365         node_to->SetInput(IntToSize(ori_node_[node_id_to].second), trans_op);
366       } else {
367         for (size_t i = 0; i < node_to->inputs().size(); i++) {
368           if (node_to->input(i) == node_from) {
369             node_to->SetInput(i, trans_op);
370           }
371         }
372       }
373     }
374   }
375 
GetId(const std::pair<NodePtr,int> & node)376   size_t GetId(const std::pair<NodePtr, int> &node) {
377     // the nodes are indexed from 1 in the MinCut model.
378     auto &id = node_id_[node];
379     if (id == 0) {
380       id = node_id_.size();
381       ori_node_.push_back(node);
382       // set format_type for new id.
383       (void)graph_vertex_.emplace_back(id, fmt_type[node]);
384     }
385     return id;
386   }
387 
388   TransformOp op_checker_;
389   NodePtr basenode_;
390   std::set<NodePtr> flexible_ops_;
391   std::set<NodePtr> trans_ops_;
392   std::set<NodePtr> visited_;
393 
394   std::map<std::pair<NodePtr, int>, FormatType> fmt_type;
395   std::map<std::pair<NodePtr, int>, size_t> node_id_;
396   std::vector<std::pair<NodePtr, int>> ori_node_;
397   std::vector<std::pair<size_t, FormatType>> graph_vertex_;
398   std::vector<std::pair<size_t, size_t>> graph_edges_;
399 };
400 
Process(const LiteGraphPtr & litegraph,const std::string & trans_op_name)401 bool TransformOpOptimizer::Process(const LiteGraphPtr &litegraph, const std::string &trans_op_name) {
402   ori_trans_op_num_ = 0;
403   auto &ops = litegraph->ops();
404   bool changed = true;
405   auto check_is_trans_op = [&trans_op_name](const NodePtr &node) { return node->As<PrimOp>()->op() == trans_op_name; };
406   auto ori_trans_op_num = std::count_if(ops.begin(), ops.end(), check_is_trans_op);
407   for (auto &op : ops) {
408     if (check_is_trans_op(op) && !op->inputs().empty() && op->input(0)->format != op->format) {
409       auto mutator = Mutator(op);
410       changed = mutator.Run() || changed;
411     }
412   }
413   if (!changed) return false;
414   auto &new_ops = litegraph->GetOrderedNodes();
415   auto new_trans_op_num = std::count_if(new_ops.begin(), new_ops.end(), check_is_trans_op);
416   if (new_trans_op_num >= ori_trans_op_num) {
417     return false;
418   }
419   for (auto &op : new_ops) {
420     op->SetBaseInfo(op->As<PrimOp>()->Infer(op->inputs(), op->attrs()));
421   }
422   return true;
423 }
424 
Run(const FuncGraphPtr & kernel_graph)425 bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
426   auto mng = kernel_graph->manager();
427   MS_EXCEPTION_IF_NULL(mng);
428   auto todos = TopoSort(kernel_graph->get_return());
429   bool changed = false;
430   for (auto node : todos) {
431     if (!AnfAlgo::IsGraphKernel(node)) continue;
432     auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
433     auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
434     if (Process(litegraph)) {
435       changed = true;
436       AnfNodePtrList outputs;
437       auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
438       new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
439       auto cnode = node->cast<CNodePtr>();
440       AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
441       auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
442       SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
443       (void)mng->Replace(node, new_node);
444       mng->AddFuncGraph(new_funcgraph);
445     }
446   }
447   return changed;
448 }
449 }  // namespace opt
450 }  // namespace mindspore
451