• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/fusion/transpose_fusion.h"
18 #include <unordered_map>
19 #include <memory>
20 #include <vector>
21 #include "tools/converter/quant_param_holder.h"
22 #include "mindspore/core/ops/transpose.h"
23 #include "tools/optimizer/common/format_utils.h"
24 #include "nnacl/op_base.h"
25 
26 namespace mindspore::opt {
IsBNCNode(const BaseRef & n)27 bool IsBNCNode(const BaseRef &n) {
28   if (utils::isa<AnfNodePtr>(n)) {
29     auto anf_node = utils::cast<AnfNodePtr>(n);
30     return CheckPrimitiveType(anf_node, prim::kPrimBatchNorm) ||
31            CheckPrimitiveType(anf_node, prim::kPrimFusedBatchNorm);
32   }
33   return false;
34 }
35 
DefineBNPattern() const36 VectorRef TransposeFusion::DefineBNPattern() const {
37   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
38   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
39   auto is_conv = std::make_shared<CondVar>(IsConvNode);
40   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
41   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
42   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
43   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
44   auto is_bn = std::make_shared<CondVar>(IsBNCNode);
45   MS_CHECK_TRUE_RET(is_bn != nullptr, {});
46   auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
47   MS_CHECK_TRUE_RET(bn_mean_var != nullptr, {});
48   auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
49   MS_CHECK_TRUE_RET(bn_variable_var != nullptr, {});
50   auto bn_other_var = std::make_shared<SeqVar>();
51   MS_CHECK_TRUE_RET(bn_other_var != nullptr, {});
52   VectorRef bn_ref = VectorRef({is_bn, transpose_conv_ref, bn_mean_var, bn_variable_var, bn_other_var});
53   return bn_ref;
54 }
55 
DefineActivationscalePattern() const56 VectorRef TransposeFusion::DefineActivationscalePattern() const {
57   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
58   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
59   auto is_conv = std::make_shared<CondVar>(IsConvNode);
60   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
61   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
62   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
63   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
64   auto is_scale = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
65   MS_CHECK_TRUE_RET(is_scale != nullptr, {});
66   auto scale_var_1 = std::make_shared<CondVar>(IsParamNode);
67   MS_CHECK_TRUE_RET(scale_var_1 != nullptr, {});
68   auto scale_var_2 = std::make_shared<SeqVar>();
69   MS_CHECK_TRUE_RET(scale_var_2 != nullptr, {});
70   VectorRef sclae_ref = VectorRef({is_scale, transpose_conv_ref, scale_var_1, scale_var_2});
71   return sclae_ref;
72 }
73 
DefineActivationPattern() const74 VectorRef TransposeFusion::DefineActivationPattern() const {
75   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
76   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
77   auto is_conv = std::make_shared<CondVar>(IsConvNode);
78   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
79   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
80   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
81   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
82   auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
83   MS_CHECK_TRUE_RET(is_activation != nullptr, {});
84   VectorRef act_ref = VectorRef({is_activation, transpose_conv_ref});
85   return act_ref;
86 }
87 
DefineBiasAddPattern() const88 VectorRef TransposeFusion::DefineBiasAddPattern() const {
89   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
90   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
91   auto is_conv = std::make_shared<CondVar>(IsConvNode);
92   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
93   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
94   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
95   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
96   auto is_bias_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
97   MS_CHECK_TRUE_RET(is_bias_add != nullptr, {});
98   auto bias_param = std::make_shared<CondVar>(IsParamNode);
99   MS_CHECK_TRUE_RET(bias_param != nullptr, {});
100   VectorRef act_ref = VectorRef({is_bias_add, transpose_conv_ref, bias_param});
101   return act_ref;
102 }
103 
DefineTransTransPattern() const104 VectorRef TransposeFusion::DefineTransTransPattern() const {
105   auto is_transpose1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
106   MS_CHECK_TRUE_RET(is_transpose1 != nullptr, {});
107   auto is_transpose2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
108   MS_CHECK_TRUE_RET(is_transpose2 != nullptr, {});
109   auto transpose_param = std::make_shared<CondVar>(IsParamNode);
110   MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
111   VectorRef trans_trans_ref = VectorRef({is_transpose2, is_transpose1, transpose_param});
112   return trans_trans_ref;
113 }
114 
DefinePatterns() const115 std::unordered_map<std::string, VectorRef> TransposeFusion::DefinePatterns() const {
116   std::unordered_map<std::string, VectorRef> patterns;
117   patterns["BNPatternName"] = DefineBNPattern();
118   patterns["ActivationPatternName"] = DefineActivationPattern();
119   patterns["BiasAddPatternName"] = DefineBiasAddPattern();
120   patterns["ScalePatternName"] = DefineActivationscalePattern();
121   patterns["TransTransPatternName"] = DefineTransTransPattern();
122   return patterns;
123 }
124 
GenTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const AnfNodePtr & perm,const std::string & cnode_name)125 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const AnfNodePtr &perm,
126                           const std::string &cnode_name) {
127   MS_ASSERT(func_graph != nullptr && input_node != nullptr);
128   auto trans_prim = std::make_shared<ops::Transpose>();
129   MS_CHECK_TRUE_RET(trans_prim != nullptr, nullptr);
130   auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm});
131   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
132   cnode->set_fullname_with_scope(cnode_name);
133   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1);
134   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
135   auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
136   MS_ASSERT(trans_insert_prim != nullptr);
137   trans_insert_prim->AddAttr("quant_params", quant_params_holder);
138   return cnode;
139 }
140 
TransTransFusion(const mindspore::AnfNodePtr & node) const141 AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::AnfNodePtr &node) const {
142   MS_ASSERT(node != nullptr);
143   auto trans_cnode_2 = node->cast<CNodePtr>();
144   if (IsMarkedTrainOp(trans_cnode_2)) {
145     return nullptr;
146   }
147   MS_CHECK_TRUE_RET(trans_cnode_2 != nullptr, nullptr);
148   if (!CheckPrimitiveType(trans_cnode_2, prim::kPrimTranspose) ||
149       !CheckPrimitiveType(trans_cnode_2->input(1), prim::kPrimTranspose)) {
150     return nullptr;
151   }
152   std::vector<int> post_perm;
153   if (GetTransposePerm(trans_cnode_2, &post_perm) != lite::RET_OK) {
154     MS_LOG(ERROR) << "get tanspose perm failed.";
155     return nullptr;
156   }
157   std::vector<int> pre_perm;
158   auto pre_node = trans_cnode_2->input(1);
159   auto pre_cnode = pre_node->cast<CNodePtr>();
160   if (pre_cnode == nullptr) {
161     return nullptr;
162   }
163   if (IsMarkedTrainOp(pre_cnode)) {
164     return nullptr;
165   }
166   if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
167     MS_LOG(ERROR) << "get tanspose perm failed.";
168     return nullptr;
169   }
170   if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
171     return pre_cnode->input(1);
172   }
173   return nullptr;
174 }
175 
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const176 AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
177                                     const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
178   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
179     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
180     return nullptr;
181   }
182   if (pattern_name == "TransTransPatternName") {
183     return TransTransFusion(node);
184   }
185   if (node->cast<CNodePtr>() == nullptr) {
186     return nullptr;
187   }
188   auto any_cnode = node->cast<CNodePtr>();
189   if (IsMarkedTrainOp(any_cnode)) {
190     return nullptr;
191   }
192   const auto transpose_node = any_cnode->input(1);
193   if (transpose_node == nullptr || transpose_node->cast<CNodePtr>() == nullptr) {
194     return nullptr;
195   }
196   const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
197   if (IsMarkedTrainOp(transpose_cnode)) {
198     return nullptr;
199   }
200   auto perm_node = transpose_cnode->input(kInputIndexTwo);
201   auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
202   MS_CHECK_TRUE_RET(trans_post_node != nullptr, nullptr);
203   if (any_cnode->abstract() != nullptr) {
204     trans_post_node->set_abstract(any_cnode->abstract()->Clone());
205   }
206   if (transpose_cnode->input(1)->abstract() != nullptr) {
207     any_cnode->set_abstract(transpose_cnode->input(1)->abstract()->Clone());
208   }
209   auto manager = func_graph->manager();
210   MS_ASSERT(manager != nullptr);
211   manager->SetEdge(any_cnode, 1, transpose_cnode->input(1));
212   return trans_post_node;
213 }
214 }  // namespace mindspore::opt
215