• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/layer_norm.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/tensor_c_utils.h"
20 #include "nnacl/fp32/layer_norm_fp32.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/layer_norm_fp16.h"
23 #endif
24 
LayerNormRun(void * cdata,int task_id,float l,float r)25 int LayerNormRun(void *cdata, int task_id, float l, float r) {
26   LayerNormStruct *ln = (LayerNormStruct *)cdata;
27   NNACL_CHECK_NULL_RETURN_ERR(ln);
28   if (ln->data_type_ == kNumberTypeFloat16) {
29 #ifdef ENABLE_FP16
30     return LayerNormFp16(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_,
31                          &ln->compute_, task_id, ln->base_.thread_nr_);
32 #endif
33   }
34   return LayerNorm(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_,
35                    &ln->compute_, task_id, ln->base_.thread_nr_);
36 }
37 
LayerNormResize(KernelBase * self)38 int LayerNormResize(KernelBase *self) {
39   LayerNormStruct *layer_norm = (LayerNormStruct *)self;
40   NNACL_CHECK_NULL_RETURN_ERR(layer_norm);
41   LayerNormComputeParam *compute = &layer_norm->compute_;
42 
43   TensorC *input = self->in_[FIRST_INPUT];
44   NNACL_CHECK_NULL_RETURN_ERR(input);
45 
46   if (compute->begin_norm_axis_ < 0) {
47     compute->begin_norm_axis_ = compute->begin_norm_axis_ + (int)input->shape_size_;
48   }
49 
50   if (compute->begin_params_axis_ < 0) {
51     compute->begin_params_axis_ = compute->begin_params_axis_ + (int)input->shape_size_;
52   }
53 
54   compute->norm_outer_size_ = 1;
55   for (int i = 0; i < compute->begin_norm_axis_; ++i) {
56     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_outer_size_, input->shape_[i], NNACL_ERR);
57     compute->norm_outer_size_ *= input->shape_[i];
58   }
59 
60   compute->norm_inner_size_ = 1;
61   for (size_t i = compute->begin_norm_axis_; i < input->shape_size_; ++i) {
62     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_inner_size_, input->shape_[i], NNACL_ERR);
63     compute->norm_inner_size_ *= input->shape_[i];
64   }
65 
66   compute->params_outer_size_ = 1;
67   for (int i = 0; i < compute->begin_params_axis_; ++i) {
68     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_outer_size_, input->shape_[i], NNACL_ERR);
69     compute->params_outer_size_ *= input->shape_[i];
70   }
71 
72   compute->params_inner_size_ = 1;
73   for (size_t i = compute->begin_params_axis_; i < input->shape_size_; ++i) {
74     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_inner_size_, input->shape_[i], NNACL_ERR);
75     compute->params_inner_size_ *= input->shape_[i];
76   }
77 
78   int out_num = GetElementNum(self->out_[OUTPUT_INDEX]);
79   self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_LayerNormFusion), compute->norm_inner_size_,
80                                         compute->norm_inner_size_, out_num, self->thread_nr_);
81   self->thread_nr_ = NNACL_MIN(compute->norm_outer_size_, self->thread_nr_);
82   return NNACL_OK;
83 }
84 
LayerNormCompute(KernelBase * self)85 int LayerNormCompute(KernelBase *self) {
86   LayerNormStruct *layer_norm = (LayerNormStruct *)self;
87   NNACL_CHECK_NULL_RETURN_ERR(layer_norm);
88 
89   layer_norm->src_data_ = self->in_[FIRST_INPUT]->data_;
90   NNACL_CHECK_NULL_RETURN_ERR(layer_norm->src_data_);
91   layer_norm->gamma_data_ = self->in_[SECOND_INPUT]->data_;
92   NNACL_CHECK_NULL_RETURN_ERR(layer_norm->gamma_data_);
93   layer_norm->beta_data_ = self->in_[THIRD_INPUT]->data_;
94   NNACL_CHECK_NULL_RETURN_ERR(layer_norm->beta_data_);
95   layer_norm->dst_data_ = self->out_[OUTPUT_INDEX]->data_;
96   NNACL_CHECK_NULL_RETURN_ERR(layer_norm->dst_data_);
97 
98   if (layer_norm->base_.out_size_ == THREE_TENSOR) {
99     layer_norm->mean_data_ = self->out_[Index1]->data_;
100     NNACL_CHECK_NULL_RETURN_ERR(layer_norm->mean_data_);
101     layer_norm->var_data_ = self->out_[Index2]->data_;
102     NNACL_CHECK_NULL_RETURN_ERR(layer_norm->var_data_);
103   } else if (layer_norm->base_.out_size_ != ONE_TENSOR) {
104     return NNACL_LAYER_NORM_OUTPUT_NUM_INVALID;
105   }
106 
107   return self->env_->ParallelLaunch(self->env_->thread_pool_, LayerNormRun, self, self->thread_nr_);
108 }
109 
CreateLayerNorm(OpParameter * param,int data_type)110 KernelBase *CreateLayerNorm(OpParameter *param, int data_type) {
111   LayerNormStruct *layer_norm = (LayerNormStruct *)malloc(sizeof(LayerNormStruct));
112   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(layer_norm);
113   memset(layer_norm, 0, sizeof(LayerNormStruct));
114   layer_norm->data_type_ = data_type;
115 
116   LayerNormParameter *layer_norm_param = (LayerNormParameter *)param;
117   layer_norm->compute_.epsilon_ = layer_norm_param->epsilon_;
118   layer_norm->compute_.elementwise_affine_ = layer_norm_param->elementwise_affine_;
119   layer_norm->compute_.begin_norm_axis_ = layer_norm_param->begin_norm_axis_;
120   layer_norm->compute_.begin_params_axis_ = layer_norm_param->begin_params_axis_;
121 
122   layer_norm->base_.Prepare = DefaultPrepare3In1Out;
123   layer_norm->base_.Release = DefaultRelease;
124   layer_norm->base_.Resize = LayerNormResize;
125   layer_norm->base_.Compute = LayerNormCompute;
126   return (KernelBase *)layer_norm;
127 }
128 
129 REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat16, CreateLayerNorm)
130 REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat32, CreateLayerNorm)
131