• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/unify_format.h"
18 #include <map>
19 #include <memory>
20 #include <vector>
21 #include "tools/common/tensor_util.h"
22 #include "nnacl/op_base.h"
23 
24 namespace mindspore {
25 namespace lite {
26 namespace {
27 constexpr int kInputChannal = 3;
28 constexpr int kNumGatherIndiceSize_4 = 4;
29 constexpr int kNumGatherIndiceSize_2 = 2;
30 constexpr int kNumResizeInputShape = 2;
31 constexpr int kNumInputSize = 2;
32 constexpr int kNumIndex_0 = 0;
33 constexpr int kNumIndex_1 = 1;
34 constexpr int kNumIndex_2 = 2;
35 constexpr int kNumIndex_3 = 3;
DecideMINDIRConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)36 STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
37   MS_ASSERT(cnode != nullptr && src_format != nullptr);
38   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
39   if (prim == nullptr) {
40     MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
41     return lite::RET_ERROR;
42   }
43   int64_t format =
44     prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
45   if (format == schema::Format_NHWC) {
46     *src_format = schema::Format_KHWC;
47   } else if (format == schema::Format_NCHW) {
48     *src_format = schema::Format_KCHW;
49   } else {
50     MS_LOG(ERROR) << "cnode format is invalid. " << cnode->fullname_with_scope();
51     return RET_ERROR;
52   }
53   return RET_OK;
54 }
55 
DecideTFConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)56 STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
57   MS_ASSERT(cnode != nullptr && src_format != nullptr);
58   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
59   if (prim == nullptr) {
60     MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
61     return lite::RET_ERROR;
62   }
63   bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
64   if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
65     if (!is_depth_wise) {
66       *src_format = schema::Format_HWCK;
67     } else {
68       *src_format = schema::Format_HWKC;
69     }
70   } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
71     *src_format = schema::Format::Format_HWCK;
72   } else {
73     MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check. " << cnode->fullname_with_scope();
74     return RET_ERROR;
75   }
76   return RET_OK;
77 }
78 
DecideTFLITEConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)79 STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
80   MS_ASSERT(cnode != nullptr && src_format != nullptr);
81   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
82   if (prim == nullptr) {
83     MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
84     return lite::RET_ERROR;
85   }
86   bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
87   if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
88     if (!is_depth_wise) {
89       *src_format = schema::Format_KHWC;
90     } else {
91       *src_format = schema::Format_CHWK;
92     }
93   } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
94     *src_format = schema::Format_CHWK;
95   } else {
96     MS_LOG(ERROR) << "cannot decide weight format, current situation need to check. " << cnode->fullname_with_scope();
97     return RET_NOT_SUPPORT;
98   }
99   return RET_OK;
100 }
101 
DecideCAFFEConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)102 STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
103   MS_ASSERT(cnode != nullptr && src_format != nullptr);
104   *src_format = schema::Format_KCHW;
105   return RET_OK;
106 }
107 
DecideONNXConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)108 STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
109   MS_ASSERT(cnode != nullptr && src_format != nullptr);
110   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
111   if (prim == nullptr) {
112     MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
113     return lite::RET_ERROR;
114   }
115   int64_t format =
116     prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
117   if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
118       opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
119     if (format == schema::Format_NHWC) {
120       *src_format = schema::Format_KHWC;
121     } else if (format == schema::Format_NCHW) {
122       *src_format = schema::Format_KCHW;
123     } else {
124       MS_LOG(ERROR) << "format is invalid, format is " << format;
125       return RET_ERROR;
126     }
127   } else {
128     MS_LOG(ERROR) << "unknown op, please check.";
129     return RET_ERROR;
130   }
131   return RET_OK;
132 }
133 }  // namespace
134 
GetTransNodeFormatType(const CNodePtr & cnode,opt::TransTypePair * trans_info)135 STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
136   MS_ASSERT(cnode != nullptr && trans_info != nullptr);
137   auto prim_node = cnode->input(0);
138   auto prim = GetValueNode<PrimitivePtr>(prim_node);
139   if (prim == nullptr) {
140     return RET_OK;
141   }
142   auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
143   auto &specify_nchw_op_map = opt::GetNCHWOpMap();
144   if (fmk_type_ == converter::kFmkTypeTflite) {
145     if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
146       return lite::RET_OK;
147     }
148     trans_info->pre_ = opt::kNHWC2NCHW;
149     trans_info->post_ = opt::kNCHW2NHWC;
150   } else if (fmk_type_ == converter::kFmkTypeTf) {
151     if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() &&
152         prim->GetAttr(ops::kOriginalFormat) != nullptr &&
153         GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) == NCHW) {
154       trans_info->pre_ = opt::kNCHW2NHWC;
155       trans_info->post_ = opt::kNHWC2NCHW;
156     }
157     if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
158       trans_info->pre_ = opt::kNHWC2NCHW;
159       trans_info->post_ = opt::kNCHW2NHWC;
160     }
161   } else {
162     if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
163       if (fmk_type_ == converter::kFmkTypeOnnx && prim->GetAttr(ops::kOriginalFormat) != nullptr &&
164           GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) == NHWC) {
165         return lite::RET_OK;
166       }
167       trans_info->pre_ = opt::kNCHW2NHWC;
168       trans_info->post_ = opt::kNHWC2NCHW;
169     }
170   }
171   return lite::RET_OK;
172 }
173 
SetSensitiveOps()174 void UnifyFormatToNHWC::SetSensitiveOps() {
175   auto &sensitive_nhwc_ops = opt::GetNHWCOpMap();
176   auto &sensitive_nchw_ops = opt::GetNCHWOpMap();
177   sensitive_ops_.insert(sensitive_nhwc_ops.begin(), sensitive_nhwc_ops.end());
178   sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end());
179 }
180 
DecideWhetherHandleGraphInput(const FuncGraphPtr & func_graph,const ParameterPtr & input,const ShapeVector & shape)181 bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
182                                                       const ShapeVector &shape) {
183   MS_ASSERT(func_graph != nullptr);
184   MS_ASSERT(input != nullptr);
185   if (shape.size() != opt::kInputSizeFour) {
186     return false;
187   }
188   if (fmk_type_ == converter::kFmkTypeTf || fmk_type_ == converter::kFmkTypeTflite) {
189     return false;
190   }
191   if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::kFmkTypeOnnx &&
192       shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) {
193     return false;
194   }
195   return true;
196 }
197 
DecideWhetherInferShapeForNewNode()198 bool UnifyFormatToNHWC::DecideWhetherInferShapeForNewNode() { return false; }
199 
DecideConvWeightSrcAndDstFormat(const CNodePtr & cnode,schema::Format * src_format,schema::Format * dst_format)200 STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
201                                                           schema::Format *dst_format) {
202   MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
203   *dst_format = schema::Format_KHWC;
204   std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::Format *)>> decide_functions = {
205     {converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
206     {converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
207     {converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
208     {converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
209     {converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
210   auto iter = decide_functions.find(fmk_type_);
211   if (iter == decide_functions.end()) {
212     MS_LOG(ERROR) << "current fmk don't support, please check.";
213     return RET_NOT_SUPPORT;
214   }
215   auto decide_func = iter->second;
216   MS_ASSERT(decide_func != nullptr);
217   if (decide_func(cnode, src_format) != RET_OK) {
218     MS_LOG(ERROR) << "run decide function failed, cannot decide conv weight format.";
219     return RET_ERROR;
220   }
221   return RET_OK;
222 }
223 
ConvertOnnxResizeForConstShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode)224 STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
225   auto resize_shape_node = cnode->input(kNumResizeInputShape)->cast<ParameterPtr>();
226   auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(resize_shape_node->default_param());
227   if (shape_tensor == nullptr) {
228     MS_LOG(ERROR) << " shape tensor is nullptr.";
229     return RET_ERROR;
230   }
231   MS_CHECK_TRUE_MSG(shape_tensor->data_c() != nullptr, RET_ERROR, "shape_tensor->data_c() is nullptr.");
232   auto shape_data = static_cast<float *>(shape_tensor->data_c());
233   std::vector<float> new_shape;
234   MS_CHECK_TRUE_MSG(!shape_tensor->shape().empty(), RET_NULL_PTR, "out of range.");
235   if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_4) {
236     new_shape = {shape_data[kNumIndex_0], shape_data[kNumIndex_2], shape_data[kNumIndex_3], shape_data[kNumIndex_1]};
237   } else if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_2) {
238     return RET_OK;
239   } else {
240     return RET_ERROR;
241   }
242   auto new_shape_node = func_graph->add_parameter();
243   MS_CHECK_TRUE_MSG(new_shape_node != nullptr, RET_NULL_PTR, "new_shape_node is nullptr.");
244   auto tensor_info = CreateTensorInfo(nullptr, 0, shape_tensor->shape(), shape_tensor->data_type());
245   if (tensor_info == nullptr) {
246     MS_LOG(ERROR) << "create tensor info failed.";
247     return RET_ERROR;
248   }
249   auto new_shape_data = static_cast<float *>(tensor_info->data_c());
250   if (new_shape_data == nullptr) {
251     MS_LOG(ERROR) << "data is nullptr";
252     return RET_ERROR;
253   }
254   auto status = memcpy_s(new_shape_data, tensor_info->Size(), new_shape.data(), tensor_info->Size());
255   if (status != RET_OK) {
256     MS_LOG(ERROR) << "init parameter from tensor info failed";
257     return RET_ERROR;
258   }
259   status = InitParameterFromTensorInfo(new_shape_node, tensor_info);
260   if (status != RET_OK) {
261     MS_LOG(ERROR) << "init parameter from tensor info failed";
262     return RET_ERROR;
263   }
264   manager_->SetEdge(cnode, kNumResizeInputShape, new_shape_node);
265   return RET_OK;
266 }
267 
ConvertOnnxResizeForVariableShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode)268 STATUS UnifyFormatToNHWC::ConvertOnnxResizeForVariableShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
269   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "func_graph is nullptr.");
270   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr.");
271   auto gather_name = cnode->fullname_with_scope() + "_gather";
272   auto gather_input = cnode->input(kNumResizeInputShape);
273   MS_CHECK_TRUE_MSG(gather_input != nullptr, RET_ERROR, "gather_input is nullptr.");
274   auto abstract = cnode->input(kNumResizeInputShape)->abstract();
275   MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr.");
276   std::vector<int> gather_indices = {0, 2, 3, 1};  // NCHW to NHWC
277   auto gather_cnode = opt::GenGatherNode(func_graph, gather_input, gather_indices, gather_name);
278   if (gather_cnode == nullptr) {
279     MS_LOG(ERROR) << "create gather cnode failed.";
280     return RET_ERROR;
281   }
282   ShapeVector indices_shape = {kNumGatherIndiceSize_4};
283   auto gather_prim = GetValueNode<PrimitivePtr>(gather_cnode->input(0));
284   MS_CHECK_TRUE_MSG(gather_prim != nullptr, RET_NULL_PTR, "gather_prim is nullptr.");
285   auto value_ptr = MakeValue<int64_t>(NHWC);
286   MS_CHECK_TRUE_MSG(value_ptr != nullptr, RET_NULL_PTR, "value_ptr is nullptr.");
287   (void)gather_prim->AddAttr(ops::kFormat, value_ptr);
288   gather_cnode->set_abstract(abstract->Clone());
289   auto shape_ptr = std::make_shared<abstract::Shape>(indices_shape);
290   MS_CHECK_TRUE_MSG(shape_ptr != nullptr, RET_NULL_PTR, "shape_ptr is nullptr.");
291   abstract->set_shape(shape_ptr);
292   manager_->SetEdge(cnode, kNumIndex_2, gather_cnode);
293   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
294   (void)prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
295   return RET_OK;
296 }
297 
ResizeNodeProcess(const FuncGraphPtr & func_graph,const CNodePtr & cnode)298 STATUS UnifyFormatToNHWC::ResizeNodeProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
299   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "func_graph is nullptr.");
300   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr.");
301   if (fmk_type_ != converter::kFmkTypeOnnx) {
302     return RET_OK;
303   }
304   if (cnode->inputs().size() > kNumInputSize && utils::isa<ParameterPtr>(cnode->input(kNumResizeInputShape))) {
305     auto status = ConvertOnnxResizeForConstShape(func_graph, cnode);
306     if (status != RET_OK) {
307       MS_LOG(ERROR) << "ConvertOnnxResizeForConstShape failed.";
308       return RET_ERROR;
309     }
310   } else if (cnode->inputs().size() > kNumInputSize && utils::isa<CNodePtr>(cnode->input(kNumResizeInputShape))) {
311     auto status = ConvertOnnxResizeForVariableShape(func_graph, cnode);
312     if (status != RET_OK) {
313       MS_LOG(ERROR) << "ConvertResizeForVariableShape failed.";
314       return RET_ERROR;
315     }
316   }
317   return RET_OK;
318 }
319 
ProcessResizeAndFormat(const FuncGraphPtr & func_graph)320 bool UnifyFormatToNHWC::ProcessResizeAndFormat(const FuncGraphPtr &func_graph) {
321   MS_ASSERT(func_graph != nullptr);
322   manager_->AddFuncGraph(func_graph);
323   auto node_list = TopoSort(func_graph->get_return());
324   int status;
325   for (auto &node : node_list) {
326     if (!utils::isa<CNodePtr>(node)) {
327       continue;
328     }
329     auto cnode = node->cast<CNodePtr>();
330     if (opt::IsSpecialType(cnode)) {
331       continue;
332     }
333     auto value_node = cnode->input(0)->cast<ValueNodePtr>();
334     if (value_node == nullptr) {
335       if (cnode->input(0)->cast<CNodePtr>() != nullptr) {
336         continue;
337       }
338       MS_LOG(ERROR) << "cnode first input is invalid.";
339       return false;
340     }
341     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
342     if (prim == nullptr) {
343       continue;
344     }
345     if (prim->GetAttr(ops::kFormat) == nullptr && prim->GetAttr(ops::kOriginalFormat) != nullptr) {
346       prim->AddAttr(mindspore::ops::kFormat, prim->GetAttr(ops::kOriginalFormat));
347     }
348     if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
349       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kNumIndex_1));
350       if (sub_func_graph == nullptr) {
351         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
352         return false;
353       }
354       (void)ProcessResizeAndFormat(sub_func_graph);
355       sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kNumIndex_2));
356       if (sub_func_graph == nullptr) {
357         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
358         return false;
359       }
360       (void)ProcessResizeAndFormat(sub_func_graph);
361       continue;
362     }
363     if (opt::CheckPrimitiveType(node, prim::kPrimResize)) {
364       status = ResizeNodeProcess(func_graph, cnode);
365       if (status != lite::RET_OK) {
366         return false;
367       }
368     }
369   }
370   return true;
371 }
372 
Run(const FuncGraphPtr & func_graph)373 bool UnifyFormatToNHWC::Run(const FuncGraphPtr &func_graph) {
374   MS_ASSERT(func_graph != nullptr);
375   manager_ = Manage(func_graph, true);
376   if (manager_ == nullptr) {
377     MS_LOG(ERROR) << "manager is nullptr.";
378     return false;
379   }
380   if (!ProcessResizeAndFormat(func_graph)) {
381     MS_LOG(ERROR) << "ProcessResizeAndFormat failed.";
382     return false;
383   }
384   if (!opt::ToFormatBase::Run(func_graph)) {
385     MS_LOG(ERROR) << "run ToFormatBase failed.";
386     return false;
387   }
388   return true;
389 }
390 }  // namespace lite
391 }  // namespace mindspore
392