1 /**
2 * Copyright 2021-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/transform_op_optimizer.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include <queue>
21 #include <map>
22 #include <utility>
23 #include <string>
24 #include <tuple>
25 #include <functional>
26
27 #include "ir/graph_utils.h"
28 #include "utils/anf_utils.h"
29 #include "backend/common/graph_kernel/model/lite_graph.h"
30 #include "backend/common/graph_kernel/model/graph_builder.h"
31 #include "backend/common/graph_kernel/model/op_register.h"
32 #include "backend/common/graph_kernel/core/graph_builder.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34
35 namespace mindspore::graphkernel {
36 constexpr const size_t num2 = 2;
37 struct Edge {
38 size_t from;
39 size_t to;
operator <mindspore::graphkernel::Edge40 bool operator<(const Edge &other) const { return from == other.from ? to < other.to : from < other.from; }
operator <<(std::ostream & os,const Edge & e)41 friend std::ostream &operator<<(std::ostream &os, const Edge &e) {
42 return os << "[" << e.from << " -> " << e.to << "]";
43 }
44 };
operator <<(std::ostream & os,FormatType fmt)45 inline std::ostream &operator<<(std::ostream &os, FormatType fmt) {
46 return os << (fmt == FormatType::kFlexFormat ? "kFlexFormat"
47 : (fmt == FormatType::kFormatA ? "kFormatA" : "kFormatB"));
48 }
operator <<(std::ostream & os,TransOpType trans)49 inline std::ostream &operator<<(std::ostream &os, TransOpType trans) {
50 return os << (trans == TransOpType::kTransAB ? "kTransAB" : "kTransBA");
51 }
52
53 // For format-inflexible nodes, index -1 represent its output field, and index 0~n represent its input field.
54 // for format-flexible nodes, only index -1 represent its all inputs and output fields.
55 using NodeWithIndex = std::pair<NodePtr, int>;
56 using NodeIdWithFormat = std::pair<size_t, FormatType>;
57
58 namespace {
59 constexpr size_t INF = static_cast<size_t>(1) << 30;
60 class MinCut {
61 private:
62 struct MinCutEdge {
63 size_t to;
64 size_t capacity;
65 };
66
67 struct Vertex {
68 FormatType format{FormatType::kFormatB};
69 size_t depth{0};
70 std::vector<size_t> out_edges;
71 };
72
73 // Add the bidirectional edges for the vertex `from` and `to`.
74 // the two edge ids are adjacent in vector, x and x+1 (x are 0,2,4,...)
75 // we can use (i xor 1) to get the inverse edge for any edge i.
76 // e.g. edge_0 and edge_1 are a couple, 0^1=1, 1^1=0.
AddEdge(size_t from,size_t to,size_t capacity,size_t inv_capacity)77 void AddEdge(size_t from, size_t to, size_t capacity, size_t inv_capacity) {
78 (void)edges_.emplace_back(MinCutEdge{to, capacity});
79 (void)nodes_[from].out_edges.emplace_back(edges_.size() - 1);
80 // inverse edge
81 (void)edges_.emplace_back(MinCutEdge{from, inv_capacity});
82 (void)nodes_[to].out_edges.emplace_back(edges_.size() - 1);
83 }
84
BfsSetDepth()85 bool BfsSetDepth() {
86 std::queue<size_t> bfs_queue;
87 for (auto &node : nodes_) {
88 node.depth = 0;
89 }
90 nodes_[source_id_].depth = 1;
91 bfs_queue.push(source_id_);
92 while (!bfs_queue.empty()) {
93 auto edge_from = bfs_queue.front();
94 bfs_queue.pop();
95 for (auto e_id : nodes_[edge_from].out_edges) {
96 auto edge_to = edges_[e_id].to;
97 if (edges_[e_id].capacity > 0 && nodes_[edge_to].depth == 0) {
98 nodes_[edge_to].depth = nodes_[edge_from].depth + 1;
99 bfs_queue.push(edge_to);
100 }
101 }
102 }
103 return nodes_[sink_id_].depth > 0;
104 }
105
DfsMaxFlow(size_t node,size_t flow)106 size_t DfsMaxFlow(size_t node, size_t flow) {
107 if (node == sink_id_) {
108 return flow;
109 }
110 size_t max_flow = 0;
111 for (size_t e_id : nodes_[node].out_edges) {
112 if ((edges_[e_id].capacity > 0) && (nodes_[node].depth + 1 == nodes_[edges_[e_id].to].depth)) {
113 auto tmp_flow = DfsMaxFlow(edges_[e_id].to, std::min(flow, edges_[e_id].capacity));
114 if (tmp_flow > 0) {
115 max_flow += tmp_flow;
116 flow -= tmp_flow;
117 edges_[e_id].capacity -= tmp_flow;
118 edges_[e_id ^ 1].capacity += tmp_flow;
119 }
120 }
121 }
122 return max_flow;
123 }
124
Dinic()125 void Dinic() {
126 while (BfsSetDepth()) {
127 (void)DfsMaxFlow(source_id_, INF);
128 }
129 }
130
131 // set the nodes that connected with source node to kFormatA, the remaining nodes are seen as kFormatB.
SetFormat(size_t node_id)132 void SetFormat(size_t node_id) {
133 nodes_[node_id].format = FormatType::kFormatA;
134 MS_LOG(DEBUG) << "Set node_id " << node_id << " to kFormatA.";
135 for (size_t i : nodes_[node_id].out_edges) {
136 if (edges_[i].capacity > 0 && nodes_[edges_[i].to].format != FormatType::kFormatA) {
137 SetFormat(edges_[i].to);
138 }
139 }
140 }
141
BuildGraph(const std::vector<NodeIdWithFormat> & original_nodes)142 void BuildGraph(const std::vector<NodeIdWithFormat> &original_nodes) {
143 for (size_t i = 0; i < origin_nodes_num_; ++i) {
144 // link the source node to the nodes with FormatA,
145 // link the nodes with FormatB to the sink node.
146 if (original_nodes[i].second == FormatType::kFormatA) {
147 AddEdge(source_id_, original_nodes[i].first, INF, 0);
148 } else if (original_nodes[i].second == FormatType::kFormatB) {
149 AddEdge(original_nodes[i].first, sink_id_, INF, 0);
150 }
151 // each nodes was split into two part, input part and output part.
152 // the input part's id is the original node's id, the output part's id is input id + origin_nodes_num_.
153 AddEdge(original_nodes[i].first, original_nodes[i].first + origin_nodes_num_, 1, 1);
154 }
155 for (auto e : original_edges_) {
156 AddEdge(e.from + origin_nodes_num_, e.to, 1, 1);
157 }
158 }
159
160 public:
MinCut(const std::vector<NodeIdWithFormat> & original_nodes,const std::vector<Edge> & original_edges)161 MinCut(const std::vector<NodeIdWithFormat> &original_nodes, const std::vector<Edge> &original_edges)
162 : origin_nodes_num_(original_nodes.size()),
163 source_id_(0),
164 sink_id_(num2 * original_nodes.size() + 1),
165 nodes_(num2 * original_nodes.size() + num2), // double nodes, and source_node/sink_node
166 original_edges_(original_edges) {
167 BuildGraph(original_nodes);
168 }
169 ~MinCut() = default;
170
Run()171 void Run() {
172 Dinic();
173 SetFormat(source_id_);
174 }
175
GetOneNodeOps() const176 std::vector<std::pair<size_t, TransOpType>> GetOneNodeOps() const {
177 std::vector<std::pair<size_t, TransOpType>> one_node_ops;
178 for (size_t i = 1; i <= origin_nodes_num_; ++i) {
179 auto tmpi = i; // to evade pclint warning "for statement index variable modified in body."
180 if (nodes_[i].format == FormatType::kFormatA && nodes_[i + origin_nodes_num_].format != FormatType::kFormatA) {
181 (void)one_node_ops.emplace_back(tmpi, TransOpType::kTransAB);
182 MS_LOG(DEBUG) << "Inserted kTransAB for node_id " << tmpi;
183 } else if (nodes_[i].format != FormatType::kFormatA &&
184 nodes_[i + origin_nodes_num_].format == FormatType::kFormatA) {
185 (void)one_node_ops.emplace_back(tmpi, TransOpType::kTransBA);
186 MS_LOG(DEBUG) << "Inserted kTransBA for node_id " << tmpi;
187 }
188 }
189 return one_node_ops;
190 }
191
GetTwoNodeOps() const192 std::vector<std::pair<Edge, TransOpType>> GetTwoNodeOps() const {
193 std::vector<std::pair<Edge, TransOpType>> two_node_ops;
194 for (auto e : original_edges_) {
195 if (nodes_[e.from + origin_nodes_num_].format == FormatType::kFormatA &&
196 nodes_[e.to].format != FormatType::kFormatA) {
197 (void)two_node_ops.emplace_back(e, TransOpType::kTransAB);
198 MS_LOG(DEBUG) << "Inserted kTransAB for edge " << e;
199 } else if (nodes_[e.from + origin_nodes_num_].format != FormatType::kFormatA &&
200 nodes_[e.to].format == FormatType::kFormatA) {
201 (void)two_node_ops.emplace_back(e, TransOpType::kTransBA);
202 MS_LOG(DEBUG) << "Inserted kTransBA for edge " << e;
203 }
204 }
205 return two_node_ops;
206 }
207
208 private:
209 size_t origin_nodes_num_;
210 size_t source_id_;
211 size_t sink_id_;
212 std::vector<Vertex> nodes_;
213 std::vector<MinCutEdge> edges_;
214 std::vector<Edge> original_edges_;
215 };
216
IsDynamicShapeGraph(const inner::LiteGraphPtr & litegraph)217 bool IsDynamicShapeGraph(const inner::LiteGraphPtr &litegraph) {
218 MS_EXCEPTION_IF_NULL(litegraph);
219 for (auto &op : litegraph->ops()) {
220 if (IsDynamic(op->shape)) {
221 return true;
222 }
223 }
224 return false;
225 }
226 } // namespace
227
228 using inner::LiteGraph;
229 using inner::LiteGraphPtr;
230 using inner::NodePtrList;
231 using inner::NType;
232 using inner::PrimOp;
233 using inner::PrimOpPtr;
234
TransformOp(const NodePtr & node)235 TransformOp::TransformOp(const NodePtr &node)
236 : op_(node->As<PrimOp>()->op()), format_a_(node->input(0)->format), format_b_(node->format) {}
237
Hash() const238 size_t TransformOp::Hash() const {
239 // TransAB and TransBA are seen as the same trans op.
240 auto fmt1 = format_a_;
241 auto fmt2 = format_b_;
242 if (fmt1 > fmt2) {
243 std::swap(fmt1, fmt2);
244 }
245 return std::hash<std::string>{}(op_ + fmt1 + fmt2);
246 }
247
GetFormat(const NodePtr & node) const248 std::string TransformOp::GetFormat(const NodePtr &node) const { return node->format; }
249
IsTransformOp(const NodePtr & node)250 bool TransformOp::IsTransformOp(const NodePtr &node) {
251 if (node->NodeType() != NType::Primitive || node->As<PrimOp>()->op() != op_) {
252 return false;
253 }
254 auto format_in = GetFormat(node->input(0));
255 auto format_out = GetFormat(node);
256 if (format_in == format_a_ && format_out == format_b_) {
257 return true;
258 } else if (format_in == format_b_ && format_out == format_a_) {
259 return true;
260 }
261 return false;
262 }
263
NeedInsert(const NodePtr & input_node) const264 bool TransformOp::NeedInsert(const NodePtr &input_node) const {
265 // a trick, if the node's size of 1, it's not need to insert transform op.
266 return input_node->tensor_size() != 1;
267 }
268
GetFormatType(const std::string & fmt)269 FormatType TransformOp::GetFormatType(const std::string &fmt) {
270 // nodes that are not flexible and not FormatA will be set to FormatB (include "others" format)
271 return fmt == format_a_ ? FormatType::kFormatA : FormatType::kFormatB;
272 }
273
SetInput(const NodePtr & node,const NodePtr & input_node)274 void TransformOp::SetInput(const NodePtr &node, const NodePtr &input_node) { node->SetInputs({input_node}); }
275
IsTransOp(const NodePtr & node) const276 bool TransformOpCreator::IsTransOp(const NodePtr &node) const {
277 if (node->NodeType() == NType::Primitive) {
278 if (node->As<PrimOp>()->op() == op_name_) {
279 if (op_name_ == "Reshape") {
280 return node->format == node->input(0)->format;
281 }
282 return true;
283 }
284 }
285 return false;
286 }
287
288 class TransposeHandle : public TransformOp {
289 public:
290 using TransformOp::TransformOp;
GenTransformOp(const NodePtr & input_node,TransOpType trans_type)291 NodePtr GenTransformOp(const NodePtr &input_node, TransOpType trans_type) override {
292 static std::map<std::tuple<size_t, std::string, std::string>, std::vector<int64_t>> perm_map = {
293 // rank 3
294 {{3, kOpFormat_NCHW, kOpFormat_NHWC}, {1, 2, 0}},
295 {{3, kOpFormat_NHWC, kOpFormat_NCHW}, {2, 0, 1}},
296 // rank 4
297 {{4, kOpFormat_DEFAULT, kOpFormat_NHWC}, {0, 2, 3, 1}},
298 {{4, kOpFormat_NCHW, kOpFormat_NHWC}, {0, 2, 3, 1}},
299 {{4, kOpFormat_NHWC, kOpFormat_NCHW}, {0, 3, 1, 2}},
300 {{4, kOpFormat_NHWC, kOpFormat_DEFAULT}, {0, 3, 1, 2}},
301 };
302 std::vector<int64_t> perm;
303 std::string dst_format;
304 auto rank = input_node->shape.size();
305 if (trans_type == TransOpType::kTransAB) {
306 perm = perm_map[{rank, format_a_, format_b_}];
307 dst_format = format_b_;
308 } else {
309 perm = perm_map[{rank, format_b_, format_a_}];
310 dst_format = format_a_;
311 }
312 if (perm.empty()) {
313 MS_LOG(INFO) << "unsupported format: " << format_a_ << " to " << format_b_ << " of rank " << rank;
314 return nullptr;
315 }
316 auto op = inner::OpRegistry::Instance().NewOp(op_);
317 auto perm_tensor = std::make_shared<tensor::Tensor>(perm, kInt64);
318 node_to_input_tensor_map_[op] = perm_tensor;
319 op->SetAttr(kAttrDstFormat, MakeValue(dst_format));
320 return op;
321 }
322
SetInput(const NodePtr & node,const NodePtr & input_node)323 void SetInput(const NodePtr &node, const NodePtr &input_node) override {
324 inner::GraphBuilder gb;
325 auto iter = node_to_input_tensor_map_.find(node);
326 if (iter == node_to_input_tensor_map_.end()) {
327 MS_LOG(EXCEPTION) << "Can't find input valueptr for node: " << node->ToString();
328 }
329 auto perm_tensor = iter->second;
330 auto perm_node = gb.Value(perm_tensor);
331 node->SetInputs({input_node, perm_node});
332 }
333
334 private:
335 std::map<NodePtr, tensor::TensorPtr> node_to_input_tensor_map_;
336 };
337
338 class LayoutTransformHandle : public TransformOp {
339 public:
340 using TransformOp::TransformOp;
GenTransformOp(const NodePtr &,TransOpType trans_type)341 NodePtr GenTransformOp(const NodePtr &, TransOpType trans_type) override {
342 auto op = inner::OpRegistry::Instance().NewOp(op_);
343 if (trans_type == TransOpType::kTransAB) {
344 op->SetAttr(kAttrSrcFormat, MakeValue(format_a_));
345 op->SetAttr(kAttrDstFormat, MakeValue(format_b_));
346 } else {
347 op->SetAttr(kAttrSrcFormat, MakeValue(format_b_));
348 op->SetAttr(kAttrDstFormat, MakeValue(format_a_));
349 }
350 return op;
351 }
352 };
353
354 class ReshapeHandle : public TransformOp {
355 public:
ReshapeHandle(const NodePtr & node)356 explicit ReshapeHandle(const NodePtr &node) : TransformOp(node) {
357 format_a_ = EncodeShape(node->input(0)->shape);
358 format_b_ = EncodeShape(node->shape);
359 }
360 virtual ~ReshapeHandle() = default;
361
GetFormat(const NodePtr & node) const362 std::string GetFormat(const NodePtr &node) const override {
363 // Reshape op uses shape as format
364 return EncodeShape(node->shape);
365 }
366
NeedInsert(const NodePtr &) const367 bool NeedInsert(const NodePtr &) const override {
368 // Reshape op must be inserted, otherwise the out shape of a node may changed and users may need infer shape again.
369 return true;
370 }
371
GenTransformOp(const NodePtr &,TransOpType trans_type)372 NodePtr GenTransformOp(const NodePtr &, TransOpType trans_type) override {
373 auto op = inner::OpRegistry::Instance().NewOp(op_);
374 auto out_format = trans_type == TransOpType::kTransAB ? format_b_ : format_a_;
375 auto out_shape = DecodeShape(out_format);
376 auto shape_tensor = std::make_shared<tensor::Tensor>(out_shape, kInt64);
377 node_to_input_tensor_map_[op] = shape_tensor;
378 return op;
379 }
380
SetInput(const NodePtr & node,const NodePtr & input_node)381 void SetInput(const NodePtr &node, const NodePtr &input_node) override {
382 inner::GraphBuilder gb;
383 auto iter = node_to_input_tensor_map_.find(node);
384 if (iter == node_to_input_tensor_map_.end()) {
385 MS_LOG(EXCEPTION) << "Can't find input valueptr for node: " << node->ToString();
386 }
387 auto shape_tensor = iter->second;
388 auto shape_node = gb.Value(shape_tensor);
389 node->SetInputs({input_node, shape_node});
390 }
391
392 private:
EncodeShape(const ShapeVector & shape) const393 std::string EncodeShape(const ShapeVector &shape) const {
394 std::string res;
395 for (const auto &s : shape) {
396 res += std::to_string(s) + "_";
397 }
398 return res;
399 }
400
DecodeShape(const std::string & shape) const401 ShapeVector DecodeShape(const std::string &shape) const {
402 ShapeVector res;
403 size_t l = 0;
404 for (size_t i = 0; i < shape.size(); ++i) {
405 if (shape[i] == '_' && i > l) {
406 std::istringstream iss(shape.substr(l, i));
407 l = i + 1;
408 int64_t s;
409 iss >> s;
410 res.push_back(s);
411 }
412 }
413 return res;
414 }
415
416 std::map<NodePtr, tensor::TensorPtr> node_to_input_tensor_map_;
417 };
418
419 constexpr int kOutputIndex = -1;
420 class Mutator {
421 public:
Mutator(const NodePtr & node,const TransformOpPtr & handle)422 Mutator(const NodePtr &node, const TransformOpPtr &handle) : op_handle_(handle), basenode_(node), ori_node_(1) {}
423 ~Mutator() = default;
424
Run(std::set<NodePtr> * changed_nodes)425 bool Run(std::set<NodePtr> *changed_nodes) {
426 VisitNode(basenode_, kOutputIndex);
427 if (flexible_ops_.empty() && trans_ops_.size() <= 1) {
428 return false;
429 }
430 // remove transform ops in litegraph
431 RemoveTransOp();
432 GenFormatGraph();
433 RebuildLiteGraph(changed_nodes);
434 changed_nodes->insert(flexible_ops_.begin(), flexible_ops_.end());
435 return true;
436 }
437
new_trans_op_num() const438 size_t new_trans_op_num() const { return new_trans_op_num_; }
439
440 private:
VisitNode(const NodePtr & node,int index)441 void VisitNode(const NodePtr &node, int index) {
442 if (visited_.count(node) > 0 && inflexible_ops_.count(node) == 0) {
443 return;
444 }
445 (void)visited_.insert(node);
446 if (op_handle_->IsTransformOp(node)) {
447 (void)trans_ops_.insert(node);
448 } else if (!IsFlexibleOp(node)) {
449 VisitInflexibleOp(node, index);
450 return;
451 } else {
452 (void)flexible_ops_.insert(node);
453 fmt_type[{node, kOutputIndex}] = FormatType::kFlexFormat;
454 }
455 // for trans op or format-flexible op, visit node bidirectionally.
456 for (auto &input : node->inputs()) {
457 if (input->NodeType() != NType::Tensor && input->NodeType() != NType::Scalar) {
458 VisitNode(input, kOutputIndex);
459 }
460 }
461 for (auto &user : node->users()) {
462 for (auto user_idx : user.second) {
463 VisitNode(user.first->shared_from_this(), SizeToInt(user_idx));
464 }
465 }
466 }
467
VisitInflexibleOp(const NodePtr & node,int index)468 void VisitInflexibleOp(const NodePtr &node, int index) {
469 auto &visited_index = inflexible_ops_[node];
470 if (!visited_index.insert(index).second) {
471 return;
472 }
473 if (visited_index.size() == 1) {
474 if (node->NodeType() != NType::Output) {
475 fmt_type[{node, kOutputIndex}] = op_handle_->GetFormatType(op_handle_->GetFormat(node));
476 }
477 if (node->NodeType() != NType::Parameter) {
478 for (size_t i = 0; i < node->inputs().size(); i++) {
479 if (node->input(i)->NodeType() != NType::Tensor && node->input(i)->NodeType() != NType::Scalar) {
480 fmt_type[{node, i}] = op_handle_->GetFormatType(op_handle_->GetFormat(node->input(i)));
481 }
482 }
483 }
484 }
485 // this node is visited from output direction, visit its other users
486 if (index < 0) {
487 for (const auto &user : node->users()) {
488 for (auto user_idx : user.second) {
489 VisitNode(user.first->shared_from_this(), SizeToInt(user_idx));
490 }
491 }
492 }
493 }
494
RemoveTransOp()495 void RemoveTransOp() {
496 for (const auto &node : trans_ops_) {
497 (void)visited_.erase(node);
498 node->ReplaceWith(node->input(0));
499 // clear inputs, so that the node will not be the basenode again.
500 node->ClearInputs();
501 }
502 trans_ops_.clear();
503 }
504
GenFormatGraph()505 void GenFormatGraph() {
506 for (const auto &node : visited_) {
507 if (node->NodeType() == NType::Parameter) {
508 continue;
509 }
510 bool is_flexible = (flexible_ops_.find(node) != flexible_ops_.cend());
511 size_t cur_id = 0;
512 if (is_flexible) {
513 cur_id = GetNodeId({node, kOutputIndex});
514 }
515 for (size_t i = 0; i < node->inputs().size(); i++) {
516 if (visited_.count(node->input(i)) == 0) {
517 continue;
518 }
519 if (!is_flexible) {
520 cur_id = GetNodeId({node, SizeToInt(i)});
521 }
522 auto input_id = GetNodeId({node->input(i), kOutputIndex});
523 (void)graph_edges_.emplace_back(Edge{input_id, cur_id});
524 }
525 }
526 }
527
NewTransOp(const NodePtr & input,TransOpType trans_type,std::set<NodePtr> * changed_nodes)528 NodePtr NewTransOp(const NodePtr &input, TransOpType trans_type, std::set<NodePtr> *changed_nodes) {
529 if (!op_handle_->NeedInsert(input)) {
530 return nullptr;
531 }
532 NodePtr trans_op = op_handle_->GenTransformOp(input, trans_type);
533 MS_EXCEPTION_IF_NULL(trans_op);
534 static size_t inc_id = 0;
535 trans_op->SetDebugName("new_trans_op_" + std::to_string(inc_id++));
536 MS_LOG(DEBUG) << "Create " << trans_op->debug_name() << " of " << trans_type << " with input node "
537 << input->debug_name();
538 (void)changed_nodes->insert(trans_op);
539 new_trans_op_num_++;
540 return trans_op;
541 }
542
RefineEdges(std::vector<std::pair<size_t,TransOpType>> * one_node_edge,std::vector<std::pair<Edge,TransOpType>> * two_node_edge) const543 void RefineEdges(std::vector<std::pair<size_t, TransOpType>> *one_node_edge,
544 std::vector<std::pair<Edge, TransOpType>> *two_node_edge) const {
545 std::map<size_t, TransOpType> one_node_edge_map;
546 for (auto &one : *one_node_edge) {
547 one_node_edge_map[one.first] = one.second;
548 }
549 std::set<Edge> removed_edges;
550 std::set<size_t> removed_edges_from;
551 for (auto iter = two_node_edge->begin(); iter != two_node_edge->end();) {
552 if (one_node_edge_map.count(iter->first.from) == 0) {
553 ++iter;
554 continue;
555 }
556 auto from = iter->first.from;
557 (void)removed_edges_from.insert(from);
558 // remove node from one_node_edge.
559 auto rm_iter = std::find_if(one_node_edge->begin(), one_node_edge->end(),
560 [from](const std::pair<size_t, TransOpType> &no) { return from == no.first; });
561 if (rm_iter != one_node_edge->end()) {
562 (void)one_node_edge->erase(rm_iter);
563 MS_LOG(DEBUG) << "Removed edge for node_id " << from;
564 }
565 // remove node from two_node_edge.
566 (void)removed_edges.insert(iter->first);
567 iter = two_node_edge->erase(iter);
568 MS_LOG(DEBUG) << "Removed edge " << iter->first.from << " -> " << iter->first.to;
569 }
570 for (auto &e : graph_edges_) {
571 if (removed_edges_from.count(e.from) != 0 && removed_edges.count(e) == 0) {
572 two_node_edge->push_back(std::make_pair(e, one_node_edge_map[e.from]));
573 MS_LOG(DEBUG) << "Inserted " << (one_node_edge_map[e.from] == TransOpType::kTransAB ? "kTransAB" : "kTransBA")
574 << " for edge " << e.from << " -> " << e.to;
575 }
576 }
577 }
578
RebuildLiteGraph(std::set<NodePtr> * changed_nodes)579 void RebuildLiteGraph(std::set<NodePtr> *changed_nodes) {
580 MinCut min_cut(graph_vertex_, graph_edges_);
581 min_cut.Run();
582 auto one_node_edge = min_cut.GetOneNodeOps();
583 auto two_node_edge = min_cut.GetTwoNodeOps();
584 RefineEdges(&one_node_edge, &two_node_edge);
585 for (auto [node_id, trans_type] : one_node_edge) {
586 if (ori_node_[node_id].second != kOutputIndex) {
587 MS_LOG(EXCEPTION) << "OneNodeOp should be the output edge. node_id:" << node_id
588 << " index:" << ori_node_[node_id].second;
589 }
590 auto input_node = ori_node_[node_id].first;
591 auto trans_op = NewTransOp(input_node, trans_type, changed_nodes);
592 if (trans_op == nullptr) {
593 continue;
594 }
595 input_node->ReplaceWith(trans_op);
596 op_handle_->SetInput(trans_op, input_node);
597 MS_LOG(DEBUG) << "Inserted " << trans_op->debug_name() << " after " << input_node->debug_name();
598 }
599
600 std::map<size_t, NodePtr> trans_op_cache;
601 for (auto [insert_edge, trans_type] : two_node_edge) {
602 if (ori_node_[insert_edge.from].second != kOutputIndex) {
603 MS_LOG(EXCEPTION) << "node_from should be the output insert_edge. node_id:" << insert_edge.from
604 << " index:" << ori_node_[insert_edge.from].second;
605 }
606 auto node_from = ori_node_[insert_edge.from].first;
607 auto node_to = ori_node_[insert_edge.to].first;
608 if (trans_op_cache.count(insert_edge.from) == 0) {
609 auto trans_op = NewTransOp(node_from, trans_type, changed_nodes);
610 if (trans_op == nullptr) {
611 continue;
612 }
613 trans_op_cache[insert_edge.from] = trans_op;
614 op_handle_->SetInput(trans_op, node_from);
615 }
616 auto trans_op = trans_op_cache[insert_edge.from];
617 if (ori_node_[insert_edge.to].second >= 0) {
618 node_to->SetInput(IntToSize(ori_node_[insert_edge.to].second), trans_op);
619 MS_LOG(DEBUG) << "Inserted " << trans_op->debug_name() << " before " << node_to->debug_name() << " (input "
620 << ori_node_[insert_edge.to].second << ")";
621 } else {
622 // "node_to" is flexible.
623 for (size_t i = 0; i < node_to->inputs().size(); i++) {
624 if (node_to->input(i) == node_from) {
625 node_to->SetInput(i, trans_op);
626 MS_LOG(DEBUG) << "Inserted " << trans_op->debug_name() << " before " << node_to->debug_name() << " (input "
627 << i << ")";
628 }
629 }
630 }
631 }
632 }
633
GetNodeId(const NodeWithIndex & node_with_index)634 size_t GetNodeId(const NodeWithIndex &node_with_index) {
635 // the nodes are indexed from 1 in the MinCut model.
636 auto &id = node_id_[node_with_index];
637 if (id == 0) {
638 id = node_id_.size();
639 ori_node_.push_back(node_with_index);
640 // set format_type for new id.
641 (void)graph_vertex_.emplace_back(id, fmt_type[node_with_index]);
642 MS_LOG(DEBUG) << "Allot node_id " << id << " to " << node_with_index.first->debug_name() << " (index "
643 << node_with_index.second << ").";
644 }
645 return id;
646 }
647
IsFlexibleOp(const NodePtr & node) const648 bool IsFlexibleOp(const NodePtr &node) const {
649 if (node->NodeType() != NType::Primitive) {
650 return false;
651 }
652 if (node->As<PrimOp>()->compute_type() != PrimOp::ComputeType::ELEMWISE) {
653 return false;
654 }
655 // check the input and output formats are all the same, except ConstValue.
656 for (auto &inp : node->inputs()) {
657 if (inp->NodeType() != NType::Tensor && inp->NodeType() != NType::Scalar &&
658 op_handle_->GetFormat(inp) != op_handle_->GetFormat(node)) {
659 return false;
660 }
661 }
662 return true;
663 }
664
665 size_t new_trans_op_num_{0};
666
667 TransformOpPtr op_handle_;
668 NodePtr basenode_;
669 std::set<NodePtr> flexible_ops_;
670 std::set<NodePtr> trans_ops_;
671 std::set<NodePtr> visited_;
672 std::map<NodePtr, std::set<int>> inflexible_ops_; // no transop and no flexibleop, record the visit index.
673
674 std::map<NodeWithIndex, FormatType> fmt_type;
675 std::map<NodeWithIndex, size_t> node_id_;
676 std::vector<NodeWithIndex> ori_node_; // node_id to NodePtr, this vector is indexed from 1
677 std::vector<NodeIdWithFormat> graph_vertex_;
678 std::vector<Edge> graph_edges_;
679 };
680
ReInfer(const LiteGraphPtr & litegraph,const std::set<NodePtr> & nodes_may_change) const681 void TransformOpOptimizer::ReInfer(const LiteGraphPtr &litegraph, const std::set<NodePtr> &nodes_may_change) const {
682 auto &new_ops = litegraph->GetOrderedNodes();
683 MS_LOG(DEBUG) << "The changed graph before InferShape: \n" << litegraph->ToString();
684 for (auto &op : new_ops) {
685 if (nodes_may_change.count(op) != 0) {
686 op->SetBaseInfo(op->As<PrimOp>()->Infer(op->inputs(), op->attrs()));
687 }
688 }
689 MS_LOG(DEBUG) << "The changed graph after InferShape: \n" << litegraph->ToString();
690 }
691
Process(const LiteGraphPtr & litegraph,const TransformOpPtr & op_handle) const692 bool TransformOpOptimizer::Process(const LiteGraphPtr &litegraph, const TransformOpPtr &op_handle) const {
693 MS_LOG(DEBUG) << "Process begin, handle is " << *op_handle << ". litegraph: \n" << litegraph->ToString();
694 auto ops = litegraph->ops();
695 bool changed = false;
696 auto check_is_trans_op = [&op_handle](const NodePtr &node) { return op_handle->IsTransformOp(node); };
697 size_t ori_trans_op_num = static_cast<size_t>(std::count_if(ops.begin(), ops.end(), check_is_trans_op));
698 size_t new_trans_op_num = 0;
699 for (auto &op : ops) {
700 if (check_is_trans_op(op) && !op->inputs().empty()) {
701 if (op_handle->GetFormat(op->input(0)) != op_handle->GetFormat(op)) {
702 std::set<NodePtr> nodes_may_change;
703 auto mutator = Mutator(op, op_handle);
704 MS_LOG(DEBUG) << "Run mutator with basenode " << op->debug_name();
705 auto ret = mutator.Run(&nodes_may_change);
706 MS_LOG(DEBUG) << "Run mutator result: " << ret;
707 new_trans_op_num += mutator.new_trans_op_num();
708 if (ret) {
709 changed = true;
710 ReInfer(litegraph, nodes_may_change);
711 }
712 }
713 }
714 }
715 bool result = changed && new_trans_op_num < ori_trans_op_num;
716 MS_LOG(DEBUG) << "Process result=" << result << ". changed=" << changed << ", new_trans_op_num=" << new_trans_op_num
717 << ", ori_trans_op_num=" << ori_trans_op_num;
718 return result;
719 }
720
Init()721 void TransformOpOptimizer::Init() {
722 (void)supported_ops_.emplace_back(TRANS_OP_CREATOR("Transpose", TransposeHandle));
723 (void)supported_ops_.emplace_back(TRANS_OP_CREATOR("LayoutTransform", LayoutTransformHandle));
724 (void)supported_ops_.emplace_back(TRANS_OP_CREATOR("Reshape", ReshapeHandle));
725 }
726
CreateOpHandles(const LiteGraphPtr & litegraph) const727 std::vector<TransformOpPtr> TransformOpOptimizer::CreateOpHandles(const LiteGraphPtr &litegraph) const {
728 HashSet<size_t> handle_hash;
729 std::vector<TransformOpPtr> handles;
730 for (auto &creator : supported_ops_) {
731 if (creator.Name() == "Reshape" && IsDynamicShapeGraph(litegraph)) {
732 // skip dynamic shape
733 continue;
734 }
735 for (auto &op : litegraph->ops()) {
736 if (creator.IsTransOp(op)) {
737 auto handle = creator.CreateHandle(op);
738 if (handle_hash.insert(handle->Hash()).second) {
739 (void)handles.emplace_back(handle);
740 }
741 }
742 }
743 }
744 return handles;
745 }
746
Run(const FuncGraphPtr & func_graph)747 bool TransformOpOptimizer::Run(const FuncGraphPtr &func_graph) {
748 auto mng = func_graph->manager();
749 MS_EXCEPTION_IF_NULL(mng);
750 auto todos = GkUtils::GetGraphKernelNodes(func_graph);
751 bool changed = false;
752 for (auto node : todos) {
753 MS_LOG(DEBUG) << "Run the node: " << node->fullname_with_scope();
754 auto sub_func_graph = GetCNodeFuncGraph(node);
755 auto node_name = sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
756 auto litegraph = GkUtils::AnfGraph2LiteGraph(sub_func_graph);
757 auto handles = CreateOpHandles(litegraph);
758 for (size_t i = 0; i < handles.size(); i++) {
759 // rebuild litegraph for every process
760 if (i > 0) {
761 litegraph = GkUtils::AnfGraph2LiteGraph(GetCNodeFuncGraph(node));
762 }
763 bool result = false;
764 try {
765 MS_LOG_TRY_CATCH_SCOPE;
766 result = Process(litegraph, handles[i]);
767 } catch (std::exception &e) {
768 result = false;
769 MS_LOG(INFO) << "Process node " << node->DebugString() << " failed. message: " << e.what();
770 }
771 if (result) {
772 changed = true;
773 MS_LOG(DEBUG) << "Replace with graph:\n" << litegraph->ToString(true);
774 auto new_funcgraph = GkUtils::LiteGraph2AnfGraph(litegraph, Callback::Instance());
775 MS_EXCEPTION_IF_NULL(new_funcgraph);
776 new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, node_name);
777 auto cnode = node->cast<CNodePtr>();
778 AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
779 (void)ConvertTensorToParameter(new_funcgraph, &inputs);
780 auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs);
781 (void)mng->Replace(node, new_node);
782 node = new_node;
783 mng->AddFuncGraph(new_funcgraph);
784 }
785 }
786 }
787 return changed;
788 }
789 } // namespace mindspore::graphkernel
790