1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
19 #include <memory>
20 #include <vector>
21 #include <utility>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/nn_ops.h"
25 #include "mindspore/core/ops/lite_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "include/errorcode.h"
29 #include "tools/lite_exporter/fetch_content.h"
30 #include "ops/make_tuple.h"
31 #include "ops/depend.h"
32 #include "ops/fusion/pad_fusion.h"
33 #include "ops/op_utils.h"
34 #include "nnacl/op_base.h"
35 #include "include/common/utils/utils.h"
36
37 namespace mindspore::opt {
38 namespace {
39 const size_t kIndexNum = 2;
ReplaceUpdateStateWithMonad(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool remove_side_effect)40 int ReplaceUpdateStateWithMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool remove_side_effect) {
41 if (!remove_side_effect) {
42 return lite::RET_NO_CHANGE;
43 }
44 // only solve UpdateState with at lease one Monad input
45 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
46 AnfNodePtr monad_input = nullptr;
47 auto first_input = cnode->input(kInputIndexOne);
48 if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
49 first_input = first_input->cast<CNodePtr>()->input(kInputIndexOne);
50 MS_CHECK_TRUE_MSG(first_input != nullptr, RET_ERROR, "first_input is nullptr");
51 }
52 auto second_input = cnode->input(kInputIndexTwo);
53 if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
54 second_input = second_input->cast<CNodePtr>()->input(kInputIndexOne);
55 MS_CHECK_TRUE_MSG(second_input != nullptr, RET_ERROR, "second_input is nullptr");
56 }
57 if (utils::isa<ValueNode>(first_input)) {
58 auto value_node = first_input->cast<ValueNodePtr>();
59 MS_ASSERT(value_node->value() != nullptr);
60 if (utils::isa<Monad>(value_node->value())) {
61 monad_input = first_input;
62 }
63 }
64 if (utils::isa<ValueNode>(second_input)) {
65 auto value_node = second_input->cast<ValueNodePtr>();
66 MS_ASSERT(value_node->value() != nullptr);
67 if (utils::isa<Monad>(value_node->value())) {
68 monad_input = second_input;
69 }
70 }
71 MS_CHECK_TRUE_MSG(monad_input != nullptr, lite::RET_NO_CHANGE, "not find monad input");
72
73 // find monad input node, using monad node replace UpdateState node
74 auto manager = func_graph->manager();
75 MS_ASSERT(manager != nullptr);
76 manager->Replace(cnode, monad_input);
77 return lite::RET_OK;
78 }
79
ProcessInputIsMonad(const FuncGraphPtr & func_graph,const CNodePtr & cnode)80 int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
81 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
82 auto first_input = cnode->input(1);
83 MS_ASSERT(first_input != nullptr);
84 if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
85 first_input = cnode->input(1)->cast<CNodePtr>()->input(1);
86 MS_CHECK_TRUE_MSG(first_input != nullptr, RET_ERROR, "first_input is nullptr");
87 }
88 auto second_input = cnode->input(kInputIndexTwo);
89 MS_ASSERT(seconde_input != nullptr);
90 if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
91 second_input = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
92 MS_CHECK_TRUE_MSG(second_input != nullptr, RET_ERROR, "second_input is nullptr");
93 }
94 AnfNodePtr must_monad = nullptr;
95 AnfNodePtr not_must_monad = nullptr;
96 if (utils::isa<ValueNode>(first_input)) {
97 auto value_node = first_input->cast<ValueNodePtr>();
98 MS_ASSERT(value_node->value() != nullptr);
99 if (utils::isa<Monad>(value_node->value())) {
100 must_monad = first_input;
101 not_must_monad = second_input;
102 }
103 }
104 if (utils::isa<ValueNode>(second_input)) {
105 auto value_node = second_input->cast<ValueNodePtr>();
106 MS_ASSERT(value_node->value() != nullptr);
107 if (utils::isa<Monad>(value_node->value())) {
108 must_monad = second_input;
109 not_must_monad = first_input;
110 }
111 }
112 if (must_monad == nullptr) {
113 return lite::RET_NO_CHANGE;
114 }
115 auto manager = func_graph->manager();
116 MS_ASSERT(manager != nullptr);
117 if (!utils::isa<CNode>(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) {
118 manager->Replace(cnode, must_monad);
119 } else {
120 manager->Replace(cnode, not_must_monad);
121 }
122 return lite::RET_OK;
123 }
124
ProcessDependencyWithTwoNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool pre_node_is_first)125 int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) {
126 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
127 AnfNodePtr pre_node = cnode->input(1);
128 AnfNodePtr post_node = cnode->input(kInputIndexTwo);
129 MS_ASSERT(pre_node != nullptr);
130 MS_ASSERT(post_node != nullptr);
131 if (!pre_node_is_first) {
132 pre_node = cnode->input(kInputIndexTwo);
133 post_node = cnode->input(1);
134 }
135 if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) {
136 pre_node = cnode->input(1)->cast<CNodePtr>()->input(1);
137 MS_CHECK_TRUE_MSG(pre_node != nullptr, RET_ERROR, "pre_node is nullptr");
138 }
139 if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
140 post_node = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
141 MS_CHECK_TRUE_MSG(post_node != nullptr, RET_ERROR, "post_node is nullptr");
142 }
143 auto manager = func_graph->manager();
144 MS_ASSERT(manager != nullptr);
145 auto node_users = manager->node_users()[pre_node];
146 auto iter =
147 std::find_if(node_users.begin(), node_users.end(),
148 [&post_node](const std::pair<AnfNodePtr, int> &post_pair) { return post_pair.first == post_node; });
149 if (iter == node_users.end()) {
150 return lite::RET_NO_CHANGE;
151 }
152 auto tr = manager->Transact();
153 tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared<UMonad>()));
154 tr.Commit();
155 auto depend_prim = std::make_shared<ops::Depend>();
156 MS_CHECK_TRUE_MSG(depend_prim != nullptr, lite::RET_ERROR, "New Depend ops Failed");
157 auto depend_prim_c = depend_prim->GetPrim();
158 MS_CHECK_TRUE_MSG(depend_prim_c != nullptr, lite::RET_ERROR, "GetPrim Failed");
159 auto depend_node = func_graph->NewCNode(depend_prim_c, {post_node, pre_node});
160 MS_CHECK_TRUE_MSG(depend_node != nullptr, lite::RET_ERROR, "NewCNode Failed");
161 depend_node->set_fullname_with_scope(cnode->fullname_with_scope());
162 manager->Replace(cnode, depend_node);
163 return lite::RET_OK;
164 }
165
ProcessInputHaveDependency(const FuncGraphPtr & func_graph,const CNodePtr & cnode)166 int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
167 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
168 if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) {
169 return lite::RET_OK;
170 }
171 if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
172 return lite::RET_OK;
173 }
174 auto make_tuple_node = std::make_shared<ops::MakeTuple>();
175 MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, lite::RET_ERROR, "make tuple node Failed");
176 auto make_tuple_prim_c = make_tuple_node->GetPrim();
177 MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, lite::RET_ERROR, "make tuple prim c Failed");
178 auto make_tuple_prim = NewValueNode(make_tuple_prim_c);
179 MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_ERROR, "NewCNode Failed");
180 auto manager = func_graph->manager();
181 MS_ASSERT(manager != nullptr);
182 if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) {
183 manager->Replace(cnode->input(0)->cast<CNodePtr>()->input(0), make_tuple_prim);
184 return RET_OK;
185 }
186 manager->Replace(cnode->input(0), make_tuple_prim);
187 return lite::RET_OK;
188 }
189 } // namespace
190
ReplaceOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)191 int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
192 MS_CHECK_TRUE_MSG(anf_node != nullptr, RET_ERROR, "anf_node is nullptr");
193 MS_CHECK_TRUE_MSG(manager != nullptr, RET_ERROR, "manager is nullptr");
194 if (!utils::isa<CNodePtr>(anf_node)) {
195 MS_LOG(DEBUG) << "anf node is node a cnode.";
196 return lite::RET_NO_CHANGE;
197 }
198 auto cnode = anf_node->cast<CNodePtr>();
199 MS_ASSERT(cnode != nullptr);
200 if (CheckPrimitiveType(anf_node, kPrimIdentity)) {
201 if (cnode->size() != kInputSizeTwo) {
202 MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
203 remove_cnode_.insert(anf_node);
204 return lite::RET_NO_CHANGE;
205 }
206 }
207 if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) {
208 if (cnode->size() != kInputSizeTwo) {
209 MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
210 remove_cnode_.insert(anf_node);
211 return lite::RET_NO_CHANGE;
212 }
213 }
214 if (CheckPrimitiveType(anf_node, prim::kPrimTranspose)) {
215 if (cnode->size() != kInputSizeThree) {
216 MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
217 remove_cnode_.insert(anf_node);
218 return lite::RET_NO_CHANGE;
219 }
220 }
221
222 bool replace_succ = manager->Replace(anf_node, cnode->input(1));
223 if (!replace_succ) {
224 MS_LOG(ERROR) << "replace redundant op failed.";
225 return lite::RET_ERROR;
226 }
227 return RET_OK;
228 }
229
ReplaceUpdateStateOp(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node)230 int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
231 if (!utils::isa<CNodePtr>(anf_node)) {
232 MS_LOG(DEBUG) << "anf node is node a cnode.";
233 return lite::RET_NO_CHANGE;
234 }
235 auto cnode = anf_node->cast<CNodePtr>();
236 MS_ASSERT(cnode != nullptr);
237 if (ReplaceUpdateStateWithMonad(func_graph, cnode, remove_side_effect_) == lite::RET_OK) {
238 return lite::RET_OK;
239 }
240
241 if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) {
242 return lite::RET_OK;
243 }
244 // both of two inputs are not monad, but have dependency.
245 return ProcessInputHaveDependency(func_graph, cnode);
246 }
247
ReplaceTupleGetItem(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)248 int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
249 if (!utils::isa<CNodePtr>(anf_node)) {
250 MS_LOG(DEBUG) << "anf node is node a cnode.";
251 return lite::RET_NO_CHANGE;
252 }
253 if (!CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
254 return lite::RET_NO_CHANGE;
255 }
256 auto cnode = anf_node->cast<CNodePtr>();
257 MS_ASSERT(cnode != nullptr);
258 if (cnode->size() != kInputSizeThree) {
259 MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->size();
260 return RET_ERROR;
261 }
262 if (!CheckPrimitiveType(cnode->input(1), kPrimIdentity)) {
263 return lite::RET_NO_CHANGE;
264 }
265 auto get_item_input_cnode = cnode->input(1)->cast<CNodePtr>();
266 auto index_vnode = cnode->input(kInputIndexTwo);
267 if (!utils::isa<ValueNode>(index_vnode)) {
268 MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
269 return lite::RET_ERROR;
270 }
271 MS_CHECK_TRUE_MSG(!CastToInt(index_vnode->cast<ValueNodePtr>()->value()).empty(), RET_ERROR, "value is empty");
272 int index = CastToInt(index_vnode->cast<ValueNodePtr>()->value()).front();
273 int input_cnode_inputs_size = static_cast<int>(get_item_input_cnode->size());
274 if ((index + 1) >= input_cnode_inputs_size) {
275 MS_LOG(ERROR) << "value node index is out of range.";
276 return lite::RET_ERROR;
277 }
278 bool replace_succ = manager->Replace(anf_node, get_item_input_cnode->input(index + 1));
279 if (!replace_succ) {
280 MS_LOG(ERROR) << "replace identity failed.";
281 return lite::RET_ERROR;
282 }
283 return lite::RET_OK;
284 }
285
RemoveDropoutOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)286 int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
287 MS_ASSERT(anf_node != nullptr);
288 MS_ASSERT(manager != nullptr);
289 if (!utils::isa<CNodePtr>(anf_node)) {
290 MS_LOG(DEBUG) << "anf node is node a cnode.";
291 return lite::RET_NO_CHANGE;
292 }
293 auto cnode = anf_node->cast<CNodePtr>();
294 MS_ASSERT(cnode != nullptr);
295 if (cnode->size() > kInputSizeTwo) {
296 MS_LOG(ERROR) << "dropout input invalid.";
297 return lite::RET_ERROR;
298 }
299 if (!utils::isa<abstract::AbstractTuplePtr>(anf_node->abstract())) {
300 MS_LOG(DEBUG) << "dropout output size is one.";
301 manager->Replace(anf_node, cnode->input(1));
302 } else {
303 auto node_users = manager->node_users()[anf_node];
304 for (auto &node_user : node_users) {
305 auto node = node_user.first;
306 if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
307 MS_LOG(ERROR) << "dropout out node is invalid.";
308 return lite::RET_ERROR;
309 }
310 auto get_index_node = node->cast<CNodePtr>()->input(kInputIndexTwo)->cast<ValueNodePtr>();
311 if (get_index_node == nullptr) {
312 MS_LOG(ERROR) << "tuple get item node is invalid.";
313 return lite::RET_ERROR;
314 }
315 auto get_index = CastToInt(get_index_node->value()).front();
316 if (get_index > 0 && !manager->node_users()[node].empty()) {
317 MS_LOG(DEBUG) << "dropout's second output is useful.";
318 continue;
319 }
320 manager->Replace(node, cnode->input(1));
321 }
322 }
323 return lite::RET_OK;
324 }
325
GetConstDataFromInputNode(const CNodePtr & cnode,lite::DataInfo * data_info)326 int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info) {
327 MS_ASSERT(cnode != nullptr);
328 MS_ASSERT(data_info != nullptr);
329 auto padding_node = cnode->input(kInputIndexTwo);
330 MS_ASSERT(padding_node != nullptr);
331 if (utils::isa<Parameter>(padding_node)) {
332 auto status = lite::FetchDataFromParameterNode(cnode, kIndexNum, converter::kFmkTypeMs, data_info, true);
333 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
334 MS_LOG(ERROR) << "fetch data from parameter node failed.";
335 return lite::RET_ERROR;
336 }
337 } else if (utils::isa<ValueNode>(padding_node)) {
338 auto status = lite::FetchDataFromValueNode(cnode, kIndexNum, converter::kFmkTypeMs, false, data_info, true);
339 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
340 MS_LOG(ERROR) << "fetch data from value node failed.";
341 return lite::RET_ERROR;
342 }
343 }
344 return lite::RET_OK;
345 }
346
RemoveInvalidPadOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)347 int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
348 if (!utils::isa<CNodePtr>(anf_node)) {
349 MS_LOG(DEBUG) << "anf node is node a cnode.";
350 return lite::RET_NO_CHANGE;
351 }
352 auto cnode = anf_node->cast<CNodePtr>();
353 MS_ASSERT(cnode != nullptr);
354 auto primitive = GetValueNode<mindspore::PrimitivePtr>(cnode->input(0));
355 if (primitive == nullptr) {
356 MS_LOG(ERROR) << "primitive is nullptr:" << cnode->fullname_with_scope();
357 return lite::RET_NO_CHANGE;
358 }
359 auto is_invalid = true;
360 if (cnode->size() > kInputSizeTwo) {
361 lite::DataInfo data_info;
362 if (GetConstDataFromInputNode(cnode, &data_info) != RET_OK) {
363 MS_LOG(ERROR) << "Get pad data failed.";
364 return lite::RET_ERROR;
365 }
366 if (!data_info.data_.empty()) {
367 auto pad_data = reinterpret_cast<int *>(data_info.data_.data());
368 size_t num = data_info.data_.size() / sizeof(int);
369 for (size_t i = 0; i < num; ++i) {
370 if (pad_data[i] != 0) {
371 is_invalid = false;
372 break;
373 }
374 }
375 } else {
376 is_invalid = false;
377 }
378 } else {
379 auto pad_prim = api::MakeShared<mindspore::ops::PadFusion>(primitive);
380 MS_CHECK_TRUE_RET(pad_prim != nullptr, lite::RET_ERROR);
381 MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPaddings) != nullptr, lite::RET_ERROR);
382 auto pad_data = pad_prim->get_paddings();
383 for (size_t i = 0; i < pad_data.size(); i++) {
384 for (size_t j = 0; j < pad_data[i].size(); j++) {
385 if (pad_data[i][j] != 0) {
386 is_invalid = false;
387 break;
388 }
389 }
390 if (is_invalid == false) {
391 break;
392 }
393 }
394 }
395 if (is_invalid) {
396 return ReplaceOp(anf_node, manager);
397 }
398 return lite::RET_OK;
399 }
400
RemoveInvalidTransposeOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)401 int RemoveRedundantOpPass::RemoveInvalidTransposeOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
402 auto cnode = anf_node->cast<CNodePtr>();
403 MS_ASSERT(cnode != nullptr);
404 if (cnode->size() != kInputSizeThree) {
405 MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
406 return lite::RET_NO_CHANGE;
407 }
408 auto index_node = cnode->inputs()[kInputIndexTwo]->cast<ParameterPtr>();
409 if (index_node == nullptr || !index_node->has_default()) {
410 return RET_OK;
411 }
412 auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(index_node->default_param());
413 MS_ASSERT(tensor_info != nullptr);
414 if (tensor_info->Size() != 0) {
415 return RET_OK;
416 }
417 return ReplaceOp(anf_node, manager);
418 }
419
FlattenMakeTuple(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)420 int RemoveRedundantOpPass::FlattenMakeTuple(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
421 MS_ASSERT(func_graph != nullptr);
422 MS_ASSERT(manager != nullptr);
423 auto node_list = TopoSort(func_graph->get_return());
424 for (auto &node : node_list) {
425 auto cnode = node->cast<CNodePtr>();
426 if (!cnode) {
427 continue;
428 }
429 if (opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
430 std::vector<AnfNodePtr> new_inputs;
431 auto inputs = cnode->inputs();
432 new_inputs.push_back(inputs[0]);
433 bool has_make_tuple = false;
434 if (lite::GetFlattenInputsIfMakeTuple(cnode, &new_inputs, &has_make_tuple) != RET_OK) {
435 MS_LOG(WARNING) << "Failed to get flatten inputs of cnode, node " << cnode->fullname_with_scope();
436 continue;
437 }
438 if (has_make_tuple) {
439 auto new_cnode = func_graph->NewCNode(new_inputs);
440 MS_CHECK_TRUE_MSG(new_cnode != nullptr, RET_ERROR, "Failed to create New node.");
441 new_cnode->set_abstract(cnode->abstract());
442 new_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_flatten");
443 manager->Replace(cnode, new_cnode);
444 }
445 } else if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
446 auto real_node = opt::GetTupleGetItemRealInput(cnode);
447 if (!real_node) {
448 MS_LOG(WARNING) << "Failed to get tuple real input, node " << cnode->fullname_with_scope();
449 continue;
450 }
451 auto real_node_as_cnode = real_node->cast<CNodePtr>();
452 if (real_node_as_cnode && CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
453 auto idx = opt::GetTupleGetItemOutIndex(cnode);
454 manager->Replace(cnode, real_node_as_cnode->input(idx));
455 }
456 }
457 }
458 return RET_OK;
459 }
460
RemoveUmonad(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)461 int RemoveRedundantOpPass::RemoveUmonad(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
462 MS_ASSERT(func_graph != nullptr);
463 MS_ASSERT(manager != nullptr);
464 auto node_list = TopoSort(func_graph->get_return());
465 for (auto &node : node_list) {
466 auto cnode = node->cast<CNodePtr>();
467 if (!cnode) {
468 continue;
469 }
470 if (!opt::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
471 continue;
472 }
473 if (cnode->size() < kDependInputSize) {
474 MS_LOG(ERROR) << "Depend input size " << cnode->size() << " cannot less than " << kDependInputSize;
475 continue;
476 }
477 auto depend_src = cnode->input(kIndex1);
478 auto depend_dst = cnode->input(kIndex2);
479 auto depend_dst_as_cnode = depend_dst->cast<CNodePtr>();
480 if (depend_dst_as_cnode && opt::CheckPrimitiveType(depend_dst_as_cnode, prim::kPrimUpdateState)) {
481 manager->Replace(cnode, depend_src);
482 }
483 }
484 return RET_OK;
485 }
486
RemoveRedundantOp(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & node)487 int RemoveRedundantOpPass::RemoveRedundantOp(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
488 const AnfNodePtr &node) {
489 int status = RET_OK;
490 if (CheckPrimitiveType(node, kPrimIdentity)) {
491 status = ReplaceOp(node, manager);
492 }
493 if (CheckPrimitiveType(node, prim::kPrimLoad)) {
494 status = ReplaceOp(node, manager);
495 }
496 if (CheckPrimitiveType(node, prim::kPrimTensorMove)) {
497 status = ReplaceOp(node, manager);
498 }
499 if (CheckPrimitiveType(node, prim::kPrimUpdateState) && !keep_update_state_) {
500 status = ReplaceUpdateStateOp(func_graph, node);
501 }
502 if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
503 status = ReplaceTupleGetItem(node, manager);
504 }
505 if (!is_train_model_ && CheckPrimitiveType(node, prim::kPrimDropout)) {
506 status = RemoveDropoutOp(node, manager);
507 }
508 if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {
509 status = RemoveInvalidPadOp(node, manager);
510 }
511 if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
512 status = RemoveInvalidTransposeOp(node, manager);
513 }
514 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
515 auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
516 if (sub_func_graph == nullptr) {
517 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
518 return lite::RET_NULL_PTR;
519 }
520 (void)Run(sub_func_graph);
521 sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2));
522 if (sub_func_graph == nullptr) {
523 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
524 return lite::RET_NULL_PTR;
525 }
526 (void)Run(sub_func_graph);
527 }
528 return status;
529 }
530
Run(const FuncGraphPtr & func_graph)531 bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
532 MS_ASSERT(func_graph != nullptr);
533 auto manager = Manage(func_graph, true);
534 MS_ASSERT(manager != nullptr);
535 if (!is_train_model_) {
536 auto ret = RemoveUmonad(func_graph, manager);
537 if (ret != lite::RET_OK) {
538 MS_LOG(ERROR) << "remove umonad.";
539 return false;
540 }
541 }
542
543 auto node_list = TopoSort(func_graph->get_return());
544 int status = RET_OK;
545 for (auto &node : node_list) {
546 if (!utils::isa<CNodePtr>(node)) {
547 continue;
548 }
549 status = RemoveRedundantOp(func_graph, manager, node);
550 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
551 MS_LOG(ERROR) << "remove identity pass is failed.";
552 return false;
553 }
554 }
555 for (auto &node : remove_cnode_) {
556 func_graph->DropNode(node);
557 }
558 FlattenMakeTuple(func_graph, manager);
559 remove_cnode_.clear();
560 return true;
561 }
562 } // namespace mindspore::opt
563