• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <memory>
17 #include "tools/optimizer/graph/eliminate_redundant_cast_pass.h"
18 #include "mindspore/core/ops/array_ops.h"
19 #include "tools/optimizer/graph/infershape_pass.h"
20 
21 namespace mindspore::opt {
RemoveCastOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)22 int EliminateRedundantCastPass::RemoveCastOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
23   const int expected_cast_input_count = 3;
24   auto cast_cnode = anf_node->cast<CNodePtr>();
25   MS_CHECK_TRUE_RET(cast_cnode->size() == expected_cast_input_count, lite::RET_NO_CHANGE);
26   TypeId first_type;
27   TypeId second_type;
28   if (opt::GetDataTypeFromAnfNode(cast_cnode->input(1), &first_type) != RET_OK) {
29     MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
30     return lite::RET_NO_CHANGE;
31   }
32 
33   auto dst_type_tensor = cast_cnode->input(2)->cast<ParameterPtr>();
34   MS_CHECK_TRUE_RET(dst_type_tensor != nullptr, lite::RET_NO_CHANGE);
35   auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(dst_type_tensor->default_param());
36   MS_CHECK_TRUE_RET(tensor_info != nullptr, lite::RET_NO_CHANGE);
37   MS_CHECK_TRUE_RET(tensor_info->ElementsNum() == 1, lite::RET_NO_CHANGE);
38 
39   second_type = static_cast<TypeId>(static_cast<int *>(tensor_info->data_c())[0]);
40   if (first_type == second_type) {
41     MS_LOG(DEBUG) << "Cast node " << anf_node->fullname_with_scope() << " is eliminated.";
42     (void)this->remove_cnode_.insert(anf_node);
43     return manager->Replace(anf_node, cast_cnode->input(1)) ? RET_OK : RET_ERROR;
44   }
45   return lite::RET_NO_CHANGE;
46 }
47 
Run(const FuncGraphPtr & func_graph)48 bool EliminateRedundantCastPass::Run(const FuncGraphPtr &func_graph) {
49   MS_ASSERT(func_graph != nullptr);
50   auto infer_shape_pass = std::make_shared<InferShapePass>(this->fmk_type_, this->train_flag_);
51   if (!infer_shape_pass->Run(func_graph)) {
52     return true;
53   }
54   auto manager = func_graph->manager();
55   MS_CHECK_TRUE_RET(manager != nullptr, false);
56   auto node_list = TopoSort(func_graph->get_return());
57   int status = RET_OK;
58   for (auto &node : node_list) {
59     if (!utils::isa<CNodePtr>(node)) {
60       continue;
61     }
62     if (CheckPrimitiveType(node, prim::kPrimCast)) {
63       status = this->RemoveCastOp(node, manager);
64     }
65     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
66       MS_LOG(ERROR) << "Failed to run cast elimination pass.";
67       return false;
68     }
69   }
70   for (auto &node : this->remove_cnode_) {
71     func_graph->DropNode(node);
72   }
73   return true;
74 }
75 }  // namespace mindspore::opt
76