1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/common/format_utils.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <unordered_map>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/image_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ops/nn_op_name.h"
26 #include "ops/auto_generate/gen_lite_ops.h"
27 #include "ops/adam.h"
28 #include "ops/apply_momentum.h"
29 #include "ops/batch_to_space.h"
30 #include "ops/crop.h"
31 #include "ops/depth_to_space.h"
32 #include "ops/fused_batch_norm.h"
33 #include "ops/fusion/activation.h"
34 #include "ops/fusion/add_fusion.h"
35 #include "ops/fusion/avg_pool_fusion.h"
36 #include "ops/fusion/conv2d_backprop_input_fusion.h"
37 #include "ops/fusion/conv2d_backprop_filter_fusion.h"
38 #include "ops/fusion/conv2d_fusion.h"
39 #include "ops/fusion/conv2d_transpose_fusion.h"
40 #include "ops/fusion/div_fusion.h"
41 #include "ops/fusion/max_pool_fusion.h"
42 #include "ops/fusion/mul_fusion.h"
43 #include "ops/fusion/pad_fusion.h"
44 #include "ops/fusion/pow_fusion.h"
45 #include "ops/fusion/prelu_fusion.h"
46 #include "ops/fusion/sub_fusion.h"
47 #include "ops/fusion/scale_fusion.h"
48 #include "ops/fusion/slice_fusion.h"
49 #include "ops/fusion/topk_fusion.h"
50 #include "ops/eltwise.h"
51 #include "ops/grad/activation_grad.h"
52 #include "ops/grad/max_pool_grad.h"
53 #include "ops/grad/resize_grad.h"
54 #include "ops/instance_norm.h"
55 #include "ops/lrn.h"
56 #include "ops/op_utils.h"
57 #include "ops/quant_dtype_cast.h"
58 #include "ops/resize.h"
59 #include "ops/roi_pooling.h"
60 #include "ops/sgd.h"
61 #include "ops/space_to_batch.h"
62 #include "ops/space_to_batch_nd.h"
63 #include "ops/space_to_depth.h"
64 #include "ops/deformable_conv2d.h"
65 #include "ops/roi_align.h"
66 #include "tools/lite_exporter/fetch_content.h"
67 #include "nnacl/op_base.h"
68 #include "tools/common/graph_util.h"
69
70 namespace mindspore {
71 namespace opt {
72 // treat the weight of deformableConv2d as an input instead of a const because of the ops infershape only support nchw.
73 static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
74 {ops::kNameAdam, {10}},
75 {ops::kNameApplyMomentum, {4}},
76 {ops::kNameAvgPoolFusion, {1}},
77 {ops::kNameAvgPoolGrad, {}},
78 {kBatchNormOpName, {1}},
79 {kBatchNormGradOpName, {1, 2}},
80 {ops::kNameBatchToSpace, {1}},
81 {ops::kNameBiasAdd, {1}},
82 {ops::kNameBiasAddGrad, {1}},
83 {ops::kNameConv2DBackpropInputFusion, {1}},
84 {ops::kNameConv2DBackpropFilterFusion, {1, 2}},
85 {ops::kNameConv2DFusion, {1}},
86 {ops::kNameConv2dTransposeFusion, {1}},
87 {ops::kNameDepthToSpace, {1}},
88 {ops::kNameDeformableConv2d, {1, 2}},
89 {ops::kNameFusedBatchNorm, {1}},
90 {ops::kNameInstanceNorm, {1}},
91 {ops::kNameGridSampler2D, {1}},
92 {ops::kNameLRN, {1}},
93 {ops::kNameMaxPoolFusion, {1}},
94 {ops::kNameMaxPoolGrad, {}},
95 {ops::kNamePReLUFusion, {1}},
96 {ops::kNameResize, {1}},
97 {ops::kNameResizeGrad, {}},
98 {ops::kNameROIAlign, {1}},
99 {ops::kNameROIPooling, {1}},
100 {ops::kNameSGD, {2}},
101 {ops::kNameSpaceToBatch, {1}},
102 {ops::kNameSpaceToBatchND, {1}},
103 {ops::kNameSpaceToDepth, {1}}};
104
105 static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {};
106
107 // treat the weight of deformableConv2d as an input instead of a const because of the ops infershape only support nchw.
108 static const std::unordered_map<std::string, std::vector<size_t>> ToNCHWOpMap = {
109 {ops::kNameAdam, {10}},
110 {ops::kNameApplyMomentum, {4}},
111 {ops::kNameAvgPoolFusion, {1}},
112 {ops::kNameAvgPoolGrad, {}},
113 {kBatchNormOpName, {1}},
114 {kBatchNormGradOpName, {1, 2}},
115 {ops::kNameBatchToSpace, {1}},
116 {ops::kNameBiasAdd, {1}},
117 {ops::kNameBiasAddGrad, {1}},
118 {ops::kNameConv2DBackpropInputFusion, {1}},
119 {ops::kNameConv2DBackpropFilterFusion, {1, 2}},
120 {ops::kNameConv2DFusion, {1}},
121 {ops::kNameConv2dTransposeFusion, {1}},
122 {ops::kNameDepthToSpace, {1}},
123 {ops::kNameDeformableConv2d, {1, 2}},
124 {ops::kNameFusedBatchNorm, {1}},
125 {ops::kNameGridSampler2D, {1}},
126 {ops::kNameInstanceNorm, {1}},
127 {ops::kNameLRN, {1}},
128 {ops::kNameMaxPoolFusion, {1}},
129 {ops::kNameMaxPoolGrad, {}},
130 {ops::kNamePReLUFusion, {1}},
131 {ops::kNameResize, {1}},
132 {ops::kNameResizeGrad, {}},
133 {ops::kNameROIAlign, {1}},
134 {ops::kNameROIPooling, {1}},
135 {ops::kNameSGD, {2}},
136 {ops::kNameSpaceToBatch, {1}},
137 {ops::kNameSpaceToBatchND, {1}},
138 {ops::kNameSpaceToDepth, {1}}};
139
140 // a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
141 static const std::unordered_map<std::string, bool> DynamicFormatOpList = {{ops::kNameAddN, false},
142 {ops::kNameCrop, true},
143 {ops::kNameSplit, true},
144 {ops::kNameConcat, true},
145 {ops::kNameEltwise, false},
146 {ops::kNameMaximum, false},
147 {ops::kNameAddFusion, false},
148 {ops::kNameDivFusion, false},
149 {ops::kNameMulFusion, false},
150 {ops::kNamePadFusion, false},
151 {ops::kNamePowFusion, false},
152 {ops::kNameActivation, false},
153 {ops::kNameSliceFusion, true},
154 {ops::kNameStridedSlice, true},
155 {ops::kNameActivationGrad, false},
156 {ops::kNameQuantDTypeCast, false},
157 {ops::kNameCast, false},
158 {ops::kNameSubFusion, false},
159 {ops::kNameErf, false}};
160
GetNHWCOpMap()161 const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
GetNCHWOpMap()162 const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
GetToNCHWOpMap()163 const std::unordered_map<std::string, std::vector<size_t>> &GetToNCHWOpMap() { return ToNCHWOpMap; }
IsDynamicFormatOp(const std::string & op_type)164 bool IsDynamicFormatOp(const std::string &op_type) {
165 return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end();
166 }
IsDynamicFormatOpWithAxis(const std::string & op_type)167 bool IsDynamicFormatOpWithAxis(const std::string &op_type) {
168 auto iter = DynamicFormatOpList.find(op_type);
169 return iter != DynamicFormatOpList.end() && iter->second;
170 }
171
GetCastDstDataType(const CNodePtr & cnode,int * perm)172 STATUS GetCastDstDataType(const CNodePtr &cnode, int *perm) {
173 MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_NULL_PTR);
174 MS_CHECK_TRUE_RET(perm != nullptr, lite::RET_NULL_PTR);
175 if (cnode->size() != kInputSizeThree) {
176 MS_LOG(ERROR) << "cast op input size must be three.";
177 return lite::RET_ERROR;
178 }
179 if (utils::isa<CNodePtr>(cnode->input(kInputIndexTwo))) {
180 return lite::RET_OK;
181 }
182 lite::DataInfo data_info;
183 int status;
184 if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
185 status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
186 } else {
187 status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
188 }
189 if (status != lite::RET_OK) {
190 MS_LOG(ERROR) << "fetch cast dst data type failed.";
191 return lite::RET_ERROR;
192 }
193 if (data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) {
194 MS_LOG(ERROR) << "cast data type is invalid.";
195 return lite::RET_ERROR;
196 }
197 if (data_info.data_.size() != sizeof(int32_t)) {
198 MS_LOG(ERROR) << "Data and datatype of data-info not match.";
199 return false;
200 }
201 *perm = reinterpret_cast<int *>(data_info.data_.data())[0];
202 return lite::RET_OK;
203 }
204
GetTransposePerm(const CNodePtr & cnode,std::vector<int> * perm)205 STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
206 MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_NULL_PTR);
207 MS_CHECK_TRUE_RET(perm != nullptr, lite::RET_NULL_PTR);
208 if (cnode->size() != kInputSizeThree) {
209 MS_LOG(ERROR) << "transpose op input size must be three.";
210 return lite::RET_ERROR;
211 }
212 if (utils::isa<CNodePtr>(cnode->input(kInputIndexTwo))) {
213 return lite::RET_OK;
214 }
215 lite::DataInfo data_info;
216 int status;
217 if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
218 status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
219 } else {
220 status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
221 }
222 if (status != lite::RET_OK) {
223 MS_LOG(ERROR) << "fetch transpose perm data failed.";
224 return lite::RET_ERROR;
225 }
226 if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
227 data_info.shape_.size() != 1) {
228 MS_LOG(ERROR) << "transpose perm data is invalid.";
229 return lite::RET_ERROR;
230 }
231 perm->resize(data_info.shape_[0]);
232 if (!data_info.data_.empty() &&
233 memcpy_s(perm->data(), perm->size() * sizeof(int), data_info.data_.data(), data_info.data_.size()) != EOK) {
234 MS_LOG(ERROR) << "memcpy data failed.";
235 return lite::RET_ERROR;
236 }
237 return lite::RET_OK;
238 }
239
RemoveIfMonad(const CNodePtr & cnode)240 void RemoveIfMonad(const CNodePtr &cnode) {
241 MS_ASSERT(cnode != nullptr);
242 std::vector<AnfNodePtr> inputs{cnode->input(0)};
243 for (size_t i = 1; i < cnode->size(); ++i) {
244 if (utils::isa<ValueNodePtr>(cnode->input(i))) {
245 auto value_node = cnode->input(i)->cast<ValueNodePtr>();
246 auto value = value_node->value();
247 if (value->isa<Monad>()) {
248 continue;
249 }
250 }
251 inputs.push_back(cnode->input(i));
252 }
253 cnode->set_inputs(inputs);
254 }
255
IsMonadNode(const AnfNodePtr & node)256 bool IsMonadNode(const AnfNodePtr &node) {
257 if (node == nullptr) {
258 MS_LOG(ERROR) << "input parameter is nullptr.";
259 return false;
260 }
261 if (!utils::isa<ValueNodePtr>(node)) {
262 return false;
263 }
264 auto value_node = node->cast<ValueNodePtr>();
265 auto value = value_node->value();
266 if (value->isa<Monad>()) {
267 return true;
268 }
269 return false;
270 }
271
IsSpecialType(const CNodePtr & cnode)272 bool IsSpecialType(const CNodePtr &cnode) {
273 return CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
274 CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimMakeTupleV2) ||
275 CheckPrimitiveType(cnode, prim::kPrimReturn);
276 }
277
DetermineCertainOutputFormat(const CNodePtr & cnode,int index,Format * format)278 int DetermineCertainOutputFormat(const CNodePtr &cnode, int index, Format *format) {
279 MS_CHECK_TRUE_MSG(cnode != nullptr && format != nullptr, RET_ERROR, "function's parameter is nullptr.");
280 *format = mindspore::NHWC;
281 auto prim = GetCNodePrimitive(cnode);
282 MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "get primitive failed");
283 auto value_ptr = prim->GetAttr(kOutputsFormat);
284 if (value_ptr != nullptr) {
285 MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "outputs_format attr should be sequence.");
286 auto formats = CastToInt(value_ptr);
287 if (index >= 0 && static_cast<size_t>(index) < formats.size()) {
288 MS_CHECK_TRUE_MSG(formats[index] >= NCHW && formats[index] <= NCW, RET_ERROR,
289 "format val is out of enum's range.");
290 *format = static_cast<Format>(formats[index]);
291 }
292 }
293 return RET_OK;
294 }
295
DetermineCertainVarInputFormat(const CNodePtr & cnode,size_t index,Format * format)296 int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format) {
297 MS_CHECK_TRUE_MSG(cnode != nullptr && format != nullptr, RET_ERROR, "function's parameter is nullptr.");
298 auto var_input_info = GetRealCertainVarInput(cnode, index);
299 if (var_input_info.first == nullptr) {
300 MS_LOG(DEBUG) << "cannot get the real var input.";
301 return RET_OK;
302 }
303 auto real_input_cnode = var_input_info.first;
304 auto item_index = var_input_info.second;
305 return DetermineCertainOutputFormat(real_input_cnode, item_index, format);
306 }
307
SetAbstractTensorInfo(const AbstractBasePtr & abstract)308 int SetAbstractTensorInfo(const AbstractBasePtr &abstract) {
309 if (!utils::isa<abstract::AbstractTensor>(abstract)) {
310 MS_LOG(ERROR) << "abstract is not a AbstractTensor";
311 return RET_ERROR;
312 }
313 if (abstract->isa<tensor::Tensor>()) {
314 MS_LOG(DEBUG) << "abstract have a tensor value.";
315 return RET_OK;
316 }
317 ShapeVector shape;
318 if (opt::FetchShapeFromAbstract(abstract, &shape) != RET_OK) {
319 MS_LOG(ERROR) << "FetchShapeFromAbstract failed.";
320 return RET_ERROR;
321 }
322 TypeId type = lite::GetAbstractTensorDtype(abstract->cast<abstract::AbstractTensorPtr>());
323 // For kObjectTypeTensorType, the abstract value is TensorList amd does not need to reset.
324 if (type != kObjectTypeTensorType) {
325 auto tensor_info = std::make_shared<tensor::Tensor>(type, shape);
326 if (tensor_info == nullptr) {
327 MS_LOG(ERROR) << "new tensor::Tensor failed";
328 return RET_ERROR;
329 }
330 abstract->set_value(tensor_info);
331 }
332 return RET_OK;
333 }
334
GetFormatSensitiveOpInsertIndex(const CNodePtr & cnode,std::vector<size_t> * insert_index)335 STATUS GetFormatSensitiveOpInsertIndex(const CNodePtr &cnode, std::vector<size_t> *insert_index) {
336 auto prim_node = cnode->input(0);
337 auto prim = GetValueNode<PrimitivePtr>(prim_node);
338 MS_ERROR_IF_NULL_W_RET_VAL(prim, lite::RET_ERROR);
339 MS_ERROR_IF_NULL_W_RET_VAL(insert_index, lite::RET_ERROR);
340 insert_index->clear();
341 if (ToNCHWOpMap.find(prim->name()) == ToNCHWOpMap.end()) {
342 return lite::RET_OK;
343 }
344
345 *insert_index = ToNCHWOpMap.at(prim->name());
346 if (insert_index->empty()) {
347 if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
348 GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
349 insert_index->push_back(1);
350 } else {
351 for (size_t i = 1; i < cnode->size(); ++i) {
352 insert_index->push_back(i);
353 }
354 }
355 }
356 return RET_OK;
357 }
358
ConvertAbstractFormatShape(const AbstractBasePtr & abstract,FormatTransNodeType perm)359 int ConvertAbstractFormatShape(const AbstractBasePtr &abstract, FormatTransNodeType perm) {
360 ShapeVector shape;
361 if (perm == kNONE) {
362 return lite::RET_OK;
363 }
364 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
365 MS_LOG(ERROR) << "fetch shape failed.";
366 return lite::RET_ERROR;
367 }
368 if (shape.size() < kInputSizeThree) {
369 MS_LOG(DEBUG) << "shape don't need to modify.";
370 return lite::RET_OK;
371 }
372 auto shape_value = abstract->BuildValue();
373 if (!shape_value->isa<tensor::Tensor>()) {
374 MS_LOG(INFO) << "abstract must be a tensor, but got: " << shape_value->ToString() << ".";
375 return RET_ERROR;
376 }
377 auto input_tensor = shape_value->cast<tensor::TensorPtr>();
378 MS_CHECK_FALSE(input_tensor == nullptr, RET_ERROR);
379 if (perm == kNHWC2NCHW) {
380 ShapeVector transfer_shape = shape;
381 size_t shape_size = shape.size();
382 transfer_shape[1] = shape[shape_size - 1];
383 for (size_t i = kDim2; i < shape_size; i++) {
384 transfer_shape[i] = shape[i - 1];
385 }
386 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
387 } else if (perm == kNCHW2NHWC) {
388 ShapeVector transfer_shape = shape;
389 size_t shape_size = shape.size();
390 transfer_shape[shape_size - 1] = shape[1];
391 for (size_t i = kDim1; i < shape_size - 1; i++) {
392 transfer_shape[i] = shape[i + 1];
393 }
394 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
395 }
396
397 return RET_OK;
398 }
399 } // namespace opt
400 } // namespace mindspore
401