• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #include <immintrin.h>  // AVX2
12 
13 #include "config/aom_dsp_rtcd.h"
14 
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/synonyms_avx2.h"
17 
aggregate_and_store_sum(uint32_t res[4],const __m256i * sum_ref0,const __m256i * sum_ref1,const __m256i * sum_ref2,const __m256i * sum_ref3)18 static AOM_FORCE_INLINE void aggregate_and_store_sum(uint32_t res[4],
19                                                      const __m256i *sum_ref0,
20                                                      const __m256i *sum_ref1,
21                                                      const __m256i *sum_ref2,
22                                                      const __m256i *sum_ref3) {
23   // In sum_ref-i the result is saved in the first 4 bytes and the other 4
24   // bytes are zeroed.
25   // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3
26   // 0, 0, 1, 1
27   __m256i sum_ref01 = _mm256_castps_si256(_mm256_shuffle_ps(
28       _mm256_castsi256_ps(*sum_ref0), _mm256_castsi256_ps(*sum_ref1),
29       _MM_SHUFFLE(2, 0, 2, 0)));
30   // 2, 2, 3, 3
31   __m256i sum_ref23 = _mm256_castps_si256(_mm256_shuffle_ps(
32       _mm256_castsi256_ps(*sum_ref2), _mm256_castsi256_ps(*sum_ref3),
33       _MM_SHUFFLE(2, 0, 2, 0)));
34 
35   // sum adjacent 32 bit integers
36   __m256i sum_ref0123 = _mm256_hadd_epi32(sum_ref01, sum_ref23);
37 
38   // add the low 128 bit to the high 128 bit
39   __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(sum_ref0123),
40                               _mm256_extractf128_si256(sum_ref0123, 1));
41 
42   _mm_storeu_si128((__m128i *)(res), sum);
43 }
44 
aom_sadMxNx4d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])45 static AOM_FORCE_INLINE void aom_sadMxNx4d_avx2(
46     int M, int N, const uint8_t *src, int src_stride,
47     const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
48   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
49   __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
50   int i, j;
51   const uint8_t *ref0, *ref1, *ref2, *ref3;
52 
53   ref0 = ref[0];
54   ref1 = ref[1];
55   ref2 = ref[2];
56   ref3 = ref[3];
57   sum_ref0 = _mm256_setzero_si256();
58   sum_ref2 = _mm256_setzero_si256();
59   sum_ref1 = _mm256_setzero_si256();
60   sum_ref3 = _mm256_setzero_si256();
61 
62   for (i = 0; i < N; i++) {
63     for (j = 0; j < M; j += 32) {
64       // load src and all refs
65       src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
66       ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
67       ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
68       ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
69       ref3_reg = _mm256_loadu_si256((const __m256i *)(ref3 + j));
70 
71       // sum of the absolute differences between every ref-i to src
72       ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
73       ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
74       ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
75       ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
76       // sum every ref-i
77       sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
78       sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
79       sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
80       sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
81     }
82     src += src_stride;
83     ref0 += ref_stride;
84     ref1 += ref_stride;
85     ref2 += ref_stride;
86     ref3 += ref_stride;
87   }
88 
89   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
90 }
91 
aom_sadMxNx3d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])92 static AOM_FORCE_INLINE void aom_sadMxNx3d_avx2(
93     int M, int N, const uint8_t *src, int src_stride,
94     const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
95   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
96   __m256i sum_ref0, sum_ref1, sum_ref2;
97   int i, j;
98   const uint8_t *ref0, *ref1, *ref2;
99   const __m256i zero = _mm256_setzero_si256();
100 
101   ref0 = ref[0];
102   ref1 = ref[1];
103   ref2 = ref[2];
104   sum_ref0 = _mm256_setzero_si256();
105   sum_ref2 = _mm256_setzero_si256();
106   sum_ref1 = _mm256_setzero_si256();
107 
108   for (i = 0; i < N; i++) {
109     for (j = 0; j < M; j += 32) {
110       // load src and all refs
111       src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
112       ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
113       ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
114       ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
115 
116       // sum of the absolute differences between every ref-i to src
117       ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
118       ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
119       ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
120       // sum every ref-i
121       sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
122       sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
123       sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
124     }
125     src += src_stride;
126     ref0 += ref_stride;
127     ref1 += ref_stride;
128     ref2 += ref_stride;
129   }
130   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
131 }
132 
133 #define SADMXN_AVX2(m, n)                                                      \
134   void aom_sad##m##x##n##x4d_avx2(const uint8_t *src, int src_stride,          \
135                                   const uint8_t *const ref[4], int ref_stride, \
136                                   uint32_t res[4]) {                           \
137     aom_sadMxNx4d_avx2(m, n, src, src_stride, ref, ref_stride, res);           \
138   }                                                                            \
139   void aom_sad##m##x##n##x3d_avx2(const uint8_t *src, int src_stride,          \
140                                   const uint8_t *const ref[4], int ref_stride, \
141                                   uint32_t res[4]) {                           \
142     aom_sadMxNx3d_avx2(m, n, src, src_stride, ref, ref_stride, res);           \
143   }
144 
145 SADMXN_AVX2(32, 8)
146 SADMXN_AVX2(32, 16)
147 SADMXN_AVX2(32, 32)
148 SADMXN_AVX2(32, 64)
149 
150 SADMXN_AVX2(64, 16)
151 SADMXN_AVX2(64, 32)
152 SADMXN_AVX2(64, 64)
153 SADMXN_AVX2(64, 128)
154 
155 SADMXN_AVX2(128, 64)
156 SADMXN_AVX2(128, 128)
157 
158 #define SAD_SKIP_MXN_AVX2(m, n)                                             \
159   void aom_sad_skip_##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
160                                         const uint8_t *const ref[4],        \
161                                         int ref_stride, uint32_t res[4]) {  \
162     aom_sadMxNx4d_avx2(m, ((n) >> 1), src, 2 * src_stride, ref,             \
163                        2 * ref_stride, res);                                \
164     res[0] <<= 1;                                                           \
165     res[1] <<= 1;                                                           \
166     res[2] <<= 1;                                                           \
167     res[3] <<= 1;                                                           \
168   }
169 
170 SAD_SKIP_MXN_AVX2(32, 8)
171 SAD_SKIP_MXN_AVX2(32, 16)
172 SAD_SKIP_MXN_AVX2(32, 32)
173 SAD_SKIP_MXN_AVX2(32, 64)
174 
175 SAD_SKIP_MXN_AVX2(64, 16)
176 SAD_SKIP_MXN_AVX2(64, 32)
177 SAD_SKIP_MXN_AVX2(64, 64)
178 SAD_SKIP_MXN_AVX2(64, 128)
179 
180 SAD_SKIP_MXN_AVX2(128, 64)
181 SAD_SKIP_MXN_AVX2(128, 128)
182 
aom_sad16xNx3d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])183 static AOM_FORCE_INLINE void aom_sad16xNx3d_avx2(int N, const uint8_t *src,
184                                                  int src_stride,
185                                                  const uint8_t *const ref[4],
186                                                  int ref_stride,
187                                                  uint32_t res[4]) {
188   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
189   __m256i sum_ref0, sum_ref1, sum_ref2;
190   const uint8_t *ref0, *ref1, *ref2;
191   const __m256i zero = _mm256_setzero_si256();
192   assert(N % 2 == 0);
193 
194   ref0 = ref[0];
195   ref1 = ref[1];
196   ref2 = ref[2];
197   sum_ref0 = _mm256_setzero_si256();
198   sum_ref2 = _mm256_setzero_si256();
199   sum_ref1 = _mm256_setzero_si256();
200 
201   for (int i = 0; i < N; i += 2) {
202     // load src and all refs
203     src_reg = yy_loadu2_128(src + src_stride, src);
204     ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
205     ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
206     ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
207 
208     // sum of the absolute differences between every ref-i to src
209     ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
210     ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
211     ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
212 
213     // sum every ref-i
214     sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
215     sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
216     sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
217 
218     src += 2 * src_stride;
219     ref0 += 2 * ref_stride;
220     ref1 += 2 * ref_stride;
221     ref2 += 2 * ref_stride;
222   }
223 
224   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
225 }
226 
aom_sad16xNx4d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])227 static AOM_FORCE_INLINE void aom_sad16xNx4d_avx2(int N, const uint8_t *src,
228                                                  int src_stride,
229                                                  const uint8_t *const ref[4],
230                                                  int ref_stride,
231                                                  uint32_t res[4]) {
232   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
233   __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
234   const uint8_t *ref0, *ref1, *ref2, *ref3;
235   assert(N % 2 == 0);
236 
237   ref0 = ref[0];
238   ref1 = ref[1];
239   ref2 = ref[2];
240   ref3 = ref[3];
241 
242   sum_ref0 = _mm256_setzero_si256();
243   sum_ref2 = _mm256_setzero_si256();
244   sum_ref1 = _mm256_setzero_si256();
245   sum_ref3 = _mm256_setzero_si256();
246 
247   for (int i = 0; i < N; i += 2) {
248     // load src and all refs
249     src_reg = yy_loadu2_128(src + src_stride, src);
250     ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
251     ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
252     ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
253     ref3_reg = yy_loadu2_128(ref3 + ref_stride, ref3);
254 
255     // sum of the absolute differences between every ref-i to src
256     ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
257     ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
258     ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
259     ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
260 
261     // sum every ref-i
262     sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
263     sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
264     sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
265     sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
266 
267     src += 2 * src_stride;
268     ref0 += 2 * ref_stride;
269     ref1 += 2 * ref_stride;
270     ref2 += 2 * ref_stride;
271     ref3 += 2 * ref_stride;
272   }
273 
274   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
275 }
276 
277 #define SAD16XNX3_AVX2(n)                                                   \
278   void aom_sad16x##n##x3d_avx2(const uint8_t *src, int src_stride,          \
279                                const uint8_t *const ref[4], int ref_stride, \
280                                uint32_t res[4]) {                           \
281     aom_sad16xNx3d_avx2(n, src, src_stride, ref, ref_stride, res);          \
282   }
283 #define SAD16XNX4_AVX2(n)                                                   \
284   void aom_sad16x##n##x4d_avx2(const uint8_t *src, int src_stride,          \
285                                const uint8_t *const ref[4], int ref_stride, \
286                                uint32_t res[4]) {                           \
287     aom_sad16xNx4d_avx2(n, src, src_stride, ref, ref_stride, res);          \
288   }
289 
290 SAD16XNX4_AVX2(32)
291 SAD16XNX4_AVX2(16)
292 SAD16XNX4_AVX2(8)
293 
294 SAD16XNX3_AVX2(32)
295 SAD16XNX3_AVX2(16)
296 SAD16XNX3_AVX2(8)
297 
298 #if !CONFIG_REALTIME_ONLY
299 SAD16XNX3_AVX2(64)
300 SAD16XNX3_AVX2(4)
301 
302 SAD16XNX4_AVX2(64)
303 SAD16XNX4_AVX2(4)
304 
305 #endif  // !CONFIG_REALTIME_ONLY
306 
307 #define SAD_SKIP_16XN_AVX2(n)                                                 \
308   void aom_sad_skip_16x##n##x4d_avx2(const uint8_t *src, int src_stride,      \
309                                      const uint8_t *const ref[4],             \
310                                      int ref_stride, uint32_t res[4]) {       \
311     aom_sad16xNx4d_avx2(((n) >> 1), src, 2 * src_stride, ref, 2 * ref_stride, \
312                         res);                                                 \
313     res[0] <<= 1;                                                             \
314     res[1] <<= 1;                                                             \
315     res[2] <<= 1;                                                             \
316     res[3] <<= 1;                                                             \
317   }
318 
319 SAD_SKIP_16XN_AVX2(32)
320 SAD_SKIP_16XN_AVX2(16)
321 SAD_SKIP_16XN_AVX2(8)
322 
323 #if !CONFIG_REALTIME_ONLY
324 SAD_SKIP_16XN_AVX2(64)
325 SAD_SKIP_16XN_AVX2(4)
326 #endif  // !CONFIG_REALTIME_ONLY
327