• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/cast_fusion.h"
19 #include <unordered_map>
20 #include <memory>
21 #include <vector>
22 #include <set>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/comparison_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "tools/converter/quantizer/quant_param_holder.h"
27 #include "tools/optimizer/common/format_utils.h"
28 #include "nnacl/op_base.h"
29 #include "tools/lite_exporter/fetch_content.h"
30 
31 namespace mindspore::opt {
32 namespace {
IsGoodCastSplitFusion(const FuncGraphPtr & func_graph,const CNodePtr & split_cnode_2)33 bool IsGoodCastSplitFusion(const FuncGraphPtr &func_graph, const CNodePtr &split_cnode_2) {
34   auto manager = func_graph->manager();
35   MS_ASSERT(manager != nullptr);
36   MS_ASSERT(split_cnode_2 != nullptr);
37   auto node_users = manager->node_users();
38   auto split_node_users = node_users[split_cnode_2];
39   for (auto &node_user : split_node_users) {
40     auto post_node = node_user.first;
41     if (opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
42       auto post_item_nodes = node_users[post_node];
43       for (auto &post_item : post_item_nodes) {
44         auto post_item_node = post_item.first;
45         if (opt::CheckPrimitiveType(post_item_node, prim::kPrimGather) && post_item.second == kInputIndexTwo) {
46           continue;
47         }
48         if (opt::CheckPrimitiveType(post_item_node, prim::kPrimCast)) {
49           int post_item_cast_type;
50           if (GetCastDstDataType(post_item_node->cast<CNodePtr>(), &post_item_cast_type) != lite::RET_OK) {
51             MS_LOG(ERROR) << "get cast dst type failed.";
52             return false;
53           }
54           if (post_item_cast_type == kNumberTypeInt32) {
55             continue;
56           }
57         }
58         return false;
59       }
60     } else {
61       return false;
62     }
63   }
64   return true;
65 }
IsAnyNode(const BaseRef & n)66 bool IsAnyNode(const BaseRef &n) { return true; }
67 }  // namespace
68 
DefineCastCastPattern() const69 VectorRef CastFusionPass::DefineCastCastPattern() const {
70   auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
71   MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
72   auto is_cast2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
73   MS_CHECK_TRUE_RET(is_cast2 != nullptr, {});
74   auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
75   MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
76   VectorRef cast_cast_ref = VectorRef({is_cast2, is_cast1, is_weight_param});
77   return cast_cast_ref;
78 }
79 
DefineCastGatherPattern() const80 VectorRef CastFusionPass::DefineCastGatherPattern() const {
81   auto is_any0 = std::make_shared<CondVar>(IsAnyNode);
82   MS_CHECK_TRUE_RET(is_any0 != nullptr, {});
83   auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
84   MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
85   auto is_gather2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimGather>);
86   MS_CHECK_TRUE_RET(is_gather2 != nullptr, {});
87   auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
88   MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
89   VectorRef cast_cast_ref = VectorRef({is_gather2, is_any0, is_cast1, is_weight_param});
90   return cast_cast_ref;
91 }
92 
DefineCastEqualPattern() const93 VectorRef CastFusionPass::DefineCastEqualPattern() const {
94   auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
95   MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
96   auto is_notequal2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimNotEqual>);
97   MS_CHECK_TRUE_RET(is_notequal2 != nullptr, {});
98   auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
99   MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
100   VectorRef cast_cast_ref = VectorRef({is_notequal2, is_cast1, is_weight_param});
101   return cast_cast_ref;
102 }
103 
DefineCastEqual2Pattern() const104 VectorRef CastFusionPass::DefineCastEqual2Pattern() const {
105   auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
106   MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
107   auto is_notequal2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimNotEqual>);
108   MS_CHECK_TRUE_RET(is_notequal2 != nullptr, {});
109   auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
110   MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
111   VectorRef cast_cast_ref = VectorRef({is_notequal2, is_weight_param, is_cast1});
112   return cast_cast_ref;
113 }
114 
DefineCastSplitPattern() const115 VectorRef CastFusionPass::DefineCastSplitPattern() const {
116   auto is_cast1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
117   MS_CHECK_TRUE_RET(is_cast1 != nullptr, {});
118   auto is_split2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSplit>);
119   MS_CHECK_TRUE_RET(is_split2 != nullptr, {});
120   VectorRef cast_cast_ref = VectorRef({is_split2, is_cast1});
121   return cast_cast_ref;
122 }
123 
DefinePatterns() const124 std::unordered_map<std::string, VectorRef> CastFusionPass::DefinePatterns() const {
125   std::unordered_map<std::string, VectorRef> patterns;
126   patterns["CastCastPatternName"] = DefineCastCastPattern();
127   patterns["CastGatherPatternName"] = DefineCastGatherPattern();
128   patterns["CastNotEqualPatternName"] = DefineCastEqualPattern();
129   patterns["CastNotEqual2PatternName"] = DefineCastEqual2Pattern();
130   patterns["CastSplitPatternName"] = DefineCastSplitPattern();
131   return patterns;
132 }
133 
CastCastFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const134 AnfNodePtr CastFusionPass::CastCastFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
135   MS_ASSERT(func_graph != nullptr && node != nullptr);
136   auto cast_cnode_2 = node->cast<CNodePtr>();
137   MS_ASSERT(cast_cnode_2 != nullptr);
138   if (IsMarkedTrainOp(cast_cnode_2)) {
139     return nullptr;
140   }
141   MS_CHECK_TRUE_RET(cast_cnode_2 != nullptr, nullptr);
142   MS_CHECK_TRUE_RET(cast_cnode_2->size() == kInputSizeThree, nullptr);
143   if (!CheckPrimitiveType(cast_cnode_2, prim::kPrimCast) ||
144       !CheckPrimitiveType(cast_cnode_2->input(1), prim::kPrimCast)) {
145     return nullptr;
146   }
147   int post_cast_type;
148   if (GetCastDstDataType(cast_cnode_2, &post_cast_type) != lite::RET_OK) {
149     MS_LOG(ERROR) << "get cast dst type failed.";
150     return nullptr;
151   }
152   int pre_cast_type;
153   auto pre_node = cast_cnode_2->input(1);
154   MS_CHECK_TRUE_RET(pre_node != nullptr, nullptr);
155   auto pre_cnode = pre_node->cast<CNodePtr>();
156   if (pre_cnode == nullptr) {
157     return nullptr;
158   }
159   if (IsMarkedTrainOp(pre_cnode)) {
160     return nullptr;
161   }
162   if (GetCastDstDataType(pre_cnode, &pre_cast_type) != lite::RET_OK) {
163     MS_LOG(ERROR) << "get cast dst type failed.";
164     return nullptr;
165   }
166   auto pre_node_input = pre_cnode->input(1);
167   MS_CHECK_TRUE_RET(pre_node_input != nullptr, nullptr);
168   TypeId input_data_type;
169   if (GetDataTypeFromAnfNode(pre_node_input, &input_data_type) != RET_OK) {
170     MS_LOG(ERROR) << "get input node data type failed." << pre_node_input->fullname_with_scope();
171     return nullptr;
172   }
173 
174   if (static_cast<int>(input_data_type) == post_cast_type) {
175     return pre_cnode->input(1);
176   }
177   auto manager = func_graph->manager();
178   MS_ASSERT(manager != nullptr);
179   manager->SetEdge(cast_cnode_2, 1, pre_cnode->input(1));
180   return nullptr;
181 }
182 
CastGatherFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const183 AnfNodePtr CastFusionPass::CastGatherFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
184   MS_ASSERT(func_graph != nullptr && node != nullptr);
185   auto gather_cnode_2 = node->cast<CNodePtr>();
186   MS_ASSERT(gather_cnode_2 != nullptr);
187   if (IsMarkedTrainOp(gather_cnode_2)) {
188     return nullptr;
189   }
190   MS_CHECK_TRUE_RET(gather_cnode_2 != nullptr, nullptr);
191   MS_CHECK_TRUE_RET(gather_cnode_2->size() == kInputSizeFour, nullptr);
192   if (!CheckPrimitiveType(gather_cnode_2, prim::kPrimGather) ||
193       !CheckPrimitiveType(gather_cnode_2->input(kInputIndexTwo), prim::kPrimCast)) {
194     return nullptr;
195   }
196   auto pre_cast_cnode = gather_cnode_2->input(kInputIndexTwo)->cast<CNodePtr>();
197   MS_ASSERT(pre_cast_cnode != nullptr);
198   int post_cast_type;
199   if (GetCastDstDataType(pre_cast_cnode, &post_cast_type) != lite::RET_OK) {
200     MS_LOG(ERROR) << "get cast dst type failed.";
201     return nullptr;
202   }
203   auto pre_node_input = pre_cast_cnode->input(1);
204   MS_CHECK_TRUE_RET(pre_node_input != nullptr, nullptr);
205   TypeId input_data_type;
206   if (GetDataTypeFromAnfNode(pre_node_input, &input_data_type) != RET_OK) {
207     MS_LOG(ERROR) << "get input node data type failed." << pre_node_input->fullname_with_scope();
208     return nullptr;
209   }
210   const std::set<TypeId> support_dtype = {kNumberTypeInt64, kNumberTypeInt32, kNumberTypeBool};
211   if (support_dtype.find(input_data_type) != support_dtype.end() &&
212       support_dtype.find(static_cast<TypeId>(post_cast_type)) != support_dtype.end() &&
213       (static_cast<TypeId>(post_cast_type) != kNumberTypeBool || input_data_type == kNumberTypeBool)) {
214     auto manager = func_graph->manager();
215     MS_ASSERT(manager != nullptr);
216     manager->SetEdge(gather_cnode_2, kInputIndexTwo, pre_node_input);
217     return nullptr;
218   }
219   return nullptr;
220 }
221 
CastSplitFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const222 AnfNodePtr CastFusionPass::CastSplitFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
223   MS_ASSERT(func_graph != nullptr && node != nullptr);
224   auto split_cnode_2 = node->cast<CNodePtr>();
225   MS_ASSERT(split_cnode_2 != nullptr);
226   if (IsMarkedTrainOp(split_cnode_2)) {
227     return nullptr;
228   }
229   MS_CHECK_TRUE_RET(split_cnode_2 != nullptr, nullptr);
230   MS_CHECK_TRUE_RET(split_cnode_2->size() == kInputSizeTwo, nullptr);
231   if (!CheckPrimitiveType(split_cnode_2, prim::kPrimSplit) ||
232       !CheckPrimitiveType(split_cnode_2->input(kInputIndexOne), prim::kPrimCast)) {
233     return nullptr;
234   }
235   auto pre_cast_cnode = split_cnode_2->input(kInputIndexOne)->cast<CNodePtr>();
236   MS_ASSERT(pre_cast_cnode != nullptr);
237   int post_cast_type;
238   if (GetCastDstDataType(pre_cast_cnode, &post_cast_type) != lite::RET_OK) {
239     MS_LOG(ERROR) << "get cast dst type failed.";
240     return nullptr;
241   }
242   auto pre_node_input = pre_cast_cnode->input(kInputIndexOne);
243   MS_CHECK_TRUE_RET(pre_node_input != nullptr, nullptr);
244   TypeId input_data_type;
245   if (GetDataTypeFromAnfNode(pre_node_input, &input_data_type) != RET_OK) {
246     MS_LOG(ERROR) << "get input node data type failed." << pre_node_input->fullname_with_scope();
247     return nullptr;
248   }
249   if (input_data_type == kNumberTypeInt32 && post_cast_type == kNumberTypeInt64) {
250     if (IsGoodCastSplitFusion(func_graph, split_cnode_2)) {
251       auto manager = func_graph->manager();
252       MS_ASSERT(manager != nullptr);
253       manager->SetEdge(split_cnode_2, kInputIndexOne, pre_node_input);
254     }
255   }
256   return nullptr;
257 }
258 
CastNotEqualFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const259 AnfNodePtr CastFusionPass::CastNotEqualFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
260   MS_ASSERT(func_graph != nullptr && node != nullptr);
261   auto not_equal_cnode_2 = node->cast<CNodePtr>();
262   MS_ASSERT(not_equal_cnode_2 != nullptr);
263   if (IsMarkedTrainOp(not_equal_cnode_2)) {
264     return nullptr;
265   }
266   MS_CHECK_TRUE_RET(not_equal_cnode_2 != nullptr, nullptr);
267   MS_CHECK_TRUE_RET(not_equal_cnode_2->size() == kInputSizeThree, nullptr);
268   if (!CheckPrimitiveType(not_equal_cnode_2, prim::kPrimNotEqual)) {
269     return nullptr;
270   }
271   CNodePtr pre_cast_cnode;
272   ParameterPtr param_node;
273   lite::DataInfo data_info;
274   int status;
275   int cast_index;
276   if (CheckPrimitiveType(not_equal_cnode_2->input(kInputIndexOne), prim::kPrimCast)) {
277     cast_index = kInputIndexOne;
278     pre_cast_cnode = not_equal_cnode_2->input(kInputIndexOne)->cast<CNodePtr>();
279     MS_CHECK_TRUE_RET(IsParamNode(not_equal_cnode_2->input(kInputIndexTwo)), nullptr);
280     param_node = not_equal_cnode_2->input(kInputIndexTwo)->cast<ParameterPtr>();
281     status =
282       lite::FetchDataFromParameterNode(not_equal_cnode_2, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
283   } else if (CheckPrimitiveType(not_equal_cnode_2->input(kInputIndexTwo), prim::kPrimCast)) {
284     cast_index = kInputIndexTwo;
285     pre_cast_cnode = not_equal_cnode_2->input(kInputIndexTwo)->cast<CNodePtr>();
286     MS_CHECK_TRUE_RET(IsParamNode(not_equal_cnode_2->input(kInputIndexOne)), nullptr);
287     param_node = not_equal_cnode_2->input(kInputIndexOne)->cast<ParameterPtr>();
288     status =
289       lite::FetchDataFromParameterNode(not_equal_cnode_2, kInputIndexOne, converter::kFmkTypeMs, &data_info, true);
290   } else {
291     return nullptr;
292   }
293   if (status != lite::RET_OK) {
294     MS_LOG(ERROR) << "fetch transpose perm data failed.";
295     return nullptr;
296   }
297   int post_cast_type;
298   if (GetCastDstDataType(pre_cast_cnode, &post_cast_type) != lite::RET_OK) {
299     MS_LOG(ERROR) << "get cast dst type failed.";
300     return nullptr;
301   }
302   if (post_cast_type != data_info.data_type_) {
303     return nullptr;
304   }
305   if (data_info.data_type_ == kNumberTypeInt64) {
306     if (data_info.data_.size() < sizeof(int32_t)) {
307       MS_LOG(ERROR) << "Data and datatype of data-info not match.";
308       return nullptr;
309     }
310     auto p_data = reinterpret_cast<int64_t *>(data_info.data_.data());
311     for (size_t i = 0; i < (data_info.data_.size() / sizeof(int64_t)); i++) {
312       if ((p_data[i] > INT32_MAX) || (p_data[i] < INT_MIN)) {
313         return nullptr;
314       }
315     }
316     auto abstract = param_node->abstract();
317     MS_CHECK_TRUE_RET(abstract != nullptr, nullptr);
318     auto new_abstract = abstract->Clone();
319     new_abstract->set_value(std::make_shared<ValueAny>());
320     if (GenCastNode(func_graph, param_node, param_node->fullname_with_scope() + "_post_cast",
321                     static_cast<TypeId>(kNumberTypeInt32), new_abstract) == nullptr) {
322       MS_LOG(ERROR) << "GenCastNode failed.";
323       return nullptr;
324     }
325     auto manager = func_graph->manager();
326     MS_ASSERT(manager != nullptr);
327     manager->SetEdge(not_equal_cnode_2, cast_index, pre_cast_cnode->input(kInputIndexOne));
328   }
329 
330   return nullptr;
331 }
332 
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const333 AnfNodePtr CastFusionPass::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
334                                    const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
335   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
336     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
337     return nullptr;
338   }
339 
340   if (pattern_name == "CastCastPatternName") {
341     return CastCastFusion(func_graph, node);
342   } else if (pattern_name == "CastGatherPatternName") {
343     return CastGatherFusion(func_graph, node);
344   } else if (pattern_name == "CastSplitPatternName") {
345     return CastSplitFusion(func_graph, node);
346   } else if (pattern_name == "CastNotEqualPatternName") {
347     return CastNotEqualFusion(func_graph, node);
348   } else if (pattern_name == "CastNotEqual2PatternName") {
349     return CastNotEqualFusion(func_graph, node);
350   }
351   return nullptr;
352 }
353 }  // namespace mindspore::opt
354