1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_ENABLE_NEON
19
20 #include <arm_neon.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace {
34
35 constexpr int kInterPostRoundBit = 4;
36
ComputeWeightedAverage8(const int16x8_t pred0,const int16x8_t pred1,const int16x4_t weights[2])37 inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0,
38 const int16x8_t pred1,
39 const int16x4_t weights[2]) {
40 // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
41 const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0));
42 const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0));
43 const int32x4_t blended_lo =
44 vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1));
45 const int32x4_t blended_hi =
46 vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1));
47
48 return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4),
49 vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4));
50 }
51
52 template <int width, int height>
DistanceWeightedBlendSmall_NEON(const int16_t * prediction_0,const int16_t * prediction_1,const int16x4_t weights[2],void * const dest,const ptrdiff_t dest_stride)53 inline void DistanceWeightedBlendSmall_NEON(const int16_t* prediction_0,
54 const int16_t* prediction_1,
55 const int16x4_t weights[2],
56 void* const dest,
57 const ptrdiff_t dest_stride) {
58 auto* dst = static_cast<uint8_t*>(dest);
59 constexpr int step = 16 / width;
60
61 for (int y = 0; y < height; y += step) {
62 const int16x8_t src_00 = vld1q_s16(prediction_0);
63 const int16x8_t src_10 = vld1q_s16(prediction_1);
64 prediction_0 += 8;
65 prediction_1 += 8;
66 const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights);
67
68 const int16x8_t src_01 = vld1q_s16(prediction_0);
69 const int16x8_t src_11 = vld1q_s16(prediction_1);
70 prediction_0 += 8;
71 prediction_1 += 8;
72 const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights);
73
74 const uint8x8_t result0 = vqmovun_s16(res0);
75 const uint8x8_t result1 = vqmovun_s16(res1);
76 if (width == 4) {
77 StoreLo4(dst, result0);
78 dst += dest_stride;
79 StoreHi4(dst, result0);
80 dst += dest_stride;
81 StoreLo4(dst, result1);
82 dst += dest_stride;
83 StoreHi4(dst, result1);
84 dst += dest_stride;
85 } else {
86 assert(width == 8);
87 vst1_u8(dst, result0);
88 dst += dest_stride;
89 vst1_u8(dst, result1);
90 dst += dest_stride;
91 }
92 }
93 }
94
DistanceWeightedBlendLarge_NEON(const int16_t * prediction_0,const int16_t * prediction_1,const int16x4_t weights[2],const int width,const int height,void * const dest,const ptrdiff_t dest_stride)95 inline void DistanceWeightedBlendLarge_NEON(const int16_t* prediction_0,
96 const int16_t* prediction_1,
97 const int16x4_t weights[2],
98 const int width, const int height,
99 void* const dest,
100 const ptrdiff_t dest_stride) {
101 auto* dst = static_cast<uint8_t*>(dest);
102
103 int y = height;
104 do {
105 int x = 0;
106 do {
107 const int16x8_t src0_lo = vld1q_s16(prediction_0 + x);
108 const int16x8_t src1_lo = vld1q_s16(prediction_1 + x);
109 const int16x8_t res_lo =
110 ComputeWeightedAverage8(src0_lo, src1_lo, weights);
111
112 const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8);
113 const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8);
114 const int16x8_t res_hi =
115 ComputeWeightedAverage8(src0_hi, src1_hi, weights);
116
117 const uint8x16_t result =
118 vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi));
119 vst1q_u8(dst + x, result);
120 x += 16;
121 } while (x < width);
122 dst += dest_stride;
123 prediction_0 += width;
124 prediction_1 += width;
125 } while (--y != 0);
126 }
127
DistanceWeightedBlend_NEON(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)128 inline void DistanceWeightedBlend_NEON(const void* prediction_0,
129 const void* prediction_1,
130 const uint8_t weight_0,
131 const uint8_t weight_1, const int width,
132 const int height, void* const dest,
133 const ptrdiff_t dest_stride) {
134 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
135 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
136 int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)};
137 // TODO(johannkoenig): Investigate the branching. May be fine to call with a
138 // variable height.
139 if (width == 4) {
140 if (height == 4) {
141 DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest,
142 dest_stride);
143 } else if (height == 8) {
144 DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest,
145 dest_stride);
146 } else {
147 assert(height == 16);
148 DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest,
149 dest_stride);
150 }
151 return;
152 }
153
154 if (width == 8) {
155 switch (height) {
156 case 4:
157 DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest,
158 dest_stride);
159 return;
160 case 8:
161 DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest,
162 dest_stride);
163 return;
164 case 16:
165 DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest,
166 dest_stride);
167 return;
168 default:
169 assert(height == 32);
170 DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest,
171 dest_stride);
172
173 return;
174 }
175 }
176
177 DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest,
178 dest_stride);
179 }
180
Init8bpp()181 void Init8bpp() {
182 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
183 assert(dsp != nullptr);
184 dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
185 }
186
187 } // namespace
188
DistanceWeightedBlendInit_NEON()189 void DistanceWeightedBlendInit_NEON() { Init8bpp(); }
190
191 } // namespace dsp
192 } // namespace libgav1
193
194 #else // !LIBGAV1_ENABLE_NEON
195
196 namespace libgav1 {
197 namespace dsp {
198
DistanceWeightedBlendInit_NEON()199 void DistanceWeightedBlendInit_NEON() {}
200
201 } // namespace dsp
202 } // namespace libgav1
203 #endif // LIBGAV1_ENABLE_NEON
204