• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/format/to_format_base.h"
18 #include <set>
19 #include "ops/op_utils.h"
20 #include "src/common/common.h"
21 #include "src/common/utils.h"
22 #include "tools/common/tensor_util.h"
23 #include "tools/converter/parser/parser_utils.h"
24 #include "nnacl/op_base.h"
25 
26 using mindspore::lite::NHWC_SHAPE;
27 namespace mindspore {
28 namespace opt {
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)29 STATUS ToFormatBase::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
30                                  bool before, size_t index) {
31   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
32   AnfNodePtr trans_input = before ? cnode->input(index) : cnode;
33   std::string trans_name = before ? cnode->fullname_with_scope() + "_pre_" + std::to_string(index - 1)
34                                   : cnode->fullname_with_scope() + "_post";
35   auto trans_cnode = opt::GenTransposeNode(func_graph, trans_input, perm, trans_name);
36   MS_ASSERT(trans_cnode != nullptr);
37   if (DecideWhetherInferShapeForNewNode()) {
38     auto status = node_infer_shape_->InferShape(trans_cnode);
39     if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
40       MS_LOG(ERROR) << "infer generated trans node failed.";
41       return lite::RET_ERROR;
42     }
43   } else {
44     auto abstract = trans_input->abstract();
45     if (abstract != nullptr) {
46       trans_cnode->set_abstract(abstract->Clone());
47     }
48   }
49   auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
50   MS_ASSERT(trans_prim != nullptr);
51   if (perm == kNC2NH) {
52     trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
53   } else if (perm == kNH2NC) {
54     trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
55   }
56   MS_ASSERT(manager_ != nullptr);
57   if (before) {
58     manager_->SetEdge(cnode, index, trans_cnode);
59   } else {
60     if (!manager_->Replace(cnode, trans_cnode)) {
61       MS_LOG(ERROR) << "replace old node failed, please check.";
62       return lite::RET_ERROR;
63     }
64   }
65   return lite::RET_OK;
66 }
67 
ModifyCNode(const CNodePtr & cnode)68 STATUS ToFormatBase::ModifyCNode(const CNodePtr &cnode) {
69   MS_ASSERT(cnode != nullptr);
70   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
71   if (prim == nullptr) {
72     MS_LOG(ERROR) << "current node's prim is nullptr, " << cnode->fullname_with_scope();
73     return lite::RET_ERROR;
74   }
75   auto insert_pos = sensitive_ops_[prim->name()];
76   if (insert_pos.empty() || std::find(insert_pos.begin(), insert_pos.end(), 1) != insert_pos.end()) {
77     prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format_));
78   }
79   auto abstract_base = cnode->abstract();
80   MS_ASSERT(abstract_base != nullptr);
81   std::vector<AbstractBasePtr> abstracts;
82   if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
83     auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
84     abstracts = abstract_tuple->elements();
85   } else {
86     abstracts.push_back(abstract_base);
87   }
88   for (auto &abstract : abstracts) {
89     ShapeVector shape;
90     if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
91       MS_LOG(ERROR) << "fetch shape failed, " << cnode->fullname_with_scope();
92       return lite::RET_ERROR;
93     }
94     if (shape.size() != kInputSizeFour) {
95       MS_LOG(DEBUG) << "shape don't need to modify.";
96       continue;
97     }
98     if (format_ == mindspore::NCHW) {
99       ShapeVector transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
100       abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
101     } else {
102       ShapeVector transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
103       abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
104     }
105   }
106   return lite::RET_OK;
107 }
108 
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)109 STATUS ToFormatBase::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
110                                         const std::vector<int> &perm) {
111   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
112   auto prim_node = cnode->input(0);
113   auto prim = GetValueNode<PrimitivePtr>(prim_node);
114   MS_ASSERT(prim != nullptr);
115   if (sensitive_ops_.find(prim->name()) == sensitive_ops_.end()) {
116     MS_LOG(ERROR) << "op don't meet condition.";
117     return lite::RET_ERROR;
118   }
119   auto insert_index = sensitive_ops_.at(prim->name());
120   if (insert_index.empty()) {
121     if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
122         GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
123       insert_index.push_back(1);
124     } else {
125       for (size_t i = 1; i < cnode->size(); ++i) {
126         insert_index.push_back(i);
127       }
128     }
129   }
130   for (auto &index : insert_index) {
131     if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
132       MS_LOG(ERROR) << "generate a new input failed.";
133       return lite::RET_ERROR;
134     }
135   }
136   return lite::RET_OK;
137 }
138 
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)139 STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
140                                          const std::vector<int> &perm) {
141   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
142   if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
143     if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
144       MS_LOG(ERROR) << "generate a new input failed.";
145       return lite::RET_ERROR;
146     }
147   } else {
148     auto node_users = manager_->node_users()[cnode];
149     for (auto &node_user : node_users) {
150       auto post_node = node_user.first;
151       CNodePtr tuple_get_item = nullptr;
152       if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
153         if (!train_flag_) {
154           MS_LOG(ERROR) << "post node is invalid.";
155           return lite::RET_ERROR;
156         } else {
157           tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
158           post_node = tuple_get_item;
159           manager_->Replace(cnode, tuple_get_item);
160         }
161       }
162       if (manager_->node_users()[post_node].empty()) {
163         continue;
164       }
165       auto post_cnode = post_node->cast<CNodePtr>();
166       if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
167         MS_LOG(ERROR) << "generate a new input failed.";
168         return lite::RET_ERROR;
169       }
170       if (tuple_get_item != nullptr) {
171         if (!manager_->Replace(tuple_get_item, tuple_get_item->input(1))) {
172           MS_LOG(ERROR) << "replace old node failed. please check.";
173           return lite::RET_ERROR;
174         }
175       }
176     }
177   }
178   return lite::RET_OK;
179 }
180 
DecideWhetherHandleGraphInput(const FuncGraphPtr & func_graph,const ParameterPtr & input,const ShapeVector & shape)181 bool ToFormatBase::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
182                                                  const ShapeVector &shape) {
183   MS_ASSERT(func_graph != nullptr && input != nullptr);
184   if (shape.size() != kInputSizeFour) {
185     return false;
186   }
187   MS_ASSERT(manager_ != nullptr);
188   auto node_users = manager_->node_users()[input];
189   for (auto &node_user : node_users) {
190     auto post_node = node_user.first;
191     if (!utils::isa<CNode>(post_node)) {
192       continue;
193     }
194     auto post_cnode = post_node->cast<CNodePtr>();
195     auto prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
196     MS_ASSERT(prim != nullptr);
197     if (prim->GetAttr(ops::kFormat) != nullptr) {
198       auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
199       if (node_format == format_) {
200         MS_LOG(DEBUG) << "this graph input don't need to change.";
201         return false;
202       }
203     }
204   }
205   return true;
206 }
207 
HandleGraphInput(const FuncGraphPtr & func_graph)208 STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
209   MS_ASSERT(func_graph != nullptr);
210   auto graph_input = func_graph->get_inputs();
211   for (auto &input : graph_input) {
212     auto input_param = input->cast<ParameterPtr>();
213     MS_ASSERT(input_param != nullptr);
214     auto abstract = input_param->abstract();
215     MS_ASSERT(abstract != nullptr);
216     ShapeVector shape;
217     if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
218       MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
219       return lite::RET_ERROR;
220     }
221     if (!DecideWhetherHandleGraphInput(func_graph, input_param, shape)) {
222       continue;
223     }
224     ShapeVector transfer_shape;
225     if (format_ == mindspore::NCHW) {
226       transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
227     } else {
228       transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
229     }
230     CNodePtr trans_cnode;
231     if (format_ == mindspore::NCHW) {
232       trans_cnode = opt::GenTransposeNode(func_graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
233     } else {
234       trans_cnode = opt::GenTransposeNode(func_graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
235     }
236     if (trans_cnode == nullptr) {
237       MS_LOG(ERROR) << "create transpose cnode failed.";
238       return lite::RET_ERROR;
239     }
240     auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
241     MS_ASSERT(trans_prim != nullptr);
242     if (format_ == mindspore::NCHW) {
243       trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
244     } else {
245       trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
246     }
247     trans_cnode->set_abstract(abstract->Clone());
248     abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
249     if (!manager_->Replace(input, trans_cnode)) {
250       MS_LOG(ERROR) << "replace old node failed, please check.";
251       return lite::RET_ERROR;
252     }
253   }
254   return lite::RET_OK;
255 }
256 
HandleGraphNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)257 STATUS ToFormatBase::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
258   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
259   opt::TransTypePair trans_info;
260   if (GetTransNodeFormatType(cnode, &trans_info) != lite::RET_OK) {
261     MS_LOG(ERROR) << "obtain node's transferring format type failed, " << cnode->fullname_with_scope();
262     return lite::RET_ERROR;
263   }
264   if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
265     return lite::RET_NO_CHANGE;
266   }
267   auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? kNH2NC : kNC2NH;
268   auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? kNC2NH : kNH2NC;
269   if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
270     MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
271     return lite::RET_ERROR;
272   }
273   if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
274     return lite::RET_OK;
275   }
276   if (ModifyCNode(cnode) != lite::RET_OK) {
277     MS_LOG(ERROR) << "adjust cnode's output shape failed, " << cnode->fullname_with_scope();
278     return lite::RET_ERROR;
279   }
280   if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
281     MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
282     return lite::RET_ERROR;
283   }
284   return lite::RET_OK;
285 }
286 
BasicProcess(const FuncGraphPtr & func_graph,bool main_graph)287 bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
288   MS_ASSERT(func_graph != nullptr);
289   manager_->AddFuncGraph(func_graph);
290   auto node_list = TopoSort(func_graph->get_return());
291   int status;
292   for (auto &node : node_list) {
293     MS_CHECK_TRUE_RET(node != nullptr, false);
294     if (!utils::isa<CNodePtr>(node)) {
295       continue;
296     }
297     auto cnode = node->cast<CNodePtr>();
298     if (IsSpecialType(cnode)) {
299       continue;
300     }
301     if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
302       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
303       if (sub_func_graph == nullptr) {
304         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
305         return false;
306       }
307       if (!BasicProcess(sub_func_graph, false)) {
308         MS_LOG(ERROR) << "process sub graph failed.";
309         return false;
310       }
311       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
312       if (sub_func_graph == nullptr) {
313         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
314         return false;
315       }
316       if (!BasicProcess(sub_func_graph, false)) {
317         MS_LOG(ERROR) << "process sub graph failed.";
318         return false;
319       }
320       continue;
321     }
322     status = HandleGraphNode(func_graph, cnode);
323     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
324       MS_LOG(ERROR) << "handle node failed.";
325       return false;
326     }
327   }
328   if (main_graph) {
329     status = HandleGraphInput(func_graph);
330     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
331       MS_LOG(ERROR) << "handle graph input failed.";
332       return false;
333     }
334   }
335   return true;
336 }
337 
ConvWeightFormatTrans(const FuncGraphPtr & graph,std::set<AnfNodePtr> * has_visited)338 STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<AnfNodePtr> *has_visited) {
339   MS_ASSERT(graph != nullptr && has_visited != nullptr);
340   manager_->AddFuncGraph(graph);
341   auto node_list = TopoSort(graph->get_return());
342   for (auto &node : node_list) {
343     if (!utils::isa<CNodePtr>(node)) {
344       continue;
345     }
346     auto cnode = node->cast<CNodePtr>();
347     if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
348       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
349       if (sub_func_graph == nullptr) {
350         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
351         return false;
352       }
353       if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
354         MS_LOG(ERROR) << "transform conv weight format failed.";
355         return lite::RET_ERROR;
356       }
357       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
358       if (sub_func_graph == nullptr) {
359         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
360         return false;
361       }
362       if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
363         MS_LOG(ERROR) << "transform conv weight format failed.";
364         return lite::RET_ERROR;
365       }
366       continue;
367     }
368     if (!IsWeightNodeSensitive(cnode)) {
369       continue;
370     }
371     if (has_visited->find(node) != has_visited->end()) {
372       continue;
373     }
374     has_visited->insert(node);
375     schema::Format src_format = schema::Format_NUM_OF_FORMAT;
376     schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
377     if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) {
378       MS_LOG(ERROR) << "weight's src format and dst format get failed.";
379       return lite::RET_ERROR;
380     }
381     auto status = lite::UnifyConvWeightFormat(graph, cnode, src_format, dst_format, has_visited);
382     if (status != lite::RET_OK) {
383       MS_LOG(ERROR) << "unify conv weight failed, current node name is " << cnode->fullname_with_scope();
384       return status;
385     }
386   }
387   return lite::RET_OK;
388 }
389 
Run(const FuncGraphPtr & func_graph)390 bool ToFormatBase::Run(const FuncGraphPtr &func_graph) {
391   MS_CHECK_TRUE_RET(func_graph != nullptr, false);
392   if (format_ != mindspore::NHWC && format_ != mindspore::NCHW) {
393     MS_LOG(ERROR) << "format transferring only support nc2nh or nh2nc.";
394     return false;
395   }
396   manager_ = Manage(func_graph, true);
397   if (manager_ == nullptr) {
398     MS_LOG(ERROR) << "manager is nullptr.";
399     return false;
400   }
401   node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
402   if (node_infer_shape_ == nullptr) {
403     MS_LOG(ERROR) << "create NodeInferShape object failed.";
404     return false;
405   }
406   std::set<AnfNodePtr> has_visited;
407   auto status = ConvWeightFormatTrans(func_graph, &has_visited);
408   if (status != lite::RET_OK) {
409     MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
410     return false;
411   }
412   SetSensitiveOps();
413   if (!BasicProcess(func_graph, true)) {
414     MS_LOG(ERROR) << "transfer format failed.";
415     return false;
416   }
417   return true;
418 }
419 }  // namespace opt
420 }  // namespace mindspore
421