• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/optimizer/pattern_to_pattern.h"
18 #include <algorithm>
19 #include <set>
20 #include <queue>
21 #include "ir/manager.h"
22 #include "include/common/utils/anfalgo.h"
23 
24 namespace mindspore {
25 namespace opt {
AlwaysReturnTrue(const BaseRef &)26 bool AlwaysReturnTrue(const BaseRef &) { return true; }
27 
Contains(const std::string & name) const28 bool PatternMap::Contains(const std::string &name) const { return name_set_.count(name) > 0; }
29 
CheckSeq(const std::string & name) const30 bool PatternMap::CheckSeq(const std::string &name) const {
31   return name_set_.count(name) > 0 && seq_map_.count(name) > 0;
32 }
33 
Erase(const mindspore::HashSet<std::string> & del_set)34 void PatternMap::Erase(const mindspore::HashSet<std::string> &del_set) {
35   for (auto &s : del_set) {
36     name_set_.erase(s);
37     node_map_.erase(s);
38   }
39 }
40 
Get(const std::string & name) const41 AnfNodePtr PatternMap::Get(const std::string &name) const {
42   if (!Contains(name)) {
43     MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " is not in PatternMap";
44   }
45 
46   auto iter = node_map_.find(name);
47   if (iter == node_map_.end()) {
48     MS_LOG(INTERNAL_EXCEPTION) << "Var Key: " << name << " is not in PatternMap";
49   }
50   return iter->second;
51 }
52 
GetSeq(const std::string & name) const53 const std::vector<AnfNodePtr> &PatternMap::GetSeq(const std::string &name) const {
54   if (!Contains(name)) {
55     MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " is not in PatternMap";
56   }
57 
58   auto iter = seq_map_.find(name);
59   if (iter == seq_map_.end()) {
60     MS_LOG(INTERNAL_EXCEPTION) << "SeqVar Key: " << name << " is not in PatternMap";
61   }
62   return iter->second;
63 }
64 
Emplace(const std::string & name,const AnfNodePtr & node)65 bool PatternMap::Emplace(const std::string &name, const AnfNodePtr &node) {
66   MS_EXCEPTION_IF_NULL(node);
67   name_set_.insert(name);
68   if (seq_map_.find(name) != seq_map_.end()) {
69     MS_LOG(INTERNAL_EXCEPTION) << "Var Key: " << name << " should not be in SeqVarMap.";
70   }
71 
72   (void)opt_scope_.insert(node);
73 
74   auto iter = node_map_.find(name);
75   if (iter == node_map_.end()) {
76     node_map_.emplace(name, node);
77   } else if (!opt::AnfEqual(node, iter->second)) {
78     MS_EXCEPTION_IF_NULL(iter->second);
79     MS_LOG(INFO) << "The value of key: " << name
80                  << " is not equal to origin value, value: " + node->fullname_with_scope()
81                  << " origin value: " << iter->second->fullname_with_scope();
82     return false;
83   }
84   return true;
85 }
86 
Emplace(const std::string & name,const std::vector<AnfNodePtr> & v)87 bool PatternMap::Emplace(const std::string &name, const std::vector<AnfNodePtr> &v) {
88   name_set_.insert(name);
89   if (node_map_.find(name) != node_map_.end()) {
90     MS_LOG(INTERNAL_EXCEPTION) << "SeqVar Key: " << name << " should not be in VarMap.";
91   }
92 
93   for (const auto &node : v) {
94     (void)opt_scope_.insert(node);
95   }
96 
97   auto iter = seq_map_.find(name);
98   if (iter == seq_map_.end()) {
99     seq_map_.emplace(name, v);
100   } else {
101     auto &origin_v = iter->second;
102     if (v.size() != origin_v.size()) {
103       MS_LOG(INFO) << "The value of key: " << name
104                    << " is not equal to origin value, v size: " + std::to_string(v.size()) +
105                         ", origin_v size: " + std::to_string(origin_v.size());
106       return false;
107     }
108 
109     for (size_t i = 0; i < v.size(); i++) {
110       MS_EXCEPTION_IF_NULL(v[i]);
111       MS_EXCEPTION_IF_NULL(origin_v[i]);
112       if (!opt::AnfEqual(v[i], origin_v[i])) {
113         MS_LOG(INFO) << "The value of key: " << name
114                      << " is not equal to origin value, value: " + v[i]->fullname_with_scope()
115                      << " origin value: " << origin_v[i]->fullname_with_scope();
116         return false;
117       }
118     }
119   }
120   return true;
121 }
122 
Clear()123 void PatternMap::Clear() {
124   name_set_.clear();
125   node_map_.clear();
126   seq_map_.clear();
127 }
128 
Check(const std::string & name,const AnfNodePtr & node) const129 bool PatternMap::Check(const std::string &name, const AnfNodePtr &node) const { return opt::AnfEqual(node, Get(name)); }
130 
AddVar(const std::string & name,const PatternConditionFunc & f)131 SrcPattern &SrcPattern::AddVar(const std::string &name, const PatternConditionFunc &f) {
132   if (ref_map_.find(name) != ref_map_.end()) {
133     MS_LOG(INTERNAL_EXCEPTION) << "Var: " << name << " is already in SrcPattern.";
134   }
135 
136   auto var = std::make_shared<CondVar>(f);
137   ref_map_.emplace(name, var);
138   return *this;
139 }
140 
AddSeqVar(const std::string & name,const PatternConditionFunc & f)141 SrcPattern &SrcPattern::AddSeqVar(const std::string &name, const PatternConditionFunc &f) {
142   if (ref_map_.find(name) != ref_map_.end()) {
143     MS_LOG(INTERNAL_EXCEPTION) << "SeqVar: " << name << " is already in SrcPattern.";
144   }
145 
146   auto seq_var = std::make_shared<SeqVar>(f);
147   ref_map_.emplace(name, seq_var);
148   return *this;
149 }
150 
GetRef(const std::string & name) const151 const BaseRef &SrcPattern::GetRef(const std::string &name) const {
152   auto iter = ref_map_.find(name);
153   if (iter == ref_map_.end()) {
154     MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " not in PatternMap";
155   }
156   return iter->second;
157 }
158 
AddCNode(const std::string & name,const std::initializer_list<PatternNode> & v)159 SrcPattern &SrcPattern::AddCNode(const std::string &name, const std::initializer_list<PatternNode> &v) {
160   if (ref_map_.find(name) != ref_map_.end()) {
161     MS_LOG(INTERNAL_EXCEPTION) << "CNode: " << name << " is already in SrcPattern.";
162   }
163 
164   std::vector<BaseRef> ele;
165   for (auto &node : v) {
166     if (node.type_ == "name") {
167       ele.emplace_back(GetRef(node.name_));
168     } else if (node.type_ == "prim") {
169       ele.emplace_back(node.p_);
170     } else {
171       MS_LOG(INTERNAL_EXCEPTION) << "Error MatchNode Type: " << node.type_ << ", CNode: " << name;
172     }
173   }
174 
175   MS_EXCEPTION_IF_CHECK_FAIL(
176     ele.size() == v.size(),
177     "The length of BaseRef Vector and CNode Input is not equal, BaseRef Vector length: " + std::to_string(ele.size()) +
178       " CNode Input length: " + std::to_string(v.size()) + ", CNode: " + name);
179 
180   inputs_map_.emplace(name, v);
181   auto vec = VectorRef(ele);
182   ref_map_.emplace(name, vec);
183   has_root_ = true;
184   root_ = name;
185   return *this;
186 }
187 
GetRoot() const188 BaseRef SrcPattern::GetRoot() const {
189   if (!has_root_) {
190     MS_LOG(INTERNAL_EXCEPTION) << "This SrcPattern has no root node.";
191   }
192   return GetRef(root_);
193 }
194 
GetSeq(const std::string & pattern_name,const std::string & node_name,const VarPtr & var,const EquivPtr & equiv)195 const Seq &GetSeq(const std::string &pattern_name, const std::string &node_name, const VarPtr &var,
196                   const EquivPtr &equiv) {
197   MS_EXCEPTION_IF_NULL(equiv);
198   auto equiv_iter = equiv->find(var);
199   if (equiv_iter == equiv->end()) {
200     MS_LOG(INTERNAL_EXCEPTION) << "The SeqVar Key: " << pattern_name << " is not in EquivMap, node name: " << node_name;
201   }
202 
203   BaseRef &seq_ref = equiv_iter->second;
204   if (utils::isa<Seq>(seq_ref)) {
205     const Seq &seq = utils::cast<Seq>(seq_ref);
206     return seq;
207   }
208   MS_LOG(INTERNAL_EXCEPTION) << "The value of SeqVar Key: " << pattern_name
209                              << " is not a seq, node name: " << node_name;
210 }
211 
CheckEmptySeqVar(const std::string & name,const EquivPtr & equiv,const std::vector<PatternNode> & inputs,size_t * now_pattern)212 bool SrcPattern::CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv,
213                                   const std::vector<PatternNode> &inputs, size_t *now_pattern) {
214   if (inputs.size() - (*now_pattern) == 1 && inputs.at(*now_pattern).type_ == "name") {
215     auto &pattern_node = inputs.at(*now_pattern);
216     auto &ref = GetRef(pattern_node.name_);
217     if (utils::isa<VarPtr>(ref) && utils::cast<VarPtr>(ref)->isa<SeqVar>()) {
218       const Seq &seq = GetSeq(pattern_node.name_, name, utils::cast<VarPtr>(ref), equiv);
219       MS_EXCEPTION_IF_CHECK_FAIL(seq.size() == IntToSize(0), "Match Failed, need zero seq, but get seq length: " +
220                                                                std::to_string(seq.size()) + ", node name: " + name);
221       std::vector<AnfNodePtr> v;
222       MS_EXCEPTION_IF_NULL(m_);
223       if (!m_->Emplace(pattern_node.name_, v)) {
224         return false;
225       }
226       (*now_pattern)++;
227     }
228   }
229   return true;
230 }
231 
match(const std::string & name,const AnfNodePtr & node,const EquivPtr & equiv)232 bool SrcPattern::match(const std::string &name, const AnfNodePtr &node, const EquivPtr &equiv) {
233   MS_EXCEPTION_IF_NULL(m_);
234   MS_EXCEPTION_IF_NULL(node);
235   MS_EXCEPTION_IF_NULL(equiv);
236   auto input_iter = inputs_map_.find(name);
237   if (input_iter == inputs_map_.end()) {
238     MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " is not a CNode.";
239   }
240 
241   if (m_->Contains(name)) {
242     return m_->Check(name, node);
243   }
244 
245   auto &inputs = input_iter->second;
246   auto cnode = node->cast<CNodePtr>();
247   MS_EXCEPTION_IF_NULL(cnode);
248   auto cnode_inputs = cnode->inputs();
249   size_t now_pattern = 0;
250   size_t now_match = 0;
251   for (; now_pattern < inputs.size() && now_match < cnode_inputs.size(); now_pattern++, now_match++) {
252     auto &pattern_node = inputs[now_pattern];
253     auto &match_node = cnode_inputs[now_match];
254     if (pattern_node.type_ == "prim") {
255       // prim
256       MS_EXCEPTION_IF_NULL(pattern_node.p_);
257       MS_EXCEPTION_IF_NULL(match_node);
258       if (!opt::AnfEqual(pattern_node.p_, match_node)) {
259         MS_LOG(EXCEPTION) << "The value of Primitive is not equal to matched value, pattern value: " +
260                                pattern_node.p_->ToString()
261                           << " matched value: " + match_node->ToString() + ", node name: " + name;
262       }
263       continue;
264     }
265     // name
266     MS_EXCEPTION_IF_CHECK_FAIL(pattern_node.type_ == "name",
267                                "Error MatchNode Type: " + pattern_node.type_ + ", node name: " + name);
268     auto &ref = GetRef(pattern_node.name_);
269     if (utils::isa<VarPtr>(ref)) {
270       if (utils::cast<VarPtr>(ref)->isa<SeqVar>()) {
271         // seq var
272         const Seq &seq = GetSeq(pattern_node.name_, name, utils::cast<VarPtr>(ref), equiv);
273         std::vector<AnfNodePtr> v;
274         for (size_t i = 0; i < seq.size(); i++) {
275           v.emplace_back(cnode_inputs.at(now_match + i));
276         }
277         if (!m_->Emplace(pattern_node.name_, v)) {
278           return false;
279         }
280         now_match += seq.size() - 1;
281         continue;
282       }
283     } else {
284       // cnode
285       if (!match(pattern_node.name_, match_node, equiv)) {
286         return false;
287       }
288     }
289     if (!m_->Emplace(pattern_node.name_, match_node)) {
290       return false;
291     }
292   }
293   // has a SeqVar at the end
294   if (now_match == cnode_inputs.size()) {
295     if (!CheckEmptySeqVar(name, equiv, inputs, &now_pattern)) {
296       return false;
297     }
298   }
299 
300   MS_EXCEPTION_IF_CHECK_FAIL(
301     now_pattern == inputs.size() && now_match == cnode_inputs.size(),
302     "Match Failed, now_pattern: " + std::to_string(now_pattern) + ", inputs.size(): " + std::to_string(inputs.size()) +
303       ", now_match: " + std::to_string(now_match) + ", cnode_inputs.size(): " + std::to_string(cnode_inputs.size()) +
304       ", node name: " + name);
305 
306   return m_->Emplace(name, node);
307 }
308 
build_pattern_map(const AnfNodePtr & node,const EquivPtr & equiv)309 bool SrcPattern::build_pattern_map(const AnfNodePtr &node, const EquivPtr &equiv) {
310   MS_EXCEPTION_IF_NULL(m_);
311   if (!has_root_) {
312     MS_LOG(INTERNAL_EXCEPTION) << "This SourcePattern has no root node.";
313   }
314   m_->Clear();
315   return match(root_, node, equiv);
316 }
317 
AddCNode(const string & name,const std::initializer_list<PatternNode> & inputs,const BuildCNodeFunc & buildfunc)318 DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list<PatternNode> &inputs,
319                                  const BuildCNodeFunc &buildfunc) {
320   MS_EXCEPTION_IF_NULL(m_);
321   if (fail_) {
322     return *this;
323   }
324 
325   if (m_->Contains(name)) {
326     MS_LOG(INTERNAL_EXCEPTION) << "CNode: " + name + " is already in DstPattern";
327   }
328 
329   std::vector<AnfNodePtr> anf_inputs;
330   for (auto &r : inputs) {
331     if (r.type_ == "prim") {
332       anf_inputs.emplace_back(r.p_);
333     } else if (r.type_ == "name") {
334       if (m_->CheckSeq(r.name_)) {
335         auto &v = m_->GetSeq(r.name_);
336         std::copy(v.begin(), v.end(), std::back_inserter(anf_inputs));
337       } else {
338         anf_inputs.emplace_back(m_->Get(r.name_));
339       }
340     } else if (r.type_ == "unpack") {
341       for (auto &it : r.v_) {
342         if (it.node_ == nullptr) {
343           anf_inputs.emplace_back(m_->Get(it.key_));
344         } else {
345           anf_inputs.emplace_back(it.node_);
346         }
347       }
348     } else {
349       MS_LOG(INTERNAL_EXCEPTION) << "Error ReplaceNode Type: " << r.type_ << ", CNode: " << name;
350     }
351   }
352 
353   MS_EXCEPTION_IF_NULL(pass_);
354   auto default_node = pass_->NewCNode(anf_inputs, fg_);
355   auto new_node = buildfunc(*m_, default_node);
356   if (new_node == nullptr) {
357     fail_ = true;
358   } else {
359     auto cnode = new_node->cast<CNodePtr>();
360     MS_EXCEPTION_IF_NULL(cnode);
361     if (anf_inputs.size() != cnode->size()) {
362       MS_LOG(INTERNAL_EXCEPTION)
363         << "The actual input size does not correspond to the input size of the pattern, actual input size: "
364         << anf_inputs.size() << ", pattern input size: " << new_node->cast<CNodePtr>()->size() << ", CNode: " << name;
365     }
366     for (size_t i = 0; i < anf_inputs.size(); i++) {
367       MS_EXCEPTION_IF_NULL(anf_inputs[i]);
368       MS_EXCEPTION_IF_NULL(cnode->input(i));
369       if (!opt::AnfEqual(anf_inputs[i], cnode->input(i))) {
370         MS_LOG(INTERNAL_EXCEPTION)
371           << "The actual input does not correspond to the input of the pattern, the input index: " << i
372           << ", actual input: " << anf_inputs[i]->DebugString()
373           << ", pattern input: " << new_node->cast<CNodePtr>()->input(i)->DebugString() << ", CNode: " << name;
374       }
375     }
376   }
377 
378   if (!m_->Emplace(name, new_node)) {
379     MS_LOG(EXCEPTION) << "CNode: " + name + " is already in DstPattern";
380   }
381   root_ = new_node;
382   return *this;
383 }
384 
AddValueNode(const string & name,const BuildValueFunc & buildfunc)385 DstPattern &DstPattern::AddValueNode(const string &name, const BuildValueFunc &buildfunc) {
386   MS_EXCEPTION_IF_NULL(m_);
387   if (fail_) {
388     return *this;
389   }
390 
391   if (m_->Contains(name)) {
392     MS_LOG(INTERNAL_EXCEPTION) << "ValueNode: " + name + " is already in DstPattern";
393   }
394 
395   auto node = buildfunc(*m_);
396   if (node == nullptr) {
397     fail_ = true;
398   }
399   dst_set_.insert(name);
400   (void)m_->Emplace(name, node);
401   root_ = node;
402   return *this;
403 }
404 
clear()405 void DstPattern::clear() {
406   MS_EXCEPTION_IF_NULL(m_);
407   fail_ = false;
408   root_ = nullptr;
409   m_->Erase(dst_set_);
410   dst_set_.clear();
411   fg_ = nullptr;
412   pass_ = nullptr;
413 }
414 
set_info(PatternToPatternPass * now_pass,const FuncGraphPtr & func_graph)415 void DstPattern::set_info(PatternToPatternPass *now_pass, const FuncGraphPtr &func_graph) {
416   pass_ = now_pass;
417   fg_ = func_graph;
418 }
419 
Root()420 AnfNodePtr DstPattern::Root() {
421   if (fail_) {
422     return nullptr;
423   } else {
424     return root_;
425   }
426 }
427 
operator =(const std::string & name)428 UnpackNode &UnpackNode::operator=(const std::string &name) {
429   key_ = name;
430   node_ = nullptr;
431   return *this;
432 }
433 
GetSrcPatternRoot()434 AnfNodePtr PatternToPatternPass::GetSrcPatternRoot() {
435   if (src_pattern_root_ == nullptr) {
436     DefineSrcPattern(&src_pattern_);
437     VarPtr fg = std::make_shared<Var>("RootG");
438     src_pattern_root_ = SexpToNode(src_pattern_.GetRoot(), fg, primitive_vars_.get(), multigraph_);
439   }
440   return src_pattern_root_;
441 }
442 
GetPatternRootPrimitiveName()443 std::string PatternToPatternPass::GetPatternRootPrimitiveName() {
444   auto src_pattern_root = GetSrcPatternRoot();
445   auto prim = GetCNodePrimitive(src_pattern_root);
446   if (prim != nullptr) {
447     return prim->name();
448   }
449   return "";
450 }
451 
Run(const FuncGraphPtr & func_graph,const AnfNodePtr & node)452 AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
453   if (src_pattern_root_ == nullptr) {
454     (void)GetSrcPatternRoot();
455   }
456 
457   auto primitive = GetCNodePrimitive(src_pattern_root_);
458   if (IsPrimitiveCNode(node, primitive)) {
459     MS_EXCEPTION_IF_NULL(primitive_vars_);
460     MS_EXCEPTION_IF_NULL(equiv_);
461     equiv_->clear();
462     EquivPtr equiv = pattern_engine_.Match(src_pattern_root_, node, *primitive_vars_, equiv_);
463     if (equiv != nullptr && !equiv->empty()) {
464       if (!src_pattern_.build_pattern_map(node, equiv)) {
465         return nullptr;
466       }
467       if (!CheckMatchedDAG(*m_, func_graph, node)) {
468         return nullptr;
469       }
470       dst_pattern_.clear();
471       dst_pattern_.set_info(this, func_graph);
472       DefineDstPattern(&dst_pattern_);
473       return dst_pattern_.Root();
474     }
475   }
476   return nullptr;
477 }
478 
479 namespace {
480 const auto kStageZero = 0;
481 const auto kStageOne = 1;
482 const auto kStageTwo = 2;
483 
DeleteCNode(const AnfNodePtr & node,const FuncGraphPtr & sub_graph,const FuncGraphIndexPtr & func_graph_index)484 void DeleteCNode(const AnfNodePtr &node, const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) {
485   MS_EXCEPTION_IF_NULL(node);
486   MS_EXCEPTION_IF_NULL(func_graph_index);
487   if (node->isa<CNode>()) {
488     auto name_to_cnode_iter = func_graph_index->name_to_cnode_.find(GetCNodeKey(node));
489     if (name_to_cnode_iter == func_graph_index->name_to_cnode_.end()) {
490       MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find cnode_name: "
491                                  << common::AnfAlgo::GetCNodeName(node);
492     }
493     auto &cnode_set = name_to_cnode_iter->second;
494     auto cnode_set_iter = cnode_set.find(node);
495     if (cnode_set_iter == cnode_set.end()) {
496       MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find node: "
497                                  << node->fullname_with_scope();
498     }
499     (void)cnode_set.erase(cnode_set_iter);
500     ModifyOutputAndCallerToMap(node->cast<CNodePtr>(), sub_graph, &func_graph_index->subgraph_out_caller_map_, false);
501   }
502 }
503 
AppendChild(const AnfNodePtr & node,const FuncGraphPtr & fg,std::queue<std::pair<AnfNodePtr,FuncGraphPtr>> * anf_q)504 void AppendChild(const AnfNodePtr &node, const FuncGraphPtr &fg,
505                  std::queue<std::pair<AnfNodePtr, FuncGraphPtr>> *anf_q) {
506   MS_EXCEPTION_IF_NULL(node);
507   MS_EXCEPTION_IF_NULL(fg);
508   MS_EXCEPTION_IF_NULL(anf_q);
509   if (IsValueNode<FuncGraph>(node)) {
510     auto const_func_graph = GetValueNode<FuncGraphPtr>(node);
511     MS_EXCEPTION_IF_NULL(const_func_graph);
512     if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
513       (void)anf_q->emplace(const_func_graph->output(), const_func_graph);
514     }
515   } else if (node->isa<CNode>()) {
516     auto cnode = node->cast<CNodePtr>();
517     MS_EXCEPTION_IF_NULL(cnode);
518     for (const auto &input_node : cnode->inputs()) {
519       (void)anf_q->emplace(input_node, fg);
520     }
521   }
522 }
523 
DelSrcPattern(const std::pair<AnfNodePtr,FuncGraphPtr> & top,const AnfNodePtr & root,const mindspore::HashSet<AnfNodePtr> & opt_scope,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index)524 bool DelSrcPattern(const std::pair<AnfNodePtr, FuncGraphPtr> &top, const AnfNodePtr &root,
525                    const mindspore::HashSet<AnfNodePtr> &opt_scope,
526                    std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
527                    const FuncGraphIndexPtr &func_graph_index) {
528   MS_EXCEPTION_IF_NULL(root);
529   MS_EXCEPTION_IF_NULL(need_delete);
530   MS_EXCEPTION_IF_NULL(func_graph_index);
531   auto node = top.first;
532   auto fg = top.second;
533   MS_EXCEPTION_IF_NULL(node);
534   MS_EXCEPTION_IF_NULL(fg);
535   if (node != root) {
536     auto degree_iter = func_graph_index->node_degree_.find(node);
537     if (degree_iter == func_graph_index->node_degree_.end()) {
538       MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
539                                  << " not in degree map";
540     }
541     if (degree_iter->second <= 0) {
542       MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
543                                  << " degree error, degree: " << degree_iter->second;
544     }
545     degree_iter->second--;
546     if (degree_iter->second > 0) {
547       return false;
548     }
549   }
550   if (opt_scope.find(node) == opt_scope.end()) {
551     (void)(*need_delete).insert({node, fg});
552     return false;
553   }
554 
555   DeleteCNode(node, fg, func_graph_index);
556   return true;
557 }
558 
AddDstPattern(const std::pair<AnfNodePtr,FuncGraphPtr> & top,const AnfNodePtr & root,const mindspore::HashSet<AnfNodePtr> & opt_scope,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index)559 bool AddDstPattern(const std::pair<AnfNodePtr, FuncGraphPtr> &top, const AnfNodePtr &root,
560                    const mindspore::HashSet<AnfNodePtr> &opt_scope,
561                    std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
562                    const FuncGraphIndexPtr &func_graph_index) {
563   MS_EXCEPTION_IF_NULL(root);
564   MS_EXCEPTION_IF_NULL(need_delete);
565   MS_EXCEPTION_IF_NULL(func_graph_index);
566   auto node = top.first;
567   auto fg = top.second;
568   MS_EXCEPTION_IF_NULL(node);
569   MS_EXCEPTION_IF_NULL(fg);
570   if (node->isa<CNode>()) {
571     ModifyOutputAndCallerToMap(node->cast<CNodePtr>(), fg, &func_graph_index->subgraph_out_caller_map_);
572     (void)func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node);
573     func_graph_index->node_to_fg_[node] = fg;
574   }
575 
576   if (node != root) {
577     auto degree_iter = func_graph_index->node_degree_.find(node);
578     if (degree_iter == func_graph_index->node_degree_.end()) {
579       func_graph_index->node_degree_[node] = 0;
580       degree_iter = func_graph_index->node_degree_.find(node);
581     }
582     degree_iter->second++;
583     if (degree_iter->second != 1) {
584       return false;
585     }
586   }
587   if (opt_scope.find(node) == opt_scope.end()) {
588     (void)(*need_delete).erase({node, fg});
589     return false;
590   }
591   return true;
592 }
593 
DelCascadeNode(const std::pair<AnfNodePtr,FuncGraphPtr> & top,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index)594 bool DelCascadeNode(const std::pair<AnfNodePtr, FuncGraphPtr> &top,
595                     std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
596                     const FuncGraphIndexPtr &func_graph_index) {
597   MS_EXCEPTION_IF_NULL(need_delete);
598   MS_EXCEPTION_IF_NULL(func_graph_index);
599   auto node = top.first;
600   auto fg = top.second;
601   MS_EXCEPTION_IF_NULL(node);
602   MS_EXCEPTION_IF_NULL(fg);
603   if ((*need_delete).find({node, fg}) == (*need_delete).end()) {
604     auto degree_iter = func_graph_index->node_degree_.find(node);
605     if (degree_iter == func_graph_index->node_degree_.end()) {
606       MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
607                                  << " not in degree map";
608     }
609     if (degree_iter->second <= 0) {
610       MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
611                                  << " degree error, degree: " << degree_iter->second;
612     }
613     degree_iter->second--;
614     if (degree_iter->second > 0) {
615       return false;
616     }
617   }
618 
619   DeleteCNode(node, fg, func_graph_index);
620   return true;
621 }
622 
BFS(const AnfNodePtr & root,const FuncGraphPtr & sub_graph,const mindspore::HashSet<AnfNodePtr> & opt_scope,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index,size_t stage)623 void BFS(const AnfNodePtr &root, const FuncGraphPtr &sub_graph, const mindspore::HashSet<AnfNodePtr> &opt_scope,
624          std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete, const FuncGraphIndexPtr &func_graph_index,
625          size_t stage) {
626   std::queue<std::pair<AnfNodePtr, FuncGraphPtr>> anf_q;
627 
628   if (stage == kStageZero || stage == kStageOne) {
629     (void)anf_q.emplace(root, sub_graph);
630   } else if (stage == kStageTwo) {
631     for (const auto &p : (*need_delete)) {
632       anf_q.push(p);
633     }
634   } else {
635     MS_LOG(INTERNAL_EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage;
636   }
637 
638   while (!anf_q.empty()) {
639     auto top = anf_q.front();
640     anf_q.pop();
641 
642     bool ret = false;
643     if (stage == kStageZero) {
644       ret = DelSrcPattern(top, root, opt_scope, need_delete, func_graph_index);
645     } else if (stage == kStageOne) {
646       ret = AddDstPattern(top, root, opt_scope, need_delete, func_graph_index);
647     } else if (stage == kStageTwo) {
648       ret = DelCascadeNode(top, need_delete, func_graph_index);
649     } else {
650       MS_LOG(INTERNAL_EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage;
651     }
652     if (!ret) {
653       continue;
654     }
655 
656     AppendChild(top.first, top.second, &anf_q);
657   }
658 }
659 }  // namespace
660 
AfterProcess(const AnfNodePtr & old_node,const AnfNodePtr & new_node,const FuncGraphPtr & sub_graph,const FuncGraphIndexPtr & func_graph_index)661 void PatternToPatternPass::AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node,
662                                         const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) {
663   MS_EXCEPTION_IF_NULL(m_);
664   MS_EXCEPTION_IF_NULL(old_node);
665   MS_EXCEPTION_IF_NULL(new_node);
666   MS_EXCEPTION_IF_NULL(sub_graph);
667   MS_EXCEPTION_IF_NULL(func_graph_index);
668   std::set<std::pair<AnfNodePtr, FuncGraphPtr>> need_delete;
669   auto &opt_scope = m_->GetOptScope();
670 
671   auto old_node_iter = func_graph_index->node_degree_.find(old_node);
672   if (old_node_iter == func_graph_index->node_degree_.end()) {
673     MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, old_node: " << old_node->fullname_with_scope()
674                                << " not in degree map";
675   }
676   auto origin_degree = old_node_iter->second;
677 
678   func_graph_index->node_degree_[new_node] = origin_degree;
679   func_graph_index->node_degree_[old_node] = 0;
680 
681   BFS(old_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageZero);
682   BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageOne);
683   BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageTwo);
684 }
685 
Unpacking(const std::string & s)686 std::vector<UnpackNode> PatternToPatternPass::Unpacking(const std::string &s) {
687   MS_EXCEPTION_IF_NULL(m_);
688   auto v = m_->GetSeq(s);
689   std::vector<UnpackNode> ret;
690   std::transform(v.begin(), v.end(), std::back_inserter(ret), [](const AnfNodePtr &node) { return UnpackNode(node); });
691   return ret;
692 }
693 
IsFastPass()694 bool PatternToPatternPass::IsFastPass() { return is_fast_pass_; }
695 }  // namespace opt
696 }  // namespace mindspore
697