1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/mask_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_ENABLE_NEON
19
20 #include <arm_neon.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35
36 template <int subsampling_y>
GetMask4x2(const uint8_t * mask)37 inline uint8x8_t GetMask4x2(const uint8_t* mask) {
38 if (subsampling_y == 1) {
39 const uint8x16x2_t mask_val = vld2q_u8(mask);
40 const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]);
41 const uint32x2_t row_01 = vreinterpret_u32_u8(vget_low_u8(combined_horz));
42 const uint32x2_t row_23 = vreinterpret_u32_u8(vget_high_u8(combined_horz));
43
44 const uint32x2x2_t row_02_13 = vtrn_u32(row_01, row_23);
45 // Use a halving add to work around the case where all |mask| values are 64.
46 return vrshr_n_u8(vhadd_u8(vreinterpret_u8_u32(row_02_13.val[0]),
47 vreinterpret_u8_u32(row_02_13.val[1])),
48 1);
49 }
50 // subsampling_x == 1
51 const uint8x8x2_t mask_val = vld2_u8(mask);
52 return vrhadd_u8(mask_val.val[0], mask_val.val[1]);
53 }
54
55 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * mask)56 inline uint8x8_t GetMask8(const uint8_t* mask) {
57 if (subsampling_x == 1 && subsampling_y == 1) {
58 const uint8x16x2_t mask_val = vld2q_u8(mask);
59 const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]);
60 // Use a halving add to work around the case where all |mask| values are 64.
61 return vrshr_n_u8(
62 vhadd_u8(vget_low_u8(combined_horz), vget_high_u8(combined_horz)), 1);
63 }
64 if (subsampling_x == 1) {
65 const uint8x8x2_t mask_val = vld2_u8(mask);
66 return vrhadd_u8(mask_val.val[0], mask_val.val[1]);
67 }
68 assert(subsampling_y == 0 && subsampling_x == 0);
69 return vld1_u8(mask);
70 }
71
WriteMaskBlendLine4x2(const int16_t * LIBGAV1_RESTRICT const pred_0,const int16_t * LIBGAV1_RESTRICT const pred_1,const int16x8_t pred_mask_0,const int16x8_t pred_mask_1,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)72 inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
73 const int16_t* LIBGAV1_RESTRICT const pred_1,
74 const int16x8_t pred_mask_0,
75 const int16x8_t pred_mask_1,
76 uint8_t* LIBGAV1_RESTRICT dst,
77 const ptrdiff_t dst_stride) {
78 const int16x8_t pred_val_0 = vld1q_s16(pred_0);
79 const int16x8_t pred_val_1 = vld1q_s16(pred_1);
80 // int res = (mask_value * prediction_0[x] +
81 // (64 - mask_value) * prediction_1[x]) >> 6;
82 const int32x4_t weighted_pred_0_lo =
83 vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
84 const int32x4_t weighted_pred_0_hi =
85 vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
86 const int32x4_t weighted_combo_lo = vmlal_s16(
87 weighted_pred_0_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1));
88 const int32x4_t weighted_combo_hi =
89 vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1),
90 vget_high_s16(pred_val_1));
91 // dst[x] = static_cast<Pixel>(
92 // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
93 // (1 << kBitdepth8) - 1));
94 const uint8x8_t result =
95 vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
96 vshrn_n_s32(weighted_combo_hi, 6)),
97 4);
98 StoreLo4(dst, result);
99 StoreHi4(dst + dst_stride, result);
100 }
101
102 template <int subsampling_y>
MaskBlending4x4_NEON(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT mask,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)103 inline void MaskBlending4x4_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
104 const int16_t* LIBGAV1_RESTRICT pred_1,
105 const uint8_t* LIBGAV1_RESTRICT mask,
106 uint8_t* LIBGAV1_RESTRICT dst,
107 const ptrdiff_t dst_stride) {
108 constexpr int subsampling_x = 1;
109 constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
110 const int16x8_t mask_inverter = vdupq_n_s16(64);
111 // Compound predictors use int16_t values and need to multiply long because
112 // the Convolve range * 64 is 20 bits. Unfortunately there is no multiply
113 // int16_t by int8_t and accumulate into int32_t instruction.
114 int16x8_t pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
115 int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
116 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
117 dst_stride);
118 pred_0 += 4 << subsampling_x;
119 pred_1 += 4 << subsampling_x;
120 mask += mask_stride << (subsampling_x + subsampling_y);
121 dst += dst_stride << subsampling_x;
122
123 pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
124 pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
125 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
126 dst_stride);
127 }
128
129 template <int subsampling_y>
MaskBlending4xH_NEON(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)130 inline void MaskBlending4xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
131 const int16_t* LIBGAV1_RESTRICT pred_1,
132 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
133 const int height,
134 uint8_t* LIBGAV1_RESTRICT dst,
135 const ptrdiff_t dst_stride) {
136 const uint8_t* mask = mask_ptr;
137 if (height == 4) {
138 MaskBlending4x4_NEON<subsampling_y>(pred_0, pred_1, mask, dst, dst_stride);
139 return;
140 }
141 constexpr int subsampling_x = 1;
142 constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
143 const int16x8_t mask_inverter = vdupq_n_s16(64);
144 int y = 0;
145 do {
146 int16x8_t pred_mask_0 =
147 vreinterpretq_s16_u16(vmovl_u8(GetMask4x2<subsampling_y>(mask)));
148 int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
149
150 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
151 dst_stride);
152 pred_0 += 4 << subsampling_x;
153 pred_1 += 4 << subsampling_x;
154 mask += mask_stride << (subsampling_x + subsampling_y);
155 dst += dst_stride << subsampling_x;
156
157 pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
158 pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
159 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
160 dst_stride);
161 pred_0 += 4 << subsampling_x;
162 pred_1 += 4 << subsampling_x;
163 mask += mask_stride << (subsampling_x + subsampling_y);
164 dst += dst_stride << subsampling_x;
165
166 pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
167 pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
168 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
169 dst_stride);
170 pred_0 += 4 << subsampling_x;
171 pred_1 += 4 << subsampling_x;
172 mask += mask_stride << (subsampling_x + subsampling_y);
173 dst += dst_stride << subsampling_x;
174
175 pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
176 pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
177 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
178 dst_stride);
179 pred_0 += 4 << subsampling_x;
180 pred_1 += 4 << subsampling_x;
181 mask += mask_stride << (subsampling_x + subsampling_y);
182 dst += dst_stride << subsampling_x;
183 y += 8;
184 } while (y < height);
185 }
186
CombinePred8(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const int16x8_t pred_mask_0,const int16x8_t pred_mask_1)187 inline uint8x8_t CombinePred8(const int16_t* LIBGAV1_RESTRICT pred_0,
188 const int16_t* LIBGAV1_RESTRICT pred_1,
189 const int16x8_t pred_mask_0,
190 const int16x8_t pred_mask_1) {
191 // First 8 values.
192 const int16x8_t pred_val_0 = vld1q_s16(pred_0);
193 const int16x8_t pred_val_1 = vld1q_s16(pred_1);
194 // int res = (mask_value * prediction_0[x] +
195 // (64 - mask_value) * prediction_1[x]) >> 6;
196 const int32x4_t weighted_pred_lo =
197 vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
198 const int32x4_t weighted_pred_hi =
199 vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
200 const int32x4_t weighted_combo_lo = vmlal_s16(
201 weighted_pred_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1));
202 const int32x4_t weighted_combo_hi = vmlal_s16(
203 weighted_pred_hi, vget_high_s16(pred_mask_1), vget_high_s16(pred_val_1));
204
205 // dst[x] = static_cast<Pixel>(
206 // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
207 // (1 << kBitdepth8) - 1));
208 return vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
209 vshrn_n_s32(weighted_combo_hi, 6)),
210 4);
211 }
212
213 template <int subsampling_x, int subsampling_y>
MaskBlending8xH_NEON(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)214 inline void MaskBlending8xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
215 const int16_t* LIBGAV1_RESTRICT pred_1,
216 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
217 const int height,
218 uint8_t* LIBGAV1_RESTRICT dst,
219 const ptrdiff_t dst_stride) {
220 const uint8_t* mask = mask_ptr;
221 const int16x8_t mask_inverter = vdupq_n_s16(64);
222 int y = height;
223 do {
224 const int16x8_t pred_mask_0 =
225 ZeroExtend(GetMask8<subsampling_x, subsampling_y>(mask));
226 // 64 - mask
227 const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
228 const uint8x8_t result =
229 CombinePred8(pred_0, pred_1, pred_mask_0, pred_mask_1);
230 vst1_u8(dst, result);
231 dst += dst_stride;
232 mask += 8 << (subsampling_x + subsampling_y);
233 pred_0 += 8;
234 pred_1 += 8;
235 } while (--y != 0);
236 }
237
238 template <int subsampling_x, int subsampling_y>
GetMask16(const uint8_t * mask,const ptrdiff_t mask_stride)239 inline uint8x16_t GetMask16(const uint8_t* mask, const ptrdiff_t mask_stride) {
240 if (subsampling_x == 1 && subsampling_y == 1) {
241 const uint8x16x2_t mask_val0 = vld2q_u8(mask);
242 const uint8x16x2_t mask_val1 = vld2q_u8(mask + mask_stride);
243 const uint8x16_t combined_horz0 =
244 vaddq_u8(mask_val0.val[0], mask_val0.val[1]);
245 const uint8x16_t combined_horz1 =
246 vaddq_u8(mask_val1.val[0], mask_val1.val[1]);
247 // Use a halving add to work around the case where all |mask| values are 64.
248 return vrshrq_n_u8(vhaddq_u8(combined_horz0, combined_horz1), 1);
249 }
250 if (subsampling_x == 1) {
251 const uint8x16x2_t mask_val = vld2q_u8(mask);
252 return vrhaddq_u8(mask_val.val[0], mask_val.val[1]);
253 }
254 assert(subsampling_y == 0 && subsampling_x == 0);
255 return vld1q_u8(mask);
256 }
257
258 template <int subsampling_x, int subsampling_y>
MaskBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)259 inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
260 const void* LIBGAV1_RESTRICT prediction_1,
261 const ptrdiff_t /*prediction_stride_1*/,
262 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
263 const ptrdiff_t mask_stride, const int width,
264 const int height, void* LIBGAV1_RESTRICT dest,
265 const ptrdiff_t dst_stride) {
266 auto* dst = static_cast<uint8_t*>(dest);
267 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
268 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
269 if (width == 4) {
270 MaskBlending4xH_NEON<subsampling_y>(pred_0, pred_1, mask_ptr, height, dst,
271 dst_stride);
272 return;
273 }
274 if (width == 8) {
275 MaskBlending8xH_NEON<subsampling_x, subsampling_y>(pred_0, pred_1, mask_ptr,
276 height, dst, dst_stride);
277 return;
278 }
279 const uint8_t* mask = mask_ptr;
280 const int16x8_t mask_inverter = vdupq_n_s16(64);
281 int y = 0;
282 do {
283 int x = 0;
284 do {
285 const uint8x16_t pred_mask_0 = GetMask16<subsampling_x, subsampling_y>(
286 mask + (x << subsampling_x), mask_stride);
287 const int16x8_t pred_mask_0_lo = ZeroExtend(vget_low_u8(pred_mask_0));
288 const int16x8_t pred_mask_0_hi = ZeroExtend(vget_high_u8(pred_mask_0));
289 // 64 - mask
290 const int16x8_t pred_mask_1_lo = vsubq_s16(mask_inverter, pred_mask_0_lo);
291 const int16x8_t pred_mask_1_hi = vsubq_s16(mask_inverter, pred_mask_0_hi);
292
293 uint8x8_t result;
294 result =
295 CombinePred8(pred_0 + x, pred_1 + x, pred_mask_0_lo, pred_mask_1_lo);
296 vst1_u8(dst + x, result);
297
298 result = CombinePred8(pred_0 + x + 8, pred_1 + x + 8, pred_mask_0_hi,
299 pred_mask_1_hi);
300 vst1_u8(dst + x + 8, result);
301
302 x += 16;
303 } while (x < width);
304 dst += dst_stride;
305 pred_0 += width;
306 pred_1 += width;
307 mask += mask_stride << subsampling_y;
308 } while (++y < height);
309 }
310
311 template <int subsampling_x, int subsampling_y>
GetInterIntraMask4x2(const uint8_t * mask,ptrdiff_t mask_stride)312 inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask,
313 ptrdiff_t mask_stride) {
314 if (subsampling_x == 1) {
315 return GetMask4x2<subsampling_y>(mask);
316 }
317 // When using intra or difference weighted masks, the function doesn't use
318 // subsampling, so |mask_stride| may be 4 or 8.
319 assert(subsampling_y == 0 && subsampling_x == 0);
320 const uint8x8_t mask_val0 = Load4(mask);
321 return Load4<1>(mask + mask_stride, mask_val0);
322 }
323
InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t * LIBGAV1_RESTRICT const pred_0,uint8_t * LIBGAV1_RESTRICT const pred_1,const ptrdiff_t pred_stride_1,const uint8x8_t pred_mask_0,const uint8x8_t pred_mask_1)324 inline void InterIntraWriteMaskBlendLine8bpp4x2(
325 const uint8_t* LIBGAV1_RESTRICT const pred_0,
326 uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
327 const uint8x8_t pred_mask_0, const uint8x8_t pred_mask_1) {
328 const uint8x8_t pred_val_0 = vld1_u8(pred_0);
329 uint8x8_t pred_val_1 = Load4(pred_1);
330 pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1);
331
332 const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
333 const uint16x8_t weighted_combo =
334 vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
335 const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
336 StoreLo4(pred_1, result);
337 StoreHi4(pred_1 + pred_stride_1, result);
338 }
339
340 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4x4_NEON(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride)341 inline void InterIntraMaskBlending8bpp4x4_NEON(
342 const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
343 const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
344 const ptrdiff_t mask_stride) {
345 const uint8x8_t mask_inverter = vdup_n_u8(64);
346 uint8x8_t pred_mask_1 =
347 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
348 uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
349 InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
350 pred_mask_0, pred_mask_1);
351 pred_0 += 4 << 1;
352 pred_1 += pred_stride_1 << 1;
353 mask += mask_stride << (1 + subsampling_y);
354
355 pred_mask_1 =
356 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
357 pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
358 InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
359 pred_mask_0, pred_mask_1);
360 }
361
362 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4xH_NEON(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,const int height)363 inline void InterIntraMaskBlending8bpp4xH_NEON(
364 const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
365 const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
366 const ptrdiff_t mask_stride, const int height) {
367 if (height == 4) {
368 InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
369 pred_0, pred_1, pred_stride_1, mask, mask_stride);
370 return;
371 }
372 int y = 0;
373 do {
374 InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
375 pred_0, pred_1, pred_stride_1, mask, mask_stride);
376 pred_0 += 4 << 2;
377 pred_1 += pred_stride_1 << 2;
378 mask += mask_stride << (2 + subsampling_y);
379
380 InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
381 pred_0, pred_1, pred_stride_1, mask, mask_stride);
382 pred_0 += 4 << 2;
383 pred_1 += pred_stride_1 << 2;
384 mask += mask_stride << (2 + subsampling_y);
385 y += 8;
386 } while (y < height);
387 }
388
389 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp8xH_NEON(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,const int height)390 inline void InterIntraMaskBlending8bpp8xH_NEON(
391 const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
392 const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
393 const ptrdiff_t mask_stride, const int height) {
394 const uint8x8_t mask_inverter = vdup_n_u8(64);
395 int y = height;
396 do {
397 const uint8x8_t pred_mask_1 = GetMask8<subsampling_x, subsampling_y>(mask);
398 // 64 - mask
399 const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
400 const uint8x8_t pred_val_0 = vld1_u8(pred_0);
401 const uint8x8_t pred_val_1 = vld1_u8(pred_1);
402 const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
403 // weighted_pred0 + weighted_pred1
404 const uint16x8_t weighted_combo =
405 vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
406 const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
407 vst1_u8(pred_1, result);
408
409 pred_0 += 8;
410 pred_1 += pred_stride_1;
411 mask += mask_stride << subsampling_y;
412 } while (--y != 0);
413 }
414
415 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend8bpp_NEON(const uint8_t * LIBGAV1_RESTRICT prediction_0,uint8_t * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height)416 inline void InterIntraMaskBlend8bpp_NEON(
417 const uint8_t* LIBGAV1_RESTRICT prediction_0,
418 uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
419 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
420 const int width, const int height) {
421 if (width == 4) {
422 InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>(
423 prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
424 height);
425 return;
426 }
427 if (width == 8) {
428 InterIntraMaskBlending8bpp8xH_NEON<subsampling_x, subsampling_y>(
429 prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
430 height);
431 return;
432 }
433 const uint8_t* mask = mask_ptr;
434 const uint8x16_t mask_inverter = vdupq_n_u8(64);
435 int y = 0;
436 do {
437 int x = 0;
438 do {
439 const uint8x16_t pred_mask_1 = GetMask16<subsampling_x, subsampling_y>(
440 mask + (x << subsampling_x), mask_stride);
441 // 64 - mask
442 const uint8x16_t pred_mask_0 = vsubq_u8(mask_inverter, pred_mask_1);
443 const uint8x8_t pred_val_0_lo = vld1_u8(prediction_0);
444 prediction_0 += 8;
445 const uint8x8_t pred_val_0_hi = vld1_u8(prediction_0);
446 prediction_0 += 8;
447 // Ensure armv7 build combines the load.
448 const uint8x16_t pred_val_1 = vld1q_u8(prediction_1 + x);
449 const uint8x8_t pred_val_1_lo = vget_low_u8(pred_val_1);
450 const uint8x8_t pred_val_1_hi = vget_high_u8(pred_val_1);
451 const uint16x8_t weighted_pred_0_lo =
452 vmull_u8(vget_low_u8(pred_mask_0), pred_val_0_lo);
453 // weighted_pred0 + weighted_pred1
454 const uint16x8_t weighted_combo_lo =
455 vmlal_u8(weighted_pred_0_lo, vget_low_u8(pred_mask_1), pred_val_1_lo);
456 const uint8x8_t result_lo = vrshrn_n_u16(weighted_combo_lo, 6);
457 vst1_u8(prediction_1 + x, result_lo);
458 const uint16x8_t weighted_pred_0_hi =
459 vmull_u8(vget_high_u8(pred_mask_0), pred_val_0_hi);
460 // weighted_pred0 + weighted_pred1
461 const uint16x8_t weighted_combo_hi = vmlal_u8(
462 weighted_pred_0_hi, vget_high_u8(pred_mask_1), pred_val_1_hi);
463 const uint8x8_t result_hi = vrshrn_n_u16(weighted_combo_hi, 6);
464 vst1_u8(prediction_1 + x + 8, result_hi);
465
466 x += 16;
467 } while (x < width);
468 prediction_1 += prediction_stride_1;
469 mask += mask_stride << subsampling_y;
470 } while (++y < height);
471 }
472
Init8bpp()473 void Init8bpp() {
474 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
475 assert(dsp != nullptr);
476 dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0>;
477 dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0>;
478 dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1>;
479 // The is_inter_intra index of mask_blend[][] is replaced by
480 // inter_intra_mask_blend_8bpp[] in 8-bit.
481 dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_NEON<0, 0>;
482 dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_NEON<1, 0>;
483 dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_NEON<1, 1>;
484 }
485
486 } // namespace
487 } // namespace low_bitdepth
488
489 #if LIBGAV1_MAX_BITDEPTH >= 10
490 namespace high_bitdepth {
491 namespace {
492
493 template <int subsampling_x, int subsampling_y>
GetMask4x2(const uint8_t * mask,ptrdiff_t mask_stride)494 inline uint16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
495 if (subsampling_x == 1) {
496 const uint8x8_t mask_val0 = vld1_u8(mask);
497 const uint8x8_t mask_val1 = vld1_u8(mask + (mask_stride << subsampling_y));
498 uint16x8_t final_val = vpaddlq_u8(vcombine_u8(mask_val0, mask_val1));
499 if (subsampling_y == 1) {
500 const uint8x8_t next_mask_val0 = vld1_u8(mask + mask_stride);
501 const uint8x8_t next_mask_val1 = vld1_u8(mask + mask_stride * 3);
502 final_val = vaddq_u16(
503 final_val, vpaddlq_u8(vcombine_u8(next_mask_val0, next_mask_val1)));
504 }
505 return vrshrq_n_u16(final_val, subsampling_y + 1);
506 }
507 assert(subsampling_y == 0 && subsampling_x == 0);
508 const uint8x8_t mask_val0 = Load4(mask);
509 const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0);
510 return vmovl_u8(mask_val);
511 }
512
513 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * mask,ptrdiff_t mask_stride)514 inline uint16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
515 if (subsampling_x == 1) {
516 uint16x8_t mask_val = vpaddlq_u8(vld1q_u8(mask));
517 if (subsampling_y == 1) {
518 const uint16x8_t next_mask_val = vpaddlq_u8(vld1q_u8(mask + mask_stride));
519 mask_val = vaddq_u16(mask_val, next_mask_val);
520 }
521 return vrshrq_n_u16(mask_val, 1 + subsampling_y);
522 }
523 assert(subsampling_y == 0 && subsampling_x == 0);
524 const uint8x8_t mask_val = vld1_u8(mask);
525 return vmovl_u8(mask_val);
526 }
527
528 template <bool is_inter_intra>
SumWeightedPred(const uint16x8_t pred_mask_0,const uint16x8_t pred_mask_1,const uint16x8_t pred_val_0,const uint16x8_t pred_val_1)529 uint16x8_t SumWeightedPred(const uint16x8_t pred_mask_0,
530 const uint16x8_t pred_mask_1,
531 const uint16x8_t pred_val_0,
532 const uint16x8_t pred_val_1) {
533 if (is_inter_intra) {
534 // dst[x] = static_cast<Pixel>(RightShiftWithRounding(
535 // mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6));
536 uint16x8_t sum = vmulq_u16(pred_mask_1, pred_val_0);
537 sum = vmlaq_u16(sum, pred_mask_0, pred_val_1);
538 return vrshrq_n_u16(sum, 6);
539 } else {
540 // int res = (mask_value * prediction_0[x] +
541 // (64 - mask_value) * prediction_1[x]) >> 6;
542 const uint32x4_t weighted_pred_0_lo =
543 vmull_u16(vget_low_u16(pred_mask_0), vget_low_u16(pred_val_0));
544 const uint32x4_t weighted_pred_0_hi = VMullHighU16(pred_mask_0, pred_val_0);
545 uint32x4x2_t sum;
546 sum.val[0] = vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1),
547 vget_low_u16(pred_val_1));
548 sum.val[1] = VMlalHighU16(weighted_pred_0_hi, pred_mask_1, pred_val_1);
549 return vcombine_u16(vshrn_n_u32(sum.val[0], 6), vshrn_n_u32(sum.val[1], 6));
550 }
551 }
552
553 template <bool is_inter_intra, int width, int bitdepth = 10>
StoreShiftedResult(uint8_t * dst,const uint16x8_t result,const ptrdiff_t dst_stride=0)554 inline void StoreShiftedResult(uint8_t* dst, const uint16x8_t result,
555 const ptrdiff_t dst_stride = 0) {
556 if (is_inter_intra) {
557 if (width == 4) {
558 // Store 2 lines of width 4.
559 assert(dst_stride != 0);
560 vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(result));
561 vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
562 vget_high_u16(result));
563 } else {
564 // Store 1 line of width 8.
565 vst1q_u16(reinterpret_cast<uint16_t*>(dst), result);
566 }
567 } else {
568 // res -= (bitdepth == 8) ? 0 : kCompoundOffset;
569 // dst[x] = static_cast<Pixel>(
570 // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
571 // (1 << kBitdepth8) - 1));
572 constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
573 const uint16x8_t compound_result =
574 vminq_u16(vrshrq_n_u16(vqsubq_u16(result, vdupq_n_u16(kCompoundOffset)),
575 inter_post_round_bits),
576 vdupq_n_u16((1 << bitdepth) - 1));
577 if (width == 4) {
578 // Store 2 lines of width 4.
579 assert(dst_stride != 0);
580 vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(compound_result));
581 vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
582 vget_high_u16(compound_result));
583 } else {
584 // Store 1 line of width 8.
585 vst1q_u16(reinterpret_cast<uint16_t*>(dst), compound_result);
586 }
587 }
588 }
589
590 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlend4x2_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const uint16x8_t mask_inverter,const ptrdiff_t mask_stride,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)591 inline void MaskBlend4x2_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
592 const uint16_t* LIBGAV1_RESTRICT pred_1,
593 const ptrdiff_t pred_stride_1,
594 const uint8_t* LIBGAV1_RESTRICT mask,
595 const uint16x8_t mask_inverter,
596 const ptrdiff_t mask_stride,
597 uint8_t* LIBGAV1_RESTRICT dst,
598 const ptrdiff_t dst_stride) {
599 // This works because stride == width == 4.
600 const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
601 const uint16x8_t pred_val_1 =
602 is_inter_intra
603 ? vcombine_u16(vld1_u16(pred_1), vld1_u16(pred_1 + pred_stride_1))
604 : vld1q_u16(pred_1);
605 const uint16x8_t pred_mask_0 =
606 GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
607 const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
608 const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
609 pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
610
611 StoreShiftedResult<is_inter_intra, 4>(dst, weighted_pred_sum, dst_stride);
612 }
613
614 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlending4x4_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)615 inline void MaskBlending4x4_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
616 const uint16_t* LIBGAV1_RESTRICT pred_1,
617 const ptrdiff_t pred_stride_1,
618 const uint8_t* LIBGAV1_RESTRICT mask,
619 const ptrdiff_t mask_stride,
620 uint8_t* LIBGAV1_RESTRICT dst,
621 const ptrdiff_t dst_stride) {
622 // Double stride because the function works on 2 lines at a time.
623 const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
624 const ptrdiff_t dst_stride_y = dst_stride << 1;
625 const uint16x8_t mask_inverter = vdupq_n_u16(64);
626
627 MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
628 pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
629 dst_stride);
630
631 pred_0 += 4 << 1;
632 pred_1 += pred_stride_1 << 1;
633 mask += mask_stride_y;
634 dst += dst_stride_y;
635
636 MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
637 pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
638 dst_stride);
639 }
640
641 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlending4xH_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)642 inline void MaskBlending4xH_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
643 const uint16_t* LIBGAV1_RESTRICT pred_1,
644 const ptrdiff_t pred_stride_1,
645 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
646 const ptrdiff_t mask_stride, const int height,
647 uint8_t* LIBGAV1_RESTRICT dst,
648 const ptrdiff_t dst_stride) {
649 const uint8_t* mask = mask_ptr;
650 if (height == 4) {
651 MaskBlending4x4_NEON<subsampling_x, subsampling_y, is_inter_intra>(
652 pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
653 return;
654 }
655 // Double stride because the function works on 2 lines at a time.
656 const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
657 const ptrdiff_t dst_stride_y = dst_stride << 1;
658 const uint16x8_t mask_inverter = vdupq_n_u16(64);
659 int y = 0;
660 do {
661 MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
662 pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
663 dst_stride);
664 pred_0 += 4 << 1;
665 pred_1 += pred_stride_1 << 1;
666 mask += mask_stride_y;
667 dst += dst_stride_y;
668
669 MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
670 pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
671 dst_stride);
672 pred_0 += 4 << 1;
673 pred_1 += pred_stride_1 << 1;
674 mask += mask_stride_y;
675 dst += dst_stride_y;
676
677 MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
678 pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
679 dst_stride);
680 pred_0 += 4 << 1;
681 pred_1 += pred_stride_1 << 1;
682 mask += mask_stride_y;
683 dst += dst_stride_y;
684
685 MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
686 pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
687 dst_stride);
688 pred_0 += 4 << 1;
689 pred_1 += pred_stride_1 << 1;
690 mask += mask_stride_y;
691 dst += dst_stride_y;
692 y += 8;
693 } while (y < height);
694 }
695
696 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlend8_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT mask,const uint16x8_t mask_inverter,const ptrdiff_t mask_stride,uint8_t * LIBGAV1_RESTRICT dst)697 void MaskBlend8_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
698 const uint16_t* LIBGAV1_RESTRICT pred_1,
699 const uint8_t* LIBGAV1_RESTRICT mask,
700 const uint16x8_t mask_inverter,
701 const ptrdiff_t mask_stride,
702 uint8_t* LIBGAV1_RESTRICT dst) {
703 const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
704 const uint16x8_t pred_val_1 = vld1q_u16(pred_1);
705 const uint16x8_t pred_mask_0 =
706 GetMask8<subsampling_x, subsampling_y>(mask, mask_stride);
707 const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
708 const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
709 pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
710
711 StoreShiftedResult<is_inter_intra, 8>(dst, weighted_pred_sum);
712 }
713
714 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)715 inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
716 const void* LIBGAV1_RESTRICT prediction_1,
717 const ptrdiff_t prediction_stride_1,
718 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
719 const ptrdiff_t mask_stride, const int width,
720 const int height, void* LIBGAV1_RESTRICT dest,
721 const ptrdiff_t dst_stride) {
722 if (!is_inter_intra) {
723 assert(prediction_stride_1 == width);
724 }
725 auto* dst = static_cast<uint8_t*>(dest);
726 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
727 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
728 if (width == 4) {
729 MaskBlending4xH_NEON<subsampling_x, subsampling_y, is_inter_intra>(
730 pred_0, pred_1, prediction_stride_1, mask_ptr, mask_stride, height, dst,
731 dst_stride);
732 return;
733 }
734 const ptrdiff_t mask_stride_y = mask_stride << subsampling_y;
735 const uint8_t* mask = mask_ptr;
736 const uint16x8_t mask_inverter = vdupq_n_u16(64);
737 int y = 0;
738 do {
739 int x = 0;
740 do {
741 MaskBlend8_NEON<subsampling_x, subsampling_y, is_inter_intra>(
742 pred_0 + x, pred_1 + x, mask + (x << subsampling_x), mask_inverter,
743 mask_stride,
744 reinterpret_cast<uint8_t*>(reinterpret_cast<uint16_t*>(dst) + x));
745 x += 8;
746 } while (x < width);
747 dst += dst_stride;
748 pred_0 += width;
749 pred_1 += prediction_stride_1;
750 mask += mask_stride_y;
751 } while (++y < height);
752 }
753
Init10bpp()754 void Init10bpp() {
755 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
756 assert(dsp != nullptr);
757 dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0, false>;
758 dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0, false>;
759 dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1, false>;
760
761 dsp->mask_blend[0][1] = MaskBlend_NEON<0, 0, true>;
762 dsp->mask_blend[1][1] = MaskBlend_NEON<1, 0, true>;
763 dsp->mask_blend[2][1] = MaskBlend_NEON<1, 1, true>;
764 }
765
766 } // namespace
767 } // namespace high_bitdepth
768 #endif // LIBGAV1_MAX_BITDEPTH >= 10
769
MaskBlendInit_NEON()770 void MaskBlendInit_NEON() {
771 low_bitdepth::Init8bpp();
772 #if LIBGAV1_MAX_BITDEPTH >= 10
773 high_bitdepth::Init10bpp();
774 #endif
775 }
776
777 } // namespace dsp
778 } // namespace libgav1
779
780 #else // !LIBGAV1_ENABLE_NEON
781
782 namespace libgav1 {
783 namespace dsp {
784
MaskBlendInit_NEON()785 void MaskBlendInit_NEON() {}
786
787 } // namespace dsp
788 } // namespace libgav1
789 #endif // LIBGAV1_ENABLE_NEON
790