• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/common/format_utils.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <unordered_map>
22 #include "ops/adam.h"
23 #include "ops/addn.h"
24 #include "ops/apply_momentum.h"
25 #include "ops/batch_norm.h"
26 #include "ops/batch_to_space.h"
27 #include "ops/bias_add.h"
28 #include "ops/cast.h"
29 #include "ops/concat.h"
30 #include "ops/crop.h"
31 #include "ops/depth_to_space.h"
32 #include "ops/fused_batch_norm.h"
33 #include "ops/fusion/activation.h"
34 #include "ops/fusion/add_fusion.h"
35 #include "ops/fusion/avg_pool_fusion.h"
36 #include "ops/fusion/conv2d_backprop_input_fusion.h"
37 #include "ops/fusion/conv2d_backprop_filter_fusion.h"
38 #include "ops/fusion/conv2d_fusion.h"
39 #include "ops/fusion/conv2d_transpose_fusion.h"
40 #include "ops/fusion/div_fusion.h"
41 #include "ops/fusion/max_pool_fusion.h"
42 #include "ops/fusion/mul_fusion.h"
43 #include "ops/fusion/pad_fusion.h"
44 #include "ops/fusion/pow_fusion.h"
45 #include "ops/fusion/prelu_fusion.h"
46 #include "ops/fusion/slice_fusion.h"
47 #include "ops/fusion/topk_fusion.h"
48 #include "ops/eltwise.h"
49 #include "ops/grad/activation_grad.h"
50 #include "ops/grad/avg_pool_grad.h"
51 #include "ops/grad/batch_norm_grad.h"
52 #include "ops/grad/bias_add_grad.h"
53 #include "ops/grad/max_pool_grad.h"
54 #include "ops/grad/resize_grad.h"
55 #include "ops/instance_norm.h"
56 #include "ops/lrn.h"
57 #include "ops/maximum.h"
58 #include "ops/op_utils.h"
59 #include "ops/quant_dtype_cast.h"
60 #include "ops/resize.h"
61 #include "ops/roi_pooling.h"
62 #include "ops/sgd.h"
63 #include "ops/space_to_batch.h"
64 #include "ops/space_to_batch_nd.h"
65 #include "ops/space_to_depth.h"
66 #include "ops/split.h"
67 #include "ops/strided_slice.h"
68 #include "tools/anf_exporter/fetch_content.h"
69 #include "nnacl/op_base.h"
70 
71 namespace mindspore {
72 namespace opt {
73 static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
74   {ops::kNameAdam, {10}},
75   {ops::kNameApplyMomentum, {4}},
76   {ops::kNameAvgPoolFusion, {1}},
77   {ops::kNameAvgPoolGrad, {}},
78   {ops::kNameBatchNorm, {1}},
79   {ops::kNameBatchNormGrad, {1, 2}},
80   {ops::kNameBatchToSpace, {1}},
81   {ops::kNameBiasAdd, {1}},
82   {ops::kNameBiasAddGrad, {1}},
83   {ops::kNameConv2DBackpropInputFusion, {1}},
84   {ops::kNameConv2DBackpropFilterFusion, {1, 2}},
85   {ops::kNameConv2DFusion, {1}},
86   {ops::kNameConv2dTransposeFusion, {1}},
87   {ops::kNameDepthToSpace, {1}},
88   {ops::kNameFusedBatchNorm, {1}},
89   {ops::kNameLRN, {1}},
90   {ops::kNameMaxPoolFusion, {1}},
91   {ops::kNameMaxPoolGrad, {}},
92   {ops::kNamePReLUFusion, {1}},
93   {ops::kNameResize, {1}},
94   {ops::kNameResizeGrad, {}},
95   {ops::kNameROIPooling, {1}},
96   {ops::kNameSGD, {2}},
97   {ops::kNameSpaceToBatch, {1}},
98   {ops::kNameSpaceToBatchND, {1}},
99   {ops::kNameSpaceToDepth, {1}},
100   {ops::kNameTopKFusion, {1}}};
101 
102 static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ops::kNameInstanceNorm, {1}}};
103 
104 // a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
105 static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
106   {ops::kNameAddN, false},           {ops::kNameCrop, true},         {ops::kNameSplit, true},
107   {ops::kNameConcat, true},          {ops::kNameEltwise, false},     {ops::kNameMaximum, false},
108   {ops::kNameAddFusion, false},      {ops::kNameDivFusion, false},   {ops::kNameMulFusion, false},
109   {ops::kNamePadFusion, false},      {ops::kNamePowFusion, false},   {ops::kNameActivation, false},
110   {ops::kNameSliceFusion, true},     {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false},
111   {ops::kNameQuantDTypeCast, false}, {ops::kNameCast, false}};
112 
GetNHWCOpMap()113 const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
GetNCHWOpMap()114 const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
IsDynamicFormatOp(const std::string & op_type)115 bool IsDynamicFormatOp(const std::string &op_type) {
116   return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end();
117 }
IsDynamicFormatOpWithAxis(const std::string & op_type)118 bool IsDynamicFormatOpWithAxis(const std::string &op_type) {
119   auto iter = DynamicFormatOpList.find(op_type);
120   return iter != DynamicFormatOpList.end() && iter->second;
121 }
122 
GetTransposePerm(const CNodePtr & cnode,std::vector<int> * perm)123 STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
124   MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_NULL_PTR);
125   MS_CHECK_TRUE_RET(perm != nullptr, lite::RET_NULL_PTR);
126   if (cnode->size() != kInputSizeThree) {
127     MS_LOG(ERROR) << "transpose op input size must be three.";
128     return lite::RET_ERROR;
129   }
130   if (utils::isa<CNodePtr>(cnode->input(kInputIndexTwo))) {
131     return lite::RET_OK;
132   }
133   lite::DataInfo data_info;
134   int status;
135   if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
136     status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
137   } else {
138     status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
139   }
140   if (status != lite::RET_OK) {
141     MS_LOG(ERROR) << "fetch transpose perm data failed.";
142     return lite::RET_ERROR;
143   }
144   if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
145       data_info.shape_.size() != 1) {
146     MS_LOG(ERROR) << "transpose perm data is invalid.";
147     return lite::RET_ERROR;
148   }
149   perm->resize(data_info.shape_[0]);
150   if (!data_info.data_.empty() &&
151       memcpy_s(perm->data(), perm->size() * sizeof(int), data_info.data_.data(), data_info.data_.size()) != EOK) {
152     MS_LOG(ERROR) << "memcpy data failed.";
153     return lite::RET_ERROR;
154   }
155   return lite::RET_OK;
156 }
157 
RemoveIfMonad(const CNodePtr & cnode)158 void RemoveIfMonad(const CNodePtr &cnode) {
159   MS_ASSERT(cnode != nullptr);
160   std::vector<AnfNodePtr> inputs{cnode->input(0)};
161   for (size_t i = 1; i < cnode->size(); ++i) {
162     if (utils::isa<ValueNodePtr>(cnode->input(i))) {
163       auto value_node = cnode->input(i)->cast<ValueNodePtr>();
164       auto value = value_node->value();
165       if (value->isa<Monad>()) {
166         continue;
167       }
168     }
169     inputs.push_back(cnode->input(i));
170   }
171   cnode->set_inputs(inputs);
172 }
173 
IsMonadNode(const AnfNodePtr & node)174 bool IsMonadNode(const AnfNodePtr &node) {
175   if (node == nullptr) {
176     MS_LOG(ERROR) << "input parameter is nullptr.";
177     return false;
178   }
179   if (!utils::isa<ValueNodePtr>(node)) {
180     return false;
181   }
182   auto value_node = node->cast<ValueNodePtr>();
183   auto value = value_node->value();
184   if (value->isa<Monad>()) {
185     return true;
186   }
187   return false;
188 }
189 
IsSpecialType(const CNodePtr & cnode)190 bool IsSpecialType(const CNodePtr &cnode) {
191   return CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
192          CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, kPrimMakeTupleV2) ||
193          CheckPrimitiveType(cnode, prim::kPrimReturn);
194 }
195 }  // namespace opt
196 }  // namespace mindspore
197