• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/decrease_transpose_algo.h"
19 #include <queue>
20 #include <set>
21 #include <unordered_map>
22 #include <utility>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "ops/op_utils.h"
28 #include "src/common/common.h"
29 #include "src/common/utils.h"
30 #include "tools/common/tensor_util.h"
31 #include "nnacl/op_base.h"
32 #include "tools/optimizer/graph/specify_graph_input_format.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace {
__anon9066e5770202(const CNodePtr &cnode) 37 std::function<bool(const CNodePtr &)> check_node = [](const CNodePtr &cnode) {
38   if (IsSpecialType(cnode) || CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
39     return true;
40   }
41   if (CheckPrimitiveType(cnode, prim::kPrimConcat)) {
42     // if concat's pre node is MakeTuple, the shape is complex, not optimize this case
43     if (cnode->size() > 1) {
44       auto tuple_node = cnode->input(1)->cast<CNodePtr>();
45       if (CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
46         for (size_t i = 1; i < tuple_node->size(); ++i) {
47           auto input_node = tuple_node->input(i);
48           if (utils::isa<ValueNodePtr>(input_node) || utils::isa<Parameter>(input_node)) {
49             return false;
50           }
51         }
52       }
53     }
54   }
55   TransposeStrategy transpose_strategy;
56   return transpose_strategy.CanChangeOpAxis(cnode);
57 };
58 
DealWithInputNodes(const FuncGraphPtr & func_graph,const CNodePtr & cur_node,std::set<CNodePtr> * in_nodes,AnfNodeIndexSet * not_trans_in_nodes,std::set<CNodePtr> * middle_nodes,std::queue<CNodePtr> * queue_nodes,std::queue<bool> * is_pre_nodes)59 void DealWithInputNodes(const FuncGraphPtr &func_graph, const CNodePtr &cur_node, std::set<CNodePtr> *in_nodes,
60                         AnfNodeIndexSet *not_trans_in_nodes, std::set<CNodePtr> *middle_nodes,
61                         std::queue<CNodePtr> *queue_nodes, std::queue<bool> *is_pre_nodes) {
62   MS_ASSERT(func_graph != nullptr && cur_node != nullptr);
63   MS_ASSERT(out_nodes != nullptr && not_trans_out_nodes != nullptr && middle_nodes != nullptr);
64   MS_ASSERT(queue_nodes != nullptr && is_pre_nodes != nullptr);
65   for (size_t i = 1; i < cur_node->size(); ++i) {
66     if (!utils::isa<CNodePtr>(cur_node->input(i))) {
67       continue;
68     }
69     auto cur_node_input = cur_node->input(i)->cast<CNodePtr>();
70     MS_ASSERT(cur_node_input != nullptr);
71     if (middle_nodes->find(cur_node_input) != middle_nodes->end() ||
72         in_nodes->find(cur_node_input) != in_nodes->end()) {
73       continue;
74     }
75     if (!check_node(cur_node_input)) {
76       (void)not_trans_in_nodes->insert(std::make_pair(cur_node, i));
77       continue;
78     }
79     queue_nodes->push(cur_node_input);
80     is_pre_nodes->push(true);
81   }
82 }
83 
DealWithOutputNodes(const FuncGraphPtr & func_graph,const CNodePtr & cur_node,std::set<CNodePtr> * out_nodes,AnfNodeIndexSet * not_trans_out_nodes,std::set<CNodePtr> * middle_nodes,std::queue<CNodePtr> * queue_nodes,std::queue<bool> * is_pre_nodes)84 STATUS DealWithOutputNodes(const FuncGraphPtr &func_graph, const CNodePtr &cur_node, std::set<CNodePtr> *out_nodes,
85                            AnfNodeIndexSet *not_trans_out_nodes, std::set<CNodePtr> *middle_nodes,
86                            std::queue<CNodePtr> *queue_nodes, std::queue<bool> *is_pre_nodes) {
87   MS_ASSERT(func_graph != nullptr && cur_node != nullptr);
88   MS_ASSERT(out_nodes != nullptr && not_trans_out_nodes != nullptr && middle_nodes != nullptr);
89   MS_ASSERT(queue_nodes != nullptr && is_pre_nodes != nullptr);
90   auto cur_node_users = func_graph->manager()->node_users()[cur_node];
91   if (cur_node_users.empty()) {
92     (void)out_nodes->insert(cur_node);
93     return lite::RET_OK;
94   }
95   for (auto &cur_node_user : cur_node_users) {
96     MS_CHECK_TRUE_RET(utils::isa<CNodePtr>(cur_node_user.first), lite::RET_ERROR);
97     auto cur_node_post = cur_node_user.first->cast<CNodePtr>();
98     MS_CHECK_TRUE_RET(cur_node_post != nullptr, lite::RET_ERROR);
99     if (middle_nodes->find(cur_node_post) != middle_nodes->end() ||
100         out_nodes->find(cur_node_post) != out_nodes->end()) {
101       continue;
102     }
103     if (!check_node(cur_node_post)) {
104       (void)not_trans_out_nodes->insert(cur_node_user);
105       continue;
106     }
107     queue_nodes->push(cur_node_post);
108     is_pre_nodes->push(false);
109   }
110   return lite::RET_OK;
111 }
112 
FindAreaSurroundedByTranspose(const FuncGraphPtr & func_graph,const CNodePtr & root_node,std::set<CNodePtr> * in_nodes,std::set<CNodePtr> * out_nodes,AnfNodeIndexSet * not_trans_in_nodes,AnfNodeIndexSet * not_trans_out_nodes,std::set<CNodePtr> * middle_nodes)113 STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node,
114                                      std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes,
115                                      AnfNodeIndexSet *not_trans_in_nodes, AnfNodeIndexSet *not_trans_out_nodes,
116                                      std::set<CNodePtr> *middle_nodes) {
117   MS_ASSERT(func_graph != nullptr && root_node != nullptr);
118   MS_ASSERT(in_nodes != nullptr && out_nodes != nullptr && middle_nodes != nullptr);
119   MS_ASSERT(not_trans_in_nodes != nullptr && not_trans_out_nodes != nullptr);
120   std::queue<CNodePtr> queue_nodes{};
121   queue_nodes.push(root_node);
122   std::queue<bool> is_pre_nodes;
123   is_pre_nodes.push(true);
124   while (!queue_nodes.empty()) {
125     auto cur_node = queue_nodes.front();
126     auto is_pre_node = is_pre_nodes.front();
127     queue_nodes.pop();
128     is_pre_nodes.pop();
129     if (CheckPrimitiveType(cur_node, prim::kPrimTranspose)) {
130       if (is_pre_node) {
131         (void)in_nodes->insert(cur_node);
132       } else {
133         (void)out_nodes->insert(cur_node);
134         continue;
135       }
136     }
137     if (middle_nodes->find(cur_node) != middle_nodes->end()) {
138       continue;
139     }
140     if (in_nodes->find(cur_node) == in_nodes->end()) {
141       middle_nodes->insert(cur_node);
142       // insert pre nodes.
143       auto origin_inputs = cur_node->inputs();
144       lite::RemoveIfDepend(cur_node);
145       if (CheckIsAllInputsParam(cur_node)) {
146         in_nodes->insert(cur_node);
147       } else {
148         DealWithInputNodes(func_graph, cur_node, in_nodes, not_trans_in_nodes, middle_nodes, &queue_nodes,
149                            &is_pre_nodes);
150       }
151       cur_node->set_inputs(origin_inputs);
152     }
153     // insert post nodes
154     if (DealWithOutputNodes(func_graph, cur_node, out_nodes, not_trans_out_nodes, middle_nodes, &queue_nodes,
155                             &is_pre_nodes) != lite::RET_OK) {
156       MS_LOG(ERROR) << "deal with output failed.";
157       return lite::RET_ERROR;
158     }
159   }
160   return lite::RET_OK;
161 }
162 
SetTransType(const std::set<CNodePtr> & cnodes,FormatTransNodeType * trans_type)163 void SetTransType(const std::set<CNodePtr> &cnodes, FormatTransNodeType *trans_type) {
164   MS_ASSERT(trans_type != nullptr);
165   FormatTransNodeType local_trans_type;
166   for (auto &cnode : cnodes) {
167     std::vector<int> perm;
168     if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
169         (perm != kNH2NC && perm != kNC2NH)) {
170       *trans_type = kNONE;
171       return;
172     }
173     local_trans_type = perm == kNH2NC ? kNHWC2NCHW : kNCHW2NHWC;
174     *trans_type = *trans_type == kNONE ? local_trans_type : *trans_type;
175     if (*trans_type != local_trans_type) {
176       *trans_type = kNONE;
177       return;
178     }
179   }
180 }
181 
JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> & in_nodes,const std::set<CNodePtr> & out_nodes,const AnfNodeIndexSet & not_trans_in_nodes,const AnfNodeIndexSet & not_trans_out_nodes,const std::set<CNodePtr> & middle_nodes,TransTypePair * trans_info,bool all_trans)182 bool JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> &in_nodes, const std::set<CNodePtr> &out_nodes,
183                                  const AnfNodeIndexSet &not_trans_in_nodes, const AnfNodeIndexSet &not_trans_out_nodes,
184                                  const std::set<CNodePtr> &middle_nodes, TransTypePair *trans_info, bool all_trans) {
185   MS_ASSERT(trans_info != nullptr);
186   if (all_trans && (!not_trans_in_nodes.empty() || !not_trans_out_nodes.empty())) {
187     return false;
188   }
189   if (not_trans_in_nodes.size() + not_trans_out_nodes.size() >= in_nodes.size() + out_nodes.size()) {
190     return false;
191   }
192   SetTransType(in_nodes, &trans_info->pre_);
193   if (trans_info->pre_ == kNONE) {
194     return false;
195   }
196   SetTransType(out_nodes, &trans_info->post_);
197   if (trans_info->post_ == kNONE) {
198     return false;
199   }
200   if (trans_info->pre_ == trans_info->post_) {
201     return false;
202   }
203   return true;
204 }
205 
ConvertTensorToNCOrNH(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index,FmkType fmk_type,bool train_flag,FormatTransNodeType trans_type)206 int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
207                           bool train_flag, FormatTransNodeType trans_type) {
208   MS_ASSERT(cnode != nullptr);
209   if (utils::isa<CNodePtr>(cnode->input(index))) {
210     return lite::RET_OK;
211   }
212   lite::DataInfo data_info;
213   int status = 0;
214   if (utils::isa<ParameterPtr>(cnode->input(index))) {
215     auto input_node = cnode->input(index)->cast<ParameterPtr>();
216     MS_CHECK_TRUE_MSG(input_node != nullptr, lite::RET_ERROR, "input_node is nullptr");
217     if (!input_node->has_default()) {
218       return lite::RET_OK;
219     }
220     status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, &data_info, true);
221   } else {
222     status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info, true);
223   }
224   if (status != lite::RET_OK) {
225     return lite::RET_ERROR;
226   }
227   if (lite::JudgeDynamicShape(data_info.shape_) || (data_info.shape_.size() == 1 && data_info.shape_[0] == 0) ||
228       std::all_of(data_info.shape_.begin(), data_info.shape_.end(), [](int val) { return val == 1; }) ||
229       (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
230     return lite::RET_OK;
231   }
232   ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
233   bool need_expand = !CheckPrimitiveType(cnode, prim::kPrimScaleFusion);
234   if (need_expand && expand_shape.size() <= DIMENSION_4D) {
235     ShapeVector tmp_shape(DIMENSION_4D - expand_shape.size(), 1);
236     (void)expand_shape.insert(expand_shape.begin(), tmp_shape.begin(), tmp_shape.end());
237   }
238   auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
239                                                  data_info.data_.data(), data_info.data_.size());
240   MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr");
241   if (trans_type == kNHWC2NCHW) {
242     (void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
243   } else {
244     (void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
245   }
246   auto param_node = func_graph->add_parameter();
247   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add_parameter failed");
248   param_node->set_name(cnode->input(index)->fullname_with_scope());
249   status = lite::InitParameterFromTensorInfo(param_node, tensor);
250   if (status != RET_OK) {
251     MS_LOG(ERROR) << "init parameter from tensor info failed";
252     return lite::RET_ERROR;
253   }
254   auto tr = func_graph->manager()->Transact();
255   tr.SetEdge(cnode, index, param_node);
256   tr.Commit();
257   return lite::RET_OK;
258 }
259 }  // namespace
260 
PostTransposeFusion(const FuncGraphPtr & func_graph,const CNodePtr & cnode)261 STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
262   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
263   if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
264     return lite::RET_OK;
265   }
266   std::vector<int> cur_perm;
267   if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) {
268     MS_LOG(ERROR) << "get transpose perm failed.";
269     return lite::RET_ERROR;
270   }
271   auto manager = func_graph->manager();
272   MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
273   MS_CHECK_TRUE_RET(manager->node_users().find(cnode) != manager->node_users().end(), lite::RET_ERROR);
274   auto node_users = manager->node_users()[cnode];
275   for (auto &node_user : node_users) {
276     auto post_node = node_user.first;
277     if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
278       std::vector<int> post_trans_perm;
279       auto post_trans_node = post_node->cast<CNodePtr>();
280       MS_ASSERT(post_trans_node != nullptr);
281       if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) {
282         MS_LOG(ERROR) << "get post transpose node perm failed.";
283         return lite::RET_ERROR;
284       }
285       if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
286         (void)manager->Replace(post_node, cnode->input(1));
287       }
288     }
289   }
290   return lite::RET_OK;
291 }
292 
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> perm,bool before,size_t index)293 STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
294                                           const std::vector<int> perm, bool before, size_t index) {
295   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
296   AnfNodePtr new_input = nullptr;
297   new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
298   if (new_input == nullptr) {
299     MS_LOG(ERROR) << "generate a transpose node failed.";
300     return lite::RET_ERROR;
301   }
302   if (new_input == cnode->input(index) || new_input == cnode) {
303     return lite::RET_OK;
304   } else if (utils::isa<CNodePtr>(new_input)) {
305     auto new_cnode_input = new_input->cast<CNodePtr>();
306     MS_ASSERT(new_cnode_input != nullptr);
307     int status = lite::RET_OK;
308     if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) {
309       status = node_infer_shape_.InferShape(new_cnode_input);
310     }
311     if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
312       MS_LOG(ERROR) << "infer shape failed.";
313       return lite::RET_ERROR;
314     }
315   }
316   auto manager = func_graph->manager();
317   if (manager == nullptr) {
318     manager = Manage(func_graph, true);
319   }
320   auto tr = manager->Transact();
321   if (before) {
322     tr.SetEdge(cnode, index, new_input);
323     tr.Commit();
324   } else {
325     (void)manager->Replace(cnode, new_input);
326     if (PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
327       MS_LOG(ERROR) << "post transpose fusion failed.";
328       return lite::RET_ERROR;
329     }
330   }
331   return lite::RET_OK;
332 }
333 
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_insert_info)334 STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
335                                                  TransTypePair *trans_insert_info) {
336   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
337   MS_ASSERT(trans_insert_info != nullptr);
338   TransTypePair trans_info;
339   auto origin_inputs = cnode->inputs();
340   lite::RemoveIfMakeTuple(cnode);
341   RemoveIfMonad(cnode);
342   if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) {
343     cnode->set_inputs(origin_inputs);
344     return lite::RET_NO_CHANGE;
345   }
346   cnode->set_inputs(origin_inputs);
347   auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_insert_info->pre_);
348   if (status == lite::RET_NOT_SUPPORT) {
349     return lite::RET_NO_CHANGE;
350   } else if (status != lite::RET_OK) {
351     MS_LOG(ERROR) << "change op attr failed.";
352     return lite::RET_ERROR;
353   }
354   status = DoPreInsert(func_graph, cnode, trans_insert_info->pre_);
355   if (status != lite::RET_OK) {
356     MS_LOG(ERROR) << "Do pre insert failed.";
357     return lite::RET_ERROR;
358   }
359   status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
360   if (status != lite::RET_OK) {
361     MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
362     return lite::RET_ERROR;
363   }
364   status = node_infer_shape_.InferShape(cnode);
365   if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
366     MS_LOG(ERROR) << "infer shape failed.";
367     return lite::RET_ERROR;
368   }
369   return lite::RET_OK;
370 }
371 
DoPreInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type)372 STATUS DecreaseTransposeAlgo::DoPreInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
373                                           FormatTransNodeType trans_type) {
374   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
375   auto abstract = cnode->abstract();
376   MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
377   if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
378     auto abstract_tuple = abstract->cast<abstract::AbstractTuplePtr>();
379     auto abstract_list = abstract_tuple->elements();
380     MS_CHECK_TRUE_RET(!abstract_list.empty(), lite::RET_OUT_OF_TENSOR_RANGE);
381     abstract = abstract_list.front();
382     MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
383   }
384   auto HandleFunc = [this](const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index,
385                            FormatTransNodeType trans_type) -> STATUS {
386     STATUS ret = lite::RET_OK;
387     auto before_perm = trans_type == kNHWC2NCHW ? kNH2NC : kNC2NH;
388     if (IsNeedGenNewInput(func_graph, cnode, index)) {
389       ret = GenNewInput(func_graph, cnode, before_perm, true, index);
390       if (ret != lite::RET_OK) {
391         MS_LOG(ERROR) << "GenNewInput failed";
392       }
393     } else {
394       ret = ConvertTensorToNCOrNH(func_graph, cnode, index, fmk_type_, train_flag_, trans_type);
395       if (ret != lite::RET_OK) {
396         MS_LOG(ERROR) << "ConvertTensorToNCOrNH faileded";
397       }
398     }
399     return ret;
400   };
401   for (size_t i = 1; i < cnode->size(); ++i) {
402     MS_CHECK_TRUE_RET(cnode->input(i) != nullptr, lite::RET_NULL_PTR);
403     if (IsMonadNode(cnode->input(i))) {
404       continue;
405     }
406     if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
407         CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTupleV2)) {
408       auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
409       MS_ASSERT(input_make_tuple != nullptr);
410       for (size_t j = 1; j < input_make_tuple->size(); ++j) {
411         MS_CHECK_TRUE_RET(input_make_tuple->input(j) != nullptr, lite::RET_NULL_PTR);
412         if (HandleFunc(func_graph, input_make_tuple, j, trans_type) != lite::RET_OK) {
413           MS_LOG(ERROR) << "handle pre insert failed.";
414           return lite::RET_ERROR;
415         }
416       }
417       continue;
418     }
419     if (HandleFunc(func_graph, cnode, i, trans_type) != lite::RET_OK) {
420       MS_LOG(ERROR) << "handle pre insert failed.";
421       return lite::RET_ERROR;
422     }
423   }
424   return lite::RET_OK;
425 }
426 
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)427 STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
428                                                   const std::vector<int> &perm) {
429   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
430   auto manager = func_graph->manager();
431   MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
432   auto abstract = cnode->abstract();
433   MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_ERROR);
434   if (!abstract->isa<abstract::AbstractTuple>()) {
435     if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
436       MS_LOG(ERROR) << "generate a new input failed.";
437       return lite::RET_ERROR;
438     }
439   } else {
440     MS_CHECK_TRUE_RET(manager->node_users().find(cnode) != manager->node_users().end(), lite::RET_ERROR);
441     auto node_users = manager->node_users()[cnode];
442     for (auto &node_user : node_users) {
443       auto post_node = node_user.first;
444       CNodePtr tuple_get_item = nullptr;
445       if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
446         if (!train_flag_) {
447           MS_LOG(ERROR) << "post node is invalid.";
448           return lite::RET_ERROR;
449         } else {
450           tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0);
451           MS_CHECK_TRUE_RET(tuple_get_item != nullptr, lite::RET_ERROR);
452           post_node = tuple_get_item;
453           (void)manager->Replace(cnode, tuple_get_item);
454         }
455       }
456       MS_CHECK_TRUE_RET(manager->node_users().find(post_node) != manager->node_users().end(), lite::RET_ERROR);
457       if (manager->node_users()[post_node].empty()) {
458         continue;
459       }
460       auto post_cnode = post_node->cast<CNodePtr>();
461       MS_ASSERT(post_cnode != nullptr);
462       if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
463         MS_LOG(ERROR) << "generate a new input failed.";
464         return lite::RET_ERROR;
465       }
466       if (tuple_get_item != nullptr) {
467         (void)manager->Replace(tuple_get_item, tuple_get_item->input(1));
468       }
469     }
470   }
471   return lite::RET_OK;
472 }
473 
HandleGraphSingleNode(const FuncGraphPtr & func_graph,const TransTypePair & trans_info,const CNodePtr & cnode)474 STATUS DecreaseTransposeAlgo::HandleGraphSingleNode(const FuncGraphPtr &func_graph, const TransTypePair &trans_info,
475                                                     const CNodePtr &cnode) {
476   for (size_t i = 1; i < cnode->size(); ++i) {
477     auto status = ConvertTensorToNCOrNH(func_graph, cnode, i, fmk_type_, train_flag_, trans_info.post_);
478     if (status != lite::RET_OK) {
479       MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
480       return lite::RET_ERROR;
481     }
482   }
483   auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_info.post_);
484   if (status != lite::RET_OK) {
485     MS_LOG(ERROR) << "change op attr failed.";
486     return lite::RET_ERROR;
487   }
488   status = ModifyCNodeFormat(cnode, trans_info.post_);
489   if (status != lite::RET_OK) {
490     MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
491     return lite::RET_ERROR;
492   }
493   status = node_infer_shape_.InferShape(cnode);
494   if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
495     MS_LOG(ERROR) << "infer shape failed.";
496     return lite::RET_ERROR;
497   }
498   return RET_OK;
499 }
500 
InsertPreTransForNonTransInOut(const FuncGraphPtr & func_graph,const AnfNodeIndexSet & not_trans_in_nodes,const AnfNodeIndexSet & not_trans_out_nodes,TransTypePair trans_info)501 int DecreaseTransposeAlgo::InsertPreTransForNonTransInOut(const FuncGraphPtr &func_graph,
502                                                           const AnfNodeIndexSet &not_trans_in_nodes,
503                                                           const AnfNodeIndexSet &not_trans_out_nodes,
504                                                           TransTypePair trans_info) {
505   std::function<bool(const AnfNodePtr &, size_t, FormatTransNodeType)> insert_pre_trans =
506     [&](const AnfNodePtr &node, size_t index, FormatTransNodeType format_trans_type) -> bool {
507     MS_CHECK_TRUE_RET(node != nullptr, false);
508     auto cnode = node->cast<CNodePtr>();
509     MS_CHECK_TRUE_RET(cnode != nullptr, false);
510     if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
511       auto node_users = func_graph->manager()->node_users()[node];
512       return std::all_of(node_users.begin(), node_users.end(),
513                          [&insert_pre_trans, &format_trans_type](const std::pair<AnfNodePtr, int> &pair) {
514                            return insert_pre_trans(pair.first, pair.second, format_trans_type);
515                          });
516     } else if (CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTuple) ||
517                CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTupleV2)) {
518       auto make_tuple_cnode = cnode->input(1)->cast<CNodePtr>();
519       MS_CHECK_TRUE_RET(make_tuple_cnode != nullptr, false);
520       for (size_t i = 0; i < make_tuple_cnode->size(); i++) {
521         if (!insert_pre_trans(make_tuple_cnode->input(i), i, format_trans_type)) {
522           return false;
523         }
524       }
525       return true;
526     }
527     auto perm = format_trans_type == kNHWC2NCHW ? kNC2NH : kNH2NC;
528     return GenNewInput(func_graph, cnode, perm, true, index) == lite::RET_OK;
529   };
530   bool deal_inputs = std::all_of(not_trans_in_nodes.begin(), not_trans_in_nodes.end(),
531                                  [&insert_pre_trans, &trans_info](const std::pair<AnfNodePtr, int> &pair) {
532                                    return insert_pre_trans(pair.first, pair.second, trans_info.pre_);
533                                  });
534   if (!deal_inputs) {
535     MS_LOG(ERROR) << "Insert transpose for inputs failed.";
536     return lite::RET_ERROR;
537   }
538   bool deal_outputs = std::all_of(not_trans_out_nodes.begin(), not_trans_out_nodes.end(),
539                                   [&insert_pre_trans, &trans_info](const std::pair<AnfNodePtr, int> &pair) {
540                                     return insert_pre_trans(pair.first, pair.second, trans_info.post_);
541                                   });
542   if (!deal_outputs) {
543     MS_LOG(ERROR) << "Insert transpose for outputs failed.";
544     return lite::RET_ERROR;
545   }
546   return lite::RET_OK;
547 }
548 
HandleGraphMultiNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::set<CNodePtr> * visit_transposes)549 STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
550                                                    std::set<CNodePtr> *visit_transposes) {
551   MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
552   auto manager = func_graph->manager();
553   MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "manager is nullptr");
554   std::set<CNodePtr> middle_nodes{};
555   std::set<CNodePtr> in_nodes{};
556   std::set<CNodePtr> out_nodes{};
557   AnfNodeIndexSet not_trans_in_nodes;
558   AnfNodeIndexSet not_trans_out_nodes;
559   auto status = FindAreaSurroundedByTranspose(func_graph, cnode, &in_nodes, &out_nodes, &not_trans_in_nodes,
560                                               &not_trans_out_nodes, &middle_nodes);
561   if (status != lite::RET_OK) {
562     MS_LOG(ERROR) << "find an area surrounded by transpose failed.";
563     return status;
564   }
565   for (auto &in_cnode : in_nodes) {
566     if (CheckPrimitiveType(in_cnode, prim::kPrimTranspose)) {
567       visit_transposes->insert(in_cnode);
568     }
569   }
570   TransTypePair trans_info;
571   if (!JudgeCanOptimizerForMultiOp(in_nodes, out_nodes, not_trans_in_nodes, not_trans_out_nodes, middle_nodes,
572                                    &trans_info, (train_flag_ || all_trans_))) {
573     return lite::RET_NO_CHANGE;
574   }
575   auto node_list = TopoSort(func_graph->get_return());
576   std::vector<CNodePtr> middle_ops_vec;
577   for (auto &node : node_list) {
578     if (utils::isa<CNodePtr>(node) && middle_nodes.find(node->cast<CNodePtr>()) != middle_nodes.end()) {
579       middle_ops_vec.push_back(node->cast<CNodePtr>());
580       middle_nodes.erase(node->cast<CNodePtr>());
581     }
582   }
583   (void)std::for_each(in_nodes.begin(), in_nodes.end(),
584                       [&manager](const CNodePtr &in_cnode) { (void)manager->Replace(in_cnode, in_cnode->input(1)); });
585   (void)std::for_each(out_nodes.begin(), out_nodes.end(),
586                       [&manager](const CNodePtr &out_node) { (void)manager->Replace(out_node, out_node->input(1)); });
587 
588   if (InsertPreTransForNonTransInOut(func_graph, not_trans_in_nodes, not_trans_out_nodes, trans_info) != lite::RET_OK) {
589     MS_LOG(ERROR) << "Insert pre transpose failed for non-transpose inputs or outputs.";
590     return lite::RET_ERROR;
591   }
592 
593   for (auto &middle_cnode : middle_ops_vec) {
594     if (IsSpecialType(middle_cnode)) {
595       continue;
596     }
597     status = HandleGraphSingleNode(func_graph, trans_info, middle_cnode);
598     if (status != RET_OK) {
599       MS_LOG(ERROR) << "Decrease transpose failed for op: " << middle_cnode->fullname_with_scope();
600       return status;
601     }
602   }
603   return lite::RET_OK;
604 }
605 
SetSubGraphInput(const CNodePtr & cnode,const FuncGraphPtr & sub_graph)606 int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
607   MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
608   auto sub_inputs = sub_graph->get_inputs();
609   sub_inputs_map_[sub_graph] = sub_inputs;
610   for (auto &node : sub_inputs) {
611     auto param_node = node->cast<ParameterPtr>();
612     MS_ASSERT(param_node != nullptr);
613     auto node_name = node->fullname_with_scope();
614     auto last_underline = node_name.find_last_of("_");
615     node_name = node_name.substr(0, last_underline);
616     last_underline = node_name.find_last_of("_");
617     size_t index = 0;
618     try {
619       index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
620     } catch (const std::exception &e) {
621       MS_LOG(ERROR) << "Get index failed: " << e.what();
622       return lite::RET_ERROR;
623     }
624     auto abstract = GetCNodeInputAbstract(cnode, index);
625     MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr.");
626     param_node->set_abstract(abstract->Clone());
627     if (utils::isa<CNodePtr>(cnode->input(index))) {
628       ShapeVector shape_vec = {-1};
629       bool has_inferred{false};
630       auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred);
631       MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
632       if (!has_inferred) {
633         auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
634         MS_CHECK_TRUE_MSG(abstract_shape != nullptr, RET_ERROR, "create shape failed.");
635         param_node->abstract()->set_shape(abstract_shape);
636       }
637     }
638     if (utils::isa<Parameter>(cnode->input(index))) {
639       param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
640     }
641     if (utils::isa<ValueNode>(cnode->input(index))) {
642       lite::DataInfo data_info;
643       auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, true);
644       if (status != lite::RET_OK) {
645         continue;
646       }
647       ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
648       auto tensor_info = lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec,
649                                                 static_cast<TypeId>(data_info.data_type_));
650       MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "create a tensor failed.");
651       param_node->set_default_param(tensor_info);
652     }
653   }
654   return lite::RET_OK;
655 }
656 
ResetSubGraphInput()657 int DecreaseTransposeAlgo::ResetSubGraphInput() {
658   for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
659     auto &sub_graph = iter->first;
660     auto &sub_inputs = iter->second;
661     auto manager = sub_graph->manager();
662     MS_ASSERT(manager != nullptr);
663     for (auto &sub_input : sub_inputs) {
664       auto param_node = sub_graph->add_parameter();
665       MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add parameter failed");
666       param_node->set_abstract(sub_input->abstract()->Clone());
667       param_node->set_name(sub_input->fullname_with_scope());
668       (void)manager->Replace(sub_input, param_node);
669       auto sub_param_input = sub_input->cast<ParameterPtr>();
670       MS_ASSERT(sub_param_input != nullptr);
671       sub_param_input->set_default_param(nullptr);
672     }
673   }
674   return lite::RET_OK;
675 }
676 
SetSubGraphOutput(const FuncGraphPtr & sub_graph)677 int DecreaseTransposeAlgo::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
678   MS_ASSERT(sub_graph != nullptr);
679   auto return_node = sub_graph->get_return();
680   MS_ASSERT(return_node != nullptr);
681   auto origin_input = return_node->inputs();
682   lite::RemoveIfDepend(return_node);
683   lite::RemoveIfMakeTuple(return_node);
684   for (size_t i = 1; i < return_node->size(); ++i) {
685     if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
686       continue;
687     }
688     auto node_name = return_node->input(i)->fullname_with_scope();
689     if (node_name.size() < kInputSizeFive || node_name.substr(node_name.size() - kInputSizeFive) != "_post") {
690       continue;
691     }
692     auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
693     MS_ASSERT(trans_cnode != nullptr);
694     auto trans_input = trans_cnode->input(1);
695     auto trans_input_name = trans_input->fullname_with_scope();
696     if (utils::isa<ParameterPtr>(trans_input)) {
697       trans_input->cast<ParameterPtr>()->set_name(node_name);
698     } else if (utils::isa<CNodePtr>(trans_input)) {
699       trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
700     }
701     trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
702     trans_cnode->set_fullname_with_scope(trans_input_name);
703   }
704   return_node->set_inputs(origin_input);
705   return lite::RET_OK;
706 }
707 
ModifyCNodeFormat(const CNodePtr & cnode,FormatTransNodeType pre_trans_type)708 int DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
709   MS_ASSERT(cnode != nullptr);
710   if (pre_trans_type == kNONE) {
711     return lite::RET_OK;
712   }
713   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
714   MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_ERROR, "GetValueNode Failed");
715   mindspore::Format new_format;
716   if (pre_trans_type == kNHWC2NCHW) {
717     new_format = mindspore::NCHW;
718   } else {
719     new_format = mindspore::NHWC;
720   }
721   primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(new_format));
722   if (primitive->HasAttr(opt::kOutputsFormat)) {
723     auto org_format = CastToInt(primitive->GetAttr(opt::kOutputsFormat));
724     std::vector<int64_t> outputs_format(org_format.size(), new_format);
725     (void)primitive->AddAttr(kOutputsFormat, MakeValue(outputs_format));
726   }
727   return lite::RET_OK;
728 }
729 
DecreaseTransposeForSingleOp(const FuncGraphPtr & func_graph)730 bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
731   MS_ASSERT(func_graph != nullptr);
732   auto manager = Manage(func_graph, true);
733   if (manager == nullptr) {
734     MS_LOG(ERROR) << "manager is nullptr.";
735     return false;
736   }
737   auto node_list = TopoSort(func_graph->get_return());
738   int status = 0;
739   for (auto &node : node_list) {
740     if (!utils::isa<CNodePtr>(node)) {
741       continue;
742     }
743     auto cnode = node->cast<CNodePtr>();
744     MS_ASSERT(cnode != nullptr);
745     if (IsSpecialType(cnode)) {
746       continue;
747     }
748     if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
749       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
750       if (sub_func_graph == nullptr) {
751         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
752         return false;
753       }
754       auto ret = SetSubGraphInput(cnode, sub_func_graph);
755       if (ret != lite::RET_OK) {
756         MS_LOG(ERROR) << "SetSubGraphInput failed";
757         return false;
758       }
759       (void)DecreaseTransposeForSingleOp(sub_func_graph);
760       ret = SetSubGraphOutput(sub_func_graph);
761       if (ret != lite::RET_OK) {
762         MS_LOG(ERROR) << "SetSubGraphOutput failed";
763         return false;
764       }
765       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
766       if (sub_func_graph == nullptr) {
767         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
768         return false;
769       }
770       ret = SetSubGraphInput(cnode, sub_func_graph);
771       if (ret != lite::RET_OK) {
772         MS_LOG(ERROR) << "SetSubGraphInput failed";
773         return false;
774       }
775       (void)DecreaseTransposeForSingleOp(sub_func_graph);
776       ret = SetSubGraphOutput(sub_func_graph);
777       if (ret != lite::RET_OK) {
778         MS_LOG(ERROR) << "SetSubGraphOutput failed";
779         return false;
780       }
781       continue;
782     }
783     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
784     if (prim == nullptr) {
785       MS_LOG(INFO) << "this is a call cnode, which input[0] is fg, node " << cnode->fullname_with_scope();
786       continue;
787     }
788     if (!IsDynamicFormatOp(prim->name())) {
789       continue;
790     }
791     TransTypePair trans_insert_info;
792     status = InsertPreTransNode(func_graph, cnode, &trans_insert_info);
793     if (status == lite::RET_NO_CHANGE) {
794       continue;
795     } else if (status != lite::RET_OK) {
796       MS_LOG(ERROR) << "insert pre node failed.";
797       return false;
798     }
799     auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
800     if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
801       MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
802       return false;
803     }
804   }
805   return true;
806 }
807 
DecreaseTransposeForMultiOp(const FuncGraphPtr & func_graph)808 bool DecreaseTransposeAlgo::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
809   MS_ASSERT(func_graph != nullptr);
810   auto manager = Manage(func_graph, true);
811   if (manager == nullptr) {
812     MS_LOG(ERROR) << "manager is nullptr.";
813     return false;
814   }
815   auto node_list = TopoSort(func_graph->get_return());
816   std::set<CNodePtr> visit_transposes;
817   for (auto &node : node_list) {
818     if (!utils::isa<CNodePtr>(node)) {
819       continue;
820     }
821     auto cnode = node->cast<CNodePtr>();
822     MS_ASSERT(cnode != nullptr);
823     if (IsSpecialType(cnode) || visit_transposes.find(cnode) != visit_transposes.end()) {
824       continue;
825     }
826     if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
827       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
828       if (sub_func_graph == nullptr) {
829         return false;
830       }
831       (void)DecreaseTransposeForMultiOp(sub_func_graph);
832       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
833       if (sub_func_graph == nullptr) {
834         return false;
835       }
836       (void)DecreaseTransposeForMultiOp(sub_func_graph);
837     }
838     std::vector<int> perm{};
839     if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
840         (perm != kNH2NC && perm != kNC2NH)) {
841       continue;
842     }
843     auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes);
844     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
845       MS_LOG(ERROR) << "global optimizer failed.";
846       return false;
847     }
848   }
849   return true;
850 }
851 
IsNeedGenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index)852 bool DecreaseTransposeAlgo::IsNeedGenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index) {
853   if (cnode->input(index)->isa<CNode>()) {
854     MS_LOG(DEBUG) << "The input index" << index << " of " << cnode->fullname_with_scope() << " can cast cnode";
855     return true;
856   }
857   if (utils::isa<ParameterPtr>(cnode->input(index))) {
858     auto input_node = cnode->input(index)->cast<ParameterPtr>();
859     if (!input_node->has_default()) {
860       Format cur_input_format = DEFAULT_FORMAT;
861       if (!opt::SpecifyGraphInputFormat::GetCurGraphInputFormat(func_graph, fmk_type_, &cur_input_format)) {
862         MS_LOG(WARNING) << "Failed to get current format of graph input";
863         return false;
864       }
865       if (cur_input_format == NCHW || cur_input_format == NHWC) {
866         MS_LOG(DEBUG) << "Parameter is the input of graph";
867         return true;
868       }
869     }
870   }
871   return false;
872 }
873 
Run(const FuncGraphPtr & func_graph)874 bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
875   MS_ASSERT(func_graph != nullptr);
876   node_infer_shape_.Init(fmk_type_, train_flag_);
877   transpose_strategy_.Init(fmk_type_, train_flag_);
878   if (!delete_redundant_transpose_.Run(func_graph)) {
879     MS_LOG(ERROR) << "Run delete-redundant-transpose pass failed.";
880     return false;
881   }
882   if (!DecreaseTransposeForSingleOp(func_graph)) {
883     MS_LOG(ERROR) << "run local trans insert optimizer failed.";
884     return false;
885   }
886 
887   auto ret = ResetSubGraphInput();
888   if (ret != lite::RET_OK) {
889     MS_LOG(ERROR) << "ResetSubGraphInput failed.";
890     return false;
891   }
892   // if input format of several ops surrounded only by transpose op all can be NHWC,
893   // we can delete these transpose ops, and at the same time, transform these middle ops.
894   if (!DecreaseTransposeForMultiOp(func_graph)) {
895     MS_LOG(ERROR) << "run global trans insert optimizer failed.";
896     return false;
897   }
898   return true;
899 }
900 }  // namespace opt
901 }  // namespace mindspore
902