1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/decrease_transpose_algo.h"
19 #include <queue>
20 #include <set>
21 #include <unordered_map>
22 #include <utility>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "ops/op_utils.h"
28 #include "src/common/common.h"
29 #include "src/common/utils.h"
30 #include "tools/common/tensor_util.h"
31 #include "nnacl/op_base.h"
32 #include "tools/optimizer/graph/specify_graph_input_format.h"
33
34 namespace mindspore {
35 namespace opt {
36 namespace {
__anon9066e5770202(const CNodePtr &cnode) 37 std::function<bool(const CNodePtr &)> check_node = [](const CNodePtr &cnode) {
38 if (IsSpecialType(cnode) || CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
39 return true;
40 }
41 if (CheckPrimitiveType(cnode, prim::kPrimConcat)) {
42 // if concat's pre node is MakeTuple, the shape is complex, not optimize this case
43 if (cnode->size() > 1) {
44 auto tuple_node = cnode->input(1)->cast<CNodePtr>();
45 if (CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
46 for (size_t i = 1; i < tuple_node->size(); ++i) {
47 auto input_node = tuple_node->input(i);
48 if (utils::isa<ValueNodePtr>(input_node) || utils::isa<Parameter>(input_node)) {
49 return false;
50 }
51 }
52 }
53 }
54 }
55 TransposeStrategy transpose_strategy;
56 return transpose_strategy.CanChangeOpAxis(cnode);
57 };
58
DealWithInputNodes(const FuncGraphPtr & func_graph,const CNodePtr & cur_node,std::set<CNodePtr> * in_nodes,AnfNodeIndexSet * not_trans_in_nodes,std::set<CNodePtr> * middle_nodes,std::queue<CNodePtr> * queue_nodes,std::queue<bool> * is_pre_nodes)59 void DealWithInputNodes(const FuncGraphPtr &func_graph, const CNodePtr &cur_node, std::set<CNodePtr> *in_nodes,
60 AnfNodeIndexSet *not_trans_in_nodes, std::set<CNodePtr> *middle_nodes,
61 std::queue<CNodePtr> *queue_nodes, std::queue<bool> *is_pre_nodes) {
62 MS_ASSERT(func_graph != nullptr && cur_node != nullptr);
63 MS_ASSERT(out_nodes != nullptr && not_trans_out_nodes != nullptr && middle_nodes != nullptr);
64 MS_ASSERT(queue_nodes != nullptr && is_pre_nodes != nullptr);
65 for (size_t i = 1; i < cur_node->size(); ++i) {
66 if (!utils::isa<CNodePtr>(cur_node->input(i))) {
67 continue;
68 }
69 auto cur_node_input = cur_node->input(i)->cast<CNodePtr>();
70 MS_ASSERT(cur_node_input != nullptr);
71 if (middle_nodes->find(cur_node_input) != middle_nodes->end() ||
72 in_nodes->find(cur_node_input) != in_nodes->end()) {
73 continue;
74 }
75 if (!check_node(cur_node_input)) {
76 (void)not_trans_in_nodes->insert(std::make_pair(cur_node, i));
77 continue;
78 }
79 queue_nodes->push(cur_node_input);
80 is_pre_nodes->push(true);
81 }
82 }
83
DealWithOutputNodes(const FuncGraphPtr & func_graph,const CNodePtr & cur_node,std::set<CNodePtr> * out_nodes,AnfNodeIndexSet * not_trans_out_nodes,std::set<CNodePtr> * middle_nodes,std::queue<CNodePtr> * queue_nodes,std::queue<bool> * is_pre_nodes)84 STATUS DealWithOutputNodes(const FuncGraphPtr &func_graph, const CNodePtr &cur_node, std::set<CNodePtr> *out_nodes,
85 AnfNodeIndexSet *not_trans_out_nodes, std::set<CNodePtr> *middle_nodes,
86 std::queue<CNodePtr> *queue_nodes, std::queue<bool> *is_pre_nodes) {
87 MS_ASSERT(func_graph != nullptr && cur_node != nullptr);
88 MS_ASSERT(out_nodes != nullptr && not_trans_out_nodes != nullptr && middle_nodes != nullptr);
89 MS_ASSERT(queue_nodes != nullptr && is_pre_nodes != nullptr);
90 auto cur_node_users = func_graph->manager()->node_users()[cur_node];
91 if (cur_node_users.empty()) {
92 (void)out_nodes->insert(cur_node);
93 return lite::RET_OK;
94 }
95 for (auto &cur_node_user : cur_node_users) {
96 MS_CHECK_TRUE_RET(utils::isa<CNodePtr>(cur_node_user.first), lite::RET_ERROR);
97 auto cur_node_post = cur_node_user.first->cast<CNodePtr>();
98 MS_CHECK_TRUE_RET(cur_node_post != nullptr, lite::RET_ERROR);
99 if (middle_nodes->find(cur_node_post) != middle_nodes->end() ||
100 out_nodes->find(cur_node_post) != out_nodes->end()) {
101 continue;
102 }
103 if (!check_node(cur_node_post)) {
104 (void)not_trans_out_nodes->insert(cur_node_user);
105 continue;
106 }
107 queue_nodes->push(cur_node_post);
108 is_pre_nodes->push(false);
109 }
110 return lite::RET_OK;
111 }
112
FindAreaSurroundedByTranspose(const FuncGraphPtr & func_graph,const CNodePtr & root_node,std::set<CNodePtr> * in_nodes,std::set<CNodePtr> * out_nodes,AnfNodeIndexSet * not_trans_in_nodes,AnfNodeIndexSet * not_trans_out_nodes,std::set<CNodePtr> * middle_nodes)113 STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node,
114 std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes,
115 AnfNodeIndexSet *not_trans_in_nodes, AnfNodeIndexSet *not_trans_out_nodes,
116 std::set<CNodePtr> *middle_nodes) {
117 MS_ASSERT(func_graph != nullptr && root_node != nullptr);
118 MS_ASSERT(in_nodes != nullptr && out_nodes != nullptr && middle_nodes != nullptr);
119 MS_ASSERT(not_trans_in_nodes != nullptr && not_trans_out_nodes != nullptr);
120 std::queue<CNodePtr> queue_nodes{};
121 queue_nodes.push(root_node);
122 std::queue<bool> is_pre_nodes;
123 is_pre_nodes.push(true);
124 while (!queue_nodes.empty()) {
125 auto cur_node = queue_nodes.front();
126 auto is_pre_node = is_pre_nodes.front();
127 queue_nodes.pop();
128 is_pre_nodes.pop();
129 if (CheckPrimitiveType(cur_node, prim::kPrimTranspose)) {
130 if (is_pre_node) {
131 (void)in_nodes->insert(cur_node);
132 } else {
133 (void)out_nodes->insert(cur_node);
134 continue;
135 }
136 }
137 if (middle_nodes->find(cur_node) != middle_nodes->end()) {
138 continue;
139 }
140 if (in_nodes->find(cur_node) == in_nodes->end()) {
141 middle_nodes->insert(cur_node);
142 // insert pre nodes.
143 auto origin_inputs = cur_node->inputs();
144 lite::RemoveIfDepend(cur_node);
145 if (CheckIsAllInputsParam(cur_node)) {
146 in_nodes->insert(cur_node);
147 } else {
148 DealWithInputNodes(func_graph, cur_node, in_nodes, not_trans_in_nodes, middle_nodes, &queue_nodes,
149 &is_pre_nodes);
150 }
151 cur_node->set_inputs(origin_inputs);
152 }
153 // insert post nodes
154 if (DealWithOutputNodes(func_graph, cur_node, out_nodes, not_trans_out_nodes, middle_nodes, &queue_nodes,
155 &is_pre_nodes) != lite::RET_OK) {
156 MS_LOG(ERROR) << "deal with output failed.";
157 return lite::RET_ERROR;
158 }
159 }
160 return lite::RET_OK;
161 }
162
SetTransType(const std::set<CNodePtr> & cnodes,FormatTransNodeType * trans_type)163 void SetTransType(const std::set<CNodePtr> &cnodes, FormatTransNodeType *trans_type) {
164 MS_ASSERT(trans_type != nullptr);
165 FormatTransNodeType local_trans_type;
166 for (auto &cnode : cnodes) {
167 std::vector<int> perm;
168 if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
169 (perm != kNH2NC && perm != kNC2NH)) {
170 *trans_type = kNONE;
171 return;
172 }
173 local_trans_type = perm == kNH2NC ? kNHWC2NCHW : kNCHW2NHWC;
174 *trans_type = *trans_type == kNONE ? local_trans_type : *trans_type;
175 if (*trans_type != local_trans_type) {
176 *trans_type = kNONE;
177 return;
178 }
179 }
180 }
181
JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> & in_nodes,const std::set<CNodePtr> & out_nodes,const AnfNodeIndexSet & not_trans_in_nodes,const AnfNodeIndexSet & not_trans_out_nodes,const std::set<CNodePtr> & middle_nodes,TransTypePair * trans_info,bool all_trans)182 bool JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> &in_nodes, const std::set<CNodePtr> &out_nodes,
183 const AnfNodeIndexSet ¬_trans_in_nodes, const AnfNodeIndexSet ¬_trans_out_nodes,
184 const std::set<CNodePtr> &middle_nodes, TransTypePair *trans_info, bool all_trans) {
185 MS_ASSERT(trans_info != nullptr);
186 if (all_trans && (!not_trans_in_nodes.empty() || !not_trans_out_nodes.empty())) {
187 return false;
188 }
189 if (not_trans_in_nodes.size() + not_trans_out_nodes.size() >= in_nodes.size() + out_nodes.size()) {
190 return false;
191 }
192 SetTransType(in_nodes, &trans_info->pre_);
193 if (trans_info->pre_ == kNONE) {
194 return false;
195 }
196 SetTransType(out_nodes, &trans_info->post_);
197 if (trans_info->post_ == kNONE) {
198 return false;
199 }
200 if (trans_info->pre_ == trans_info->post_) {
201 return false;
202 }
203 return true;
204 }
205
ConvertTensorToNCOrNH(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index,FmkType fmk_type,bool train_flag,FormatTransNodeType trans_type)206 int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
207 bool train_flag, FormatTransNodeType trans_type) {
208 MS_ASSERT(cnode != nullptr);
209 if (utils::isa<CNodePtr>(cnode->input(index))) {
210 return lite::RET_OK;
211 }
212 lite::DataInfo data_info;
213 int status = 0;
214 if (utils::isa<ParameterPtr>(cnode->input(index))) {
215 auto input_node = cnode->input(index)->cast<ParameterPtr>();
216 MS_CHECK_TRUE_MSG(input_node != nullptr, lite::RET_ERROR, "input_node is nullptr");
217 if (!input_node->has_default()) {
218 return lite::RET_OK;
219 }
220 status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, &data_info, true);
221 } else {
222 status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info, true);
223 }
224 if (status != lite::RET_OK) {
225 return lite::RET_ERROR;
226 }
227 if (lite::JudgeDynamicShape(data_info.shape_) || (data_info.shape_.size() == 1 && data_info.shape_[0] == 0) ||
228 std::all_of(data_info.shape_.begin(), data_info.shape_.end(), [](int val) { return val == 1; }) ||
229 (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
230 return lite::RET_OK;
231 }
232 ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
233 bool need_expand = !CheckPrimitiveType(cnode, prim::kPrimScaleFusion);
234 if (need_expand && expand_shape.size() <= DIMENSION_4D) {
235 ShapeVector tmp_shape(DIMENSION_4D - expand_shape.size(), 1);
236 (void)expand_shape.insert(expand_shape.begin(), tmp_shape.begin(), tmp_shape.end());
237 }
238 auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
239 data_info.data_.data(), data_info.data_.size());
240 MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr");
241 if (trans_type == kNHWC2NCHW) {
242 (void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
243 } else {
244 (void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
245 }
246 auto param_node = func_graph->add_parameter();
247 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add_parameter failed");
248 param_node->set_name(cnode->input(index)->fullname_with_scope());
249 status = lite::InitParameterFromTensorInfo(param_node, tensor);
250 if (status != RET_OK) {
251 MS_LOG(ERROR) << "init parameter from tensor info failed";
252 return lite::RET_ERROR;
253 }
254 auto tr = func_graph->manager()->Transact();
255 tr.SetEdge(cnode, index, param_node);
256 tr.Commit();
257 return lite::RET_OK;
258 }
259 } // namespace
260
PostTransposeFusion(const FuncGraphPtr & func_graph,const CNodePtr & cnode)261 STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
262 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
263 if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
264 return lite::RET_OK;
265 }
266 std::vector<int> cur_perm;
267 if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) {
268 MS_LOG(ERROR) << "get transpose perm failed.";
269 return lite::RET_ERROR;
270 }
271 auto manager = func_graph->manager();
272 MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
273 MS_CHECK_TRUE_RET(manager->node_users().find(cnode) != manager->node_users().end(), lite::RET_ERROR);
274 auto node_users = manager->node_users()[cnode];
275 for (auto &node_user : node_users) {
276 auto post_node = node_user.first;
277 if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
278 std::vector<int> post_trans_perm;
279 auto post_trans_node = post_node->cast<CNodePtr>();
280 MS_ASSERT(post_trans_node != nullptr);
281 if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) {
282 MS_LOG(ERROR) << "get post transpose node perm failed.";
283 return lite::RET_ERROR;
284 }
285 if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
286 (void)manager->Replace(post_node, cnode->input(1));
287 }
288 }
289 }
290 return lite::RET_OK;
291 }
292
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> perm,bool before,size_t index)293 STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
294 const std::vector<int> perm, bool before, size_t index) {
295 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
296 AnfNodePtr new_input = nullptr;
297 new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
298 if (new_input == nullptr) {
299 MS_LOG(ERROR) << "generate a transpose node failed.";
300 return lite::RET_ERROR;
301 }
302 if (new_input == cnode->input(index) || new_input == cnode) {
303 return lite::RET_OK;
304 } else if (utils::isa<CNodePtr>(new_input)) {
305 auto new_cnode_input = new_input->cast<CNodePtr>();
306 MS_ASSERT(new_cnode_input != nullptr);
307 int status = lite::RET_OK;
308 if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) {
309 status = node_infer_shape_.InferShape(new_cnode_input);
310 }
311 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
312 MS_LOG(ERROR) << "infer shape failed.";
313 return lite::RET_ERROR;
314 }
315 }
316 auto manager = func_graph->manager();
317 if (manager == nullptr) {
318 manager = Manage(func_graph, true);
319 }
320 auto tr = manager->Transact();
321 if (before) {
322 tr.SetEdge(cnode, index, new_input);
323 tr.Commit();
324 } else {
325 (void)manager->Replace(cnode, new_input);
326 if (PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
327 MS_LOG(ERROR) << "post transpose fusion failed.";
328 return lite::RET_ERROR;
329 }
330 }
331 return lite::RET_OK;
332 }
333
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_insert_info)334 STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
335 TransTypePair *trans_insert_info) {
336 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
337 MS_ASSERT(trans_insert_info != nullptr);
338 TransTypePair trans_info;
339 auto origin_inputs = cnode->inputs();
340 lite::RemoveIfMakeTuple(cnode);
341 RemoveIfMonad(cnode);
342 if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) {
343 cnode->set_inputs(origin_inputs);
344 return lite::RET_NO_CHANGE;
345 }
346 cnode->set_inputs(origin_inputs);
347 auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_insert_info->pre_);
348 if (status == lite::RET_NOT_SUPPORT) {
349 return lite::RET_NO_CHANGE;
350 } else if (status != lite::RET_OK) {
351 MS_LOG(ERROR) << "change op attr failed.";
352 return lite::RET_ERROR;
353 }
354 status = DoPreInsert(func_graph, cnode, trans_insert_info->pre_);
355 if (status != lite::RET_OK) {
356 MS_LOG(ERROR) << "Do pre insert failed.";
357 return lite::RET_ERROR;
358 }
359 status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
360 if (status != lite::RET_OK) {
361 MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
362 return lite::RET_ERROR;
363 }
364 status = node_infer_shape_.InferShape(cnode);
365 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
366 MS_LOG(ERROR) << "infer shape failed.";
367 return lite::RET_ERROR;
368 }
369 return lite::RET_OK;
370 }
371
DoPreInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type)372 STATUS DecreaseTransposeAlgo::DoPreInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
373 FormatTransNodeType trans_type) {
374 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
375 auto abstract = cnode->abstract();
376 MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
377 if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
378 auto abstract_tuple = abstract->cast<abstract::AbstractTuplePtr>();
379 auto abstract_list = abstract_tuple->elements();
380 MS_CHECK_TRUE_RET(!abstract_list.empty(), lite::RET_OUT_OF_TENSOR_RANGE);
381 abstract = abstract_list.front();
382 MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
383 }
384 auto HandleFunc = [this](const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index,
385 FormatTransNodeType trans_type) -> STATUS {
386 STATUS ret = lite::RET_OK;
387 auto before_perm = trans_type == kNHWC2NCHW ? kNH2NC : kNC2NH;
388 if (IsNeedGenNewInput(func_graph, cnode, index)) {
389 ret = GenNewInput(func_graph, cnode, before_perm, true, index);
390 if (ret != lite::RET_OK) {
391 MS_LOG(ERROR) << "GenNewInput failed";
392 }
393 } else {
394 ret = ConvertTensorToNCOrNH(func_graph, cnode, index, fmk_type_, train_flag_, trans_type);
395 if (ret != lite::RET_OK) {
396 MS_LOG(ERROR) << "ConvertTensorToNCOrNH faileded";
397 }
398 }
399 return ret;
400 };
401 for (size_t i = 1; i < cnode->size(); ++i) {
402 MS_CHECK_TRUE_RET(cnode->input(i) != nullptr, lite::RET_NULL_PTR);
403 if (IsMonadNode(cnode->input(i))) {
404 continue;
405 }
406 if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
407 CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTupleV2)) {
408 auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
409 MS_ASSERT(input_make_tuple != nullptr);
410 for (size_t j = 1; j < input_make_tuple->size(); ++j) {
411 MS_CHECK_TRUE_RET(input_make_tuple->input(j) != nullptr, lite::RET_NULL_PTR);
412 if (HandleFunc(func_graph, input_make_tuple, j, trans_type) != lite::RET_OK) {
413 MS_LOG(ERROR) << "handle pre insert failed.";
414 return lite::RET_ERROR;
415 }
416 }
417 continue;
418 }
419 if (HandleFunc(func_graph, cnode, i, trans_type) != lite::RET_OK) {
420 MS_LOG(ERROR) << "handle pre insert failed.";
421 return lite::RET_ERROR;
422 }
423 }
424 return lite::RET_OK;
425 }
426
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)427 STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
428 const std::vector<int> &perm) {
429 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
430 auto manager = func_graph->manager();
431 MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
432 auto abstract = cnode->abstract();
433 MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_ERROR);
434 if (!abstract->isa<abstract::AbstractTuple>()) {
435 if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
436 MS_LOG(ERROR) << "generate a new input failed.";
437 return lite::RET_ERROR;
438 }
439 } else {
440 MS_CHECK_TRUE_RET(manager->node_users().find(cnode) != manager->node_users().end(), lite::RET_ERROR);
441 auto node_users = manager->node_users()[cnode];
442 for (auto &node_user : node_users) {
443 auto post_node = node_user.first;
444 CNodePtr tuple_get_item = nullptr;
445 if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
446 if (!train_flag_) {
447 MS_LOG(ERROR) << "post node is invalid.";
448 return lite::RET_ERROR;
449 } else {
450 tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0);
451 MS_CHECK_TRUE_RET(tuple_get_item != nullptr, lite::RET_ERROR);
452 post_node = tuple_get_item;
453 (void)manager->Replace(cnode, tuple_get_item);
454 }
455 }
456 MS_CHECK_TRUE_RET(manager->node_users().find(post_node) != manager->node_users().end(), lite::RET_ERROR);
457 if (manager->node_users()[post_node].empty()) {
458 continue;
459 }
460 auto post_cnode = post_node->cast<CNodePtr>();
461 MS_ASSERT(post_cnode != nullptr);
462 if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
463 MS_LOG(ERROR) << "generate a new input failed.";
464 return lite::RET_ERROR;
465 }
466 if (tuple_get_item != nullptr) {
467 (void)manager->Replace(tuple_get_item, tuple_get_item->input(1));
468 }
469 }
470 }
471 return lite::RET_OK;
472 }
473
HandleGraphSingleNode(const FuncGraphPtr & func_graph,const TransTypePair & trans_info,const CNodePtr & cnode)474 STATUS DecreaseTransposeAlgo::HandleGraphSingleNode(const FuncGraphPtr &func_graph, const TransTypePair &trans_info,
475 const CNodePtr &cnode) {
476 for (size_t i = 1; i < cnode->size(); ++i) {
477 auto status = ConvertTensorToNCOrNH(func_graph, cnode, i, fmk_type_, train_flag_, trans_info.post_);
478 if (status != lite::RET_OK) {
479 MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
480 return lite::RET_ERROR;
481 }
482 }
483 auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_info.post_);
484 if (status != lite::RET_OK) {
485 MS_LOG(ERROR) << "change op attr failed.";
486 return lite::RET_ERROR;
487 }
488 status = ModifyCNodeFormat(cnode, trans_info.post_);
489 if (status != lite::RET_OK) {
490 MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
491 return lite::RET_ERROR;
492 }
493 status = node_infer_shape_.InferShape(cnode);
494 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
495 MS_LOG(ERROR) << "infer shape failed.";
496 return lite::RET_ERROR;
497 }
498 return RET_OK;
499 }
500
InsertPreTransForNonTransInOut(const FuncGraphPtr & func_graph,const AnfNodeIndexSet & not_trans_in_nodes,const AnfNodeIndexSet & not_trans_out_nodes,TransTypePair trans_info)501 int DecreaseTransposeAlgo::InsertPreTransForNonTransInOut(const FuncGraphPtr &func_graph,
502 const AnfNodeIndexSet ¬_trans_in_nodes,
503 const AnfNodeIndexSet ¬_trans_out_nodes,
504 TransTypePair trans_info) {
505 std::function<bool(const AnfNodePtr &, size_t, FormatTransNodeType)> insert_pre_trans =
506 [&](const AnfNodePtr &node, size_t index, FormatTransNodeType format_trans_type) -> bool {
507 MS_CHECK_TRUE_RET(node != nullptr, false);
508 auto cnode = node->cast<CNodePtr>();
509 MS_CHECK_TRUE_RET(cnode != nullptr, false);
510 if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
511 auto node_users = func_graph->manager()->node_users()[node];
512 return std::all_of(node_users.begin(), node_users.end(),
513 [&insert_pre_trans, &format_trans_type](const std::pair<AnfNodePtr, int> &pair) {
514 return insert_pre_trans(pair.first, pair.second, format_trans_type);
515 });
516 } else if (CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTuple) ||
517 CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTupleV2)) {
518 auto make_tuple_cnode = cnode->input(1)->cast<CNodePtr>();
519 MS_CHECK_TRUE_RET(make_tuple_cnode != nullptr, false);
520 for (size_t i = 0; i < make_tuple_cnode->size(); i++) {
521 if (!insert_pre_trans(make_tuple_cnode->input(i), i, format_trans_type)) {
522 return false;
523 }
524 }
525 return true;
526 }
527 auto perm = format_trans_type == kNHWC2NCHW ? kNC2NH : kNH2NC;
528 return GenNewInput(func_graph, cnode, perm, true, index) == lite::RET_OK;
529 };
530 bool deal_inputs = std::all_of(not_trans_in_nodes.begin(), not_trans_in_nodes.end(),
531 [&insert_pre_trans, &trans_info](const std::pair<AnfNodePtr, int> &pair) {
532 return insert_pre_trans(pair.first, pair.second, trans_info.pre_);
533 });
534 if (!deal_inputs) {
535 MS_LOG(ERROR) << "Insert transpose for inputs failed.";
536 return lite::RET_ERROR;
537 }
538 bool deal_outputs = std::all_of(not_trans_out_nodes.begin(), not_trans_out_nodes.end(),
539 [&insert_pre_trans, &trans_info](const std::pair<AnfNodePtr, int> &pair) {
540 return insert_pre_trans(pair.first, pair.second, trans_info.post_);
541 });
542 if (!deal_outputs) {
543 MS_LOG(ERROR) << "Insert transpose for outputs failed.";
544 return lite::RET_ERROR;
545 }
546 return lite::RET_OK;
547 }
548
HandleGraphMultiNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::set<CNodePtr> * visit_transposes)549 STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
550 std::set<CNodePtr> *visit_transposes) {
551 MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
552 auto manager = func_graph->manager();
553 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "manager is nullptr");
554 std::set<CNodePtr> middle_nodes{};
555 std::set<CNodePtr> in_nodes{};
556 std::set<CNodePtr> out_nodes{};
557 AnfNodeIndexSet not_trans_in_nodes;
558 AnfNodeIndexSet not_trans_out_nodes;
559 auto status = FindAreaSurroundedByTranspose(func_graph, cnode, &in_nodes, &out_nodes, ¬_trans_in_nodes,
560 ¬_trans_out_nodes, &middle_nodes);
561 if (status != lite::RET_OK) {
562 MS_LOG(ERROR) << "find an area surrounded by transpose failed.";
563 return status;
564 }
565 for (auto &in_cnode : in_nodes) {
566 if (CheckPrimitiveType(in_cnode, prim::kPrimTranspose)) {
567 visit_transposes->insert(in_cnode);
568 }
569 }
570 TransTypePair trans_info;
571 if (!JudgeCanOptimizerForMultiOp(in_nodes, out_nodes, not_trans_in_nodes, not_trans_out_nodes, middle_nodes,
572 &trans_info, (train_flag_ || all_trans_))) {
573 return lite::RET_NO_CHANGE;
574 }
575 auto node_list = TopoSort(func_graph->get_return());
576 std::vector<CNodePtr> middle_ops_vec;
577 for (auto &node : node_list) {
578 if (utils::isa<CNodePtr>(node) && middle_nodes.find(node->cast<CNodePtr>()) != middle_nodes.end()) {
579 middle_ops_vec.push_back(node->cast<CNodePtr>());
580 middle_nodes.erase(node->cast<CNodePtr>());
581 }
582 }
583 (void)std::for_each(in_nodes.begin(), in_nodes.end(),
584 [&manager](const CNodePtr &in_cnode) { (void)manager->Replace(in_cnode, in_cnode->input(1)); });
585 (void)std::for_each(out_nodes.begin(), out_nodes.end(),
586 [&manager](const CNodePtr &out_node) { (void)manager->Replace(out_node, out_node->input(1)); });
587
588 if (InsertPreTransForNonTransInOut(func_graph, not_trans_in_nodes, not_trans_out_nodes, trans_info) != lite::RET_OK) {
589 MS_LOG(ERROR) << "Insert pre transpose failed for non-transpose inputs or outputs.";
590 return lite::RET_ERROR;
591 }
592
593 for (auto &middle_cnode : middle_ops_vec) {
594 if (IsSpecialType(middle_cnode)) {
595 continue;
596 }
597 status = HandleGraphSingleNode(func_graph, trans_info, middle_cnode);
598 if (status != RET_OK) {
599 MS_LOG(ERROR) << "Decrease transpose failed for op: " << middle_cnode->fullname_with_scope();
600 return status;
601 }
602 }
603 return lite::RET_OK;
604 }
605
SetSubGraphInput(const CNodePtr & cnode,const FuncGraphPtr & sub_graph)606 int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
607 MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
608 auto sub_inputs = sub_graph->get_inputs();
609 sub_inputs_map_[sub_graph] = sub_inputs;
610 for (auto &node : sub_inputs) {
611 auto param_node = node->cast<ParameterPtr>();
612 MS_ASSERT(param_node != nullptr);
613 auto node_name = node->fullname_with_scope();
614 auto last_underline = node_name.find_last_of("_");
615 node_name = node_name.substr(0, last_underline);
616 last_underline = node_name.find_last_of("_");
617 size_t index = 0;
618 try {
619 index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
620 } catch (const std::exception &e) {
621 MS_LOG(ERROR) << "Get index failed: " << e.what();
622 return lite::RET_ERROR;
623 }
624 auto abstract = GetCNodeInputAbstract(cnode, index);
625 MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr.");
626 param_node->set_abstract(abstract->Clone());
627 if (utils::isa<CNodePtr>(cnode->input(index))) {
628 ShapeVector shape_vec = {-1};
629 bool has_inferred{false};
630 auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred);
631 MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
632 if (!has_inferred) {
633 auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
634 MS_CHECK_TRUE_MSG(abstract_shape != nullptr, RET_ERROR, "create shape failed.");
635 param_node->abstract()->set_shape(abstract_shape);
636 }
637 }
638 if (utils::isa<Parameter>(cnode->input(index))) {
639 param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
640 }
641 if (utils::isa<ValueNode>(cnode->input(index))) {
642 lite::DataInfo data_info;
643 auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, true);
644 if (status != lite::RET_OK) {
645 continue;
646 }
647 ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
648 auto tensor_info = lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec,
649 static_cast<TypeId>(data_info.data_type_));
650 MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "create a tensor failed.");
651 param_node->set_default_param(tensor_info);
652 }
653 }
654 return lite::RET_OK;
655 }
656
ResetSubGraphInput()657 int DecreaseTransposeAlgo::ResetSubGraphInput() {
658 for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
659 auto &sub_graph = iter->first;
660 auto &sub_inputs = iter->second;
661 auto manager = sub_graph->manager();
662 MS_ASSERT(manager != nullptr);
663 for (auto &sub_input : sub_inputs) {
664 auto param_node = sub_graph->add_parameter();
665 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add parameter failed");
666 param_node->set_abstract(sub_input->abstract()->Clone());
667 param_node->set_name(sub_input->fullname_with_scope());
668 (void)manager->Replace(sub_input, param_node);
669 auto sub_param_input = sub_input->cast<ParameterPtr>();
670 MS_ASSERT(sub_param_input != nullptr);
671 sub_param_input->set_default_param(nullptr);
672 }
673 }
674 return lite::RET_OK;
675 }
676
SetSubGraphOutput(const FuncGraphPtr & sub_graph)677 int DecreaseTransposeAlgo::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
678 MS_ASSERT(sub_graph != nullptr);
679 auto return_node = sub_graph->get_return();
680 MS_ASSERT(return_node != nullptr);
681 auto origin_input = return_node->inputs();
682 lite::RemoveIfDepend(return_node);
683 lite::RemoveIfMakeTuple(return_node);
684 for (size_t i = 1; i < return_node->size(); ++i) {
685 if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
686 continue;
687 }
688 auto node_name = return_node->input(i)->fullname_with_scope();
689 if (node_name.size() < kInputSizeFive || node_name.substr(node_name.size() - kInputSizeFive) != "_post") {
690 continue;
691 }
692 auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
693 MS_ASSERT(trans_cnode != nullptr);
694 auto trans_input = trans_cnode->input(1);
695 auto trans_input_name = trans_input->fullname_with_scope();
696 if (utils::isa<ParameterPtr>(trans_input)) {
697 trans_input->cast<ParameterPtr>()->set_name(node_name);
698 } else if (utils::isa<CNodePtr>(trans_input)) {
699 trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
700 }
701 trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
702 trans_cnode->set_fullname_with_scope(trans_input_name);
703 }
704 return_node->set_inputs(origin_input);
705 return lite::RET_OK;
706 }
707
ModifyCNodeFormat(const CNodePtr & cnode,FormatTransNodeType pre_trans_type)708 int DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
709 MS_ASSERT(cnode != nullptr);
710 if (pre_trans_type == kNONE) {
711 return lite::RET_OK;
712 }
713 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
714 MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_ERROR, "GetValueNode Failed");
715 mindspore::Format new_format;
716 if (pre_trans_type == kNHWC2NCHW) {
717 new_format = mindspore::NCHW;
718 } else {
719 new_format = mindspore::NHWC;
720 }
721 primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(new_format));
722 if (primitive->HasAttr(opt::kOutputsFormat)) {
723 auto org_format = CastToInt(primitive->GetAttr(opt::kOutputsFormat));
724 std::vector<int64_t> outputs_format(org_format.size(), new_format);
725 (void)primitive->AddAttr(kOutputsFormat, MakeValue(outputs_format));
726 }
727 return lite::RET_OK;
728 }
729
DecreaseTransposeForSingleOp(const FuncGraphPtr & func_graph)730 bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
731 MS_ASSERT(func_graph != nullptr);
732 auto manager = Manage(func_graph, true);
733 if (manager == nullptr) {
734 MS_LOG(ERROR) << "manager is nullptr.";
735 return false;
736 }
737 auto node_list = TopoSort(func_graph->get_return());
738 int status = 0;
739 for (auto &node : node_list) {
740 if (!utils::isa<CNodePtr>(node)) {
741 continue;
742 }
743 auto cnode = node->cast<CNodePtr>();
744 MS_ASSERT(cnode != nullptr);
745 if (IsSpecialType(cnode)) {
746 continue;
747 }
748 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
749 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
750 if (sub_func_graph == nullptr) {
751 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
752 return false;
753 }
754 auto ret = SetSubGraphInput(cnode, sub_func_graph);
755 if (ret != lite::RET_OK) {
756 MS_LOG(ERROR) << "SetSubGraphInput failed";
757 return false;
758 }
759 (void)DecreaseTransposeForSingleOp(sub_func_graph);
760 ret = SetSubGraphOutput(sub_func_graph);
761 if (ret != lite::RET_OK) {
762 MS_LOG(ERROR) << "SetSubGraphOutput failed";
763 return false;
764 }
765 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
766 if (sub_func_graph == nullptr) {
767 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
768 return false;
769 }
770 ret = SetSubGraphInput(cnode, sub_func_graph);
771 if (ret != lite::RET_OK) {
772 MS_LOG(ERROR) << "SetSubGraphInput failed";
773 return false;
774 }
775 (void)DecreaseTransposeForSingleOp(sub_func_graph);
776 ret = SetSubGraphOutput(sub_func_graph);
777 if (ret != lite::RET_OK) {
778 MS_LOG(ERROR) << "SetSubGraphOutput failed";
779 return false;
780 }
781 continue;
782 }
783 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
784 if (prim == nullptr) {
785 MS_LOG(INFO) << "this is a call cnode, which input[0] is fg, node " << cnode->fullname_with_scope();
786 continue;
787 }
788 if (!IsDynamicFormatOp(prim->name())) {
789 continue;
790 }
791 TransTypePair trans_insert_info;
792 status = InsertPreTransNode(func_graph, cnode, &trans_insert_info);
793 if (status == lite::RET_NO_CHANGE) {
794 continue;
795 } else if (status != lite::RET_OK) {
796 MS_LOG(ERROR) << "insert pre node failed.";
797 return false;
798 }
799 auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
800 if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
801 MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
802 return false;
803 }
804 }
805 return true;
806 }
807
DecreaseTransposeForMultiOp(const FuncGraphPtr & func_graph)808 bool DecreaseTransposeAlgo::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
809 MS_ASSERT(func_graph != nullptr);
810 auto manager = Manage(func_graph, true);
811 if (manager == nullptr) {
812 MS_LOG(ERROR) << "manager is nullptr.";
813 return false;
814 }
815 auto node_list = TopoSort(func_graph->get_return());
816 std::set<CNodePtr> visit_transposes;
817 for (auto &node : node_list) {
818 if (!utils::isa<CNodePtr>(node)) {
819 continue;
820 }
821 auto cnode = node->cast<CNodePtr>();
822 MS_ASSERT(cnode != nullptr);
823 if (IsSpecialType(cnode) || visit_transposes.find(cnode) != visit_transposes.end()) {
824 continue;
825 }
826 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
827 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
828 if (sub_func_graph == nullptr) {
829 return false;
830 }
831 (void)DecreaseTransposeForMultiOp(sub_func_graph);
832 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
833 if (sub_func_graph == nullptr) {
834 return false;
835 }
836 (void)DecreaseTransposeForMultiOp(sub_func_graph);
837 }
838 std::vector<int> perm{};
839 if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
840 (perm != kNH2NC && perm != kNC2NH)) {
841 continue;
842 }
843 auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes);
844 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
845 MS_LOG(ERROR) << "global optimizer failed.";
846 return false;
847 }
848 }
849 return true;
850 }
851
IsNeedGenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index)852 bool DecreaseTransposeAlgo::IsNeedGenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index) {
853 if (cnode->input(index)->isa<CNode>()) {
854 MS_LOG(DEBUG) << "The input index" << index << " of " << cnode->fullname_with_scope() << " can cast cnode";
855 return true;
856 }
857 if (utils::isa<ParameterPtr>(cnode->input(index))) {
858 auto input_node = cnode->input(index)->cast<ParameterPtr>();
859 if (!input_node->has_default()) {
860 Format cur_input_format = DEFAULT_FORMAT;
861 if (!opt::SpecifyGraphInputFormat::GetCurGraphInputFormat(func_graph, fmk_type_, &cur_input_format)) {
862 MS_LOG(WARNING) << "Failed to get current format of graph input";
863 return false;
864 }
865 if (cur_input_format == NCHW || cur_input_format == NHWC) {
866 MS_LOG(DEBUG) << "Parameter is the input of graph";
867 return true;
868 }
869 }
870 }
871 return false;
872 }
873
Run(const FuncGraphPtr & func_graph)874 bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
875 MS_ASSERT(func_graph != nullptr);
876 node_infer_shape_.Init(fmk_type_, train_flag_);
877 transpose_strategy_.Init(fmk_type_, train_flag_);
878 if (!delete_redundant_transpose_.Run(func_graph)) {
879 MS_LOG(ERROR) << "Run delete-redundant-transpose pass failed.";
880 return false;
881 }
882 if (!DecreaseTransposeForSingleOp(func_graph)) {
883 MS_LOG(ERROR) << "run local trans insert optimizer failed.";
884 return false;
885 }
886
887 auto ret = ResetSubGraphInput();
888 if (ret != lite::RET_OK) {
889 MS_LOG(ERROR) << "ResetSubGraphInput failed.";
890 return false;
891 }
892 // if input format of several ops surrounded only by transpose op all can be NHWC,
893 // we can delete these transpose ops, and at the same time, transform these middle ops.
894 if (!DecreaseTransposeForMultiOp(func_graph)) {
895 MS_LOG(ERROR) << "run global trans insert optimizer failed.";
896 return false;
897 }
898 return true;
899 }
900 } // namespace opt
901 } // namespace mindspore
902