1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/format/to_format_base.h"
18 #include <set>
19 #include "ops/op_utils.h"
20 #include "src/common/common.h"
21 #include "src/common/utils.h"
22 #include "tools/common/tensor_util.h"
23 #include "tools/converter/parser/parser_utils.h"
24 #include "nnacl/op_base.h"
25
26 using mindspore::lite::NHWC_SHAPE;
27 namespace mindspore {
28 namespace opt {
GenNewInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)29 STATUS ToFormatBase::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
30 bool before, size_t index) {
31 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
32 AnfNodePtr trans_input = before ? cnode->input(index) : cnode;
33 std::string trans_name = before ? cnode->fullname_with_scope() + "_pre_" + std::to_string(index - 1)
34 : cnode->fullname_with_scope() + "_post";
35 auto trans_cnode = opt::GenTransposeNode(func_graph, trans_input, perm, trans_name);
36 MS_ASSERT(trans_cnode != nullptr);
37 if (DecideWhetherInferShapeForNewNode()) {
38 auto status = node_infer_shape_->InferShape(trans_cnode);
39 if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
40 MS_LOG(ERROR) << "infer generated trans node failed.";
41 return lite::RET_ERROR;
42 }
43 } else {
44 auto abstract = trans_input->abstract();
45 if (abstract != nullptr) {
46 trans_cnode->set_abstract(abstract->Clone());
47 }
48 }
49 auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
50 MS_ASSERT(trans_prim != nullptr);
51 if (perm == kNC2NH) {
52 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
53 } else if (perm == kNH2NC) {
54 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
55 }
56 MS_ASSERT(manager_ != nullptr);
57 if (before) {
58 manager_->SetEdge(cnode, index, trans_cnode);
59 } else {
60 if (!manager_->Replace(cnode, trans_cnode)) {
61 MS_LOG(ERROR) << "replace old node failed, please check.";
62 return lite::RET_ERROR;
63 }
64 }
65 return lite::RET_OK;
66 }
67
ModifyCNode(const CNodePtr & cnode)68 STATUS ToFormatBase::ModifyCNode(const CNodePtr &cnode) {
69 MS_ASSERT(cnode != nullptr);
70 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
71 if (prim == nullptr) {
72 MS_LOG(ERROR) << "current node's prim is nullptr, " << cnode->fullname_with_scope();
73 return lite::RET_ERROR;
74 }
75 auto insert_pos = sensitive_ops_[prim->name()];
76 if (insert_pos.empty() || std::find(insert_pos.begin(), insert_pos.end(), 1) != insert_pos.end()) {
77 prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format_));
78 }
79 auto abstract_base = cnode->abstract();
80 MS_ASSERT(abstract_base != nullptr);
81 std::vector<AbstractBasePtr> abstracts;
82 if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
83 auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
84 abstracts = abstract_tuple->elements();
85 } else {
86 abstracts.push_back(abstract_base);
87 }
88 for (auto &abstract : abstracts) {
89 ShapeVector shape;
90 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
91 MS_LOG(ERROR) << "fetch shape failed, " << cnode->fullname_with_scope();
92 return lite::RET_ERROR;
93 }
94 if (shape.size() != kInputSizeFour) {
95 MS_LOG(DEBUG) << "shape don't need to modify.";
96 continue;
97 }
98 if (format_ == mindspore::NCHW) {
99 ShapeVector transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
100 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
101 } else {
102 ShapeVector transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
103 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
104 }
105 }
106 return lite::RET_OK;
107 }
108
InsertPreTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)109 STATUS ToFormatBase::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
110 const std::vector<int> &perm) {
111 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
112 auto prim_node = cnode->input(0);
113 auto prim = GetValueNode<PrimitivePtr>(prim_node);
114 MS_ASSERT(prim != nullptr);
115 if (sensitive_ops_.find(prim->name()) == sensitive_ops_.end()) {
116 MS_LOG(ERROR) << "op don't meet condition.";
117 return lite::RET_ERROR;
118 }
119 auto insert_index = sensitive_ops_.at(prim->name());
120 if (insert_index.empty()) {
121 if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
122 GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
123 insert_index.push_back(1);
124 } else {
125 for (size_t i = 1; i < cnode->size(); ++i) {
126 insert_index.push_back(i);
127 }
128 }
129 }
130 for (auto &index : insert_index) {
131 if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
132 MS_LOG(ERROR) << "generate a new input failed.";
133 return lite::RET_ERROR;
134 }
135 }
136 return lite::RET_OK;
137 }
138
InsertPostTransNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm)139 STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
140 const std::vector<int> &perm) {
141 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
142 if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
143 if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
144 MS_LOG(ERROR) << "generate a new input failed.";
145 return lite::RET_ERROR;
146 }
147 } else {
148 auto node_users = manager_->node_users()[cnode];
149 for (auto &node_user : node_users) {
150 auto post_node = node_user.first;
151 CNodePtr tuple_get_item = nullptr;
152 if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
153 if (!train_flag_) {
154 MS_LOG(ERROR) << "post node is invalid.";
155 return lite::RET_ERROR;
156 } else {
157 tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
158 post_node = tuple_get_item;
159 manager_->Replace(cnode, tuple_get_item);
160 }
161 }
162 if (manager_->node_users()[post_node].empty()) {
163 continue;
164 }
165 auto post_cnode = post_node->cast<CNodePtr>();
166 if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
167 MS_LOG(ERROR) << "generate a new input failed.";
168 return lite::RET_ERROR;
169 }
170 if (tuple_get_item != nullptr) {
171 if (!manager_->Replace(tuple_get_item, tuple_get_item->input(1))) {
172 MS_LOG(ERROR) << "replace old node failed. please check.";
173 return lite::RET_ERROR;
174 }
175 }
176 }
177 }
178 return lite::RET_OK;
179 }
180
DecideWhetherHandleGraphInput(const FuncGraphPtr & func_graph,const ParameterPtr & input,const ShapeVector & shape)181 bool ToFormatBase::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
182 const ShapeVector &shape) {
183 MS_ASSERT(func_graph != nullptr && input != nullptr);
184 if (shape.size() != kInputSizeFour) {
185 return false;
186 }
187 MS_ASSERT(manager_ != nullptr);
188 auto node_users = manager_->node_users()[input];
189 for (auto &node_user : node_users) {
190 auto post_node = node_user.first;
191 if (!utils::isa<CNode>(post_node)) {
192 continue;
193 }
194 auto post_cnode = post_node->cast<CNodePtr>();
195 auto prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
196 MS_ASSERT(prim != nullptr);
197 if (prim->GetAttr(ops::kFormat) != nullptr) {
198 auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
199 if (node_format == format_) {
200 MS_LOG(DEBUG) << "this graph input don't need to change.";
201 return false;
202 }
203 }
204 }
205 return true;
206 }
207
HandleGraphInput(const FuncGraphPtr & func_graph)208 STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
209 MS_ASSERT(func_graph != nullptr);
210 auto graph_input = func_graph->get_inputs();
211 for (auto &input : graph_input) {
212 auto input_param = input->cast<ParameterPtr>();
213 MS_ASSERT(input_param != nullptr);
214 auto abstract = input_param->abstract();
215 MS_ASSERT(abstract != nullptr);
216 ShapeVector shape;
217 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
218 MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
219 return lite::RET_ERROR;
220 }
221 if (!DecideWhetherHandleGraphInput(func_graph, input_param, shape)) {
222 continue;
223 }
224 ShapeVector transfer_shape;
225 if (format_ == mindspore::NCHW) {
226 transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
227 } else {
228 transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
229 }
230 CNodePtr trans_cnode;
231 if (format_ == mindspore::NCHW) {
232 trans_cnode = opt::GenTransposeNode(func_graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
233 } else {
234 trans_cnode = opt::GenTransposeNode(func_graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
235 }
236 if (trans_cnode == nullptr) {
237 MS_LOG(ERROR) << "create transpose cnode failed.";
238 return lite::RET_ERROR;
239 }
240 auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
241 MS_ASSERT(trans_prim != nullptr);
242 if (format_ == mindspore::NCHW) {
243 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
244 } else {
245 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
246 }
247 trans_cnode->set_abstract(abstract->Clone());
248 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
249 if (!manager_->Replace(input, trans_cnode)) {
250 MS_LOG(ERROR) << "replace old node failed, please check.";
251 return lite::RET_ERROR;
252 }
253 }
254 return lite::RET_OK;
255 }
256
HandleGraphNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)257 STATUS ToFormatBase::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
258 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
259 opt::TransTypePair trans_info;
260 if (GetTransNodeFormatType(cnode, &trans_info) != lite::RET_OK) {
261 MS_LOG(ERROR) << "obtain node's transferring format type failed, " << cnode->fullname_with_scope();
262 return lite::RET_ERROR;
263 }
264 if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
265 return lite::RET_NO_CHANGE;
266 }
267 auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? kNH2NC : kNC2NH;
268 auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? kNC2NH : kNH2NC;
269 if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
270 MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
271 return lite::RET_ERROR;
272 }
273 if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
274 return lite::RET_OK;
275 }
276 if (ModifyCNode(cnode) != lite::RET_OK) {
277 MS_LOG(ERROR) << "adjust cnode's output shape failed, " << cnode->fullname_with_scope();
278 return lite::RET_ERROR;
279 }
280 if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
281 MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
282 return lite::RET_ERROR;
283 }
284 return lite::RET_OK;
285 }
286
BasicProcess(const FuncGraphPtr & func_graph,bool main_graph)287 bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
288 MS_ASSERT(func_graph != nullptr);
289 manager_->AddFuncGraph(func_graph);
290 auto node_list = TopoSort(func_graph->get_return());
291 int status;
292 for (auto &node : node_list) {
293 MS_CHECK_TRUE_RET(node != nullptr, false);
294 if (!utils::isa<CNodePtr>(node)) {
295 continue;
296 }
297 auto cnode = node->cast<CNodePtr>();
298 if (IsSpecialType(cnode)) {
299 continue;
300 }
301 if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
302 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
303 if (sub_func_graph == nullptr) {
304 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
305 return false;
306 }
307 if (!BasicProcess(sub_func_graph, false)) {
308 MS_LOG(ERROR) << "process sub graph failed.";
309 return false;
310 }
311 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
312 if (sub_func_graph == nullptr) {
313 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
314 return false;
315 }
316 if (!BasicProcess(sub_func_graph, false)) {
317 MS_LOG(ERROR) << "process sub graph failed.";
318 return false;
319 }
320 continue;
321 }
322 status = HandleGraphNode(func_graph, cnode);
323 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
324 MS_LOG(ERROR) << "handle node failed.";
325 return false;
326 }
327 }
328 if (main_graph) {
329 status = HandleGraphInput(func_graph);
330 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
331 MS_LOG(ERROR) << "handle graph input failed.";
332 return false;
333 }
334 }
335 return true;
336 }
337
ConvWeightFormatTrans(const FuncGraphPtr & graph,std::set<AnfNodePtr> * has_visited)338 STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<AnfNodePtr> *has_visited) {
339 MS_ASSERT(graph != nullptr && has_visited != nullptr);
340 manager_->AddFuncGraph(graph);
341 auto node_list = TopoSort(graph->get_return());
342 for (auto &node : node_list) {
343 if (!utils::isa<CNodePtr>(node)) {
344 continue;
345 }
346 auto cnode = node->cast<CNodePtr>();
347 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
348 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
349 if (sub_func_graph == nullptr) {
350 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
351 return false;
352 }
353 if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
354 MS_LOG(ERROR) << "transform conv weight format failed.";
355 return lite::RET_ERROR;
356 }
357 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
358 if (sub_func_graph == nullptr) {
359 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
360 return false;
361 }
362 if (ConvWeightFormatTrans(sub_func_graph, has_visited) != lite::RET_OK) {
363 MS_LOG(ERROR) << "transform conv weight format failed.";
364 return lite::RET_ERROR;
365 }
366 continue;
367 }
368 if (!IsWeightNodeSensitive(cnode)) {
369 continue;
370 }
371 if (has_visited->find(node) != has_visited->end()) {
372 continue;
373 }
374 has_visited->insert(node);
375 schema::Format src_format = schema::Format_NUM_OF_FORMAT;
376 schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
377 if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) {
378 MS_LOG(ERROR) << "weight's src format and dst format get failed.";
379 return lite::RET_ERROR;
380 }
381 auto status = lite::UnifyConvWeightFormat(graph, cnode, src_format, dst_format, has_visited);
382 if (status != lite::RET_OK) {
383 MS_LOG(ERROR) << "unify conv weight failed, current node name is " << cnode->fullname_with_scope();
384 return status;
385 }
386 }
387 return lite::RET_OK;
388 }
389
Run(const FuncGraphPtr & func_graph)390 bool ToFormatBase::Run(const FuncGraphPtr &func_graph) {
391 MS_CHECK_TRUE_RET(func_graph != nullptr, false);
392 if (format_ != mindspore::NHWC && format_ != mindspore::NCHW) {
393 MS_LOG(ERROR) << "format transferring only support nc2nh or nh2nc.";
394 return false;
395 }
396 manager_ = Manage(func_graph, true);
397 if (manager_ == nullptr) {
398 MS_LOG(ERROR) << "manager is nullptr.";
399 return false;
400 }
401 node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
402 if (node_infer_shape_ == nullptr) {
403 MS_LOG(ERROR) << "create NodeInferShape object failed.";
404 return false;
405 }
406 std::set<AnfNodePtr> has_visited;
407 auto status = ConvWeightFormatTrans(func_graph, &has_visited);
408 if (status != lite::RET_OK) {
409 MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
410 return false;
411 }
412 SetSensitiveOps();
413 if (!BasicProcess(func_graph, true)) {
414 MS_LOG(ERROR) << "transfer format failed.";
415 return false;
416 }
417 return true;
418 }
419 } // namespace opt
420 } // namespace mindspore
421