• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/mask_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35 
36 template <int subsampling_y>
GetMask4x2(const uint8_t * mask)37 inline uint8x8_t GetMask4x2(const uint8_t* mask) {
38   if (subsampling_y == 1) {
39     const uint8x16x2_t mask_val = vld2q_u8(mask);
40     const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]);
41     const uint32x2_t row_01 = vreinterpret_u32_u8(vget_low_u8(combined_horz));
42     const uint32x2_t row_23 = vreinterpret_u32_u8(vget_high_u8(combined_horz));
43 
44     const uint32x2x2_t row_02_13 = vtrn_u32(row_01, row_23);
45     // Use a halving add to work around the case where all |mask| values are 64.
46     return vrshr_n_u8(vhadd_u8(vreinterpret_u8_u32(row_02_13.val[0]),
47                                vreinterpret_u8_u32(row_02_13.val[1])),
48                       1);
49   }
50   // subsampling_x == 1
51   const uint8x8x2_t mask_val = vld2_u8(mask);
52   return vrhadd_u8(mask_val.val[0], mask_val.val[1]);
53 }
54 
55 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * mask)56 inline uint8x8_t GetMask8(const uint8_t* mask) {
57   if (subsampling_x == 1 && subsampling_y == 1) {
58     const uint8x16x2_t mask_val = vld2q_u8(mask);
59     const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]);
60     // Use a halving add to work around the case where all |mask| values are 64.
61     return vrshr_n_u8(
62         vhadd_u8(vget_low_u8(combined_horz), vget_high_u8(combined_horz)), 1);
63   }
64   if (subsampling_x == 1) {
65     const uint8x8x2_t mask_val = vld2_u8(mask);
66     return vrhadd_u8(mask_val.val[0], mask_val.val[1]);
67   }
68   assert(subsampling_y == 0 && subsampling_x == 0);
69   return vld1_u8(mask);
70 }
71 
WriteMaskBlendLine4x2(const int16_t * LIBGAV1_RESTRICT const pred_0,const int16_t * LIBGAV1_RESTRICT const pred_1,const int16x8_t pred_mask_0,const int16x8_t pred_mask_1,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)72 inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
73                                   const int16_t* LIBGAV1_RESTRICT const pred_1,
74                                   const int16x8_t pred_mask_0,
75                                   const int16x8_t pred_mask_1,
76                                   uint8_t* LIBGAV1_RESTRICT dst,
77                                   const ptrdiff_t dst_stride) {
78   const int16x8_t pred_val_0 = vld1q_s16(pred_0);
79   const int16x8_t pred_val_1 = vld1q_s16(pred_1);
80   // int res = (mask_value * prediction_0[x] +
81   //      (64 - mask_value) * prediction_1[x]) >> 6;
82   const int32x4_t weighted_pred_0_lo =
83       vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
84   const int32x4_t weighted_pred_0_hi =
85       vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
86   const int32x4_t weighted_combo_lo = vmlal_s16(
87       weighted_pred_0_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1));
88   const int32x4_t weighted_combo_hi =
89       vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1),
90                 vget_high_s16(pred_val_1));
91   // dst[x] = static_cast<Pixel>(
92   //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
93   //         (1 << kBitdepth8) - 1));
94   const uint8x8_t result =
95       vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
96                                   vshrn_n_s32(weighted_combo_hi, 6)),
97                      4);
98   StoreLo4(dst, result);
99   StoreHi4(dst + dst_stride, result);
100 }
101 
102 template <int subsampling_y>
MaskBlending4x4_NEON(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT mask,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)103 inline void MaskBlending4x4_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
104                                  const int16_t* LIBGAV1_RESTRICT pred_1,
105                                  const uint8_t* LIBGAV1_RESTRICT mask,
106                                  uint8_t* LIBGAV1_RESTRICT dst,
107                                  const ptrdiff_t dst_stride) {
108   constexpr int subsampling_x = 1;
109   constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
110   const int16x8_t mask_inverter = vdupq_n_s16(64);
111   // Compound predictors use int16_t values and need to multiply long because
112   // the Convolve range * 64 is 20 bits. Unfortunately there is no multiply
113   // int16_t by int8_t and accumulate into int32_t instruction.
114   int16x8_t pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
115   int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
116   WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
117                         dst_stride);
118   pred_0 += 4 << subsampling_x;
119   pred_1 += 4 << subsampling_x;
120   mask += mask_stride << (subsampling_x + subsampling_y);
121   dst += dst_stride << subsampling_x;
122 
123   pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
124   pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
125   WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
126                         dst_stride);
127 }
128 
129 template <int subsampling_y>
MaskBlending4xH_NEON(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)130 inline void MaskBlending4xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
131                                  const int16_t* LIBGAV1_RESTRICT pred_1,
132                                  const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
133                                  const int height,
134                                  uint8_t* LIBGAV1_RESTRICT dst,
135                                  const ptrdiff_t dst_stride) {
136   const uint8_t* mask = mask_ptr;
137   if (height == 4) {
138     MaskBlending4x4_NEON<subsampling_y>(pred_0, pred_1, mask, dst, dst_stride);
139     return;
140   }
141   constexpr int subsampling_x = 1;
142   constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
143   const int16x8_t mask_inverter = vdupq_n_s16(64);
144   int y = 0;
145   do {
146     int16x8_t pred_mask_0 =
147         vreinterpretq_s16_u16(vmovl_u8(GetMask4x2<subsampling_y>(mask)));
148     int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
149 
150     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
151                           dst_stride);
152     pred_0 += 4 << subsampling_x;
153     pred_1 += 4 << subsampling_x;
154     mask += mask_stride << (subsampling_x + subsampling_y);
155     dst += dst_stride << subsampling_x;
156 
157     pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
158     pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
159     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
160                           dst_stride);
161     pred_0 += 4 << subsampling_x;
162     pred_1 += 4 << subsampling_x;
163     mask += mask_stride << (subsampling_x + subsampling_y);
164     dst += dst_stride << subsampling_x;
165 
166     pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
167     pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
168     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
169                           dst_stride);
170     pred_0 += 4 << subsampling_x;
171     pred_1 += 4 << subsampling_x;
172     mask += mask_stride << (subsampling_x + subsampling_y);
173     dst += dst_stride << subsampling_x;
174 
175     pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
176     pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
177     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
178                           dst_stride);
179     pred_0 += 4 << subsampling_x;
180     pred_1 += 4 << subsampling_x;
181     mask += mask_stride << (subsampling_x + subsampling_y);
182     dst += dst_stride << subsampling_x;
183     y += 8;
184   } while (y < height);
185 }
186 
CombinePred8(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const int16x8_t pred_mask_0,const int16x8_t pred_mask_1)187 inline uint8x8_t CombinePred8(const int16_t* LIBGAV1_RESTRICT pred_0,
188                               const int16_t* LIBGAV1_RESTRICT pred_1,
189                               const int16x8_t pred_mask_0,
190                               const int16x8_t pred_mask_1) {
191   // First 8 values.
192   const int16x8_t pred_val_0 = vld1q_s16(pred_0);
193   const int16x8_t pred_val_1 = vld1q_s16(pred_1);
194   // int res = (mask_value * prediction_0[x] +
195   //      (64 - mask_value) * prediction_1[x]) >> 6;
196   const int32x4_t weighted_pred_lo =
197       vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
198   const int32x4_t weighted_pred_hi =
199       vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
200   const int32x4_t weighted_combo_lo = vmlal_s16(
201       weighted_pred_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1));
202   const int32x4_t weighted_combo_hi = vmlal_s16(
203       weighted_pred_hi, vget_high_s16(pred_mask_1), vget_high_s16(pred_val_1));
204 
205   // dst[x] = static_cast<Pixel>(
206   //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
207   //           (1 << kBitdepth8) - 1));
208   return vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
209                                      vshrn_n_s32(weighted_combo_hi, 6)),
210                         4);
211 }
212 
213 template <int subsampling_x, int subsampling_y>
MaskBlending8xH_NEON(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)214 inline void MaskBlending8xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
215                                  const int16_t* LIBGAV1_RESTRICT pred_1,
216                                  const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
217                                  const int height,
218                                  uint8_t* LIBGAV1_RESTRICT dst,
219                                  const ptrdiff_t dst_stride) {
220   const uint8_t* mask = mask_ptr;
221   const int16x8_t mask_inverter = vdupq_n_s16(64);
222   int y = height;
223   do {
224     const int16x8_t pred_mask_0 =
225         ZeroExtend(GetMask8<subsampling_x, subsampling_y>(mask));
226     // 64 - mask
227     const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
228     const uint8x8_t result =
229         CombinePred8(pred_0, pred_1, pred_mask_0, pred_mask_1);
230     vst1_u8(dst, result);
231     dst += dst_stride;
232     mask += 8 << (subsampling_x + subsampling_y);
233     pred_0 += 8;
234     pred_1 += 8;
235   } while (--y != 0);
236 }
237 
238 template <int subsampling_x, int subsampling_y>
GetMask16(const uint8_t * mask,const ptrdiff_t mask_stride)239 inline uint8x16_t GetMask16(const uint8_t* mask, const ptrdiff_t mask_stride) {
240   if (subsampling_x == 1 && subsampling_y == 1) {
241     const uint8x16x2_t mask_val0 = vld2q_u8(mask);
242     const uint8x16x2_t mask_val1 = vld2q_u8(mask + mask_stride);
243     const uint8x16_t combined_horz0 =
244         vaddq_u8(mask_val0.val[0], mask_val0.val[1]);
245     const uint8x16_t combined_horz1 =
246         vaddq_u8(mask_val1.val[0], mask_val1.val[1]);
247     // Use a halving add to work around the case where all |mask| values are 64.
248     return vrshrq_n_u8(vhaddq_u8(combined_horz0, combined_horz1), 1);
249   }
250   if (subsampling_x == 1) {
251     const uint8x16x2_t mask_val = vld2q_u8(mask);
252     return vrhaddq_u8(mask_val.val[0], mask_val.val[1]);
253   }
254   assert(subsampling_y == 0 && subsampling_x == 0);
255   return vld1q_u8(mask);
256 }
257 
258 template <int subsampling_x, int subsampling_y>
MaskBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)259 inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
260                            const void* LIBGAV1_RESTRICT prediction_1,
261                            const ptrdiff_t /*prediction_stride_1*/,
262                            const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
263                            const ptrdiff_t mask_stride, const int width,
264                            const int height, void* LIBGAV1_RESTRICT dest,
265                            const ptrdiff_t dst_stride) {
266   auto* dst = static_cast<uint8_t*>(dest);
267   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
268   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
269   if (width == 4) {
270     MaskBlending4xH_NEON<subsampling_y>(pred_0, pred_1, mask_ptr, height, dst,
271                                         dst_stride);
272     return;
273   }
274   if (width == 8) {
275     MaskBlending8xH_NEON<subsampling_x, subsampling_y>(pred_0, pred_1, mask_ptr,
276                                                        height, dst, dst_stride);
277     return;
278   }
279   const uint8_t* mask = mask_ptr;
280   const int16x8_t mask_inverter = vdupq_n_s16(64);
281   int y = 0;
282   do {
283     int x = 0;
284     do {
285       const uint8x16_t pred_mask_0 = GetMask16<subsampling_x, subsampling_y>(
286           mask + (x << subsampling_x), mask_stride);
287       const int16x8_t pred_mask_0_lo = ZeroExtend(vget_low_u8(pred_mask_0));
288       const int16x8_t pred_mask_0_hi = ZeroExtend(vget_high_u8(pred_mask_0));
289       // 64 - mask
290       const int16x8_t pred_mask_1_lo = vsubq_s16(mask_inverter, pred_mask_0_lo);
291       const int16x8_t pred_mask_1_hi = vsubq_s16(mask_inverter, pred_mask_0_hi);
292 
293       uint8x8_t result;
294       result =
295           CombinePred8(pred_0 + x, pred_1 + x, pred_mask_0_lo, pred_mask_1_lo);
296       vst1_u8(dst + x, result);
297 
298       result = CombinePred8(pred_0 + x + 8, pred_1 + x + 8, pred_mask_0_hi,
299                             pred_mask_1_hi);
300       vst1_u8(dst + x + 8, result);
301 
302       x += 16;
303     } while (x < width);
304     dst += dst_stride;
305     pred_0 += width;
306     pred_1 += width;
307     mask += mask_stride << subsampling_y;
308   } while (++y < height);
309 }
310 
311 template <int subsampling_x, int subsampling_y>
GetInterIntraMask4x2(const uint8_t * mask,ptrdiff_t mask_stride)312 inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask,
313                                       ptrdiff_t mask_stride) {
314   if (subsampling_x == 1) {
315     return GetMask4x2<subsampling_y>(mask);
316   }
317   // When using intra or difference weighted masks, the function doesn't use
318   // subsampling, so |mask_stride| may be 4 or 8.
319   assert(subsampling_y == 0 && subsampling_x == 0);
320   const uint8x8_t mask_val0 = Load4(mask);
321   return Load4<1>(mask + mask_stride, mask_val0);
322 }
323 
InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t * LIBGAV1_RESTRICT const pred_0,uint8_t * LIBGAV1_RESTRICT const pred_1,const ptrdiff_t pred_stride_1,const uint8x8_t pred_mask_0,const uint8x8_t pred_mask_1)324 inline void InterIntraWriteMaskBlendLine8bpp4x2(
325     const uint8_t* LIBGAV1_RESTRICT const pred_0,
326     uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
327     const uint8x8_t pred_mask_0, const uint8x8_t pred_mask_1) {
328   const uint8x8_t pred_val_0 = vld1_u8(pred_0);
329   uint8x8_t pred_val_1 = Load4(pred_1);
330   pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1);
331 
332   const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
333   const uint16x8_t weighted_combo =
334       vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
335   const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
336   StoreLo4(pred_1, result);
337   StoreHi4(pred_1 + pred_stride_1, result);
338 }
339 
340 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4x4_NEON(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride)341 inline void InterIntraMaskBlending8bpp4x4_NEON(
342     const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
343     const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
344     const ptrdiff_t mask_stride) {
345   const uint8x8_t mask_inverter = vdup_n_u8(64);
346   uint8x8_t pred_mask_1 =
347       GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
348   uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
349   InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
350                                       pred_mask_0, pred_mask_1);
351   pred_0 += 4 << 1;
352   pred_1 += pred_stride_1 << 1;
353   mask += mask_stride << (1 + subsampling_y);
354 
355   pred_mask_1 =
356       GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
357   pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
358   InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
359                                       pred_mask_0, pred_mask_1);
360 }
361 
362 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4xH_NEON(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,const int height)363 inline void InterIntraMaskBlending8bpp4xH_NEON(
364     const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
365     const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
366     const ptrdiff_t mask_stride, const int height) {
367   if (height == 4) {
368     InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
369         pred_0, pred_1, pred_stride_1, mask, mask_stride);
370     return;
371   }
372   int y = 0;
373   do {
374     InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
375         pred_0, pred_1, pred_stride_1, mask, mask_stride);
376     pred_0 += 4 << 2;
377     pred_1 += pred_stride_1 << 2;
378     mask += mask_stride << (2 + subsampling_y);
379 
380     InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
381         pred_0, pred_1, pred_stride_1, mask, mask_stride);
382     pred_0 += 4 << 2;
383     pred_1 += pred_stride_1 << 2;
384     mask += mask_stride << (2 + subsampling_y);
385     y += 8;
386   } while (y < height);
387 }
388 
389 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp8xH_NEON(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,const int height)390 inline void InterIntraMaskBlending8bpp8xH_NEON(
391     const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
392     const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
393     const ptrdiff_t mask_stride, const int height) {
394   const uint8x8_t mask_inverter = vdup_n_u8(64);
395   int y = height;
396   do {
397     const uint8x8_t pred_mask_1 = GetMask8<subsampling_x, subsampling_y>(mask);
398     // 64 - mask
399     const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
400     const uint8x8_t pred_val_0 = vld1_u8(pred_0);
401     const uint8x8_t pred_val_1 = vld1_u8(pred_1);
402     const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
403     // weighted_pred0 + weighted_pred1
404     const uint16x8_t weighted_combo =
405         vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
406     const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
407     vst1_u8(pred_1, result);
408 
409     pred_0 += 8;
410     pred_1 += pred_stride_1;
411     mask += mask_stride << subsampling_y;
412   } while (--y != 0);
413 }
414 
415 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend8bpp_NEON(const uint8_t * LIBGAV1_RESTRICT prediction_0,uint8_t * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height)416 inline void InterIntraMaskBlend8bpp_NEON(
417     const uint8_t* LIBGAV1_RESTRICT prediction_0,
418     uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
419     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
420     const int width, const int height) {
421   if (width == 4) {
422     InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>(
423         prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
424         height);
425     return;
426   }
427   if (width == 8) {
428     InterIntraMaskBlending8bpp8xH_NEON<subsampling_x, subsampling_y>(
429         prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
430         height);
431     return;
432   }
433   const uint8_t* mask = mask_ptr;
434   const uint8x16_t mask_inverter = vdupq_n_u8(64);
435   int y = 0;
436   do {
437     int x = 0;
438     do {
439       const uint8x16_t pred_mask_1 = GetMask16<subsampling_x, subsampling_y>(
440           mask + (x << subsampling_x), mask_stride);
441       // 64 - mask
442       const uint8x16_t pred_mask_0 = vsubq_u8(mask_inverter, pred_mask_1);
443       const uint8x8_t pred_val_0_lo = vld1_u8(prediction_0);
444       prediction_0 += 8;
445       const uint8x8_t pred_val_0_hi = vld1_u8(prediction_0);
446       prediction_0 += 8;
447       // Ensure armv7 build combines the load.
448       const uint8x16_t pred_val_1 = vld1q_u8(prediction_1 + x);
449       const uint8x8_t pred_val_1_lo = vget_low_u8(pred_val_1);
450       const uint8x8_t pred_val_1_hi = vget_high_u8(pred_val_1);
451       const uint16x8_t weighted_pred_0_lo =
452           vmull_u8(vget_low_u8(pred_mask_0), pred_val_0_lo);
453       // weighted_pred0 + weighted_pred1
454       const uint16x8_t weighted_combo_lo =
455           vmlal_u8(weighted_pred_0_lo, vget_low_u8(pred_mask_1), pred_val_1_lo);
456       const uint8x8_t result_lo = vrshrn_n_u16(weighted_combo_lo, 6);
457       vst1_u8(prediction_1 + x, result_lo);
458       const uint16x8_t weighted_pred_0_hi =
459           vmull_u8(vget_high_u8(pred_mask_0), pred_val_0_hi);
460       // weighted_pred0 + weighted_pred1
461       const uint16x8_t weighted_combo_hi = vmlal_u8(
462           weighted_pred_0_hi, vget_high_u8(pred_mask_1), pred_val_1_hi);
463       const uint8x8_t result_hi = vrshrn_n_u16(weighted_combo_hi, 6);
464       vst1_u8(prediction_1 + x + 8, result_hi);
465 
466       x += 16;
467     } while (x < width);
468     prediction_1 += prediction_stride_1;
469     mask += mask_stride << subsampling_y;
470   } while (++y < height);
471 }
472 
Init8bpp()473 void Init8bpp() {
474   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
475   assert(dsp != nullptr);
476   dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0>;
477   dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0>;
478   dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1>;
479   // The is_inter_intra index of mask_blend[][] is replaced by
480   // inter_intra_mask_blend_8bpp[] in 8-bit.
481   dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_NEON<0, 0>;
482   dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_NEON<1, 0>;
483   dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_NEON<1, 1>;
484 }
485 
486 }  // namespace
487 }  // namespace low_bitdepth
488 
489 #if LIBGAV1_MAX_BITDEPTH >= 10
490 namespace high_bitdepth {
491 namespace {
492 
493 template <int subsampling_x, int subsampling_y>
GetMask4x2(const uint8_t * mask,ptrdiff_t mask_stride)494 inline uint16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
495   if (subsampling_x == 1) {
496     const uint8x8_t mask_val0 = vld1_u8(mask);
497     const uint8x8_t mask_val1 = vld1_u8(mask + (mask_stride << subsampling_y));
498     uint16x8_t final_val = vpaddlq_u8(vcombine_u8(mask_val0, mask_val1));
499     if (subsampling_y == 1) {
500       const uint8x8_t next_mask_val0 = vld1_u8(mask + mask_stride);
501       const uint8x8_t next_mask_val1 = vld1_u8(mask + mask_stride * 3);
502       final_val = vaddq_u16(
503           final_val, vpaddlq_u8(vcombine_u8(next_mask_val0, next_mask_val1)));
504     }
505     return vrshrq_n_u16(final_val, subsampling_y + 1);
506   }
507   assert(subsampling_y == 0 && subsampling_x == 0);
508   const uint8x8_t mask_val0 = Load4(mask);
509   const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0);
510   return vmovl_u8(mask_val);
511 }
512 
513 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * mask,ptrdiff_t mask_stride)514 inline uint16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
515   if (subsampling_x == 1) {
516     uint16x8_t mask_val = vpaddlq_u8(vld1q_u8(mask));
517     if (subsampling_y == 1) {
518       const uint16x8_t next_mask_val = vpaddlq_u8(vld1q_u8(mask + mask_stride));
519       mask_val = vaddq_u16(mask_val, next_mask_val);
520     }
521     return vrshrq_n_u16(mask_val, 1 + subsampling_y);
522   }
523   assert(subsampling_y == 0 && subsampling_x == 0);
524   const uint8x8_t mask_val = vld1_u8(mask);
525   return vmovl_u8(mask_val);
526 }
527 
528 template <bool is_inter_intra>
SumWeightedPred(const uint16x8_t pred_mask_0,const uint16x8_t pred_mask_1,const uint16x8_t pred_val_0,const uint16x8_t pred_val_1)529 uint16x8_t SumWeightedPred(const uint16x8_t pred_mask_0,
530                            const uint16x8_t pred_mask_1,
531                            const uint16x8_t pred_val_0,
532                            const uint16x8_t pred_val_1) {
533   if (is_inter_intra) {
534     // dst[x] = static_cast<Pixel>(RightShiftWithRounding(
535     //     mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6));
536     uint16x8_t sum = vmulq_u16(pred_mask_1, pred_val_0);
537     sum = vmlaq_u16(sum, pred_mask_0, pred_val_1);
538     return vrshrq_n_u16(sum, 6);
539   } else {
540     // int res = (mask_value * prediction_0[x] +
541     //      (64 - mask_value) * prediction_1[x]) >> 6;
542     const uint32x4_t weighted_pred_0_lo =
543         vmull_u16(vget_low_u16(pred_mask_0), vget_low_u16(pred_val_0));
544     const uint32x4_t weighted_pred_0_hi = VMullHighU16(pred_mask_0, pred_val_0);
545     uint32x4x2_t sum;
546     sum.val[0] = vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1),
547                            vget_low_u16(pred_val_1));
548     sum.val[1] = VMlalHighU16(weighted_pred_0_hi, pred_mask_1, pred_val_1);
549     return vcombine_u16(vshrn_n_u32(sum.val[0], 6), vshrn_n_u32(sum.val[1], 6));
550   }
551 }
552 
553 template <bool is_inter_intra, int width, int bitdepth = 10>
StoreShiftedResult(uint8_t * dst,const uint16x8_t result,const ptrdiff_t dst_stride=0)554 inline void StoreShiftedResult(uint8_t* dst, const uint16x8_t result,
555                                const ptrdiff_t dst_stride = 0) {
556   if (is_inter_intra) {
557     if (width == 4) {
558       // Store 2 lines of width 4.
559       assert(dst_stride != 0);
560       vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(result));
561       vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
562                vget_high_u16(result));
563     } else {
564       // Store 1 line of width 8.
565       vst1q_u16(reinterpret_cast<uint16_t*>(dst), result);
566     }
567   } else {
568     // res -= (bitdepth == 8) ? 0 : kCompoundOffset;
569     // dst[x] = static_cast<Pixel>(
570     //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
571     //           (1 << kBitdepth8) - 1));
572     constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
573     const uint16x8_t compound_result =
574         vminq_u16(vrshrq_n_u16(vqsubq_u16(result, vdupq_n_u16(kCompoundOffset)),
575                                inter_post_round_bits),
576                   vdupq_n_u16((1 << bitdepth) - 1));
577     if (width == 4) {
578       // Store 2 lines of width 4.
579       assert(dst_stride != 0);
580       vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(compound_result));
581       vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
582                vget_high_u16(compound_result));
583     } else {
584       // Store 1 line of width 8.
585       vst1q_u16(reinterpret_cast<uint16_t*>(dst), compound_result);
586     }
587   }
588 }
589 
590 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlend4x2_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const uint16x8_t mask_inverter,const ptrdiff_t mask_stride,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)591 inline void MaskBlend4x2_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
592                               const uint16_t* LIBGAV1_RESTRICT pred_1,
593                               const ptrdiff_t pred_stride_1,
594                               const uint8_t* LIBGAV1_RESTRICT mask,
595                               const uint16x8_t mask_inverter,
596                               const ptrdiff_t mask_stride,
597                               uint8_t* LIBGAV1_RESTRICT dst,
598                               const ptrdiff_t dst_stride) {
599   // This works because stride == width == 4.
600   const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
601   const uint16x8_t pred_val_1 =
602       is_inter_intra
603           ? vcombine_u16(vld1_u16(pred_1), vld1_u16(pred_1 + pred_stride_1))
604           : vld1q_u16(pred_1);
605   const uint16x8_t pred_mask_0 =
606       GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
607   const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
608   const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
609       pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
610 
611   StoreShiftedResult<is_inter_intra, 4>(dst, weighted_pred_sum, dst_stride);
612 }
613 
614 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlending4x4_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)615 inline void MaskBlending4x4_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
616                                  const uint16_t* LIBGAV1_RESTRICT pred_1,
617                                  const ptrdiff_t pred_stride_1,
618                                  const uint8_t* LIBGAV1_RESTRICT mask,
619                                  const ptrdiff_t mask_stride,
620                                  uint8_t* LIBGAV1_RESTRICT dst,
621                                  const ptrdiff_t dst_stride) {
622   // Double stride because the function works on 2 lines at a time.
623   const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
624   const ptrdiff_t dst_stride_y = dst_stride << 1;
625   const uint16x8_t mask_inverter = vdupq_n_u16(64);
626 
627   MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
628       pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
629       dst_stride);
630 
631   pred_0 += 4 << 1;
632   pred_1 += pred_stride_1 << 1;
633   mask += mask_stride_y;
634   dst += dst_stride_y;
635 
636   MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
637       pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
638       dst_stride);
639 }
640 
641 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlending4xH_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)642 inline void MaskBlending4xH_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
643                                  const uint16_t* LIBGAV1_RESTRICT pred_1,
644                                  const ptrdiff_t pred_stride_1,
645                                  const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
646                                  const ptrdiff_t mask_stride, const int height,
647                                  uint8_t* LIBGAV1_RESTRICT dst,
648                                  const ptrdiff_t dst_stride) {
649   const uint8_t* mask = mask_ptr;
650   if (height == 4) {
651     MaskBlending4x4_NEON<subsampling_x, subsampling_y, is_inter_intra>(
652         pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
653     return;
654   }
655   // Double stride because the function works on 2 lines at a time.
656   const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
657   const ptrdiff_t dst_stride_y = dst_stride << 1;
658   const uint16x8_t mask_inverter = vdupq_n_u16(64);
659   int y = 0;
660   do {
661     MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
662         pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
663         dst_stride);
664     pred_0 += 4 << 1;
665     pred_1 += pred_stride_1 << 1;
666     mask += mask_stride_y;
667     dst += dst_stride_y;
668 
669     MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
670         pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
671         dst_stride);
672     pred_0 += 4 << 1;
673     pred_1 += pred_stride_1 << 1;
674     mask += mask_stride_y;
675     dst += dst_stride_y;
676 
677     MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
678         pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
679         dst_stride);
680     pred_0 += 4 << 1;
681     pred_1 += pred_stride_1 << 1;
682     mask += mask_stride_y;
683     dst += dst_stride_y;
684 
685     MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
686         pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
687         dst_stride);
688     pred_0 += 4 << 1;
689     pred_1 += pred_stride_1 << 1;
690     mask += mask_stride_y;
691     dst += dst_stride_y;
692     y += 8;
693   } while (y < height);
694 }
695 
696 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlend8_NEON(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT mask,const uint16x8_t mask_inverter,const ptrdiff_t mask_stride,uint8_t * LIBGAV1_RESTRICT dst)697 void MaskBlend8_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
698                      const uint16_t* LIBGAV1_RESTRICT pred_1,
699                      const uint8_t* LIBGAV1_RESTRICT mask,
700                      const uint16x8_t mask_inverter,
701                      const ptrdiff_t mask_stride,
702                      uint8_t* LIBGAV1_RESTRICT dst) {
703   const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
704   const uint16x8_t pred_val_1 = vld1q_u16(pred_1);
705   const uint16x8_t pred_mask_0 =
706       GetMask8<subsampling_x, subsampling_y>(mask, mask_stride);
707   const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
708   const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
709       pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
710 
711   StoreShiftedResult<is_inter_intra, 8>(dst, weighted_pred_sum);
712 }
713 
714 template <int subsampling_x, int subsampling_y, bool is_inter_intra>
MaskBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)715 inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
716                            const void* LIBGAV1_RESTRICT prediction_1,
717                            const ptrdiff_t prediction_stride_1,
718                            const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
719                            const ptrdiff_t mask_stride, const int width,
720                            const int height, void* LIBGAV1_RESTRICT dest,
721                            const ptrdiff_t dst_stride) {
722   if (!is_inter_intra) {
723     assert(prediction_stride_1 == width);
724   }
725   auto* dst = static_cast<uint8_t*>(dest);
726   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
727   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
728   if (width == 4) {
729     MaskBlending4xH_NEON<subsampling_x, subsampling_y, is_inter_intra>(
730         pred_0, pred_1, prediction_stride_1, mask_ptr, mask_stride, height, dst,
731         dst_stride);
732     return;
733   }
734   const ptrdiff_t mask_stride_y = mask_stride << subsampling_y;
735   const uint8_t* mask = mask_ptr;
736   const uint16x8_t mask_inverter = vdupq_n_u16(64);
737   int y = 0;
738   do {
739     int x = 0;
740     do {
741       MaskBlend8_NEON<subsampling_x, subsampling_y, is_inter_intra>(
742           pred_0 + x, pred_1 + x, mask + (x << subsampling_x), mask_inverter,
743           mask_stride,
744           reinterpret_cast<uint8_t*>(reinterpret_cast<uint16_t*>(dst) + x));
745       x += 8;
746     } while (x < width);
747     dst += dst_stride;
748     pred_0 += width;
749     pred_1 += prediction_stride_1;
750     mask += mask_stride_y;
751   } while (++y < height);
752 }
753 
Init10bpp()754 void Init10bpp() {
755   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
756   assert(dsp != nullptr);
757   dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0, false>;
758   dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0, false>;
759   dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1, false>;
760 
761   dsp->mask_blend[0][1] = MaskBlend_NEON<0, 0, true>;
762   dsp->mask_blend[1][1] = MaskBlend_NEON<1, 0, true>;
763   dsp->mask_blend[2][1] = MaskBlend_NEON<1, 1, true>;
764 }
765 
766 }  // namespace
767 }  // namespace high_bitdepth
768 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
769 
MaskBlendInit_NEON()770 void MaskBlendInit_NEON() {
771   low_bitdepth::Init8bpp();
772 #if LIBGAV1_MAX_BITDEPTH >= 10
773   high_bitdepth::Init10bpp();
774 #endif
775 }
776 
777 }  // namespace dsp
778 }  // namespace libgav1
779 
780 #else   // !LIBGAV1_ENABLE_NEON
781 
782 namespace libgav1 {
783 namespace dsp {
784 
MaskBlendInit_NEON()785 void MaskBlendInit_NEON() {}
786 
787 }  // namespace dsp
788 }  // namespace libgav1
789 #endif  // LIBGAV1_ENABLE_NEON
790