• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/fusion/conv_transform_fusion.h"
18 #include <memory>
19 #include "ops/fusion/conv2d_fusion.h"
20 #include "ops/fusion/conv2d_transpose_fusion.h"
21 #include "tools/common/tensor_util.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "securec/include/securec.h"
24 #include "nnacl/op_base.h"
25 
26 namespace mindspore::opt {
27 namespace {
28 constexpr size_t kConvWeightIndex = 2;
29 constexpr size_t kConvBiasIndex = 3;
30 constexpr size_t kConvNoBiasLen = 3;
31 constexpr size_t kConvWithBiasLen = 4;
GetOutChannels(const CNodePtr & conv_node)32 int64_t GetOutChannels(const CNodePtr &conv_node) {
33   MS_ASSERT(conv_node != nullptr);
34   auto value_node = conv_node->input(0);
35   MS_ASSERT(value_node != nullptr);
36   if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
37     auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(value_node);
38     MS_ASSERT(conv_prim != nullptr);
39     if (conv_prim->GetAttr(ops::kOutChannel) == nullptr) {
40       return 0;
41     }
42     return conv_prim->get_out_channel();
43   } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
44     auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2dTransposeFusion>>(value_node);
45     MS_ASSERT(conv_prim != nullptr);
46     if (conv_prim->GetAttr(ops::kOutChannel) == nullptr) {
47       return 0;
48     }
49     return conv_prim->get_out_channel();
50   }
51   return 0;
52 }
53 
GenerateNewWeightConv2D(float * dst_weight,const float * conv_weight,const float * scale_weight,int weight_shape_size,int kernel_num)54 void GenerateNewWeightConv2D(float *dst_weight, const float *conv_weight, const float *scale_weight,
55                              int weight_shape_size, int kernel_num) {
56   MS_ASSERT(dst_weight != nullptr && conv_weight != nullptr && scale_weight != nullptr);
57   if (kernel_num <= 0) {
58     return;
59   }
60   auto kernel_size = weight_shape_size / kernel_num;
61   for (int i = 0; i < kernel_num; i++) {
62     for (int j = 0; j < kernel_size; j++) {
63       dst_weight[i * kernel_size + j] = conv_weight[i * kernel_size + j] * scale_weight[i];
64     }
65   }
66 }
67 
GenerateNewWeightConv2DTranspose(float * dst_weight,const float * scale_weight,const tensor::TensorPtr & weight_tensor,int64_t group,int kernel_num)68 void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weight,
69                                       const tensor::TensorPtr &weight_tensor, int64_t group, int kernel_num) {
70   MS_ASSERT(dst_weight != nullptr && scale_weight != nullptr && weight_tensor != nullptr);
71   if (group <= 0 || kernel_num <= 0) {
72     return;
73   }
74   MS_ASSERT(weight_tensor->data_c() != nullptr);
75   auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
76   auto cin_group = weight_tensor->shape()[0] / group;
77   int64_t area_size = weight_tensor->shape()[kInputIndexTwo] * weight_tensor->shape()[kInputIndexTwo];
78   for (int64_t k = 0; k < cin_group; ++k) {
79     for (int64_t j = 0; j < area_size; j++) {
80       for (int64_t i = 0; i < kernel_num; ++i) {
81         dst_weight[i + j * kernel_num + k * area_size * kernel_num] =
82           weight_data[i + j * kernel_num + k * area_size * kernel_num] * scale_weight[i];
83       }
84     }
85   }
86 }
87 }  // namespace
88 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const89 const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
90                                               const EquivPtr &) const {
91   if (func_graph == nullptr || node == nullptr) {
92     return nullptr;
93   }
94   // transform node means scale,bn
95   auto transform_node = node->cast<CNodePtr>();
96   if (transform_node == nullptr || transform_node->size() < kInputSizeTwo) {
97     return nullptr;
98   }
99   if (IsMarkedTrainOp(transform_node)) {
100     return nullptr;
101   }
102 
103   auto pre_node = transform_node->input(1);
104   auto conv_node = pre_node->cast<CNodePtr>();
105   if (conv_node == nullptr || IsMultiOutputTensors(func_graph, conv_node) || IsVariableWeightConv(conv_node)) {
106     return nullptr;
107   }
108   if (IsMarkedTrainOp(conv_node)) {
109     return nullptr;
110   }
111   auto abstr = transform_node->abstract();
112   int kernel_nums = static_cast<int>(GetOutChannels(conv_node));
113   if (kernel_nums <= 0) {
114     MS_LOG(INFO) << "Unsupported conv node, " << conv_node->DebugString();
115     return node;
116   }
117   auto trans_scale = new (std::nothrow) float[kernel_nums];
118   if (trans_scale == nullptr) {
119     MS_LOG(ERROR) << "tensor_data is nullptr";
120     return nullptr;
121   }
122   auto trans_bias = new (std::nothrow) float[kernel_nums];
123   if (trans_bias == nullptr) {
124     MS_LOG(ERROR) << "tensor_data is nullptr";
125     delete[] trans_scale;
126     return nullptr;
127   }
128   if (GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias) != lite::RET_OK) {
129     MS_LOG(DEBUG) << "cannot do fusion.";
130     delete[] trans_bias;
131     delete[] trans_scale;
132     return nullptr;
133   }
134   if (GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias) != lite::RET_OK) {
135     MS_LOG(WARNING) << "generate a new weight tensor failed.";
136     delete[] trans_bias;
137     delete[] trans_scale;
138     return nullptr;
139   }
140   delete[] trans_bias;
141   delete[] trans_scale;
142   pre_node->set_abstract(abstr);
143   return pre_node;
144 }
145 
GenTransParam(const CNodePtr & transform_node,int kernel_nums,float * trans_scale,float * trans_bias) const146 int ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
147                                        float *trans_bias) const {
148   MS_ASSERT(transform_node != nullptr);
149   if (trans_scale == nullptr) {
150     MS_LOG(ERROR) << "new transScale failed";
151     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
152     return lite::RET_NULL_PTR;
153   }
154   if (trans_bias == nullptr) {
155     MS_LOG(ERROR) << "new transBias failed";
156     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
157     return lite::RET_NULL_PTR;
158   }
159   if (memset_s(trans_scale, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float)) != EOK) {
160     MS_LOG(ERROR) << "memset transScale failed";
161     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
162     return lite::RET_ERROR;
163   }
164   if (memset_s(trans_bias, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float)) != EOK) {
165     MS_LOG(ERROR) << "memset transBias failed";
166     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
167     return lite::RET_ERROR;
168   }
169 
170   return InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
171 }
172 
GenNewConvTensor(const FuncGraphPtr & func_graph,const CNodePtr & conv_node,int kernel_num,const float * trans_scale,const float * trans_bias) const173 int ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num,
174                                           const float *trans_scale, const float *trans_bias) const {
175   MS_ASSERT(func_graph != nullptr && conv_node != nullptr);
176   MS_ASSERT(trans_scale != nullptr && trans_bias != nullptr);
177   auto manager = func_graph->manager();
178   MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_NULL_PTR, "manager is nullptr.");
179   AnfNodePtr conv_weight_node = nullptr;
180   AnfNodePtr conv_bias_node = nullptr;
181   if (conv_node->inputs().size() == kConvNoBiasLen) {
182     conv_weight_node = conv_node->input(kConvWeightIndex);
183   } else if (conv_node->inputs().size() == kConvWithBiasLen) {
184     conv_weight_node = conv_node->input(kConvWeightIndex);
185     conv_bias_node = conv_node->input(kConvBiasIndex);
186   } else {
187     MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
188     return lite::RET_ERROR;
189   }
190   MS_CHECK_TRUE_RET(conv_weight_node != nullptr, lite::RET_ERROR);
191   if (!conv_weight_node->isa<Parameter>()) {
192     MS_LOG(ERROR) << "scale weight node not parameter node";
193     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
194     return lite::RET_ERROR;
195   }
196   if (conv_bias_node != nullptr && !conv_bias_node->isa<Parameter>()) {
197     MS_LOG(ERROR) << "scale bias node not parameter node";
198     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
199     return lite::RET_ERROR;
200   }
201   auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
202   auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
203   if (kernel_num <= 0) {
204     MS_LOG(ERROR) << "kernel num less than 0";
205     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
206     return lite::RET_ERROR;
207   }
208   auto new_weight_tensor = lite::CreateTensorInfo(weight_tensor->data_c(), weight_tensor->DataSize() * sizeof(float),
209                                                   weight_tensor->shape(), weight_tensor->data_type());
210   if (new_weight_tensor == nullptr) {
211     MS_LOG(ERROR) << "create tensor info failed.";
212     return lite::RET_ERROR;
213   }
214   if (CalNewWeightTensor(conv_node, new_weight_tensor, kernel_num, trans_scale) != lite::RET_OK) {
215     MS_LOG(WARNING) << "generate a new weight tensor failed.";
216     return lite::RET_ERROR;
217   }
218   float *bias_data = nullptr;
219   // conv has bias,bias_flag true
220   bool bias_flag = false;
221   if (conv_bias_node != nullptr) {
222     auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
223     auto bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param);
224     bias_data = reinterpret_cast<float *>(bias_tensor->data_c());
225     bias_flag = true;
226   } else {
227     bias_data = new (std::nothrow) float[kernel_num];
228     if (bias_data == nullptr) {
229       MS_LOG(ERROR) << "tensor_data is nullptr";
230       return lite::RET_ERROR;
231     }
232     if (memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
233       delete[] bias_data;
234       return lite::RET_ERROR;
235     }
236   }
237   if (CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias) != lite::RET_OK) {
238     MS_LOG(ERROR) << "generate a new bias failed.";
239     if (!bias_flag) {
240       delete[] bias_data;
241     }
242     return lite::RET_ERROR;
243   }
244   if (!bias_flag) {
245     auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor->data_type());
246     delete[] bias_data;
247     bias_data = nullptr;
248     if (bias_node == nullptr) {
249       MS_LOG(ERROR) << "generate a new bias node failed.";
250       return lite::RET_ERROR;
251     }
252     bias_node->set_name(conv_node->fullname_with_scope() + "_bias");
253     manager->AddEdge(conv_node, bias_node);
254   }
255   auto new_weight_paramter = func_graph->add_parameter();
256   if (new_weight_paramter == nullptr) {
257     MS_LOG(ERROR) << "new_weight_paramter is nullptr";
258     return lite::RET_ERROR;
259   }
260   new_weight_paramter->set_default_param(new_weight_tensor);
261   new_weight_paramter->set_abstract(conv_weight_node->abstract());
262   new_weight_paramter->set_name(conv_node->fullname_with_scope() + conv_weight_node->fullname_with_scope());
263   manager->SetEdge(conv_node, kConvWeightIndex, new_weight_paramter);
264   return lite::RET_OK;
265 }
266 
CalNewWeightTensor(const CNodePtr & conv_node,const tensor::TensorPtr & weight_tensor,int kernel_num,const float * trans_scale) const267 int ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const tensor::TensorPtr &weight_tensor,
268                                             int kernel_num, const float *trans_scale) const {
269   MS_ASSERT(conv_node != nullptr);
270   MS_ASSERT(weight_tensor != nullptr);
271   MS_ASSERT(trans_scale != nullptr);
272   if (weight_tensor->shape().size() > kInputSizeFour) {
273     MS_LOG(ERROR) << "weight tensor shape error";
274     return lite::RET_ERROR;
275   }
276   auto weight_shape_size = weight_tensor->DataSize();
277   MS_CHECK_TRUE_RET(weight_shape_size > 0, lite::RET_ERROR);
278   auto tmp_weight_data = new (std::nothrow) float[weight_shape_size];
279   if (tmp_weight_data == nullptr) {
280     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
281     return lite::RET_ERROR;
282   }
283   auto data_size = weight_shape_size * sizeof(float);
284   if (memset_s(tmp_weight_data, data_size, 0, data_size) != EOK) {
285     MS_LOG(ERROR) << "memset newWeightData failed";
286     delete[] tmp_weight_data;
287     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
288     return lite::RET_ERROR;
289   }
290   auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
291   auto conv_prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
292   MS_ASSERT(conv_prim != nullptr);
293   bool is_depth_wise =
294     conv_prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(conv_prim->GetAttr(ops::kIsDepthWise));
295   if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
296     GenerateNewWeightConv2D(tmp_weight_data, weight_data, trans_scale, weight_shape_size, kernel_num);
297   } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
298     auto conv_primc = conv_prim->cast<std::shared_ptr<ops::Conv2dTransposeFusion>>();
299     MS_ASSERT(conv_primc != nullptr);
300     auto group = conv_primc->GetAttr(ops::kGroup) == nullptr ? 1 : conv_primc->get_group();
301     GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
302   }
303   auto ret = memcpy_s(weight_data, weight_tensor->Size(), tmp_weight_data, data_size);
304   delete[] tmp_weight_data;
305   if (ret != EOK) {
306     MS_LOG(ERROR) << "memcpy error: " << ret;
307     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
308     return lite::RET_ERROR;
309   }
310   return lite::RET_OK;
311 }
312 
CalNewBiasTensor(float * bias_data,int kernel_num,bool bias_flag,const float * trans_scale,const float * trans_bias) const313 int ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale,
314                                           const float *trans_bias) const {
315   MS_ASSERT(bias_data != nullptr);
316   MS_ASSERT(trans_bias != nullptr);
317   MS_ASSERT(trans_scale != nullptr);
318   if (bias_flag) {
319     auto tmp_bias_data = new (std::nothrow) float[kernel_num];
320     if (tmp_bias_data == nullptr) {
321       MS_LOG(ERROR) << "tensor_data is nullptr";
322       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
323       return lite::RET_NULL_PTR;
324     }
325     if (memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
326       MS_LOG(ERROR) << "memset bias data failed";
327       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
328       delete[] tmp_bias_data;
329       return lite::RET_MEMORY_FAILED;
330     }
331     for (int i = 0; i < kernel_num; i++) {
332       tmp_bias_data[i] = bias_data[i] * trans_scale[i] + trans_bias[i];
333     }
334 
335     auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), tmp_bias_data, kernel_num * sizeof(float));
336     delete[] tmp_bias_data;
337     if (ret != EOK) {
338       MS_LOG(ERROR) << "memcpy error: " << ret;
339       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
340       return lite::RET_MEMORY_FAILED;
341     }
342   } else {
343     if (memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
344       MS_LOG(ERROR) << "memset bias data failed";
345       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
346       return lite::RET_MEMORY_FAILED;
347     }
348     auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), trans_bias, kernel_num * sizeof(float));
349     if (ret != EOK) {
350       MS_LOG(ERROR) << "memcpy error: " << ret;
351       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
352       return lite::RET_MEMORY_FAILED;
353     }
354   }
355   return lite::RET_OK;
356 }
357 
IsVariableWeightConv(const CNodePtr & conv_node) const358 bool ConvTransformFusion::IsVariableWeightConv(const CNodePtr &conv_node) const {
359   MS_ASSERT(conv_node != nullptr);
360   MS_ASSERT(conv_node->inputs().size() >= kConvNoBiasLen);
361   auto conv_weight_node = conv_node->input(kConvWeightIndex);
362   bool is_value_node = conv_weight_node->isa<ValueNode>();
363   auto conv_weight_param =
364     conv_weight_node->isa<Parameter>() ? conv_weight_node->cast<ParameterPtr>()->default_param() : nullptr;
365   return !is_value_node && conv_weight_param == nullptr;
366 }
367 }  // namespace mindspore::opt
368