• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/graph/decrease_transpose_algo.h"
18 #include <queue>
19 #include <set>
20 #include <unordered_map>
21 #include <utility>
22 #include "ops/op_utils.h"
23 #include "src/common/common.h"
24 #include "src/common/utils.h"
25 #include "tools/common/tensor_util.h"
26 #include "nnacl/op_base.h"
27 
28 namespace mindspore {
29 namespace opt {
30 namespace {
FindAreaSurroundedByTranspose(const FuncGraphPtr & func_graph,const CNodePtr & root_node,std::set<CNodePtr> * in_nodes,std::set<CNodePtr> * out_nodes,std::set<CNodePtr> * middle_nodes)31 STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node,
32                                      std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes,
33                                      std::set<CNodePtr> *middle_nodes) {
34   MS_ASSERT(func_graph != nullptr && root_node != nullptr);
35   MS_ASSERT(in_nodes != nullptr && out_nodes != nullptr && middle_nodes != nullptr);
36   std::queue<CNodePtr> queue_nodes{};
37   queue_nodes.push(root_node);
38   std::queue<bool> is_pre_nodes;
39   is_pre_nodes.push(true);
40   while (!queue_nodes.empty()) {
41     auto cur_node = queue_nodes.front();
42     auto is_pre_node = is_pre_nodes.front();
43     queue_nodes.pop();
44     is_pre_nodes.pop();
45     if (CheckPrimitiveType(cur_node, prim::kPrimTranspose)) {
46       if (is_pre_node) {
47         in_nodes->insert(cur_node);
48       } else {
49         out_nodes->insert(cur_node);
50         continue;
51       }
52     }
53     if (middle_nodes->find(cur_node) != middle_nodes->end()) {
54       continue;
55     }
56     if (in_nodes->find(cur_node) == in_nodes->end()) {
57       middle_nodes->insert(cur_node);
58       // insert pre nodes.
59       auto origin_inputs = cur_node->inputs();
60       lite::RemoveIfDepend(cur_node);
61       for (size_t i = 1; i < cur_node->size(); ++i) {
62         if (!utils::isa<CNodePtr>(cur_node->input(i))) {
63           continue;
64         }
65         auto cur_node_input = cur_node->input(i)->cast<CNodePtr>();
66         MS_ASSERT(cur_node_input != nullptr);
67         if (middle_nodes->find(cur_node_input) != middle_nodes->end() ||
68             in_nodes->find(cur_node_input) != in_nodes->end()) {
69           continue;
70         }
71         queue_nodes.push(cur_node_input);
72         is_pre_nodes.push(true);
73       }
74       if (CheckIsAllInputsParam(cur_node)) {
75         in_nodes->insert(cur_node);
76       }
77       cur_node->set_inputs(origin_inputs);
78     }
79     // insert post nodes
80     auto cur_node_users = func_graph->manager()->node_users()[cur_node];
81     for (auto &cur_node_user : cur_node_users) {
82       if (!utils::isa<CNodePtr>(cur_node_user.first)) {
83         MS_LOG(ERROR) << "post node is not cnode.";
84         return lite::RET_ERROR;
85       }
86       auto cur_node_post = cur_node_user.first->cast<CNodePtr>();
87       MS_CHECK_TRUE_MSG(cur_node_post != nullptr, RET_ERROR, "cast ptr failed");
88       if (middle_nodes->find(cur_node_post) != middle_nodes->end() ||
89           out_nodes->find(cur_node_post) != out_nodes->end()) {
90         continue;
91       }
92       queue_nodes.push(cur_node_post);
93       is_pre_nodes.push(false);
94     }
95     if (cur_node_users.empty()) {
96       out_nodes->insert(cur_node);
97     }
98   }
99   return lite::RET_OK;
100 }
101 
SetTransType(const std::set<CNodePtr> & cnodes,FormatTransNodeType * trans_type)102 void SetTransType(const std::set<CNodePtr> &cnodes, FormatTransNodeType *trans_type) {
103   MS_ASSERT(trans_type != nullptr);
104   FormatTransNodeType local_trans_type;
105   for (auto &cnode : cnodes) {
106     std::vector<int> perm;
107     if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
108         (perm != kNH2NC && perm != kNC2NH)) {
109       *trans_type = kNONE;
110       return;
111     }
112     local_trans_type = perm == kNH2NC ? kNHWC2NCHW : kNCHW2NHWC;
113     *trans_type = *trans_type == kNONE ? local_trans_type : *trans_type;
114     if (*trans_type != local_trans_type) {
115       *trans_type = kNONE;
116       return;
117     }
118   }
119 }
120 
JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> & in_nodes,const std::set<CNodePtr> & out_nodes,const std::set<CNodePtr> & middle_nodes,TransTypePair * trans_info)121 bool JudgeCanOptimizerForMultiOp(const std::set<CNodePtr> &in_nodes, const std::set<CNodePtr> &out_nodes,
122                                  const std::set<CNodePtr> &middle_nodes, TransTypePair *trans_info) {
123   MS_ASSERT(trans_info != nullptr);
124   SetTransType(in_nodes, &trans_info->pre_);
125   if (trans_info->pre_ == kNONE) {
126     return false;
127   }
128   SetTransType(out_nodes, &trans_info->post_);
129   if (trans_info->post_ == kNONE) {
130     return false;
131   }
132   if (trans_info->pre_ == trans_info->post_) {
133     return false;
134   }
135   TransposeStrategy transpose_strategy;
136   for (auto &middle_cnode : middle_nodes) {
137     if (IsSpecialType(middle_cnode)) {
138       continue;
139     }
140     auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0));
141     MS_CHECK_TRUE_MSG(middle_node_prim != nullptr, false, "GetValueNode failed");
142     if (!transpose_strategy.CanChangeOpAxis(middle_cnode)) {
143       return false;
144     }
145   }
146   return true;
147 }
148 
ConvertTensorToNCOrNH(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index,FmkType fmk_type,bool train_flag,FormatTransNodeType trans_type)149 int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
150                           bool train_flag, FormatTransNodeType trans_type) {
151   MS_ASSERT(cnode != nullptr);
152   if (utils::isa<CNodePtr>(cnode->input(index))) {
153     return lite::RET_OK;
154   }
155   lite::DataInfo data_info;
156   int status = 0;
157   if (utils::isa<ParameterPtr>(cnode->input(index))) {
158     auto input_node = cnode->input(index)->cast<ParameterPtr>();
159     MS_CHECK_TRUE_MSG(input_node != nullptr, lite::RET_ERROR, "input_node is nullptr");
160     if (!input_node->has_default()) {
161       return lite::RET_OK;
162     }
163     status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
164   } else {
165     status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
166   }
167   if (status != lite::RET_OK) {
168     return lite::RET_ERROR;
169   }
170   if (data_info.shape_.empty() ||
171       (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
172     return lite::RET_OK;
173   }
174   ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
175   if (data_info.shape_.size() == 1) {
176     expand_shape = {1, 1, 1, data_info.shape_[0]};
177   } else if (data_info.shape_.size() == kInputSizeTwo) {
178     expand_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]};
179   } else if (data_info.shape_.size() == kInputSizeThree) {
180     expand_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]};
181   }
182   auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
183                                                  data_info.data_.data(), data_info.data_.size());
184   MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr");
185   if (trans_type == kNHWC2NCHW) {
186     (void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
187   } else {
188     (void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
189   }
190   auto param_node = func_graph->add_parameter();
191   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add_parameter failed");
192   param_node->set_name(cnode->input(index)->fullname_with_scope());
193   status = lite::InitParameterFromTensorInfo(param_node, tensor);
194   if (status != RET_OK) {
195     MS_LOG(ERROR) << "init parameter from tensor info failed";
196     return lite::RET_ERROR;
197   }
198   auto tr = func_graph->manager()->Transact();
199   tr.SetEdge(cnode, index, param_node);
200   tr.Commit();
201   return lite::RET_OK;
202 }
203 }  // namespace
204 
PostTransposeFusion(const FuncGraphPtr & func_graph,const CNodePtr & cnode)205 STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
206   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
207   if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
208     return lite::RET_OK;
209   }
210   std::vector<int> cur_perm;
211   if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) {
212     MS_LOG(ERROR) << "get transpose perm failed.";
213     return lite::RET_ERROR;
214   }
215   auto node_users = func_graph->manager()->node_users()[cnode];
216   for (auto &node_user : node_users) {
217     auto post_node = node_user.first;
218     if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
219       std::vector<int> post_trans_perm;
220       auto post_trans_node = post_node->cast<CNodePtr>();
221       MS_ASSERT(post_trans_node != nullptr);
222       if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) {
223         MS_LOG(ERROR) << "get post transpose node perm failed.";
224         return lite::RET_ERROR;
225       }
226       if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) {
227         func_graph->manager()->Replace(post_node, cnode->input(1));
228       }
229     }
230   }
231   return lite::RET_OK;
232 }
233 
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> perm,bool before,size_t index)234 STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
235                                           const std::vector<int> perm, bool before, size_t index) {
236   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
237   AnfNodePtr new_input = nullptr;
238   new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
239   if (new_input == nullptr) {
240     MS_LOG(ERROR) << "generate a transpose node failed.";
241     return lite::RET_ERROR;
242   }
243   if (new_input == cnode->input(index) || new_input == cnode) {
244     return lite::RET_OK;
245   } else if (utils::isa<CNodePtr>(new_input)) {
246     auto new_cnode_input = new_input->cast<CNodePtr>();
247     MS_ASSERT(new_cnode_input != nullptr);
248     int status = lite::RET_OK;
249     if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) {
250       status = node_infer_shape_.InferShape(new_cnode_input);
251     }
252     if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
253       MS_LOG(ERROR) << "infer shape failed.";
254       return lite::RET_ERROR;
255     }
256   }
257   auto manager = func_graph->manager();
258   if (manager == nullptr) {
259     manager = Manage(func_graph, true);
260   }
261   auto tr = manager->Transact();
262   if (before) {
263     tr.SetEdge(cnode, index, new_input);
264     tr.Commit();
265   } else {
266     func_graph->manager()->Replace(cnode, new_input);
267     if (PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
268       MS_LOG(ERROR) << "post transpose fusion failed.";
269       return lite::RET_ERROR;
270     }
271   }
272   return lite::RET_OK;
273 }
274 
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)275 STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
276                                                  const std::vector<int> &perm) {
277   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
278   auto prim_node = cnode->input(0);
279   MS_CHECK_TRUE_MSG(prim_node != nullptr, lite::RET_ERROR, "prim_node is nullptr");
280   auto prim = GetValueNode<PrimitivePtr>(prim_node);
281   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
282   auto &specify_nhwc_op_map = GetNHWCOpMap();
283   auto &specify_nchw_op_map = GetNCHWOpMap();
284   if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
285       specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
286     MS_LOG(ERROR) << "op don't meet nhwc condition.";
287     return lite::RET_ERROR;
288   }
289   std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
290                                        ? specify_nhwc_op_map.at(prim->name())
291                                        : specify_nchw_op_map.at(prim->name());
292   if (insert_index.empty()) {
293     if (CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
294         GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
295       insert_index.push_back(1);
296     } else {
297       for (size_t i = 1; i < cnode->size(); ++i) {
298         insert_index.push_back(i);
299       }
300     }
301   }
302   for (auto &index : insert_index) {
303     if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
304       MS_LOG(ERROR) << "generate a new input failed.";
305       return lite::RET_ERROR;
306     }
307   }
308   return lite::RET_OK;
309 }
310 
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_insert_info)311 STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
312                                                  TransTypePair *trans_insert_info) {
313   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
314   MS_ASSERT(trans_insert_info != nullptr);
315   TransTypePair trans_info;
316   auto origin_inputs = cnode->inputs();
317   lite::RemoveIfMakeTuple(cnode);
318   RemoveIfMonad(cnode);
319   if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) {
320     cnode->set_inputs(origin_inputs);
321     return lite::RET_NO_CHANGE;
322   }
323   cnode->set_inputs(origin_inputs);
324   auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_insert_info->pre_);
325   if (status == lite::RET_NOT_SUPPORT) {
326     return lite::RET_NO_CHANGE;
327   } else if (status != lite::RET_OK) {
328     MS_LOG(ERROR) << "change op attr failed.";
329     return lite::RET_ERROR;
330   }
331   auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
332   for (size_t i = 1; i < cnode->size(); ++i) {
333     if (IsMonadNode(cnode->input(i))) {
334       continue;
335     }
336     if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
337         CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) {
338       auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
339       MS_ASSERT(input_make_tuple != nullptr);
340       for (size_t j = 1; j < input_make_tuple->size(); ++j) {
341         if (GenNewInput(func_graph, input_make_tuple, before_perm, true, j) != lite::RET_OK) {
342           MS_LOG(ERROR) << "generate a new input failed.";
343           return lite::RET_ERROR;
344         }
345       }
346       continue;
347     }
348     if (GenNewInput(func_graph, cnode, before_perm, true, i) != lite::RET_OK) {
349       MS_LOG(ERROR) << "generate a new input failed.";
350       return lite::RET_ERROR;
351     }
352   }
353   status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
354   if (status != lite::RET_OK) {
355     MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
356     return lite::RET_ERROR;
357   }
358   status = node_infer_shape_.InferShape(cnode);
359   if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
360     MS_LOG(ERROR) << "infer shape failed.";
361     return lite::RET_ERROR;
362   }
363   return lite::RET_OK;
364 }
365 
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)366 STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
367                                                   const std::vector<int> &perm) {
368   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
369   if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
370     if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
371       MS_LOG(ERROR) << "generate a new input failed.";
372       return lite::RET_ERROR;
373     }
374   } else {
375     auto node_users = func_graph->manager()->node_users()[cnode];
376     for (auto &node_user : node_users) {
377       auto post_node = node_user.first;
378       CNodePtr tuple_get_item = nullptr;
379       if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
380         if (!train_flag_) {
381           MS_LOG(ERROR) << "post node is invalid.";
382           return lite::RET_ERROR;
383         } else {
384           tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0);
385           MS_CHECK_TRUE_RET(tuple_get_item != nullptr, lite::RET_ERROR);
386           post_node = tuple_get_item;
387           func_graph->manager()->Replace(cnode, tuple_get_item);
388         }
389       }
390       if (func_graph->manager()->node_users()[post_node].empty()) {
391         continue;
392       }
393       auto post_cnode = post_node->cast<CNodePtr>();
394       MS_ASSERT(post_cnode != nullptr);
395       if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
396         MS_LOG(ERROR) << "generate a new input failed.";
397         return lite::RET_ERROR;
398       }
399       if (tuple_get_item != nullptr) {
400         func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
401       }
402     }
403   }
404   return lite::RET_OK;
405 }
406 
HandleGraphMultiNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::set<CNodePtr> * visit_transposes)407 STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
408                                                    std::set<CNodePtr> *visit_transposes) {
409   MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
410   auto manager = func_graph->manager();
411   MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "manager is nullptr");
412   std::set<CNodePtr> middle_nodes{};
413   std::set<CNodePtr> in_nodes{};
414   std::set<CNodePtr> out_nodes{};
415   auto status = FindAreaSurroundedByTranspose(func_graph, cnode, &in_nodes, &out_nodes, &middle_nodes);
416   if (status != lite::RET_OK) {
417     MS_LOG(ERROR) << "find an area surrounded by transpose failed.";
418     return status;
419   }
420   for (auto &in_cnode : in_nodes) {
421     if (CheckPrimitiveType(in_cnode, prim::kPrimTranspose)) {
422       visit_transposes->insert(in_cnode);
423     }
424   }
425   TransTypePair trans_info;
426   if (!JudgeCanOptimizerForMultiOp(in_nodes, out_nodes, middle_nodes, &trans_info)) {
427     return lite::RET_NO_CHANGE;
428   }
429   auto node_list = TopoSort(func_graph->get_return());
430   std::vector<CNodePtr> middle_ops_vec;
431   for (auto &node : node_list) {
432     if (!utils::isa<CNodePtr>(node)) {
433       continue;
434     }
435     if (middle_nodes.find(node->cast<CNodePtr>()) != middle_nodes.end()) {
436       middle_ops_vec.push_back(node->cast<CNodePtr>());
437       middle_nodes.erase(node->cast<CNodePtr>());
438     }
439   }
440   for (auto &in_cnode : in_nodes) {
441     manager->Replace(in_cnode, in_cnode->input(1));
442   }
443   for (auto &out_cnode : out_nodes) {
444     manager->Replace(out_cnode, out_cnode->input(1));
445   }
446   for (auto &middle_cnode : middle_ops_vec) {
447     if (IsSpecialType(middle_cnode)) {
448       continue;
449     }
450     for (size_t i = 1; i < middle_cnode->size(); ++i) {
451       status = ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
452       if (status != lite::RET_OK) {
453         MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
454         return lite::RET_ERROR;
455       }
456     }
457     status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
458     if (status != lite::RET_OK) {
459       MS_LOG(ERROR) << "change op attr failed.";
460       return lite::RET_ERROR;
461     }
462     status = ModifyCNodeFormat(middle_cnode, trans_info.post_);
463     if (status != lite::RET_OK) {
464       MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
465       return lite::RET_ERROR;
466     }
467     status = node_infer_shape_.InferShape(middle_cnode);
468     if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
469       MS_LOG(ERROR) << "infer shape failed.";
470       return lite::RET_ERROR;
471     }
472   }
473   return lite::RET_OK;
474 }
475 
SetSubGraphInput(const CNodePtr & cnode,const FuncGraphPtr & sub_graph)476 int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
477   MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
478   auto sub_inputs = sub_graph->get_inputs();
479   sub_inputs_map_[sub_graph] = sub_inputs;
480   for (auto &node : sub_inputs) {
481     auto param_node = node->cast<ParameterPtr>();
482     MS_ASSERT(param_node != nullptr);
483     auto node_name = node->fullname_with_scope();
484     auto last_underline = node_name.find_last_of("_");
485     node_name = node_name.substr(0, last_underline);
486     last_underline = node_name.find_last_of("_");
487     auto index = 0;
488     try {
489       index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
490     } catch (const std::exception &e) {
491       MS_LOG(ERROR) << "Get index failed: " << e.what();
492       return lite::RET_ERROR;
493     }
494     param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
495     if (utils::isa<CNodePtr>(cnode->input(index))) {
496       ShapeVector shape_vec = {-1};
497       auto out_cnode = cnode->input(index)->cast<CNodePtr>();
498       MS_ASSERT(out_cnode != nullptr);
499       MS_ASSERT(trans_cnode != nullptr);
500       auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
501       MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
502       if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
503         param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
504       }
505     } else {
506       lite::DataInfo data_info;
507       if (utils::isa<ParameterPtr>(cnode->input(index))) {
508         if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
509           param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
510         }
511         continue;
512       }
513       auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
514       if (status != lite::RET_OK) {
515         continue;
516       }
517       ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
518       if (data_info.data_.empty()) {
519         param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
520       } else {
521         param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
522                                                                        data_info.data_.data(), data_info.data_.size()));
523       }
524     }
525   }
526   return lite::RET_OK;
527 }
528 
ResetSubGraphInput()529 int DecreaseTransposeAlgo::ResetSubGraphInput() {
530   for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
531     auto &sub_graph = iter->first;
532     auto &sub_inputs = iter->second;
533     auto manager = sub_graph->manager();
534     MS_ASSERT(manager != nullptr);
535     for (auto &sub_input : sub_inputs) {
536       auto param_node = sub_graph->add_parameter();
537       MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add parameter failed");
538       param_node->set_abstract(sub_input->abstract()->Clone());
539       param_node->set_name(sub_input->fullname_with_scope());
540       manager->Replace(sub_input, param_node);
541       auto sub_param_input = sub_input->cast<ParameterPtr>();
542       MS_ASSERT(sub_param_input != nullptr);
543       sub_param_input->set_default_param(nullptr);
544     }
545   }
546   return lite::RET_OK;
547 }
548 
SetSubGraphOutput(const FuncGraphPtr & sub_graph)549 int DecreaseTransposeAlgo::SetSubGraphOutput(const FuncGraphPtr &sub_graph) {
550   MS_ASSERT(sub_graph != nullptr);
551   auto return_node = sub_graph->get_return();
552   MS_ASSERT(return_node != nullptr);
553   auto origin_input = return_node->inputs();
554   lite::RemoveIfDepend(return_node);
555   lite::RemoveIfMakeTuple(return_node);
556   for (size_t i = 1; i < return_node->size(); ++i) {
557     if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
558       continue;
559     }
560     auto node_name = return_node->input(i)->fullname_with_scope();
561     if (node_name.size() < kInputSizeFive || node_name.substr(node_name.size() - kInputSizeFive) != "_post") {
562       continue;
563     }
564     auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
565     MS_ASSERT(trans_cnode != nullptr);
566     auto trans_input = trans_cnode->input(1);
567     auto trans_input_name = trans_input->fullname_with_scope();
568     if (utils::isa<ParameterPtr>(trans_input)) {
569       trans_input->cast<ParameterPtr>()->set_name(node_name);
570     } else if (utils::isa<CNodePtr>(trans_input)) {
571       trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
572     }
573     trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
574     trans_cnode->set_fullname_with_scope(trans_input_name);
575   }
576   return_node->set_inputs(origin_input);
577   return lite::RET_OK;
578 }
579 
SetSubGraphAbstract(const CNodePtr & cnode,const FuncGraphPtr & sub_graph)580 int DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
581   MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
582   auto return_node = sub_graph->get_return();
583   MS_CHECK_TRUE_MSG(return_node != nullptr, lite::RET_ERROR, "return_node is nullptr");
584   auto origin_inputs = return_node->inputs();
585   lite::RemoveIfDepend(return_node);
586   lite::RemoveIfMakeTuple(return_node);
587   AbstractBasePtrList abstract_list;
588   bool infer_done = true;
589   for (size_t i = 1; i < return_node->size(); ++i) {
590     auto abstract_base = GetCNodeInputAbstract(return_node, i);
591     MS_CHECK_TRUE_MSG(abstract_base != nullptr, lite::RET_ERROR, "GetCNodeInputAbstract failed");
592     abstract_list.emplace_back(abstract_base->Clone());
593     auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
594     MS_ASSERT(abstract_tensor != nullptr);
595     auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
596     MS_ASSERT(shape_ptr != nullptr);
597     auto shape = shape_ptr->shape();
598     if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
599       infer_done = false;
600     }
601     if (utils::isa<CNodePtr>(return_node->input(i))) {
602       auto input_cnode = return_node->input(i)->cast<CNodePtr>();
603       MS_CHECK_TRUE_MSG(input_cnode != nullptr, lite::RET_ERROR, "input_cnode is nullptr");
604       if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
605         input_cnode = input_cnode->input(1)->cast<CNodePtr>();
606       }
607       auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
608       MS_CHECK_TRUE_MSG(input_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
609       if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
610         infer_done = false;
611       }
612     }
613   }
614   return_node->set_inputs(origin_inputs);
615   if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
616     cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
617   } else {
618     if (abstract_list.size() != 1) {
619       MS_LOG(ERROR) << "cnode output is invalid.";
620     }
621     cnode->set_abstract(abstract_list.front());
622   }
623   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
624   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
625   prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
626 
627   return lite::RET_OK;
628 }
629 
ModifyCNodeFormat(const CNodePtr & cnode,FormatTransNodeType pre_trans_type)630 int DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
631   MS_ASSERT(cnode != nullptr);
632   if (pre_trans_type == kNONE) {
633     return lite::RET_OK;
634   }
635   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
636   MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_ERROR, "GetValueNode Failed");
637   if (pre_trans_type == kNHWC2NCHW) {
638     primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
639   } else {
640     primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
641   }
642   return lite::RET_OK;
643 }
644 
DecreaseTransposeForSingleOp(const FuncGraphPtr & func_graph)645 bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
646   MS_ASSERT(func_graph != nullptr);
647   auto manager = Manage(func_graph, true);
648   if (manager == nullptr) {
649     MS_LOG(ERROR) << "manager is nullptr.";
650     return false;
651   }
652   auto node_list = TopoSort(func_graph->get_return());
653   int status = 0;
654   for (auto &node : node_list) {
655     if (!utils::isa<CNodePtr>(node)) {
656       continue;
657     }
658     auto cnode = node->cast<CNodePtr>();
659     MS_ASSERT(cnode != nullptr);
660     if (IsSpecialType(cnode)) {
661       continue;
662     }
663     if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
664       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
665       if (sub_func_graph == nullptr) {
666         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
667         return false;
668       }
669       auto ret = SetSubGraphInput(cnode, sub_func_graph);
670       if (ret != lite::RET_OK) {
671         MS_LOG(ERROR) << "SetSubGraphInput failed";
672         return false;
673       }
674       (void)DecreaseTransposeForSingleOp(sub_func_graph);
675       ret = SetSubGraphOutput(sub_func_graph);
676       if (ret != lite::RET_OK) {
677         MS_LOG(ERROR) << "SetSubGraphOutput failed";
678         return false;
679       }
680       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
681       if (sub_func_graph == nullptr) {
682         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
683         return false;
684       }
685       ret = SetSubGraphInput(cnode, sub_func_graph);
686       if (ret != lite::RET_OK) {
687         MS_LOG(ERROR) << "SetSubGraphInput failed";
688         return false;
689       }
690       (void)DecreaseTransposeForSingleOp(sub_func_graph);
691       ret = SetSubGraphOutput(sub_func_graph);
692       if (ret != lite::RET_OK) {
693         MS_LOG(ERROR) << "SetSubGraphOutput failed";
694         return false;
695       }
696       ret = SetSubGraphAbstract(cnode, sub_func_graph);
697       if (ret != lite::RET_OK) {
698         MS_LOG(ERROR) << "SetSubGraphAbstract failed";
699         return false;
700       }
701       continue;
702     }
703     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
704     MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
705     if (!IsDynamicFormatOp(prim->name())) {
706       continue;
707     }
708     TransTypePair trans_insert_info;
709     status = InsertPreTransNode(func_graph, cnode, &trans_insert_info);
710     if (status == lite::RET_NO_CHANGE) {
711       continue;
712     } else if (status != lite::RET_OK) {
713       MS_LOG(ERROR) << "insert pre node failed.";
714       return false;
715     }
716     auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
717     if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
718       MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
719       return false;
720     }
721   }
722   return true;
723 }
724 
DecreaseTransposeForMultiOp(const FuncGraphPtr & func_graph)725 bool DecreaseTransposeAlgo::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
726   MS_ASSERT(func_graph != nullptr);
727   auto manager = Manage(func_graph, true);
728   if (manager == nullptr) {
729     MS_LOG(ERROR) << "manager is nullptr.";
730     return false;
731   }
732   auto node_list = TopoSort(func_graph->get_return());
733   std::set<CNodePtr> visit_transposes;
734   for (auto &node : node_list) {
735     if (!utils::isa<CNodePtr>(node)) {
736       continue;
737     }
738     auto cnode = node->cast<CNodePtr>();
739     MS_ASSERT(cnode != nullptr);
740     if (IsSpecialType(cnode) || visit_transposes.find(cnode) != visit_transposes.end()) {
741       continue;
742     }
743     if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
744       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
745       if (sub_func_graph == nullptr) {
746         return false;
747       }
748       (void)DecreaseTransposeForMultiOp(sub_func_graph);
749       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
750       if (sub_func_graph == nullptr) {
751         return false;
752       }
753       (void)DecreaseTransposeForMultiOp(sub_func_graph);
754     }
755     std::vector<int> perm{};
756     if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
757         perm != kNH2NC) {
758       continue;
759     }
760     auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes);
761     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
762       MS_LOG(ERROR) << "global optimizer failed.";
763       return false;
764     }
765   }
766   return true;
767 }
768 
RunDoFixFormat(const FuncGraphPtr & func_graph,const CNodePtr & cnode)769 bool DecreaseTransposeAlgo::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
770   auto prim_node = cnode->input(0);
771   auto prim = GetValueNode<PrimitivePtr>(prim_node);
772   MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
773   auto &nchw_op = GetNCHWOpMap();
774   if (!utils::isa<CNodePtr>(cnode->input(1))) {
775     return true;
776   }
777   if (utils::isa<CNodePtr>(cnode->input(1))) {
778     auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
779     if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) {
780       InsertPreTransNode(func_graph, cnode, kNH2NC);
781       InsertPostTransNode(func_graph, cnode, kNC2NH);
782     }
783   }
784   return true;
785 }
786 
DoFixFormat(const FuncGraphPtr & func_graph)787 bool DecreaseTransposeAlgo::DoFixFormat(const FuncGraphPtr &func_graph) {
788   auto node_list = TopoSort(func_graph->get_return());
789   for (auto &node : node_list) {
790     if (!utils::isa<CNodePtr>(node)) {
791       continue;
792     }
793     auto cnode = node->cast<CNodePtr>();
794     MS_ASSERT(cnode != nullptr);
795     if (IsSpecialType(cnode)) {
796       continue;
797     }
798     if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
799       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
800       if (sub_func_graph == nullptr) {
801         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
802         return false;
803       }
804       SetSubGraphInput(cnode, sub_func_graph);
805       if (!DoFixFormat(sub_func_graph)) {
806         MS_LOG(ERROR) << "subgraph infer shape failed.";
807         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
808         return false;
809       }
810       SetSubGraphOutput(sub_func_graph);
811 
812       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
813       if (sub_func_graph == nullptr) {
814         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
815         return false;
816       }
817       SetSubGraphInput(cnode, sub_func_graph);
818       if (!DoFixFormat(sub_func_graph)) {
819         MS_LOG(ERROR) << "subgraph infer shape failed.";
820         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
821         return false;
822       }
823       SetSubGraphOutput(sub_func_graph);
824       SetSubGraphAbstract(cnode, sub_func_graph);
825       continue;
826     }
827     if (!RunDoFixFormat(func_graph, cnode)) {
828       return false;
829     }
830   }
831   return true;
832 }
833 
Run(const FuncGraphPtr & func_graph)834 bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
835   MS_ASSERT(func_graph != nullptr);
836   node_infer_shape_.Init(fmk_type_, train_flag_);
837   transpose_strategy_.Init(fmk_type_, train_flag_);
838   if (!delete_redundant_transpose_.Run(func_graph)) {
839     MS_LOG(ERROR) << "Run delete-redundant-transpose pass failed.";
840     return false;
841   }
842   auto node_list = TopoSort(func_graph->get_return());
843   for (auto &node : node_list) {
844     auto prim = GetValueNode<PrimitivePtr>(node);
845     if (prim == nullptr) {
846       continue;
847     }
848   }
849 
850   if (!DoFixFormat(func_graph)) {
851     MS_LOG(ERROR) << "DoFixFormat failed.";
852     return false;
853   }
854   ResetSubGraphInput();
855 
856   if (!DecreaseTransposeForSingleOp(func_graph)) {
857     MS_LOG(ERROR) << "run local trans insert optimizer failed.";
858     return false;
859   }
860 
861   auto ret = ResetSubGraphInput();
862   if (ret != lite::RET_OK) {
863     MS_LOG(ERROR) << "ResetSubGraphInput failed.";
864     return false;
865   }
866   // if input format of several ops surrounded only by transpose op all can be NHWC,
867   // we can delete these transpose ops, and at the same time, transform these middle ops.
868   if (!DecreaseTransposeForMultiOp(func_graph)) {
869     MS_LOG(ERROR) << "run global trans insert optimizer failed.";
870     return false;
871   }
872   return true;
873 }
874 }  // namespace opt
875 }  // namespace mindspore
876