• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <arm_neon.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_ports/mem.h"
21 
22 #define ROUND_SHIFT 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS
23 
highbd_12_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)24 static INLINE void highbd_12_comp_avg_neon(const uint16_t *src_ptr,
25                                            int src_stride, uint16_t *dst_ptr,
26                                            int dst_stride, int w, int h,
27                                            ConvolveParams *conv_params,
28                                            const int offset, const int bd) {
29   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
30   const int ref_stride = conv_params->dst_stride;
31   const uint16x4_t offset_vec = vdup_n_u16(offset);
32   const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
33 
34   if (w == 4) {
35     do {
36       const uint16x4_t src = vld1_u16(src_ptr);
37       const uint16x4_t ref = vld1_u16(ref_ptr);
38 
39       uint16x4_t avg = vhadd_u16(src, ref);
40       int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
41 
42       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
43       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
44 
45       vst1_u16(dst_ptr, d0_u16);
46 
47       src_ptr += src_stride;
48       ref_ptr += ref_stride;
49       dst_ptr += dst_stride;
50     } while (--h != 0);
51   } else {
52     do {
53       int width = w;
54       const uint16_t *src = src_ptr;
55       const uint16_t *ref = ref_ptr;
56       uint16_t *dst = dst_ptr;
57       do {
58         const uint16x8_t s = vld1q_u16(src);
59         const uint16x8_t r = vld1q_u16(ref);
60 
61         uint16x8_t avg = vhaddq_u16(s, r);
62         int32x4_t d0_lo =
63             vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
64         int32x4_t d0_hi =
65             vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
66 
67         uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT - 2),
68                                      vqrshrun_n_s32(d0_hi, ROUND_SHIFT - 2));
69         d0 = vminq_u16(d0, max);
70         vst1q_u16(dst, d0);
71 
72         src += 8;
73         ref += 8;
74         dst += 8;
75         width -= 8;
76       } while (width != 0);
77 
78       src_ptr += src_stride;
79       ref_ptr += ref_stride;
80       dst_ptr += dst_stride;
81     } while (--h != 0);
82   }
83 }
84 
highbd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)85 static INLINE void highbd_comp_avg_neon(const uint16_t *src_ptr, int src_stride,
86                                         uint16_t *dst_ptr, int dst_stride,
87                                         int w, int h,
88                                         ConvolveParams *conv_params,
89                                         const int offset, const int bd) {
90   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
91   const int ref_stride = conv_params->dst_stride;
92   const uint16x4_t offset_vec = vdup_n_u16(offset);
93   const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
94 
95   if (w == 4) {
96     do {
97       const uint16x4_t src = vld1_u16(src_ptr);
98       const uint16x4_t ref = vld1_u16(ref_ptr);
99 
100       uint16x4_t avg = vhadd_u16(src, ref);
101       int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
102 
103       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
104       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
105 
106       vst1_u16(dst_ptr, d0_u16);
107 
108       src_ptr += src_stride;
109       ref_ptr += ref_stride;
110       dst_ptr += dst_stride;
111     } while (--h != 0);
112   } else {
113     do {
114       int width = w;
115       const uint16_t *src = src_ptr;
116       const uint16_t *ref = ref_ptr;
117       uint16_t *dst = dst_ptr;
118       do {
119         const uint16x8_t s = vld1q_u16(src);
120         const uint16x8_t r = vld1q_u16(ref);
121 
122         uint16x8_t avg = vhaddq_u16(s, r);
123         int32x4_t d0_lo =
124             vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
125         int32x4_t d0_hi =
126             vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
127 
128         uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT),
129                                      vqrshrun_n_s32(d0_hi, ROUND_SHIFT));
130         d0 = vminq_u16(d0, max);
131         vst1q_u16(dst, d0);
132 
133         src += 8;
134         ref += 8;
135         dst += 8;
136         width -= 8;
137       } while (width != 0);
138 
139       src_ptr += src_stride;
140       ref_ptr += ref_stride;
141       dst_ptr += dst_stride;
142     } while (--h != 0);
143   }
144 }
145 
highbd_12_dist_wtd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)146 static INLINE void highbd_12_dist_wtd_comp_avg_neon(
147     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
148     int w, int h, ConvolveParams *conv_params, const int offset, const int bd) {
149   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
150   const int ref_stride = conv_params->dst_stride;
151   const uint32x4_t offset_vec = vdupq_n_u32(offset);
152   const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
153   uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
154   uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
155 
156   // Weighted averaging
157   if (w == 4) {
158     do {
159       const uint16x4_t src = vld1_u16(src_ptr);
160       const uint16x4_t ref = vld1_u16(ref_ptr);
161 
162       uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
163       wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
164       wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
165       int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
166 
167       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
168       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
169 
170       vst1_u16(dst_ptr, d0_u16);
171 
172       src_ptr += src_stride;
173       dst_ptr += dst_stride;
174       ref_ptr += ref_stride;
175     } while (--h != 0);
176   } else {
177     do {
178       int width = w;
179       const uint16_t *src = src_ptr;
180       const uint16_t *ref = ref_ptr;
181       uint16_t *dst = dst_ptr;
182       do {
183         const uint16x8_t s = vld1q_u16(src);
184         const uint16x8_t r = vld1q_u16(ref);
185 
186         uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
187         wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
188         wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
189         int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
190 
191         uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
192         wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
193         wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
194         int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
195 
196         uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT - 2),
197                                       vqrshrun_n_s32(d1, ROUND_SHIFT - 2));
198         d01 = vminq_u16(d01, max);
199         vst1q_u16(dst, d01);
200 
201         src += 8;
202         ref += 8;
203         dst += 8;
204         width -= 8;
205       } while (width != 0);
206       src_ptr += src_stride;
207       dst_ptr += dst_stride;
208       ref_ptr += ref_stride;
209     } while (--h != 0);
210   }
211 }
212 
highbd_dist_wtd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int offset,const int bd)213 static INLINE void highbd_dist_wtd_comp_avg_neon(
214     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
215     int w, int h, ConvolveParams *conv_params, const int offset, const int bd) {
216   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
217   const int ref_stride = conv_params->dst_stride;
218   const uint32x4_t offset_vec = vdupq_n_u32(offset);
219   const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
220   uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
221   uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
222 
223   // Weighted averaging
224   if (w == 4) {
225     do {
226       const uint16x4_t src = vld1_u16(src_ptr);
227       const uint16x4_t ref = vld1_u16(ref_ptr);
228 
229       uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
230       wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
231       wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
232       int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
233 
234       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
235       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
236 
237       vst1_u16(dst_ptr, d0_u16);
238 
239       src_ptr += src_stride;
240       dst_ptr += dst_stride;
241       ref_ptr += ref_stride;
242     } while (--h != 0);
243   } else {
244     do {
245       int width = w;
246       const uint16_t *src = src_ptr;
247       const uint16_t *ref = ref_ptr;
248       uint16_t *dst = dst_ptr;
249       do {
250         const uint16x8_t s = vld1q_u16(src);
251         const uint16x8_t r = vld1q_u16(ref);
252 
253         uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
254         wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
255         wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
256         int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
257 
258         uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
259         wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
260         wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
261         int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
262 
263         uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT),
264                                       vqrshrun_n_s32(d1, ROUND_SHIFT));
265         d01 = vminq_u16(d01, max);
266         vst1q_u16(dst, d01);
267 
268         src += 8;
269         ref += 8;
270         dst += 8;
271         width -= 8;
272       } while (width != 0);
273       src_ptr += src_stride;
274       dst_ptr += dst_stride;
275       ref_ptr += ref_stride;
276     } while (--h != 0);
277   }
278 }
279