1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_ENABLE_SSE4_1
19
20 #include <xmmintrin.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace {
34
35 constexpr int kInterPostRoundBit = 4;
36
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weights)37 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
38 const __m128i& pred1,
39 const __m128i& weights) {
40 // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
41 const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
42 const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
43 const __m128i result_lo =
44 RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
45
46 const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
47 const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
48 const __m128i result_hi =
49 RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
50
51 return _mm_packs_epi32(result_lo, result_hi);
52 }
53
54 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)55 inline void DistanceWeightedBlend4xH_SSE4_1(
56 const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
57 const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
58 auto* dst = static_cast<uint8_t*>(dest);
59 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
60
61 for (int y = 0; y < height; y += 4) {
62 // TODO(b/150326556): Use larger loads.
63 const __m128i src_00 = LoadLo8(pred_0);
64 const __m128i src_10 = LoadLo8(pred_1);
65 pred_0 += 4;
66 pred_1 += 4;
67 __m128i src_0 = LoadHi8(src_00, pred_0);
68 __m128i src_1 = LoadHi8(src_10, pred_1);
69 pred_0 += 4;
70 pred_1 += 4;
71 const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
72
73 const __m128i src_01 = LoadLo8(pred_0);
74 const __m128i src_11 = LoadLo8(pred_1);
75 pred_0 += 4;
76 pred_1 += 4;
77 src_0 = LoadHi8(src_01, pred_0);
78 src_1 = LoadHi8(src_11, pred_1);
79 pred_0 += 4;
80 pred_1 += 4;
81 const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
82
83 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
84 Store4(dst, result_pixels);
85 dst += dest_stride;
86 const int result_1 = _mm_extract_epi32(result_pixels, 1);
87 memcpy(dst, &result_1, sizeof(result_1));
88 dst += dest_stride;
89 const int result_2 = _mm_extract_epi32(result_pixels, 2);
90 memcpy(dst, &result_2, sizeof(result_2));
91 dst += dest_stride;
92 const int result_3 = _mm_extract_epi32(result_pixels, 3);
93 memcpy(dst, &result_3, sizeof(result_3));
94 dst += dest_stride;
95 }
96 }
97
98 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,void * const dest,const ptrdiff_t dest_stride)99 inline void DistanceWeightedBlend8xH_SSE4_1(
100 const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
101 const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
102 auto* dst = static_cast<uint8_t*>(dest);
103 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
104
105 for (int y = 0; y < height; y += 2) {
106 const __m128i src_00 = LoadAligned16(pred_0);
107 const __m128i src_10 = LoadAligned16(pred_1);
108 pred_0 += 8;
109 pred_1 += 8;
110 const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
111
112 const __m128i src_01 = LoadAligned16(pred_0);
113 const __m128i src_11 = LoadAligned16(pred_1);
114 pred_0 += 8;
115 pred_1 += 8;
116 const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
117
118 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
119 StoreLo8(dst, result_pixels);
120 dst += dest_stride;
121 StoreHi8(dst, result_pixels);
122 dst += dest_stride;
123 }
124 }
125
DistanceWeightedBlendLarge_SSE4_1(const int16_t * pred_0,const int16_t * pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)126 inline void DistanceWeightedBlendLarge_SSE4_1(
127 const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
128 const uint8_t weight_1, const int width, const int height, void* const dest,
129 const ptrdiff_t dest_stride) {
130 auto* dst = static_cast<uint8_t*>(dest);
131 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
132
133 int y = height;
134 do {
135 int x = 0;
136 do {
137 const __m128i src_0_lo = LoadAligned16(pred_0 + x);
138 const __m128i src_1_lo = LoadAligned16(pred_1 + x);
139 const __m128i res_lo =
140 ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
141
142 const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
143 const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
144 const __m128i res_hi =
145 ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
146
147 StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
148 x += 16;
149 } while (x < width);
150 dst += dest_stride;
151 pred_0 += width;
152 pred_1 += width;
153 } while (--y != 0);
154 }
155
DistanceWeightedBlend_SSE4_1(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)156 void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
157 const void* prediction_1,
158 const uint8_t weight_0,
159 const uint8_t weight_1, const int width,
160 const int height, void* const dest,
161 const ptrdiff_t dest_stride) {
162 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
163 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
164 if (width == 4) {
165 if (height == 4) {
166 DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
167 dest, dest_stride);
168 } else if (height == 8) {
169 DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
170 dest, dest_stride);
171 } else {
172 assert(height == 16);
173 DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
174 dest, dest_stride);
175 }
176 return;
177 }
178
179 if (width == 8) {
180 switch (height) {
181 case 4:
182 DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
183 dest, dest_stride);
184 return;
185 case 8:
186 DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
187 dest, dest_stride);
188 return;
189 case 16:
190 DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
191 dest, dest_stride);
192 return;
193 default:
194 assert(height == 32);
195 DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
196 dest, dest_stride);
197
198 return;
199 }
200 }
201
202 DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
203 height, dest, dest_stride);
204 }
205
Init8bpp()206 void Init8bpp() {
207 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
208 assert(dsp != nullptr);
209 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
210 dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
211 #endif
212 }
213
214 } // namespace
215
DistanceWeightedBlendInit_SSE4_1()216 void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); }
217
218 } // namespace dsp
219 } // namespace libgav1
220
221 #else // !LIBGAV1_ENABLE_SSE4_1
222
223 namespace libgav1 {
224 namespace dsp {
225
DistanceWeightedBlendInit_SSE4_1()226 void DistanceWeightedBlendInit_SSE4_1() {}
227
228 } // namespace dsp
229 } // namespace libgav1
230 #endif // LIBGAV1_ENABLE_SSE4_1
231