1 /*
2 *
3 * Copyright (c) 2018, Alliance for Open Media. All rights reserved
4 *
5 * This source code is subject to the terms of the BSD 2 Clause License and
6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7 * was not distributed with this source code in the LICENSE file, you can
8 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9 * Media Patent License 1.0 was not distributed with this source code in the
10 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11 */
12
13 #include <arm_neon.h>
14 #include <assert.h>
15 #include <stdbool.h>
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/blend.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_ports/mem.h"
21 #include "av1/common/blockd.h"
22 #include "config/av1_rtcd.h"
23
diffwtd_mask_d16_neon(uint8_t * mask,const bool inverse,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)24 static AOM_INLINE void diffwtd_mask_d16_neon(
25 uint8_t *mask, const bool inverse, const CONV_BUF_TYPE *src0,
26 int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
27 ConvolveParams *conv_params, int bd) {
28 const int round =
29 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
30 const int16x8_t round_vec = vdupq_n_s16((int16_t)(-round));
31
32 if (w >= 16) {
33 int i = 0;
34 do {
35 int j = 0;
36 do {
37 uint16x8_t s0_lo = vld1q_u16(src0 + j);
38 uint16x8_t s1_lo = vld1q_u16(src1 + j);
39 uint16x8_t s0_hi = vld1q_u16(src0 + j + 8);
40 uint16x8_t s1_hi = vld1q_u16(src1 + j + 8);
41
42 uint16x8_t diff_lo_u16 = vrshlq_u16(vabdq_u16(s0_lo, s1_lo), round_vec);
43 uint16x8_t diff_hi_u16 = vrshlq_u16(vabdq_u16(s0_hi, s1_hi), round_vec);
44 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2);
45 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2);
46 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
47
48 uint8x16_t m;
49 if (inverse) {
50 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
51 } else {
52 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
53 }
54
55 vst1q_u8(mask, m);
56
57 mask += 16;
58 j += 16;
59 } while (j < w);
60 src0 += src0_stride;
61 src1 += src1_stride;
62 } while (++i < h);
63 } else if (w == 8) {
64 int i = 0;
65 do {
66 uint16x8_t s0 = vld1q_u16(src0);
67 uint16x8_t s1 = vld1q_u16(src1);
68
69 uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
70 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
71 uint8x8_t m;
72 if (inverse) {
73 m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8); // Saturating to 0
74 } else {
75 m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
76 }
77
78 vst1_u8(mask, m);
79
80 mask += 8;
81 src0 += src0_stride;
82 src1 += src1_stride;
83 } while (++i < h);
84 } else if (w == 4) {
85 int i = 0;
86 do {
87 uint16x8_t s0 =
88 vcombine_u16(vld1_u16(src0), vld1_u16(src0 + src0_stride));
89 uint16x8_t s1 =
90 vcombine_u16(vld1_u16(src1), vld1_u16(src1 + src1_stride));
91
92 uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
93 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
94 uint8x8_t m;
95 if (inverse) {
96 m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8); // Saturating to 0
97 } else {
98 m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
99 }
100
101 vst1_u8(mask, m);
102
103 mask += 8;
104 src0 += 2 * src0_stride;
105 src1 += 2 * src1_stride;
106 i += 2;
107 } while (i < h);
108 }
109 }
110
av1_build_compound_diffwtd_mask_d16_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)111 void av1_build_compound_diffwtd_mask_d16_neon(
112 uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
113 int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
114 ConvolveParams *conv_params, int bd) {
115 assert(h >= 4);
116 assert(w >= 4);
117 assert((mask_type == DIFFWTD_38_INV) || (mask_type == DIFFWTD_38));
118
119 if (mask_type == DIFFWTD_38) {
120 diffwtd_mask_d16_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
121 src1_stride, h, w, conv_params, bd);
122 } else { // mask_type == DIFFWTD_38_INV
123 diffwtd_mask_d16_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
124 src1_stride, h, w, conv_params, bd);
125 }
126 }
127
diffwtd_mask_neon(uint8_t * mask,const bool inverse,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)128 static AOM_INLINE void diffwtd_mask_neon(uint8_t *mask, const bool inverse,
129 const uint8_t *src0, int src0_stride,
130 const uint8_t *src1, int src1_stride,
131 int h, int w) {
132 if (w >= 16) {
133 int i = 0;
134 do {
135 int j = 0;
136 do {
137 uint8x16_t s0 = vld1q_u8(src0 + j);
138 uint8x16_t s1 = vld1q_u8(src1 + j);
139
140 uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
141 uint8x16_t m;
142 if (inverse) {
143 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
144 } else {
145 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
146 }
147
148 vst1q_u8(mask, m);
149
150 mask += 16;
151 j += 16;
152 } while (j < w);
153 src0 += src0_stride;
154 src1 += src1_stride;
155 } while (++i < h);
156 } else if (w == 8) {
157 int i = 0;
158 do {
159 uint8x16_t s0 = vcombine_u8(vld1_u8(src0), vld1_u8(src0 + src0_stride));
160 uint8x16_t s1 = vcombine_u8(vld1_u8(src1), vld1_u8(src1 + src0_stride));
161
162 uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
163 uint8x16_t m;
164 if (inverse) {
165 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
166 } else {
167 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
168 }
169
170 vst1q_u8(mask, m);
171
172 mask += 16;
173 src0 += 2 * src0_stride;
174 src1 += 2 * src1_stride;
175 i += 2;
176 } while (i < h);
177 } else if (w == 4) {
178 int i = 0;
179 do {
180 uint8x16_t s0 = load_unaligned_u8q(src0, src0_stride);
181 uint8x16_t s1 = load_unaligned_u8q(src1, src1_stride);
182
183 uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
184 uint8x16_t m;
185 if (inverse) {
186 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
187 } else {
188 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
189 }
190
191 vst1q_u8(mask, m);
192
193 mask += 16;
194 src0 += 4 * src0_stride;
195 src1 += 4 * src1_stride;
196 i += 4;
197 } while (i < h);
198 }
199 }
200
av1_build_compound_diffwtd_mask_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)201 void av1_build_compound_diffwtd_mask_neon(uint8_t *mask,
202 DIFFWTD_MASK_TYPE mask_type,
203 const uint8_t *src0, int src0_stride,
204 const uint8_t *src1, int src1_stride,
205 int h, int w) {
206 assert(h % 4 == 0);
207 assert(w % 4 == 0);
208 assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38);
209
210 if (mask_type == DIFFWTD_38) {
211 diffwtd_mask_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
212 src1_stride, h, w);
213 } else { // mask_type == DIFFWTD_38_INV
214 diffwtd_mask_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
215 src1_stride, h, w);
216 }
217 }
218