• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/int64_cast_int32_pass.h"
19 #include <vector>
20 #include <memory>
21 #include "mindspore/core/ops/comparison_ops.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "ops/op_utils.h"
25 #include "ops/auto_generate/gen_lite_ops.h"
26 #include "ops/sequence_len.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "src/tensor.h"
29 #include "src/common/log_adapter.h"
30 #include "nnacl/op_base.h"
31 #include "src/common/utils.h"
32 
33 namespace mindspore::opt {
34 namespace {
35 constexpr size_t kNotEqualMinIndex = 3;
36 }  // namespace
37 
NotEqualInputsCheck(const CNodePtr & cnode)38 bool Int64CastInt32Pass::NotEqualInputsCheck(const CNodePtr &cnode) {
39   MS_ASSERT(cnode->size() == kNotEqualMinIndex);
40   auto abstract0 = GetCNodeInputAbstract(cnode, kInputIndexOne);
41   if (abstract0 == nullptr) {
42     MS_LOG(ERROR) << "Abstract of CNode is nullptr";
43     return false;
44   }
45   if (!utils::isa<abstract::AbstractTensorPtr>(abstract0)) {
46     MS_LOG(DEBUG) << "abstract is not AbstractTensor";
47     return false;
48   }
49   auto abstract1 = GetCNodeInputAbstract(cnode, kInputIndexTwo);
50   if (abstract1 == nullptr) {
51     MS_LOG(ERROR) << "Abstract of CNode is nullptr";
52     return false;
53   }
54   if (!utils::isa<abstract::AbstractTensorPtr>(abstract1)) {
55     MS_LOG(DEBUG) << "abstract is not AbstractTensor";
56     return false;
57   }
58   auto abstract0_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract0);
59   MS_ASSERT(abstract0_tensor != nullptr && abstract0_tensor->shape() != nullptr);
60   auto type0_ptr = abstract0_tensor->element()->GetTypeTrack();
61   MS_CHECK_TRUE_MSG(type0_ptr != nullptr, false, "type_ptr is nullptr");
62   auto abstract1_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract1);
63   MS_ASSERT(abstract1_tensor != nullptr && abstract1_tensor->shape() != nullptr);
64   auto type1_ptr = abstract1_tensor->element()->GetTypeTrack();
65   MS_CHECK_TRUE_MSG(type1_ptr != nullptr, false, "type_ptr is nullptr");
66   if (type0_ptr->type_id() == type1_ptr->type_id()) {
67     return true;
68   }
69   return false;
70 }
71 
Run(const FuncGraphPtr & graph)72 bool Int64CastInt32Pass::Run(const FuncGraphPtr &graph) {
73   MS_ASSERT(graph != nullptr);
74   bool change_flag = false;
75   auto node_list = TopoSort(graph->get_return());
76   for (auto &node : node_list) {
77     MS_ASSERT(node != nullptr);
78     if (!utils::isa<CNode>(node)) {
79       continue;
80     }
81     if (CheckPrimitiveType(node, prim::kPrimCast) || CheckPrimitiveType(node, prim::kPrimSplit) ||
82         CheckPrimitiveType(node, prim::kPrimGather) || CheckPrimitiveType(node, prim::kPrimCustom) ||
83         CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
84       continue;
85     }
86     auto cnode = node->cast<CNodePtr>();
87     MS_ASSERT(cnode != nullptr);
88     auto inputs_size = cnode->size();
89     if (CheckPrimitiveType(node, prim::kPrimNotEqual)) {
90       if (NotEqualInputsCheck(cnode)) {
91         continue;
92       }
93     }
94     if (!IsRealCNodeKernel(cnode)) {
95       continue;
96     }
97 
98     for (size_t index = kInputIndexOne; index < inputs_size; index++) {
99       auto abstract = GetCNodeInputAbstract(cnode, index);
100       if (abstract == nullptr) {
101         MS_LOG(DEBUG) << "Cnode " << cnode->fullname_with_scope() << " input " << index << " abstract is nullptr";
102         continue;
103       }
104       if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
105         MS_LOG(DEBUG) << "abstract is not AbstractTensor";
106         continue;
107       }
108       auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
109       MS_ASSERT(abstract_tensor != nullptr && abstract_tensor->shape() != nullptr);
110       auto type_ptr = abstract_tensor->element()->GetTypeTrack();
111       MS_CHECK_TRUE_MSG(type_ptr != nullptr, change_flag, "type_ptr is nullptr");
112       if (type_ptr->type_id() == mindspore::kNumberTypeInt64) {
113         auto new_cast = std::make_shared<mindspore::ops::Cast>();
114         MS_CHECK_TRUE_MSG(new_cast != nullptr, change_flag, "new_cast is nullptr");
115         auto new_cast_c = new_cast->GetPrim();
116         MS_CHECK_TRUE_MSG(new_cast_c != nullptr, change_flag, "new_cast_c is nullptr");
117         ValueNodePtr value_node = NewValueNode(new_cast_c);
118         MS_CHECK_TRUE_MSG(value_node != nullptr, change_flag, "NewValueNode Failed");
119 
120         auto param_node = opt::BuildIntValueParameterNode(
121           graph, static_cast<int32_t>(kNumberTypeInt32),
122           cnode->fullname_with_scope() + "_input" + std::to_string(index) + "_cast_type");
123 
124         auto cast_cnode = graph->NewCNode({value_node});
125         MS_CHECK_TRUE_MSG(cast_cnode != nullptr, change_flag, "new_cnode is nullptr");
126         cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_input" + std::to_string(index) +
127                                             "_pre_cast");
128         cast_cnode->set_abstract(abstract->Clone());
129         auto cast_abstract = cast_cnode->abstract();
130         MS_ASSERT(cast_abstract != nullptr);
131         cast_abstract->set_value(std::make_shared<ValueAny>());
132 
133         auto manager = Manage(graph);
134         auto input_node = cnode->input(index);
135         (void)manager->Replace(input_node, cast_cnode);
136         manager->AddEdge(cast_cnode, input_node);
137         manager->AddEdge(cast_cnode, param_node);
138 
139         change_flag = true;
140       }
141     }
142     if (change_flag) {
143       return change_flag;
144     }
145   }
146   return change_flag;
147 }
148 }  // namespace mindspore::opt
149