1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11 #include <immintrin.h> // AVX2
12
13 #include "config/aom_dsp_rtcd.h"
14
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/synonyms_avx2.h"
17
aggregate_and_store_sum(uint32_t res[4],const __m256i * sum_ref0,const __m256i * sum_ref1,const __m256i * sum_ref2,const __m256i * sum_ref3)18 static AOM_FORCE_INLINE void aggregate_and_store_sum(uint32_t res[4],
19 const __m256i *sum_ref0,
20 const __m256i *sum_ref1,
21 const __m256i *sum_ref2,
22 const __m256i *sum_ref3) {
23 // In sum_ref-i the result is saved in the first 4 bytes and the other 4
24 // bytes are zeroed.
25 // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3
26 // 0, 0, 1, 1
27 __m256i sum_ref01 = _mm256_castps_si256(_mm256_shuffle_ps(
28 _mm256_castsi256_ps(*sum_ref0), _mm256_castsi256_ps(*sum_ref1),
29 _MM_SHUFFLE(2, 0, 2, 0)));
30 // 2, 2, 3, 3
31 __m256i sum_ref23 = _mm256_castps_si256(_mm256_shuffle_ps(
32 _mm256_castsi256_ps(*sum_ref2), _mm256_castsi256_ps(*sum_ref3),
33 _MM_SHUFFLE(2, 0, 2, 0)));
34
35 // sum adjacent 32 bit integers
36 __m256i sum_ref0123 = _mm256_hadd_epi32(sum_ref01, sum_ref23);
37
38 // add the low 128 bit to the high 128 bit
39 __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(sum_ref0123),
40 _mm256_extractf128_si256(sum_ref0123, 1));
41
42 _mm_storeu_si128((__m128i *)(res), sum);
43 }
44
aom_sadMxNx4d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])45 static AOM_FORCE_INLINE void aom_sadMxNx4d_avx2(
46 int M, int N, const uint8_t *src, int src_stride,
47 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
48 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
49 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
50 int i, j;
51 const uint8_t *ref0, *ref1, *ref2, *ref3;
52
53 ref0 = ref[0];
54 ref1 = ref[1];
55 ref2 = ref[2];
56 ref3 = ref[3];
57 sum_ref0 = _mm256_setzero_si256();
58 sum_ref2 = _mm256_setzero_si256();
59 sum_ref1 = _mm256_setzero_si256();
60 sum_ref3 = _mm256_setzero_si256();
61
62 for (i = 0; i < N; i++) {
63 for (j = 0; j < M; j += 32) {
64 // load src and all refs
65 src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
66 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
67 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
68 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
69 ref3_reg = _mm256_loadu_si256((const __m256i *)(ref3 + j));
70
71 // sum of the absolute differences between every ref-i to src
72 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
73 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
74 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
75 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
76 // sum every ref-i
77 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
78 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
79 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
80 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
81 }
82 src += src_stride;
83 ref0 += ref_stride;
84 ref1 += ref_stride;
85 ref2 += ref_stride;
86 ref3 += ref_stride;
87 }
88
89 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
90 }
91
aom_sadMxNx3d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])92 static AOM_FORCE_INLINE void aom_sadMxNx3d_avx2(
93 int M, int N, const uint8_t *src, int src_stride,
94 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
95 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
96 __m256i sum_ref0, sum_ref1, sum_ref2;
97 int i, j;
98 const uint8_t *ref0, *ref1, *ref2;
99 const __m256i zero = _mm256_setzero_si256();
100
101 ref0 = ref[0];
102 ref1 = ref[1];
103 ref2 = ref[2];
104 sum_ref0 = _mm256_setzero_si256();
105 sum_ref2 = _mm256_setzero_si256();
106 sum_ref1 = _mm256_setzero_si256();
107
108 for (i = 0; i < N; i++) {
109 for (j = 0; j < M; j += 32) {
110 // load src and all refs
111 src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
112 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
113 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
114 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
115
116 // sum of the absolute differences between every ref-i to src
117 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
118 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
119 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
120 // sum every ref-i
121 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
122 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
123 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
124 }
125 src += src_stride;
126 ref0 += ref_stride;
127 ref1 += ref_stride;
128 ref2 += ref_stride;
129 }
130 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
131 }
132
133 #define SADMXN_AVX2(m, n) \
134 void aom_sad##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
135 const uint8_t *const ref[4], int ref_stride, \
136 uint32_t res[4]) { \
137 aom_sadMxNx4d_avx2(m, n, src, src_stride, ref, ref_stride, res); \
138 } \
139 void aom_sad##m##x##n##x3d_avx2(const uint8_t *src, int src_stride, \
140 const uint8_t *const ref[4], int ref_stride, \
141 uint32_t res[4]) { \
142 aom_sadMxNx3d_avx2(m, n, src, src_stride, ref, ref_stride, res); \
143 }
144
145 SADMXN_AVX2(32, 8)
146 SADMXN_AVX2(32, 16)
147 SADMXN_AVX2(32, 32)
148 SADMXN_AVX2(32, 64)
149
150 SADMXN_AVX2(64, 16)
151 SADMXN_AVX2(64, 32)
152 SADMXN_AVX2(64, 64)
153 SADMXN_AVX2(64, 128)
154
155 SADMXN_AVX2(128, 64)
156 SADMXN_AVX2(128, 128)
157
158 #define SAD_SKIP_MXN_AVX2(m, n) \
159 void aom_sad_skip_##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
160 const uint8_t *const ref[4], \
161 int ref_stride, uint32_t res[4]) { \
162 aom_sadMxNx4d_avx2(m, ((n) >> 1), src, 2 * src_stride, ref, \
163 2 * ref_stride, res); \
164 res[0] <<= 1; \
165 res[1] <<= 1; \
166 res[2] <<= 1; \
167 res[3] <<= 1; \
168 }
169
170 SAD_SKIP_MXN_AVX2(32, 8)
171 SAD_SKIP_MXN_AVX2(32, 16)
172 SAD_SKIP_MXN_AVX2(32, 32)
173 SAD_SKIP_MXN_AVX2(32, 64)
174
175 SAD_SKIP_MXN_AVX2(64, 16)
176 SAD_SKIP_MXN_AVX2(64, 32)
177 SAD_SKIP_MXN_AVX2(64, 64)
178 SAD_SKIP_MXN_AVX2(64, 128)
179
180 SAD_SKIP_MXN_AVX2(128, 64)
181 SAD_SKIP_MXN_AVX2(128, 128)
182
aom_sad16xNx3d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])183 static AOM_FORCE_INLINE void aom_sad16xNx3d_avx2(int N, const uint8_t *src,
184 int src_stride,
185 const uint8_t *const ref[4],
186 int ref_stride,
187 uint32_t res[4]) {
188 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
189 __m256i sum_ref0, sum_ref1, sum_ref2;
190 const uint8_t *ref0, *ref1, *ref2;
191 const __m256i zero = _mm256_setzero_si256();
192 assert(N % 2 == 0);
193
194 ref0 = ref[0];
195 ref1 = ref[1];
196 ref2 = ref[2];
197 sum_ref0 = _mm256_setzero_si256();
198 sum_ref2 = _mm256_setzero_si256();
199 sum_ref1 = _mm256_setzero_si256();
200
201 for (int i = 0; i < N; i += 2) {
202 // load src and all refs
203 src_reg = yy_loadu2_128(src + src_stride, src);
204 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
205 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
206 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
207
208 // sum of the absolute differences between every ref-i to src
209 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
210 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
211 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
212
213 // sum every ref-i
214 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
215 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
216 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
217
218 src += 2 * src_stride;
219 ref0 += 2 * ref_stride;
220 ref1 += 2 * ref_stride;
221 ref2 += 2 * ref_stride;
222 }
223
224 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
225 }
226
aom_sad16xNx4d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])227 static AOM_FORCE_INLINE void aom_sad16xNx4d_avx2(int N, const uint8_t *src,
228 int src_stride,
229 const uint8_t *const ref[4],
230 int ref_stride,
231 uint32_t res[4]) {
232 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
233 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
234 const uint8_t *ref0, *ref1, *ref2, *ref3;
235 assert(N % 2 == 0);
236
237 ref0 = ref[0];
238 ref1 = ref[1];
239 ref2 = ref[2];
240 ref3 = ref[3];
241
242 sum_ref0 = _mm256_setzero_si256();
243 sum_ref2 = _mm256_setzero_si256();
244 sum_ref1 = _mm256_setzero_si256();
245 sum_ref3 = _mm256_setzero_si256();
246
247 for (int i = 0; i < N; i += 2) {
248 // load src and all refs
249 src_reg = yy_loadu2_128(src + src_stride, src);
250 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
251 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
252 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
253 ref3_reg = yy_loadu2_128(ref3 + ref_stride, ref3);
254
255 // sum of the absolute differences between every ref-i to src
256 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
257 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
258 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
259 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
260
261 // sum every ref-i
262 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
263 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
264 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
265 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
266
267 src += 2 * src_stride;
268 ref0 += 2 * ref_stride;
269 ref1 += 2 * ref_stride;
270 ref2 += 2 * ref_stride;
271 ref3 += 2 * ref_stride;
272 }
273
274 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
275 }
276
277 #define SAD16XNX3_AVX2(n) \
278 void aom_sad16x##n##x3d_avx2(const uint8_t *src, int src_stride, \
279 const uint8_t *const ref[4], int ref_stride, \
280 uint32_t res[4]) { \
281 aom_sad16xNx3d_avx2(n, src, src_stride, ref, ref_stride, res); \
282 }
283 #define SAD16XNX4_AVX2(n) \
284 void aom_sad16x##n##x4d_avx2(const uint8_t *src, int src_stride, \
285 const uint8_t *const ref[4], int ref_stride, \
286 uint32_t res[4]) { \
287 aom_sad16xNx4d_avx2(n, src, src_stride, ref, ref_stride, res); \
288 }
289
290 SAD16XNX4_AVX2(32)
291 SAD16XNX4_AVX2(16)
292 SAD16XNX4_AVX2(8)
293
294 SAD16XNX3_AVX2(32)
295 SAD16XNX3_AVX2(16)
296 SAD16XNX3_AVX2(8)
297
298 #if !CONFIG_REALTIME_ONLY
299 SAD16XNX3_AVX2(64)
300 SAD16XNX3_AVX2(4)
301
302 SAD16XNX4_AVX2(64)
303 SAD16XNX4_AVX2(4)
304
305 #endif // !CONFIG_REALTIME_ONLY
306
307 #define SAD_SKIP_16XN_AVX2(n) \
308 void aom_sad_skip_16x##n##x4d_avx2(const uint8_t *src, int src_stride, \
309 const uint8_t *const ref[4], \
310 int ref_stride, uint32_t res[4]) { \
311 aom_sad16xNx4d_avx2(((n) >> 1), src, 2 * src_stride, ref, 2 * ref_stride, \
312 res); \
313 res[0] <<= 1; \
314 res[1] <<= 1; \
315 res[2] <<= 1; \
316 res[3] <<= 1; \
317 }
318
319 SAD_SKIP_16XN_AVX2(32)
320 SAD_SKIP_16XN_AVX2(16)
321 SAD_SKIP_16XN_AVX2(8)
322
323 #if !CONFIG_REALTIME_ONLY
324 SAD_SKIP_16XN_AVX2(64)
325 SAD_SKIP_16XN_AVX2(4)
326 #endif // !CONFIG_REALTIME_ONLY
327