1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <immintrin.h>
13
14 #include "config/aom_dsp_rtcd.h"
15
sad32x32(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)16 static unsigned int sad32x32(const uint8_t *src_ptr, int src_stride,
17 const uint8_t *ref_ptr, int ref_stride) {
18 __m256i s1, s2, r1, r2;
19 __m256i sum = _mm256_setzero_si256();
20 __m128i sum_i128;
21 int i;
22
23 for (i = 0; i < 16; ++i) {
24 r1 = _mm256_loadu_si256((__m256i const *)ref_ptr);
25 r2 = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride));
26 s1 = _mm256_sad_epu8(r1, _mm256_loadu_si256((__m256i const *)src_ptr));
27 s2 = _mm256_sad_epu8(
28 r2, _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride)));
29 sum = _mm256_add_epi32(sum, _mm256_add_epi32(s1, s2));
30 ref_ptr += ref_stride << 1;
31 src_ptr += src_stride << 1;
32 }
33
34 sum = _mm256_add_epi32(sum, _mm256_srli_si256(sum, 8));
35 sum_i128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 1),
36 _mm256_castsi256_si128(sum));
37 return (unsigned int)_mm_cvtsi128_si32(sum_i128);
38 }
39
sad64x32(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)40 static unsigned int sad64x32(const uint8_t *src_ptr, int src_stride,
41 const uint8_t *ref_ptr, int ref_stride) {
42 unsigned int half_width = 32;
43 uint32_t sum = sad32x32(src_ptr, src_stride, ref_ptr, ref_stride);
44 src_ptr += half_width;
45 ref_ptr += half_width;
46 sum += sad32x32(src_ptr, src_stride, ref_ptr, ref_stride);
47 return sum;
48 }
49
sad64x64(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)50 static unsigned int sad64x64(const uint8_t *src_ptr, int src_stride,
51 const uint8_t *ref_ptr, int ref_stride) {
52 uint32_t sum = sad64x32(src_ptr, src_stride, ref_ptr, ref_stride);
53 src_ptr += src_stride << 5;
54 ref_ptr += ref_stride << 5;
55 sum += sad64x32(src_ptr, src_stride, ref_ptr, ref_stride);
56 return sum;
57 }
58
aom_sad128x64_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)59 unsigned int aom_sad128x64_avx2(const uint8_t *src_ptr, int src_stride,
60 const uint8_t *ref_ptr, int ref_stride) {
61 unsigned int half_width = 64;
62 uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
63 src_ptr += half_width;
64 ref_ptr += half_width;
65 sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
66 return sum;
67 }
68
aom_sad64x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)69 unsigned int aom_sad64x128_avx2(const uint8_t *src_ptr, int src_stride,
70 const uint8_t *ref_ptr, int ref_stride) {
71 uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
72 src_ptr += src_stride << 6;
73 ref_ptr += ref_stride << 6;
74 sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
75 return sum;
76 }
77
aom_sad128x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)78 unsigned int aom_sad128x128_avx2(const uint8_t *src_ptr, int src_stride,
79 const uint8_t *ref_ptr, int ref_stride) {
80 uint32_t sum = aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride);
81 src_ptr += src_stride << 6;
82 ref_ptr += ref_stride << 6;
83 sum += aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride);
84 return sum;
85 }
86
aom_sad_skip_128x64_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)87 unsigned int aom_sad_skip_128x64_avx2(const uint8_t *src_ptr, int src_stride,
88 const uint8_t *ref_ptr, int ref_stride) {
89 const uint32_t half_width = 64;
90 uint32_t sum = sad64x32(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2);
91 src_ptr += half_width;
92 ref_ptr += half_width;
93 sum += sad64x32(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2);
94 return 2 * sum;
95 }
96
aom_sad_skip_64x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)97 unsigned int aom_sad_skip_64x128_avx2(const uint8_t *src_ptr, int src_stride,
98 const uint8_t *ref_ptr, int ref_stride) {
99 const uint32_t sum =
100 sad64x64(src_ptr, 2 * src_stride, ref_ptr, 2 * ref_stride);
101 return 2 * sum;
102 }
103
aom_sad_skip_128x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)104 unsigned int aom_sad_skip_128x128_avx2(const uint8_t *src_ptr, int src_stride,
105 const uint8_t *ref_ptr, int ref_stride) {
106 const uint32_t sum =
107 aom_sad128x64_avx2(src_ptr, 2 * src_stride, ref_ptr, 2 * ref_stride);
108 return 2 * sum;
109 }
110
sad_w64_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const int h,const uint8_t * second_pred,const int second_pred_stride)111 static unsigned int sad_w64_avg_avx2(const uint8_t *src_ptr, int src_stride,
112 const uint8_t *ref_ptr, int ref_stride,
113 const int h, const uint8_t *second_pred,
114 const int second_pred_stride) {
115 int i;
116 __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;
117 __m256i sum_sad = _mm256_setzero_si256();
118 __m256i sum_sad_h;
119 __m128i sum_sad128;
120 for (i = 0; i < h; i++) {
121 ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);
122 ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + 32));
123 ref1_reg = _mm256_avg_epu8(
124 ref1_reg, _mm256_loadu_si256((__m256i const *)second_pred));
125 ref2_reg = _mm256_avg_epu8(
126 ref2_reg, _mm256_loadu_si256((__m256i const *)(second_pred + 32)));
127 sad1_reg =
128 _mm256_sad_epu8(ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));
129 sad2_reg = _mm256_sad_epu8(
130 ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + 32)));
131 sum_sad = _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));
132 ref_ptr += ref_stride;
133 src_ptr += src_stride;
134 second_pred += second_pred_stride;
135 }
136 sum_sad_h = _mm256_srli_si256(sum_sad, 8);
137 sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);
138 sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);
139 sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);
140 return (unsigned int)_mm_cvtsi128_si32(sum_sad128);
141 }
142
aom_sad64x128_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)143 unsigned int aom_sad64x128_avg_avx2(const uint8_t *src_ptr, int src_stride,
144 const uint8_t *ref_ptr, int ref_stride,
145 const uint8_t *second_pred) {
146 uint32_t sum = sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
147 second_pred, 64);
148 src_ptr += src_stride << 6;
149 ref_ptr += ref_stride << 6;
150 second_pred += 64 << 6;
151 sum += sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
152 second_pred, 64);
153 return sum;
154 }
155
aom_sad128x64_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)156 unsigned int aom_sad128x64_avg_avx2(const uint8_t *src_ptr, int src_stride,
157 const uint8_t *ref_ptr, int ref_stride,
158 const uint8_t *second_pred) {
159 unsigned int half_width = 64;
160 uint32_t sum = sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
161 second_pred, 128);
162 src_ptr += half_width;
163 ref_ptr += half_width;
164 second_pred += half_width;
165 sum += sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
166 second_pred, 128);
167 return sum;
168 }
169
aom_sad128x128_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)170 unsigned int aom_sad128x128_avg_avx2(const uint8_t *src_ptr, int src_stride,
171 const uint8_t *ref_ptr, int ref_stride,
172 const uint8_t *second_pred) {
173 uint32_t sum = aom_sad128x64_avg_avx2(src_ptr, src_stride, ref_ptr,
174 ref_stride, second_pred);
175 src_ptr += src_stride << 6;
176 ref_ptr += ref_stride << 6;
177 second_pred += 128 << 6;
178 sum += aom_sad128x64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride,
179 second_pred);
180 return sum;
181 }
182