• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_SSE4_1
19 
20 #include <xmmintrin.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35 
36 constexpr int kInterPostRoundBit = 4;
37 
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weights)38 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
39                                        const __m128i& pred1,
40                                        const __m128i& weights) {
41   // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
42   const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
43   const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
44   const __m128i result_lo =
45       RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
46 
47   const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
48   const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
49   const __m128i result_hi =
50       RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
51 
52   return _mm_packs_epi32(result_lo, result_hi);
53 }
54 
55 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)56 inline void DistanceWeightedBlend4xH_SSE4_1(
57     const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
58     const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
59   auto* dst = static_cast<uint8_t*>(dest);
60   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
61 
62   for (int y = 0; y < height; y += 4) {
63     // TODO(b/150326556): Use larger loads.
64     const __m128i src_00 = LoadLo8(pred_0);
65     const __m128i src_10 = LoadLo8(pred_1);
66     pred_0 += 4;
67     pred_1 += 4;
68     __m128i src_0 = LoadHi8(src_00, pred_0);
69     __m128i src_1 = LoadHi8(src_10, pred_1);
70     pred_0 += 4;
71     pred_1 += 4;
72     const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
73 
74     const __m128i src_01 = LoadLo8(pred_0);
75     const __m128i src_11 = LoadLo8(pred_1);
76     pred_0 += 4;
77     pred_1 += 4;
78     src_0 = LoadHi8(src_01, pred_0);
79     src_1 = LoadHi8(src_11, pred_1);
80     pred_0 += 4;
81     pred_1 += 4;
82     const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
83 
84     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
85     Store4(dst, result_pixels);
86     dst += dest_stride;
87     const int result_1 = _mm_extract_epi32(result_pixels, 1);
88     memcpy(dst, &result_1, sizeof(result_1));
89     dst += dest_stride;
90     const int result_2 = _mm_extract_epi32(result_pixels, 2);
91     memcpy(dst, &result_2, sizeof(result_2));
92     dst += dest_stride;
93     const int result_3 = _mm_extract_epi32(result_pixels, 3);
94     memcpy(dst, &result_3, sizeof(result_3));
95     dst += dest_stride;
96   }
97 }
98 
99 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)100 inline void DistanceWeightedBlend8xH_SSE4_1(
101     const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
102     const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
103   auto* dst = static_cast<uint8_t*>(dest);
104   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
105 
106   for (int y = 0; y < height; y += 2) {
107     const __m128i src_00 = LoadAligned16(pred_0);
108     const __m128i src_10 = LoadAligned16(pred_1);
109     pred_0 += 8;
110     pred_1 += 8;
111     const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
112 
113     const __m128i src_01 = LoadAligned16(pred_0);
114     const __m128i src_11 = LoadAligned16(pred_1);
115     pred_0 += 8;
116     pred_1 += 8;
117     const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
118 
119     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
120     StoreLo8(dst, result_pixels);
121     dst += dest_stride;
122     StoreHi8(dst, result_pixels);
123     dst += dest_stride;
124   }
125 }
126 
DistanceWeightedBlendLarge_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)127 inline void DistanceWeightedBlendLarge_SSE4_1(
128     const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
129     const uint8_t weight_1, const int width, const int height, void* const dest,
130     const ptrdiff_t dest_stride) {
131   auto* dst = static_cast<uint8_t*>(dest);
132   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
133 
134   int y = height;
135   do {
136     int x = 0;
137     do {
138       const __m128i src_0_lo = LoadAligned16(pred_0 + x);
139       const __m128i src_1_lo = LoadAligned16(pred_1 + x);
140       const __m128i res_lo =
141           ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
142 
143       const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
144       const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
145       const __m128i res_hi =
146           ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
147 
148       StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
149       x += 16;
150     } while (x < width);
151     dst += dest_stride;
152     pred_0 += width;
153     pred_1 += width;
154   } while (--y != 0);
155 }
156 
DistanceWeightedBlend_SSE4_1(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)157 void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
158                                   const void* prediction_1,
159                                   const uint8_t weight_0,
160                                   const uint8_t weight_1, const int width,
161                                   const int height, void* const dest,
162                                   const ptrdiff_t dest_stride) {
163   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
164   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
165   if (width == 4) {
166     if (height == 4) {
167       DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
168                                          dest, dest_stride);
169     } else if (height == 8) {
170       DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
171                                          dest, dest_stride);
172     } else {
173       assert(height == 16);
174       DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
175                                           dest, dest_stride);
176     }
177     return;
178   }
179 
180   if (width == 8) {
181     switch (height) {
182       case 4:
183         DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
184                                            dest, dest_stride);
185         return;
186       case 8:
187         DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
188                                            dest, dest_stride);
189         return;
190       case 16:
191         DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
192                                             dest, dest_stride);
193         return;
194       default:
195         assert(height == 32);
196         DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
197                                             dest, dest_stride);
198 
199         return;
200     }
201   }
202 
203   DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
204                                     height, dest, dest_stride);
205 }
206 
Init8bpp()207 void Init8bpp() {
208   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
209   assert(dsp != nullptr);
210 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
211   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
212 #endif
213 }
214 
215 }  // namespace
216 }  // namespace low_bitdepth
217 
218 #if LIBGAV1_MAX_BITDEPTH >= 10
219 namespace high_bitdepth {
220 namespace {
221 
222 constexpr int kMax10bppSample = (1 << 10) - 1;
223 constexpr int kInterPostRoundBit = 4;
224 
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight0,const __m128i & weight1)225 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
226                                        const __m128i& pred1,
227                                        const __m128i& weight0,
228                                        const __m128i& weight1) {
229   // This offset is a combination of round_factor and round_offset
230   // which are to be added and subtracted respectively.
231   // Here kInterPostRoundBit + 4 is considering bitdepth=10.
232   constexpr int offset =
233       (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4);
234   const __m128i zero = _mm_setzero_si128();
235   const __m128i bias = _mm_set1_epi32(offset);
236   const __m128i clip_high = _mm_set1_epi16(kMax10bppSample);
237 
238   __m128i prediction0 = _mm_cvtepu16_epi32(pred0);
239   __m128i mult0 = _mm_mullo_epi32(prediction0, weight0);
240   __m128i prediction1 = _mm_cvtepu16_epi32(pred1);
241   __m128i mult1 = _mm_mullo_epi32(prediction1, weight1);
242   __m128i sum = _mm_add_epi32(mult0, mult1);
243   sum = _mm_add_epi32(sum, bias);
244   const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
245 
246   prediction0 = _mm_unpackhi_epi16(pred0, zero);
247   mult0 = _mm_mullo_epi32(prediction0, weight0);
248   prediction1 = _mm_unpackhi_epi16(pred1, zero);
249   mult1 = _mm_mullo_epi32(prediction1, weight1);
250   sum = _mm_add_epi32(mult0, mult1);
251   sum = _mm_add_epi32(sum, bias);
252   const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
253   const __m128i pack = _mm_packus_epi32(result0, result1);
254 
255   return _mm_min_epi16(pack, clip_high);
256 }
257 
258 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const uint16_t * pred_0,const uint16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)259 inline void DistanceWeightedBlend4xH_SSE4_1(
260     const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
261     const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
262   auto* dst = static_cast<uint16_t*>(dest);
263   const __m128i weight0 = _mm_set1_epi32(weight_0);
264   const __m128i weight1 = _mm_set1_epi32(weight_1);
265 
266   int y = height;
267   do {
268     const __m128i src_00 = LoadLo8(pred_0);
269     const __m128i src_10 = LoadLo8(pred_1);
270     pred_0 += 4;
271     pred_1 += 4;
272     __m128i src_0 = LoadHi8(src_00, pred_0);
273     __m128i src_1 = LoadHi8(src_10, pred_1);
274     pred_0 += 4;
275     pred_1 += 4;
276     const __m128i res0 =
277         ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
278 
279     const __m128i src_01 = LoadLo8(pred_0);
280     const __m128i src_11 = LoadLo8(pred_1);
281     pred_0 += 4;
282     pred_1 += 4;
283     src_0 = LoadHi8(src_01, pred_0);
284     src_1 = LoadHi8(src_11, pred_1);
285     pred_0 += 4;
286     pred_1 += 4;
287     const __m128i res1 =
288         ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
289 
290     StoreLo8(dst, res0);
291     dst += dest_stride;
292     StoreHi8(dst, res0);
293     dst += dest_stride;
294     StoreLo8(dst, res1);
295     dst += dest_stride;
296     StoreHi8(dst, res1);
297     dst += dest_stride;
298     y -= 4;
299   } while (y != 0);
300 }
301 
302 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const uint16_t * pred_0,const uint16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)303 inline void DistanceWeightedBlend8xH_SSE4_1(
304     const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
305     const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
306   auto* dst = static_cast<uint16_t*>(dest);
307   const __m128i weight0 = _mm_set1_epi32(weight_0);
308   const __m128i weight1 = _mm_set1_epi32(weight_1);
309 
310   int y = height;
311   do {
312     const __m128i src_00 = LoadAligned16(pred_0);
313     const __m128i src_10 = LoadAligned16(pred_1);
314     pred_0 += 8;
315     pred_1 += 8;
316     const __m128i res0 =
317         ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
318 
319     const __m128i src_01 = LoadAligned16(pred_0);
320     const __m128i src_11 = LoadAligned16(pred_1);
321     pred_0 += 8;
322     pred_1 += 8;
323     const __m128i res1 =
324         ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
325 
326     StoreUnaligned16(dst, res0);
327     dst += dest_stride;
328     StoreUnaligned16(dst, res1);
329     dst += dest_stride;
330     y -= 2;
331   } while (y != 0);
332 }
333 
DistanceWeightedBlendLarge_SSE4_1(const uint16_t * pred_0,const uint16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)334 inline void DistanceWeightedBlendLarge_SSE4_1(
335     const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
336     const uint8_t weight_1, const int width, const int height, void* const dest,
337     const ptrdiff_t dest_stride) {
338   auto* dst = static_cast<uint16_t*>(dest);
339   const __m128i weight0 = _mm_set1_epi32(weight_0);
340   const __m128i weight1 = _mm_set1_epi32(weight_1);
341 
342   int y = height;
343   do {
344     int x = 0;
345     do {
346       const __m128i src_0_lo = LoadAligned16(pred_0 + x);
347       const __m128i src_1_lo = LoadAligned16(pred_1 + x);
348       const __m128i res_lo =
349           ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1);
350 
351       const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
352       const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
353       const __m128i res_hi =
354           ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1);
355 
356       StoreUnaligned16(dst + x, res_lo);
357       x += 8;
358       StoreUnaligned16(dst + x, res_hi);
359       x += 8;
360     } while (x < width);
361     dst += dest_stride;
362     pred_0 += width;
363     pred_1 += width;
364   } while (--y != 0);
365 }
366 
DistanceWeightedBlend_SSE4_1(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)367 void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
368                                   const void* prediction_1,
369                                   const uint8_t weight_0,
370                                   const uint8_t weight_1, const int width,
371                                   const int height, void* const dest,
372                                   const ptrdiff_t dest_stride) {
373   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
374   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
375   const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0);
376   if (width == 4) {
377     if (height == 4) {
378       DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
379                                          dest, dst_stride);
380     } else if (height == 8) {
381       DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
382                                          dest, dst_stride);
383     } else {
384       assert(height == 16);
385       DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
386                                           dest, dst_stride);
387     }
388     return;
389   }
390 
391   if (width == 8) {
392     switch (height) {
393       case 4:
394         DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
395                                            dest, dst_stride);
396         return;
397       case 8:
398         DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
399                                            dest, dst_stride);
400         return;
401       case 16:
402         DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
403                                             dest, dst_stride);
404         return;
405       default:
406         assert(height == 32);
407         DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
408                                             dest, dst_stride);
409 
410         return;
411     }
412   }
413 
414   DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
415                                     height, dest, dst_stride);
416 }
417 
Init10bpp()418 void Init10bpp() {
419   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
420   assert(dsp != nullptr);
421 #if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend)
422   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
423 #endif
424 }
425 
426 }  // namespace
427 }  // namespace high_bitdepth
428 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
429 
DistanceWeightedBlendInit_SSE4_1()430 void DistanceWeightedBlendInit_SSE4_1() {
431   low_bitdepth::Init8bpp();
432 #if LIBGAV1_MAX_BITDEPTH >= 10
433   high_bitdepth::Init10bpp();
434 #endif
435 }
436 
437 }  // namespace dsp
438 }  // namespace libgav1
439 
440 #else   // !LIBGAV1_TARGETING_SSE4_1
441 
442 namespace libgav1 {
443 namespace dsp {
444 
DistanceWeightedBlendInit_SSE4_1()445 void DistanceWeightedBlendInit_SSE4_1() {}
446 
447 }  // namespace dsp
448 }  // namespace libgav1
449 #endif  // LIBGAV1_TARGETING_SSE4_1
450