1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/train/train_populate_parameter.h"
17 #include <algorithm>
18 #include "src/ops/populate/populate_register.h"
19 #include "src/ops/populate/default_populate.h"
20 #include "src/ops/populate/strided_slice_populate.h"
21 #include "nnacl/arithmetic.h"
22 #include "nnacl/conv_parameter.h"
23 #include "nnacl/lstm_parameter.h"
24 #include "nnacl/pooling_parameter.h"
25 #include "nnacl/power_parameter.h"
26 #include "nnacl/fp32/activation_fp32.h"
27 #include "nnacl/fp32_grad/softmax_grad.h"
28 #include "nnacl/fp32_grad/optimizer.h"
29 #include "nnacl/fp32_grad/batch_norm.h"
30 #include "nnacl/fp32_grad/dropout_parameter.h"
31 #include "nnacl/fp32_grad/smooth_l1_loss.h"
32 #include "nnacl/fp32_grad/resize_grad.h"
33
34 using mindspore::lite::Registry;
35
36 namespace mindspore {
37 namespace kernel {
38 namespace {
39 constexpr int kInputIndexTwo = 2;
40 constexpr int kInputIndexThree = 3;
41 } // namespace
PopulateSmoothL1LossParameter(const void * prim)42 OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
43 SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
44 if (p == nullptr) {
45 MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
46 return nullptr;
47 }
48 memset(p, 0, sizeof(SmoothL1LossParameter));
49 auto primitive = static_cast<const schema::Primitive *>(prim);
50 auto value = primitive->value_as_SmoothL1Loss();
51 MS_ASSERT(value != nullptr);
52 p->op_parameter_.type_ = primitive->value_type();
53 p->beta_ = value->beta();
54 return reinterpret_cast<OpParameter *>(p);
55 }
56
PopulateSmoothL1LossGradParameter(const void * prim)57 OpParameter *PopulateSmoothL1LossGradParameter(const void *prim) {
58 SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
59 if (p == nullptr) {
60 MS_LOG(ERROR) << "malloc SmoothL1LossParameter failed.";
61 return nullptr;
62 }
63 memset(p, 0, sizeof(SmoothL1LossParameter));
64 auto primitive = static_cast<const schema::Primitive *>(prim);
65 auto value = primitive->value_as_SmoothL1LossGrad();
66 MS_ASSERT(value != nullptr);
67 p->op_parameter_.type_ = primitive->value_type();
68 p->beta_ = value->beta();
69 return reinterpret_cast<OpParameter *>(p);
70 }
71
PopulateApplyMomentumParameter(const void * prim)72 OpParameter *PopulateApplyMomentumParameter(const void *prim) {
73 ApplyMomentumParameter *p = reinterpret_cast<ApplyMomentumParameter *>(malloc(sizeof(ApplyMomentumParameter)));
74 if (p == nullptr) {
75 MS_LOG(ERROR) << "malloc ApplyMomentumParameter failed.";
76 return nullptr;
77 }
78 memset(p, 0, sizeof(ApplyMomentumParameter));
79 auto primitive = static_cast<const schema::Primitive *>(prim);
80 auto value = primitive->value_as_ApplyMomentum();
81 p->op_parameter_.type_ = primitive->value_type();
82 p->grad_scale_ = value->gradient_scale();
83 p->use_nesterov_ = value->use_nesterov();
84 return reinterpret_cast<OpParameter *>(p);
85 }
86
PopulateBCEParameter(const void * prim)87 OpParameter *PopulateBCEParameter(const void *prim) {
88 int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
89 if (reduction == nullptr) {
90 MS_LOG(ERROR) << "malloc reduction failed.";
91 return nullptr;
92 }
93 auto primitive = static_cast<const schema::Primitive *>(prim);
94 auto value = primitive->value_as_BinaryCrossEntropy();
95 MS_ASSERT(value != nullptr);
96 *reduction = value->reduction();
97 return reinterpret_cast<OpParameter *>(reduction);
98 }
99
PopulateBCEGradParameter(const void * prim)100 OpParameter *PopulateBCEGradParameter(const void *prim) {
101 int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
102 if (reduction == nullptr) {
103 MS_LOG(ERROR) << "malloc reduction failed.";
104 return nullptr;
105 }
106 auto primitive = static_cast<const schema::Primitive *>(prim);
107 auto value = primitive->value_as_BinaryCrossEntropyGrad();
108 MS_ASSERT(value != nullptr);
109 *reduction = value->reduction();
110 return reinterpret_cast<OpParameter *>(reduction);
111 }
112
PopulateAdamParameter(const void * prim)113 OpParameter *PopulateAdamParameter(const void *prim) {
114 AdamParameter *p = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter)));
115 if (p == nullptr) {
116 MS_LOG(ERROR) << "new AdamParameter failed.";
117 return nullptr;
118 }
119 memset(p, 0, sizeof(AdamParameter));
120 auto primitive = static_cast<const schema::Primitive *>(prim);
121 auto value = primitive->value_as_Adam();
122 MS_ASSERT(value != nullptr);
123 p->op_parameter_.type_ = primitive->value_type();
124 p->use_nesterov_ = value->use_nesterov();
125 return reinterpret_cast<OpParameter *>(p);
126 }
127
PopulateSgdParameter(const void * prim)128 OpParameter *PopulateSgdParameter(const void *prim) {
129 SgdParameter *p = reinterpret_cast<SgdParameter *>(malloc(sizeof(SgdParameter)));
130 if (p == nullptr) {
131 MS_LOG(ERROR) << "malloc SgdParameter failed.";
132 return nullptr;
133 }
134 memset(p, 0, sizeof(SgdParameter));
135 auto primitive = static_cast<const schema::Primitive *>(prim);
136 auto value = primitive->value_as_SGD();
137 MS_ASSERT(value != nullptr);
138 p->op_parameter_.type_ = primitive->value_type();
139 p->weight_decay_ = value->weight_decay();
140 p->dampening_ = value->dampening();
141 p->use_nesterov_ = value->nesterov();
142
143 return reinterpret_cast<OpParameter *>(p);
144 }
145
PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void * prim)146 OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) {
147 SoftmaxCrossEntropyParameter *sce_param =
148 reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
149 if (sce_param == nullptr) {
150 MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
151 return nullptr;
152 }
153 memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
154 auto primitive = static_cast<const schema::Primitive *>(prim);
155 auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
156 MS_ASSERT(value != nullptr);
157 sce_param->op_parameter_.type_ = primitive->value_type();
158 sce_param->is_grad_ = value->is_grad();
159 return reinterpret_cast<OpParameter *>(sce_param);
160 }
161
PopulateSoftmaxCrossEntropyParameter(const void * prim)162 OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *prim) {
163 SoftmaxCrossEntropyParameter *sce_param =
164 reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
165 if (sce_param == nullptr) {
166 MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
167 return nullptr;
168 }
169 memset(sce_param, 0, sizeof(SoftmaxCrossEntropyParameter));
170 auto primitive = static_cast<const schema::Primitive *>(prim);
171 sce_param->op_parameter_.type_ = primitive->value_type();
172 sce_param->is_grad_ = 0;
173 return reinterpret_cast<OpParameter *>(sce_param);
174 }
175
PopulateMaxPoolGradParameter(const void * prim)176 OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
177 PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
178 if (pooling_param == nullptr) {
179 MS_LOG(ERROR) << "malloc PoolingParameter failed.";
180 return nullptr;
181 }
182 memset(pooling_param, 0, sizeof(PoolingParameter));
183 auto primitive = static_cast<const schema::Primitive *>(prim);
184 auto value = primitive->value_as_MaxPoolGrad();
185 MS_ASSERT(value != nullptr);
186 pooling_param->op_parameter_.type_ = primitive->value_type();
187
188 pooling_param->global_ = false;
189 pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
190 pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
191
192 pooling_param->pad_u_ = 0;
193 pooling_param->pad_d_ = 0;
194 pooling_param->pad_l_ = 0;
195 pooling_param->pad_r_ = 0;
196 pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
197 pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
198 pooling_param->round_mode_ = RoundMode_No;
199 pooling_param->pool_mode_ = PoolMode_MaxPool;
200 switch (value->pad_mode()) {
201 case schema::PadMode_SAME:
202 pooling_param->pad_mode_ = Pad_same;
203 break;
204 case schema::PadMode_VALID:
205 pooling_param->pad_mode_ = Pad_valid;
206 break;
207 default:
208 pooling_param->pad_mode_ = Pad_pad;
209 break;
210 }
211
212 return reinterpret_cast<OpParameter *>(pooling_param);
213 }
214
PopulateAvgPoolGradParameter(const void * prim)215 OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
216 PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
217 if (pooling_param == nullptr) {
218 MS_LOG(ERROR) << "malloc PoolingParameter failed.";
219 return nullptr;
220 }
221 memset(pooling_param, 0, sizeof(PoolingParameter));
222 auto primitive = static_cast<const schema::Primitive *>(prim);
223 auto value = primitive->value_as_AvgPoolGrad();
224 MS_ASSERT(value != nullptr);
225 pooling_param->op_parameter_.type_ = primitive->value_type();
226
227 pooling_param->global_ = false;
228 pooling_param->window_w_ = static_cast<int>(value->kernel_size()->Get(1));
229 pooling_param->window_h_ = static_cast<int>(value->kernel_size()->Get(0));
230
231 pooling_param->pad_u_ = 0;
232 pooling_param->pad_d_ = 0;
233 pooling_param->pad_l_ = 0;
234 pooling_param->pad_r_ = 0;
235 pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1));
236 pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0));
237
238 switch (value->pad_mode()) {
239 case schema::PadMode_SAME:
240 pooling_param->pad_mode_ = Pad_same;
241 break;
242 case schema::PadMode_VALID:
243 pooling_param->pad_mode_ = Pad_valid;
244 break;
245 default:
246 pooling_param->pad_mode_ = Pad_pad;
247 break;
248 }
249 pooling_param->round_mode_ = RoundMode_No;
250 pooling_param->pool_mode_ = PoolMode_AvgPool;
251 switch (value->pad_mode()) {
252 case schema::PadMode_SAME:
253 pooling_param->pad_mode_ = Pad_same;
254 break;
255 case schema::PadMode_VALID:
256 pooling_param->pad_mode_ = Pad_valid;
257 break;
258 default:
259 pooling_param->pad_mode_ = Pad_pad;
260 break;
261 }
262 return reinterpret_cast<OpParameter *>(pooling_param);
263 }
264
PopulateActivationGradParameter(const void * prim)265 OpParameter *PopulateActivationGradParameter(const void *prim) {
266 ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
267 if (act_param == nullptr) {
268 MS_LOG(ERROR) << "malloc ActivationParameter failed.";
269 return nullptr;
270 }
271 memset(act_param, 0, sizeof(ActivationParameter));
272 auto primitive = static_cast<const schema::Primitive *>(prim);
273 auto value = primitive->value_as_ActivationGrad();
274 MS_ASSERT(value != nullptr);
275 act_param->op_parameter_.type_ = primitive->value_type();
276 act_param->type_ = static_cast<int>(value->activation_type());
277 act_param->alpha_ = value->alpha();
278 return reinterpret_cast<OpParameter *>(act_param);
279 }
280
PopulateConvolutionGradFilterParameter(const void * prim)281 OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
282 ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
283 if (param == nullptr) {
284 MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
285 return nullptr;
286 }
287 memset(param, 0, sizeof(ConvParameter));
288 auto primitive = static_cast<const schema::Primitive *>(prim);
289 auto value = primitive->value_as_Conv2DBackpropFilterFusion();
290 MS_ASSERT(value != nullptr);
291 param->op_parameter_.type_ = primitive->value_type();
292
293 param->kernel_h_ = value->kernel_size()->Get(0);
294 param->kernel_w_ = value->kernel_size()->Get(1);
295 param->stride_h_ = value->stride()->Get(0);
296 param->stride_w_ = value->stride()->Get(1);
297 param->dilation_h_ = value->dilation()->Get(0);
298 param->dilation_w_ = value->dilation()->Get(1);
299 param->pad_u_ = value->pad_list()->Get(0);
300 param->pad_d_ = value->pad_list()->Get(1);
301 param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
302 param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
303 param->group_ = value->group();
304 param->act_type_ = ActType_No;
305 switch (value->activation_type()) {
306 case schema::ActivationType_RELU:
307 param->act_type_ = ActType_Relu;
308 break;
309 case schema::ActivationType_RELU6:
310 param->act_type_ = ActType_Relu6;
311 break;
312 default:
313 break;
314 }
315
316 return reinterpret_cast<OpParameter *>(param);
317 }
318
PopulateConvolutionGradInputParameter(const void * prim)319 OpParameter *PopulateConvolutionGradInputParameter(const void *prim) {
320 ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
321 if (param == nullptr) {
322 MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
323 return nullptr;
324 }
325 memset(param, 0, sizeof(ConvParameter));
326 auto primitive = static_cast<const schema::Primitive *>(prim);
327 auto value = primitive->value_as_Conv2DBackpropInputFusion();
328 MS_ASSERT(value != nullptr);
329 param->op_parameter_.type_ = primitive->value_type();
330
331 param->kernel_h_ = value->kernel_size()->Get(0);
332 param->kernel_w_ = value->kernel_size()->Get(1);
333 param->stride_h_ = value->stride()->Get(0);
334 param->stride_w_ = value->stride()->Get(1);
335 param->dilation_h_ = value->dilation()->Get(0);
336 param->dilation_w_ = value->dilation()->Get(1);
337 param->pad_u_ = value->pad_list()->Get(0);
338 param->pad_d_ = value->pad_list()->Get(1);
339 param->pad_l_ = value->pad_list()->Get(kInputIndexTwo);
340 param->pad_r_ = value->pad_list()->Get(kInputIndexThree);
341 param->group_ = value->group();
342 param->act_type_ = ActType_No;
343 switch (value->activation_type()) {
344 case schema::ActivationType_RELU:
345 param->act_type_ = ActType_Relu;
346 break;
347 case schema::ActivationType_RELU6:
348 param->act_type_ = ActType_Relu6;
349 break;
350 default:
351 break;
352 }
353
354 return reinterpret_cast<OpParameter *>(param);
355 }
356
PopulatePowerGradParameter(const void * prim)357 OpParameter *PopulatePowerGradParameter(const void *prim) {
358 PowerParameter *power_param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
359 if (power_param == nullptr) {
360 MS_LOG(ERROR) << "malloc PowerParameter failed.";
361 return nullptr;
362 }
363 memset(power_param, 0, sizeof(PowerParameter));
364 auto primitive = static_cast<const schema::Primitive *>(prim);
365 auto value = primitive->value_as_PowerGrad();
366 MS_ASSERT(value != nullptr);
367 power_param->op_parameter_.type_ = primitive->value_type();
368 power_param->power_ = value->power();
369 power_param->scale_ = value->scale();
370 power_param->shift_ = value->shift();
371 return reinterpret_cast<OpParameter *>(power_param);
372 }
373
PopulateBiasGradParameter(const void * prim)374 OpParameter *PopulateBiasGradParameter(const void *prim) {
375 ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
376 if (arithmetic_param == nullptr) {
377 MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
378 return nullptr;
379 }
380 memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
381 auto primitive = static_cast<const schema::Primitive *>(prim);
382 arithmetic_param->op_parameter_.type_ = primitive->value_type();
383 return reinterpret_cast<OpParameter *>(arithmetic_param);
384 }
385
PopulateBNGradParameter(const void * prim)386 OpParameter *PopulateBNGradParameter(const void *prim) {
387 BNGradParameter *bnGrad_param = reinterpret_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter)));
388 if (bnGrad_param == nullptr) {
389 MS_LOG(ERROR) << "malloc BNGradParameter failed.";
390 return nullptr;
391 }
392 memset(bnGrad_param, 0, sizeof(BNGradParameter));
393 auto primitive = static_cast<const schema::Primitive *>(prim);
394 auto value = primitive->value_as_BatchNormGrad();
395 MS_ASSERT(value != nullptr);
396 bnGrad_param->op_parameter_.type_ = primitive->value_type();
397 bnGrad_param->epsilon_ = value->epsilon();
398 return reinterpret_cast<OpParameter *>(bnGrad_param);
399 }
400
PopulateDropoutParameter(const void * prim)401 OpParameter *PopulateDropoutParameter(const void *prim) {
402 DropoutParameter *dropout_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
403 if (dropout_parameter == nullptr) {
404 MS_LOG(ERROR) << "malloc Dropout Parameter failed.";
405 return nullptr;
406 }
407 memset(dropout_parameter, 0, sizeof(DropoutParameter));
408 auto primitive = static_cast<const schema::Primitive *>(prim);
409 auto value = primitive->value_as_Dropout();
410 MS_ASSERT(value != nullptr);
411 dropout_parameter->op_parameter_.type_ = primitive->value_type();
412 dropout_parameter->ratio_ = value->keep_prob();
413 if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
414 MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_;
415 free(dropout_parameter);
416 return nullptr;
417 }
418 return reinterpret_cast<OpParameter *>(dropout_parameter);
419 }
420
PopulateDropoutGradParameter(const void * prim)421 OpParameter *PopulateDropoutGradParameter(const void *prim) {
422 DropoutParameter *dropoutgrad_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter)));
423 if (dropoutgrad_parameter == nullptr) {
424 MS_LOG(ERROR) << "malloc Dropout Grad Parameter failed.";
425 return nullptr;
426 }
427 memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter));
428 auto primitive = static_cast<const schema::Primitive *>(prim);
429 auto value = primitive->value_as_DropoutGrad();
430 MS_ASSERT(value != nullptr);
431 dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type();
432 dropoutgrad_parameter->ratio_ = value->keep_prob();
433 if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) {
434 MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutgrad_parameter->ratio_;
435 free(dropoutgrad_parameter);
436 return nullptr;
437 }
438 return reinterpret_cast<OpParameter *>(dropoutgrad_parameter);
439 }
440
PopulateArithmeticGradParameter(const void * prim)441 OpParameter *PopulateArithmeticGradParameter(const void *prim) {
442 ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
443 if (arithmetic_param == nullptr) {
444 MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
445 return nullptr;
446 }
447 memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
448 auto primitive = static_cast<const schema::Primitive *>(prim);
449 arithmetic_param->op_parameter_.type_ = primitive->value_type();
450 return reinterpret_cast<OpParameter *>(arithmetic_param);
451 }
452
PopulateResizeGradParameter(const void * prim)453 OpParameter *PopulateResizeGradParameter(const void *prim) {
454 ResizeGradParameter *resize_grad_param = reinterpret_cast<ResizeGradParameter *>(malloc(sizeof(ResizeGradParameter)));
455 if (resize_grad_param == nullptr) {
456 MS_LOG(ERROR) << "malloc resize grad parameter failed.";
457 return nullptr;
458 }
459 memset(resize_grad_param, 0, sizeof(ResizeGradParameter));
460 auto primitive = static_cast<const schema::Primitive *>(prim);
461 resize_grad_param->op_parameter_.type_ = primitive->value_type();
462 auto param = primitive->value_as_ResizeGrad();
463 MS_ASSERT(param != nullptr);
464 resize_grad_param->method = static_cast<int>(param->method());
465 resize_grad_param->align_corners_ = param->align_corners();
466
467 return reinterpret_cast<OpParameter *>(resize_grad_param);
468 }
469
PopulateStridedSliceGradParameter(const void * prim)470 OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
471 StridedSliceParameter *strided_slice_param =
472 reinterpret_cast<StridedSliceParameter *>(malloc(sizeof(StridedSliceParameter)));
473 if (strided_slice_param == nullptr) {
474 MS_LOG(ERROR) << "malloc StridedSliceParameter failed.";
475 return nullptr;
476 }
477 memset(strided_slice_param, 0, sizeof(StridedSliceParameter));
478
479 auto primitive = static_cast<const schema::Primitive *>(prim);
480 auto value = primitive->value_as_StridedSliceGrad();
481 MS_ASSERT(value != nullptr);
482 strided_slice_param->op_parameter_.type_ = primitive->value_type();
483
484 strided_slice_param->begins_mask_ = value->begin_mask();
485 strided_slice_param->ends_mask_ = value->end_mask();
486 strided_slice_param->ellipsisMask_ = value->ellipsis_mask();
487 strided_slice_param->newAxisMask_ = value->new_axis_mask();
488 strided_slice_param->shrinkAxisMask_ = value->shrink_axis_mask();
489 return reinterpret_cast<OpParameter *>(strided_slice_param);
490 }
491
PopulateLstmGradParameter(const void * prim)492 OpParameter *PopulateLstmGradParameter(const void *prim) {
493 auto primitive = static_cast<const schema::Primitive *>(prim);
494 MS_ASSERT(primitive != nullptr);
495 auto value = primitive->value_as_LSTMGrad();
496 if (value == nullptr) {
497 MS_LOG(ERROR) << "value is nullptr.";
498 return nullptr;
499 }
500
501 auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
502 if (param == nullptr) {
503 MS_LOG(ERROR) << "malloc LstmParameter failed.";
504 return nullptr;
505 }
506 memset(param, 0, sizeof(LstmParameter));
507
508 param->op_parameter_.type_ = primitive->value_type();
509 param->bidirectional_ = value->bidirectional();
510 param->zoneout_cell_ = value->zoneout_cell();
511 param->zoneout_hidden_ = value->zoneout_hidden();
512 return reinterpret_cast<OpParameter *>(param);
513 }
514
PopulateTrainParameters()515 void PopulateTrainParameters() {
516 lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter,
517 lite::SCHEMA_CUR);
518 lite::Registry BiasGradParameterRegistry(schema::PrimitiveType_BiasAddGrad, PopulateBiasGradParameter,
519 lite::SCHEMA_CUR);
520 lite::Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits,
521 PopulateSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR);
522 lite::Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
523 PopulateSparseSoftmaxCrossEntropyWithLogitsParameter,
524 lite::SCHEMA_CUR);
525 lite::Registry ActivationParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter,
526 lite::SCHEMA_CUR);
527 lite::Registry DependParameterRegistry(schema::PrimitiveType_Depend, lite::DefaultPopulateParameter,
528 lite::SCHEMA_CUR);
529 lite::Registry Conv2DGradFilterParameterRegistry(schema::PrimitiveType_Conv2DBackpropFilterFusion,
530 PopulateConvolutionGradFilterParameter, lite::SCHEMA_CUR);
531 lite::Registry Conv2DGradInputParameterRegistry(schema::PrimitiveType_Conv2DBackpropInputFusion,
532 PopulateConvolutionGradInputParameter, lite::SCHEMA_CUR);
533 lite::Registry avgPoolParameterRegistry(schema::PrimitiveType_AvgPoolGrad, PopulateAvgPoolGradParameter,
534 lite::SCHEMA_CUR);
535 lite::Registry maxPoolParameterRegistry(schema::PrimitiveType_MaxPoolGrad, PopulateMaxPoolGradParameter,
536 lite::SCHEMA_CUR);
537 lite::Registry PowerGradParameterRegistry(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter,
538 lite::SCHEMA_CUR);
539 lite::Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
540 lite::Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter,
541 lite::SCHEMA_CUR);
542 lite::Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR);
543 lite::Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter,
544 lite::SCHEMA_CUR);
545 lite::Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter,
546 lite::SCHEMA_CUR);
547 lite::Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter,
548 lite::SCHEMA_CUR);
549 lite::Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad,
550 PopulateBCEGradParameter, lite::SCHEMA_CUR);
551 lite::Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, lite::DefaultPopulateParameter,
552 lite::SCHEMA_CUR);
553 lite::Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum,
554 lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
555 lite::Registry DropoutParameterRegistry(schema::PrimitiveType_Dropout, PopulateDropoutParameter, lite::SCHEMA_CUR);
556 lite::Registry DropGradParameterRegistry(schema::PrimitiveType_DropoutGrad, PopulateDropoutGradParameter,
557 lite::SCHEMA_CUR);
558 lite::Registry MaximumGradParameterRegistry(schema::PrimitiveType_MaximumGrad, PopulateArithmeticGradParameter,
559 lite::SCHEMA_CUR);
560 lite::Registry MinimumGradParameterRegistry(schema::PrimitiveType_MinimumGrad, PopulateArithmeticGradParameter,
561 lite::SCHEMA_CUR);
562 lite::Registry SmoothL1LossRegistry(schema::PrimitiveType_SmoothL1Loss, PopulateSmoothL1LossParameter,
563 lite::SCHEMA_CUR);
564 lite::Registry SmoothL1LossGradRegistry(schema::PrimitiveType_SmoothL1LossGrad, PopulateSmoothL1LossGradParameter,
565 lite::SCHEMA_CUR);
566 lite::Registry SigmoidCrossEntropyWithLogitsRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
567 lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
568 lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
569 lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
570 lite::Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, lite::DefaultPopulateParameter,
571 lite::SCHEMA_CUR);
572 lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad,
573 PopulateStridedSliceGradParameter, lite::SCHEMA_CUR);
574 lite::Registry SqrtGradParameterRegistry(schema::PrimitiveType_SqrtGrad, lite::DefaultPopulateParameter,
575 lite::SCHEMA_CUR);
576 lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter,
577 lite::SCHEMA_CUR);
578 Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, lite::SCHEMA_CUR);
579 Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
580 Registry LSTMGradParameterRegistry(schema::PrimitiveType_LSTMGrad, PopulateLstmGradParameter, lite::SCHEMA_CUR);
581 }
582 } // namespace kernel
583 } // namespace mindspore
584