1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/fused_batch_norm.h"
18 #include <math.h>
19 #include "nnacl/op_base.h"
20 #include "nnacl/kernel/default_kernel_base.h"
21 #include "nnacl/tensor_c_utils.h"
22 #include "nnacl/batchnorm_parameter.h"
23 #include "nnacl/fp32/batchnorm_fp32.h"
24 #include "nnacl/fp32/scale_fp32.h"
25 #ifdef ENABLE_FP16
26 #include "nnacl/fp16/scale_fp16.h"
27 #include "nnacl/fp16/batchnorm_fp16.h"
28 #endif
29
FusedBatchNormInitScaleParam(FusedBatchNormStruct * fused_batch_norm)30 int FusedBatchNormInitScaleParam(FusedBatchNormStruct *fused_batch_norm) {
31 ScaleStruct *scale = &fused_batch_norm->scale_param_;
32 scale->base_.thread_nr_ = fused_batch_norm->bn_.base_.thread_nr_;
33
34 scale->axis_ = kNHWC_C;
35 TensorC *in_tensor = fused_batch_norm->bn_.base_.in_[FIRST_INPUT];
36 if (in_tensor->shape_size_ != DIMENSION_4D) {
37 return NNACL_FUSED_BATCH_NORM_NO_CHANGE;
38 }
39
40 scale->outer_size_ = 1;
41 for (int i = 0; i < scale->axis_; i++) {
42 scale->outer_size_ *= in_tensor->shape_[i];
43 }
44 scale->axis_size_ = in_tensor->shape_[Index3];
45 scale->inner_size_ = 1;
46 return NNACL_OK;
47 }
48
FusedBatchNormCalculateScaleF32(FusedBatchNormStruct * fbn,const void * scale_data,const void * bias_data,const void * mean_data,const void * var_data,float eps,int kernel_num)49 void FusedBatchNormCalculateScaleF32(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data,
50 const void *mean_data, const void *var_data, float eps, int kernel_num) {
51 float *fp32_scale_origin = (float *)scale_data;
52 float *fp32_var_origin = (float *)var_data;
53 float *fp32_bias_origin = (float *)bias_data;
54 float *fp32_mean_origin = (float *)mean_data;
55
56 float *fp32_scale = (float *)fbn->scale_;
57 for (int i = 0; i < kernel_num; i++) {
58 fp32_scale[i] = fp32_scale_origin[i] / sqrtf(fp32_var_origin[i] + eps);
59 }
60
61 float *fp32_offset = (float *)fbn->offset_;
62 for (int i = 0; i < kernel_num; i++) {
63 fp32_offset[i] = fp32_bias_origin[i] - fp32_mean_origin[i] * fp32_scale[i];
64 }
65 }
66
FusedBatchNormCalculateScaleF16(FusedBatchNormStruct * fbn,const void * scale_data,const void * bias_data,const void * mean_data,const void * var_data,float eps,int kernel_num)67 void FusedBatchNormCalculateScaleF16(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data,
68 const void *mean_data, const void *var_data, float eps, int kernel_num) {
69 #ifdef ENABLE_FP16
70 float16_t *fp16_scale_origin = (float16_t *)scale_data;
71 float16_t *fp16_var_origin = (float16_t *)var_data;
72 float16_t *fp16_bias_origin = (float16_t *)bias_data;
73 float16_t *fp16_mean_origin = (float16_t *)mean_data;
74
75 float16_t *fp16_scale = (float16_t *)fbn->scale_;
76 for (int i = 0; i < kernel_num; i++) {
77 fp16_scale[i] = fp16_scale_origin[i] / sqrtf(fp16_var_origin[i] + eps);
78 }
79
80 float16_t *fp16_offset = (float16_t *)fbn->offset_;
81 for (int i = 0; i < kernel_num; i++) {
82 fp16_offset[i] = fp16_bias_origin[i] - fp16_mean_origin[i] * fp16_scale[i];
83 }
84 #endif
85 }
86
FusedBatchNormRunFp16(FusedBatchNormStruct * fused_batch_norm,int task_id)87 void FusedBatchNormRunFp16(FusedBatchNormStruct *fused_batch_norm, int task_id) {
88 #ifdef ENABLE_FP16
89 void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_;
90 void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_;
91
92 if (fused_batch_norm->is_scale_) {
93 DoScaleFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)fused_batch_norm->scale_,
94 (float16_t *)fused_batch_norm->offset_, task_id, &fused_batch_norm->scale_param_);
95 } else {
96 FusedBatchNormFp16((float16_t *)in_data, (float16_t *)fused_batch_norm->scale_,
97 (float16_t *)fused_batch_norm->offset_, (float16_t *)fused_batch_norm->bn_.mean_,
98 (float16_t *)fused_batch_norm->bn_.variance_, &fused_batch_norm->bn_, task_id,
99 fused_batch_norm->bn_.base_.thread_nr_, (float16_t *)out_data);
100 }
101 #endif
102 }
103
FusedBatchNormBatchnorm2Scale(FusedBatchNormStruct * fused_batch_norm,const void * scale_data,const void * bias_data,const void * mean_data,const void * var_data,float eps,int kernel_num)104 int FusedBatchNormBatchnorm2Scale(FusedBatchNormStruct *fused_batch_norm, const void *scale_data, const void *bias_data,
105 const void *mean_data, const void *var_data, float eps, int kernel_num) {
106 int ret = FusedBatchNormInitScaleParam(fused_batch_norm);
107 if (ret != NNACL_OK) {
108 return ret;
109 }
110
111 ExecEnv *env = fused_batch_norm->bn_.base_.env_;
112 TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT];
113 fused_batch_norm->scale_ = env->Alloc(env->allocator_, GetSize(scale_tensor));
114 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_);
115 TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT];
116 fused_batch_norm->offset_ = env->Alloc(env->allocator_, GetSize(offset_tensor));
117 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_);
118
119 // new scale: -scale / sqrt(variance + eps)
120 // new bias: -scale * mean / sqrt(variance + eps) + bias
121 if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) {
122 FusedBatchNormCalculateScaleF16(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num);
123 } else {
124 FusedBatchNormCalculateScaleF32(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num);
125 }
126
127 fused_batch_norm->is_scale_ = true;
128 return NNACL_OK;
129 }
130
FusedBatchNormInitConstTensor(FusedBatchNormStruct * fused_batch_norm)131 int FusedBatchNormInitConstTensor(FusedBatchNormStruct *fused_batch_norm) {
132 TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT];
133 TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT];
134 TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT];
135 TensorC *variance_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT];
136
137 if (!fused_batch_norm->bn_.base_.train_session_) {
138 int ret = FusedBatchNormBatchnorm2Scale(
139 fused_batch_norm, (float *)scale_tensor->data_, (float *)offset_tensor->data_, (float *)mean_tensor->data_,
140 (float *)variance_tensor->data_, fused_batch_norm->bn_.epsilon_, GetElementNum(scale_tensor));
141 if (ret == NNACL_OK) {
142 return NNACL_OK;
143 } else {
144 fused_batch_norm->bn_.base_.Release(&fused_batch_norm->bn_.base_);
145 if (ret != NNACL_FUSED_BATCH_NORM_NO_CHANGE) {
146 return NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED;
147 }
148 }
149 }
150
151 ExecEnv *env = fused_batch_norm->bn_.base_.env_;
152 fused_batch_norm->scale_ = env->Alloc(env->allocator_, GetSize(scale_tensor));
153 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_);
154 (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, GetSize(scale_tensor));
155 fused_batch_norm->offset_ = env->Alloc(env->allocator_, GetSize(offset_tensor));
156 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_);
157 (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, GetSize(offset_tensor));
158 fused_batch_norm->bn_.mean_ = env->Alloc(env->allocator_, GetSize(mean_tensor));
159 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.mean_);
160 (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, GetSize(mean_tensor));
161 fused_batch_norm->bn_.variance_ = env->Alloc(env->allocator_, GetSize(variance_tensor));
162 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.variance_);
163 (void)memcpy(fused_batch_norm->bn_.variance_, variance_tensor->data_, GetSize(variance_tensor));
164 return NNACL_OK;
165 }
166
FusedBatchNormRunFp32(FusedBatchNormStruct * fused_batch_norm,int task_id)167 void FusedBatchNormRunFp32(FusedBatchNormStruct *fused_batch_norm, int task_id) {
168 void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_;
169 void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_;
170
171 if (fused_batch_norm->is_scale_) {
172 DoScale((float *)in_data, (float *)out_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_,
173 task_id, &fused_batch_norm->scale_param_);
174 } else {
175 FusedBatchNormFp32((float *)in_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_,
176 (float *)fused_batch_norm->bn_.mean_, (float *)fused_batch_norm->bn_.variance_,
177 &fused_batch_norm->bn_, task_id, fused_batch_norm->bn_.base_.thread_nr_, (float *)out_data);
178 }
179 }
180
FusedBatchNormRun(void * cdata,int task_id,float l,float r)181 int FusedBatchNormRun(void *cdata, int task_id, float l, float r) {
182 FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)cdata;
183 NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
184 if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) {
185 FusedBatchNormRunFp16(fused_batch_norm, task_id);
186 } else if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat32) {
187 FusedBatchNormRunFp32(fused_batch_norm, task_id);
188 }
189 return NNACL_OK;
190 }
191
FusedBatchNormTrainComputeInit(FusedBatchNormStruct * fbn)192 int FusedBatchNormTrainComputeInit(FusedBatchNormStruct *fbn) {
193 if (fbn->bn_.base_.out_size_ < Num5) {
194 return NNACL_OK;
195 }
196
197 TensorC *out_scale = fbn->bn_.base_.out_[SECOND_INPUT];
198 TensorC *out_offset = fbn->bn_.base_.out_[THIRD_INPUT];
199 TensorC *out_mean = fbn->bn_.base_.out_[FOURTH_INPUT];
200 TensorC *out_var = fbn->bn_.base_.out_[FIFTH_INPUT];
201
202 void *current_mean = fbn->bn_.mean_;
203 void *current_var = fbn->bn_.variance_;
204
205 bool schema_trained = ((BatchNormParameter *)fbn->bn_.base_.param_)->is_training_;
206 if (fbn->train_mode_ && schema_trained && fbn->bn_.base_.in_size_ >= Num5) {
207 TensorC *in_tensor = fbn->bn_.base_.in_[FIRST_INPUT];
208 TensorC *scale_tensor = fbn->bn_.base_.in_[SECOND_INPUT];
209 TensorC *offset_tensor = fbn->bn_.base_.in_[THIRD_INPUT];
210 TensorC *mean_tensor = fbn->bn_.base_.in_[FOURTH_INPUT];
211 TensorC *var_tensor = fbn->bn_.base_.in_[FIFTH_INPUT];
212 if (in_tensor->data_ == NULL || scale_tensor->data_ == NULL || offset_tensor->data_ == NULL ||
213 mean_tensor->data_ == NULL || var_tensor->data_ == NULL) {
214 return NNACL_FUSED_BATCH_TRAIN_DATA_INVALID;
215 }
216
217 memset(current_mean, 0, GetSize(mean_tensor));
218 memset(current_var, 0, GetSize(var_tensor));
219
220 bool isBatch2d = true;
221 if (fbn->bn_.base_.in_[FIRST_INPUT]->shape_size_ == Num2) isBatch2d = false;
222
223 if (fbn->bn_.data_type_ == kNumberTypeFloat16) {
224 #ifdef ENABLE_FP16
225 FusedBatchNormFp16MeanVar((float16_t *)in_tensor->data_, (float16_t *)current_mean, current_var, &fbn->bn_,
226 (float16_t *)mean_tensor->data_, (float16_t *)var_tensor->data_);
227 #endif
228 } else {
229 FusedBatchNormFp32MeanVar((float *)in_tensor->data_, (float *)current_mean, current_var, &fbn->bn_,
230 (float *)mean_tensor->data_, (float *)var_tensor->data_, isBatch2d);
231 }
232
233 (void)memcpy(out_scale->data_, scale_tensor->data_, GetSize(out_scale));
234 (void)memcpy(out_offset->data_, offset_tensor->data_, GetSize(out_offset));
235 (void)memcpy(out_mean->data_, current_mean, GetSize(out_mean));
236 (void)memcpy(out_var->data_, current_var, GetSize(out_var));
237
238 // Copy to local variables
239 (void)memcpy(fbn->scale_, scale_tensor->data_, GetSize(scale_tensor));
240 (void)memcpy(fbn->offset_, offset_tensor->data_, GetSize(offset_tensor));
241
242 fbn->trained_ = true; // trained at least once
243 return NNACL_OK;
244 }
245
246 if (fbn->bn_.base_.train_session_) {
247 (void)memcpy(out_scale->data_, fbn->scale_, GetSize(out_scale));
248 (void)memcpy(out_offset->data_, fbn->offset_, GetSize(out_offset));
249 (void)memcpy(out_mean->data_, current_mean, GetSize(out_mean));
250 (void)memcpy(out_var->data_, current_var, GetSize(out_var));
251 }
252
253 return NNACL_OK;
254 }
255
FusedBatchNormCompute(KernelBase * self)256 int FusedBatchNormCompute(KernelBase *self) {
257 FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
258 NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
259
260 int ret = FusedBatchNormTrainComputeInit(fused_batch_norm);
261 if (ret != NNACL_OK) {
262 return ret;
263 }
264
265 ret = self->env_->ParallelLaunch(self->env_->thread_pool_, FusedBatchNormRun, self, self->thread_nr_);
266 if (ret != NNACL_OK) {
267 return ret;
268 }
269 return NNACL_OK;
270 }
271
FusedBatchNormReSize(KernelBase * self)272 int FusedBatchNormReSize(KernelBase *self) {
273 FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
274 NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
275
276 int ret = BatchNormFillParam(&fused_batch_norm->bn_);
277 if (ret != NNACL_OK) {
278 return ret;
279 }
280
281 (void)self->Release(self);
282
283 return FusedBatchNormInitConstTensor(fused_batch_norm);
284 }
285
FusedBatchNormPrepare(KernelBase * self)286 int FusedBatchNormPrepare(KernelBase *self) {
287 NNACL_CHECK_FALSE(self->in_size_ < FIVE_TENSOR, NNACL_ERR);
288 NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR);
289
290 FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
291 NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
292 fused_batch_norm->bn_.momentum_ = ((BatchNormParameter *)self->param_)->momentum_;
293 fused_batch_norm->bn_.epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_;
294 return NNACL_OK;
295 }
296
FusedBatchNormRelease(KernelBase * self)297 int FusedBatchNormRelease(KernelBase *self) {
298 FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self;
299 NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm);
300
301 (void)BatchNormRelease(&fused_batch_norm->bn_.base_);
302
303 if (fused_batch_norm->scale_ != NULL) {
304 self->env_->Free(self->env_->allocator_, fused_batch_norm->scale_);
305 fused_batch_norm->scale_ = NULL;
306 }
307 if (fused_batch_norm->offset_ != NULL) {
308 self->env_->Free(self->env_->allocator_, fused_batch_norm->offset_);
309 fused_batch_norm->offset_ = NULL;
310 }
311 return NNACL_OK;
312 }
313
CreateFusedBatchNorm(OpParameter * param,int data_type)314 KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type) {
315 FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)malloc(sizeof(FusedBatchNormStruct));
316 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fused_batch_norm);
317 memset(fused_batch_norm, 0, sizeof(FusedBatchNormStruct));
318 fused_batch_norm->bn_.data_type_ = data_type;
319 fused_batch_norm->bn_.base_.Prepare = FusedBatchNormPrepare;
320 fused_batch_norm->bn_.base_.Resize = FusedBatchNormReSize;
321 fused_batch_norm->bn_.base_.Release = FusedBatchNormRelease;
322 fused_batch_norm->bn_.base_.Compute = FusedBatchNormCompute;
323 return (KernelBase *)fused_batch_norm;
324 }
325
326 REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat16, CreateFusedBatchNorm)
327 REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat32, CreateFusedBatchNorm)
328