• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 The WebM project authors. All rights reserved.
3  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 #include <assert.h>
15 
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18 
19 #include "aom_dsp/arm/blend_neon.h"
20 #include "aom_dsp/arm/dist_wtd_avg_neon.h"
21 #include "aom_dsp/arm/mem_neon.h"
22 #include "aom_dsp/blend.h"
23 
aom_highbd_comp_avg_pred_neon(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride)24 void aom_highbd_comp_avg_pred_neon(uint8_t *comp_pred8, const uint8_t *pred8,
25                                    int width, int height, const uint8_t *ref8,
26                                    int ref_stride) {
27   const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
28   const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
29   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
30 
31   int i = height;
32   if (width > 8) {
33     do {
34       int j = 0;
35       do {
36         const uint16x8_t p = vld1q_u16(pred + j);
37         const uint16x8_t r = vld1q_u16(ref + j);
38 
39         uint16x8_t avg = vrhaddq_u16(p, r);
40         vst1q_u16(comp_pred + j, avg);
41 
42         j += 8;
43       } while (j < width);
44 
45       comp_pred += width;
46       pred += width;
47       ref += ref_stride;
48     } while (--i != 0);
49   } else if (width == 8) {
50     do {
51       const uint16x8_t p = vld1q_u16(pred);
52       const uint16x8_t r = vld1q_u16(ref);
53 
54       uint16x8_t avg = vrhaddq_u16(p, r);
55       vst1q_u16(comp_pred, avg);
56 
57       comp_pred += width;
58       pred += width;
59       ref += ref_stride;
60     } while (--i != 0);
61   } else {
62     assert(width == 4);
63     do {
64       const uint16x4_t p = vld1_u16(pred);
65       const uint16x4_t r = vld1_u16(ref);
66 
67       uint16x4_t avg = vrhadd_u16(p, r);
68       vst1_u16(comp_pred, avg);
69 
70       comp_pred += width;
71       pred += width;
72       ref += ref_stride;
73     } while (--i != 0);
74   }
75 }
76 
aom_highbd_comp_mask_pred_neon(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)77 void aom_highbd_comp_mask_pred_neon(uint8_t *comp_pred8, const uint8_t *pred8,
78                                     int width, int height, const uint8_t *ref8,
79                                     int ref_stride, const uint8_t *mask,
80                                     int mask_stride, int invert_mask) {
81   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
82   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
83   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
84 
85   const uint16_t *src0 = invert_mask ? pred : ref;
86   const uint16_t *src1 = invert_mask ? ref : pred;
87   const int src_stride0 = invert_mask ? width : ref_stride;
88   const int src_stride1 = invert_mask ? ref_stride : width;
89 
90   if (width >= 8) {
91     do {
92       int j = 0;
93 
94       do {
95         const uint16x8_t s0 = vld1q_u16(src0 + j);
96         const uint16x8_t s1 = vld1q_u16(src1 + j);
97         const uint16x8_t m0 = vmovl_u8(vld1_u8(mask + j));
98 
99         uint16x8_t blend_u16 = alpha_blend_a64_u16x8(m0, s0, s1);
100 
101         vst1q_u16(comp_pred + j, blend_u16);
102 
103         j += 8;
104       } while (j < width);
105 
106       src0 += src_stride0;
107       src1 += src_stride1;
108       mask += mask_stride;
109       comp_pred += width;
110     } while (--height != 0);
111   } else {
112     assert(width == 4);
113 
114     do {
115       const uint16x4_t s0 = vld1_u16(src0);
116       const uint16x4_t s1 = vld1_u16(src1);
117       const uint16x4_t m0 = vget_low_u16(vmovl_u8(load_unaligned_u8_4x1(mask)));
118 
119       uint16x4_t blend_u16 = alpha_blend_a64_u16x4(m0, s0, s1);
120 
121       vst1_u16(comp_pred, blend_u16);
122 
123       src0 += src_stride0;
124       src1 += src_stride1;
125       mask += mask_stride;
126       comp_pred += 4;
127     } while (--height != 0);
128   }
129 }
130