• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 #include <assert.h>
15 #include <stdbool.h>
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/blend.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_ports/mem.h"
21 #include "av1/common/blockd.h"
22 #include "config/av1_rtcd.h"
23 
diffwtd_mask_d16_neon(uint8_t * mask,const bool inverse,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)24 static AOM_INLINE void diffwtd_mask_d16_neon(
25     uint8_t *mask, const bool inverse, const CONV_BUF_TYPE *src0,
26     int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
27     ConvolveParams *conv_params, int bd) {
28   const int round =
29       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
30   const int16x8_t round_vec = vdupq_n_s16((int16_t)(-round));
31 
32   if (w >= 16) {
33     int i = 0;
34     do {
35       int j = 0;
36       do {
37         uint16x8_t s0_lo = vld1q_u16(src0 + j);
38         uint16x8_t s1_lo = vld1q_u16(src1 + j);
39         uint16x8_t s0_hi = vld1q_u16(src0 + j + 8);
40         uint16x8_t s1_hi = vld1q_u16(src1 + j + 8);
41 
42         uint16x8_t diff_lo_u16 = vrshlq_u16(vabdq_u16(s0_lo, s1_lo), round_vec);
43         uint16x8_t diff_hi_u16 = vrshlq_u16(vabdq_u16(s0_hi, s1_hi), round_vec);
44         uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2);
45         uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2);
46         uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
47 
48         uint8x16_t m;
49         if (inverse) {
50           m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
51         } else {
52           m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
53         }
54 
55         vst1q_u8(mask, m);
56 
57         mask += 16;
58         j += 16;
59       } while (j < w);
60       src0 += src0_stride;
61       src1 += src1_stride;
62     } while (++i < h);
63   } else if (w == 8) {
64     int i = 0;
65     do {
66       uint16x8_t s0 = vld1q_u16(src0);
67       uint16x8_t s1 = vld1q_u16(src1);
68 
69       uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
70       uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
71       uint8x8_t m;
72       if (inverse) {
73         m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8);  // Saturating to 0
74       } else {
75         m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
76       }
77 
78       vst1_u8(mask, m);
79 
80       mask += 8;
81       src0 += src0_stride;
82       src1 += src1_stride;
83     } while (++i < h);
84   } else if (w == 4) {
85     int i = 0;
86     do {
87       uint16x8_t s0 =
88           vcombine_u16(vld1_u16(src0), vld1_u16(src0 + src0_stride));
89       uint16x8_t s1 =
90           vcombine_u16(vld1_u16(src1), vld1_u16(src1 + src1_stride));
91 
92       uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
93       uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
94       uint8x8_t m;
95       if (inverse) {
96         m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8);  // Saturating to 0
97       } else {
98         m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
99       }
100 
101       vst1_u8(mask, m);
102 
103       mask += 8;
104       src0 += 2 * src0_stride;
105       src1 += 2 * src1_stride;
106       i += 2;
107     } while (i < h);
108   }
109 }
110 
av1_build_compound_diffwtd_mask_d16_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)111 void av1_build_compound_diffwtd_mask_d16_neon(
112     uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
113     int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
114     ConvolveParams *conv_params, int bd) {
115   assert(h >= 4);
116   assert(w >= 4);
117   assert((mask_type == DIFFWTD_38_INV) || (mask_type == DIFFWTD_38));
118 
119   if (mask_type == DIFFWTD_38) {
120     diffwtd_mask_d16_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
121                           src1_stride, h, w, conv_params, bd);
122   } else {  // mask_type == DIFFWTD_38_INV
123     diffwtd_mask_d16_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
124                           src1_stride, h, w, conv_params, bd);
125   }
126 }
127 
diffwtd_mask_neon(uint8_t * mask,const bool inverse,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)128 static AOM_INLINE void diffwtd_mask_neon(uint8_t *mask, const bool inverse,
129                                          const uint8_t *src0, int src0_stride,
130                                          const uint8_t *src1, int src1_stride,
131                                          int h, int w) {
132   if (w >= 16) {
133     int i = 0;
134     do {
135       int j = 0;
136       do {
137         uint8x16_t s0 = vld1q_u8(src0 + j);
138         uint8x16_t s1 = vld1q_u8(src1 + j);
139 
140         uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
141         uint8x16_t m;
142         if (inverse) {
143           m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
144         } else {
145           m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
146         }
147 
148         vst1q_u8(mask, m);
149 
150         mask += 16;
151         j += 16;
152       } while (j < w);
153       src0 += src0_stride;
154       src1 += src1_stride;
155     } while (++i < h);
156   } else if (w == 8) {
157     int i = 0;
158     do {
159       uint8x16_t s0 = vcombine_u8(vld1_u8(src0), vld1_u8(src0 + src0_stride));
160       uint8x16_t s1 = vcombine_u8(vld1_u8(src1), vld1_u8(src1 + src0_stride));
161 
162       uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
163       uint8x16_t m;
164       if (inverse) {
165         m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
166       } else {
167         m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
168       }
169 
170       vst1q_u8(mask, m);
171 
172       mask += 16;
173       src0 += 2 * src0_stride;
174       src1 += 2 * src1_stride;
175       i += 2;
176     } while (i < h);
177   } else if (w == 4) {
178     int i = 0;
179     do {
180       uint8x16_t s0 = load_unaligned_u8q(src0, src0_stride);
181       uint8x16_t s1 = load_unaligned_u8q(src1, src1_stride);
182 
183       uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
184       uint8x16_t m;
185       if (inverse) {
186         m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
187       } else {
188         m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
189       }
190 
191       vst1q_u8(mask, m);
192 
193       mask += 16;
194       src0 += 4 * src0_stride;
195       src1 += 4 * src1_stride;
196       i += 4;
197     } while (i < h);
198   }
199 }
200 
av1_build_compound_diffwtd_mask_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)201 void av1_build_compound_diffwtd_mask_neon(uint8_t *mask,
202                                           DIFFWTD_MASK_TYPE mask_type,
203                                           const uint8_t *src0, int src0_stride,
204                                           const uint8_t *src1, int src1_stride,
205                                           int h, int w) {
206   assert(h % 4 == 0);
207   assert(w % 4 == 0);
208   assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38);
209 
210   if (mask_type == DIFFWTD_38) {
211     diffwtd_mask_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
212                       src1_stride, h, w);
213   } else {  // mask_type == DIFFWTD_38_INV
214     diffwtd_mask_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
215                       src1_stride, h, w);
216   }
217 }
218