1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/int64_cast_int32_pass.h"
19 #include <vector>
20 #include <memory>
21 #include "mindspore/core/ops/comparison_ops.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "ops/op_utils.h"
25 #include "ops/auto_generate/gen_lite_ops.h"
26 #include "ops/sequence_len.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "src/tensor.h"
29 #include "src/common/log_adapter.h"
30 #include "nnacl/op_base.h"
31 #include "src/common/utils.h"
32
33 namespace mindspore::opt {
34 namespace {
35 constexpr size_t kNotEqualMinIndex = 3;
36 } // namespace
37
NotEqualInputsCheck(const CNodePtr & cnode)38 bool Int64CastInt32Pass::NotEqualInputsCheck(const CNodePtr &cnode) {
39 MS_ASSERT(cnode->size() == kNotEqualMinIndex);
40 auto abstract0 = GetCNodeInputAbstract(cnode, kInputIndexOne);
41 if (abstract0 == nullptr) {
42 MS_LOG(ERROR) << "Abstract of CNode is nullptr";
43 return false;
44 }
45 if (!utils::isa<abstract::AbstractTensorPtr>(abstract0)) {
46 MS_LOG(DEBUG) << "abstract is not AbstractTensor";
47 return false;
48 }
49 auto abstract1 = GetCNodeInputAbstract(cnode, kInputIndexTwo);
50 if (abstract1 == nullptr) {
51 MS_LOG(ERROR) << "Abstract of CNode is nullptr";
52 return false;
53 }
54 if (!utils::isa<abstract::AbstractTensorPtr>(abstract1)) {
55 MS_LOG(DEBUG) << "abstract is not AbstractTensor";
56 return false;
57 }
58 auto abstract0_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract0);
59 MS_ASSERT(abstract0_tensor != nullptr && abstract0_tensor->shape() != nullptr);
60 auto type0_ptr = abstract0_tensor->element()->GetTypeTrack();
61 MS_CHECK_TRUE_MSG(type0_ptr != nullptr, false, "type_ptr is nullptr");
62 auto abstract1_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract1);
63 MS_ASSERT(abstract1_tensor != nullptr && abstract1_tensor->shape() != nullptr);
64 auto type1_ptr = abstract1_tensor->element()->GetTypeTrack();
65 MS_CHECK_TRUE_MSG(type1_ptr != nullptr, false, "type_ptr is nullptr");
66 if (type0_ptr->type_id() == type1_ptr->type_id()) {
67 return true;
68 }
69 return false;
70 }
71
Run(const FuncGraphPtr & graph)72 bool Int64CastInt32Pass::Run(const FuncGraphPtr &graph) {
73 MS_ASSERT(graph != nullptr);
74 bool change_flag = false;
75 auto node_list = TopoSort(graph->get_return());
76 for (auto &node : node_list) {
77 MS_ASSERT(node != nullptr);
78 if (!utils::isa<CNode>(node)) {
79 continue;
80 }
81 if (CheckPrimitiveType(node, prim::kPrimCast) || CheckPrimitiveType(node, prim::kPrimSplit) ||
82 CheckPrimitiveType(node, prim::kPrimGather) || CheckPrimitiveType(node, prim::kPrimCustom) ||
83 CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
84 continue;
85 }
86 auto cnode = node->cast<CNodePtr>();
87 MS_ASSERT(cnode != nullptr);
88 auto inputs_size = cnode->size();
89 if (CheckPrimitiveType(node, prim::kPrimNotEqual)) {
90 if (NotEqualInputsCheck(cnode)) {
91 continue;
92 }
93 }
94 if (!IsRealCNodeKernel(cnode)) {
95 continue;
96 }
97
98 for (size_t index = kInputIndexOne; index < inputs_size; index++) {
99 auto abstract = GetCNodeInputAbstract(cnode, index);
100 if (abstract == nullptr) {
101 MS_LOG(DEBUG) << "Cnode " << cnode->fullname_with_scope() << " input " << index << " abstract is nullptr";
102 continue;
103 }
104 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
105 MS_LOG(DEBUG) << "abstract is not AbstractTensor";
106 continue;
107 }
108 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
109 MS_ASSERT(abstract_tensor != nullptr && abstract_tensor->shape() != nullptr);
110 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
111 MS_CHECK_TRUE_MSG(type_ptr != nullptr, change_flag, "type_ptr is nullptr");
112 if (type_ptr->type_id() == mindspore::kNumberTypeInt64) {
113 auto new_cast = std::make_shared<mindspore::ops::Cast>();
114 MS_CHECK_TRUE_MSG(new_cast != nullptr, change_flag, "new_cast is nullptr");
115 auto new_cast_c = new_cast->GetPrim();
116 MS_CHECK_TRUE_MSG(new_cast_c != nullptr, change_flag, "new_cast_c is nullptr");
117 ValueNodePtr value_node = NewValueNode(new_cast_c);
118 MS_CHECK_TRUE_MSG(value_node != nullptr, change_flag, "NewValueNode Failed");
119
120 auto param_node = opt::BuildIntValueParameterNode(
121 graph, static_cast<int32_t>(kNumberTypeInt32),
122 cnode->fullname_with_scope() + "_input" + std::to_string(index) + "_cast_type");
123
124 auto cast_cnode = graph->NewCNode({value_node});
125 MS_CHECK_TRUE_MSG(cast_cnode != nullptr, change_flag, "new_cnode is nullptr");
126 cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_input" + std::to_string(index) +
127 "_pre_cast");
128 cast_cnode->set_abstract(abstract->Clone());
129 auto cast_abstract = cast_cnode->abstract();
130 MS_ASSERT(cast_abstract != nullptr);
131 cast_abstract->set_value(std::make_shared<ValueAny>());
132
133 auto manager = Manage(graph);
134 auto input_node = cnode->input(index);
135 (void)manager->Replace(input_node, cast_cnode);
136 manager->AddEdge(cast_cnode, input_node);
137 manager->AddEdge(cast_cnode, param_node);
138
139 change_flag = true;
140 }
141 }
142 if (change_flag) {
143 return change_flag;
144 }
145 }
146 return change_flag;
147 }
148 } // namespace mindspore::opt
149