• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/backend/optimizer/node_pass.h"
17 
18 #include <deque>
19 #include <utility>
20 #include <vector>
21 #include <set>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/anf.h"
26 #include "ir/func_graph.h"
27 #include "ir/manager.h"
28 #include "utils/hash_map.h"
29 #include "utils/hash_set.h"
30 #include "include/backend/kernel_graph.h"
31 #include "include/common/utils/anfalgo.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace {
36 const size_t kSwitchBranchIndex = 2;
37 const size_t kCallArgsIndex = 1;
38 const size_t kPartialArgsIndex = 1;
39 }  // namespace
40 
UpdateCallerAbstract(const AnfNodePtr & call_node,const FuncGraphPtr & call_node_fg,const FuncGraphPtr & sub_graph)41 void UpdateCallerAbstract(const AnfNodePtr &call_node, const FuncGraphPtr &call_node_fg,
42                           const FuncGraphPtr &sub_graph) {
43   MS_EXCEPTION_IF_NULL(call_node);
44   MS_EXCEPTION_IF_NULL(call_node_fg);
45   MS_EXCEPTION_IF_NULL(sub_graph);
46   MS_EXCEPTION_IF_NULL(sub_graph->output());
47   call_node->set_abstract(sub_graph->output()->abstract());
48   auto manager = call_node_fg->manager();
49   MS_EXCEPTION_IF_NULL(manager);
50 
51   // need update TupleGetItem abstract after call node
52   auto &node_users = manager->node_users();
53   auto iter = node_users.find(call_node);
54   if (iter == node_users.end()) {
55     return;
56   }
57   for (auto &node_index : iter->second) {
58     auto used_node = node_index.first;
59     MS_EXCEPTION_IF_NULL(used_node);
60     if (!common::AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimTupleGetItem)) {
61       continue;
62     }
63     auto idx = common::AnfAlgo::GetTupleGetItemOutIndex(used_node->cast<CNodePtr>());
64     auto call_abstract = call_node->abstract();
65     MS_EXCEPTION_IF_NULL(call_abstract);
66     auto tuple_abstract = call_abstract->cast<abstract::AbstractSequencePtr>();
67     MS_EXCEPTION_IF_NULL(tuple_abstract);
68     auto cur_abstract = tuple_abstract->elements().at(idx);
69     MS_EXCEPTION_IF_NULL(cur_abstract);
70     used_node->set_abstract(cur_abstract->Clone());
71   }
72 }
73 
ModifyOutputAndCallerToMap(const CNodePtr & cnode,const FuncGraphPtr & fg,mindspore::HashMap<AnfNodePtr,std::set<AnfNodePtr>> * out_caller_map,bool is_add)74 void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg,
75                                 mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *out_caller_map, bool is_add) {
76   MS_EXCEPTION_IF_NULL(cnode);
77   MS_EXCEPTION_IF_NULL(out_caller_map);
78   auto inputs = cnode->inputs();
79   if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
80     FuncGraphPtr switch_subgraph = nullptr;
81     const auto &node = inputs.at(kSwitchBranchIndex);
82     MS_EXCEPTION_IF_NULL(node);
83     if (node->isa<CNode>()) {
84       auto partial_node = dyn_cast<CNode>(node);
85       const auto &partial_inputs = partial_node->inputs();
86       MS_EXCEPTION_IF_NULL(partial_inputs.at(0));
87       if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartial)) {
88         MS_EXCEPTION_IF_NULL(partial_inputs.at(kPartialArgsIndex));
89         switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
90       } else if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartialInline)) {
91         switch_subgraph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(partial_node, kAttrKernelGraph);
92       } else {
93         MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
94       }
95     } else if (node->isa<ValueNode>()) {
96       switch_subgraph = GetValueNode<FuncGraphPtr>(node);
97     } else {
98       MS_LOG(EXCEPTION) << "Get unknown cnode: " << cnode->DebugString();
99     }
100     MS_EXCEPTION_IF_NULL(switch_subgraph);
101     if (is_add) {
102       (void)(*out_caller_map)[switch_subgraph->output()].insert(cnode);
103       UpdateCallerAbstract(cnode, fg, switch_subgraph);
104     } else {
105       (void)(*out_caller_map)[switch_subgraph->output()].erase(cnode);
106     }
107   } else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
108     auto call_subgraph = GetValueNode<FuncGraphPtr>(inputs.at(kCallArgsIndex));
109     MS_EXCEPTION_IF_NULL(call_subgraph);
110     if (is_add) {
111       (void)(*out_caller_map)[call_subgraph->output()].insert(cnode);
112       UpdateCallerAbstract(cnode, fg, call_subgraph);
113     } else {
114       (void)(*out_caller_map)[call_subgraph->output()].erase(cnode);
115     }
116   }
117 }
118 
UpdateSubGraphCaller(const AnfNodePtr & origin_output,const FuncGraphPtr & fg,mindspore::HashMap<AnfNodePtr,std::set<AnfNodePtr>> * out_caller_map,const mindspore::HashMap<AnfNodePtr,FuncGraphWeakPtr> & node_to_fg)119 void UpdateSubGraphCaller(const AnfNodePtr &origin_output, const FuncGraphPtr &fg,
120                           mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *out_caller_map,
121                           const mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> &node_to_fg) {
122   MS_EXCEPTION_IF_NULL(fg);
123   MS_EXCEPTION_IF_NULL(fg->output());
124   auto find_iter = (*out_caller_map).find(origin_output);
125   if (find_iter != (*out_caller_map).end()) {
126     auto call_node_list = find_iter->second;
127     (void)(*out_caller_map).erase(find_iter);
128     for (auto &call_node : call_node_list) {
129       auto fg_iter = node_to_fg.find(call_node);
130       if (fg_iter == node_to_fg.end()) {
131         MS_LOG(EXCEPTION) << "Node to Funcgraph find failed: " << call_node->fullname_with_scope();
132       }
133       auto call_node_fg = fg_iter->second.lock();
134       UpdateCallerAbstract(call_node, call_node_fg, fg);
135     }
136     (*out_caller_map)[fg->output()] = call_node_list;
137   }
138 }
139 
SkipSameOp(const AnfNodePtr & old_node,const AnfNodePtr & new_node,mindspore::HashSet<AnfNodePtr> * seen_node)140 void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) {
141   MS_EXCEPTION_IF_NULL(seen_node);
142   MS_EXCEPTION_IF_NULL(old_node);
143   MS_EXCEPTION_IF_NULL(new_node);
144   if (old_node->isa<CNode>() && new_node->isa<CNode>() &&
145       (common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) {
146     (void)seen_node->insert(new_node);
147   }
148 }
149 
GetCNodeKey(const AnfNodePtr & node)150 std::string GetCNodeKey(const AnfNodePtr &node) {
151   auto primitive = GetCNodePrimitive(node);
152   if (primitive != nullptr) {
153     return primitive->name();
154   } else {
155     return "";
156   }
157 }
158 
IsNeedUnfoldSubGraph(const FuncGraphPtr & func_graph)159 bool IsNeedUnfoldSubGraph(const FuncGraphPtr &func_graph) {
160   MS_EXCEPTION_IF_NULL(func_graph);
161   return !func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !func_graph->has_flag(kFlagJitCallGraph);
162 }
163 
GenIndex(const FuncGraphPtr & func_graph,const FuncGraphIndexPtr & func_graph_index)164 void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) {
165   MS_EXCEPTION_IF_NULL(func_graph);
166   MS_EXCEPTION_IF_NULL(func_graph_index);
167   if (func_graph_index->has_gen_index()) {
168     return;
169   }
170 
171   func_graph_index->set_has_gen_index(true);
172   func_graph_index->node_to_fg_.clear();
173   func_graph_index->node_degree_.clear();
174   func_graph_index->name_to_cnode_.clear();
175   func_graph_index->subgraph_out_caller_map_.clear();
176 
177   FuncGraphManagerPtr manager = func_graph->manager();
178   MS_EXCEPTION_IF_NULL(manager);
179   mindspore::HashSet<AnfNodePtr> seen_node;
180   std::deque<std::pair<AnfNodePtr, FuncGraphPtr>> todo{{func_graph->output(), func_graph}};
181 
182   while (!todo.empty()) {
183     AnfNodePtr node = todo.front().first;
184     MS_EXCEPTION_IF_NULL(node);
185     auto fg = todo.front().second;
186     manager->AddFuncGraph(fg);
187     todo.pop_front();
188 
189     func_graph_index->node_to_fg_[node] = fg;
190     auto degree_iter = func_graph_index->node_degree_.find(node);
191     if (degree_iter == func_graph_index->node_degree_.end()) {
192       func_graph_index->node_degree_[node] = 1;
193     } else {
194       degree_iter->second++;
195     }
196     if (node->isa<CNode>()) {
197       (void)func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node);
198     }
199 
200     if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
201       continue;
202     }
203     (void)seen_node.insert(node);
204     TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
205 
206     if (IsValueNode<FuncGraph>(node)) {
207       auto const_func_graph = GetValueNode<FuncGraphPtr>(node);
208       MS_EXCEPTION_IF_NULL(const_func_graph);
209       if (IsNeedUnfoldSubGraph(const_func_graph)) {
210         (void)todo.emplace_back(const_func_graph->output(), const_func_graph);
211       }
212     } else if (node->isa<CNode>()) {
213       auto cnode = node->cast<CNodePtr>();
214       MS_EXCEPTION_IF_NULL(cnode);
215       ModifyOutputAndCallerToMap(cnode, fg, &func_graph_index->subgraph_out_caller_map_);
216       auto inputs = cnode->inputs();
217       (void)std::for_each(inputs.begin(), inputs.end(),
218                           [&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); });
219     }
220   }
221 }
222 
ProcessFastPassNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph,const FuncGraphIndexPtr & func_graph_index,const FuncGraphManagerPtr & manager)223 bool NodePass::ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
224                                    const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager) {
225   MS_EXCEPTION_IF_NULL(node);
226   MS_EXCEPTION_IF_NULL(func_graph);
227   MS_EXCEPTION_IF_NULL(func_graph_index);
228   MS_EXCEPTION_IF_NULL(manager);
229   auto iter = func_graph_index->node_to_fg_.find(node);
230   if (iter == func_graph_index->node_to_fg_.end()) {
231     MS_LOG(EXCEPTION) << "Node to Funcgraph map can't find node: " << node->fullname_with_scope();
232   }
233   auto fg = iter->second.lock();
234   TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
235   auto degree_iter = func_graph_index->node_degree_.find(node);
236   if (degree_iter == func_graph_index->node_degree_.end()) {
237     MS_LOG(EXCEPTION) << "Node degree map can't find node: " << node->fullname_with_scope();
238   }
239   auto degree = degree_iter->second;
240   if (degree == 0 && node != func_graph->output()) {
241     return false;
242   }
243   // we may update return value in some pass.
244   MS_EXCEPTION_IF_NULL(fg);
245   auto origin_output = fg->output();
246   MS_EXCEPTION_IF_NULL(origin_output);
247   auto origin_abstract = origin_output->abstract();
248   AnfNodePtr new_node = Run(fg, node);
249   bool change = (new_node != nullptr);
250   MS_EXCEPTION_IF_NULL(fg->output());
251   if (origin_abstract != fg->output()->abstract()) {
252     UpdateSubGraphCaller(origin_output, fg, &func_graph_index->subgraph_out_caller_map_, func_graph_index->node_to_fg_);
253   }
254   if (new_node != nullptr && new_node != node) {
255     (void)manager->Replace(node, new_node);
256     // if replaced node is end_goto, refresh relative params in kernel graph
257     auto kernel_graph = fg->cast<std::shared_ptr<session::KernelGraph>>();
258     if (kernel_graph != nullptr && node->isa<CNode>()) {
259       auto cnode = node->cast<CNodePtr>();
260       MS_EXCEPTION_IF_NULL(cnode);
261       auto end_label = kernel_graph->get_end_goto();
262       if (cnode == end_label && common::AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
263         kernel_graph->set_end_goto(new_node->cast<CNodePtr>());
264       }
265     }
266     AfterProcess(node, new_node, fg, func_graph_index);
267   }
268   return change;
269 }
270 
ProcessFastPass(const FuncGraphPtr & func_graph,const FuncGraphIndexPtr & func_graph_index)271 bool NodePass::ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) {
272   MS_EXCEPTION_IF_NULL(func_graph);
273   MS_EXCEPTION_IF_NULL(func_graph_index);
274   if (!func_graph_index->has_gen_index()) {
275     MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, func graph has not gen index, pass name: " << name();
276   }
277   auto src_pattern_root_name = GetPatternRootPrimitiveName();
278   FuncGraphManagerPtr manager = func_graph->manager();
279   MS_EXCEPTION_IF_NULL(manager);
280   bool changes = false;
281 
282   std::vector<AnfNodePtr> cand_node;
283   if (!src_pattern_root_name.empty()) {
284     auto cnode_iter = func_graph_index->name_to_cnode_.find(src_pattern_root_name);
285     if (cnode_iter == func_graph_index->name_to_cnode_.end()) {
286       return false;
287     }
288     (void)std::copy(cnode_iter->second.begin(), cnode_iter->second.end(), std::back_inserter(cand_node));
289   } else {
290     for (const auto &kv : func_graph_index->name_to_cnode_) {
291       (void)std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(cand_node));
292     }
293   }
294   for (const auto &node : cand_node) {
295     auto change = ProcessFastPassNode(node, func_graph, func_graph_index, manager);
296     changes = changes || change;
297   }
298   return changes;
299 }
300 
ProcessPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)301 bool NodePass::ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
302   MS_EXCEPTION_IF_NULL(func_graph);
303   MS_EXCEPTION_IF_NULL(manager);
304   bool changes = false;
305 
306   // maybe call subgraph many times
307   mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> subgraph_out_caller_map = {};
308   mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> node_to_fg = {};
309   mindspore::HashSet<AnfNodePtr> seen_node;
310   std::deque<std::pair<AnfNodePtr, FuncGraphPtr>> todo{{func_graph->get_return(), func_graph}};
311   while (!todo.empty()) {
312     AnfNodePtr node = todo.front().first;
313     auto fg = todo.front().second;
314     MS_EXCEPTION_IF_NULL(node);
315     manager->AddFuncGraph(fg);
316     todo.pop_front();
317     node_to_fg[node] = fg;
318     if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
319       continue;
320     }
321     (void)seen_node.insert(node);
322     TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
323     // we may update return value in some pass.
324     MS_EXCEPTION_IF_NULL(fg);
325     auto origin_output = fg->output();
326     MS_EXCEPTION_IF_NULL(origin_output);
327     auto origin_abstract = origin_output->abstract();
328     AnfNodePtr new_node = Run(fg, node);
329     bool change = (new_node != nullptr);
330     if (origin_abstract != fg->output()->abstract()) {
331       UpdateSubGraphCaller(origin_output, fg, &subgraph_out_caller_map, node_to_fg);
332     }
333     if (new_node != nullptr && new_node != node) {
334       SkipSameOp(node, new_node, &seen_node);
335       (void)manager->Replace(node, new_node);
336       // if replaced node is end_goto, refresh relative params in kernel graph
337       auto kernel_graph = fg->cast<std::shared_ptr<session::KernelGraph>>();
338       if (kernel_graph != nullptr && node->isa<CNode>()) {
339         auto cnode = node->cast<CNodePtr>();
340         MS_EXCEPTION_IF_NULL(cnode);
341         auto end_label = kernel_graph->get_end_goto();
342         if (cnode == end_label && common::AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
343           kernel_graph->set_end_goto(new_node->cast<CNodePtr>());
344         }
345       }
346       (void)seen_node.erase(node);
347     } else if (new_node == nullptr) {
348       new_node = node;
349     }
350     if (new_node && IsValueNode<FuncGraph>(new_node)) {
351       auto const_func_graph = GetValueNode<FuncGraphPtr>(new_node);
352       MS_EXCEPTION_IF_NULL(const_func_graph);
353       if (IsNeedUnfoldSubGraph(const_func_graph)) {
354         (void)todo.emplace_back(const_func_graph->output(), const_func_graph);
355       }
356     } else if (new_node && new_node->isa<CNode>()) {
357       if (common::AnfAlgo::IsGraphKernel(new_node)) {
358         (void)todo.emplace_back(new_node, func_graph);
359       }
360       auto cnode = new_node->cast<CNodePtr>();
361       MS_EXCEPTION_IF_NULL(cnode);
362       ModifyOutputAndCallerToMap(cnode, fg, &subgraph_out_caller_map, is_add_);
363       auto inputs = cnode->inputs();
364       (void)std::for_each(inputs.begin(), inputs.end(),
365                           [&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); });
366     }
367     changes = changes || change;
368   }
369   return changes;
370 }
371 
Run(const FuncGraphPtr & func_graph)372 bool NodePass::Run(const FuncGraphPtr &func_graph) {
373   MS_EXCEPTION_IF_NULL(func_graph);
374   FuncGraphManagerPtr manager = func_graph->manager();
375   MS_EXCEPTION_IF_NULL(manager);
376   manager->AddFuncGraph(func_graph);
377   if (!func_graph->has_user_data<FuncGraphPassIndex>()) {
378     func_graph->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
379   }
380   auto func_graph_index = func_graph->user_data<FuncGraphPassIndex>();
381   MS_EXCEPTION_IF_NULL(func_graph_index);
382 
383   if (IsFastPass()) {
384     MS_LOG(INFO) << "Run fast pass: " << name();
385     GenIndex(func_graph, func_graph_index);
386     return ProcessFastPass(func_graph, func_graph_index);
387   }
388   if (func_graph_index->has_gen_index()) {
389     const auto &ret = MustExistPrimitiveName();
390     for (const auto &primtive_name : ret) {
391       const auto cnode_iter = func_graph_index->name_to_cnode_.find(primtive_name);
392       if (cnode_iter == func_graph_index->name_to_cnode_.end()) {
393         MS_LOG(INFO) << "Prim " << primtive_name << " not exist in name to cnode";
394         return false;
395       }
396     }
397     if (!ret.empty()) {
398       MS_LOG(INFO) << "Skip pass fail, run pass: " << name();
399     }
400   }
401   func_graph_index->set_has_gen_index(false);
402 
403   return ProcessPass(func_graph, manager);
404 }
405 }  // namespace opt
406 }  // namespace mindspore
407