• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom/aom_integer.h"
19 #include "aom_dsp/arm/sum_neon.h"
20 #include "aom_dsp/arm/transpose_neon.h"
21 #include "aom_dsp/intrapred_common.h"
22 
23 // -----------------------------------------------------------------------------
24 // DC
25 
highbd_dc_store_4xh(uint16_t * dst,ptrdiff_t stride,int h,uint16x4_t dc)26 static INLINE void highbd_dc_store_4xh(uint16_t *dst, ptrdiff_t stride, int h,
27                                        uint16x4_t dc) {
28   for (int i = 0; i < h; ++i) {
29     vst1_u16(dst + i * stride, dc);
30   }
31 }
32 
highbd_dc_store_8xh(uint16_t * dst,ptrdiff_t stride,int h,uint16x8_t dc)33 static INLINE void highbd_dc_store_8xh(uint16_t *dst, ptrdiff_t stride, int h,
34                                        uint16x8_t dc) {
35   for (int i = 0; i < h; ++i) {
36     vst1q_u16(dst + i * stride, dc);
37   }
38 }
39 
highbd_dc_store_16xh(uint16_t * dst,ptrdiff_t stride,int h,uint16x8_t dc)40 static INLINE void highbd_dc_store_16xh(uint16_t *dst, ptrdiff_t stride, int h,
41                                         uint16x8_t dc) {
42   for (int i = 0; i < h; ++i) {
43     vst1q_u16(dst + i * stride, dc);
44     vst1q_u16(dst + i * stride + 8, dc);
45   }
46 }
47 
highbd_dc_store_32xh(uint16_t * dst,ptrdiff_t stride,int h,uint16x8_t dc)48 static INLINE void highbd_dc_store_32xh(uint16_t *dst, ptrdiff_t stride, int h,
49                                         uint16x8_t dc) {
50   for (int i = 0; i < h; ++i) {
51     vst1q_u16(dst + i * stride, dc);
52     vst1q_u16(dst + i * stride + 8, dc);
53     vst1q_u16(dst + i * stride + 16, dc);
54     vst1q_u16(dst + i * stride + 24, dc);
55   }
56 }
57 
highbd_dc_store_64xh(uint16_t * dst,ptrdiff_t stride,int h,uint16x8_t dc)58 static INLINE void highbd_dc_store_64xh(uint16_t *dst, ptrdiff_t stride, int h,
59                                         uint16x8_t dc) {
60   for (int i = 0; i < h; ++i) {
61     vst1q_u16(dst + i * stride, dc);
62     vst1q_u16(dst + i * stride + 8, dc);
63     vst1q_u16(dst + i * stride + 16, dc);
64     vst1q_u16(dst + i * stride + 24, dc);
65     vst1q_u16(dst + i * stride + 32, dc);
66     vst1q_u16(dst + i * stride + 40, dc);
67     vst1q_u16(dst + i * stride + 48, dc);
68     vst1q_u16(dst + i * stride + 56, dc);
69   }
70 }
71 
horizontal_add_and_broadcast_long_u16x8(uint16x8_t a)72 static INLINE uint32x4_t horizontal_add_and_broadcast_long_u16x8(uint16x8_t a) {
73   // Need to assume input is up to 16 bits wide from dc 64x64 partial sum, so
74   // promote first.
75   const uint32x4_t b = vpaddlq_u16(a);
76 #if AOM_ARCH_AARCH64
77   const uint32x4_t c = vpaddq_u32(b, b);
78   return vpaddq_u32(c, c);
79 #else
80   const uint32x2_t c = vadd_u32(vget_low_u32(b), vget_high_u32(b));
81   const uint32x2_t d = vpadd_u32(c, c);
82   return vcombine_u32(d, d);
83 #endif
84 }
85 
highbd_dc_load_partial_sum_4(const uint16_t * left)86 static INLINE uint16x8_t highbd_dc_load_partial_sum_4(const uint16_t *left) {
87   // Nothing to do since sum is already one vector, but saves needing to
88   // special case w=4 or h=4 cases. The combine will be zero cost for a sane
89   // compiler since vld1 already sets the top half of a vector to zero as part
90   // of the operation.
91   return vcombine_u16(vld1_u16(left), vdup_n_u16(0));
92 }
93 
highbd_dc_load_partial_sum_8(const uint16_t * left)94 static INLINE uint16x8_t highbd_dc_load_partial_sum_8(const uint16_t *left) {
95   // Nothing to do since sum is already one vector, but saves needing to
96   // special case w=8 or h=8 cases.
97   return vld1q_u16(left);
98 }
99 
highbd_dc_load_partial_sum_16(const uint16_t * left)100 static INLINE uint16x8_t highbd_dc_load_partial_sum_16(const uint16_t *left) {
101   const uint16x8_t a0 = vld1q_u16(left + 0);  // up to 12 bits
102   const uint16x8_t a1 = vld1q_u16(left + 8);
103   return vaddq_u16(a0, a1);  // up to 13 bits
104 }
105 
highbd_dc_load_partial_sum_32(const uint16_t * left)106 static INLINE uint16x8_t highbd_dc_load_partial_sum_32(const uint16_t *left) {
107   const uint16x8_t a0 = vld1q_u16(left + 0);  // up to 12 bits
108   const uint16x8_t a1 = vld1q_u16(left + 8);
109   const uint16x8_t a2 = vld1q_u16(left + 16);
110   const uint16x8_t a3 = vld1q_u16(left + 24);
111   const uint16x8_t b0 = vaddq_u16(a0, a1);  // up to 13 bits
112   const uint16x8_t b1 = vaddq_u16(a2, a3);
113   return vaddq_u16(b0, b1);  // up to 14 bits
114 }
115 
highbd_dc_load_partial_sum_64(const uint16_t * left)116 static INLINE uint16x8_t highbd_dc_load_partial_sum_64(const uint16_t *left) {
117   const uint16x8_t a0 = vld1q_u16(left + 0);  // up to 12 bits
118   const uint16x8_t a1 = vld1q_u16(left + 8);
119   const uint16x8_t a2 = vld1q_u16(left + 16);
120   const uint16x8_t a3 = vld1q_u16(left + 24);
121   const uint16x8_t a4 = vld1q_u16(left + 32);
122   const uint16x8_t a5 = vld1q_u16(left + 40);
123   const uint16x8_t a6 = vld1q_u16(left + 48);
124   const uint16x8_t a7 = vld1q_u16(left + 56);
125   const uint16x8_t b0 = vaddq_u16(a0, a1);  // up to 13 bits
126   const uint16x8_t b1 = vaddq_u16(a2, a3);
127   const uint16x8_t b2 = vaddq_u16(a4, a5);
128   const uint16x8_t b3 = vaddq_u16(a6, a7);
129   const uint16x8_t c0 = vaddq_u16(b0, b1);  // up to 14 bits
130   const uint16x8_t c1 = vaddq_u16(b2, b3);
131   return vaddq_u16(c0, c1);  // up to 15 bits
132 }
133 
134 #define HIGHBD_DC_PREDICTOR(w, h, shift)                               \
135   void aom_highbd_dc_predictor_##w##x##h##_neon(                       \
136       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,          \
137       const uint16_t *left, int bd) {                                  \
138     (void)bd;                                                          \
139     const uint16x8_t a = highbd_dc_load_partial_sum_##w(above);        \
140     const uint16x8_t l = highbd_dc_load_partial_sum_##h(left);         \
141     const uint32x4_t sum =                                             \
142         horizontal_add_and_broadcast_long_u16x8(vaddq_u16(a, l));      \
143     const uint16x4_t dc0 = vrshrn_n_u32(sum, shift);                   \
144     highbd_dc_store_##w##xh(dst, stride, (h), vdupq_lane_u16(dc0, 0)); \
145   }
146 
aom_highbd_dc_predictor_4x4_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)147 void aom_highbd_dc_predictor_4x4_neon(uint16_t *dst, ptrdiff_t stride,
148                                       const uint16_t *above,
149                                       const uint16_t *left, int bd) {
150   // In the rectangular cases we simply extend the shorter vector to uint16x8
151   // in order to accumulate, however in the 4x4 case there is no shorter vector
152   // to extend so it is beneficial to do the whole calculation in uint16x4
153   // instead.
154   (void)bd;
155   const uint16x4_t a = vld1_u16(above);  // up to 12 bits
156   const uint16x4_t l = vld1_u16(left);
157   uint16x4_t sum = vpadd_u16(a, l);  // up to 13 bits
158   sum = vpadd_u16(sum, sum);         // up to 14 bits
159   sum = vpadd_u16(sum, sum);
160   const uint16x4_t dc = vrshr_n_u16(sum, 3);
161   highbd_dc_store_4xh(dst, stride, 4, dc);
162 }
163 
164 HIGHBD_DC_PREDICTOR(8, 8, 4)
165 HIGHBD_DC_PREDICTOR(16, 16, 5)
166 HIGHBD_DC_PREDICTOR(32, 32, 6)
167 HIGHBD_DC_PREDICTOR(64, 64, 7)
168 
169 #undef HIGHBD_DC_PREDICTOR
170 
divide_using_multiply_shift(int num,int shift1,int multiplier,int shift2)171 static INLINE int divide_using_multiply_shift(int num, int shift1,
172                                               int multiplier, int shift2) {
173   const int interm = num >> shift1;
174   return interm * multiplier >> shift2;
175 }
176 
177 #define HIGHBD_DC_MULTIPLIER_1X2 0xAAAB
178 #define HIGHBD_DC_MULTIPLIER_1X4 0x6667
179 #define HIGHBD_DC_SHIFT2 17
180 
highbd_dc_predictor_rect(int bw,int bh,int sum,int shift1,uint32_t multiplier)181 static INLINE int highbd_dc_predictor_rect(int bw, int bh, int sum, int shift1,
182                                            uint32_t multiplier) {
183   return divide_using_multiply_shift(sum + ((bw + bh) >> 1), shift1, multiplier,
184                                      HIGHBD_DC_SHIFT2);
185 }
186 
187 #undef HIGHBD_DC_SHIFT2
188 
189 #define HIGHBD_DC_PREDICTOR_RECT(w, h, q, shift, mult)                  \
190   void aom_highbd_dc_predictor_##w##x##h##_neon(                        \
191       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,           \
192       const uint16_t *left, int bd) {                                   \
193     (void)bd;                                                           \
194     uint16x8_t sum_above = highbd_dc_load_partial_sum_##w(above);       \
195     uint16x8_t sum_left = highbd_dc_load_partial_sum_##h(left);         \
196     uint16x8_t sum_vec = vaddq_u16(sum_left, sum_above);                \
197     int sum = horizontal_add_u16x8(sum_vec);                            \
198     int dc0 = highbd_dc_predictor_rect((w), (h), sum, (shift), (mult)); \
199     highbd_dc_store_##w##xh(dst, stride, (h), vdup##q##_n_u16(dc0));    \
200   }
201 
202 HIGHBD_DC_PREDICTOR_RECT(4, 8, , 2, HIGHBD_DC_MULTIPLIER_1X2)
203 HIGHBD_DC_PREDICTOR_RECT(4, 16, , 2, HIGHBD_DC_MULTIPLIER_1X4)
204 HIGHBD_DC_PREDICTOR_RECT(8, 4, q, 2, HIGHBD_DC_MULTIPLIER_1X2)
205 HIGHBD_DC_PREDICTOR_RECT(8, 16, q, 3, HIGHBD_DC_MULTIPLIER_1X2)
206 HIGHBD_DC_PREDICTOR_RECT(8, 32, q, 3, HIGHBD_DC_MULTIPLIER_1X4)
207 HIGHBD_DC_PREDICTOR_RECT(16, 4, q, 2, HIGHBD_DC_MULTIPLIER_1X4)
208 HIGHBD_DC_PREDICTOR_RECT(16, 8, q, 3, HIGHBD_DC_MULTIPLIER_1X2)
209 HIGHBD_DC_PREDICTOR_RECT(16, 32, q, 4, HIGHBD_DC_MULTIPLIER_1X2)
210 HIGHBD_DC_PREDICTOR_RECT(16, 64, q, 4, HIGHBD_DC_MULTIPLIER_1X4)
211 HIGHBD_DC_PREDICTOR_RECT(32, 8, q, 3, HIGHBD_DC_MULTIPLIER_1X4)
212 HIGHBD_DC_PREDICTOR_RECT(32, 16, q, 4, HIGHBD_DC_MULTIPLIER_1X2)
213 HIGHBD_DC_PREDICTOR_RECT(32, 64, q, 5, HIGHBD_DC_MULTIPLIER_1X2)
214 HIGHBD_DC_PREDICTOR_RECT(64, 16, q, 4, HIGHBD_DC_MULTIPLIER_1X4)
215 HIGHBD_DC_PREDICTOR_RECT(64, 32, q, 5, HIGHBD_DC_MULTIPLIER_1X2)
216 
217 #undef HIGHBD_DC_PREDICTOR_RECT
218 #undef HIGHBD_DC_MULTIPLIER_1X2
219 #undef HIGHBD_DC_MULTIPLIER_1X4
220 
221 // -----------------------------------------------------------------------------
222 // DC_128
223 
224 #define HIGHBD_DC_PREDICTOR_128(w, h, q)                        \
225   void aom_highbd_dc_128_predictor_##w##x##h##_neon(            \
226       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,   \
227       const uint16_t *left, int bd) {                           \
228     (void)above;                                                \
229     (void)bd;                                                   \
230     (void)left;                                                 \
231     highbd_dc_store_##w##xh(dst, stride, (h),                   \
232                             vdup##q##_n_u16(0x80 << (bd - 8))); \
233   }
234 
235 HIGHBD_DC_PREDICTOR_128(4, 4, )
236 HIGHBD_DC_PREDICTOR_128(4, 8, )
237 HIGHBD_DC_PREDICTOR_128(4, 16, )
238 HIGHBD_DC_PREDICTOR_128(8, 4, q)
239 HIGHBD_DC_PREDICTOR_128(8, 8, q)
240 HIGHBD_DC_PREDICTOR_128(8, 16, q)
241 HIGHBD_DC_PREDICTOR_128(8, 32, q)
242 HIGHBD_DC_PREDICTOR_128(16, 4, q)
243 HIGHBD_DC_PREDICTOR_128(16, 8, q)
244 HIGHBD_DC_PREDICTOR_128(16, 16, q)
245 HIGHBD_DC_PREDICTOR_128(16, 32, q)
246 HIGHBD_DC_PREDICTOR_128(16, 64, q)
247 HIGHBD_DC_PREDICTOR_128(32, 8, q)
248 HIGHBD_DC_PREDICTOR_128(32, 16, q)
249 HIGHBD_DC_PREDICTOR_128(32, 32, q)
250 HIGHBD_DC_PREDICTOR_128(32, 64, q)
251 HIGHBD_DC_PREDICTOR_128(64, 16, q)
252 HIGHBD_DC_PREDICTOR_128(64, 32, q)
253 HIGHBD_DC_PREDICTOR_128(64, 64, q)
254 
255 #undef HIGHBD_DC_PREDICTOR_128
256 
257 // -----------------------------------------------------------------------------
258 // DC_LEFT
259 
highbd_dc_load_sum_4(const uint16_t * left)260 static INLINE uint32x4_t highbd_dc_load_sum_4(const uint16_t *left) {
261   const uint16x4_t a = vld1_u16(left);   // up to 12 bits
262   const uint16x4_t b = vpadd_u16(a, a);  // up to 13 bits
263   return vcombine_u32(vpaddl_u16(b), vdup_n_u32(0));
264 }
265 
highbd_dc_load_sum_8(const uint16_t * left)266 static INLINE uint32x4_t highbd_dc_load_sum_8(const uint16_t *left) {
267   return horizontal_add_and_broadcast_long_u16x8(vld1q_u16(left));
268 }
269 
highbd_dc_load_sum_16(const uint16_t * left)270 static INLINE uint32x4_t highbd_dc_load_sum_16(const uint16_t *left) {
271   return horizontal_add_and_broadcast_long_u16x8(
272       highbd_dc_load_partial_sum_16(left));
273 }
274 
highbd_dc_load_sum_32(const uint16_t * left)275 static INLINE uint32x4_t highbd_dc_load_sum_32(const uint16_t *left) {
276   return horizontal_add_and_broadcast_long_u16x8(
277       highbd_dc_load_partial_sum_32(left));
278 }
279 
highbd_dc_load_sum_64(const uint16_t * left)280 static INLINE uint32x4_t highbd_dc_load_sum_64(const uint16_t *left) {
281   return horizontal_add_and_broadcast_long_u16x8(
282       highbd_dc_load_partial_sum_64(left));
283 }
284 
285 #define DC_PREDICTOR_LEFT(w, h, shift, q)                                  \
286   void aom_highbd_dc_left_predictor_##w##x##h##_neon(                      \
287       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,              \
288       const uint16_t *left, int bd) {                                      \
289     (void)above;                                                           \
290     (void)bd;                                                              \
291     const uint32x4_t sum = highbd_dc_load_sum_##h(left);                   \
292     const uint16x4_t dc0 = vrshrn_n_u32(sum, (shift));                     \
293     highbd_dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u16(dc0, 0)); \
294   }
295 
296 DC_PREDICTOR_LEFT(4, 4, 2, )
297 DC_PREDICTOR_LEFT(4, 8, 3, )
298 DC_PREDICTOR_LEFT(4, 16, 4, )
299 DC_PREDICTOR_LEFT(8, 4, 2, q)
300 DC_PREDICTOR_LEFT(8, 8, 3, q)
301 DC_PREDICTOR_LEFT(8, 16, 4, q)
302 DC_PREDICTOR_LEFT(8, 32, 5, q)
303 DC_PREDICTOR_LEFT(16, 4, 2, q)
304 DC_PREDICTOR_LEFT(16, 8, 3, q)
305 DC_PREDICTOR_LEFT(16, 16, 4, q)
306 DC_PREDICTOR_LEFT(16, 32, 5, q)
307 DC_PREDICTOR_LEFT(16, 64, 6, q)
308 DC_PREDICTOR_LEFT(32, 8, 3, q)
309 DC_PREDICTOR_LEFT(32, 16, 4, q)
310 DC_PREDICTOR_LEFT(32, 32, 5, q)
311 DC_PREDICTOR_LEFT(32, 64, 6, q)
312 DC_PREDICTOR_LEFT(64, 16, 4, q)
313 DC_PREDICTOR_LEFT(64, 32, 5, q)
314 DC_PREDICTOR_LEFT(64, 64, 6, q)
315 
316 #undef DC_PREDICTOR_LEFT
317 
318 // -----------------------------------------------------------------------------
319 // DC_TOP
320 
321 #define DC_PREDICTOR_TOP(w, h, shift, q)                                   \
322   void aom_highbd_dc_top_predictor_##w##x##h##_neon(                       \
323       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,              \
324       const uint16_t *left, int bd) {                                      \
325     (void)bd;                                                              \
326     (void)left;                                                            \
327     const uint32x4_t sum = highbd_dc_load_sum_##w(above);                  \
328     const uint16x4_t dc0 = vrshrn_n_u32(sum, (shift));                     \
329     highbd_dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u16(dc0, 0)); \
330   }
331 
332 DC_PREDICTOR_TOP(4, 4, 2, )
333 DC_PREDICTOR_TOP(4, 8, 2, )
334 DC_PREDICTOR_TOP(4, 16, 2, )
335 DC_PREDICTOR_TOP(8, 4, 3, q)
336 DC_PREDICTOR_TOP(8, 8, 3, q)
337 DC_PREDICTOR_TOP(8, 16, 3, q)
338 DC_PREDICTOR_TOP(8, 32, 3, q)
339 DC_PREDICTOR_TOP(16, 4, 4, q)
340 DC_PREDICTOR_TOP(16, 8, 4, q)
341 DC_PREDICTOR_TOP(16, 16, 4, q)
342 DC_PREDICTOR_TOP(16, 32, 4, q)
343 DC_PREDICTOR_TOP(16, 64, 4, q)
344 DC_PREDICTOR_TOP(32, 8, 5, q)
345 DC_PREDICTOR_TOP(32, 16, 5, q)
346 DC_PREDICTOR_TOP(32, 32, 5, q)
347 DC_PREDICTOR_TOP(32, 64, 5, q)
348 DC_PREDICTOR_TOP(64, 16, 6, q)
349 DC_PREDICTOR_TOP(64, 32, 6, q)
350 DC_PREDICTOR_TOP(64, 64, 6, q)
351 
352 #undef DC_PREDICTOR_TOP
353 
354 // -----------------------------------------------------------------------------
355 // V_PRED
356 
357 #define HIGHBD_V_NXM(W, H)                                    \
358   void aom_highbd_v_predictor_##W##x##H##_neon(               \
359       uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \
360       const uint16_t *left, int bd) {                         \
361     (void)left;                                               \
362     (void)bd;                                                 \
363     vertical##W##xh_neon(dst, stride, above, H);              \
364   }
365 
load_uint16x8x2(uint16_t const * ptr)366 static INLINE uint16x8x2_t load_uint16x8x2(uint16_t const *ptr) {
367   uint16x8x2_t x;
368   // Clang/gcc uses ldp here.
369   x.val[0] = vld1q_u16(ptr);
370   x.val[1] = vld1q_u16(ptr + 8);
371   return x;
372 }
373 
store_uint16x8x2(uint16_t * ptr,uint16x8x2_t x)374 static INLINE void store_uint16x8x2(uint16_t *ptr, uint16x8x2_t x) {
375   vst1q_u16(ptr, x.val[0]);
376   vst1q_u16(ptr + 8, x.val[1]);
377 }
378 
vertical4xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const above,int height)379 static INLINE void vertical4xh_neon(uint16_t *dst, ptrdiff_t stride,
380                                     const uint16_t *const above, int height) {
381   const uint16x4_t row = vld1_u16(above);
382   int y = height;
383   do {
384     vst1_u16(dst, row);
385     vst1_u16(dst + stride, row);
386     dst += stride << 1;
387     y -= 2;
388   } while (y != 0);
389 }
390 
vertical8xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const above,int height)391 static INLINE void vertical8xh_neon(uint16_t *dst, ptrdiff_t stride,
392                                     const uint16_t *const above, int height) {
393   const uint16x8_t row = vld1q_u16(above);
394   int y = height;
395   do {
396     vst1q_u16(dst, row);
397     vst1q_u16(dst + stride, row);
398     dst += stride << 1;
399     y -= 2;
400   } while (y != 0);
401 }
402 
vertical16xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const above,int height)403 static INLINE void vertical16xh_neon(uint16_t *dst, ptrdiff_t stride,
404                                      const uint16_t *const above, int height) {
405   const uint16x8x2_t row = load_uint16x8x2(above);
406   int y = height;
407   do {
408     store_uint16x8x2(dst, row);
409     store_uint16x8x2(dst + stride, row);
410     dst += stride << 1;
411     y -= 2;
412   } while (y != 0);
413 }
414 
load_uint16x8x4(uint16_t const * ptr)415 static INLINE uint16x8x4_t load_uint16x8x4(uint16_t const *ptr) {
416   uint16x8x4_t x;
417   // Clang/gcc uses ldp here.
418   x.val[0] = vld1q_u16(ptr);
419   x.val[1] = vld1q_u16(ptr + 8);
420   x.val[2] = vld1q_u16(ptr + 16);
421   x.val[3] = vld1q_u16(ptr + 24);
422   return x;
423 }
424 
store_uint16x8x4(uint16_t * ptr,uint16x8x4_t x)425 static INLINE void store_uint16x8x4(uint16_t *ptr, uint16x8x4_t x) {
426   vst1q_u16(ptr, x.val[0]);
427   vst1q_u16(ptr + 8, x.val[1]);
428   vst1q_u16(ptr + 16, x.val[2]);
429   vst1q_u16(ptr + 24, x.val[3]);
430 }
431 
vertical32xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const above,int height)432 static INLINE void vertical32xh_neon(uint16_t *dst, ptrdiff_t stride,
433                                      const uint16_t *const above, int height) {
434   const uint16x8x4_t row = load_uint16x8x4(above);
435   int y = height;
436   do {
437     store_uint16x8x4(dst, row);
438     store_uint16x8x4(dst + stride, row);
439     dst += stride << 1;
440     y -= 2;
441   } while (y != 0);
442 }
443 
vertical64xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const above,int height)444 static INLINE void vertical64xh_neon(uint16_t *dst, ptrdiff_t stride,
445                                      const uint16_t *const above, int height) {
446   uint16_t *dst32 = dst + 32;
447   const uint16x8x4_t row = load_uint16x8x4(above);
448   const uint16x8x4_t row32 = load_uint16x8x4(above + 32);
449   int y = height;
450   do {
451     store_uint16x8x4(dst, row);
452     store_uint16x8x4(dst32, row32);
453     store_uint16x8x4(dst + stride, row);
454     store_uint16x8x4(dst32 + stride, row32);
455     dst += stride << 1;
456     dst32 += stride << 1;
457     y -= 2;
458   } while (y != 0);
459 }
460 
461 HIGHBD_V_NXM(4, 4)
462 HIGHBD_V_NXM(4, 8)
463 HIGHBD_V_NXM(4, 16)
464 
465 HIGHBD_V_NXM(8, 4)
466 HIGHBD_V_NXM(8, 8)
467 HIGHBD_V_NXM(8, 16)
468 HIGHBD_V_NXM(8, 32)
469 
470 HIGHBD_V_NXM(16, 4)
471 HIGHBD_V_NXM(16, 8)
472 HIGHBD_V_NXM(16, 16)
473 HIGHBD_V_NXM(16, 32)
474 HIGHBD_V_NXM(16, 64)
475 
476 HIGHBD_V_NXM(32, 8)
477 HIGHBD_V_NXM(32, 16)
478 HIGHBD_V_NXM(32, 32)
479 HIGHBD_V_NXM(32, 64)
480 
481 HIGHBD_V_NXM(64, 16)
482 HIGHBD_V_NXM(64, 32)
483 HIGHBD_V_NXM(64, 64)
484 
485 // -----------------------------------------------------------------------------
486 // H_PRED
487 
highbd_h_store_4x4(uint16_t * dst,ptrdiff_t stride,uint16x4_t left)488 static INLINE void highbd_h_store_4x4(uint16_t *dst, ptrdiff_t stride,
489                                       uint16x4_t left) {
490   vst1_u16(dst + 0 * stride, vdup_lane_u16(left, 0));
491   vst1_u16(dst + 1 * stride, vdup_lane_u16(left, 1));
492   vst1_u16(dst + 2 * stride, vdup_lane_u16(left, 2));
493   vst1_u16(dst + 3 * stride, vdup_lane_u16(left, 3));
494 }
495 
highbd_h_store_8x4(uint16_t * dst,ptrdiff_t stride,uint16x4_t left)496 static INLINE void highbd_h_store_8x4(uint16_t *dst, ptrdiff_t stride,
497                                       uint16x4_t left) {
498   vst1q_u16(dst + 0 * stride, vdupq_lane_u16(left, 0));
499   vst1q_u16(dst + 1 * stride, vdupq_lane_u16(left, 1));
500   vst1q_u16(dst + 2 * stride, vdupq_lane_u16(left, 2));
501   vst1q_u16(dst + 3 * stride, vdupq_lane_u16(left, 3));
502 }
503 
highbd_h_store_16x1(uint16_t * dst,uint16x8_t left)504 static INLINE void highbd_h_store_16x1(uint16_t *dst, uint16x8_t left) {
505   vst1q_u16(dst + 0, left);
506   vst1q_u16(dst + 8, left);
507 }
508 
highbd_h_store_16x4(uint16_t * dst,ptrdiff_t stride,uint16x4_t left)509 static INLINE void highbd_h_store_16x4(uint16_t *dst, ptrdiff_t stride,
510                                        uint16x4_t left) {
511   highbd_h_store_16x1(dst + 0 * stride, vdupq_lane_u16(left, 0));
512   highbd_h_store_16x1(dst + 1 * stride, vdupq_lane_u16(left, 1));
513   highbd_h_store_16x1(dst + 2 * stride, vdupq_lane_u16(left, 2));
514   highbd_h_store_16x1(dst + 3 * stride, vdupq_lane_u16(left, 3));
515 }
516 
highbd_h_store_32x1(uint16_t * dst,uint16x8_t left)517 static INLINE void highbd_h_store_32x1(uint16_t *dst, uint16x8_t left) {
518   vst1q_u16(dst + 0, left);
519   vst1q_u16(dst + 8, left);
520   vst1q_u16(dst + 16, left);
521   vst1q_u16(dst + 24, left);
522 }
523 
highbd_h_store_32x4(uint16_t * dst,ptrdiff_t stride,uint16x4_t left)524 static INLINE void highbd_h_store_32x4(uint16_t *dst, ptrdiff_t stride,
525                                        uint16x4_t left) {
526   highbd_h_store_32x1(dst + 0 * stride, vdupq_lane_u16(left, 0));
527   highbd_h_store_32x1(dst + 1 * stride, vdupq_lane_u16(left, 1));
528   highbd_h_store_32x1(dst + 2 * stride, vdupq_lane_u16(left, 2));
529   highbd_h_store_32x1(dst + 3 * stride, vdupq_lane_u16(left, 3));
530 }
531 
highbd_h_store_64x1(uint16_t * dst,uint16x8_t left)532 static INLINE void highbd_h_store_64x1(uint16_t *dst, uint16x8_t left) {
533   vst1q_u16(dst + 0, left);
534   vst1q_u16(dst + 8, left);
535   vst1q_u16(dst + 16, left);
536   vst1q_u16(dst + 24, left);
537   vst1q_u16(dst + 32, left);
538   vst1q_u16(dst + 40, left);
539   vst1q_u16(dst + 48, left);
540   vst1q_u16(dst + 56, left);
541 }
542 
highbd_h_store_64x4(uint16_t * dst,ptrdiff_t stride,uint16x4_t left)543 static INLINE void highbd_h_store_64x4(uint16_t *dst, ptrdiff_t stride,
544                                        uint16x4_t left) {
545   highbd_h_store_64x1(dst + 0 * stride, vdupq_lane_u16(left, 0));
546   highbd_h_store_64x1(dst + 1 * stride, vdupq_lane_u16(left, 1));
547   highbd_h_store_64x1(dst + 2 * stride, vdupq_lane_u16(left, 2));
548   highbd_h_store_64x1(dst + 3 * stride, vdupq_lane_u16(left, 3));
549 }
550 
aom_highbd_h_predictor_4x4_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)551 void aom_highbd_h_predictor_4x4_neon(uint16_t *dst, ptrdiff_t stride,
552                                      const uint16_t *above,
553                                      const uint16_t *left, int bd) {
554   (void)above;
555   (void)bd;
556   highbd_h_store_4x4(dst, stride, vld1_u16(left));
557 }
558 
aom_highbd_h_predictor_4x8_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)559 void aom_highbd_h_predictor_4x8_neon(uint16_t *dst, ptrdiff_t stride,
560                                      const uint16_t *above,
561                                      const uint16_t *left, int bd) {
562   (void)above;
563   (void)bd;
564   uint16x8_t l = vld1q_u16(left);
565   highbd_h_store_4x4(dst + 0 * stride, stride, vget_low_u16(l));
566   highbd_h_store_4x4(dst + 4 * stride, stride, vget_high_u16(l));
567 }
568 
aom_highbd_h_predictor_8x4_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)569 void aom_highbd_h_predictor_8x4_neon(uint16_t *dst, ptrdiff_t stride,
570                                      const uint16_t *above,
571                                      const uint16_t *left, int bd) {
572   (void)above;
573   (void)bd;
574   highbd_h_store_8x4(dst, stride, vld1_u16(left));
575 }
576 
aom_highbd_h_predictor_8x8_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)577 void aom_highbd_h_predictor_8x8_neon(uint16_t *dst, ptrdiff_t stride,
578                                      const uint16_t *above,
579                                      const uint16_t *left, int bd) {
580   (void)above;
581   (void)bd;
582   uint16x8_t l = vld1q_u16(left);
583   highbd_h_store_8x4(dst + 0 * stride, stride, vget_low_u16(l));
584   highbd_h_store_8x4(dst + 4 * stride, stride, vget_high_u16(l));
585 }
586 
aom_highbd_h_predictor_16x4_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)587 void aom_highbd_h_predictor_16x4_neon(uint16_t *dst, ptrdiff_t stride,
588                                       const uint16_t *above,
589                                       const uint16_t *left, int bd) {
590   (void)above;
591   (void)bd;
592   highbd_h_store_16x4(dst, stride, vld1_u16(left));
593 }
594 
aom_highbd_h_predictor_16x8_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)595 void aom_highbd_h_predictor_16x8_neon(uint16_t *dst, ptrdiff_t stride,
596                                       const uint16_t *above,
597                                       const uint16_t *left, int bd) {
598   (void)above;
599   (void)bd;
600   uint16x8_t l = vld1q_u16(left);
601   highbd_h_store_16x4(dst + 0 * stride, stride, vget_low_u16(l));
602   highbd_h_store_16x4(dst + 4 * stride, stride, vget_high_u16(l));
603 }
604 
aom_highbd_h_predictor_32x8_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int bd)605 void aom_highbd_h_predictor_32x8_neon(uint16_t *dst, ptrdiff_t stride,
606                                       const uint16_t *above,
607                                       const uint16_t *left, int bd) {
608   (void)above;
609   (void)bd;
610   uint16x8_t l = vld1q_u16(left);
611   highbd_h_store_32x4(dst + 0 * stride, stride, vget_low_u16(l));
612   highbd_h_store_32x4(dst + 4 * stride, stride, vget_high_u16(l));
613 }
614 
615 // For cases where height >= 16 we use pairs of loads to get LDP instructions.
616 #define HIGHBD_H_WXH_LARGE(w, h)                                            \
617   void aom_highbd_h_predictor_##w##x##h##_neon(                             \
618       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,               \
619       const uint16_t *left, int bd) {                                       \
620     (void)above;                                                            \
621     (void)bd;                                                               \
622     for (int i = 0; i < (h) / 16; ++i) {                                    \
623       uint16x8_t l0 = vld1q_u16(left + 0);                                  \
624       uint16x8_t l1 = vld1q_u16(left + 8);                                  \
625       highbd_h_store_##w##x4(dst + 0 * stride, stride, vget_low_u16(l0));   \
626       highbd_h_store_##w##x4(dst + 4 * stride, stride, vget_high_u16(l0));  \
627       highbd_h_store_##w##x4(dst + 8 * stride, stride, vget_low_u16(l1));   \
628       highbd_h_store_##w##x4(dst + 12 * stride, stride, vget_high_u16(l1)); \
629       left += 16;                                                           \
630       dst += 16 * stride;                                                   \
631     }                                                                       \
632   }
633 
634 HIGHBD_H_WXH_LARGE(4, 16)
635 HIGHBD_H_WXH_LARGE(8, 16)
636 HIGHBD_H_WXH_LARGE(8, 32)
637 HIGHBD_H_WXH_LARGE(16, 16)
638 HIGHBD_H_WXH_LARGE(16, 32)
639 HIGHBD_H_WXH_LARGE(16, 64)
640 HIGHBD_H_WXH_LARGE(32, 16)
641 HIGHBD_H_WXH_LARGE(32, 32)
642 HIGHBD_H_WXH_LARGE(32, 64)
643 HIGHBD_H_WXH_LARGE(64, 16)
644 HIGHBD_H_WXH_LARGE(64, 32)
645 HIGHBD_H_WXH_LARGE(64, 64)
646 
647 #undef HIGHBD_H_WXH_LARGE
648 
649 // -----------------------------------------------------------------------------
650 // PAETH
651 
highbd_paeth_4or8_x_h_neon(uint16_t * dest,ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,int width,int height)652 static INLINE void highbd_paeth_4or8_x_h_neon(uint16_t *dest, ptrdiff_t stride,
653                                               const uint16_t *const top_row,
654                                               const uint16_t *const left_column,
655                                               int width, int height) {
656   const uint16x8_t top_left = vdupq_n_u16(top_row[-1]);
657   const uint16x8_t top_left_x2 = vdupq_n_u16(top_row[-1] + top_row[-1]);
658   uint16x8_t top;
659   if (width == 4) {
660     top = vcombine_u16(vld1_u16(top_row), vdup_n_u16(0));
661   } else {  // width == 8
662     top = vld1q_u16(top_row);
663   }
664 
665   for (int y = 0; y < height; ++y) {
666     const uint16x8_t left = vdupq_n_u16(left_column[y]);
667 
668     const uint16x8_t left_dist = vabdq_u16(top, top_left);
669     const uint16x8_t top_dist = vabdq_u16(left, top_left);
670     const uint16x8_t top_left_dist =
671         vabdq_u16(vaddq_u16(top, left), top_left_x2);
672 
673     const uint16x8_t left_le_top = vcleq_u16(left_dist, top_dist);
674     const uint16x8_t left_le_top_left = vcleq_u16(left_dist, top_left_dist);
675     const uint16x8_t top_le_top_left = vcleq_u16(top_dist, top_left_dist);
676 
677     // if (left_dist <= top_dist && left_dist <= top_left_dist)
678     const uint16x8_t left_mask = vandq_u16(left_le_top, left_le_top_left);
679     //   dest[x] = left_column[y];
680     // Fill all the unused spaces with 'top'. They will be overwritten when
681     // the positions for top_left are known.
682     uint16x8_t result = vbslq_u16(left_mask, left, top);
683     // else if (top_dist <= top_left_dist)
684     //   dest[x] = top_row[x];
685     // Add these values to the mask. They were already set.
686     const uint16x8_t left_or_top_mask = vorrq_u16(left_mask, top_le_top_left);
687     // else
688     //   dest[x] = top_left;
689     result = vbslq_u16(left_or_top_mask, result, top_left);
690 
691     if (width == 4) {
692       vst1_u16(dest, vget_low_u16(result));
693     } else {  // width == 8
694       vst1q_u16(dest, result);
695     }
696     dest += stride;
697   }
698 }
699 
700 #define HIGHBD_PAETH_NXM(W, H)                                  \
701   void aom_highbd_paeth_predictor_##W##x##H##_neon(             \
702       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,   \
703       const uint16_t *left, int bd) {                           \
704     (void)bd;                                                   \
705     highbd_paeth_4or8_x_h_neon(dst, stride, above, left, W, H); \
706   }
707 
708 HIGHBD_PAETH_NXM(4, 4)
709 HIGHBD_PAETH_NXM(4, 8)
710 HIGHBD_PAETH_NXM(4, 16)
711 HIGHBD_PAETH_NXM(8, 4)
712 HIGHBD_PAETH_NXM(8, 8)
713 HIGHBD_PAETH_NXM(8, 16)
714 HIGHBD_PAETH_NXM(8, 32)
715 
716 // Select the closest values and collect them.
select_paeth(const uint16x8_t top,const uint16x8_t left,const uint16x8_t top_left,const uint16x8_t left_le_top,const uint16x8_t left_le_top_left,const uint16x8_t top_le_top_left)717 static INLINE uint16x8_t select_paeth(const uint16x8_t top,
718                                       const uint16x8_t left,
719                                       const uint16x8_t top_left,
720                                       const uint16x8_t left_le_top,
721                                       const uint16x8_t left_le_top_left,
722                                       const uint16x8_t top_le_top_left) {
723   // if (left_dist <= top_dist && left_dist <= top_left_dist)
724   const uint16x8_t left_mask = vandq_u16(left_le_top, left_le_top_left);
725   //   dest[x] = left_column[y];
726   // Fill all the unused spaces with 'top'. They will be overwritten when
727   // the positions for top_left are known.
728   const uint16x8_t result = vbslq_u16(left_mask, left, top);
729   // else if (top_dist <= top_left_dist)
730   //   dest[x] = top_row[x];
731   // Add these values to the mask. They were already set.
732   const uint16x8_t left_or_top_mask = vorrq_u16(left_mask, top_le_top_left);
733   // else
734   //   dest[x] = top_left;
735   return vbslq_u16(left_or_top_mask, result, top_left);
736 }
737 
738 #define PAETH_PREDICTOR(num)                                                  \
739   do {                                                                        \
740     const uint16x8_t left_dist = vabdq_u16(top[num], top_left);               \
741     const uint16x8_t top_left_dist =                                          \
742         vabdq_u16(vaddq_u16(top[num], left), top_left_x2);                    \
743     const uint16x8_t left_le_top = vcleq_u16(left_dist, top_dist);            \
744     const uint16x8_t left_le_top_left = vcleq_u16(left_dist, top_left_dist);  \
745     const uint16x8_t top_le_top_left = vcleq_u16(top_dist, top_left_dist);    \
746     const uint16x8_t result =                                                 \
747         select_paeth(top[num], left, top_left, left_le_top, left_le_top_left, \
748                      top_le_top_left);                                        \
749     vst1q_u16(dest + (num * 8), result);                                      \
750   } while (0)
751 
752 #define LOAD_TOP_ROW(num) vld1q_u16(top_row + (num * 8))
753 
highbd_paeth16_plus_x_h_neon(uint16_t * dest,ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,int width,int height)754 static INLINE void highbd_paeth16_plus_x_h_neon(
755     uint16_t *dest, ptrdiff_t stride, const uint16_t *const top_row,
756     const uint16_t *const left_column, int width, int height) {
757   const uint16x8_t top_left = vdupq_n_u16(top_row[-1]);
758   const uint16x8_t top_left_x2 = vdupq_n_u16(top_row[-1] + top_row[-1]);
759   uint16x8_t top[8];
760   top[0] = LOAD_TOP_ROW(0);
761   top[1] = LOAD_TOP_ROW(1);
762   if (width > 16) {
763     top[2] = LOAD_TOP_ROW(2);
764     top[3] = LOAD_TOP_ROW(3);
765     if (width == 64) {
766       top[4] = LOAD_TOP_ROW(4);
767       top[5] = LOAD_TOP_ROW(5);
768       top[6] = LOAD_TOP_ROW(6);
769       top[7] = LOAD_TOP_ROW(7);
770     }
771   }
772 
773   for (int y = 0; y < height; ++y) {
774     const uint16x8_t left = vdupq_n_u16(left_column[y]);
775     const uint16x8_t top_dist = vabdq_u16(left, top_left);
776     PAETH_PREDICTOR(0);
777     PAETH_PREDICTOR(1);
778     if (width > 16) {
779       PAETH_PREDICTOR(2);
780       PAETH_PREDICTOR(3);
781       if (width == 64) {
782         PAETH_PREDICTOR(4);
783         PAETH_PREDICTOR(5);
784         PAETH_PREDICTOR(6);
785         PAETH_PREDICTOR(7);
786       }
787     }
788     dest += stride;
789   }
790 }
791 
792 #define HIGHBD_PAETH_NXM_WIDE(W, H)                               \
793   void aom_highbd_paeth_predictor_##W##x##H##_neon(               \
794       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,     \
795       const uint16_t *left, int bd) {                             \
796     (void)bd;                                                     \
797     highbd_paeth16_plus_x_h_neon(dst, stride, above, left, W, H); \
798   }
799 
800 HIGHBD_PAETH_NXM_WIDE(16, 4)
801 HIGHBD_PAETH_NXM_WIDE(16, 8)
802 HIGHBD_PAETH_NXM_WIDE(16, 16)
803 HIGHBD_PAETH_NXM_WIDE(16, 32)
804 HIGHBD_PAETH_NXM_WIDE(16, 64)
805 HIGHBD_PAETH_NXM_WIDE(32, 8)
806 HIGHBD_PAETH_NXM_WIDE(32, 16)
807 HIGHBD_PAETH_NXM_WIDE(32, 32)
808 HIGHBD_PAETH_NXM_WIDE(32, 64)
809 HIGHBD_PAETH_NXM_WIDE(64, 16)
810 HIGHBD_PAETH_NXM_WIDE(64, 32)
811 HIGHBD_PAETH_NXM_WIDE(64, 64)
812 
813 // -----------------------------------------------------------------------------
814 // SMOOTH
815 
816 // 256 - v = vneg_s8(v)
negate_s8(const uint16x4_t v)817 static INLINE uint16x4_t negate_s8(const uint16x4_t v) {
818   return vreinterpret_u16_s8(vneg_s8(vreinterpret_s8_u16(v)));
819 }
820 
highbd_smooth_4xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,const int height)821 static INLINE void highbd_smooth_4xh_neon(uint16_t *dst, ptrdiff_t stride,
822                                           const uint16_t *const top_row,
823                                           const uint16_t *const left_column,
824                                           const int height) {
825   const uint16_t top_right = top_row[3];
826   const uint16_t bottom_left = left_column[height - 1];
827   const uint16_t *const weights_y = smooth_weights_u16 + height - 4;
828 
829   const uint16x4_t top_v = vld1_u16(top_row);
830   const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
831   const uint16x4_t weights_x_v = vld1_u16(smooth_weights_u16);
832   const uint16x4_t scaled_weights_x = negate_s8(weights_x_v);
833   const uint32x4_t weighted_tr = vmull_n_u16(scaled_weights_x, top_right);
834 
835   for (int y = 0; y < height; ++y) {
836     // Each variable in the running summation is named for the last item to be
837     // accumulated.
838     const uint32x4_t weighted_top =
839         vmlal_n_u16(weighted_tr, top_v, weights_y[y]);
840     const uint32x4_t weighted_left =
841         vmlal_n_u16(weighted_top, weights_x_v, left_column[y]);
842     const uint32x4_t weighted_bl =
843         vmlal_n_u16(weighted_left, bottom_left_v, 256 - weights_y[y]);
844 
845     const uint16x4_t pred =
846         vrshrn_n_u32(weighted_bl, SMOOTH_WEIGHT_LOG2_SCALE + 1);
847     vst1_u16(dst, pred);
848     dst += stride;
849   }
850 }
851 
852 // Common code between 8xH and [16|32|64]xH.
highbd_calculate_pred8(uint16_t * dst,const uint32x4_t weighted_corners_low,const uint32x4_t weighted_corners_high,const uint16x4x2_t top_vals,const uint16x4x2_t weights_x,const uint16_t left_y,const uint16_t weight_y)853 static INLINE void highbd_calculate_pred8(
854     uint16_t *dst, const uint32x4_t weighted_corners_low,
855     const uint32x4_t weighted_corners_high, const uint16x4x2_t top_vals,
856     const uint16x4x2_t weights_x, const uint16_t left_y,
857     const uint16_t weight_y) {
858   // Each variable in the running summation is named for the last item to be
859   // accumulated.
860   const uint32x4_t weighted_top_low =
861       vmlal_n_u16(weighted_corners_low, top_vals.val[0], weight_y);
862   const uint32x4_t weighted_edges_low =
863       vmlal_n_u16(weighted_top_low, weights_x.val[0], left_y);
864 
865   const uint16x4_t pred_low =
866       vrshrn_n_u32(weighted_edges_low, SMOOTH_WEIGHT_LOG2_SCALE + 1);
867   vst1_u16(dst, pred_low);
868 
869   const uint32x4_t weighted_top_high =
870       vmlal_n_u16(weighted_corners_high, top_vals.val[1], weight_y);
871   const uint32x4_t weighted_edges_high =
872       vmlal_n_u16(weighted_top_high, weights_x.val[1], left_y);
873 
874   const uint16x4_t pred_high =
875       vrshrn_n_u32(weighted_edges_high, SMOOTH_WEIGHT_LOG2_SCALE + 1);
876   vst1_u16(dst + 4, pred_high);
877 }
878 
highbd_smooth_8xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,const int height)879 static void highbd_smooth_8xh_neon(uint16_t *dst, ptrdiff_t stride,
880                                    const uint16_t *const top_row,
881                                    const uint16_t *const left_column,
882                                    const int height) {
883   const uint16_t top_right = top_row[7];
884   const uint16_t bottom_left = left_column[height - 1];
885   const uint16_t *const weights_y = smooth_weights_u16 + height - 4;
886 
887   const uint16x4x2_t top_vals = { { vld1_u16(top_row),
888                                     vld1_u16(top_row + 4) } };
889   const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
890   const uint16x4x2_t weights_x = { { vld1_u16(smooth_weights_u16 + 4),
891                                      vld1_u16(smooth_weights_u16 + 8) } };
892   const uint32x4_t weighted_tr_low =
893       vmull_n_u16(negate_s8(weights_x.val[0]), top_right);
894   const uint32x4_t weighted_tr_high =
895       vmull_n_u16(negate_s8(weights_x.val[1]), top_right);
896 
897   for (int y = 0; y < height; ++y) {
898     const uint32x4_t weighted_bl =
899         vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
900     const uint32x4_t weighted_corners_low =
901         vaddq_u32(weighted_bl, weighted_tr_low);
902     const uint32x4_t weighted_corners_high =
903         vaddq_u32(weighted_bl, weighted_tr_high);
904     highbd_calculate_pred8(dst, weighted_corners_low, weighted_corners_high,
905                            top_vals, weights_x, left_column[y], weights_y[y]);
906     dst += stride;
907   }
908 }
909 
910 #define HIGHBD_SMOOTH_NXM(W, H)                                 \
911   void aom_highbd_smooth_predictor_##W##x##H##_neon(            \
912       uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \
913       const uint16_t *left, int bd) {                           \
914     (void)bd;                                                   \
915     highbd_smooth_##W##xh_neon(dst, y_stride, above, left, H);  \
916   }
917 
918 HIGHBD_SMOOTH_NXM(4, 4)
919 HIGHBD_SMOOTH_NXM(4, 8)
920 HIGHBD_SMOOTH_NXM(8, 4)
921 HIGHBD_SMOOTH_NXM(8, 8)
922 HIGHBD_SMOOTH_NXM(4, 16)
923 HIGHBD_SMOOTH_NXM(8, 16)
924 HIGHBD_SMOOTH_NXM(8, 32)
925 
926 #undef HIGHBD_SMOOTH_NXM
927 
928 // For width 16 and above.
929 #define HIGHBD_SMOOTH_PREDICTOR(W)                                             \
930   static void highbd_smooth_##W##xh_neon(                                      \
931       uint16_t *dst, ptrdiff_t stride, const uint16_t *const top_row,          \
932       const uint16_t *const left_column, const int height) {                   \
933     const uint16_t top_right = top_row[(W)-1];                                 \
934     const uint16_t bottom_left = left_column[height - 1];                      \
935     const uint16_t *const weights_y = smooth_weights_u16 + height - 4;         \
936                                                                                \
937     /* Precompute weighted values that don't vary with |y|. */                 \
938     uint32x4_t weighted_tr_low[(W) >> 3];                                      \
939     uint32x4_t weighted_tr_high[(W) >> 3];                                     \
940     for (int i = 0; i < (W) >> 3; ++i) {                                       \
941       const int x = i << 3;                                                    \
942       const uint16x4_t weights_x_low =                                         \
943           vld1_u16(smooth_weights_u16 + (W)-4 + x);                            \
944       weighted_tr_low[i] = vmull_n_u16(negate_s8(weights_x_low), top_right);   \
945       const uint16x4_t weights_x_high =                                        \
946           vld1_u16(smooth_weights_u16 + (W) + x);                              \
947       weighted_tr_high[i] = vmull_n_u16(negate_s8(weights_x_high), top_right); \
948     }                                                                          \
949                                                                                \
950     const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);                  \
951     for (int y = 0; y < height; ++y) {                                         \
952       const uint32x4_t weighted_bl =                                           \
953           vmull_n_u16(bottom_left_v, 256 - weights_y[y]);                      \
954       uint16_t *dst_x = dst;                                                   \
955       for (int i = 0; i < (W) >> 3; ++i) {                                     \
956         const int x = i << 3;                                                  \
957         const uint16x4x2_t top_vals = { { vld1_u16(top_row + x),               \
958                                           vld1_u16(top_row + x + 4) } };       \
959         const uint32x4_t weighted_corners_low =                                \
960             vaddq_u32(weighted_bl, weighted_tr_low[i]);                        \
961         const uint32x4_t weighted_corners_high =                               \
962             vaddq_u32(weighted_bl, weighted_tr_high[i]);                       \
963         /* Accumulate weighted edge values and store. */                       \
964         const uint16x4x2_t weights_x = {                                       \
965           { vld1_u16(smooth_weights_u16 + (W)-4 + x),                          \
966             vld1_u16(smooth_weights_u16 + (W) + x) }                           \
967         };                                                                     \
968         highbd_calculate_pred8(dst_x, weighted_corners_low,                    \
969                                weighted_corners_high, top_vals, weights_x,     \
970                                left_column[y], weights_y[y]);                  \
971         dst_x += 8;                                                            \
972       }                                                                        \
973       dst += stride;                                                           \
974     }                                                                          \
975   }
976 
977 HIGHBD_SMOOTH_PREDICTOR(16)
978 HIGHBD_SMOOTH_PREDICTOR(32)
979 HIGHBD_SMOOTH_PREDICTOR(64)
980 
981 #undef HIGHBD_SMOOTH_PREDICTOR
982 
983 #define HIGHBD_SMOOTH_NXM_WIDE(W, H)                            \
984   void aom_highbd_smooth_predictor_##W##x##H##_neon(            \
985       uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \
986       const uint16_t *left, int bd) {                           \
987     (void)bd;                                                   \
988     highbd_smooth_##W##xh_neon(dst, y_stride, above, left, H);  \
989   }
990 
991 HIGHBD_SMOOTH_NXM_WIDE(16, 4)
992 HIGHBD_SMOOTH_NXM_WIDE(16, 8)
993 HIGHBD_SMOOTH_NXM_WIDE(16, 16)
994 HIGHBD_SMOOTH_NXM_WIDE(16, 32)
995 HIGHBD_SMOOTH_NXM_WIDE(16, 64)
996 HIGHBD_SMOOTH_NXM_WIDE(32, 8)
997 HIGHBD_SMOOTH_NXM_WIDE(32, 16)
998 HIGHBD_SMOOTH_NXM_WIDE(32, 32)
999 HIGHBD_SMOOTH_NXM_WIDE(32, 64)
1000 HIGHBD_SMOOTH_NXM_WIDE(64, 16)
1001 HIGHBD_SMOOTH_NXM_WIDE(64, 32)
1002 HIGHBD_SMOOTH_NXM_WIDE(64, 64)
1003 
1004 #undef HIGHBD_SMOOTH_NXM_WIDE
1005 
highbd_smooth_v_4xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,const int height)1006 static void highbd_smooth_v_4xh_neon(uint16_t *dst, ptrdiff_t stride,
1007                                      const uint16_t *const top_row,
1008                                      const uint16_t *const left_column,
1009                                      const int height) {
1010   const uint16_t bottom_left = left_column[height - 1];
1011   const uint16_t *const weights_y = smooth_weights_u16 + height - 4;
1012 
1013   const uint16x4_t top_v = vld1_u16(top_row);
1014   const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
1015 
1016   for (int y = 0; y < height; ++y) {
1017     const uint32x4_t weighted_bl =
1018         vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
1019     const uint32x4_t weighted_top =
1020         vmlal_n_u16(weighted_bl, top_v, weights_y[y]);
1021     vst1_u16(dst, vrshrn_n_u32(weighted_top, SMOOTH_WEIGHT_LOG2_SCALE));
1022 
1023     dst += stride;
1024   }
1025 }
1026 
highbd_smooth_v_8xh_neon(uint16_t * dst,const ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,const int height)1027 static void highbd_smooth_v_8xh_neon(uint16_t *dst, const ptrdiff_t stride,
1028                                      const uint16_t *const top_row,
1029                                      const uint16_t *const left_column,
1030                                      const int height) {
1031   const uint16_t bottom_left = left_column[height - 1];
1032   const uint16_t *const weights_y = smooth_weights_u16 + height - 4;
1033 
1034   const uint16x4_t top_low = vld1_u16(top_row);
1035   const uint16x4_t top_high = vld1_u16(top_row + 4);
1036   const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
1037 
1038   for (int y = 0; y < height; ++y) {
1039     const uint32x4_t weighted_bl =
1040         vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
1041 
1042     const uint32x4_t weighted_top_low =
1043         vmlal_n_u16(weighted_bl, top_low, weights_y[y]);
1044     vst1_u16(dst, vrshrn_n_u32(weighted_top_low, SMOOTH_WEIGHT_LOG2_SCALE));
1045 
1046     const uint32x4_t weighted_top_high =
1047         vmlal_n_u16(weighted_bl, top_high, weights_y[y]);
1048     vst1_u16(dst + 4,
1049              vrshrn_n_u32(weighted_top_high, SMOOTH_WEIGHT_LOG2_SCALE));
1050     dst += stride;
1051   }
1052 }
1053 
1054 #define HIGHBD_SMOOTH_V_NXM(W, H)                                \
1055   void aom_highbd_smooth_v_predictor_##W##x##H##_neon(           \
1056       uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above,  \
1057       const uint16_t *left, int bd) {                            \
1058     (void)bd;                                                    \
1059     highbd_smooth_v_##W##xh_neon(dst, y_stride, above, left, H); \
1060   }
1061 
1062 HIGHBD_SMOOTH_V_NXM(4, 4)
1063 HIGHBD_SMOOTH_V_NXM(4, 8)
1064 HIGHBD_SMOOTH_V_NXM(4, 16)
1065 HIGHBD_SMOOTH_V_NXM(8, 4)
1066 HIGHBD_SMOOTH_V_NXM(8, 8)
1067 HIGHBD_SMOOTH_V_NXM(8, 16)
1068 HIGHBD_SMOOTH_V_NXM(8, 32)
1069 
1070 #undef HIGHBD_SMOOTH_V_NXM
1071 
1072 // For width 16 and above.
1073 #define HIGHBD_SMOOTH_V_PREDICTOR(W)                                         \
1074   static void highbd_smooth_v_##W##xh_neon(                                  \
1075       uint16_t *dst, const ptrdiff_t stride, const uint16_t *const top_row,  \
1076       const uint16_t *const left_column, const int height) {                 \
1077     const uint16_t bottom_left = left_column[height - 1];                    \
1078     const uint16_t *const weights_y = smooth_weights_u16 + height - 4;       \
1079                                                                              \
1080     uint16x4x2_t top_vals[(W) >> 3];                                         \
1081     for (int i = 0; i < (W) >> 3; ++i) {                                     \
1082       const int x = i << 3;                                                  \
1083       top_vals[i].val[0] = vld1_u16(top_row + x);                            \
1084       top_vals[i].val[1] = vld1_u16(top_row + x + 4);                        \
1085     }                                                                        \
1086                                                                              \
1087     const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);                \
1088     for (int y = 0; y < height; ++y) {                                       \
1089       const uint32x4_t weighted_bl =                                         \
1090           vmull_n_u16(bottom_left_v, 256 - weights_y[y]);                    \
1091                                                                              \
1092       uint16_t *dst_x = dst;                                                 \
1093       for (int i = 0; i < (W) >> 3; ++i) {                                   \
1094         const uint32x4_t weighted_top_low =                                  \
1095             vmlal_n_u16(weighted_bl, top_vals[i].val[0], weights_y[y]);      \
1096         vst1_u16(dst_x,                                                      \
1097                  vrshrn_n_u32(weighted_top_low, SMOOTH_WEIGHT_LOG2_SCALE));  \
1098                                                                              \
1099         const uint32x4_t weighted_top_high =                                 \
1100             vmlal_n_u16(weighted_bl, top_vals[i].val[1], weights_y[y]);      \
1101         vst1_u16(dst_x + 4,                                                  \
1102                  vrshrn_n_u32(weighted_top_high, SMOOTH_WEIGHT_LOG2_SCALE)); \
1103         dst_x += 8;                                                          \
1104       }                                                                      \
1105       dst += stride;                                                         \
1106     }                                                                        \
1107   }
1108 
1109 HIGHBD_SMOOTH_V_PREDICTOR(16)
1110 HIGHBD_SMOOTH_V_PREDICTOR(32)
1111 HIGHBD_SMOOTH_V_PREDICTOR(64)
1112 
1113 #undef HIGHBD_SMOOTH_V_PREDICTOR
1114 
1115 #define HIGHBD_SMOOTH_V_NXM_WIDE(W, H)                           \
1116   void aom_highbd_smooth_v_predictor_##W##x##H##_neon(           \
1117       uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above,  \
1118       const uint16_t *left, int bd) {                            \
1119     (void)bd;                                                    \
1120     highbd_smooth_v_##W##xh_neon(dst, y_stride, above, left, H); \
1121   }
1122 
1123 HIGHBD_SMOOTH_V_NXM_WIDE(16, 4)
1124 HIGHBD_SMOOTH_V_NXM_WIDE(16, 8)
1125 HIGHBD_SMOOTH_V_NXM_WIDE(16, 16)
1126 HIGHBD_SMOOTH_V_NXM_WIDE(16, 32)
1127 HIGHBD_SMOOTH_V_NXM_WIDE(16, 64)
1128 HIGHBD_SMOOTH_V_NXM_WIDE(32, 8)
1129 HIGHBD_SMOOTH_V_NXM_WIDE(32, 16)
1130 HIGHBD_SMOOTH_V_NXM_WIDE(32, 32)
1131 HIGHBD_SMOOTH_V_NXM_WIDE(32, 64)
1132 HIGHBD_SMOOTH_V_NXM_WIDE(64, 16)
1133 HIGHBD_SMOOTH_V_NXM_WIDE(64, 32)
1134 HIGHBD_SMOOTH_V_NXM_WIDE(64, 64)
1135 
1136 #undef HIGHBD_SMOOTH_V_NXM_WIDE
1137 
highbd_smooth_h_4xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,const int height)1138 static INLINE void highbd_smooth_h_4xh_neon(uint16_t *dst, ptrdiff_t stride,
1139                                             const uint16_t *const top_row,
1140                                             const uint16_t *const left_column,
1141                                             const int height) {
1142   const uint16_t top_right = top_row[3];
1143 
1144   const uint16x4_t weights_x = vld1_u16(smooth_weights_u16);
1145   const uint16x4_t scaled_weights_x = negate_s8(weights_x);
1146 
1147   const uint32x4_t weighted_tr = vmull_n_u16(scaled_weights_x, top_right);
1148   for (int y = 0; y < height; ++y) {
1149     const uint32x4_t weighted_left =
1150         vmlal_n_u16(weighted_tr, weights_x, left_column[y]);
1151     vst1_u16(dst, vrshrn_n_u32(weighted_left, SMOOTH_WEIGHT_LOG2_SCALE));
1152     dst += stride;
1153   }
1154 }
1155 
highbd_smooth_h_8xh_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * const top_row,const uint16_t * const left_column,const int height)1156 static INLINE void highbd_smooth_h_8xh_neon(uint16_t *dst, ptrdiff_t stride,
1157                                             const uint16_t *const top_row,
1158                                             const uint16_t *const left_column,
1159                                             const int height) {
1160   const uint16_t top_right = top_row[7];
1161 
1162   const uint16x4x2_t weights_x = { { vld1_u16(smooth_weights_u16 + 4),
1163                                      vld1_u16(smooth_weights_u16 + 8) } };
1164 
1165   const uint32x4_t weighted_tr_low =
1166       vmull_n_u16(negate_s8(weights_x.val[0]), top_right);
1167   const uint32x4_t weighted_tr_high =
1168       vmull_n_u16(negate_s8(weights_x.val[1]), top_right);
1169 
1170   for (int y = 0; y < height; ++y) {
1171     const uint16_t left_y = left_column[y];
1172     const uint32x4_t weighted_left_low =
1173         vmlal_n_u16(weighted_tr_low, weights_x.val[0], left_y);
1174     vst1_u16(dst, vrshrn_n_u32(weighted_left_low, SMOOTH_WEIGHT_LOG2_SCALE));
1175 
1176     const uint32x4_t weighted_left_high =
1177         vmlal_n_u16(weighted_tr_high, weights_x.val[1], left_y);
1178     vst1_u16(dst + 4,
1179              vrshrn_n_u32(weighted_left_high, SMOOTH_WEIGHT_LOG2_SCALE));
1180     dst += stride;
1181   }
1182 }
1183 
1184 #define HIGHBD_SMOOTH_H_NXM(W, H)                                \
1185   void aom_highbd_smooth_h_predictor_##W##x##H##_neon(           \
1186       uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above,  \
1187       const uint16_t *left, int bd) {                            \
1188     (void)bd;                                                    \
1189     highbd_smooth_h_##W##xh_neon(dst, y_stride, above, left, H); \
1190   }
1191 
1192 HIGHBD_SMOOTH_H_NXM(4, 4)
1193 HIGHBD_SMOOTH_H_NXM(4, 8)
1194 HIGHBD_SMOOTH_H_NXM(4, 16)
1195 HIGHBD_SMOOTH_H_NXM(8, 4)
1196 HIGHBD_SMOOTH_H_NXM(8, 8)
1197 HIGHBD_SMOOTH_H_NXM(8, 16)
1198 HIGHBD_SMOOTH_H_NXM(8, 32)
1199 
1200 #undef HIGHBD_SMOOTH_H_NXM
1201 
1202 // For width 16 and above.
1203 #define HIGHBD_SMOOTH_H_PREDICTOR(W)                                          \
1204   void highbd_smooth_h_##W##xh_neon(                                          \
1205       uint16_t *dst, ptrdiff_t stride, const uint16_t *const top_row,         \
1206       const uint16_t *const left_column, const int height) {                  \
1207     const uint16_t top_right = top_row[(W)-1];                                \
1208                                                                               \
1209     uint16x4_t weights_x_low[(W) >> 3];                                       \
1210     uint16x4_t weights_x_high[(W) >> 3];                                      \
1211     uint32x4_t weighted_tr_low[(W) >> 3];                                     \
1212     uint32x4_t weighted_tr_high[(W) >> 3];                                    \
1213     for (int i = 0; i < (W) >> 3; ++i) {                                      \
1214       const int x = i << 3;                                                   \
1215       weights_x_low[i] = vld1_u16(smooth_weights_u16 + (W)-4 + x);            \
1216       weighted_tr_low[i] =                                                    \
1217           vmull_n_u16(negate_s8(weights_x_low[i]), top_right);                \
1218       weights_x_high[i] = vld1_u16(smooth_weights_u16 + (W) + x);             \
1219       weighted_tr_high[i] =                                                   \
1220           vmull_n_u16(negate_s8(weights_x_high[i]), top_right);               \
1221     }                                                                         \
1222                                                                               \
1223     for (int y = 0; y < height; ++y) {                                        \
1224       uint16_t *dst_x = dst;                                                  \
1225       const uint16_t left_y = left_column[y];                                 \
1226       for (int i = 0; i < (W) >> 3; ++i) {                                    \
1227         const uint32x4_t weighted_left_low =                                  \
1228             vmlal_n_u16(weighted_tr_low[i], weights_x_low[i], left_y);        \
1229         vst1_u16(dst_x,                                                       \
1230                  vrshrn_n_u32(weighted_left_low, SMOOTH_WEIGHT_LOG2_SCALE));  \
1231                                                                               \
1232         const uint32x4_t weighted_left_high =                                 \
1233             vmlal_n_u16(weighted_tr_high[i], weights_x_high[i], left_y);      \
1234         vst1_u16(dst_x + 4,                                                   \
1235                  vrshrn_n_u32(weighted_left_high, SMOOTH_WEIGHT_LOG2_SCALE)); \
1236         dst_x += 8;                                                           \
1237       }                                                                       \
1238       dst += stride;                                                          \
1239     }                                                                         \
1240   }
1241 
1242 HIGHBD_SMOOTH_H_PREDICTOR(16)
1243 HIGHBD_SMOOTH_H_PREDICTOR(32)
1244 HIGHBD_SMOOTH_H_PREDICTOR(64)
1245 
1246 #undef HIGHBD_SMOOTH_H_PREDICTOR
1247 
1248 #define HIGHBD_SMOOTH_H_NXM_WIDE(W, H)                           \
1249   void aom_highbd_smooth_h_predictor_##W##x##H##_neon(           \
1250       uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above,  \
1251       const uint16_t *left, int bd) {                            \
1252     (void)bd;                                                    \
1253     highbd_smooth_h_##W##xh_neon(dst, y_stride, above, left, H); \
1254   }
1255 
1256 HIGHBD_SMOOTH_H_NXM_WIDE(16, 4)
1257 HIGHBD_SMOOTH_H_NXM_WIDE(16, 8)
1258 HIGHBD_SMOOTH_H_NXM_WIDE(16, 16)
1259 HIGHBD_SMOOTH_H_NXM_WIDE(16, 32)
1260 HIGHBD_SMOOTH_H_NXM_WIDE(16, 64)
1261 HIGHBD_SMOOTH_H_NXM_WIDE(32, 8)
1262 HIGHBD_SMOOTH_H_NXM_WIDE(32, 16)
1263 HIGHBD_SMOOTH_H_NXM_WIDE(32, 32)
1264 HIGHBD_SMOOTH_H_NXM_WIDE(32, 64)
1265 HIGHBD_SMOOTH_H_NXM_WIDE(64, 16)
1266 HIGHBD_SMOOTH_H_NXM_WIDE(64, 32)
1267 HIGHBD_SMOOTH_H_NXM_WIDE(64, 64)
1268 
1269 #undef HIGHBD_SMOOTH_H_NXM_WIDE
1270 
1271 // -----------------------------------------------------------------------------
1272 // Z1
1273 
1274 static int16_t iota1_s16[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8 };
1275 static int16_t iota2_s16[] = { 0, 2, 4, 6, 8, 10, 12, 14 };
1276 
highbd_dr_z1_apply_shift_x4(uint16x4_t a0,uint16x4_t a1,int shift)1277 static AOM_FORCE_INLINE uint16x4_t highbd_dr_z1_apply_shift_x4(uint16x4_t a0,
1278                                                                uint16x4_t a1,
1279                                                                int shift) {
1280   // The C implementation of the z1 predictor uses (32 - shift) and a right
1281   // shift by 5, however we instead double shift to avoid an unnecessary right
1282   // shift by 1.
1283   uint32x4_t res = vmull_n_u16(a1, shift);
1284   res = vmlal_n_u16(res, a0, 64 - shift);
1285   return vrshrn_n_u32(res, 6);
1286 }
1287 
highbd_dr_z1_apply_shift_x8(uint16x8_t a0,uint16x8_t a1,int shift)1288 static AOM_FORCE_INLINE uint16x8_t highbd_dr_z1_apply_shift_x8(uint16x8_t a0,
1289                                                                uint16x8_t a1,
1290                                                                int shift) {
1291   return vcombine_u16(
1292       highbd_dr_z1_apply_shift_x4(vget_low_u16(a0), vget_low_u16(a1), shift),
1293       highbd_dr_z1_apply_shift_x4(vget_high_u16(a0), vget_high_u16(a1), shift));
1294 }
1295 
1296 // clang-format off
1297 static const uint8_t kLoadMaxShuffles[] = {
1298   14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
1299   12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
1300   10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
1301    8,  9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
1302    6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15,
1303    4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15,
1304    2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 14, 15,
1305    0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
1306 };
1307 // clang-format on
1308 
zn_load_masked_neon(const uint16_t * ptr,int shuffle_idx)1309 static INLINE uint16x8_t zn_load_masked_neon(const uint16_t *ptr,
1310                                              int shuffle_idx) {
1311   uint8x16_t shuffle = vld1q_u8(&kLoadMaxShuffles[16 * shuffle_idx]);
1312   uint8x16_t src = vreinterpretq_u8_u16(vld1q_u16(ptr));
1313 #if AOM_ARCH_AARCH64
1314   return vreinterpretq_u16_u8(vqtbl1q_u8(src, shuffle));
1315 #else
1316   uint8x8x2_t src2 = { { vget_low_u8(src), vget_high_u8(src) } };
1317   uint8x8_t lo = vtbl2_u8(src2, vget_low_u8(shuffle));
1318   uint8x8_t hi = vtbl2_u8(src2, vget_high_u8(shuffle));
1319   return vreinterpretq_u16_u8(vcombine_u8(lo, hi));
1320 #endif
1321 }
1322 
highbd_dr_prediction_z1_upsample0_neon(uint16_t * dst,ptrdiff_t stride,int bw,int bh,const uint16_t * above,int dx)1323 static void highbd_dr_prediction_z1_upsample0_neon(uint16_t *dst,
1324                                                    ptrdiff_t stride, int bw,
1325                                                    int bh,
1326                                                    const uint16_t *above,
1327                                                    int dx) {
1328   assert(bw % 4 == 0);
1329   assert(bh % 4 == 0);
1330   assert(dx > 0);
1331 
1332   const int max_base_x = (bw + bh) - 1;
1333   const int above_max = above[max_base_x];
1334 
1335   const int16x8_t iota1x8 = vld1q_s16(iota1_s16);
1336   const int16x4_t iota1x4 = vget_low_s16(iota1x8);
1337 
1338   int x = dx;
1339   int r = 0;
1340   do {
1341     const int base = x >> 6;
1342     if (base >= max_base_x) {
1343       for (int i = r; i < bh; ++i) {
1344         aom_memset16(dst, above_max, bw);
1345         dst += stride;
1346       }
1347       return;
1348     }
1349 
1350     // The C implementation of the z1 predictor when not upsampling uses:
1351     // ((x & 0x3f) >> 1)
1352     // The right shift is unnecessary here since we instead shift by +1 later,
1353     // so adjust the mask to 0x3e to ensure we don't consider the extra bit.
1354     const int shift = x & 0x3e;
1355 
1356     if (bw == 4) {
1357       const uint16x4_t a0 = vld1_u16(&above[base]);
1358       const uint16x4_t a1 = vld1_u16(&above[base + 1]);
1359       const uint16x4_t val = highbd_dr_z1_apply_shift_x4(a0, a1, shift);
1360       const uint16x4_t cmp = vcgt_s16(vdup_n_s16(max_base_x - base), iota1x4);
1361       const uint16x4_t res = vbsl_u16(cmp, val, vdup_n_u16(above_max));
1362       vst1_u16(dst, res);
1363     } else {
1364       int c = 0;
1365       do {
1366         uint16x8_t a0;
1367         uint16x8_t a1;
1368         if (base + c >= max_base_x) {
1369           a0 = a1 = vdupq_n_u16(above_max);
1370         } else {
1371           if (base + c + 7 >= max_base_x) {
1372             int shuffle_idx = max_base_x - base - c;
1373             a0 = zn_load_masked_neon(above + (max_base_x - 7), shuffle_idx);
1374           } else {
1375             a0 = vld1q_u16(above + base + c);
1376           }
1377           if (base + c + 8 >= max_base_x) {
1378             int shuffle_idx = max_base_x - base - c - 1;
1379             a1 = zn_load_masked_neon(above + (max_base_x - 7), shuffle_idx);
1380           } else {
1381             a1 = vld1q_u16(above + base + c + 1);
1382           }
1383         }
1384 
1385         vst1q_u16(dst + c, highbd_dr_z1_apply_shift_x8(a0, a1, shift));
1386         c += 8;
1387       } while (c < bw);
1388     }
1389 
1390     dst += stride;
1391     x += dx;
1392   } while (++r < bh);
1393 }
1394 
highbd_dr_prediction_z1_upsample1_neon(uint16_t * dst,ptrdiff_t stride,int bw,int bh,const uint16_t * above,int dx)1395 static void highbd_dr_prediction_z1_upsample1_neon(uint16_t *dst,
1396                                                    ptrdiff_t stride, int bw,
1397                                                    int bh,
1398                                                    const uint16_t *above,
1399                                                    int dx) {
1400   assert(bw % 4 == 0);
1401   assert(bh % 4 == 0);
1402   assert(dx > 0);
1403 
1404   const int max_base_x = ((bw + bh) - 1) << 1;
1405   const int above_max = above[max_base_x];
1406 
1407   const int16x8_t iota2x8 = vld1q_s16(iota2_s16);
1408   const int16x4_t iota2x4 = vget_low_s16(iota2x8);
1409 
1410   int x = dx;
1411   int r = 0;
1412   do {
1413     const int base = x >> 5;
1414     if (base >= max_base_x) {
1415       for (int i = r; i < bh; ++i) {
1416         aom_memset16(dst, above_max, bw);
1417         dst += stride;
1418       }
1419       return;
1420     }
1421 
1422     // The C implementation of the z1 predictor when upsampling uses:
1423     // (((x << 1) & 0x3f) >> 1)
1424     // The right shift is unnecessary here since we instead shift by +1 later,
1425     // so adjust the mask to 0x3e to ensure we don't consider the extra bit.
1426     const int shift = (x << 1) & 0x3e;
1427 
1428     if (bw == 4) {
1429       const uint16x4x2_t a01 = vld2_u16(&above[base]);
1430       const uint16x4_t val =
1431           highbd_dr_z1_apply_shift_x4(a01.val[0], a01.val[1], shift);
1432       const uint16x4_t cmp = vcgt_s16(vdup_n_s16(max_base_x - base), iota2x4);
1433       const uint16x4_t res = vbsl_u16(cmp, val, vdup_n_u16(above_max));
1434       vst1_u16(dst, res);
1435     } else {
1436       int c = 0;
1437       do {
1438         const uint16x8x2_t a01 = vld2q_u16(&above[base + 2 * c]);
1439         const uint16x8_t val =
1440             highbd_dr_z1_apply_shift_x8(a01.val[0], a01.val[1], shift);
1441         const uint16x8_t cmp =
1442             vcgtq_s16(vdupq_n_s16(max_base_x - base - 2 * c), iota2x8);
1443         const uint16x8_t res = vbslq_u16(cmp, val, vdupq_n_u16(above_max));
1444         vst1q_u16(dst + c, res);
1445         c += 8;
1446       } while (c < bw);
1447     }
1448 
1449     dst += stride;
1450     x += dx;
1451   } while (++r < bh);
1452 }
1453 
1454 // Directional prediction, zone 1: 0 < angle < 90
av1_highbd_dr_prediction_z1_neon(uint16_t * dst,ptrdiff_t stride,int bw,int bh,const uint16_t * above,const uint16_t * left,int upsample_above,int dx,int dy,int bd)1455 void av1_highbd_dr_prediction_z1_neon(uint16_t *dst, ptrdiff_t stride, int bw,
1456                                       int bh, const uint16_t *above,
1457                                       const uint16_t *left, int upsample_above,
1458                                       int dx, int dy, int bd) {
1459   (void)left;
1460   (void)dy;
1461   (void)bd;
1462   assert(dy == 1);
1463 
1464   if (upsample_above) {
1465     highbd_dr_prediction_z1_upsample1_neon(dst, stride, bw, bh, above, dx);
1466   } else {
1467     highbd_dr_prediction_z1_upsample0_neon(dst, stride, bw, bh, above, dx);
1468   }
1469 }
1470 
1471 // -----------------------------------------------------------------------------
1472 // Z2
1473 
1474 #if AOM_ARCH_AARCH64
1475 // Incrementally shift more elements from `above` into the result, merging with
1476 // existing `left` elements.
1477 // X0, X1, X2, X3
1478 // Y0, X0, X1, X2
1479 // Y0, Y1, X0, X1
1480 // Y0, Y1, Y2, X0
1481 // Y0, Y1, Y2, Y3
1482 // clang-format off
1483 static const uint8_t z2_merge_shuffles_u16x4[5][8] = {
1484   {  8,  9, 10, 11, 12, 13, 14, 15 },
1485   {  0,  1,  8,  9, 10, 11, 12, 13 },
1486   {  0,  1,  2,  3,  8,  9, 10, 11 },
1487   {  0,  1,  2,  3,  4,  5,  8,  9 },
1488   {  0,  1,  2,  3,  4,  5,  6,  7 },
1489 };
1490 // clang-format on
1491 
1492 // Incrementally shift more elements from `above` into the result, merging with
1493 // existing `left` elements.
1494 // X0, X1, X2, X3, X4, X5, X6, X7
1495 // Y0, X0, X1, X2, X3, X4, X5, X6
1496 // Y0, Y1, X0, X1, X2, X3, X4, X5
1497 // Y0, Y1, Y2, X0, X1, X2, X3, X4
1498 // Y0, Y1, Y2, Y3, X0, X1, X2, X3
1499 // Y0, Y1, Y2, Y3, Y4, X0, X1, X2
1500 // Y0, Y1, Y2, Y3, Y4, Y5, X0, X1
1501 // Y0, Y1, Y2, Y3, Y4, Y5, Y6, X0
1502 // Y0, Y1, Y2, Y3, Y4, Y5, Y6, Y7
1503 // clang-format off
1504 static const uint8_t z2_merge_shuffles_u16x8[9][16] = {
1505   { 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 },
1506   {  0,  1, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29 },
1507   {  0,  1,  2,  3, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27 },
1508   {  0,  1,  2,  3,  4,  5, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 },
1509   {  0,  1,  2,  3,  4,  5,  6,  7, 16, 17, 18, 19, 20, 21, 22, 23 },
1510   {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 16, 17, 18, 19, 20, 21 },
1511   {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 16, 17, 18, 19 },
1512   {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 16, 17 },
1513   {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15 },
1514 };
1515 // clang-format on
1516 
1517 // clang-format off
1518 static const uint16_t z2_y_iter_masks_u16x4[5][4] = {
1519   {      0U,      0U,      0U,      0U },
1520   { 0xffffU,      0U,      0U,      0U },
1521   { 0xffffU, 0xffffU,      0U,      0U },
1522   { 0xffffU, 0xffffU, 0xffffU,      0U },
1523   { 0xffffU, 0xffffU, 0xffffU, 0xffffU },
1524 };
1525 // clang-format on
1526 
1527 // clang-format off
1528 static const uint16_t z2_y_iter_masks_u16x8[9][8] = {
1529   {      0U,      0U,      0U,      0U,      0U,      0U,      0U,      0U },
1530   { 0xffffU,      0U,      0U,      0U,      0U,      0U,      0U,      0U },
1531   { 0xffffU, 0xffffU,      0U,      0U,      0U,      0U,      0U,      0U },
1532   { 0xffffU, 0xffffU, 0xffffU,      0U,      0U,      0U,      0U,      0U },
1533   { 0xffffU, 0xffffU, 0xffffU, 0xffffU,      0U,      0U,      0U,      0U },
1534   { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU,      0U,      0U,      0U },
1535   { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU,      0U,      0U },
1536   { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU,      0U },
1537   { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU },
1538 };
1539 // clang-format on
1540 
highbd_dr_prediction_z2_tbl_left_x4_from_x8(const uint16x8_t left_data,const int16x4_t indices,int base,int n)1541 static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_tbl_left_x4_from_x8(
1542     const uint16x8_t left_data, const int16x4_t indices, int base, int n) {
1543   // Need to adjust indices to operate on 0-based indices rather than
1544   // `base`-based indices and then adjust from uint16x4 indices to uint8x8
1545   // indices so we can use a tbl instruction (which only operates on bytes).
1546   uint8x8_t left_indices =
1547       vreinterpret_u8_s16(vsub_s16(indices, vdup_n_s16(base)));
1548   left_indices = vtrn1_u8(left_indices, left_indices);
1549   left_indices = vadd_u8(left_indices, left_indices);
1550   left_indices = vadd_u8(left_indices, vreinterpret_u8_u16(vdup_n_u16(0x0100)));
1551   const uint16x4_t ret = vreinterpret_u16_u8(
1552       vqtbl1_u8(vreinterpretq_u8_u16(left_data), left_indices));
1553   return vand_u16(ret, vld1_u16(z2_y_iter_masks_u16x4[n]));
1554 }
1555 
highbd_dr_prediction_z2_tbl_left_x4_from_x16(const uint16x8x2_t left_data,const int16x4_t indices,int base,int n)1556 static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_tbl_left_x4_from_x16(
1557     const uint16x8x2_t left_data, const int16x4_t indices, int base, int n) {
1558   // Need to adjust indices to operate on 0-based indices rather than
1559   // `base`-based indices and then adjust from uint16x4 indices to uint8x8
1560   // indices so we can use a tbl instruction (which only operates on bytes).
1561   uint8x8_t left_indices =
1562       vreinterpret_u8_s16(vsub_s16(indices, vdup_n_s16(base)));
1563   left_indices = vtrn1_u8(left_indices, left_indices);
1564   left_indices = vadd_u8(left_indices, left_indices);
1565   left_indices = vadd_u8(left_indices, vreinterpret_u8_u16(vdup_n_u16(0x0100)));
1566   uint8x16x2_t data_u8 = { { vreinterpretq_u8_u16(left_data.val[0]),
1567                              vreinterpretq_u8_u16(left_data.val[1]) } };
1568   const uint16x4_t ret = vreinterpret_u16_u8(vqtbl2_u8(data_u8, left_indices));
1569   return vand_u16(ret, vld1_u16(z2_y_iter_masks_u16x4[n]));
1570 }
1571 
highbd_dr_prediction_z2_tbl_left_x8_from_x8(const uint16x8_t left_data,const int16x8_t indices,int base,int n)1572 static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_tbl_left_x8_from_x8(
1573     const uint16x8_t left_data, const int16x8_t indices, int base, int n) {
1574   // Need to adjust indices to operate on 0-based indices rather than
1575   // `base`-based indices and then adjust from uint16x4 indices to uint8x8
1576   // indices so we can use a tbl instruction (which only operates on bytes).
1577   uint8x16_t left_indices =
1578       vreinterpretq_u8_s16(vsubq_s16(indices, vdupq_n_s16(base)));
1579   left_indices = vtrn1q_u8(left_indices, left_indices);
1580   left_indices = vaddq_u8(left_indices, left_indices);
1581   left_indices =
1582       vaddq_u8(left_indices, vreinterpretq_u8_u16(vdupq_n_u16(0x0100)));
1583   const uint16x8_t ret = vreinterpretq_u16_u8(
1584       vqtbl1q_u8(vreinterpretq_u8_u16(left_data), left_indices));
1585   return vandq_u16(ret, vld1q_u16(z2_y_iter_masks_u16x8[n]));
1586 }
1587 
highbd_dr_prediction_z2_tbl_left_x8_from_x16(const uint16x8x2_t left_data,const int16x8_t indices,int base,int n)1588 static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_tbl_left_x8_from_x16(
1589     const uint16x8x2_t left_data, const int16x8_t indices, int base, int n) {
1590   // Need to adjust indices to operate on 0-based indices rather than
1591   // `base`-based indices and then adjust from uint16x4 indices to uint8x8
1592   // indices so we can use a tbl instruction (which only operates on bytes).
1593   uint8x16_t left_indices =
1594       vreinterpretq_u8_s16(vsubq_s16(indices, vdupq_n_s16(base)));
1595   left_indices = vtrn1q_u8(left_indices, left_indices);
1596   left_indices = vaddq_u8(left_indices, left_indices);
1597   left_indices =
1598       vaddq_u8(left_indices, vreinterpretq_u8_u16(vdupq_n_u16(0x0100)));
1599   uint8x16x2_t data_u8 = { { vreinterpretq_u8_u16(left_data.val[0]),
1600                              vreinterpretq_u8_u16(left_data.val[1]) } };
1601   const uint16x8_t ret =
1602       vreinterpretq_u16_u8(vqtbl2q_u8(data_u8, left_indices));
1603   return vandq_u16(ret, vld1q_u16(z2_y_iter_masks_u16x8[n]));
1604 }
1605 #endif  // AOM_ARCH_AARCH64
1606 
1607 // TODO(aomedia:349428506): enable this for armv7 after SIGBUS is fixed.
1608 #if AOM_ARCH_AARCH64
highbd_dr_prediction_z2_gather_left_x4(const uint16_t * left,const int16x4_t indices,int n)1609 static AOM_FORCE_INLINE uint16x4x2_t highbd_dr_prediction_z2_gather_left_x4(
1610     const uint16_t *left, const int16x4_t indices, int n) {
1611   assert(n > 0);
1612   assert(n <= 4);
1613   // Load two elements at a time and then uzp them into separate vectors, to
1614   // reduce the number of memory accesses.
1615   uint32x2_t ret0_u32 = vdup_n_u32(0);
1616   uint32x2_t ret1_u32 = vdup_n_u32(0);
1617 
1618   // Use a single vget_lane_u64 to minimize vector to general purpose register
1619   // transfers and then mask off the bits we actually want.
1620   const uint64_t indices0123 = vget_lane_u64(vreinterpret_u64_s16(indices), 0);
1621   const int idx0 = (int16_t)((indices0123 >> 0) & 0xffffU);
1622   const int idx1 = (int16_t)((indices0123 >> 16) & 0xffffU);
1623   const int idx2 = (int16_t)((indices0123 >> 32) & 0xffffU);
1624   const int idx3 = (int16_t)((indices0123 >> 48) & 0xffffU);
1625 
1626   // At time of writing both Clang and GCC produced better code with these
1627   // nested if-statements compared to a switch statement with fallthrough.
1628   ret0_u32 = vld1_lane_u32((const uint32_t *)(left + idx0), ret0_u32, 0);
1629   if (n > 1) {
1630     ret0_u32 = vld1_lane_u32((const uint32_t *)(left + idx1), ret0_u32, 1);
1631     if (n > 2) {
1632       ret1_u32 = vld1_lane_u32((const uint32_t *)(left + idx2), ret1_u32, 0);
1633       if (n > 3) {
1634         ret1_u32 = vld1_lane_u32((const uint32_t *)(left + idx3), ret1_u32, 1);
1635       }
1636     }
1637   }
1638   return vuzp_u16(vreinterpret_u16_u32(ret0_u32),
1639                   vreinterpret_u16_u32(ret1_u32));
1640 }
1641 
highbd_dr_prediction_z2_gather_left_x8(const uint16_t * left,const int16x8_t indices,int n)1642 static AOM_FORCE_INLINE uint16x8x2_t highbd_dr_prediction_z2_gather_left_x8(
1643     const uint16_t *left, const int16x8_t indices, int n) {
1644   assert(n > 0);
1645   assert(n <= 8);
1646   // Load two elements at a time and then uzp them into separate vectors, to
1647   // reduce the number of memory accesses.
1648   uint32x4_t ret0_u32 = vdupq_n_u32(0);
1649   uint32x4_t ret1_u32 = vdupq_n_u32(0);
1650 
1651   // Use a pair of vget_lane_u64 to minimize vector to general purpose register
1652   // transfers and then mask off the bits we actually want.
1653   const uint64_t indices0123 =
1654       vgetq_lane_u64(vreinterpretq_u64_s16(indices), 0);
1655   const uint64_t indices4567 =
1656       vgetq_lane_u64(vreinterpretq_u64_s16(indices), 1);
1657   const int idx0 = (int16_t)((indices0123 >> 0) & 0xffffU);
1658   const int idx1 = (int16_t)((indices0123 >> 16) & 0xffffU);
1659   const int idx2 = (int16_t)((indices0123 >> 32) & 0xffffU);
1660   const int idx3 = (int16_t)((indices0123 >> 48) & 0xffffU);
1661   const int idx4 = (int16_t)((indices4567 >> 0) & 0xffffU);
1662   const int idx5 = (int16_t)((indices4567 >> 16) & 0xffffU);
1663   const int idx6 = (int16_t)((indices4567 >> 32) & 0xffffU);
1664   const int idx7 = (int16_t)((indices4567 >> 48) & 0xffffU);
1665 
1666   // At time of writing both Clang and GCC produced better code with these
1667   // nested if-statements compared to a switch statement with fallthrough.
1668   ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx0), ret0_u32, 0);
1669   if (n > 1) {
1670     ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx1), ret0_u32, 1);
1671     if (n > 2) {
1672       ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx2), ret0_u32, 2);
1673       if (n > 3) {
1674         ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx3), ret0_u32, 3);
1675         if (n > 4) {
1676           ret1_u32 =
1677               vld1q_lane_u32((const uint32_t *)(left + idx4), ret1_u32, 0);
1678           if (n > 5) {
1679             ret1_u32 =
1680                 vld1q_lane_u32((const uint32_t *)(left + idx5), ret1_u32, 1);
1681             if (n > 6) {
1682               ret1_u32 =
1683                   vld1q_lane_u32((const uint32_t *)(left + idx6), ret1_u32, 2);
1684               if (n > 7) {
1685                 ret1_u32 = vld1q_lane_u32((const uint32_t *)(left + idx7),
1686                                           ret1_u32, 3);
1687               }
1688             }
1689           }
1690         }
1691       }
1692     }
1693   }
1694   return vuzpq_u16(vreinterpretq_u16_u32(ret0_u32),
1695                    vreinterpretq_u16_u32(ret1_u32));
1696 }
1697 
highbd_dr_prediction_z2_merge_x4(uint16x4_t out_x,uint16x4_t out_y,int base_shift)1698 static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_merge_x4(
1699     uint16x4_t out_x, uint16x4_t out_y, int base_shift) {
1700   assert(base_shift >= 0);
1701   assert(base_shift <= 4);
1702   // On AArch64 we can permute the data from the `above` and `left` vectors
1703   // into a single vector in a single load (of the permute vector) + tbl.
1704 #if AOM_ARCH_AARCH64
1705   const uint8x8x2_t out_yx = { { vreinterpret_u8_u16(out_y),
1706                                  vreinterpret_u8_u16(out_x) } };
1707   return vreinterpret_u16_u8(
1708       vtbl2_u8(out_yx, vld1_u8(z2_merge_shuffles_u16x4[base_shift])));
1709 #else
1710   uint16x4_t out = out_y;
1711   for (int c2 = base_shift, x_idx = 0; c2 < 4; ++c2, ++x_idx) {
1712     out[c2] = out_x[x_idx];
1713   }
1714   return out;
1715 #endif
1716 }
1717 
highbd_dr_prediction_z2_merge_x8(uint16x8_t out_x,uint16x8_t out_y,int base_shift)1718 static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_merge_x8(
1719     uint16x8_t out_x, uint16x8_t out_y, int base_shift) {
1720   assert(base_shift >= 0);
1721   assert(base_shift <= 8);
1722   // On AArch64 we can permute the data from the `above` and `left` vectors
1723   // into a single vector in a single load (of the permute vector) + tbl.
1724 #if AOM_ARCH_AARCH64
1725   const uint8x16x2_t out_yx = { { vreinterpretq_u8_u16(out_y),
1726                                   vreinterpretq_u8_u16(out_x) } };
1727   return vreinterpretq_u16_u8(
1728       vqtbl2q_u8(out_yx, vld1q_u8(z2_merge_shuffles_u16x8[base_shift])));
1729 #else
1730   uint16x8_t out = out_y;
1731   for (int c2 = base_shift, x_idx = 0; c2 < 8; ++c2, ++x_idx) {
1732     out[c2] = out_x[x_idx];
1733   }
1734   return out;
1735 #endif
1736 }
1737 
highbd_dr_prediction_z2_apply_shift_x4(uint16x4_t a0,uint16x4_t a1,int16x4_t shift)1738 static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_apply_shift_x4(
1739     uint16x4_t a0, uint16x4_t a1, int16x4_t shift) {
1740   uint32x4_t res = vmull_u16(a1, vreinterpret_u16_s16(shift));
1741   res =
1742       vmlal_u16(res, a0, vsub_u16(vdup_n_u16(32), vreinterpret_u16_s16(shift)));
1743   return vrshrn_n_u32(res, 5);
1744 }
1745 
highbd_dr_prediction_z2_apply_shift_x8(uint16x8_t a0,uint16x8_t a1,int16x8_t shift)1746 static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_apply_shift_x8(
1747     uint16x8_t a0, uint16x8_t a1, int16x8_t shift) {
1748   return vcombine_u16(
1749       highbd_dr_prediction_z2_apply_shift_x4(vget_low_u16(a0), vget_low_u16(a1),
1750                                              vget_low_s16(shift)),
1751       highbd_dr_prediction_z2_apply_shift_x4(
1752           vget_high_u16(a0), vget_high_u16(a1), vget_high_s16(shift)));
1753 }
1754 
highbd_dr_prediction_z2_step_x4(const uint16_t * above,const uint16x4_t above0,const uint16x4_t above1,const uint16_t * left,int dx,int dy,int r,int c)1755 static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_step_x4(
1756     const uint16_t *above, const uint16x4_t above0, const uint16x4_t above1,
1757     const uint16_t *left, int dx, int dy, int r, int c) {
1758   const int16x4_t iota = vld1_s16(iota1_s16);
1759 
1760   const int x0 = (c << 6) - (r + 1) * dx;
1761   const int y0 = (r << 6) - (c + 1) * dy;
1762 
1763   const int16x4_t x0123 = vadd_s16(vdup_n_s16(x0), vshl_n_s16(iota, 6));
1764   const int16x4_t y0123 = vsub_s16(vdup_n_s16(y0), vmul_n_s16(iota, dy));
1765   const int16x4_t shift_x0123 =
1766       vshr_n_s16(vand_s16(x0123, vdup_n_s16(0x3F)), 1);
1767   const int16x4_t shift_y0123 =
1768       vshr_n_s16(vand_s16(y0123, vdup_n_s16(0x3F)), 1);
1769   const int16x4_t base_y0123 = vshr_n_s16(y0123, 6);
1770 
1771   const int base_shift = ((((r + 1) * dx) - 1) >> 6) - c;
1772 
1773   // Based on the value of `base_shift` there are three possible cases to
1774   // compute the result:
1775   // 1) base_shift <= 0: We can load and operate entirely on data from the
1776   //                     `above` input vector.
1777   // 2) base_shift < vl: We can load from `above[-1]` and shift
1778   //                     `vl - base_shift` elements across to the end of the
1779   //                     vector, then compute the remainder from `left`.
1780   // 3) base_shift >= vl: We can load and operate entirely on data from the
1781   //                      `left` input vector.
1782 
1783   if (base_shift <= 0) {
1784     const int base_x = x0 >> 6;
1785     const uint16x4_t a0 = vld1_u16(above + base_x);
1786     const uint16x4_t a1 = vld1_u16(above + base_x + 1);
1787     return highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123);
1788   } else if (base_shift < 4) {
1789     const uint16x4x2_t l01 = highbd_dr_prediction_z2_gather_left_x4(
1790         left + 1, base_y0123, base_shift);
1791     const uint16x4_t out16_y = highbd_dr_prediction_z2_apply_shift_x4(
1792         l01.val[0], l01.val[1], shift_y0123);
1793 
1794     // No need to reload from above in the loop, just use pre-loaded constants.
1795     const uint16x4_t out16_x =
1796         highbd_dr_prediction_z2_apply_shift_x4(above0, above1, shift_x0123);
1797 
1798     return highbd_dr_prediction_z2_merge_x4(out16_x, out16_y, base_shift);
1799   } else {
1800     const uint16x4x2_t l01 =
1801         highbd_dr_prediction_z2_gather_left_x4(left + 1, base_y0123, 4);
1802     return highbd_dr_prediction_z2_apply_shift_x4(l01.val[0], l01.val[1],
1803                                                   shift_y0123);
1804   }
1805 }
1806 
highbd_dr_prediction_z2_step_x8(const uint16_t * above,const uint16x8_t above0,const uint16x8_t above1,const uint16_t * left,int dx,int dy,int r,int c)1807 static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_step_x8(
1808     const uint16_t *above, const uint16x8_t above0, const uint16x8_t above1,
1809     const uint16_t *left, int dx, int dy, int r, int c) {
1810   const int16x8_t iota = vld1q_s16(iota1_s16);
1811 
1812   const int x0 = (c << 6) - (r + 1) * dx;
1813   const int y0 = (r << 6) - (c + 1) * dy;
1814 
1815   const int16x8_t x01234567 = vaddq_s16(vdupq_n_s16(x0), vshlq_n_s16(iota, 6));
1816   const int16x8_t y01234567 = vsubq_s16(vdupq_n_s16(y0), vmulq_n_s16(iota, dy));
1817   const int16x8_t shift_x01234567 =
1818       vshrq_n_s16(vandq_s16(x01234567, vdupq_n_s16(0x3F)), 1);
1819   const int16x8_t shift_y01234567 =
1820       vshrq_n_s16(vandq_s16(y01234567, vdupq_n_s16(0x3F)), 1);
1821   const int16x8_t base_y01234567 = vshrq_n_s16(y01234567, 6);
1822 
1823   const int base_shift = ((((r + 1) * dx) - 1) >> 6) - c;
1824 
1825   // Based on the value of `base_shift` there are three possible cases to
1826   // compute the result:
1827   // 1) base_shift <= 0: We can load and operate entirely on data from the
1828   //                     `above` input vector.
1829   // 2) base_shift < vl: We can load from `above[-1]` and shift
1830   //                     `vl - base_shift` elements across to the end of the
1831   //                     vector, then compute the remainder from `left`.
1832   // 3) base_shift >= vl: We can load and operate entirely on data from the
1833   //                      `left` input vector.
1834 
1835   if (base_shift <= 0) {
1836     const int base_x = x0 >> 6;
1837     const uint16x8_t a0 = vld1q_u16(above + base_x);
1838     const uint16x8_t a1 = vld1q_u16(above + base_x + 1);
1839     return highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567);
1840   } else if (base_shift < 8) {
1841     const uint16x8x2_t l01 = highbd_dr_prediction_z2_gather_left_x8(
1842         left + 1, base_y01234567, base_shift);
1843     const uint16x8_t out16_y = highbd_dr_prediction_z2_apply_shift_x8(
1844         l01.val[0], l01.val[1], shift_y01234567);
1845 
1846     // No need to reload from above in the loop, just use pre-loaded constants.
1847     const uint16x8_t out16_x =
1848         highbd_dr_prediction_z2_apply_shift_x8(above0, above1, shift_x01234567);
1849 
1850     return highbd_dr_prediction_z2_merge_x8(out16_x, out16_y, base_shift);
1851   } else {
1852     const uint16x8x2_t l01 =
1853         highbd_dr_prediction_z2_gather_left_x8(left + 1, base_y01234567, 8);
1854     return highbd_dr_prediction_z2_apply_shift_x8(l01.val[0], l01.val[1],
1855                                                   shift_y01234567);
1856   }
1857 }
1858 
1859 // Left array is accessed from -1 through `bh - 1` inclusive.
1860 // Above array is accessed from -1 through `bw - 1` inclusive.
1861 #define HIGHBD_DR_PREDICTOR_Z2_WXH(bw, bh)                                 \
1862   static void highbd_dr_prediction_z2_##bw##x##bh##_neon(                  \
1863       uint16_t *dst, ptrdiff_t stride, const uint16_t *above,              \
1864       const uint16_t *left, int upsample_above, int upsample_left, int dx, \
1865       int dy, int bd) {                                                    \
1866     (void)bd;                                                              \
1867     (void)upsample_above;                                                  \
1868     (void)upsample_left;                                                   \
1869     assert(!upsample_above);                                               \
1870     assert(!upsample_left);                                                \
1871     assert(bw % 4 == 0);                                                   \
1872     assert(bh % 4 == 0);                                                   \
1873     assert(dx > 0);                                                        \
1874     assert(dy > 0);                                                        \
1875                                                                            \
1876     uint16_t left_data[bh + 1];                                            \
1877     memcpy(left_data, left - 1, (bh + 1) * sizeof(uint16_t));              \
1878                                                                            \
1879     uint16x8_t a0, a1;                                                     \
1880     if (bw == 4) {                                                         \
1881       a0 = vcombine_u16(vld1_u16(above - 1), vdup_n_u16(0));               \
1882       a1 = vcombine_u16(vld1_u16(above + 0), vdup_n_u16(0));               \
1883     } else {                                                               \
1884       a0 = vld1q_u16(above - 1);                                           \
1885       a1 = vld1q_u16(above + 0);                                           \
1886     }                                                                      \
1887                                                                            \
1888     int r = 0;                                                             \
1889     do {                                                                   \
1890       if (bw == 4) {                                                       \
1891         vst1_u16(dst, highbd_dr_prediction_z2_step_x4(                     \
1892                           above, vget_low_u16(a0), vget_low_u16(a1),       \
1893                           left_data, dx, dy, r, 0));                       \
1894       } else {                                                             \
1895         int c = 0;                                                         \
1896         do {                                                               \
1897           vst1q_u16(dst + c, highbd_dr_prediction_z2_step_x8(              \
1898                                  above, a0, a1, left_data, dx, dy, r, c)); \
1899           c += 8;                                                          \
1900         } while (c < bw);                                                  \
1901       }                                                                    \
1902       dst += stride;                                                       \
1903     } while (++r < bh);                                                    \
1904   }
1905 
1906 HIGHBD_DR_PREDICTOR_Z2_WXH(4, 16)
1907 HIGHBD_DR_PREDICTOR_Z2_WXH(8, 16)
1908 HIGHBD_DR_PREDICTOR_Z2_WXH(8, 32)
1909 HIGHBD_DR_PREDICTOR_Z2_WXH(16, 4)
1910 HIGHBD_DR_PREDICTOR_Z2_WXH(16, 8)
1911 HIGHBD_DR_PREDICTOR_Z2_WXH(16, 16)
1912 HIGHBD_DR_PREDICTOR_Z2_WXH(16, 32)
1913 HIGHBD_DR_PREDICTOR_Z2_WXH(16, 64)
1914 HIGHBD_DR_PREDICTOR_Z2_WXH(32, 8)
1915 HIGHBD_DR_PREDICTOR_Z2_WXH(32, 16)
1916 HIGHBD_DR_PREDICTOR_Z2_WXH(32, 32)
1917 HIGHBD_DR_PREDICTOR_Z2_WXH(32, 64)
1918 HIGHBD_DR_PREDICTOR_Z2_WXH(64, 16)
1919 HIGHBD_DR_PREDICTOR_Z2_WXH(64, 32)
1920 HIGHBD_DR_PREDICTOR_Z2_WXH(64, 64)
1921 
1922 #undef HIGHBD_DR_PREDICTOR_Z2_WXH
1923 
1924 typedef void (*highbd_dr_prediction_z2_ptr)(uint16_t *dst, ptrdiff_t stride,
1925                                             const uint16_t *above,
1926                                             const uint16_t *left,
1927                                             int upsample_above,
1928                                             int upsample_left, int dx, int dy,
1929                                             int bd);
1930 
highbd_dr_prediction_z2_4x4_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int upsample_above,int upsample_left,int dx,int dy,int bd)1931 static void highbd_dr_prediction_z2_4x4_neon(uint16_t *dst, ptrdiff_t stride,
1932                                              const uint16_t *above,
1933                                              const uint16_t *left,
1934                                              int upsample_above,
1935                                              int upsample_left, int dx, int dy,
1936                                              int bd) {
1937   (void)bd;
1938   assert(dx > 0);
1939   assert(dy > 0);
1940 
1941   const int frac_bits_x = 6 - upsample_above;
1942   const int frac_bits_y = 6 - upsample_left;
1943   const int min_base_x = -(1 << (upsample_above + frac_bits_x));
1944 
1945   // if `upsample_left` then we need -2 through 6 inclusive from `left`.
1946   // else we only need -1 through 3 inclusive.
1947 
1948 #if AOM_ARCH_AARCH64
1949   uint16x8_t left_data0, left_data1;
1950   if (upsample_left) {
1951     left_data0 = vld1q_u16(left - 2);
1952     left_data1 = vld1q_u16(left - 1);
1953   } else {
1954     left_data0 = vcombine_u16(vld1_u16(left - 1), vdup_n_u16(0));
1955     left_data1 = vcombine_u16(vld1_u16(left + 0), vdup_n_u16(0));
1956   }
1957 #endif
1958 
1959   const int16x4_t iota0123 = vld1_s16(iota1_s16);
1960   const int16x4_t iota1234 = vld1_s16(iota1_s16 + 1);
1961 
1962   for (int r = 0; r < 4; ++r) {
1963     const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6;
1964     const int x0 = (r + 1) * dx;
1965     const int16x4_t x0123 = vsub_s16(vshl_n_s16(iota0123, 6), vdup_n_s16(x0));
1966     const int base_x0 = (-x0) >> frac_bits_x;
1967     if (base_shift <= 0) {
1968       uint16x4_t a0, a1;
1969       int16x4_t shift_x0123;
1970       if (upsample_above) {
1971         const uint16x4x2_t a01 = vld2_u16(above + base_x0);
1972         a0 = a01.val[0];
1973         a1 = a01.val[1];
1974         shift_x0123 = vand_s16(x0123, vdup_n_s16(0x1F));
1975       } else {
1976         a0 = vld1_u16(above + base_x0);
1977         a1 = vld1_u16(above + base_x0 + 1);
1978         shift_x0123 = vshr_n_s16(vand_s16(x0123, vdup_n_s16(0x3F)), 1);
1979       }
1980       vst1_u16(dst,
1981                highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123));
1982     } else if (base_shift < 4) {
1983       // Calculate Y component from `left`.
1984       const int y_iters = base_shift;
1985       const int16x4_t y0123 =
1986           vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy));
1987       const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y));
1988       const int16x4_t shift_y0123 = vshr_n_s16(
1989           vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1);
1990       uint16x4_t l0, l1;
1991 #if AOM_ARCH_AARCH64
1992       const int left_data_base = upsample_left ? -2 : -1;
1993       l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data0, base_y0123,
1994                                                        left_data_base, y_iters);
1995       l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data1, base_y0123,
1996                                                        left_data_base, y_iters);
1997 #else
1998       const uint16x4x2_t l01 =
1999           highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, y_iters);
2000       l0 = l01.val[0];
2001       l1 = l01.val[1];
2002 #endif
2003 
2004       const uint16x4_t out_y =
2005           highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123);
2006 
2007       // Calculate X component from `above`.
2008       const int16x4_t shift_x0123 = vshr_n_s16(
2009           vand_s16(vmul_n_s16(x0123, 1 << upsample_above), vdup_n_s16(0x3F)),
2010           1);
2011       uint16x4_t a0, a1;
2012       if (upsample_above) {
2013         const uint16x4x2_t a01 = vld2_u16(above + (base_x0 % 2 == 0 ? -2 : -1));
2014         a0 = a01.val[0];
2015         a1 = a01.val[1];
2016       } else {
2017         a0 = vld1_u16(above - 1);
2018         a1 = vld1_u16(above + 0);
2019       }
2020       const uint16x4_t out_x =
2021           highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123);
2022 
2023       // Combine X and Y vectors.
2024       const uint16x4_t out =
2025           highbd_dr_prediction_z2_merge_x4(out_x, out_y, base_shift);
2026       vst1_u16(dst, out);
2027     } else {
2028       const int16x4_t y0123 =
2029           vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy));
2030       const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y));
2031       const int16x4_t shift_y0123 = vshr_n_s16(
2032           vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1);
2033       uint16x4_t l0, l1;
2034 #if AOM_ARCH_AARCH64
2035       const int left_data_base = upsample_left ? -2 : -1;
2036       l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data0, base_y0123,
2037                                                        left_data_base, 4);
2038       l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data1, base_y0123,
2039                                                        left_data_base, 4);
2040 #else
2041       const uint16x4x2_t l01 =
2042           highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, 4);
2043       l0 = l01.val[0];
2044       l1 = l01.val[1];
2045 #endif
2046       vst1_u16(dst,
2047                highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123));
2048     }
2049     dst += stride;
2050   }
2051 }
2052 
highbd_dr_prediction_z2_4x8_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int upsample_above,int upsample_left,int dx,int dy,int bd)2053 static void highbd_dr_prediction_z2_4x8_neon(uint16_t *dst, ptrdiff_t stride,
2054                                              const uint16_t *above,
2055                                              const uint16_t *left,
2056                                              int upsample_above,
2057                                              int upsample_left, int dx, int dy,
2058                                              int bd) {
2059   (void)bd;
2060   assert(dx > 0);
2061   assert(dy > 0);
2062 
2063   const int frac_bits_x = 6 - upsample_above;
2064   const int frac_bits_y = 6 - upsample_left;
2065   const int min_base_x = -(1 << (upsample_above + frac_bits_x));
2066 
2067   // if `upsample_left` then we need -2 through 14 inclusive from `left`.
2068   // else we only need -1 through 6 inclusive.
2069 
2070 #if AOM_ARCH_AARCH64
2071   uint16x8x2_t left_data0, left_data1;
2072   if (upsample_left) {
2073     left_data0 = vld1q_u16_x2(left - 2);
2074     left_data1 = vld1q_u16_x2(left - 1);
2075   } else {
2076     left_data0 = (uint16x8x2_t){ { vld1q_u16(left - 1), vdupq_n_u16(0) } };
2077     left_data1 = (uint16x8x2_t){ { vld1q_u16(left + 0), vdupq_n_u16(0) } };
2078   }
2079 #endif
2080 
2081   const int16x4_t iota0123 = vld1_s16(iota1_s16);
2082   const int16x4_t iota1234 = vld1_s16(iota1_s16 + 1);
2083 
2084   for (int r = 0; r < 8; ++r) {
2085     const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6;
2086     const int x0 = (r + 1) * dx;
2087     const int16x4_t x0123 = vsub_s16(vshl_n_s16(iota0123, 6), vdup_n_s16(x0));
2088     const int base_x0 = (-x0) >> frac_bits_x;
2089     if (base_shift <= 0) {
2090       uint16x4_t a0, a1;
2091       int16x4_t shift_x0123;
2092       if (upsample_above) {
2093         const uint16x4x2_t a01 = vld2_u16(above + base_x0);
2094         a0 = a01.val[0];
2095         a1 = a01.val[1];
2096         shift_x0123 = vand_s16(x0123, vdup_n_s16(0x1F));
2097       } else {
2098         a0 = vld1_u16(above + base_x0);
2099         a1 = vld1_u16(above + base_x0 + 1);
2100         shift_x0123 = vand_s16(vshr_n_s16(x0123, 1), vdup_n_s16(0x1F));
2101       }
2102       vst1_u16(dst,
2103                highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123));
2104     } else if (base_shift < 4) {
2105       // Calculate Y component from `left`.
2106       const int y_iters = base_shift;
2107       const int16x4_t y0123 =
2108           vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy));
2109       const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y));
2110       const int16x4_t shift_y0123 = vshr_n_s16(
2111           vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1);
2112 
2113       uint16x4_t l0, l1;
2114 #if AOM_ARCH_AARCH64
2115       const int left_data_base = upsample_left ? -2 : -1;
2116       l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x16(
2117           left_data0, base_y0123, left_data_base, y_iters);
2118       l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x16(
2119           left_data1, base_y0123, left_data_base, y_iters);
2120 #else
2121       const uint16x4x2_t l01 =
2122           highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, y_iters);
2123       l0 = l01.val[0];
2124       l1 = l01.val[1];
2125 #endif
2126 
2127       const uint16x4_t out_y =
2128           highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123);
2129 
2130       // Calculate X component from `above`.
2131       uint16x4_t a0, a1;
2132       int16x4_t shift_x0123;
2133       if (upsample_above) {
2134         const uint16x4x2_t a01 = vld2_u16(above + (base_x0 % 2 == 0 ? -2 : -1));
2135         a0 = a01.val[0];
2136         a1 = a01.val[1];
2137         shift_x0123 = vand_s16(x0123, vdup_n_s16(0x1F));
2138       } else {
2139         a0 = vld1_u16(above - 1);
2140         a1 = vld1_u16(above + 0);
2141         shift_x0123 = vand_s16(vshr_n_s16(x0123, 1), vdup_n_s16(0x1F));
2142       }
2143       const uint16x4_t out_x =
2144           highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123);
2145 
2146       // Combine X and Y vectors.
2147       const uint16x4_t out =
2148           highbd_dr_prediction_z2_merge_x4(out_x, out_y, base_shift);
2149       vst1_u16(dst, out);
2150     } else {
2151       const int16x4_t y0123 =
2152           vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy));
2153       const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y));
2154       const int16x4_t shift_y0123 = vshr_n_s16(
2155           vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1);
2156 
2157       uint16x4_t l0, l1;
2158 #if AOM_ARCH_AARCH64
2159       const int left_data_base = upsample_left ? -2 : -1;
2160       l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x16(left_data0, base_y0123,
2161                                                         left_data_base, 4);
2162       l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x16(left_data1, base_y0123,
2163                                                         left_data_base, 4);
2164 #else
2165       const uint16x4x2_t l01 =
2166           highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, 4);
2167       l0 = l01.val[0];
2168       l1 = l01.val[1];
2169 #endif
2170 
2171       vst1_u16(dst,
2172                highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123));
2173     }
2174     dst += stride;
2175   }
2176 }
2177 
highbd_dr_prediction_z2_8x4_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int upsample_above,int upsample_left,int dx,int dy,int bd)2178 static void highbd_dr_prediction_z2_8x4_neon(uint16_t *dst, ptrdiff_t stride,
2179                                              const uint16_t *above,
2180                                              const uint16_t *left,
2181                                              int upsample_above,
2182                                              int upsample_left, int dx, int dy,
2183                                              int bd) {
2184   (void)bd;
2185   assert(dx > 0);
2186   assert(dy > 0);
2187 
2188   const int frac_bits_x = 6 - upsample_above;
2189   const int frac_bits_y = 6 - upsample_left;
2190   const int min_base_x = -(1 << (upsample_above + frac_bits_x));
2191 
2192   // if `upsample_left` then we need -2 through 6 inclusive from `left`.
2193   // else we only need -1 through 3 inclusive.
2194 
2195 #if AOM_ARCH_AARCH64
2196   uint16x8_t left_data0, left_data1;
2197   if (upsample_left) {
2198     left_data0 = vld1q_u16(left - 2);
2199     left_data1 = vld1q_u16(left - 1);
2200   } else {
2201     left_data0 = vcombine_u16(vld1_u16(left - 1), vdup_n_u16(0));
2202     left_data1 = vcombine_u16(vld1_u16(left + 0), vdup_n_u16(0));
2203   }
2204 #endif
2205 
2206   const int16x8_t iota01234567 = vld1q_s16(iota1_s16);
2207   const int16x8_t iota12345678 = vld1q_s16(iota1_s16 + 1);
2208 
2209   for (int r = 0; r < 4; ++r) {
2210     const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6;
2211     const int x0 = (r + 1) * dx;
2212     const int16x8_t x01234567 =
2213         vsubq_s16(vshlq_n_s16(iota01234567, 6), vdupq_n_s16(x0));
2214     const int base_x0 = (-x0) >> frac_bits_x;
2215     if (base_shift <= 0) {
2216       uint16x8_t a0, a1;
2217       int16x8_t shift_x01234567;
2218       if (upsample_above) {
2219         const uint16x8x2_t a01 = vld2q_u16(above + base_x0);
2220         a0 = a01.val[0];
2221         a1 = a01.val[1];
2222         shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F));
2223       } else {
2224         a0 = vld1q_u16(above + base_x0);
2225         a1 = vld1q_u16(above + base_x0 + 1);
2226         shift_x01234567 =
2227             vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F));
2228       }
2229       vst1q_u16(
2230           dst, highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567));
2231     } else if (base_shift < 8) {
2232       // Calculate Y component from `left`.
2233       const int y_iters = base_shift;
2234       const int16x8_t y01234567 =
2235           vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy));
2236       const int16x8_t base_y01234567 =
2237           vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y));
2238       const int16x8_t shift_y01234567 =
2239           vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left),
2240                                 vdupq_n_s16(0x3F)),
2241                       1);
2242 
2243       uint16x8_t l0, l1;
2244 #if AOM_ARCH_AARCH64
2245       const int left_data_base = upsample_left ? -2 : -1;
2246       l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x8(
2247           left_data0, base_y01234567, left_data_base, y_iters);
2248       l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x8(
2249           left_data1, base_y01234567, left_data_base, y_iters);
2250 #else
2251       const uint16x8x2_t l01 =
2252           highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, y_iters);
2253       l0 = l01.val[0];
2254       l1 = l01.val[1];
2255 #endif
2256 
2257       const uint16x8_t out_y =
2258           highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567);
2259 
2260       // Calculate X component from `above`.
2261       uint16x8_t a0, a1;
2262       int16x8_t shift_x01234567;
2263       if (upsample_above) {
2264         const uint16x8x2_t a01 =
2265             vld2q_u16(above + (base_x0 % 2 == 0 ? -2 : -1));
2266         a0 = a01.val[0];
2267         a1 = a01.val[1];
2268         shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F));
2269       } else {
2270         a0 = vld1q_u16(above - 1);
2271         a1 = vld1q_u16(above + 0);
2272         shift_x01234567 =
2273             vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F));
2274       }
2275       const uint16x8_t out_x =
2276           highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567);
2277 
2278       // Combine X and Y vectors.
2279       const uint16x8_t out =
2280           highbd_dr_prediction_z2_merge_x8(out_x, out_y, base_shift);
2281       vst1q_u16(dst, out);
2282     } else {
2283       const int16x8_t y01234567 =
2284           vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy));
2285       const int16x8_t base_y01234567 =
2286           vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y));
2287       const int16x8_t shift_y01234567 =
2288           vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left),
2289                                 vdupq_n_s16(0x3F)),
2290                       1);
2291 
2292       uint16x8_t l0, l1;
2293 #if AOM_ARCH_AARCH64
2294       const int left_data_base = upsample_left ? -2 : -1;
2295       l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x8(
2296           left_data0, base_y01234567, left_data_base, 8);
2297       l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x8(
2298           left_data1, base_y01234567, left_data_base, 8);
2299 #else
2300       const uint16x8x2_t l01 =
2301           highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, 8);
2302       l0 = l01.val[0];
2303       l1 = l01.val[1];
2304 #endif
2305 
2306       vst1q_u16(
2307           dst, highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567));
2308     }
2309     dst += stride;
2310   }
2311 }
2312 
highbd_dr_prediction_z2_8x8_neon(uint16_t * dst,ptrdiff_t stride,const uint16_t * above,const uint16_t * left,int upsample_above,int upsample_left,int dx,int dy,int bd)2313 static void highbd_dr_prediction_z2_8x8_neon(uint16_t *dst, ptrdiff_t stride,
2314                                              const uint16_t *above,
2315                                              const uint16_t *left,
2316                                              int upsample_above,
2317                                              int upsample_left, int dx, int dy,
2318                                              int bd) {
2319   (void)bd;
2320   assert(dx > 0);
2321   assert(dy > 0);
2322 
2323   const int frac_bits_x = 6 - upsample_above;
2324   const int frac_bits_y = 6 - upsample_left;
2325   const int min_base_x = -(1 << (upsample_above + frac_bits_x));
2326 
2327   // if `upsample_left` then we need -2 through 14 inclusive from `left`.
2328   // else we only need -1 through 6 inclusive.
2329 
2330 #if AOM_ARCH_AARCH64
2331   uint16x8x2_t left_data0, left_data1;
2332   if (upsample_left) {
2333     left_data0 = vld1q_u16_x2(left - 2);
2334     left_data1 = vld1q_u16_x2(left - 1);
2335   } else {
2336     left_data0 = (uint16x8x2_t){ { vld1q_u16(left - 1), vdupq_n_u16(0) } };
2337     left_data1 = (uint16x8x2_t){ { vld1q_u16(left + 0), vdupq_n_u16(0) } };
2338   }
2339 #endif
2340 
2341   const int16x8_t iota01234567 = vld1q_s16(iota1_s16);
2342   const int16x8_t iota12345678 = vld1q_s16(iota1_s16 + 1);
2343 
2344   for (int r = 0; r < 8; ++r) {
2345     const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6;
2346     const int x0 = (r + 1) * dx;
2347     const int16x8_t x01234567 =
2348         vsubq_s16(vshlq_n_s16(iota01234567, 6), vdupq_n_s16(x0));
2349     const int base_x0 = (-x0) >> frac_bits_x;
2350     if (base_shift <= 0) {
2351       uint16x8_t a0, a1;
2352       int16x8_t shift_x01234567;
2353       if (upsample_above) {
2354         const uint16x8x2_t a01 = vld2q_u16(above + base_x0);
2355         a0 = a01.val[0];
2356         a1 = a01.val[1];
2357         shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F));
2358       } else {
2359         a0 = vld1q_u16(above + base_x0);
2360         a1 = vld1q_u16(above + base_x0 + 1);
2361         shift_x01234567 =
2362             vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F));
2363       }
2364       vst1q_u16(
2365           dst, highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567));
2366     } else if (base_shift < 8) {
2367       // Calculate Y component from `left`.
2368       const int y_iters = base_shift;
2369       const int16x8_t y01234567 =
2370           vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy));
2371       const int16x8_t base_y01234567 =
2372           vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y));
2373       const int16x8_t shift_y01234567 =
2374           vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left),
2375                                 vdupq_n_s16(0x3F)),
2376                       1);
2377 
2378       uint16x8_t l0, l1;
2379 #if AOM_ARCH_AARCH64
2380       const int left_data_base = upsample_left ? -2 : -1;
2381       l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x16(
2382           left_data0, base_y01234567, left_data_base, y_iters);
2383       l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x16(
2384           left_data1, base_y01234567, left_data_base, y_iters);
2385 #else
2386       const uint16x8x2_t l01 =
2387           highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, y_iters);
2388       l0 = l01.val[0];
2389       l1 = l01.val[1];
2390 #endif
2391 
2392       const uint16x8_t out_y =
2393           highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567);
2394 
2395       // Calculate X component from `above`.
2396       uint16x8_t a0, a1;
2397       int16x8_t shift_x01234567;
2398       if (upsample_above) {
2399         const uint16x8x2_t a01 =
2400             vld2q_u16(above + (base_x0 % 2 == 0 ? -2 : -1));
2401         a0 = a01.val[0];
2402         a1 = a01.val[1];
2403         shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F));
2404       } else {
2405         a0 = vld1q_u16(above - 1);
2406         a1 = vld1q_u16(above + 0);
2407         shift_x01234567 =
2408             vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F));
2409       }
2410       const uint16x8_t out_x =
2411           highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567);
2412 
2413       // Combine X and Y vectors.
2414       const uint16x8_t out =
2415           highbd_dr_prediction_z2_merge_x8(out_x, out_y, base_shift);
2416       vst1q_u16(dst, out);
2417     } else {
2418       const int16x8_t y01234567 =
2419           vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy));
2420       const int16x8_t base_y01234567 =
2421           vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y));
2422       const int16x8_t shift_y01234567 =
2423           vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left),
2424                                 vdupq_n_s16(0x3F)),
2425                       1);
2426 
2427       uint16x8_t l0, l1;
2428 #if AOM_ARCH_AARCH64
2429       const int left_data_base = upsample_left ? -2 : -1;
2430       l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x16(
2431           left_data0, base_y01234567, left_data_base, 8);
2432       l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x16(
2433           left_data1, base_y01234567, left_data_base, 8);
2434 #else
2435       const uint16x8x2_t l01 =
2436           highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, 8);
2437       l0 = l01.val[0];
2438       l1 = l01.val[1];
2439 #endif
2440 
2441       vst1q_u16(
2442           dst, highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567));
2443     }
2444     dst += stride;
2445   }
2446 }
2447 
2448 static highbd_dr_prediction_z2_ptr dr_predictor_z2_arr_neon[7][7] = {
2449   { NULL, NULL, NULL, NULL, NULL, NULL, NULL },
2450   { NULL, NULL, NULL, NULL, NULL, NULL, NULL },
2451   { NULL, NULL, &highbd_dr_prediction_z2_4x4_neon,
2452     &highbd_dr_prediction_z2_4x8_neon, &highbd_dr_prediction_z2_4x16_neon, NULL,
2453     NULL },
2454   { NULL, NULL, &highbd_dr_prediction_z2_8x4_neon,
2455     &highbd_dr_prediction_z2_8x8_neon, &highbd_dr_prediction_z2_8x16_neon,
2456     &highbd_dr_prediction_z2_8x32_neon, NULL },
2457   { NULL, NULL, &highbd_dr_prediction_z2_16x4_neon,
2458     &highbd_dr_prediction_z2_16x8_neon, &highbd_dr_prediction_z2_16x16_neon,
2459     &highbd_dr_prediction_z2_16x32_neon, &highbd_dr_prediction_z2_16x64_neon },
2460   { NULL, NULL, NULL, &highbd_dr_prediction_z2_32x8_neon,
2461     &highbd_dr_prediction_z2_32x16_neon, &highbd_dr_prediction_z2_32x32_neon,
2462     &highbd_dr_prediction_z2_32x64_neon },
2463   { NULL, NULL, NULL, NULL, &highbd_dr_prediction_z2_64x16_neon,
2464     &highbd_dr_prediction_z2_64x32_neon, &highbd_dr_prediction_z2_64x64_neon },
2465 };
2466 
2467 // Directional prediction, zone 2: 90 < angle < 180
av1_highbd_dr_prediction_z2_neon(uint16_t * dst,ptrdiff_t stride,int bw,int bh,const uint16_t * above,const uint16_t * left,int upsample_above,int upsample_left,int dx,int dy,int bd)2468 void av1_highbd_dr_prediction_z2_neon(uint16_t *dst, ptrdiff_t stride, int bw,
2469                                       int bh, const uint16_t *above,
2470                                       const uint16_t *left, int upsample_above,
2471                                       int upsample_left, int dx, int dy,
2472                                       int bd) {
2473   highbd_dr_prediction_z2_ptr f =
2474       dr_predictor_z2_arr_neon[get_msb(bw)][get_msb(bh)];
2475   assert(f != NULL);
2476   f(dst, stride, above, left, upsample_above, upsample_left, dx, dy, bd);
2477 }
2478 #endif  // AOM_ARCH_AARCH64
2479 
2480 // -----------------------------------------------------------------------------
2481 // Z3
2482 
2483 // Both the lane to the use and the shift amount must be immediates.
2484 #define HIGHBD_DR_PREDICTOR_Z3_STEP_X4(out, iota, base, in0, in1, s0, s1, \
2485                                        lane, shift)                       \
2486   do {                                                                    \
2487     uint32x4_t val = vmull_lane_u16((in0), (s0), (lane));                 \
2488     val = vmlal_lane_u16(val, (in1), (s1), (lane));                       \
2489     const uint16x4_t cmp = vadd_u16((iota), vdup_n_u16(base));            \
2490     const uint16x4_t res = vrshrn_n_u32(val, (shift));                    \
2491     *(out) = vbsl_u16(vclt_u16(cmp, vdup_n_u16(max_base_y)), res,         \
2492                       vdup_n_u16(left_max));                              \
2493   } while (0)
2494 
2495 #define HIGHBD_DR_PREDICTOR_Z3_STEP_X8(out, iota, base, in0, in1, s0, s1, \
2496                                        lane, shift)                       \
2497   do {                                                                    \
2498     uint32x4_t val_lo = vmull_lane_u16(vget_low_u16(in0), (s0), (lane));  \
2499     val_lo = vmlal_lane_u16(val_lo, vget_low_u16(in1), (s1), (lane));     \
2500     uint32x4_t val_hi = vmull_lane_u16(vget_high_u16(in0), (s0), (lane)); \
2501     val_hi = vmlal_lane_u16(val_hi, vget_high_u16(in1), (s1), (lane));    \
2502     *(out) = vcombine_u16(vrshrn_n_u32(val_lo, (shift)),                  \
2503                           vrshrn_n_u32(val_hi, (shift)));                 \
2504   } while (0)
2505 
z3_load_left_neon(const uint16_t * left0,int ofs,int max_ofs)2506 static INLINE uint16x8x2_t z3_load_left_neon(const uint16_t *left0, int ofs,
2507                                              int max_ofs) {
2508   uint16x8_t r0;
2509   uint16x8_t r1;
2510   if (ofs + 7 >= max_ofs) {
2511     int shuffle_idx = max_ofs - ofs;
2512     r0 = zn_load_masked_neon(left0 + (max_ofs - 7), shuffle_idx);
2513   } else {
2514     r0 = vld1q_u16(left0 + ofs);
2515   }
2516   if (ofs + 8 >= max_ofs) {
2517     int shuffle_idx = max_ofs - ofs - 1;
2518     r1 = zn_load_masked_neon(left0 + (max_ofs - 7), shuffle_idx);
2519   } else {
2520     r1 = vld1q_u16(left0 + ofs + 1);
2521   }
2522   return (uint16x8x2_t){ { r0, r1 } };
2523 }
2524 
highbd_dr_prediction_z3_upsample0_neon(uint16_t * dst,ptrdiff_t stride,int bw,int bh,const uint16_t * left,int dy)2525 static void highbd_dr_prediction_z3_upsample0_neon(uint16_t *dst,
2526                                                    ptrdiff_t stride, int bw,
2527                                                    int bh, const uint16_t *left,
2528                                                    int dy) {
2529   assert(bw % 4 == 0);
2530   assert(bh % 4 == 0);
2531   assert(dy > 0);
2532 
2533   // Factor out left + 1 to give the compiler a better chance of recognising
2534   // that the offsets used for the loads from left and left + 1 are otherwise
2535   // identical.
2536   const uint16_t *left1 = left + 1;
2537 
2538   const int max_base_y = (bw + bh - 1);
2539   const int left_max = left[max_base_y];
2540   const int frac_bits = 6;
2541 
2542   const uint16x8_t iota1x8 = vreinterpretq_u16_s16(vld1q_s16(iota1_s16));
2543   const uint16x4_t iota1x4 = vget_low_u16(iota1x8);
2544 
2545   // The C implementation of the z3 predictor when not upsampling uses:
2546   // ((y & 0x3f) >> 1)
2547   // The right shift is unnecessary here since we instead shift by +1 later,
2548   // so adjust the mask to 0x3e to ensure we don't consider the extra bit.
2549   const uint16x4_t shift_mask = vdup_n_u16(0x3e);
2550 
2551   if (bh == 4) {
2552     int y = dy;
2553     int c = 0;
2554     do {
2555       // Fully unroll the 4x4 block to allow us to use immediate lane-indexed
2556       // multiply instructions.
2557       const uint16x4_t shifts1 =
2558           vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask);
2559       const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(64), shifts1);
2560       const int base0 = (y + 0 * dy) >> frac_bits;
2561       const int base1 = (y + 1 * dy) >> frac_bits;
2562       const int base2 = (y + 2 * dy) >> frac_bits;
2563       const int base3 = (y + 3 * dy) >> frac_bits;
2564       uint16x4_t out[4];
2565       if (base0 >= max_base_y) {
2566         out[0] = vdup_n_u16(left_max);
2567       } else {
2568         const uint16x4_t l00 = vld1_u16(left + base0);
2569         const uint16x4_t l01 = vld1_u16(left1 + base0);
2570         HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[0], iota1x4, base0, l00, l01,
2571                                        shifts0, shifts1, 0, 6);
2572       }
2573       if (base1 >= max_base_y) {
2574         out[1] = vdup_n_u16(left_max);
2575       } else {
2576         const uint16x4_t l10 = vld1_u16(left + base1);
2577         const uint16x4_t l11 = vld1_u16(left1 + base1);
2578         HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[1], iota1x4, base1, l10, l11,
2579                                        shifts0, shifts1, 1, 6);
2580       }
2581       if (base2 >= max_base_y) {
2582         out[2] = vdup_n_u16(left_max);
2583       } else {
2584         const uint16x4_t l20 = vld1_u16(left + base2);
2585         const uint16x4_t l21 = vld1_u16(left1 + base2);
2586         HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[2], iota1x4, base2, l20, l21,
2587                                        shifts0, shifts1, 2, 6);
2588       }
2589       if (base3 >= max_base_y) {
2590         out[3] = vdup_n_u16(left_max);
2591       } else {
2592         const uint16x4_t l30 = vld1_u16(left + base3);
2593         const uint16x4_t l31 = vld1_u16(left1 + base3);
2594         HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[3], iota1x4, base3, l30, l31,
2595                                        shifts0, shifts1, 3, 6);
2596       }
2597       transpose_array_inplace_u16_4x4(out);
2598       for (int r2 = 0; r2 < 4; ++r2) {
2599         vst1_u16(dst + r2 * stride + c, out[r2]);
2600       }
2601       y += 4 * dy;
2602       c += 4;
2603     } while (c < bw);
2604   } else {
2605     int y = dy;
2606     int c = 0;
2607     do {
2608       int r = 0;
2609       do {
2610         // Fully unroll the 4x4 block to allow us to use immediate lane-indexed
2611         // multiply instructions.
2612         const uint16x4_t shifts1 =
2613             vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask);
2614         const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(64), shifts1);
2615         const int base0 = ((y + 0 * dy) >> frac_bits) + r;
2616         const int base1 = ((y + 1 * dy) >> frac_bits) + r;
2617         const int base2 = ((y + 2 * dy) >> frac_bits) + r;
2618         const int base3 = ((y + 3 * dy) >> frac_bits) + r;
2619         uint16x8_t out[4];
2620         if (base0 >= max_base_y) {
2621           out[0] = vdupq_n_u16(left_max);
2622         } else {
2623           const uint16x8x2_t l0 = z3_load_left_neon(left, base0, max_base_y);
2624           HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota1x8, base0, l0.val[0],
2625                                          l0.val[1], shifts0, shifts1, 0, 6);
2626         }
2627         if (base1 >= max_base_y) {
2628           out[1] = vdupq_n_u16(left_max);
2629         } else {
2630           const uint16x8x2_t l1 = z3_load_left_neon(left, base1, max_base_y);
2631           HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota1x8, base1, l1.val[0],
2632                                          l1.val[1], shifts0, shifts1, 1, 6);
2633         }
2634         if (base2 >= max_base_y) {
2635           out[2] = vdupq_n_u16(left_max);
2636         } else {
2637           const uint16x8x2_t l2 = z3_load_left_neon(left, base2, max_base_y);
2638           HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota1x8, base2, l2.val[0],
2639                                          l2.val[1], shifts0, shifts1, 2, 6);
2640         }
2641         if (base3 >= max_base_y) {
2642           out[3] = vdupq_n_u16(left_max);
2643         } else {
2644           const uint16x8x2_t l3 = z3_load_left_neon(left, base3, max_base_y);
2645           HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota1x8, base3, l3.val[0],
2646                                          l3.val[1], shifts0, shifts1, 3, 6);
2647         }
2648         transpose_array_inplace_u16_4x8(out);
2649         for (int r2 = 0; r2 < 4; ++r2) {
2650           vst1_u16(dst + (r + r2) * stride + c, vget_low_u16(out[r2]));
2651         }
2652         for (int r2 = 0; r2 < 4; ++r2) {
2653           vst1_u16(dst + (r + r2 + 4) * stride + c, vget_high_u16(out[r2]));
2654         }
2655         r += 8;
2656       } while (r < bh);
2657       y += 4 * dy;
2658       c += 4;
2659     } while (c < bw);
2660   }
2661 }
2662 
highbd_dr_prediction_z3_upsample1_neon(uint16_t * dst,ptrdiff_t stride,int bw,int bh,const uint16_t * left,int dy)2663 static void highbd_dr_prediction_z3_upsample1_neon(uint16_t *dst,
2664                                                    ptrdiff_t stride, int bw,
2665                                                    int bh, const uint16_t *left,
2666                                                    int dy) {
2667   assert(bw % 4 == 0);
2668   assert(bh % 4 == 0);
2669   assert(dy > 0);
2670 
2671   const int max_base_y = (bw + bh - 1) << 1;
2672   const int left_max = left[max_base_y];
2673   const int frac_bits = 5;
2674 
2675   const uint16x4_t iota1x4 = vreinterpret_u16_s16(vld1_s16(iota1_s16));
2676   const uint16x8_t iota2x8 = vreinterpretq_u16_s16(vld1q_s16(iota2_s16));
2677   const uint16x4_t iota2x4 = vget_low_u16(iota2x8);
2678 
2679   // The C implementation of the z3 predictor when upsampling uses:
2680   // (((x << 1) & 0x3f) >> 1)
2681   // The two shifts are unnecessary here since the lowest bit is guaranteed to
2682   // be zero when the mask is applied, so adjust the mask to 0x1f to avoid
2683   // needing the shifts at all.
2684   const uint16x4_t shift_mask = vdup_n_u16(0x1F);
2685 
2686   if (bh == 4) {
2687     int y = dy;
2688     int c = 0;
2689     do {
2690       // Fully unroll the 4x4 block to allow us to use immediate lane-indexed
2691       // multiply instructions.
2692       const uint16x4_t shifts1 =
2693           vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask);
2694       const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(32), shifts1);
2695       const int base0 = (y + 0 * dy) >> frac_bits;
2696       const int base1 = (y + 1 * dy) >> frac_bits;
2697       const int base2 = (y + 2 * dy) >> frac_bits;
2698       const int base3 = (y + 3 * dy) >> frac_bits;
2699       const uint16x4x2_t l0 = vld2_u16(left + base0);
2700       const uint16x4x2_t l1 = vld2_u16(left + base1);
2701       const uint16x4x2_t l2 = vld2_u16(left + base2);
2702       const uint16x4x2_t l3 = vld2_u16(left + base3);
2703       uint16x4_t out[4];
2704       HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[0], iota2x4, base0, l0.val[0],
2705                                      l0.val[1], shifts0, shifts1, 0, 5);
2706       HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[1], iota2x4, base1, l1.val[0],
2707                                      l1.val[1], shifts0, shifts1, 1, 5);
2708       HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[2], iota2x4, base2, l2.val[0],
2709                                      l2.val[1], shifts0, shifts1, 2, 5);
2710       HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[3], iota2x4, base3, l3.val[0],
2711                                      l3.val[1], shifts0, shifts1, 3, 5);
2712       transpose_array_inplace_u16_4x4(out);
2713       for (int r2 = 0; r2 < 4; ++r2) {
2714         vst1_u16(dst + r2 * stride + c, out[r2]);
2715       }
2716       y += 4 * dy;
2717       c += 4;
2718     } while (c < bw);
2719   } else {
2720     assert(bh % 8 == 0);
2721 
2722     int y = dy;
2723     int c = 0;
2724     do {
2725       int r = 0;
2726       do {
2727         // Fully unroll the 4x8 block to allow us to use immediate lane-indexed
2728         // multiply instructions.
2729         const uint16x4_t shifts1 =
2730             vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask);
2731         const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(32), shifts1);
2732         const int base0 = ((y + 0 * dy) >> frac_bits) + (r * 2);
2733         const int base1 = ((y + 1 * dy) >> frac_bits) + (r * 2);
2734         const int base2 = ((y + 2 * dy) >> frac_bits) + (r * 2);
2735         const int base3 = ((y + 3 * dy) >> frac_bits) + (r * 2);
2736         const uint16x8x2_t l0 = vld2q_u16(left + base0);
2737         const uint16x8x2_t l1 = vld2q_u16(left + base1);
2738         const uint16x8x2_t l2 = vld2q_u16(left + base2);
2739         const uint16x8x2_t l3 = vld2q_u16(left + base3);
2740         uint16x8_t out[4];
2741         HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota2x8, base0, l0.val[0],
2742                                        l0.val[1], shifts0, shifts1, 0, 5);
2743         HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota2x8, base1, l1.val[0],
2744                                        l1.val[1], shifts0, shifts1, 1, 5);
2745         HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota2x8, base2, l2.val[0],
2746                                        l2.val[1], shifts0, shifts1, 2, 5);
2747         HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota2x8, base3, l3.val[0],
2748                                        l3.val[1], shifts0, shifts1, 3, 5);
2749         transpose_array_inplace_u16_4x8(out);
2750         for (int r2 = 0; r2 < 4; ++r2) {
2751           vst1_u16(dst + (r + r2) * stride + c, vget_low_u16(out[r2]));
2752         }
2753         for (int r2 = 0; r2 < 4; ++r2) {
2754           vst1_u16(dst + (r + r2 + 4) * stride + c, vget_high_u16(out[r2]));
2755         }
2756         r += 8;
2757       } while (r < bh);
2758       y += 4 * dy;
2759       c += 4;
2760     } while (c < bw);
2761   }
2762 }
2763 
2764 // Directional prediction, zone 3: 180 < angle < 270
av1_highbd_dr_prediction_z3_neon(uint16_t * dst,ptrdiff_t stride,int bw,int bh,const uint16_t * above,const uint16_t * left,int upsample_left,int dx,int dy,int bd)2765 void av1_highbd_dr_prediction_z3_neon(uint16_t *dst, ptrdiff_t stride, int bw,
2766                                       int bh, const uint16_t *above,
2767                                       const uint16_t *left, int upsample_left,
2768                                       int dx, int dy, int bd) {
2769   (void)above;
2770   (void)dx;
2771   (void)bd;
2772   assert(bw % 4 == 0);
2773   assert(bh % 4 == 0);
2774   assert(dx == 1);
2775   assert(dy > 0);
2776 
2777   if (upsample_left) {
2778     highbd_dr_prediction_z3_upsample1_neon(dst, stride, bw, bh, left, dy);
2779   } else {
2780     highbd_dr_prediction_z3_upsample0_neon(dst, stride, bw, bh, left, dy);
2781   }
2782 }
2783 
2784 #undef HIGHBD_DR_PREDICTOR_Z3_STEP_X4
2785 #undef HIGHBD_DR_PREDICTOR_Z3_STEP_X8
2786