• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/format/delete_redundant_transpose.h"
19 #include <vector>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "mindspore/core/ops/array_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "tools/optimizer/common/format_utils.h"
24 #include "nnacl/op_base.h"
25 #include "ops/op_utils.h"
26 #include "tools/common/node_util.h"
27 #include "tools/converter/quantizer/quant_params.h"
28 
29 namespace mindspore {
30 namespace opt {
DeleteControlFlowTranspose(const CNodePtr & cnode)31 STATUS DeleteRedundantTranspose::DeleteControlFlowTranspose(const CNodePtr &cnode) {
32   auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
33   if (sub_func_graph == nullptr) {
34     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
35     return lite::RET_NULL_PTR;
36   }
37   if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
38     MS_LOG(ERROR) << "delete transpose failed.";
39     return lite::RET_ERROR;
40   }
41   sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
42   if (sub_func_graph == nullptr) {
43     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
44     return lite::RET_NULL_PTR;
45   }
46   if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) {
47     MS_LOG(ERROR) << "delete transpose failed.";
48     return lite::RET_ERROR;
49   }
50   return lite::RET_OK;
51 }
52 
DeleteNot4DTranspose(const FuncGraphPtr & func_graph)53 STATUS DeleteRedundantTranspose::DeleteNot4DTranspose(const FuncGraphPtr &func_graph) {
54   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
55   MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
56   manager_->AddFuncGraph(func_graph);
57   auto node_list = TopoSort(func_graph->get_return());
58   for (auto &node : node_list) {
59     MS_CHECK_TRUE_RET(node != nullptr, lite::RET_NULL_PTR);
60     if (!utils::isa<CNode>(node)) {
61       continue;
62     }
63     auto cnode = node->cast<CNodePtr>();
64     if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
65       if (DeleteControlFlowTranspose(cnode) != RET_OK) {
66         MS_LOG(ERROR) << "DeleteControlFlowTranspose failed.";
67         return lite::RET_ERROR;
68       }
69       continue;
70     }
71     if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
72       continue;
73     }
74     auto abstract = GetCNodeInputAbstract(cnode, 1);
75     ShapeVector shape;
76     if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
77       MS_LOG(ERROR) << "fetch shape failed.";
78       return lite::RET_ERROR;
79     }
80     std::vector<int> perm;
81     if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
82       MS_LOG(ERROR) << "fetch transpose perm failed.";
83       return lite::RET_ERROR;
84     }
85     int start_dat = 0;
86     bool useless = true;
87     for (auto dat : perm) {
88       if (dat == start_dat) {
89         start_dat += 1;
90       } else {
91         useless = false;
92         break;
93       }
94     }
95     if (useless) {
96       if (!manager_->Replace(node, cnode->input(1))) {
97         MS_LOG(ERROR) << "replace old node failed, please check.";
98         return lite::RET_ERROR;
99       }
100       continue;
101     }
102     if (!lite::JudgeDynamicShape(shape) && shape.size() != perm.size()) {
103       MS_LOG(DEBUG) << "transpose node need to be deleted.";
104       if (UpdateNodeFormat(cnode) != lite::RET_OK) {
105         MS_LOG(ERROR) << "update cnode format failed.";
106         return lite::RET_ERROR;
107       }
108       if (!manager_->Replace(node, cnode->input(1))) {
109         MS_LOG(ERROR) << "replace old node failed, please check.";
110         return lite::RET_ERROR;
111       }
112     }
113   }
114   return lite::RET_OK;
115 }
116 
DoTransTransFusion(const FuncGraphPtr & func_graph,const CNodePtr & cnode)117 STATUS DeleteRedundantTranspose::DoTransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
118   if (func_graph == nullptr || cnode == nullptr) {
119     return lite::RET_ERROR;
120   }
121   if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
122     return lite::RET_OK;
123   }
124   if (cnode->size() <= 1 || cnode->input(1) == nullptr) {
125     MS_LOG(INFO) << "Failed to get input 1 of cnode " << cnode->fullname_with_scope() << ", input size "
126                  << cnode->size();
127     return lite::RET_ERROR;
128   }
129   auto pre_cnode = cnode->input(1)->cast<CNodePtr>();
130   if (pre_cnode == nullptr) {
131     MS_LOG(INFO) << "node input 1 is not a cnode, node " << cnode->fullname_with_scope();
132     return lite::RET_OK;
133   }
134   if (!CheckPrimitiveType(pre_cnode, prim::kPrimTranspose) || IsMultiOutputTensors(func_graph, pre_cnode)) {
135     return lite::RET_OK;
136   }
137   std::vector<int> post_perm;
138   if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) {
139     MS_LOG(ERROR) << "transpose perm cannot be obtained, " << cnode->fullname_with_scope();
140     return lite::RET_ERROR;
141   }
142   std::vector<int> pre_perm;
143   if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
144     MS_LOG(ERROR) << "transpose perm cannot be obtained, " << pre_cnode->fullname_with_scope();
145     return lite::RET_ERROR;
146   }
147   if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) {
148     auto node_users = manager_->node_users()[cnode];
149     MS_LOG(INFO) << "node_users map size: " << node_users.size();
150     if (!manager_->Replace(cnode, pre_cnode->input(1))) {
151       MS_LOG(ERROR) << "replace old node failed, please check.";
152       return lite::RET_ERROR;
153     }
154     if (CopyQuantParam(cnode, pre_cnode, node_users) != RET_OK) {
155       MS_LOG(ERROR) << "Copy quant param failed, please check.";
156       return lite::RET_ERROR;
157     }
158     func_graph->DropNode(cnode->input(kInputIndexTwo));
159     func_graph->DropNode(pre_cnode->input(kInputIndexTwo));
160   }
161   return lite::RET_OK;
162 }
163 
TransTransFusion(const FuncGraphPtr & func_graph)164 STATUS DeleteRedundantTranspose::TransTransFusion(const FuncGraphPtr &func_graph) {
165   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
166   MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
167   manager_->AddFuncGraph(func_graph);
168   auto node_lite = TopoSort(func_graph->get_return());
169   for (auto &node : node_lite) {
170     MS_CHECK_TRUE_RET(node != nullptr, lite::RET_NULL_PTR);
171     if (!utils::isa<CNode>(node)) {
172       continue;
173     }
174     auto cnode = node->cast<CNodePtr>();
175     if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
176       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
177       MS_CHECK_TRUE_MSG(sub_func_graph != nullptr, lite::RET_NULL_PTR, "find a subgraph is a nullptr.");
178       if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
179         MS_LOG(ERROR) << "delete transpose failed.";
180         return lite::RET_ERROR;
181       }
182       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
183       MS_CHECK_TRUE_MSG(sub_func_graph != nullptr, lite::RET_NULL_PTR, "find a subgraph is a nullptr.");
184       if (TransTransFusion(sub_func_graph) != lite::RET_OK) {
185         MS_LOG(ERROR) << "delete transpose failed.";
186         return lite::RET_ERROR;
187       }
188       continue;
189     }
190     auto ret = DoTransTransFusion(func_graph, cnode);
191     if (ret != lite::RET_OK) {
192       return ret;
193     }
194   }
195   return lite::RET_OK;
196 }
197 
UpdateNodeFormat(const CNodePtr & cnode)198 STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) {
199   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
200   MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
201   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
202   MS_ERROR_IF_NULL_W_RET_VAL(prim, lite::RET_ERROR);
203   if (prim->GetAttr(ops::kFormat) == nullptr) {
204     return lite::RET_OK;
205   }
206   auto forward_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
207   const int max_search_depth{3};
208   int loop{0};
209   auto search_node = cnode->input(1);
210   while (loop < max_search_depth) {
211     MS_CHECK_TRUE_RET(search_node != nullptr, lite::RET_ERROR);
212     auto search_cnode = search_node->cast<CNodePtr>();
213     if (search_cnode == nullptr) {
214       break;
215     }
216     auto primitive = GetCNodePrimitive(search_cnode);
217     if (primitive == nullptr) {
218       break;
219     }
220     if (primitive->GetAttr(ops::kFormat) != nullptr) {
221       forward_format = GetValue<int64_t>(primitive->GetAttr(ops::kFormat));
222       break;
223     }
224     search_node = search_cnode->input(1);
225     ++loop;
226   }
227   auto node_users = manager_->node_users()[cnode];
228   for (auto &node_user : node_users) {
229     if (node_user.second != 1) {
230       continue;
231     }
232     if (!utils::isa<CNode>(node_user.first)) {
233       MS_LOG(ERROR) << "post node is not cnode, which is invalid.";
234       return lite::RET_ERROR;
235     }
236     auto post_cnode = node_user.first->cast<CNodePtr>();
237     auto post_prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
238     MS_ERROR_IF_NULL_W_RET_VAL(post_prim, lite::RET_ERROR);
239     post_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(forward_format));
240     if (prim->HasAttr(opt::kOutputsFormat)) {
241       auto org_format = CastToInt(prim->GetAttr(opt::kOutputsFormat));
242       std::vector<int64_t> outputs_format(org_format.size(), forward_format);
243       (void)prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
244     }
245   }
246   return lite::RET_OK;
247 }
248 
Run(const FuncGraphPtr & func_graph)249 bool DeleteRedundantTranspose::Run(const FuncGraphPtr &func_graph) {
250   MS_CHECK_TRUE_RET(func_graph != nullptr, false);
251   manager_ = Manage(func_graph, true);
252   if (manager_ == nullptr) {
253     MS_LOG(ERROR) << "manager is nullptr.";
254     return false;
255   }
256   if (TransTransFusion(func_graph) != lite::RET_OK) {
257     MS_LOG(ERROR) << "ranspose and transpose fusion failed.";
258     return false;
259   }
260   if (DeleteNot4DTranspose(func_graph) != lite::RET_OK) {
261     MS_LOG(ERROR) << "delete not 4D transpose failed.";
262     return false;
263   }
264   return true;
265 }
266 
267 // copy quant info from transpose to post_cnode or input_cnode
CopyQuantParam(const CNodePtr & cnode,const CNodePtr & pre_cnode,const AnfNodeIndexSet & node_users)268 STATUS DeleteRedundantTranspose::CopyQuantParam(const CNodePtr &cnode, const CNodePtr &pre_cnode,
269                                                 const AnfNodeIndexSet &node_users) {
270   auto input_node = pre_cnode->input(Index1);
271   CHECK_NULL_RETURN(input_node);
272   auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
273   CHECK_NULL_RETURN(cnode_primitive);
274   auto pre_cnode_primitive = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
275   CHECK_NULL_RETURN(pre_cnode_primitive);
276   if (lite::IsGraphInput(input_node)) {
277     for (auto &node_user : node_users) {
278       auto post_cnode = node_user.first->cast<CNodePtr>();
279       CHECK_NULL_RETURN(post_cnode);
280       auto post_cnode_primitive = GetValueNode<PrimitivePtr>(post_cnode->input(0));
281       CHECK_NULL_RETURN(post_cnode_primitive);
282       if (cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
283         auto quantization_param_value = cnode_primitive->GetAttr(lite::quant::kQuantParam);
284         CHECK_NULL_RETURN(quantization_param_value);
285         auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
286         if (!quantization_param_list.empty()) {
287           MS_LOG(INFO) << "Copy quant param to " << post_cnode->fullname_with_scope();
288           post_cnode_primitive->AddAttr(lite::quant::kGraphInputQuantParam, quantization_param_list.front());
289         }
290       }
291       if (pre_cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
292         auto quantization_param_value = pre_cnode_primitive->GetAttr(lite::quant::kQuantParam);
293         CHECK_NULL_RETURN(quantization_param_value);
294         auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
295         if (!quantization_param_list.empty()) {
296           MS_LOG(INFO) << "Copy quant param to " << post_cnode->fullname_with_scope();
297           post_cnode_primitive->AddAttr(lite::quant::kGraphInputQuantParam, quantization_param_list.front());
298         }
299       }
300     }
301   } else if (input_node->isa<mindspore::CNode>()) {
302     auto input_cnode = input_node->cast<mindspore::CNodePtr>();
303     auto input_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
304     CHECK_NULL_RETURN(input_primitive);
305     if (cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
306       input_primitive->AddAttr(lite::quant::kQuantParam, cnode_primitive->GetAttr(lite::quant::kQuantParam));
307     }
308     if (pre_cnode_primitive->HasAttr(lite::quant::kQuantParam)) {
309       input_primitive->AddAttr(lite::quant::kQuantParam, pre_cnode_primitive->GetAttr(lite::quant::kQuantParam));
310     }
311   } else {
312     MS_LOG(ERROR) << input_node->fullname_with_scope() << " Not supported type.";
313     return RET_ERROR;
314   }
315   return RET_OK;
316 }
317 }  // namespace opt
318 }  // namespace mindspore
319