• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/train/train_populate_parameter.h"
17 #include "src/common/ops/populate/populate_register.h"
18 #include "src/common/ops/populate/default_populate.h"
19 #include "nnacl/strided_slice_parameter.h"
20 #include "nnacl/arithmetic_parameter.h"
21 #include "nnacl/conv_parameter.h"
22 #include "nnacl/pooling_parameter.h"
23 #include "nnacl/pow_parameter.h"
24 #include "nnacl/activation_parameter.h"
25 #include "nnacl/fp32_grad/softmax_crossentropy_parameter.h"
26 #include "nnacl/fp32_grad/optimizer.h"
27 #include "nnacl/fp32_grad/batch_norm_parameter.h"
28 #include "nnacl/fp32_grad/dropout_parameter.h"
29 #include "nnacl/fp32_grad/smooth_l1_loss.h"
30 #include "nnacl/fp32_grad/resize_grad_parameter.h"
31 #include "nnacl/fp32_grad/lstm_grad_fp32.h"
32 #include "nnacl/fp32_grad/binary_cross_entropy.h"
33 #include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
34 
35 using mindspore::lite::Registry;
36 
37 namespace mindspore {
38 namespace kernel {
39 namespace {
40 constexpr int kInputIndexOne = 1;
41 constexpr int kInputIndexTwo = 2;
42 constexpr int kInputIndexThree = 3;
43 }  // namespace
PopulateSmoothL1LossParameter(const void * prim)44 OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
45   SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
46   if (p == nullptr) {
47     MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
48     return nullptr;
49   }
50   memset(p, 0, sizeof(SmoothL1LossParameter));
51   auto primitive = static_cast<const schema::Primitive *>(prim);
52   auto value = primitive->value_as_SmoothL1Loss();
53   MS_ASSERT(value != nullptr);
54   p->op_parameter_.type_ = primitive->value_type();
55   p->beta_ = value->beta();
56   return reinterpret_cast<OpParameter *>(p);
57 }
58 
PopulateSmoothL1LossGradParameter(const void * prim)59 OpParameter *PopulateSmoothL1LossGradParameter(const void *prim) {
60   SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
61   if (p == nullptr) {
62     MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
63     return nullptr;
64   }
65   memset(p, 0, sizeof(SmoothL1LossParameter));
66   auto primitive = static_cast<const schema::Primitive *>(prim);
67   auto value = primitive->value_as_SmoothL1LossGrad();
68   MS_ASSERT(value != nullptr);
69   p->op_parameter_.type_ = primitive->value_type();
70   p->beta_ = value->beta();
71   return reinterpret_cast<OpParameter *>(p);
72 }
73 
PopulateApplyMomentumParameter(const void * prim)74 OpParameter *PopulateApplyMomentumParameter(const void *prim) {
75   ApplyMomentumParameter *p = reinterpret_cast<ApplyMomentumParameter *>(malloc(sizeof(ApplyMomentumParameter)));
76   if (p == nullptr) {
77     MS_LOG(ERROR) << "malloc ApplyMomentumParameter failed.";
78     return nullptr;
79   }
80   memset(p, 0, sizeof(ApplyMomentumParameter));
81   auto primitive = static_cast<const schema::Primitive *>(prim);
82   auto value = primitive->value_as_ApplyMomentum();
83   MS_ASSERT(value != nullptr);
84   p->op_parameter_.type_ = primitive->value_type();
85   p->grad_scale_ = value->gradient_scale();
86   p->use_nesterov_ = value->use_nesterov();
87   return reinterpret_cast<OpParameter *>(p);
88 }
89 
PopulateBCEParameter(const void * prim)90 OpParameter *PopulateBCEParameter(const void *prim) {
91   auto primitive = static_cast<const schema::Primitive *>(prim);
92   MS_ASSERT(primitive != nullptr);
93   auto value = primitive->value_as_BinaryCrossEntropy();
94   if (value == nullptr) {
95     MS_LOG(ERROR) << "value is nullptr";
96     return nullptr;
97   }
98 
99   auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
100   if (param == nullptr) {
101     MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
102     return nullptr;
103   }
104   memset(param, 0, sizeof(BinaryCrossEntropyParameter));
105 
106   param->op_parameter_.type_ = primitive->value_type();
107   param->reduction = value->reduction();
108   return reinterpret_cast<OpParameter *>(param);
109 }
110 
PopulateBCEGradParameter(const void * prim)111 OpParameter *PopulateBCEGradParameter(const void *prim) {
112   auto *primitive = static_cast<const schema::Primitive *>(prim);
113   MS_ASSERT(primitive != nullptr);
114   auto value = primitive->value_as_BinaryCrossEntropyGrad();
115   if (value == nullptr) {
116     MS_LOG(ERROR) << "param is nullptr";
117     return nullptr;
118   }
119 
120   auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
121   if (param == nullptr) {
122     MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
123     return nullptr;
124   }
125   memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
126 
127   param->op_parameter_.type_ = primitive->value_type();
128   param->reduction = value->reduction();
129   return reinterpret_cast<OpParameter *>(param);
130 }
131 
PopulateAdamParameter(const void * prim)132 OpParameter *PopulateAdamParameter(const void *prim) {
133   AdamParameter *p = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter)));
134   if (p == nullptr) {
135     MS_LOG(ERROR) << "new AdamParameter failed.";
136     return nullptr;
137   }
138   memset(p, 0, sizeof(AdamParameter));
139   auto primitive = static_cast<const schema::Primitive *>(prim);
140   auto value = primitive->value_as_Adam();
141   MS_ASSERT(value != nullptr);
142   p->op_parameter_.type_ = primitive->value_type();
143   p->use_nesterov_ = value->use_nesterov();
144   return reinterpret_cast<OpParameter *>(p);
145 }
146 
PopulateSgdParameter(const void * prim)147 OpParameter *PopulateSgdParameter(const void *prim) {
148   SgdParameter *p = reinterpret_cast<SgdParameter *>(malloc(sizeof(SgdParameter)));
149   if (p == nullptr) {
150     MS_LOG(ERROR) << "malloc SgdParameter failed.";
151     return nullptr;
152   }
153   memset(p, 0, sizeof(SgdParameter));
154   auto primitive = static_cast<const schema::Primitive *>(prim);
155   auto value = primitive->value_as_SGD();
156   MS_ASSERT(value != nullptr);
157   p->op_parameter_.type_ = primitive->value_type();
158   p->weight_decay_ = value->weight_decay();
159   p->dampening_ = value->dampening();
160   p->use_nesterov_ = value->nesterov();
161 
162   return reinterpret_cast<OpParameter *>(p);
163 }
164 
PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void * prim)165 OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) {
166   SoftmaxCrossEntropyParameter *sce_param =
167     reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
168   if (sce_param == nullptr) {
169     MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
170     return nullptr;
171   }
172   memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
173   auto primitive = static_cast<const schema::Primitive *>(prim);
174   auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
175   MS_ASSERT(value != nullptr);
176   sce_param->op_parameter_.type_ = primitive->value_type();
177   sce_param->is_grad_ = value->is_grad();
178   return reinterpret_cast<OpParameter *>(sce_param);
179 }
180 
PopulateSoftmaxCrossEntropyParameter(const void * prim)181 OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *prim) {
182   SoftmaxCrossEntropyParameter *sce_param =
183     reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
184   if (sce_param == nullptr) {
185     MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
186     return nullptr;
187   }
188   memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
189   auto primitive = static_cast<const schema::Primitive *>(prim);
190   sce_param->op_parameter_.type_ = primitive->value_type();
191   sce_param->is_grad_ = false;
192   return reinterpret_cast<OpParameter *>(sce_param);
193 }
194 
PopulateMaxPoolGradParameter(const void * prim)195 OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
196   PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
197   if (pooling_param == nullptr) {
198     MS_LOG(ERROR) << "malloc PoolingParameter failed.";
199     return nullptr;
200   }
201   memset(pooling_param, 0, sizeof(PoolingParameter));
202   auto primitive = static_cast<const schema::Primitive *>(prim);
203   auto value = primitive->value_as_MaxPoolGrad();
204   MS_ASSERT(value != nullptr);
205   pooling_param->op_parameter_.type_ = primitive->value_type();
206 
207   pooling_param->global_ = false;
208   pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
209   pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
210 
211   pooling_param->pad_u_ = 0;
212   pooling_param->pad_d_ = 0;
213   pooling_param->pad_l_ = 0;
214   pooling_param->pad_r_ = 0;
215   pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
216   pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
217   pooling_param->round_type_ = RoundType_No;
218   pooling_param->pool_mode_ = PoolMode_MaxPool;
219   switch (value->pad_mode()) {
220     case schema::PadMode_SAME:
221       pooling_param->pad_mode_ = Pad_same;
222       break;
223     case schema::PadMode_VALID:
224       pooling_param->pad_mode_ = Pad_valid;
225       break;
226     default:
227       pooling_param->pad_mode_ = Pad_pad;
228       break;
229   }
230 
231   return reinterpret_cast<OpParameter *>(pooling_param);
232 }
233 
PopulateAvgPoolGradParameter(const void * prim)234 OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
235   PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
236   if (pooling_param == nullptr) {
237     MS_LOG(ERROR) << "malloc PoolingParameter failed.";
238     return nullptr;
239   }
240   memset(pooling_param, 0, sizeof(PoolingParameter));
241   auto primitive = static_cast<const schema::Primitive *>(prim);
242   auto value = primitive->value_as_AvgPoolGrad();
243   MS_ASSERT(value != nullptr);
244   pooling_param->op_parameter_.type_ = primitive->value_type();
245 
246   pooling_param->global_ = false;
247   pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
248   pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
249 
250   pooling_param->pad_u_ = 0;
251   pooling_param->pad_d_ = 0;
252   pooling_param->pad_l_ = 0;
253   pooling_param->pad_r_ = 0;
254   pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
255   pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
256 
257   switch (value->pad_mode()) {
258     case schema::PadMode_SAME:
259       pooling_param->pad_mode_ = Pad_same;
260       break;
261     case schema::PadMode_VALID:
262       pooling_param->pad_mode_ = Pad_valid;
263       break;
264     default:
265       pooling_param->pad_mode_ = Pad_pad;
266       break;
267   }
268   pooling_param->round_type_ = RoundType_No;
269   pooling_param->pool_mode_ = PoolMode_AvgPool;
270   return reinterpret_cast<OpParameter *>(pooling_param);
271 }
272 
PopulateActivationGradParameter(const void * prim)273 OpParameter *PopulateActivationGradParameter(const void *prim) {
274   ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
275   if (act_param == nullptr) {
276     MS_LOG(ERROR) << "malloc ActivationParameter failed.";
277     return nullptr;
278   }
279   memset(act_param, 0, sizeof(ActivationParameter));
280   auto primitive = static_cast<const schema::Primitive *>(prim);
281   auto value = primitive->value_as_ActivationGrad();
282   MS_ASSERT(value != nullptr);
283   act_param->op_parameter_.type_ = primitive->value_type();
284   act_param->type_ = static_cast<int>(value->activation_type());
285   act_param->alpha_ = value->alpha();
286   return reinterpret_cast<OpParameter *>(act_param);
287 }
288 
PopulateConvolutionGradFilterParameter(const void * prim)289 OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
290   ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
291   if (param == nullptr) {
292     MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
293     return nullptr;
294   }
295   memset(param, 0, sizeof(ConvParameter));
296   auto primitive = static_cast<const schema::Primitive *>(prim);
297   auto value = primitive->value_as_Conv2DBackpropFilterFusion();
298   MS_ASSERT(value != nullptr);
299   param->op_parameter_.type_ = primitive->value_type();
300 
301   param->kernel_h_ = value->kernel_size()->Get(0);
302   param->kernel_w_ = value->kernel_size()->Get(1);
303   param->stride_h_ = value->stride()->Get((value->stride()->size()) - kInputIndexTwo);
304   param->stride_w_ = value->stride()->Get((value->stride()->size()) - kInputIndexOne);
305   param->dilation_h_ = value->dilation()->Get(0);
306   param->dilation_w_ = value->dilation()->Get(1);
307   param->pad_u_ = value->pad_list()->Get(0);
308   param->pad_d_ = value->pad_list()->Get(1);
309   param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
310   param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
311   param->group_ = value->group();
312   param->act_type_ = ActType_No;
313   switch (value->activation_type()) {
314     case schema::ActivationType_RELU:
315       param->act_type_ = ActType_Relu;
316       break;
317     case schema::ActivationType_RELU6:
318       param->act_type_ = ActType_Relu6;
319       break;
320     default:
321       break;
322   }
323   switch (value->pad_mode()) {
324     case schema::PadMode_SAME:
325       param->pad_mode_ = Pad_same;
326       break;
327     case schema::PadMode_VALID:
328       param->pad_mode_ = Pad_valid;
329       break;
330     default:
331       param->pad_mode_ = Pad_pad;
332       break;
333   }
334 
335   return reinterpret_cast<OpParameter *>(param);
336 }
337 
PopulateConvolutionGradInputParameter(const void * prim)338 OpParameter *PopulateConvolutionGradInputParameter(const void *prim) {
339   ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
340   if (param == nullptr) {
341     MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
342     return nullptr;
343   }
344   memset(param, 0, sizeof(ConvParameter));
345   auto primitive = static_cast<const schema::Primitive *>(prim);
346   auto value = primitive->value_as_Conv2DBackpropInputFusion();
347   MS_ASSERT(value != nullptr);
348   param->op_parameter_.type_ = primitive->value_type();
349 
350   param->kernel_h_ = value->kernel_size()->Get(0);
351   param->kernel_w_ = value->kernel_size()->Get(1);
352   param->stride_h_ = value->stride()->Get((value->stride()->size()) - kInputIndexTwo);
353   param->stride_w_ = value->stride()->Get((value->stride()->size()) - kInputIndexOne);
354   param->dilation_h_ = value->dilation()->Get(0);
355   param->dilation_w_ = value->dilation()->Get(1);
356   param->pad_u_ = value->pad_list()->Get(0);
357   param->pad_d_ = value->pad_list()->Get(1);
358   param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
359   param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
360   param->group_ = value->group();
361   param->act_type_ = ActType_No;
362   switch (value->activation_type()) {
363     case schema::ActivationType_RELU:
364       param->act_type_ = ActType_Relu;
365       break;
366     case schema::ActivationType_RELU6:
367       param->act_type_ = ActType_Relu6;
368       break;
369     default:
370       break;
371   }
372   switch (value->pad_mode()) {
373     case schema::PadMode_SAME:
374       param->pad_mode_ = Pad_same;
375       break;
376     case schema::PadMode_VALID:
377       param->pad_mode_ = Pad_valid;
378       break;
379     default:
380       param->pad_mode_ = Pad_pad;
381       break;
382   }
383 
384   return reinterpret_cast<OpParameter *>(param);
385 }
386 
PopulatePowerGradParameter(const void * prim)387 OpParameter *PopulatePowerGradParameter(const void *prim) {
388   PowParameter *power_param = reinterpret_cast<PowParameter *>(malloc(sizeof(PowParameter)));
389   if (power_param == nullptr) {
390     MS_LOG(ERROR) << "malloc PowParameter failed.";
391     return nullptr;
392   }
393   memset(power_param, 0, sizeof(PowParameter));
394   auto primitive = static_cast<const schema::Primitive *>(prim);
395   auto value = primitive->value_as_PowerGrad();
396   MS_ASSERT(value != nullptr);
397   power_param->op_parameter_.type_ = primitive->value_type();
398   power_param->power_ = value->power();
399   power_param->scale_ = value->scale();
400   power_param->shift_ = value->shift();
401   return reinterpret_cast<OpParameter *>(power_param);
402 }
403 
PopulateBiasGradParameter(const void * prim)404 OpParameter *PopulateBiasGradParameter(const void *prim) {
405   ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
406   if (arithmetic_param == nullptr) {
407     MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
408     return nullptr;
409   }
410   memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
411   auto primitive = static_cast<const schema::Primitive *>(prim);
412   arithmetic_param->op_parameter_.type_ = primitive->value_type();
413   return reinterpret_cast<OpParameter *>(arithmetic_param);
414 }
415 
PopulateBNGradParameter(const void * prim)416 OpParameter *PopulateBNGradParameter(const void *prim) {
417   BNGradParameter *bnGrad_param = reinterpret_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter)));
418   if (bnGrad_param == nullptr) {
419     MS_LOG(ERROR) << "malloc BNGradParameter failed.";
420     return nullptr;
421   }
422   memset(bnGrad_param, 0, sizeof(BNGradParameter));
423   auto primitive = static_cast<const schema::Primitive *>(prim);
424   auto value = primitive->value_as_BatchNormGrad();
425   MS_ASSERT(value != nullptr);
426   bnGrad_param->op_parameter_.type_ = primitive->value_type();
427   bnGrad_param->epsilon_ = value->epsilon();
428   bnGrad_param->is_training_ = value->is_training();
429   return reinterpret_cast<OpParameter *>(bnGrad_param);
430 }
431 
PopulateDropoutParameter(const void * prim)432 OpParameter *PopulateDropoutParameter(const void *prim) {
433   DropoutParameter *dropout_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
434   if (dropout_parameter == nullptr) {
435     MS_LOG(ERROR) << "malloc Dropout Parameter failed.";
436     return nullptr;
437   }
438   memset(dropout_parameter, 0, sizeof(DropoutParameter));
439   auto primitive = static_cast<const schema::Primitive *>(prim);
440   auto value = primitive->value_as_Dropout();
441   MS_ASSERT(value != nullptr);
442   dropout_parameter->op_parameter_.type_ = primitive->value_type();
443   dropout_parameter->ratio_ = value->keep_prob();
444   if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
445     MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_;
446     free(dropout_parameter);
447     return nullptr;
448   }
449   return reinterpret_cast<OpParameter *>(dropout_parameter);
450 }
451 
PopulateDropoutGradParameter(const void * prim)452 OpParameter *PopulateDropoutGradParameter(const void *prim) {
453   DropoutParameter *dropoutgrad_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
454   if (dropoutgrad_parameter == nullptr) {
455     MS_LOG(ERROR) << "malloc Dropout Grad Parameter failed.";
456     return nullptr;
457   }
458   memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter));
459   auto primitive = static_cast<const schema::Primitive *>(prim);
460   auto value = primitive->value_as_DropoutGrad();
461   MS_ASSERT(value != nullptr);
462   dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type();
463   dropoutgrad_parameter->ratio_ = value->keep_prob();
464   if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) {
465     MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutgrad_parameter->ratio_;
466     free(dropoutgrad_parameter);
467     return nullptr;
468   }
469   return reinterpret_cast<OpParameter *>(dropoutgrad_parameter);
470 }
471 
PopulateArithmeticGradParameter(const void * prim)472 OpParameter *PopulateArithmeticGradParameter(const void *prim) {
473   ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
474   if (arithmetic_param == nullptr) {
475     MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
476     return nullptr;
477   }
478   memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
479   auto primitive = static_cast<const schema::Primitive *>(prim);
480   arithmetic_param->op_parameter_.type_ = primitive->value_type();
481   return reinterpret_cast<OpParameter *>(arithmetic_param);
482 }
483 
PopulateResizeGradParameter(const void * prim)484 OpParameter *PopulateResizeGradParameter(const void *prim) {
485   ResizeGradParameter *resize_grad_param = reinterpret_cast<ResizeGradParameter *>(malloc(sizeof(ResizeGradParameter)));
486   if (resize_grad_param == nullptr) {
487     MS_LOG(ERROR) << "malloc resize grad parameter failed.";
488     return nullptr;
489   }
490   memset(resize_grad_param, 0, sizeof(ResizeGradParameter));
491   auto primitive = static_cast<const schema::Primitive *>(prim);
492   resize_grad_param->op_parameter_.type_ = primitive->value_type();
493   auto param = primitive->value_as_ResizeGrad();
494   MS_ASSERT(param != nullptr);
495   resize_grad_param->method = static_cast<int>(param->method());
496   resize_grad_param->align_corners_ = param->align_corners();
497 
498   return reinterpret_cast<OpParameter *>(resize_grad_param);
499 }
500 
PopulateStridedSliceGradParameter(const void * prim)501 OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
502   StridedSliceParameter *strided_slice_param =
503     reinterpret_cast<StridedSliceParameter *>(malloc(sizeof(StridedSliceParameter)));
504   if (strided_slice_param == nullptr) {
505     MS_LOG(ERROR) << "malloc StridedSliceParameter failed.";
506     return nullptr;
507   }
508   memset(strided_slice_param, 0, sizeof(StridedSliceParameter));
509 
510   auto primitive = static_cast<const schema::Primitive *>(prim);
511   auto value = primitive->value_as_StridedSliceGrad();
512   MS_ASSERT(value != nullptr);
513   strided_slice_param->op_parameter_.type_ = primitive->value_type();
514 
515   strided_slice_param->begins_mask_ = value->begin_mask();
516   strided_slice_param->ends_mask_ = value->end_mask();
517   strided_slice_param->ellipsisMask_ = value->ellipsis_mask();
518   strided_slice_param->newAxisMask_ = value->new_axis_mask();
519   strided_slice_param->shrinkAxisMask_ = value->shrink_axis_mask();
520   return reinterpret_cast<OpParameter *>(strided_slice_param);
521 }
522 
PopulateLstmGradParameter(const void * prim)523 OpParameter *PopulateLstmGradParameter(const void *prim) {
524   auto primitive = static_cast<const schema::Primitive *>(prim);
525   MS_ASSERT(primitive != nullptr);
526   auto value = primitive->value_as_LSTMGrad();
527   if (value == nullptr) {
528     MS_LOG(ERROR) << "value is nullptr.";
529     return nullptr;
530   }
531 
532   auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
533   if (param == nullptr) {
534     MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
535     return nullptr;
536   }
537   memset(param, 0, sizeof(LstmGradParameter));
538 
539   param->op_parameter_.type_ = primitive->value_type();
540   param->bidirectional_ = value->bidirectional();
541   param->zoneout_cell_ = value->zoneout_cell();
542   param->zoneout_hidden_ = value->zoneout_hidden();
543   param->input_size_ = value->input_size();
544   param->has_bias_ = static_cast<int>(value->has_bias());
545   param->hidden_size_ = value->hidden_size();
546 
547   return reinterpret_cast<OpParameter *>(param);
548 }
549 
PopulateLstmGradDataParameter(const void * prim)550 OpParameter *PopulateLstmGradDataParameter(const void *prim) {
551   auto primitive = static_cast<const schema::Primitive *>(prim);
552   MS_ASSERT(primitive != nullptr);
553   auto value = primitive->value_as_LSTMGradData();
554   if (value == nullptr) {
555     MS_LOG(ERROR) << "value is nullptr.";
556     return nullptr;
557   }
558 
559   auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
560   if (param == nullptr) {
561     MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
562     return nullptr;
563   }
564   memset(param, 0, sizeof(LstmGradParameter));
565 
566   param->op_parameter_.type_ = primitive->value_type();
567   param->bidirectional_ = value->bidirectional();
568   param->zoneout_cell_ = value->zoneout_cell();
569   param->zoneout_hidden_ = value->zoneout_hidden();
570   param->input_size_ = value->input_size();
571   param->has_bias_ = static_cast<int>(value->has_bias());
572   param->hidden_size_ = value->hidden_size();
573   return reinterpret_cast<OpParameter *>(param);
574 }
575 
PopulateLstmGradWeightParameter(const void * prim)576 OpParameter *PopulateLstmGradWeightParameter(const void *prim) {
577   auto primitive = static_cast<const schema::Primitive *>(prim);
578   MS_ASSERT(primitive != nullptr);
579   auto value = primitive->value_as_LSTMGradWeight();
580   if (value == nullptr) {
581     MS_LOG(ERROR) << "value is nullptr.";
582     return nullptr;
583   }
584 
585   auto *param = reinterpret_cast<LstmGradParameter *>(malloc(sizeof(LstmGradParameter)));
586   if (param == nullptr) {
587     MS_LOG(ERROR) << "malloc LstmGradParameter failed.";
588     return nullptr;
589   }
590   memset(param, 0, sizeof(LstmGradParameter));
591 
592   param->op_parameter_.type_ = primitive->value_type();
593   param->input_size_ = value->input_size();
594   param->hidden_size_ = value->hidden_size();
595   param->bidirectional_ = value->bidirectional();
596   param->zoneout_cell_ = value->zoneout_cell();
597   param->zoneout_hidden_ = value->zoneout_hidden();
598   param->has_bias_ = static_cast<int>(value->has_bias());
599   return reinterpret_cast<OpParameter *>(param);
600 }
601 
PopulateTrainParameters()602 void PopulateTrainParameters() {
603   Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter,
604                                           lite::SCHEMA_CUR);
605   Registry BiasGradParameterRegistry(schema::PrimitiveType_BiasAddGrad, PopulateBiasGradParameter, lite::SCHEMA_CUR);
606   Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits,
607                                                 PopulateSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR);
608   Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
609                                                       PopulateSparseSoftmaxCrossEntropyWithLogitsParameter,
610                                                       lite::SCHEMA_CUR);
611   Registry ActivationParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter,
612                                        lite::SCHEMA_CUR);
613   Registry DependParameterRegistry(schema::PrimitiveType_Depend, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
614   Registry Conv2DGradFilterParameterRegistry(schema::PrimitiveType_Conv2DBackpropFilterFusion,
615                                              PopulateConvolutionGradFilterParameter, lite::SCHEMA_CUR);
616   Registry Conv2DGradInputParameterRegistry(schema::PrimitiveType_Conv2DBackpropInputFusion,
617                                             PopulateConvolutionGradInputParameter, lite::SCHEMA_CUR);
618   Registry avgPoolParameterRegistry(schema::PrimitiveType_AvgPoolGrad, PopulateAvgPoolGradParameter, lite::SCHEMA_CUR);
619   Registry maxPoolParameterRegistry(schema::PrimitiveType_MaxPoolGrad, PopulateMaxPoolGradParameter, lite::SCHEMA_CUR);
620   Registry PowerGradParameterRegistry(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter, lite::SCHEMA_CUR);
621   Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
622   Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter, lite::SCHEMA_CUR);
623   Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR);
624   Registry AdamWeightDecayParameterRegistry(schema::PrimitiveType_AdamWeightDecay, lite::DefaultPopulateParameter,
625                                             lite::SCHEMA_CUR);
626   Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
627   Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter,
628                                       lite::SCHEMA_CUR);
629   Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter,
630                                                lite::SCHEMA_CUR);
631   Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad,
632                                                    PopulateBCEGradParameter, lite::SCHEMA_CUR);
633   Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
634   Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum, lite::DefaultPopulateParameter,
635                                                lite::SCHEMA_CUR);
636   Registry DropoutParameterRegistry(schema::PrimitiveType_Dropout, PopulateDropoutParameter, lite::SCHEMA_CUR);
637   Registry DropGradParameterRegistry(schema::PrimitiveType_DropoutGrad, PopulateDropoutGradParameter, lite::SCHEMA_CUR);
638   Registry MaximumGradParameterRegistry(schema::PrimitiveType_MaximumGrad, PopulateArithmeticGradParameter,
639                                         lite::SCHEMA_CUR);
640   Registry MinimumGradParameterRegistry(schema::PrimitiveType_MinimumGrad, PopulateArithmeticGradParameter,
641                                         lite::SCHEMA_CUR);
642   Registry SmoothL1LossRegistry(schema::PrimitiveType_SmoothL1Loss, PopulateSmoothL1LossParameter, lite::SCHEMA_CUR);
643   Registry SmoothL1LossGradRegistry(schema::PrimitiveType_SmoothL1LossGrad, PopulateSmoothL1LossGradParameter,
644                                     lite::SCHEMA_CUR);
645   Registry SigmoidCrossEntropyWithLogitsRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
646                                                  lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
647   Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
648                                                      lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
649   Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, lite::DefaultPopulateParameter,
650                                         lite::SCHEMA_CUR);
651   Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, PopulateStridedSliceGradParameter,
652                                              lite::SCHEMA_CUR);
653   Registry SqrtGradParameterRegistry(schema::PrimitiveType_SqrtGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
654   Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter,
655                                       lite::SCHEMA_CUR);
656   Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, lite::SCHEMA_CUR);
657   Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
658   Registry LSTMGradParameterRegistry(schema::PrimitiveType_LSTMGrad, PopulateLstmGradParameter, lite::SCHEMA_CUR);
659   Registry LSTMGradDataParameterRegistry(schema::PrimitiveType_LSTMGradData, PopulateLstmGradDataParameter,
660                                          lite::SCHEMA_CUR);
661   Registry LSTMGradWeightParameterRegistry(schema::PrimitiveType_LSTMGradWeight, PopulateLstmGradWeightParameter,
662                                            lite::SCHEMA_CUR);
663 }
664 }  // namespace kernel
665 }  // namespace mindspore
666