1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11 #include <immintrin.h> // AVX2
12
13 #include "config/aom_dsp_rtcd.h"
14
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/synonyms_avx2.h"
17
aggregate_and_store_sum(uint32_t res[4],const __m256i * sum_ref0,const __m256i * sum_ref1,const __m256i * sum_ref2,const __m256i * sum_ref3)18 static AOM_FORCE_INLINE void aggregate_and_store_sum(uint32_t res[4],
19 const __m256i *sum_ref0,
20 const __m256i *sum_ref1,
21 const __m256i *sum_ref2,
22 const __m256i *sum_ref3) {
23 // In sum_ref-i the result is saved in the first 4 bytes and the other 4
24 // bytes are zeroed.
25 // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3
26 // 0, 0, 1, 1
27 __m256i sum_ref01 = _mm256_castps_si256(_mm256_shuffle_ps(
28 _mm256_castsi256_ps(*sum_ref0), _mm256_castsi256_ps(*sum_ref1),
29 _MM_SHUFFLE(2, 0, 2, 0)));
30 // 2, 2, 3, 3
31 __m256i sum_ref23 = _mm256_castps_si256(_mm256_shuffle_ps(
32 _mm256_castsi256_ps(*sum_ref2), _mm256_castsi256_ps(*sum_ref3),
33 _MM_SHUFFLE(2, 0, 2, 0)));
34
35 // sum adjacent 32 bit integers
36 __m256i sum_ref0123 = _mm256_hadd_epi32(sum_ref01, sum_ref23);
37
38 // add the low 128 bit to the high 128 bit
39 __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(sum_ref0123),
40 _mm256_extractf128_si256(sum_ref0123, 1));
41
42 _mm_storeu_si128((__m128i *)(res), sum);
43 }
44
aom_sadMxNx4d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])45 static AOM_FORCE_INLINE void aom_sadMxNx4d_avx2(
46 int M, int N, const uint8_t *src, int src_stride,
47 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
48 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
49 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
50 int i, j;
51 const uint8_t *ref0, *ref1, *ref2, *ref3;
52
53 ref0 = ref[0];
54 ref1 = ref[1];
55 ref2 = ref[2];
56 ref3 = ref[3];
57 sum_ref0 = _mm256_setzero_si256();
58 sum_ref2 = _mm256_setzero_si256();
59 sum_ref1 = _mm256_setzero_si256();
60 sum_ref3 = _mm256_setzero_si256();
61
62 for (i = 0; i < N; i++) {
63 for (j = 0; j < M; j += 32) {
64 // load src and all refs
65 src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
66 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
67 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
68 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
69 ref3_reg = _mm256_loadu_si256((const __m256i *)(ref3 + j));
70
71 // sum of the absolute differences between every ref-i to src
72 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
73 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
74 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
75 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
76 // sum every ref-i
77 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
78 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
79 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
80 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
81 }
82 src += src_stride;
83 ref0 += ref_stride;
84 ref1 += ref_stride;
85 ref2 += ref_stride;
86 ref3 += ref_stride;
87 }
88
89 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
90 }
91
aom_sadMxNx3d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])92 static AOM_FORCE_INLINE void aom_sadMxNx3d_avx2(
93 int M, int N, const uint8_t *src, int src_stride,
94 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
95 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
96 __m256i sum_ref0, sum_ref1, sum_ref2;
97 int i, j;
98 const uint8_t *ref0, *ref1, *ref2;
99 const __m256i zero = _mm256_setzero_si256();
100
101 ref0 = ref[0];
102 ref1 = ref[1];
103 ref2 = ref[2];
104 sum_ref0 = _mm256_setzero_si256();
105 sum_ref2 = _mm256_setzero_si256();
106 sum_ref1 = _mm256_setzero_si256();
107
108 for (i = 0; i < N; i++) {
109 for (j = 0; j < M; j += 32) {
110 // load src and all refs
111 src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
112 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
113 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
114 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
115
116 // sum of the absolute differences between every ref-i to src
117 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
118 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
119 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
120 // sum every ref-i
121 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
122 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
123 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
124 }
125 src += src_stride;
126 ref0 += ref_stride;
127 ref1 += ref_stride;
128 ref2 += ref_stride;
129 }
130 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
131 }
132
133 #define SADMXN_AVX2(m, n) \
134 void aom_sad##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
135 const uint8_t *const ref[4], int ref_stride, \
136 uint32_t res[4]) { \
137 aom_sadMxNx4d_avx2(m, n, src, src_stride, ref, ref_stride, res); \
138 } \
139 void aom_sad##m##x##n##x3d_avx2(const uint8_t *src, int src_stride, \
140 const uint8_t *const ref[4], int ref_stride, \
141 uint32_t res[4]) { \
142 aom_sadMxNx3d_avx2(m, n, src, src_stride, ref, ref_stride, res); \
143 }
144
145 SADMXN_AVX2(32, 16)
146 SADMXN_AVX2(32, 32)
147 SADMXN_AVX2(32, 64)
148
149 SADMXN_AVX2(64, 32)
150 SADMXN_AVX2(64, 64)
151 SADMXN_AVX2(64, 128)
152
153 SADMXN_AVX2(128, 64)
154 SADMXN_AVX2(128, 128)
155
156 #if !CONFIG_REALTIME_ONLY
157 SADMXN_AVX2(32, 8)
158 SADMXN_AVX2(64, 16)
159 #endif // !CONFIG_REALTIME_ONLY
160
161 #define SAD_SKIP_MXN_AVX2(m, n) \
162 void aom_sad_skip_##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
163 const uint8_t *const ref[4], \
164 int ref_stride, uint32_t res[4]) { \
165 aom_sadMxNx4d_avx2(m, ((n) >> 1), src, 2 * src_stride, ref, \
166 2 * ref_stride, res); \
167 res[0] <<= 1; \
168 res[1] <<= 1; \
169 res[2] <<= 1; \
170 res[3] <<= 1; \
171 }
172
173 SAD_SKIP_MXN_AVX2(32, 16)
174 SAD_SKIP_MXN_AVX2(32, 32)
175 SAD_SKIP_MXN_AVX2(32, 64)
176
177 SAD_SKIP_MXN_AVX2(64, 32)
178 SAD_SKIP_MXN_AVX2(64, 64)
179 SAD_SKIP_MXN_AVX2(64, 128)
180
181 SAD_SKIP_MXN_AVX2(128, 64)
182 SAD_SKIP_MXN_AVX2(128, 128)
183
184 #if !CONFIG_REALTIME_ONLY
185 SAD_SKIP_MXN_AVX2(64, 16)
186 #endif // !CONFIG_REALTIME_ONLY
187
aom_sad16xNx3d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])188 static AOM_FORCE_INLINE void aom_sad16xNx3d_avx2(int N, const uint8_t *src,
189 int src_stride,
190 const uint8_t *const ref[4],
191 int ref_stride,
192 uint32_t res[4]) {
193 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
194 __m256i sum_ref0, sum_ref1, sum_ref2;
195 const uint8_t *ref0, *ref1, *ref2;
196 const __m256i zero = _mm256_setzero_si256();
197 assert(N % 2 == 0);
198
199 ref0 = ref[0];
200 ref1 = ref[1];
201 ref2 = ref[2];
202 sum_ref0 = _mm256_setzero_si256();
203 sum_ref2 = _mm256_setzero_si256();
204 sum_ref1 = _mm256_setzero_si256();
205
206 for (int i = 0; i < N; i += 2) {
207 // load src and all refs
208 src_reg = yy_loadu2_128(src + src_stride, src);
209 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
210 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
211 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
212
213 // sum of the absolute differences between every ref-i to src
214 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
215 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
216 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
217
218 // sum every ref-i
219 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
220 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
221 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
222
223 src += 2 * src_stride;
224 ref0 += 2 * ref_stride;
225 ref1 += 2 * ref_stride;
226 ref2 += 2 * ref_stride;
227 }
228
229 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
230 }
231
aom_sad16xNx4d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])232 static AOM_FORCE_INLINE void aom_sad16xNx4d_avx2(int N, const uint8_t *src,
233 int src_stride,
234 const uint8_t *const ref[4],
235 int ref_stride,
236 uint32_t res[4]) {
237 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
238 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
239 const uint8_t *ref0, *ref1, *ref2, *ref3;
240 assert(N % 2 == 0);
241
242 ref0 = ref[0];
243 ref1 = ref[1];
244 ref2 = ref[2];
245 ref3 = ref[3];
246
247 sum_ref0 = _mm256_setzero_si256();
248 sum_ref2 = _mm256_setzero_si256();
249 sum_ref1 = _mm256_setzero_si256();
250 sum_ref3 = _mm256_setzero_si256();
251
252 for (int i = 0; i < N; i += 2) {
253 // load src and all refs
254 src_reg = yy_loadu2_128(src + src_stride, src);
255 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
256 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
257 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
258 ref3_reg = yy_loadu2_128(ref3 + ref_stride, ref3);
259
260 // sum of the absolute differences between every ref-i to src
261 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
262 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
263 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
264 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
265
266 // sum every ref-i
267 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
268 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
269 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
270 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
271
272 src += 2 * src_stride;
273 ref0 += 2 * ref_stride;
274 ref1 += 2 * ref_stride;
275 ref2 += 2 * ref_stride;
276 ref3 += 2 * ref_stride;
277 }
278
279 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
280 }
281
282 #define SAD16XNX3_AVX2(n) \
283 void aom_sad16x##n##x3d_avx2(const uint8_t *src, int src_stride, \
284 const uint8_t *const ref[4], int ref_stride, \
285 uint32_t res[4]) { \
286 aom_sad16xNx3d_avx2(n, src, src_stride, ref, ref_stride, res); \
287 }
288 #define SAD16XNX4_AVX2(n) \
289 void aom_sad16x##n##x4d_avx2(const uint8_t *src, int src_stride, \
290 const uint8_t *const ref[4], int ref_stride, \
291 uint32_t res[4]) { \
292 aom_sad16xNx4d_avx2(n, src, src_stride, ref, ref_stride, res); \
293 }
294
295 SAD16XNX4_AVX2(32)
296 SAD16XNX4_AVX2(16)
297 SAD16XNX4_AVX2(8)
298
299 SAD16XNX3_AVX2(32)
300 SAD16XNX3_AVX2(16)
301 SAD16XNX3_AVX2(8)
302
303 #if !CONFIG_REALTIME_ONLY
304 SAD16XNX3_AVX2(64)
305 SAD16XNX3_AVX2(4)
306
307 SAD16XNX4_AVX2(64)
308 SAD16XNX4_AVX2(4)
309
310 #endif // !CONFIG_REALTIME_ONLY
311
312 #define SAD_SKIP_16XN_AVX2(n) \
313 void aom_sad_skip_16x##n##x4d_avx2(const uint8_t *src, int src_stride, \
314 const uint8_t *const ref[4], \
315 int ref_stride, uint32_t res[4]) { \
316 aom_sad16xNx4d_avx2(((n) >> 1), src, 2 * src_stride, ref, 2 * ref_stride, \
317 res); \
318 res[0] <<= 1; \
319 res[1] <<= 1; \
320 res[2] <<= 1; \
321 res[3] <<= 1; \
322 }
323
324 SAD_SKIP_16XN_AVX2(32)
325 SAD_SKIP_16XN_AVX2(16)
326
327 #if !CONFIG_REALTIME_ONLY
328 SAD_SKIP_16XN_AVX2(64)
329 #endif // !CONFIG_REALTIME_ONLY
330