1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/format/to_format_base.h"
19 #include <set>
20 #include "mindspore/core/ops/sequence_ops.h"
21 #include "mindspore/core/ops/nn_optimizer_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "ops/op_utils.h"
25 #include "src/common/common.h"
26 #include "src/common/utils.h"
27 #include "tools/common/tensor_util.h"
28 #include "tools/converter/parser/parser_utils.h"
29 #include "nnacl/op_base.h"
30
31 using mindspore::lite::NHWC_SHAPE;
32 namespace mindspore {
33 namespace opt {
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)34 STATUS ToFormatBase::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
35 bool before, size_t index) {
36 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
37 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
38 AnfNodePtr trans_input = before ? cnode->input(index) : cnode;
39 std::string trans_name = before ? cnode->fullname_with_scope() + "_pre_" + std::to_string(index - 1)
40 : cnode->fullname_with_scope() + "_post";
41 auto trans_cnode = opt::GenTransposeNode(func_graph, trans_input, perm, trans_name);
42
43 MS_ERROR_IF_NULL_W_RET_VAL(trans_cnode, lite::RET_ERROR);
44 if (DecideWhetherInferShapeForNewNode()) {
45 auto status = node_infer_shape_->InferShape(trans_cnode);
46 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
47 MS_LOG(ERROR) << "infer generated trans node failed.";
48 return lite::RET_ERROR;
49 }
50 } else {
51 auto abstract = trans_input->abstract();
52 if (abstract != nullptr) {
53 trans_cnode->set_abstract(abstract->Clone());
54 }
55 }
56 auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
57 MS_ERROR_IF_NULL_W_RET_VAL(trans_prim, lite::RET_ERROR);
58 if (perm == kNC2NH) {
59 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
60 } else if (perm == kNH2NC) {
61 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
62 }
63 MS_ERROR_IF_NULL_W_RET_VAL(manager_, lite::RET_ERROR);
64 if (before) {
65 manager_->SetEdge(cnode, index, trans_cnode);
66 } else {
67 if (!manager_->Replace(cnode, trans_cnode)) {
68 MS_LOG(ERROR) << "replace old node failed, please check.";
69 return lite::RET_ERROR;
70 }
71 }
72 return lite::RET_OK;
73 }
74
ModifyCNode(const CNodePtr & cnode)75 STATUS ToFormatBase::ModifyCNode(const CNodePtr &cnode) {
76 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
77 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
78 if (prim == nullptr) {
79 MS_LOG(ERROR) << "current node's prim is nullptr, " << cnode->fullname_with_scope();
80 return lite::RET_ERROR;
81 }
82 auto insert_pos = sensitive_ops_[prim->name()];
83 if (insert_pos.empty() || std::find(insert_pos.begin(), insert_pos.end(), 1) != insert_pos.end()) {
84 prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format_));
85 if (prim->HasAttr(opt::kOutputsFormat)) {
86 auto org_format = CastToInt(prim->GetAttr(opt::kOutputsFormat));
87 std::vector<int64_t> outputs_format(org_format.size(), format_);
88 (void)prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
89 }
90 }
91 auto abstract_base = cnode->abstract();
92 MS_ERROR_IF_NULL_W_RET_VAL(abstract_base, lite::RET_ERROR);
93 std::vector<AbstractBasePtr> abstracts;
94 if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
95 auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
96 abstracts = abstract_tuple->elements();
97 } else {
98 abstracts.push_back(abstract_base);
99 }
100 for (auto &abstract : abstracts) {
101 ShapeVector shape;
102 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
103 MS_LOG(ERROR) << "fetch shape failed, " << cnode->fullname_with_scope();
104 return lite::RET_ERROR;
105 }
106 if (shape.size() < kInputSizeThree) {
107 MS_LOG(DEBUG) << "shape don't need to modify.";
108 continue;
109 }
110 if (format_ == mindspore::NCHW) {
111 ShapeVector transfer_shape = shape;
112 size_t shape_size = shape.size();
113 transfer_shape[1] = shape[shape_size - 1];
114 for (size_t i = kDim2; i < shape_size; i++) {
115 transfer_shape[i] = shape[i - 1];
116 }
117 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
118 } else {
119 ShapeVector transfer_shape = shape;
120 size_t shape_size = shape.size();
121 transfer_shape[shape_size - 1] = shape[1];
122 for (size_t i = kDim1; i < shape_size - 1; i++) {
123 transfer_shape[i] = shape[i + 1];
124 }
125 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
126 }
127 }
128 return lite::RET_OK;
129 }
130
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)131 STATUS ToFormatBase::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
132 const std::vector<int> &perm) {
133 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
134 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
135 std::vector<size_t> insert_index;
136 if (GetFormatSensitiveOpInsertIndex(cnode, &insert_index) != RET_OK) {
137 MS_LOG(ERROR) << "GetFormatSensitiveOpInsertIndex failed.";
138 return RET_ERROR;
139 }
140 if (insert_index.size() == 0) {
141 MS_LOG(ERROR) << "op don't meet condition.";
142 return lite::RET_ERROR;
143 }
144 for (auto &index : insert_index) {
145 if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
146 MS_LOG(ERROR) << "generate a new input failed.";
147 return lite::RET_ERROR;
148 }
149 }
150 return lite::RET_OK;
151 }
152
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)153 STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
154 const std::vector<int> &perm) {
155 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
156 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
157 if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
158 if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
159 MS_LOG(ERROR) << "generate a new input failed.";
160 return lite::RET_ERROR;
161 }
162 } else {
163 auto node_users = manager_->node_users()[cnode];
164 for (auto &node_user : node_users) {
165 auto post_node = node_user.first;
166 CNodePtr tuple_get_item = nullptr;
167 if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
168 if (!train_flag_) {
169 MS_LOG(ERROR) << "post node is invalid.";
170 return lite::RET_ERROR;
171 } else {
172 tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
173 if (!manager_->Replace(cnode, tuple_get_item, post_node)) {
174 MS_LOG(ERROR) << "replace node failed.";
175 return lite::RET_ERROR;
176 }
177 post_node = tuple_get_item;
178 }
179 }
180 if (manager_->node_users()[post_node].empty()) {
181 continue;
182 }
183 auto post_cnode = post_node->cast<CNodePtr>();
184 if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
185 MS_LOG(ERROR) << "generate a new input failed.";
186 return lite::RET_ERROR;
187 }
188 if (tuple_get_item != nullptr) {
189 if (!manager_->Replace(tuple_get_item, tuple_get_item->input(1))) {
190 MS_LOG(ERROR) << "replace old node failed. please check.";
191 return lite::RET_ERROR;
192 }
193 }
194 }
195 }
196 return lite::RET_OK;
197 }
198
DecideWhetherHandleGraphInput(const FuncGraphPtr & func_graph,const ParameterPtr & input,const ShapeVector & shape)199 bool ToFormatBase::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
200 const ShapeVector &shape) {
201 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, false);
202 MS_ERROR_IF_NULL_W_RET_VAL(input, false);
203 if (shape.size() != kInputSizeFour) {
204 return false;
205 }
206 MS_ERROR_IF_NULL_W_RET_VAL(manager_, false);
207 auto node_users = manager_->node_users()[input];
208 for (auto &node_user : node_users) {
209 auto post_node = node_user.first;
210 if (!utils::isa<CNode>(post_node)) {
211 continue;
212 }
213 auto post_cnode = post_node->cast<CNodePtr>();
214 auto prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
215 MS_ERROR_IF_NULL_W_RET_VAL(prim, false);
216 if (prim->GetAttr(ops::kFormat) != nullptr) {
217 auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
218 if (node_format == format_) {
219 MS_LOG(DEBUG) << "this graph input don't need to change.";
220 return false;
221 }
222 }
223 }
224 return true;
225 }
226
HandleGraphInput(const FuncGraphPtr & func_graph)227 STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
228 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
229 auto graph_input = func_graph->get_inputs();
230 for (auto &input : graph_input) {
231 auto input_param = input->cast<ParameterPtr>();
232 MS_ERROR_IF_NULL_W_RET_VAL(input_param, lite::RET_ERROR);
233 auto abstract = input_param->abstract();
234 MS_ERROR_IF_NULL_W_RET_VAL(abstract, lite::RET_ERROR);
235 ShapeVector shape;
236 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
237 MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
238 return lite::RET_ERROR;
239 }
240 if (!DecideWhetherHandleGraphInput(func_graph, input_param, shape)) {
241 continue;
242 }
243 ShapeVector transfer_shape;
244 if (format_ == mindspore::NCHW) {
245 transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
246 } else {
247 transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
248 }
249 CNodePtr trans_cnode;
250 if (format_ == mindspore::NCHW) {
251 trans_cnode = opt::GenTransposeNode(func_graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
252 } else {
253 trans_cnode = opt::GenTransposeNode(func_graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
254 }
255 if (trans_cnode == nullptr) {
256 MS_LOG(ERROR) << "create transpose cnode failed.";
257 return lite::RET_ERROR;
258 }
259 auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
260 MS_ERROR_IF_NULL_W_RET_VAL(trans_prim, lite::RET_ERROR);
261 if (format_ == mindspore::NCHW) {
262 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
263 } else {
264 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
265 }
266 trans_cnode->set_abstract(abstract->Clone());
267 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
268 if (!manager_->Replace(input, trans_cnode)) {
269 MS_LOG(ERROR) << "replace old node failed, please check.";
270 return lite::RET_ERROR;
271 }
272 }
273 return lite::RET_OK;
274 }
275
DealConv2dTransposeFusionNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)276 STATUS ToFormatBase::DealConv2dTransposeFusionNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
277 const std::vector<int> &perm) {
278 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
279 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
280 const int kInputSizeIndex = 3;
281 auto prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
282 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
283 MS_ERROR_IF_NULL_W_RET_VAL(prim, lite::RET_ERROR);
284 auto val_ptr = prim->GetAttr(ops::kOriginalOpName);
285 if (val_ptr == nullptr || GetValue<std::string>(val_ptr) != "Conv2DBackpropInput" ||
286 cnode->size() < kInputSizeIndex + 1) { // no input_size
287 return lite::RET_OK;
288 }
289 if (func_graph->has_attr(lite::kIsDynamicShape) && GetValue<bool>(func_graph->get_attr(lite::kIsDynamicShape))) {
290 MS_LOG(DEBUG) << "Dynamic input shape does not need Conv2dTransposeFusion format conversion";
291 return lite::RET_OK;
292 }
293 auto gather_input = cnode->input(kInputSizeIndex);
294 MS_CHECK_TRUE_MSG(gather_input != nullptr, RET_ERROR, "gather input is nullptr");
295 auto abstract = gather_input->abstract();
296 MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr");
297 std::vector<int> gather_indices_n;
298 std::vector<int> gather_indices_hw;
299 std::vector<int> gather_indices_c;
300 auto value_ptr = MakeValue<int64_t>(NCHW);
301 if (perm == kNH2NC) { // NHWC To NCHW
302 gather_indices_n = {0}; // fetch N dimension
303 gather_indices_hw = {1, 2}; // fetch H and W dimension
304 gather_indices_c = {3}; // fetch C dimension
305 } else { // NCHW To NHWC
306 gather_indices_n = {0}; // fetch N dimension;
307 gather_indices_hw = {2, 3}; // fetch H and W dimension
308 gather_indices_c = {1}; // fetch C dimension
309 value_ptr = MakeValue<int64_t>(NHWC);
310 }
311 auto gather_name_n = cnode->fullname_with_scope() + "_gather_n";
312 auto gather_cnode_n = opt::GenGatherNode(func_graph, gather_input, gather_indices_n, gather_name_n);
313 MS_CHECK_TRUE_MSG(gather_cnode_n != nullptr, RET_ERROR, "create gather cnode n failed.");
314 auto gather_prim_n = GetValueNode<PrimitivePtr>(gather_cnode_n->input(0));
315 (void)gather_prim_n->AddAttr(ops::kFormat, value_ptr);
316 ShapeVector gather_n_shape = {1};
317 auto n_shape_ptr = std::make_shared<abstract::Shape>(gather_n_shape);
318 MS_CHECK_TRUE_MSG(n_shape_ptr != nullptr, RET_ERROR, "n_shape_ptr is nullptr.");
319 auto tmp_abstract = abstract->Clone();
320 tmp_abstract->set_shape(n_shape_ptr);
321 gather_cnode_n->set_abstract(tmp_abstract);
322
323 auto gather_name_c = cnode->fullname_with_scope() + "_gather_c";
324 auto gather_cnode_c = opt::GenGatherNode(func_graph, gather_input, gather_indices_c, gather_name_c);
325 MS_CHECK_TRUE_MSG(gather_cnode_c != nullptr, RET_ERROR, "create gather cnode c failed.");
326 auto gather_prim_c = GetValueNode<PrimitivePtr>(gather_cnode_c->input(0));
327 (void)gather_prim_c->AddAttr(ops::kFormat, value_ptr);
328 ShapeVector gather_c_shape = {1};
329 auto c_shape_ptr = std::make_shared<abstract::Shape>(gather_c_shape);
330 MS_CHECK_TRUE_MSG(c_shape_ptr != nullptr, RET_ERROR, "c_shape_ptr is nullptr.");
331 tmp_abstract = abstract->Clone();
332 tmp_abstract->set_shape(c_shape_ptr);
333 gather_cnode_c->set_abstract(tmp_abstract);
334
335 auto gather_name_hw = cnode->fullname_with_scope() + "_gather_hw";
336 auto gather_cnode_hw = opt::GenGatherNode(func_graph, gather_input, gather_indices_hw, gather_name_hw);
337 MS_CHECK_TRUE_MSG(gather_cnode_hw != nullptr, RET_ERROR, "create gather cnode hw failed.");
338 auto gather_prim_hw = GetValueNode<PrimitivePtr>(gather_cnode_hw->input(0));
339 (void)gather_prim_hw->AddAttr(ops::kFormat, value_ptr);
340 ShapeVector gather_hw_shape = {2};
341 auto hw_shape_ptr = std::make_shared<abstract::Shape>(gather_hw_shape);
342 MS_CHECK_TRUE_MSG(hw_shape_ptr != nullptr, RET_ERROR, "hw_shape_ptr is nullptr.");
343 tmp_abstract = abstract->Clone();
344 tmp_abstract->set_shape(hw_shape_ptr);
345 gather_cnode_hw->set_abstract(tmp_abstract);
346
347 std::vector<AnfNodePtr> concat_inputnodes;
348 if (perm == kNH2NC) {
349 concat_inputnodes = {gather_cnode_n, gather_cnode_c, gather_cnode_hw};
350 } else {
351 concat_inputnodes = {gather_cnode_n, gather_cnode_hw, gather_cnode_c};
352 }
353 auto concat_name = cnode->fullname_with_scope() + "_concat_gather";
354 auto concat_node = opt::GenConcatNode(func_graph, concat_inputnodes, concat_name);
355 MS_CHECK_TRUE_MSG(concat_node != nullptr, RET_ERROR, "create concat_node failed.");
356 auto concat_node_prim = GetValueNode<PrimitivePtr>(concat_node->input(0));
357 (void)concat_node_prim->AddAttr(ops::kFormat, value_ptr);
358 ShapeVector concat_shape = {4};
359 auto concat_shape_ptr = std::make_shared<abstract::Shape>(concat_shape);
360 MS_CHECK_TRUE_MSG(concat_shape_ptr != nullptr, RET_ERROR, "concat_shape_ptr is nullptr.");
361 tmp_abstract = abstract->Clone();
362 tmp_abstract->set_shape(concat_shape_ptr);
363 concat_node->set_abstract(tmp_abstract);
364 manager_->SetEdge(cnode, kInputSizeIndex, concat_node);
365 return lite::RET_OK;
366 }
367
SetCNodeFormat(const CNodePtr & cnode,mindspore::Format dst_format)368 void SetCNodeFormat(const CNodePtr &cnode, mindspore::Format dst_format) {
369 MS_ASSERT(cnode != nullptr);
370 // update the format of cnode.
371 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
372 MS_CHECK_TRUE_RET_VOID(prim != nullptr);
373 auto format_value = prim->GetAttr(ops::kOriginalFormat);
374 if (prim->GetAttr(ops::kFormat) == nullptr && format_value != nullptr) {
375 auto format = GetValue<int64_t>(format_value);
376 if (format == dst_format) {
377 (void)prim->AddAttr(ops::kFormat, format_value);
378 }
379 }
380 return;
381 }
382
HandleGraphNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)383 STATUS ToFormatBase::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
384 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, lite::RET_ERROR);
385 MS_ERROR_IF_NULL_W_RET_VAL(cnode, lite::RET_ERROR);
386 opt::TransTypePair trans_info;
387 if (GetTransNodeFormatType(cnode, &trans_info) != lite::RET_OK) {
388 MS_LOG(ERROR) << "obtain node's transferring format type failed, " << cnode->fullname_with_scope();
389 return lite::RET_ERROR;
390 }
391 if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
392 SetCNodeFormat(cnode, format_);
393 return lite::RET_NO_CHANGE;
394 }
395 auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? kNH2NC : kNC2NH;
396 auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? kNC2NH : kNH2NC;
397 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) &&
398 DealConv2dTransposeFusionNode(func_graph, cnode, before_perm) != lite::RET_OK) {
399 MS_LOG(ERROR) << "Deal conv2d transpose fusion attr: input_size failed." << cnode->fullname_with_scope();
400 return lite::RET_ERROR;
401 }
402 if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
403 MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
404 return lite::RET_ERROR;
405 }
406 if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
407 return lite::RET_OK;
408 }
409 if (ModifyCNode(cnode) != lite::RET_OK) {
410 MS_LOG(ERROR) << "adjust cnode's output shape failed, " << cnode->fullname_with_scope();
411 return lite::RET_ERROR;
412 }
413 if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
414 MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
415 return lite::RET_ERROR;
416 }
417 return lite::RET_OK;
418 }
419
BasicProcess(const FuncGraphPtr & func_graph,bool main_graph)420 bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
421 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, false);
422 manager_->AddFuncGraph(func_graph);
423 auto node_list = TopoSort(func_graph->get_return());
424 int status;
425 for (auto &node : node_list) {
426 MS_CHECK_TRUE_RET(node != nullptr, false);
427 if (!utils::isa<CNodePtr>(node)) {
428 continue;
429 }
430 auto cnode = node->cast<CNodePtr>();
431 if (IsSpecialType(cnode)) {
432 continue;
433 }
434 if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
435 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
436 if (sub_func_graph == nullptr) {
437 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
438 return false;
439 }
440 if (!BasicProcess(sub_func_graph, false)) {
441 MS_LOG(ERROR) << "process sub graph failed.";
442 return false;
443 }
444 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
445 if (sub_func_graph == nullptr) {
446 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
447 return false;
448 }
449 if (!BasicProcess(sub_func_graph, false)) {
450 MS_LOG(ERROR) << "process sub graph failed.";
451 return false;
452 }
453 continue;
454 }
455 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
456 if (prim == nullptr) {
457 MS_LOG(INFO) << "this is a call cnode, which input[0] is fg, node " << cnode->fullname_with_scope();
458 continue;
459 }
460 status = HandleGraphNode(func_graph, cnode);
461 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
462 MS_LOG(ERROR) << "handle node failed.";
463 return false;
464 }
465 }
466
467 if (main_graph && save_type_ != kMindIR) {
468 status = HandleGraphInput(func_graph);
469 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
470 MS_LOG(ERROR) << "handle graph input failed.";
471 return false;
472 }
473 }
474 return true;
475 }
476
ConvWeightFormatTrans(const FuncGraphPtr & graph,std::set<AnfNodePtr> * has_visited)477 STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<AnfNodePtr> *has_visited) {
478 MS_ERROR_IF_NULL_W_RET_VAL(graph, lite::RET_ERROR);
479 MS_ERROR_IF_NULL_W_RET_VAL(has_visited, lite::RET_ERROR);
480 manager_->AddFuncGraph(graph);
481 auto node_list = TopoSort(graph->get_return());
482 for (auto &node : node_list) {
483 if (!utils::isa<CNodePtr>(node)) {
484 continue;
485 }
486 auto cnode = node->cast<CNodePtr>();
487 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
488 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
489 if (sub_func_graph == nullptr) {
490 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
491 return lite::RET_NULL_PTR;
492 }
493 if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
494 MS_LOG(ERROR) << "transform conv weight format failed.";
495 return lite::RET_ERROR;
496 }
497 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
498 if (sub_func_graph == nullptr) {
499 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
500 return lite::RET_NULL_PTR;
501 }
502 if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
503 MS_LOG(ERROR) << "transform conv weight format failed.";
504 return lite::RET_ERROR;
505 }
506 continue;
507 }
508 if (!IsWeightNodeSensitive(cnode)) {
509 continue;
510 }
511 if (has_visited->find(node) != has_visited->end()) {
512 continue;
513 }
514 has_visited->insert(node);
515 schema::Format src_format = schema::Format_NUM_OF_FORMAT;
516 schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
517 if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) {
518 MS_LOG(ERROR) << "weight's src format and dst format get failed.";
519 return lite::RET_ERROR;
520 }
521 auto status = lite::UnifyConvWeightFormat(graph, cnode, src_format, dst_format, has_visited);
522 if (status != lite::RET_OK) {
523 MS_LOG(ERROR) << "unify conv weight failed, current node name is " << cnode->fullname_with_scope();
524 return status;
525 }
526 }
527 return lite::RET_OK;
528 }
529
NodeConvWeightFormatTrans(const FuncGraphPtr & graph,const CNodePtr & cnode)530 STATUS ToFormatBase::NodeConvWeightFormatTrans(const FuncGraphPtr &graph, const CNodePtr &cnode) {
531 MS_ERROR_IF_NULL_W_RET_VAL(graph, lite::RET_ERROR);
532 manager_->AddFuncGraph(graph);
533 if (!IsWeightNodeSensitive(cnode)) {
534 return RET_OK;
535 }
536
537 schema::Format src_format = schema::Format_NUM_OF_FORMAT;
538 schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
539 if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) {
540 MS_LOG(ERROR) << "weight's src format and dst format get failed.";
541 return lite::RET_ERROR;
542 }
543 std::set<AnfNodePtr> has_visited;
544 auto status = lite::UnifyConvWeightFormat(graph, cnode, src_format, dst_format, &has_visited);
545 if (status != lite::RET_OK) {
546 MS_LOG(ERROR) << "unify conv weight failed, current node name is " << cnode->fullname_with_scope();
547 return status;
548 }
549 return lite::RET_OK;
550 }
551
RunPassOneNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)552 STATUS ToFormatBase::RunPassOneNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
553 SetSensitiveOps();
554 node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
555 if (node_infer_shape_ == nullptr) {
556 MS_LOG(ERROR) << "create NodeInferShape object failed.";
557 return false;
558 }
559 manager_ = Manage(func_graph, true);
560 if (manager_ == nullptr) {
561 MS_LOG(ERROR) << "manager is nullptr.";
562 return false;
563 }
564 auto status = NodeConvWeightFormatTrans(func_graph, cnode);
565 if (status != lite::RET_OK) {
566 MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
567 return false;
568 }
569 status = HandleGraphNode(func_graph, cnode);
570 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
571 MS_LOG(ERROR) << "handle node failed.";
572 return RET_ERROR;
573 }
574 return RET_OK;
575 }
576
Run(const FuncGraphPtr & func_graph)577 bool ToFormatBase::Run(const FuncGraphPtr &func_graph) {
578 MS_CHECK_TRUE_RET(func_graph != nullptr, false);
579 auto value = func_graph->get_attr(ops::kFormat);
580 if (value != nullptr && GetValue<int64_t>(value) == format_) {
581 return true;
582 }
583 if (format_ != mindspore::NHWC && format_ != mindspore::NCHW) {
584 MS_LOG(ERROR) << "format transferring only support nc2nh or nh2nc.";
585 return false;
586 }
587 manager_ = Manage(func_graph, true);
588 if (manager_ == nullptr) {
589 MS_LOG(ERROR) << "manager is nullptr.";
590 return false;
591 }
592 node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
593 if (node_infer_shape_ == nullptr) {
594 MS_LOG(ERROR) << "create NodeInferShape object failed.";
595 return false;
596 }
597 std::set<AnfNodePtr> has_visited;
598 auto status = ConvWeightFormatTrans(func_graph, &has_visited);
599 if (status != lite::RET_OK) {
600 MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
601 return false;
602 }
603 SetSensitiveOps();
604 if (!BasicProcess(func_graph, true)) {
605 MS_LOG(ERROR) << "transfer format failed.";
606 return false;
607 }
608 func_graph->set_attr(ops::kFormat, MakeValue<int64_t>(format_));
609
610 return true;
611 }
612 } // namespace opt
613 } // namespace mindspore
614