1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/cast_fusion.h"
19 #include <unordered_map>
20 #include <memory>
21 #include <vector>
22 #include <set>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/comparison_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "tools/converter/quantizer/quant_param_holder.h"
27 #include "tools/optimizer/common/format_utils.h"
28 #include "nnacl/op_base.h"
29 #include "tools/lite_exporter/fetch_content.h"
30
31 namespace mindspore::opt {
32 namespace {
IsGoodCastSplitFusion(const FuncGraphPtr & func_graph,const CNodePtr & split_cnode_2)33 bool IsGoodCastSplitFusion(const FuncGraphPtr &func_graph, const CNodePtr &split_cnode_2) {
34 auto manager = func_graph->manager();
35 MS_ASSERT(manager != nullptr);
36 MS_ASSERT(split_cnode_2 != nullptr);
37 auto node_users = manager->node_users();
38 auto split_node_users = node_users[split_cnode_2];
39 for (auto &node_user : split_node_users) {
40 auto post_node = node_user.first;
41 if (opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
42 auto post_item_nodes = node_users[post_node];
43 for (auto &post_item : post_item_nodes) {
44 auto post_item_node = post_item.first;
45 if (opt::CheckPrimitiveType(post_item_node, prim::kPrimGather) && post_item.second == kInputIndexTwo) {
46 continue;
47 }
48 if (opt::CheckPrimitiveType(post_item_node, prim::kPrimCast)) {
49 int post_item_cast_type;
50 if (GetCastDstDataType(post_item_node->cast<CNodePtr>(), &post_item_cast_type) != lite::RET_OK) {
51 MS_LOG(ERROR) << "get cast dst type failed.";
52 return false;
53 }
54 if (post_item_cast_type == kNumberTypeInt32) {
55 continue;
56 }
57 }
58 return false;
59 }
60 } else {
61 return false;
62 }
63 }
64 return true;
65 }
IsAnyNode(const BaseRef & n)66 bool IsAnyNode(const BaseRef &n) { return true; }
67 } // namespace
68
DefineCastCastPattern() const69 VectorRef CastFusionPass::DefineCastCastPattern() const {
70 auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
71 MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
72 auto is_cast2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
73 MS_CHECK_TRUE_RET(is_cast2 != nullptr, {});
74 auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
75 MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
76 VectorRef cast_cast_ref = VectorRef({is_cast2, is_cast1, is_weight_param});
77 return cast_cast_ref;
78 }
79
DefineCastGatherPattern() const80 VectorRef CastFusionPass::DefineCastGatherPattern() const {
81 auto is_any0 = std::make_shared<CondVar>(IsAnyNode);
82 MS_CHECK_TRUE_RET(is_any0 != nullptr, {});
83 auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
84 MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
85 auto is_gather2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimGather>);
86 MS_CHECK_TRUE_RET(is_gather2 != nullptr, {});
87 auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
88 MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
89 VectorRef cast_cast_ref = VectorRef({is_gather2, is_any0, is_cast1, is_weight_param});
90 return cast_cast_ref;
91 }
92
DefineCastEqualPattern() const93 VectorRef CastFusionPass::DefineCastEqualPattern() const {
94 auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
95 MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
96 auto is_notequal2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimNotEqual>);
97 MS_CHECK_TRUE_RET(is_notequal2 != nullptr, {});
98 auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
99 MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
100 VectorRef cast_cast_ref = VectorRef({is_notequal2, is_cast1, is_weight_param});
101 return cast_cast_ref;
102 }
103
DefineCastEqual2Pattern() const104 VectorRef CastFusionPass::DefineCastEqual2Pattern() const {
105 auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
106 MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
107 auto is_notequal2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimNotEqual>);
108 MS_CHECK_TRUE_RET(is_notequal2 != nullptr, {});
109 auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
110 MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
111 VectorRef cast_cast_ref = VectorRef({is_notequal2, is_weight_param, is_cast1});
112 return cast_cast_ref;
113 }
114
DefineCastSplitPattern() const115 VectorRef CastFusionPass::DefineCastSplitPattern() const {
116 auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
117 MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
118 auto is_split2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSplit>);
119 MS_CHECK_TRUE_RET(is_split2 != nullptr, {});
120 VectorRef cast_cast_ref = VectorRef({is_split2, is_cast1});
121 return cast_cast_ref;
122 }
123
DefinePatterns() const124 std::unordered_map<std::string, VectorRef> CastFusionPass::DefinePatterns() const {
125 std::unordered_map<std::string, VectorRef> patterns;
126 patterns["CastCastPatternName"] = DefineCastCastPattern();
127 patterns["CastGatherPatternName"] = DefineCastGatherPattern();
128 patterns["CastNotEqualPatternName"] = DefineCastEqualPattern();
129 patterns["CastNotEqual2PatternName"] = DefineCastEqual2Pattern();
130 patterns["CastSplitPatternName"] = DefineCastSplitPattern();
131 return patterns;
132 }
133
CastCastFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const134 AnfNodePtr CastFusionPass::CastCastFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
135 MS_ASSERT(func_graph != nullptr && node != nullptr);
136 auto cast_cnode_2 = node->cast<CNodePtr>();
137 MS_ASSERT(cast_cnode_2 != nullptr);
138 if (IsMarkedTrainOp(cast_cnode_2)) {
139 return nullptr;
140 }
141 MS_CHECK_TRUE_RET(cast_cnode_2 != nullptr, nullptr);
142 MS_CHECK_TRUE_RET(cast_cnode_2->size() == kInputSizeThree, nullptr);
143 if (!CheckPrimitiveType(cast_cnode_2, prim::kPrimCast) ||
144 !CheckPrimitiveType(cast_cnode_2->input(1), prim::kPrimCast)) {
145 return nullptr;
146 }
147 int post_cast_type;
148 if (GetCastDstDataType(cast_cnode_2, &post_cast_type) != lite::RET_OK) {
149 MS_LOG(ERROR) << "get cast dst type failed.";
150 return nullptr;
151 }
152 int pre_cast_type;
153 auto pre_node = cast_cnode_2->input(1);
154 MS_CHECK_TRUE_RET(pre_node != nullptr, nullptr);
155 auto pre_cnode = pre_node->cast<CNodePtr>();
156 if (pre_cnode == nullptr) {
157 return nullptr;
158 }
159 if (IsMarkedTrainOp(pre_cnode)) {
160 return nullptr;
161 }
162 if (GetCastDstDataType(pre_cnode, &pre_cast_type) != lite::RET_OK) {
163 MS_LOG(ERROR) << "get cast dst type failed.";
164 return nullptr;
165 }
166 auto pre_node_input = pre_cnode->input(1);
167 MS_CHECK_TRUE_RET(pre_node_input != nullptr, nullptr);
168 TypeId input_data_type;
169 if (GetDataTypeFromAnfNode(pre_node_input, &input_data_type) != RET_OK) {
170 MS_LOG(ERROR) << "get input node data type failed." << pre_node_input->fullname_with_scope();
171 return nullptr;
172 }
173
174 if (static_cast<int>(input_data_type) == post_cast_type) {
175 return pre_cnode->input(1);
176 }
177 auto manager = func_graph->manager();
178 MS_ASSERT(manager != nullptr);
179 manager->SetEdge(cast_cnode_2, 1, pre_cnode->input(1));
180 return nullptr;
181 }
182
CastGatherFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const183 AnfNodePtr CastFusionPass::CastGatherFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
184 MS_ASSERT(func_graph != nullptr && node != nullptr);
185 auto gather_cnode_2 = node->cast<CNodePtr>();
186 MS_ASSERT(gather_cnode_2 != nullptr);
187 if (IsMarkedTrainOp(gather_cnode_2)) {
188 return nullptr;
189 }
190 MS_CHECK_TRUE_RET(gather_cnode_2 != nullptr, nullptr);
191 MS_CHECK_TRUE_RET(gather_cnode_2->size() == kInputSizeFour, nullptr);
192 if (!CheckPrimitiveType(gather_cnode_2, prim::kPrimGather) ||
193 !CheckPrimitiveType(gather_cnode_2->input(kInputIndexTwo), prim::kPrimCast)) {
194 return nullptr;
195 }
196 auto pre_cast_cnode = gather_cnode_2->input(kInputIndexTwo)->cast<CNodePtr>();
197 MS_ASSERT(pre_cast_cnode != nullptr);
198 int post_cast_type;
199 if (GetCastDstDataType(pre_cast_cnode, &post_cast_type) != lite::RET_OK) {
200 MS_LOG(ERROR) << "get cast dst type failed.";
201 return nullptr;
202 }
203 auto pre_node_input = pre_cast_cnode->input(1);
204 MS_CHECK_TRUE_RET(pre_node_input != nullptr, nullptr);
205 TypeId input_data_type;
206 if (GetDataTypeFromAnfNode(pre_node_input, &input_data_type) != RET_OK) {
207 MS_LOG(ERROR) << "get input node data type failed." << pre_node_input->fullname_with_scope();
208 return nullptr;
209 }
210 const std::set<TypeId> support_dtype = {kNumberTypeInt64, kNumberTypeInt32, kNumberTypeBool};
211 if (support_dtype.find(input_data_type) != support_dtype.end() &&
212 support_dtype.find(static_cast<TypeId>(post_cast_type)) != support_dtype.end() &&
213 (static_cast<TypeId>(post_cast_type) != kNumberTypeBool || input_data_type == kNumberTypeBool)) {
214 auto manager = func_graph->manager();
215 MS_ASSERT(manager != nullptr);
216 manager->SetEdge(gather_cnode_2, kInputIndexTwo, pre_node_input);
217 return nullptr;
218 }
219 return nullptr;
220 }
221
CastSplitFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const222 AnfNodePtr CastFusionPass::CastSplitFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
223 MS_ASSERT(func_graph != nullptr && node != nullptr);
224 auto split_cnode_2 = node->cast<CNodePtr>();
225 MS_ASSERT(split_cnode_2 != nullptr);
226 if (IsMarkedTrainOp(split_cnode_2)) {
227 return nullptr;
228 }
229 MS_CHECK_TRUE_RET(split_cnode_2 != nullptr, nullptr);
230 MS_CHECK_TRUE_RET(split_cnode_2->size() == kInputSizeTwo, nullptr);
231 if (!CheckPrimitiveType(split_cnode_2, prim::kPrimSplit) ||
232 !CheckPrimitiveType(split_cnode_2->input(kInputIndexOne), prim::kPrimCast)) {
233 return nullptr;
234 }
235 auto pre_cast_cnode = split_cnode_2->input(kInputIndexOne)->cast<CNodePtr>();
236 MS_ASSERT(pre_cast_cnode != nullptr);
237 int post_cast_type;
238 if (GetCastDstDataType(pre_cast_cnode, &post_cast_type) != lite::RET_OK) {
239 MS_LOG(ERROR) << "get cast dst type failed.";
240 return nullptr;
241 }
242 auto pre_node_input = pre_cast_cnode->input(kInputIndexOne);
243 MS_CHECK_TRUE_RET(pre_node_input != nullptr, nullptr);
244 TypeId input_data_type;
245 if (GetDataTypeFromAnfNode(pre_node_input, &input_data_type) != RET_OK) {
246 MS_LOG(ERROR) << "get input node data type failed." << pre_node_input->fullname_with_scope();
247 return nullptr;
248 }
249 if (input_data_type == kNumberTypeInt32 && post_cast_type == kNumberTypeInt64) {
250 if (IsGoodCastSplitFusion(func_graph, split_cnode_2)) {
251 auto manager = func_graph->manager();
252 MS_ASSERT(manager != nullptr);
253 manager->SetEdge(split_cnode_2, kInputIndexOne, pre_node_input);
254 }
255 }
256 return nullptr;
257 }
258
CastNotEqualFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const259 AnfNodePtr CastFusionPass::CastNotEqualFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
260 MS_ASSERT(func_graph != nullptr && node != nullptr);
261 auto not_equal_cnode_2 = node->cast<CNodePtr>();
262 MS_ASSERT(not_equal_cnode_2 != nullptr);
263 if (IsMarkedTrainOp(not_equal_cnode_2)) {
264 return nullptr;
265 }
266 MS_CHECK_TRUE_RET(not_equal_cnode_2 != nullptr, nullptr);
267 MS_CHECK_TRUE_RET(not_equal_cnode_2->size() == kInputSizeThree, nullptr);
268 if (!CheckPrimitiveType(not_equal_cnode_2, prim::kPrimNotEqual)) {
269 return nullptr;
270 }
271 CNodePtr pre_cast_cnode;
272 ParameterPtr param_node;
273 lite::DataInfo data_info;
274 int status;
275 int cast_index;
276 if (CheckPrimitiveType(not_equal_cnode_2->input(kInputIndexOne), prim::kPrimCast)) {
277 cast_index = kInputIndexOne;
278 pre_cast_cnode = not_equal_cnode_2->input(kInputIndexOne)->cast<CNodePtr>();
279 MS_CHECK_TRUE_RET(IsParamNode(not_equal_cnode_2->input(kInputIndexTwo)), nullptr);
280 param_node = not_equal_cnode_2->input(kInputIndexTwo)->cast<ParameterPtr>();
281 status =
282 lite::FetchDataFromParameterNode(not_equal_cnode_2, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
283 } else if (CheckPrimitiveType(not_equal_cnode_2->input(kInputIndexTwo), prim::kPrimCast)) {
284 cast_index = kInputIndexTwo;
285 pre_cast_cnode = not_equal_cnode_2->input(kInputIndexTwo)->cast<CNodePtr>();
286 MS_CHECK_TRUE_RET(IsParamNode(not_equal_cnode_2->input(kInputIndexOne)), nullptr);
287 param_node = not_equal_cnode_2->input(kInputIndexOne)->cast<ParameterPtr>();
288 status =
289 lite::FetchDataFromParameterNode(not_equal_cnode_2, kInputIndexOne, converter::kFmkTypeMs, &data_info, true);
290 } else {
291 return nullptr;
292 }
293 if (status != lite::RET_OK) {
294 MS_LOG(ERROR) << "fetch transpose perm data failed.";
295 return nullptr;
296 }
297 int post_cast_type;
298 if (GetCastDstDataType(pre_cast_cnode, &post_cast_type) != lite::RET_OK) {
299 MS_LOG(ERROR) << "get cast dst type failed.";
300 return nullptr;
301 }
302 if (post_cast_type != data_info.data_type_) {
303 return nullptr;
304 }
305 if (data_info.data_type_ == kNumberTypeInt64) {
306 if (data_info.data_.size() < sizeof(int32_t)) {
307 MS_LOG(ERROR) << "Data and datatype of data-info not match.";
308 return nullptr;
309 }
310 auto p_data = reinterpret_cast<int64_t *>(data_info.data_.data());
311 for (size_t i = 0; i < (data_info.data_.size() / sizeof(int64_t)); i++) {
312 if ((p_data[i] > INT32_MAX) || (p_data[i] < INT_MIN)) {
313 return nullptr;
314 }
315 }
316 auto abstract = param_node->abstract();
317 MS_CHECK_TRUE_RET(abstract != nullptr, nullptr);
318 auto new_abstract = abstract->Clone();
319 new_abstract->set_value(std::make_shared<ValueAny>());
320 if (GenCastNode(func_graph, param_node, param_node->fullname_with_scope() + "_post_cast",
321 static_cast<TypeId>(kNumberTypeInt32), new_abstract) == nullptr) {
322 MS_LOG(ERROR) << "GenCastNode failed.";
323 return nullptr;
324 }
325 auto manager = func_graph->manager();
326 MS_ASSERT(manager != nullptr);
327 manager->SetEdge(not_equal_cnode_2, cast_index, pre_cast_cnode->input(kInputIndexOne));
328 }
329
330 return nullptr;
331 }
332
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const333 AnfNodePtr CastFusionPass::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
334 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
335 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
336 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
337 return nullptr;
338 }
339
340 if (pattern_name == "CastCastPatternName") {
341 return CastCastFusion(func_graph, node);
342 } else if (pattern_name == "CastGatherPatternName") {
343 return CastGatherFusion(func_graph, node);
344 } else if (pattern_name == "CastSplitPatternName") {
345 return CastSplitFusion(func_graph, node);
346 } else if (pattern_name == "CastNotEqualPatternName") {
347 return CastNotEqualFusion(func_graph, node);
348 } else if (pattern_name == "CastNotEqual2PatternName") {
349 return CastNotEqualFusion(func_graph, node);
350 }
351 return nullptr;
352 }
353 } // namespace mindspore::opt
354