1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/converter/parser/unify_format.h"
19 #include <map>
20 #include <memory>
21 #include <vector>
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "tools/common/tensor_util.h"
25 #include "nnacl/op_base.h"
26 #include "ops/op_utils.h"
27
28 namespace mindspore {
29 namespace lite {
30 namespace {
31 constexpr int kInputChannal = 3;
32 constexpr int kNumGatherIndiceSize_4 = 4;
33 constexpr int kNumGatherIndiceSize_2 = 2;
34 constexpr int kNumResizeInputShape = 2;
35 constexpr int kNumResizeInputShape_3 = 3;
36 constexpr int kNumResizeInputShape_5 = 5;
37 constexpr int kNumInputSize = 2;
38 constexpr int kNumIndex_0 = 0;
39 constexpr int kNumIndex_1 = 1;
40 constexpr int kNumIndex_2 = 2;
41 constexpr int kNumIndex_3 = 3;
DecideMINDIRConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)42 STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
43 MS_ASSERT(cnode != nullptr && src_format != nullptr);
44 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
45 if (prim == nullptr) {
46 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
47 return lite::RET_ERROR;
48 }
49 int64_t format =
50 prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
51 if (format == schema::Format_NHWC) {
52 *src_format = schema::Format_KHWC;
53 } else if (format == schema::Format_NCHW) {
54 *src_format = schema::Format_KCHW;
55 } else {
56 MS_LOG(ERROR) << "cnode format is invalid. " << cnode->fullname_with_scope();
57 return RET_ERROR;
58 }
59 return RET_OK;
60 }
61
DecideTFConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)62 STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
63 MS_ASSERT(cnode != nullptr && src_format != nullptr);
64 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
65 if (prim == nullptr) {
66 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
67 return lite::RET_ERROR;
68 }
69 bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
70 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
71 if (!is_depth_wise) {
72 *src_format = schema::Format_HWCK;
73 } else {
74 *src_format = schema::Format_HWKC;
75 }
76 } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
77 *src_format = schema::Format::Format_HWCK;
78 } else {
79 MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check. " << cnode->fullname_with_scope();
80 return RET_ERROR;
81 }
82 return RET_OK;
83 }
84
DecideTFLITEConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)85 STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
86 MS_ASSERT(cnode != nullptr && src_format != nullptr);
87 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
88 if (prim == nullptr) {
89 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
90 return lite::RET_ERROR;
91 }
92 bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
93 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
94 if (!is_depth_wise) {
95 *src_format = schema::Format_KHWC;
96 } else {
97 *src_format = schema::Format_CHWK;
98 }
99 } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
100 *src_format = schema::Format_CHWK;
101 } else {
102 MS_LOG(ERROR) << "cannot decide weight format, current situation need to check. " << cnode->fullname_with_scope();
103 return RET_NOT_SUPPORT;
104 }
105 return RET_OK;
106 }
107
DecideCAFFEConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)108 STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
109 MS_ASSERT(cnode != nullptr && src_format != nullptr);
110 *src_format = schema::Format_KCHW;
111 return RET_OK;
112 }
113
DecideONNXConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)114 STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
115 MS_ASSERT(cnode != nullptr && src_format != nullptr);
116 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
117 if (prim == nullptr) {
118 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
119 return lite::RET_ERROR;
120 }
121 int64_t format =
122 prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
123 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
124 opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
125 if (format == schema::Format_NHWC) {
126 *src_format = schema::Format_KHWC;
127 } else if (format == schema::Format_NCHW) {
128 *src_format = schema::Format_KCHW;
129 } else {
130 MS_LOG(ERROR) << "format is invalid, format is " << format;
131 return RET_ERROR;
132 }
133 } else {
134 MS_LOG(ERROR) << "unknown op, please check.";
135 return RET_ERROR;
136 }
137 return RET_OK;
138 }
139 } // namespace
140
GetTransNodeFormatType(const CNodePtr & cnode,opt::TransTypePair * trans_info)141 STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
142 MS_ASSERT(cnode != nullptr && trans_info != nullptr);
143 auto prim_node = cnode->input(0);
144 auto prim = GetValueNode<PrimitivePtr>(prim_node);
145 if (prim == nullptr) {
146 return RET_OK;
147 }
148 auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
149 auto &specify_nchw_op_map = opt::GetNCHWOpMap();
150 if (fmk_type_ == converter::kFmkTypeTflite) {
151 if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
152 return lite::RET_OK;
153 }
154 trans_info->pre_ = opt::kNHWC2NCHW;
155 trans_info->post_ = opt::kNCHW2NHWC;
156 } else if (fmk_type_ == converter::kFmkTypeTf) {
157 if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() &&
158 prim->GetAttr(ops::kOriginalFormat) != nullptr &&
159 GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) == NCHW) {
160 trans_info->pre_ = opt::kNCHW2NHWC;
161 trans_info->post_ = opt::kNHWC2NCHW;
162 }
163 if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
164 trans_info->pre_ = opt::kNHWC2NCHW;
165 trans_info->post_ = opt::kNCHW2NHWC;
166 }
167 } else {
168 if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
169 if (fmk_type_ == converter::kFmkTypeOnnx && prim->GetAttr(ops::kOriginalFormat) != nullptr &&
170 GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) == NHWC) {
171 return lite::RET_OK;
172 }
173 trans_info->pre_ = opt::kNCHW2NHWC;
174 trans_info->post_ = opt::kNHWC2NCHW;
175 }
176 }
177 return lite::RET_OK;
178 }
179
SetSensitiveOps()180 void UnifyFormatToNHWC::SetSensitiveOps() {
181 auto &sensitive_nhwc_ops = opt::GetNHWCOpMap();
182 auto &sensitive_nchw_ops = opt::GetNCHWOpMap();
183 sensitive_ops_.insert(sensitive_nhwc_ops.begin(), sensitive_nhwc_ops.end());
184 sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end());
185 }
186
DecideWhetherHandleGraphInput(const FuncGraphPtr & func_graph,const ParameterPtr & input,const ShapeVector & shape)187 bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
188 const ShapeVector &shape) {
189 MS_ASSERT(func_graph != nullptr);
190 MS_ASSERT(input != nullptr);
191 if (shape.size() != opt::kInputSizeFour) {
192 return false;
193 }
194 if (fmk_type_ == converter::kFmkTypeTf || fmk_type_ == converter::kFmkTypeTflite) {
195 return false;
196 }
197 if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::kFmkTypeOnnx &&
198 shape[opt::kInputIndexThree] == kInputChannal) {
199 return false;
200 }
201 return true;
202 }
203
DecideWhetherInferShapeForNewNode()204 bool UnifyFormatToNHWC::DecideWhetherInferShapeForNewNode() { return false; }
205
DecideConvWeightSrcAndDstFormat(const CNodePtr & cnode,schema::Format * src_format,schema::Format * dst_format)206 STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
207 schema::Format *dst_format) {
208 MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
209 *dst_format = schema::Format_KHWC;
210 std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::Format *)>> decide_functions = {
211 {converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
212 {converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
213 {converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
214 {converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
215 {converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat},
216 {converter::kFmkTypePytorch, DecideONNXConvWeightSrcFormat}};
217 auto iter = decide_functions.find(fmk_type_);
218 if (iter == decide_functions.end()) {
219 MS_LOG(ERROR) << "current fmk don't support, please check.";
220 return RET_NOT_SUPPORT;
221 }
222 auto decide_func = iter->second;
223 MS_ASSERT(decide_func != nullptr);
224 if (decide_func(cnode, src_format) != RET_OK) {
225 MS_LOG(ERROR) << "run decide function failed, cannot decide conv weight format.";
226 return RET_ERROR;
227 }
228 return RET_OK;
229 }
230
231 namespace {
232 template <typename T>
SetTensorInfoData(const tensor::Tensor * shape_tensor,tensor::TensorPtr tensor_info)233 int SetTensorInfoData(const tensor::Tensor *shape_tensor, tensor::TensorPtr tensor_info) {
234 std::vector<T> new_shape{};
235 T *shape_data = static_cast<T *>(shape_tensor->data_c());
236 MS_CHECK_TRUE_MSG(shape_tensor->data_c() != nullptr, RET_ERROR, "shape_tensor->data_c() is nullptr.");
237 new_shape = {shape_data[kNumIndex_2], shape_data[kNumIndex_3]};
238 auto status = memcpy_s(static_cast<T *>(tensor_info->data_c()), tensor_info->Size(), new_shape.data(),
239 sizeof(T) * new_shape.size());
240 MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "shape_tensor->data_c() is nullptr.");
241 return RET_OK;
242 }
243
SetShapeTensorInfo(const tensor::Tensor * shape_tensor,tensor::TensorPtr tensor_info)244 int SetShapeTensorInfo(const tensor::Tensor *shape_tensor, tensor::TensorPtr tensor_info) {
245 int ret = RET_ERROR;
246 switch (shape_tensor->data_type()) {
247 case kNumberTypeFloat:
248 case kNumberTypeFloat32: {
249 ret = SetTensorInfoData<float>(shape_tensor, tensor_info);
250 } break;
251 case kNumberTypeInt:
252 case kNumberTypeInt32: {
253 ret = SetTensorInfoData<int32_t>(shape_tensor, tensor_info);
254 } break;
255 case kNumberTypeInt64: {
256 ret = SetTensorInfoData<int64_t>(shape_tensor, tensor_info);
257 } break;
258 default:
259 MS_LOG(ERROR) << "not support shape tensor data type: " << shape_tensor->data_type();
260 }
261 return ret;
262 }
263 } // namespace
264
ConvertOnnxResizeForConstShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode)265 STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
266 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
267 auto resize_shape_node = cnode->input(kNumResizeInputShape)->cast<ParameterPtr>();
268 auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(resize_shape_node->default_param());
269 if (shape_tensor == nullptr) {
270 MS_LOG(ERROR) << " shape tensor is nullptr.";
271 return RET_ERROR;
272 }
273 MS_CHECK_TRUE_MSG(shape_tensor->data_c() != nullptr, RET_ERROR, "shape_tensor->data_c() is nullptr.");
274 MS_CHECK_TRUE_MSG(!shape_tensor->shape().empty(), RET_NULL_PTR, "out of range.");
275 if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_2 || shape_tensor->shape().at(0) == kNumResizeInputShape_3 ||
276 shape_tensor->shape().at(0) == kNumResizeInputShape_5) {
277 MS_LOG(INFO) << "The pass for the axis transform will run.";
278 return RET_OK;
279 }
280
281 if (shape_tensor->shape().at(0) != kNumGatherIndiceSize_4) {
282 MS_LOG(ERROR) << "The dimension of resize must be 2 or 4 or 3 or 5, but got " << shape_tensor->shape().at(0) << "!";
283 return RET_ERROR;
284 }
285
286 auto new_shape_node = func_graph->add_parameter();
287 MS_CHECK_TRUE_MSG(new_shape_node != nullptr, RET_NULL_PTR, "new_shape_node is nullptr.");
288 auto tensor_info = CreateTensorInfo(nullptr, 0, std::vector<int64_t>{kNumInputSize}, shape_tensor->data_type());
289 if (tensor_info == nullptr) {
290 MS_LOG(ERROR) << "create tensor info failed.";
291 return RET_ERROR;
292 }
293 auto status = SetShapeTensorInfo(shape_tensor.get(), tensor_info);
294 if (status != RET_OK) {
295 MS_LOG(ERROR) << "set shape for tensor info failed";
296 return RET_ERROR;
297 }
298 status = InitParameterFromTensorInfo(new_shape_node, tensor_info);
299 new_shape_node->set_name(cnode->fullname_with_scope() + "_" + resize_shape_node->fullname_with_scope());
300 if (status != RET_OK) {
301 MS_LOG(ERROR) << "init parameter from tensor info failed";
302 return RET_ERROR;
303 }
304 manager_->SetEdge(cnode, kNumResizeInputShape, new_shape_node);
305 return RET_OK;
306 }
307
ConvertOnnxResizeForVariableShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode)308 STATUS UnifyFormatToNHWC::ConvertOnnxResizeForVariableShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
309 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
310 auto gather_name = cnode->fullname_with_scope() + "_gather";
311 auto gather_input = cnode->input(kNumResizeInputShape);
312 MS_CHECK_TRUE_MSG(gather_input != nullptr, RET_ERROR, "gather_input is nullptr.");
313 auto abstract = cnode->input(kNumResizeInputShape)->abstract();
314 MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr.");
315 std::vector<int> gather_indices = {kNumIndex_2, kNumIndex_3}; // fetch H and W dimension
316 auto gather_cnode = opt::GenGatherNode(func_graph, gather_input, gather_indices, gather_name);
317 if (gather_cnode == nullptr) {
318 MS_LOG(ERROR) << "create gather cnode failed.";
319 return RET_ERROR;
320 }
321 ShapeVector indices_shape = {kNumGatherIndiceSize_4};
322 auto gather_prim = GetValueNode<PrimitivePtr>(gather_cnode->input(0));
323 MS_CHECK_TRUE_MSG(gather_prim != nullptr, RET_NULL_PTR, "gather_prim is nullptr.");
324 auto value_ptr = MakeValue<int64_t>(NHWC);
325 MS_CHECK_TRUE_MSG(value_ptr != nullptr, RET_NULL_PTR, "value_ptr is nullptr.");
326 (void)gather_prim->AddAttr(ops::kFormat, value_ptr);
327 gather_cnode->set_abstract(abstract->Clone());
328 auto shape_ptr = std::make_shared<abstract::Shape>(indices_shape);
329 MS_CHECK_TRUE_MSG(shape_ptr != nullptr, RET_NULL_PTR, "shape_ptr is nullptr.");
330 abstract->set_shape(shape_ptr);
331 manager_->SetEdge(cnode, kNumIndex_2, gather_cnode);
332 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
333 (void)prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
334 return RET_OK;
335 }
336
ResizeNodeProcess(const FuncGraphPtr & func_graph,const CNodePtr & cnode)337 STATUS UnifyFormatToNHWC::ResizeNodeProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
338 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
339 if (fmk_type_ != converter::kFmkTypeOnnx || cnode->size() <= kNumInputSize) {
340 return RET_OK;
341 }
342 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
343 if (prim == nullptr) {
344 return RET_ERROR;
345 }
346 auto val_ptr = prim->GetAttr(ops::kMode);
347 // ge::Resize sizes need nchw
348 if (val_ptr != nullptr && GetValue<std::string>(val_ptr) == "bicubic") {
349 return RET_OK;
350 }
351
352 auto size_node = cnode->input(kNumInputSize);
353 MS_CHECK_TRUE_RET(size_node != nullptr, RET_ERROR);
354 if (utils::isa<ParameterPtr>(size_node) && size_node->cast<ParameterPtr>() != nullptr &&
355 size_node->cast<ParameterPtr>()->has_default()) {
356 auto status = ConvertOnnxResizeForConstShape(func_graph, cnode);
357 if (status != RET_OK) {
358 MS_LOG(ERROR) << "ConvertOnnxResizeForConstShape failed.";
359 return RET_ERROR;
360 }
361 } else {
362 auto status = ConvertOnnxResizeForVariableShape(func_graph, cnode);
363 if (status != RET_OK) {
364 MS_LOG(ERROR) << "ConvertResizeForVariableShape failed.";
365 return RET_ERROR;
366 }
367 }
368 return RET_OK;
369 }
370
ProcessResizeAndFormat(const FuncGraphPtr & func_graph)371 bool UnifyFormatToNHWC::ProcessResizeAndFormat(const FuncGraphPtr &func_graph) {
372 MS_ASSERT(func_graph != nullptr);
373 manager_->AddFuncGraph(func_graph);
374 auto node_list = TopoSort(func_graph->get_return());
375 int status;
376 for (auto &node : node_list) {
377 if (!utils::isa<CNodePtr>(node)) {
378 continue;
379 }
380 auto cnode = node->cast<CNodePtr>();
381 if (opt::IsSpecialType(cnode)) {
382 continue;
383 }
384 auto anf_node = cnode->input(0);
385 MS_CHECK_TRUE_MSG(anf_node != nullptr, false, "anf_node is nullptr.");
386 auto value_node = anf_node->cast<ValueNodePtr>();
387 if (value_node == nullptr) {
388 if (anf_node->cast<CNodePtr>() != nullptr) {
389 continue;
390 }
391 MS_LOG(ERROR) << "cnode first input is invalid.";
392 return false;
393 }
394 auto prim = GetValueNode<PrimitivePtr>(anf_node);
395 if (prim == nullptr) {
396 continue;
397 }
398 if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
399 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kNumIndex_1));
400 if (sub_func_graph == nullptr) {
401 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
402 return false;
403 }
404 (void)ProcessResizeAndFormat(sub_func_graph);
405 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kNumIndex_2));
406 if (sub_func_graph == nullptr) {
407 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
408 return false;
409 }
410 (void)ProcessResizeAndFormat(sub_func_graph);
411 continue;
412 }
413 if (opt::CheckPrimitiveType(node, prim::kPrimResize)) {
414 status = ResizeNodeProcess(func_graph, cnode);
415 if (status != lite::RET_OK) {
416 return false;
417 }
418 }
419 }
420 return true;
421 }
422
Run(const FuncGraphPtr & func_graph)423 bool UnifyFormatToNHWC::Run(const FuncGraphPtr &func_graph) {
424 MS_ASSERT(func_graph != nullptr);
425 manager_ = Manage(func_graph, true);
426 if (manager_ == nullptr) {
427 MS_LOG(ERROR) << "manager is nullptr.";
428 return false;
429 }
430 if (!ProcessResizeAndFormat(func_graph)) {
431 MS_LOG(ERROR) << "ProcessResizeAndFormat failed.";
432 return false;
433 }
434 if (!opt::ToFormatBase::Run(func_graph)) {
435 MS_LOG(ERROR) << "run ToFormatBase failed.";
436 return false;
437 }
438 return true;
439 }
440 } // namespace lite
441 } // namespace mindspore
442