• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/transpose_fusion.h"
19 #include <unordered_map>
20 #include <memory>
21 #include <vector>
22 #include "mindspore/core/ops/nn_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "tools/converter/quantizer/quant_param_holder.h"
26 #include "ops/auto_generate/gen_lite_ops.h"
27 #include "tools/optimizer/common/format_utils.h"
28 #include "ops/fusion/scale_fusion.h"
29 #include "nnacl/op_base.h"
30 #include "ops/op_utils.h"
31 
32 namespace mindspore::opt {
IsBNCNode(const BaseRef & n)33 bool IsBNCNode(const BaseRef &n) {
34   if (utils::isa<AnfNodePtr>(n)) {
35     auto anf_node = utils::cast<AnfNodePtr>(n);
36     return CheckPrimitiveType(anf_node, prim::kPrimBatchNorm) ||
37            CheckPrimitiveType(anf_node, prim::kPrimFusedBatchNorm);
38   }
39   return false;
40 }
41 
IsSoftmaxNode(const BaseRef & n)42 bool IsSoftmaxNode(const BaseRef &n) {
43   if (utils::isa<AnfNodePtr>(n)) {
44     auto anf_node = utils::cast<AnfNodePtr>(n);
45     return CheckPrimitiveType(anf_node, prim::kPrimSoftmax) || CheckPrimitiveType(anf_node, prim::kPrimLogSoftmax);
46   }
47   return false;
48 }
49 
DefineBNPattern() const50 VectorRef TransposeFusion::DefineBNPattern() const {
51   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
52   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
53   auto is_conv = std::make_shared<CondVar>(IsConvNode);
54   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
55   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
56   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
57   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
58   auto is_bn = std::make_shared<CondVar>(IsBNCNode);
59   MS_CHECK_TRUE_RET(is_bn != nullptr, {});
60   auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
61   MS_CHECK_TRUE_RET(bn_mean_var != nullptr, {});
62   auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
63   MS_CHECK_TRUE_RET(bn_variable_var != nullptr, {});
64   auto bn_other_var = std::make_shared<SeqVar>();
65   MS_CHECK_TRUE_RET(bn_other_var != nullptr, {});
66   VectorRef bn_ref = VectorRef({is_bn, transpose_conv_ref, bn_mean_var, bn_variable_var, bn_other_var});
67   return bn_ref;
68 }
69 
DefineActivationscalePattern() const70 VectorRef TransposeFusion::DefineActivationscalePattern() const {
71   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
72   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
73   auto is_conv = std::make_shared<CondVar>(IsConvNode);
74   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
75   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
76   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
77   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
78   auto is_scale = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
79   MS_CHECK_TRUE_RET(is_scale != nullptr, {});
80   auto scale_var_1 = std::make_shared<CondVar>(IsParamNode);
81   MS_CHECK_TRUE_RET(scale_var_1 != nullptr, {});
82   auto scale_var_2 = std::make_shared<SeqVar>();
83   MS_CHECK_TRUE_RET(scale_var_2 != nullptr, {});
84   VectorRef sclae_ref = VectorRef({is_scale, transpose_conv_ref, scale_var_1, scale_var_2});
85   return sclae_ref;
86 }
87 
DefineActivationPattern() const88 VectorRef TransposeFusion::DefineActivationPattern() const {
89   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
90   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
91   auto is_conv = std::make_shared<CondVar>(IsConvNode);
92   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
93   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
94   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
95   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
96   auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
97   MS_CHECK_TRUE_RET(is_activation != nullptr, {});
98   VectorRef act_ref = VectorRef({is_activation, transpose_conv_ref});
99   return act_ref;
100 }
101 
DefineBiasAddPattern() const102 VectorRef TransposeFusion::DefineBiasAddPattern() const {
103   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
104   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
105   auto is_conv = std::make_shared<CondVar>(IsConvNode);
106   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
107   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
108   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
109   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
110   auto is_bias_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
111   MS_CHECK_TRUE_RET(is_bias_add != nullptr, {});
112   auto bias_param = std::make_shared<CondVar>(IsParamNode);
113   MS_CHECK_TRUE_RET(bias_param != nullptr, {});
114   VectorRef act_ref = VectorRef({is_bias_add, transpose_conv_ref, bias_param});
115   return act_ref;
116 }
117 
DefineScalePattern() const118 VectorRef TransposeFusion::DefineScalePattern() const {
119   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
120   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
121   auto is_scale = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
122   MS_CHECK_TRUE_RET(is_scale != nullptr, {});
123   auto is_weight_param = std::make_shared<CondVar>(IsParamNode);
124   MS_CHECK_TRUE_RET(is_weight_param != nullptr, {});
125   auto is_seq_var = std::make_shared<SeqVar>();
126   MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
127   VectorRef trans_scale_ref = VectorRef({is_scale, is_transpose, is_weight_param, is_seq_var});
128   return trans_scale_ref;
129 }
130 
DefineSoftmaxPattern() const131 VectorRef TransposeFusion::DefineSoftmaxPattern() const {
132   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
133   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
134   auto is_softmax = std::make_shared<CondVar>(IsSoftmaxNode);
135   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
136   VectorRef trans_softmax_ref = VectorRef({is_softmax, is_transpose});
137   return trans_softmax_ref;
138 }
139 
DefineTransTransPattern() const140 VectorRef TransposeFusion::DefineTransTransPattern() const {
141   auto is_transpose1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
142   MS_CHECK_TRUE_RET(is_transpose1 != nullptr, {});
143   auto is_transpose2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
144   MS_CHECK_TRUE_RET(is_transpose2 != nullptr, {});
145   auto transpose_param = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
146   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
147   VectorRef trans_trans_ref = VectorRef({is_transpose2, is_transpose1, transpose_param});
148   return trans_trans_ref;
149 }
150 
DefinePatterns() const151 std::unordered_map<std::string, VectorRef> TransposeFusion::DefinePatterns() const {
152   std::unordered_map<std::string, VectorRef> patterns;
153   patterns["BNPatternName"] = DefineBNPattern();
154   patterns["ActivationPatternName"] = DefineActivationPattern();
155   patterns["BiasAddPatternName"] = DefineBiasAddPattern();
156   patterns["ScalePatternName"] = DefineActivationscalePattern();
157   patterns["TransScalePatternName"] = DefineScalePattern();
158   patterns["TransSoftmaxPatternName"] = DefineSoftmaxPattern();
159   patterns["TransTransPatternName"] = DefineTransTransPattern();
160   return patterns;
161 }
162 
GenTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const AnfNodePtr & perm,const std::string & cnode_name)163 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const AnfNodePtr &perm,
164                           const std::string &cnode_name) {
165   MS_ASSERT(func_graph != nullptr && input_node != nullptr);
166   auto trans_prim = std::make_shared<ops::Transpose>();
167   MS_CHECK_TRUE_RET(trans_prim != nullptr, nullptr);
168   auto trans_prim_c = trans_prim->GetPrim();
169   MS_CHECK_TRUE_RET(trans_prim_c != nullptr, nullptr);
170   auto cnode = func_graph->NewCNode(trans_prim_c, {input_node, perm});
171   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
172   cnode->set_fullname_with_scope(cnode_name);
173   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1);
174   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
175   auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
176   MS_ASSERT(trans_insert_prim != nullptr);
177   trans_insert_prim->AddAttr("quant_params", quant_params_holder);
178   return cnode;
179 }
180 
TransTransFusion(const FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node) const181 AnfNodePtr TransposeFusion::TransTransFusion(const FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const {
182   MS_ASSERT(func_graph != nullptr && node != nullptr);
183   auto trans_cnode_2 = node->cast<CNodePtr>();
184   MS_ASSERT(trans_cnode_2 != nullptr);
185   if (IsMarkedTrainOp(trans_cnode_2)) {
186     return nullptr;
187   }
188   MS_CHECK_TRUE_RET(trans_cnode_2 != nullptr, nullptr);
189   MS_CHECK_TRUE_RET(trans_cnode_2->size() == kInputSizeThree, nullptr);
190   if (!CheckPrimitiveType(trans_cnode_2, prim::kPrimTranspose) ||
191       !CheckPrimitiveType(trans_cnode_2->input(1), prim::kPrimTranspose)) {
192     return nullptr;
193   }
194   std::vector<int> post_perm;
195   if (GetTransposePerm(trans_cnode_2, &post_perm) != lite::RET_OK) {
196     MS_LOG(ERROR) << "get transpose perm failed.";
197     return nullptr;
198   }
199   std::vector<int> pre_perm;
200   auto pre_node = trans_cnode_2->input(1);
201   MS_CHECK_TRUE_RET(pre_node != nullptr, nullptr);
202   auto pre_cnode = pre_node->cast<CNodePtr>();
203   if (pre_cnode == nullptr) {
204     return nullptr;
205   }
206   if (IsMarkedTrainOp(pre_cnode)) {
207     return nullptr;
208   }
209   if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
210     MS_LOG(ERROR) << "get transpose perm failed.";
211     return nullptr;
212   }
213   if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
214     return pre_cnode->input(1);
215   }
216   if (pre_perm.size() == post_perm.size()) {
217     std::vector<int> perm;
218     for (auto idx : post_perm) {
219       MS_CHECK_TRUE_RET(idx >= 0 && static_cast<size_t>(idx) < pre_perm.size(), nullptr);
220       perm.push_back(pre_perm[idx]);
221     }
222     std::vector<int> ori_perm = [&pre_perm]() {
223       std::vector<int> value;
224       for (int i = 0; i < static_cast<int>(pre_perm.size()); i++) value.push_back(i);
225       return value;
226     }();
227     if (perm == ori_perm) {
228       return pre_cnode->input(1);
229     }
230     auto name = trans_cnode_2->fullname_with_scope();
231     auto perm_node = BuildIntVecParameterNode(func_graph, perm, name + "_perm");
232     auto manager = func_graph->manager();
233     MS_ASSERT(manager != nullptr);
234     manager->SetEdge(trans_cnode_2, 1, pre_cnode->input(1));
235     manager->SetEdge(trans_cnode_2, kInputIndexTwo, perm_node);
236   }
237   return nullptr;
238 }
239 
AdjustAxis(const mindspore::AnfNodePtr & node) const240 int TransposeFusion::AdjustAxis(const mindspore::AnfNodePtr &node) const {
241   MS_ASSERT(node != nullptr);
242   auto cnode = node->cast<CNodePtr>();
243   MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_ERROR);
244   if (IsMarkedTrainOp(cnode)) {
245     return lite::RET_ERROR;
246   }
247   bool has_inferred{false};
248   if (DetermineCertainVarInputHasInferred(cnode, 1, &has_inferred) != RET_OK || !has_inferred) {
249     return lite::RET_ERROR;
250   }
251   auto transpose_node = cnode->input(1);
252   auto transpose_cnode = transpose_node->cast<CNodePtr>();
253   if (transpose_cnode == nullptr) {
254     return lite::RET_ERROR;
255   }
256   if (IsMarkedTrainOp(transpose_cnode)) {
257     return lite::RET_ERROR;
258   }
259   if (CheckPrimitiveType(cnode, prim::kPrimScaleFusion)) {
260     auto weight_param = cnode->input(2);
261     MS_CHECK_TRUE_RET(weight_param != nullptr, lite::RET_ERROR);
262     std::vector<int64_t> weight_shape;
263     if (FetchShapeFromAbstract(weight_param->abstract(), &weight_shape) != lite::RET_OK) {
264       MS_LOG(ERROR) << "Get shape from abstract failed.";
265       return lite::RET_ERROR;
266     }
267     if (weight_shape.size() != 1) {
268       return lite::RET_ERROR;
269     }
270   }
271   std::vector<int> perm;
272   if (GetTransposePerm(transpose_cnode, &perm) != lite::RET_OK) {
273     MS_LOG(ERROR) << "get tanspose perm failed.";
274     return lite::RET_ERROR;
275   }
276   MS_CHECK_TRUE_RET(!perm.empty(), lite::RET_ERROR);
277 
278   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
279   MS_CHECK_TRUE_RET(prim != nullptr, lite::RET_ERROR);
280   auto axis_value_ptr = prim->GetAttr(ops::kAxis);
281   MS_CHECK_TRUE_RET(axis_value_ptr != nullptr, lite::RET_ERROR);
282   int64_t axis = !utils::isa<ValueSequencePtr>(axis_value_ptr) ? GetValue<int64_t>(axis_value_ptr)
283                                                                : GetValue<std::vector<int64_t>>(axis_value_ptr).front();
284   axis = axis < 0 ? axis + perm.size() : axis;
285   MS_CHECK_TRUE_RET(axis >= 0 && static_cast<size_t>(axis) < perm.size(), lite::RET_ERROR);
286   auto axis_attr = !utils::isa<ValueSequencePtr>(axis_value_ptr) ? MakeValue<int64_t>(perm.at(axis))
287                                                                  : MakeValue<std::vector<int64_t>>({perm.at(axis)});
288   (void)prim->AddAttr(ops::kAxis, axis_attr);
289   if (perm == kNC2NH) {
290     (void)prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
291   } else if (perm == kNH2NC) {
292     (void)prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
293   }
294   return lite::RET_OK;
295 }
296 
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const297 AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
298                                     const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
299   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
300     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
301     return nullptr;
302   }
303   if (pattern_name == "TransTransPatternName") {
304     return TransTransFusion(func_graph, node);
305   } else if (pattern_name == "TransScalePatternName" || pattern_name == "TransSoftmaxPatternName") {
306     if (AdjustAxis(node) != lite::RET_OK) {
307       return nullptr;
308     }
309   }
310 
311   if (node->cast<CNodePtr>() == nullptr) {
312     return nullptr;
313   }
314   auto any_cnode = node->cast<CNodePtr>();
315   if (IsMarkedTrainOp(any_cnode)) {
316     return nullptr;
317   }
318   const auto transpose_node = any_cnode->input(1);
319   if (transpose_node == nullptr || transpose_node->cast<CNodePtr>() == nullptr) {
320     return nullptr;
321   }
322   const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
323   if (IsMarkedTrainOp(transpose_cnode)) {
324     return nullptr;
325   }
326   auto perm_node = transpose_cnode->input(kInputIndexTwo);
327   MS_CHECK_TRUE_RET(perm_node != nullptr, nullptr);
328   auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
329   MS_CHECK_TRUE_RET(trans_post_node != nullptr, nullptr);
330   if (any_cnode->abstract() != nullptr) {
331     trans_post_node->set_abstract(any_cnode->abstract()->Clone());
332   }
333   if (transpose_cnode->input(1)->abstract() != nullptr) {
334     any_cnode->set_abstract(transpose_cnode->input(1)->abstract()->Clone());
335   }
336   auto manager = func_graph->manager();
337   MS_ASSERT(manager != nullptr);
338   manager->SetEdge(any_cnode, 1, transpose_cnode->input(1));
339   return trans_post_node;
340 }
341 }  // namespace mindspore::opt
342