• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2023, Alliance for Open Media. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 
13 #include "config/aom_dsp_rtcd.h"
14 #include "aom_dsp/arm/mem_neon.h"
15 #include "aom_dsp/arm/sum_neon.h"
16 
sse_16x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x4_t * sse)17 static INLINE void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
18                                          uint32x4_t *sse) {
19   uint8x16_t s = vld1q_u8(src);
20   uint8x16_t r = vld1q_u8(ref);
21 
22   uint8x16_t abs_diff = vabdq_u8(s, r);
23 
24   *sse = vdotq_u32(*sse, abs_diff, abs_diff);
25 }
26 
sse_8x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x2_t * sse)27 static INLINE void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
28                                         uint32x2_t *sse) {
29   uint8x8_t s = vld1_u8(src);
30   uint8x8_t r = vld1_u8(ref);
31 
32   uint8x8_t abs_diff = vabd_u8(s, r);
33 
34   *sse = vdot_u32(*sse, abs_diff, abs_diff);
35 }
36 
sse_4x2_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,uint32x2_t * sse)37 static INLINE void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride,
38                                         const uint8_t *ref, int ref_stride,
39                                         uint32x2_t *sse) {
40   uint8x8_t s = load_unaligned_u8(src, src_stride);
41   uint8x8_t r = load_unaligned_u8(ref, ref_stride);
42 
43   uint8x8_t abs_diff = vabd_u8(s, r);
44 
45   *sse = vdot_u32(*sse, abs_diff, abs_diff);
46 }
47 
sse_wxh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)48 static INLINE uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride,
49                                             const uint8_t *ref, int ref_stride,
50                                             int width, int height) {
51   uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
52 
53   if ((width & 0x07) && ((width & 0x07) < 5)) {
54     int i = height;
55     do {
56       int j = 0;
57       do {
58         sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
59         sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
60                              &sse[1]);
61         j += 8;
62       } while (j + 4 < width);
63 
64       sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]);
65       src += 2 * src_stride;
66       ref += 2 * ref_stride;
67       i -= 2;
68     } while (i != 0);
69   } else {
70     int i = height;
71     do {
72       int j = 0;
73       do {
74         sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
75         sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
76                              &sse[1]);
77         j += 8;
78       } while (j < width);
79 
80       src += 2 * src_stride;
81       ref += 2 * ref_stride;
82       i -= 2;
83     } while (i != 0);
84   }
85   return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
86 }
87 
sse_128xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)88 static INLINE uint32_t sse_128xh_neon_dotprod(const uint8_t *src,
89                                               int src_stride,
90                                               const uint8_t *ref,
91                                               int ref_stride, int height) {
92   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
93 
94   int i = height;
95   do {
96     sse_16x1_neon_dotprod(src, ref, &sse[0]);
97     sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
98     sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
99     sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
100     sse_16x1_neon_dotprod(src + 64, ref + 64, &sse[0]);
101     sse_16x1_neon_dotprod(src + 80, ref + 80, &sse[1]);
102     sse_16x1_neon_dotprod(src + 96, ref + 96, &sse[0]);
103     sse_16x1_neon_dotprod(src + 112, ref + 112, &sse[1]);
104 
105     src += src_stride;
106     ref += ref_stride;
107   } while (--i != 0);
108 
109   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
110 }
111 
sse_64xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)112 static INLINE uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride,
113                                              const uint8_t *ref, int ref_stride,
114                                              int height) {
115   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
116 
117   int i = height;
118   do {
119     sse_16x1_neon_dotprod(src, ref, &sse[0]);
120     sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
121     sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
122     sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
123 
124     src += src_stride;
125     ref += ref_stride;
126   } while (--i != 0);
127 
128   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
129 }
130 
sse_32xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)131 static INLINE uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride,
132                                              const uint8_t *ref, int ref_stride,
133                                              int height) {
134   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
135 
136   int i = height;
137   do {
138     sse_16x1_neon_dotprod(src, ref, &sse[0]);
139     sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
140 
141     src += src_stride;
142     ref += ref_stride;
143   } while (--i != 0);
144 
145   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
146 }
147 
sse_16xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)148 static INLINE uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride,
149                                              const uint8_t *ref, int ref_stride,
150                                              int height) {
151   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
152 
153   int i = height;
154   do {
155     sse_16x1_neon_dotprod(src, ref, &sse[0]);
156     src += src_stride;
157     ref += ref_stride;
158     sse_16x1_neon_dotprod(src, ref, &sse[1]);
159     src += src_stride;
160     ref += ref_stride;
161     i -= 2;
162   } while (i != 0);
163 
164   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
165 }
166 
sse_8xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)167 static INLINE uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride,
168                                             const uint8_t *ref, int ref_stride,
169                                             int height) {
170   uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
171 
172   int i = height;
173   do {
174     sse_8x1_neon_dotprod(src, ref, &sse[0]);
175     src += src_stride;
176     ref += ref_stride;
177     sse_8x1_neon_dotprod(src, ref, &sse[1]);
178     src += src_stride;
179     ref += ref_stride;
180     i -= 2;
181   } while (i != 0);
182 
183   return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
184 }
185 
sse_4xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)186 static INLINE uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride,
187                                             const uint8_t *ref, int ref_stride,
188                                             int height) {
189   uint32x2_t sse = vdup_n_u32(0);
190 
191   int i = height;
192   do {
193     sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse);
194 
195     src += 2 * src_stride;
196     ref += 2 * ref_stride;
197     i -= 2;
198   } while (i != 0);
199 
200   return horizontal_add_u32x2(sse);
201 }
202 
aom_sse_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)203 int64_t aom_sse_neon_dotprod(const uint8_t *src, int src_stride,
204                              const uint8_t *ref, int ref_stride, int width,
205                              int height) {
206   switch (width) {
207     case 4:
208       return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
209     case 8:
210       return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
211     case 16:
212       return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
213     case 32:
214       return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
215     case 64:
216       return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
217     case 128:
218       return sse_128xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
219     default:
220       return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width,
221                                   height);
222   }
223 }
224