• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/infer/layer_norm_grad_infer.h"
17 #include "nnacl/infer/common_infer.h"
18 #include "nnacl/fp32_grad/layernormgrad_parameter.h"
19 #include "nnacl/infer/infer_register.h"
20 
LayerNormGradInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)21 int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
22                             OpParameter *parameter) {
23   int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 3);
24   if (check_ret != NNACL_OK) {
25     return check_ret;
26   }
27   LayerNormGradParameter *param = (LayerNormGradParameter *)parameter;
28   const TensorC *input_x = inputs[0];
29   TensorC *output_dx = outputs[0];
30   TensorC *output_dg = outputs[1];
31   TensorC *output_db = outputs[2];
32   SetDataTypeFormat(output_dx, input_x);
33   SetDataTypeFormat(output_dg, input_x);
34   SetDataTypeFormat(output_db, input_x);
35   SetShapeTensor(output_dx, input_x);
36   int begin_params_axis = param->begin_params_axis_;
37   if (param->begin_params_axis_ < 0) {
38     begin_params_axis += (int)(input_x->shape_size_);
39   }
40   size_t size = 0;
41   if (input_x->shape_size_ > MAX_SHAPE_SIZE) {
42     return NNACL_INPUT_TENSOR_ERROR;
43   }
44   for (int i = begin_params_axis; i < input_x->shape_size_; i++) {
45     if (size >= MAX_SHAPE_SIZE) {
46       return NNACL_ERR;
47     }
48     output_dg->shape_[size] = input_x->shape_[i];
49     output_db->shape_[size] = input_x->shape_[i];
50     size++;
51   }
52   output_db->shape_size_ = size;
53   output_dg->shape_size_ = size;
54   return NNACL_OK;
55 }
56 
57 REG_INFER(LayerNormGrad, PrimType_LayerNormGrad, LayerNormGradInferShape)
58