1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_TARGETING_SSE4_1
19
20 #include <xmmintrin.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35
36 constexpr int kInterPostRoundBit = 4;
37
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weights)38 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
39 const __m128i& pred1,
40 const __m128i& weights) {
41 // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
42 const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
43 const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
44 const __m128i result_lo =
45 RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
46
47 const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
48 const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
49 const __m128i result_hi =
50 RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
51
52 return _mm_packs_epi32(result_lo, result_hi);
53 }
54
55 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)56 inline void DistanceWeightedBlend4xH_SSE4_1(
57 const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
58 const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
59 auto* dst = static_cast<uint8_t*>(dest);
60 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
61
62 for (int y = 0; y < height; y += 4) {
63 // TODO(b/150326556): Use larger loads.
64 const __m128i src_00 = LoadLo8(pred_0);
65 const __m128i src_10 = LoadLo8(pred_1);
66 pred_0 += 4;
67 pred_1 += 4;
68 __m128i src_0 = LoadHi8(src_00, pred_0);
69 __m128i src_1 = LoadHi8(src_10, pred_1);
70 pred_0 += 4;
71 pred_1 += 4;
72 const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
73
74 const __m128i src_01 = LoadLo8(pred_0);
75 const __m128i src_11 = LoadLo8(pred_1);
76 pred_0 += 4;
77 pred_1 += 4;
78 src_0 = LoadHi8(src_01, pred_0);
79 src_1 = LoadHi8(src_11, pred_1);
80 pred_0 += 4;
81 pred_1 += 4;
82 const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
83
84 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
85 Store4(dst, result_pixels);
86 dst += dest_stride;
87 const int result_1 = _mm_extract_epi32(result_pixels, 1);
88 memcpy(dst, &result_1, sizeof(result_1));
89 dst += dest_stride;
90 const int result_2 = _mm_extract_epi32(result_pixels, 2);
91 memcpy(dst, &result_2, sizeof(result_2));
92 dst += dest_stride;
93 const int result_3 = _mm_extract_epi32(result_pixels, 3);
94 memcpy(dst, &result_3, sizeof(result_3));
95 dst += dest_stride;
96 }
97 }
98
99 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)100 inline void DistanceWeightedBlend8xH_SSE4_1(
101 const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
102 const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
103 auto* dst = static_cast<uint8_t*>(dest);
104 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
105
106 for (int y = 0; y < height; y += 2) {
107 const __m128i src_00 = LoadAligned16(pred_0);
108 const __m128i src_10 = LoadAligned16(pred_1);
109 pred_0 += 8;
110 pred_1 += 8;
111 const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
112
113 const __m128i src_01 = LoadAligned16(pred_0);
114 const __m128i src_11 = LoadAligned16(pred_1);
115 pred_0 += 8;
116 pred_1 += 8;
117 const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
118
119 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
120 StoreLo8(dst, result_pixels);
121 dst += dest_stride;
122 StoreHi8(dst, result_pixels);
123 dst += dest_stride;
124 }
125 }
126
DistanceWeightedBlendLarge_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)127 inline void DistanceWeightedBlendLarge_SSE4_1(
128 const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
129 const uint8_t weight_1, const int width, const int height, void* const dest,
130 const ptrdiff_t dest_stride) {
131 auto* dst = static_cast<uint8_t*>(dest);
132 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
133
134 int y = height;
135 do {
136 int x = 0;
137 do {
138 const __m128i src_0_lo = LoadAligned16(pred_0 + x);
139 const __m128i src_1_lo = LoadAligned16(pred_1 + x);
140 const __m128i res_lo =
141 ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
142
143 const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
144 const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
145 const __m128i res_hi =
146 ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
147
148 StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
149 x += 16;
150 } while (x < width);
151 dst += dest_stride;
152 pred_0 += width;
153 pred_1 += width;
154 } while (--y != 0);
155 }
156
DistanceWeightedBlend_SSE4_1(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)157 void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
158 const void* prediction_1,
159 const uint8_t weight_0,
160 const uint8_t weight_1, const int width,
161 const int height, void* const dest,
162 const ptrdiff_t dest_stride) {
163 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
164 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
165 if (width == 4) {
166 if (height == 4) {
167 DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
168 dest, dest_stride);
169 } else if (height == 8) {
170 DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
171 dest, dest_stride);
172 } else {
173 assert(height == 16);
174 DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
175 dest, dest_stride);
176 }
177 return;
178 }
179
180 if (width == 8) {
181 switch (height) {
182 case 4:
183 DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
184 dest, dest_stride);
185 return;
186 case 8:
187 DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
188 dest, dest_stride);
189 return;
190 case 16:
191 DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
192 dest, dest_stride);
193 return;
194 default:
195 assert(height == 32);
196 DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
197 dest, dest_stride);
198
199 return;
200 }
201 }
202
203 DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
204 height, dest, dest_stride);
205 }
206
Init8bpp()207 void Init8bpp() {
208 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
209 assert(dsp != nullptr);
210 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
211 dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
212 #endif
213 }
214
215 } // namespace
216 } // namespace low_bitdepth
217
218 #if LIBGAV1_MAX_BITDEPTH >= 10
219 namespace high_bitdepth {
220 namespace {
221
222 constexpr int kMax10bppSample = (1 << 10) - 1;
223 constexpr int kInterPostRoundBit = 4;
224
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight0,const __m128i & weight1)225 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
226 const __m128i& pred1,
227 const __m128i& weight0,
228 const __m128i& weight1) {
229 // This offset is a combination of round_factor and round_offset
230 // which are to be added and subtracted respectively.
231 // Here kInterPostRoundBit + 4 is considering bitdepth=10.
232 constexpr int offset =
233 (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4);
234 const __m128i zero = _mm_setzero_si128();
235 const __m128i bias = _mm_set1_epi32(offset);
236 const __m128i clip_high = _mm_set1_epi16(kMax10bppSample);
237
238 __m128i prediction0 = _mm_cvtepu16_epi32(pred0);
239 __m128i mult0 = _mm_mullo_epi32(prediction0, weight0);
240 __m128i prediction1 = _mm_cvtepu16_epi32(pred1);
241 __m128i mult1 = _mm_mullo_epi32(prediction1, weight1);
242 __m128i sum = _mm_add_epi32(mult0, mult1);
243 sum = _mm_add_epi32(sum, bias);
244 const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
245
246 prediction0 = _mm_unpackhi_epi16(pred0, zero);
247 mult0 = _mm_mullo_epi32(prediction0, weight0);
248 prediction1 = _mm_unpackhi_epi16(pred1, zero);
249 mult1 = _mm_mullo_epi32(prediction1, weight1);
250 sum = _mm_add_epi32(mult0, mult1);
251 sum = _mm_add_epi32(sum, bias);
252 const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
253 const __m128i pack = _mm_packus_epi32(result0, result1);
254
255 return _mm_min_epi16(pack, clip_high);
256 }
257
258 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const uint16_t * pred_0,const uint16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)259 inline void DistanceWeightedBlend4xH_SSE4_1(
260 const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
261 const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
262 auto* dst = static_cast<uint16_t*>(dest);
263 const __m128i weight0 = _mm_set1_epi32(weight_0);
264 const __m128i weight1 = _mm_set1_epi32(weight_1);
265
266 int y = height;
267 do {
268 const __m128i src_00 = LoadLo8(pred_0);
269 const __m128i src_10 = LoadLo8(pred_1);
270 pred_0 += 4;
271 pred_1 += 4;
272 __m128i src_0 = LoadHi8(src_00, pred_0);
273 __m128i src_1 = LoadHi8(src_10, pred_1);
274 pred_0 += 4;
275 pred_1 += 4;
276 const __m128i res0 =
277 ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
278
279 const __m128i src_01 = LoadLo8(pred_0);
280 const __m128i src_11 = LoadLo8(pred_1);
281 pred_0 += 4;
282 pred_1 += 4;
283 src_0 = LoadHi8(src_01, pred_0);
284 src_1 = LoadHi8(src_11, pred_1);
285 pred_0 += 4;
286 pred_1 += 4;
287 const __m128i res1 =
288 ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
289
290 StoreLo8(dst, res0);
291 dst += dest_stride;
292 StoreHi8(dst, res0);
293 dst += dest_stride;
294 StoreLo8(dst, res1);
295 dst += dest_stride;
296 StoreHi8(dst, res1);
297 dst += dest_stride;
298 y -= 4;
299 } while (y != 0);
300 }
301
302 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const uint16_t * pred_0,const uint16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)303 inline void DistanceWeightedBlend8xH_SSE4_1(
304 const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
305 const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
306 auto* dst = static_cast<uint16_t*>(dest);
307 const __m128i weight0 = _mm_set1_epi32(weight_0);
308 const __m128i weight1 = _mm_set1_epi32(weight_1);
309
310 int y = height;
311 do {
312 const __m128i src_00 = LoadAligned16(pred_0);
313 const __m128i src_10 = LoadAligned16(pred_1);
314 pred_0 += 8;
315 pred_1 += 8;
316 const __m128i res0 =
317 ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
318
319 const __m128i src_01 = LoadAligned16(pred_0);
320 const __m128i src_11 = LoadAligned16(pred_1);
321 pred_0 += 8;
322 pred_1 += 8;
323 const __m128i res1 =
324 ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
325
326 StoreUnaligned16(dst, res0);
327 dst += dest_stride;
328 StoreUnaligned16(dst, res1);
329 dst += dest_stride;
330 y -= 2;
331 } while (y != 0);
332 }
333
DistanceWeightedBlendLarge_SSE4_1(const uint16_t * pred_0,const uint16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)334 inline void DistanceWeightedBlendLarge_SSE4_1(
335 const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
336 const uint8_t weight_1, const int width, const int height, void* const dest,
337 const ptrdiff_t dest_stride) {
338 auto* dst = static_cast<uint16_t*>(dest);
339 const __m128i weight0 = _mm_set1_epi32(weight_0);
340 const __m128i weight1 = _mm_set1_epi32(weight_1);
341
342 int y = height;
343 do {
344 int x = 0;
345 do {
346 const __m128i src_0_lo = LoadAligned16(pred_0 + x);
347 const __m128i src_1_lo = LoadAligned16(pred_1 + x);
348 const __m128i res_lo =
349 ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1);
350
351 const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
352 const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
353 const __m128i res_hi =
354 ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1);
355
356 StoreUnaligned16(dst + x, res_lo);
357 x += 8;
358 StoreUnaligned16(dst + x, res_hi);
359 x += 8;
360 } while (x < width);
361 dst += dest_stride;
362 pred_0 += width;
363 pred_1 += width;
364 } while (--y != 0);
365 }
366
DistanceWeightedBlend_SSE4_1(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)367 void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
368 const void* prediction_1,
369 const uint8_t weight_0,
370 const uint8_t weight_1, const int width,
371 const int height, void* const dest,
372 const ptrdiff_t dest_stride) {
373 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
374 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
375 const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0);
376 if (width == 4) {
377 if (height == 4) {
378 DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
379 dest, dst_stride);
380 } else if (height == 8) {
381 DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
382 dest, dst_stride);
383 } else {
384 assert(height == 16);
385 DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
386 dest, dst_stride);
387 }
388 return;
389 }
390
391 if (width == 8) {
392 switch (height) {
393 case 4:
394 DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
395 dest, dst_stride);
396 return;
397 case 8:
398 DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
399 dest, dst_stride);
400 return;
401 case 16:
402 DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
403 dest, dst_stride);
404 return;
405 default:
406 assert(height == 32);
407 DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
408 dest, dst_stride);
409
410 return;
411 }
412 }
413
414 DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
415 height, dest, dst_stride);
416 }
417
Init10bpp()418 void Init10bpp() {
419 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
420 assert(dsp != nullptr);
421 #if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend)
422 dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
423 #endif
424 }
425
426 } // namespace
427 } // namespace high_bitdepth
428 #endif // LIBGAV1_MAX_BITDEPTH >= 10
429
DistanceWeightedBlendInit_SSE4_1()430 void DistanceWeightedBlendInit_SSE4_1() {
431 low_bitdepth::Init8bpp();
432 #if LIBGAV1_MAX_BITDEPTH >= 10
433 high_bitdepth::Init10bpp();
434 #endif
435 }
436
437 } // namespace dsp
438 } // namespace libgav1
439
440 #else // !LIBGAV1_TARGETING_SSE4_1
441
442 namespace libgav1 {
443 namespace dsp {
444
DistanceWeightedBlendInit_SSE4_1()445 void DistanceWeightedBlendInit_SSE4_1() {}
446
447 } // namespace dsp
448 } // namespace libgav1
449 #endif // LIBGAV1_TARGETING_SSE4_1
450