• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_SSE4_1
19 
20 #include <xmmintrin.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace {
34 
35 constexpr int kInterPostRoundBit = 4;
36 
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weights)37 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
38                                        const __m128i& pred1,
39                                        const __m128i& weights) {
40   // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
41   const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
42   const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
43   const __m128i result_lo =
44       RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
45 
46   const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
47   const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
48   const __m128i result_hi =
49       RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
50 
51   return _mm_packs_epi32(result_lo, result_hi);
52 }
53 
54 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)55 inline void DistanceWeightedBlend4xH_SSE4_1(
56     const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
57     const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
58   auto* dst = static_cast<uint8_t*>(dest);
59   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
60 
61   for (int y = 0; y < height; y += 4) {
62     // TODO(b/150326556): Use larger loads.
63     const __m128i src_00 = LoadLo8(pred_0);
64     const __m128i src_10 = LoadLo8(pred_1);
65     pred_0 += 4;
66     pred_1 += 4;
67     __m128i src_0 = LoadHi8(src_00, pred_0);
68     __m128i src_1 = LoadHi8(src_10, pred_1);
69     pred_0 += 4;
70     pred_1 += 4;
71     const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
72 
73     const __m128i src_01 = LoadLo8(pred_0);
74     const __m128i src_11 = LoadLo8(pred_1);
75     pred_0 += 4;
76     pred_1 += 4;
77     src_0 = LoadHi8(src_01, pred_0);
78     src_1 = LoadHi8(src_11, pred_1);
79     pred_0 += 4;
80     pred_1 += 4;
81     const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
82 
83     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
84     Store4(dst, result_pixels);
85     dst += dest_stride;
86     const int result_1 = _mm_extract_epi32(result_pixels, 1);
87     memcpy(dst, &result_1, sizeof(result_1));
88     dst += dest_stride;
89     const int result_2 = _mm_extract_epi32(result_pixels, 2);
90     memcpy(dst, &result_2, sizeof(result_2));
91     dst += dest_stride;
92     const int result_3 = _mm_extract_epi32(result_pixels, 3);
93     memcpy(dst, &result_3, sizeof(result_3));
94     dst += dest_stride;
95   }
96 }
97 
98 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)99 inline void DistanceWeightedBlend8xH_SSE4_1(
100     const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
101     const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
102   auto* dst = static_cast<uint8_t*>(dest);
103   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
104 
105   for (int y = 0; y < height; y += 2) {
106     const __m128i src_00 = LoadAligned16(pred_0);
107     const __m128i src_10 = LoadAligned16(pred_1);
108     pred_0 += 8;
109     pred_1 += 8;
110     const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
111 
112     const __m128i src_01 = LoadAligned16(pred_0);
113     const __m128i src_11 = LoadAligned16(pred_1);
114     pred_0 += 8;
115     pred_1 += 8;
116     const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
117 
118     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
119     StoreLo8(dst, result_pixels);
120     dst += dest_stride;
121     StoreHi8(dst, result_pixels);
122     dst += dest_stride;
123   }
124 }
125 
DistanceWeightedBlendLarge_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)126 inline void DistanceWeightedBlendLarge_SSE4_1(
127     const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
128     const uint8_t weight_1, const int width, const int height, void* const dest,
129     const ptrdiff_t dest_stride) {
130   auto* dst = static_cast<uint8_t*>(dest);
131   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
132 
133   int y = height;
134   do {
135     int x = 0;
136     do {
137       const __m128i src_0_lo = LoadAligned16(pred_0 + x);
138       const __m128i src_1_lo = LoadAligned16(pred_1 + x);
139       const __m128i res_lo =
140           ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
141 
142       const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
143       const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
144       const __m128i res_hi =
145           ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
146 
147       StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
148       x += 16;
149     } while (x < width);
150     dst += dest_stride;
151     pred_0 += width;
152     pred_1 += width;
153   } while (--y != 0);
154 }
155 
DistanceWeightedBlend_SSE4_1(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)156 void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
157                                   const void* prediction_1,
158                                   const uint8_t weight_0,
159                                   const uint8_t weight_1, const int width,
160                                   const int height, void* const dest,
161                                   const ptrdiff_t dest_stride) {
162   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
163   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
164   if (width == 4) {
165     if (height == 4) {
166       DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
167                                          dest, dest_stride);
168     } else if (height == 8) {
169       DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
170                                          dest, dest_stride);
171     } else {
172       assert(height == 16);
173       DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
174                                           dest, dest_stride);
175     }
176     return;
177   }
178 
179   if (width == 8) {
180     switch (height) {
181       case 4:
182         DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
183                                            dest, dest_stride);
184         return;
185       case 8:
186         DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
187                                            dest, dest_stride);
188         return;
189       case 16:
190         DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
191                                             dest, dest_stride);
192         return;
193       default:
194         assert(height == 32);
195         DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
196                                             dest, dest_stride);
197 
198         return;
199     }
200   }
201 
202   DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
203                                     height, dest, dest_stride);
204 }
205 
Init8bpp()206 void Init8bpp() {
207   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
208   assert(dsp != nullptr);
209 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
210   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
211 #endif
212 }
213 
214 }  // namespace
215 
DistanceWeightedBlendInit_SSE4_1()216 void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); }
217 
218 }  // namespace dsp
219 }  // namespace libgav1
220 
221 #else  // !LIBGAV1_ENABLE_SSE4_1
222 
223 namespace libgav1 {
224 namespace dsp {
225 
DistanceWeightedBlendInit_SSE4_1()226 void DistanceWeightedBlendInit_SSE4_1() {}
227 
228 }  // namespace dsp
229 }  // namespace libgav1
230 #endif  // LIBGAV1_ENABLE_SSE4_1
231