• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/fused_batch_norm.h"
18 #include <math.h>
19 #include "nnacl/op_base.h"
20 #include "nnacl/kernel/default_kernel_base.h"
21 #include "nnacl/tensor_c_utils.h"
22 #include "nnacl/batchnorm_parameter.h"
23 #include "nnacl/fp32/batchnorm_fp32.h"
24 #include "nnacl/fp32/scale_fp32.h"
25 #ifdef ENABLE_FP16
26 #include "nnacl/fp16/scale_fp16.h"
27 #include "nnacl/fp16/batchnorm_fp16.h"
28 #endif
29 
FusedBatchNormInitScaleParam(FusedBatchNormStruct * fused_batch_norm)30 int FusedBatchNormInitScaleParam(FusedBatchNormStruct *fused_batch_norm) {
31   ScaleStruct *scale = &fused_batch_norm->scale_param_;
32   scale->base_.thread_nr_ = fused_batch_norm->bn_.base_.thread_nr_;
33 
34   scale->axis_ = kNHWC_C;
35   TensorC *in_tensor = fused_batch_norm->bn_.base_.in_[FIRST_INPUT];
36   if (in_tensor->shape_size_ != DIMENSION_4D) {
37     return NNACL_FUSED_BATCH_NORM_NO_CHANGE;
38   }
39 
40   scale->outer_size_ = 1;
41   for (int i = 0; i < scale->axis_; i++) {
42     scale->outer_size_ *= in_tensor->shape_[i];
43   }
44   scale->axis_size_ = in_tensor->shape_[Index3];
45   scale->inner_size_ = 1;
46   return NNACL_OK;
47 }
48 
FusedBatchNormCalculateScaleF32(FusedBatchNormStruct * fbn,const void * scale_data,const void * bias_data,const void * mean_data,const void * var_data,float eps,int kernel_num)49 void FusedBatchNormCalculateScaleF32(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data,
50                                      const void *mean_data, const void *var_data, float eps, int kernel_num) {
51   float *fp32_scale_origin = (float *)scale_data;
52   float *fp32_var_origin = (float *)var_data;
53   float *fp32_bias_origin = (float *)bias_data;
54   float *fp32_mean_origin = (float *)mean_data;
55 
56   float *fp32_scale = (float *)fbn->scale_;
57   for (int i = 0; i < kernel_num; i++) {
58     fp32_scale[i] = fp32_scale_origin[i] / sqrtf(fp32_var_origin[i] + eps);
59   }
60 
61   float *fp32_offset = (float *)fbn->offset_;
62   for (int i = 0; i < kernel_num; i++) {
63     fp32_offset[i] = fp32_bias_origin[i] - fp32_mean_origin[i] * fp32_scale[i];
64   }
65 }
66 
FusedBatchNormCalculateScaleF16(FusedBatchNormStruct * fbn,const void * scale_data,const void * bias_data,const void * mean_data,const void * var_data,float eps,int kernel_num)67 void FusedBatchNormCalculateScaleF16(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data,
68                                      const void *mean_data, const void *var_data, float eps, int kernel_num) {
69 #ifdef ENABLE_FP16
70   float16_t *fp16_scale_origin = (float16_t *)scale_data;
71   float16_t *fp16_var_origin = (float16_t *)var_data;
72   float16_t *fp16_bias_origin = (float16_t *)bias_data;
73   float16_t *fp16_mean_origin = (float16_t *)mean_data;
74 
75   float16_t *fp16_scale = (float16_t *)fbn->scale_;
76   for (int i = 0; i < kernel_num; i++) {
77     fp16_scale[i] = fp16_scale_origin[i] / sqrtf(fp16_var_origin[i] + eps);
78   }
79 
80   float16_t *fp16_offset = (float16_t *)fbn->offset_;
81   for (int i = 0; i < kernel_num; i++) {
82     fp16_offset[i] = fp16_bias_origin[i] - fp16_mean_origin[i] * fp16_scale[i];
83   }
84 #endif
85 }
86 
FusedBatchNormRunFp16(FusedBatchNormStruct * fused_batch_norm,int task_id)87 void FusedBatchNormRunFp16(FusedBatchNormStruct *fused_batch_norm, int task_id) {
88 #ifdef ENABLE_FP16
89   void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_;
90   void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_;
91 
92   if (fused_batch_norm->is_scale_) {
93     DoScaleFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)fused_batch_norm->scale_,
94                 (float16_t *)fused_batch_norm->offset_, task_id, &fused_batch_norm->scale_param_);
95   } else {
96     FusedBatchNormFp16((float16_t *)in_data, (float16_t *)fused_batch_norm->scale_,
97                        (float16_t *)fused_batch_norm->offset_, (float16_t *)fused_batch_norm->bn_.mean_,
98                        (float16_t *)fused_batch_norm->bn_.variance_, &fused_batch_norm->bn_, task_id,
99                        fused_batch_norm->bn_.base_.thread_nr_, (float16_t *)out_data);
100   }
101 #endif
102 }
103 
FusedBatchNormBatchnorm2Scale(FusedBatchNormStruct * fused_batch_norm,const void * scale_data,const void * bias_data,const void * mean_data,const void * var_data,float eps,int kernel_num)104 int FusedBatchNormBatchnorm2Scale(FusedBatchNormStruct *fused_batch_norm, const void *scale_data, const void *bias_data,
105                                   const void *mean_data, const void *var_data, float eps, int kernel_num) {
106   int ret = FusedBatchNormInitScaleParam(fused_batch_norm);
107   if (ret != NNACL_OK) {
108     return ret;
109   }
110 
111   ExecEnv *env = fused_batch_norm->bn_.base_.env_;
112   TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT];
113   fused_batch_norm->scale_ = env->Alloc(env->allocator_, GetSize(scale_tensor));
114   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_);
115   TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT];
116   fused_batch_norm->offset_ = env->Alloc(env->allocator_, GetSize(offset_tensor));
117   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_);
118 
119   // new scale: -scale / sqrt(variance + eps)
120   // new bias: -scale * mean / sqrt(variance + eps) + bias
121   if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) {
122     FusedBatchNormCalculateScaleF16(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num);
123   } else {
124     FusedBatchNormCalculateScaleF32(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num);
125   }
126 
127   fused_batch_norm->is_scale_ = true;
128   return NNACL_OK;
129 }
130 
FusedBatchNormInitConstTensor(FusedBatchNormStruct * fused_batch_norm)131 int FusedBatchNormInitConstTensor(FusedBatchNormStruct *fused_batch_norm) {
132   TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT];
133   TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT];
134   TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT];
135   TensorC *variance_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT];
136 
137   if (!fused_batch_norm->bn_.base_.train_session_) {
138     int ret = FusedBatchNormBatchnorm2Scale(
139       fused_batch_norm, (float *)scale_tensor->data_, (float *)offset_tensor->data_, (float *)mean_tensor->data_,
140       (float *)variance_tensor->data_, fused_batch_norm->bn_.epsilon_, GetElementNum(scale_tensor));
141     if (ret == NNACL_OK) {
142       return NNACL_OK;
143     } else {
144       fused_batch_norm->bn_.base_.Release(&fused_batch_norm->bn_.base_);
145       if (ret != NNACL_FUSED_BATCH_NORM_NO_CHANGE) {
146         return NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED;
147       }
148     }
149   }
150 
151   ExecEnv *env = fused_batch_norm->bn_.base_.env_;
152   fused_batch_norm->scale_ = env->Alloc(env->allocator_, GetSize(scale_tensor));
153   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_);
154   (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, GetSize(scale_tensor));
155   fused_batch_norm->offset_ = env->Alloc(env->allocator_, GetSize(offset_tensor));
156   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_);
157   (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, GetSize(offset_tensor));
158   fused_batch_norm->bn_.mean_ = env->Alloc(env->allocator_, GetSize(mean_tensor));
159   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.mean_);
160   (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, GetSize(mean_tensor));
161   fused_batch_norm->bn_.variance_ = env->Alloc(env->allocator_, GetSize(variance_tensor));
162   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.variance_);
163   (void)memcpy(fused_batch_norm->bn_.variance_, variance_tensor->data_, GetSize(variance_tensor));
164   return NNACL_OK;
165 }
166 
FusedBatchNormRunFp32(FusedBatchNormStruct * fused_batch_norm,int task_id)167 void FusedBatchNormRunFp32(FusedBatchNormStruct *fused_batch_norm, int task_id) {
168   void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_;
169   void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_;
170 
171   if (fused_batch_norm->is_scale_) {
172     DoScale((float *)in_data, (float *)out_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_,
173             task_id, &fused_batch_norm->scale_param_);
174   } else {
175     FusedBatchNormFp32((float *)in_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_,
176                        (float *)fused_batch_norm->bn_.mean_, (float *)fused_batch_norm->bn_.variance_,
177                        &fused_batch_norm->bn_, task_id, fused_batch_norm->bn_.base_.thread_nr_, (float *)out_data);
178   }
179 }
180 
FusedBatchNormRun(void * cdata,int task_id,float l,float r)181 int FusedBatchNormRun(void *cdata, int task_id, float l, float r) {
182   FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)cdata;
183   NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
184   if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) {
185     FusedBatchNormRunFp16(fused_batch_norm, task_id);
186   } else if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat32) {
187     FusedBatchNormRunFp32(fused_batch_norm, task_id);
188   }
189   return NNACL_OK;
190 }
191 
FusedBatchNormTrainComputeInit(FusedBatchNormStruct * fbn)192 int FusedBatchNormTrainComputeInit(FusedBatchNormStruct *fbn) {
193   if (fbn->bn_.base_.out_size_ < Num5) {
194     return NNACL_OK;
195   }
196 
197   TensorC *out_scale = fbn->bn_.base_.out_[SECOND_INPUT];
198   TensorC *out_offset = fbn->bn_.base_.out_[THIRD_INPUT];
199   TensorC *out_mean = fbn->bn_.base_.out_[FOURTH_INPUT];
200   TensorC *out_var = fbn->bn_.base_.out_[FIFTH_INPUT];
201 
202   void *current_mean = fbn->bn_.mean_;
203   void *current_var = fbn->bn_.variance_;
204 
205   bool schema_trained = ((BatchNormParameter *)fbn->bn_.base_.param_)->is_training_;
206   if (fbn->train_mode_ && schema_trained && fbn->bn_.base_.in_size_ >= Num5) {
207     TensorC *in_tensor = fbn->bn_.base_.in_[FIRST_INPUT];
208     TensorC *scale_tensor = fbn->bn_.base_.in_[SECOND_INPUT];
209     TensorC *offset_tensor = fbn->bn_.base_.in_[THIRD_INPUT];
210     TensorC *mean_tensor = fbn->bn_.base_.in_[FOURTH_INPUT];
211     TensorC *var_tensor = fbn->bn_.base_.in_[FIFTH_INPUT];
212     if (in_tensor->data_ == NULL || scale_tensor->data_ == NULL || offset_tensor->data_ == NULL ||
213         mean_tensor->data_ == NULL || var_tensor->data_ == NULL) {
214       return NNACL_FUSED_BATCH_TRAIN_DATA_INVALID;
215     }
216 
217     memset(current_mean, 0, GetSize(mean_tensor));
218     memset(current_var, 0, GetSize(var_tensor));
219 
220     bool isBatch2d = true;
221     if (fbn->bn_.base_.in_[FIRST_INPUT]->shape_size_ == Num2) isBatch2d = false;
222 
223     if (fbn->bn_.data_type_ == kNumberTypeFloat16) {
224 #ifdef ENABLE_FP16
225       FusedBatchNormFp16MeanVar((float16_t *)in_tensor->data_, (float16_t *)current_mean, current_var, &fbn->bn_,
226                                 (float16_t *)mean_tensor->data_, (float16_t *)var_tensor->data_);
227 #endif
228     } else {
229       FusedBatchNormFp32MeanVar((float *)in_tensor->data_, (float *)current_mean, current_var, &fbn->bn_,
230                                 (float *)mean_tensor->data_, (float *)var_tensor->data_, isBatch2d);
231     }
232 
233     (void)memcpy(out_scale->data_, scale_tensor->data_, GetSize(out_scale));
234     (void)memcpy(out_offset->data_, offset_tensor->data_, GetSize(out_offset));
235     (void)memcpy(out_mean->data_, current_mean, GetSize(out_mean));
236     (void)memcpy(out_var->data_, current_var, GetSize(out_var));
237 
238     // Copy to local variables
239     (void)memcpy(fbn->scale_, scale_tensor->data_, GetSize(scale_tensor));
240     (void)memcpy(fbn->offset_, offset_tensor->data_, GetSize(offset_tensor));
241 
242     fbn->trained_ = true;  // trained at least once
243     return NNACL_OK;
244   }
245 
246   if (fbn->bn_.base_.train_session_) {
247     (void)memcpy(out_scale->data_, fbn->scale_, GetSize(out_scale));
248     (void)memcpy(out_offset->data_, fbn->offset_, GetSize(out_offset));
249     (void)memcpy(out_mean->data_, current_mean, GetSize(out_mean));
250     (void)memcpy(out_var->data_, current_var, GetSize(out_var));
251   }
252 
253   return NNACL_OK;
254 }
255 
FusedBatchNormCompute(KernelBase * self)256 int FusedBatchNormCompute(KernelBase *self) {
257   FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
258   NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
259 
260   int ret = FusedBatchNormTrainComputeInit(fused_batch_norm);
261   if (ret != NNACL_OK) {
262     return ret;
263   }
264 
265   ret = self->env_->ParallelLaunch(self->env_->thread_pool_, FusedBatchNormRun, self, self->thread_nr_);
266   if (ret != NNACL_OK) {
267     return ret;
268   }
269   return NNACL_OK;
270 }
271 
FusedBatchNormReSize(KernelBase * self)272 int FusedBatchNormReSize(KernelBase *self) {
273   FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
274   NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
275 
276   int ret = BatchNormFillParam(&fused_batch_norm->bn_);
277   if (ret != NNACL_OK) {
278     return ret;
279   }
280 
281   (void)self->Release(self);
282 
283   return FusedBatchNormInitConstTensor(fused_batch_norm);
284 }
285 
FusedBatchNormPrepare(KernelBase * self)286 int FusedBatchNormPrepare(KernelBase *self) {
287   NNACL_CHECK_FALSE(self->in_size_ < FIVE_TENSOR, NNACL_ERR);
288   NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR);
289 
290   FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
291   NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
292   fused_batch_norm->bn_.momentum_ = ((BatchNormParameter *)self->param_)->momentum_;
293   fused_batch_norm->bn_.epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_;
294   return NNACL_OK;
295 }
296 
FusedBatchNormRelease(KernelBase * self)297 int FusedBatchNormRelease(KernelBase *self) {
298   FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
299   NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
300 
301   (void)BatchNormRelease(&fused_batch_norm->bn_.base_);
302 
303   if (fused_batch_norm->scale_ != NULL) {
304     self->env_->Free(self->env_->allocator_, fused_batch_norm->scale_);
305     fused_batch_norm->scale_ = NULL;
306   }
307   if (fused_batch_norm->offset_ != NULL) {
308     self->env_->Free(self->env_->allocator_, fused_batch_norm->offset_);
309     fused_batch_norm->offset_ = NULL;
310   }
311   return NNACL_OK;
312 }
313 
CreateFusedBatchNorm(OpParameter * param,int data_type)314 KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type) {
315   FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)malloc(sizeof(FusedBatchNormStruct));
316   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fused_batch_norm);
317   memset(fused_batch_norm, 0, sizeof(FusedBatchNormStruct));
318   fused_batch_norm->bn_.data_type_ = data_type;
319   fused_batch_norm->bn_.base_.Prepare = FusedBatchNormPrepare;
320   fused_batch_norm->bn_.base_.Resize = FusedBatchNormReSize;
321   fused_batch_norm->bn_.base_.Release = FusedBatchNormRelease;
322   fused_batch_norm->bn_.base_.Compute = FusedBatchNormCompute;
323   return (KernelBase *)fused_batch_norm;
324 }
325 
326 REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat16, CreateFusedBatchNorm)
327 REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat32, CreateFusedBatchNorm)
328