1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/format/delete_redundant_transpose.h"
19 #include <vector>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "mindspore/core/ops/array_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "tools/optimizer/common/format_utils.h"
24 #include "nnacl/op_base.h"
25 #include "ops/op_utils.h"
26 #include "tools/common/node_util.h"
27 #include "tools/converter/quantizer/quant_params.h"
28
29 namespace mindspore {
30 namespace opt {
DeleteControlFlowTranspose(const CNodePtr & cnode)31 STATUS DeleteRedundantTranspose::DeleteControlFlowTranspose(const CNodePtr &cnode) {
32 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
33 if (sub_func_graph == nullptr) {
34 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
35 return lite::RET_NULL_PTR;
36 }
37 if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
38 MS_LOG(ERROR) << "delete transpose failed.";
39 return lite::RET_ERROR;
40 }
41 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
42 if (sub_func_graph == nullptr) {
43 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
44 return lite::RET_NULL_PTR;
45 }
46 if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
47 MS_LOG(ERROR) << "delete transpose failed.";
48 return lite::RET_ERROR;
49 }
50 return lite::RET_OK;
51 }
52
DeleteNot4DTranspose(const FuncGraphPtr & func_graph)53 STATUS DeleteRedundantTranspose::DeleteNot4DTranspose(const FuncGraphPtr &func_graph) {
54 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
55 MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
56 manager_->AddFuncGraph(func_graph);
57 auto node_list = TopoSort(func_graph->get_return());
58 for (auto &node : node_list) {
59 MS_CHECK_TRUE_RET(node != nullptr, lite::RET_NULL_PTR);
60 if (!utils::isa<CNode>(node)) {
61 continue;
62 }
63 auto cnode = node->cast<CNodePtr>();
64 if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
65 if (DeleteControlFlowTranspose(cnode) != RET_OK) {
66 MS_LOG(ERROR) << "DeleteControlFlowTranspose failed.";
67 return lite::RET_ERROR;
68 }
69 continue;
70 }
71 if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
72 continue;
73 }
74 auto abstract = GetCNodeInputAbstract(cnode, 1);
75 ShapeVector shape;
76 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
77 MS_LOG(ERROR) << "fetch shape failed.";
78 return lite::RET_ERROR;
79 }
80 std::vector<int> perm;
81 if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
82 MS_LOG(ERROR) << "fetch transpose perm failed.";
83 return lite::RET_ERROR;
84 }
85 int start_dat = 0;
86 bool useless = true;
87 for (auto dat : perm) {
88 if (dat == start_dat) {
89 start_dat += 1;
90 } else {
91 useless = false;
92 break;
93 }
94 }
95 if (useless) {
96 if (!manager_->Replace(node, cnode->input(1))) {
97 MS_LOG(ERROR) << "replace old node failed, please check.";
98 return lite::RET_ERROR;
99 }
100 continue;
101 }
102 if (!lite::JudgeDynamicShape(shape) && shape.size() != perm.size()) {
103 MS_LOG(DEBUG) << "transpose node need to be deleted.";
104 if (UpdateNodeFormat(cnode) != lite::RET_OK) {
105 MS_LOG(ERROR) << "update cnode format failed.";
106 return lite::RET_ERROR;
107 }
108 if (!manager_->Replace(node, cnode->input(1))) {
109 MS_LOG(ERROR) << "replace old node failed, please check.";
110 return lite::RET_ERROR;
111 }
112 }
113 }
114 return lite::RET_OK;
115 }
116
DoTransTransFusion(const FuncGraphPtr & func_graph,const CNodePtr & cnode)117 STATUS DeleteRedundantTranspose::DoTransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
118 if (func_graph == nullptr || cnode == nullptr) {
119 return lite::RET_ERROR;
120 }
121 if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
122 return lite::RET_OK;
123 }
124 if (cnode->size() <= 1 || cnode->input(1) == nullptr) {
125 MS_LOG(INFO) << "Failed to get input 1 of cnode " << cnode->fullname_with_scope() << ", input size "
126 << cnode->size();
127 return lite::RET_ERROR;
128 }
129 auto pre_cnode = cnode->input(1)->cast<CNodePtr>();
130 if (pre_cnode == nullptr) {
131 MS_LOG(INFO) << "node input 1 is not a cnode, node " << cnode->fullname_with_scope();
132 return lite::RET_OK;
133 }
134 if (!CheckPrimitiveType(pre_cnode, prim::kPrimTranspose) || IsMultiOutputTensors(func_graph, pre_cnode)) {
135 return lite::RET_OK;
136 }
137 std::vector<int> post_perm;
138 if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) {
139 MS_LOG(ERROR) << "transpose perm cannot be obtained, " << cnode->fullname_with_scope();
140 return lite::RET_ERROR;
141 }
142 std::vector<int> pre_perm;
143 if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
144 MS_LOG(ERROR) << "transpose perm cannot be obtained, " << pre_cnode->fullname_with_scope();
145 return lite::RET_ERROR;
146 }
147 if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
148 auto node_users = manager_->node_users()[cnode];
149 MS_LOG(INFO) << "node_users map size: " << node_users.size();
150 if (!manager_->Replace(cnode, pre_cnode->input(1))) {
151 MS_LOG(ERROR) << "replace old node failed, please check.";
152 return lite::RET_ERROR;
153 }
154 if (CopyQuantParam(cnode, pre_cnode, node_users) != RET_OK) {
155 MS_LOG(ERROR) << "Copy quant param failed, please check.";
156 return lite::RET_ERROR;
157 }
158 func_graph->DropNode(cnode->input(kInputIndexTwo));
159 func_graph->DropNode(pre_cnode->input(kInputIndexTwo));
160 }
161 return lite::RET_OK;
162 }
163
TransTransFusion(const FuncGraphPtr & func_graph)164 STATUS DeleteRedundantTranspose::TransTransFusion(const FuncGraphPtr &func_graph) {
165 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
166 MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
167 manager_->AddFuncGraph(func_graph);
168 auto node_lite = TopoSort(func_graph->get_return());
169 for (auto &node : node_lite) {
170 MS_CHECK_TRUE_RET(node != nullptr, lite::RET_NULL_PTR);
171 if (!utils::isa<CNode>(node)) {
172 continue;
173 }
174 auto cnode = node->cast<CNodePtr>();
175 if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
176 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
177 MS_CHECK_TRUE_MSG(sub_func_graph != nullptr, lite::RET_NULL_PTR, "find a subgraph is a nullptr.");
178 if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
179 MS_LOG(ERROR) << "delete transpose failed.";
180 return lite::RET_ERROR;
181 }
182 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
183 MS_CHECK_TRUE_MSG(sub_func_graph != nullptr, lite::RET_NULL_PTR, "find a subgraph is a nullptr.");
184 if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
185 MS_LOG(ERROR) << "delete transpose failed.";
186 return lite::RET_ERROR;
187 }
188 continue;
189 }
190 auto ret = DoTransTransFusion(func_graph, cnode);
191 if (ret != lite::RET_OK) {
192 return ret;
193 }
194 }
195 return lite::RET_OK;
196 }
197
UpdateNodeFormat(const CNodePtr & cnode)198 STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) {
199 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
200 MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
201 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
202 MS_ERROR_IF_NULL_W_RET_VAL(prim, lite::RET_ERROR);
203 if (prim->GetAttr(ops::kFormat) == nullptr) {
204 return lite::RET_OK;
205 }
206 auto forward_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
207 const int max_search_depth{3};
208 int loop{0};
209 auto search_node = cnode->input(1);
210 while (loop < max_search_depth) {
211 MS_CHECK_TRUE_RET(search_node != nullptr, lite::RET_ERROR);
212 auto search_cnode = search_node->cast<CNodePtr>();
213 if (search_cnode == nullptr) {
214 break;
215 }
216 auto primitive = GetCNodePrimitive(search_cnode);
217 if (primitive == nullptr) {
218 break;
219 }
220 if (primitive->GetAttr(ops::kFormat) != nullptr) {
221 forward_format = GetValue<int64_t>(primitive->GetAttr(ops::kFormat));
222 break;
223 }
224 search_node = search_cnode->input(1);
225 ++loop;
226 }
227 auto node_users = manager_->node_users()[cnode];
228 for (auto &node_user : node_users) {
229 if (node_user.second != 1) {
230 continue;
231 }
232 if (!utils::isa<CNode>(node_user.first)) {
233 MS_LOG(ERROR) << "post node is not cnode, which is invalid.";
234 return lite::RET_ERROR;
235 }
236 auto post_cnode = node_user.first->cast<CNodePtr>();
237 auto post_prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
238 MS_ERROR_IF_NULL_W_RET_VAL(post_prim, lite::RET_ERROR);
239 post_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(forward_format));
240 if (prim->HasAttr(opt::kOutputsFormat)) {
241 auto org_format = CastToInt(prim->GetAttr(opt::kOutputsFormat));
242 std::vector<int64_t> outputs_format(org_format.size(), forward_format);
243 (void)prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
244 }
245 }
246 return lite::RET_OK;
247 }
248
Run(const FuncGraphPtr & func_graph)249 bool DeleteRedundantTranspose::Run(const FuncGraphPtr &func_graph) {
250 MS_CHECK_TRUE_RET(func_graph != nullptr, false);
251 manager_ = Manage(func_graph, true);
252 if (manager_ == nullptr) {
253 MS_LOG(ERROR) << "manager is nullptr.";
254 return false;
255 }
256 if (TransTransFusion(func_graph) != lite::RET_OK) {
257 MS_LOG(ERROR) << "ranspose and transpose fusion failed.";
258 return false;
259 }
260 if (DeleteNot4DTranspose(func_graph) != lite::RET_OK) {
261 MS_LOG(ERROR) << "delete not 4D transpose failed.";
262 return false;
263 }
264 return true;
265 }
266
267 // copy quant info from transpose to post_cnode or input_cnode
CopyQuantParam(const CNodePtr & cnode,const CNodePtr & pre_cnode,const AnfNodeIndexSet & node_users)268 STATUS DeleteRedundantTranspose::CopyQuantParam(const CNodePtr &cnode, const CNodePtr &pre_cnode,
269 const AnfNodeIndexSet &node_users) {
270 auto input_node = pre_cnode->input(Index1);
271 CHECK_NULL_RETURN(input_node);
272 auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
273 CHECK_NULL_RETURN(cnode_primitive);
274 auto pre_cnode_primitive = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
275 CHECK_NULL_RETURN(pre_cnode_primitive);
276 if (lite::IsGraphInput(input_node)) {
277 for (auto &node_user : node_users) {
278 auto post_cnode = node_user.first->cast<CNodePtr>();
279 CHECK_NULL_RETURN(post_cnode);
280 auto post_cnode_primitive = GetValueNode<PrimitivePtr>(post_cnode->input(0));
281 CHECK_NULL_RETURN(post_cnode_primitive);
282 if (cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
283 auto quantization_param_value = cnode_primitive->GetAttr(lite::quant::kQuantParam);
284 CHECK_NULL_RETURN(quantization_param_value);
285 auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
286 if (!quantization_param_list.empty()) {
287 MS_LOG(INFO) << "Copy quant param to " << post_cnode->fullname_with_scope();
288 post_cnode_primitive->AddAttr(lite::quant::kGraphInputQuantParam, quantization_param_list.front());
289 }
290 }
291 if (pre_cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
292 auto quantization_param_value = pre_cnode_primitive->GetAttr(lite::quant::kQuantParam);
293 CHECK_NULL_RETURN(quantization_param_value);
294 auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
295 if (!quantization_param_list.empty()) {
296 MS_LOG(INFO) << "Copy quant param to " << post_cnode->fullname_with_scope();
297 post_cnode_primitive->AddAttr(lite::quant::kGraphInputQuantParam, quantization_param_list.front());
298 }
299 }
300 }
301 } else if (input_node->isa<mindspore::CNode>()) {
302 auto input_cnode = input_node->cast<mindspore::CNodePtr>();
303 auto input_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
304 CHECK_NULL_RETURN(input_primitive);
305 if (cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
306 input_primitive->AddAttr(lite::quant::kQuantParam, cnode_primitive->GetAttr(lite::quant::kQuantParam));
307 }
308 if (pre_cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
309 input_primitive->AddAttr(lite::quant::kQuantParam, pre_cnode_primitive->GetAttr(lite::quant::kQuantParam));
310 }
311 } else {
312 MS_LOG(ERROR) << input_node->fullname_with_scope() << " Not supported type.";
313 return RET_ERROR;
314 }
315 return RET_OK;
316 }
317 } // namespace opt
318 } // namespace mindspore
319