• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/transform_op_optimizer.h"
17 
18 #include <algorithm>
19 #include <vector>
20 #include <queue>
21 #include <map>
22 #include <utility>
23 #include <string>
24 #include <tuple>
25 #include <functional>
26 
27 #include "ir/graph_utils.h"
28 #include "utils/anf_utils.h"
29 #include "backend/common/graph_kernel/model/lite_graph.h"
30 #include "backend/common/graph_kernel/model/graph_builder.h"
31 #include "backend/common/graph_kernel/model/op_register.h"
32 #include "backend/common/graph_kernel/core/graph_builder.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34 
35 namespace mindspore::graphkernel {
36 constexpr const size_t num2 = 2;
37 struct Edge {
38   size_t from;
39   size_t to;
operator <mindspore::graphkernel::Edge40   bool operator<(const Edge &other) const { return from == other.from ? to < other.to : from < other.from; }
operator <<(std::ostream & os,const Edge & e)41   friend std::ostream &operator<<(std::ostream &os, const Edge &e) {
42     return os << "[" << e.from << " -> " << e.to << "]";
43   }
44 };
operator <<(std::ostream & os,FormatType fmt)45 inline std::ostream &operator<<(std::ostream &os, FormatType fmt) {
46   return os << (fmt == FormatType::kFlexFormat ? "kFlexFormat"
47                                                : (fmt == FormatType::kFormatA ? "kFormatA" : "kFormatB"));
48 }
operator <<(std::ostream & os,TransOpType trans)49 inline std::ostream &operator<<(std::ostream &os, TransOpType trans) {
50   return os << (trans == TransOpType::kTransAB ? "kTransAB" : "kTransBA");
51 }
52 
53 // For format-inflexible nodes, index -1 represent its output field, and index 0~n represent its input field.
54 // for format-flexible nodes, only index -1 represent its all inputs and output fields.
55 using NodeWithIndex = std::pair<NodePtr, int>;
56 using NodeIdWithFormat = std::pair<size_t, FormatType>;
57 
58 namespace {
59 constexpr size_t INF = static_cast<size_t>(1) << 30;
60 class MinCut {
61  private:
62   struct MinCutEdge {
63     size_t to;
64     size_t capacity;
65   };
66 
67   struct Vertex {
68     FormatType format{FormatType::kFormatB};
69     size_t depth{0};
70     std::vector<size_t> out_edges;
71   };
72 
73   // Add the bidirectional edges for the vertex `from` and `to`.
74   // the two edge ids are adjacent in vector, x and x+1 (x are 0,2,4,...)
75   // we can use (i xor 1) to get the inverse edge for any edge i.
76   // e.g. edge_0 and edge_1 are a couple, 0^1=1, 1^1=0.
AddEdge(size_t from,size_t to,size_t capacity,size_t inv_capacity)77   void AddEdge(size_t from, size_t to, size_t capacity, size_t inv_capacity) {
78     (void)edges_.emplace_back(MinCutEdge{to, capacity});
79     (void)nodes_[from].out_edges.emplace_back(edges_.size() - 1);
80     // inverse edge
81     (void)edges_.emplace_back(MinCutEdge{from, inv_capacity});
82     (void)nodes_[to].out_edges.emplace_back(edges_.size() - 1);
83   }
84 
BfsSetDepth()85   bool BfsSetDepth() {
86     std::queue<size_t> bfs_queue;
87     for (auto &node : nodes_) {
88       node.depth = 0;
89     }
90     nodes_[source_id_].depth = 1;
91     bfs_queue.push(source_id_);
92     while (!bfs_queue.empty()) {
93       auto edge_from = bfs_queue.front();
94       bfs_queue.pop();
95       for (auto e_id : nodes_[edge_from].out_edges) {
96         auto edge_to = edges_[e_id].to;
97         if (edges_[e_id].capacity > 0 && nodes_[edge_to].depth == 0) {
98           nodes_[edge_to].depth = nodes_[edge_from].depth + 1;
99           bfs_queue.push(edge_to);
100         }
101       }
102     }
103     return nodes_[sink_id_].depth > 0;
104   }
105 
DfsMaxFlow(size_t node,size_t flow)106   size_t DfsMaxFlow(size_t node, size_t flow) {
107     if (node == sink_id_) {
108       return flow;
109     }
110     size_t max_flow = 0;
111     for (size_t e_id : nodes_[node].out_edges) {
112       if ((edges_[e_id].capacity > 0) && (nodes_[node].depth + 1 == nodes_[edges_[e_id].to].depth)) {
113         auto tmp_flow = DfsMaxFlow(edges_[e_id].to, std::min(flow, edges_[e_id].capacity));
114         if (tmp_flow > 0) {
115           max_flow += tmp_flow;
116           flow -= tmp_flow;
117           edges_[e_id].capacity -= tmp_flow;
118           edges_[e_id ^ 1].capacity += tmp_flow;
119         }
120       }
121     }
122     return max_flow;
123   }
124 
Dinic()125   void Dinic() {
126     while (BfsSetDepth()) {
127       (void)DfsMaxFlow(source_id_, INF);
128     }
129   }
130 
131   // set the nodes that connected with source node to kFormatA, the remaining nodes are seen as kFormatB.
SetFormat(size_t node_id)132   void SetFormat(size_t node_id) {
133     nodes_[node_id].format = FormatType::kFormatA;
134     MS_LOG(DEBUG) << "Set node_id " << node_id << " to kFormatA.";
135     for (size_t i : nodes_[node_id].out_edges) {
136       if (edges_[i].capacity > 0 && nodes_[edges_[i].to].format != FormatType::kFormatA) {
137         SetFormat(edges_[i].to);
138       }
139     }
140   }
141 
BuildGraph(const std::vector<NodeIdWithFormat> & original_nodes)142   void BuildGraph(const std::vector<NodeIdWithFormat> &original_nodes) {
143     for (size_t i = 0; i < origin_nodes_num_; ++i) {
144       // link the source node to the nodes with FormatA,
145       // link the nodes with FormatB to the sink node.
146       if (original_nodes[i].second == FormatType::kFormatA) {
147         AddEdge(source_id_, original_nodes[i].first, INF, 0);
148       } else if (original_nodes[i].second == FormatType::kFormatB) {
149         AddEdge(original_nodes[i].first, sink_id_, INF, 0);
150       }
151       // each nodes was split into two part, input part and output part.
152       // the input part's id is the original node's id, the output part's id is input id + origin_nodes_num_.
153       AddEdge(original_nodes[i].first, original_nodes[i].first + origin_nodes_num_, 1, 1);
154     }
155     for (auto e : original_edges_) {
156       AddEdge(e.from + origin_nodes_num_, e.to, 1, 1);
157     }
158   }
159 
160  public:
MinCut(const std::vector<NodeIdWithFormat> & original_nodes,const std::vector<Edge> & original_edges)161   MinCut(const std::vector<NodeIdWithFormat> &original_nodes, const std::vector<Edge> &original_edges)
162       : origin_nodes_num_(original_nodes.size()),
163         source_id_(0),
164         sink_id_(num2 * original_nodes.size() + 1),
165         nodes_(num2 * original_nodes.size() + num2),  // double nodes, and source_node/sink_node
166         original_edges_(original_edges) {
167     BuildGraph(original_nodes);
168   }
169   ~MinCut() = default;
170 
Run()171   void Run() {
172     Dinic();
173     SetFormat(source_id_);
174   }
175 
GetOneNodeOps() const176   std::vector<std::pair<size_t, TransOpType>> GetOneNodeOps() const {
177     std::vector<std::pair<size_t, TransOpType>> one_node_ops;
178     for (size_t i = 1; i <= origin_nodes_num_; ++i) {
179       auto tmpi = i;  // to evade pclint warning "for statement index variable modified in body."
180       if (nodes_[i].format == FormatType::kFormatA && nodes_[i + origin_nodes_num_].format != FormatType::kFormatA) {
181         (void)one_node_ops.emplace_back(tmpi, TransOpType::kTransAB);
182         MS_LOG(DEBUG) << "Inserted kTransAB for node_id " << tmpi;
183       } else if (nodes_[i].format != FormatType::kFormatA &&
184                  nodes_[i + origin_nodes_num_].format == FormatType::kFormatA) {
185         (void)one_node_ops.emplace_back(tmpi, TransOpType::kTransBA);
186         MS_LOG(DEBUG) << "Inserted kTransBA for node_id " << tmpi;
187       }
188     }
189     return one_node_ops;
190   }
191 
GetTwoNodeOps() const192   std::vector<std::pair<Edge, TransOpType>> GetTwoNodeOps() const {
193     std::vector<std::pair<Edge, TransOpType>> two_node_ops;
194     for (auto e : original_edges_) {
195       if (nodes_[e.from + origin_nodes_num_].format == FormatType::kFormatA &&
196           nodes_[e.to].format != FormatType::kFormatA) {
197         (void)two_node_ops.emplace_back(e, TransOpType::kTransAB);
198         MS_LOG(DEBUG) << "Inserted kTransAB for edge " << e;
199       } else if (nodes_[e.from + origin_nodes_num_].format != FormatType::kFormatA &&
200                  nodes_[e.to].format == FormatType::kFormatA) {
201         (void)two_node_ops.emplace_back(e, TransOpType::kTransBA);
202         MS_LOG(DEBUG) << "Inserted kTransBA for edge " << e;
203       }
204     }
205     return two_node_ops;
206   }
207 
208  private:
209   size_t origin_nodes_num_;
210   size_t source_id_;
211   size_t sink_id_;
212   std::vector<Vertex> nodes_;
213   std::vector<MinCutEdge> edges_;
214   std::vector<Edge> original_edges_;
215 };
216 
IsDynamicShapeGraph(const inner::LiteGraphPtr & litegraph)217 bool IsDynamicShapeGraph(const inner::LiteGraphPtr &litegraph) {
218   MS_EXCEPTION_IF_NULL(litegraph);
219   for (auto &op : litegraph->ops()) {
220     if (IsDynamic(op->shape)) {
221       return true;
222     }
223   }
224   return false;
225 }
226 }  // namespace
227 
228 using inner::LiteGraph;
229 using inner::LiteGraphPtr;
230 using inner::NodePtrList;
231 using inner::NType;
232 using inner::PrimOp;
233 using inner::PrimOpPtr;
234 
TransformOp(const NodePtr & node)235 TransformOp::TransformOp(const NodePtr &node)
236     : op_(node->As<PrimOp>()->op()), format_a_(node->input(0)->format), format_b_(node->format) {}
237 
Hash() const238 size_t TransformOp::Hash() const {
239   // TransAB and TransBA are seen as the same trans op.
240   auto fmt1 = format_a_;
241   auto fmt2 = format_b_;
242   if (fmt1 > fmt2) {
243     std::swap(fmt1, fmt2);
244   }
245   return std::hash<std::string>{}(op_ + fmt1 + fmt2);
246 }
247 
GetFormat(const NodePtr & node) const248 std::string TransformOp::GetFormat(const NodePtr &node) const { return node->format; }
249 
IsTransformOp(const NodePtr & node)250 bool TransformOp::IsTransformOp(const NodePtr &node) {
251   if (node->NodeType() != NType::Primitive || node->As<PrimOp>()->op() != op_) {
252     return false;
253   }
254   auto format_in = GetFormat(node->input(0));
255   auto format_out = GetFormat(node);
256   if (format_in == format_a_ && format_out == format_b_) {
257     return true;
258   } else if (format_in == format_b_ && format_out == format_a_) {
259     return true;
260   }
261   return false;
262 }
263 
NeedInsert(const NodePtr & input_node) const264 bool TransformOp::NeedInsert(const NodePtr &input_node) const {
265   // a trick, if the node's size of 1, it's not need to insert transform op.
266   return input_node->tensor_size() != 1;
267 }
268 
GetFormatType(const std::string & fmt)269 FormatType TransformOp::GetFormatType(const std::string &fmt) {
270   // nodes that are not flexible and not FormatA will be set to FormatB (include "others" format)
271   return fmt == format_a_ ? FormatType::kFormatA : FormatType::kFormatB;
272 }
273 
SetInput(const NodePtr & node,const NodePtr & input_node)274 void TransformOp::SetInput(const NodePtr &node, const NodePtr &input_node) { node->SetInputs({input_node}); }
275 
IsTransOp(const NodePtr & node) const276 bool TransformOpCreator::IsTransOp(const NodePtr &node) const {
277   if (node->NodeType() == NType::Primitive) {
278     if (node->As<PrimOp>()->op() == op_name_) {
279       if (op_name_ == "Reshape") {
280         return node->format == node->input(0)->format;
281       }
282       return true;
283     }
284   }
285   return false;
286 }
287 
288 class TransposeHandle : public TransformOp {
289  public:
290   using TransformOp::TransformOp;
GenTransformOp(const NodePtr & input_node,TransOpType trans_type)291   NodePtr GenTransformOp(const NodePtr &input_node, TransOpType trans_type) override {
292     static std::map<std::tuple<size_t, std::string, std::string>, std::vector<int64_t>> perm_map = {
293       // rank 3
294       {{3, kOpFormat_NCHW, kOpFormat_NHWC}, {1, 2, 0}},
295       {{3, kOpFormat_NHWC, kOpFormat_NCHW}, {2, 0, 1}},
296       // rank 4
297       {{4, kOpFormat_DEFAULT, kOpFormat_NHWC}, {0, 2, 3, 1}},
298       {{4, kOpFormat_NCHW, kOpFormat_NHWC}, {0, 2, 3, 1}},
299       {{4, kOpFormat_NHWC, kOpFormat_NCHW}, {0, 3, 1, 2}},
300       {{4, kOpFormat_NHWC, kOpFormat_DEFAULT}, {0, 3, 1, 2}},
301     };
302     std::vector<int64_t> perm;
303     std::string dst_format;
304     auto rank = input_node->shape.size();
305     if (trans_type == TransOpType::kTransAB) {
306       perm = perm_map[{rank, format_a_, format_b_}];
307       dst_format = format_b_;
308     } else {
309       perm = perm_map[{rank, format_b_, format_a_}];
310       dst_format = format_a_;
311     }
312     if (perm.empty()) {
313       MS_LOG(INFO) << "unsupported format: " << format_a_ << " to " << format_b_ << " of rank " << rank;
314       return nullptr;
315     }
316     auto op = inner::OpRegistry::Instance().NewOp(op_);
317     auto perm_tensor = std::make_shared<tensor::Tensor>(perm, kInt64);
318     node_to_input_tensor_map_[op] = perm_tensor;
319     op->SetAttr(kAttrDstFormat, MakeValue(dst_format));
320     return op;
321   }
322 
SetInput(const NodePtr & node,const NodePtr & input_node)323   void SetInput(const NodePtr &node, const NodePtr &input_node) override {
324     inner::GraphBuilder gb;
325     auto iter = node_to_input_tensor_map_.find(node);
326     if (iter == node_to_input_tensor_map_.end()) {
327       MS_LOG(EXCEPTION) << "Can't find input valueptr for node: " << node->ToString();
328     }
329     auto perm_tensor = iter->second;
330     auto perm_node = gb.Value(perm_tensor);
331     node->SetInputs({input_node, perm_node});
332   }
333 
334  private:
335   std::map<NodePtr, tensor::TensorPtr> node_to_input_tensor_map_;
336 };
337 
338 class LayoutTransformHandle : public TransformOp {
339  public:
340   using TransformOp::TransformOp;
GenTransformOp(const NodePtr &,TransOpType trans_type)341   NodePtr GenTransformOp(const NodePtr &, TransOpType trans_type) override {
342     auto op = inner::OpRegistry::Instance().NewOp(op_);
343     if (trans_type == TransOpType::kTransAB) {
344       op->SetAttr(kAttrSrcFormat, MakeValue(format_a_));
345       op->SetAttr(kAttrDstFormat, MakeValue(format_b_));
346     } else {
347       op->SetAttr(kAttrSrcFormat, MakeValue(format_b_));
348       op->SetAttr(kAttrDstFormat, MakeValue(format_a_));
349     }
350     return op;
351   }
352 };
353 
354 class ReshapeHandle : public TransformOp {
355  public:
ReshapeHandle(const NodePtr & node)356   explicit ReshapeHandle(const NodePtr &node) : TransformOp(node) {
357     format_a_ = EncodeShape(node->input(0)->shape);
358     format_b_ = EncodeShape(node->shape);
359   }
360   virtual ~ReshapeHandle() = default;
361 
GetFormat(const NodePtr & node) const362   std::string GetFormat(const NodePtr &node) const override {
363     // Reshape op uses shape as format
364     return EncodeShape(node->shape);
365   }
366 
NeedInsert(const NodePtr &) const367   bool NeedInsert(const NodePtr &) const override {
368     // Reshape op must be inserted, otherwise the out shape of a node may changed and users may need infer shape again.
369     return true;
370   }
371 
GenTransformOp(const NodePtr &,TransOpType trans_type)372   NodePtr GenTransformOp(const NodePtr &, TransOpType trans_type) override {
373     auto op = inner::OpRegistry::Instance().NewOp(op_);
374     auto out_format = trans_type == TransOpType::kTransAB ? format_b_ : format_a_;
375     auto out_shape = DecodeShape(out_format);
376     auto shape_tensor = std::make_shared<tensor::Tensor>(out_shape, kInt64);
377     node_to_input_tensor_map_[op] = shape_tensor;
378     return op;
379   }
380 
SetInput(const NodePtr & node,const NodePtr & input_node)381   void SetInput(const NodePtr &node, const NodePtr &input_node) override {
382     inner::GraphBuilder gb;
383     auto iter = node_to_input_tensor_map_.find(node);
384     if (iter == node_to_input_tensor_map_.end()) {
385       MS_LOG(EXCEPTION) << "Can't find input valueptr for node: " << node->ToString();
386     }
387     auto shape_tensor = iter->second;
388     auto shape_node = gb.Value(shape_tensor);
389     node->SetInputs({input_node, shape_node});
390   }
391 
392  private:
EncodeShape(const ShapeVector & shape) const393   std::string EncodeShape(const ShapeVector &shape) const {
394     std::string res;
395     for (const auto &s : shape) {
396       res += std::to_string(s) + "_";
397     }
398     return res;
399   }
400 
DecodeShape(const std::string & shape) const401   ShapeVector DecodeShape(const std::string &shape) const {
402     ShapeVector res;
403     size_t l = 0;
404     for (size_t i = 0; i < shape.size(); ++i) {
405       if (shape[i] == '_' && i > l) {
406         std::istringstream iss(shape.substr(l, i));
407         l = i + 1;
408         int64_t s;
409         iss >> s;
410         res.push_back(s);
411       }
412     }
413     return res;
414   }
415 
416   std::map<NodePtr, tensor::TensorPtr> node_to_input_tensor_map_;
417 };
418 
419 constexpr int kOutputIndex = -1;
420 class Mutator {
421  public:
Mutator(const NodePtr & node,const TransformOpPtr & handle)422   Mutator(const NodePtr &node, const TransformOpPtr &handle) : op_handle_(handle), basenode_(node), ori_node_(1) {}
423   ~Mutator() = default;
424 
Run(std::set<NodePtr> * changed_nodes)425   bool Run(std::set<NodePtr> *changed_nodes) {
426     VisitNode(basenode_, kOutputIndex);
427     if (flexible_ops_.empty() && trans_ops_.size() <= 1) {
428       return false;
429     }
430     // remove transform ops in litegraph
431     RemoveTransOp();
432     GenFormatGraph();
433     RebuildLiteGraph(changed_nodes);
434     changed_nodes->insert(flexible_ops_.begin(), flexible_ops_.end());
435     return true;
436   }
437 
new_trans_op_num() const438   size_t new_trans_op_num() const { return new_trans_op_num_; }
439 
440  private:
VisitNode(const NodePtr & node,int index)441   void VisitNode(const NodePtr &node, int index) {
442     if (visited_.count(node) > 0 && inflexible_ops_.count(node) == 0) {
443       return;
444     }
445     (void)visited_.insert(node);
446     if (op_handle_->IsTransformOp(node)) {
447       (void)trans_ops_.insert(node);
448     } else if (!IsFlexibleOp(node)) {
449       VisitInflexibleOp(node, index);
450       return;
451     } else {
452       (void)flexible_ops_.insert(node);
453       fmt_type[{node, kOutputIndex}] = FormatType::kFlexFormat;
454     }
455     // for trans op or format-flexible op, visit node bidirectionally.
456     for (auto &input : node->inputs()) {
457       if (input->NodeType() != NType::Tensor && input->NodeType() != NType::Scalar) {
458         VisitNode(input, kOutputIndex);
459       }
460     }
461     for (auto &user : node->users()) {
462       for (auto user_idx : user.second) {
463         VisitNode(user.first->shared_from_this(), SizeToInt(user_idx));
464       }
465     }
466   }
467 
VisitInflexibleOp(const NodePtr & node,int index)468   void VisitInflexibleOp(const NodePtr &node, int index) {
469     auto &visited_index = inflexible_ops_[node];
470     if (!visited_index.insert(index).second) {
471       return;
472     }
473     if (visited_index.size() == 1) {
474       if (node->NodeType() != NType::Output) {
475         fmt_type[{node, kOutputIndex}] = op_handle_->GetFormatType(op_handle_->GetFormat(node));
476       }
477       if (node->NodeType() != NType::Parameter) {
478         for (size_t i = 0; i < node->inputs().size(); i++) {
479           if (node->input(i)->NodeType() != NType::Tensor && node->input(i)->NodeType() != NType::Scalar) {
480             fmt_type[{node, i}] = op_handle_->GetFormatType(op_handle_->GetFormat(node->input(i)));
481           }
482         }
483       }
484     }
485     // this node is visited from output direction, visit its other users
486     if (index < 0) {
487       for (const auto &user : node->users()) {
488         for (auto user_idx : user.second) {
489           VisitNode(user.first->shared_from_this(), SizeToInt(user_idx));
490         }
491       }
492     }
493   }
494 
RemoveTransOp()495   void RemoveTransOp() {
496     for (const auto &node : trans_ops_) {
497       (void)visited_.erase(node);
498       node->ReplaceWith(node->input(0));
499       // clear inputs, so that the node will not be the basenode again.
500       node->ClearInputs();
501     }
502     trans_ops_.clear();
503   }
504 
GenFormatGraph()505   void GenFormatGraph() {
506     for (const auto &node : visited_) {
507       if (node->NodeType() == NType::Parameter) {
508         continue;
509       }
510       bool is_flexible = (flexible_ops_.find(node) != flexible_ops_.cend());
511       size_t cur_id = 0;
512       if (is_flexible) {
513         cur_id = GetNodeId({node, kOutputIndex});
514       }
515       for (size_t i = 0; i < node->inputs().size(); i++) {
516         if (visited_.count(node->input(i)) == 0) {
517           continue;
518         }
519         if (!is_flexible) {
520           cur_id = GetNodeId({node, SizeToInt(i)});
521         }
522         auto input_id = GetNodeId({node->input(i), kOutputIndex});
523         (void)graph_edges_.emplace_back(Edge{input_id, cur_id});
524       }
525     }
526   }
527 
NewTransOp(const NodePtr & input,TransOpType trans_type,std::set<NodePtr> * changed_nodes)528   NodePtr NewTransOp(const NodePtr &input, TransOpType trans_type, std::set<NodePtr> *changed_nodes) {
529     if (!op_handle_->NeedInsert(input)) {
530       return nullptr;
531     }
532     NodePtr trans_op = op_handle_->GenTransformOp(input, trans_type);
533     MS_EXCEPTION_IF_NULL(trans_op);
534     static size_t inc_id = 0;
535     trans_op->SetDebugName("new_trans_op_" + std::to_string(inc_id++));
536     MS_LOG(DEBUG) << "Create " << trans_op->debug_name() << " of " << trans_type << " with input node "
537                   << input->debug_name();
538     (void)changed_nodes->insert(trans_op);
539     new_trans_op_num_++;
540     return trans_op;
541   }
542 
RefineEdges(std::vector<std::pair<size_t,TransOpType>> * one_node_edge,std::vector<std::pair<Edge,TransOpType>> * two_node_edge) const543   void RefineEdges(std::vector<std::pair<size_t, TransOpType>> *one_node_edge,
544                    std::vector<std::pair<Edge, TransOpType>> *two_node_edge) const {
545     std::map<size_t, TransOpType> one_node_edge_map;
546     for (auto &one : *one_node_edge) {
547       one_node_edge_map[one.first] = one.second;
548     }
549     std::set<Edge> removed_edges;
550     std::set<size_t> removed_edges_from;
551     for (auto iter = two_node_edge->begin(); iter != two_node_edge->end();) {
552       if (one_node_edge_map.count(iter->first.from) == 0) {
553         ++iter;
554         continue;
555       }
556       auto from = iter->first.from;
557       (void)removed_edges_from.insert(from);
558       // remove node from one_node_edge.
559       auto rm_iter = std::find_if(one_node_edge->begin(), one_node_edge->end(),
560                                   [from](const std::pair<size_t, TransOpType> &no) { return from == no.first; });
561       if (rm_iter != one_node_edge->end()) {
562         (void)one_node_edge->erase(rm_iter);
563         MS_LOG(DEBUG) << "Removed edge for node_id " << from;
564       }
565       // remove node from two_node_edge.
566       (void)removed_edges.insert(iter->first);
567       iter = two_node_edge->erase(iter);
568       MS_LOG(DEBUG) << "Removed edge " << iter->first.from << " -> " << iter->first.to;
569     }
570     for (auto &e : graph_edges_) {
571       if (removed_edges_from.count(e.from) != 0 && removed_edges.count(e) == 0) {
572         two_node_edge->push_back(std::make_pair(e, one_node_edge_map[e.from]));
573         MS_LOG(DEBUG) << "Inserted " << (one_node_edge_map[e.from] == TransOpType::kTransAB ? "kTransAB" : "kTransBA")
574                       << " for edge " << e.from << " -> " << e.to;
575       }
576     }
577   }
578 
RebuildLiteGraph(std::set<NodePtr> * changed_nodes)579   void RebuildLiteGraph(std::set<NodePtr> *changed_nodes) {
580     MinCut min_cut(graph_vertex_, graph_edges_);
581     min_cut.Run();
582     auto one_node_edge = min_cut.GetOneNodeOps();
583     auto two_node_edge = min_cut.GetTwoNodeOps();
584     RefineEdges(&one_node_edge, &two_node_edge);
585     for (auto [node_id, trans_type] : one_node_edge) {
586       if (ori_node_[node_id].second != kOutputIndex) {
587         MS_LOG(EXCEPTION) << "OneNodeOp should be the output edge. node_id:" << node_id
588                           << " index:" << ori_node_[node_id].second;
589       }
590       auto input_node = ori_node_[node_id].first;
591       auto trans_op = NewTransOp(input_node, trans_type, changed_nodes);
592       if (trans_op == nullptr) {
593         continue;
594       }
595       input_node->ReplaceWith(trans_op);
596       op_handle_->SetInput(trans_op, input_node);
597       MS_LOG(DEBUG) << "Inserted " << trans_op->debug_name() << " after " << input_node->debug_name();
598     }
599 
600     std::map<size_t, NodePtr> trans_op_cache;
601     for (auto [insert_edge, trans_type] : two_node_edge) {
602       if (ori_node_[insert_edge.from].second != kOutputIndex) {
603         MS_LOG(EXCEPTION) << "node_from should be the output insert_edge. node_id:" << insert_edge.from
604                           << " index:" << ori_node_[insert_edge.from].second;
605       }
606       auto node_from = ori_node_[insert_edge.from].first;
607       auto node_to = ori_node_[insert_edge.to].first;
608       if (trans_op_cache.count(insert_edge.from) == 0) {
609         auto trans_op = NewTransOp(node_from, trans_type, changed_nodes);
610         if (trans_op == nullptr) {
611           continue;
612         }
613         trans_op_cache[insert_edge.from] = trans_op;
614         op_handle_->SetInput(trans_op, node_from);
615       }
616       auto trans_op = trans_op_cache[insert_edge.from];
617       if (ori_node_[insert_edge.to].second >= 0) {
618         node_to->SetInput(IntToSize(ori_node_[insert_edge.to].second), trans_op);
619         MS_LOG(DEBUG) << "Inserted " << trans_op->debug_name() << " before " << node_to->debug_name() << " (input "
620                       << ori_node_[insert_edge.to].second << ")";
621       } else {
622         // "node_to" is flexible.
623         for (size_t i = 0; i < node_to->inputs().size(); i++) {
624           if (node_to->input(i) == node_from) {
625             node_to->SetInput(i, trans_op);
626             MS_LOG(DEBUG) << "Inserted " << trans_op->debug_name() << " before " << node_to->debug_name() << " (input "
627                           << i << ")";
628           }
629         }
630       }
631     }
632   }
633 
GetNodeId(const NodeWithIndex & node_with_index)634   size_t GetNodeId(const NodeWithIndex &node_with_index) {
635     // the nodes are indexed from 1 in the MinCut model.
636     auto &id = node_id_[node_with_index];
637     if (id == 0) {
638       id = node_id_.size();
639       ori_node_.push_back(node_with_index);
640       // set format_type for new id.
641       (void)graph_vertex_.emplace_back(id, fmt_type[node_with_index]);
642       MS_LOG(DEBUG) << "Allot node_id " << id << " to " << node_with_index.first->debug_name() << " (index "
643                     << node_with_index.second << ").";
644     }
645     return id;
646   }
647 
IsFlexibleOp(const NodePtr & node) const648   bool IsFlexibleOp(const NodePtr &node) const {
649     if (node->NodeType() != NType::Primitive) {
650       return false;
651     }
652     if (node->As<PrimOp>()->compute_type() != PrimOp::ComputeType::ELEMWISE) {
653       return false;
654     }
655     // check the input and output formats are all the same, except ConstValue.
656     for (auto &inp : node->inputs()) {
657       if (inp->NodeType() != NType::Tensor && inp->NodeType() != NType::Scalar &&
658           op_handle_->GetFormat(inp) != op_handle_->GetFormat(node)) {
659         return false;
660       }
661     }
662     return true;
663   }
664 
665   size_t new_trans_op_num_{0};
666 
667   TransformOpPtr op_handle_;
668   NodePtr basenode_;
669   std::set<NodePtr> flexible_ops_;
670   std::set<NodePtr> trans_ops_;
671   std::set<NodePtr> visited_;
672   std::map<NodePtr, std::set<int>> inflexible_ops_;  // no transop and no flexibleop, record the visit index.
673 
674   std::map<NodeWithIndex, FormatType> fmt_type;
675   std::map<NodeWithIndex, size_t> node_id_;
676   std::vector<NodeWithIndex> ori_node_;  // node_id to NodePtr, this vector is indexed from 1
677   std::vector<NodeIdWithFormat> graph_vertex_;
678   std::vector<Edge> graph_edges_;
679 };
680 
ReInfer(const LiteGraphPtr & litegraph,const std::set<NodePtr> & nodes_may_change) const681 void TransformOpOptimizer::ReInfer(const LiteGraphPtr &litegraph, const std::set<NodePtr> &nodes_may_change) const {
682   auto &new_ops = litegraph->GetOrderedNodes();
683   MS_LOG(DEBUG) << "The changed graph before InferShape: \n" << litegraph->ToString();
684   for (auto &op : new_ops) {
685     if (nodes_may_change.count(op) != 0) {
686       op->SetBaseInfo(op->As<PrimOp>()->Infer(op->inputs(), op->attrs()));
687     }
688   }
689   MS_LOG(DEBUG) << "The changed graph after InferShape: \n" << litegraph->ToString();
690 }
691 
Process(const LiteGraphPtr & litegraph,const TransformOpPtr & op_handle) const692 bool TransformOpOptimizer::Process(const LiteGraphPtr &litegraph, const TransformOpPtr &op_handle) const {
693   MS_LOG(DEBUG) << "Process begin, handle is " << *op_handle << ". litegraph: \n" << litegraph->ToString();
694   auto ops = litegraph->ops();
695   bool changed = false;
696   auto check_is_trans_op = [&op_handle](const NodePtr &node) { return op_handle->IsTransformOp(node); };
697   size_t ori_trans_op_num = static_cast<size_t>(std::count_if(ops.begin(), ops.end(), check_is_trans_op));
698   size_t new_trans_op_num = 0;
699   for (auto &op : ops) {
700     if (check_is_trans_op(op) && !op->inputs().empty()) {
701       if (op_handle->GetFormat(op->input(0)) != op_handle->GetFormat(op)) {
702         std::set<NodePtr> nodes_may_change;
703         auto mutator = Mutator(op, op_handle);
704         MS_LOG(DEBUG) << "Run mutator with basenode " << op->debug_name();
705         auto ret = mutator.Run(&nodes_may_change);
706         MS_LOG(DEBUG) << "Run mutator result: " << ret;
707         new_trans_op_num += mutator.new_trans_op_num();
708         if (ret) {
709           changed = true;
710           ReInfer(litegraph, nodes_may_change);
711         }
712       }
713     }
714   }
715   bool result = changed && new_trans_op_num < ori_trans_op_num;
716   MS_LOG(DEBUG) << "Process result=" << result << ". changed=" << changed << ", new_trans_op_num=" << new_trans_op_num
717                 << ", ori_trans_op_num=" << ori_trans_op_num;
718   return result;
719 }
720 
Init()721 void TransformOpOptimizer::Init() {
722   (void)supported_ops_.emplace_back(TRANS_OP_CREATOR("Transpose", TransposeHandle));
723   (void)supported_ops_.emplace_back(TRANS_OP_CREATOR("LayoutTransform", LayoutTransformHandle));
724   (void)supported_ops_.emplace_back(TRANS_OP_CREATOR("Reshape", ReshapeHandle));
725 }
726 
CreateOpHandles(const LiteGraphPtr & litegraph) const727 std::vector<TransformOpPtr> TransformOpOptimizer::CreateOpHandles(const LiteGraphPtr &litegraph) const {
728   HashSet<size_t> handle_hash;
729   std::vector<TransformOpPtr> handles;
730   for (auto &creator : supported_ops_) {
731     if (creator.Name() == "Reshape" && IsDynamicShapeGraph(litegraph)) {
732       // skip dynamic shape
733       continue;
734     }
735     for (auto &op : litegraph->ops()) {
736       if (creator.IsTransOp(op)) {
737         auto handle = creator.CreateHandle(op);
738         if (handle_hash.insert(handle->Hash()).second) {
739           (void)handles.emplace_back(handle);
740         }
741       }
742     }
743   }
744   return handles;
745 }
746 
Run(const FuncGraphPtr & func_graph)747 bool TransformOpOptimizer::Run(const FuncGraphPtr &func_graph) {
748   auto mng = func_graph->manager();
749   MS_EXCEPTION_IF_NULL(mng);
750   auto todos = GkUtils::GetGraphKernelNodes(func_graph);
751   bool changed = false;
752   for (auto node : todos) {
753     MS_LOG(DEBUG) << "Run the node: " << node->fullname_with_scope();
754     auto sub_func_graph = GetCNodeFuncGraph(node);
755     auto node_name = sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
756     auto litegraph = GkUtils::AnfGraph2LiteGraph(sub_func_graph);
757     auto handles = CreateOpHandles(litegraph);
758     for (size_t i = 0; i < handles.size(); i++) {
759       // rebuild litegraph for every process
760       if (i > 0) {
761         litegraph = GkUtils::AnfGraph2LiteGraph(GetCNodeFuncGraph(node));
762       }
763       bool result = false;
764       try {
765         MS_LOG_TRY_CATCH_SCOPE;
766         result = Process(litegraph, handles[i]);
767       } catch (std::exception &e) {
768         result = false;
769         MS_LOG(INFO) << "Process node " << node->DebugString() << " failed. message: " << e.what();
770       }
771       if (result) {
772         changed = true;
773         MS_LOG(DEBUG) << "Replace with graph:\n" << litegraph->ToString(true);
774         auto new_funcgraph = GkUtils::LiteGraph2AnfGraph(litegraph, Callback::Instance());
775         MS_EXCEPTION_IF_NULL(new_funcgraph);
776         new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, node_name);
777         auto cnode = node->cast<CNodePtr>();
778         AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
779         (void)ConvertTensorToParameter(new_funcgraph, &inputs);
780         auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs);
781         (void)mng->Replace(node, new_node);
782         node = new_node;
783         mng->AddFuncGraph(new_funcgraph);
784       }
785     }
786   }
787   return changed;
788 }
789 }  // namespace mindspore::graphkernel
790