1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/fusion/transpose_fusion.h"
18 #include <unordered_map>
19 #include <memory>
20 #include <vector>
21 #include "tools/converter/quant_param_holder.h"
22 #include "mindspore/core/ops/transpose.h"
23 #include "tools/optimizer/common/format_utils.h"
24 #include "nnacl/op_base.h"
25
26 namespace mindspore::opt {
IsBNCNode(const BaseRef & n)27 bool IsBNCNode(const BaseRef &n) {
28 if (utils::isa<AnfNodePtr>(n)) {
29 auto anf_node = utils::cast<AnfNodePtr>(n);
30 return CheckPrimitiveType(anf_node, prim::kPrimBatchNorm) ||
31 CheckPrimitiveType(anf_node, prim::kPrimFusedBatchNorm);
32 }
33 return false;
34 }
35
DefineBNPattern() const36 VectorRef TransposeFusion::DefineBNPattern() const {
37 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
38 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
39 auto is_conv = std::make_shared<CondVar>(IsConvNode);
40 MS_CHECK_TRUE_RET(is_conv != nullptr, {});
41 auto transpose_param = std::make_shared<CondVar>(IsParamNode);
42 MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
43 VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
44 auto is_bn = std::make_shared<CondVar>(IsBNCNode);
45 MS_CHECK_TRUE_RET(is_bn != nullptr, {});
46 auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
47 MS_CHECK_TRUE_RET(bn_mean_var != nullptr, {});
48 auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
49 MS_CHECK_TRUE_RET(bn_variable_var != nullptr, {});
50 auto bn_other_var = std::make_shared<SeqVar>();
51 MS_CHECK_TRUE_RET(bn_other_var != nullptr, {});
52 VectorRef bn_ref = VectorRef({is_bn, transpose_conv_ref, bn_mean_var, bn_variable_var, bn_other_var});
53 return bn_ref;
54 }
55
DefineActivationscalePattern() const56 VectorRef TransposeFusion::DefineActivationscalePattern() const {
57 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
58 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
59 auto is_conv = std::make_shared<CondVar>(IsConvNode);
60 MS_CHECK_TRUE_RET(is_conv != nullptr, {});
61 auto transpose_param = std::make_shared<CondVar>(IsParamNode);
62 MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
63 VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
64 auto is_scale = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
65 MS_CHECK_TRUE_RET(is_scale != nullptr, {});
66 auto scale_var_1 = std::make_shared<CondVar>(IsParamNode);
67 MS_CHECK_TRUE_RET(scale_var_1 != nullptr, {});
68 auto scale_var_2 = std::make_shared<SeqVar>();
69 MS_CHECK_TRUE_RET(scale_var_2 != nullptr, {});
70 VectorRef sclae_ref = VectorRef({is_scale, transpose_conv_ref, scale_var_1, scale_var_2});
71 return sclae_ref;
72 }
73
DefineActivationPattern() const74 VectorRef TransposeFusion::DefineActivationPattern() const {
75 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
76 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
77 auto is_conv = std::make_shared<CondVar>(IsConvNode);
78 MS_CHECK_TRUE_RET(is_conv != nullptr, {});
79 auto transpose_param = std::make_shared<CondVar>(IsParamNode);
80 MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
81 VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
82 auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
83 MS_CHECK_TRUE_RET(is_activation != nullptr, {});
84 VectorRef act_ref = VectorRef({is_activation, transpose_conv_ref});
85 return act_ref;
86 }
87
DefineBiasAddPattern() const88 VectorRef TransposeFusion::DefineBiasAddPattern() const {
89 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
90 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
91 auto is_conv = std::make_shared<CondVar>(IsConvNode);
92 MS_CHECK_TRUE_RET(is_conv != nullptr, {});
93 auto transpose_param = std::make_shared<CondVar>(IsParamNode);
94 MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
95 VectorRef transpose_conv_ref = VectorRef({is_transpose, is_conv, transpose_param});
96 auto is_bias_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
97 MS_CHECK_TRUE_RET(is_bias_add != nullptr, {});
98 auto bias_param = std::make_shared<CondVar>(IsParamNode);
99 MS_CHECK_TRUE_RET(bias_param != nullptr, {});
100 VectorRef act_ref = VectorRef({is_bias_add, transpose_conv_ref, bias_param});
101 return act_ref;
102 }
103
DefineTransTransPattern() const104 VectorRef TransposeFusion::DefineTransTransPattern() const {
105 auto is_transpose1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
106 MS_CHECK_TRUE_RET(is_transpose1 != nullptr, {});
107 auto is_transpose2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
108 MS_CHECK_TRUE_RET(is_transpose2 != nullptr, {});
109 auto transpose_param = std::make_shared<CondVar>(IsParamNode);
110 MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
111 VectorRef trans_trans_ref = VectorRef({is_transpose2, is_transpose1, transpose_param});
112 return trans_trans_ref;
113 }
114
DefinePatterns() const115 std::unordered_map<std::string, VectorRef> TransposeFusion::DefinePatterns() const {
116 std::unordered_map<std::string, VectorRef> patterns;
117 patterns["BNPatternName"] = DefineBNPattern();
118 patterns["ActivationPatternName"] = DefineActivationPattern();
119 patterns["BiasAddPatternName"] = DefineBiasAddPattern();
120 patterns["ScalePatternName"] = DefineActivationscalePattern();
121 patterns["TransTransPatternName"] = DefineTransTransPattern();
122 return patterns;
123 }
124
GenTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const AnfNodePtr & perm,const std::string & cnode_name)125 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const AnfNodePtr &perm,
126 const std::string &cnode_name) {
127 MS_ASSERT(func_graph != nullptr && input_node != nullptr);
128 auto trans_prim = std::make_shared<ops::Transpose>();
129 MS_CHECK_TRUE_RET(trans_prim != nullptr, nullptr);
130 auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm});
131 MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
132 cnode->set_fullname_with_scope(cnode_name);
133 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1);
134 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
135 auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
136 MS_ASSERT(trans_insert_prim != nullptr);
137 trans_insert_prim->AddAttr("quant_params", quant_params_holder);
138 return cnode;
139 }
140
TransTransFusion(const mindspore::AnfNodePtr & node) const141 AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::AnfNodePtr &node) const {
142 MS_ASSERT(node != nullptr);
143 auto trans_cnode_2 = node->cast<CNodePtr>();
144 if (IsMarkedTrainOp(trans_cnode_2)) {
145 return nullptr;
146 }
147 MS_CHECK_TRUE_RET(trans_cnode_2 != nullptr, nullptr);
148 if (!CheckPrimitiveType(trans_cnode_2, prim::kPrimTranspose) ||
149 !CheckPrimitiveType(trans_cnode_2->input(1), prim::kPrimTranspose)) {
150 return nullptr;
151 }
152 std::vector<int> post_perm;
153 if (GetTransposePerm(trans_cnode_2, &post_perm) != lite::RET_OK) {
154 MS_LOG(ERROR) << "get tanspose perm failed.";
155 return nullptr;
156 }
157 std::vector<int> pre_perm;
158 auto pre_node = trans_cnode_2->input(1);
159 auto pre_cnode = pre_node->cast<CNodePtr>();
160 if (pre_cnode == nullptr) {
161 return nullptr;
162 }
163 if (IsMarkedTrainOp(pre_cnode)) {
164 return nullptr;
165 }
166 if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
167 MS_LOG(ERROR) << "get tanspose perm failed.";
168 return nullptr;
169 }
170 if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
171 return pre_cnode->input(1);
172 }
173 return nullptr;
174 }
175
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const176 AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
177 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
178 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
179 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
180 return nullptr;
181 }
182 if (pattern_name == "TransTransPatternName") {
183 return TransTransFusion(node);
184 }
185 if (node->cast<CNodePtr>() == nullptr) {
186 return nullptr;
187 }
188 auto any_cnode = node->cast<CNodePtr>();
189 if (IsMarkedTrainOp(any_cnode)) {
190 return nullptr;
191 }
192 const auto transpose_node = any_cnode->input(1);
193 if (transpose_node == nullptr || transpose_node->cast<CNodePtr>() == nullptr) {
194 return nullptr;
195 }
196 const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
197 if (IsMarkedTrainOp(transpose_cnode)) {
198 return nullptr;
199 }
200 auto perm_node = transpose_cnode->input(kInputIndexTwo);
201 auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
202 MS_CHECK_TRUE_RET(trans_post_node != nullptr, nullptr);
203 if (any_cnode->abstract() != nullptr) {
204 trans_post_node->set_abstract(any_cnode->abstract()->Clone());
205 }
206 if (transpose_cnode->input(1)->abstract() != nullptr) {
207 any_cnode->set_abstract(transpose_cnode->input(1)->abstract()->Clone());
208 }
209 auto manager = func_graph->manager();
210 MS_ASSERT(manager != nullptr);
211 manager->SetEdge(any_cnode, 1, transpose_cnode->input(1));
212 return trans_post_node;
213 }
214 } // namespace mindspore::opt
215