• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/fp16/exp_fp16.h"
17 #include <math.h>
18 #include <string.h>
19 #include "nnacl/errorcode.h"
20 
21 #if defined(ENABLE_NEON)
simd_exp_fp16(float16x8_t input,float16_t * dst)22 static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) {
23   static float16x8_t maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f,
24                              88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f};
25   static float16x8_t minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f,
26                              -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f};
27   input = vmaxq_f16(minv, vminq_f16(input, maxv));
28   vst1q_f16(dst, VexpFp16(input));
29 }
30 #endif
31 
ExpFp16(const float16_t * src,float16_t * dst,int num)32 void ExpFp16(const float16_t *src, float16_t *dst, int num) {
33   int i = 0;
34 #ifdef ENABLE_NEON
35   int count = (num / C8NUM) * C8NUM;
36   for (; i < count; i += C8NUM) {
37     simd_exp_fp16(vld1q_f16(src + i), dst + i);
38   }
39 #endif
40   for (; i < num; ++i) {
41     single_exp_fp16(src[i], dst + i);
42   }
43 }
44 
ExpFusionFp16(const void * src_data,void * dst_data,const ExpStruct * exp,int task_id)45 int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) {
46   NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_);
47   ExpParameter *param = (ExpParameter *)exp->base_.param_;
48   NNACL_CHECK_NULL_RETURN_ERR(param);
49 
50   float16_t *src = (float16_t *)src_data;
51   float16_t *dst = (float16_t *)dst_data;
52   int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_);
53   int start = stride * task_id;
54   int end = MSMIN(exp->element_num_, start + stride);
55   int num = end - start;
56 
57   if (param->scale_ == 1) {
58     ExpFp16(src + start, dst + start, num);
59   } else {
60     int i = 0;
61 #ifdef ENABLE_ARM64
62     MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->in_scale_);
63     int count = (num / C8NUM) * C8NUM;
64     for (; i < count; i += C8NUM) {
65       simd_exp_fp16(MS_MULQ_F16(MS_LDQ_F16(src + i), scale), dst + i);
66     }
67 #endif
68     for (; i < num; ++i) {
69       single_exp_fp16(src[i] * exp->in_scale_, dst + i);
70     }
71   }
72   if (exp->out_scale_ != 1) {
73     int i = 0;
74 #ifdef ENABLE_ARM64
75     MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->out_scale_);
76     int count = (num / C8NUM) * C8NUM;
77     for (; i < count; i += C8NUM) {
78       simd_exp_fp16(MS_LDQ_F16(src + i), dst + i);
79       MS_STQ_F16(dst + i, MS_MULQ_F16(MS_LDQ_F16(dst + i), scale));
80     }
81 #endif
82     for (; i < num; ++i) {
83       single_exp_fp16(src[i], dst + i);
84       dst[i] *= exp->out_scale_;
85     }
86   }
87   return NNACL_OK;
88 }
89