• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <vector>
23 
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/optimizer/irpass.h"
29 #include "frontend/optimizer/optimizer_caller.h"
30 #include "frontend/optimizer/anf_visitor.h"
31 #include "ir/pattern_matcher.h"
32 
33 namespace mindspore::opt::irpass {
34 namespace {
35 // data = Load(input, attach)
36 // data = Depend(input, attach)
37 // monad = UpdateState(input, attach)
38 constexpr size_t kFirstInputIndex = 0;
39 constexpr size_t kInputIndex = 1;
40 constexpr size_t kAttachIndex = 2;
41 constexpr size_t kMakeTupleSize = 3;
42 constexpr size_t kDependSize = 3;
43 constexpr size_t kUpdateStateSize = 3;
44 constexpr size_t kAssignSize = 4;
45 constexpr size_t kAssignRefInputIndex = 1;
46 constexpr size_t kAssignMonadInputIndex = 3;
47 
GetManager(const AnfNodePtr & node)48 FuncGraphManagerPtr GetManager(const AnfNodePtr &node) {
49   auto fg = node->func_graph();
50   if (fg == nullptr) {
51     return nullptr;
52   }
53   return fg->manager();
54 }
55 
56 // Return true if the node(be_used_node) is only used by the given node.
OnlyUsedByOneNode(const AnfNodePtr & be_used_node,const CNodePtr & given_node)57 bool OnlyUsedByOneNode(const AnfNodePtr &be_used_node, const CNodePtr &given_node) {
58   auto mgr = GetManager(given_node);
59   if (mgr == nullptr) {
60     return false;
61   }
62   auto &node_users = mgr->node_users();
63   auto iter = node_users.find(be_used_node);
64   if (iter == node_users.end()) {
65     return false;
66   }
67   auto &partial_users = iter->second;
68   return (partial_users.size() == 1) && (partial_users.front().first == given_node);
69 }
70 
71 // Return true if the node(be_used_node) is only used by the given two nodes(first_node and second_node).
OnlyUsedByTwoNode(const AnfNodePtr & be_used_node,const AnfNodePtr & first_node,const AnfNodePtr & second_node)72 bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_node, const AnfNodePtr &second_node) {
73   auto mgr = GetManager(be_used_node);
74   if (mgr == nullptr || first_node == second_node) {
75     return false;
76   }
77   auto &node_users = mgr->node_users();
78   auto iter = node_users.find(be_used_node);
79   if (iter == node_users.end()) {
80     return false;
81   }
82   constexpr size_t partial_users_cnt = 2;
83   auto &partial_users = iter->second;
84   if (partial_users.size() != partial_users_cnt) {
85     return false;
86   }
87   const auto &first_user = partial_users.front().first;
88   const auto &second_user = partial_users.back().first;
89   return (first_user == first_node && second_user == second_node) ||
90          (first_user == second_node && second_user == first_node);
91 }
92 
93 // Determine whether there is a monad in the inputs of the node.
CheckHasMonadInput(const CNodePtr & cnode)94 bool CheckHasMonadInput(const CNodePtr &cnode) {
95   // If the last input is a monad, means the attach node has side-effect and
96   // we should keep UpdateState; otherwise, we will remove the UpdateState.
97   if (cnode->size() > 1 && HasAbstractMonad(cnode->inputs().back())) {
98     return true;
99   }
100 
101   // Check the inputs of Call/Switch/SwitchLayer.
102   auto first_input_node = cnode->input(kFirstInputIndex);
103   if (IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
104       IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
105     for (auto &weak_input : first_input_node->cast<CNodePtr>()->weak_inputs()) {
106       auto input = weak_input.lock();
107       MS_EXCEPTION_IF_NULL(input);
108       if (HasAbstractMonad(input)) {
109         return true;
110       }
111       auto input_cnode = dyn_cast<CNode>(input);
112       if (input_cnode != nullptr && input_cnode->size() > 1 && HasAbstractMonad(input_cnode->inputs().back())) {
113         return true;
114       }
115     }
116   }
117   return false;
118 }
119 
NewUpdateStateWithAttach(const CNodePtr & update_state,const AnfNodePtr & attach)120 AnfNodePtr NewUpdateStateWithAttach(const CNodePtr &update_state, const AnfNodePtr &attach) {
121   auto fg = update_state->func_graph();
122   if (fg == nullptr) {
123     return nullptr;
124   }
125   auto new_update_state =
126     fg->NewCNode({update_state->input(kFirstInputIndex), update_state->input(kInputIndex), attach});
127   new_update_state->set_abstract(update_state->abstract());
128   new_update_state->set_scope(update_state->scope());
129   return new_update_state;
130 }
131 
EliminateUpdateStateWithDepend(const CNodePtr & update_state)132 AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state) {
133   auto depend = update_state->input(kAttachIndex)->cast<CNodePtr>();
134   constexpr auto recur_2 = 2;
135   // If same Depend CNode is used by multiple UpdateState CNode, it may be replaced by previous elimination.
136   if (depend == nullptr) {
137     MS_LOG(DEBUG) << "UpdateState's input 2 Depend had been replaced: " << update_state->DebugString(recur_2);
138     return nullptr;
139   }
140   auto input_monad = depend->inputs().back();
141   if (!HasAbstractMonad(input_monad)) {
142     // Skip if Depend attach input is not a monad.
143     return nullptr;
144   }
145   auto update_monad = update_state->input(kInputIndex);
146   if (update_monad->abstract() == nullptr || !HasAbstractMonad(update_monad)) {
147     // Skip if UpdateState input is not a monad.
148     MS_LOG(INFO) << "Not a monad input: " << update_state->DebugString();
149     return nullptr;
150   }
151   // x1 = Depend(x, u0)        <-not match--   <--match--
152   // u2 = UpdateState(u1, x1)                  <--match--
153   // u3 = UpdateState(u2, x1)  <-not match--
154   // u3 and x1 should not match otherwise u1 will be lost; u2 and x1 can match.
155   if (IsPrimitiveCNode(update_monad, prim::kPrimUpdateState) &&
156       update_monad->cast<CNodePtr>()->input(kAttachIndex) == depend) {
157     MS_LOG(DEBUG) << "UpdateState should not be replaced. node: " << update_state->DebugString(recur_2);
158     return nullptr;
159   }
160   // Check monad inputs.
161   const auto &input_monad_abs = *(input_monad->abstract());
162   const auto &update_monad_abs = *(update_monad->abstract());
163   bool same_monad = (input_monad_abs == update_monad_abs);
164   if (!same_monad) {
165     // Skip if they are different monad (one is IO, another is U).
166     return nullptr;
167   }
168   // Now we can eliminate the UpdateState and Depend nodes.
169   auto mgr = GetManager(update_state);
170   if (mgr == nullptr) {
171     return nullptr;
172   }
173   // Replace Depend with its input.
174   if (depend->size() != kDependSize) {
175     MS_LOG(EXCEPTION) << "The Depend node has wrong inputs. " << depend->DebugString();
176   }
177   auto depend_input = depend->input(kInputIndex);
178   (void)mgr->Replace(depend, depend_input);
179   // Replace UpdateState node with the input monad of Depend.
180   return input_monad;
181 }
182 
ExistEnvironGet(const FuncGraphManagerPtr & manager)183 bool ExistEnvironGet(const FuncGraphManagerPtr &manager) {
184   const FuncGraphSet &fgs = manager->func_graphs();
185   for (auto &fg : fgs) {
186     auto &nodes = fg->value_nodes();
187     bool exist = std::any_of(nodes.begin(), nodes.end(),
188                              [](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvironGet); });
189     if (exist) {
190       return true;
191     }
192   }
193   return false;
194 }
195 
196 // Convert:
197 // cnode1 = EnvironSet(EnvCreate(), para1, attach1)
198 // cnode2 = EnvironSet(cnode1, para2, attach2)
199 // ...
200 // cnode_n = EnvironSet(cnode_n-1, para_n-1, attach_n-1)
201 // maketuple = maketuple(cnode_n, ...)
202 // updatestate = updatestate(umonad, maketuple)
203 // To:
204 // new_maketuple = maketuple(..., attach1, attach2, ..., attach_n-1)
205 // new_updatestate = updatestate(umonad, new_maketuple)
EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr & update_state,const CNodePtr & make_tuple)206 AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_state, const CNodePtr &make_tuple) {
207   std::vector<AnfNodePtr> env_nodes;
208   std::vector<AnfNodePtr> new_maketuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
209   size_t input_size = make_tuple->size();
210   bool has_environ_set = false;
211   for (size_t i = 1; i < input_size; i++) {
212     auto node = make_tuple->input(i);
213     if (IsPrimitiveCNode(node, prim::kPrimEnvironSet) && OnlyUsedByOneNode(node, make_tuple)) {
214       (void)env_nodes.emplace_back(node);
215       has_environ_set = true;
216     } else if (node->isa<CNode>() && !IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
217       (void)new_maketuple_inputs.emplace_back(node);
218     }
219   }
220   if (!has_environ_set) {
221     return nullptr;
222   }
223   // Check EnvironSet in MakeTuple
224   auto mgr = GetManager(update_state);
225   if (mgr == nullptr) {
226     return nullptr;
227   }
228   // If exist EnvironGet, don't eliminate EnvironSet.
229   if (ExistEnvironGet(mgr)) {
230     return nullptr;
231   }
232   const size_t first_index = 1;
233   const size_t attach_index = 3;
234   const size_t no_env_node_size = new_maketuple_inputs.size();
235   while (!env_nodes.empty()) {
236     auto env = env_nodes.back();
237     env_nodes.pop_back();
238     if (!env->isa<CNode>()) {
239       continue;
240     }
241     auto env_cnode = env->cast<CNodePtr>();
242     auto env_input = env_cnode->input(first_index);
243     auto attach = env_cnode->input(attach_index);
244     if (IsPrimitiveCNode(env_input, prim::kPrimEnvironSet) && OnlyUsedByOneNode(env_input, env_cnode)) {
245       (void)env_nodes.emplace_back(env_input);
246       (void)new_maketuple_inputs.insert(new_maketuple_inputs.cbegin() + SizeToLong(no_env_node_size), attach);
247     }
248   }
249   if (new_maketuple_inputs.size() == 1) {
250     return nullptr;
251   }
252   auto fg = update_state->func_graph();
253   if (fg == nullptr) {
254     return nullptr;
255   }
256   abstract::AbstractBasePtrList element_abstracts;
257   (void)std::transform(new_maketuple_inputs.begin() + 1, new_maketuple_inputs.end(),
258                        std::back_inserter(element_abstracts),
259                        [](const AnfNodePtr &input) { return input->abstract(); });
260   auto new_make_tuple = fg->NewCNode(new_maketuple_inputs);
261   new_make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
262   auto new_update_state =
263     fg->NewCNode({update_state->input(kFirstInputIndex), update_state->input(kInputIndex), new_make_tuple});
264   new_update_state->set_abstract(update_state->abstract());
265   new_update_state->set_scope(update_state->scope());
266   return new_update_state;
267 }
268 
EliminateUpdateStateMakeTupleWithUselessNode(const CNodePtr & update_state,const CNodePtr & make_tuple)269 AnfNodePtr EliminateUpdateStateMakeTupleWithUselessNode(const CNodePtr &update_state, const CNodePtr &make_tuple) {
270   if (make_tuple->size() != kMakeTupleSize) {
271     return nullptr;
272   }
273   AnfNodePtr attach_node = nullptr;
274   auto &first_input = make_tuple->input(kInputIndex);
275   auto &second_input = make_tuple->input(kAttachIndex);
276 
277   // Eliminate useless make_tuple with 'DeadNode' or 'PolyNode'.
278   // UpdateState(u, MakeTuple(input, "DeadNode")) -> UpdateState(u, input)
279   if (IsDeadNode(second_input) || IsPolyNode(second_input)) {
280     return NewUpdateStateWithAttach(update_state, first_input);
281   }
282 
283   // Eliminate useless make_tuple with useless Function.
284   // UpdateState(u, MakeTuple(Function, input) -> UpdateState(u, input)
285   // UpdateState(u, MakeTuple(input, Function) -> UpdateState(u, input)
286   if (IsValueNode<FuncGraph>(first_input) && OnlyUsedByOneNode(first_input, make_tuple)) {
287     return NewUpdateStateWithAttach(update_state, second_input);
288   }
289   if (IsValueNode<FuncGraph>(second_input) && OnlyUsedByOneNode(second_input, make_tuple)) {
290     return NewUpdateStateWithAttach(update_state, first_input);
291   }
292   return nullptr;
293 }
294 
295 void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states,
296                         std::vector<CNodePtr> *loads);
297 void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states,
298                          std::vector<CNodePtr> *loads);
299 
300 // Search consecutive load nodes from UpdateState node.
GetLoadsFromUpdateState(const CNodePtr & update_state,std::vector<CNodePtr> * update_states,std::vector<CNodePtr> * loads)301 void GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
302                              std::vector<CNodePtr> *loads) {
303   auto &attach = update_state->input(kAttachIndex);
304   if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
305     GetLoadsFollowLoad(update_state, attach->cast<CNodePtr>(), update_states, loads);
306   } else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
307     GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
308   }
309 }
310 
GetLoadsFollowLoad(const CNodePtr & update_state,const CNodePtr & load,std::vector<CNodePtr> * update_states,std::vector<CNodePtr> * loads)311 void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std::vector<CNodePtr> *update_states,
312                         std::vector<CNodePtr> *loads) {
313   (void)update_states->emplace_back(update_state);
314   (void)loads->emplace_back(load);
315   auto &load_attach = load->input(kAttachIndex);
316   if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) {
317     GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads);
318   }
319 }
320 
GetLoadsFollowTuple(const CNodePtr & update_state,const CNodePtr & make_tuple,std::vector<CNodePtr> * update_states,std::vector<CNodePtr> * loads)321 void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states,
322                          std::vector<CNodePtr> *loads) {
323   if (!OnlyUsedByOneNode(make_tuple, update_state)) {
324     // UpdateState should be the only user of make_tuple.
325     return;
326   }
327   auto &inputs = make_tuple->inputs();
328   const auto &monad = update_state->input(kInputIndex);
329   bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), [&monad](const AnfNodePtr &node) {
330     // Tuple element should be Load and use same monad that UpdateState used.
331     return (IsPrimitiveCNode(node, prim::kPrimLoad) && node->cast<CNodePtr>()->input(kAttachIndex) == monad);
332   });
333   if (!is_all_load) {
334     // Stop if not all tuple elements are load nodes and use same monad.
335     return;
336   }
337   // Add update_state and load nodes.
338   (void)update_states->emplace_back(update_state);
339   for (size_t i = 1; i < inputs.size(); ++i) {
340     auto &element = inputs.at(i);
341     (void)loads->emplace_back(element->cast<CNodePtr>());
342   }
343   // Follow prev update state if found.
344   auto prev_node = update_state->input(kInputIndex);
345   if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) {
346     GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads);
347   }
348 }
349 
350 // Create a MakeTuple node before UpdateState for same nodes, if there are more than 1 node used.
MakeTupleForSameNodes(const FuncGraphPtr & fg,const CNodePtr & old_update_state,const AnfNodePtrList & make_tuple_inputs)351 AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_update_state,
352                                  const AnfNodePtrList &make_tuple_inputs) {
353   constexpr size_t kOneNodeInputSize = 2;
354   if (make_tuple_inputs.size() == kOneNodeInputSize) {
355     // We don't need make_tuple since there is only one load.
356     return make_tuple_inputs.at(1);
357   }
358   // Create MakeTuple cnode.
359   auto make_tuple = fg->NewCNode(make_tuple_inputs);
360   // Set abstract for the MakeTuple node.
361   abstract::AbstractBasePtrList element_abstracts;
362   std::transform(make_tuple_inputs.begin() + 1, make_tuple_inputs.end(), std::back_inserter(element_abstracts),
363                  [](const AnfNodePtr &input) { return input->abstract(); });
364   make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
365   make_tuple->set_scope(old_update_state->scope());
366   return make_tuple;
367 }
368 
369 // Remove all nodes related to UpdateStates, if they're redundant.
EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> & update_states)370 void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) {
371   if (update_states.empty()) {
372     return;
373   }
374   auto mgr = GetManager(update_states[0]);
375   if (mgr == nullptr) {
376     return;
377   }
378 
379   // 1. Remove the use of UpdateState nodes, except the last one.
380   for (auto i = update_states.size() - 1; i > 0; i--) {
381     auto &us = update_states[i];
382     (void)mgr->Replace(us, us->input(kInputIndex));
383   }
384 
385   // 2. Remove the Depend users of last UpdateState node.
386   auto &node_users = mgr->node_users();
387   auto iter = node_users.find(update_states[0]);
388   if (iter == node_users.end()) {
389     return;
390   }
391   auto &us_users = iter->second;
392   if (us_users.size() < 2) {
393     return;
394   }
395   std::vector<AnfNodePtr> depend_nodes;
396   for (auto &user : us_users) {
397     if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kAttachIndex) {
398       (void)depend_nodes.emplace_back(user.first);
399     }
400   }
401   if (depend_nodes.empty()) {
402     return;
403   }
404   ssize_t end = 0;
405   // If all users are Depend CNode.
406   if (depend_nodes.size() == us_users.size()) {
407     end = 1;
408     // Set abstract value for reserved Depend node.
409     auto &reserved_depend_node = depend_nodes[0];
410     auto &primary_node = reserved_depend_node->cast<CNodePtr>()->input(kInputIndex);
411     reserved_depend_node->set_abstract(primary_node->abstract());
412   }
413   for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) {
414     const auto &depend_node = depend_nodes[i];
415     const auto &depend_cnode = depend_node->cast<CNodePtr>();
416     (void)mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex));
417   }
418 }
419 
420 // Eliminate UpdateStates for consecutive Loads.
421 // Convert:
422 //    x1 = Load(x1, u)
423 //    u1 = UpdateState(u, x1)
424 //    x2 = Load(x2, u1)
425 //    u2 = UpdateState(u1, x2)
426 //    ...
427 //    xN = Load(xN, u(N-1))
428 //    uN = UpdateState(u(N-1), xN)
429 // To:
430 //    x1 = Load(x1, u)
431 //    x2 = Load(x2, u)
432 //    ...
433 //    xN = Load(xN, u)
434 //    t = make_tuple(x1, x2, ... , xN)
435 //    u1 = UpdateState(u, t)
EliminateUpdateStateForLoads(const CNodePtr & old_update_state,const std::vector<CNodePtr> & update_states,const std::vector<CNodePtr> & loads)436 AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
437                                         const std::vector<CNodePtr> &loads) {
438   auto fg = old_update_state->func_graph();
439   if (fg == nullptr) {
440     return nullptr;
441   }
442   auto mgr = fg->manager();
443   if (mgr == nullptr) {
444     return nullptr;
445   }
446   // Prepare tuple elements from Load nodes.
447   AnfNodePtrList make_tuple_inputs;
448   std::set<AnfNodePtr> loaded_para_set;
449   make_tuple_inputs.reserve(loads.size() + 1);
450   (void)make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
451   auto input_monad = loads.back()->input(kAttachIndex);
452   for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) {
453     auto &load = *iter;
454     auto result = loaded_para_set.emplace(load->input(kInputIndex));
455     const bool is_new_load = result.second;
456     if (is_new_load) {
457       // Put Load node as a tuple element, if the parameter is not loaded by other Load.
458       (void)make_tuple_inputs.emplace_back(load);
459     }
460     auto load_attach = load->input(kAttachIndex);
461     if (load_attach != input_monad) {
462       // Set all load use same input monad.
463       (void)mgr->Replace(load_attach, input_monad);
464     }
465   }
466 
467   EliminateUselessNodesForUpdateStates(update_states);
468 
469   if (make_tuple_inputs.size() == 1) {
470     // This should not happen.
471     MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);
472     return nullptr;
473   }
474   // Create the new UpdateState node with a MakeTuple, replace the old UpdateStateNode.
475   auto attach = MakeTupleForSameNodes(fg, old_update_state, make_tuple_inputs);
476   auto update_state = NewValueNode(prim::kPrimUpdateState);
477   auto new_update_state = fg->NewCNode({update_state, input_monad, attach});
478   new_update_state->set_abstract(old_update_state->abstract());
479   new_update_state->set_scope(old_update_state->scope());
480   return new_update_state;
481 }
482 
483 // Eliminate UpdateStates between Assign nodes.
484 // Covert:
485 // a1 = Assign(para1, value1, u1)
486 // u2 = UpdateState(u1, a1)
487 // a2 = Assign(para2, value2, u2)  # para1 != para2, para1 != value2, para2 != value1
488 // u3 = UpdateState(u2, a2)
489 // To:
490 // a1 = Assign(para1, value1, u1)
491 // a2 = Assign(para2, value2, u1)
492 // t = MakeTuple(a1, a2)
493 // u3 = UpdateState(u1, t)
EliminateUpdateStateBetweenAssigns(const CNodePtr & update_state,const AnfNodePtr & assign)494 AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, const AnfNodePtr &assign) {
495   auto a2_cnode = assign->cast<CNodePtr>();
496   auto u2 = a2_cnode->input(kAssignMonadInputIndex);
497   auto a1 = u2->cast<CNodePtr>()->input(kAttachIndex);
498   if (IsPrimitiveCNode(a1, prim::kPrimAssign)) {
499     auto a1_cnode = a1->cast<CNodePtr>();
500     if (a1_cnode->size() != kAssignSize) {
501       return nullptr;
502     }
503     auto para1 = a1_cnode->input(kInputIndex);
504     auto value1 = a1_cnode->input(kAttachIndex);
505     auto para2 = a2_cnode->input(kInputIndex);
506     auto value2 = a2_cnode->input(kAttachIndex);
507     auto u1 = a1_cnode->input(kAssignMonadInputIndex);
508     if (para1 != para2 && para1 != value2 && para2 != value1) {
509       auto fg = update_state->func_graph();
510       MS_EXCEPTION_IF_NULL(fg);
511       auto mgr = fg->manager();
512       MS_EXCEPTION_IF_NULL(mgr);
513       (void)mgr->Replace(u2, u1);
514 
515       AnfNodePtrList make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1, assign};
516       auto make_tuple = MakeTupleForSameNodes(fg, update_state, make_tuple_inputs);
517       auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, make_tuple});
518       new_update_state->set_abstract(update_state->abstract());
519       new_update_state->set_scope(update_state->scope());
520       return new_update_state;
521     }
522   }
523   return nullptr;
524 }
525 
526 // Eliminate Load before Assign nodes.
527 // Covert:
528 // load = Load(parameter)
529 // a = Assign(load, value, u)
530 // To:
531 // a = Assign(parameter, value, u)
EliminateLoadBeforeAssigns(const FuncGraphManagerPtr & manager,const CNodePtr & update_state)532 bool EliminateLoadBeforeAssigns(const FuncGraphManagerPtr &manager, const CNodePtr &update_state) {
533   auto &attach = update_state->input(kAttachIndex);
534   // UpdateState(u, Assign(para, value, u))
535   if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
536     auto assign = attach->cast<CNodePtr>();
537     if (assign->size() != kAssignSize) {
538       return false;
539     }
540     // If assign's first input is load, eliminate load.
541     auto &ref_node = assign->input(kAssignRefInputIndex);
542     if (IsPrimitiveCNode(ref_node, prim::kPrimLoad)) {
543       auto load = ref_node->cast<CNodePtr>();
544       auto &parameter = load->input(kInputIndex);
545       // If Load used by other nodes, keep load node.
546       auto assign_cnode = assign->cast<CNodePtr>();
547       if (OnlyUsedByOneNode(ref_node, assign_cnode)) {
548         (void)manager->Replace(ref_node, parameter);
549       } else {
550         manager->SetEdge(assign, kInputIndex, parameter);
551       }
552       return true;
553     }
554   }
555   return false;
556 }
557 
558 // Eliminate UpdateStates between MakeTuple and Assign.
559 // Covert:
560 // a1 = Assign(para1, value1, u1)
561 // a2 = Assign(para2, value2, u2)  # u2 == u1
562 // t = MakeTuple(a1, a2)
563 // u3 = UpdateState(u1, t)
564 // a3 = Assign(para3, value3, u3)  # para3 != para1, para3 != para2, value3 != para1, value3 != para2
565 //                                 # value1 != para3, value2 != para3
566 // u4 = UpdateState(u3, a3)
567 // To:
568 // a1 = Assign(para1, value1, u1)
569 // a2 = Assign(para2, value2, u1)
570 // a3 = Assign(para3, value3, u1)
571 // t = MakeTuple(a1, a2, a3)
572 // u4 = UpdateState(u1, t)
EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr & update_state,const AnfNodePtr & assign)573 AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &assign) {
574   auto a3_cnode = assign->cast<CNodePtr>();
575   auto u3 = a3_cnode->input(kAssignMonadInputIndex);
576   auto u3_cnode = u3->cast<CNodePtr>();
577   auto make_tuple = u3_cnode->input(kAttachIndex);
578   if (IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple) && OnlyUsedByOneNode(make_tuple, u3_cnode)) {
579     auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
580     if (make_tuple_cnode->size() != kMakeTupleSize) {
581       return nullptr;
582     }
583     auto a1 = make_tuple_cnode->input(kInputIndex);
584     auto a2 = make_tuple_cnode->input(kAttachIndex);
585     if (IsPrimitiveCNode(a1, prim::kPrimAssign) && IsPrimitiveCNode(a2, prim::kPrimAssign)) {
586       auto a1_cnode = a1->cast<CNodePtr>();
587       auto a2_cnode = a2->cast<CNodePtr>();
588       if (a1_cnode->size() != kAssignSize || a2_cnode->size() != kAssignSize) {
589         return nullptr;
590       }
591       auto para1 = a1_cnode->input(kInputIndex);
592       auto value1 = a1_cnode->input(kAttachIndex);
593       auto u1 = a1_cnode->input(kAssignMonadInputIndex);
594       auto para2 = a2_cnode->input(kInputIndex);
595       auto value2 = a2_cnode->input(kAttachIndex);
596       auto u2 = a2_cnode->input(kAssignMonadInputIndex);
597       auto para3 = a3_cnode->input(kInputIndex);
598       auto value3 = a3_cnode->input(kAttachIndex);
599       bool replace_judge = (u1 == u2) && (para1 != para3) && (para1 != value3) && (para2 != para3) &&
600                            (para2 != value3) && (value1 != para3) && (value2 != para3);
601       if (replace_judge) {
602         auto fg = update_state->func_graph();
603         MS_EXCEPTION_IF_NULL(fg);
604         auto mgr = fg->manager();
605         MS_EXCEPTION_IF_NULL(mgr);
606         (void)mgr->Replace(u3, u1);
607 
608         AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), make_tuple_cnode->input(kInputIndex),
609                                              make_tuple_cnode->input(kAttachIndex), assign};
610         auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
611         (void)mgr->Replace(make_tuple, new_make_tuple);
612         auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
613         new_update_state->set_abstract(update_state->abstract());
614         new_update_state->set_scope(update_state->scope());
615         return new_update_state;
616       }
617     }
618   }
619   return nullptr;
620 }
621 
622 // Eliminate UpdateStates between Assign and MakeTuple.
623 // Covert:
624 // a1 = Assign(para1, value1, u1)
625 // u2 = UpdateState(u1_1, a1)      # u1_1 == u1
626 // a2 = Assign(para2, value2, u2)
627 // a3 = Assign(para3, value3, u3)  # u2 == u3
628 // t = MakeTuple(a2, a3)
629 // u4 = UpdateState(u3, t)         # para3 != para1, para3 != para2, value3 != para1, value3 != para2
630 //                                 # value1 != para3, value1 != para3
631 // To:
632 // a1 = Assign(para1, value1, u1)
633 // a2 = Assign(para2, value2, u1)
634 // a3 = Assign(para3, value3, u1)
635 // t = MakeTuple(a1, a2, a3)
636 // u4 = UpdateState(u1, t)
EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr & update_state,const AnfNodePtr & make_tuple)637 AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_state, const AnfNodePtr &make_tuple) {
638   auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
639   if (make_tuple_cnode->size() != kMakeTupleSize || !OnlyUsedByOneNode(make_tuple, update_state)) {
640     return nullptr;
641   }
642   auto a2 = make_tuple_cnode->input(kInputIndex);
643   auto a3 = make_tuple_cnode->input(kAttachIndex);
644   if (IsPrimitiveCNode(a2, prim::kPrimAssign) && IsPrimitiveCNode(a3, prim::kPrimAssign)) {
645     auto a2_cnode = a2->cast<CNodePtr>();
646     auto a3_cnode = a3->cast<CNodePtr>();
647     if (a2_cnode->size() != kAssignSize || a3_cnode->size() != kAssignSize) {
648       return nullptr;
649     }
650     auto para2 = a2_cnode->input(kInputIndex);
651     auto value2 = a2_cnode->input(kAttachIndex);
652     auto u2 = a2_cnode->input(kAssignMonadInputIndex);
653     if (!IsPrimitiveCNode(u2, prim::kPrimUpdateState) || !OnlyUsedByTwoNode(u2, a2, a3)) {
654       return nullptr;
655     }
656     auto para3 = a3_cnode->input(kInputIndex);
657     auto value3 = a3_cnode->input(kAttachIndex);
658     auto u3 = a3_cnode->input(kAssignMonadInputIndex);
659     if (u2 == u3) {
660       auto u2_cnode = u2->cast<CNodePtr>();
661       MS_EXCEPTION_IF_NULL(u2_cnode);
662       auto u1 = u2_cnode->input(kInputIndex);
663       auto a1 = u2_cnode->input(kAttachIndex);
664       if (IsPrimitiveCNode(a1, prim::kPrimAssign)) {
665         auto a1_cnode = a1->cast<CNodePtr>();
666         MS_EXCEPTION_IF_NULL(a1_cnode);
667         if (a1_cnode->size() != kAssignSize) {
668           return nullptr;
669         }
670         auto para1 = a1_cnode->input(kInputIndex);
671         auto value1 = a1_cnode->input(kAttachIndex);
672         auto u1_1 = a1_cnode->input(kAssignMonadInputIndex);
673         bool replace_judge = (u1 == u1_1) && (para1 != para2) && (para1 != para3) && (para1 != value2) &&
674                              (para1 != value3) && (para2 != value1) && (para3 != value1);
675         if (replace_judge) {
676           auto fg = update_state->func_graph();
677           MS_EXCEPTION_IF_NULL(fg);
678           auto mgr = fg->manager();
679           MS_EXCEPTION_IF_NULL(mgr);
680           (void)mgr->Replace(u2, u1);
681           AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1,
682                                                make_tuple_cnode->input(kInputIndex),
683                                                make_tuple_cnode->input(kAttachIndex)};
684           auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
685           (void)mgr->Replace(make_tuple, new_make_tuple);
686           auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
687           new_update_state->set_abstract(update_state->abstract());
688           new_update_state->set_scope(update_state->scope());
689           return new_update_state;
690         }
691       }
692     }
693   }
694   return nullptr;
695 }
696 
EliminateUpdateStateForAssign(const CNodePtr & update_state)697 AnfNodePtr EliminateUpdateStateForAssign(const CNodePtr &update_state) {
698   // UpdateState(u, MakeTuple)
699   auto &attach = update_state->input(kAttachIndex);
700   if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
701     return EliminateUpdateStateBetweenMakeTupleAssign(update_state, attach);
702   }
703   // UpdateState(u, Assign(para, value, u))
704   if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
705     auto assign = attach->cast<CNodePtr>();
706     if (assign->size() != kAssignSize) {
707       return nullptr;
708     }
709     auto u = assign->input(kAssignMonadInputIndex);
710     // u is UpdateState, u only be used by assign and update_state.
711     if (IsPrimitiveCNode(u, prim::kPrimUpdateState) && OnlyUsedByTwoNode(u, assign, update_state)) {
712       auto u_attach = u->cast<CNodePtr>()->input(kAttachIndex);
713       if (IsPrimitiveCNode(u_attach, prim::kPrimAssign)) {
714         return EliminateUpdateStateBetweenAssigns(update_state, assign);
715       }
716       if (IsPrimitiveCNode(u_attach, prim::kPrimMakeTuple)) {
717         return EliminateUpdateStateBetweenAssignMakeTuple(update_state, assign);
718       }
719     }
720   }
721   return nullptr;
722 }
723 }  // namespace
724 
725 // Eliminate useless node that only used by associated update_state.
operator ()(const OptimizerPtr &,const AnfNodePtr & node)726 AnfNodePtr UpdatestateUselessNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
727   auto update_state_node = dyn_cast<CNode>(node);
728   if (update_state_node == nullptr || update_state_node->size() != kUpdateStateSize) {
729     return nullptr;
730   }
731 
732   // If update_state is the only user of partial/load, replace it with the input monad.
733   // UpdateState(u, Partial) -> u
734   // UpdateState(u, Load) -> u
735   // UpdateState(u, FuncGraph) -> u
736   auto &attach = update_state_node->input(kAttachIndex);
737   if (IsPrimitiveCNode(attach, prim::kPrimPartial) || IsPrimitiveCNode(attach, prim::kPrimLoad) ||
738       IsValueNode<FuncGraph>(attach)) {
739     // Replace UpdateState with the input monad.
740     if (OnlyUsedByOneNode(attach, update_state_node)) {
741       return update_state_node->input(kInputIndex);
742     }
743     return nullptr;
744   }
745 
746   // Handling the case where the second input of update_state is make_tuple which contains DeadNode or useless function.
747   // UpdateState(u, MakeTuple(input, "Dead Node")) -> UpdateState(u, input)
748   // UpdateState(u, MakeTuple(Function, input) -> UpdateState(u, input)
749   // UpdateState(u, MakeTuple(input, Function) -> UpdateState(u, input)
750   if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
751     auto new_node = EliminateUpdateStateMakeTupleWithUselessNode(update_state_node, attach->cast<CNodePtr>());
752     if (new_node != nullptr) {
753       return new_node;
754     }
755     return EliminateUpdateStateMakeTupleWithUselessEnv(update_state_node, attach->cast<CNodePtr>());
756   }
757   return nullptr;
758 }
759 
760 // Eliminate UpdateState that attaches a pure (no-side-effect) node.
761 // Convert:
762 //   x = pure_node(args) # no side effect
763 //   u1 = update_state(u, x)
764 //   user(u1)
765 // To:
766 //   x = pure_node(args)
767 //   user(u)
operator ()(const OptimizerPtr &,const AnfNodePtr & node)768 AnfNodePtr UpdatestatePureNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
769   auto update_state_node = dyn_cast<CNode>(node);
770   if (update_state_node == nullptr || update_state_node->size() != kUpdateStateSize) {
771     return nullptr;
772   }
773 
774   auto &attach = update_state_node->input(kAttachIndex);
775   // update_state(u, param) or update_state(u, value_node) is redundant.
776   auto cnode = dyn_cast<CNode>(attach);
777   if (cnode == nullptr) {
778     return update_state_node->input(kInputIndex);
779   }
780   const auto &first_input = cnode->input(0);
781   bool is_special_ops = cnode->IsApply(prim::kPrimTupleGetItem) || cnode->IsApply(prim::kPrimDepend) ||
782                         cnode->IsApply(prim::kPrimPartial) || cnode->IsApply(prim::kPrimMakeTuple) ||
783                         cnode->IsApply(prim::kPrimCall) || IsValueNode<FuncGraph>(first_input) ||
784                         IsPrimitiveCNode(first_input, prim::kPrimJ) || IsPrimitiveCNode(first_input, prim::kPrimVmap) ||
785                         IsPrimitiveCNode(first_input, prim::kPrimTaylor) ||
786                         IsPrimitiveCNode(first_input, prim::kPrimShard);
787   if (is_special_ops) {
788     return nullptr;
789   }
790   if (CheckHasMonadInput(cnode)) {
791     return nullptr;
792   }
793   return update_state_node->input(kInputIndex);
794 }
795 
796 // Eliminate redundant UpdateState/Depend pair nodes caused by inline.
797 // Convert:
798 //    x1 = Depend(x, u0)
799 //    u1 = UpdateState(u', x1)
800 //    out = x_user(x1)
801 //    u2 = u_user(u1)
802 // To:
803 //    out = x_user(x)
804 //    u2 = u_user(u0)
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)805 bool UpdatestateDependEliminater::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
806   // Filter nodes that do not match UpdateState(u, Depend).
807   auto filter = [](const AnfNodePtr &node) {
808     if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
809       auto update_state = node->cast<CNodePtr>();
810       if (update_state->size() != kUpdateStateSize) {
811         return true;
812       }
813       auto &attach = update_state->input(kAttachIndex);
814       if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
815         return false;
816       }
817     }
818     return true;
819   };
820 
821   bool change = false;
822   auto manager = optimizer->manager();
823   MS_EXCEPTION_IF_NULL(manager);
824   auto &all_nodes = manager->all_nodes();
825   auto todo = TopoSort(func_graph->get_return(), SuccDeeperSimple);
826   for (auto &node : todo) {
827     if (node == nullptr || !all_nodes.contains(node) || filter(node)) {
828       continue;
829     }
830     auto new_node = EliminateUpdateStateWithDepend(node->cast<CNodePtr>());
831     if (new_node != nullptr) {
832       (void)manager->Replace(node, new_node);
833       change = true;
834     }
835   }
836   return change;
837 }
838 
839 // Eliminate UpdateStates for consecutive Assign.
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)840 bool UpdatestateAssignEliminater::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
841   // Filter nodes that do not match UpdateState(u, Assign) or UpdateState(u, MakeTuple).
842   auto filter = [](const AnfNodePtr &node) {
843     if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
844       auto update_state = node->cast<CNodePtr>();
845       if (update_state->size() != kUpdateStateSize) {
846         return true;
847       }
848       auto &attach = update_state->input(kAttachIndex);
849       if (IsPrimitiveCNode(attach, prim::kPrimAssign) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
850         return false;
851       }
852     }
853     return true;
854   };
855 
856   bool change = false;
857   auto manager = optimizer->manager();
858   MS_EXCEPTION_IF_NULL(manager);
859   auto &all_nodes = manager->all_nodes();
860   std::vector<AnfNodePtr> todo = TopoSort(func_graph->get_return(), SuccDeeperSimple);
861   for (auto &node : todo) {
862     if (node == nullptr || !all_nodes.contains(node) || filter(node)) {
863       continue;
864     }
865     auto new_node = EliminateUpdateStateForAssign(node->cast<CNodePtr>());
866     if (new_node != nullptr) {
867       (void)manager->Replace(node, new_node);
868       change = true;
869     }
870     bool load_eliminate = EliminateLoadBeforeAssigns(manager, node->cast<CNodePtr>());
871     change = change || load_eliminate;
872   }
873   return change;
874 }
875 
876 // Eliminate UpdateStates for consecutive Loads.
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)877 bool UpdatestateLoadsEliminater::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
878   // Filter nodes that do not match UpdateState(u, Load) or UpdateState(u, MakeTuple).
879   auto filter = [](const AnfNodePtr &node) {
880     if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
881       auto update_state = node->cast<CNodePtr>();
882       if (update_state->size() != kUpdateStateSize) {
883         return true;
884       }
885       auto &attach = update_state->input(kAttachIndex);
886       if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
887         return false;
888       }
889     }
890     return true;
891   };
892 
893   bool change = false;
894   auto manager = optimizer->manager();
895   MS_EXCEPTION_IF_NULL(manager);
896   auto &all_nodes = manager->all_nodes();
897   std::vector<AnfNodePtr> todo = TopoSort(func_graph->get_return(), SuccDeeperSimple);
898   for (auto &node : todo) {
899     if (node == nullptr || !all_nodes.contains(node) || filter(node)) {
900       continue;
901     }
902     std::vector<CNodePtr> update_states;
903     std::vector<CNodePtr> loads;
904     auto update_state_node = node->cast<CNodePtr>();
905     GetLoadsFromUpdateState(update_state_node, &update_states, &loads);
906     if (update_states.size() > 1 && loads.size() > 1) {
907       auto new_node = EliminateUpdateStateForLoads(update_state_node, update_states, loads);
908       if (new_node != nullptr) {
909         (void)manager->Replace(node, new_node);
910         change = true;
911       }
912     }
913   }
914   return change;
915 }
916 
917 // Eliminate Monad parameter for switch call.
918 // Convert:
919 //     x = Load(x, u)
920 //     u = UpdateState(u, x)
921 //     ...
922 //     g1 = Partial(...)
923 //     g2 = Partial(...)
924 //     s = switch(cond, g1, g2)
925 //     res = s(u)
926 // To:
927 //     x = Load(x, u)
928 //     u = UpdateState(u, x)
929 //     ...
930 //     g1 = Partial(..., u)
931 //     g2 = Partial(..., u)
932 //     s = switch(cond, g1, g2)
933 //     res = s()
operator ()(const OptimizerPtr &,const AnfNodePtr & node)934 AnfNodePtr SwitchCallMonadParameterEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
935   const CNodePtr &switch_call = dyn_cast<CNode>(node);
936   if (switch_call == nullptr) {
937     return nullptr;
938   }
939   auto fg = switch_call->func_graph();
940   if (fg == nullptr) {
941     return nullptr;
942   }
943   auto mgr = fg->manager();
944   if (mgr == nullptr) {
945     return nullptr;
946   }
947   const size_t switch_call_input_size = 2;
948   if (switch_call->size() < switch_call_input_size) {
949     return nullptr;
950   }
951   constexpr size_t primary_index = 0;
952   auto switch_node = switch_call->input(primary_index);
953   if (!IsPrimitiveCNode(switch_node, prim::kPrimSwitch)) {
954     return nullptr;
955   }
956   MS_LOG(DEBUG) << "Found switch call with monad parameter, " << switch_call->DebugString();
957   auto switch_cnode = dyn_cast<CNode>(switch_node);
958   if (switch_cnode == nullptr) {
959     MS_LOG(EXCEPTION) << "switch node cast to CNode failed, " << switch_node->DebugString();
960   }
961   constexpr size_t condition_index = 1;
962   constexpr size_t first_fg_index = 2;
963   constexpr size_t second_fg_index = 3;
964   auto fg1_node = switch_cnode->input(first_fg_index);
965   auto fg2_node = switch_cnode->input(second_fg_index);
966   auto build_partial = [&fg, &switch_call](const AnfNodePtr &node) {
967     CNodePtr new_node;
968     if (IsPrimitiveCNode(node, prim::kPrimPartial)) {  // Node is already Partial CNode.
969       new_node = fg->NewCNode(node->cast<CNodePtr>()->inputs());
970     } else {  // Node is FuncGraph ValueNode.
971       new_node = fg->NewCNode({NewValueNode(prim::kPrimPartial), node});
972     }
973     constexpr size_t args_start_index = 1;
974     for (size_t i = args_start_index; i < switch_call->size(); i++) {
975       new_node->add_input(switch_call->input(i));
976     }
977     // partial's abstract is same with first input.
978     new_node->set_abstract(new_node->input(1)->abstract());
979     return new_node;
980   };
981   fg1_node = build_partial(fg1_node);
982   fg2_node = build_partial(fg2_node);
983   auto cond = switch_cnode->input(condition_index);
984   auto new_switch_cnode = fg->NewCNode({NewValueNode(prim::kPrimSwitch), cond, fg1_node, fg2_node});
985   auto new_switch_call = fg->NewCNode({new_switch_cnode});
986   new_switch_cnode->set_abstract(switch_node->abstract());
987   new_switch_call->set_abstract(switch_call->abstract());
988   return new_switch_call;
989 }
990 }  // namespace mindspore::opt::irpass
991