• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace {
34 
35 constexpr int kInterPostRoundBit = 4;
36 
ComputeWeightedAverage8(const int16x8_t pred0,const int16x8_t pred1,const int16x4_t weights[2])37 inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0,
38                                          const int16x8_t pred1,
39                                          const int16x4_t weights[2]) {
40   // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
41   const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0));
42   const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0));
43   const int32x4_t blended_lo =
44       vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1));
45   const int32x4_t blended_hi =
46       vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1));
47 
48   return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4),
49                       vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4));
50 }
51 
52 template <int width, int height>
DistanceWeightedBlendSmall_NEON(const int16_t * prediction_0,const int16_t * prediction_1,const int16x4_t weights[2],void * const dest,const ptrdiff_t dest_stride)53 inline void DistanceWeightedBlendSmall_NEON(const int16_t* prediction_0,
54                                             const int16_t* prediction_1,
55                                             const int16x4_t weights[2],
56                                             void* const dest,
57                                             const ptrdiff_t dest_stride) {
58   auto* dst = static_cast<uint8_t*>(dest);
59   constexpr int step = 16 / width;
60 
61   for (int y = 0; y < height; y += step) {
62     const int16x8_t src_00 = vld1q_s16(prediction_0);
63     const int16x8_t src_10 = vld1q_s16(prediction_1);
64     prediction_0 += 8;
65     prediction_1 += 8;
66     const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights);
67 
68     const int16x8_t src_01 = vld1q_s16(prediction_0);
69     const int16x8_t src_11 = vld1q_s16(prediction_1);
70     prediction_0 += 8;
71     prediction_1 += 8;
72     const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights);
73 
74     const uint8x8_t result0 = vqmovun_s16(res0);
75     const uint8x8_t result1 = vqmovun_s16(res1);
76     if (width == 4) {
77       StoreLo4(dst, result0);
78       dst += dest_stride;
79       StoreHi4(dst, result0);
80       dst += dest_stride;
81       StoreLo4(dst, result1);
82       dst += dest_stride;
83       StoreHi4(dst, result1);
84       dst += dest_stride;
85     } else {
86       assert(width == 8);
87       vst1_u8(dst, result0);
88       dst += dest_stride;
89       vst1_u8(dst, result1);
90       dst += dest_stride;
91     }
92   }
93 }
94 
DistanceWeightedBlendLarge_NEON(const int16_t * prediction_0,const int16_t * prediction_1,const int16x4_t weights[2],const int width,const int height,void * const dest,const ptrdiff_t dest_stride)95 inline void DistanceWeightedBlendLarge_NEON(const int16_t* prediction_0,
96                                             const int16_t* prediction_1,
97                                             const int16x4_t weights[2],
98                                             const int width, const int height,
99                                             void* const dest,
100                                             const ptrdiff_t dest_stride) {
101   auto* dst = static_cast<uint8_t*>(dest);
102 
103   int y = height;
104   do {
105     int x = 0;
106     do {
107       const int16x8_t src0_lo = vld1q_s16(prediction_0 + x);
108       const int16x8_t src1_lo = vld1q_s16(prediction_1 + x);
109       const int16x8_t res_lo =
110           ComputeWeightedAverage8(src0_lo, src1_lo, weights);
111 
112       const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8);
113       const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8);
114       const int16x8_t res_hi =
115           ComputeWeightedAverage8(src0_hi, src1_hi, weights);
116 
117       const uint8x16_t result =
118           vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi));
119       vst1q_u8(dst + x, result);
120       x += 16;
121     } while (x < width);
122     dst += dest_stride;
123     prediction_0 += width;
124     prediction_1 += width;
125   } while (--y != 0);
126 }
127 
DistanceWeightedBlend_NEON(const void * prediction_0,const void * prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * const dest,const ptrdiff_t dest_stride)128 inline void DistanceWeightedBlend_NEON(const void* prediction_0,
129                                        const void* prediction_1,
130                                        const uint8_t weight_0,
131                                        const uint8_t weight_1, const int width,
132                                        const int height, void* const dest,
133                                        const ptrdiff_t dest_stride) {
134   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
135   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
136   int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)};
137   // TODO(johannkoenig): Investigate the branching. May be fine to call with a
138   // variable height.
139   if (width == 4) {
140     if (height == 4) {
141       DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest,
142                                             dest_stride);
143     } else if (height == 8) {
144       DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest,
145                                             dest_stride);
146     } else {
147       assert(height == 16);
148       DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest,
149                                              dest_stride);
150     }
151     return;
152   }
153 
154   if (width == 8) {
155     switch (height) {
156       case 4:
157         DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest,
158                                               dest_stride);
159         return;
160       case 8:
161         DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest,
162                                               dest_stride);
163         return;
164       case 16:
165         DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest,
166                                                dest_stride);
167         return;
168       default:
169         assert(height == 32);
170         DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest,
171                                                dest_stride);
172 
173         return;
174     }
175   }
176 
177   DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest,
178                                   dest_stride);
179 }
180 
Init8bpp()181 void Init8bpp() {
182   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
183   assert(dsp != nullptr);
184   dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
185 }
186 
187 }  // namespace
188 
DistanceWeightedBlendInit_NEON()189 void DistanceWeightedBlendInit_NEON() { Init8bpp(); }
190 
191 }  // namespace dsp
192 }  // namespace libgav1
193 
194 #else  // !LIBGAV1_ENABLE_NEON
195 
196 namespace libgav1 {
197 namespace dsp {
198 
DistanceWeightedBlendInit_NEON()199 void DistanceWeightedBlendInit_NEON() {}
200 
201 }  // namespace dsp
202 }  // namespace libgav1
203 #endif  // LIBGAV1_ENABLE_NEON
204