• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <emmintrin.h>
14 #include "aom_dsp/x86/synonyms.h"
15 
16 #include "config/av1_rtcd.h"
17 #include "av1/common/restoration.h"
18 #include "av1/encoder/pickrst.h"
19 
acc_stat_sse41(int32_t * dst,const uint8_t * src,const __m128i * shuffle,const __m128i * kl)20 static INLINE void acc_stat_sse41(int32_t *dst, const uint8_t *src,
21                                   const __m128i *shuffle, const __m128i *kl) {
22   const __m128i s = _mm_shuffle_epi8(xx_loadu_128(src), *shuffle);
23   const __m128i d0 = _mm_madd_epi16(*kl, _mm_cvtepu8_epi16(s));
24   const __m128i d1 =
25       _mm_madd_epi16(*kl, _mm_cvtepu8_epi16(_mm_srli_si128(s, 8)));
26   const __m128i dst0 = xx_loadu_128(dst);
27   const __m128i dst1 = xx_loadu_128(dst + 4);
28   const __m128i r0 = _mm_add_epi32(dst0, d0);
29   const __m128i r1 = _mm_add_epi32(dst1, d1);
30   xx_storeu_128(dst, r0);
31   xx_storeu_128(dst + 4, r1);
32 }
33 
acc_stat_win7_one_line_sse4_1(const uint8_t * dgd,const uint8_t * src,int h_start,int h_end,int dgd_stride,const __m128i * shuffle,int32_t * sumX,int32_t sumY[WIENER_WIN][WIENER_WIN],int32_t M_int[WIENER_WIN][WIENER_WIN],int32_t H_int[WIENER_WIN2][WIENER_WIN * 8])34 static INLINE void acc_stat_win7_one_line_sse4_1(
35     const uint8_t *dgd, const uint8_t *src, int h_start, int h_end,
36     int dgd_stride, const __m128i *shuffle, int32_t *sumX,
37     int32_t sumY[WIENER_WIN][WIENER_WIN], int32_t M_int[WIENER_WIN][WIENER_WIN],
38     int32_t H_int[WIENER_WIN2][WIENER_WIN * 8]) {
39   const int wiener_win = 7;
40   int j, k, l;
41   for (j = h_start; j < h_end; j += 2) {
42     const uint8_t *dgd_ij = dgd + j;
43     const uint8_t X1 = src[j];
44     const uint8_t X2 = src[j + 1];
45     *sumX += X1 + X2;
46     for (k = 0; k < wiener_win; k++) {
47       const uint8_t *dgd_ijk = dgd_ij + k * dgd_stride;
48       for (l = 0; l < wiener_win; l++) {
49         int32_t *H_ = &H_int[(l * wiener_win + k)][0];
50         const uint8_t D1 = dgd_ijk[l];
51         const uint8_t D2 = dgd_ijk[l + 1];
52         sumY[k][l] += D1 + D2;
53         M_int[k][l] += D1 * X1 + D2 * X2;
54 
55         const __m128i kl =
56             _mm_cvtepu8_epi16(_mm_set1_epi16(*((uint16_t *)(dgd_ijk + l))));
57         acc_stat_sse41(H_ + 0 * 8, dgd_ij + 0 * dgd_stride, shuffle, &kl);
58         acc_stat_sse41(H_ + 1 * 8, dgd_ij + 1 * dgd_stride, shuffle, &kl);
59         acc_stat_sse41(H_ + 2 * 8, dgd_ij + 2 * dgd_stride, shuffle, &kl);
60         acc_stat_sse41(H_ + 3 * 8, dgd_ij + 3 * dgd_stride, shuffle, &kl);
61         acc_stat_sse41(H_ + 4 * 8, dgd_ij + 4 * dgd_stride, shuffle, &kl);
62         acc_stat_sse41(H_ + 5 * 8, dgd_ij + 5 * dgd_stride, shuffle, &kl);
63         acc_stat_sse41(H_ + 6 * 8, dgd_ij + 6 * dgd_stride, shuffle, &kl);
64       }
65     }
66   }
67 }
68 
compute_stats_win7_opt_sse4_1(const uint8_t * dgd,const uint8_t * src,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H)69 static INLINE void compute_stats_win7_opt_sse4_1(
70     const uint8_t *dgd, const uint8_t *src, int h_start, int h_end, int v_start,
71     int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H) {
72   int i, j, k, l, m, n;
73   const int wiener_win = WIENER_WIN;
74   const int pixel_count = (h_end - h_start) * (v_end - v_start);
75   const int wiener_win2 = wiener_win * wiener_win;
76   const int wiener_halfwin = (wiener_win >> 1);
77   const uint8_t avg =
78       find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
79 
80   int32_t M_int32[WIENER_WIN][WIENER_WIN] = { { 0 } };
81   int64_t M_int64[WIENER_WIN][WIENER_WIN] = { { 0 } };
82   int32_t H_int32[WIENER_WIN2][WIENER_WIN * 8] = { { 0 } };
83   int64_t H_int64[WIENER_WIN2][WIENER_WIN * 8] = { { 0 } };
84   int32_t sumY[WIENER_WIN][WIENER_WIN] = { { 0 } };
85   int32_t sumX = 0;
86   const uint8_t *dgd_win = dgd - wiener_halfwin * dgd_stride - wiener_halfwin;
87 
88   const __m128i shuffle = xx_loadu_128(g_shuffle_stats_data);
89   for (j = v_start; j < v_end; j += 64) {
90     const int vert_end = AOMMIN(64, v_end - j) + j;
91     for (i = j; i < vert_end; i++) {
92       acc_stat_win7_one_line_sse4_1(
93           dgd_win + i * dgd_stride, src + i * src_stride, h_start, h_end,
94           dgd_stride, &shuffle, &sumX, sumY, M_int32, H_int32);
95     }
96     for (k = 0; k < wiener_win; ++k) {
97       for (l = 0; l < wiener_win; ++l) {
98         M_int64[k][l] += M_int32[k][l];
99         M_int32[k][l] = 0;
100       }
101     }
102     for (k = 0; k < WIENER_WIN2; ++k) {
103       for (l = 0; l < WIENER_WIN * 8; ++l) {
104         H_int64[k][l] += H_int32[k][l];
105         H_int32[k][l] = 0;
106       }
107     }
108   }
109 
110   const int64_t avg_square_sum = (int64_t)avg * (int64_t)avg * pixel_count;
111   for (k = 0; k < wiener_win; k++) {
112     for (l = 0; l < wiener_win; l++) {
113       const int32_t idx0 = l * wiener_win + k;
114       M[idx0] =
115           M_int64[k][l] + (avg_square_sum - (int64_t)avg * (sumX + sumY[k][l]));
116       int64_t *H_ = H + idx0 * wiener_win2;
117       int64_t *H_int_ = &H_int64[idx0][0];
118       for (m = 0; m < wiener_win; m++) {
119         for (n = 0; n < wiener_win; n++) {
120           H_[m * wiener_win + n] = H_int_[n * 8 + m] + avg_square_sum -
121                                    (int64_t)avg * (sumY[k][l] + sumY[n][m]);
122         }
123       }
124     }
125   }
126 }
127 
128 #if CONFIG_AV1_HIGHBITDEPTH
acc_stat_highbd_sse41(int64_t * dst,const uint16_t * dgd,const __m128i * shuffle,const __m128i * dgd_ijkl)129 static INLINE void acc_stat_highbd_sse41(int64_t *dst, const uint16_t *dgd,
130                                          const __m128i *shuffle,
131                                          const __m128i *dgd_ijkl) {
132   // Load 256 bits from dgd in two chunks
133   const __m128i s0l = xx_loadu_128(dgd);
134   const __m128i s0h = xx_loadu_128(dgd + 4);
135   // s0l = [7 6 5 4 3 2 1 0] as u16 values (dgd indices)
136   // s0h = [11 10 9 8 7 6 5 4] as u16 values (dgd indices)
137   // (Slightly strange order so we can apply the same shuffle to both halves)
138 
139   // Shuffle the u16 values in each half (actually using 8-bit shuffle mask)
140   const __m128i s1l = _mm_shuffle_epi8(s0l, *shuffle);
141   const __m128i s1h = _mm_shuffle_epi8(s0h, *shuffle);
142   // s1l = [4 3 3 2 2 1 1 0] as u16 values (dgd indices)
143   // s1h = [8 7 7 6 6 5 5 4] as u16 values (dgd indices)
144 
145   // Multiply s1 by dgd_ijkl resulting in 8x u32 values
146   // Horizontally add pairs of u32 resulting in 4x u32
147   const __m128i dl = _mm_madd_epi16(*dgd_ijkl, s1l);
148   const __m128i dh = _mm_madd_epi16(*dgd_ijkl, s1h);
149   // dl = [d c b a] as u32 values
150   // dh = [h g f e] as u32 values
151 
152   // Add these 8x u32 results on to dst in four parts
153   const __m128i dll = _mm_cvtepu32_epi64(dl);
154   const __m128i dlh = _mm_cvtepu32_epi64(_mm_srli_si128(dl, 8));
155   const __m128i dhl = _mm_cvtepu32_epi64(dh);
156   const __m128i dhh = _mm_cvtepu32_epi64(_mm_srli_si128(dh, 8));
157   // dll = [b a] as u64 values, etc.
158 
159   const __m128i rll = _mm_add_epi64(xx_loadu_128(dst), dll);
160   xx_storeu_128(dst, rll);
161   const __m128i rlh = _mm_add_epi64(xx_loadu_128(dst + 2), dlh);
162   xx_storeu_128(dst + 2, rlh);
163   const __m128i rhl = _mm_add_epi64(xx_loadu_128(dst + 4), dhl);
164   xx_storeu_128(dst + 4, rhl);
165   const __m128i rhh = _mm_add_epi64(xx_loadu_128(dst + 6), dhh);
166   xx_storeu_128(dst + 6, rhh);
167 }
168 
acc_stat_highbd_win7_one_line_sse4_1(const uint16_t * dgd,const uint16_t * src,int h_start,int h_end,int dgd_stride,const __m128i * shuffle,int32_t * sumX,int32_t sumY[WIENER_WIN][WIENER_WIN],int64_t M_int[WIENER_WIN][WIENER_WIN],int64_t H_int[WIENER_WIN2][WIENER_WIN * 8])169 static INLINE void acc_stat_highbd_win7_one_line_sse4_1(
170     const uint16_t *dgd, const uint16_t *src, int h_start, int h_end,
171     int dgd_stride, const __m128i *shuffle, int32_t *sumX,
172     int32_t sumY[WIENER_WIN][WIENER_WIN], int64_t M_int[WIENER_WIN][WIENER_WIN],
173     int64_t H_int[WIENER_WIN2][WIENER_WIN * 8]) {
174   int j, k, l;
175   const int wiener_win = WIENER_WIN;
176   for (j = h_start; j < h_end; j += 2) {
177     const uint16_t X1 = src[j];
178     const uint16_t X2 = src[j + 1];
179     *sumX += X1 + X2;
180     const uint16_t *dgd_ij = dgd + j;
181     for (k = 0; k < wiener_win; k++) {
182       const uint16_t *dgd_ijk = dgd_ij + k * dgd_stride;
183       for (l = 0; l < wiener_win; l++) {
184         int64_t *H_ = &H_int[(l * wiener_win + k)][0];
185         const uint16_t D1 = dgd_ijk[l];
186         const uint16_t D2 = dgd_ijk[l + 1];
187         sumY[k][l] += D1 + D2;
188         M_int[k][l] += D1 * X1 + D2 * X2;
189 
190         // Load two u16 values from dgd as a single u32
191         // Then broadcast to 4x u32 slots of a 128
192         const __m128i dgd_ijkl = _mm_set1_epi32(*((uint32_t *)(dgd_ijk + l)));
193         // dgd_ijkl = [y x y x y x y x] as u16
194 
195         acc_stat_highbd_sse41(H_ + 0 * 8, dgd_ij + 0 * dgd_stride, shuffle,
196                               &dgd_ijkl);
197         acc_stat_highbd_sse41(H_ + 1 * 8, dgd_ij + 1 * dgd_stride, shuffle,
198                               &dgd_ijkl);
199         acc_stat_highbd_sse41(H_ + 2 * 8, dgd_ij + 2 * dgd_stride, shuffle,
200                               &dgd_ijkl);
201         acc_stat_highbd_sse41(H_ + 3 * 8, dgd_ij + 3 * dgd_stride, shuffle,
202                               &dgd_ijkl);
203         acc_stat_highbd_sse41(H_ + 4 * 8, dgd_ij + 4 * dgd_stride, shuffle,
204                               &dgd_ijkl);
205         acc_stat_highbd_sse41(H_ + 5 * 8, dgd_ij + 5 * dgd_stride, shuffle,
206                               &dgd_ijkl);
207         acc_stat_highbd_sse41(H_ + 6 * 8, dgd_ij + 6 * dgd_stride, shuffle,
208                               &dgd_ijkl);
209       }
210     }
211   }
212 }
213 
compute_stats_highbd_win7_opt_sse4_1(const uint8_t * dgd8,const uint8_t * src8,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)214 static INLINE void compute_stats_highbd_win7_opt_sse4_1(
215     const uint8_t *dgd8, const uint8_t *src8, int h_start, int h_end,
216     int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M,
217     int64_t *H, aom_bit_depth_t bit_depth) {
218   int i, j, k, l, m, n;
219   const int wiener_win = WIENER_WIN;
220   const int pixel_count = (h_end - h_start) * (v_end - v_start);
221   const int wiener_win2 = wiener_win * wiener_win;
222   const int wiener_halfwin = (wiener_win >> 1);
223   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
224   const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
225   const uint16_t avg =
226       find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
227 
228   int64_t M_int[WIENER_WIN][WIENER_WIN] = { { 0 } };
229   int64_t H_int[WIENER_WIN2][WIENER_WIN * 8] = { { 0 } };
230   int32_t sumY[WIENER_WIN][WIENER_WIN] = { { 0 } };
231   int32_t sumX = 0;
232   const uint16_t *dgd_win = dgd - wiener_halfwin * dgd_stride - wiener_halfwin;
233 
234   // Load just half of the 256-bit shuffle control used for the AVX2 version
235   const __m128i shuffle = xx_loadu_128(g_shuffle_stats_highbd_data);
236   for (j = v_start; j < v_end; j += 64) {
237     const int vert_end = AOMMIN(64, v_end - j) + j;
238     for (i = j; i < vert_end; i++) {
239       acc_stat_highbd_win7_one_line_sse4_1(
240           dgd_win + i * dgd_stride, src + i * src_stride, h_start, h_end,
241           dgd_stride, &shuffle, &sumX, sumY, M_int, H_int);
242     }
243   }
244 
245   uint8_t bit_depth_divider = 1;
246   if (bit_depth == AOM_BITS_12)
247     bit_depth_divider = 16;
248   else if (bit_depth == AOM_BITS_10)
249     bit_depth_divider = 4;
250 
251   const int64_t avg_square_sum = (int64_t)avg * (int64_t)avg * pixel_count;
252   for (k = 0; k < wiener_win; k++) {
253     for (l = 0; l < wiener_win; l++) {
254       const int32_t idx0 = l * wiener_win + k;
255       M[idx0] = (M_int[k][l] +
256                  (avg_square_sum - (int64_t)avg * (sumX + sumY[k][l]))) /
257                 bit_depth_divider;
258       int64_t *H_ = H + idx0 * wiener_win2;
259       int64_t *H_int_ = &H_int[idx0][0];
260       for (m = 0; m < wiener_win; m++) {
261         for (n = 0; n < wiener_win; n++) {
262           H_[m * wiener_win + n] =
263               (H_int_[n * 8 + m] +
264                (avg_square_sum - (int64_t)avg * (sumY[k][l] + sumY[n][m]))) /
265               bit_depth_divider;
266         }
267       }
268     }
269   }
270 }
271 
acc_stat_highbd_win5_one_line_sse4_1(const uint16_t * dgd,const uint16_t * src,int h_start,int h_end,int dgd_stride,const __m128i * shuffle,int32_t * sumX,int32_t sumY[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],int64_t M_int[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],int64_t H_int[WIENER_WIN2_CHROMA][WIENER_WIN_CHROMA * 8])272 static INLINE void acc_stat_highbd_win5_one_line_sse4_1(
273     const uint16_t *dgd, const uint16_t *src, int h_start, int h_end,
274     int dgd_stride, const __m128i *shuffle, int32_t *sumX,
275     int32_t sumY[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],
276     int64_t M_int[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],
277     int64_t H_int[WIENER_WIN2_CHROMA][WIENER_WIN_CHROMA * 8]) {
278   int j, k, l;
279   const int wiener_win = WIENER_WIN_CHROMA;
280   for (j = h_start; j < h_end; j += 2) {
281     const uint16_t X1 = src[j];
282     const uint16_t X2 = src[j + 1];
283     *sumX += X1 + X2;
284     const uint16_t *dgd_ij = dgd + j;
285     for (k = 0; k < wiener_win; k++) {
286       const uint16_t *dgd_ijk = dgd_ij + k * dgd_stride;
287       for (l = 0; l < wiener_win; l++) {
288         int64_t *H_ = &H_int[(l * wiener_win + k)][0];
289         const uint16_t D1 = dgd_ijk[l];
290         const uint16_t D2 = dgd_ijk[l + 1];
291         sumY[k][l] += D1 + D2;
292         M_int[k][l] += D1 * X1 + D2 * X2;
293 
294         // Load two u16 values from dgd as a single u32
295         // then broadcast to 4x u32 slots of a 128
296         const __m128i dgd_ijkl = _mm_set1_epi32(*((uint32_t *)(dgd_ijk + l)));
297         // dgd_ijkl = [y x y x y x y x] as u16
298 
299         acc_stat_highbd_sse41(H_ + 0 * 8, dgd_ij + 0 * dgd_stride, shuffle,
300                               &dgd_ijkl);
301         acc_stat_highbd_sse41(H_ + 1 * 8, dgd_ij + 1 * dgd_stride, shuffle,
302                               &dgd_ijkl);
303         acc_stat_highbd_sse41(H_ + 2 * 8, dgd_ij + 2 * dgd_stride, shuffle,
304                               &dgd_ijkl);
305         acc_stat_highbd_sse41(H_ + 3 * 8, dgd_ij + 3 * dgd_stride, shuffle,
306                               &dgd_ijkl);
307         acc_stat_highbd_sse41(H_ + 4 * 8, dgd_ij + 4 * dgd_stride, shuffle,
308                               &dgd_ijkl);
309       }
310     }
311   }
312 }
313 
compute_stats_highbd_win5_opt_sse4_1(const uint8_t * dgd8,const uint8_t * src8,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)314 static INLINE void compute_stats_highbd_win5_opt_sse4_1(
315     const uint8_t *dgd8, const uint8_t *src8, int h_start, int h_end,
316     int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M,
317     int64_t *H, aom_bit_depth_t bit_depth) {
318   int i, j, k, l, m, n;
319   const int wiener_win = WIENER_WIN_CHROMA;
320   const int pixel_count = (h_end - h_start) * (v_end - v_start);
321   const int wiener_win2 = wiener_win * wiener_win;
322   const int wiener_halfwin = (wiener_win >> 1);
323   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
324   const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
325   const uint16_t avg =
326       find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
327 
328   int64_t M_int[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA] = { { 0 } };
329   int64_t H_int[WIENER_WIN2_CHROMA][WIENER_WIN_CHROMA * 8] = { { 0 } };
330   int32_t sumY[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA] = { { 0 } };
331   int32_t sumX = 0;
332   const uint16_t *dgd_win = dgd - wiener_halfwin * dgd_stride - wiener_halfwin;
333 
334   // Load just half of the 256-bit shuffle control used for the AVX2 version
335   const __m128i shuffle = xx_loadu_128(g_shuffle_stats_highbd_data);
336   for (j = v_start; j < v_end; j += 64) {
337     const int vert_end = AOMMIN(64, v_end - j) + j;
338     for (i = j; i < vert_end; i++) {
339       acc_stat_highbd_win5_one_line_sse4_1(
340           dgd_win + i * dgd_stride, src + i * src_stride, h_start, h_end,
341           dgd_stride, &shuffle, &sumX, sumY, M_int, H_int);
342     }
343   }
344 
345   uint8_t bit_depth_divider = 1;
346   if (bit_depth == AOM_BITS_12)
347     bit_depth_divider = 16;
348   else if (bit_depth == AOM_BITS_10)
349     bit_depth_divider = 4;
350 
351   const int64_t avg_square_sum = (int64_t)avg * (int64_t)avg * pixel_count;
352   for (k = 0; k < wiener_win; k++) {
353     for (l = 0; l < wiener_win; l++) {
354       const int32_t idx0 = l * wiener_win + k;
355       M[idx0] = (M_int[k][l] +
356                  (avg_square_sum - (int64_t)avg * (sumX + sumY[k][l]))) /
357                 bit_depth_divider;
358       int64_t *H_ = H + idx0 * wiener_win2;
359       int64_t *H_int_ = &H_int[idx0][0];
360       for (m = 0; m < wiener_win; m++) {
361         for (n = 0; n < wiener_win; n++) {
362           H_[m * wiener_win + n] =
363               (H_int_[n * 8 + m] +
364                (avg_square_sum - (int64_t)avg * (sumY[k][l] + sumY[n][m]))) /
365               bit_depth_divider;
366         }
367       }
368     }
369   }
370 }
371 
av1_compute_stats_highbd_sse4_1(int wiener_win,const uint8_t * dgd8,const uint8_t * src8,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)372 void av1_compute_stats_highbd_sse4_1(int wiener_win, const uint8_t *dgd8,
373                                      const uint8_t *src8, int h_start,
374                                      int h_end, int v_start, int v_end,
375                                      int dgd_stride, int src_stride, int64_t *M,
376                                      int64_t *H, aom_bit_depth_t bit_depth) {
377   if (wiener_win == WIENER_WIN) {
378     compute_stats_highbd_win7_opt_sse4_1(dgd8, src8, h_start, h_end, v_start,
379                                          v_end, dgd_stride, src_stride, M, H,
380                                          bit_depth);
381   } else if (wiener_win == WIENER_WIN_CHROMA) {
382     compute_stats_highbd_win5_opt_sse4_1(dgd8, src8, h_start, h_end, v_start,
383                                          v_end, dgd_stride, src_stride, M, H,
384                                          bit_depth);
385   } else {
386     av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, v_start,
387                                v_end, dgd_stride, src_stride, M, H, bit_depth);
388   }
389 }
390 #endif  // CONFIG_AV1_HIGHBITDEPTH
391 
acc_stat_win5_one_line_sse4_1(const uint8_t * dgd,const uint8_t * src,int h_start,int h_end,int dgd_stride,const __m128i * shuffle,int32_t * sumX,int32_t sumY[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],int32_t M_int[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],int32_t H_int[WIENER_WIN2_CHROMA][WIENER_WIN_CHROMA * 8])392 static INLINE void acc_stat_win5_one_line_sse4_1(
393     const uint8_t *dgd, const uint8_t *src, int h_start, int h_end,
394     int dgd_stride, const __m128i *shuffle, int32_t *sumX,
395     int32_t sumY[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],
396     int32_t M_int[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA],
397     int32_t H_int[WIENER_WIN2_CHROMA][WIENER_WIN_CHROMA * 8]) {
398   const int wiener_win = WIENER_WIN_CHROMA;
399   int j, k, l;
400   for (j = h_start; j < h_end; j += 2) {
401     const uint8_t *dgd_ij = dgd + j;
402     const uint8_t X1 = src[j];
403     const uint8_t X2 = src[j + 1];
404     *sumX += X1 + X2;
405     for (k = 0; k < wiener_win; k++) {
406       const uint8_t *dgd_ijk = dgd_ij + k * dgd_stride;
407       for (l = 0; l < wiener_win; l++) {
408         int32_t *H_ = &H_int[(l * wiener_win + k)][0];
409         const uint8_t D1 = dgd_ijk[l];
410         const uint8_t D2 = dgd_ijk[l + 1];
411         sumY[k][l] += D1 + D2;
412         M_int[k][l] += D1 * X1 + D2 * X2;
413 
414         const __m128i kl =
415             _mm_cvtepu8_epi16(_mm_set1_epi16(*((uint16_t *)(dgd_ijk + l))));
416         acc_stat_sse41(H_ + 0 * 8, dgd_ij + 0 * dgd_stride, shuffle, &kl);
417         acc_stat_sse41(H_ + 1 * 8, dgd_ij + 1 * dgd_stride, shuffle, &kl);
418         acc_stat_sse41(H_ + 2 * 8, dgd_ij + 2 * dgd_stride, shuffle, &kl);
419         acc_stat_sse41(H_ + 3 * 8, dgd_ij + 3 * dgd_stride, shuffle, &kl);
420         acc_stat_sse41(H_ + 4 * 8, dgd_ij + 4 * dgd_stride, shuffle, &kl);
421       }
422     }
423   }
424 }
425 
compute_stats_win5_opt_sse4_1(const uint8_t * dgd,const uint8_t * src,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H)426 static INLINE void compute_stats_win5_opt_sse4_1(
427     const uint8_t *dgd, const uint8_t *src, int h_start, int h_end, int v_start,
428     int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H) {
429   int i, j, k, l, m, n;
430   const int wiener_win = WIENER_WIN_CHROMA;
431   const int pixel_count = (h_end - h_start) * (v_end - v_start);
432   const int wiener_win2 = wiener_win * wiener_win;
433   const int wiener_halfwin = (wiener_win >> 1);
434   const uint8_t avg =
435       find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
436 
437   int32_t M_int32[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA] = { { 0 } };
438   int64_t M_int64[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA] = { { 0 } };
439   int32_t H_int32[WIENER_WIN2_CHROMA][WIENER_WIN_CHROMA * 8] = { { 0 } };
440   int64_t H_int64[WIENER_WIN2_CHROMA][WIENER_WIN_CHROMA * 8] = { { 0 } };
441   int32_t sumY[WIENER_WIN_CHROMA][WIENER_WIN_CHROMA] = { { 0 } };
442   int32_t sumX = 0;
443   const uint8_t *dgd_win = dgd - wiener_halfwin * dgd_stride - wiener_halfwin;
444 
445   const __m128i shuffle = xx_loadu_128(g_shuffle_stats_data);
446   for (j = v_start; j < v_end; j += 64) {
447     const int vert_end = AOMMIN(64, v_end - j) + j;
448     for (i = j; i < vert_end; i++) {
449       acc_stat_win5_one_line_sse4_1(
450           dgd_win + i * dgd_stride, src + i * src_stride, h_start, h_end,
451           dgd_stride, &shuffle, &sumX, sumY, M_int32, H_int32);
452     }
453     for (k = 0; k < wiener_win; ++k) {
454       for (l = 0; l < wiener_win; ++l) {
455         M_int64[k][l] += M_int32[k][l];
456         M_int32[k][l] = 0;
457       }
458     }
459     for (k = 0; k < WIENER_WIN_CHROMA * WIENER_WIN_CHROMA; ++k) {
460       for (l = 0; l < WIENER_WIN_CHROMA * 8; ++l) {
461         H_int64[k][l] += H_int32[k][l];
462         H_int32[k][l] = 0;
463       }
464     }
465   }
466 
467   const int64_t avg_square_sum = (int64_t)avg * (int64_t)avg * pixel_count;
468   for (k = 0; k < wiener_win; k++) {
469     for (l = 0; l < wiener_win; l++) {
470       const int32_t idx0 = l * wiener_win + k;
471       M[idx0] =
472           M_int64[k][l] + (avg_square_sum - (int64_t)avg * (sumX + sumY[k][l]));
473       int64_t *H_ = H + idx0 * wiener_win2;
474       int64_t *H_int_ = &H_int64[idx0][0];
475       for (m = 0; m < wiener_win; m++) {
476         for (n = 0; n < wiener_win; n++) {
477           H_[m * wiener_win + n] = H_int_[n * 8 + m] + avg_square_sum -
478                                    (int64_t)avg * (sumY[k][l] + sumY[n][m]);
479         }
480       }
481     }
482   }
483 }
av1_compute_stats_sse4_1(int wiener_win,const uint8_t * dgd,const uint8_t * src,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H)484 void av1_compute_stats_sse4_1(int wiener_win, const uint8_t *dgd,
485                               const uint8_t *src, int h_start, int h_end,
486                               int v_start, int v_end, int dgd_stride,
487                               int src_stride, int64_t *M, int64_t *H) {
488   if (wiener_win == WIENER_WIN) {
489     compute_stats_win7_opt_sse4_1(dgd, src, h_start, h_end, v_start, v_end,
490                                   dgd_stride, src_stride, M, H);
491   } else if (wiener_win == WIENER_WIN_CHROMA) {
492     compute_stats_win5_opt_sse4_1(dgd, src, h_start, h_end, v_start, v_end,
493                                   dgd_stride, src_stride, M, H);
494   } else {
495     av1_compute_stats_c(wiener_win, dgd, src, h_start, h_end, v_start, v_end,
496                         dgd_stride, src_stride, M, H);
497   }
498 }
499 
pair_set_epi16(int a,int b)500 static INLINE __m128i pair_set_epi16(int a, int b) {
501   return _mm_set1_epi32((int32_t)(((uint16_t)(a)) | (((uint32_t)(b)) << 16)));
502 }
503 
av1_lowbd_pixel_proj_error_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)504 int64_t av1_lowbd_pixel_proj_error_sse4_1(
505     const uint8_t *src8, int width, int height, int src_stride,
506     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
507     int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
508   int i, j, k;
509   const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
510   const __m128i rounding = _mm_set1_epi32(1 << (shift - 1));
511   __m128i sum64 = _mm_setzero_si128();
512   const uint8_t *src = src8;
513   const uint8_t *dat = dat8;
514   int64_t err = 0;
515   if (params->r[0] > 0 && params->r[1] > 0) {
516     __m128i xq_coeff = pair_set_epi16(xq[0], xq[1]);
517     for (i = 0; i < height; ++i) {
518       __m128i sum32 = _mm_setzero_si128();
519       for (j = 0; j <= width - 8; j += 8) {
520         const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j));
521         const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j));
522         const __m128i flt0_16b =
523             _mm_packs_epi32(xx_loadu_128(flt0 + j), xx_loadu_128(flt0 + j + 4));
524         const __m128i flt1_16b =
525             _mm_packs_epi32(xx_loadu_128(flt1 + j), xx_loadu_128(flt1 + j + 4));
526         const __m128i u0 = _mm_slli_epi16(d0, SGRPROJ_RST_BITS);
527         const __m128i flt0_0_sub_u = _mm_sub_epi16(flt0_16b, u0);
528         const __m128i flt1_0_sub_u = _mm_sub_epi16(flt1_16b, u0);
529         const __m128i v0 = _mm_madd_epi16(
530             xq_coeff, _mm_unpacklo_epi16(flt0_0_sub_u, flt1_0_sub_u));
531         const __m128i v1 = _mm_madd_epi16(
532             xq_coeff, _mm_unpackhi_epi16(flt0_0_sub_u, flt1_0_sub_u));
533         const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift);
534         const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift);
535         const __m128i e0 =
536             _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0);
537         const __m128i err0 = _mm_madd_epi16(e0, e0);
538         sum32 = _mm_add_epi32(sum32, err0);
539       }
540       for (k = j; k < width; ++k) {
541         const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
542         int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
543         const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
544         err += ((int64_t)e * e);
545       }
546       dat += dat_stride;
547       src += src_stride;
548       flt0 += flt0_stride;
549       flt1 += flt1_stride;
550       const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32);
551       const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8));
552       sum64 = _mm_add_epi64(sum64, sum64_0);
553       sum64 = _mm_add_epi64(sum64, sum64_1);
554     }
555   } else if (params->r[0] > 0 || params->r[1] > 0) {
556     const int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
557     const __m128i xq_coeff =
558         pair_set_epi16(xq_active, -(xq_active << SGRPROJ_RST_BITS));
559     const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
560     const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
561     for (i = 0; i < height; ++i) {
562       __m128i sum32 = _mm_setzero_si128();
563       for (j = 0; j <= width - 8; j += 8) {
564         const __m128i d0 = _mm_cvtepu8_epi16(xx_loadl_64(dat + j));
565         const __m128i s0 = _mm_cvtepu8_epi16(xx_loadl_64(src + j));
566         const __m128i flt_16b =
567             _mm_packs_epi32(xx_loadu_128(flt + j), xx_loadu_128(flt + j + 4));
568         const __m128i v0 =
569             _mm_madd_epi16(xq_coeff, _mm_unpacklo_epi16(flt_16b, d0));
570         const __m128i v1 =
571             _mm_madd_epi16(xq_coeff, _mm_unpackhi_epi16(flt_16b, d0));
572         const __m128i vr0 = _mm_srai_epi32(_mm_add_epi32(v0, rounding), shift);
573         const __m128i vr1 = _mm_srai_epi32(_mm_add_epi32(v1, rounding), shift);
574         const __m128i e0 =
575             _mm_sub_epi16(_mm_add_epi16(_mm_packs_epi32(vr0, vr1), d0), s0);
576         const __m128i err0 = _mm_madd_epi16(e0, e0);
577         sum32 = _mm_add_epi32(sum32, err0);
578       }
579       for (k = j; k < width; ++k) {
580         const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
581         int32_t v = xq_active * (flt[k] - u);
582         const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
583         err += ((int64_t)e * e);
584       }
585       dat += dat_stride;
586       src += src_stride;
587       flt += flt_stride;
588       const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32);
589       const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8));
590       sum64 = _mm_add_epi64(sum64, sum64_0);
591       sum64 = _mm_add_epi64(sum64, sum64_1);
592     }
593   } else {
594     __m128i sum32 = _mm_setzero_si128();
595     for (i = 0; i < height; ++i) {
596       for (j = 0; j <= width - 16; j += 16) {
597         const __m128i d = xx_loadu_128(dat + j);
598         const __m128i s = xx_loadu_128(src + j);
599         const __m128i d0 = _mm_cvtepu8_epi16(d);
600         const __m128i d1 = _mm_cvtepu8_epi16(_mm_srli_si128(d, 8));
601         const __m128i s0 = _mm_cvtepu8_epi16(s);
602         const __m128i s1 = _mm_cvtepu8_epi16(_mm_srli_si128(s, 8));
603         const __m128i diff0 = _mm_sub_epi16(d0, s0);
604         const __m128i diff1 = _mm_sub_epi16(d1, s1);
605         const __m128i err0 = _mm_madd_epi16(diff0, diff0);
606         const __m128i err1 = _mm_madd_epi16(diff1, diff1);
607         sum32 = _mm_add_epi32(sum32, err0);
608         sum32 = _mm_add_epi32(sum32, err1);
609       }
610       for (k = j; k < width; ++k) {
611         const int32_t e = (int32_t)(dat[k]) - src[k];
612         err += ((int64_t)e * e);
613       }
614       dat += dat_stride;
615       src += src_stride;
616     }
617     const __m128i sum64_0 = _mm_cvtepi32_epi64(sum32);
618     const __m128i sum64_1 = _mm_cvtepi32_epi64(_mm_srli_si128(sum32, 8));
619     sum64 = _mm_add_epi64(sum64_0, sum64_1);
620   }
621   int64_t sum[2];
622   xx_storeu_128(sum, sum64);
623   err += sum[0] + sum[1];
624   return err;
625 }
626 
627 // When params->r[0] > 0 and params->r[1] > 0. In this case all elements of
628 // C and H need to be computed.
calc_proj_params_r0_r1_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])629 static AOM_INLINE void calc_proj_params_r0_r1_sse4_1(
630     const uint8_t *src8, int width, int height, int src_stride,
631     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
632     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
633   const int size = width * height;
634   const uint8_t *src = src8;
635   const uint8_t *dat = dat8;
636   __m128i h00, h01, h11, c0, c1;
637   const __m128i zero = _mm_setzero_si128();
638   h01 = h11 = c0 = c1 = h00 = zero;
639 
640   for (int i = 0; i < height; ++i) {
641     for (int j = 0; j < width; j += 4) {
642       const __m128i u_load = _mm_cvtepu8_epi32(
643           _mm_cvtsi32_si128(*((int *)(dat + i * dat_stride + j))));
644       const __m128i s_load = _mm_cvtepu8_epi32(
645           _mm_cvtsi32_si128(*((int *)(src + i * src_stride + j))));
646       __m128i f1 = _mm_loadu_si128((__m128i *)(flt0 + i * flt0_stride + j));
647       __m128i f2 = _mm_loadu_si128((__m128i *)(flt1 + i * flt1_stride + j));
648       __m128i d = _mm_slli_epi32(u_load, SGRPROJ_RST_BITS);
649       __m128i s = _mm_slli_epi32(s_load, SGRPROJ_RST_BITS);
650       s = _mm_sub_epi32(s, d);
651       f1 = _mm_sub_epi32(f1, d);
652       f2 = _mm_sub_epi32(f2, d);
653 
654       const __m128i h00_even = _mm_mul_epi32(f1, f1);
655       const __m128i h00_odd =
656           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(f1, 32));
657       h00 = _mm_add_epi64(h00, h00_even);
658       h00 = _mm_add_epi64(h00, h00_odd);
659 
660       const __m128i h01_even = _mm_mul_epi32(f1, f2);
661       const __m128i h01_odd =
662           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(f2, 32));
663       h01 = _mm_add_epi64(h01, h01_even);
664       h01 = _mm_add_epi64(h01, h01_odd);
665 
666       const __m128i h11_even = _mm_mul_epi32(f2, f2);
667       const __m128i h11_odd =
668           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(f2, 32));
669       h11 = _mm_add_epi64(h11, h11_even);
670       h11 = _mm_add_epi64(h11, h11_odd);
671 
672       const __m128i c0_even = _mm_mul_epi32(f1, s);
673       const __m128i c0_odd =
674           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(s, 32));
675       c0 = _mm_add_epi64(c0, c0_even);
676       c0 = _mm_add_epi64(c0, c0_odd);
677 
678       const __m128i c1_even = _mm_mul_epi32(f2, s);
679       const __m128i c1_odd =
680           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(s, 32));
681       c1 = _mm_add_epi64(c1, c1_even);
682       c1 = _mm_add_epi64(c1, c1_odd);
683     }
684   }
685 
686   __m128i c_low = _mm_unpacklo_epi64(c0, c1);
687   const __m128i c_high = _mm_unpackhi_epi64(c0, c1);
688   c_low = _mm_add_epi64(c_low, c_high);
689 
690   __m128i h0x_low = _mm_unpacklo_epi64(h00, h01);
691   const __m128i h0x_high = _mm_unpackhi_epi64(h00, h01);
692   h0x_low = _mm_add_epi64(h0x_low, h0x_high);
693 
694   // Using the symmetric properties of H,  calculations of H[1][0] are not
695   // needed.
696   __m128i h1x_low = _mm_unpacklo_epi64(zero, h11);
697   const __m128i h1x_high = _mm_unpackhi_epi64(zero, h11);
698   h1x_low = _mm_add_epi64(h1x_low, h1x_high);
699 
700   xx_storeu_128(C, c_low);
701   xx_storeu_128(H[0], h0x_low);
702   xx_storeu_128(H[1], h1x_low);
703 
704   H[0][0] /= size;
705   H[0][1] /= size;
706   H[1][1] /= size;
707 
708   // Since H is a symmetric matrix
709   H[1][0] = H[0][1];
710   C[0] /= size;
711   C[1] /= size;
712 }
713 
714 // When only params->r[0] > 0. In this case only H[0][0] and C[0] are
715 // non-zero and need to be computed.
calc_proj_params_r0_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])716 static AOM_INLINE void calc_proj_params_r0_sse4_1(
717     const uint8_t *src8, int width, int height, int src_stride,
718     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
719     int64_t H[2][2], int64_t C[2]) {
720   const int size = width * height;
721   const uint8_t *src = src8;
722   const uint8_t *dat = dat8;
723   __m128i h00, c0;
724   const __m128i zero = _mm_setzero_si128();
725   c0 = h00 = zero;
726 
727   for (int i = 0; i < height; ++i) {
728     for (int j = 0; j < width; j += 4) {
729       const __m128i u_load = _mm_cvtepu8_epi32(
730           _mm_cvtsi32_si128(*((int *)(dat + i * dat_stride + j))));
731       const __m128i s_load = _mm_cvtepu8_epi32(
732           _mm_cvtsi32_si128(*((int *)(src + i * src_stride + j))));
733       __m128i f1 = _mm_loadu_si128((__m128i *)(flt0 + i * flt0_stride + j));
734       __m128i d = _mm_slli_epi32(u_load, SGRPROJ_RST_BITS);
735       __m128i s = _mm_slli_epi32(s_load, SGRPROJ_RST_BITS);
736       s = _mm_sub_epi32(s, d);
737       f1 = _mm_sub_epi32(f1, d);
738 
739       const __m128i h00_even = _mm_mul_epi32(f1, f1);
740       const __m128i h00_odd =
741           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(f1, 32));
742       h00 = _mm_add_epi64(h00, h00_even);
743       h00 = _mm_add_epi64(h00, h00_odd);
744 
745       const __m128i c0_even = _mm_mul_epi32(f1, s);
746       const __m128i c0_odd =
747           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(s, 32));
748       c0 = _mm_add_epi64(c0, c0_even);
749       c0 = _mm_add_epi64(c0, c0_odd);
750     }
751   }
752   const __m128i h00_val = _mm_add_epi64(h00, _mm_srli_si128(h00, 8));
753 
754   const __m128i c0_val = _mm_add_epi64(c0, _mm_srli_si128(c0, 8));
755 
756   const __m128i c = _mm_unpacklo_epi64(c0_val, zero);
757   const __m128i h0x = _mm_unpacklo_epi64(h00_val, zero);
758 
759   xx_storeu_128(C, c);
760   xx_storeu_128(H[0], h0x);
761 
762   H[0][0] /= size;
763   C[0] /= size;
764 }
765 
766 // When only params->r[1] > 0. In this case only H[1][1] and C[1] are
767 // non-zero and need to be computed.
calc_proj_params_r1_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])768 static AOM_INLINE void calc_proj_params_r1_sse4_1(
769     const uint8_t *src8, int width, int height, int src_stride,
770     const uint8_t *dat8, int dat_stride, int32_t *flt1, int flt1_stride,
771     int64_t H[2][2], int64_t C[2]) {
772   const int size = width * height;
773   const uint8_t *src = src8;
774   const uint8_t *dat = dat8;
775   __m128i h11, c1;
776   const __m128i zero = _mm_setzero_si128();
777   c1 = h11 = zero;
778 
779   for (int i = 0; i < height; ++i) {
780     for (int j = 0; j < width; j += 4) {
781       const __m128i u_load = _mm_cvtepu8_epi32(
782           _mm_cvtsi32_si128(*((int *)(dat + i * dat_stride + j))));
783       const __m128i s_load = _mm_cvtepu8_epi32(
784           _mm_cvtsi32_si128(*((int *)(src + i * src_stride + j))));
785       __m128i f2 = _mm_loadu_si128((__m128i *)(flt1 + i * flt1_stride + j));
786       __m128i d = _mm_slli_epi32(u_load, SGRPROJ_RST_BITS);
787       __m128i s = _mm_slli_epi32(s_load, SGRPROJ_RST_BITS);
788       s = _mm_sub_epi32(s, d);
789       f2 = _mm_sub_epi32(f2, d);
790 
791       const __m128i h11_even = _mm_mul_epi32(f2, f2);
792       const __m128i h11_odd =
793           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(f2, 32));
794       h11 = _mm_add_epi64(h11, h11_even);
795       h11 = _mm_add_epi64(h11, h11_odd);
796 
797       const __m128i c1_even = _mm_mul_epi32(f2, s);
798       const __m128i c1_odd =
799           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(s, 32));
800       c1 = _mm_add_epi64(c1, c1_even);
801       c1 = _mm_add_epi64(c1, c1_odd);
802     }
803   }
804 
805   const __m128i h11_val = _mm_add_epi64(h11, _mm_srli_si128(h11, 8));
806 
807   const __m128i c1_val = _mm_add_epi64(c1, _mm_srli_si128(c1, 8));
808 
809   const __m128i c = _mm_unpacklo_epi64(zero, c1_val);
810   const __m128i h1x = _mm_unpacklo_epi64(zero, h11_val);
811 
812   xx_storeu_128(C, c);
813   xx_storeu_128(H[1], h1x);
814 
815   H[1][1] /= size;
816   C[1] /= size;
817 }
818 
819 // SSE4.1 variant of av1_calc_proj_params_c.
av1_calc_proj_params_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)820 void av1_calc_proj_params_sse4_1(const uint8_t *src8, int width, int height,
821                                  int src_stride, const uint8_t *dat8,
822                                  int dat_stride, int32_t *flt0, int flt0_stride,
823                                  int32_t *flt1, int flt1_stride,
824                                  int64_t H[2][2], int64_t C[2],
825                                  const sgr_params_type *params) {
826   if ((params->r[0] > 0) && (params->r[1] > 0)) {
827     calc_proj_params_r0_r1_sse4_1(src8, width, height, src_stride, dat8,
828                                   dat_stride, flt0, flt0_stride, flt1,
829                                   flt1_stride, H, C);
830   } else if (params->r[0] > 0) {
831     calc_proj_params_r0_sse4_1(src8, width, height, src_stride, dat8,
832                                dat_stride, flt0, flt0_stride, H, C);
833   } else if (params->r[1] > 0) {
834     calc_proj_params_r1_sse4_1(src8, width, height, src_stride, dat8,
835                                dat_stride, flt1, flt1_stride, H, C);
836   }
837 }
838 
calc_proj_params_r0_r1_high_bd_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])839 static AOM_INLINE void calc_proj_params_r0_r1_high_bd_sse4_1(
840     const uint8_t *src8, int width, int height, int src_stride,
841     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
842     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
843   const int size = width * height;
844   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
845   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
846   __m128i h00, h01, h11, c0, c1;
847   const __m128i zero = _mm_setzero_si128();
848   h01 = h11 = c0 = c1 = h00 = zero;
849 
850   for (int i = 0; i < height; ++i) {
851     for (int j = 0; j < width; j += 4) {
852       const __m128i u_load = _mm_cvtepu16_epi32(
853           _mm_loadl_epi64((__m128i *)(dat + i * dat_stride + j)));
854       const __m128i s_load = _mm_cvtepu16_epi32(
855           _mm_loadl_epi64((__m128i *)(src + i * src_stride + j)));
856       __m128i f1 = _mm_loadu_si128((__m128i *)(flt0 + i * flt0_stride + j));
857       __m128i f2 = _mm_loadu_si128((__m128i *)(flt1 + i * flt1_stride + j));
858       __m128i d = _mm_slli_epi32(u_load, SGRPROJ_RST_BITS);
859       __m128i s = _mm_slli_epi32(s_load, SGRPROJ_RST_BITS);
860       s = _mm_sub_epi32(s, d);
861       f1 = _mm_sub_epi32(f1, d);
862       f2 = _mm_sub_epi32(f2, d);
863 
864       const __m128i h00_even = _mm_mul_epi32(f1, f1);
865       const __m128i h00_odd =
866           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(f1, 32));
867       h00 = _mm_add_epi64(h00, h00_even);
868       h00 = _mm_add_epi64(h00, h00_odd);
869 
870       const __m128i h01_even = _mm_mul_epi32(f1, f2);
871       const __m128i h01_odd =
872           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(f2, 32));
873       h01 = _mm_add_epi64(h01, h01_even);
874       h01 = _mm_add_epi64(h01, h01_odd);
875 
876       const __m128i h11_even = _mm_mul_epi32(f2, f2);
877       const __m128i h11_odd =
878           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(f2, 32));
879       h11 = _mm_add_epi64(h11, h11_even);
880       h11 = _mm_add_epi64(h11, h11_odd);
881 
882       const __m128i c0_even = _mm_mul_epi32(f1, s);
883       const __m128i c0_odd =
884           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(s, 32));
885       c0 = _mm_add_epi64(c0, c0_even);
886       c0 = _mm_add_epi64(c0, c0_odd);
887 
888       const __m128i c1_even = _mm_mul_epi32(f2, s);
889       const __m128i c1_odd =
890           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(s, 32));
891       c1 = _mm_add_epi64(c1, c1_even);
892       c1 = _mm_add_epi64(c1, c1_odd);
893     }
894   }
895 
896   __m128i c_low = _mm_unpacklo_epi64(c0, c1);
897   const __m128i c_high = _mm_unpackhi_epi64(c0, c1);
898   c_low = _mm_add_epi64(c_low, c_high);
899 
900   __m128i h0x_low = _mm_unpacklo_epi64(h00, h01);
901   const __m128i h0x_high = _mm_unpackhi_epi64(h00, h01);
902   h0x_low = _mm_add_epi64(h0x_low, h0x_high);
903 
904   // Using the symmetric properties of H,  calculations of H[1][0] are not
905   // needed.
906   __m128i h1x_low = _mm_unpacklo_epi64(zero, h11);
907   const __m128i h1x_high = _mm_unpackhi_epi64(zero, h11);
908   h1x_low = _mm_add_epi64(h1x_low, h1x_high);
909 
910   xx_storeu_128(C, c_low);
911   xx_storeu_128(H[0], h0x_low);
912   xx_storeu_128(H[1], h1x_low);
913 
914   H[0][0] /= size;
915   H[0][1] /= size;
916   H[1][1] /= size;
917 
918   // Since H is a symmetric matrix
919   H[1][0] = H[0][1];
920   C[0] /= size;
921   C[1] /= size;
922 }
923 
924 // When only params->r[0] > 0. In this case only H[0][0] and C[0] are
925 // non-zero and need to be computed.
calc_proj_params_r0_high_bd_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])926 static AOM_INLINE void calc_proj_params_r0_high_bd_sse4_1(
927     const uint8_t *src8, int width, int height, int src_stride,
928     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
929     int64_t H[2][2], int64_t C[2]) {
930   const int size = width * height;
931   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
932   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
933   __m128i h00, c0;
934   const __m128i zero = _mm_setzero_si128();
935   c0 = h00 = zero;
936 
937   for (int i = 0; i < height; ++i) {
938     for (int j = 0; j < width; j += 4) {
939       const __m128i u_load = _mm_cvtepu16_epi32(
940           _mm_loadl_epi64((__m128i *)(dat + i * dat_stride + j)));
941       const __m128i s_load = _mm_cvtepu16_epi32(
942           _mm_loadl_epi64((__m128i *)(src + i * src_stride + j)));
943       __m128i f1 = _mm_loadu_si128((__m128i *)(flt0 + i * flt0_stride + j));
944       __m128i d = _mm_slli_epi32(u_load, SGRPROJ_RST_BITS);
945       __m128i s = _mm_slli_epi32(s_load, SGRPROJ_RST_BITS);
946       s = _mm_sub_epi32(s, d);
947       f1 = _mm_sub_epi32(f1, d);
948 
949       const __m128i h00_even = _mm_mul_epi32(f1, f1);
950       const __m128i h00_odd =
951           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(f1, 32));
952       h00 = _mm_add_epi64(h00, h00_even);
953       h00 = _mm_add_epi64(h00, h00_odd);
954 
955       const __m128i c0_even = _mm_mul_epi32(f1, s);
956       const __m128i c0_odd =
957           _mm_mul_epi32(_mm_srli_epi64(f1, 32), _mm_srli_epi64(s, 32));
958       c0 = _mm_add_epi64(c0, c0_even);
959       c0 = _mm_add_epi64(c0, c0_odd);
960     }
961   }
962   const __m128i h00_val = _mm_add_epi64(h00, _mm_srli_si128(h00, 8));
963 
964   const __m128i c0_val = _mm_add_epi64(c0, _mm_srli_si128(c0, 8));
965 
966   const __m128i c = _mm_unpacklo_epi64(c0_val, zero);
967   const __m128i h0x = _mm_unpacklo_epi64(h00_val, zero);
968 
969   xx_storeu_128(C, c);
970   xx_storeu_128(H[0], h0x);
971 
972   H[0][0] /= size;
973   C[0] /= size;
974 }
975 
976 // When only params->r[1] > 0. In this case only H[1][1] and C[1] are
977 // non-zero and need to be computed.
calc_proj_params_r1_high_bd_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])978 static AOM_INLINE void calc_proj_params_r1_high_bd_sse4_1(
979     const uint8_t *src8, int width, int height, int src_stride,
980     const uint8_t *dat8, int dat_stride, int32_t *flt1, int flt1_stride,
981     int64_t H[2][2], int64_t C[2]) {
982   const int size = width * height;
983   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
984   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
985   __m128i h11, c1;
986   const __m128i zero = _mm_setzero_si128();
987   c1 = h11 = zero;
988 
989   for (int i = 0; i < height; ++i) {
990     for (int j = 0; j < width; j += 4) {
991       const __m128i u_load = _mm_cvtepu16_epi32(
992           _mm_loadl_epi64((__m128i *)(dat + i * dat_stride + j)));
993       const __m128i s_load = _mm_cvtepu16_epi32(
994           _mm_loadl_epi64((__m128i *)(src + i * src_stride + j)));
995       __m128i f2 = _mm_loadu_si128((__m128i *)(flt1 + i * flt1_stride + j));
996       __m128i d = _mm_slli_epi32(u_load, SGRPROJ_RST_BITS);
997       __m128i s = _mm_slli_epi32(s_load, SGRPROJ_RST_BITS);
998       s = _mm_sub_epi32(s, d);
999       f2 = _mm_sub_epi32(f2, d);
1000 
1001       const __m128i h11_even = _mm_mul_epi32(f2, f2);
1002       const __m128i h11_odd =
1003           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(f2, 32));
1004       h11 = _mm_add_epi64(h11, h11_even);
1005       h11 = _mm_add_epi64(h11, h11_odd);
1006 
1007       const __m128i c1_even = _mm_mul_epi32(f2, s);
1008       const __m128i c1_odd =
1009           _mm_mul_epi32(_mm_srli_epi64(f2, 32), _mm_srli_epi64(s, 32));
1010       c1 = _mm_add_epi64(c1, c1_even);
1011       c1 = _mm_add_epi64(c1, c1_odd);
1012     }
1013   }
1014 
1015   const __m128i h11_val = _mm_add_epi64(h11, _mm_srli_si128(h11, 8));
1016 
1017   const __m128i c1_val = _mm_add_epi64(c1, _mm_srli_si128(c1, 8));
1018 
1019   const __m128i c = _mm_unpacklo_epi64(zero, c1_val);
1020   const __m128i h1x = _mm_unpacklo_epi64(zero, h11_val);
1021 
1022   xx_storeu_128(C, c);
1023   xx_storeu_128(H[1], h1x);
1024 
1025   H[1][1] /= size;
1026   C[1] /= size;
1027 }
1028 
1029 // SSE4.1 variant of av1_calc_proj_params_high_bd_c.
av1_calc_proj_params_high_bd_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)1030 void av1_calc_proj_params_high_bd_sse4_1(const uint8_t *src8, int width,
1031                                          int height, int src_stride,
1032                                          const uint8_t *dat8, int dat_stride,
1033                                          int32_t *flt0, int flt0_stride,
1034                                          int32_t *flt1, int flt1_stride,
1035                                          int64_t H[2][2], int64_t C[2],
1036                                          const sgr_params_type *params) {
1037   if ((params->r[0] > 0) && (params->r[1] > 0)) {
1038     calc_proj_params_r0_r1_high_bd_sse4_1(src8, width, height, src_stride, dat8,
1039                                           dat_stride, flt0, flt0_stride, flt1,
1040                                           flt1_stride, H, C);
1041   } else if (params->r[0] > 0) {
1042     calc_proj_params_r0_high_bd_sse4_1(src8, width, height, src_stride, dat8,
1043                                        dat_stride, flt0, flt0_stride, H, C);
1044   } else if (params->r[1] > 0) {
1045     calc_proj_params_r1_high_bd_sse4_1(src8, width, height, src_stride, dat8,
1046                                        dat_stride, flt1, flt1_stride, H, C);
1047   }
1048 }
1049 
1050 #if CONFIG_AV1_HIGHBITDEPTH
av1_highbd_pixel_proj_error_sse4_1(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)1051 int64_t av1_highbd_pixel_proj_error_sse4_1(
1052     const uint8_t *src8, int width, int height, int src_stride,
1053     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
1054     int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
1055   int i, j, k;
1056   const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
1057   const __m128i rounding = _mm_set1_epi32(1 << (shift - 1));
1058   __m128i sum64 = _mm_setzero_si128();
1059   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1060   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
1061   int64_t err = 0;
1062   if (params->r[0] > 0 && params->r[1] > 0) {  // Both filters are enabled
1063     const __m128i xq0 = _mm_set1_epi32(xq[0]);
1064     const __m128i xq1 = _mm_set1_epi32(xq[1]);
1065 
1066     for (i = 0; i < height; ++i) {
1067       __m128i sum32 = _mm_setzero_si128();
1068       for (j = 0; j <= width - 8; j += 8) {
1069         // Load 8x pixels from source image
1070         const __m128i s0 = xx_loadu_128(src + j);
1071         // s0 = [7 6 5 4 3 2 1 0] as i16 (indices of src[])
1072 
1073         // Load 8x pixels from corrupted image
1074         const __m128i d0 = xx_loadu_128(dat + j);
1075         // d0 = [7 6 5 4 3 2 1 0] as i16 (indices of dat[])
1076 
1077         // Shift each pixel value up by SGRPROJ_RST_BITS
1078         const __m128i u0 = _mm_slli_epi16(d0, SGRPROJ_RST_BITS);
1079 
1080         // Split u0 into two halves and pad each from u16 to i32
1081         const __m128i u0l = _mm_cvtepu16_epi32(u0);
1082         const __m128i u0h = _mm_cvtepu16_epi32(_mm_srli_si128(u0, 8));
1083         // u0h = [7 6 5 4] as i32, u0l = [3 2 1 0] as i32, all dat[] indices
1084 
1085         // Load 8 pixels from first and second filtered images
1086         const __m128i flt0l = xx_loadu_128(flt0 + j);
1087         const __m128i flt0h = xx_loadu_128(flt0 + j + 4);
1088         const __m128i flt1l = xx_loadu_128(flt1 + j);
1089         const __m128i flt1h = xx_loadu_128(flt1 + j + 4);
1090         // flt0 = [7 6 5 4] [3 2 1 0] as i32 (indices of flt0+j)
1091         // flt1 = [7 6 5 4] [3 2 1 0] as i32 (indices of flt1+j)
1092 
1093         // Subtract shifted corrupt image from each filtered image
1094         // This gives our two basis vectors for the projection
1095         const __m128i flt0l_subu = _mm_sub_epi32(flt0l, u0l);
1096         const __m128i flt0h_subu = _mm_sub_epi32(flt0h, u0h);
1097         const __m128i flt1l_subu = _mm_sub_epi32(flt1l, u0l);
1098         const __m128i flt1h_subu = _mm_sub_epi32(flt1h, u0h);
1099         // flt?h_subu = [ f[7]-u[7] f[6]-u[6] f[5]-u[5] f[4]-u[4] ] as i32
1100         // flt?l_subu = [ f[3]-u[3] f[2]-u[2] f[1]-u[1] f[0]-u[0] ] as i32
1101 
1102         // Multiply each basis vector by the corresponding coefficient
1103         const __m128i v0l = _mm_mullo_epi32(flt0l_subu, xq0);
1104         const __m128i v0h = _mm_mullo_epi32(flt0h_subu, xq0);
1105         const __m128i v1l = _mm_mullo_epi32(flt1l_subu, xq1);
1106         const __m128i v1h = _mm_mullo_epi32(flt1h_subu, xq1);
1107 
1108         // Add together the contribution from each scaled basis vector
1109         const __m128i vl = _mm_add_epi32(v0l, v1l);
1110         const __m128i vh = _mm_add_epi32(v0h, v1h);
1111 
1112         // Right-shift v with appropriate rounding
1113         const __m128i vrl = _mm_srai_epi32(_mm_add_epi32(vl, rounding), shift);
1114         const __m128i vrh = _mm_srai_epi32(_mm_add_epi32(vh, rounding), shift);
1115 
1116         // Saturate each i32 value to i16 and combine lower and upper halves
1117         const __m128i vr = _mm_packs_epi32(vrl, vrh);
1118 
1119         // Add twin-subspace-sgr-filter to corrupt image then subtract source
1120         const __m128i e0 = _mm_sub_epi16(_mm_add_epi16(vr, d0), s0);
1121 
1122         // Calculate squared error and add adjacent values
1123         const __m128i err0 = _mm_madd_epi16(e0, e0);
1124 
1125         sum32 = _mm_add_epi32(sum32, err0);
1126       }
1127 
1128       const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
1129       sum64 = _mm_add_epi64(sum64, sum32l);
1130       const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
1131       sum64 = _mm_add_epi64(sum64, sum32h);
1132 
1133       // Process remaining pixels in this row (modulo 8)
1134       for (k = j; k < width; ++k) {
1135         const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
1136         int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
1137         const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
1138         err += ((int64_t)e * e);
1139       }
1140       dat += dat_stride;
1141       src += src_stride;
1142       flt0 += flt0_stride;
1143       flt1 += flt1_stride;
1144     }
1145   } else if (params->r[0] > 0 || params->r[1] > 0) {  // Only one filter enabled
1146     const int32_t xq_on = (params->r[0] > 0) ? xq[0] : xq[1];
1147     const __m128i xq_active = _mm_set1_epi32(xq_on);
1148     const __m128i xq_inactive =
1149         _mm_set1_epi32(-xq_on * (1 << SGRPROJ_RST_BITS));
1150     const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
1151     const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
1152     for (i = 0; i < height; ++i) {
1153       __m128i sum32 = _mm_setzero_si128();
1154       for (j = 0; j <= width - 8; j += 8) {
1155         // Load 8x pixels from source image
1156         const __m128i s0 = xx_loadu_128(src + j);
1157         // s0 = [7 6 5 4 3 2 1 0] as u16 (indices of src[])
1158 
1159         // Load 8x pixels from corrupted image and pad each u16 to i32
1160         const __m128i d0 = xx_loadu_128(dat + j);
1161         const __m128i d0h = _mm_cvtepu16_epi32(_mm_srli_si128(d0, 8));
1162         const __m128i d0l = _mm_cvtepu16_epi32(d0);
1163         // d0h, d0l = [7 6 5 4], [3 2 1 0] as u32 (indices of dat[])
1164 
1165         // Load 8 pixels from the filtered image
1166         const __m128i flth = xx_loadu_128(flt + j + 4);
1167         const __m128i fltl = xx_loadu_128(flt + j);
1168         // flth, fltl = [7 6 5 4], [3 2 1 0] as i32 (indices of flt+j)
1169 
1170         const __m128i flth_xq = _mm_mullo_epi32(flth, xq_active);
1171         const __m128i fltl_xq = _mm_mullo_epi32(fltl, xq_active);
1172         const __m128i d0h_xq = _mm_mullo_epi32(d0h, xq_inactive);
1173         const __m128i d0l_xq = _mm_mullo_epi32(d0l, xq_inactive);
1174 
1175         const __m128i vh = _mm_add_epi32(flth_xq, d0h_xq);
1176         const __m128i vl = _mm_add_epi32(fltl_xq, d0l_xq);
1177         // vh = [ xq0(f[7]-d[7]) xq0(f[6]-d[6]) xq0(f[5]-d[5]) xq0(f[4]-d[4]) ]
1178         // vl = [ xq0(f[3]-d[3]) xq0(f[2]-d[2]) xq0(f[1]-d[1]) xq0(f[0]-d[0]) ]
1179 
1180         // Shift this down with appropriate rounding
1181         const __m128i vrh = _mm_srai_epi32(_mm_add_epi32(vh, rounding), shift);
1182         const __m128i vrl = _mm_srai_epi32(_mm_add_epi32(vl, rounding), shift);
1183 
1184         // Saturate vr0 and vr1 from i32 to i16 then pack together
1185         const __m128i vr = _mm_packs_epi32(vrl, vrh);
1186 
1187         // Subtract twin-subspace-sgr filtered from source image to get error
1188         const __m128i e0 = _mm_sub_epi16(_mm_add_epi16(vr, d0), s0);
1189 
1190         // Calculate squared error and add adjacent values
1191         const __m128i err0 = _mm_madd_epi16(e0, e0);
1192 
1193         sum32 = _mm_add_epi32(sum32, err0);
1194       }
1195 
1196       const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
1197       sum64 = _mm_add_epi64(sum64, sum32l);
1198       const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
1199       sum64 = _mm_add_epi64(sum64, sum32h);
1200 
1201       // Process remaining pixels in this row (modulo 8)
1202       for (k = j; k < width; ++k) {
1203         const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
1204         int32_t v = xq_on * (flt[k] - u);
1205         const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
1206         err += ((int64_t)e * e);
1207       }
1208       dat += dat_stride;
1209       src += src_stride;
1210       flt += flt_stride;
1211     }
1212   } else {  // Neither filter is enabled
1213     for (i = 0; i < height; ++i) {
1214       __m128i sum32 = _mm_setzero_si128();
1215       for (j = 0; j <= width - 16; j += 16) {
1216         // Load 2x8 u16 from source image
1217         const __m128i s0 = xx_loadu_128(src + j);
1218         const __m128i s1 = xx_loadu_128(src + j + 8);
1219         // Load 2x8 u16 from corrupted image
1220         const __m128i d0 = xx_loadu_128(dat + j);
1221         const __m128i d1 = xx_loadu_128(dat + j + 8);
1222 
1223         // Subtract corrupted image from source image
1224         const __m128i diff0 = _mm_sub_epi16(d0, s0);
1225         const __m128i diff1 = _mm_sub_epi16(d1, s1);
1226 
1227         // Square error and add adjacent values
1228         const __m128i err0 = _mm_madd_epi16(diff0, diff0);
1229         const __m128i err1 = _mm_madd_epi16(diff1, diff1);
1230 
1231         sum32 = _mm_add_epi32(sum32, err0);
1232         sum32 = _mm_add_epi32(sum32, err1);
1233       }
1234 
1235       const __m128i sum32l = _mm_cvtepu32_epi64(sum32);
1236       sum64 = _mm_add_epi64(sum64, sum32l);
1237       const __m128i sum32h = _mm_cvtepu32_epi64(_mm_srli_si128(sum32, 8));
1238       sum64 = _mm_add_epi64(sum64, sum32h);
1239 
1240       // Process remaining pixels (modulu 8)
1241       for (k = j; k < width; ++k) {
1242         const int32_t e = (int32_t)(dat[k]) - src[k];
1243         err += ((int64_t)e * e);
1244       }
1245       dat += dat_stride;
1246       src += src_stride;
1247     }
1248   }
1249 
1250   // Sum 4 values from sum64l and sum64h into err
1251   int64_t sum[2];
1252   xx_storeu_128(sum, sum64);
1253   err += sum[0] + sum[1];
1254   return err;
1255 }
1256 #endif  // CONFIG_AV1_HIGHBITDEPTH
1257