1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/unify_format.h"
18 #include <map>
19 #include <memory>
20 #include <vector>
21 #include "tools/common/tensor_util.h"
22 #include "nnacl/op_base.h"
23
24 namespace mindspore {
25 namespace lite {
26 namespace {
27 constexpr int kInputChannal = 3;
28 constexpr int kNumGatherIndiceSize_4 = 4;
29 constexpr int kNumGatherIndiceSize_2 = 2;
30 constexpr int kNumResizeInputShape = 2;
31 constexpr int kNumInputSize = 2;
32 constexpr int kNumIndex_0 = 0;
33 constexpr int kNumIndex_1 = 1;
34 constexpr int kNumIndex_2 = 2;
35 constexpr int kNumIndex_3 = 3;
DecideMINDIRConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)36 STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
37 MS_ASSERT(cnode != nullptr && src_format != nullptr);
38 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
39 if (prim == nullptr) {
40 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
41 return lite::RET_ERROR;
42 }
43 int64_t format =
44 prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
45 if (format == schema::Format_NHWC) {
46 *src_format = schema::Format_KHWC;
47 } else if (format == schema::Format_NCHW) {
48 *src_format = schema::Format_KCHW;
49 } else {
50 MS_LOG(ERROR) << "cnode format is invalid. " << cnode->fullname_with_scope();
51 return RET_ERROR;
52 }
53 return RET_OK;
54 }
55
DecideTFConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)56 STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
57 MS_ASSERT(cnode != nullptr && src_format != nullptr);
58 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
59 if (prim == nullptr) {
60 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
61 return lite::RET_ERROR;
62 }
63 bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
64 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
65 if (!is_depth_wise) {
66 *src_format = schema::Format_HWCK;
67 } else {
68 *src_format = schema::Format_HWKC;
69 }
70 } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
71 *src_format = schema::Format::Format_HWCK;
72 } else {
73 MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check. " << cnode->fullname_with_scope();
74 return RET_ERROR;
75 }
76 return RET_OK;
77 }
78
DecideTFLITEConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)79 STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
80 MS_ASSERT(cnode != nullptr && src_format != nullptr);
81 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
82 if (prim == nullptr) {
83 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
84 return lite::RET_ERROR;
85 }
86 bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
87 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
88 if (!is_depth_wise) {
89 *src_format = schema::Format_KHWC;
90 } else {
91 *src_format = schema::Format_CHWK;
92 }
93 } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
94 *src_format = schema::Format_CHWK;
95 } else {
96 MS_LOG(ERROR) << "cannot decide weight format, current situation need to check. " << cnode->fullname_with_scope();
97 return RET_NOT_SUPPORT;
98 }
99 return RET_OK;
100 }
101
DecideCAFFEConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)102 STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
103 MS_ASSERT(cnode != nullptr && src_format != nullptr);
104 *src_format = schema::Format_KCHW;
105 return RET_OK;
106 }
107
DecideONNXConvWeightSrcFormat(const CNodePtr & cnode,schema::Format * src_format)108 STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
109 MS_ASSERT(cnode != nullptr && src_format != nullptr);
110 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
111 if (prim == nullptr) {
112 MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
113 return lite::RET_ERROR;
114 }
115 int64_t format =
116 prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
117 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
118 opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
119 if (format == schema::Format_NHWC) {
120 *src_format = schema::Format_KHWC;
121 } else if (format == schema::Format_NCHW) {
122 *src_format = schema::Format_KCHW;
123 } else {
124 MS_LOG(ERROR) << "format is invalid, format is " << format;
125 return RET_ERROR;
126 }
127 } else {
128 MS_LOG(ERROR) << "unknown op, please check.";
129 return RET_ERROR;
130 }
131 return RET_OK;
132 }
133 } // namespace
134
GetTransNodeFormatType(const CNodePtr & cnode,opt::TransTypePair * trans_info)135 STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
136 MS_ASSERT(cnode != nullptr && trans_info != nullptr);
137 auto prim_node = cnode->input(0);
138 auto prim = GetValueNode<PrimitivePtr>(prim_node);
139 if (prim == nullptr) {
140 return RET_OK;
141 }
142 auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
143 auto &specify_nchw_op_map = opt::GetNCHWOpMap();
144 if (fmk_type_ == converter::kFmkTypeTflite) {
145 if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
146 return lite::RET_OK;
147 }
148 trans_info->pre_ = opt::kNHWC2NCHW;
149 trans_info->post_ = opt::kNCHW2NHWC;
150 } else if (fmk_type_ == converter::kFmkTypeTf) {
151 if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() &&
152 prim->GetAttr(ops::kOriginalFormat) != nullptr &&
153 GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) == NCHW) {
154 trans_info->pre_ = opt::kNCHW2NHWC;
155 trans_info->post_ = opt::kNHWC2NCHW;
156 }
157 if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
158 trans_info->pre_ = opt::kNHWC2NCHW;
159 trans_info->post_ = opt::kNCHW2NHWC;
160 }
161 } else {
162 if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
163 if (fmk_type_ == converter::kFmkTypeOnnx && prim->GetAttr(ops::kOriginalFormat) != nullptr &&
164 GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) == NHWC) {
165 return lite::RET_OK;
166 }
167 trans_info->pre_ = opt::kNCHW2NHWC;
168 trans_info->post_ = opt::kNHWC2NCHW;
169 }
170 }
171 return lite::RET_OK;
172 }
173
SetSensitiveOps()174 void UnifyFormatToNHWC::SetSensitiveOps() {
175 auto &sensitive_nhwc_ops = opt::GetNHWCOpMap();
176 auto &sensitive_nchw_ops = opt::GetNCHWOpMap();
177 sensitive_ops_.insert(sensitive_nhwc_ops.begin(), sensitive_nhwc_ops.end());
178 sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end());
179 }
180
DecideWhetherHandleGraphInput(const FuncGraphPtr & func_graph,const ParameterPtr & input,const ShapeVector & shape)181 bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
182 const ShapeVector &shape) {
183 MS_ASSERT(func_graph != nullptr);
184 MS_ASSERT(input != nullptr);
185 if (shape.size() != opt::kInputSizeFour) {
186 return false;
187 }
188 if (fmk_type_ == converter::kFmkTypeTf || fmk_type_ == converter::kFmkTypeTflite) {
189 return false;
190 }
191 if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::kFmkTypeOnnx &&
192 shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) {
193 return false;
194 }
195 return true;
196 }
197
DecideWhetherInferShapeForNewNode()198 bool UnifyFormatToNHWC::DecideWhetherInferShapeForNewNode() { return false; }
199
DecideConvWeightSrcAndDstFormat(const CNodePtr & cnode,schema::Format * src_format,schema::Format * dst_format)200 STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
201 schema::Format *dst_format) {
202 MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
203 *dst_format = schema::Format_KHWC;
204 std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::Format *)>> decide_functions = {
205 {converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
206 {converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
207 {converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
208 {converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
209 {converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
210 auto iter = decide_functions.find(fmk_type_);
211 if (iter == decide_functions.end()) {
212 MS_LOG(ERROR) << "current fmk don't support, please check.";
213 return RET_NOT_SUPPORT;
214 }
215 auto decide_func = iter->second;
216 MS_ASSERT(decide_func != nullptr);
217 if (decide_func(cnode, src_format) != RET_OK) {
218 MS_LOG(ERROR) << "run decide function failed, cannot decide conv weight format.";
219 return RET_ERROR;
220 }
221 return RET_OK;
222 }
223
ConvertOnnxResizeForConstShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode)224 STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
225 auto resize_shape_node = cnode->input(kNumResizeInputShape)->cast<ParameterPtr>();
226 auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(resize_shape_node->default_param());
227 if (shape_tensor == nullptr) {
228 MS_LOG(ERROR) << " shape tensor is nullptr.";
229 return RET_ERROR;
230 }
231 MS_CHECK_TRUE_MSG(shape_tensor->data_c() != nullptr, RET_ERROR, "shape_tensor->data_c() is nullptr.");
232 auto shape_data = static_cast<float *>(shape_tensor->data_c());
233 std::vector<float> new_shape;
234 MS_CHECK_TRUE_MSG(!shape_tensor->shape().empty(), RET_NULL_PTR, "out of range.");
235 if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_4) {
236 new_shape = {shape_data[kNumIndex_0], shape_data[kNumIndex_2], shape_data[kNumIndex_3], shape_data[kNumIndex_1]};
237 } else if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_2) {
238 return RET_OK;
239 } else {
240 return RET_ERROR;
241 }
242 auto new_shape_node = func_graph->add_parameter();
243 MS_CHECK_TRUE_MSG(new_shape_node != nullptr, RET_NULL_PTR, "new_shape_node is nullptr.");
244 auto tensor_info = CreateTensorInfo(nullptr, 0, shape_tensor->shape(), shape_tensor->data_type());
245 if (tensor_info == nullptr) {
246 MS_LOG(ERROR) << "create tensor info failed.";
247 return RET_ERROR;
248 }
249 auto new_shape_data = static_cast<float *>(tensor_info->data_c());
250 if (new_shape_data == nullptr) {
251 MS_LOG(ERROR) << "data is nullptr";
252 return RET_ERROR;
253 }
254 auto status = memcpy_s(new_shape_data, tensor_info->Size(), new_shape.data(), tensor_info->Size());
255 if (status != RET_OK) {
256 MS_LOG(ERROR) << "init parameter from tensor info failed";
257 return RET_ERROR;
258 }
259 status = InitParameterFromTensorInfo(new_shape_node, tensor_info);
260 if (status != RET_OK) {
261 MS_LOG(ERROR) << "init parameter from tensor info failed";
262 return RET_ERROR;
263 }
264 manager_->SetEdge(cnode, kNumResizeInputShape, new_shape_node);
265 return RET_OK;
266 }
267
ConvertOnnxResizeForVariableShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode)268 STATUS UnifyFormatToNHWC::ConvertOnnxResizeForVariableShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
269 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "func_graph is nullptr.");
270 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr.");
271 auto gather_name = cnode->fullname_with_scope() + "_gather";
272 auto gather_input = cnode->input(kNumResizeInputShape);
273 MS_CHECK_TRUE_MSG(gather_input != nullptr, RET_ERROR, "gather_input is nullptr.");
274 auto abstract = cnode->input(kNumResizeInputShape)->abstract();
275 MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr.");
276 std::vector<int> gather_indices = {0, 2, 3, 1}; // NCHW to NHWC
277 auto gather_cnode = opt::GenGatherNode(func_graph, gather_input, gather_indices, gather_name);
278 if (gather_cnode == nullptr) {
279 MS_LOG(ERROR) << "create gather cnode failed.";
280 return RET_ERROR;
281 }
282 ShapeVector indices_shape = {kNumGatherIndiceSize_4};
283 auto gather_prim = GetValueNode<PrimitivePtr>(gather_cnode->input(0));
284 MS_CHECK_TRUE_MSG(gather_prim != nullptr, RET_NULL_PTR, "gather_prim is nullptr.");
285 auto value_ptr = MakeValue<int64_t>(NHWC);
286 MS_CHECK_TRUE_MSG(value_ptr != nullptr, RET_NULL_PTR, "value_ptr is nullptr.");
287 (void)gather_prim->AddAttr(ops::kFormat, value_ptr);
288 gather_cnode->set_abstract(abstract->Clone());
289 auto shape_ptr = std::make_shared<abstract::Shape>(indices_shape);
290 MS_CHECK_TRUE_MSG(shape_ptr != nullptr, RET_NULL_PTR, "shape_ptr is nullptr.");
291 abstract->set_shape(shape_ptr);
292 manager_->SetEdge(cnode, kNumIndex_2, gather_cnode);
293 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
294 (void)prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
295 return RET_OK;
296 }
297
ResizeNodeProcess(const FuncGraphPtr & func_graph,const CNodePtr & cnode)298 STATUS UnifyFormatToNHWC::ResizeNodeProcess(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
299 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "func_graph is nullptr.");
300 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr.");
301 if (fmk_type_ != converter::kFmkTypeOnnx) {
302 return RET_OK;
303 }
304 if (cnode->inputs().size() > kNumInputSize && utils::isa<ParameterPtr>(cnode->input(kNumResizeInputShape))) {
305 auto status = ConvertOnnxResizeForConstShape(func_graph, cnode);
306 if (status != RET_OK) {
307 MS_LOG(ERROR) << "ConvertOnnxResizeForConstShape failed.";
308 return RET_ERROR;
309 }
310 } else if (cnode->inputs().size() > kNumInputSize && utils::isa<CNodePtr>(cnode->input(kNumResizeInputShape))) {
311 auto status = ConvertOnnxResizeForVariableShape(func_graph, cnode);
312 if (status != RET_OK) {
313 MS_LOG(ERROR) << "ConvertResizeForVariableShape failed.";
314 return RET_ERROR;
315 }
316 }
317 return RET_OK;
318 }
319
ProcessResizeAndFormat(const FuncGraphPtr & func_graph)320 bool UnifyFormatToNHWC::ProcessResizeAndFormat(const FuncGraphPtr &func_graph) {
321 MS_ASSERT(func_graph != nullptr);
322 manager_->AddFuncGraph(func_graph);
323 auto node_list = TopoSort(func_graph->get_return());
324 int status;
325 for (auto &node : node_list) {
326 if (!utils::isa<CNodePtr>(node)) {
327 continue;
328 }
329 auto cnode = node->cast<CNodePtr>();
330 if (opt::IsSpecialType(cnode)) {
331 continue;
332 }
333 auto value_node = cnode->input(0)->cast<ValueNodePtr>();
334 if (value_node == nullptr) {
335 if (cnode->input(0)->cast<CNodePtr>() != nullptr) {
336 continue;
337 }
338 MS_LOG(ERROR) << "cnode first input is invalid.";
339 return false;
340 }
341 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
342 if (prim == nullptr) {
343 continue;
344 }
345 if (prim->GetAttr(ops::kFormat) == nullptr && prim->GetAttr(ops::kOriginalFormat) != nullptr) {
346 prim->AddAttr(mindspore::ops::kFormat, prim->GetAttr(ops::kOriginalFormat));
347 }
348 if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
349 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kNumIndex_1));
350 if (sub_func_graph == nullptr) {
351 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
352 return false;
353 }
354 (void)ProcessResizeAndFormat(sub_func_graph);
355 sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kNumIndex_2));
356 if (sub_func_graph == nullptr) {
357 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
358 return false;
359 }
360 (void)ProcessResizeAndFormat(sub_func_graph);
361 continue;
362 }
363 if (opt::CheckPrimitiveType(node, prim::kPrimResize)) {
364 status = ResizeNodeProcess(func_graph, cnode);
365 if (status != lite::RET_OK) {
366 return false;
367 }
368 }
369 }
370 return true;
371 }
372
Run(const FuncGraphPtr & func_graph)373 bool UnifyFormatToNHWC::Run(const FuncGraphPtr &func_graph) {
374 MS_ASSERT(func_graph != nullptr);
375 manager_ = Manage(func_graph, true);
376 if (manager_ == nullptr) {
377 MS_LOG(ERROR) << "manager is nullptr.";
378 return false;
379 }
380 if (!ProcessResizeAndFormat(func_graph)) {
381 MS_LOG(ERROR) << "ProcessResizeAndFormat failed.";
382 return false;
383 }
384 if (!opt::ToFormatBase::Run(func_graph)) {
385 MS_LOG(ERROR) << "run ToFormatBase failed.";
386 return false;
387 }
388 return true;
389 }
390 } // namespace lite
391 } // namespace mindspore
392