1 /*
2 * Copyright (c) 2014 The WebM project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #ifndef VPX_VPX_DSP_ARM_VPX_CONVOLVE8_NEON_H_
12 #define VPX_VPX_DSP_ARM_VPX_CONVOLVE8_NEON_H_
13
14 #include <arm_neon.h>
15
16 #include "./vpx_config.h"
17 #include "./vpx_dsp_rtcd.h"
18 #include "vpx_dsp/vpx_filter.h"
19
20 #if VPX_ARCH_AARCH64 && defined(__ARM_FEATURE_DOTPROD)
21
22 void vpx_convolve8_2d_horiz_neon_dotprod(const uint8_t *src,
23 ptrdiff_t src_stride, uint8_t *dst,
24 ptrdiff_t dst_stride,
25 const InterpKernel *filter, int x0_q4,
26 int x_step_q4, int y0_q4,
27 int y_step_q4, int w, int h);
28
convolve8_4_sdot_partial(const int8x16_t samples_lo,const int8x16_t samples_hi,const int32x4_t correction,const int8x8_t filters)29 static INLINE int16x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo,
30 const int8x16_t samples_hi,
31 const int32x4_t correction,
32 const int8x8_t filters) {
33 /* Sample range-clamping and permutation are performed by the caller. */
34 int32x4_t sum;
35
36 /* Accumulate dot product into 'correction' to account for range clamp. */
37 sum = vdotq_lane_s32(correction, samples_lo, filters, 0);
38 sum = vdotq_lane_s32(sum, samples_hi, filters, 1);
39
40 /* Further narrowing and packing is performed by the caller. */
41 return vqmovn_s32(sum);
42 }
43
convolve8_4_sdot(uint8x16_t samples,const int8x8_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x2_t permute_tbl)44 static INLINE int16x4_t convolve8_4_sdot(uint8x16_t samples,
45 const int8x8_t filters,
46 const int32x4_t correction,
47 const uint8x16_t range_limit,
48 const uint8x16x2_t permute_tbl) {
49 int8x16_t clamped_samples, permuted_samples[2];
50 int32x4_t sum;
51
52 /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
53 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
54
55 /* Permute samples ready for dot product. */
56 /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
57 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
58 /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
59 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
60
61 /* Accumulate dot product into 'correction' to account for range clamp. */
62 sum = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
63 sum = vdotq_lane_s32(sum, permuted_samples[1], filters, 1);
64
65 /* Further narrowing and packing is performed by the caller. */
66 return vqmovn_s32(sum);
67 }
68
convolve8_8_sdot_partial(const int8x16_t samples0_lo,const int8x16_t samples0_hi,const int8x16_t samples1_lo,const int8x16_t samples1_hi,const int32x4_t correction,const int8x8_t filters)69 static INLINE uint8x8_t convolve8_8_sdot_partial(const int8x16_t samples0_lo,
70 const int8x16_t samples0_hi,
71 const int8x16_t samples1_lo,
72 const int8x16_t samples1_hi,
73 const int32x4_t correction,
74 const int8x8_t filters) {
75 /* Sample range-clamping and permutation are performed by the caller. */
76 int32x4_t sum0, sum1;
77 int16x8_t sum;
78
79 /* Accumulate dot product into 'correction' to account for range clamp. */
80 /* First 4 output values. */
81 sum0 = vdotq_lane_s32(correction, samples0_lo, filters, 0);
82 sum0 = vdotq_lane_s32(sum0, samples0_hi, filters, 1);
83 /* Second 4 output values. */
84 sum1 = vdotq_lane_s32(correction, samples1_lo, filters, 0);
85 sum1 = vdotq_lane_s32(sum1, samples1_hi, filters, 1);
86
87 /* Narrow and re-pack. */
88 sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
89 return vqrshrun_n_s16(sum, FILTER_BITS);
90 }
91
convolve8_8_sdot(uint8x16_t samples,const int8x8_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)92 static INLINE uint8x8_t convolve8_8_sdot(uint8x16_t samples,
93 const int8x8_t filters,
94 const int32x4_t correction,
95 const uint8x16_t range_limit,
96 const uint8x16x3_t permute_tbl) {
97 int8x16_t clamped_samples, permuted_samples[3];
98 int32x4_t sum0, sum1;
99 int16x8_t sum;
100
101 /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
102 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
103
104 /* Permute samples ready for dot product. */
105 /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
106 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
107 /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
108 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
109 /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
110 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
111
112 /* Accumulate dot product into 'correction' to account for range clamp. */
113 /* First 4 output values. */
114 sum0 = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
115 sum0 = vdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
116 /* Second 4 output values. */
117 sum1 = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
118 sum1 = vdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
119
120 /* Narrow and re-pack. */
121 sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
122 return vqrshrun_n_s16(sum, FILTER_BITS);
123 }
124
125 #endif // VPX_ARCH_AARCH64 && defined(__ARM_FEATURE_DOTPROD)
126
127 #if VPX_ARCH_AARCH64 && defined(__ARM_FEATURE_MATMUL_INT8)
128
129 void vpx_convolve8_2d_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
130 uint8_t *dst, ptrdiff_t dst_stride,
131 const InterpKernel *filter, int x0_q4,
132 int x_step_q4, int y0_q4, int y_step_q4,
133 int w, int h);
134
convolve8_4_usdot_partial(const uint8x16_t samples_lo,const uint8x16_t samples_hi,const int8x8_t filters)135 static INLINE int16x4_t convolve8_4_usdot_partial(const uint8x16_t samples_lo,
136 const uint8x16_t samples_hi,
137 const int8x8_t filters) {
138 /* Sample permutation is performed by the caller. */
139 int32x4_t sum;
140
141 sum = vusdotq_lane_s32(vdupq_n_s32(0), samples_lo, filters, 0);
142 sum = vusdotq_lane_s32(sum, samples_hi, filters, 1);
143
144 /* Further narrowing and packing is performed by the caller. */
145 return vqmovn_s32(sum);
146 }
147
convolve8_4_usdot(uint8x16_t samples,const int8x8_t filters,const uint8x16x2_t permute_tbl)148 static INLINE int16x4_t convolve8_4_usdot(uint8x16_t samples,
149 const int8x8_t filters,
150 const uint8x16x2_t permute_tbl) {
151 uint8x16_t permuted_samples[2];
152 int32x4_t sum;
153
154 /* Permute samples ready for dot product. */
155 /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
156 permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
157 /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
158 permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
159
160 /* Accumulate dot product into 'correction' to account for range clamp. */
161 sum = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
162 sum = vusdotq_lane_s32(sum, permuted_samples[1], filters, 1);
163
164 /* Further narrowing and packing is performed by the caller. */
165 return vqmovn_s32(sum);
166 }
167
convolve8_8_usdot_partial(const uint8x16_t samples0_lo,const uint8x16_t samples0_hi,const uint8x16_t samples1_lo,const uint8x16_t samples1_hi,const int8x8_t filters)168 static INLINE uint8x8_t convolve8_8_usdot_partial(const uint8x16_t samples0_lo,
169 const uint8x16_t samples0_hi,
170 const uint8x16_t samples1_lo,
171 const uint8x16_t samples1_hi,
172 const int8x8_t filters) {
173 /* Sample permutation is performed by the caller. */
174 int32x4_t sum0, sum1;
175 int16x8_t sum;
176
177 /* First 4 output values. */
178 sum0 = vusdotq_lane_s32(vdupq_n_s32(0), samples0_lo, filters, 0);
179 sum0 = vusdotq_lane_s32(sum0, samples0_hi, filters, 1);
180 /* Second 4 output values. */
181 sum1 = vusdotq_lane_s32(vdupq_n_s32(0), samples1_lo, filters, 0);
182 sum1 = vusdotq_lane_s32(sum1, samples1_hi, filters, 1);
183
184 /* Narrow and re-pack. */
185 sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
186 return vqrshrun_n_s16(sum, FILTER_BITS);
187 }
188
convolve8_8_usdot(uint8x16_t samples,const int8x8_t filters,const uint8x16x3_t permute_tbl)189 static INLINE uint8x8_t convolve8_8_usdot(uint8x16_t samples,
190 const int8x8_t filters,
191 const uint8x16x3_t permute_tbl) {
192 uint8x16_t permuted_samples[3];
193 int32x4_t sum0, sum1;
194 int16x8_t sum;
195
196 /* Permute samples ready for dot product. */
197 /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
198 permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
199 /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
200 permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
201 /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
202 permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
203
204 /* First 4 output values. */
205 sum0 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
206 sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
207 /* Second 4 output values. */
208 sum1 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0);
209 sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
210
211 /* Narrow and re-pack. */
212 sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
213 return vqrshrun_n_s16(sum, FILTER_BITS);
214 }
215
216 #endif // VPX_ARCH_AARCH64 && defined(__ARM_FEATURE_MATMUL_INT8)
217
convolve8_4(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,const int16x4_t s7,const int16x8_t filters)218 static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
219 const int16x4_t s2, const int16x4_t s3,
220 const int16x4_t s4, const int16x4_t s5,
221 const int16x4_t s6, const int16x4_t s7,
222 const int16x8_t filters) {
223 const int16x4_t filters_lo = vget_low_s16(filters);
224 const int16x4_t filters_hi = vget_high_s16(filters);
225 int16x4_t sum;
226
227 sum = vmul_lane_s16(s0, filters_lo, 0);
228 sum = vmla_lane_s16(sum, s1, filters_lo, 1);
229 sum = vmla_lane_s16(sum, s2, filters_lo, 2);
230 sum = vmla_lane_s16(sum, s5, filters_hi, 1);
231 sum = vmla_lane_s16(sum, s6, filters_hi, 2);
232 sum = vmla_lane_s16(sum, s7, filters_hi, 3);
233 sum = vqadd_s16(sum, vmul_lane_s16(s3, filters_lo, 3));
234 sum = vqadd_s16(sum, vmul_lane_s16(s4, filters_hi, 0));
235 return sum;
236 }
237
convolve8_8(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,const int16x8_t s7,const int16x8_t filters)238 static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
239 const int16x8_t s2, const int16x8_t s3,
240 const int16x8_t s4, const int16x8_t s5,
241 const int16x8_t s6, const int16x8_t s7,
242 const int16x8_t filters) {
243 const int16x4_t filters_lo = vget_low_s16(filters);
244 const int16x4_t filters_hi = vget_high_s16(filters);
245 int16x8_t sum;
246
247 sum = vmulq_lane_s16(s0, filters_lo, 0);
248 sum = vmlaq_lane_s16(sum, s1, filters_lo, 1);
249 sum = vmlaq_lane_s16(sum, s2, filters_lo, 2);
250 sum = vmlaq_lane_s16(sum, s5, filters_hi, 1);
251 sum = vmlaq_lane_s16(sum, s6, filters_hi, 2);
252 sum = vmlaq_lane_s16(sum, s7, filters_hi, 3);
253 sum = vqaddq_s16(sum, vmulq_lane_s16(s3, filters_lo, 3));
254 sum = vqaddq_s16(sum, vmulq_lane_s16(s4, filters_hi, 0));
255 return vqrshrun_n_s16(sum, FILTER_BITS);
256 }
257
scale_filter_8(const uint8x8_t * const s,const int16x8_t filters)258 static INLINE uint8x8_t scale_filter_8(const uint8x8_t *const s,
259 const int16x8_t filters) {
260 int16x8_t ss[8];
261
262 ss[0] = vreinterpretq_s16_u16(vmovl_u8(s[0]));
263 ss[1] = vreinterpretq_s16_u16(vmovl_u8(s[1]));
264 ss[2] = vreinterpretq_s16_u16(vmovl_u8(s[2]));
265 ss[3] = vreinterpretq_s16_u16(vmovl_u8(s[3]));
266 ss[4] = vreinterpretq_s16_u16(vmovl_u8(s[4]));
267 ss[5] = vreinterpretq_s16_u16(vmovl_u8(s[5]));
268 ss[6] = vreinterpretq_s16_u16(vmovl_u8(s[6]));
269 ss[7] = vreinterpretq_s16_u16(vmovl_u8(s[7]));
270
271 return convolve8_8(ss[0], ss[1], ss[2], ss[3], ss[4], ss[5], ss[6], ss[7],
272 filters);
273 }
274
275 #endif // VPX_VPX_DSP_ARM_VPX_CONVOLVE8_NEON_H_
276