1 /*
2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <assert.h>
13 #include <arm_neon.h>
14 #include <memory.h>
15 #include <math.h>
16
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_ports/mem.h"
19 #include "config/av1_rtcd.h"
20 #include "av1/common/warped_motion.h"
21 #include "av1/common/scale.h"
22
23 /* This is a modified version of 'warped_filter' from warped_motion.c:
24 * Each coefficient is stored in 8 bits instead of 16 bits
25 * The coefficients are rearranged in the column order 0, 2, 4, 6, 1, 3, 5, 7
26
27 This is done in order to avoid overflow: Since the tap with the largest
28 coefficient could be any of taps 2, 3, 4 or 5, we can't use the summation
29 order ((0 + 1) + (4 + 5)) + ((2 + 3) + (6 + 7)) used in the regular
30 convolve functions.
31
32 Instead, we use the summation order
33 ((0 + 2) + (4 + 6)) + ((1 + 3) + (5 + 7)).
34 The rearrangement of coefficients in this table is so that we can get the
35 coefficients into the correct order more quickly.
36 */
37 /* clang-format off */
38 DECLARE_ALIGNED(8, static const int8_t,
39 filter_8bit_neon[WARPEDPIXEL_PREC_SHIFTS * 3 + 1][8]) = {
40 #if WARPEDPIXEL_PREC_BITS == 6
41 // [-1, 0)
42 { 0, 127, 0, 0, 0, 1, 0, 0}, { 0, 127, 0, 0, -1, 2, 0, 0},
43 { 1, 127, -1, 0, -3, 4, 0, 0}, { 1, 126, -2, 0, -4, 6, 1, 0},
44 { 1, 126, -3, 0, -5, 8, 1, 0}, { 1, 125, -4, 0, -6, 11, 1, 0},
45 { 1, 124, -4, 0, -7, 13, 1, 0}, { 2, 123, -5, 0, -8, 15, 1, 0},
46 { 2, 122, -6, 0, -9, 18, 1, 0}, { 2, 121, -6, 0, -10, 20, 1, 0},
47 { 2, 120, -7, 0, -11, 22, 2, 0}, { 2, 119, -8, 0, -12, 25, 2, 0},
48 { 3, 117, -8, 0, -13, 27, 2, 0}, { 3, 116, -9, 0, -13, 29, 2, 0},
49 { 3, 114, -10, 0, -14, 32, 3, 0}, { 3, 113, -10, 0, -15, 35, 2, 0},
50 { 3, 111, -11, 0, -15, 37, 3, 0}, { 3, 109, -11, 0, -16, 40, 3, 0},
51 { 3, 108, -12, 0, -16, 42, 3, 0}, { 4, 106, -13, 0, -17, 45, 3, 0},
52 { 4, 104, -13, 0, -17, 47, 3, 0}, { 4, 102, -14, 0, -17, 50, 3, 0},
53 { 4, 100, -14, 0, -17, 52, 3, 0}, { 4, 98, -15, 0, -18, 55, 4, 0},
54 { 4, 96, -15, 0, -18, 58, 3, 0}, { 4, 94, -16, 0, -18, 60, 4, 0},
55 { 4, 91, -16, 0, -18, 63, 4, 0}, { 4, 89, -16, 0, -18, 65, 4, 0},
56 { 4, 87, -17, 0, -18, 68, 4, 0}, { 4, 85, -17, 0, -18, 70, 4, 0},
57 { 4, 82, -17, 0, -18, 73, 4, 0}, { 4, 80, -17, 0, -18, 75, 4, 0},
58 { 4, 78, -18, 0, -18, 78, 4, 0}, { 4, 75, -18, 0, -17, 80, 4, 0},
59 { 4, 73, -18, 0, -17, 82, 4, 0}, { 4, 70, -18, 0, -17, 85, 4, 0},
60 { 4, 68, -18, 0, -17, 87, 4, 0}, { 4, 65, -18, 0, -16, 89, 4, 0},
61 { 4, 63, -18, 0, -16, 91, 4, 0}, { 4, 60, -18, 0, -16, 94, 4, 0},
62 { 3, 58, -18, 0, -15, 96, 4, 0}, { 4, 55, -18, 0, -15, 98, 4, 0},
63 { 3, 52, -17, 0, -14, 100, 4, 0}, { 3, 50, -17, 0, -14, 102, 4, 0},
64 { 3, 47, -17, 0, -13, 104, 4, 0}, { 3, 45, -17, 0, -13, 106, 4, 0},
65 { 3, 42, -16, 0, -12, 108, 3, 0}, { 3, 40, -16, 0, -11, 109, 3, 0},
66 { 3, 37, -15, 0, -11, 111, 3, 0}, { 2, 35, -15, 0, -10, 113, 3, 0},
67 { 3, 32, -14, 0, -10, 114, 3, 0}, { 2, 29, -13, 0, -9, 116, 3, 0},
68 { 2, 27, -13, 0, -8, 117, 3, 0}, { 2, 25, -12, 0, -8, 119, 2, 0},
69 { 2, 22, -11, 0, -7, 120, 2, 0}, { 1, 20, -10, 0, -6, 121, 2, 0},
70 { 1, 18, -9, 0, -6, 122, 2, 0}, { 1, 15, -8, 0, -5, 123, 2, 0},
71 { 1, 13, -7, 0, -4, 124, 1, 0}, { 1, 11, -6, 0, -4, 125, 1, 0},
72 { 1, 8, -5, 0, -3, 126, 1, 0}, { 1, 6, -4, 0, -2, 126, 1, 0},
73 { 0, 4, -3, 0, -1, 127, 1, 0}, { 0, 2, -1, 0, 0, 127, 0, 0},
74 // [0, 1)
75 { 0, 0, 1, 0, 0, 127, 0, 0}, { 0, -1, 2, 0, 0, 127, 0, 0},
76 { 0, -3, 4, 1, 1, 127, -2, 0}, { 0, -5, 6, 1, 1, 127, -2, 0},
77 { 0, -6, 8, 1, 2, 126, -3, 0}, {-1, -7, 11, 2, 2, 126, -4, -1},
78 {-1, -8, 13, 2, 3, 125, -5, -1}, {-1, -10, 16, 3, 3, 124, -6, -1},
79 {-1, -11, 18, 3, 4, 123, -7, -1}, {-1, -12, 20, 3, 4, 122, -7, -1},
80 {-1, -13, 23, 3, 4, 121, -8, -1}, {-2, -14, 25, 4, 5, 120, -9, -1},
81 {-1, -15, 27, 4, 5, 119, -10, -1}, {-1, -16, 30, 4, 5, 118, -11, -1},
82 {-2, -17, 33, 5, 6, 116, -12, -1}, {-2, -17, 35, 5, 6, 114, -12, -1},
83 {-2, -18, 38, 5, 6, 113, -13, -1}, {-2, -19, 41, 6, 7, 111, -14, -2},
84 {-2, -19, 43, 6, 7, 110, -15, -2}, {-2, -20, 46, 6, 7, 108, -15, -2},
85 {-2, -20, 49, 6, 7, 106, -16, -2}, {-2, -21, 51, 7, 7, 104, -16, -2},
86 {-2, -21, 54, 7, 7, 102, -17, -2}, {-2, -21, 56, 7, 8, 100, -18, -2},
87 {-2, -22, 59, 7, 8, 98, -18, -2}, {-2, -22, 62, 7, 8, 96, -19, -2},
88 {-2, -22, 64, 7, 8, 94, -19, -2}, {-2, -22, 67, 8, 8, 91, -20, -2},
89 {-2, -22, 69, 8, 8, 89, -20, -2}, {-2, -22, 72, 8, 8, 87, -21, -2},
90 {-2, -21, 74, 8, 8, 84, -21, -2}, {-2, -22, 77, 8, 8, 82, -21, -2},
91 {-2, -21, 79, 8, 8, 79, -21, -2}, {-2, -21, 82, 8, 8, 77, -22, -2},
92 {-2, -21, 84, 8, 8, 74, -21, -2}, {-2, -21, 87, 8, 8, 72, -22, -2},
93 {-2, -20, 89, 8, 8, 69, -22, -2}, {-2, -20, 91, 8, 8, 67, -22, -2},
94 {-2, -19, 94, 8, 7, 64, -22, -2}, {-2, -19, 96, 8, 7, 62, -22, -2},
95 {-2, -18, 98, 8, 7, 59, -22, -2}, {-2, -18, 100, 8, 7, 56, -21, -2},
96 {-2, -17, 102, 7, 7, 54, -21, -2}, {-2, -16, 104, 7, 7, 51, -21, -2},
97 {-2, -16, 106, 7, 6, 49, -20, -2}, {-2, -15, 108, 7, 6, 46, -20, -2},
98 {-2, -15, 110, 7, 6, 43, -19, -2}, {-2, -14, 111, 7, 6, 41, -19, -2},
99 {-1, -13, 113, 6, 5, 38, -18, -2}, {-1, -12, 114, 6, 5, 35, -17, -2},
100 {-1, -12, 116, 6, 5, 33, -17, -2}, {-1, -11, 118, 5, 4, 30, -16, -1},
101 {-1, -10, 119, 5, 4, 27, -15, -1}, {-1, -9, 120, 5, 4, 25, -14, -2},
102 {-1, -8, 121, 4, 3, 23, -13, -1}, {-1, -7, 122, 4, 3, 20, -12, -1},
103 {-1, -7, 123, 4, 3, 18, -11, -1}, {-1, -6, 124, 3, 3, 16, -10, -1},
104 {-1, -5, 125, 3, 2, 13, -8, -1}, {-1, -4, 126, 2, 2, 11, -7, -1},
105 { 0, -3, 126, 2, 1, 8, -6, 0}, { 0, -2, 127, 1, 1, 6, -5, 0},
106 { 0, -2, 127, 1, 1, 4, -3, 0}, { 0, 0, 127, 0, 0, 2, -1, 0},
107 // [1, 2)
108 { 0, 0, 127, 0, 0, 1, 0, 0}, { 0, 0, 127, 0, 0, -1, 2, 0},
109 { 0, 1, 127, -1, 0, -3, 4, 0}, { 0, 1, 126, -2, 0, -4, 6, 1},
110 { 0, 1, 126, -3, 0, -5, 8, 1}, { 0, 1, 125, -4, 0, -6, 11, 1},
111 { 0, 1, 124, -4, 0, -7, 13, 1}, { 0, 2, 123, -5, 0, -8, 15, 1},
112 { 0, 2, 122, -6, 0, -9, 18, 1}, { 0, 2, 121, -6, 0, -10, 20, 1},
113 { 0, 2, 120, -7, 0, -11, 22, 2}, { 0, 2, 119, -8, 0, -12, 25, 2},
114 { 0, 3, 117, -8, 0, -13, 27, 2}, { 0, 3, 116, -9, 0, -13, 29, 2},
115 { 0, 3, 114, -10, 0, -14, 32, 3}, { 0, 3, 113, -10, 0, -15, 35, 2},
116 { 0, 3, 111, -11, 0, -15, 37, 3}, { 0, 3, 109, -11, 0, -16, 40, 3},
117 { 0, 3, 108, -12, 0, -16, 42, 3}, { 0, 4, 106, -13, 0, -17, 45, 3},
118 { 0, 4, 104, -13, 0, -17, 47, 3}, { 0, 4, 102, -14, 0, -17, 50, 3},
119 { 0, 4, 100, -14, 0, -17, 52, 3}, { 0, 4, 98, -15, 0, -18, 55, 4},
120 { 0, 4, 96, -15, 0, -18, 58, 3}, { 0, 4, 94, -16, 0, -18, 60, 4},
121 { 0, 4, 91, -16, 0, -18, 63, 4}, { 0, 4, 89, -16, 0, -18, 65, 4},
122 { 0, 4, 87, -17, 0, -18, 68, 4}, { 0, 4, 85, -17, 0, -18, 70, 4},
123 { 0, 4, 82, -17, 0, -18, 73, 4}, { 0, 4, 80, -17, 0, -18, 75, 4},
124 { 0, 4, 78, -18, 0, -18, 78, 4}, { 0, 4, 75, -18, 0, -17, 80, 4},
125 { 0, 4, 73, -18, 0, -17, 82, 4}, { 0, 4, 70, -18, 0, -17, 85, 4},
126 { 0, 4, 68, -18, 0, -17, 87, 4}, { 0, 4, 65, -18, 0, -16, 89, 4},
127 { 0, 4, 63, -18, 0, -16, 91, 4}, { 0, 4, 60, -18, 0, -16, 94, 4},
128 { 0, 3, 58, -18, 0, -15, 96, 4}, { 0, 4, 55, -18, 0, -15, 98, 4},
129 { 0, 3, 52, -17, 0, -14, 100, 4}, { 0, 3, 50, -17, 0, -14, 102, 4},
130 { 0, 3, 47, -17, 0, -13, 104, 4}, { 0, 3, 45, -17, 0, -13, 106, 4},
131 { 0, 3, 42, -16, 0, -12, 108, 3}, { 0, 3, 40, -16, 0, -11, 109, 3},
132 { 0, 3, 37, -15, 0, -11, 111, 3}, { 0, 2, 35, -15, 0, -10, 113, 3},
133 { 0, 3, 32, -14, 0, -10, 114, 3}, { 0, 2, 29, -13, 0, -9, 116, 3},
134 { 0, 2, 27, -13, 0, -8, 117, 3}, { 0, 2, 25, -12, 0, -8, 119, 2},
135 { 0, 2, 22, -11, 0, -7, 120, 2}, { 0, 1, 20, -10, 0, -6, 121, 2},
136 { 0, 1, 18, -9, 0, -6, 122, 2}, { 0, 1, 15, -8, 0, -5, 123, 2},
137 { 0, 1, 13, -7, 0, -4, 124, 1}, { 0, 1, 11, -6, 0, -4, 125, 1},
138 { 0, 1, 8, -5, 0, -3, 126, 1}, { 0, 1, 6, -4, 0, -2, 126, 1},
139 { 0, 0, 4, -3, 0, -1, 127, 1}, { 0, 0, 2, -1, 0, 0, 127, 0},
140 // dummy (replicate row index 191)
141 { 0, 0, 2, -1, 0, 0, 127, 0},
142
143 #else
144 // [-1, 0)
145 { 0, 127, 0, 0, 0, 1, 0, 0}, { 1, 127, -1, 0, -3, 4, 0, 0},
146 { 1, 126, -3, 0, -5, 8, 1, 0}, { 1, 124, -4, 0, -7, 13, 1, 0},
147 { 2, 122, -6, 0, -9, 18, 1, 0}, { 2, 120, -7, 0, -11, 22, 2, 0},
148 { 3, 117, -8, 0, -13, 27, 2, 0}, { 3, 114, -10, 0, -14, 32, 3, 0},
149 { 3, 111, -11, 0, -15, 37, 3, 0}, { 3, 108, -12, 0, -16, 42, 3, 0},
150 { 4, 104, -13, 0, -17, 47, 3, 0}, { 4, 100, -14, 0, -17, 52, 3, 0},
151 { 4, 96, -15, 0, -18, 58, 3, 0}, { 4, 91, -16, 0, -18, 63, 4, 0},
152 { 4, 87, -17, 0, -18, 68, 4, 0}, { 4, 82, -17, 0, -18, 73, 4, 0},
153 { 4, 78, -18, 0, -18, 78, 4, 0}, { 4, 73, -18, 0, -17, 82, 4, 0},
154 { 4, 68, -18, 0, -17, 87, 4, 0}, { 4, 63, -18, 0, -16, 91, 4, 0},
155 { 3, 58, -18, 0, -15, 96, 4, 0}, { 3, 52, -17, 0, -14, 100, 4, 0},
156 { 3, 47, -17, 0, -13, 104, 4, 0}, { 3, 42, -16, 0, -12, 108, 3, 0},
157 { 3, 37, -15, 0, -11, 111, 3, 0}, { 3, 32, -14, 0, -10, 114, 3, 0},
158 { 2, 27, -13, 0, -8, 117, 3, 0}, { 2, 22, -11, 0, -7, 120, 2, 0},
159 { 1, 18, -9, 0, -6, 122, 2, 0}, { 1, 13, -7, 0, -4, 124, 1, 0},
160 { 1, 8, -5, 0, -3, 126, 1, 0}, { 0, 4, -3, 0, -1, 127, 1, 0},
161 // [0, 1)
162 { 0, 0, 1, 0, 0, 127, 0, 0}, { 0, -3, 4, 1, 1, 127, -2, 0},
163 { 0, -6, 8, 1, 2, 126, -3, 0}, {-1, -8, 13, 2, 3, 125, -5, -1},
164 {-1, -11, 18, 3, 4, 123, -7, -1}, {-1, -13, 23, 3, 4, 121, -8, -1},
165 {-1, -15, 27, 4, 5, 119, -10, -1}, {-2, -17, 33, 5, 6, 116, -12, -1},
166 {-2, -18, 38, 5, 6, 113, -13, -1}, {-2, -19, 43, 6, 7, 110, -15, -2},
167 {-2, -20, 49, 6, 7, 106, -16, -2}, {-2, -21, 54, 7, 7, 102, -17, -2},
168 {-2, -22, 59, 7, 8, 98, -18, -2}, {-2, -22, 64, 7, 8, 94, -19, -2},
169 {-2, -22, 69, 8, 8, 89, -20, -2}, {-2, -21, 74, 8, 8, 84, -21, -2},
170 {-2, -21, 79, 8, 8, 79, -21, -2}, {-2, -21, 84, 8, 8, 74, -21, -2},
171 {-2, -20, 89, 8, 8, 69, -22, -2}, {-2, -19, 94, 8, 7, 64, -22, -2},
172 {-2, -18, 98, 8, 7, 59, -22, -2}, {-2, -17, 102, 7, 7, 54, -21, -2},
173 {-2, -16, 106, 7, 6, 49, -20, -2}, {-2, -15, 110, 7, 6, 43, -19, -2},
174 {-1, -13, 113, 6, 5, 38, -18, -2}, {-1, -12, 116, 6, 5, 33, -17, -2},
175 {-1, -10, 119, 5, 4, 27, -15, -1}, {-1, -8, 121, 4, 3, 23, -13, -1},
176 {-1, -7, 123, 4, 3, 18, -11, -1}, {-1, -5, 125, 3, 2, 13, -8, -1},
177 { 0, -3, 126, 2, 1, 8, -6, 0}, { 0, -2, 127, 1, 1, 4, -3, 0},
178 // [1, 2)
179 { 0, 0, 127, 0, 0, 1, 0, 0}, { 0, 1, 127, -1, 0, -3, 4, 0},
180 { 0, 1, 126, -3, 0, -5, 8, 1}, { 0, 1, 124, -4, 0, -7, 13, 1},
181 { 0, 2, 122, -6, 0, -9, 18, 1}, { 0, 2, 120, -7, 0, -11, 22, 2},
182 { 0, 3, 117, -8, 0, -13, 27, 2}, { 0, 3, 114, -10, 0, -14, 32, 3},
183 { 0, 3, 111, -11, 0, -15, 37, 3}, { 0, 3, 108, -12, 0, -16, 42, 3},
184 { 0, 4, 104, -13, 0, -17, 47, 3}, { 0, 4, 100, -14, 0, -17, 52, 3},
185 { 0, 4, 96, -15, 0, -18, 58, 3}, { 0, 4, 91, -16, 0, -18, 63, 4},
186 { 0, 4, 87, -17, 0, -18, 68, 4}, { 0, 4, 82, -17, 0, -18, 73, 4},
187 { 0, 4, 78, -18, 0, -18, 78, 4}, { 0, 4, 73, -18, 0, -17, 82, 4},
188 { 0, 4, 68, -18, 0, -17, 87, 4}, { 0, 4, 63, -18, 0, -16, 91, 4},
189 { 0, 3, 58, -18, 0, -15, 96, 4}, { 0, 3, 52, -17, 0, -14, 100, 4},
190 { 0, 3, 47, -17, 0, -13, 104, 4}, { 0, 3, 42, -16, 0, -12, 108, 3},
191 { 0, 3, 37, -15, 0, -11, 111, 3}, { 0, 3, 32, -14, 0, -10, 114, 3},
192 { 0, 2, 27, -13, 0, -8, 117, 3}, { 0, 2, 22, -11, 0, -7, 120, 2},
193 { 0, 1, 18, -9, 0, -6, 122, 2}, { 0, 1, 13, -7, 0, -4, 124, 1},
194 { 0, 1, 8, -5, 0, -3, 126, 1}, { 0, 0, 4, -3, 0, -1, 127, 1},
195 // dummy (replicate row index 95)
196 { 0, 0, 4, -3, 0, -1, 127, 1},
197 #endif // WARPEDPIXEL_PREC_BITS == 6
198 };
199 /* clang-format on */
200
convolve(int32x2x2_t x0,int32x2x2_t x1,uint8x8_t src_0,uint8x8_t src_1,int16x4_t * res)201 static INLINE void convolve(int32x2x2_t x0, int32x2x2_t x1, uint8x8_t src_0,
202 uint8x8_t src_1, int16x4_t *res) {
203 int16x8_t coeff_0, coeff_1;
204 int16x8_t pix_0, pix_1;
205
206 coeff_0 = vcombine_s16(vreinterpret_s16_s32(x0.val[0]),
207 vreinterpret_s16_s32(x1.val[0]));
208 coeff_1 = vcombine_s16(vreinterpret_s16_s32(x0.val[1]),
209 vreinterpret_s16_s32(x1.val[1]));
210
211 pix_0 = vreinterpretq_s16_u16(vmovl_u8(src_0));
212 pix_0 = vmulq_s16(coeff_0, pix_0);
213
214 pix_1 = vreinterpretq_s16_u16(vmovl_u8(src_1));
215 pix_0 = vmlaq_s16(pix_0, coeff_1, pix_1);
216
217 *res = vpadd_s16(vget_low_s16(pix_0), vget_high_s16(pix_0));
218 }
219
horizontal_filter_neon(uint8x16_t src_1,uint8x16_t src_2,uint8x16_t src_3,uint8x16_t src_4,int16x8_t * tmp_dst,int sx,int alpha,int k,const int offset_bits_horiz,const int reduce_bits_horiz)220 static INLINE void horizontal_filter_neon(uint8x16_t src_1, uint8x16_t src_2,
221 uint8x16_t src_3, uint8x16_t src_4,
222 int16x8_t *tmp_dst, int sx, int alpha,
223 int k, const int offset_bits_horiz,
224 const int reduce_bits_horiz) {
225 const uint8x16_t mask = { 255, 0, 255, 0, 255, 0, 255, 0,
226 255, 0, 255, 0, 255, 0, 255, 0 };
227 const int32x4_t add_const = vdupq_n_s32((int32_t)(1 << offset_bits_horiz));
228 const int16x8_t shift = vdupq_n_s16(-(int16_t)reduce_bits_horiz);
229
230 int16x8_t f0, f1, f2, f3, f4, f5, f6, f7;
231 int32x2x2_t b0, b1;
232 uint8x8_t src_1_low, src_2_low, src_3_low, src_4_low, src_5_low, src_6_low;
233 int32x4_t tmp_res_low, tmp_res_high;
234 uint16x8_t res;
235 int16x4_t res_0246_even, res_0246_odd, res_1357_even, res_1357_odd;
236
237 uint8x16_t tmp_0 = vandq_u8(src_1, mask);
238 uint8x16_t tmp_1 = vandq_u8(src_2, mask);
239 uint8x16_t tmp_2 = vandq_u8(src_3, mask);
240 uint8x16_t tmp_3 = vandq_u8(src_4, mask);
241
242 tmp_2 = vextq_u8(tmp_0, tmp_0, 1);
243 tmp_3 = vextq_u8(tmp_1, tmp_1, 1);
244
245 src_1 = vaddq_u8(tmp_0, tmp_2);
246 src_2 = vaddq_u8(tmp_1, tmp_3);
247
248 src_1_low = vget_low_u8(src_1);
249 src_2_low = vget_low_u8(src_2);
250 src_3_low = vget_low_u8(vextq_u8(src_1, src_1, 4));
251 src_4_low = vget_low_u8(vextq_u8(src_2, src_2, 4));
252 src_5_low = vget_low_u8(vextq_u8(src_1, src_1, 2));
253 src_6_low = vget_low_u8(vextq_u8(src_1, src_1, 6));
254
255 // Loading the 8 filter taps
256 f0 = vmovl_s8(
257 vld1_s8(filter_8bit_neon[(sx + 0 * alpha) >> WARPEDDIFF_PREC_BITS]));
258 f1 = vmovl_s8(
259 vld1_s8(filter_8bit_neon[(sx + 1 * alpha) >> WARPEDDIFF_PREC_BITS]));
260 f2 = vmovl_s8(
261 vld1_s8(filter_8bit_neon[(sx + 2 * alpha) >> WARPEDDIFF_PREC_BITS]));
262 f3 = vmovl_s8(
263 vld1_s8(filter_8bit_neon[(sx + 3 * alpha) >> WARPEDDIFF_PREC_BITS]));
264 f4 = vmovl_s8(
265 vld1_s8(filter_8bit_neon[(sx + 4 * alpha) >> WARPEDDIFF_PREC_BITS]));
266 f5 = vmovl_s8(
267 vld1_s8(filter_8bit_neon[(sx + 5 * alpha) >> WARPEDDIFF_PREC_BITS]));
268 f6 = vmovl_s8(
269 vld1_s8(filter_8bit_neon[(sx + 6 * alpha) >> WARPEDDIFF_PREC_BITS]));
270 f7 = vmovl_s8(
271 vld1_s8(filter_8bit_neon[(sx + 7 * alpha) >> WARPEDDIFF_PREC_BITS]));
272
273 b0 = vtrn_s32(vreinterpret_s32_s16(vget_low_s16(f0)),
274 vreinterpret_s32_s16(vget_low_s16(f2)));
275 b1 = vtrn_s32(vreinterpret_s32_s16(vget_low_s16(f4)),
276 vreinterpret_s32_s16(vget_low_s16(f6)));
277 convolve(b0, b1, src_1_low, src_3_low, &res_0246_even);
278
279 b0 = vtrn_s32(vreinterpret_s32_s16(vget_low_s16(f1)),
280 vreinterpret_s32_s16(vget_low_s16(f3)));
281 b1 = vtrn_s32(vreinterpret_s32_s16(vget_low_s16(f5)),
282 vreinterpret_s32_s16(vget_low_s16(f7)));
283 convolve(b0, b1, src_2_low, src_4_low, &res_0246_odd);
284
285 b0 = vtrn_s32(vreinterpret_s32_s16(vget_high_s16(f0)),
286 vreinterpret_s32_s16(vget_high_s16(f2)));
287 b1 = vtrn_s32(vreinterpret_s32_s16(vget_high_s16(f4)),
288 vreinterpret_s32_s16(vget_high_s16(f6)));
289 convolve(b0, b1, src_2_low, src_4_low, &res_1357_even);
290
291 b0 = vtrn_s32(vreinterpret_s32_s16(vget_high_s16(f1)),
292 vreinterpret_s32_s16(vget_high_s16(f3)));
293 b1 = vtrn_s32(vreinterpret_s32_s16(vget_high_s16(f5)),
294 vreinterpret_s32_s16(vget_high_s16(f7)));
295 convolve(b0, b1, src_5_low, src_6_low, &res_1357_odd);
296
297 tmp_res_low = vaddl_s16(res_0246_even, res_1357_even);
298 tmp_res_high = vaddl_s16(res_0246_odd, res_1357_odd);
299
300 tmp_res_low = vaddq_s32(tmp_res_low, add_const);
301 tmp_res_high = vaddq_s32(tmp_res_high, add_const);
302
303 res = vcombine_u16(vqmovun_s32(tmp_res_low), vqmovun_s32(tmp_res_high));
304 res = vqrshlq_u16(res, shift);
305
306 tmp_dst[k + 7] = vreinterpretq_s16_u16(res);
307 }
308
vertical_filter_neon(const int16x8_t * src,int32x4_t * res_low,int32x4_t * res_high,int sy,int gamma)309 static INLINE void vertical_filter_neon(const int16x8_t *src,
310 int32x4_t *res_low, int32x4_t *res_high,
311 int sy, int gamma) {
312 int16x4_t src_0, src_1, fltr_0, fltr_1;
313 int32x4_t res_0, res_1;
314 int32x2_t res_0_im, res_1_im;
315 int32x4_t res_even, res_odd, im_res_0, im_res_1;
316
317 int16x8_t f0, f1, f2, f3, f4, f5, f6, f7;
318 int16x8x2_t b0, b1, b2, b3;
319 int32x4x2_t c0, c1, c2, c3;
320 int32x4x2_t d0, d1, d2, d3;
321
322 b0 = vtrnq_s16(src[0], src[1]);
323 b1 = vtrnq_s16(src[2], src[3]);
324 b2 = vtrnq_s16(src[4], src[5]);
325 b3 = vtrnq_s16(src[6], src[7]);
326
327 c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
328 vreinterpretq_s32_s16(b0.val[1]));
329 c1 = vtrnq_s32(vreinterpretq_s32_s16(b1.val[0]),
330 vreinterpretq_s32_s16(b1.val[1]));
331 c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
332 vreinterpretq_s32_s16(b2.val[1]));
333 c3 = vtrnq_s32(vreinterpretq_s32_s16(b3.val[0]),
334 vreinterpretq_s32_s16(b3.val[1]));
335
336 f0 = vld1q_s16(
337 (int16_t *)(warped_filter + ((sy + 0 * gamma) >> WARPEDDIFF_PREC_BITS)));
338 f1 = vld1q_s16(
339 (int16_t *)(warped_filter + ((sy + 1 * gamma) >> WARPEDDIFF_PREC_BITS)));
340 f2 = vld1q_s16(
341 (int16_t *)(warped_filter + ((sy + 2 * gamma) >> WARPEDDIFF_PREC_BITS)));
342 f3 = vld1q_s16(
343 (int16_t *)(warped_filter + ((sy + 3 * gamma) >> WARPEDDIFF_PREC_BITS)));
344 f4 = vld1q_s16(
345 (int16_t *)(warped_filter + ((sy + 4 * gamma) >> WARPEDDIFF_PREC_BITS)));
346 f5 = vld1q_s16(
347 (int16_t *)(warped_filter + ((sy + 5 * gamma) >> WARPEDDIFF_PREC_BITS)));
348 f6 = vld1q_s16(
349 (int16_t *)(warped_filter + ((sy + 6 * gamma) >> WARPEDDIFF_PREC_BITS)));
350 f7 = vld1q_s16(
351 (int16_t *)(warped_filter + ((sy + 7 * gamma) >> WARPEDDIFF_PREC_BITS)));
352
353 d0 = vtrnq_s32(vreinterpretq_s32_s16(f0), vreinterpretq_s32_s16(f2));
354 d1 = vtrnq_s32(vreinterpretq_s32_s16(f4), vreinterpretq_s32_s16(f6));
355 d2 = vtrnq_s32(vreinterpretq_s32_s16(f1), vreinterpretq_s32_s16(f3));
356 d3 = vtrnq_s32(vreinterpretq_s32_s16(f5), vreinterpretq_s32_s16(f7));
357
358 // row:0,1 even_col:0,2
359 src_0 = vget_low_s16(vreinterpretq_s16_s32(c0.val[0]));
360 fltr_0 = vget_low_s16(vreinterpretq_s16_s32(d0.val[0]));
361 res_0 = vmull_s16(src_0, fltr_0);
362
363 // row:0,1,2,3 even_col:0,2
364 src_0 = vget_low_s16(vreinterpretq_s16_s32(c1.val[0]));
365 fltr_0 = vget_low_s16(vreinterpretq_s16_s32(d0.val[1]));
366 res_0 = vmlal_s16(res_0, src_0, fltr_0);
367 res_0_im = vpadd_s32(vget_low_s32(res_0), vget_high_s32(res_0));
368
369 // row:0,1 even_col:4,6
370 src_1 = vget_low_s16(vreinterpretq_s16_s32(c0.val[1]));
371 fltr_1 = vget_low_s16(vreinterpretq_s16_s32(d1.val[0]));
372 res_1 = vmull_s16(src_1, fltr_1);
373
374 // row:0,1,2,3 even_col:4,6
375 src_1 = vget_low_s16(vreinterpretq_s16_s32(c1.val[1]));
376 fltr_1 = vget_low_s16(vreinterpretq_s16_s32(d1.val[1]));
377 res_1 = vmlal_s16(res_1, src_1, fltr_1);
378 res_1_im = vpadd_s32(vget_low_s32(res_1), vget_high_s32(res_1));
379
380 // row:0,1,2,3 even_col:0,2,4,6
381 im_res_0 = vcombine_s32(res_0_im, res_1_im);
382
383 // row:4,5 even_col:0,2
384 src_0 = vget_low_s16(vreinterpretq_s16_s32(c2.val[0]));
385 fltr_0 = vget_high_s16(vreinterpretq_s16_s32(d0.val[0]));
386 res_0 = vmull_s16(src_0, fltr_0);
387
388 // row:4,5,6,7 even_col:0,2
389 src_0 = vget_low_s16(vreinterpretq_s16_s32(c3.val[0]));
390 fltr_0 = vget_high_s16(vreinterpretq_s16_s32(d0.val[1]));
391 res_0 = vmlal_s16(res_0, src_0, fltr_0);
392 res_0_im = vpadd_s32(vget_low_s32(res_0), vget_high_s32(res_0));
393
394 // row:4,5 even_col:4,6
395 src_1 = vget_low_s16(vreinterpretq_s16_s32(c2.val[1]));
396 fltr_1 = vget_high_s16(vreinterpretq_s16_s32(d1.val[0]));
397 res_1 = vmull_s16(src_1, fltr_1);
398
399 // row:4,5,6,7 even_col:4,6
400 src_1 = vget_low_s16(vreinterpretq_s16_s32(c3.val[1]));
401 fltr_1 = vget_high_s16(vreinterpretq_s16_s32(d1.val[1]));
402 res_1 = vmlal_s16(res_1, src_1, fltr_1);
403 res_1_im = vpadd_s32(vget_low_s32(res_1), vget_high_s32(res_1));
404
405 // row:4,5,6,7 even_col:0,2,4,6
406 im_res_1 = vcombine_s32(res_0_im, res_1_im);
407
408 // row:0-7 even_col:0,2,4,6
409 res_even = vaddq_s32(im_res_0, im_res_1);
410
411 // row:0,1 odd_col:1,3
412 src_0 = vget_high_s16(vreinterpretq_s16_s32(c0.val[0]));
413 fltr_0 = vget_low_s16(vreinterpretq_s16_s32(d2.val[0]));
414 res_0 = vmull_s16(src_0, fltr_0);
415
416 // row:0,1,2,3 odd_col:1,3
417 src_0 = vget_high_s16(vreinterpretq_s16_s32(c1.val[0]));
418 fltr_0 = vget_low_s16(vreinterpretq_s16_s32(d2.val[1]));
419 res_0 = vmlal_s16(res_0, src_0, fltr_0);
420 res_0_im = vpadd_s32(vget_low_s32(res_0), vget_high_s32(res_0));
421
422 // row:0,1 odd_col:5,7
423 src_1 = vget_high_s16(vreinterpretq_s16_s32(c0.val[1]));
424 fltr_1 = vget_low_s16(vreinterpretq_s16_s32(d3.val[0]));
425 res_1 = vmull_s16(src_1, fltr_1);
426
427 // row:0,1,2,3 odd_col:5,7
428 src_1 = vget_high_s16(vreinterpretq_s16_s32(c1.val[1]));
429 fltr_1 = vget_low_s16(vreinterpretq_s16_s32(d3.val[1]));
430 res_1 = vmlal_s16(res_1, src_1, fltr_1);
431 res_1_im = vpadd_s32(vget_low_s32(res_1), vget_high_s32(res_1));
432
433 // row:0,1,2,3 odd_col:1,3,5,7
434 im_res_0 = vcombine_s32(res_0_im, res_1_im);
435
436 // row:4,5 odd_col:1,3
437 src_0 = vget_high_s16(vreinterpretq_s16_s32(c2.val[0]));
438 fltr_0 = vget_high_s16(vreinterpretq_s16_s32(d2.val[0]));
439 res_0 = vmull_s16(src_0, fltr_0);
440
441 // row:4,5,6,7 odd_col:1,3
442 src_0 = vget_high_s16(vreinterpretq_s16_s32(c3.val[0]));
443 fltr_0 = vget_high_s16(vreinterpretq_s16_s32(d2.val[1]));
444 res_0 = vmlal_s16(res_0, src_0, fltr_0);
445 res_0_im = vpadd_s32(vget_low_s32(res_0), vget_high_s32(res_0));
446
447 // row:4,5 odd_col:5,7
448 src_1 = vget_high_s16(vreinterpretq_s16_s32(c2.val[1]));
449 fltr_1 = vget_high_s16(vreinterpretq_s16_s32(d3.val[0]));
450 res_1 = vmull_s16(src_1, fltr_1);
451
452 // row:4,5,6,7 odd_col:5,7
453 src_1 = vget_high_s16(vreinterpretq_s16_s32(c3.val[1]));
454 fltr_1 = vget_high_s16(vreinterpretq_s16_s32(d3.val[1]));
455 res_1 = vmlal_s16(res_1, src_1, fltr_1);
456 res_1_im = vpadd_s32(vget_low_s32(res_1), vget_high_s32(res_1));
457
458 // row:4,5,6,7 odd_col:1,3,5,7
459 im_res_1 = vcombine_s32(res_0_im, res_1_im);
460
461 // row:0-7 odd_col:1,3,5,7
462 res_odd = vaddq_s32(im_res_0, im_res_1);
463
464 // reordering as 0 1 2 3 | 4 5 6 7
465 c0 = vtrnq_s32(res_even, res_odd);
466
467 // Final store
468 *res_low = vcombine_s32(vget_low_s32(c0.val[0]), vget_low_s32(c0.val[1]));
469 *res_high = vcombine_s32(vget_high_s32(c0.val[0]), vget_high_s32(c0.val[1]));
470 }
471
av1_warp_affine_neon(const int32_t * mat,const uint8_t * ref,int width,int height,int stride,uint8_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)472 void av1_warp_affine_neon(const int32_t *mat, const uint8_t *ref, int width,
473 int height, int stride, uint8_t *pred, int p_col,
474 int p_row, int p_width, int p_height, int p_stride,
475 int subsampling_x, int subsampling_y,
476 ConvolveParams *conv_params, int16_t alpha,
477 int16_t beta, int16_t gamma, int16_t delta) {
478 int16x8_t tmp[15];
479 const int bd = 8;
480 const int w0 = conv_params->fwd_offset;
481 const int w1 = conv_params->bck_offset;
482 const int32x4_t fwd = vdupq_n_s32((int32_t)w0);
483 const int32x4_t bwd = vdupq_n_s32((int32_t)w1);
484 const int16x8_t sub_constant = vdupq_n_s16((1 << (bd - 1)) + (1 << bd));
485
486 int limit = 0;
487 uint8x16_t vec_dup, mask_val;
488 int32x4_t res_lo, res_hi;
489 int16x8_t result_final;
490 uint8x16_t src_1, src_2, src_3, src_4;
491 uint8x16_t indx_vec = {
492 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
493 };
494 uint8x16_t cmp_vec;
495
496 const int reduce_bits_horiz = conv_params->round_0;
497 const int reduce_bits_vert = conv_params->is_compound
498 ? conv_params->round_1
499 : 2 * FILTER_BITS - reduce_bits_horiz;
500 const int32x4_t shift_vert = vdupq_n_s32(-(int32_t)reduce_bits_vert);
501 const int offset_bits_horiz = bd + FILTER_BITS - 1;
502
503 assert(IMPLIES(conv_params->is_compound, conv_params->dst != NULL));
504
505 const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz;
506 int32x4_t add_const_vert = vdupq_n_s32((int32_t)(1 << offset_bits_vert));
507 const int round_bits =
508 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
509 const int16x4_t round_bits_vec = vdup_n_s16(-(int16_t)round_bits);
510 const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
511 const int16x4_t res_sub_const =
512 vdup_n_s16(-((1 << (offset_bits - conv_params->round_1)) +
513 (1 << (offset_bits - conv_params->round_1 - 1))));
514 int k;
515
516 assert(IMPLIES(conv_params->do_average, conv_params->is_compound));
517
518 for (int i = 0; i < p_height; i += 8) {
519 for (int j = 0; j < p_width; j += 8) {
520 const int32_t src_x = (p_col + j + 4) << subsampling_x;
521 const int32_t src_y = (p_row + i + 4) << subsampling_y;
522 const int32_t dst_x = mat[2] * src_x + mat[3] * src_y + mat[0];
523 const int32_t dst_y = mat[4] * src_x + mat[5] * src_y + mat[1];
524 const int32_t x4 = dst_x >> subsampling_x;
525 const int32_t y4 = dst_y >> subsampling_y;
526
527 int32_t ix4 = x4 >> WARPEDMODEL_PREC_BITS;
528 int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
529 int32_t iy4 = y4 >> WARPEDMODEL_PREC_BITS;
530 int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
531
532 sx4 += alpha * (-4) + beta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
533 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
534 sy4 += gamma * (-4) + delta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
535 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
536
537 sx4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
538 sy4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
539 // horizontal
540 if (ix4 <= -7) {
541 for (k = -7; k < AOMMIN(8, p_height - i); ++k) {
542 int iy = iy4 + k;
543 if (iy < 0)
544 iy = 0;
545 else if (iy > height - 1)
546 iy = height - 1;
547 int16_t dup_val =
548 (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
549 ref[iy * stride] * (1 << (FILTER_BITS - reduce_bits_horiz));
550
551 tmp[k + 7] = vdupq_n_s16(dup_val);
552 }
553 } else if (ix4 >= width + 6) {
554 for (k = -7; k < AOMMIN(8, p_height - i); ++k) {
555 int iy = iy4 + k;
556 if (iy < 0)
557 iy = 0;
558 else if (iy > height - 1)
559 iy = height - 1;
560 int16_t dup_val = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
561 ref[iy * stride + (width - 1)] *
562 (1 << (FILTER_BITS - reduce_bits_horiz));
563 tmp[k + 7] = vdupq_n_s16(dup_val);
564 }
565 } else if (((ix4 - 7) < 0) || ((ix4 + 9) > width)) {
566 const int out_of_boundary_left = -(ix4 - 6);
567 const int out_of_boundary_right = (ix4 + 8) - width;
568
569 for (k = -7; k < AOMMIN(8, p_height - i); ++k) {
570 int iy = iy4 + k;
571 if (iy < 0)
572 iy = 0;
573 else if (iy > height - 1)
574 iy = height - 1;
575 int sx = sx4 + beta * (k + 4);
576
577 const uint8_t *src = ref + iy * stride + ix4 - 7;
578 src_1 = vld1q_u8(src);
579
580 if (out_of_boundary_left >= 0) {
581 limit = out_of_boundary_left + 1;
582 cmp_vec = vdupq_n_u8(out_of_boundary_left);
583 vec_dup = vdupq_n_u8(*(src + limit));
584 mask_val = vcleq_u8(indx_vec, cmp_vec);
585 src_1 = vbslq_u8(mask_val, vec_dup, src_1);
586 }
587 if (out_of_boundary_right >= 0) {
588 limit = 15 - (out_of_boundary_right + 1);
589 cmp_vec = vdupq_n_u8(15 - out_of_boundary_right);
590 vec_dup = vdupq_n_u8(*(src + limit));
591 mask_val = vcgeq_u8(indx_vec, cmp_vec);
592 src_1 = vbslq_u8(mask_val, vec_dup, src_1);
593 }
594 src_2 = vextq_u8(src_1, src_1, 1);
595 src_3 = vextq_u8(src_2, src_2, 1);
596 src_4 = vextq_u8(src_3, src_3, 1);
597
598 horizontal_filter_neon(src_1, src_2, src_3, src_4, tmp, sx, alpha, k,
599 offset_bits_horiz, reduce_bits_horiz);
600 }
601 } else {
602 for (k = -7; k < AOMMIN(8, p_height - i); ++k) {
603 int iy = iy4 + k;
604 if (iy < 0)
605 iy = 0;
606 else if (iy > height - 1)
607 iy = height - 1;
608 int sx = sx4 + beta * (k + 4);
609
610 const uint8_t *src = ref + iy * stride + ix4 - 7;
611 src_1 = vld1q_u8(src);
612 src_2 = vextq_u8(src_1, src_1, 1);
613 src_3 = vextq_u8(src_2, src_2, 1);
614 src_4 = vextq_u8(src_3, src_3, 1);
615
616 horizontal_filter_neon(src_1, src_2, src_3, src_4, tmp, sx, alpha, k,
617 offset_bits_horiz, reduce_bits_horiz);
618 }
619 }
620
621 // vertical
622 for (k = -4; k < AOMMIN(4, p_height - i - 4); ++k) {
623 int sy = sy4 + delta * (k + 4);
624
625 const int16x8_t *v_src = tmp + (k + 4);
626
627 vertical_filter_neon(v_src, &res_lo, &res_hi, sy, gamma);
628
629 res_lo = vaddq_s32(res_lo, add_const_vert);
630 res_hi = vaddq_s32(res_hi, add_const_vert);
631
632 if (conv_params->is_compound) {
633 uint16_t *const p =
634 (uint16_t *)&conv_params
635 ->dst[(i + k + 4) * conv_params->dst_stride + j];
636
637 res_lo = vrshlq_s32(res_lo, shift_vert);
638 if (conv_params->do_average) {
639 uint8_t *const dst8 = &pred[(i + k + 4) * p_stride + j];
640 uint16x4_t tmp16_lo = vld1_u16(p);
641 int32x4_t tmp32_lo = vreinterpretq_s32_u32(vmovl_u16(tmp16_lo));
642 int16x4_t tmp16_low;
643 if (conv_params->use_dist_wtd_comp_avg) {
644 res_lo = vmulq_s32(res_lo, bwd);
645 tmp32_lo = vmulq_s32(tmp32_lo, fwd);
646 tmp32_lo = vaddq_s32(tmp32_lo, res_lo);
647 tmp16_low = vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS);
648 } else {
649 tmp32_lo = vaddq_s32(tmp32_lo, res_lo);
650 tmp16_low = vshrn_n_s32(tmp32_lo, 1);
651 }
652 int16x4_t res_low = vadd_s16(tmp16_low, res_sub_const);
653 res_low = vqrshl_s16(res_low, round_bits_vec);
654 int16x8_t final_res_low = vcombine_s16(res_low, res_low);
655 uint8x8_t res_8_low = vqmovun_s16(final_res_low);
656
657 vst1_lane_u32((uint32_t *)dst8, vreinterpret_u32_u8(res_8_low), 0);
658 } else {
659 uint16x4_t res_u16_low = vqmovun_s32(res_lo);
660 vst1_u16(p, res_u16_low);
661 }
662 if (p_width > 4) {
663 uint16_t *const p4 =
664 (uint16_t *)&conv_params
665 ->dst[(i + k + 4) * conv_params->dst_stride + j + 4];
666
667 res_hi = vrshlq_s32(res_hi, shift_vert);
668 if (conv_params->do_average) {
669 uint8_t *const dst8_4 = &pred[(i + k + 4) * p_stride + j + 4];
670
671 uint16x4_t tmp16_hi = vld1_u16(p4);
672 int32x4_t tmp32_hi = vreinterpretq_s32_u32(vmovl_u16(tmp16_hi));
673 int16x4_t tmp16_high;
674 if (conv_params->use_dist_wtd_comp_avg) {
675 res_hi = vmulq_s32(res_hi, bwd);
676 tmp32_hi = vmulq_s32(tmp32_hi, fwd);
677 tmp32_hi = vaddq_s32(tmp32_hi, res_hi);
678 tmp16_high = vshrn_n_s32(tmp32_hi, DIST_PRECISION_BITS);
679 } else {
680 tmp32_hi = vaddq_s32(tmp32_hi, res_hi);
681 tmp16_high = vshrn_n_s32(tmp32_hi, 1);
682 }
683 int16x4_t res_high = vadd_s16(tmp16_high, res_sub_const);
684 res_high = vqrshl_s16(res_high, round_bits_vec);
685 int16x8_t final_res_high = vcombine_s16(res_high, res_high);
686 uint8x8_t res_8_high = vqmovun_s16(final_res_high);
687
688 vst1_lane_u32((uint32_t *)dst8_4, vreinterpret_u32_u8(res_8_high),
689 0);
690 } else {
691 uint16x4_t res_u16_high = vqmovun_s32(res_hi);
692 vst1_u16(p4, res_u16_high);
693 }
694 }
695 } else {
696 res_lo = vrshlq_s32(res_lo, shift_vert);
697 res_hi = vrshlq_s32(res_hi, shift_vert);
698
699 result_final = vcombine_s16(vmovn_s32(res_lo), vmovn_s32(res_hi));
700 result_final = vsubq_s16(result_final, sub_constant);
701
702 uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j];
703 uint8x8_t val = vqmovun_s16(result_final);
704
705 if (p_width == 4) {
706 vst1_lane_u32((uint32_t *)p, vreinterpret_u32_u8(val), 0);
707 } else {
708 vst1_u8(p, val);
709 }
710 }
711 }
712 }
713 }
714 }
715