1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <memory>
17 #include "tools/optimizer/graph/eliminate_redundant_cast_pass.h"
18 #include "mindspore/core/ops/array_ops.h"
19 #include "tools/optimizer/graph/infershape_pass.h"
20
21 namespace mindspore::opt {
RemoveCastOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)22 int EliminateRedundantCastPass::RemoveCastOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
23 const int expected_cast_input_count = 3;
24 auto cast_cnode = anf_node->cast<CNodePtr>();
25 MS_CHECK_TRUE_RET(cast_cnode->size() == expected_cast_input_count, lite::RET_NO_CHANGE);
26 TypeId first_type;
27 TypeId second_type;
28 if (opt::GetDataTypeFromAnfNode(cast_cnode->input(1), &first_type) != RET_OK) {
29 MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
30 return lite::RET_NO_CHANGE;
31 }
32
33 auto dst_type_tensor = cast_cnode->input(2)->cast<ParameterPtr>();
34 MS_CHECK_TRUE_RET(dst_type_tensor != nullptr, lite::RET_NO_CHANGE);
35 auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(dst_type_tensor->default_param());
36 MS_CHECK_TRUE_RET(tensor_info != nullptr, lite::RET_NO_CHANGE);
37 MS_CHECK_TRUE_RET(tensor_info->ElementsNum() == 1, lite::RET_NO_CHANGE);
38
39 second_type = static_cast<TypeId>(static_cast<int *>(tensor_info->data_c())[0]);
40 if (first_type == second_type) {
41 MS_LOG(DEBUG) << "Cast node " << anf_node->fullname_with_scope() << " is eliminated.";
42 (void)this->remove_cnode_.insert(anf_node);
43 return manager->Replace(anf_node, cast_cnode->input(1)) ? RET_OK : RET_ERROR;
44 }
45 return lite::RET_NO_CHANGE;
46 }
47
Run(const FuncGraphPtr & func_graph)48 bool EliminateRedundantCastPass::Run(const FuncGraphPtr &func_graph) {
49 MS_ASSERT(func_graph != nullptr);
50 auto infer_shape_pass = std::make_shared<InferShapePass>(this->fmk_type_, this->train_flag_);
51 if (!infer_shape_pass->Run(func_graph)) {
52 return true;
53 }
54 auto manager = func_graph->manager();
55 MS_CHECK_TRUE_RET(manager != nullptr, false);
56 auto node_list = TopoSort(func_graph->get_return());
57 int status = RET_OK;
58 for (auto &node : node_list) {
59 if (!utils::isa<CNodePtr>(node)) {
60 continue;
61 }
62 if (CheckPrimitiveType(node, prim::kPrimCast)) {
63 status = this->RemoveCastOp(node, manager);
64 }
65 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
66 MS_LOG(ERROR) << "Failed to run cast elimination pass.";
67 return false;
68 }
69 }
70 for (auto &node : this->remove_cnode_) {
71 func_graph->DropNode(node);
72 }
73 return true;
74 }
75 } // namespace mindspore::opt
76