1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/infer/layer_norm_grad_infer.h"
17 #include "nnacl/infer/common_infer.h"
18 #include "nnacl/fp32_grad/layernormgrad_parameter.h"
19 #include "nnacl/infer/infer_register.h"
20
LayerNormGradInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)21 int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
22 OpParameter *parameter) {
23 int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 3);
24 if (check_ret != NNACL_OK) {
25 return check_ret;
26 }
27 LayerNormGradParameter *param = (LayerNormGradParameter *)parameter;
28 const TensorC *input_x = inputs[0];
29 TensorC *output_dx = outputs[0];
30 TensorC *output_dg = outputs[1];
31 TensorC *output_db = outputs[2];
32 SetDataTypeFormat(output_dx, input_x);
33 SetDataTypeFormat(output_dg, input_x);
34 SetDataTypeFormat(output_db, input_x);
35 SetShapeTensor(output_dx, input_x);
36 int begin_params_axis = param->begin_params_axis_;
37 if (param->begin_params_axis_ < 0) {
38 begin_params_axis += (int)(input_x->shape_size_);
39 }
40 size_t size = 0;
41 if (input_x->shape_size_ > MAX_SHAPE_SIZE) {
42 return NNACL_INPUT_TENSOR_ERROR;
43 }
44 for (int i = begin_params_axis; i < input_x->shape_size_; i++) {
45 if (size >= MAX_SHAPE_SIZE) {
46 return NNACL_ERR;
47 }
48 output_dg->shape_[size] = input_x->shape_[i];
49 output_db->shape_[size] = input_x->shape_[i];
50 size++;
51 }
52 output_db->shape_size_ = size;
53 output_dg->shape_size_ = size;
54 return NNACL_OK;
55 }
56
57 REG_INFER(LayerNormGrad, PrimType_LayerNormGrad, LayerNormGradInferShape)
58