1 /*
2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <assert.h>
13 #include <arm_neon.h>
14
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_ports/mem.h"
21
22 #define ROUND_SHIFT 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS
23
highbd_12_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)24 static INLINE void highbd_12_comp_avg_neon(const uint16_t *src_ptr,
25 int src_stride, uint16_t *dst_ptr,
26 int dst_stride, int w, int h,
27 ConvolveParams *conv_params,
28 const int offset, const int bd) {
29 CONV_BUF_TYPE *ref_ptr = conv_params->dst;
30 const int ref_stride = conv_params->dst_stride;
31 const uint16x4_t offset_vec = vdup_n_u16(offset);
32 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
33
34 if (w == 4) {
35 do {
36 const uint16x4_t src = vld1_u16(src_ptr);
37 const uint16x4_t ref = vld1_u16(ref_ptr);
38
39 uint16x4_t avg = vhadd_u16(src, ref);
40 int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
41
42 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
43 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
44
45 vst1_u16(dst_ptr, d0_u16);
46
47 src_ptr += src_stride;
48 ref_ptr += ref_stride;
49 dst_ptr += dst_stride;
50 } while (--h != 0);
51 } else {
52 do {
53 int width = w;
54 const uint16_t *src = src_ptr;
55 const uint16_t *ref = ref_ptr;
56 uint16_t *dst = dst_ptr;
57 do {
58 const uint16x8_t s = vld1q_u16(src);
59 const uint16x8_t r = vld1q_u16(ref);
60
61 uint16x8_t avg = vhaddq_u16(s, r);
62 int32x4_t d0_lo =
63 vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
64 int32x4_t d0_hi =
65 vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
66
67 uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT - 2),
68 vqrshrun_n_s32(d0_hi, ROUND_SHIFT - 2));
69 d0 = vminq_u16(d0, max);
70 vst1q_u16(dst, d0);
71
72 src += 8;
73 ref += 8;
74 dst += 8;
75 width -= 8;
76 } while (width != 0);
77
78 src_ptr += src_stride;
79 ref_ptr += ref_stride;
80 dst_ptr += dst_stride;
81 } while (--h != 0);
82 }
83 }
84
highbd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)85 static INLINE void highbd_comp_avg_neon(const uint16_t *src_ptr, int src_stride,
86 uint16_t *dst_ptr, int dst_stride,
87 int w, int h,
88 ConvolveParams *conv_params,
89 const int offset, const int bd) {
90 CONV_BUF_TYPE *ref_ptr = conv_params->dst;
91 const int ref_stride = conv_params->dst_stride;
92 const uint16x4_t offset_vec = vdup_n_u16(offset);
93 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
94
95 if (w == 4) {
96 do {
97 const uint16x4_t src = vld1_u16(src_ptr);
98 const uint16x4_t ref = vld1_u16(ref_ptr);
99
100 uint16x4_t avg = vhadd_u16(src, ref);
101 int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
102
103 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
104 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
105
106 vst1_u16(dst_ptr, d0_u16);
107
108 src_ptr += src_stride;
109 ref_ptr += ref_stride;
110 dst_ptr += dst_stride;
111 } while (--h != 0);
112 } else {
113 do {
114 int width = w;
115 const uint16_t *src = src_ptr;
116 const uint16_t *ref = ref_ptr;
117 uint16_t *dst = dst_ptr;
118 do {
119 const uint16x8_t s = vld1q_u16(src);
120 const uint16x8_t r = vld1q_u16(ref);
121
122 uint16x8_t avg = vhaddq_u16(s, r);
123 int32x4_t d0_lo =
124 vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
125 int32x4_t d0_hi =
126 vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
127
128 uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT),
129 vqrshrun_n_s32(d0_hi, ROUND_SHIFT));
130 d0 = vminq_u16(d0, max);
131 vst1q_u16(dst, d0);
132
133 src += 8;
134 ref += 8;
135 dst += 8;
136 width -= 8;
137 } while (width != 0);
138
139 src_ptr += src_stride;
140 ref_ptr += ref_stride;
141 dst_ptr += dst_stride;
142 } while (--h != 0);
143 }
144 }
145
highbd_12_dist_wtd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)146 static INLINE void highbd_12_dist_wtd_comp_avg_neon(
147 const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
148 int w, int h, ConvolveParams *conv_params, const int offset, const int bd) {
149 CONV_BUF_TYPE *ref_ptr = conv_params->dst;
150 const int ref_stride = conv_params->dst_stride;
151 const uint32x4_t offset_vec = vdupq_n_u32(offset);
152 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
153 uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
154 uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
155
156 // Weighted averaging
157 if (w == 4) {
158 do {
159 const uint16x4_t src = vld1_u16(src_ptr);
160 const uint16x4_t ref = vld1_u16(ref_ptr);
161
162 uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
163 wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
164 wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
165 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
166
167 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
168 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
169
170 vst1_u16(dst_ptr, d0_u16);
171
172 src_ptr += src_stride;
173 dst_ptr += dst_stride;
174 ref_ptr += ref_stride;
175 } while (--h != 0);
176 } else {
177 do {
178 int width = w;
179 const uint16_t *src = src_ptr;
180 const uint16_t *ref = ref_ptr;
181 uint16_t *dst = dst_ptr;
182 do {
183 const uint16x8_t s = vld1q_u16(src);
184 const uint16x8_t r = vld1q_u16(ref);
185
186 uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
187 wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
188 wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
189 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
190
191 uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
192 wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
193 wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
194 int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
195
196 uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT - 2),
197 vqrshrun_n_s32(d1, ROUND_SHIFT - 2));
198 d01 = vminq_u16(d01, max);
199 vst1q_u16(dst, d01);
200
201 src += 8;
202 ref += 8;
203 dst += 8;
204 width -= 8;
205 } while (width != 0);
206 src_ptr += src_stride;
207 dst_ptr += dst_stride;
208 ref_ptr += ref_stride;
209 } while (--h != 0);
210 }
211 }
212
highbd_dist_wtd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)213 static INLINE void highbd_dist_wtd_comp_avg_neon(
214 const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
215 int w, int h, ConvolveParams *conv_params, const int offset, const int bd) {
216 CONV_BUF_TYPE *ref_ptr = conv_params->dst;
217 const int ref_stride = conv_params->dst_stride;
218 const uint32x4_t offset_vec = vdupq_n_u32(offset);
219 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
220 uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
221 uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
222
223 // Weighted averaging
224 if (w == 4) {
225 do {
226 const uint16x4_t src = vld1_u16(src_ptr);
227 const uint16x4_t ref = vld1_u16(ref_ptr);
228
229 uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
230 wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
231 wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
232 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
233
234 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
235 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
236
237 vst1_u16(dst_ptr, d0_u16);
238
239 src_ptr += src_stride;
240 dst_ptr += dst_stride;
241 ref_ptr += ref_stride;
242 } while (--h != 0);
243 } else {
244 do {
245 int width = w;
246 const uint16_t *src = src_ptr;
247 const uint16_t *ref = ref_ptr;
248 uint16_t *dst = dst_ptr;
249 do {
250 const uint16x8_t s = vld1q_u16(src);
251 const uint16x8_t r = vld1q_u16(ref);
252
253 uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
254 wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
255 wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
256 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
257
258 uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
259 wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
260 wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
261 int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
262
263 uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT),
264 vqrshrun_n_s32(d1, ROUND_SHIFT));
265 d01 = vminq_u16(d01, max);
266 vst1q_u16(dst, d01);
267
268 src += 8;
269 ref += 8;
270 dst += 8;
271 width -= 8;
272 } while (width != 0);
273 src_ptr += src_stride;
274 dst_ptr += dst_stride;
275 ref_ptr += ref_stride;
276 } while (--h != 0);
277 }
278 }
279