• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/common/format_utils.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <unordered_map>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/image_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ops/nn_op_name.h"
26 #include "ops/auto_generate/gen_lite_ops.h"
27 #include "ops/adam.h"
28 #include "ops/apply_momentum.h"
29 #include "ops/batch_to_space.h"
30 #include "ops/crop.h"
31 #include "ops/depth_to_space.h"
32 #include "ops/fused_batch_norm.h"
33 #include "ops/fusion/activation.h"
34 #include "ops/fusion/add_fusion.h"
35 #include "ops/fusion/avg_pool_fusion.h"
36 #include "ops/fusion/conv2d_backprop_input_fusion.h"
37 #include "ops/fusion/conv2d_backprop_filter_fusion.h"
38 #include "ops/fusion/conv2d_fusion.h"
39 #include "ops/fusion/conv2d_transpose_fusion.h"
40 #include "ops/fusion/div_fusion.h"
41 #include "ops/fusion/max_pool_fusion.h"
42 #include "ops/fusion/mul_fusion.h"
43 #include "ops/fusion/pad_fusion.h"
44 #include "ops/fusion/pow_fusion.h"
45 #include "ops/fusion/prelu_fusion.h"
46 #include "ops/fusion/sub_fusion.h"
47 #include "ops/fusion/scale_fusion.h"
48 #include "ops/fusion/slice_fusion.h"
49 #include "ops/fusion/topk_fusion.h"
50 #include "ops/eltwise.h"
51 #include "ops/grad/activation_grad.h"
52 #include "ops/grad/max_pool_grad.h"
53 #include "ops/grad/resize_grad.h"
54 #include "ops/instance_norm.h"
55 #include "ops/lrn.h"
56 #include "ops/op_utils.h"
57 #include "ops/quant_dtype_cast.h"
58 #include "ops/resize.h"
59 #include "ops/roi_pooling.h"
60 #include "ops/sgd.h"
61 #include "ops/space_to_batch.h"
62 #include "ops/space_to_batch_nd.h"
63 #include "ops/space_to_depth.h"
64 #include "ops/deformable_conv2d.h"
65 #include "ops/roi_align.h"
66 #include "tools/lite_exporter/fetch_content.h"
67 #include "nnacl/op_base.h"
68 #include "tools/common/graph_util.h"
69 
70 namespace mindspore {
71 namespace opt {
72 // treat the weight of deformableConv2d as an input instead of a const because of the ops infershape only support nchw.
73 static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
74   {ops::kNameAdam, {10}},
75   {ops::kNameApplyMomentum, {4}},
76   {ops::kNameAvgPoolFusion, {1}},
77   {ops::kNameAvgPoolGrad, {}},
78   {kBatchNormOpName, {1}},
79   {kBatchNormGradOpName, {1, 2}},
80   {ops::kNameBatchToSpace, {1}},
81   {ops::kNameBiasAdd, {1}},
82   {ops::kNameBiasAddGrad, {1}},
83   {ops::kNameConv2DBackpropInputFusion, {1}},
84   {ops::kNameConv2DBackpropFilterFusion, {1, 2}},
85   {ops::kNameConv2DFusion, {1}},
86   {ops::kNameConv2dTransposeFusion, {1}},
87   {ops::kNameDepthToSpace, {1}},
88   {ops::kNameDeformableConv2d, {1, 2}},
89   {ops::kNameFusedBatchNorm, {1}},
90   {ops::kNameInstanceNorm, {1}},
91   {ops::kNameGridSampler2D, {1}},
92   {ops::kNameLRN, {1}},
93   {ops::kNameMaxPoolFusion, {1}},
94   {ops::kNameMaxPoolGrad, {}},
95   {ops::kNamePReLUFusion, {1}},
96   {ops::kNameResize, {1}},
97   {ops::kNameResizeGrad, {}},
98   {ops::kNameROIAlign, {1}},
99   {ops::kNameROIPooling, {1}},
100   {ops::kNameSGD, {2}},
101   {ops::kNameSpaceToBatch, {1}},
102   {ops::kNameSpaceToBatchND, {1}},
103   {ops::kNameSpaceToDepth, {1}}};
104 
105 static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {};
106 
107 // treat the weight of deformableConv2d as an input instead of a const because of the ops infershape only support nchw.
108 static const std::unordered_map<std::string, std::vector<size_t>> ToNCHWOpMap = {
109   {ops::kNameAdam, {10}},
110   {ops::kNameApplyMomentum, {4}},
111   {ops::kNameAvgPoolFusion, {1}},
112   {ops::kNameAvgPoolGrad, {}},
113   {kBatchNormOpName, {1}},
114   {kBatchNormGradOpName, {1, 2}},
115   {ops::kNameBatchToSpace, {1}},
116   {ops::kNameBiasAdd, {1}},
117   {ops::kNameBiasAddGrad, {1}},
118   {ops::kNameConv2DBackpropInputFusion, {1}},
119   {ops::kNameConv2DBackpropFilterFusion, {1, 2}},
120   {ops::kNameConv2DFusion, {1}},
121   {ops::kNameConv2dTransposeFusion, {1}},
122   {ops::kNameDepthToSpace, {1}},
123   {ops::kNameDeformableConv2d, {1, 2}},
124   {ops::kNameFusedBatchNorm, {1}},
125   {ops::kNameGridSampler2D, {1}},
126   {ops::kNameInstanceNorm, {1}},
127   {ops::kNameLRN, {1}},
128   {ops::kNameMaxPoolFusion, {1}},
129   {ops::kNameMaxPoolGrad, {}},
130   {ops::kNamePReLUFusion, {1}},
131   {ops::kNameResize, {1}},
132   {ops::kNameResizeGrad, {}},
133   {ops::kNameROIAlign, {1}},
134   {ops::kNameROIPooling, {1}},
135   {ops::kNameSGD, {2}},
136   {ops::kNameSpaceToBatch, {1}},
137   {ops::kNameSpaceToBatchND, {1}},
138   {ops::kNameSpaceToDepth, {1}}};
139 
140 // a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
141 static const std::unordered_map<std::string, bool> DynamicFormatOpList = {{ops::kNameAddN, false},
142                                                                           {ops::kNameCrop, true},
143                                                                           {ops::kNameSplit, true},
144                                                                           {ops::kNameConcat, true},
145                                                                           {ops::kNameEltwise, false},
146                                                                           {ops::kNameMaximum, false},
147                                                                           {ops::kNameAddFusion, false},
148                                                                           {ops::kNameDivFusion, false},
149                                                                           {ops::kNameMulFusion, false},
150                                                                           {ops::kNamePadFusion, false},
151                                                                           {ops::kNamePowFusion, false},
152                                                                           {ops::kNameActivation, false},
153                                                                           {ops::kNameSliceFusion, true},
154                                                                           {ops::kNameStridedSlice, true},
155                                                                           {ops::kNameActivationGrad, false},
156                                                                           {ops::kNameQuantDTypeCast, false},
157                                                                           {ops::kNameCast, false},
158                                                                           {ops::kNameSubFusion, false},
159                                                                           {ops::kNameErf, false}};
160 
GetNHWCOpMap()161 const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
GetNCHWOpMap()162 const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
GetToNCHWOpMap()163 const std::unordered_map<std::string, std::vector<size_t>> &GetToNCHWOpMap() { return ToNCHWOpMap; }
IsDynamicFormatOp(const std::string & op_type)164 bool IsDynamicFormatOp(const std::string &op_type) {
165   return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end();
166 }
IsDynamicFormatOpWithAxis(const std::string & op_type)167 bool IsDynamicFormatOpWithAxis(const std::string &op_type) {
168   auto iter = DynamicFormatOpList.find(op_type);
169   return iter != DynamicFormatOpList.end() && iter->second;
170 }
171 
GetCastDstDataType(const CNodePtr & cnode,int * perm)172 STATUS GetCastDstDataType(const CNodePtr &cnode, int *perm) {
173   MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_NULL_PTR);
174   MS_CHECK_TRUE_RET(perm != nullptr, lite::RET_NULL_PTR);
175   if (cnode->size() != kInputSizeThree) {
176     MS_LOG(ERROR) << "cast op input size must be three.";
177     return lite::RET_ERROR;
178   }
179   if (utils::isa<CNodePtr>(cnode->input(kInputIndexTwo))) {
180     return lite::RET_OK;
181   }
182   lite::DataInfo data_info;
183   int status;
184   if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
185     status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
186   } else {
187     status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
188   }
189   if (status != lite::RET_OK) {
190     MS_LOG(ERROR) << "fetch cast dst data type failed.";
191     return lite::RET_ERROR;
192   }
193   if (data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) {
194     MS_LOG(ERROR) << "cast data type is invalid.";
195     return lite::RET_ERROR;
196   }
197   if (data_info.data_.size() != sizeof(int32_t)) {
198     MS_LOG(ERROR) << "Data and datatype of data-info not match.";
199     return false;
200   }
201   *perm = reinterpret_cast<int *>(data_info.data_.data())[0];
202   return lite::RET_OK;
203 }
204 
GetTransposePerm(const CNodePtr & cnode,std::vector<int> * perm)205 STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
206   MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_NULL_PTR);
207   MS_CHECK_TRUE_RET(perm != nullptr, lite::RET_NULL_PTR);
208   if (cnode->size() != kInputSizeThree) {
209     MS_LOG(ERROR) << "transpose op input size must be three.";
210     return lite::RET_ERROR;
211   }
212   if (utils::isa<CNodePtr>(cnode->input(kInputIndexTwo))) {
213     return lite::RET_OK;
214   }
215   lite::DataInfo data_info;
216   int status;
217   if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
218     status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
219   } else {
220     status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
221   }
222   if (status != lite::RET_OK) {
223     MS_LOG(ERROR) << "fetch transpose perm data failed.";
224     return lite::RET_ERROR;
225   }
226   if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
227       data_info.shape_.size() != 1) {
228     MS_LOG(ERROR) << "transpose perm data is invalid.";
229     return lite::RET_ERROR;
230   }
231   perm->resize(data_info.shape_[0]);
232   if (!data_info.data_.empty() &&
233       memcpy_s(perm->data(), perm->size() * sizeof(int), data_info.data_.data(), data_info.data_.size()) != EOK) {
234     MS_LOG(ERROR) << "memcpy data failed.";
235     return lite::RET_ERROR;
236   }
237   return lite::RET_OK;
238 }
239 
RemoveIfMonad(const CNodePtr & cnode)240 void RemoveIfMonad(const CNodePtr &cnode) {
241   MS_ASSERT(cnode != nullptr);
242   std::vector<AnfNodePtr> inputs{cnode->input(0)};
243   for (size_t i = 1; i < cnode->size(); ++i) {
244     if (utils::isa<ValueNodePtr>(cnode->input(i))) {
245       auto value_node = cnode->input(i)->cast<ValueNodePtr>();
246       auto value = value_node->value();
247       if (value->isa<Monad>()) {
248         continue;
249       }
250     }
251     inputs.push_back(cnode->input(i));
252   }
253   cnode->set_inputs(inputs);
254 }
255 
IsMonadNode(const AnfNodePtr & node)256 bool IsMonadNode(const AnfNodePtr &node) {
257   if (node == nullptr) {
258     MS_LOG(ERROR) << "input parameter is nullptr.";
259     return false;
260   }
261   if (!utils::isa<ValueNodePtr>(node)) {
262     return false;
263   }
264   auto value_node = node->cast<ValueNodePtr>();
265   auto value = value_node->value();
266   if (value->isa<Monad>()) {
267     return true;
268   }
269   return false;
270 }
271 
IsSpecialType(const CNodePtr & cnode)272 bool IsSpecialType(const CNodePtr &cnode) {
273   return CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
274          CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimMakeTupleV2) ||
275          CheckPrimitiveType(cnode, prim::kPrimReturn);
276 }
277 
DetermineCertainOutputFormat(const CNodePtr & cnode,int index,Format * format)278 int DetermineCertainOutputFormat(const CNodePtr &cnode, int index, Format *format) {
279   MS_CHECK_TRUE_MSG(cnode != nullptr && format != nullptr, RET_ERROR, "function's parameter is nullptr.");
280   *format = mindspore::NHWC;
281   auto prim = GetCNodePrimitive(cnode);
282   MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "get primitive failed");
283   auto value_ptr = prim->GetAttr(kOutputsFormat);
284   if (value_ptr != nullptr) {
285     MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "outputs_format attr should be sequence.");
286     auto formats = CastToInt(value_ptr);
287     if (index >= 0 && static_cast<size_t>(index) < formats.size()) {
288       MS_CHECK_TRUE_MSG(formats[index] >= NCHW && formats[index] <= NCW, RET_ERROR,
289                         "format val is out of enum's range.");
290       *format = static_cast<Format>(formats[index]);
291     }
292   }
293   return RET_OK;
294 }
295 
DetermineCertainVarInputFormat(const CNodePtr & cnode,size_t index,Format * format)296 int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format) {
297   MS_CHECK_TRUE_MSG(cnode != nullptr && format != nullptr, RET_ERROR, "function's parameter is nullptr.");
298   auto var_input_info = GetRealCertainVarInput(cnode, index);
299   if (var_input_info.first == nullptr) {
300     MS_LOG(DEBUG) << "cannot get the real var input.";
301     return RET_OK;
302   }
303   auto real_input_cnode = var_input_info.first;
304   auto item_index = var_input_info.second;
305   return DetermineCertainOutputFormat(real_input_cnode, item_index, format);
306 }
307 
SetAbstractTensorInfo(const AbstractBasePtr & abstract)308 int SetAbstractTensorInfo(const AbstractBasePtr &abstract) {
309   if (!utils::isa<abstract::AbstractTensor>(abstract)) {
310     MS_LOG(ERROR) << "abstract is not a AbstractTensor";
311     return RET_ERROR;
312   }
313   if (abstract->isa<tensor::Tensor>()) {
314     MS_LOG(DEBUG) << "abstract have a tensor value.";
315     return RET_OK;
316   }
317   ShapeVector shape;
318   if (opt::FetchShapeFromAbstract(abstract, &shape) != RET_OK) {
319     MS_LOG(ERROR) << "FetchShapeFromAbstract failed.";
320     return RET_ERROR;
321   }
322   TypeId type = lite::GetAbstractTensorDtype(abstract->cast<abstract::AbstractTensorPtr>());
323   // For kObjectTypeTensorType, the abstract value is TensorList amd does not need to reset.
324   if (type != kObjectTypeTensorType) {
325     auto tensor_info = std::make_shared<tensor::Tensor>(type, shape);
326     if (tensor_info == nullptr) {
327       MS_LOG(ERROR) << "new tensor::Tensor failed";
328       return RET_ERROR;
329     }
330     abstract->set_value(tensor_info);
331   }
332   return RET_OK;
333 }
334 
GetFormatSensitiveOpInsertIndex(const CNodePtr & cnode,std::vector<size_t> * insert_index)335 STATUS GetFormatSensitiveOpInsertIndex(const CNodePtr &cnode, std::vector<size_t> *insert_index) {
336   auto prim_node = cnode->input(0);
337   auto prim = GetValueNode<PrimitivePtr>(prim_node);
338   MS_ERROR_IF_NULL_W_RET_VAL(prim, lite::RET_ERROR);
339   MS_ERROR_IF_NULL_W_RET_VAL(insert_index, lite::RET_ERROR);
340   insert_index->clear();
341   if (ToNCHWOpMap.find(prim->name()) == ToNCHWOpMap.end()) {
342     return lite::RET_OK;
343   }
344 
345   *insert_index = ToNCHWOpMap.at(prim->name());
346   if (insert_index->empty()) {
347     if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
348         GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
349       insert_index->push_back(1);
350     } else {
351       for (size_t i = 1; i < cnode->size(); ++i) {
352         insert_index->push_back(i);
353       }
354     }
355   }
356   return RET_OK;
357 }
358 
ConvertAbstractFormatShape(const AbstractBasePtr & abstract,FormatTransNodeType perm)359 int ConvertAbstractFormatShape(const AbstractBasePtr &abstract, FormatTransNodeType perm) {
360   ShapeVector shape;
361   if (perm == kNONE) {
362     return lite::RET_OK;
363   }
364   if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
365     MS_LOG(ERROR) << "fetch shape failed.";
366     return lite::RET_ERROR;
367   }
368   if (shape.size() < kInputSizeThree) {
369     MS_LOG(DEBUG) << "shape don't need to modify.";
370     return lite::RET_OK;
371   }
372   auto shape_value = abstract->BuildValue();
373   if (!shape_value->isa<tensor::Tensor>()) {
374     MS_LOG(INFO) << "abstract must be a tensor, but got: " << shape_value->ToString() << ".";
375     return RET_ERROR;
376   }
377   auto input_tensor = shape_value->cast<tensor::TensorPtr>();
378   MS_CHECK_FALSE(input_tensor == nullptr, RET_ERROR);
379   if (perm == kNHWC2NCHW) {
380     ShapeVector transfer_shape = shape;
381     size_t shape_size = shape.size();
382     transfer_shape[1] = shape[shape_size - 1];
383     for (size_t i = kDim2; i < shape_size; i++) {
384       transfer_shape[i] = shape[i - 1];
385     }
386     abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
387   } else if (perm == kNCHW2NHWC) {
388     ShapeVector transfer_shape = shape;
389     size_t shape_size = shape.size();
390     transfer_shape[shape_size - 1] = shape[1];
391     for (size_t i = kDim1; i < shape_size - 1; i++) {
392       transfer_shape[i] = shape[i + 1];
393     }
394     abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
395   }
396 
397   return RET_OK;
398 }
399 }  // namespace opt
400 }  // namespace mindspore
401