1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/optimizer/pattern_to_pattern.h"
18 #include <algorithm>
19 #include <set>
20 #include <queue>
21 #include "ir/manager.h"
22 #include "include/common/utils/anfalgo.h"
23
24 namespace mindspore {
25 namespace opt {
AlwaysReturnTrue(const BaseRef &)26 bool AlwaysReturnTrue(const BaseRef &) { return true; }
27
Contains(const std::string & name) const28 bool PatternMap::Contains(const std::string &name) const { return name_set_.count(name) > 0; }
29
CheckSeq(const std::string & name) const30 bool PatternMap::CheckSeq(const std::string &name) const {
31 return name_set_.count(name) > 0 && seq_map_.count(name) > 0;
32 }
33
Erase(const mindspore::HashSet<std::string> & del_set)34 void PatternMap::Erase(const mindspore::HashSet<std::string> &del_set) {
35 for (auto &s : del_set) {
36 name_set_.erase(s);
37 node_map_.erase(s);
38 }
39 }
40
Get(const std::string & name) const41 AnfNodePtr PatternMap::Get(const std::string &name) const {
42 if (!Contains(name)) {
43 MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " is not in PatternMap";
44 }
45
46 auto iter = node_map_.find(name);
47 if (iter == node_map_.end()) {
48 MS_LOG(INTERNAL_EXCEPTION) << "Var Key: " << name << " is not in PatternMap";
49 }
50 return iter->second;
51 }
52
GetSeq(const std::string & name) const53 const std::vector<AnfNodePtr> &PatternMap::GetSeq(const std::string &name) const {
54 if (!Contains(name)) {
55 MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " is not in PatternMap";
56 }
57
58 auto iter = seq_map_.find(name);
59 if (iter == seq_map_.end()) {
60 MS_LOG(INTERNAL_EXCEPTION) << "SeqVar Key: " << name << " is not in PatternMap";
61 }
62 return iter->second;
63 }
64
Emplace(const std::string & name,const AnfNodePtr & node)65 bool PatternMap::Emplace(const std::string &name, const AnfNodePtr &node) {
66 MS_EXCEPTION_IF_NULL(node);
67 name_set_.insert(name);
68 if (seq_map_.find(name) != seq_map_.end()) {
69 MS_LOG(INTERNAL_EXCEPTION) << "Var Key: " << name << " should not be in SeqVarMap.";
70 }
71
72 (void)opt_scope_.insert(node);
73
74 auto iter = node_map_.find(name);
75 if (iter == node_map_.end()) {
76 node_map_.emplace(name, node);
77 } else if (!opt::AnfEqual(node, iter->second)) {
78 MS_EXCEPTION_IF_NULL(iter->second);
79 MS_LOG(INFO) << "The value of key: " << name
80 << " is not equal to origin value, value: " + node->fullname_with_scope()
81 << " origin value: " << iter->second->fullname_with_scope();
82 return false;
83 }
84 return true;
85 }
86
Emplace(const std::string & name,const std::vector<AnfNodePtr> & v)87 bool PatternMap::Emplace(const std::string &name, const std::vector<AnfNodePtr> &v) {
88 name_set_.insert(name);
89 if (node_map_.find(name) != node_map_.end()) {
90 MS_LOG(INTERNAL_EXCEPTION) << "SeqVar Key: " << name << " should not be in VarMap.";
91 }
92
93 for (const auto &node : v) {
94 (void)opt_scope_.insert(node);
95 }
96
97 auto iter = seq_map_.find(name);
98 if (iter == seq_map_.end()) {
99 seq_map_.emplace(name, v);
100 } else {
101 auto &origin_v = iter->second;
102 if (v.size() != origin_v.size()) {
103 MS_LOG(INFO) << "The value of key: " << name
104 << " is not equal to origin value, v size: " + std::to_string(v.size()) +
105 ", origin_v size: " + std::to_string(origin_v.size());
106 return false;
107 }
108
109 for (size_t i = 0; i < v.size(); i++) {
110 MS_EXCEPTION_IF_NULL(v[i]);
111 MS_EXCEPTION_IF_NULL(origin_v[i]);
112 if (!opt::AnfEqual(v[i], origin_v[i])) {
113 MS_LOG(INFO) << "The value of key: " << name
114 << " is not equal to origin value, value: " + v[i]->fullname_with_scope()
115 << " origin value: " << origin_v[i]->fullname_with_scope();
116 return false;
117 }
118 }
119 }
120 return true;
121 }
122
Clear()123 void PatternMap::Clear() {
124 name_set_.clear();
125 node_map_.clear();
126 seq_map_.clear();
127 }
128
Check(const std::string & name,const AnfNodePtr & node) const129 bool PatternMap::Check(const std::string &name, const AnfNodePtr &node) const { return opt::AnfEqual(node, Get(name)); }
130
AddVar(const std::string & name,const PatternConditionFunc & f)131 SrcPattern &SrcPattern::AddVar(const std::string &name, const PatternConditionFunc &f) {
132 if (ref_map_.find(name) != ref_map_.end()) {
133 MS_LOG(INTERNAL_EXCEPTION) << "Var: " << name << " is already in SrcPattern.";
134 }
135
136 auto var = std::make_shared<CondVar>(f);
137 ref_map_.emplace(name, var);
138 return *this;
139 }
140
AddSeqVar(const std::string & name,const PatternConditionFunc & f)141 SrcPattern &SrcPattern::AddSeqVar(const std::string &name, const PatternConditionFunc &f) {
142 if (ref_map_.find(name) != ref_map_.end()) {
143 MS_LOG(INTERNAL_EXCEPTION) << "SeqVar: " << name << " is already in SrcPattern.";
144 }
145
146 auto seq_var = std::make_shared<SeqVar>(f);
147 ref_map_.emplace(name, seq_var);
148 return *this;
149 }
150
GetRef(const std::string & name) const151 const BaseRef &SrcPattern::GetRef(const std::string &name) const {
152 auto iter = ref_map_.find(name);
153 if (iter == ref_map_.end()) {
154 MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " not in PatternMap";
155 }
156 return iter->second;
157 }
158
AddCNode(const std::string & name,const std::initializer_list<PatternNode> & v)159 SrcPattern &SrcPattern::AddCNode(const std::string &name, const std::initializer_list<PatternNode> &v) {
160 if (ref_map_.find(name) != ref_map_.end()) {
161 MS_LOG(INTERNAL_EXCEPTION) << "CNode: " << name << " is already in SrcPattern.";
162 }
163
164 std::vector<BaseRef> ele;
165 for (auto &node : v) {
166 if (node.type_ == "name") {
167 ele.emplace_back(GetRef(node.name_));
168 } else if (node.type_ == "prim") {
169 ele.emplace_back(node.p_);
170 } else {
171 MS_LOG(INTERNAL_EXCEPTION) << "Error MatchNode Type: " << node.type_ << ", CNode: " << name;
172 }
173 }
174
175 MS_EXCEPTION_IF_CHECK_FAIL(
176 ele.size() == v.size(),
177 "The length of BaseRef Vector and CNode Input is not equal, BaseRef Vector length: " + std::to_string(ele.size()) +
178 " CNode Input length: " + std::to_string(v.size()) + ", CNode: " + name);
179
180 inputs_map_.emplace(name, v);
181 auto vec = VectorRef(ele);
182 ref_map_.emplace(name, vec);
183 has_root_ = true;
184 root_ = name;
185 return *this;
186 }
187
GetRoot() const188 BaseRef SrcPattern::GetRoot() const {
189 if (!has_root_) {
190 MS_LOG(INTERNAL_EXCEPTION) << "This SrcPattern has no root node.";
191 }
192 return GetRef(root_);
193 }
194
GetSeq(const std::string & pattern_name,const std::string & node_name,const VarPtr & var,const EquivPtr & equiv)195 const Seq &GetSeq(const std::string &pattern_name, const std::string &node_name, const VarPtr &var,
196 const EquivPtr &equiv) {
197 MS_EXCEPTION_IF_NULL(equiv);
198 auto equiv_iter = equiv->find(var);
199 if (equiv_iter == equiv->end()) {
200 MS_LOG(INTERNAL_EXCEPTION) << "The SeqVar Key: " << pattern_name << " is not in EquivMap, node name: " << node_name;
201 }
202
203 BaseRef &seq_ref = equiv_iter->second;
204 if (utils::isa<Seq>(seq_ref)) {
205 const Seq &seq = utils::cast<Seq>(seq_ref);
206 return seq;
207 }
208 MS_LOG(INTERNAL_EXCEPTION) << "The value of SeqVar Key: " << pattern_name
209 << " is not a seq, node name: " << node_name;
210 }
211
CheckEmptySeqVar(const std::string & name,const EquivPtr & equiv,const std::vector<PatternNode> & inputs,size_t * now_pattern)212 bool SrcPattern::CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv,
213 const std::vector<PatternNode> &inputs, size_t *now_pattern) {
214 if (inputs.size() - (*now_pattern) == 1 && inputs.at(*now_pattern).type_ == "name") {
215 auto &pattern_node = inputs.at(*now_pattern);
216 auto &ref = GetRef(pattern_node.name_);
217 if (utils::isa<VarPtr>(ref) && utils::cast<VarPtr>(ref)->isa<SeqVar>()) {
218 const Seq &seq = GetSeq(pattern_node.name_, name, utils::cast<VarPtr>(ref), equiv);
219 MS_EXCEPTION_IF_CHECK_FAIL(seq.size() == IntToSize(0), "Match Failed, need zero seq, but get seq length: " +
220 std::to_string(seq.size()) + ", node name: " + name);
221 std::vector<AnfNodePtr> v;
222 MS_EXCEPTION_IF_NULL(m_);
223 if (!m_->Emplace(pattern_node.name_, v)) {
224 return false;
225 }
226 (*now_pattern)++;
227 }
228 }
229 return true;
230 }
231
match(const std::string & name,const AnfNodePtr & node,const EquivPtr & equiv)232 bool SrcPattern::match(const std::string &name, const AnfNodePtr &node, const EquivPtr &equiv) {
233 MS_EXCEPTION_IF_NULL(m_);
234 MS_EXCEPTION_IF_NULL(node);
235 MS_EXCEPTION_IF_NULL(equiv);
236 auto input_iter = inputs_map_.find(name);
237 if (input_iter == inputs_map_.end()) {
238 MS_LOG(INTERNAL_EXCEPTION) << "Key: " << name << " is not a CNode.";
239 }
240
241 if (m_->Contains(name)) {
242 return m_->Check(name, node);
243 }
244
245 auto &inputs = input_iter->second;
246 auto cnode = node->cast<CNodePtr>();
247 MS_EXCEPTION_IF_NULL(cnode);
248 auto cnode_inputs = cnode->inputs();
249 size_t now_pattern = 0;
250 size_t now_match = 0;
251 for (; now_pattern < inputs.size() && now_match < cnode_inputs.size(); now_pattern++, now_match++) {
252 auto &pattern_node = inputs[now_pattern];
253 auto &match_node = cnode_inputs[now_match];
254 if (pattern_node.type_ == "prim") {
255 // prim
256 MS_EXCEPTION_IF_NULL(pattern_node.p_);
257 MS_EXCEPTION_IF_NULL(match_node);
258 if (!opt::AnfEqual(pattern_node.p_, match_node)) {
259 MS_LOG(EXCEPTION) << "The value of Primitive is not equal to matched value, pattern value: " +
260 pattern_node.p_->ToString()
261 << " matched value: " + match_node->ToString() + ", node name: " + name;
262 }
263 continue;
264 }
265 // name
266 MS_EXCEPTION_IF_CHECK_FAIL(pattern_node.type_ == "name",
267 "Error MatchNode Type: " + pattern_node.type_ + ", node name: " + name);
268 auto &ref = GetRef(pattern_node.name_);
269 if (utils::isa<VarPtr>(ref)) {
270 if (utils::cast<VarPtr>(ref)->isa<SeqVar>()) {
271 // seq var
272 const Seq &seq = GetSeq(pattern_node.name_, name, utils::cast<VarPtr>(ref), equiv);
273 std::vector<AnfNodePtr> v;
274 for (size_t i = 0; i < seq.size(); i++) {
275 v.emplace_back(cnode_inputs.at(now_match + i));
276 }
277 if (!m_->Emplace(pattern_node.name_, v)) {
278 return false;
279 }
280 now_match += seq.size() - 1;
281 continue;
282 }
283 } else {
284 // cnode
285 if (!match(pattern_node.name_, match_node, equiv)) {
286 return false;
287 }
288 }
289 if (!m_->Emplace(pattern_node.name_, match_node)) {
290 return false;
291 }
292 }
293 // has a SeqVar at the end
294 if (now_match == cnode_inputs.size()) {
295 if (!CheckEmptySeqVar(name, equiv, inputs, &now_pattern)) {
296 return false;
297 }
298 }
299
300 MS_EXCEPTION_IF_CHECK_FAIL(
301 now_pattern == inputs.size() && now_match == cnode_inputs.size(),
302 "Match Failed, now_pattern: " + std::to_string(now_pattern) + ", inputs.size(): " + std::to_string(inputs.size()) +
303 ", now_match: " + std::to_string(now_match) + ", cnode_inputs.size(): " + std::to_string(cnode_inputs.size()) +
304 ", node name: " + name);
305
306 return m_->Emplace(name, node);
307 }
308
build_pattern_map(const AnfNodePtr & node,const EquivPtr & equiv)309 bool SrcPattern::build_pattern_map(const AnfNodePtr &node, const EquivPtr &equiv) {
310 MS_EXCEPTION_IF_NULL(m_);
311 if (!has_root_) {
312 MS_LOG(INTERNAL_EXCEPTION) << "This SourcePattern has no root node.";
313 }
314 m_->Clear();
315 return match(root_, node, equiv);
316 }
317
AddCNode(const string & name,const std::initializer_list<PatternNode> & inputs,const BuildCNodeFunc & buildfunc)318 DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list<PatternNode> &inputs,
319 const BuildCNodeFunc &buildfunc) {
320 MS_EXCEPTION_IF_NULL(m_);
321 if (fail_) {
322 return *this;
323 }
324
325 if (m_->Contains(name)) {
326 MS_LOG(INTERNAL_EXCEPTION) << "CNode: " + name + " is already in DstPattern";
327 }
328
329 std::vector<AnfNodePtr> anf_inputs;
330 for (auto &r : inputs) {
331 if (r.type_ == "prim") {
332 anf_inputs.emplace_back(r.p_);
333 } else if (r.type_ == "name") {
334 if (m_->CheckSeq(r.name_)) {
335 auto &v = m_->GetSeq(r.name_);
336 std::copy(v.begin(), v.end(), std::back_inserter(anf_inputs));
337 } else {
338 anf_inputs.emplace_back(m_->Get(r.name_));
339 }
340 } else if (r.type_ == "unpack") {
341 for (auto &it : r.v_) {
342 if (it.node_ == nullptr) {
343 anf_inputs.emplace_back(m_->Get(it.key_));
344 } else {
345 anf_inputs.emplace_back(it.node_);
346 }
347 }
348 } else {
349 MS_LOG(INTERNAL_EXCEPTION) << "Error ReplaceNode Type: " << r.type_ << ", CNode: " << name;
350 }
351 }
352
353 MS_EXCEPTION_IF_NULL(pass_);
354 auto default_node = pass_->NewCNode(anf_inputs, fg_);
355 auto new_node = buildfunc(*m_, default_node);
356 if (new_node == nullptr) {
357 fail_ = true;
358 } else {
359 auto cnode = new_node->cast<CNodePtr>();
360 MS_EXCEPTION_IF_NULL(cnode);
361 if (anf_inputs.size() != cnode->size()) {
362 MS_LOG(INTERNAL_EXCEPTION)
363 << "The actual input size does not correspond to the input size of the pattern, actual input size: "
364 << anf_inputs.size() << ", pattern input size: " << new_node->cast<CNodePtr>()->size() << ", CNode: " << name;
365 }
366 for (size_t i = 0; i < anf_inputs.size(); i++) {
367 MS_EXCEPTION_IF_NULL(anf_inputs[i]);
368 MS_EXCEPTION_IF_NULL(cnode->input(i));
369 if (!opt::AnfEqual(anf_inputs[i], cnode->input(i))) {
370 MS_LOG(INTERNAL_EXCEPTION)
371 << "The actual input does not correspond to the input of the pattern, the input index: " << i
372 << ", actual input: " << anf_inputs[i]->DebugString()
373 << ", pattern input: " << new_node->cast<CNodePtr>()->input(i)->DebugString() << ", CNode: " << name;
374 }
375 }
376 }
377
378 if (!m_->Emplace(name, new_node)) {
379 MS_LOG(EXCEPTION) << "CNode: " + name + " is already in DstPattern";
380 }
381 root_ = new_node;
382 return *this;
383 }
384
AddValueNode(const string & name,const BuildValueFunc & buildfunc)385 DstPattern &DstPattern::AddValueNode(const string &name, const BuildValueFunc &buildfunc) {
386 MS_EXCEPTION_IF_NULL(m_);
387 if (fail_) {
388 return *this;
389 }
390
391 if (m_->Contains(name)) {
392 MS_LOG(INTERNAL_EXCEPTION) << "ValueNode: " + name + " is already in DstPattern";
393 }
394
395 auto node = buildfunc(*m_);
396 if (node == nullptr) {
397 fail_ = true;
398 }
399 dst_set_.insert(name);
400 (void)m_->Emplace(name, node);
401 root_ = node;
402 return *this;
403 }
404
clear()405 void DstPattern::clear() {
406 MS_EXCEPTION_IF_NULL(m_);
407 fail_ = false;
408 root_ = nullptr;
409 m_->Erase(dst_set_);
410 dst_set_.clear();
411 fg_ = nullptr;
412 pass_ = nullptr;
413 }
414
set_info(PatternToPatternPass * now_pass,const FuncGraphPtr & func_graph)415 void DstPattern::set_info(PatternToPatternPass *now_pass, const FuncGraphPtr &func_graph) {
416 pass_ = now_pass;
417 fg_ = func_graph;
418 }
419
Root()420 AnfNodePtr DstPattern::Root() {
421 if (fail_) {
422 return nullptr;
423 } else {
424 return root_;
425 }
426 }
427
operator =(const std::string & name)428 UnpackNode &UnpackNode::operator=(const std::string &name) {
429 key_ = name;
430 node_ = nullptr;
431 return *this;
432 }
433
GetSrcPatternRoot()434 AnfNodePtr PatternToPatternPass::GetSrcPatternRoot() {
435 if (src_pattern_root_ == nullptr) {
436 DefineSrcPattern(&src_pattern_);
437 VarPtr fg = std::make_shared<Var>("RootG");
438 src_pattern_root_ = SexpToNode(src_pattern_.GetRoot(), fg, primitive_vars_.get(), multigraph_);
439 }
440 return src_pattern_root_;
441 }
442
GetPatternRootPrimitiveName()443 std::string PatternToPatternPass::GetPatternRootPrimitiveName() {
444 auto src_pattern_root = GetSrcPatternRoot();
445 auto prim = GetCNodePrimitive(src_pattern_root);
446 if (prim != nullptr) {
447 return prim->name();
448 }
449 return "";
450 }
451
Run(const FuncGraphPtr & func_graph,const AnfNodePtr & node)452 AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
453 if (src_pattern_root_ == nullptr) {
454 (void)GetSrcPatternRoot();
455 }
456
457 auto primitive = GetCNodePrimitive(src_pattern_root_);
458 if (IsPrimitiveCNode(node, primitive)) {
459 MS_EXCEPTION_IF_NULL(primitive_vars_);
460 MS_EXCEPTION_IF_NULL(equiv_);
461 equiv_->clear();
462 EquivPtr equiv = pattern_engine_.Match(src_pattern_root_, node, *primitive_vars_, equiv_);
463 if (equiv != nullptr && !equiv->empty()) {
464 if (!src_pattern_.build_pattern_map(node, equiv)) {
465 return nullptr;
466 }
467 if (!CheckMatchedDAG(*m_, func_graph, node)) {
468 return nullptr;
469 }
470 dst_pattern_.clear();
471 dst_pattern_.set_info(this, func_graph);
472 DefineDstPattern(&dst_pattern_);
473 return dst_pattern_.Root();
474 }
475 }
476 return nullptr;
477 }
478
479 namespace {
480 const auto kStageZero = 0;
481 const auto kStageOne = 1;
482 const auto kStageTwo = 2;
483
DeleteCNode(const AnfNodePtr & node,const FuncGraphPtr & sub_graph,const FuncGraphIndexPtr & func_graph_index)484 void DeleteCNode(const AnfNodePtr &node, const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) {
485 MS_EXCEPTION_IF_NULL(node);
486 MS_EXCEPTION_IF_NULL(func_graph_index);
487 if (node->isa<CNode>()) {
488 auto name_to_cnode_iter = func_graph_index->name_to_cnode_.find(GetCNodeKey(node));
489 if (name_to_cnode_iter == func_graph_index->name_to_cnode_.end()) {
490 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find cnode_name: "
491 << common::AnfAlgo::GetCNodeName(node);
492 }
493 auto &cnode_set = name_to_cnode_iter->second;
494 auto cnode_set_iter = cnode_set.find(node);
495 if (cnode_set_iter == cnode_set.end()) {
496 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find node: "
497 << node->fullname_with_scope();
498 }
499 (void)cnode_set.erase(cnode_set_iter);
500 ModifyOutputAndCallerToMap(node->cast<CNodePtr>(), sub_graph, &func_graph_index->subgraph_out_caller_map_, false);
501 }
502 }
503
AppendChild(const AnfNodePtr & node,const FuncGraphPtr & fg,std::queue<std::pair<AnfNodePtr,FuncGraphPtr>> * anf_q)504 void AppendChild(const AnfNodePtr &node, const FuncGraphPtr &fg,
505 std::queue<std::pair<AnfNodePtr, FuncGraphPtr>> *anf_q) {
506 MS_EXCEPTION_IF_NULL(node);
507 MS_EXCEPTION_IF_NULL(fg);
508 MS_EXCEPTION_IF_NULL(anf_q);
509 if (IsValueNode<FuncGraph>(node)) {
510 auto const_func_graph = GetValueNode<FuncGraphPtr>(node);
511 MS_EXCEPTION_IF_NULL(const_func_graph);
512 if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
513 (void)anf_q->emplace(const_func_graph->output(), const_func_graph);
514 }
515 } else if (node->isa<CNode>()) {
516 auto cnode = node->cast<CNodePtr>();
517 MS_EXCEPTION_IF_NULL(cnode);
518 for (const auto &input_node : cnode->inputs()) {
519 (void)anf_q->emplace(input_node, fg);
520 }
521 }
522 }
523
DelSrcPattern(const std::pair<AnfNodePtr,FuncGraphPtr> & top,const AnfNodePtr & root,const mindspore::HashSet<AnfNodePtr> & opt_scope,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index)524 bool DelSrcPattern(const std::pair<AnfNodePtr, FuncGraphPtr> &top, const AnfNodePtr &root,
525 const mindspore::HashSet<AnfNodePtr> &opt_scope,
526 std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
527 const FuncGraphIndexPtr &func_graph_index) {
528 MS_EXCEPTION_IF_NULL(root);
529 MS_EXCEPTION_IF_NULL(need_delete);
530 MS_EXCEPTION_IF_NULL(func_graph_index);
531 auto node = top.first;
532 auto fg = top.second;
533 MS_EXCEPTION_IF_NULL(node);
534 MS_EXCEPTION_IF_NULL(fg);
535 if (node != root) {
536 auto degree_iter = func_graph_index->node_degree_.find(node);
537 if (degree_iter == func_graph_index->node_degree_.end()) {
538 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
539 << " not in degree map";
540 }
541 if (degree_iter->second <= 0) {
542 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
543 << " degree error, degree: " << degree_iter->second;
544 }
545 degree_iter->second--;
546 if (degree_iter->second > 0) {
547 return false;
548 }
549 }
550 if (opt_scope.find(node) == opt_scope.end()) {
551 (void)(*need_delete).insert({node, fg});
552 return false;
553 }
554
555 DeleteCNode(node, fg, func_graph_index);
556 return true;
557 }
558
AddDstPattern(const std::pair<AnfNodePtr,FuncGraphPtr> & top,const AnfNodePtr & root,const mindspore::HashSet<AnfNodePtr> & opt_scope,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index)559 bool AddDstPattern(const std::pair<AnfNodePtr, FuncGraphPtr> &top, const AnfNodePtr &root,
560 const mindspore::HashSet<AnfNodePtr> &opt_scope,
561 std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
562 const FuncGraphIndexPtr &func_graph_index) {
563 MS_EXCEPTION_IF_NULL(root);
564 MS_EXCEPTION_IF_NULL(need_delete);
565 MS_EXCEPTION_IF_NULL(func_graph_index);
566 auto node = top.first;
567 auto fg = top.second;
568 MS_EXCEPTION_IF_NULL(node);
569 MS_EXCEPTION_IF_NULL(fg);
570 if (node->isa<CNode>()) {
571 ModifyOutputAndCallerToMap(node->cast<CNodePtr>(), fg, &func_graph_index->subgraph_out_caller_map_);
572 (void)func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node);
573 func_graph_index->node_to_fg_[node] = fg;
574 }
575
576 if (node != root) {
577 auto degree_iter = func_graph_index->node_degree_.find(node);
578 if (degree_iter == func_graph_index->node_degree_.end()) {
579 func_graph_index->node_degree_[node] = 0;
580 degree_iter = func_graph_index->node_degree_.find(node);
581 }
582 degree_iter->second++;
583 if (degree_iter->second != 1) {
584 return false;
585 }
586 }
587 if (opt_scope.find(node) == opt_scope.end()) {
588 (void)(*need_delete).erase({node, fg});
589 return false;
590 }
591 return true;
592 }
593
DelCascadeNode(const std::pair<AnfNodePtr,FuncGraphPtr> & top,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index)594 bool DelCascadeNode(const std::pair<AnfNodePtr, FuncGraphPtr> &top,
595 std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
596 const FuncGraphIndexPtr &func_graph_index) {
597 MS_EXCEPTION_IF_NULL(need_delete);
598 MS_EXCEPTION_IF_NULL(func_graph_index);
599 auto node = top.first;
600 auto fg = top.second;
601 MS_EXCEPTION_IF_NULL(node);
602 MS_EXCEPTION_IF_NULL(fg);
603 if ((*need_delete).find({node, fg}) == (*need_delete).end()) {
604 auto degree_iter = func_graph_index->node_degree_.find(node);
605 if (degree_iter == func_graph_index->node_degree_.end()) {
606 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
607 << " not in degree map";
608 }
609 if (degree_iter->second <= 0) {
610 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
611 << " degree error, degree: " << degree_iter->second;
612 }
613 degree_iter->second--;
614 if (degree_iter->second > 0) {
615 return false;
616 }
617 }
618
619 DeleteCNode(node, fg, func_graph_index);
620 return true;
621 }
622
BFS(const AnfNodePtr & root,const FuncGraphPtr & sub_graph,const mindspore::HashSet<AnfNodePtr> & opt_scope,std::set<std::pair<AnfNodePtr,FuncGraphPtr>> * need_delete,const FuncGraphIndexPtr & func_graph_index,size_t stage)623 void BFS(const AnfNodePtr &root, const FuncGraphPtr &sub_graph, const mindspore::HashSet<AnfNodePtr> &opt_scope,
624 std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete, const FuncGraphIndexPtr &func_graph_index,
625 size_t stage) {
626 std::queue<std::pair<AnfNodePtr, FuncGraphPtr>> anf_q;
627
628 if (stage == kStageZero || stage == kStageOne) {
629 (void)anf_q.emplace(root, sub_graph);
630 } else if (stage == kStageTwo) {
631 for (const auto &p : (*need_delete)) {
632 anf_q.push(p);
633 }
634 } else {
635 MS_LOG(INTERNAL_EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage;
636 }
637
638 while (!anf_q.empty()) {
639 auto top = anf_q.front();
640 anf_q.pop();
641
642 bool ret = false;
643 if (stage == kStageZero) {
644 ret = DelSrcPattern(top, root, opt_scope, need_delete, func_graph_index);
645 } else if (stage == kStageOne) {
646 ret = AddDstPattern(top, root, opt_scope, need_delete, func_graph_index);
647 } else if (stage == kStageTwo) {
648 ret = DelCascadeNode(top, need_delete, func_graph_index);
649 } else {
650 MS_LOG(INTERNAL_EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage;
651 }
652 if (!ret) {
653 continue;
654 }
655
656 AppendChild(top.first, top.second, &anf_q);
657 }
658 }
659 } // namespace
660
AfterProcess(const AnfNodePtr & old_node,const AnfNodePtr & new_node,const FuncGraphPtr & sub_graph,const FuncGraphIndexPtr & func_graph_index)661 void PatternToPatternPass::AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node,
662 const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) {
663 MS_EXCEPTION_IF_NULL(m_);
664 MS_EXCEPTION_IF_NULL(old_node);
665 MS_EXCEPTION_IF_NULL(new_node);
666 MS_EXCEPTION_IF_NULL(sub_graph);
667 MS_EXCEPTION_IF_NULL(func_graph_index);
668 std::set<std::pair<AnfNodePtr, FuncGraphPtr>> need_delete;
669 auto &opt_scope = m_->GetOptScope();
670
671 auto old_node_iter = func_graph_index->node_degree_.find(old_node);
672 if (old_node_iter == func_graph_index->node_degree_.end()) {
673 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, old_node: " << old_node->fullname_with_scope()
674 << " not in degree map";
675 }
676 auto origin_degree = old_node_iter->second;
677
678 func_graph_index->node_degree_[new_node] = origin_degree;
679 func_graph_index->node_degree_[old_node] = 0;
680
681 BFS(old_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageZero);
682 BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageOne);
683 BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageTwo);
684 }
685
Unpacking(const std::string & s)686 std::vector<UnpackNode> PatternToPatternPass::Unpacking(const std::string &s) {
687 MS_EXCEPTION_IF_NULL(m_);
688 auto v = m_->GetSeq(s);
689 std::vector<UnpackNode> ret;
690 std::transform(v.begin(), v.end(), std::back_inserter(ret), [](const AnfNodePtr &node) { return UnpackNode(node); });
691 return ret;
692 }
693
IsFastPass()694 bool PatternToPatternPass::IsFastPass() { return is_fast_pass_; }
695 } // namespace opt
696 } // namespace mindspore
697