• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
19 #include <memory>
20 #include <vector>
21 #include <utility>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/nn_ops.h"
25 #include "mindspore/core/ops/lite_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "include/errorcode.h"
29 #include "tools/lite_exporter/fetch_content.h"
30 #include "ops/make_tuple.h"
31 #include "ops/depend.h"
32 #include "ops/fusion/pad_fusion.h"
33 #include "ops/op_utils.h"
34 #include "nnacl/op_base.h"
35 #include "include/common/utils/utils.h"
36 
37 namespace mindspore::opt {
38 namespace {
39 const size_t kIndexNum = 2;
ReplaceUpdateStateWithMonad(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool remove_side_effect)40 int ReplaceUpdateStateWithMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool remove_side_effect) {
41   if (!remove_side_effect) {
42     return lite::RET_NO_CHANGE;
43   }
44   // only solve UpdateState with at lease one Monad input
45   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
46   AnfNodePtr monad_input = nullptr;
47   auto first_input = cnode->input(kInputIndexOne);
48   if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
49     first_input = first_input->cast<CNodePtr>()->input(kInputIndexOne);
50     MS_CHECK_TRUE_MSG(first_input != nullptr, RET_ERROR, "first_input is nullptr");
51   }
52   auto second_input = cnode->input(kInputIndexTwo);
53   if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
54     second_input = second_input->cast<CNodePtr>()->input(kInputIndexOne);
55     MS_CHECK_TRUE_MSG(second_input != nullptr, RET_ERROR, "second_input is nullptr");
56   }
57   if (utils::isa<ValueNode>(first_input)) {
58     auto value_node = first_input->cast<ValueNodePtr>();
59     MS_ASSERT(value_node->value() != nullptr);
60     if (utils::isa<Monad>(value_node->value())) {
61       monad_input = first_input;
62     }
63   }
64   if (utils::isa<ValueNode>(second_input)) {
65     auto value_node = second_input->cast<ValueNodePtr>();
66     MS_ASSERT(value_node->value() != nullptr);
67     if (utils::isa<Monad>(value_node->value())) {
68       monad_input = second_input;
69     }
70   }
71   MS_CHECK_TRUE_MSG(monad_input != nullptr, lite::RET_NO_CHANGE, "not find monad input");
72 
73   // find monad input node, using monad node replace UpdateState node
74   auto manager = func_graph->manager();
75   MS_ASSERT(manager != nullptr);
76   manager->Replace(cnode, monad_input);
77   return lite::RET_OK;
78 }
79 
ProcessInputIsMonad(const FuncGraphPtr & func_graph,const CNodePtr & cnode)80 int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
81   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
82   auto first_input = cnode->input(1);
83   MS_ASSERT(first_input != nullptr);
84   if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
85     first_input = cnode->input(1)->cast<CNodePtr>()->input(1);
86     MS_CHECK_TRUE_MSG(first_input != nullptr, RET_ERROR, "first_input is nullptr");
87   }
88   auto second_input = cnode->input(kInputIndexTwo);
89   MS_ASSERT(seconde_input != nullptr);
90   if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
91     second_input = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
92     MS_CHECK_TRUE_MSG(second_input != nullptr, RET_ERROR, "second_input is nullptr");
93   }
94   AnfNodePtr must_monad = nullptr;
95   AnfNodePtr not_must_monad = nullptr;
96   if (utils::isa<ValueNode>(first_input)) {
97     auto value_node = first_input->cast<ValueNodePtr>();
98     MS_ASSERT(value_node->value() != nullptr);
99     if (utils::isa<Monad>(value_node->value())) {
100       must_monad = first_input;
101       not_must_monad = second_input;
102     }
103   }
104   if (utils::isa<ValueNode>(second_input)) {
105     auto value_node = second_input->cast<ValueNodePtr>();
106     MS_ASSERT(value_node->value() != nullptr);
107     if (utils::isa<Monad>(value_node->value())) {
108       must_monad = second_input;
109       not_must_monad = first_input;
110     }
111   }
112   if (must_monad == nullptr) {
113     return lite::RET_NO_CHANGE;
114   }
115   auto manager = func_graph->manager();
116   MS_ASSERT(manager != nullptr);
117   if (!utils::isa<CNode>(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) {
118     manager->Replace(cnode, must_monad);
119   } else {
120     manager->Replace(cnode, not_must_monad);
121   }
122   return lite::RET_OK;
123 }
124 
ProcessDependencyWithTwoNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool pre_node_is_first)125 int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) {
126   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
127   AnfNodePtr pre_node = cnode->input(1);
128   AnfNodePtr post_node = cnode->input(kInputIndexTwo);
129   MS_ASSERT(pre_node != nullptr);
130   MS_ASSERT(post_node != nullptr);
131   if (!pre_node_is_first) {
132     pre_node = cnode->input(kInputIndexTwo);
133     post_node = cnode->input(1);
134   }
135   if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) {
136     pre_node = cnode->input(1)->cast<CNodePtr>()->input(1);
137     MS_CHECK_TRUE_MSG(pre_node != nullptr, RET_ERROR, "pre_node is nullptr");
138   }
139   if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
140     post_node = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
141     MS_CHECK_TRUE_MSG(post_node != nullptr, RET_ERROR, "post_node is nullptr");
142   }
143   auto manager = func_graph->manager();
144   MS_ASSERT(manager != nullptr);
145   auto node_users = manager->node_users()[pre_node];
146   auto iter =
147     std::find_if(node_users.begin(), node_users.end(),
148                  [&post_node](const std::pair<AnfNodePtr, int> &post_pair) { return post_pair.first == post_node; });
149   if (iter == node_users.end()) {
150     return lite::RET_NO_CHANGE;
151   }
152   auto tr = manager->Transact();
153   tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared<UMonad>()));
154   tr.Commit();
155   auto depend_prim = std::make_shared<ops::Depend>();
156   MS_CHECK_TRUE_MSG(depend_prim != nullptr, lite::RET_ERROR, "New Depend ops Failed");
157   auto depend_prim_c = depend_prim->GetPrim();
158   MS_CHECK_TRUE_MSG(depend_prim_c != nullptr, lite::RET_ERROR, "GetPrim Failed");
159   auto depend_node = func_graph->NewCNode(depend_prim_c, {post_node, pre_node});
160   MS_CHECK_TRUE_MSG(depend_node != nullptr, lite::RET_ERROR, "NewCNode Failed");
161   depend_node->set_fullname_with_scope(cnode->fullname_with_scope());
162   manager->Replace(cnode, depend_node);
163   return lite::RET_OK;
164 }
165 
ProcessInputHaveDependency(const FuncGraphPtr & func_graph,const CNodePtr & cnode)166 int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
167   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
168   if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) {
169     return lite::RET_OK;
170   }
171   if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
172     return lite::RET_OK;
173   }
174   auto make_tuple_node = std::make_shared<ops::MakeTuple>();
175   MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, lite::RET_ERROR, "make tuple node Failed");
176   auto make_tuple_prim_c = make_tuple_node->GetPrim();
177   MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, lite::RET_ERROR, "make tuple prim c Failed");
178   auto make_tuple_prim = NewValueNode(make_tuple_prim_c);
179   MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_ERROR, "NewCNode Failed");
180   auto manager = func_graph->manager();
181   MS_ASSERT(manager != nullptr);
182   if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) {
183     manager->Replace(cnode->input(0)->cast<CNodePtr>()->input(0), make_tuple_prim);
184     return RET_OK;
185   }
186   manager->Replace(cnode->input(0), make_tuple_prim);
187   return lite::RET_OK;
188 }
189 }  // namespace
190 
ReplaceOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)191 int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
192   MS_CHECK_TRUE_MSG(anf_node != nullptr, RET_ERROR, "anf_node is nullptr");
193   MS_CHECK_TRUE_MSG(manager != nullptr, RET_ERROR, "manager is nullptr");
194   if (!utils::isa<CNodePtr>(anf_node)) {
195     MS_LOG(DEBUG) << "anf node is node a cnode.";
196     return lite::RET_NO_CHANGE;
197   }
198   auto cnode = anf_node->cast<CNodePtr>();
199   MS_ASSERT(cnode != nullptr);
200   if (CheckPrimitiveType(anf_node, kPrimIdentity)) {
201     if (cnode->size() != kInputSizeTwo) {
202       MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
203       remove_cnode_.insert(anf_node);
204       return lite::RET_NO_CHANGE;
205     }
206   }
207   if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) {
208     if (cnode->size() != kInputSizeTwo) {
209       MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
210       remove_cnode_.insert(anf_node);
211       return lite::RET_NO_CHANGE;
212     }
213   }
214   if (CheckPrimitiveType(anf_node, prim::kPrimTranspose)) {
215     if (cnode->size() != kInputSizeThree) {
216       MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
217       remove_cnode_.insert(anf_node);
218       return lite::RET_NO_CHANGE;
219     }
220   }
221 
222   bool replace_succ = manager->Replace(anf_node, cnode->input(1));
223   if (!replace_succ) {
224     MS_LOG(ERROR) << "replace redundant op failed.";
225     return lite::RET_ERROR;
226   }
227   return RET_OK;
228 }
229 
ReplaceUpdateStateOp(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node)230 int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
231   if (!utils::isa<CNodePtr>(anf_node)) {
232     MS_LOG(DEBUG) << "anf node is node a cnode.";
233     return lite::RET_NO_CHANGE;
234   }
235   auto cnode = anf_node->cast<CNodePtr>();
236   MS_ASSERT(cnode != nullptr);
237   if (ReplaceUpdateStateWithMonad(func_graph, cnode, remove_side_effect_) == lite::RET_OK) {
238     return lite::RET_OK;
239   }
240 
241   if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) {
242     return lite::RET_OK;
243   }
244   // both of two inputs are not monad, but have dependency.
245   return ProcessInputHaveDependency(func_graph, cnode);
246 }
247 
ReplaceTupleGetItem(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)248 int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
249   if (!utils::isa<CNodePtr>(anf_node)) {
250     MS_LOG(DEBUG) << "anf node is node a cnode.";
251     return lite::RET_NO_CHANGE;
252   }
253   if (!CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
254     return lite::RET_NO_CHANGE;
255   }
256   auto cnode = anf_node->cast<CNodePtr>();
257   MS_ASSERT(cnode != nullptr);
258   if (cnode->size() != kInputSizeThree) {
259     MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->size();
260     return RET_ERROR;
261   }
262   if (!CheckPrimitiveType(cnode->input(1), kPrimIdentity)) {
263     return lite::RET_NO_CHANGE;
264   }
265   auto get_item_input_cnode = cnode->input(1)->cast<CNodePtr>();
266   auto index_vnode = cnode->input(kInputIndexTwo);
267   if (!utils::isa<ValueNode>(index_vnode)) {
268     MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
269     return lite::RET_ERROR;
270   }
271   MS_CHECK_TRUE_MSG(!CastToInt(index_vnode->cast<ValueNodePtr>()->value()).empty(), RET_ERROR, "value is empty");
272   int index = CastToInt(index_vnode->cast<ValueNodePtr>()->value()).front();
273   int input_cnode_inputs_size = static_cast<int>(get_item_input_cnode->size());
274   if ((index + 1) >= input_cnode_inputs_size) {
275     MS_LOG(ERROR) << "value node index is out of range.";
276     return lite::RET_ERROR;
277   }
278   bool replace_succ = manager->Replace(anf_node, get_item_input_cnode->input(index + 1));
279   if (!replace_succ) {
280     MS_LOG(ERROR) << "replace identity failed.";
281     return lite::RET_ERROR;
282   }
283   return lite::RET_OK;
284 }
285 
RemoveDropoutOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)286 int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
287   MS_ASSERT(anf_node != nullptr);
288   MS_ASSERT(manager != nullptr);
289   if (!utils::isa<CNodePtr>(anf_node)) {
290     MS_LOG(DEBUG) << "anf node is node a cnode.";
291     return lite::RET_NO_CHANGE;
292   }
293   auto cnode = anf_node->cast<CNodePtr>();
294   MS_ASSERT(cnode != nullptr);
295   if (cnode->size() > kInputSizeTwo) {
296     MS_LOG(ERROR) << "dropout input invalid.";
297     return lite::RET_ERROR;
298   }
299   if (!utils::isa<abstract::AbstractTuplePtr>(anf_node->abstract())) {
300     MS_LOG(DEBUG) << "dropout output size is one.";
301     manager->Replace(anf_node, cnode->input(1));
302   } else {
303     auto node_users = manager->node_users()[anf_node];
304     for (auto &node_user : node_users) {
305       auto node = node_user.first;
306       if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
307         MS_LOG(ERROR) << "dropout out node is invalid.";
308         return lite::RET_ERROR;
309       }
310       auto get_index_node = node->cast<CNodePtr>()->input(kInputIndexTwo)->cast<ValueNodePtr>();
311       if (get_index_node == nullptr) {
312         MS_LOG(ERROR) << "tuple get item node is invalid.";
313         return lite::RET_ERROR;
314       }
315       auto get_index = CastToInt(get_index_node->value()).front();
316       if (get_index > 0 && !manager->node_users()[node].empty()) {
317         MS_LOG(DEBUG) << "dropout's second output is useful.";
318         continue;
319       }
320       manager->Replace(node, cnode->input(1));
321     }
322   }
323   return lite::RET_OK;
324 }
325 
GetConstDataFromInputNode(const CNodePtr & cnode,lite::DataInfo * data_info)326 int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info) {
327   MS_ASSERT(cnode != nullptr);
328   MS_ASSERT(data_info != nullptr);
329   auto padding_node = cnode->input(kInputIndexTwo);
330   MS_ASSERT(padding_node != nullptr);
331   if (utils::isa<Parameter>(padding_node)) {
332     auto status = lite::FetchDataFromParameterNode(cnode, kIndexNum, converter::kFmkTypeMs, data_info, true);
333     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
334       MS_LOG(ERROR) << "fetch data from parameter node failed.";
335       return lite::RET_ERROR;
336     }
337   } else if (utils::isa<ValueNode>(padding_node)) {
338     auto status = lite::FetchDataFromValueNode(cnode, kIndexNum, converter::kFmkTypeMs, false, data_info, true);
339     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
340       MS_LOG(ERROR) << "fetch data from value node failed.";
341       return lite::RET_ERROR;
342     }
343   }
344   return lite::RET_OK;
345 }
346 
RemoveInvalidPadOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)347 int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
348   if (!utils::isa<CNodePtr>(anf_node)) {
349     MS_LOG(DEBUG) << "anf node is node a cnode.";
350     return lite::RET_NO_CHANGE;
351   }
352   auto cnode = anf_node->cast<CNodePtr>();
353   MS_ASSERT(cnode != nullptr);
354   auto primitive = GetValueNode<mindspore::PrimitivePtr>(cnode->input(0));
355   if (primitive == nullptr) {
356     MS_LOG(ERROR) << "primitive is nullptr:" << cnode->fullname_with_scope();
357     return lite::RET_NO_CHANGE;
358   }
359   auto is_invalid = true;
360   if (cnode->size() > kInputSizeTwo) {
361     lite::DataInfo data_info;
362     if (GetConstDataFromInputNode(cnode, &data_info) != RET_OK) {
363       MS_LOG(ERROR) << "Get pad data failed.";
364       return lite::RET_ERROR;
365     }
366     if (!data_info.data_.empty()) {
367       auto pad_data = reinterpret_cast<int *>(data_info.data_.data());
368       size_t num = data_info.data_.size() / sizeof(int);
369       for (size_t i = 0; i < num; ++i) {
370         if (pad_data[i] != 0) {
371           is_invalid = false;
372           break;
373         }
374       }
375     } else {
376       is_invalid = false;
377     }
378   } else {
379     auto pad_prim = api::MakeShared<mindspore::ops::PadFusion>(primitive);
380     MS_CHECK_TRUE_RET(pad_prim != nullptr, lite::RET_ERROR);
381     MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPaddings) != nullptr, lite::RET_ERROR);
382     auto pad_data = pad_prim->get_paddings();
383     for (size_t i = 0; i < pad_data.size(); i++) {
384       for (size_t j = 0; j < pad_data[i].size(); j++) {
385         if (pad_data[i][j] != 0) {
386           is_invalid = false;
387           break;
388         }
389       }
390       if (is_invalid == false) {
391         break;
392       }
393     }
394   }
395   if (is_invalid) {
396     return ReplaceOp(anf_node, manager);
397   }
398   return lite::RET_OK;
399 }
400 
RemoveInvalidTransposeOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)401 int RemoveRedundantOpPass::RemoveInvalidTransposeOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
402   auto cnode = anf_node->cast<CNodePtr>();
403   MS_ASSERT(cnode != nullptr);
404   if (cnode->size() != kInputSizeThree) {
405     MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
406     return lite::RET_NO_CHANGE;
407   }
408   auto index_node = cnode->inputs()[kInputIndexTwo]->cast<ParameterPtr>();
409   if (index_node == nullptr || !index_node->has_default()) {
410     return RET_OK;
411   }
412   auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(index_node->default_param());
413   MS_ASSERT(tensor_info != nullptr);
414   if (tensor_info->Size() != 0) {
415     return RET_OK;
416   }
417   return ReplaceOp(anf_node, manager);
418 }
419 
FlattenMakeTuple(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)420 int RemoveRedundantOpPass::FlattenMakeTuple(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
421   MS_ASSERT(func_graph != nullptr);
422   MS_ASSERT(manager != nullptr);
423   auto node_list = TopoSort(func_graph->get_return());
424   for (auto &node : node_list) {
425     auto cnode = node->cast<CNodePtr>();
426     if (!cnode) {
427       continue;
428     }
429     if (opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
430       std::vector<AnfNodePtr> new_inputs;
431       auto inputs = cnode->inputs();
432       new_inputs.push_back(inputs[0]);
433       bool has_make_tuple = false;
434       if (lite::GetFlattenInputsIfMakeTuple(cnode, &new_inputs, &has_make_tuple) != RET_OK) {
435         MS_LOG(WARNING) << "Failed to get flatten inputs of cnode, node " << cnode->fullname_with_scope();
436         continue;
437       }
438       if (has_make_tuple) {
439         auto new_cnode = func_graph->NewCNode(new_inputs);
440         MS_CHECK_TRUE_MSG(new_cnode != nullptr, RET_ERROR, "Failed to create New node.");
441         new_cnode->set_abstract(cnode->abstract());
442         new_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_flatten");
443         manager->Replace(cnode, new_cnode);
444       }
445     } else if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
446       auto real_node = opt::GetTupleGetItemRealInput(cnode);
447       if (!real_node) {
448         MS_LOG(WARNING) << "Failed to get tuple real input, node " << cnode->fullname_with_scope();
449         continue;
450       }
451       auto real_node_as_cnode = real_node->cast<CNodePtr>();
452       if (real_node_as_cnode && CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
453         auto idx = opt::GetTupleGetItemOutIndex(cnode);
454         manager->Replace(cnode, real_node_as_cnode->input(idx));
455       }
456     }
457   }
458   return RET_OK;
459 }
460 
RemoveUmonad(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)461 int RemoveRedundantOpPass::RemoveUmonad(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
462   MS_ASSERT(func_graph != nullptr);
463   MS_ASSERT(manager != nullptr);
464   auto node_list = TopoSort(func_graph->get_return());
465   for (auto &node : node_list) {
466     auto cnode = node->cast<CNodePtr>();
467     if (!cnode) {
468       continue;
469     }
470     if (!opt::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
471       continue;
472     }
473     if (cnode->size() < kDependInputSize) {
474       MS_LOG(ERROR) << "Depend input size " << cnode->size() << " cannot less than " << kDependInputSize;
475       continue;
476     }
477     auto depend_src = cnode->input(kIndex1);
478     auto depend_dst = cnode->input(kIndex2);
479     auto depend_dst_as_cnode = depend_dst->cast<CNodePtr>();
480     if (depend_dst_as_cnode && opt::CheckPrimitiveType(depend_dst_as_cnode, prim::kPrimUpdateState)) {
481       manager->Replace(cnode, depend_src);
482     }
483   }
484   return RET_OK;
485 }
486 
RemoveRedundantOp(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & node)487 int RemoveRedundantOpPass::RemoveRedundantOp(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
488                                              const AnfNodePtr &node) {
489   int status = RET_OK;
490   if (CheckPrimitiveType(node, kPrimIdentity)) {
491     status = ReplaceOp(node, manager);
492   }
493   if (CheckPrimitiveType(node, prim::kPrimLoad)) {
494     status = ReplaceOp(node, manager);
495   }
496   if (CheckPrimitiveType(node, prim::kPrimTensorMove)) {
497     status = ReplaceOp(node, manager);
498   }
499   if (CheckPrimitiveType(node, prim::kPrimUpdateState) && !keep_update_state_) {
500     status = ReplaceUpdateStateOp(func_graph, node);
501   }
502   if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
503     status = ReplaceTupleGetItem(node, manager);
504   }
505   if (!is_train_model_ && CheckPrimitiveType(node, prim::kPrimDropout)) {
506     status = RemoveDropoutOp(node, manager);
507   }
508   if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {
509     status = RemoveInvalidPadOp(node, manager);
510   }
511   if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
512     status = RemoveInvalidTransposeOp(node, manager);
513   }
514   if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
515     auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
516     if (sub_func_graph == nullptr) {
517       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
518       return lite::RET_NULL_PTR;
519     }
520     (void)Run(sub_func_graph);
521     sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2));
522     if (sub_func_graph == nullptr) {
523       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
524       return lite::RET_NULL_PTR;
525     }
526     (void)Run(sub_func_graph);
527   }
528   return status;
529 }
530 
Run(const FuncGraphPtr & func_graph)531 bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
532   MS_ASSERT(func_graph != nullptr);
533   auto manager = Manage(func_graph, true);
534   MS_ASSERT(manager != nullptr);
535   if (!is_train_model_) {
536     auto ret = RemoveUmonad(func_graph, manager);
537     if (ret != lite::RET_OK) {
538       MS_LOG(ERROR) << "remove umonad.";
539       return false;
540     }
541   }
542 
543   auto node_list = TopoSort(func_graph->get_return());
544   int status = RET_OK;
545   for (auto &node : node_list) {
546     if (!utils::isa<CNodePtr>(node)) {
547       continue;
548     }
549     status = RemoveRedundantOp(func_graph, manager, node);
550     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
551       MS_LOG(ERROR) << "remove identity pass is failed.";
552       return false;
553     }
554   }
555   for (auto &node : remove_cnode_) {
556     func_graph->DropNode(node);
557   }
558   FlattenMakeTuple(func_graph, manager);
559   remove_cnode_.clear();
560   return true;
561 }
562 }  // namespace mindspore::opt
563