• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/format/to_format_base.h"
19 #include <set>
20 #include "mindspore/core/ops/sequence_ops.h"
21 #include "mindspore/core/ops/nn_optimizer_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "ops/op_utils.h"
25 #include "src/common/common.h"
26 #include "src/common/utils.h"
27 #include "tools/common/tensor_util.h"
28 #include "tools/converter/parser/parser_utils.h"
29 #include "nnacl/op_base.h"
30 
31 using mindspore::lite::NHWC_SHAPE;
32 namespace mindspore {
33 namespace opt {
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)34 STATUS ToFormatBase::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
35                                  bool before, size_t index) {
36   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
37   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
38   AnfNodePtr trans_input = before ? cnode->input(index) : cnode;
39   std::string trans_name = before ? cnode->fullname_with_scope() + "_pre_" + std::to_string(index - 1)
40                                   : cnode->fullname_with_scope() + "_post";
41   auto trans_cnode = opt::GenTransposeNode(func_graph, trans_input, perm, trans_name);
42 
43   MS_ERROR_IF_NULL_W_RET_VAL(trans_cnode, lite::RET_ERROR);
44   if (DecideWhetherInferShapeForNewNode()) {
45     auto status = node_infer_shape_->InferShape(trans_cnode);
46     if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
47       MS_LOG(ERROR) << "infer generated trans node failed.";
48       return lite::RET_ERROR;
49     }
50   } else {
51     auto abstract = trans_input->abstract();
52     if (abstract != nullptr) {
53       trans_cnode->set_abstract(abstract->Clone());
54     }
55   }
56   auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
57   MS_ERROR_IF_NULL_W_RET_VAL(trans_prim, lite::RET_ERROR);
58   if (perm == kNC2NH) {
59     trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
60   } else if (perm == kNH2NC) {
61     trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
62   }
63   MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
64   if (before) {
65     manager_->SetEdge(cnode, index, trans_cnode);
66   } else {
67     if (!manager_->Replace(cnode, trans_cnode)) {
68       MS_LOG(ERROR) << "replace old node failed, please check.";
69       return lite::RET_ERROR;
70     }
71   }
72   return lite::RET_OK;
73 }
74 
ModifyCNode(const CNodePtr & cnode)75 STATUS ToFormatBase::ModifyCNode(const CNodePtr &cnode) {
76   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
77   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
78   if (prim == nullptr) {
79     MS_LOG(ERROR) << "current node's prim is nullptr, " << cnode->fullname_with_scope();
80     return lite::RET_ERROR;
81   }
82   auto insert_pos = sensitive_ops_[prim->name()];
83   if (insert_pos.empty() || std::find(insert_pos.begin(), insert_pos.end(), 1) != insert_pos.end()) {
84     prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format_));
85     if (prim->HasAttr(opt::kOutputsFormat)) {
86       auto org_format = CastToInt(prim->GetAttr(opt::kOutputsFormat));
87       std::vector<int64_t> outputs_format(org_format.size(), format_);
88       (void)prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
89     }
90   }
91   auto abstract_base = cnode->abstract();
92   MS_ERROR_IF_NULL_W_RET_VAL(abstract_base, lite::RET_ERROR);
93   std::vector<AbstractBasePtr> abstracts;
94   if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
95     auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
96     abstracts = abstract_tuple->elements();
97   } else {
98     abstracts.push_back(abstract_base);
99   }
100   for (auto &abstract : abstracts) {
101     ShapeVector shape;
102     if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
103       MS_LOG(ERROR) << "fetch shape failed, " << cnode->fullname_with_scope();
104       return lite::RET_ERROR;
105     }
106     if (shape.size() < kInputSizeThree) {
107       MS_LOG(DEBUG) << "shape don't need to modify.";
108       continue;
109     }
110     if (format_ == mindspore::NCHW) {
111       ShapeVector transfer_shape = shape;
112       size_t shape_size = shape.size();
113       transfer_shape[1] = shape[shape_size - 1];
114       for (size_t i = kDim2; i < shape_size; i++) {
115         transfer_shape[i] = shape[i - 1];
116       }
117       abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
118     } else {
119       ShapeVector transfer_shape = shape;
120       size_t shape_size = shape.size();
121       transfer_shape[shape_size - 1] = shape[1];
122       for (size_t i = kDim1; i < shape_size - 1; i++) {
123         transfer_shape[i] = shape[i + 1];
124       }
125       abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
126     }
127   }
128   return lite::RET_OK;
129 }
130 
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)131 STATUS ToFormatBase::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
132                                         const std::vector<int> &perm) {
133   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
134   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
135   std::vector<size_t> insert_index;
136   if (GetFormatSensitiveOpInsertIndex(cnode, &insert_index) != RET_OK) {
137     MS_LOG(ERROR) << "GetFormatSensitiveOpInsertIndex failed.";
138     return RET_ERROR;
139   }
140   if (insert_index.size() == 0) {
141     MS_LOG(ERROR) << "op don't meet condition.";
142     return lite::RET_ERROR;
143   }
144   for (auto &index : insert_index) {
145     if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
146       MS_LOG(ERROR) << "generate a new input failed.";
147       return lite::RET_ERROR;
148     }
149   }
150   return lite::RET_OK;
151 }
152 
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)153 STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
154                                          const std::vector<int> &perm) {
155   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
156   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
157   if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
158     if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
159       MS_LOG(ERROR) << "generate a new input failed.";
160       return lite::RET_ERROR;
161     }
162   } else {
163     auto node_users = manager_->node_users()[cnode];
164     for (auto &node_user : node_users) {
165       auto post_node = node_user.first;
166       CNodePtr tuple_get_item = nullptr;
167       if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
168         if (!train_flag_) {
169           MS_LOG(ERROR) << "post node is invalid.";
170           return lite::RET_ERROR;
171         } else {
172           tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
173           if (!manager_->Replace(cnode, tuple_get_item, post_node)) {
174             MS_LOG(ERROR) << "replace node failed.";
175             return lite::RET_ERROR;
176           }
177           post_node = tuple_get_item;
178         }
179       }
180       if (manager_->node_users()[post_node].empty()) {
181         continue;
182       }
183       auto post_cnode = post_node->cast<CNodePtr>();
184       if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
185         MS_LOG(ERROR) << "generate a new input failed.";
186         return lite::RET_ERROR;
187       }
188       if (tuple_get_item != nullptr) {
189         if (!manager_->Replace(tuple_get_item, tuple_get_item->input(1))) {
190           MS_LOG(ERROR) << "replace old node failed. please check.";
191           return lite::RET_ERROR;
192         }
193       }
194     }
195   }
196   return lite::RET_OK;
197 }
198 
DecideWhetherHandleGraphInput(const FuncGraphPtr & func_graph,const ParameterPtr & input,const ShapeVector & shape)199 bool ToFormatBase::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
200                                                  const ShapeVector &shape) {
201   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, false);
202   MS_ERROR_IF_NULL_W_RET_VAL(input, false);
203   if (shape.size() != kInputSizeFour) {
204     return false;
205   }
206   MS_ERROR_IF_NULL_W_RET_VAL(manager_, false);
207   auto node_users = manager_->node_users()[input];
208   for (auto &node_user : node_users) {
209     auto post_node = node_user.first;
210     if (!utils::isa<CNode>(post_node)) {
211       continue;
212     }
213     auto post_cnode = post_node->cast<CNodePtr>();
214     auto prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
215     MS_ERROR_IF_NULL_W_RET_VAL(prim, false);
216     if (prim->GetAttr(ops::kFormat) != nullptr) {
217       auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
218       if (node_format == format_) {
219         MS_LOG(DEBUG) << "this graph input don't need to change.";
220         return false;
221       }
222     }
223   }
224   return true;
225 }
226 
HandleGraphInput(const FuncGraphPtr & func_graph)227 STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
228   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
229   auto graph_input = func_graph->get_inputs();
230   for (auto &input : graph_input) {
231     auto input_param = input->cast<ParameterPtr>();
232     MS_ERROR_IF_NULL_W_RET_VAL(input_param, lite::RET_ERROR);
233     auto abstract = input_param->abstract();
234     MS_ERROR_IF_NULL_W_RET_VAL(abstract, lite::RET_ERROR);
235     ShapeVector shape;
236     if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
237       MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
238       return lite::RET_ERROR;
239     }
240     if (!DecideWhetherHandleGraphInput(func_graph, input_param, shape)) {
241       continue;
242     }
243     ShapeVector transfer_shape;
244     if (format_ == mindspore::NCHW) {
245       transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
246     } else {
247       transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
248     }
249     CNodePtr trans_cnode;
250     if (format_ == mindspore::NCHW) {
251       trans_cnode = opt::GenTransposeNode(func_graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
252     } else {
253       trans_cnode = opt::GenTransposeNode(func_graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
254     }
255     if (trans_cnode == nullptr) {
256       MS_LOG(ERROR) << "create transpose cnode failed.";
257       return lite::RET_ERROR;
258     }
259     auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
260     MS_ERROR_IF_NULL_W_RET_VAL(trans_prim, lite::RET_ERROR);
261     if (format_ == mindspore::NCHW) {
262       trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
263     } else {
264       trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
265     }
266     trans_cnode->set_abstract(abstract->Clone());
267     abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
268     if (!manager_->Replace(input, trans_cnode)) {
269       MS_LOG(ERROR) << "replace old node failed, please check.";
270       return lite::RET_ERROR;
271     }
272   }
273   return lite::RET_OK;
274 }
275 
DealConv2dTransposeFusionNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)276 STATUS ToFormatBase::DealConv2dTransposeFusionNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
277                                                    const std::vector<int> &perm) {
278   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
279   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
280   const int kInputSizeIndex = 3;
281   auto prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
282   auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
283   MS_ERROR_IF_NULL_W_RET_VAL(prim, lite::RET_ERROR);
284   auto val_ptr = prim->GetAttr(ops::kOriginalOpName);
285   if (val_ptr == nullptr || GetValue<std::string>(val_ptr) != "Conv2DBackpropInput" ||
286       cnode->size() < kInputSizeIndex + 1) {  // no input_size
287     return lite::RET_OK;
288   }
289   if (func_graph->has_attr(lite::kIsDynamicShape) && GetValue<bool>(func_graph->get_attr(lite::kIsDynamicShape))) {
290     MS_LOG(DEBUG) << "Dynamic input shape does not need Conv2dTransposeFusion format conversion";
291     return lite::RET_OK;
292   }
293   auto gather_input = cnode->input(kInputSizeIndex);
294   MS_CHECK_TRUE_MSG(gather_input != nullptr, RET_ERROR, "gather input is nullptr");
295   auto abstract = gather_input->abstract();
296   MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr");
297   std::vector<int> gather_indices_n;
298   std::vector<int> gather_indices_hw;
299   std::vector<int> gather_indices_c;
300   auto value_ptr = MakeValue<int64_t>(NCHW);
301   if (perm == kNH2NC) {          // NHWC To NCHW
302     gather_indices_n = {0};      // fetch N dimension
303     gather_indices_hw = {1, 2};  // fetch H and W dimension
304     gather_indices_c = {3};      // fetch C dimension
305   } else {                       // NCHW To NHWC
306     gather_indices_n = {0};      // fetch N dimension;
307     gather_indices_hw = {2, 3};  // fetch H and W dimension
308     gather_indices_c = {1};      // fetch C dimension
309     value_ptr = MakeValue<int64_t>(NHWC);
310   }
311   auto gather_name_n = cnode->fullname_with_scope() + "_gather_n";
312   auto gather_cnode_n = opt::GenGatherNode(func_graph, gather_input, gather_indices_n, gather_name_n);
313   MS_CHECK_TRUE_MSG(gather_cnode_n != nullptr, RET_ERROR, "create gather cnode n failed.");
314   auto gather_prim_n = GetValueNode<PrimitivePtr>(gather_cnode_n->input(0));
315   (void)gather_prim_n->AddAttr(ops::kFormat, value_ptr);
316   ShapeVector gather_n_shape = {1};
317   auto n_shape_ptr = std::make_shared<abstract::Shape>(gather_n_shape);
318   MS_CHECK_TRUE_MSG(n_shape_ptr != nullptr, RET_ERROR, "n_shape_ptr is nullptr.");
319   auto tmp_abstract = abstract->Clone();
320   tmp_abstract->set_shape(n_shape_ptr);
321   gather_cnode_n->set_abstract(tmp_abstract);
322 
323   auto gather_name_c = cnode->fullname_with_scope() + "_gather_c";
324   auto gather_cnode_c = opt::GenGatherNode(func_graph, gather_input, gather_indices_c, gather_name_c);
325   MS_CHECK_TRUE_MSG(gather_cnode_c != nullptr, RET_ERROR, "create gather cnode c failed.");
326   auto gather_prim_c = GetValueNode<PrimitivePtr>(gather_cnode_c->input(0));
327   (void)gather_prim_c->AddAttr(ops::kFormat, value_ptr);
328   ShapeVector gather_c_shape = {1};
329   auto c_shape_ptr = std::make_shared<abstract::Shape>(gather_c_shape);
330   MS_CHECK_TRUE_MSG(c_shape_ptr != nullptr, RET_ERROR, "c_shape_ptr is nullptr.");
331   tmp_abstract = abstract->Clone();
332   tmp_abstract->set_shape(c_shape_ptr);
333   gather_cnode_c->set_abstract(tmp_abstract);
334 
335   auto gather_name_hw = cnode->fullname_with_scope() + "_gather_hw";
336   auto gather_cnode_hw = opt::GenGatherNode(func_graph, gather_input, gather_indices_hw, gather_name_hw);
337   MS_CHECK_TRUE_MSG(gather_cnode_hw != nullptr, RET_ERROR, "create gather cnode hw failed.");
338   auto gather_prim_hw = GetValueNode<PrimitivePtr>(gather_cnode_hw->input(0));
339   (void)gather_prim_hw->AddAttr(ops::kFormat, value_ptr);
340   ShapeVector gather_hw_shape = {2};
341   auto hw_shape_ptr = std::make_shared<abstract::Shape>(gather_hw_shape);
342   MS_CHECK_TRUE_MSG(hw_shape_ptr != nullptr, RET_ERROR, "hw_shape_ptr is nullptr.");
343   tmp_abstract = abstract->Clone();
344   tmp_abstract->set_shape(hw_shape_ptr);
345   gather_cnode_hw->set_abstract(tmp_abstract);
346 
347   std::vector<AnfNodePtr> concat_inputnodes;
348   if (perm == kNH2NC) {
349     concat_inputnodes = {gather_cnode_n, gather_cnode_c, gather_cnode_hw};
350   } else {
351     concat_inputnodes = {gather_cnode_n, gather_cnode_hw, gather_cnode_c};
352   }
353   auto concat_name = cnode->fullname_with_scope() + "_concat_gather";
354   auto concat_node = opt::GenConcatNode(func_graph, concat_inputnodes, concat_name);
355   MS_CHECK_TRUE_MSG(concat_node != nullptr, RET_ERROR, "create concat_node failed.");
356   auto concat_node_prim = GetValueNode<PrimitivePtr>(concat_node->input(0));
357   (void)concat_node_prim->AddAttr(ops::kFormat, value_ptr);
358   ShapeVector concat_shape = {4};
359   auto concat_shape_ptr = std::make_shared<abstract::Shape>(concat_shape);
360   MS_CHECK_TRUE_MSG(concat_shape_ptr != nullptr, RET_ERROR, "concat_shape_ptr is nullptr.");
361   tmp_abstract = abstract->Clone();
362   tmp_abstract->set_shape(concat_shape_ptr);
363   concat_node->set_abstract(tmp_abstract);
364   manager_->SetEdge(cnode, kInputSizeIndex, concat_node);
365   return lite::RET_OK;
366 }
367 
SetCNodeFormat(const CNodePtr & cnode,mindspore::Format dst_format)368 void SetCNodeFormat(const CNodePtr &cnode, mindspore::Format dst_format) {
369   MS_ASSERT(cnode != nullptr);
370   // update the format of cnode.
371   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
372   MS_CHECK_TRUE_RET_VOID(prim != nullptr);
373   auto format_value = prim->GetAttr(ops::kOriginalFormat);
374   if (prim->GetAttr(ops::kFormat) == nullptr && format_value != nullptr) {
375     auto format = GetValue<int64_t>(format_value);
376     if (format == dst_format) {
377       (void)prim->AddAttr(ops::kFormat, format_value);
378     }
379   }
380   return;
381 }
382 
HandleGraphNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)383 STATUS ToFormatBase::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
384   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
385   MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
386   opt::TransTypePair trans_info;
387   if (GetTransNodeFormatType(cnode, &trans_info) != lite::RET_OK) {
388     MS_LOG(ERROR) << "obtain node's transferring format type failed, " << cnode->fullname_with_scope();
389     return lite::RET_ERROR;
390   }
391   if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
392     SetCNodeFormat(cnode, format_);
393     return lite::RET_NO_CHANGE;
394   }
395   auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? kNH2NC : kNC2NH;
396   auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? kNC2NH : kNH2NC;
397   if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) &&
398       DealConv2dTransposeFusionNode(func_graph, cnode, before_perm) != lite::RET_OK) {
399     MS_LOG(ERROR) << "Deal conv2d transpose fusion attr: input_size failed." << cnode->fullname_with_scope();
400     return lite::RET_ERROR;
401   }
402   if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
403     MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
404     return lite::RET_ERROR;
405   }
406   if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
407     return lite::RET_OK;
408   }
409   if (ModifyCNode(cnode) != lite::RET_OK) {
410     MS_LOG(ERROR) << "adjust cnode's output shape failed, " << cnode->fullname_with_scope();
411     return lite::RET_ERROR;
412   }
413   if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
414     MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
415     return lite::RET_ERROR;
416   }
417   return lite::RET_OK;
418 }
419 
BasicProcess(const FuncGraphPtr & func_graph,bool main_graph)420 bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
421   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, false);
422   manager_->AddFuncGraph(func_graph);
423   auto node_list = TopoSort(func_graph->get_return());
424   int status;
425   for (auto &node : node_list) {
426     MS_CHECK_TRUE_RET(node != nullptr, false);
427     if (!utils::isa<CNodePtr>(node)) {
428       continue;
429     }
430     auto cnode = node->cast<CNodePtr>();
431     if (IsSpecialType(cnode)) {
432       continue;
433     }
434     if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
435       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
436       if (sub_func_graph == nullptr) {
437         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
438         return false;
439       }
440       if (!BasicProcess(sub_func_graph, false)) {
441         MS_LOG(ERROR) << "process sub graph failed.";
442         return false;
443       }
444       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
445       if (sub_func_graph == nullptr) {
446         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
447         return false;
448       }
449       if (!BasicProcess(sub_func_graph, false)) {
450         MS_LOG(ERROR) << "process sub graph failed.";
451         return false;
452       }
453       continue;
454     }
455     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
456     if (prim == nullptr) {
457       MS_LOG(INFO) << "this is a call cnode, which input[0] is fg, node " << cnode->fullname_with_scope();
458       continue;
459     }
460     status = HandleGraphNode(func_graph, cnode);
461     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
462       MS_LOG(ERROR) << "handle node failed.";
463       return false;
464     }
465   }
466 
467   if (main_graph && save_type_ != kMindIR) {
468     status = HandleGraphInput(func_graph);
469     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
470       MS_LOG(ERROR) << "handle graph input failed.";
471       return false;
472     }
473   }
474   return true;
475 }
476 
ConvWeightFormatTrans(const FuncGraphPtr & graph,std::set<AnfNodePtr> * has_visited)477 STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<AnfNodePtr> *has_visited) {
478   MS_ERROR_IF_NULL_W_RET_VAL(graph, lite::RET_ERROR);
479   MS_ERROR_IF_NULL_W_RET_VAL(has_visited, lite::RET_ERROR);
480   manager_->AddFuncGraph(graph);
481   auto node_list = TopoSort(graph->get_return());
482   for (auto &node : node_list) {
483     if (!utils::isa<CNodePtr>(node)) {
484       continue;
485     }
486     auto cnode = node->cast<CNodePtr>();
487     if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
488       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
489       if (sub_func_graph == nullptr) {
490         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
491         return lite::RET_NULL_PTR;
492       }
493       if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
494         MS_LOG(ERROR) << "transform conv weight format failed.";
495         return lite::RET_ERROR;
496       }
497       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
498       if (sub_func_graph == nullptr) {
499         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
500         return lite::RET_NULL_PTR;
501       }
502       if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
503         MS_LOG(ERROR) << "transform conv weight format failed.";
504         return lite::RET_ERROR;
505       }
506       continue;
507     }
508     if (!IsWeightNodeSensitive(cnode)) {
509       continue;
510     }
511     if (has_visited->find(node) != has_visited->end()) {
512       continue;
513     }
514     has_visited->insert(node);
515     schema::Format src_format = schema::Format_NUM_OF_FORMAT;
516     schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
517     if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) {
518       MS_LOG(ERROR) << "weight's src format and dst format get failed.";
519       return lite::RET_ERROR;
520     }
521     auto status = lite::UnifyConvWeightFormat(graph, cnode, src_format, dst_format, has_visited);
522     if (status != lite::RET_OK) {
523       MS_LOG(ERROR) << "unify conv weight failed, current node name is " << cnode->fullname_with_scope();
524       return status;
525     }
526   }
527   return lite::RET_OK;
528 }
529 
NodeConvWeightFormatTrans(const FuncGraphPtr & graph,const CNodePtr & cnode)530 STATUS ToFormatBase::NodeConvWeightFormatTrans(const FuncGraphPtr &graph, const CNodePtr &cnode) {
531   MS_ERROR_IF_NULL_W_RET_VAL(graph, lite::RET_ERROR);
532   manager_->AddFuncGraph(graph);
533   if (!IsWeightNodeSensitive(cnode)) {
534     return RET_OK;
535   }
536 
537   schema::Format src_format = schema::Format_NUM_OF_FORMAT;
538   schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
539   if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) {
540     MS_LOG(ERROR) << "weight's src format and dst format get failed.";
541     return lite::RET_ERROR;
542   }
543   std::set<AnfNodePtr> has_visited;
544   auto status = lite::UnifyConvWeightFormat(graph, cnode, src_format, dst_format, &has_visited);
545   if (status != lite::RET_OK) {
546     MS_LOG(ERROR) << "unify conv weight failed, current node name is " << cnode->fullname_with_scope();
547     return status;
548   }
549   return lite::RET_OK;
550 }
551 
RunPassOneNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)552 STATUS ToFormatBase::RunPassOneNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
553   SetSensitiveOps();
554   node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
555   if (node_infer_shape_ == nullptr) {
556     MS_LOG(ERROR) << "create NodeInferShape object failed.";
557     return false;
558   }
559   manager_ = Manage(func_graph, true);
560   if (manager_ == nullptr) {
561     MS_LOG(ERROR) << "manager is nullptr.";
562     return false;
563   }
564   auto status = NodeConvWeightFormatTrans(func_graph, cnode);
565   if (status != lite::RET_OK) {
566     MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
567     return false;
568   }
569   status = HandleGraphNode(func_graph, cnode);
570   if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
571     MS_LOG(ERROR) << "handle node failed.";
572     return RET_ERROR;
573   }
574   return RET_OK;
575 }
576 
Run(const FuncGraphPtr & func_graph)577 bool ToFormatBase::Run(const FuncGraphPtr &func_graph) {
578   MS_CHECK_TRUE_RET(func_graph != nullptr, false);
579   auto value = func_graph->get_attr(ops::kFormat);
580   if (value != nullptr && GetValue<int64_t>(value) == format_) {
581     return true;
582   }
583   if (format_ != mindspore::NHWC && format_ != mindspore::NCHW) {
584     MS_LOG(ERROR) << "format transferring only support nc2nh or nh2nc.";
585     return false;
586   }
587   manager_ = Manage(func_graph, true);
588   if (manager_ == nullptr) {
589     MS_LOG(ERROR) << "manager is nullptr.";
590     return false;
591   }
592   node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
593   if (node_infer_shape_ == nullptr) {
594     MS_LOG(ERROR) << "create NodeInferShape object failed.";
595     return false;
596   }
597   std::set<AnfNodePtr> has_visited;
598   auto status = ConvWeightFormatTrans(func_graph, &has_visited);
599   if (status != lite::RET_OK) {
600     MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
601     return false;
602   }
603   SetSensitiveOps();
604   if (!BasicProcess(func_graph, true)) {
605     MS_LOG(ERROR) << "transfer format failed.";
606     return false;
607   }
608   func_graph->set_attr(ops::kFormat, MakeValue<int64_t>(format_));
609 
610   return true;
611 }
612 }  // namespace opt
613 }  // namespace mindspore
614