1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/fusion/conv_transform_fusion.h"
18 #include <memory>
19 #include "ops/fusion/conv2d_fusion.h"
20 #include "ops/fusion/conv2d_transpose_fusion.h"
21 #include "tools/common/tensor_util.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "securec/include/securec.h"
24 #include "nnacl/op_base.h"
25
26 namespace mindspore::opt {
27 namespace {
28 constexpr size_t kConvWeightIndex = 2;
29 constexpr size_t kConvBiasIndex = 3;
30 constexpr size_t kConvNoBiasLen = 3;
31 constexpr size_t kConvWithBiasLen = 4;
GetOutChannels(const CNodePtr & conv_node)32 int64_t GetOutChannels(const CNodePtr &conv_node) {
33 MS_ASSERT(conv_node != nullptr);
34 auto value_node = conv_node->input(0);
35 MS_ASSERT(value_node != nullptr);
36 if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
37 auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(value_node);
38 MS_ASSERT(conv_prim != nullptr);
39 if (conv_prim->GetAttr(ops::kOutChannel) == nullptr) {
40 return 0;
41 }
42 return conv_prim->get_out_channel();
43 } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
44 auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2dTransposeFusion>>(value_node);
45 MS_ASSERT(conv_prim != nullptr);
46 if (conv_prim->GetAttr(ops::kOutChannel) == nullptr) {
47 return 0;
48 }
49 return conv_prim->get_out_channel();
50 }
51 return 0;
52 }
53
GenerateNewWeightConv2D(float * dst_weight,const float * conv_weight,const float * scale_weight,int weight_shape_size,int kernel_num)54 void GenerateNewWeightConv2D(float *dst_weight, const float *conv_weight, const float *scale_weight,
55 int weight_shape_size, int kernel_num) {
56 MS_ASSERT(dst_weight != nullptr && conv_weight != nullptr && scale_weight != nullptr);
57 if (kernel_num <= 0) {
58 return;
59 }
60 auto kernel_size = weight_shape_size / kernel_num;
61 for (int i = 0; i < kernel_num; i++) {
62 for (int j = 0; j < kernel_size; j++) {
63 dst_weight[i * kernel_size + j] = conv_weight[i * kernel_size + j] * scale_weight[i];
64 }
65 }
66 }
67
GenerateNewWeightConv2DTranspose(float * dst_weight,const float * scale_weight,const tensor::TensorPtr & weight_tensor,int64_t group,int kernel_num)68 void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weight,
69 const tensor::TensorPtr &weight_tensor, int64_t group, int kernel_num) {
70 MS_ASSERT(dst_weight != nullptr && scale_weight != nullptr && weight_tensor != nullptr);
71 if (group <= 0 || kernel_num <= 0) {
72 return;
73 }
74 MS_ASSERT(weight_tensor->data_c() != nullptr);
75 auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
76 auto cin_group = weight_tensor->shape()[0] / group;
77 int64_t area_size = weight_tensor->shape()[kInputIndexTwo] * weight_tensor->shape()[kInputIndexTwo];
78 for (int64_t k = 0; k < cin_group; ++k) {
79 for (int64_t j = 0; j < area_size; j++) {
80 for (int64_t i = 0; i < kernel_num; ++i) {
81 dst_weight[i + j * kernel_num + k * area_size * kernel_num] =
82 weight_data[i + j * kernel_num + k * area_size * kernel_num] * scale_weight[i];
83 }
84 }
85 }
86 }
87 } // namespace
88
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const89 const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
90 const EquivPtr &) const {
91 if (func_graph == nullptr || node == nullptr) {
92 return nullptr;
93 }
94 // transform node means scale,bn
95 auto transform_node = node->cast<CNodePtr>();
96 if (transform_node == nullptr || transform_node->size() < kInputSizeTwo) {
97 return nullptr;
98 }
99 if (IsMarkedTrainOp(transform_node)) {
100 return nullptr;
101 }
102
103 auto pre_node = transform_node->input(1);
104 auto conv_node = pre_node->cast<CNodePtr>();
105 if (conv_node == nullptr || IsMultiOutputTensors(func_graph, conv_node) || IsVariableWeightConv(conv_node)) {
106 return nullptr;
107 }
108 if (IsMarkedTrainOp(conv_node)) {
109 return nullptr;
110 }
111 auto abstr = transform_node->abstract();
112 int kernel_nums = static_cast<int>(GetOutChannels(conv_node));
113 if (kernel_nums <= 0) {
114 MS_LOG(INFO) << "Unsupported conv node, " << conv_node->DebugString();
115 return node;
116 }
117 auto trans_scale = new (std::nothrow) float[kernel_nums];
118 if (trans_scale == nullptr) {
119 MS_LOG(ERROR) << "tensor_data is nullptr";
120 return nullptr;
121 }
122 auto trans_bias = new (std::nothrow) float[kernel_nums];
123 if (trans_bias == nullptr) {
124 MS_LOG(ERROR) << "tensor_data is nullptr";
125 delete[] trans_scale;
126 return nullptr;
127 }
128 if (GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias) != lite::RET_OK) {
129 MS_LOG(DEBUG) << "cannot do fusion.";
130 delete[] trans_bias;
131 delete[] trans_scale;
132 return nullptr;
133 }
134 if (GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias) != lite::RET_OK) {
135 MS_LOG(WARNING) << "generate a new weight tensor failed.";
136 delete[] trans_bias;
137 delete[] trans_scale;
138 return nullptr;
139 }
140 delete[] trans_bias;
141 delete[] trans_scale;
142 pre_node->set_abstract(abstr);
143 return pre_node;
144 }
145
GenTransParam(const CNodePtr & transform_node,int kernel_nums,float * trans_scale,float * trans_bias) const146 int ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
147 float *trans_bias) const {
148 MS_ASSERT(transform_node != nullptr);
149 if (trans_scale == nullptr) {
150 MS_LOG(ERROR) << "new transScale failed";
151 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
152 return lite::RET_NULL_PTR;
153 }
154 if (trans_bias == nullptr) {
155 MS_LOG(ERROR) << "new transBias failed";
156 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
157 return lite::RET_NULL_PTR;
158 }
159 if (memset_s(trans_scale, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float)) != EOK) {
160 MS_LOG(ERROR) << "memset transScale failed";
161 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
162 return lite::RET_ERROR;
163 }
164 if (memset_s(trans_bias, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float)) != EOK) {
165 MS_LOG(ERROR) << "memset transBias failed";
166 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
167 return lite::RET_ERROR;
168 }
169
170 return InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
171 }
172
GenNewConvTensor(const FuncGraphPtr & func_graph,const CNodePtr & conv_node,int kernel_num,const float * trans_scale,const float * trans_bias) const173 int ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num,
174 const float *trans_scale, const float *trans_bias) const {
175 MS_ASSERT(func_graph != nullptr && conv_node != nullptr);
176 MS_ASSERT(trans_scale != nullptr && trans_bias != nullptr);
177 auto manager = func_graph->manager();
178 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_NULL_PTR, "manager is nullptr.");
179 AnfNodePtr conv_weight_node = nullptr;
180 AnfNodePtr conv_bias_node = nullptr;
181 if (conv_node->inputs().size() == kConvNoBiasLen) {
182 conv_weight_node = conv_node->input(kConvWeightIndex);
183 } else if (conv_node->inputs().size() == kConvWithBiasLen) {
184 conv_weight_node = conv_node->input(kConvWeightIndex);
185 conv_bias_node = conv_node->input(kConvBiasIndex);
186 } else {
187 MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
188 return lite::RET_ERROR;
189 }
190 MS_CHECK_TRUE_RET(conv_weight_node != nullptr, lite::RET_ERROR);
191 if (!conv_weight_node->isa<Parameter>()) {
192 MS_LOG(ERROR) << "scale weight node not parameter node";
193 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
194 return lite::RET_ERROR;
195 }
196 if (conv_bias_node != nullptr && !conv_bias_node->isa<Parameter>()) {
197 MS_LOG(ERROR) << "scale bias node not parameter node";
198 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
199 return lite::RET_ERROR;
200 }
201 auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
202 auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
203 if (kernel_num <= 0) {
204 MS_LOG(ERROR) << "kernel num less than 0";
205 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
206 return lite::RET_ERROR;
207 }
208 auto new_weight_tensor = lite::CreateTensorInfo(weight_tensor->data_c(), weight_tensor->DataSize() * sizeof(float),
209 weight_tensor->shape(), weight_tensor->data_type());
210 if (new_weight_tensor == nullptr) {
211 MS_LOG(ERROR) << "create tensor info failed.";
212 return lite::RET_ERROR;
213 }
214 if (CalNewWeightTensor(conv_node, new_weight_tensor, kernel_num, trans_scale) != lite::RET_OK) {
215 MS_LOG(WARNING) << "generate a new weight tensor failed.";
216 return lite::RET_ERROR;
217 }
218 float *bias_data = nullptr;
219 // conv has bias,bias_flag true
220 bool bias_flag = false;
221 if (conv_bias_node != nullptr) {
222 auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
223 auto bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param);
224 bias_data = reinterpret_cast<float *>(bias_tensor->data_c());
225 bias_flag = true;
226 } else {
227 bias_data = new (std::nothrow) float[kernel_num];
228 if (bias_data == nullptr) {
229 MS_LOG(ERROR) << "tensor_data is nullptr";
230 return lite::RET_ERROR;
231 }
232 if (memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
233 delete[] bias_data;
234 return lite::RET_ERROR;
235 }
236 }
237 if (CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias) != lite::RET_OK) {
238 MS_LOG(ERROR) << "generate a new bias failed.";
239 if (!bias_flag) {
240 delete[] bias_data;
241 }
242 return lite::RET_ERROR;
243 }
244 if (!bias_flag) {
245 auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor->data_type());
246 delete[] bias_data;
247 bias_data = nullptr;
248 if (bias_node == nullptr) {
249 MS_LOG(ERROR) << "generate a new bias node failed.";
250 return lite::RET_ERROR;
251 }
252 bias_node->set_name(conv_node->fullname_with_scope() + "_bias");
253 manager->AddEdge(conv_node, bias_node);
254 }
255 auto new_weight_paramter = func_graph->add_parameter();
256 if (new_weight_paramter == nullptr) {
257 MS_LOG(ERROR) << "new_weight_paramter is nullptr";
258 return lite::RET_ERROR;
259 }
260 new_weight_paramter->set_default_param(new_weight_tensor);
261 new_weight_paramter->set_abstract(conv_weight_node->abstract());
262 new_weight_paramter->set_name(conv_node->fullname_with_scope() + conv_weight_node->fullname_with_scope());
263 manager->SetEdge(conv_node, kConvWeightIndex, new_weight_paramter);
264 return lite::RET_OK;
265 }
266
CalNewWeightTensor(const CNodePtr & conv_node,const tensor::TensorPtr & weight_tensor,int kernel_num,const float * trans_scale) const267 int ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const tensor::TensorPtr &weight_tensor,
268 int kernel_num, const float *trans_scale) const {
269 MS_ASSERT(conv_node != nullptr);
270 MS_ASSERT(weight_tensor != nullptr);
271 MS_ASSERT(trans_scale != nullptr);
272 if (weight_tensor->shape().size() > kInputSizeFour) {
273 MS_LOG(ERROR) << "weight tensor shape error";
274 return lite::RET_ERROR;
275 }
276 auto weight_shape_size = weight_tensor->DataSize();
277 MS_CHECK_TRUE_RET(weight_shape_size > 0, lite::RET_ERROR);
278 auto tmp_weight_data = new (std::nothrow) float[weight_shape_size];
279 if (tmp_weight_data == nullptr) {
280 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
281 return lite::RET_ERROR;
282 }
283 auto data_size = weight_shape_size * sizeof(float);
284 if (memset_s(tmp_weight_data, data_size, 0, data_size) != EOK) {
285 MS_LOG(ERROR) << "memset newWeightData failed";
286 delete[] tmp_weight_data;
287 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
288 return lite::RET_ERROR;
289 }
290 auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
291 auto conv_prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
292 MS_ASSERT(conv_prim != nullptr);
293 bool is_depth_wise =
294 conv_prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(conv_prim->GetAttr(ops::kIsDepthWise));
295 if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
296 GenerateNewWeightConv2D(tmp_weight_data, weight_data, trans_scale, weight_shape_size, kernel_num);
297 } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
298 auto conv_primc = conv_prim->cast<std::shared_ptr<ops::Conv2dTransposeFusion>>();
299 MS_ASSERT(conv_primc != nullptr);
300 auto group = conv_primc->GetAttr(ops::kGroup) == nullptr ? 1 : conv_primc->get_group();
301 GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
302 }
303 auto ret = memcpy_s(weight_data, weight_tensor->Size(), tmp_weight_data, data_size);
304 delete[] tmp_weight_data;
305 if (ret != EOK) {
306 MS_LOG(ERROR) << "memcpy error: " << ret;
307 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
308 return lite::RET_ERROR;
309 }
310 return lite::RET_OK;
311 }
312
CalNewBiasTensor(float * bias_data,int kernel_num,bool bias_flag,const float * trans_scale,const float * trans_bias) const313 int ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale,
314 const float *trans_bias) const {
315 MS_ASSERT(bias_data != nullptr);
316 MS_ASSERT(trans_bias != nullptr);
317 MS_ASSERT(trans_scale != nullptr);
318 if (bias_flag) {
319 auto tmp_bias_data = new (std::nothrow) float[kernel_num];
320 if (tmp_bias_data == nullptr) {
321 MS_LOG(ERROR) << "tensor_data is nullptr";
322 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
323 return lite::RET_NULL_PTR;
324 }
325 if (memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
326 MS_LOG(ERROR) << "memset bias data failed";
327 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
328 delete[] tmp_bias_data;
329 return lite::RET_MEMORY_FAILED;
330 }
331 for (int i = 0; i < kernel_num; i++) {
332 tmp_bias_data[i] = bias_data[i] * trans_scale[i] + trans_bias[i];
333 }
334
335 auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), tmp_bias_data, kernel_num * sizeof(float));
336 delete[] tmp_bias_data;
337 if (ret != EOK) {
338 MS_LOG(ERROR) << "memcpy error: " << ret;
339 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
340 return lite::RET_MEMORY_FAILED;
341 }
342 } else {
343 if (memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
344 MS_LOG(ERROR) << "memset bias data failed";
345 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
346 return lite::RET_MEMORY_FAILED;
347 }
348 auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), trans_bias, kernel_num * sizeof(float));
349 if (ret != EOK) {
350 MS_LOG(ERROR) << "memcpy error: " << ret;
351 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
352 return lite::RET_MEMORY_FAILED;
353 }
354 }
355 return lite::RET_OK;
356 }
357
IsVariableWeightConv(const CNodePtr & conv_node) const358 bool ConvTransformFusion::IsVariableWeightConv(const CNodePtr &conv_node) const {
359 MS_ASSERT(conv_node != nullptr);
360 MS_ASSERT(conv_node->inputs().size() >= kConvNoBiasLen);
361 auto conv_weight_node = conv_node->input(kConvWeightIndex);
362 bool is_value_node = conv_weight_node->isa<ValueNode>();
363 auto conv_weight_param =
364 conv_weight_node->isa<Parameter>() ? conv_weight_node->cast<ParameterPtr>()->default_param() : nullptr;
365 return !is_value_node && conv_weight_param == nullptr;
366 }
367 } // namespace mindspore::opt
368