1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/conv_transform_fusion.h"
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "ops/fusion/conv2d_fusion.h"
24 #include "ops/fusion/conv2d_transpose_fusion.h"
25 #include "tools/common/tensor_util.h"
26 #include "tools/optimizer/common/gllo_utils.h"
27 #include "tools/converter/quantizer/quant_param_holder.h"
28 #include "securec/include/securec.h"
29 #include "nnacl/op_base.h"
30 #include "ops/op_utils.h"
31
32 namespace mindspore::opt {
33 namespace {
34 constexpr size_t kConvWeightIndex = 2;
35 constexpr size_t kConvBiasIndex = 3;
36 constexpr size_t kConvNoBiasLen = 3;
37 constexpr size_t kConvWithBiasLen = 4;
GetOutChannels(const CNodePtr & conv_node)38 int64_t GetOutChannels(const CNodePtr &conv_node) {
39 MS_ASSERT(conv_node != nullptr);
40 auto value_node = conv_node->input(0);
41 MS_ASSERT(value_node != nullptr);
42 if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
43 auto conv_prim = ops::GetOperator<ops::Conv2DFusion>(value_node);
44 MS_ASSERT(conv_prim != nullptr);
45 auto conv_prim_c = conv_prim->GetPrim();
46 MS_ASSERT(conv_prim_c != nullptr);
47 if (conv_prim_c->GetAttr(ops::kOutChannel) == nullptr) {
48 return 0;
49 }
50 return conv_prim->get_out_channel();
51 } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
52 auto conv_prim = ops::GetOperator<ops::Conv2dTransposeFusion>(value_node);
53 MS_ASSERT(conv_prim != nullptr);
54 auto conv_prim_c = conv_prim->GetPrim();
55 MS_ASSERT(conv_prim_c != nullptr);
56 if (conv_prim_c->GetAttr(ops::kOutChannel) == nullptr) {
57 return 0;
58 }
59 return conv_prim->get_out_channel();
60 }
61 return 0;
62 }
63
GenerateNewWeightConv2D(float * dst_weight,const float * conv_weight,const float * scale_weight,size_t weight_shape_size,int kernel_num)64 void GenerateNewWeightConv2D(float *dst_weight, const float *conv_weight, const float *scale_weight,
65 size_t weight_shape_size, int kernel_num) {
66 MS_ASSERT(dst_weight != nullptr && conv_weight != nullptr && scale_weight != nullptr);
67 if (kernel_num <= 0) {
68 return;
69 }
70 auto kernel_size = weight_shape_size / static_cast<size_t>(kernel_num);
71 for (size_t i = 0; i < static_cast<size_t>(kernel_num); ++i) {
72 for (size_t j = 0; j < kernel_size; j++) {
73 dst_weight[i * kernel_size + j] = conv_weight[i * kernel_size + j] * scale_weight[i];
74 }
75 }
76 }
77
GenerateNewWeightConv2DTranspose(float * dst_weight,const float * scale_weight,const tensor::TensorPtr & weight_tensor,int64_t group,int kernel_num)78 void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weight,
79 const tensor::TensorPtr &weight_tensor, int64_t group, int kernel_num) {
80 MS_ASSERT(dst_weight != nullptr && scale_weight != nullptr && weight_tensor != nullptr);
81 if (group <= 0 || kernel_num <= 0) {
82 return;
83 }
84 MS_ASSERT(weight_tensor->data_c() != nullptr);
85 auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
86 auto cin_group = weight_tensor->shape()[0] / group;
87 int64_t area_size = weight_tensor->shape()[kNHWC_H] * weight_tensor->shape()[kNHWC_W];
88 for (int64_t k = 0; k < cin_group; ++k) {
89 for (int64_t j = 0; j < area_size; j++) {
90 for (int64_t i = 0; i < kernel_num; ++i) {
91 dst_weight[i + j * kernel_num + k * area_size * kernel_num] =
92 weight_data[i + j * kernel_num + k * area_size * kernel_num] * scale_weight[i];
93 }
94 }
95 }
96 }
97
98 // this function should replace GenerateNewWeightConv2DTranspose after all fusions support NCHW
GenerateNewWeightConv2DTranspose_NCHW(float * dst_weight,const float * scale_weight,const tensor::TensorPtr & weight_tensor,int64_t group,int kernel_num)99 void GenerateNewWeightConv2DTranspose_NCHW(float *dst_weight, const float *scale_weight,
100 const tensor::TensorPtr &weight_tensor, int64_t group, int kernel_num) {
101 MS_ASSERT(dst_weight != nullptr && scale_weight != nullptr && weight_tensor != nullptr);
102 if (group <= 0 || kernel_num <= 0) {
103 return;
104 }
105 auto cin_group = weight_tensor->shape()[0] / group;
106 MS_ASSERT(weight_tensor->data_c() != nullptr);
107 auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
108 int64_t area_size = weight_tensor->shape()[kNHWC_H] * weight_tensor->shape()[kNHWC_W];
109 for (int64_t k = 0; k < cin_group; ++k) {
110 for (int64_t i = 0; i < kernel_num; ++i) { // output channel num -> C
111 for (int64_t j = 0; j < area_size; j++) { // HW
112 dst_weight[i * area_size + j + k * area_size * kernel_num] =
113 weight_data[i * area_size + j + k * area_size * kernel_num] * scale_weight[i];
114 }
115 }
116 }
117 }
118 } // namespace
119
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const120 const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
121 const EquivPtr &) const {
122 if (func_graph == nullptr || node == nullptr) {
123 return nullptr;
124 }
125 // transform node means scale,bn
126 auto transform_node = node->cast<CNodePtr>();
127 if (transform_node == nullptr || transform_node->size() < kInputSizeTwo) {
128 return nullptr;
129 }
130 if (IsMarkedTrainOp(transform_node)) {
131 return nullptr;
132 }
133
134 auto pre_node = transform_node->input(1);
135 auto conv_node = pre_node->cast<CNodePtr>();
136 MS_CHECK_TRUE_RET(conv_node != nullptr, nullptr);
137 if (!CheckCanFused(func_graph, conv_node)) {
138 return nullptr;
139 }
140
141 // Check the activation type of scale.
142 if (!AdjustActivationType(conv_node, transform_node)) {
143 return nullptr;
144 }
145 auto abstr = transform_node->abstract();
146 int kernel_nums = static_cast<int>(GetOutChannels(conv_node));
147 if (kernel_nums <= 0) {
148 MS_LOG(INFO) << "Unsupported conv node, " << conv_node->DebugString();
149 return node;
150 }
151 auto trans_scale = new (std::nothrow) float[kernel_nums];
152 if (trans_scale == nullptr) {
153 MS_LOG(ERROR) << "tensor_data is nullptr";
154 return nullptr;
155 }
156 auto trans_bias = new (std::nothrow) float[kernel_nums];
157 if (trans_bias == nullptr) {
158 MS_LOG(ERROR) << "tensor_data is nullptr";
159 delete[] trans_scale;
160 return nullptr;
161 }
162 if (GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias) != lite::RET_OK) {
163 MS_LOG(DEBUG) << "cannot do fusion.";
164 delete[] trans_bias;
165 delete[] trans_scale;
166 return nullptr;
167 }
168 if (GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias) != lite::RET_OK) {
169 MS_LOG(WARNING) << "generate a new weight tensor failed.";
170 delete[] trans_bias;
171 delete[] trans_scale;
172 return nullptr;
173 }
174 delete[] trans_bias;
175 delete[] trans_scale;
176 pre_node->set_abstract(abstr);
177 return pre_node;
178 }
179
AdjustActivationType(const CNodePtr & conv_node,const CNodePtr & transform_node) const180 bool ConvTransformFusion::AdjustActivationType(const CNodePtr &conv_node, const CNodePtr &transform_node) const {
181 MS_ASSERT(conv_node != nullptr && transform_node != nullptr);
182 MS_CHECK_TRUE_RET(transform_node->input(0) != nullptr, false);
183 auto trans_prim = GetValueNode<PrimitivePtr>(transform_node->input(0));
184 MS_CHECK_TRUE_RET(trans_prim != nullptr, false);
185 auto trans_act_ptr = trans_prim->GetAttr(ops::kActivationType);
186 if (trans_act_ptr == nullptr || GetValue<int64_t>(trans_act_ptr) == ActivationType::NO_ACTIVATION) {
187 return true;
188 }
189 auto trans_act = GetValue<int64_t>(trans_act_ptr);
190 // convolution only supports RELU and RELU6.
191 if (trans_act != ActivationType::RELU && trans_act != ActivationType::RELU6) {
192 return false;
193 }
194 MS_CHECK_TRUE_RET(conv_node->input(0) != nullptr, false);
195 auto conv_prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
196 MS_CHECK_TRUE_RET(conv_prim != nullptr, false);
197 (void)conv_prim->AddAttr(ops::kActivationType, MakeValue(trans_act));
198 return true;
199 }
200
GenTransParam(const CNodePtr & transform_node,int kernel_nums,float * trans_scale,float * trans_bias) const201 int ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
202 float *trans_bias) const {
203 MS_ASSERT(transform_node != nullptr);
204 if (trans_scale == nullptr) {
205 MS_LOG(ERROR) << "new transScale failed";
206 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
207 return lite::RET_NULL_PTR;
208 }
209 if (trans_bias == nullptr) {
210 MS_LOG(ERROR) << "new transBias failed";
211 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
212 return lite::RET_NULL_PTR;
213 }
214 if (memset_s(trans_scale, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float)) != EOK) {
215 MS_LOG(ERROR) << "memset transScale failed";
216 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
217 return lite::RET_ERROR;
218 }
219 if (memset_s(trans_bias, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float)) != EOK) {
220 MS_LOG(ERROR) << "memset transBias failed";
221 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
222 return lite::RET_ERROR;
223 }
224
225 return InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
226 }
227
GenNewConvTensor(const FuncGraphPtr & func_graph,const CNodePtr & conv_node,int kernel_num,const float * trans_scale,const float * trans_bias) const228 int ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num,
229 const float *trans_scale, const float *trans_bias) const {
230 MS_ASSERT(func_graph != nullptr && conv_node != nullptr);
231 MS_ASSERT(trans_scale != nullptr && trans_bias != nullptr);
232 auto manager = func_graph->manager();
233 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_NULL_PTR, "manager is nullptr");
234 AnfNodePtr conv_weight_node = nullptr;
235 AnfNodePtr conv_bias_node = nullptr;
236 if (conv_node->size() == kConvNoBiasLen) {
237 conv_weight_node = conv_node->input(kConvWeightIndex);
238 } else if (conv_node->size() == kConvWithBiasLen) {
239 conv_weight_node = conv_node->input(kConvWeightIndex);
240 conv_bias_node = conv_node->input(kConvBiasIndex);
241 } else {
242 MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
243 return lite::RET_ERROR;
244 }
245 MS_CHECK_TRUE_RET(conv_weight_node != nullptr, lite::RET_ERROR);
246 if (!conv_weight_node->isa<Parameter>()) {
247 MS_LOG(ERROR) << "scale weight node not parameter node";
248 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
249 return lite::RET_ERROR;
250 }
251 if (conv_bias_node != nullptr && !conv_bias_node->isa<Parameter>()) {
252 MS_LOG(ERROR) << "scale bias node not parameter node";
253 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
254 return lite::RET_ERROR;
255 }
256 auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
257 auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
258 if (kernel_num <= 0) {
259 MS_LOG(ERROR) << "kernel num less than 0";
260 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
261 return lite::RET_ERROR;
262 }
263 auto new_weight_tensor = lite::CreateTensorInfo(weight_tensor->data_c(), weight_tensor->DataSize() * sizeof(float),
264 weight_tensor->shape(), weight_tensor->data_type());
265 if (new_weight_tensor == nullptr) {
266 MS_LOG(ERROR) << "create tensor info failed.";
267 return lite::RET_ERROR;
268 }
269 if (CalNewWeightTensor(conv_node, new_weight_tensor, kernel_num, trans_scale) != lite::RET_OK) {
270 MS_LOG(WARNING) << "generate a new weight tensor failed.";
271 return lite::RET_ERROR;
272 }
273 float *bias_data = nullptr;
274 // conv has bias,bias_flag true
275 bool bias_flag = false;
276 if (conv_bias_node != nullptr) {
277 auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
278 auto bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param);
279 bias_data = reinterpret_cast<float *>(bias_tensor->data_c());
280 bias_flag = true;
281 } else {
282 bias_data = new (std::nothrow) float[kernel_num];
283 if (bias_data == nullptr) {
284 MS_LOG(ERROR) << "tensor_data is nullptr";
285 return lite::RET_ERROR;
286 }
287 if (memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
288 delete[] bias_data;
289 return lite::RET_ERROR;
290 }
291 }
292 if (CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias) != lite::RET_OK) {
293 MS_LOG(ERROR) << "generate a new bias failed.";
294 if (!bias_flag) {
295 delete[] bias_data;
296 }
297 return lite::RET_ERROR;
298 }
299 if (!bias_flag) {
300 auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor->data_type());
301 delete[] bias_data;
302 bias_data = nullptr;
303 if (bias_node == nullptr) {
304 MS_LOG(ERROR) << "generate a new bias node failed.";
305 return lite::RET_ERROR;
306 }
307 bias_node->set_name(conv_node->fullname_with_scope() + "_bias");
308 manager->AddEdge(conv_node, bias_node);
309 }
310 auto new_weight_paramter = func_graph->add_parameter();
311 if (new_weight_paramter == nullptr) {
312 MS_LOG(ERROR) << "new_weight_paramter is nullptr";
313 return lite::RET_ERROR;
314 }
315 new_weight_paramter->set_default_param(new_weight_tensor);
316 new_weight_paramter->set_abstract(conv_weight_node->abstract());
317 new_weight_paramter->set_name(conv_weight_node->fullname_with_scope());
318 manager->SetEdge(conv_node, kConvWeightIndex, new_weight_paramter);
319 return lite::RET_OK;
320 }
321
CalNewWeightTensor(const CNodePtr & conv_node,const tensor::TensorPtr & weight_tensor,int kernel_num,const float * trans_scale) const322 int ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const tensor::TensorPtr &weight_tensor,
323 int kernel_num, const float *trans_scale) const {
324 MS_ASSERT(conv_node != nullptr);
325 MS_ASSERT(weight_tensor != nullptr);
326 MS_ASSERT(trans_scale != nullptr);
327 if (weight_tensor->shape().size() > kInputSizeFour) {
328 MS_LOG(ERROR) << "weight tensor shape error";
329 return lite::RET_ERROR;
330 }
331 auto weight_shape_size = weight_tensor->DataSize();
332 MS_CHECK_TRUE_RET(weight_shape_size > 0, lite::RET_ERROR);
333 auto tmp_weight_data = new (std::nothrow) float[weight_shape_size];
334 if (tmp_weight_data == nullptr) {
335 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
336 return lite::RET_ERROR;
337 }
338 auto data_size = weight_shape_size * sizeof(float);
339 if (memset_s(tmp_weight_data, data_size, 0, data_size) != EOK) {
340 MS_LOG(ERROR) << "memset newWeightData failed";
341 delete[] tmp_weight_data;
342 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
343 return lite::RET_ERROR;
344 }
345 auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
346 auto conv_prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
347 MS_ASSERT(conv_prim != nullptr);
348 bool is_depth_wise =
349 conv_prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(conv_prim->GetAttr(ops::kIsDepthWise));
350 if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
351 GenerateNewWeightConv2D(tmp_weight_data, weight_data, trans_scale, weight_shape_size, kernel_num);
352 } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
353 auto conv2d_prim = api::MakeShared<ops::Conv2dTransposeFusion>(conv_prim);
354 MS_ASSERT(conv2d_prim != nullptr);
355 auto conv2d_prim_c = conv2d_prim->GetPrim();
356 MS_ASSERT(conv2d_prim_c != nullptr);
357 auto group = conv2d_prim_c->GetAttr(ops::kGroup) == nullptr ? 1 : conv2d_prim->get_group();
358 if (!nchw_format_) {
359 GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
360 } else {
361 GenerateNewWeightConv2DTranspose_NCHW(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
362 }
363 }
364 auto ret = memcpy_s(weight_data, weight_tensor->Size(), tmp_weight_data, data_size);
365 delete[] tmp_weight_data;
366 if (ret != EOK) {
367 MS_LOG(ERROR) << "memcpy error: " << ret;
368 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
369 return lite::RET_ERROR;
370 }
371 return lite::RET_OK;
372 }
373
CalNewBiasTensor(float * bias_data,int kernel_num,bool bias_flag,const float * trans_scale,const float * trans_bias) const374 int ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale,
375 const float *trans_bias) const {
376 MS_ASSERT(bias_data != nullptr);
377 MS_ASSERT(trans_bias != nullptr);
378 MS_ASSERT(trans_scale != nullptr);
379 if (bias_flag) {
380 auto tmp_bias_data = new (std::nothrow) float[kernel_num];
381 if (tmp_bias_data == nullptr) {
382 MS_LOG(ERROR) << "tensor_data is nullptr";
383 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
384 return lite::RET_NULL_PTR;
385 }
386 if (memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
387 MS_LOG(ERROR) << "memset bias data failed";
388 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
389 delete[] tmp_bias_data;
390 return lite::RET_MEMORY_FAILED;
391 }
392 for (int i = 0; i < kernel_num; i++) {
393 tmp_bias_data[i] = bias_data[i] * trans_scale[i] + trans_bias[i];
394 }
395
396 auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), tmp_bias_data, kernel_num * sizeof(float));
397 delete[] tmp_bias_data;
398 if (ret != EOK) {
399 MS_LOG(ERROR) << "memcpy error: " << ret;
400 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
401 return lite::RET_MEMORY_FAILED;
402 }
403 } else {
404 if (memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float)) != EOK) {
405 MS_LOG(ERROR) << "memset bias data failed";
406 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
407 return lite::RET_MEMORY_FAILED;
408 }
409 auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), trans_bias, kernel_num * sizeof(float));
410 if (ret != EOK) {
411 MS_LOG(ERROR) << "memcpy error: " << ret;
412 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
413 return lite::RET_MEMORY_FAILED;
414 }
415 }
416 return lite::RET_OK;
417 }
418
CheckCanFused(const FuncGraphPtr & func_graph,const CNodePtr & conv_node) const419 bool ConvTransformFusion::CheckCanFused(const FuncGraphPtr &func_graph, const CNodePtr &conv_node) const {
420 MS_ASSERT(func_graph != nullptr && conv_node != nullptr);
421 if (IsMultiOutputTensors(func_graph, conv_node) || IsMarkedTrainOp(conv_node)) {
422 return false;
423 }
424 MS_ASSERT(conv_node->size() >= kConvNoBiasLen);
425 auto conv_prim = GetValueNode<PrimitivePtr>(conv_node->input(kInputIndex));
426 auto quant_attr = conv_prim->GetAttr("quant_params");
427 if (quant_attr != nullptr) {
428 auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
429 MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
430 auto quant_params = quant_param_holder->get_input_quant_params();
431 bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> ¶ms) {
432 return !params.empty() && params.front().inited;
433 });
434 if (is_quant) {
435 return false;
436 }
437 }
438 auto conv_act_ptr = conv_prim->GetAttr(ops::kActivationType);
439 if (conv_act_ptr != nullptr && GetValue<int64_t>(conv_act_ptr) != ActivationType::NO_ACTIVATION) {
440 return false;
441 }
442 // Check weight is const.
443 auto conv_weight_node = conv_node->input(kConvWeightIndex);
444 bool is_value_node = conv_weight_node->isa<ValueNode>();
445 auto conv_weight_param =
446 conv_weight_node->isa<Parameter>() ? conv_weight_node->cast<ParameterPtr>()->default_param() : nullptr;
447 return is_value_node || conv_weight_param != nullptr;
448 }
449 } // namespace mindspore::opt
450