1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/fp16/exp_fp16.h"
17 #include <math.h>
18 #include <string.h>
19 #include "nnacl/errorcode.h"
20
21 #if defined(ENABLE_NEON)
simd_exp_fp16(float16x8_t input,float16_t * dst)22 static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) {
23 static float16x8_t maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f,
24 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f};
25 static float16x8_t minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f,
26 -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f};
27 input = vmaxq_f16(minv, vminq_f16(input, maxv));
28 vst1q_f16(dst, VexpFp16(input));
29 }
30 #endif
31
ExpFp16(const float16_t * src,float16_t * dst,int num)32 void ExpFp16(const float16_t *src, float16_t *dst, int num) {
33 int i = 0;
34 #ifdef ENABLE_NEON
35 int count = (num / C8NUM) * C8NUM;
36 for (; i < count; i += C8NUM) {
37 simd_exp_fp16(vld1q_f16(src + i), dst + i);
38 }
39 #endif
40 for (; i < num; ++i) {
41 single_exp_fp16(src[i], dst + i);
42 }
43 }
44
ExpFusionFp16(const void * src_data,void * dst_data,const ExpStruct * exp,int task_id)45 int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) {
46 NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_);
47 ExpParameter *param = (ExpParameter *)exp->base_.param_;
48 NNACL_CHECK_NULL_RETURN_ERR(param);
49
50 float16_t *src = (float16_t *)src_data;
51 float16_t *dst = (float16_t *)dst_data;
52 int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_);
53 int start = stride * task_id;
54 int end = MSMIN(exp->element_num_, start + stride);
55 int num = end - start;
56
57 if (param->scale_ == 1) {
58 ExpFp16(src + start, dst + start, num);
59 } else {
60 int i = 0;
61 #ifdef ENABLE_ARM64
62 MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->in_scale_);
63 int count = (num / C8NUM) * C8NUM;
64 for (; i < count; i += C8NUM) {
65 simd_exp_fp16(MS_MULQ_F16(MS_LDQ_F16(src + i), scale), dst + i);
66 }
67 #endif
68 for (; i < num; ++i) {
69 single_exp_fp16(src[i] * exp->in_scale_, dst + i);
70 }
71 }
72 if (exp->out_scale_ != 1) {
73 int i = 0;
74 #ifdef ENABLE_ARM64
75 MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->out_scale_);
76 int count = (num / C8NUM) * C8NUM;
77 for (; i < count; i += C8NUM) {
78 simd_exp_fp16(MS_LDQ_F16(src + i), dst + i);
79 MS_STQ_F16(dst + i, MS_MULQ_F16(MS_LDQ_F16(dst + i), scale));
80 }
81 #endif
82 for (; i < num; ++i) {
83 single_exp_fp16(src[i], dst + i);
84 dst[i] *= exp->out_scale_;
85 }
86 }
87 return NNACL_OK;
88 }
89