1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/common/format_utils.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <unordered_map>
22 #include "ops/adam.h"
23 #include "ops/addn.h"
24 #include "ops/apply_momentum.h"
25 #include "ops/batch_norm.h"
26 #include "ops/batch_to_space.h"
27 #include "ops/bias_add.h"
28 #include "ops/cast.h"
29 #include "ops/concat.h"
30 #include "ops/crop.h"
31 #include "ops/depth_to_space.h"
32 #include "ops/fused_batch_norm.h"
33 #include "ops/fusion/activation.h"
34 #include "ops/fusion/add_fusion.h"
35 #include "ops/fusion/avg_pool_fusion.h"
36 #include "ops/fusion/conv2d_backprop_input_fusion.h"
37 #include "ops/fusion/conv2d_backprop_filter_fusion.h"
38 #include "ops/fusion/conv2d_fusion.h"
39 #include "ops/fusion/conv2d_transpose_fusion.h"
40 #include "ops/fusion/div_fusion.h"
41 #include "ops/fusion/max_pool_fusion.h"
42 #include "ops/fusion/mul_fusion.h"
43 #include "ops/fusion/pad_fusion.h"
44 #include "ops/fusion/pow_fusion.h"
45 #include "ops/fusion/prelu_fusion.h"
46 #include "ops/fusion/slice_fusion.h"
47 #include "ops/fusion/topk_fusion.h"
48 #include "ops/eltwise.h"
49 #include "ops/grad/activation_grad.h"
50 #include "ops/grad/avg_pool_grad.h"
51 #include "ops/grad/batch_norm_grad.h"
52 #include "ops/grad/bias_add_grad.h"
53 #include "ops/grad/max_pool_grad.h"
54 #include "ops/grad/resize_grad.h"
55 #include "ops/instance_norm.h"
56 #include "ops/lrn.h"
57 #include "ops/maximum.h"
58 #include "ops/op_utils.h"
59 #include "ops/quant_dtype_cast.h"
60 #include "ops/resize.h"
61 #include "ops/roi_pooling.h"
62 #include "ops/sgd.h"
63 #include "ops/space_to_batch.h"
64 #include "ops/space_to_batch_nd.h"
65 #include "ops/space_to_depth.h"
66 #include "ops/split.h"
67 #include "ops/strided_slice.h"
68 #include "tools/anf_exporter/fetch_content.h"
69 #include "nnacl/op_base.h"
70
71 namespace mindspore {
72 namespace opt {
73 static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
74 {ops::kNameAdam, {10}},
75 {ops::kNameApplyMomentum, {4}},
76 {ops::kNameAvgPoolFusion, {1}},
77 {ops::kNameAvgPoolGrad, {}},
78 {ops::kNameBatchNorm, {1}},
79 {ops::kNameBatchNormGrad, {1, 2}},
80 {ops::kNameBatchToSpace, {1}},
81 {ops::kNameBiasAdd, {1}},
82 {ops::kNameBiasAddGrad, {1}},
83 {ops::kNameConv2DBackpropInputFusion, {1}},
84 {ops::kNameConv2DBackpropFilterFusion, {1, 2}},
85 {ops::kNameConv2DFusion, {1}},
86 {ops::kNameConv2dTransposeFusion, {1}},
87 {ops::kNameDepthToSpace, {1}},
88 {ops::kNameFusedBatchNorm, {1}},
89 {ops::kNameLRN, {1}},
90 {ops::kNameMaxPoolFusion, {1}},
91 {ops::kNameMaxPoolGrad, {}},
92 {ops::kNamePReLUFusion, {1}},
93 {ops::kNameResize, {1}},
94 {ops::kNameResizeGrad, {}},
95 {ops::kNameROIPooling, {1}},
96 {ops::kNameSGD, {2}},
97 {ops::kNameSpaceToBatch, {1}},
98 {ops::kNameSpaceToBatchND, {1}},
99 {ops::kNameSpaceToDepth, {1}},
100 {ops::kNameTopKFusion, {1}}};
101
102 static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ops::kNameInstanceNorm, {1}}};
103
104 // a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
105 static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
106 {ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
107 {ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false},
108 {ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false},
109 {ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false},
110 {ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false},
111 {ops::kNameQuantDTypeCast, false}, {ops::kNameCast, false}};
112
GetNHWCOpMap()113 const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
GetNCHWOpMap()114 const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
IsDynamicFormatOp(const std::string & op_type)115 bool IsDynamicFormatOp(const std::string &op_type) {
116 return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end();
117 }
IsDynamicFormatOpWithAxis(const std::string & op_type)118 bool IsDynamicFormatOpWithAxis(const std::string &op_type) {
119 auto iter = DynamicFormatOpList.find(op_type);
120 return iter != DynamicFormatOpList.end() && iter->second;
121 }
122
GetTransposePerm(const CNodePtr & cnode,std::vector<int> * perm)123 STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
124 MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_NULL_PTR);
125 MS_CHECK_TRUE_RET(perm != nullptr, lite::RET_NULL_PTR);
126 if (cnode->size() != kInputSizeThree) {
127 MS_LOG(ERROR) << "transpose op input size must be three.";
128 return lite::RET_ERROR;
129 }
130 if (utils::isa<CNodePtr>(cnode->input(kInputIndexTwo))) {
131 return lite::RET_OK;
132 }
133 lite::DataInfo data_info;
134 int status;
135 if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
136 status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
137 } else {
138 status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
139 }
140 if (status != lite::RET_OK) {
141 MS_LOG(ERROR) << "fetch transpose perm data failed.";
142 return lite::RET_ERROR;
143 }
144 if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
145 data_info.shape_.size() != 1) {
146 MS_LOG(ERROR) << "transpose perm data is invalid.";
147 return lite::RET_ERROR;
148 }
149 perm->resize(data_info.shape_[0]);
150 if (!data_info.data_.empty() &&
151 memcpy_s(perm->data(), perm->size() * sizeof(int), data_info.data_.data(), data_info.data_.size()) != EOK) {
152 MS_LOG(ERROR) << "memcpy data failed.";
153 return lite::RET_ERROR;
154 }
155 return lite::RET_OK;
156 }
157
RemoveIfMonad(const CNodePtr & cnode)158 void RemoveIfMonad(const CNodePtr &cnode) {
159 MS_ASSERT(cnode != nullptr);
160 std::vector<AnfNodePtr> inputs{cnode->input(0)};
161 for (size_t i = 1; i < cnode->size(); ++i) {
162 if (utils::isa<ValueNodePtr>(cnode->input(i))) {
163 auto value_node = cnode->input(i)->cast<ValueNodePtr>();
164 auto value = value_node->value();
165 if (value->isa<Monad>()) {
166 continue;
167 }
168 }
169 inputs.push_back(cnode->input(i));
170 }
171 cnode->set_inputs(inputs);
172 }
173
IsMonadNode(const AnfNodePtr & node)174 bool IsMonadNode(const AnfNodePtr &node) {
175 if (node == nullptr) {
176 MS_LOG(ERROR) << "input parameter is nullptr.";
177 return false;
178 }
179 if (!utils::isa<ValueNodePtr>(node)) {
180 return false;
181 }
182 auto value_node = node->cast<ValueNodePtr>();
183 auto value = value_node->value();
184 if (value->isa<Monad>()) {
185 return true;
186 }
187 return false;
188 }
189
IsSpecialType(const CNodePtr & cnode)190 bool IsSpecialType(const CNodePtr &cnode) {
191 return CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
192 CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, kPrimMakeTupleV2) ||
193 CheckPrimitiveType(cnode, prim::kPrimReturn);
194 }
195 } // namespace opt
196 } // namespace mindspore
197