1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/backend/optimizer/node_pass.h"
17
18 #include <deque>
19 #include <utility>
20 #include <vector>
21 #include <set>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/anf.h"
26 #include "ir/func_graph.h"
27 #include "ir/manager.h"
28 #include "utils/hash_map.h"
29 #include "utils/hash_set.h"
30 #include "include/backend/kernel_graph.h"
31 #include "include/common/utils/anfalgo.h"
32
33 namespace mindspore {
34 namespace opt {
35 namespace {
36 const size_t kSwitchBranchIndex = 2;
37 const size_t kCallArgsIndex = 1;
38 const size_t kPartialArgsIndex = 1;
39 } // namespace
40
UpdateCallerAbstract(const AnfNodePtr & call_node,const FuncGraphPtr & call_node_fg,const FuncGraphPtr & sub_graph)41 void UpdateCallerAbstract(const AnfNodePtr &call_node, const FuncGraphPtr &call_node_fg,
42 const FuncGraphPtr &sub_graph) {
43 MS_EXCEPTION_IF_NULL(call_node);
44 MS_EXCEPTION_IF_NULL(call_node_fg);
45 MS_EXCEPTION_IF_NULL(sub_graph);
46 MS_EXCEPTION_IF_NULL(sub_graph->output());
47 call_node->set_abstract(sub_graph->output()->abstract());
48 auto manager = call_node_fg->manager();
49 MS_EXCEPTION_IF_NULL(manager);
50
51 // need update TupleGetItem abstract after call node
52 auto &node_users = manager->node_users();
53 auto iter = node_users.find(call_node);
54 if (iter == node_users.end()) {
55 return;
56 }
57 for (auto &node_index : iter->second) {
58 auto used_node = node_index.first;
59 MS_EXCEPTION_IF_NULL(used_node);
60 if (!common::AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimTupleGetItem)) {
61 continue;
62 }
63 auto idx = common::AnfAlgo::GetTupleGetItemOutIndex(used_node->cast<CNodePtr>());
64 auto call_abstract = call_node->abstract();
65 MS_EXCEPTION_IF_NULL(call_abstract);
66 auto tuple_abstract = call_abstract->cast<abstract::AbstractSequencePtr>();
67 MS_EXCEPTION_IF_NULL(tuple_abstract);
68 auto cur_abstract = tuple_abstract->elements().at(idx);
69 MS_EXCEPTION_IF_NULL(cur_abstract);
70 used_node->set_abstract(cur_abstract->Clone());
71 }
72 }
73
ModifyOutputAndCallerToMap(const CNodePtr & cnode,const FuncGraphPtr & fg,mindspore::HashMap<AnfNodePtr,std::set<AnfNodePtr>> * out_caller_map,bool is_add)74 void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg,
75 mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *out_caller_map, bool is_add) {
76 MS_EXCEPTION_IF_NULL(cnode);
77 MS_EXCEPTION_IF_NULL(out_caller_map);
78 auto inputs = cnode->inputs();
79 if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
80 FuncGraphPtr switch_subgraph = nullptr;
81 const auto &node = inputs.at(kSwitchBranchIndex);
82 MS_EXCEPTION_IF_NULL(node);
83 if (node->isa<CNode>()) {
84 auto partial_node = dyn_cast<CNode>(node);
85 const auto &partial_inputs = partial_node->inputs();
86 MS_EXCEPTION_IF_NULL(partial_inputs.at(0));
87 if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartial)) {
88 MS_EXCEPTION_IF_NULL(partial_inputs.at(kPartialArgsIndex));
89 switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
90 } else if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartialInline)) {
91 switch_subgraph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(partial_node, kAttrKernelGraph);
92 } else {
93 MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
94 }
95 } else if (node->isa<ValueNode>()) {
96 switch_subgraph = GetValueNode<FuncGraphPtr>(node);
97 } else {
98 MS_LOG(EXCEPTION) << "Get unknown cnode: " << cnode->DebugString();
99 }
100 MS_EXCEPTION_IF_NULL(switch_subgraph);
101 if (is_add) {
102 (void)(*out_caller_map)[switch_subgraph->output()].insert(cnode);
103 UpdateCallerAbstract(cnode, fg, switch_subgraph);
104 } else {
105 (void)(*out_caller_map)[switch_subgraph->output()].erase(cnode);
106 }
107 } else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
108 auto call_subgraph = GetValueNode<FuncGraphPtr>(inputs.at(kCallArgsIndex));
109 MS_EXCEPTION_IF_NULL(call_subgraph);
110 if (is_add) {
111 (void)(*out_caller_map)[call_subgraph->output()].insert(cnode);
112 UpdateCallerAbstract(cnode, fg, call_subgraph);
113 } else {
114 (void)(*out_caller_map)[call_subgraph->output()].erase(cnode);
115 }
116 }
117 }
118
UpdateSubGraphCaller(const AnfNodePtr & origin_output,const FuncGraphPtr & fg,mindspore::HashMap<AnfNodePtr,std::set<AnfNodePtr>> * out_caller_map,const mindspore::HashMap<AnfNodePtr,FuncGraphWeakPtr> & node_to_fg)119 void UpdateSubGraphCaller(const AnfNodePtr &origin_output, const FuncGraphPtr &fg,
120 mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *out_caller_map,
121 const mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> &node_to_fg) {
122 MS_EXCEPTION_IF_NULL(fg);
123 MS_EXCEPTION_IF_NULL(fg->output());
124 auto find_iter = (*out_caller_map).find(origin_output);
125 if (find_iter != (*out_caller_map).end()) {
126 auto call_node_list = find_iter->second;
127 (void)(*out_caller_map).erase(find_iter);
128 for (auto &call_node : call_node_list) {
129 auto fg_iter = node_to_fg.find(call_node);
130 if (fg_iter == node_to_fg.end()) {
131 MS_LOG(EXCEPTION) << "Node to Funcgraph find failed: " << call_node->fullname_with_scope();
132 }
133 auto call_node_fg = fg_iter->second.lock();
134 UpdateCallerAbstract(call_node, call_node_fg, fg);
135 }
136 (*out_caller_map)[fg->output()] = call_node_list;
137 }
138 }
139
SkipSameOp(const AnfNodePtr & old_node,const AnfNodePtr & new_node,mindspore::HashSet<AnfNodePtr> * seen_node)140 void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) {
141 MS_EXCEPTION_IF_NULL(seen_node);
142 MS_EXCEPTION_IF_NULL(old_node);
143 MS_EXCEPTION_IF_NULL(new_node);
144 if (old_node->isa<CNode>() && new_node->isa<CNode>() &&
145 (common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) {
146 (void)seen_node->insert(new_node);
147 }
148 }
149
GetCNodeKey(const AnfNodePtr & node)150 std::string GetCNodeKey(const AnfNodePtr &node) {
151 auto primitive = GetCNodePrimitive(node);
152 if (primitive != nullptr) {
153 return primitive->name();
154 } else {
155 return "";
156 }
157 }
158
IsNeedUnfoldSubGraph(const FuncGraphPtr & func_graph)159 bool IsNeedUnfoldSubGraph(const FuncGraphPtr &func_graph) {
160 MS_EXCEPTION_IF_NULL(func_graph);
161 return !func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !func_graph->has_flag(kFlagJitCallGraph);
162 }
163
GenIndex(const FuncGraphPtr & func_graph,const FuncGraphIndexPtr & func_graph_index)164 void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) {
165 MS_EXCEPTION_IF_NULL(func_graph);
166 MS_EXCEPTION_IF_NULL(func_graph_index);
167 if (func_graph_index->has_gen_index()) {
168 return;
169 }
170
171 func_graph_index->set_has_gen_index(true);
172 func_graph_index->node_to_fg_.clear();
173 func_graph_index->node_degree_.clear();
174 func_graph_index->name_to_cnode_.clear();
175 func_graph_index->subgraph_out_caller_map_.clear();
176
177 FuncGraphManagerPtr manager = func_graph->manager();
178 MS_EXCEPTION_IF_NULL(manager);
179 mindspore::HashSet<AnfNodePtr> seen_node;
180 std::deque<std::pair<AnfNodePtr, FuncGraphPtr>> todo{{func_graph->output(), func_graph}};
181
182 while (!todo.empty()) {
183 AnfNodePtr node = todo.front().first;
184 MS_EXCEPTION_IF_NULL(node);
185 auto fg = todo.front().second;
186 manager->AddFuncGraph(fg);
187 todo.pop_front();
188
189 func_graph_index->node_to_fg_[node] = fg;
190 auto degree_iter = func_graph_index->node_degree_.find(node);
191 if (degree_iter == func_graph_index->node_degree_.end()) {
192 func_graph_index->node_degree_[node] = 1;
193 } else {
194 degree_iter->second++;
195 }
196 if (node->isa<CNode>()) {
197 (void)func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node);
198 }
199
200 if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
201 continue;
202 }
203 (void)seen_node.insert(node);
204 TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
205
206 if (IsValueNode<FuncGraph>(node)) {
207 auto const_func_graph = GetValueNode<FuncGraphPtr>(node);
208 MS_EXCEPTION_IF_NULL(const_func_graph);
209 if (IsNeedUnfoldSubGraph(const_func_graph)) {
210 (void)todo.emplace_back(const_func_graph->output(), const_func_graph);
211 }
212 } else if (node->isa<CNode>()) {
213 auto cnode = node->cast<CNodePtr>();
214 MS_EXCEPTION_IF_NULL(cnode);
215 ModifyOutputAndCallerToMap(cnode, fg, &func_graph_index->subgraph_out_caller_map_);
216 auto inputs = cnode->inputs();
217 (void)std::for_each(inputs.begin(), inputs.end(),
218 [&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); });
219 }
220 }
221 }
222
ProcessFastPassNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph,const FuncGraphIndexPtr & func_graph_index,const FuncGraphManagerPtr & manager)223 bool NodePass::ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
224 const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager) {
225 MS_EXCEPTION_IF_NULL(node);
226 MS_EXCEPTION_IF_NULL(func_graph);
227 MS_EXCEPTION_IF_NULL(func_graph_index);
228 MS_EXCEPTION_IF_NULL(manager);
229 auto iter = func_graph_index->node_to_fg_.find(node);
230 if (iter == func_graph_index->node_to_fg_.end()) {
231 MS_LOG(EXCEPTION) << "Node to Funcgraph map can't find node: " << node->fullname_with_scope();
232 }
233 auto fg = iter->second.lock();
234 TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
235 auto degree_iter = func_graph_index->node_degree_.find(node);
236 if (degree_iter == func_graph_index->node_degree_.end()) {
237 MS_LOG(EXCEPTION) << "Node degree map can't find node: " << node->fullname_with_scope();
238 }
239 auto degree = degree_iter->second;
240 if (degree == 0 && node != func_graph->output()) {
241 return false;
242 }
243 // we may update return value in some pass.
244 MS_EXCEPTION_IF_NULL(fg);
245 auto origin_output = fg->output();
246 MS_EXCEPTION_IF_NULL(origin_output);
247 auto origin_abstract = origin_output->abstract();
248 AnfNodePtr new_node = Run(fg, node);
249 bool change = (new_node != nullptr);
250 MS_EXCEPTION_IF_NULL(fg->output());
251 if (origin_abstract != fg->output()->abstract()) {
252 UpdateSubGraphCaller(origin_output, fg, &func_graph_index->subgraph_out_caller_map_, func_graph_index->node_to_fg_);
253 }
254 if (new_node != nullptr && new_node != node) {
255 (void)manager->Replace(node, new_node);
256 // if replaced node is end_goto, refresh relative params in kernel graph
257 auto kernel_graph = fg->cast<std::shared_ptr<session::KernelGraph>>();
258 if (kernel_graph != nullptr && node->isa<CNode>()) {
259 auto cnode = node->cast<CNodePtr>();
260 MS_EXCEPTION_IF_NULL(cnode);
261 auto end_label = kernel_graph->get_end_goto();
262 if (cnode == end_label && common::AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
263 kernel_graph->set_end_goto(new_node->cast<CNodePtr>());
264 }
265 }
266 AfterProcess(node, new_node, fg, func_graph_index);
267 }
268 return change;
269 }
270
ProcessFastPass(const FuncGraphPtr & func_graph,const FuncGraphIndexPtr & func_graph_index)271 bool NodePass::ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) {
272 MS_EXCEPTION_IF_NULL(func_graph);
273 MS_EXCEPTION_IF_NULL(func_graph_index);
274 if (!func_graph_index->has_gen_index()) {
275 MS_LOG(INTERNAL_EXCEPTION) << "ProcessFastPass Error, func graph has not gen index, pass name: " << name();
276 }
277 auto src_pattern_root_name = GetPatternRootPrimitiveName();
278 FuncGraphManagerPtr manager = func_graph->manager();
279 MS_EXCEPTION_IF_NULL(manager);
280 bool changes = false;
281
282 std::vector<AnfNodePtr> cand_node;
283 if (!src_pattern_root_name.empty()) {
284 auto cnode_iter = func_graph_index->name_to_cnode_.find(src_pattern_root_name);
285 if (cnode_iter == func_graph_index->name_to_cnode_.end()) {
286 return false;
287 }
288 (void)std::copy(cnode_iter->second.begin(), cnode_iter->second.end(), std::back_inserter(cand_node));
289 } else {
290 for (const auto &kv : func_graph_index->name_to_cnode_) {
291 (void)std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(cand_node));
292 }
293 }
294 for (const auto &node : cand_node) {
295 auto change = ProcessFastPassNode(node, func_graph, func_graph_index, manager);
296 changes = changes || change;
297 }
298 return changes;
299 }
300
ProcessPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)301 bool NodePass::ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
302 MS_EXCEPTION_IF_NULL(func_graph);
303 MS_EXCEPTION_IF_NULL(manager);
304 bool changes = false;
305
306 // maybe call subgraph many times
307 mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> subgraph_out_caller_map = {};
308 mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> node_to_fg = {};
309 mindspore::HashSet<AnfNodePtr> seen_node;
310 std::deque<std::pair<AnfNodePtr, FuncGraphPtr>> todo{{func_graph->get_return(), func_graph}};
311 while (!todo.empty()) {
312 AnfNodePtr node = todo.front().first;
313 auto fg = todo.front().second;
314 MS_EXCEPTION_IF_NULL(node);
315 manager->AddFuncGraph(fg);
316 todo.pop_front();
317 node_to_fg[node] = fg;
318 if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
319 continue;
320 }
321 (void)seen_node.insert(node);
322 TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
323 // we may update return value in some pass.
324 MS_EXCEPTION_IF_NULL(fg);
325 auto origin_output = fg->output();
326 MS_EXCEPTION_IF_NULL(origin_output);
327 auto origin_abstract = origin_output->abstract();
328 AnfNodePtr new_node = Run(fg, node);
329 bool change = (new_node != nullptr);
330 if (origin_abstract != fg->output()->abstract()) {
331 UpdateSubGraphCaller(origin_output, fg, &subgraph_out_caller_map, node_to_fg);
332 }
333 if (new_node != nullptr && new_node != node) {
334 SkipSameOp(node, new_node, &seen_node);
335 (void)manager->Replace(node, new_node);
336 // if replaced node is end_goto, refresh relative params in kernel graph
337 auto kernel_graph = fg->cast<std::shared_ptr<session::KernelGraph>>();
338 if (kernel_graph != nullptr && node->isa<CNode>()) {
339 auto cnode = node->cast<CNodePtr>();
340 MS_EXCEPTION_IF_NULL(cnode);
341 auto end_label = kernel_graph->get_end_goto();
342 if (cnode == end_label && common::AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
343 kernel_graph->set_end_goto(new_node->cast<CNodePtr>());
344 }
345 }
346 (void)seen_node.erase(node);
347 } else if (new_node == nullptr) {
348 new_node = node;
349 }
350 if (new_node && IsValueNode<FuncGraph>(new_node)) {
351 auto const_func_graph = GetValueNode<FuncGraphPtr>(new_node);
352 MS_EXCEPTION_IF_NULL(const_func_graph);
353 if (IsNeedUnfoldSubGraph(const_func_graph)) {
354 (void)todo.emplace_back(const_func_graph->output(), const_func_graph);
355 }
356 } else if (new_node && new_node->isa<CNode>()) {
357 if (common::AnfAlgo::IsGraphKernel(new_node)) {
358 (void)todo.emplace_back(new_node, func_graph);
359 }
360 auto cnode = new_node->cast<CNodePtr>();
361 MS_EXCEPTION_IF_NULL(cnode);
362 ModifyOutputAndCallerToMap(cnode, fg, &subgraph_out_caller_map, is_add_);
363 auto inputs = cnode->inputs();
364 (void)std::for_each(inputs.begin(), inputs.end(),
365 [&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); });
366 }
367 changes = changes || change;
368 }
369 return changes;
370 }
371
Run(const FuncGraphPtr & func_graph)372 bool NodePass::Run(const FuncGraphPtr &func_graph) {
373 MS_EXCEPTION_IF_NULL(func_graph);
374 FuncGraphManagerPtr manager = func_graph->manager();
375 MS_EXCEPTION_IF_NULL(manager);
376 manager->AddFuncGraph(func_graph);
377 if (!func_graph->has_user_data<FuncGraphPassIndex>()) {
378 func_graph->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
379 }
380 auto func_graph_index = func_graph->user_data<FuncGraphPassIndex>();
381 MS_EXCEPTION_IF_NULL(func_graph_index);
382
383 if (IsFastPass()) {
384 MS_LOG(INFO) << "Run fast pass: " << name();
385 GenIndex(func_graph, func_graph_index);
386 return ProcessFastPass(func_graph, func_graph_index);
387 }
388 if (func_graph_index->has_gen_index()) {
389 const auto &ret = MustExistPrimitiveName();
390 for (const auto &primtive_name : ret) {
391 const auto cnode_iter = func_graph_index->name_to_cnode_.find(primtive_name);
392 if (cnode_iter == func_graph_index->name_to_cnode_.end()) {
393 MS_LOG(INFO) << "Prim " << primtive_name << " not exist in name to cnode";
394 return false;
395 }
396 }
397 if (!ret.empty()) {
398 MS_LOG(INFO) << "Skip pass fail, run pass: " << name();
399 }
400 }
401 func_graph_index->set_has_gen_index(false);
402
403 return ProcessPass(func_graph, manager);
404 }
405 } // namespace opt
406 } // namespace mindspore
407