• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #include <arm_neon.h>
12 
13 #include "config/aom_config.h"
14 #include "config/av1_rtcd.h"
15 
16 #include "aom_dsp/arm/mem_neon.h"
17 #include "av1/common/cfl.h"
18 
vldsubstq_s16(int16_t * dst,const uint16_t * src,int offset,int16x8_t sub)19 static inline void vldsubstq_s16(int16_t *dst, const uint16_t *src, int offset,
20                                  int16x8_t sub) {
21   vst1q_s16(dst + offset,
22             vsubq_s16(vreinterpretq_s16_u16(vld1q_u16(src + offset)), sub));
23 }
24 
vldaddq_u16(const uint16_t * buf,size_t offset)25 static inline uint16x8_t vldaddq_u16(const uint16_t *buf, size_t offset) {
26   return vaddq_u16(vld1q_u16(buf), vld1q_u16(buf + offset));
27 }
28 
29 // Load half of a vector and duplicated in other half
vldh_dup_u8(const uint8_t * ptr)30 static inline uint8x8_t vldh_dup_u8(const uint8_t *ptr) {
31   return vreinterpret_u8_u32(vld1_dup_u32((const uint32_t *)ptr));
32 }
33 
34 // Store half of a vector.
vsth_u16(uint16_t * ptr,uint16x4_t val)35 static inline void vsth_u16(uint16_t *ptr, uint16x4_t val) {
36   vst1_lane_u32((uint32_t *)ptr, vreinterpret_u32_u16(val), 0);
37 }
38 
39 // Store half of a vector.
vsth_u8(uint8_t * ptr,uint8x8_t val)40 static inline void vsth_u8(uint8_t *ptr, uint8x8_t val) {
41   vst1_lane_u32((uint32_t *)ptr, vreinterpret_u32_u8(val), 0);
42 }
43 
cfl_luma_subsampling_420_lbd_neon(const uint8_t * input,int input_stride,uint16_t * pred_buf_q3,int width,int height)44 static void cfl_luma_subsampling_420_lbd_neon(const uint8_t *input,
45                                               int input_stride,
46                                               uint16_t *pred_buf_q3, int width,
47                                               int height) {
48   const uint16_t *end = pred_buf_q3 + (height >> 1) * CFL_BUF_LINE;
49   const int luma_stride = input_stride << 1;
50   do {
51     if (width == 4) {
52       const uint16x4_t top = vpaddl_u8(vldh_dup_u8(input));
53       const uint16x4_t sum = vpadal_u8(top, vldh_dup_u8(input + input_stride));
54       vsth_u16(pred_buf_q3, vshl_n_u16(sum, 1));
55     } else if (width == 8) {
56       const uint16x4_t top = vpaddl_u8(vld1_u8(input));
57       const uint16x4_t sum = vpadal_u8(top, vld1_u8(input + input_stride));
58       vst1_u16(pred_buf_q3, vshl_n_u16(sum, 1));
59     } else if (width == 16) {
60       const uint16x8_t top = vpaddlq_u8(vld1q_u8(input));
61       const uint16x8_t sum = vpadalq_u8(top, vld1q_u8(input + input_stride));
62       vst1q_u16(pred_buf_q3, vshlq_n_u16(sum, 1));
63     } else {
64       const uint8x8x4_t top = vld4_u8(input);
65       const uint8x8x4_t bot = vld4_u8(input + input_stride);
66       // equivalent to a vpaddlq_u8 (because vld4q interleaves)
67       const uint16x8_t top_0 = vaddl_u8(top.val[0], top.val[1]);
68       // equivalent to a vpaddlq_u8 (because vld4q interleaves)
69       const uint16x8_t bot_0 = vaddl_u8(bot.val[0], bot.val[1]);
70       // equivalent to a vpaddlq_u8 (because vld4q interleaves)
71       const uint16x8_t top_1 = vaddl_u8(top.val[2], top.val[3]);
72       // equivalent to a vpaddlq_u8 (because vld4q interleaves)
73       const uint16x8_t bot_1 = vaddl_u8(bot.val[2], bot.val[3]);
74       uint16x8x2_t sum;
75       sum.val[0] = vshlq_n_u16(vaddq_u16(top_0, bot_0), 1);
76       sum.val[1] = vshlq_n_u16(vaddq_u16(top_1, bot_1), 1);
77       vst2q_u16(pred_buf_q3, sum);
78     }
79     input += luma_stride;
80   } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
81 }
82 
cfl_luma_subsampling_422_lbd_neon(const uint8_t * input,int input_stride,uint16_t * pred_buf_q3,int width,int height)83 static void cfl_luma_subsampling_422_lbd_neon(const uint8_t *input,
84                                               int input_stride,
85                                               uint16_t *pred_buf_q3, int width,
86                                               int height) {
87   const uint16_t *end = pred_buf_q3 + height * CFL_BUF_LINE;
88   do {
89     if (width == 4) {
90       const uint16x4_t top = vpaddl_u8(vldh_dup_u8(input));
91       vsth_u16(pred_buf_q3, vshl_n_u16(top, 2));
92     } else if (width == 8) {
93       const uint16x4_t top = vpaddl_u8(vld1_u8(input));
94       vst1_u16(pred_buf_q3, vshl_n_u16(top, 2));
95     } else if (width == 16) {
96       const uint16x8_t top = vpaddlq_u8(vld1q_u8(input));
97       vst1q_u16(pred_buf_q3, vshlq_n_u16(top, 2));
98     } else {
99       const uint8x8x4_t top = vld4_u8(input);
100       uint16x8x2_t sum;
101       // vaddl_u8 is equivalent to a vpaddlq_u8 (because vld4q interleaves)
102       sum.val[0] = vshlq_n_u16(vaddl_u8(top.val[0], top.val[1]), 2);
103       sum.val[1] = vshlq_n_u16(vaddl_u8(top.val[2], top.val[3]), 2);
104       vst2q_u16(pred_buf_q3, sum);
105     }
106     input += input_stride;
107   } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
108 }
109 
cfl_luma_subsampling_444_lbd_neon(const uint8_t * input,int input_stride,uint16_t * pred_buf_q3,int width,int height)110 static void cfl_luma_subsampling_444_lbd_neon(const uint8_t *input,
111                                               int input_stride,
112                                               uint16_t *pred_buf_q3, int width,
113                                               int height) {
114   const uint16_t *end = pred_buf_q3 + height * CFL_BUF_LINE;
115   do {
116     if (width == 4) {
117       const uint16x8_t top = vshll_n_u8(vldh_dup_u8(input), 3);
118       vst1_u16(pred_buf_q3, vget_low_u16(top));
119     } else if (width == 8) {
120       const uint16x8_t top = vshll_n_u8(vld1_u8(input), 3);
121       vst1q_u16(pred_buf_q3, top);
122     } else {
123       const uint8x16_t top = vld1q_u8(input);
124       vst1q_u16(pred_buf_q3, vshll_n_u8(vget_low_u8(top), 3));
125       vst1q_u16(pred_buf_q3 + 8, vshll_n_u8(vget_high_u8(top), 3));
126       if (width == 32) {
127         const uint8x16_t next_top = vld1q_u8(input + 16);
128         vst1q_u16(pred_buf_q3 + 16, vshll_n_u8(vget_low_u8(next_top), 3));
129         vst1q_u16(pred_buf_q3 + 24, vshll_n_u8(vget_high_u8(next_top), 3));
130       }
131     }
132     input += input_stride;
133   } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
134 }
135 
136 #if CONFIG_AV1_HIGHBITDEPTH
137 #if !AOM_ARCH_AARCH64
vpaddq_u16(uint16x8_t a,uint16x8_t b)138 static uint16x8_t vpaddq_u16(uint16x8_t a, uint16x8_t b) {
139   return vcombine_u16(vpadd_u16(vget_low_u16(a), vget_high_u16(a)),
140                       vpadd_u16(vget_low_u16(b), vget_high_u16(b)));
141 }
142 #endif
143 
cfl_luma_subsampling_420_hbd_neon(const uint16_t * input,int input_stride,uint16_t * pred_buf_q3,int width,int height)144 static void cfl_luma_subsampling_420_hbd_neon(const uint16_t *input,
145                                               int input_stride,
146                                               uint16_t *pred_buf_q3, int width,
147                                               int height) {
148   const uint16_t *end = pred_buf_q3 + (height >> 1) * CFL_BUF_LINE;
149   const int luma_stride = input_stride << 1;
150   do {
151     if (width == 4) {
152       const uint16x4_t top = vld1_u16(input);
153       const uint16x4_t bot = vld1_u16(input + input_stride);
154       const uint16x4_t sum = vadd_u16(top, bot);
155       const uint16x4_t hsum = vpadd_u16(sum, sum);
156       vsth_u16(pred_buf_q3, vshl_n_u16(hsum, 1));
157     } else if (width < 32) {
158       const uint16x8_t top = vld1q_u16(input);
159       const uint16x8_t bot = vld1q_u16(input + input_stride);
160       const uint16x8_t sum = vaddq_u16(top, bot);
161       if (width == 8) {
162         const uint16x4_t hsum = vget_low_u16(vpaddq_u16(sum, sum));
163         vst1_u16(pred_buf_q3, vshl_n_u16(hsum, 1));
164       } else {
165         const uint16x8_t top_1 = vld1q_u16(input + 8);
166         const uint16x8_t bot_1 = vld1q_u16(input + 8 + input_stride);
167         const uint16x8_t sum_1 = vaddq_u16(top_1, bot_1);
168         const uint16x8_t hsum = vpaddq_u16(sum, sum_1);
169         vst1q_u16(pred_buf_q3, vshlq_n_u16(hsum, 1));
170       }
171     } else {
172       const uint16x8x4_t top = vld4q_u16(input);
173       const uint16x8x4_t bot = vld4q_u16(input + input_stride);
174       // equivalent to a vpaddq_u16 (because vld4q interleaves)
175       const uint16x8_t top_0 = vaddq_u16(top.val[0], top.val[1]);
176       // equivalent to a vpaddq_u16 (because vld4q interleaves)
177       const uint16x8_t bot_0 = vaddq_u16(bot.val[0], bot.val[1]);
178       // equivalent to a vpaddq_u16 (because vld4q interleaves)
179       const uint16x8_t top_1 = vaddq_u16(top.val[2], top.val[3]);
180       // equivalent to a vpaddq_u16 (because vld4q interleaves)
181       const uint16x8_t bot_1 = vaddq_u16(bot.val[2], bot.val[3]);
182       uint16x8x2_t sum;
183       sum.val[0] = vshlq_n_u16(vaddq_u16(top_0, bot_0), 1);
184       sum.val[1] = vshlq_n_u16(vaddq_u16(top_1, bot_1), 1);
185       vst2q_u16(pred_buf_q3, sum);
186     }
187     input += luma_stride;
188   } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
189 }
190 
cfl_luma_subsampling_422_hbd_neon(const uint16_t * input,int input_stride,uint16_t * pred_buf_q3,int width,int height)191 static void cfl_luma_subsampling_422_hbd_neon(const uint16_t *input,
192                                               int input_stride,
193                                               uint16_t *pred_buf_q3, int width,
194                                               int height) {
195   const uint16_t *end = pred_buf_q3 + height * CFL_BUF_LINE;
196   do {
197     if (width == 4) {
198       const uint16x4_t top = vld1_u16(input);
199       const uint16x4_t hsum = vpadd_u16(top, top);
200       vsth_u16(pred_buf_q3, vshl_n_u16(hsum, 2));
201     } else if (width == 8) {
202       const uint16x4x2_t top = vld2_u16(input);
203       // equivalent to a vpadd_u16 (because vld2 interleaves)
204       const uint16x4_t hsum = vadd_u16(top.val[0], top.val[1]);
205       vst1_u16(pred_buf_q3, vshl_n_u16(hsum, 2));
206     } else if (width == 16) {
207       const uint16x8x2_t top = vld2q_u16(input);
208       // equivalent to a vpaddq_u16 (because vld2q interleaves)
209       const uint16x8_t hsum = vaddq_u16(top.val[0], top.val[1]);
210       vst1q_u16(pred_buf_q3, vshlq_n_u16(hsum, 2));
211     } else {
212       const uint16x8x4_t top = vld4q_u16(input);
213       // equivalent to a vpaddq_u16 (because vld4q interleaves)
214       const uint16x8_t hsum_0 = vaddq_u16(top.val[0], top.val[1]);
215       // equivalent to a vpaddq_u16 (because vld4q interleaves)
216       const uint16x8_t hsum_1 = vaddq_u16(top.val[2], top.val[3]);
217       uint16x8x2_t result = { { vshlq_n_u16(hsum_0, 2),
218                                 vshlq_n_u16(hsum_1, 2) } };
219       vst2q_u16(pred_buf_q3, result);
220     }
221     input += input_stride;
222   } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
223 }
224 
cfl_luma_subsampling_444_hbd_neon(const uint16_t * input,int input_stride,uint16_t * pred_buf_q3,int width,int height)225 static void cfl_luma_subsampling_444_hbd_neon(const uint16_t *input,
226                                               int input_stride,
227                                               uint16_t *pred_buf_q3, int width,
228                                               int height) {
229   const uint16_t *end = pred_buf_q3 + height * CFL_BUF_LINE;
230   do {
231     if (width == 4) {
232       const uint16x4_t top = vld1_u16(input);
233       vst1_u16(pred_buf_q3, vshl_n_u16(top, 3));
234     } else if (width == 8) {
235       const uint16x8_t top = vld1q_u16(input);
236       vst1q_u16(pred_buf_q3, vshlq_n_u16(top, 3));
237     } else if (width == 16) {
238       uint16x8x2_t top = vld2q_u16(input);
239       top.val[0] = vshlq_n_u16(top.val[0], 3);
240       top.val[1] = vshlq_n_u16(top.val[1], 3);
241       vst2q_u16(pred_buf_q3, top);
242     } else {
243       uint16x8x4_t top = vld4q_u16(input);
244       top.val[0] = vshlq_n_u16(top.val[0], 3);
245       top.val[1] = vshlq_n_u16(top.val[1], 3);
246       top.val[2] = vshlq_n_u16(top.val[2], 3);
247       top.val[3] = vshlq_n_u16(top.val[3], 3);
248       vst4q_u16(pred_buf_q3, top);
249     }
250     input += input_stride;
251   } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
252 }
253 #endif  // CONFIG_AV1_HIGHBITDEPTH
254 
CFL_GET_SUBSAMPLE_FUNCTION(neon)255 CFL_GET_SUBSAMPLE_FUNCTION(neon)
256 
257 static inline void subtract_average_neon(const uint16_t *src, int16_t *dst,
258                                          int width, int height,
259                                          int round_offset,
260                                          const int num_pel_log2) {
261   const uint16_t *const end = src + height * CFL_BUF_LINE;
262 
263   // Round offset is not needed, because NEON will handle the rounding.
264   (void)round_offset;
265 
266   // To optimize the use of the CPU pipeline, we process 4 rows per iteration
267   const int step = 4 * CFL_BUF_LINE;
268 
269   // At this stage, the prediction buffer contains scaled reconstructed luma
270   // pixels, which are positive integer and only require 15 bits. By using
271   // unsigned integer for the sum, we can do one addition operation inside 16
272   // bits (8 lanes) before having to convert to 32 bits (4 lanes).
273   const uint16_t *sum_buf = src;
274   uint32x4_t sum_32x4 = vdupq_n_u32(0);
275   do {
276     // For all widths, we load, add and combine the data so it fits in 4 lanes.
277     if (width == 4) {
278       const uint16x4_t a0 =
279           vadd_u16(vld1_u16(sum_buf), vld1_u16(sum_buf + CFL_BUF_LINE));
280       const uint16x4_t a1 = vadd_u16(vld1_u16(sum_buf + 2 * CFL_BUF_LINE),
281                                      vld1_u16(sum_buf + 3 * CFL_BUF_LINE));
282       sum_32x4 = vaddq_u32(sum_32x4, vaddl_u16(a0, a1));
283     } else if (width == 8) {
284       const uint16x8_t a0 = vldaddq_u16(sum_buf, CFL_BUF_LINE);
285       const uint16x8_t a1 =
286           vldaddq_u16(sum_buf + 2 * CFL_BUF_LINE, CFL_BUF_LINE);
287       sum_32x4 = vpadalq_u16(sum_32x4, a0);
288       sum_32x4 = vpadalq_u16(sum_32x4, a1);
289     } else {
290       const uint16x8_t row0 = vldaddq_u16(sum_buf, 8);
291       const uint16x8_t row1 = vldaddq_u16(sum_buf + CFL_BUF_LINE, 8);
292       const uint16x8_t row2 = vldaddq_u16(sum_buf + 2 * CFL_BUF_LINE, 8);
293       const uint16x8_t row3 = vldaddq_u16(sum_buf + 3 * CFL_BUF_LINE, 8);
294       sum_32x4 = vpadalq_u16(sum_32x4, row0);
295       sum_32x4 = vpadalq_u16(sum_32x4, row1);
296       sum_32x4 = vpadalq_u16(sum_32x4, row2);
297       sum_32x4 = vpadalq_u16(sum_32x4, row3);
298 
299       if (width == 32) {
300         const uint16x8_t row0_1 = vldaddq_u16(sum_buf + 16, 8);
301         const uint16x8_t row1_1 = vldaddq_u16(sum_buf + CFL_BUF_LINE + 16, 8);
302         const uint16x8_t row2_1 =
303             vldaddq_u16(sum_buf + 2 * CFL_BUF_LINE + 16, 8);
304         const uint16x8_t row3_1 =
305             vldaddq_u16(sum_buf + 3 * CFL_BUF_LINE + 16, 8);
306 
307         sum_32x4 = vpadalq_u16(sum_32x4, row0_1);
308         sum_32x4 = vpadalq_u16(sum_32x4, row1_1);
309         sum_32x4 = vpadalq_u16(sum_32x4, row2_1);
310         sum_32x4 = vpadalq_u16(sum_32x4, row3_1);
311       }
312     }
313     sum_buf += step;
314   } while (sum_buf < end);
315 
316   // Permute and add in such a way that each lane contains the block sum.
317   // [A+C+B+D, B+D+A+C, C+A+D+B, D+B+C+A]
318 #if AOM_ARCH_AARCH64
319   sum_32x4 = vpaddq_u32(sum_32x4, sum_32x4);
320   sum_32x4 = vpaddq_u32(sum_32x4, sum_32x4);
321 #else
322   uint32x4_t flip =
323       vcombine_u32(vget_high_u32(sum_32x4), vget_low_u32(sum_32x4));
324   sum_32x4 = vaddq_u32(sum_32x4, flip);
325   sum_32x4 = vaddq_u32(sum_32x4, vrev64q_u32(sum_32x4));
326 #endif
327 
328   // Computing the average could be done using scalars, but getting off the NEON
329   // engine introduces latency, so we use vqrshrn.
330   int16x4_t avg_16x4;
331   // Constant propagation makes for some ugly code.
332   switch (num_pel_log2) {
333     case 4: avg_16x4 = vreinterpret_s16_u16(vqrshrn_n_u32(sum_32x4, 4)); break;
334     case 5: avg_16x4 = vreinterpret_s16_u16(vqrshrn_n_u32(sum_32x4, 5)); break;
335     case 6: avg_16x4 = vreinterpret_s16_u16(vqrshrn_n_u32(sum_32x4, 6)); break;
336     case 7: avg_16x4 = vreinterpret_s16_u16(vqrshrn_n_u32(sum_32x4, 7)); break;
337     case 8: avg_16x4 = vreinterpret_s16_u16(vqrshrn_n_u32(sum_32x4, 8)); break;
338     case 9: avg_16x4 = vreinterpret_s16_u16(vqrshrn_n_u32(sum_32x4, 9)); break;
339     case 10:
340       avg_16x4 = vreinterpret_s16_u16(vqrshrn_n_u32(sum_32x4, 10));
341       break;
342     default: assert(0);
343   }
344 
345   if (width == 4) {
346     do {
347       vst1_s16(dst, vsub_s16(vreinterpret_s16_u16(vld1_u16(src)), avg_16x4));
348       src += CFL_BUF_LINE;
349       dst += CFL_BUF_LINE;
350     } while (src < end);
351   } else {
352     const int16x8_t avg_16x8 = vcombine_s16(avg_16x4, avg_16x4);
353     do {
354       vldsubstq_s16(dst, src, 0, avg_16x8);
355       vldsubstq_s16(dst, src, CFL_BUF_LINE, avg_16x8);
356       vldsubstq_s16(dst, src, 2 * CFL_BUF_LINE, avg_16x8);
357       vldsubstq_s16(dst, src, 3 * CFL_BUF_LINE, avg_16x8);
358 
359       if (width > 8) {
360         vldsubstq_s16(dst, src, 8, avg_16x8);
361         vldsubstq_s16(dst, src, 8 + CFL_BUF_LINE, avg_16x8);
362         vldsubstq_s16(dst, src, 8 + 2 * CFL_BUF_LINE, avg_16x8);
363         vldsubstq_s16(dst, src, 8 + 3 * CFL_BUF_LINE, avg_16x8);
364       }
365       if (width == 32) {
366         vldsubstq_s16(dst, src, 16, avg_16x8);
367         vldsubstq_s16(dst, src, 16 + CFL_BUF_LINE, avg_16x8);
368         vldsubstq_s16(dst, src, 16 + 2 * CFL_BUF_LINE, avg_16x8);
369         vldsubstq_s16(dst, src, 16 + 3 * CFL_BUF_LINE, avg_16x8);
370         vldsubstq_s16(dst, src, 24, avg_16x8);
371         vldsubstq_s16(dst, src, 24 + CFL_BUF_LINE, avg_16x8);
372         vldsubstq_s16(dst, src, 24 + 2 * CFL_BUF_LINE, avg_16x8);
373         vldsubstq_s16(dst, src, 24 + 3 * CFL_BUF_LINE, avg_16x8);
374       }
375       src += step;
376       dst += step;
377     } while (src < end);
378   }
379 }
380 
CFL_SUB_AVG_FN(neon)381 CFL_SUB_AVG_FN(neon)
382 
383 // Saturating negate 16-bit integers in a when the corresponding signed 16-bit
384 // integer in b is negative.
385 // Notes:
386 //   * Negating INT16_MIN results in INT16_MIN. However, this cannot occur in
387 //   practice, as scaled_luma is the multiplication of two absolute values.
388 //   * In the Intel equivalent, elements in a are zeroed out when the
389 //   corresponding elements in b are zero. Because vsign is used twice in a
390 //   row, with b in the first call becoming a in the second call, there's no
391 //   impact from not zeroing out.
392 static int16x4_t vsign_s16(int16x4_t a, int16x4_t b) {
393   const int16x4_t mask = vshr_n_s16(b, 15);
394   return veor_s16(vadd_s16(a, mask), mask);
395 }
396 
397 // Saturating negate 16-bit integers in a when the corresponding signed 16-bit
398 // integer in b is negative.
399 // Notes:
400 //   * Negating INT16_MIN results in INT16_MIN. However, this cannot occur in
401 //   practice, as scaled_luma is the multiplication of two absolute values.
402 //   * In the Intel equivalent, elements in a are zeroed out when the
403 //   corresponding elements in b are zero. Because vsignq is used twice in a
404 //   row, with b in the first call becoming a in the second call, there's no
405 //   impact from not zeroing out.
vsignq_s16(int16x8_t a,int16x8_t b)406 static int16x8_t vsignq_s16(int16x8_t a, int16x8_t b) {
407   const int16x8_t mask = vshrq_n_s16(b, 15);
408   return veorq_s16(vaddq_s16(a, mask), mask);
409 }
410 
predict_w4(const int16_t * pred_buf_q3,int16x4_t alpha_sign,int abs_alpha_q12,int16x4_t dc)411 static inline int16x4_t predict_w4(const int16_t *pred_buf_q3,
412                                    int16x4_t alpha_sign, int abs_alpha_q12,
413                                    int16x4_t dc) {
414   const int16x4_t ac_q3 = vld1_s16(pred_buf_q3);
415   const int16x4_t ac_sign = veor_s16(alpha_sign, ac_q3);
416   int16x4_t scaled_luma = vqrdmulh_n_s16(vabs_s16(ac_q3), abs_alpha_q12);
417   return vadd_s16(vsign_s16(scaled_luma, ac_sign), dc);
418 }
419 
predict_w8(const int16_t * pred_buf_q3,int16x8_t alpha_sign,int abs_alpha_q12,int16x8_t dc)420 static inline int16x8_t predict_w8(const int16_t *pred_buf_q3,
421                                    int16x8_t alpha_sign, int abs_alpha_q12,
422                                    int16x8_t dc) {
423   const int16x8_t ac_q3 = vld1q_s16(pred_buf_q3);
424   const int16x8_t ac_sign = veorq_s16(alpha_sign, ac_q3);
425   int16x8_t scaled_luma = vqrdmulhq_n_s16(vabsq_s16(ac_q3), abs_alpha_q12);
426   return vaddq_s16(vsignq_s16(scaled_luma, ac_sign), dc);
427 }
428 
predict_w16(const int16_t * pred_buf_q3,int16x8_t alpha_sign,int abs_alpha_q12,int16x8_t dc)429 static inline int16x8x2_t predict_w16(const int16_t *pred_buf_q3,
430                                       int16x8_t alpha_sign, int abs_alpha_q12,
431                                       int16x8_t dc) {
432   const int16x8x2_t ac_q3 = vld1q_s16_x2(pred_buf_q3);
433   const int16x8_t ac_sign_0 = veorq_s16(alpha_sign, ac_q3.val[0]);
434   const int16x8_t ac_sign_1 = veorq_s16(alpha_sign, ac_q3.val[1]);
435   const int16x8_t scaled_luma_0 =
436       vqrdmulhq_n_s16(vabsq_s16(ac_q3.val[0]), abs_alpha_q12);
437   const int16x8_t scaled_luma_1 =
438       vqrdmulhq_n_s16(vabsq_s16(ac_q3.val[1]), abs_alpha_q12);
439   int16x8x2_t result;
440   result.val[0] = vaddq_s16(vsignq_s16(scaled_luma_0, ac_sign_0), dc);
441   result.val[1] = vaddq_s16(vsignq_s16(scaled_luma_1, ac_sign_1), dc);
442   return result;
443 }
444 
predict_w32(const int16_t * pred_buf_q3,int16x8_t alpha_sign,int abs_alpha_q12,int16x8_t dc)445 static inline int16x8x4_t predict_w32(const int16_t *pred_buf_q3,
446                                       int16x8_t alpha_sign, int abs_alpha_q12,
447                                       int16x8_t dc) {
448   const int16x8x4_t ac_q3 = vld1q_s16_x4(pred_buf_q3);
449   const int16x8_t ac_sign_0 = veorq_s16(alpha_sign, ac_q3.val[0]);
450   const int16x8_t ac_sign_1 = veorq_s16(alpha_sign, ac_q3.val[1]);
451   const int16x8_t ac_sign_2 = veorq_s16(alpha_sign, ac_q3.val[2]);
452   const int16x8_t ac_sign_3 = veorq_s16(alpha_sign, ac_q3.val[3]);
453   const int16x8_t scaled_luma_0 =
454       vqrdmulhq_n_s16(vabsq_s16(ac_q3.val[0]), abs_alpha_q12);
455   const int16x8_t scaled_luma_1 =
456       vqrdmulhq_n_s16(vabsq_s16(ac_q3.val[1]), abs_alpha_q12);
457   const int16x8_t scaled_luma_2 =
458       vqrdmulhq_n_s16(vabsq_s16(ac_q3.val[2]), abs_alpha_q12);
459   const int16x8_t scaled_luma_3 =
460       vqrdmulhq_n_s16(vabsq_s16(ac_q3.val[3]), abs_alpha_q12);
461   int16x8x4_t result;
462   result.val[0] = vaddq_s16(vsignq_s16(scaled_luma_0, ac_sign_0), dc);
463   result.val[1] = vaddq_s16(vsignq_s16(scaled_luma_1, ac_sign_1), dc);
464   result.val[2] = vaddq_s16(vsignq_s16(scaled_luma_2, ac_sign_2), dc);
465   result.val[3] = vaddq_s16(vsignq_s16(scaled_luma_3, ac_sign_3), dc);
466   return result;
467 }
468 
cfl_predict_lbd_neon(const int16_t * pred_buf_q3,uint8_t * dst,int dst_stride,int alpha_q3,int width,int height)469 static inline void cfl_predict_lbd_neon(const int16_t *pred_buf_q3,
470                                         uint8_t *dst, int dst_stride,
471                                         int alpha_q3, int width, int height) {
472   const int16_t abs_alpha_q12 = abs(alpha_q3) << 9;
473   const int16_t *const end = pred_buf_q3 + height * CFL_BUF_LINE;
474   if (width == 4) {
475     const int16x4_t alpha_sign = vdup_n_s16(alpha_q3);
476     const int16x4_t dc = vdup_n_s16(*dst);
477     do {
478       const int16x4_t pred =
479           predict_w4(pred_buf_q3, alpha_sign, abs_alpha_q12, dc);
480       vsth_u8(dst, vqmovun_s16(vcombine_s16(pred, pred)));
481       dst += dst_stride;
482     } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
483   } else {
484     const int16x8_t alpha_sign = vdupq_n_s16(alpha_q3);
485     const int16x8_t dc = vdupq_n_s16(*dst);
486     do {
487       if (width == 8) {
488         vst1_u8(dst, vqmovun_s16(predict_w8(pred_buf_q3, alpha_sign,
489                                             abs_alpha_q12, dc)));
490       } else if (width == 16) {
491         const int16x8x2_t pred =
492             predict_w16(pred_buf_q3, alpha_sign, abs_alpha_q12, dc);
493         const uint8x8x2_t predun = { { vqmovun_s16(pred.val[0]),
494                                        vqmovun_s16(pred.val[1]) } };
495         vst1_u8_x2(dst, predun);
496       } else {
497         const int16x8x4_t pred =
498             predict_w32(pred_buf_q3, alpha_sign, abs_alpha_q12, dc);
499         const uint8x8x4_t predun = {
500           { vqmovun_s16(pred.val[0]), vqmovun_s16(pred.val[1]),
501             vqmovun_s16(pred.val[2]), vqmovun_s16(pred.val[3]) }
502         };
503         vst1_u8_x4(dst, predun);
504       }
505       dst += dst_stride;
506     } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
507   }
508 }
509 
CFL_PREDICT_FN(neon,lbd)510 CFL_PREDICT_FN(neon, lbd)
511 
512 #if CONFIG_AV1_HIGHBITDEPTH
513 static inline uint16x4_t clamp_s16(int16x4_t a, int16x4_t max) {
514   return vreinterpret_u16_s16(vmax_s16(vmin_s16(a, max), vdup_n_s16(0)));
515 }
516 
clampq_s16(int16x8_t a,int16x8_t max)517 static inline uint16x8_t clampq_s16(int16x8_t a, int16x8_t max) {
518   return vreinterpretq_u16_s16(vmaxq_s16(vminq_s16(a, max), vdupq_n_s16(0)));
519 }
520 
clamp2q_s16(int16x8x2_t a,int16x8_t max)521 static inline uint16x8x2_t clamp2q_s16(int16x8x2_t a, int16x8_t max) {
522   uint16x8x2_t result;
523   result.val[0] = vreinterpretq_u16_s16(
524       vmaxq_s16(vminq_s16(a.val[0], max), vdupq_n_s16(0)));
525   result.val[1] = vreinterpretq_u16_s16(
526       vmaxq_s16(vminq_s16(a.val[1], max), vdupq_n_s16(0)));
527   return result;
528 }
529 
clamp4q_s16(int16x8x4_t a,int16x8_t max)530 static inline uint16x8x4_t clamp4q_s16(int16x8x4_t a, int16x8_t max) {
531   uint16x8x4_t result;
532   result.val[0] = vreinterpretq_u16_s16(
533       vmaxq_s16(vminq_s16(a.val[0], max), vdupq_n_s16(0)));
534   result.val[1] = vreinterpretq_u16_s16(
535       vmaxq_s16(vminq_s16(a.val[1], max), vdupq_n_s16(0)));
536   result.val[2] = vreinterpretq_u16_s16(
537       vmaxq_s16(vminq_s16(a.val[2], max), vdupq_n_s16(0)));
538   result.val[3] = vreinterpretq_u16_s16(
539       vmaxq_s16(vminq_s16(a.val[3], max), vdupq_n_s16(0)));
540   return result;
541 }
542 
cfl_predict_hbd_neon(const int16_t * pred_buf_q3,uint16_t * dst,int dst_stride,int alpha_q3,int bd,int width,int height)543 static inline void cfl_predict_hbd_neon(const int16_t *pred_buf_q3,
544                                         uint16_t *dst, int dst_stride,
545                                         int alpha_q3, int bd, int width,
546                                         int height) {
547   const int max = (1 << bd) - 1;
548   const int16_t abs_alpha_q12 = abs(alpha_q3) << 9;
549   const int16_t *const end = pred_buf_q3 + height * CFL_BUF_LINE;
550   if (width == 4) {
551     const int16x4_t alpha_sign = vdup_n_s16(alpha_q3);
552     const int16x4_t dc = vdup_n_s16(*dst);
553     const int16x4_t max_16x4 = vdup_n_s16(max);
554     do {
555       const int16x4_t scaled_luma =
556           predict_w4(pred_buf_q3, alpha_sign, abs_alpha_q12, dc);
557       vst1_u16(dst, clamp_s16(scaled_luma, max_16x4));
558       dst += dst_stride;
559     } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
560   } else {
561     const int16x8_t alpha_sign = vdupq_n_s16(alpha_q3);
562     const int16x8_t dc = vdupq_n_s16(*dst);
563     const int16x8_t max_16x8 = vdupq_n_s16(max);
564     do {
565       if (width == 8) {
566         const int16x8_t pred =
567             predict_w8(pred_buf_q3, alpha_sign, abs_alpha_q12, dc);
568         vst1q_u16(dst, clampq_s16(pred, max_16x8));
569       } else if (width == 16) {
570         const int16x8x2_t pred =
571             predict_w16(pred_buf_q3, alpha_sign, abs_alpha_q12, dc);
572         vst1q_u16_x2(dst, clamp2q_s16(pred, max_16x8));
573       } else {
574         const int16x8x4_t pred =
575             predict_w32(pred_buf_q3, alpha_sign, abs_alpha_q12, dc);
576         vst1q_u16_x4(dst, clamp4q_s16(pred, max_16x8));
577       }
578       dst += dst_stride;
579     } while ((pred_buf_q3 += CFL_BUF_LINE) < end);
580   }
581 }
582 
583 CFL_PREDICT_FN(neon, hbd)
584 #endif  // CONFIG_AV1_HIGHBITDEPTH
585