1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/train/train_populate_parameter.h"
17 #include "src/common/ops/populate/populate_register.h"
18 #include "src/common/ops/populate/default_populate.h"
19 #include "nnacl/strided_slice_parameter.h"
20 #include "nnacl/arithmetic_parameter.h"
21 #include "nnacl/conv_parameter.h"
22 #include "nnacl/pooling_parameter.h"
23 #include "nnacl/pow_parameter.h"
24 #include "nnacl/activation_parameter.h"
25 #include "nnacl/fp32_grad/softmax_crossentropy_parameter.h"
26 #include "nnacl/fp32_grad/optimizer.h"
27 #include "nnacl/fp32_grad/batch_norm_parameter.h"
28 #include "nnacl/fp32_grad/dropout_parameter.h"
29 #include "nnacl/fp32_grad/smooth_l1_loss.h"
30 #include "nnacl/fp32_grad/resize_grad_parameter.h"
31 #include "nnacl/fp32_grad/lstm_grad_fp32.h"
32 #include "nnacl/fp32_grad/binary_cross_entropy.h"
33 #include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
34
35 using mindspore::lite::Registry;
36
37 namespace mindspore {
38 namespace kernel {
39 namespace {
40 constexpr int kInputIndexOne = 1;
41 constexpr int kInputIndexTwo = 2;
42 constexpr int kInputIndexThree = 3;
43 } // namespace
PopulateSmoothL1LossParameter(const void * prim)44 OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
45 SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
46 if (p == nullptr) {
47 MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
48 return nullptr;
49 }
50 memset(p, 0, sizeof(SmoothL1LossParameter));
51 auto primitive = static_cast<const schema::Primitive *>(prim);
52 auto value = primitive->value_as_SmoothL1Loss();
53 MS_ASSERT(value != nullptr);
54 p->op_parameter_.type_ = primitive->value_type();
55 p->beta_ = value->beta();
56 return reinterpret_cast<OpParameter *>(p);
57 }
58
PopulateSmoothL1LossGradParameter(const void * prim)59 OpParameter *PopulateSmoothL1LossGradParameter(const void *prim) {
60 SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
61 if (p == nullptr) {
62 MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
63 return nullptr;
64 }
65 memset(p, 0, sizeof(SmoothL1LossParameter));
66 auto primitive = static_cast<const schema::Primitive *>(prim);
67 auto value = primitive->value_as_SmoothL1LossGrad();
68 MS_ASSERT(value != nullptr);
69 p->op_parameter_.type_ = primitive->value_type();
70 p->beta_ = value->beta();
71 return reinterpret_cast<OpParameter *>(p);
72 }
73
PopulateApplyMomentumParameter(const void * prim)74 OpParameter *PopulateApplyMomentumParameter(const void *prim) {
75 ApplyMomentumParameter *p = reinterpret_cast<ApplyMomentumParameter *>(malloc(sizeof(ApplyMomentumParameter)));
76 if (p == nullptr) {
77 MS_LOG(ERROR) << "malloc ApplyMomentumParameter failed.";
78 return nullptr;
79 }
80 memset(p, 0, sizeof(ApplyMomentumParameter));
81 auto primitive = static_cast<const schema::Primitive *>(prim);
82 auto value = primitive->value_as_ApplyMomentum();
83 MS_ASSERT(value != nullptr);
84 p->op_parameter_.type_ = primitive->value_type();
85 p->grad_scale_ = value->gradient_scale();
86 p->use_nesterov_ = value->use_nesterov();
87 return reinterpret_cast<OpParameter *>(p);
88 }
89
PopulateBCEParameter(const void * prim)90 OpParameter *PopulateBCEParameter(const void *prim) {
91 auto primitive = static_cast<const schema::Primitive *>(prim);
92 MS_ASSERT(primitive != nullptr);
93 auto value = primitive->value_as_BinaryCrossEntropy();
94 if (value == nullptr) {
95 MS_LOG(ERROR) << "value is nullptr";
96 return nullptr;
97 }
98
99 auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
100 if (param == nullptr) {
101 MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
102 return nullptr;
103 }
104 memset(param, 0, sizeof(BinaryCrossEntropyParameter));
105
106 param->op_parameter_.type_ = primitive->value_type();
107 param->reduction = value->reduction();
108 return reinterpret_cast<OpParameter *>(param);
109 }
110
PopulateBCEGradParameter(const void * prim)111 OpParameter *PopulateBCEGradParameter(const void *prim) {
112 auto *primitive = static_cast<const schema::Primitive *>(prim);
113 MS_ASSERT(primitive != nullptr);
114 auto value = primitive->value_as_BinaryCrossEntropyGrad();
115 if (value == nullptr) {
116 MS_LOG(ERROR) << "param is nullptr";
117 return nullptr;
118 }
119
120 auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
121 if (param == nullptr) {
122 MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
123 return nullptr;
124 }
125 memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
126
127 param->op_parameter_.type_ = primitive->value_type();
128 param->reduction = value->reduction();
129 return reinterpret_cast<OpParameter *>(param);
130 }
131
PopulateAdamParameter(const void * prim)132 OpParameter *PopulateAdamParameter(const void *prim) {
133 AdamParameter *p = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter)));
134 if (p == nullptr) {
135 MS_LOG(ERROR) << "new AdamParameter failed.";
136 return nullptr;
137 }
138 memset(p, 0, sizeof(AdamParameter));
139 auto primitive = static_cast<const schema::Primitive *>(prim);
140 auto value = primitive->value_as_Adam();
141 MS_ASSERT(value != nullptr);
142 p->op_parameter_.type_ = primitive->value_type();
143 p->use_nesterov_ = value->use_nesterov();
144 return reinterpret_cast<OpParameter *>(p);
145 }
146
PopulateSgdParameter(const void * prim)147 OpParameter *PopulateSgdParameter(const void *prim) {
148 SgdParameter *p = reinterpret_cast<SgdParameter *>(malloc(sizeof(SgdParameter)));
149 if (p == nullptr) {
150 MS_LOG(ERROR) << "malloc SgdParameter failed.";
151 return nullptr;
152 }
153 memset(p, 0, sizeof(SgdParameter));
154 auto primitive = static_cast<const schema::Primitive *>(prim);
155 auto value = primitive->value_as_SGD();
156 MS_ASSERT(value != nullptr);
157 p->op_parameter_.type_ = primitive->value_type();
158 p->weight_decay_ = value->weight_decay();
159 p->dampening_ = value->dampening();
160 p->use_nesterov_ = value->nesterov();
161
162 return reinterpret_cast<OpParameter *>(p);
163 }
164
PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void * prim)165 OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) {
166 SoftmaxCrossEntropyParameter *sce_param =
167 reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
168 if (sce_param == nullptr) {
169 MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
170 return nullptr;
171 }
172 memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
173 auto primitive = static_cast<const schema::Primitive *>(prim);
174 auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
175 MS_ASSERT(value != nullptr);
176 sce_param->op_parameter_.type_ = primitive->value_type();
177 sce_param->is_grad_ = value->is_grad();
178 return reinterpret_cast<OpParameter *>(sce_param);
179 }
180
PopulateSoftmaxCrossEntropyParameter(const void * prim)181 OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *prim) {
182 SoftmaxCrossEntropyParameter *sce_param =
183 reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
184 if (sce_param == nullptr) {
185 MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
186 return nullptr;
187 }
188 memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
189 auto primitive = static_cast<const schema::Primitive *>(prim);
190 sce_param->op_parameter_.type_ = primitive->value_type();
191 sce_param->is_grad_ = false;
192 return reinterpret_cast<OpParameter *>(sce_param);
193 }
194
PopulateMaxPoolGradParameter(const void * prim)195 OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
196 PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
197 if (pooling_param == nullptr) {
198 MS_LOG(ERROR) << "malloc PoolingParameter failed.";
199 return nullptr;
200 }
201 memset(pooling_param, 0, sizeof(PoolingParameter));
202 auto primitive = static_cast<const schema::Primitive *>(prim);
203 auto value = primitive->value_as_MaxPoolGrad();
204 MS_ASSERT(value != nullptr);
205 pooling_param->op_parameter_.type_ = primitive->value_type();
206
207 pooling_param->global_ = false;
208 pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
209 pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
210
211 pooling_param->pad_u_ = 0;
212 pooling_param->pad_d_ = 0;
213 pooling_param->pad_l_ = 0;
214 pooling_param->pad_r_ = 0;
215 pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
216 pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
217 pooling_param->round_type_ = RoundType_No;
218 pooling_param->pool_mode_ = PoolMode_MaxPool;
219 switch (value->pad_mode()) {
220 case schema::PadMode_SAME:
221 pooling_param->pad_mode_ = Pad_same;
222 break;
223 case schema::PadMode_VALID:
224 pooling_param->pad_mode_ = Pad_valid;
225 break;
226 default:
227 pooling_param->pad_mode_ = Pad_pad;
228 break;
229 }
230
231 return reinterpret_cast<OpParameter *>(pooling_param);
232 }
233
PopulateAvgPoolGradParameter(const void * prim)234 OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
235 PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
236 if (pooling_param == nullptr) {
237 MS_LOG(ERROR) << "malloc PoolingParameter failed.";
238 return nullptr;
239 }
240 memset(pooling_param, 0, sizeof(PoolingParameter));
241 auto primitive = static_cast<const schema::Primitive *>(prim);
242 auto value = primitive->value_as_AvgPoolGrad();
243 MS_ASSERT(value != nullptr);
244 pooling_param->op_parameter_.type_ = primitive->value_type();
245
246 pooling_param->global_ = false;
247 pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
248 pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
249
250 pooling_param->pad_u_ = 0;
251 pooling_param->pad_d_ = 0;
252 pooling_param->pad_l_ = 0;
253 pooling_param->pad_r_ = 0;
254 pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
255 pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
256
257 switch (value->pad_mode()) {
258 case schema::PadMode_SAME:
259 pooling_param->pad_mode_ = Pad_same;
260 break;
261 case schema::PadMode_VALID:
262 pooling_param->pad_mode_ = Pad_valid;
263 break;
264 default:
265 pooling_param->pad_mode_ = Pad_pad;
266 break;
267 }
268 pooling_param->round_type_ = RoundType_No;
269 pooling_param->pool_mode_ = PoolMode_AvgPool;
270 return reinterpret_cast<OpParameter *>(pooling_param);
271 }
272
PopulateActivationGradParameter(const void * prim)273 OpParameter *PopulateActivationGradParameter(const void *prim) {
274 ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
275 if (act_param == nullptr) {
276 MS_LOG(ERROR) << "malloc ActivationParameter failed.";
277 return nullptr;
278 }
279 memset(act_param, 0, sizeof(ActivationParameter));
280 auto primitive = static_cast<const schema::Primitive *>(prim);
281 auto value = primitive->value_as_ActivationGrad();
282 MS_ASSERT(value != nullptr);
283 act_param->op_parameter_.type_ = primitive->value_type();
284 act_param->type_ = static_cast<int>(value->activation_type());
285 act_param->alpha_ = value->alpha();
286 return reinterpret_cast<OpParameter *>(act_param);
287 }
288
PopulateConvolutionGradFilterParameter(const void * prim)289 OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
290 ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
291 if (param == nullptr) {
292 MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
293 return nullptr;
294 }
295 memset(param, 0, sizeof(ConvParameter));
296 auto primitive = static_cast<const schema::Primitive *>(prim);
297 auto value = primitive->value_as_Conv2DBackpropFilterFusion();
298 MS_ASSERT(value != nullptr);
299 param->op_parameter_.type_ = primitive->value_type();
300
301 param->kernel_h_ = value->kernel_size()->Get(0);
302 param->kernel_w_ = value->kernel_size()->Get(1);
303 param->stride_h_ = value->stride()->Get((value->stride()->size()) - kInputIndexTwo);
304 param->stride_w_ = value->stride()->Get((value->stride()->size()) - kInputIndexOne);
305 param->dilation_h_ = value->dilation()->Get(0);
306 param->dilation_w_ = value->dilation()->Get(1);
307 param->pad_u_ = value->pad_list()->Get(0);
308 param->pad_d_ = value->pad_list()->Get(1);
309 param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
310 param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
311 param->group_ = value->group();
312 param->act_type_ = ActType_No;
313 switch (value->activation_type()) {
314 case schema::ActivationType_RELU:
315 param->act_type_ = ActType_Relu;
316 break;
317 case schema::ActivationType_RELU6:
318 param->act_type_ = ActType_Relu6;
319 break;
320 default:
321 break;
322 }
323 switch (value->pad_mode()) {
324 case schema::PadMode_SAME:
325 param->pad_mode_ = Pad_same;
326 break;
327 case schema::PadMode_VALID:
328 param->pad_mode_ = Pad_valid;
329 break;
330 default:
331 param->pad_mode_ = Pad_pad;
332 break;
333 }
334
335 return reinterpret_cast<OpParameter *>(param);
336 }
337
PopulateConvolutionGradInputParameter(const void * prim)338 OpParameter *PopulateConvolutionGradInputParameter(const void *prim) {
339 ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
340 if (param == nullptr) {
341 MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
342 return nullptr;
343 }
344 memset(param, 0, sizeof(ConvParameter));
345 auto primitive = static_cast<const schema::Primitive *>(prim);
346 auto value = primitive->value_as_Conv2DBackpropInputFusion();
347 MS_ASSERT(value != nullptr);
348 param->op_parameter_.type_ = primitive->value_type();
349
350 param->kernel_h_ = value->kernel_size()->Get(0);
351 param->kernel_w_ = value->kernel_size()->Get(1);
352 param->stride_h_ = value->stride()->Get((value->stride()->size()) - kInputIndexTwo);
353 param->stride_w_ = value->stride()->Get((value->stride()->size()) - kInputIndexOne);
354 param->dilation_h_ = value->dilation()->Get(0);
355 param->dilation_w_ = value->dilation()->Get(1);
356 param->pad_u_ = value->pad_list()->Get(0);
357 param->pad_d_ = value->pad_list()->Get(1);
358 param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
359 param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
360 param->group_ = value->group();
361 param->act_type_ = ActType_No;
362 switch (value->activation_type()) {
363 case schema::ActivationType_RELU:
364 param->act_type_ = ActType_Relu;
365 break;
366 case schema::ActivationType_RELU6:
367 param->act_type_ = ActType_Relu6;
368 break;
369 default:
370 break;
371 }
372 switch (value->pad_mode()) {
373 case schema::PadMode_SAME:
374 param->pad_mode_ = Pad_same;
375 break;
376 case schema::PadMode_VALID:
377 param->pad_mode_ = Pad_valid;
378 break;
379 default:
380 param->pad_mode_ = Pad_pad;
381 break;
382 }
383
384 return reinterpret_cast<OpParameter *>(param);
385 }
386
PopulatePowerGradParameter(const void * prim)387 OpParameter *PopulatePowerGradParameter(const void *prim) {
388 PowParameter *power_param = reinterpret_cast<PowParameter *>(malloc(sizeof(PowParameter)));
389 if (power_param == nullptr) {
390 MS_LOG(ERROR) << "malloc PowParameter failed.";
391 return nullptr;
392 }
393 memset(power_param, 0, sizeof(PowParameter));
394 auto primitive = static_cast<const schema::Primitive *>(prim);
395 auto value = primitive->value_as_PowerGrad();
396 MS_ASSERT(value != nullptr);
397 power_param->op_parameter_.type_ = primitive->value_type();
398 power_param->power_ = value->power();
399 power_param->scale_ = value->scale();
400 power_param->shift_ = value->shift();
401 return reinterpret_cast<OpParameter *>(power_param);
402 }
403
PopulateBiasGradParameter(const void * prim)404 OpParameter *PopulateBiasGradParameter(const void *prim) {
405 ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
406 if (arithmetic_param == nullptr) {
407 MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
408 return nullptr;
409 }
410 memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
411 auto primitive = static_cast<const schema::Primitive *>(prim);
412 arithmetic_param->op_parameter_.type_ = primitive->value_type();
413 return reinterpret_cast<OpParameter *>(arithmetic_param);
414 }
415
PopulateBNGradParameter(const void * prim)416 OpParameter *PopulateBNGradParameter(const void *prim) {
417 BNGradParameter *bnGrad_param = reinterpret_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter)));
418 if (bnGrad_param == nullptr) {
419 MS_LOG(ERROR) << "malloc BNGradParameter failed.";
420 return nullptr;
421 }
422 memset(bnGrad_param, 0, sizeof(BNGradParameter));
423 auto primitive = static_cast<const schema::Primitive *>(prim);
424 auto value = primitive->value_as_BatchNormGrad();
425 MS_ASSERT(value != nullptr);
426 bnGrad_param->op_parameter_.type_ = primitive->value_type();
427 bnGrad_param->epsilon_ = value->epsilon();
428 bnGrad_param->is_training_ = value->is_training();
429 return reinterpret_cast<OpParameter *>(bnGrad_param);
430 }
431
PopulateDropoutParameter(const void * prim)432 OpParameter *PopulateDropoutParameter(const void *prim) {
433 DropoutParameter *dropout_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
434 if (dropout_parameter == nullptr) {
435 MS_LOG(ERROR) << "malloc Dropout Parameter failed.";
436 return nullptr;
437 }
438 memset(dropout_parameter, 0, sizeof(DropoutParameter));
439 auto primitive = static_cast<const schema::Primitive *>(prim);
440 auto value = primitive->value_as_Dropout();
441 MS_ASSERT(value != nullptr);
442 dropout_parameter->op_parameter_.type_ = primitive->value_type();
443 dropout_parameter->ratio_ = value->keep_prob();
444 if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
445 MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_;
446 free(dropout_parameter);
447 return nullptr;
448 }
449 return reinterpret_cast<OpParameter *>(dropout_parameter);
450 }
451
PopulateDropoutGradParameter(const void * prim)452 OpParameter *PopulateDropoutGradParameter(const void *prim) {
453 DropoutParameter *dropoutgrad_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
454 if (dropoutgrad_parameter == nullptr) {
455 MS_LOG(ERROR) << "malloc Dropout Grad Parameter failed.";
456 return nullptr;
457 }
458 memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter));
459 auto primitive = static_cast<const schema::Primitive *>(prim);
460 auto value = primitive->value_as_DropoutGrad();
461 MS_ASSERT(value != nullptr);
462 dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type();
463 dropoutgrad_parameter->ratio_ = value->keep_prob();
464 if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) {
465 MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutgrad_parameter->ratio_;
466 free(dropoutgrad_parameter);
467 return nullptr;
468 }
469 return reinterpret_cast<OpParameter *>(dropoutgrad_parameter);
470 }
471
PopulateArithmeticGradParameter(const void * prim)472 OpParameter *PopulateArithmeticGradParameter(const void *prim) {
473 ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
474 if (arithmetic_param == nullptr) {
475 MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
476 return nullptr;
477 }
478 memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
479 auto primitive = static_cast<const schema::Primitive *>(prim);
480 arithmetic_param->op_parameter_.type_ = primitive->value_type();
481 return reinterpret_cast<OpParameter *>(arithmetic_param);
482 }
483
PopulateResizeGradParameter(const void * prim)484 OpParameter *PopulateResizeGradParameter(const void *prim) {
485 ResizeGradParameter *resize_grad_param = reinterpret_cast<ResizeGradParameter *>(malloc(sizeof(ResizeGradParameter)));
486 if (resize_grad_param == nullptr) {
487 MS_LOG(ERROR) << "malloc resize grad parameter failed.";
488 return nullptr;
489 }
490 memset(resize_grad_param, 0, sizeof(ResizeGradParameter));
491 auto primitive = static_cast<const schema::Primitive *>(prim);
492 resize_grad_param->op_parameter_.type_ = primitive->value_type();
493 auto param = primitive->value_as_ResizeGrad();
494 MS_ASSERT(param != nullptr);
495 resize_grad_param->method = static_cast<int>(param->method());
496 resize_grad_param->align_corners_ = param->align_corners();
497
498 return reinterpret_cast<OpParameter *>(resize_grad_param);
499 }
500
PopulateStridedSliceGradParameter(const void * prim)501 OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
502 StridedSliceParameter *strided_slice_param =
503 reinterpret_cast<StridedSliceParameter *>(malloc(sizeof(StridedSliceParameter)));
504 if (strided_slice_param == nullptr) {
505 MS_LOG(ERROR) << "malloc StridedSliceParameter failed.";
506 return nullptr;
507 }
508 memset(strided_slice_param, 0, sizeof(StridedSliceParameter));
509
510 auto primitive = static_cast<const schema::Primitive *>(prim);
511 auto value = primitive->value_as_StridedSliceGrad();
512 MS_ASSERT(value != nullptr);
513 strided_slice_param->op_parameter_.type_ = primitive->value_type();
514
515 strided_slice_param->begins_mask_ = value->begin_mask();
516 strided_slice_param->ends_mask_ = value->end_mask();
517 strided_slice_param->ellipsisMask_ = value->ellipsis_mask();
518 strided_slice_param->newAxisMask_ = value->new_axis_mask();
519 strided_slice_param->shrinkAxisMask_ = value->shrink_axis_mask();
520 return reinterpret_cast<OpParameter *>(strided_slice_param);
521 }
522
PopulateLstmGradParameter(const void * prim)523 OpParameter *PopulateLstmGradParameter(const void *prim) {
524 auto primitive = static_cast<const schema::Primitive *>(prim);
525 MS_ASSERT(primitive != nullptr);
526 auto value = primitive->value_as_LSTMGrad();
527 if (value == nullptr) {
528 MS_LOG(ERROR) << "value is nullptr.";
529 return nullptr;
530 }
531
532 auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
533 if (param == nullptr) {
534 MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
535 return nullptr;
536 }
537 memset(param, 0, sizeof(LstmGradParameter));
538
539 param->op_parameter_.type_ = primitive->value_type();
540 param->bidirectional_ = value->bidirectional();
541 param->zoneout_cell_ = value->zoneout_cell();
542 param->zoneout_hidden_ = value->zoneout_hidden();
543 param->input_size_ = value->input_size();
544 param->has_bias_ = static_cast<int>(value->has_bias());
545 param->hidden_size_ = value->hidden_size();
546
547 return reinterpret_cast<OpParameter *>(param);
548 }
549
PopulateLstmGradDataParameter(const void * prim)550 OpParameter *PopulateLstmGradDataParameter(const void *prim) {
551 auto primitive = static_cast<const schema::Primitive *>(prim);
552 MS_ASSERT(primitive != nullptr);
553 auto value = primitive->value_as_LSTMGradData();
554 if (value == nullptr) {
555 MS_LOG(ERROR) << "value is nullptr.";
556 return nullptr;
557 }
558
559 auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
560 if (param == nullptr) {
561 MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
562 return nullptr;
563 }
564 memset(param, 0, sizeof(LstmGradParameter));
565
566 param->op_parameter_.type_ = primitive->value_type();
567 param->bidirectional_ = value->bidirectional();
568 param->zoneout_cell_ = value->zoneout_cell();
569 param->zoneout_hidden_ = value->zoneout_hidden();
570 param->input_size_ = value->input_size();
571 param->has_bias_ = static_cast<int>(value->has_bias());
572 param->hidden_size_ = value->hidden_size();
573 return reinterpret_cast<OpParameter *>(param);
574 }
575
PopulateLstmGradWeightParameter(const void * prim)576 OpParameter *PopulateLstmGradWeightParameter(const void *prim) {
577 auto primitive = static_cast<const schema::Primitive *>(prim);
578 MS_ASSERT(primitive != nullptr);
579 auto value = primitive->value_as_LSTMGradWeight();
580 if (value == nullptr) {
581 MS_LOG(ERROR) << "value is nullptr.";
582 return nullptr;
583 }
584
585 auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
586 if (param == nullptr) {
587 MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
588 return nullptr;
589 }
590 memset(param, 0, sizeof(LstmGradParameter));
591
592 param->op_parameter_.type_ = primitive->value_type();
593 param->input_size_ = value->input_size();
594 param->hidden_size_ = value->hidden_size();
595 param->bidirectional_ = value->bidirectional();
596 param->zoneout_cell_ = value->zoneout_cell();
597 param->zoneout_hidden_ = value->zoneout_hidden();
598 param->has_bias_ = static_cast<int>(value->has_bias());
599 return reinterpret_cast<OpParameter *>(param);
600 }
601
PopulateTrainParameters()602 void PopulateTrainParameters() {
603 Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter,
604 lite::SCHEMA_CUR);
605 Registry BiasGradParameterRegistry(schema::PrimitiveType_BiasAddGrad, PopulateBiasGradParameter, lite::SCHEMA_CUR);
606 Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits,
607 PopulateSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR);
608 Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
609 PopulateSparseSoftmaxCrossEntropyWithLogitsParameter,
610 lite::SCHEMA_CUR);
611 Registry ActivationParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter,
612 lite::SCHEMA_CUR);
613 Registry DependParameterRegistry(schema::PrimitiveType_Depend, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
614 Registry Conv2DGradFilterParameterRegistry(schema::PrimitiveType_Conv2DBackpropFilterFusion,
615 PopulateConvolutionGradFilterParameter, lite::SCHEMA_CUR);
616 Registry Conv2DGradInputParameterRegistry(schema::PrimitiveType_Conv2DBackpropInputFusion,
617 PopulateConvolutionGradInputParameter, lite::SCHEMA_CUR);
618 Registry avgPoolParameterRegistry(schema::PrimitiveType_AvgPoolGrad, PopulateAvgPoolGradParameter, lite::SCHEMA_CUR);
619 Registry maxPoolParameterRegistry(schema::PrimitiveType_MaxPoolGrad, PopulateMaxPoolGradParameter, lite::SCHEMA_CUR);
620 Registry PowerGradParameterRegistry(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter, lite::SCHEMA_CUR);
621 Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
622 Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter, lite::SCHEMA_CUR);
623 Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR);
624 Registry AdamWeightDecayParameterRegistry(schema::PrimitiveType_AdamWeightDecay, lite::DefaultPopulateParameter,
625 lite::SCHEMA_CUR);
626 Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
627 Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter,
628 lite::SCHEMA_CUR);
629 Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter,
630 lite::SCHEMA_CUR);
631 Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad,
632 PopulateBCEGradParameter, lite::SCHEMA_CUR);
633 Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
634 Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum, lite::DefaultPopulateParameter,
635 lite::SCHEMA_CUR);
636 Registry DropoutParameterRegistry(schema::PrimitiveType_Dropout, PopulateDropoutParameter, lite::SCHEMA_CUR);
637 Registry DropGradParameterRegistry(schema::PrimitiveType_DropoutGrad, PopulateDropoutGradParameter, lite::SCHEMA_CUR);
638 Registry MaximumGradParameterRegistry(schema::PrimitiveType_MaximumGrad, PopulateArithmeticGradParameter,
639 lite::SCHEMA_CUR);
640 Registry MinimumGradParameterRegistry(schema::PrimitiveType_MinimumGrad, PopulateArithmeticGradParameter,
641 lite::SCHEMA_CUR);
642 Registry SmoothL1LossRegistry(schema::PrimitiveType_SmoothL1Loss, PopulateSmoothL1LossParameter, lite::SCHEMA_CUR);
643 Registry SmoothL1LossGradRegistry(schema::PrimitiveType_SmoothL1LossGrad, PopulateSmoothL1LossGradParameter,
644 lite::SCHEMA_CUR);
645 Registry SigmoidCrossEntropyWithLogitsRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
646 lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
647 Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
648 lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
649 Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, lite::DefaultPopulateParameter,
650 lite::SCHEMA_CUR);
651 Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, PopulateStridedSliceGradParameter,
652 lite::SCHEMA_CUR);
653 Registry SqrtGradParameterRegistry(schema::PrimitiveType_SqrtGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
654 Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter,
655 lite::SCHEMA_CUR);
656 Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, lite::SCHEMA_CUR);
657 Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
658 Registry LSTMGradParameterRegistry(schema::PrimitiveType_LSTMGrad, PopulateLstmGradParameter, lite::SCHEMA_CUR);
659 Registry LSTMGradDataParameterRegistry(schema::PrimitiveType_LSTMGradData, PopulateLstmGradDataParameter,
660 lite::SCHEMA_CUR);
661 Registry LSTMGradWeightParameterRegistry(schema::PrimitiveType_LSTMGradWeight, PopulateLstmGradWeightParameter,
662 lite::SCHEMA_CUR);
663 }
664 } // namespace kernel
665 } // namespace mindspore
666