1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
17 #include "base/core_ops.h"
18 #include "ir/graph_utils.h"
19 #include "debug/common.h"
20 #include "backend/kernel_compiler/common_utils.h"
21 #include "backend/session/anf_runtime_algorithm.h"
22 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
23 #include "backend/optimizer/graph_kernel/model/lite_graph.h"
24 #include "backend/optimizer/graph_kernel/model/op_register.h"
25
26 namespace mindspore {
27 namespace opt {
28 namespace {
29 enum FormatType { kFormatUnknown, kFormatA, kFormatB };
30 enum TransOpType { kTransAB, kTransBA };
31 struct Edge {
32 size_t to;
33 size_t capacity;
34 };
35
36 struct Vertex {
37 FormatType format{kFormatB};
38 size_t depth{0};
39 std::vector<size_t> out_edges;
40 };
41
42 constexpr size_t INF = static_cast<size_t>(1) << 30;
43
44 class MinCut {
45 private:
46 // Add the bidirectional edges for the vertex `from` and `to`.
47 // the two edge ids are adjacent in vector, x and x+1 (x are 0,2,4,...)
48 // we can use (i xor 1) to get the inverse edge for any edge i.
49 // e.g. edge_0 and edge_1 are a couple, 0^1=1, 1^1=0.
AddEdge(size_t from,size_t to,size_t capacity,size_t inv_capacity)50 void AddEdge(size_t from, size_t to, size_t capacity, size_t inv_capacity) {
51 (void)edges_.emplace_back(Edge{to, capacity});
52 (void)nodes_[from].out_edges.emplace_back(edges_.size() - 1);
53 // inverse edge
54 (void)edges_.emplace_back(Edge{from, inv_capacity});
55 (void)nodes_[to].out_edges.emplace_back(edges_.size() - 1);
56 }
57
BfsSetDepth()58 bool BfsSetDepth() {
59 std::queue<size_t> bfs_queue;
60 for (auto &node : nodes_) {
61 node.depth = 0;
62 }
63 nodes_[source_id_].depth = 1;
64 bfs_queue.push(source_id_);
65 while (!bfs_queue.empty()) {
66 auto edge_from = bfs_queue.front();
67 bfs_queue.pop();
68 for (auto e_id : nodes_[edge_from].out_edges) {
69 auto edge_to = edges_[e_id].to;
70 if (edges_[e_id].capacity > 0 && nodes_[edge_to].depth == 0) {
71 nodes_[edge_to].depth = nodes_[edge_from].depth + 1;
72 bfs_queue.push(edge_to);
73 }
74 }
75 }
76 return nodes_[sink_id_].depth > 0;
77 }
78
DfsMaxFlow(size_t node,size_t flow)79 size_t DfsMaxFlow(size_t node, size_t flow) {
80 if (node == sink_id_) return flow;
81 size_t max_flow = 0;
82 for (size_t e_id : nodes_[node].out_edges) {
83 if ((edges_[e_id].capacity > 0) && (nodes_[node].depth + 1 == nodes_[edges_[e_id].to].depth)) {
84 auto tmp_flow = DfsMaxFlow(edges_[e_id].to, std::min(flow, edges_[e_id].capacity));
85 if (tmp_flow > 0) {
86 max_flow += tmp_flow;
87 flow -= tmp_flow;
88 edges_[e_id].capacity -= tmp_flow;
89 edges_[e_id ^ 1].capacity += tmp_flow;
90 }
91 }
92 }
93 return max_flow;
94 }
95
Dinic()96 void Dinic() {
97 while (BfsSetDepth()) {
98 (void)DfsMaxFlow(source_id_, INF);
99 }
100 }
101
SetFormat(size_t node_id)102 void SetFormat(size_t node_id) {
103 nodes_[node_id].format = kFormatA;
104 for (size_t i : nodes_[node_id].out_edges) {
105 if (edges_[i].capacity > 0 && nodes_[edges_[i].to].format != kFormatA) {
106 SetFormat(edges_[i].to);
107 }
108 }
109 }
110
BuildGraph(const std::vector<std::pair<size_t,FormatType>> & original_nodes)111 void BuildGraph(const std::vector<std::pair<size_t, FormatType>> &original_nodes) {
112 for (size_t i = 0; i < origin_nodes_num_; ++i) {
113 // link the source node to the nodes with FormatA,
114 // link the nodes with FormatB to the sink node.
115 if (original_nodes[i].second == kFormatA) {
116 AddEdge(source_id_, original_nodes[i].first, INF, 0);
117 } else if (original_nodes[i].second == kFormatB) {
118 AddEdge(original_nodes[i].first, sink_id_, INF, 0);
119 }
120 // each nodes was split into two part, input part and output part.
121 // the input part's id is the original node's id, the output part's id is input id + origin_nodes_num_.
122 AddEdge(original_nodes[i].first, original_nodes[i].first + origin_nodes_num_, 1, 1);
123 }
124 for (auto e : original_edges_) {
125 auto from = e.first, to = e.second;
126 AddEdge(from + origin_nodes_num_, to, 1, 1);
127 }
128 }
129
130 public:
MinCut(const std::vector<std::pair<size_t,FormatType>> & original_nodes,const std::vector<std::pair<size_t,size_t>> & original_edges)131 MinCut(const std::vector<std::pair<size_t, FormatType>> &original_nodes,
132 const std::vector<std::pair<size_t, size_t>> &original_edges)
133 : origin_nodes_num_(original_nodes.size()),
134 sink_id_(2 * original_nodes.size() + 1), // source_id_ is 0
135 nodes_(2 * original_nodes.size() + 2), // double nodes, and source_node/sink_node
136 original_edges_(original_edges) {
137 BuildGraph(original_nodes);
138 }
139 ~MinCut() = default;
140
Run()141 void Run() {
142 Dinic();
143 SetFormat(source_id_);
144 }
145
GetOneNodeOps() const146 std::vector<std::pair<size_t, TransOpType>> GetOneNodeOps() const {
147 std::vector<std::pair<size_t, TransOpType>> one_node_ops;
148 for (size_t i = 1; i <= origin_nodes_num_; ++i) {
149 if (nodes_[i].format == kFormatA && nodes_[i + origin_nodes_num_].format != kFormatA) {
150 (void)one_node_ops.emplace_back(i, kTransAB);
151 } else if (nodes_[i].format != kFormatA && nodes_[i + origin_nodes_num_].format == kFormatA) {
152 (void)one_node_ops.emplace_back(i, kTransBA);
153 }
154 }
155 return one_node_ops;
156 }
157
GetTwoNodeOps() const158 std::vector<std::pair<std::pair<size_t, size_t>, TransOpType>> GetTwoNodeOps() const {
159 std::vector<std::pair<std::pair<size_t, size_t>, TransOpType>> two_node_ops;
160 for (auto i : original_edges_) {
161 if (nodes_[i.first + origin_nodes_num_].format == kFormatA && nodes_[i.second].format != kFormatA) {
162 (void)two_node_ops.emplace_back(i, kTransAB);
163 } else if (nodes_[i.first + origin_nodes_num_].format != kFormatA && nodes_[i.second].format == kFormatA) {
164 (void)two_node_ops.emplace_back(i, kTransBA);
165 }
166 }
167 return two_node_ops;
168 }
169
170 private:
171 size_t origin_nodes_num_;
172 size_t source_id_{0};
173 size_t sink_id_;
174 std::vector<Vertex> nodes_;
175 std::vector<Edge> edges_;
176 std::vector<std::pair<size_t, size_t>> original_edges_;
177 };
178 } // namespace
179
180 using graphkernel::LiteGraph;
181 using graphkernel::LiteGraphPtr;
182 using graphkernel::Node;
183 using graphkernel::NodePtr;
184 using graphkernel::NodePtrList;
185 using graphkernel::NType;
186 using graphkernel::PrimOp;
187 using graphkernel::PrimOpPtr;
188
189 class TransformOp {
190 public:
TransformOp(const NodePtr & node)191 explicit TransformOp(const NodePtr &node)
192 : op_(node->As<PrimOp>()->op()), format_a_(node->input(0)->format), format_b_(node->format) {}
193 ~TransformOp() = default;
IsTransformOp(const NodePtr & node)194 bool IsTransformOp(const NodePtr &node) {
195 if (node->NodeType() != NType::Primitive || node->As<PrimOp>()->op() != op_) {
196 return false;
197 }
198 if (node->input(0)->format == format_a_ && node->format == format_b_) {
199 return true;
200 } else if (node->input(0)->format == format_b_ && node->format == format_a_) {
201 return true;
202 }
203 return false;
204 }
205
GetFormatType(const std::string & fmt)206 FormatType GetFormatType(const std::string &fmt) {
207 return fmt == format_a_ ? FormatType::kFormatA : FormatType::kFormatB;
208 }
209
GenTransformOp(TransOpType trans_type)210 NodePtr GenTransformOp(TransOpType trans_type) {
211 // Only support Transpose now
212 static std::map<std::pair<std::string, std::string>, std::vector<int64_t>> perm_map = {
213 {{kOpFormat_DEFAULT, kOpFormat_NHWC}, {0, 2, 3, 1}},
214 {{kOpFormat_NCHW, kOpFormat_NHWC}, {0, 2, 3, 1}},
215 {{kOpFormat_NHWC, kOpFormat_NCHW}, {0, 3, 1, 2}},
216 {{kOpFormat_NHWC, kOpFormat_DEFAULT}, {0, 3, 1, 2}},
217 };
218 std::vector<int64_t> perm;
219 if (trans_type == TransOpType::kTransAB) {
220 perm = perm_map[{format_a_, format_b_}];
221 } else {
222 perm = perm_map[{format_b_, format_a_}];
223 }
224 if (perm.empty()) {
225 MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_;
226 }
227 auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans");
228 op->SetAttr("perm", MakeValue(perm));
229 return op;
230 }
231
232 private:
233 std::string op_;
234 std::string format_a_;
235 std::string format_b_;
236 };
237
IsFlexibleOp(const NodePtr & node)238 bool IsFlexibleOp(const NodePtr &node) {
239 static const std::set<std::string> format_flexible_ops = {
240 "Abs", "Add", "Sub", "Mul", "Round", "Cast", "Neg", "Exp", "Log",
241 "Pow", "Minimum", "Maximum", "Rsqrt", "Sqrt", "Reciprocal", "Tanh", "Sin", "Cos",
242 "Asin", "ACos", "RealDiv", "Equal", "Greater", "GreaterEqual", "Less", "LessEqual", "Sign"};
243 if (node->NodeType() != NType::Primitive) {
244 return false;
245 }
246 if (format_flexible_ops.count(node->As<PrimOp>()->op()) == 0) {
247 return false;
248 }
249 // check the input and output formats are all the same, except ConstValue.
250 for (auto &inp : node->inputs()) {
251 if (inp->NodeType() != NType::Value && inp->format != node->format) {
252 return false;
253 }
254 }
255 return true;
256 }
257
258 class Mutator {
259 public:
Mutator(const NodePtr & node)260 explicit Mutator(const NodePtr &node) : op_checker_(node), basenode_(node), ori_node_(1) {}
261 ~Mutator() = default;
Run()262 bool Run() {
263 VisitNode(basenode_);
264 if (flexible_ops_.empty()) return false;
265 // remove transform ops in litegraph
266 RemoveTransOp();
267 GenFormatGraph();
268 RebuildLiteGraph();
269 return true;
270 }
271
272 private:
273 // visit nodes bidirectionally
VisitNode(const NodePtr & node)274 void VisitNode(const NodePtr &node) {
275 if (visited_.count(node) > 0) return;
276 (void)visited_.insert(node);
277 if (op_checker_.IsTransformOp(node)) {
278 (void)trans_ops_.insert(node);
279 } else if (!IsFlexibleOp(node)) {
280 if (node->NodeType() != NType::Output) {
281 fmt_type[{node, -1}] = op_checker_.GetFormatType(node->format);
282 }
283 if (node->NodeType() != NType::Parameter) {
284 for (size_t i = 0; i < node->inputs().size(); i++) {
285 if (node->input(i)->NodeType() == NType::Value) {
286 continue;
287 }
288 fmt_type[{node, i}] = op_checker_.GetFormatType(node->input(i)->format);
289 }
290 }
291 return;
292 } else {
293 (void)flexible_ops_.insert(node);
294 fmt_type[{node, -1}] = FormatType::kFormatUnknown;
295 }
296
297 for (auto &input : node->inputs()) {
298 if (input->NodeType() != NType::Value) {
299 VisitNode(input);
300 }
301 }
302 for (auto &user : node->users()) {
303 VisitNode(user.first->shared_from_this());
304 }
305 }
306
RemoveTransOp()307 void RemoveTransOp() {
308 for (auto &node : trans_ops_) {
309 (void)visited_.erase(node);
310 node->ReplaceWith(node->input(0));
311 // clear inputs, so that the node will not be the basenode again.
312 node->SetInputs({});
313 }
314 trans_ops_.clear();
315 }
316
GenFormatGraph()317 void GenFormatGraph() {
318 for (auto &node : visited_) {
319 if (node->NodeType() == NType::Parameter) continue;
320 bool is_flexible = (flexible_ops_.find(node) != flexible_ops_.end());
321 size_t cur_id = 0;
322 if (is_flexible) {
323 cur_id = GetId({node, -1});
324 }
325 for (size_t i = 0; i < node->inputs().size(); i++) {
326 if (visited_.count(node->input(i)) == 0) continue;
327 if (!is_flexible) {
328 cur_id = GetId({node, SizeToInt(i)});
329 }
330 auto input_id = GetId({node->input(i), -1});
331 (void)graph_edges_.emplace_back(input_id, cur_id);
332 }
333 }
334 }
335
RebuildLiteGraph()336 void RebuildLiteGraph() {
337 MinCut min_cut(graph_vertex_, graph_edges_);
338 min_cut.Run();
339 for (auto [node_id, trans_type] : min_cut.GetOneNodeOps()) {
340 if (ori_node_[node_id].second != -1) {
341 MS_LOG(EXCEPTION) << "OneNodeOp should be the output edge. node_id:" << node_id
342 << " index:" << ori_node_[node_id].second;
343 }
344 auto trans_op = op_checker_.GenTransformOp(trans_type);
345 ori_node_[node_id].first->ReplaceWith(trans_op);
346 trans_op->SetInputs({ori_node_[node_id].first});
347 }
348
349 std::map<size_t, NodePtr> trans_op_cache;
350 for (auto [edge, trans_type] : min_cut.GetTwoNodeOps()) {
351 auto node_id_from = edge.first;
352 auto node_id_to = edge.second;
353 if (ori_node_[node_id_from].second != -1) {
354 MS_LOG(EXCEPTION) << "node_from should be the output edge. node_id:" << node_id_from
355 << " index:" << ori_node_[node_id_from].second;
356 }
357 auto node_from = ori_node_[node_id_from].first;
358 auto node_to = ori_node_[node_id_to].first;
359 auto &trans_op = trans_op_cache[node_id_from];
360 if (trans_op == nullptr) {
361 trans_op = op_checker_.GenTransformOp(trans_type);
362 trans_op->SetInputs({node_from});
363 }
364 if (ori_node_[node_id_to].second >= 0) {
365 node_to->SetInput(IntToSize(ori_node_[node_id_to].second), trans_op);
366 } else {
367 for (size_t i = 0; i < node_to->inputs().size(); i++) {
368 if (node_to->input(i) == node_from) {
369 node_to->SetInput(i, trans_op);
370 }
371 }
372 }
373 }
374 }
375
GetId(const std::pair<NodePtr,int> & node)376 size_t GetId(const std::pair<NodePtr, int> &node) {
377 // the nodes are indexed from 1 in the MinCut model.
378 auto &id = node_id_[node];
379 if (id == 0) {
380 id = node_id_.size();
381 ori_node_.push_back(node);
382 // set format_type for new id.
383 (void)graph_vertex_.emplace_back(id, fmt_type[node]);
384 }
385 return id;
386 }
387
388 TransformOp op_checker_;
389 NodePtr basenode_;
390 std::set<NodePtr> flexible_ops_;
391 std::set<NodePtr> trans_ops_;
392 std::set<NodePtr> visited_;
393
394 std::map<std::pair<NodePtr, int>, FormatType> fmt_type;
395 std::map<std::pair<NodePtr, int>, size_t> node_id_;
396 std::vector<std::pair<NodePtr, int>> ori_node_;
397 std::vector<std::pair<size_t, FormatType>> graph_vertex_;
398 std::vector<std::pair<size_t, size_t>> graph_edges_;
399 };
400
Process(const LiteGraphPtr & litegraph,const std::string & trans_op_name)401 bool TransformOpOptimizer::Process(const LiteGraphPtr &litegraph, const std::string &trans_op_name) {
402 ori_trans_op_num_ = 0;
403 auto &ops = litegraph->ops();
404 bool changed = true;
405 auto check_is_trans_op = [&trans_op_name](const NodePtr &node) { return node->As<PrimOp>()->op() == trans_op_name; };
406 auto ori_trans_op_num = std::count_if(ops.begin(), ops.end(), check_is_trans_op);
407 for (auto &op : ops) {
408 if (check_is_trans_op(op) && !op->inputs().empty() && op->input(0)->format != op->format) {
409 auto mutator = Mutator(op);
410 changed = mutator.Run() || changed;
411 }
412 }
413 if (!changed) return false;
414 auto &new_ops = litegraph->GetOrderedNodes();
415 auto new_trans_op_num = std::count_if(new_ops.begin(), new_ops.end(), check_is_trans_op);
416 if (new_trans_op_num >= ori_trans_op_num) {
417 return false;
418 }
419 for (auto &op : new_ops) {
420 op->SetBaseInfo(op->As<PrimOp>()->Infer(op->inputs(), op->attrs()));
421 }
422 return true;
423 }
424
Run(const FuncGraphPtr & kernel_graph)425 bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
426 auto mng = kernel_graph->manager();
427 MS_EXCEPTION_IF_NULL(mng);
428 auto todos = TopoSort(kernel_graph->get_return());
429 bool changed = false;
430 for (auto node : todos) {
431 if (!AnfAlgo::IsGraphKernel(node)) continue;
432 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
433 auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
434 if (Process(litegraph)) {
435 changed = true;
436 AnfNodePtrList outputs;
437 auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
438 new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
439 auto cnode = node->cast<CNodePtr>();
440 AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
441 auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
442 SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
443 (void)mng->Replace(node, new_node);
444 mng->AddFuncGraph(new_funcgraph);
445 }
446 }
447 return changed;
448 }
449 } // namespace opt
450 } // namespace mindspore
451