• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE >= 1
7#include <assert.h>
8#include <math.h>
9
10#include <xnnpack/common.h>
11#include <xnnpack/math.h>
12#include <xnnpack/vcvt.h>
13
14#include <fp16.h>
15
16
17void xnn_f32_f16_vcvt_ukernel__scalar_fabsf_x${BATCH_TILE}(
18    size_t n,
19    const float* input,
20    void* output,
21    const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
22{
23  assert(n != 0);
24  assert(n % sizeof(float) == 0);
25  assert(input != NULL);
26  assert(output != NULL);
27
28  const float vscale_to_inf = params->scalar_fabsf.scale_to_inf;
29  const uint32_t vexp_bias = params->scalar_fabsf.exp_bias;
30  const float vscale_to_zero = params->scalar_fabsf.scale_to_zero;
31  const uint32_t vexpw_max = params->scalar_fabsf.expw_max;
32  const uint32_t vbias_min = params->scalar_fabsf.bias_min;
33  const uint16_t vexph_mask = params->scalar_fabsf.exph_mask;
34  const uint16_t vmanth_mask = params->scalar_fabsf.manth_mask;
35  const uint16_t vnanh = params->scalar_fabsf.nanh;
36
37  uint16_t* o = (uint16_t*) output;
38  $if BATCH_TILE > 1:
39    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
40      $for N in range(BATCH_TILE):
41        const float vx${N} = input[${N}];
42      input += ${BATCH_TILE};
43
44      $for N in range(BATCH_TILE):
45        const float vabsx${N} = fabsf(vx${N});
46      $for N in range(BATCH_TILE):
47        uint32_t vsignw${N} = fp32_to_bits(vx${N});
48
49      $for N in range(BATCH_TILE):
50        const uint32_t vnonsignw${N} = fp32_to_bits(vabsx${N});
51      $for N in range(BATCH_TILE):
52        float vf${N} = vabsx${N} * vscale_to_inf;
53
54      $for N in range(BATCH_TILE):
55        uint32_t vbias${N} = vnonsignw${N} + vexp_bias;
56      $for N in range(BATCH_TILE):
57        vsignw${N} ^= vnonsignw${N};
58
59      $for N in range(BATCH_TILE):
60        vf${N} *= vscale_to_zero;
61      $for N in range(BATCH_TILE):
62        vbias${N} &= vexpw_max;
63
64      $for N in range(BATCH_TILE):
65        vbias${N} = math_max_u32(vbias${N}, vbias_min);
66
67      $for N in range(BATCH_TILE):
68        vf${N} += fp32_from_bits(vbias${N});
69
70      $for N in range(BATCH_TILE):
71        const uint32_t vbits${N} = fp32_to_bits(vf${N});
72
73      $for N in range(BATCH_TILE):
74        const uint16_t vexph${N} = (uint16_t) (vbits${N} >> 13) & vexph_mask;
75      $for N in range(BATCH_TILE):
76        const uint16_t vmanth${N} = (uint16_t) vbits${N} & vmanth_mask;
77      $for N in range(BATCH_TILE):
78        const uint16_t vsignh${N} = (uint16_t) (vsignw${N} >> 16);
79
80      $for N in range(BATCH_TILE):
81        uint16_t vh${N} = vexph${N} + vmanth${N};
82      $for N in range(BATCH_TILE):
83        if XNN_UNPREDICTABLE(vnonsignw${N} > vexpw_max) {
84          vh${N} = vnanh;
85        }
86      $for N in range(BATCH_TILE):
87        vh${N} |= vsignh${N};
88
89      $for N in range(BATCH_TILE):
90        o[${N}] = vh${N};
91      o += ${BATCH_TILE};
92    }
93  $if BATCH_TILE == 1:
94    do {
95      const float vx = *input++;
96
97      const float vabsx = fabsf(vx);
98      uint32_t vsignw = fp32_to_bits(vx);
99
100      const uint32_t vnonsignw = fp32_to_bits(vabsx);
101      float vf = vabsx * vscale_to_inf;
102
103      uint32_t vbias = vnonsignw + vexp_bias;
104      vsignw ^= vnonsignw;
105
106      vf *= vscale_to_zero;
107      vbias &= vexpw_max;
108
109      vbias = math_max_u32(vbias, vbias_min);
110
111      vf += fp32_from_bits(vbias);
112
113      const uint32_t vbits = fp32_to_bits(vf);
114
115      const uint16_t vexph = (uint16_t) (vbits >> 13) & vexph_mask;
116      const uint16_t vmanth = (uint16_t) vbits & vmanth_mask;
117      const uint16_t vsignh = (uint16_t) (vsignw >> 16);
118
119      uint16_t vh = vexph + vmanth;
120      if XNN_UNPREDICTABLE(vnonsignw > vexpw_max) {
121        vh = vnanh;
122      }
123      vh |= vsignh;
124
125      *o++ = vh;
126
127      n -= sizeof(float);
128    } while (n != 0);
129  $elif BATCH_TILE == 2:
130    if XNN_UNLIKELY(n != 0) {
131      const float vx = *input;
132
133      const float vabsx = fabsf(vx);
134      uint32_t vsignw = fp32_to_bits(vx);
135
136      const uint32_t vnonsignw = fp32_to_bits(vabsx);
137      float vf = vabsx * vscale_to_inf;
138
139      uint32_t vbias = vnonsignw + vexp_bias;
140      vsignw ^= vnonsignw;
141
142      vf *= vscale_to_zero;
143      vbias &= vexpw_max;
144
145      vbias = math_max_u32(vbias, vbias_min);
146
147      vf += fp32_from_bits(vbias);
148
149      const uint32_t vbits = fp32_to_bits(vf);
150
151      const uint16_t vexph = (uint16_t) (vbits >> 13) & vexph_mask;
152      const uint16_t vmanth = (uint16_t) vbits & vmanth_mask;
153      const uint16_t vsignh = (uint16_t) (vsignw >> 16);
154
155      uint16_t vh = vexph + vmanth;
156      if XNN_UNPREDICTABLE(vnonsignw > vexpw_max) {
157        vh = vnanh;
158      }
159      vh |= vsignh;
160
161      *o = vh;
162    }
163  $else:
164    if XNN_UNLIKELY(n != 0) {
165      do {
166        const float vx = *input++;
167
168        const float vabsx = fabsf(vx);
169        uint32_t vsignw = fp32_to_bits(vx);
170
171        const uint32_t vnonsignw = fp32_to_bits(vabsx);
172        float vf = vabsx * vscale_to_inf;
173
174        uint32_t vbias = vnonsignw + vexp_bias;
175        vsignw ^= vnonsignw;
176
177        vf *= vscale_to_zero;
178        vbias &= vexpw_max;
179
180        vbias = math_max_u32(vbias, vbias_min);
181
182        vf += fp32_from_bits(vbias);
183
184        const uint32_t vbits = fp32_to_bits(vf);
185
186        const uint16_t vexph = (uint16_t) (vbits >> 13) & vexph_mask;
187        const uint16_t vmanth = (uint16_t) vbits & vmanth_mask;
188        const uint16_t vsignh = (uint16_t) (vsignw >> 16);
189
190        uint16_t vh = vexph + vmanth;
191        if XNN_UNPREDICTABLE(vnonsignw > vexpw_max) {
192          vh = vnanh;
193        }
194        vh |= vsignh;
195
196        *o++ = vh;
197
198        n -= sizeof(float);
199      } while (n != 0);
200    }
201}
202