1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/graph/decrease_transpose_algo.h"
18 #include <queue>
19 #include <set>
20 #include <unordered_map>
21 #include <utility>
22 #include "ops/op_utils.h"
23 #include "src/common/common.h"
24 #include "src/common/utils.h"
25 #include "tools/common/tensor_util.h"
26 #include "nnacl/op_base.h"
27
28 namespace mindspore {
29 namespace opt {
30 namespace {
FindAreaSurroundedByTranspose(const FuncGraphPtr & func_graph,const CNodePtr & root_node,std::set<CNodePtr> * in_nodes,std::set<CNodePtr> * out_nodes,std::set<CNodePtr> * middle_nodes)31 STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node,
32 std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes,
33 std::set<CNodePtr> *middle_nodes) {
34 MS_ASSERT(func_graph != nullptr && root_node != nullptr);
35 MS_ASSERT(in_nodes != nullptr && out_nodes != nullptr && middle_nodes != nullptr);
36 std::queue<CNodePtr> queue_nodes{};
37 queue_nodes.push(root_node);
38 std::queue<bool> is_pre_nodes;
39 is_pre_nodes.push(true);
40 while (!queue_nodes.empty()) {
41 auto cur_node = queue_nodes.front();
42 auto is_pre_node = is_pre_nodes.front();
43 queue_nodes.pop();
44 is_pre_nodes.pop();
45 if (CheckPrimitiveType(cur_node, prim::kPrimTranspose)) {
46 if (is_pre_node) {
47 in_nodes->insert(cur_node);
48 } else {
49 out_nodes->insert(cur_node);
50 continue;
51 }
52 }
53 if (middle_nodes->find(cur_node) != middle_nodes->end()) {
54 continue;
55 }
56 if (in_nodes->find(cur_node) == in_nodes->end()) {
57 middle_nodes->insert(cur_node);
58 // insert pre nodes.
59 auto origin_inputs = cur_node->inputs();
60 lite::RemoveIfDepend(cur_node);
61 for (size_t i = 1; i < cur_node->size(); ++i) {
62 if (!utils::isa<CNodePtr>(cur_node->input(i))) {
63 continue;
64 }
65 auto cur_node_input = cur_node->input(i)->cast<CNodePtr>();
66 MS_ASSERT(cur_node_input != nullptr);
67 if (middle_nodes->find(cur_node_input) != middle_nodes->end() ||
68 in_nodes->find(cur_node_input) != in_nodes->end()) {
69 continue;
70 }
71 queue_nodes.push(cur_node_input);
72 is_pre_nodes.push(true);
73 }
74 if (CheckIsAllInputsParam(cur_node)) {
75 in_nodes->insert(cur_node);
76 }
77 cur_node->set_inputs(origin_inputs);
78 }
79 // insert post nodes
80 auto cur_node_users = func_graph->manager()->node_users()[cur_node];
81 for (auto &cur_node_user : cur_node_users) {
82 if (!utils::isa<CNodePtr>(cur_node_user.first)) {
83 MS_LOG(ERROR) << "post node is not cnode.";
84 return lite::RET_ERROR;
85 }
86 auto cur_node_post = cur_node_user.first->cast<CNodePtr>();
87 MS_CHECK_TRUE_MSG(cur_node_post != nullptr, RET_ERROR, "cast ptr failed");
88 if (middle_nodes->find(cur_node_post) != middle_nodes->end() ||
89 out_nodes->find(cur_node_post) != out_nodes->end()) {
90 continue;
91 }
92 queue_nodes.push(cur_node_post);
93 is_pre_nodes.push(false);
94 }
95 if (cur_node_users.empty()) {
96 out_nodes->insert(cur_node);
97 }
98 }
99 return lite::RET_OK;
100 }
101
SetTransType(const std::set<CNodePtr> & cnodes,FormatTransNodeType * trans_type)102 void SetTransType(const std::set<CNodePtr> &cnodes, FormatTransNodeType *trans_type) {
103 MS_ASSERT(trans_type != nullptr);
104 FormatTransNodeType local_trans_type;
105 for (auto &cnode : cnodes) {
106 std::vector<int> perm;
107 if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
108 (perm != kNH2NC && perm != kNC2NH)) {
109 *trans_type = kNONE;
110 return;
111 }
112 local_trans_type = perm == kNH2NC ? kNHWC2NCHW : kNCHW2NHWC;
113 *trans_type = *trans_type == kNONE ? local_trans_type : *trans_type;
114 if (*trans_type != local_trans_type) {
115 *trans_type = kNONE;
116 return;
117 }
118 }
119 }
120
JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> & in_nodes,const std::set<CNodePtr> & out_nodes,const std::set<CNodePtr> & middle_nodes,TransTypePair * trans_info)121 bool JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> &in_nodes, const std::set<CNodePtr> &out_nodes,
122 const std::set<CNodePtr> &middle_nodes, TransTypePair *trans_info) {
123 MS_ASSERT(trans_info != nullptr);
124 SetTransType(in_nodes, &trans_info->pre_);
125 if (trans_info->pre_ == kNONE) {
126 return false;
127 }
128 SetTransType(out_nodes, &trans_info->post_);
129 if (trans_info->post_ == kNONE) {
130 return false;
131 }
132 if (trans_info->pre_ == trans_info->post_) {
133 return false;
134 }
135 TransposeStrategy transpose_strategy;
136 for (auto &middle_cnode : middle_nodes) {
137 if (IsSpecialType(middle_cnode)) {
138 continue;
139 }
140 auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0));
141 MS_CHECK_TRUE_MSG(middle_node_prim != nullptr, false, "GetValueNode failed");
142 if (!transpose_strategy.CanChangeOpAxis(middle_cnode)) {
143 return false;
144 }
145 }
146 return true;
147 }
148
ConvertTensorToNCOrNH(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index,FmkType fmk_type,bool train_flag,FormatTransNodeType trans_type)149 int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
150 bool train_flag, FormatTransNodeType trans_type) {
151 MS_ASSERT(cnode != nullptr);
152 if (utils::isa<CNodePtr>(cnode->input(index))) {
153 return lite::RET_OK;
154 }
155 lite::DataInfo data_info;
156 int status = 0;
157 if (utils::isa<ParameterPtr>(cnode->input(index))) {
158 auto input_node = cnode->input(index)->cast<ParameterPtr>();
159 MS_CHECK_TRUE_MSG(input_node != nullptr, lite::RET_ERROR, "input_node is nullptr");
160 if (!input_node->has_default()) {
161 return lite::RET_OK;
162 }
163 status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
164 } else {
165 status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
166 }
167 if (status != lite::RET_OK) {
168 return lite::RET_ERROR;
169 }
170 if (data_info.shape_.empty() ||
171 (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
172 return lite::RET_OK;
173 }
174 ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
175 if (data_info.shape_.size() == 1) {
176 expand_shape = {1, 1, 1, data_info.shape_[0]};
177 } else if (data_info.shape_.size() == kInputSizeTwo) {
178 expand_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]};
179 } else if (data_info.shape_.size() == kInputSizeThree) {
180 expand_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]};
181 }
182 auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
183 data_info.data_.data(), data_info.data_.size());
184 MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr");
185 if (trans_type == kNHWC2NCHW) {
186 (void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
187 } else {
188 (void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
189 }
190 auto param_node = func_graph->add_parameter();
191 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add_parameter failed");
192 param_node->set_name(cnode->input(index)->fullname_with_scope());
193 status = lite::InitParameterFromTensorInfo(param_node, tensor);
194 if (status != RET_OK) {
195 MS_LOG(ERROR) << "init parameter from tensor info failed";
196 return lite::RET_ERROR;
197 }
198 auto tr = func_graph->manager()->Transact();
199 tr.SetEdge(cnode, index, param_node);
200 tr.Commit();
201 return lite::RET_OK;
202 }
203 } // namespace
204
PostTransposeFusion(const FuncGraphPtr & func_graph,const CNodePtr & cnode)205 STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
206 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
207 if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
208 return lite::RET_OK;
209 }
210 std::vector<int> cur_perm;
211 if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) {
212 MS_LOG(ERROR) << "get transpose perm failed.";
213 return lite::RET_ERROR;
214 }
215 auto node_users = func_graph->manager()->node_users()[cnode];
216 for (auto &node_user : node_users) {
217 auto post_node = node_user.first;
218 if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
219 std::vector<int> post_trans_perm;
220 auto post_trans_node = post_node->cast<CNodePtr>();
221 MS_ASSERT(post_trans_node != nullptr);
222 if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) {
223 MS_LOG(ERROR) << "get post transpose node perm failed.";
224 return lite::RET_ERROR;
225 }
226 if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
227 func_graph->manager()->Replace(post_node, cnode->input(1));
228 }
229 }
230 }
231 return lite::RET_OK;
232 }
233
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> perm,bool before,size_t index)234 STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
235 const std::vector<int> perm, bool before, size_t index) {
236 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
237 AnfNodePtr new_input = nullptr;
238 new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
239 if (new_input == nullptr) {
240 MS_LOG(ERROR) << "generate a transpose node failed.";
241 return lite::RET_ERROR;
242 }
243 if (new_input == cnode->input(index) || new_input == cnode) {
244 return lite::RET_OK;
245 } else if (utils::isa<CNodePtr>(new_input)) {
246 auto new_cnode_input = new_input->cast<CNodePtr>();
247 MS_ASSERT(new_cnode_input != nullptr);
248 int status = lite::RET_OK;
249 if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) {
250 status = node_infer_shape_.InferShape(new_cnode_input);
251 }
252 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
253 MS_LOG(ERROR) << "infer shape failed.";
254 return lite::RET_ERROR;
255 }
256 }
257 auto manager = func_graph->manager();
258 if (manager == nullptr) {
259 manager = Manage(func_graph, true);
260 }
261 auto tr = manager->Transact();
262 if (before) {
263 tr.SetEdge(cnode, index, new_input);
264 tr.Commit();
265 } else {
266 func_graph->manager()->Replace(cnode, new_input);
267 if (PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
268 MS_LOG(ERROR) << "post transpose fusion failed.";
269 return lite::RET_ERROR;
270 }
271 }
272 return lite::RET_OK;
273 }
274
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)275 STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
276 const std::vector<int> &perm) {
277 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
278 auto prim_node = cnode->input(0);
279 MS_CHECK_TRUE_MSG(prim_node != nullptr, lite::RET_ERROR, "prim_node is nullptr");
280 auto prim = GetValueNode<PrimitivePtr>(prim_node);
281 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
282 auto &specify_nhwc_op_map = GetNHWCOpMap();
283 auto &specify_nchw_op_map = GetNCHWOpMap();
284 if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
285 specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
286 MS_LOG(ERROR) << "op don't meet nhwc condition.";
287 return lite::RET_ERROR;
288 }
289 std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
290 ? specify_nhwc_op_map.at(prim->name())
291 : specify_nchw_op_map.at(prim->name());
292 if (insert_index.empty()) {
293 if (CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
294 GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
295 insert_index.push_back(1);
296 } else {
297 for (size_t i = 1; i < cnode->size(); ++i) {
298 insert_index.push_back(i);
299 }
300 }
301 }
302 for (auto &index : insert_index) {
303 if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
304 MS_LOG(ERROR) << "generate a new input failed.";
305 return lite::RET_ERROR;
306 }
307 }
308 return lite::RET_OK;
309 }
310
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_insert_info)311 STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
312 TransTypePair *trans_insert_info) {
313 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
314 MS_ASSERT(trans_insert_info != nullptr);
315 TransTypePair trans_info;
316 auto origin_inputs = cnode->inputs();
317 lite::RemoveIfMakeTuple(cnode);
318 RemoveIfMonad(cnode);
319 if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) {
320 cnode->set_inputs(origin_inputs);
321 return lite::RET_NO_CHANGE;
322 }
323 cnode->set_inputs(origin_inputs);
324 auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_insert_info->pre_);
325 if (status == lite::RET_NOT_SUPPORT) {
326 return lite::RET_NO_CHANGE;
327 } else if (status != lite::RET_OK) {
328 MS_LOG(ERROR) << "change op attr failed.";
329 return lite::RET_ERROR;
330 }
331 auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
332 for (size_t i = 1; i < cnode->size(); ++i) {
333 if (IsMonadNode(cnode->input(i))) {
334 continue;
335 }
336 if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
337 CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) {
338 auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
339 MS_ASSERT(input_make_tuple != nullptr);
340 for (size_t j = 1; j < input_make_tuple->size(); ++j) {
341 if (GenNewInput(func_graph, input_make_tuple, before_perm, true, j) != lite::RET_OK) {
342 MS_LOG(ERROR) << "generate a new input failed.";
343 return lite::RET_ERROR;
344 }
345 }
346 continue;
347 }
348 if (GenNewInput(func_graph, cnode, before_perm, true, i) != lite::RET_OK) {
349 MS_LOG(ERROR) << "generate a new input failed.";
350 return lite::RET_ERROR;
351 }
352 }
353 status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
354 if (status != lite::RET_OK) {
355 MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
356 return lite::RET_ERROR;
357 }
358 status = node_infer_shape_.InferShape(cnode);
359 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
360 MS_LOG(ERROR) << "infer shape failed.";
361 return lite::RET_ERROR;
362 }
363 return lite::RET_OK;
364 }
365
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)366 STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
367 const std::vector<int> &perm) {
368 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
369 if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
370 if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
371 MS_LOG(ERROR) << "generate a new input failed.";
372 return lite::RET_ERROR;
373 }
374 } else {
375 auto node_users = func_graph->manager()->node_users()[cnode];
376 for (auto &node_user : node_users) {
377 auto post_node = node_user.first;
378 CNodePtr tuple_get_item = nullptr;
379 if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
380 if (!train_flag_) {
381 MS_LOG(ERROR) << "post node is invalid.";
382 return lite::RET_ERROR;
383 } else {
384 tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0);
385 MS_CHECK_TRUE_RET(tuple_get_item != nullptr, lite::RET_ERROR);
386 post_node = tuple_get_item;
387 func_graph->manager()->Replace(cnode, tuple_get_item);
388 }
389 }
390 if (func_graph->manager()->node_users()[post_node].empty()) {
391 continue;
392 }
393 auto post_cnode = post_node->cast<CNodePtr>();
394 MS_ASSERT(post_cnode != nullptr);
395 if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
396 MS_LOG(ERROR) << "generate a new input failed.";
397 return lite::RET_ERROR;
398 }
399 if (tuple_get_item != nullptr) {
400 func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
401 }
402 }
403 }
404 return lite::RET_OK;
405 }
406
HandleGraphMultiNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::set<CNodePtr> * visit_transposes)407 STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
408 std::set<CNodePtr> *visit_transposes) {
409 MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
410 auto manager = func_graph->manager();
411 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "manager is nullptr");
412 std::set<CNodePtr> middle_nodes{};
413 std::set<CNodePtr> in_nodes{};
414 std::set<CNodePtr> out_nodes{};
415 auto status = FindAreaSurroundedByTranspose(func_graph, cnode, &in_nodes, &out_nodes, &middle_nodes);
416 if (status != lite::RET_OK) {
417 MS_LOG(ERROR) << "find an area surrounded by transpose failed.";
418 return status;
419 }
420 for (auto &in_cnode : in_nodes) {
421 if (CheckPrimitiveType(in_cnode, prim::kPrimTranspose)) {
422 visit_transposes->insert(in_cnode);
423 }
424 }
425 TransTypePair trans_info;
426 if (!JudgeCanOptimizerForMultiOp(in_nodes, out_nodes, middle_nodes, &trans_info)) {
427 return lite::RET_NO_CHANGE;
428 }
429 auto node_list = TopoSort(func_graph->get_return());
430 std::vector<CNodePtr> middle_ops_vec;
431 for (auto &node : node_list) {
432 if (!utils::isa<CNodePtr>(node)) {
433 continue;
434 }
435 if (middle_nodes.find(node->cast<CNodePtr>()) != middle_nodes.end()) {
436 middle_ops_vec.push_back(node->cast<CNodePtr>());
437 middle_nodes.erase(node->cast<CNodePtr>());
438 }
439 }
440 for (auto &in_cnode : in_nodes) {
441 manager->Replace(in_cnode, in_cnode->input(1));
442 }
443 for (auto &out_cnode : out_nodes) {
444 manager->Replace(out_cnode, out_cnode->input(1));
445 }
446 for (auto &middle_cnode : middle_ops_vec) {
447 if (IsSpecialType(middle_cnode)) {
448 continue;
449 }
450 for (size_t i = 1; i < middle_cnode->size(); ++i) {
451 status = ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
452 if (status != lite::RET_OK) {
453 MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
454 return lite::RET_ERROR;
455 }
456 }
457 status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
458 if (status != lite::RET_OK) {
459 MS_LOG(ERROR) << "change op attr failed.";
460 return lite::RET_ERROR;
461 }
462 status = ModifyCNodeFormat(middle_cnode, trans_info.post_);
463 if (status != lite::RET_OK) {
464 MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
465 return lite::RET_ERROR;
466 }
467 status = node_infer_shape_.InferShape(middle_cnode);
468 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
469 MS_LOG(ERROR) << "infer shape failed.";
470 return lite::RET_ERROR;
471 }
472 }
473 return lite::RET_OK;
474 }
475
SetSubGraphInput(const CNodePtr & cnode,const FuncGraphPtr & sub_graph)476 int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
477 MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
478 auto sub_inputs = sub_graph->get_inputs();
479 sub_inputs_map_[sub_graph] = sub_inputs;
480 for (auto &node : sub_inputs) {
481 auto param_node = node->cast<ParameterPtr>();
482 MS_ASSERT(param_node != nullptr);
483 auto node_name = node->fullname_with_scope();
484 auto last_underline = node_name.find_last_of("_");
485 node_name = node_name.substr(0, last_underline);
486 last_underline = node_name.find_last_of("_");
487 auto index = 0;
488 try {
489 index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
490 } catch (const std::exception &e) {
491 MS_LOG(ERROR) << "Get index failed: " << e.what();
492 return lite::RET_ERROR;
493 }
494 param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
495 if (utils::isa<CNodePtr>(cnode->input(index))) {
496 ShapeVector shape_vec = {-1};
497 auto out_cnode = cnode->input(index)->cast<CNodePtr>();
498 MS_ASSERT(out_cnode != nullptr);
499 MS_ASSERT(trans_cnode != nullptr);
500 auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
501 MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
502 if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
503 param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
504 }
505 } else {
506 lite::DataInfo data_info;
507 if (utils::isa<ParameterPtr>(cnode->input(index))) {
508 if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
509 param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
510 }
511 continue;
512 }
513 auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
514 if (status != lite::RET_OK) {
515 continue;
516 }
517 ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
518 if (data_info.data_.empty()) {
519 param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
520 } else {
521 param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
522 data_info.data_.data(), data_info.data_.size()));
523 }
524 }
525 }
526 return lite::RET_OK;
527 }
528
ResetSubGraphInput()529 int DecreaseTransposeAlgo::ResetSubGraphInput() {
530 for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
531 auto &sub_graph = iter->first;
532 auto &sub_inputs = iter->second;
533 auto manager = sub_graph->manager();
534 MS_ASSERT(manager != nullptr);
535 for (auto &sub_input : sub_inputs) {
536 auto param_node = sub_graph->add_parameter();
537 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add parameter failed");
538 param_node->set_abstract(sub_input->abstract()->Clone());
539 param_node->set_name(sub_input->fullname_with_scope());
540 manager->Replace(sub_input, param_node);
541 auto sub_param_input = sub_input->cast<ParameterPtr>();
542 MS_ASSERT(sub_param_input != nullptr);
543 sub_param_input->set_default_param(nullptr);
544 }
545 }
546 return lite::RET_OK;
547 }
548
SetSubGraphOutput(const FuncGraphPtr & sub_graph)549 int DecreaseTransposeAlgo::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
550 MS_ASSERT(sub_graph != nullptr);
551 auto return_node = sub_graph->get_return();
552 MS_ASSERT(return_node != nullptr);
553 auto origin_input = return_node->inputs();
554 lite::RemoveIfDepend(return_node);
555 lite::RemoveIfMakeTuple(return_node);
556 for (size_t i = 1; i < return_node->size(); ++i) {
557 if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
558 continue;
559 }
560 auto node_name = return_node->input(i)->fullname_with_scope();
561 if (node_name.size() < kInputSizeFive || node_name.substr(node_name.size() - kInputSizeFive) != "_post") {
562 continue;
563 }
564 auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
565 MS_ASSERT(trans_cnode != nullptr);
566 auto trans_input = trans_cnode->input(1);
567 auto trans_input_name = trans_input->fullname_with_scope();
568 if (utils::isa<ParameterPtr>(trans_input)) {
569 trans_input->cast<ParameterPtr>()->set_name(node_name);
570 } else if (utils::isa<CNodePtr>(trans_input)) {
571 trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
572 }
573 trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
574 trans_cnode->set_fullname_with_scope(trans_input_name);
575 }
576 return_node->set_inputs(origin_input);
577 return lite::RET_OK;
578 }
579
SetSubGraphAbstract(const CNodePtr & cnode,const FuncGraphPtr & sub_graph)580 int DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
581 MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
582 auto return_node = sub_graph->get_return();
583 MS_CHECK_TRUE_MSG(return_node != nullptr, lite::RET_ERROR, "return_node is nullptr");
584 auto origin_inputs = return_node->inputs();
585 lite::RemoveIfDepend(return_node);
586 lite::RemoveIfMakeTuple(return_node);
587 AbstractBasePtrList abstract_list;
588 bool infer_done = true;
589 for (size_t i = 1; i < return_node->size(); ++i) {
590 auto abstract_base = GetCNodeInputAbstract(return_node, i);
591 MS_CHECK_TRUE_MSG(abstract_base != nullptr, lite::RET_ERROR, "GetCNodeInputAbstract failed");
592 abstract_list.emplace_back(abstract_base->Clone());
593 auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
594 MS_ASSERT(abstract_tensor != nullptr);
595 auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
596 MS_ASSERT(shape_ptr != nullptr);
597 auto shape = shape_ptr->shape();
598 if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
599 infer_done = false;
600 }
601 if (utils::isa<CNodePtr>(return_node->input(i))) {
602 auto input_cnode = return_node->input(i)->cast<CNodePtr>();
603 MS_CHECK_TRUE_MSG(input_cnode != nullptr, lite::RET_ERROR, "input_cnode is nullptr");
604 if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
605 input_cnode = input_cnode->input(1)->cast<CNodePtr>();
606 }
607 auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
608 MS_CHECK_TRUE_MSG(input_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
609 if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
610 infer_done = false;
611 }
612 }
613 }
614 return_node->set_inputs(origin_inputs);
615 if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
616 cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
617 } else {
618 if (abstract_list.size() != 1) {
619 MS_LOG(ERROR) << "cnode output is invalid.";
620 }
621 cnode->set_abstract(abstract_list.front());
622 }
623 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
624 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
625 prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
626
627 return lite::RET_OK;
628 }
629
ModifyCNodeFormat(const CNodePtr & cnode,FormatTransNodeType pre_trans_type)630 int DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
631 MS_ASSERT(cnode != nullptr);
632 if (pre_trans_type == kNONE) {
633 return lite::RET_OK;
634 }
635 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
636 MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_ERROR, "GetValueNode Failed");
637 if (pre_trans_type == kNHWC2NCHW) {
638 primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
639 } else {
640 primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
641 }
642 return lite::RET_OK;
643 }
644
DecreaseTransposeForSingleOp(const FuncGraphPtr & func_graph)645 bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
646 MS_ASSERT(func_graph != nullptr);
647 auto manager = Manage(func_graph, true);
648 if (manager == nullptr) {
649 MS_LOG(ERROR) << "manager is nullptr.";
650 return false;
651 }
652 auto node_list = TopoSort(func_graph->get_return());
653 int status = 0;
654 for (auto &node : node_list) {
655 if (!utils::isa<CNodePtr>(node)) {
656 continue;
657 }
658 auto cnode = node->cast<CNodePtr>();
659 MS_ASSERT(cnode != nullptr);
660 if (IsSpecialType(cnode)) {
661 continue;
662 }
663 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
664 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
665 if (sub_func_graph == nullptr) {
666 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
667 return false;
668 }
669 auto ret = SetSubGraphInput(cnode, sub_func_graph);
670 if (ret != lite::RET_OK) {
671 MS_LOG(ERROR) << "SetSubGraphInput failed";
672 return false;
673 }
674 (void)DecreaseTransposeForSingleOp(sub_func_graph);
675 ret = SetSubGraphOutput(sub_func_graph);
676 if (ret != lite::RET_OK) {
677 MS_LOG(ERROR) << "SetSubGraphOutput failed";
678 return false;
679 }
680 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
681 if (sub_func_graph == nullptr) {
682 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
683 return false;
684 }
685 ret = SetSubGraphInput(cnode, sub_func_graph);
686 if (ret != lite::RET_OK) {
687 MS_LOG(ERROR) << "SetSubGraphInput failed";
688 return false;
689 }
690 (void)DecreaseTransposeForSingleOp(sub_func_graph);
691 ret = SetSubGraphOutput(sub_func_graph);
692 if (ret != lite::RET_OK) {
693 MS_LOG(ERROR) << "SetSubGraphOutput failed";
694 return false;
695 }
696 ret = SetSubGraphAbstract(cnode, sub_func_graph);
697 if (ret != lite::RET_OK) {
698 MS_LOG(ERROR) << "SetSubGraphAbstract failed";
699 return false;
700 }
701 continue;
702 }
703 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
704 MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
705 if (!IsDynamicFormatOp(prim->name())) {
706 continue;
707 }
708 TransTypePair trans_insert_info;
709 status = InsertPreTransNode(func_graph, cnode, &trans_insert_info);
710 if (status == lite::RET_NO_CHANGE) {
711 continue;
712 } else if (status != lite::RET_OK) {
713 MS_LOG(ERROR) << "insert pre node failed.";
714 return false;
715 }
716 auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
717 if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
718 MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
719 return false;
720 }
721 }
722 return true;
723 }
724
DecreaseTransposeForMultiOp(const FuncGraphPtr & func_graph)725 bool DecreaseTransposeAlgo::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
726 MS_ASSERT(func_graph != nullptr);
727 auto manager = Manage(func_graph, true);
728 if (manager == nullptr) {
729 MS_LOG(ERROR) << "manager is nullptr.";
730 return false;
731 }
732 auto node_list = TopoSort(func_graph->get_return());
733 std::set<CNodePtr> visit_transposes;
734 for (auto &node : node_list) {
735 if (!utils::isa<CNodePtr>(node)) {
736 continue;
737 }
738 auto cnode = node->cast<CNodePtr>();
739 MS_ASSERT(cnode != nullptr);
740 if (IsSpecialType(cnode) || visit_transposes.find(cnode) != visit_transposes.end()) {
741 continue;
742 }
743 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
744 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
745 if (sub_func_graph == nullptr) {
746 return false;
747 }
748 (void)DecreaseTransposeForMultiOp(sub_func_graph);
749 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
750 if (sub_func_graph == nullptr) {
751 return false;
752 }
753 (void)DecreaseTransposeForMultiOp(sub_func_graph);
754 }
755 std::vector<int> perm{};
756 if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
757 perm != kNH2NC) {
758 continue;
759 }
760 auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes);
761 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
762 MS_LOG(ERROR) << "global optimizer failed.";
763 return false;
764 }
765 }
766 return true;
767 }
768
RunDoFixFormat(const FuncGraphPtr & func_graph,const CNodePtr & cnode)769 bool DecreaseTransposeAlgo::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
770 auto prim_node = cnode->input(0);
771 auto prim = GetValueNode<PrimitivePtr>(prim_node);
772 MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
773 auto &nchw_op = GetNCHWOpMap();
774 if (!utils::isa<CNodePtr>(cnode->input(1))) {
775 return true;
776 }
777 if (utils::isa<CNodePtr>(cnode->input(1))) {
778 auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
779 if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) {
780 InsertPreTransNode(func_graph, cnode, kNH2NC);
781 InsertPostTransNode(func_graph, cnode, kNC2NH);
782 }
783 }
784 return true;
785 }
786
DoFixFormat(const FuncGraphPtr & func_graph)787 bool DecreaseTransposeAlgo::DoFixFormat(const FuncGraphPtr &func_graph) {
788 auto node_list = TopoSort(func_graph->get_return());
789 for (auto &node : node_list) {
790 if (!utils::isa<CNodePtr>(node)) {
791 continue;
792 }
793 auto cnode = node->cast<CNodePtr>();
794 MS_ASSERT(cnode != nullptr);
795 if (IsSpecialType(cnode)) {
796 continue;
797 }
798 if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
799 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
800 if (sub_func_graph == nullptr) {
801 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
802 return false;
803 }
804 SetSubGraphInput(cnode, sub_func_graph);
805 if (!DoFixFormat(sub_func_graph)) {
806 MS_LOG(ERROR) << "subgraph infer shape failed.";
807 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
808 return false;
809 }
810 SetSubGraphOutput(sub_func_graph);
811
812 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
813 if (sub_func_graph == nullptr) {
814 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
815 return false;
816 }
817 SetSubGraphInput(cnode, sub_func_graph);
818 if (!DoFixFormat(sub_func_graph)) {
819 MS_LOG(ERROR) << "subgraph infer shape failed.";
820 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
821 return false;
822 }
823 SetSubGraphOutput(sub_func_graph);
824 SetSubGraphAbstract(cnode, sub_func_graph);
825 continue;
826 }
827 if (!RunDoFixFormat(func_graph, cnode)) {
828 return false;
829 }
830 }
831 return true;
832 }
833
Run(const FuncGraphPtr & func_graph)834 bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
835 MS_ASSERT(func_graph != nullptr);
836 node_infer_shape_.Init(fmk_type_, train_flag_);
837 transpose_strategy_.Init(fmk_type_, train_flag_);
838 if (!delete_redundant_transpose_.Run(func_graph)) {
839 MS_LOG(ERROR) << "Run delete-redundant-transpose pass failed.";
840 return false;
841 }
842 auto node_list = TopoSort(func_graph->get_return());
843 for (auto &node : node_list) {
844 auto prim = GetValueNode<PrimitivePtr>(node);
845 if (prim == nullptr) {
846 continue;
847 }
848 }
849
850 if (!DoFixFormat(func_graph)) {
851 MS_LOG(ERROR) << "DoFixFormat failed.";
852 return false;
853 }
854 ResetSubGraphInput();
855
856 if (!DecreaseTransposeForSingleOp(func_graph)) {
857 MS_LOG(ERROR) << "run local trans insert optimizer failed.";
858 return false;
859 }
860
861 auto ret = ResetSubGraphInput();
862 if (ret != lite::RET_OK) {
863 MS_LOG(ERROR) << "ResetSubGraphInput failed.";
864 return false;
865 }
866 // if input format of several ops surrounded only by transpose op all can be NHWC,
867 // we can delete these transpose ops, and at the same time, transform these middle ops.
868 if (!DecreaseTransposeForMultiOp(func_graph)) {
869 MS_LOG(ERROR) << "run global trans insert optimizer failed.";
870 return false;
871 }
872 return true;
873 }
874 } // namespace opt
875 } // namespace mindspore
876