1 /**
2 * Copyright 2020-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <vector>
23
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/optimizer/irpass.h"
29 #include "frontend/optimizer/optimizer_caller.h"
30 #include "frontend/optimizer/anf_visitor.h"
31 #include "ir/pattern_matcher.h"
32
33 namespace mindspore::opt::irpass {
34 namespace {
35 // data = Load(input, attach)
36 // data = Depend(input, attach)
37 // monad = UpdateState(input, attach)
38 constexpr size_t kFirstInputIndex = 0;
39 constexpr size_t kInputIndex = 1;
40 constexpr size_t kAttachIndex = 2;
41 constexpr size_t kMakeTupleSize = 3;
42 constexpr size_t kDependSize = 3;
43 constexpr size_t kUpdateStateSize = 3;
44 constexpr size_t kAssignSize = 4;
45 constexpr size_t kAssignRefInputIndex = 1;
46 constexpr size_t kAssignMonadInputIndex = 3;
47
GetManager(const AnfNodePtr & node)48 FuncGraphManagerPtr GetManager(const AnfNodePtr &node) {
49 auto fg = node->func_graph();
50 if (fg == nullptr) {
51 return nullptr;
52 }
53 return fg->manager();
54 }
55
56 // Return true if the node(be_used_node) is only used by the given node.
OnlyUsedByOneNode(const AnfNodePtr & be_used_node,const CNodePtr & given_node)57 bool OnlyUsedByOneNode(const AnfNodePtr &be_used_node, const CNodePtr &given_node) {
58 auto mgr = GetManager(given_node);
59 if (mgr == nullptr) {
60 return false;
61 }
62 auto &node_users = mgr->node_users();
63 auto iter = node_users.find(be_used_node);
64 if (iter == node_users.end()) {
65 return false;
66 }
67 auto &partial_users = iter->second;
68 return (partial_users.size() == 1) && (partial_users.front().first == given_node);
69 }
70
71 // Return true if the node(be_used_node) is only used by the given two nodes(first_node and second_node).
OnlyUsedByTwoNode(const AnfNodePtr & be_used_node,const AnfNodePtr & first_node,const AnfNodePtr & second_node)72 bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_node, const AnfNodePtr &second_node) {
73 auto mgr = GetManager(be_used_node);
74 if (mgr == nullptr || first_node == second_node) {
75 return false;
76 }
77 auto &node_users = mgr->node_users();
78 auto iter = node_users.find(be_used_node);
79 if (iter == node_users.end()) {
80 return false;
81 }
82 constexpr size_t partial_users_cnt = 2;
83 auto &partial_users = iter->second;
84 if (partial_users.size() != partial_users_cnt) {
85 return false;
86 }
87 const auto &first_user = partial_users.front().first;
88 const auto &second_user = partial_users.back().first;
89 return (first_user == first_node && second_user == second_node) ||
90 (first_user == second_node && second_user == first_node);
91 }
92
93 // Determine whether there is a monad in the inputs of the node.
CheckHasMonadInput(const CNodePtr & cnode)94 bool CheckHasMonadInput(const CNodePtr &cnode) {
95 // If the last input is a monad, means the attach node has side-effect and
96 // we should keep UpdateState; otherwise, we will remove the UpdateState.
97 if (cnode->size() > 1 && HasAbstractMonad(cnode->inputs().back())) {
98 return true;
99 }
100
101 // Check the inputs of Call/Switch/SwitchLayer.
102 auto first_input_node = cnode->input(kFirstInputIndex);
103 if (IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
104 IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
105 for (auto &weak_input : first_input_node->cast<CNodePtr>()->weak_inputs()) {
106 auto input = weak_input.lock();
107 MS_EXCEPTION_IF_NULL(input);
108 if (HasAbstractMonad(input)) {
109 return true;
110 }
111 auto input_cnode = dyn_cast<CNode>(input);
112 if (input_cnode != nullptr && input_cnode->size() > 1 && HasAbstractMonad(input_cnode->inputs().back())) {
113 return true;
114 }
115 }
116 }
117 return false;
118 }
119
NewUpdateStateWithAttach(const CNodePtr & update_state,const AnfNodePtr & attach)120 AnfNodePtr NewUpdateStateWithAttach(const CNodePtr &update_state, const AnfNodePtr &attach) {
121 auto fg = update_state->func_graph();
122 if (fg == nullptr) {
123 return nullptr;
124 }
125 auto new_update_state =
126 fg->NewCNode({update_state->input(kFirstInputIndex), update_state->input(kInputIndex), attach});
127 new_update_state->set_abstract(update_state->abstract());
128 new_update_state->set_scope(update_state->scope());
129 return new_update_state;
130 }
131
EliminateUpdateStateWithDepend(const CNodePtr & update_state)132 AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state) {
133 auto depend = update_state->input(kAttachIndex)->cast<CNodePtr>();
134 constexpr auto recur_2 = 2;
135 // If same Depend CNode is used by multiple UpdateState CNode, it may be replaced by previous elimination.
136 if (depend == nullptr) {
137 MS_LOG(DEBUG) << "UpdateState's input 2 Depend had been replaced: " << update_state->DebugString(recur_2);
138 return nullptr;
139 }
140 auto input_monad = depend->inputs().back();
141 if (!HasAbstractMonad(input_monad)) {
142 // Skip if Depend attach input is not a monad.
143 return nullptr;
144 }
145 auto update_monad = update_state->input(kInputIndex);
146 if (update_monad->abstract() == nullptr || !HasAbstractMonad(update_monad)) {
147 // Skip if UpdateState input is not a monad.
148 MS_LOG(INFO) << "Not a monad input: " << update_state->DebugString();
149 return nullptr;
150 }
151 // x1 = Depend(x, u0) <-not match-- <--match--
152 // u2 = UpdateState(u1, x1) <--match--
153 // u3 = UpdateState(u2, x1) <-not match--
154 // u3 and x1 should not match otherwise u1 will be lost; u2 and x1 can match.
155 if (IsPrimitiveCNode(update_monad, prim::kPrimUpdateState) &&
156 update_monad->cast<CNodePtr>()->input(kAttachIndex) == depend) {
157 MS_LOG(DEBUG) << "UpdateState should not be replaced. node: " << update_state->DebugString(recur_2);
158 return nullptr;
159 }
160 // Check monad inputs.
161 const auto &input_monad_abs = *(input_monad->abstract());
162 const auto &update_monad_abs = *(update_monad->abstract());
163 bool same_monad = (input_monad_abs == update_monad_abs);
164 if (!same_monad) {
165 // Skip if they are different monad (one is IO, another is U).
166 return nullptr;
167 }
168 // Now we can eliminate the UpdateState and Depend nodes.
169 auto mgr = GetManager(update_state);
170 if (mgr == nullptr) {
171 return nullptr;
172 }
173 // Replace Depend with its input.
174 if (depend->size() != kDependSize) {
175 MS_LOG(EXCEPTION) << "The Depend node has wrong inputs. " << depend->DebugString();
176 }
177 auto depend_input = depend->input(kInputIndex);
178 (void)mgr->Replace(depend, depend_input);
179 // Replace UpdateState node with the input monad of Depend.
180 return input_monad;
181 }
182
ExistEnvironGet(const FuncGraphManagerPtr & manager)183 bool ExistEnvironGet(const FuncGraphManagerPtr &manager) {
184 const FuncGraphSet &fgs = manager->func_graphs();
185 for (auto &fg : fgs) {
186 auto &nodes = fg->value_nodes();
187 bool exist = std::any_of(nodes.begin(), nodes.end(),
188 [](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvironGet); });
189 if (exist) {
190 return true;
191 }
192 }
193 return false;
194 }
195
196 // Convert:
197 // cnode1 = EnvironSet(EnvCreate(), para1, attach1)
198 // cnode2 = EnvironSet(cnode1, para2, attach2)
199 // ...
200 // cnode_n = EnvironSet(cnode_n-1, para_n-1, attach_n-1)
201 // maketuple = maketuple(cnode_n, ...)
202 // updatestate = updatestate(umonad, maketuple)
203 // To:
204 // new_maketuple = maketuple(..., attach1, attach2, ..., attach_n-1)
205 // new_updatestate = updatestate(umonad, new_maketuple)
EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr & update_state,const CNodePtr & make_tuple)206 AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_state, const CNodePtr &make_tuple) {
207 std::vector<AnfNodePtr> env_nodes;
208 std::vector<AnfNodePtr> new_maketuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
209 size_t input_size = make_tuple->size();
210 bool has_environ_set = false;
211 for (size_t i = 1; i < input_size; i++) {
212 auto node = make_tuple->input(i);
213 if (IsPrimitiveCNode(node, prim::kPrimEnvironSet) && OnlyUsedByOneNode(node, make_tuple)) {
214 (void)env_nodes.emplace_back(node);
215 has_environ_set = true;
216 } else if (node->isa<CNode>() && !IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
217 (void)new_maketuple_inputs.emplace_back(node);
218 }
219 }
220 if (!has_environ_set) {
221 return nullptr;
222 }
223 // Check EnvironSet in MakeTuple
224 auto mgr = GetManager(update_state);
225 if (mgr == nullptr) {
226 return nullptr;
227 }
228 // If exist EnvironGet, don't eliminate EnvironSet.
229 if (ExistEnvironGet(mgr)) {
230 return nullptr;
231 }
232 const size_t first_index = 1;
233 const size_t attach_index = 3;
234 const size_t no_env_node_size = new_maketuple_inputs.size();
235 while (!env_nodes.empty()) {
236 auto env = env_nodes.back();
237 env_nodes.pop_back();
238 if (!env->isa<CNode>()) {
239 continue;
240 }
241 auto env_cnode = env->cast<CNodePtr>();
242 auto env_input = env_cnode->input(first_index);
243 auto attach = env_cnode->input(attach_index);
244 if (IsPrimitiveCNode(env_input, prim::kPrimEnvironSet) && OnlyUsedByOneNode(env_input, env_cnode)) {
245 (void)env_nodes.emplace_back(env_input);
246 (void)new_maketuple_inputs.insert(new_maketuple_inputs.cbegin() + SizeToLong(no_env_node_size), attach);
247 }
248 }
249 if (new_maketuple_inputs.size() == 1) {
250 return nullptr;
251 }
252 auto fg = update_state->func_graph();
253 if (fg == nullptr) {
254 return nullptr;
255 }
256 abstract::AbstractBasePtrList element_abstracts;
257 (void)std::transform(new_maketuple_inputs.begin() + 1, new_maketuple_inputs.end(),
258 std::back_inserter(element_abstracts),
259 [](const AnfNodePtr &input) { return input->abstract(); });
260 auto new_make_tuple = fg->NewCNode(new_maketuple_inputs);
261 new_make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
262 auto new_update_state =
263 fg->NewCNode({update_state->input(kFirstInputIndex), update_state->input(kInputIndex), new_make_tuple});
264 new_update_state->set_abstract(update_state->abstract());
265 new_update_state->set_scope(update_state->scope());
266 return new_update_state;
267 }
268
EliminateUpdateStateMakeTupleWithUselessNode(const CNodePtr & update_state,const CNodePtr & make_tuple)269 AnfNodePtr EliminateUpdateStateMakeTupleWithUselessNode(const CNodePtr &update_state, const CNodePtr &make_tuple) {
270 if (make_tuple->size() != kMakeTupleSize) {
271 return nullptr;
272 }
273 AnfNodePtr attach_node = nullptr;
274 auto &first_input = make_tuple->input(kInputIndex);
275 auto &second_input = make_tuple->input(kAttachIndex);
276
277 // Eliminate useless make_tuple with 'DeadNode' or 'PolyNode'.
278 // UpdateState(u, MakeTuple(input, "DeadNode")) -> UpdateState(u, input)
279 if (IsDeadNode(second_input) || IsPolyNode(second_input)) {
280 return NewUpdateStateWithAttach(update_state, first_input);
281 }
282
283 // Eliminate useless make_tuple with useless Function.
284 // UpdateState(u, MakeTuple(Function, input) -> UpdateState(u, input)
285 // UpdateState(u, MakeTuple(input, Function) -> UpdateState(u, input)
286 if (IsValueNode<FuncGraph>(first_input) && OnlyUsedByOneNode(first_input, make_tuple)) {
287 return NewUpdateStateWithAttach(update_state, second_input);
288 }
289 if (IsValueNode<FuncGraph>(second_input) && OnlyUsedByOneNode(second_input, make_tuple)) {
290 return NewUpdateStateWithAttach(update_state, first_input);
291 }
292 return nullptr;
293 }
294
295 void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states,
296 std::vector<CNodePtr> *loads);
297 void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states,
298 std::vector<CNodePtr> *loads);
299
300 // Search consecutive load nodes from UpdateState node.
GetLoadsFromUpdateState(const CNodePtr & update_state,std::vector<CNodePtr> * update_states,std::vector<CNodePtr> * loads)301 void GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
302 std::vector<CNodePtr> *loads) {
303 auto &attach = update_state->input(kAttachIndex);
304 if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
305 GetLoadsFollowLoad(update_state, attach->cast<CNodePtr>(), update_states, loads);
306 } else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
307 GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
308 }
309 }
310
GetLoadsFollowLoad(const CNodePtr & update_state,const CNodePtr & load,std::vector<CNodePtr> * update_states,std::vector<CNodePtr> * loads)311 void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states,
312 std::vector<CNodePtr> *loads) {
313 (void)update_states->emplace_back(update_state);
314 (void)loads->emplace_back(load);
315 auto &load_attach = load->input(kAttachIndex);
316 if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) {
317 GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads);
318 }
319 }
320
GetLoadsFollowTuple(const CNodePtr & update_state,const CNodePtr & make_tuple,std::vector<CNodePtr> * update_states,std::vector<CNodePtr> * loads)321 void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states,
322 std::vector<CNodePtr> *loads) {
323 if (!OnlyUsedByOneNode(make_tuple, update_state)) {
324 // UpdateState should be the only user of make_tuple.
325 return;
326 }
327 auto &inputs = make_tuple->inputs();
328 const auto &monad = update_state->input(kInputIndex);
329 bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), [&monad](const AnfNodePtr &node) {
330 // Tuple element should be Load and use same monad that UpdateState used.
331 return (IsPrimitiveCNode(node, prim::kPrimLoad) && node->cast<CNodePtr>()->input(kAttachIndex) == monad);
332 });
333 if (!is_all_load) {
334 // Stop if not all tuple elements are load nodes and use same monad.
335 return;
336 }
337 // Add update_state and load nodes.
338 (void)update_states->emplace_back(update_state);
339 for (size_t i = 1; i < inputs.size(); ++i) {
340 auto &element = inputs.at(i);
341 (void)loads->emplace_back(element->cast<CNodePtr>());
342 }
343 // Follow prev update state if found.
344 auto prev_node = update_state->input(kInputIndex);
345 if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) {
346 GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads);
347 }
348 }
349
350 // Create a MakeTuple node before UpdateState for same nodes, if there are more than 1 node used.
MakeTupleForSameNodes(const FuncGraphPtr & fg,const CNodePtr & old_update_state,const AnfNodePtrList & make_tuple_inputs)351 AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_update_state,
352 const AnfNodePtrList &make_tuple_inputs) {
353 constexpr size_t kOneNodeInputSize = 2;
354 if (make_tuple_inputs.size() == kOneNodeInputSize) {
355 // We don't need make_tuple since there is only one load.
356 return make_tuple_inputs.at(1);
357 }
358 // Create MakeTuple cnode.
359 auto make_tuple = fg->NewCNode(make_tuple_inputs);
360 // Set abstract for the MakeTuple node.
361 abstract::AbstractBasePtrList element_abstracts;
362 std::transform(make_tuple_inputs.begin() + 1, make_tuple_inputs.end(), std::back_inserter(element_abstracts),
363 [](const AnfNodePtr &input) { return input->abstract(); });
364 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
365 make_tuple->set_scope(old_update_state->scope());
366 return make_tuple;
367 }
368
369 // Remove all nodes related to UpdateStates, if they're redundant.
EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> & update_states)370 void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) {
371 if (update_states.empty()) {
372 return;
373 }
374 auto mgr = GetManager(update_states[0]);
375 if (mgr == nullptr) {
376 return;
377 }
378
379 // 1. Remove the use of UpdateState nodes, except the last one.
380 for (auto i = update_states.size() - 1; i > 0; i--) {
381 auto &us = update_states[i];
382 (void)mgr->Replace(us, us->input(kInputIndex));
383 }
384
385 // 2. Remove the Depend users of last UpdateState node.
386 auto &node_users = mgr->node_users();
387 auto iter = node_users.find(update_states[0]);
388 if (iter == node_users.end()) {
389 return;
390 }
391 auto &us_users = iter->second;
392 if (us_users.size() < 2) {
393 return;
394 }
395 std::vector<AnfNodePtr> depend_nodes;
396 for (auto &user : us_users) {
397 if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kAttachIndex) {
398 (void)depend_nodes.emplace_back(user.first);
399 }
400 }
401 if (depend_nodes.empty()) {
402 return;
403 }
404 ssize_t end = 0;
405 // If all users are Depend CNode.
406 if (depend_nodes.size() == us_users.size()) {
407 end = 1;
408 // Set abstract value for reserved Depend node.
409 auto &reserved_depend_node = depend_nodes[0];
410 auto &primary_node = reserved_depend_node->cast<CNodePtr>()->input(kInputIndex);
411 reserved_depend_node->set_abstract(primary_node->abstract());
412 }
413 for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) {
414 const auto &depend_node = depend_nodes[i];
415 const auto &depend_cnode = depend_node->cast<CNodePtr>();
416 (void)mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex));
417 }
418 }
419
420 // Eliminate UpdateStates for consecutive Loads.
421 // Convert:
422 // x1 = Load(x1, u)
423 // u1 = UpdateState(u, x1)
424 // x2 = Load(x2, u1)
425 // u2 = UpdateState(u1, x2)
426 // ...
427 // xN = Load(xN, u(N-1))
428 // uN = UpdateState(u(N-1), xN)
429 // To:
430 // x1 = Load(x1, u)
431 // x2 = Load(x2, u)
432 // ...
433 // xN = Load(xN, u)
434 // t = make_tuple(x1, x2, ... , xN)
435 // u1 = UpdateState(u, t)
EliminateUpdateStateForLoads(const CNodePtr & old_update_state,const std::vector<CNodePtr> & update_states,const std::vector<CNodePtr> & loads)436 AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
437 const std::vector<CNodePtr> &loads) {
438 auto fg = old_update_state->func_graph();
439 if (fg == nullptr) {
440 return nullptr;
441 }
442 auto mgr = fg->manager();
443 if (mgr == nullptr) {
444 return nullptr;
445 }
446 // Prepare tuple elements from Load nodes.
447 AnfNodePtrList make_tuple_inputs;
448 std::set<AnfNodePtr> loaded_para_set;
449 make_tuple_inputs.reserve(loads.size() + 1);
450 (void)make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
451 auto input_monad = loads.back()->input(kAttachIndex);
452 for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) {
453 auto &load = *iter;
454 auto result = loaded_para_set.emplace(load->input(kInputIndex));
455 const bool is_new_load = result.second;
456 if (is_new_load) {
457 // Put Load node as a tuple element, if the parameter is not loaded by other Load.
458 (void)make_tuple_inputs.emplace_back(load);
459 }
460 auto load_attach = load->input(kAttachIndex);
461 if (load_attach != input_monad) {
462 // Set all load use same input monad.
463 (void)mgr->Replace(load_attach, input_monad);
464 }
465 }
466
467 EliminateUselessNodesForUpdateStates(update_states);
468
469 if (make_tuple_inputs.size() == 1) {
470 // This should not happen.
471 MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);
472 return nullptr;
473 }
474 // Create the new UpdateState node with a MakeTuple, replace the old UpdateStateNode.
475 auto attach = MakeTupleForSameNodes(fg, old_update_state, make_tuple_inputs);
476 auto update_state = NewValueNode(prim::kPrimUpdateState);
477 auto new_update_state = fg->NewCNode({update_state, input_monad, attach});
478 new_update_state->set_abstract(old_update_state->abstract());
479 new_update_state->set_scope(old_update_state->scope());
480 return new_update_state;
481 }
482
483 // Eliminate UpdateStates between Assign nodes.
484 // Covert:
485 // a1 = Assign(para1, value1, u1)
486 // u2 = UpdateState(u1, a1)
487 // a2 = Assign(para2, value2, u2) # para1 != para2, para1 != value2, para2 != value1
488 // u3 = UpdateState(u2, a2)
489 // To:
490 // a1 = Assign(para1, value1, u1)
491 // a2 = Assign(para2, value2, u1)
492 // t = MakeTuple(a1, a2)
493 // u3 = UpdateState(u1, t)
EliminateUpdateStateBetweenAssigns(const CNodePtr & update_state,const AnfNodePtr & assign)494 AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, const AnfNodePtr &assign) {
495 auto a2_cnode = assign->cast<CNodePtr>();
496 auto u2 = a2_cnode->input(kAssignMonadInputIndex);
497 auto a1 = u2->cast<CNodePtr>()->input(kAttachIndex);
498 if (IsPrimitiveCNode(a1, prim::kPrimAssign)) {
499 auto a1_cnode = a1->cast<CNodePtr>();
500 if (a1_cnode->size() != kAssignSize) {
501 return nullptr;
502 }
503 auto para1 = a1_cnode->input(kInputIndex);
504 auto value1 = a1_cnode->input(kAttachIndex);
505 auto para2 = a2_cnode->input(kInputIndex);
506 auto value2 = a2_cnode->input(kAttachIndex);
507 auto u1 = a1_cnode->input(kAssignMonadInputIndex);
508 if (para1 != para2 && para1 != value2 && para2 != value1) {
509 auto fg = update_state->func_graph();
510 MS_EXCEPTION_IF_NULL(fg);
511 auto mgr = fg->manager();
512 MS_EXCEPTION_IF_NULL(mgr);
513 (void)mgr->Replace(u2, u1);
514
515 AnfNodePtrList make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1, assign};
516 auto make_tuple = MakeTupleForSameNodes(fg, update_state, make_tuple_inputs);
517 auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, make_tuple});
518 new_update_state->set_abstract(update_state->abstract());
519 new_update_state->set_scope(update_state->scope());
520 return new_update_state;
521 }
522 }
523 return nullptr;
524 }
525
526 // Eliminate Load before Assign nodes.
527 // Covert:
528 // load = Load(parameter)
529 // a = Assign(load, value, u)
530 // To:
531 // a = Assign(parameter, value, u)
EliminateLoadBeforeAssigns(const FuncGraphManagerPtr & manager,const CNodePtr & update_state)532 bool EliminateLoadBeforeAssigns(const FuncGraphManagerPtr &manager, const CNodePtr &update_state) {
533 auto &attach = update_state->input(kAttachIndex);
534 // UpdateState(u, Assign(para, value, u))
535 if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
536 auto assign = attach->cast<CNodePtr>();
537 if (assign->size() != kAssignSize) {
538 return false;
539 }
540 // If assign's first input is load, eliminate load.
541 auto &ref_node = assign->input(kAssignRefInputIndex);
542 if (IsPrimitiveCNode(ref_node, prim::kPrimLoad)) {
543 auto load = ref_node->cast<CNodePtr>();
544 auto ¶meter = load->input(kInputIndex);
545 // If Load used by other nodes, keep load node.
546 auto assign_cnode = assign->cast<CNodePtr>();
547 if (OnlyUsedByOneNode(ref_node, assign_cnode)) {
548 (void)manager->Replace(ref_node, parameter);
549 } else {
550 manager->SetEdge(assign, kInputIndex, parameter);
551 }
552 return true;
553 }
554 }
555 return false;
556 }
557
558 // Eliminate UpdateStates between MakeTuple and Assign.
559 // Covert:
560 // a1 = Assign(para1, value1, u1)
561 // a2 = Assign(para2, value2, u2) # u2 == u1
562 // t = MakeTuple(a1, a2)
563 // u3 = UpdateState(u1, t)
564 // a3 = Assign(para3, value3, u3) # para3 != para1, para3 != para2, value3 != para1, value3 != para2
565 // # value1 != para3, value2 != para3
566 // u4 = UpdateState(u3, a3)
567 // To:
568 // a1 = Assign(para1, value1, u1)
569 // a2 = Assign(para2, value2, u1)
570 // a3 = Assign(para3, value3, u1)
571 // t = MakeTuple(a1, a2, a3)
572 // u4 = UpdateState(u1, t)
EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr & update_state,const AnfNodePtr & assign)573 AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &assign) {
574 auto a3_cnode = assign->cast<CNodePtr>();
575 auto u3 = a3_cnode->input(kAssignMonadInputIndex);
576 auto u3_cnode = u3->cast<CNodePtr>();
577 auto make_tuple = u3_cnode->input(kAttachIndex);
578 if (IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple) && OnlyUsedByOneNode(make_tuple, u3_cnode)) {
579 auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
580 if (make_tuple_cnode->size() != kMakeTupleSize) {
581 return nullptr;
582 }
583 auto a1 = make_tuple_cnode->input(kInputIndex);
584 auto a2 = make_tuple_cnode->input(kAttachIndex);
585 if (IsPrimitiveCNode(a1, prim::kPrimAssign) && IsPrimitiveCNode(a2, prim::kPrimAssign)) {
586 auto a1_cnode = a1->cast<CNodePtr>();
587 auto a2_cnode = a2->cast<CNodePtr>();
588 if (a1_cnode->size() != kAssignSize || a2_cnode->size() != kAssignSize) {
589 return nullptr;
590 }
591 auto para1 = a1_cnode->input(kInputIndex);
592 auto value1 = a1_cnode->input(kAttachIndex);
593 auto u1 = a1_cnode->input(kAssignMonadInputIndex);
594 auto para2 = a2_cnode->input(kInputIndex);
595 auto value2 = a2_cnode->input(kAttachIndex);
596 auto u2 = a2_cnode->input(kAssignMonadInputIndex);
597 auto para3 = a3_cnode->input(kInputIndex);
598 auto value3 = a3_cnode->input(kAttachIndex);
599 bool replace_judge = (u1 == u2) && (para1 != para3) && (para1 != value3) && (para2 != para3) &&
600 (para2 != value3) && (value1 != para3) && (value2 != para3);
601 if (replace_judge) {
602 auto fg = update_state->func_graph();
603 MS_EXCEPTION_IF_NULL(fg);
604 auto mgr = fg->manager();
605 MS_EXCEPTION_IF_NULL(mgr);
606 (void)mgr->Replace(u3, u1);
607
608 AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), make_tuple_cnode->input(kInputIndex),
609 make_tuple_cnode->input(kAttachIndex), assign};
610 auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
611 (void)mgr->Replace(make_tuple, new_make_tuple);
612 auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
613 new_update_state->set_abstract(update_state->abstract());
614 new_update_state->set_scope(update_state->scope());
615 return new_update_state;
616 }
617 }
618 }
619 return nullptr;
620 }
621
622 // Eliminate UpdateStates between Assign and MakeTuple.
623 // Covert:
624 // a1 = Assign(para1, value1, u1)
625 // u2 = UpdateState(u1_1, a1) # u1_1 == u1
626 // a2 = Assign(para2, value2, u2)
627 // a3 = Assign(para3, value3, u3) # u2 == u3
628 // t = MakeTuple(a2, a3)
629 // u4 = UpdateState(u3, t) # para3 != para1, para3 != para2, value3 != para1, value3 != para2
630 // # value1 != para3, value1 != para3
631 // To:
632 // a1 = Assign(para1, value1, u1)
633 // a2 = Assign(para2, value2, u1)
634 // a3 = Assign(para3, value3, u1)
635 // t = MakeTuple(a1, a2, a3)
636 // u4 = UpdateState(u1, t)
EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr & update_state,const AnfNodePtr & make_tuple)637 AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_state, const AnfNodePtr &make_tuple) {
638 auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
639 if (make_tuple_cnode->size() != kMakeTupleSize || !OnlyUsedByOneNode(make_tuple, update_state)) {
640 return nullptr;
641 }
642 auto a2 = make_tuple_cnode->input(kInputIndex);
643 auto a3 = make_tuple_cnode->input(kAttachIndex);
644 if (IsPrimitiveCNode(a2, prim::kPrimAssign) && IsPrimitiveCNode(a3, prim::kPrimAssign)) {
645 auto a2_cnode = a2->cast<CNodePtr>();
646 auto a3_cnode = a3->cast<CNodePtr>();
647 if (a2_cnode->size() != kAssignSize || a3_cnode->size() != kAssignSize) {
648 return nullptr;
649 }
650 auto para2 = a2_cnode->input(kInputIndex);
651 auto value2 = a2_cnode->input(kAttachIndex);
652 auto u2 = a2_cnode->input(kAssignMonadInputIndex);
653 if (!IsPrimitiveCNode(u2, prim::kPrimUpdateState) || !OnlyUsedByTwoNode(u2, a2, a3)) {
654 return nullptr;
655 }
656 auto para3 = a3_cnode->input(kInputIndex);
657 auto value3 = a3_cnode->input(kAttachIndex);
658 auto u3 = a3_cnode->input(kAssignMonadInputIndex);
659 if (u2 == u3) {
660 auto u2_cnode = u2->cast<CNodePtr>();
661 MS_EXCEPTION_IF_NULL(u2_cnode);
662 auto u1 = u2_cnode->input(kInputIndex);
663 auto a1 = u2_cnode->input(kAttachIndex);
664 if (IsPrimitiveCNode(a1, prim::kPrimAssign)) {
665 auto a1_cnode = a1->cast<CNodePtr>();
666 MS_EXCEPTION_IF_NULL(a1_cnode);
667 if (a1_cnode->size() != kAssignSize) {
668 return nullptr;
669 }
670 auto para1 = a1_cnode->input(kInputIndex);
671 auto value1 = a1_cnode->input(kAttachIndex);
672 auto u1_1 = a1_cnode->input(kAssignMonadInputIndex);
673 bool replace_judge = (u1 == u1_1) && (para1 != para2) && (para1 != para3) && (para1 != value2) &&
674 (para1 != value3) && (para2 != value1) && (para3 != value1);
675 if (replace_judge) {
676 auto fg = update_state->func_graph();
677 MS_EXCEPTION_IF_NULL(fg);
678 auto mgr = fg->manager();
679 MS_EXCEPTION_IF_NULL(mgr);
680 (void)mgr->Replace(u2, u1);
681 AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1,
682 make_tuple_cnode->input(kInputIndex),
683 make_tuple_cnode->input(kAttachIndex)};
684 auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
685 (void)mgr->Replace(make_tuple, new_make_tuple);
686 auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
687 new_update_state->set_abstract(update_state->abstract());
688 new_update_state->set_scope(update_state->scope());
689 return new_update_state;
690 }
691 }
692 }
693 }
694 return nullptr;
695 }
696
EliminateUpdateStateForAssign(const CNodePtr & update_state)697 AnfNodePtr EliminateUpdateStateForAssign(const CNodePtr &update_state) {
698 // UpdateState(u, MakeTuple)
699 auto &attach = update_state->input(kAttachIndex);
700 if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
701 return EliminateUpdateStateBetweenMakeTupleAssign(update_state, attach);
702 }
703 // UpdateState(u, Assign(para, value, u))
704 if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
705 auto assign = attach->cast<CNodePtr>();
706 if (assign->size() != kAssignSize) {
707 return nullptr;
708 }
709 auto u = assign->input(kAssignMonadInputIndex);
710 // u is UpdateState, u only be used by assign and update_state.
711 if (IsPrimitiveCNode(u, prim::kPrimUpdateState) && OnlyUsedByTwoNode(u, assign, update_state)) {
712 auto u_attach = u->cast<CNodePtr>()->input(kAttachIndex);
713 if (IsPrimitiveCNode(u_attach, prim::kPrimAssign)) {
714 return EliminateUpdateStateBetweenAssigns(update_state, assign);
715 }
716 if (IsPrimitiveCNode(u_attach, prim::kPrimMakeTuple)) {
717 return EliminateUpdateStateBetweenAssignMakeTuple(update_state, assign);
718 }
719 }
720 }
721 return nullptr;
722 }
723 } // namespace
724
725 // Eliminate useless node that only used by associated update_state.
operator ()(const OptimizerPtr &,const AnfNodePtr & node)726 AnfNodePtr UpdatestateUselessNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
727 auto update_state_node = dyn_cast<CNode>(node);
728 if (update_state_node == nullptr || update_state_node->size() != kUpdateStateSize) {
729 return nullptr;
730 }
731
732 // If update_state is the only user of partial/load, replace it with the input monad.
733 // UpdateState(u, Partial) -> u
734 // UpdateState(u, Load) -> u
735 // UpdateState(u, FuncGraph) -> u
736 auto &attach = update_state_node->input(kAttachIndex);
737 if (IsPrimitiveCNode(attach, prim::kPrimPartial) || IsPrimitiveCNode(attach, prim::kPrimLoad) ||
738 IsValueNode<FuncGraph>(attach)) {
739 // Replace UpdateState with the input monad.
740 if (OnlyUsedByOneNode(attach, update_state_node)) {
741 return update_state_node->input(kInputIndex);
742 }
743 return nullptr;
744 }
745
746 // Handling the case where the second input of update_state is make_tuple which contains DeadNode or useless function.
747 // UpdateState(u, MakeTuple(input, "Dead Node")) -> UpdateState(u, input)
748 // UpdateState(u, MakeTuple(Function, input) -> UpdateState(u, input)
749 // UpdateState(u, MakeTuple(input, Function) -> UpdateState(u, input)
750 if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
751 auto new_node = EliminateUpdateStateMakeTupleWithUselessNode(update_state_node, attach->cast<CNodePtr>());
752 if (new_node != nullptr) {
753 return new_node;
754 }
755 return EliminateUpdateStateMakeTupleWithUselessEnv(update_state_node, attach->cast<CNodePtr>());
756 }
757 return nullptr;
758 }
759
760 // Eliminate UpdateState that attaches a pure (no-side-effect) node.
761 // Convert:
762 // x = pure_node(args) # no side effect
763 // u1 = update_state(u, x)
764 // user(u1)
765 // To:
766 // x = pure_node(args)
767 // user(u)
operator ()(const OptimizerPtr &,const AnfNodePtr & node)768 AnfNodePtr UpdatestatePureNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
769 auto update_state_node = dyn_cast<CNode>(node);
770 if (update_state_node == nullptr || update_state_node->size() != kUpdateStateSize) {
771 return nullptr;
772 }
773
774 auto &attach = update_state_node->input(kAttachIndex);
775 // update_state(u, param) or update_state(u, value_node) is redundant.
776 auto cnode = dyn_cast<CNode>(attach);
777 if (cnode == nullptr) {
778 return update_state_node->input(kInputIndex);
779 }
780 const auto &first_input = cnode->input(0);
781 bool is_special_ops = cnode->IsApply(prim::kPrimTupleGetItem) || cnode->IsApply(prim::kPrimDepend) ||
782 cnode->IsApply(prim::kPrimPartial) || cnode->IsApply(prim::kPrimMakeTuple) ||
783 cnode->IsApply(prim::kPrimCall) || IsValueNode<FuncGraph>(first_input) ||
784 IsPrimitiveCNode(first_input, prim::kPrimJ) || IsPrimitiveCNode(first_input, prim::kPrimVmap) ||
785 IsPrimitiveCNode(first_input, prim::kPrimTaylor) ||
786 IsPrimitiveCNode(first_input, prim::kPrimShard);
787 if (is_special_ops) {
788 return nullptr;
789 }
790 if (CheckHasMonadInput(cnode)) {
791 return nullptr;
792 }
793 return update_state_node->input(kInputIndex);
794 }
795
796 // Eliminate redundant UpdateState/Depend pair nodes caused by inline.
797 // Convert:
798 // x1 = Depend(x, u0)
799 // u1 = UpdateState(u', x1)
800 // out = x_user(x1)
801 // u2 = u_user(u1)
802 // To:
803 // out = x_user(x)
804 // u2 = u_user(u0)
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)805 bool UpdatestateDependEliminater::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
806 // Filter nodes that do not match UpdateState(u, Depend).
807 auto filter = [](const AnfNodePtr &node) {
808 if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
809 auto update_state = node->cast<CNodePtr>();
810 if (update_state->size() != kUpdateStateSize) {
811 return true;
812 }
813 auto &attach = update_state->input(kAttachIndex);
814 if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
815 return false;
816 }
817 }
818 return true;
819 };
820
821 bool change = false;
822 auto manager = optimizer->manager();
823 MS_EXCEPTION_IF_NULL(manager);
824 auto &all_nodes = manager->all_nodes();
825 auto todo = TopoSort(func_graph->get_return(), SuccDeeperSimple);
826 for (auto &node : todo) {
827 if (node == nullptr || !all_nodes.contains(node) || filter(node)) {
828 continue;
829 }
830 auto new_node = EliminateUpdateStateWithDepend(node->cast<CNodePtr>());
831 if (new_node != nullptr) {
832 (void)manager->Replace(node, new_node);
833 change = true;
834 }
835 }
836 return change;
837 }
838
839 // Eliminate UpdateStates for consecutive Assign.
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)840 bool UpdatestateAssignEliminater::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
841 // Filter nodes that do not match UpdateState(u, Assign) or UpdateState(u, MakeTuple).
842 auto filter = [](const AnfNodePtr &node) {
843 if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
844 auto update_state = node->cast<CNodePtr>();
845 if (update_state->size() != kUpdateStateSize) {
846 return true;
847 }
848 auto &attach = update_state->input(kAttachIndex);
849 if (IsPrimitiveCNode(attach, prim::kPrimAssign) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
850 return false;
851 }
852 }
853 return true;
854 };
855
856 bool change = false;
857 auto manager = optimizer->manager();
858 MS_EXCEPTION_IF_NULL(manager);
859 auto &all_nodes = manager->all_nodes();
860 std::vector<AnfNodePtr> todo = TopoSort(func_graph->get_return(), SuccDeeperSimple);
861 for (auto &node : todo) {
862 if (node == nullptr || !all_nodes.contains(node) || filter(node)) {
863 continue;
864 }
865 auto new_node = EliminateUpdateStateForAssign(node->cast<CNodePtr>());
866 if (new_node != nullptr) {
867 (void)manager->Replace(node, new_node);
868 change = true;
869 }
870 bool load_eliminate = EliminateLoadBeforeAssigns(manager, node->cast<CNodePtr>());
871 change = change || load_eliminate;
872 }
873 return change;
874 }
875
876 // Eliminate UpdateStates for consecutive Loads.
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)877 bool UpdatestateLoadsEliminater::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
878 // Filter nodes that do not match UpdateState(u, Load) or UpdateState(u, MakeTuple).
879 auto filter = [](const AnfNodePtr &node) {
880 if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
881 auto update_state = node->cast<CNodePtr>();
882 if (update_state->size() != kUpdateStateSize) {
883 return true;
884 }
885 auto &attach = update_state->input(kAttachIndex);
886 if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
887 return false;
888 }
889 }
890 return true;
891 };
892
893 bool change = false;
894 auto manager = optimizer->manager();
895 MS_EXCEPTION_IF_NULL(manager);
896 auto &all_nodes = manager->all_nodes();
897 std::vector<AnfNodePtr> todo = TopoSort(func_graph->get_return(), SuccDeeperSimple);
898 for (auto &node : todo) {
899 if (node == nullptr || !all_nodes.contains(node) || filter(node)) {
900 continue;
901 }
902 std::vector<CNodePtr> update_states;
903 std::vector<CNodePtr> loads;
904 auto update_state_node = node->cast<CNodePtr>();
905 GetLoadsFromUpdateState(update_state_node, &update_states, &loads);
906 if (update_states.size() > 1 && loads.size() > 1) {
907 auto new_node = EliminateUpdateStateForLoads(update_state_node, update_states, loads);
908 if (new_node != nullptr) {
909 (void)manager->Replace(node, new_node);
910 change = true;
911 }
912 }
913 }
914 return change;
915 }
916
917 // Eliminate Monad parameter for switch call.
918 // Convert:
919 // x = Load(x, u)
920 // u = UpdateState(u, x)
921 // ...
922 // g1 = Partial(...)
923 // g2 = Partial(...)
924 // s = switch(cond, g1, g2)
925 // res = s(u)
926 // To:
927 // x = Load(x, u)
928 // u = UpdateState(u, x)
929 // ...
930 // g1 = Partial(..., u)
931 // g2 = Partial(..., u)
932 // s = switch(cond, g1, g2)
933 // res = s()
operator ()(const OptimizerPtr &,const AnfNodePtr & node)934 AnfNodePtr SwitchCallMonadParameterEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
935 const CNodePtr &switch_call = dyn_cast<CNode>(node);
936 if (switch_call == nullptr) {
937 return nullptr;
938 }
939 auto fg = switch_call->func_graph();
940 if (fg == nullptr) {
941 return nullptr;
942 }
943 auto mgr = fg->manager();
944 if (mgr == nullptr) {
945 return nullptr;
946 }
947 const size_t switch_call_input_size = 2;
948 if (switch_call->size() < switch_call_input_size) {
949 return nullptr;
950 }
951 constexpr size_t primary_index = 0;
952 auto switch_node = switch_call->input(primary_index);
953 if (!IsPrimitiveCNode(switch_node, prim::kPrimSwitch)) {
954 return nullptr;
955 }
956 MS_LOG(DEBUG) << "Found switch call with monad parameter, " << switch_call->DebugString();
957 auto switch_cnode = dyn_cast<CNode>(switch_node);
958 if (switch_cnode == nullptr) {
959 MS_LOG(EXCEPTION) << "switch node cast to CNode failed, " << switch_node->DebugString();
960 }
961 constexpr size_t condition_index = 1;
962 constexpr size_t first_fg_index = 2;
963 constexpr size_t second_fg_index = 3;
964 auto fg1_node = switch_cnode->input(first_fg_index);
965 auto fg2_node = switch_cnode->input(second_fg_index);
966 auto build_partial = [&fg, &switch_call](const AnfNodePtr &node) {
967 CNodePtr new_node;
968 if (IsPrimitiveCNode(node, prim::kPrimPartial)) { // Node is already Partial CNode.
969 new_node = fg->NewCNode(node->cast<CNodePtr>()->inputs());
970 } else { // Node is FuncGraph ValueNode.
971 new_node = fg->NewCNode({NewValueNode(prim::kPrimPartial), node});
972 }
973 constexpr size_t args_start_index = 1;
974 for (size_t i = args_start_index; i < switch_call->size(); i++) {
975 new_node->add_input(switch_call->input(i));
976 }
977 // partial's abstract is same with first input.
978 new_node->set_abstract(new_node->input(1)->abstract());
979 return new_node;
980 };
981 fg1_node = build_partial(fg1_node);
982 fg2_node = build_partial(fg2_node);
983 auto cond = switch_cnode->input(condition_index);
984 auto new_switch_cnode = fg->NewCNode({NewValueNode(prim::kPrimSwitch), cond, fg1_node, fg2_node});
985 auto new_switch_call = fg->NewCNode({new_switch_cnode});
986 new_switch_cnode->set_abstract(switch_node->abstract());
987 new_switch_call->set_abstract(switch_call->abstract());
988 return new_switch_call;
989 }
990 } // namespace mindspore::opt::irpass
991