• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/train/train_populate_parameter.h"
17 #include <algorithm>
18 #include "src/ops/populate/populate_register.h"
19 #include "src/ops/populate/default_populate.h"
20 #include "src/ops/populate/strided_slice_populate.h"
21 #include "nnacl/arithmetic.h"
22 #include "nnacl/conv_parameter.h"
23 #include "nnacl/lstm_parameter.h"
24 #include "nnacl/pooling_parameter.h"
25 #include "nnacl/power_parameter.h"
26 #include "nnacl/fp32/activation_fp32.h"
27 #include "nnacl/fp32_grad/softmax_grad.h"
28 #include "nnacl/fp32_grad/optimizer.h"
29 #include "nnacl/fp32_grad/batch_norm.h"
30 #include "nnacl/fp32_grad/dropout_parameter.h"
31 #include "nnacl/fp32_grad/smooth_l1_loss.h"
32 #include "nnacl/fp32_grad/resize_grad.h"
33 
34 using mindspore::lite::Registry;
35 
36 namespace mindspore {
37 namespace kernel {
38 namespace {
39 constexpr int kInputIndexTwo = 2;
40 constexpr int kInputIndexThree = 3;
41 }  // namespace
PopulateSmoothL1LossParameter(const void * prim)42 OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
43   SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
44   if (p == nullptr) {
45     MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
46     return nullptr;
47   }
48   memset(p, 0, sizeof(SmoothL1LossParameter));
49   auto primitive = static_cast<const schema::Primitive *>(prim);
50   auto value = primitive->value_as_SmoothL1Loss();
51   MS_ASSERT(value != nullptr);
52   p->op_parameter_.type_ = primitive->value_type();
53   p->beta_ = value->beta();
54   return reinterpret_cast<OpParameter *>(p);
55 }
56 
PopulateSmoothL1LossGradParameter(const void * prim)57 OpParameter *PopulateSmoothL1LossGradParameter(const void *prim) {
58   SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
59   if (p == nullptr) {
60     MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
61     return nullptr;
62   }
63   memset(p, 0, sizeof(SmoothL1LossParameter));
64   auto primitive = static_cast<const schema::Primitive *>(prim);
65   auto value = primitive->value_as_SmoothL1LossGrad();
66   MS_ASSERT(value != nullptr);
67   p->op_parameter_.type_ = primitive->value_type();
68   p->beta_ = value->beta();
69   return reinterpret_cast<OpParameter *>(p);
70 }
71 
PopulateApplyMomentumParameter(const void * prim)72 OpParameter *PopulateApplyMomentumParameter(const void *prim) {
73   ApplyMomentumParameter *p = reinterpret_cast<ApplyMomentumParameter *>(malloc(sizeof(ApplyMomentumParameter)));
74   if (p == nullptr) {
75     MS_LOG(ERROR) << "malloc ApplyMomentumParameter failed.";
76     return nullptr;
77   }
78   memset(p, 0, sizeof(ApplyMomentumParameter));
79   auto primitive = static_cast<const schema::Primitive *>(prim);
80   auto value = primitive->value_as_ApplyMomentum();
81   p->op_parameter_.type_ = primitive->value_type();
82   p->grad_scale_ = value->gradient_scale();
83   p->use_nesterov_ = value->use_nesterov();
84   return reinterpret_cast<OpParameter *>(p);
85 }
86 
PopulateBCEParameter(const void * prim)87 OpParameter *PopulateBCEParameter(const void *prim) {
88   int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
89   if (reduction == nullptr) {
90     MS_LOG(ERROR) << "malloc reduction failed.";
91     return nullptr;
92   }
93   auto primitive = static_cast<const schema::Primitive *>(prim);
94   auto value = primitive->value_as_BinaryCrossEntropy();
95   MS_ASSERT(value != nullptr);
96   *reduction = value->reduction();
97   return reinterpret_cast<OpParameter *>(reduction);
98 }
99 
PopulateBCEGradParameter(const void * prim)100 OpParameter *PopulateBCEGradParameter(const void *prim) {
101   int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
102   if (reduction == nullptr) {
103     MS_LOG(ERROR) << "malloc reduction failed.";
104     return nullptr;
105   }
106   auto primitive = static_cast<const schema::Primitive *>(prim);
107   auto value = primitive->value_as_BinaryCrossEntropyGrad();
108   MS_ASSERT(value != nullptr);
109   *reduction = value->reduction();
110   return reinterpret_cast<OpParameter *>(reduction);
111 }
112 
PopulateAdamParameter(const void * prim)113 OpParameter *PopulateAdamParameter(const void *prim) {
114   AdamParameter *p = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter)));
115   if (p == nullptr) {
116     MS_LOG(ERROR) << "new AdamParameter failed.";
117     return nullptr;
118   }
119   memset(p, 0, sizeof(AdamParameter));
120   auto primitive = static_cast<const schema::Primitive *>(prim);
121   auto value = primitive->value_as_Adam();
122   MS_ASSERT(value != nullptr);
123   p->op_parameter_.type_ = primitive->value_type();
124   p->use_nesterov_ = value->use_nesterov();
125   return reinterpret_cast<OpParameter *>(p);
126 }
127 
PopulateSgdParameter(const void * prim)128 OpParameter *PopulateSgdParameter(const void *prim) {
129   SgdParameter *p = reinterpret_cast<SgdParameter *>(malloc(sizeof(SgdParameter)));
130   if (p == nullptr) {
131     MS_LOG(ERROR) << "malloc SgdParameter failed.";
132     return nullptr;
133   }
134   memset(p, 0, sizeof(SgdParameter));
135   auto primitive = static_cast<const schema::Primitive *>(prim);
136   auto value = primitive->value_as_SGD();
137   MS_ASSERT(value != nullptr);
138   p->op_parameter_.type_ = primitive->value_type();
139   p->weight_decay_ = value->weight_decay();
140   p->dampening_ = value->dampening();
141   p->use_nesterov_ = value->nesterov();
142 
143   return reinterpret_cast<OpParameter *>(p);
144 }
145 
PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void * prim)146 OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) {
147   SoftmaxCrossEntropyParameter *sce_param =
148     reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
149   if (sce_param == nullptr) {
150     MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
151     return nullptr;
152   }
153   memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
154   auto primitive = static_cast<const schema::Primitive *>(prim);
155   auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
156   MS_ASSERT(value != nullptr);
157   sce_param->op_parameter_.type_ = primitive->value_type();
158   sce_param->is_grad_ = value->is_grad();
159   return reinterpret_cast<OpParameter *>(sce_param);
160 }
161 
PopulateSoftmaxCrossEntropyParameter(const void * prim)162 OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *prim) {
163   SoftmaxCrossEntropyParameter *sce_param =
164     reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
165   if (sce_param == nullptr) {
166     MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
167     return nullptr;
168   }
169   memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
170   auto primitive = static_cast<const schema::Primitive *>(prim);
171   sce_param->op_parameter_.type_ = primitive->value_type();
172   sce_param->is_grad_ = 0;
173   return reinterpret_cast<OpParameter *>(sce_param);
174 }
175 
PopulateMaxPoolGradParameter(const void * prim)176 OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
177   PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
178   if (pooling_param == nullptr) {
179     MS_LOG(ERROR) << "malloc PoolingParameter failed.";
180     return nullptr;
181   }
182   memset(pooling_param, 0, sizeof(PoolingParameter));
183   auto primitive = static_cast<const schema::Primitive *>(prim);
184   auto value = primitive->value_as_MaxPoolGrad();
185   MS_ASSERT(value != nullptr);
186   pooling_param->op_parameter_.type_ = primitive->value_type();
187 
188   pooling_param->global_ = false;
189   pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
190   pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
191 
192   pooling_param->pad_u_ = 0;
193   pooling_param->pad_d_ = 0;
194   pooling_param->pad_l_ = 0;
195   pooling_param->pad_r_ = 0;
196   pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
197   pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
198   pooling_param->round_mode_ = RoundMode_No;
199   pooling_param->pool_mode_ = PoolMode_MaxPool;
200   switch (value->pad_mode()) {
201     case schema::PadMode_SAME:
202       pooling_param->pad_mode_ = Pad_same;
203       break;
204     case schema::PadMode_VALID:
205       pooling_param->pad_mode_ = Pad_valid;
206       break;
207     default:
208       pooling_param->pad_mode_ = Pad_pad;
209       break;
210   }
211 
212   return reinterpret_cast<OpParameter *>(pooling_param);
213 }
214 
PopulateAvgPoolGradParameter(const void * prim)215 OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
216   PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
217   if (pooling_param == nullptr) {
218     MS_LOG(ERROR) << "malloc PoolingParameter failed.";
219     return nullptr;
220   }
221   memset(pooling_param, 0, sizeof(PoolingParameter));
222   auto primitive = static_cast<const schema::Primitive *>(prim);
223   auto value = primitive->value_as_AvgPoolGrad();
224   MS_ASSERT(value != nullptr);
225   pooling_param->op_parameter_.type_ = primitive->value_type();
226 
227   pooling_param->global_ = false;
228   pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
229   pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
230 
231   pooling_param->pad_u_ = 0;
232   pooling_param->pad_d_ = 0;
233   pooling_param->pad_l_ = 0;
234   pooling_param->pad_r_ = 0;
235   pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
236   pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
237 
238   switch (value->pad_mode()) {
239     case schema::PadMode_SAME:
240       pooling_param->pad_mode_ = Pad_same;
241       break;
242     case schema::PadMode_VALID:
243       pooling_param->pad_mode_ = Pad_valid;
244       break;
245     default:
246       pooling_param->pad_mode_ = Pad_pad;
247       break;
248   }
249   pooling_param->round_mode_ = RoundMode_No;
250   pooling_param->pool_mode_ = PoolMode_AvgPool;
251   switch (value->pad_mode()) {
252     case schema::PadMode_SAME:
253       pooling_param->pad_mode_ = Pad_same;
254       break;
255     case schema::PadMode_VALID:
256       pooling_param->pad_mode_ = Pad_valid;
257       break;
258     default:
259       pooling_param->pad_mode_ = Pad_pad;
260       break;
261   }
262   return reinterpret_cast<OpParameter *>(pooling_param);
263 }
264 
PopulateActivationGradParameter(const void * prim)265 OpParameter *PopulateActivationGradParameter(const void *prim) {
266   ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
267   if (act_param == nullptr) {
268     MS_LOG(ERROR) << "malloc ActivationParameter failed.";
269     return nullptr;
270   }
271   memset(act_param, 0, sizeof(ActivationParameter));
272   auto primitive = static_cast<const schema::Primitive *>(prim);
273   auto value = primitive->value_as_ActivationGrad();
274   MS_ASSERT(value != nullptr);
275   act_param->op_parameter_.type_ = primitive->value_type();
276   act_param->type_ = static_cast<int>(value->activation_type());
277   act_param->alpha_ = value->alpha();
278   return reinterpret_cast<OpParameter *>(act_param);
279 }
280 
PopulateConvolutionGradFilterParameter(const void * prim)281 OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
282   ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
283   if (param == nullptr) {
284     MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
285     return nullptr;
286   }
287   memset(param, 0, sizeof(ConvParameter));
288   auto primitive = static_cast<const schema::Primitive *>(prim);
289   auto value = primitive->value_as_Conv2DBackpropFilterFusion();
290   MS_ASSERT(value != nullptr);
291   param->op_parameter_.type_ = primitive->value_type();
292 
293   param->kernel_h_ = value->kernel_size()->Get(0);
294   param->kernel_w_ = value->kernel_size()->Get(1);
295   param->stride_h_ = value->stride()->Get(0);
296   param->stride_w_ = value->stride()->Get(1);
297   param->dilation_h_ = value->dilation()->Get(0);
298   param->dilation_w_ = value->dilation()->Get(1);
299   param->pad_u_ = value->pad_list()->Get(0);
300   param->pad_d_ = value->pad_list()->Get(1);
301   param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
302   param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
303   param->group_ = value->group();
304   param->act_type_ = ActType_No;
305   switch (value->activation_type()) {
306     case schema::ActivationType_RELU:
307       param->act_type_ = ActType_Relu;
308       break;
309     case schema::ActivationType_RELU6:
310       param->act_type_ = ActType_Relu6;
311       break;
312     default:
313       break;
314   }
315 
316   return reinterpret_cast<OpParameter *>(param);
317 }
318 
PopulateConvolutionGradInputParameter(const void * prim)319 OpParameter *PopulateConvolutionGradInputParameter(const void *prim) {
320   ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
321   if (param == nullptr) {
322     MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
323     return nullptr;
324   }
325   memset(param, 0, sizeof(ConvParameter));
326   auto primitive = static_cast<const schema::Primitive *>(prim);
327   auto value = primitive->value_as_Conv2DBackpropInputFusion();
328   MS_ASSERT(value != nullptr);
329   param->op_parameter_.type_ = primitive->value_type();
330 
331   param->kernel_h_ = value->kernel_size()->Get(0);
332   param->kernel_w_ = value->kernel_size()->Get(1);
333   param->stride_h_ = value->stride()->Get(0);
334   param->stride_w_ = value->stride()->Get(1);
335   param->dilation_h_ = value->dilation()->Get(0);
336   param->dilation_w_ = value->dilation()->Get(1);
337   param->pad_u_ = value->pad_list()->Get(0);
338   param->pad_d_ = value->pad_list()->Get(1);
339   param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
340   param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
341   param->group_ = value->group();
342   param->act_type_ = ActType_No;
343   switch (value->activation_type()) {
344     case schema::ActivationType_RELU:
345       param->act_type_ = ActType_Relu;
346       break;
347     case schema::ActivationType_RELU6:
348       param->act_type_ = ActType_Relu6;
349       break;
350     default:
351       break;
352   }
353 
354   return reinterpret_cast<OpParameter *>(param);
355 }
356 
PopulatePowerGradParameter(const void * prim)357 OpParameter *PopulatePowerGradParameter(const void *prim) {
358   PowerParameter *power_param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
359   if (power_param == nullptr) {
360     MS_LOG(ERROR) << "malloc PowerParameter failed.";
361     return nullptr;
362   }
363   memset(power_param, 0, sizeof(PowerParameter));
364   auto primitive = static_cast<const schema::Primitive *>(prim);
365   auto value = primitive->value_as_PowerGrad();
366   MS_ASSERT(value != nullptr);
367   power_param->op_parameter_.type_ = primitive->value_type();
368   power_param->power_ = value->power();
369   power_param->scale_ = value->scale();
370   power_param->shift_ = value->shift();
371   return reinterpret_cast<OpParameter *>(power_param);
372 }
373 
PopulateBiasGradParameter(const void * prim)374 OpParameter *PopulateBiasGradParameter(const void *prim) {
375   ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
376   if (arithmetic_param == nullptr) {
377     MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
378     return nullptr;
379   }
380   memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
381   auto primitive = static_cast<const schema::Primitive *>(prim);
382   arithmetic_param->op_parameter_.type_ = primitive->value_type();
383   return reinterpret_cast<OpParameter *>(arithmetic_param);
384 }
385 
PopulateBNGradParameter(const void * prim)386 OpParameter *PopulateBNGradParameter(const void *prim) {
387   BNGradParameter *bnGrad_param = reinterpret_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter)));
388   if (bnGrad_param == nullptr) {
389     MS_LOG(ERROR) << "malloc BNGradParameter failed.";
390     return nullptr;
391   }
392   memset(bnGrad_param, 0, sizeof(BNGradParameter));
393   auto primitive = static_cast<const schema::Primitive *>(prim);
394   auto value = primitive->value_as_BatchNormGrad();
395   MS_ASSERT(value != nullptr);
396   bnGrad_param->op_parameter_.type_ = primitive->value_type();
397   bnGrad_param->epsilon_ = value->epsilon();
398   return reinterpret_cast<OpParameter *>(bnGrad_param);
399 }
400 
PopulateDropoutParameter(const void * prim)401 OpParameter *PopulateDropoutParameter(const void *prim) {
402   DropoutParameter *dropout_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
403   if (dropout_parameter == nullptr) {
404     MS_LOG(ERROR) << "malloc Dropout Parameter failed.";
405     return nullptr;
406   }
407   memset(dropout_parameter, 0, sizeof(DropoutParameter));
408   auto primitive = static_cast<const schema::Primitive *>(prim);
409   auto value = primitive->value_as_Dropout();
410   MS_ASSERT(value != nullptr);
411   dropout_parameter->op_parameter_.type_ = primitive->value_type();
412   dropout_parameter->ratio_ = value->keep_prob();
413   if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
414     MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_;
415     free(dropout_parameter);
416     return nullptr;
417   }
418   return reinterpret_cast<OpParameter *>(dropout_parameter);
419 }
420 
PopulateDropoutGradParameter(const void * prim)421 OpParameter *PopulateDropoutGradParameter(const void *prim) {
422   DropoutParameter *dropoutgrad_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
423   if (dropoutgrad_parameter == nullptr) {
424     MS_LOG(ERROR) << "malloc Dropout Grad Parameter failed.";
425     return nullptr;
426   }
427   memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter));
428   auto primitive = static_cast<const schema::Primitive *>(prim);
429   auto value = primitive->value_as_DropoutGrad();
430   MS_ASSERT(value != nullptr);
431   dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type();
432   dropoutgrad_parameter->ratio_ = value->keep_prob();
433   if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) {
434     MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutgrad_parameter->ratio_;
435     free(dropoutgrad_parameter);
436     return nullptr;
437   }
438   return reinterpret_cast<OpParameter *>(dropoutgrad_parameter);
439 }
440 
PopulateArithmeticGradParameter(const void * prim)441 OpParameter *PopulateArithmeticGradParameter(const void *prim) {
442   ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
443   if (arithmetic_param == nullptr) {
444     MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
445     return nullptr;
446   }
447   memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
448   auto primitive = static_cast<const schema::Primitive *>(prim);
449   arithmetic_param->op_parameter_.type_ = primitive->value_type();
450   return reinterpret_cast<OpParameter *>(arithmetic_param);
451 }
452 
PopulateResizeGradParameter(const void * prim)453 OpParameter *PopulateResizeGradParameter(const void *prim) {
454   ResizeGradParameter *resize_grad_param = reinterpret_cast<ResizeGradParameter *>(malloc(sizeof(ResizeGradParameter)));
455   if (resize_grad_param == nullptr) {
456     MS_LOG(ERROR) << "malloc resize grad parameter failed.";
457     return nullptr;
458   }
459   memset(resize_grad_param, 0, sizeof(ResizeGradParameter));
460   auto primitive = static_cast<const schema::Primitive *>(prim);
461   resize_grad_param->op_parameter_.type_ = primitive->value_type();
462   auto param = primitive->value_as_ResizeGrad();
463   MS_ASSERT(param != nullptr);
464   resize_grad_param->method = static_cast<int>(param->method());
465   resize_grad_param->align_corners_ = param->align_corners();
466 
467   return reinterpret_cast<OpParameter *>(resize_grad_param);
468 }
469 
PopulateStridedSliceGradParameter(const void * prim)470 OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
471   StridedSliceParameter *strided_slice_param =
472     reinterpret_cast<StridedSliceParameter *>(malloc(sizeof(StridedSliceParameter)));
473   if (strided_slice_param == nullptr) {
474     MS_LOG(ERROR) << "malloc StridedSliceParameter failed.";
475     return nullptr;
476   }
477   memset(strided_slice_param, 0, sizeof(StridedSliceParameter));
478 
479   auto primitive = static_cast<const schema::Primitive *>(prim);
480   auto value = primitive->value_as_StridedSliceGrad();
481   MS_ASSERT(value != nullptr);
482   strided_slice_param->op_parameter_.type_ = primitive->value_type();
483 
484   strided_slice_param->begins_mask_ = value->begin_mask();
485   strided_slice_param->ends_mask_ = value->end_mask();
486   strided_slice_param->ellipsisMask_ = value->ellipsis_mask();
487   strided_slice_param->newAxisMask_ = value->new_axis_mask();
488   strided_slice_param->shrinkAxisMask_ = value->shrink_axis_mask();
489   return reinterpret_cast<OpParameter *>(strided_slice_param);
490 }
491 
PopulateLstmGradParameter(const void * prim)492 OpParameter *PopulateLstmGradParameter(const void *prim) {
493   auto primitive = static_cast<const schema::Primitive *>(prim);
494   MS_ASSERT(primitive != nullptr);
495   auto value = primitive->value_as_LSTMGrad();
496   if (value == nullptr) {
497     MS_LOG(ERROR) << "value is nullptr.";
498     return nullptr;
499   }
500 
501   auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
502   if (param == nullptr) {
503     MS_LOG(ERROR) << "malloc LstmParameter failed.";
504     return nullptr;
505   }
506   memset(param, 0, sizeof(LstmParameter));
507 
508   param->op_parameter_.type_ = primitive->value_type();
509   param->bidirectional_ = value->bidirectional();
510   param->zoneout_cell_ = value->zoneout_cell();
511   param->zoneout_hidden_ = value->zoneout_hidden();
512   return reinterpret_cast<OpParameter *>(param);
513 }
514 
PopulateTrainParameters()515 void PopulateTrainParameters() {
516   lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter,
517                                                 lite::SCHEMA_CUR);
518   lite::Registry BiasGradParameterRegistry(schema::PrimitiveType_BiasAddGrad, PopulateBiasGradParameter,
519                                            lite::SCHEMA_CUR);
520   lite::Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits,
521                                                       PopulateSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR);
522   lite::Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
523                                                             PopulateSparseSoftmaxCrossEntropyWithLogitsParameter,
524                                                             lite::SCHEMA_CUR);
525   lite::Registry ActivationParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter,
526                                              lite::SCHEMA_CUR);
527   lite::Registry DependParameterRegistry(schema::PrimitiveType_Depend, lite::DefaultPopulateParameter,
528                                          lite::SCHEMA_CUR);
529   lite::Registry Conv2DGradFilterParameterRegistry(schema::PrimitiveType_Conv2DBackpropFilterFusion,
530                                                    PopulateConvolutionGradFilterParameter, lite::SCHEMA_CUR);
531   lite::Registry Conv2DGradInputParameterRegistry(schema::PrimitiveType_Conv2DBackpropInputFusion,
532                                                   PopulateConvolutionGradInputParameter, lite::SCHEMA_CUR);
533   lite::Registry avgPoolParameterRegistry(schema::PrimitiveType_AvgPoolGrad, PopulateAvgPoolGradParameter,
534                                           lite::SCHEMA_CUR);
535   lite::Registry maxPoolParameterRegistry(schema::PrimitiveType_MaxPoolGrad, PopulateMaxPoolGradParameter,
536                                           lite::SCHEMA_CUR);
537   lite::Registry PowerGradParameterRegistry(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter,
538                                             lite::SCHEMA_CUR);
539   lite::Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
540   lite::Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter,
541                                          lite::SCHEMA_CUR);
542   lite::Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR);
543   lite::Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter,
544                                          lite::SCHEMA_CUR);
545   lite::Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter,
546                                             lite::SCHEMA_CUR);
547   lite::Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter,
548                                                      lite::SCHEMA_CUR);
549   lite::Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad,
550                                                          PopulateBCEGradParameter, lite::SCHEMA_CUR);
551   lite::Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, lite::DefaultPopulateParameter,
552                                            lite::SCHEMA_CUR);
553   lite::Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum,
554                                                      lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
555   lite::Registry DropoutParameterRegistry(schema::PrimitiveType_Dropout, PopulateDropoutParameter, lite::SCHEMA_CUR);
556   lite::Registry DropGradParameterRegistry(schema::PrimitiveType_DropoutGrad, PopulateDropoutGradParameter,
557                                            lite::SCHEMA_CUR);
558   lite::Registry MaximumGradParameterRegistry(schema::PrimitiveType_MaximumGrad, PopulateArithmeticGradParameter,
559                                               lite::SCHEMA_CUR);
560   lite::Registry MinimumGradParameterRegistry(schema::PrimitiveType_MinimumGrad, PopulateArithmeticGradParameter,
561                                               lite::SCHEMA_CUR);
562   lite::Registry SmoothL1LossRegistry(schema::PrimitiveType_SmoothL1Loss, PopulateSmoothL1LossParameter,
563                                       lite::SCHEMA_CUR);
564   lite::Registry SmoothL1LossGradRegistry(schema::PrimitiveType_SmoothL1LossGrad, PopulateSmoothL1LossGradParameter,
565                                           lite::SCHEMA_CUR);
566   lite::Registry SigmoidCrossEntropyWithLogitsRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
567                                                        lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
568   lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
569                                                            lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
570   lite::Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, lite::DefaultPopulateParameter,
571                                               lite::SCHEMA_CUR);
572   lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad,
573                                                    PopulateStridedSliceGradParameter, lite::SCHEMA_CUR);
574   lite::Registry SqrtGradParameterRegistry(schema::PrimitiveType_SqrtGrad, lite::DefaultPopulateParameter,
575                                            lite::SCHEMA_CUR);
576   lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter,
577                                             lite::SCHEMA_CUR);
578   Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, lite::SCHEMA_CUR);
579   Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
580   Registry LSTMGradParameterRegistry(schema::PrimitiveType_LSTMGrad, PopulateLstmGradParameter, lite::SCHEMA_CUR);
581 }
582 }  // namespace kernel
583 }  // namespace mindspore
584