• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/warp.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26 #include <cstdlib>
27 #include <type_traits>
28 
29 #include "src/dsp/arm/common_neon.h"
30 #include "src/dsp/constants.h"
31 #include "src/dsp/dsp.h"
32 #include "src/utils/common.h"
33 #include "src/utils/constants.h"
34 
35 namespace libgav1 {
36 namespace dsp {
37 namespace {
38 
39 // Number of extra bits of precision in warped filtering.
40 constexpr int kWarpedDiffPrecisionBits = 10;
41 
42 }  // namespace
43 
44 namespace low_bitdepth {
45 namespace {
46 
47 constexpr int kFirstPassOffset = 1 << 14;
48 constexpr int kOffsetRemoval =
49     (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128;
50 
51 // Applies the horizontal filter to one source row and stores the result in
52 // |intermediate_result_row|. |intermediate_result_row| is a row in the 15x8
53 // |intermediate_result| two-dimensional array.
54 //
55 // src_row_centered contains 16 "centered" samples of a source row. (We center
56 // the samples by subtracting 128 from the samples.)
HorizontalFilter(const int sx4,const int16_t alpha,const int8x16_t src_row_centered,int16_t intermediate_result_row[8])57 void HorizontalFilter(const int sx4, const int16_t alpha,
58                       const int8x16_t src_row_centered,
59                       int16_t intermediate_result_row[8]) {
60   int sx = sx4 - MultiplyBy4(alpha);
61   int8x8_t filter[8];
62   for (auto& f : filter) {
63     const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
64                        kWarpedPixelPrecisionShifts;
65     f = vld1_s8(kWarpedFilters8[offset]);
66     sx += alpha;
67   }
68   Transpose8x8(filter);
69   // Add kFirstPassOffset to ensure |sum| stays within uint16_t.
70   // Add 128 (offset) * 128 (filter sum) (also 1 << 14) to account for the
71   // centering of the source samples. These combined are 1 << 15 or -32768.
72   int16x8_t sum =
73       vdupq_n_s16(static_cast<int16_t>(kFirstPassOffset + 128 * 128));
74   // Unrolled k = 0..7 loop. We need to manually unroll the loop because the
75   // third argument (an index value) to vextq_s8() must be a constant
76   // (immediate). src_row_window is a sliding window of length 8 into
77   // src_row_centered.
78   // k = 0.
79   int8x8_t src_row_window = vget_low_s8(src_row_centered);
80   sum = vmlal_s8(sum, filter[0], src_row_window);
81   // k = 1.
82   src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 1));
83   sum = vmlal_s8(sum, filter[1], src_row_window);
84   // k = 2.
85   src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 2));
86   sum = vmlal_s8(sum, filter[2], src_row_window);
87   // k = 3.
88   src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 3));
89   sum = vmlal_s8(sum, filter[3], src_row_window);
90   // k = 4.
91   src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 4));
92   sum = vmlal_s8(sum, filter[4], src_row_window);
93   // k = 5.
94   src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 5));
95   sum = vmlal_s8(sum, filter[5], src_row_window);
96   // k = 6.
97   src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 6));
98   sum = vmlal_s8(sum, filter[6], src_row_window);
99   // k = 7.
100   src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 7));
101   sum = vmlal_s8(sum, filter[7], src_row_window);
102   // End of unrolled k = 0..7 loop.
103   // Due to the offset |sum| is guaranteed to be unsigned.
104   uint16x8_t sum_unsigned = vreinterpretq_u16_s16(sum);
105   sum_unsigned = vrshrq_n_u16(sum_unsigned, kInterRoundBitsHorizontal);
106   // After the shift |sum_unsigned| will fit into int16_t.
107   vst1q_s16(intermediate_result_row, vreinterpretq_s16_u16(sum_unsigned));
108 }
109 
110 template <bool is_compound>
Warp_NEON(const void * LIBGAV1_RESTRICT const source,const ptrdiff_t source_stride,const int source_width,const int source_height,const int * LIBGAV1_RESTRICT const warp_params,const int subsampling_x,const int subsampling_y,const int block_start_x,const int block_start_y,const int block_width,const int block_height,const int16_t alpha,const int16_t beta,const int16_t gamma,const int16_t delta,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)111 void Warp_NEON(const void* LIBGAV1_RESTRICT const source,
112                const ptrdiff_t source_stride, const int source_width,
113                const int source_height,
114                const int* LIBGAV1_RESTRICT const warp_params,
115                const int subsampling_x, const int subsampling_y,
116                const int block_start_x, const int block_start_y,
117                const int block_width, const int block_height,
118                const int16_t alpha, const int16_t beta, const int16_t gamma,
119                const int16_t delta, void* LIBGAV1_RESTRICT dest,
120                const ptrdiff_t dest_stride) {
121   constexpr int kRoundBitsVertical =
122       is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
123   union {
124     // Intermediate_result is the output of the horizontal filtering and
125     // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 -
126     // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t
127     // type so that we can multiply it by kWarpedFilters (which has signed
128     // values) using vmlal_s16().
129     int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
130     // In the simple special cases where the samples in each row are all the
131     // same, store one sample per row in a column vector.
132     int16_t intermediate_result_column[15];
133   };
134 
135   const auto* const src = static_cast<const uint8_t*>(source);
136   using DestType =
137       typename std::conditional<is_compound, int16_t, uint8_t>::type;
138   auto* dst = static_cast<DestType*>(dest);
139 
140   assert(block_width >= 8);
141   assert(block_height >= 8);
142 
143   // Warp process applies for each 8x8 block.
144   int start_y = block_start_y;
145   do {
146     int start_x = block_start_x;
147     do {
148       const int src_x = (start_x + 4) << subsampling_x;
149       const int src_y = (start_y + 4) << subsampling_y;
150       const WarpFilterParams filter_params = GetWarpFilterParams(
151           src_x, src_y, subsampling_x, subsampling_y, warp_params);
152       // A prediction block may fall outside the frame's boundaries. If a
153       // prediction block is calculated using only samples outside the frame's
154       // boundary, the filtering can be simplified. We can divide the plane
155       // into several regions and handle them differently.
156       //
157       //                |           |
158       //            1   |     3     |   1
159       //                |           |
160       //         -------+-----------+-------
161       //                |***********|
162       //            2   |*****4*****|   2
163       //                |***********|
164       //         -------+-----------+-------
165       //                |           |
166       //            1   |     3     |   1
167       //                |           |
168       //
169       // At the center, region 4 represents the frame and is the general case.
170       //
171       // In regions 1 and 2, the prediction block is outside the frame's
172       // boundary horizontally. Therefore the horizontal filtering can be
173       // simplified. Furthermore, in the region 1 (at the four corners), the
174       // prediction is outside the frame's boundary both horizontally and
175       // vertically, so we get a constant prediction block.
176       //
177       // In region 3, the prediction block is outside the frame's boundary
178       // vertically. Unfortunately because we apply the horizontal filters
179       // first, by the time we apply the vertical filters, they no longer see
180       // simple inputs. So the only simplification is that all the rows are
181       // the same, but we still need to apply all the horizontal and vertical
182       // filters.
183 
184       // Check for two simple special cases, where the horizontal filter can
185       // be significantly simplified.
186       //
187       // In general, for each row, the horizontal filter is calculated as
188       // follows:
189       //   for (int x = -4; x < 4; ++x) {
190       //     const int offset = ...;
191       //     int sum = first_pass_offset;
192       //     for (int k = 0; k < 8; ++k) {
193       //       const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
194       //       sum += kWarpedFilters[offset][k] * src_row[column];
195       //     }
196       //     ...
197       //   }
198       // The column index before clipping, ix4 + x + k - 3, varies in the range
199       // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1
200       // or ix4 + 7 <= 0, then all the column indexes are clipped to the same
201       // border index (source_width - 1 or 0, respectively). Then for each x,
202       // the inner for loop of the horizontal filter is reduced to multiplying
203       // the border pixel by the sum of the filter coefficients.
204       if (filter_params.ix4 - 7 >= source_width - 1 ||
205           filter_params.ix4 + 7 <= 0) {
206         // Regions 1 and 2.
207         // Points to the left or right border of the first row of |src|.
208         const uint8_t* first_row_border =
209             (filter_params.ix4 + 7 <= 0) ? src : src + source_width - 1;
210         // In general, for y in [-7, 8), the row number iy4 + y is clipped:
211         //   const int row = Clip3(iy4 + y, 0, source_height - 1);
212         // In two special cases, iy4 + y is clipped to either 0 or
213         // source_height - 1 for all y. In the rest of the cases, iy4 + y is
214         // bounded and we can avoid clipping iy4 + y by relying on a reference
215         // frame's boundary extension on the top and bottom.
216         if (filter_params.iy4 - 7 >= source_height - 1 ||
217             filter_params.iy4 + 7 <= 0) {
218           // Region 1.
219           // Every sample used to calculate the prediction block has the same
220           // value. So the whole prediction block has the same value.
221           const int row = (filter_params.iy4 + 7 <= 0) ? 0 : source_height - 1;
222           const uint8_t row_border_pixel =
223               first_row_border[row * source_stride];
224 
225           DestType* dst_row = dst + start_x - block_start_x;
226           for (int y = 0; y < 8; ++y) {
227             if (is_compound) {
228               const int16x8_t sum =
229                   vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical -
230                                                    kRoundBitsVertical));
231               vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
232             } else {
233               memset(dst_row, row_border_pixel, 8);
234             }
235             dst_row += dest_stride;
236           }
237           // End of region 1. Continue the |start_x| do-while loop.
238           start_x += 8;
239           continue;
240         }
241 
242         // Region 2.
243         // Horizontal filter.
244         // The input values in this region are generated by extending the border
245         // which makes them identical in the horizontal direction. This
246         // computation could be inlined in the vertical pass but most
247         // implementations will need a transpose of some sort.
248         // It is not necessary to use the offset values here because the
249         // horizontal pass is a simple shift and the vertical pass will always
250         // require using 32 bits.
251         for (int y = -7; y < 8; ++y) {
252           // We may over-read up to 13 pixels above the top source row, or up
253           // to 13 pixels below the bottom source row. This is proved in
254           // warp.cc.
255           const int row = filter_params.iy4 + y;
256           int sum = first_row_border[row * source_stride];
257           sum <<= (kFilterBits - kInterRoundBitsHorizontal);
258           intermediate_result_column[y + 7] = sum;
259         }
260         // Vertical filter.
261         DestType* dst_row = dst + start_x - block_start_x;
262         int sy4 = (filter_params.y4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
263                   MultiplyBy4(delta);
264         for (int y = 0; y < 8; ++y) {
265           int sy = sy4 - MultiplyBy4(gamma);
266 #if defined(__aarch64__)
267           const int16x8_t intermediate =
268               vld1q_s16(&intermediate_result_column[y]);
269           int16_t tmp[8];
270           for (int x = 0; x < 8; ++x) {
271             const int offset =
272                 RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
273                 kWarpedPixelPrecisionShifts;
274             const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]);
275             const int32x4_t product_low =
276                 vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate));
277             const int32x4_t product_high =
278                 vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate));
279             // vaddvq_s32 is only available on __aarch64__.
280             const int32_t sum =
281                 vaddvq_s32(product_low) + vaddvq_s32(product_high);
282             const int16_t sum_descale =
283                 RightShiftWithRounding(sum, kRoundBitsVertical);
284             if (is_compound) {
285               dst_row[x] = sum_descale;
286             } else {
287               tmp[x] = sum_descale;
288             }
289             sy += gamma;
290           }
291           if (!is_compound) {
292             const int16x8_t sum = vld1q_s16(tmp);
293             vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
294           }
295 #else   // !defined(__aarch64__)
296           int16x8_t filter[8];
297           for (int x = 0; x < 8; ++x) {
298             const int offset =
299                 RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
300                 kWarpedPixelPrecisionShifts;
301             filter[x] = vld1q_s16(kWarpedFilters[offset]);
302             sy += gamma;
303           }
304           Transpose8x8(filter);
305           int32x4_t sum_low = vdupq_n_s32(0);
306           int32x4_t sum_high = sum_low;
307           for (int k = 0; k < 8; ++k) {
308             const int16_t intermediate = intermediate_result_column[y + k];
309             sum_low =
310                 vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate);
311             sum_high =
312                 vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate);
313           }
314           const int16x8_t sum =
315               vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
316                            vrshrn_n_s32(sum_high, kRoundBitsVertical));
317           if (is_compound) {
318             vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
319           } else {
320             vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
321           }
322 #endif  // defined(__aarch64__)
323           dst_row += dest_stride;
324           sy4 += delta;
325         }
326         // End of region 2. Continue the |start_x| do-while loop.
327         start_x += 8;
328         continue;
329       }
330 
331       // Regions 3 and 4.
332       // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
333 
334       // In general, for y in [-7, 8), the row number iy4 + y is clipped:
335       //   const int row = Clip3(iy4 + y, 0, source_height - 1);
336       // In two special cases, iy4 + y is clipped to either 0 or
337       // source_height - 1 for all y. In the rest of the cases, iy4 + y is
338       // bounded and we can avoid clipping iy4 + y by relying on a reference
339       // frame's boundary extension on the top and bottom.
340       if (filter_params.iy4 - 7 >= source_height - 1 ||
341           filter_params.iy4 + 7 <= 0) {
342         // Region 3.
343         // Horizontal filter.
344         const int row = (filter_params.iy4 + 7 <= 0) ? 0 : source_height - 1;
345         const uint8_t* const src_row = src + row * source_stride;
346         // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
347         // read but is ignored.
348         //
349         // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
350         // bytes after src_row[source_width - 1]. We assume the source frame
351         // has left and right borders of at least 13 bytes that extend the
352         // frame boundary pixels. We also assume there is at least one extra
353         // padding byte after the right border of the last source row.
354         const uint8x16_t src_row_v = vld1q_u8(&src_row[filter_params.ix4 - 7]);
355         // Convert src_row_v to int8 (subtract 128).
356         const int8x16_t src_row_centered =
357             vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128)));
358         int sx4 = (filter_params.x4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
359                   beta * 7;
360         for (int y = -7; y < 8; ++y) {
361           HorizontalFilter(sx4, alpha, src_row_centered,
362                            intermediate_result[y + 7]);
363           sx4 += beta;
364         }
365       } else {
366         // Region 4.
367         // Horizontal filter.
368         int sx4 = (filter_params.x4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
369                   beta * 7;
370         for (int y = -7; y < 8; ++y) {
371           // We may over-read up to 13 pixels above the top source row, or up
372           // to 13 pixels below the bottom source row. This is proved in
373           // warp.cc.
374           const int row = filter_params.iy4 + y;
375           const uint8_t* const src_row = src + row * source_stride;
376           // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
377           // read but is ignored.
378           //
379           // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
380           // bytes after src_row[source_width - 1]. We assume the source frame
381           // has left and right borders of at least 13 bytes that extend the
382           // frame boundary pixels. We also assume there is at least one extra
383           // padding byte after the right border of the last source row.
384           const uint8x16_t src_row_v =
385               vld1q_u8(&src_row[filter_params.ix4 - 7]);
386           // Convert src_row_v to int8 (subtract 128).
387           const int8x16_t src_row_centered =
388               vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128)));
389           HorizontalFilter(sx4, alpha, src_row_centered,
390                            intermediate_result[y + 7]);
391           sx4 += beta;
392         }
393       }
394 
395       // Regions 3 and 4.
396       // Vertical filter.
397       DestType* dst_row = dst + start_x - block_start_x;
398       int sy4 = (filter_params.y4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
399                 MultiplyBy4(delta);
400       for (int y = 0; y < 8; ++y) {
401         int sy = sy4 - MultiplyBy4(gamma);
402         int16x8_t filter[8];
403         for (auto& f : filter) {
404           const int offset =
405               RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
406               kWarpedPixelPrecisionShifts;
407           f = vld1q_s16(kWarpedFilters[offset]);
408           sy += gamma;
409         }
410         Transpose8x8(filter);
411         int32x4_t sum_low = vdupq_n_s32(-kOffsetRemoval);
412         int32x4_t sum_high = sum_low;
413         for (int k = 0; k < 8; ++k) {
414           const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]);
415           sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]),
416                               vget_low_s16(intermediate));
417           sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]),
418                                vget_high_s16(intermediate));
419         }
420         const int16x8_t sum =
421             vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
422                          vrshrn_n_s32(sum_high, kRoundBitsVertical));
423         if (is_compound) {
424           vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
425         } else {
426           vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
427         }
428         dst_row += dest_stride;
429         sy4 += delta;
430       }
431       start_x += 8;
432     } while (start_x < block_start_x + block_width);
433     dst += 8 * dest_stride;
434     start_y += 8;
435   } while (start_y < block_start_y + block_height);
436 }
437 
Init8bpp()438 void Init8bpp() {
439   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
440   assert(dsp != nullptr);
441   dsp->warp = Warp_NEON</*is_compound=*/false>;
442   dsp->warp_compound = Warp_NEON</*is_compound=*/true>;
443 }
444 
445 }  // namespace
446 }  // namespace low_bitdepth
447 
448 //------------------------------------------------------------------------------
449 #if LIBGAV1_MAX_BITDEPTH >= 10
450 namespace high_bitdepth {
451 namespace {
452 
LoadSrcRow(uint16_t const * ptr)453 LIBGAV1_ALWAYS_INLINE uint16x8x2_t LoadSrcRow(uint16_t const* ptr) {
454   uint16x8x2_t x;
455   // Clang/gcc uses ldp here.
456   x.val[0] = vld1q_u16(ptr);
457   x.val[1] = vld1q_u16(ptr + 8);
458   return x;
459 }
460 
HorizontalFilter(const int sx4,const int16_t alpha,const uint16x8x2_t src_row,int16_t intermediate_result_row[8])461 LIBGAV1_ALWAYS_INLINE void HorizontalFilter(
462     const int sx4, const int16_t alpha, const uint16x8x2_t src_row,
463     int16_t intermediate_result_row[8]) {
464   int sx = sx4 - MultiplyBy4(alpha);
465   int8x8_t filter8[8];
466   for (auto& f : filter8) {
467     const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
468                        kWarpedPixelPrecisionShifts;
469     f = vld1_s8(kWarpedFilters8[offset]);
470     sx += alpha;
471   }
472 
473   Transpose8x8(filter8);
474 
475   int16x8_t filter[8];
476   for (int i = 0; i < 8; ++i) {
477     filter[i] = vmovl_s8(filter8[i]);
478   }
479 
480   int32x4x2_t sum;
481   int16x8_t src_row_window;
482   // k = 0.
483   src_row_window = vreinterpretq_s16_u16(src_row.val[0]);
484   sum.val[0] = vmull_s16(vget_low_s16(filter[0]), vget_low_s16(src_row_window));
485   sum.val[1] = VMullHighS16(filter[0], src_row_window);
486   // k = 1.
487   src_row_window =
488       vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 1));
489   sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[1]),
490                          vget_low_s16(src_row_window));
491   sum.val[1] = VMlalHighS16(sum.val[1], filter[1], src_row_window);
492   // k = 2.
493   src_row_window =
494       vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 2));
495   sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[2]),
496                          vget_low_s16(src_row_window));
497   sum.val[1] = VMlalHighS16(sum.val[1], filter[2], src_row_window);
498   // k = 3.
499   src_row_window =
500       vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 3));
501   sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[3]),
502                          vget_low_s16(src_row_window));
503   sum.val[1] = VMlalHighS16(sum.val[1], filter[3], src_row_window);
504   // k = 4.
505   src_row_window =
506       vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 4));
507   sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[4]),
508                          vget_low_s16(src_row_window));
509   sum.val[1] = VMlalHighS16(sum.val[1], filter[4], src_row_window);
510   // k = 5.
511   src_row_window =
512       vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 5));
513   sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[5]),
514                          vget_low_s16(src_row_window));
515   sum.val[1] = VMlalHighS16(sum.val[1], filter[5], src_row_window);
516   // k = 6.
517   src_row_window =
518       vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 6));
519   sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[6]),
520                          vget_low_s16(src_row_window));
521   sum.val[1] = VMlalHighS16(sum.val[1], filter[6], src_row_window);
522   // k = 7.
523   src_row_window =
524       vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 7));
525   sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[7]),
526                          vget_low_s16(src_row_window));
527   sum.val[1] = VMlalHighS16(sum.val[1], filter[7], src_row_window);
528   // End of unrolled k = 0..7 loop.
529 
530   vst1_s16(intermediate_result_row,
531            vrshrn_n_s32(sum.val[0], kInterRoundBitsHorizontal));
532   vst1_s16(intermediate_result_row + 4,
533            vrshrn_n_s32(sum.val[1], kInterRoundBitsHorizontal));
534 }
535 
536 template <bool is_compound>
Warp_NEON(const void * LIBGAV1_RESTRICT const source,const ptrdiff_t source_stride,const int source_width,const int source_height,const int * LIBGAV1_RESTRICT const warp_params,const int subsampling_x,const int subsampling_y,const int block_start_x,const int block_start_y,const int block_width,const int block_height,const int16_t alpha,const int16_t beta,const int16_t gamma,const int16_t delta,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)537 void Warp_NEON(const void* LIBGAV1_RESTRICT const source,
538                const ptrdiff_t source_stride, const int source_width,
539                const int source_height,
540                const int* LIBGAV1_RESTRICT const warp_params,
541                const int subsampling_x, const int subsampling_y,
542                const int block_start_x, const int block_start_y,
543                const int block_width, const int block_height,
544                const int16_t alpha, const int16_t beta, const int16_t gamma,
545                const int16_t delta, void* LIBGAV1_RESTRICT dest,
546                const ptrdiff_t dest_stride) {
547   constexpr int kRoundBitsVertical =
548       is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
549   union {
550     // Intermediate_result is the output of the horizontal filtering and
551     // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 -
552     // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t
553     // type so that we can multiply it by kWarpedFilters (which has signed
554     // values) using vmlal_s16().
555     int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
556     // In the simple special cases where the samples in each row are all the
557     // same, store one sample per row in a column vector.
558     int16_t intermediate_result_column[15];
559   };
560 
561   const auto* const src = static_cast<const uint16_t*>(source);
562   const ptrdiff_t src_stride = source_stride >> 1;
563   using DestType =
564       typename std::conditional<is_compound, int16_t, uint16_t>::type;
565   auto* dst = static_cast<DestType*>(dest);
566   const ptrdiff_t dst_stride = is_compound ? dest_stride : dest_stride >> 1;
567   assert(block_width >= 8);
568   assert(block_height >= 8);
569 
570   // Warp process applies for each 8x8 block.
571   int start_y = block_start_y;
572   do {
573     int start_x = block_start_x;
574     do {
575       const int src_x = (start_x + 4) << subsampling_x;
576       const int src_y = (start_y + 4) << subsampling_y;
577       const WarpFilterParams filter_params = GetWarpFilterParams(
578           src_x, src_y, subsampling_x, subsampling_y, warp_params);
579       // A prediction block may fall outside the frame's boundaries. If a
580       // prediction block is calculated using only samples outside the frame's
581       // boundary, the filtering can be simplified. We can divide the plane
582       // into several regions and handle them differently.
583       //
584       //                |           |
585       //            1   |     3     |   1
586       //                |           |
587       //         -------+-----------+-------
588       //                |***********|
589       //            2   |*****4*****|   2
590       //                |***********|
591       //         -------+-----------+-------
592       //                |           |
593       //            1   |     3     |   1
594       //                |           |
595       //
596       // At the center, region 4 represents the frame and is the general case.
597       //
598       // In regions 1 and 2, the prediction block is outside the frame's
599       // boundary horizontally. Therefore the horizontal filtering can be
600       // simplified. Furthermore, in the region 1 (at the four corners), the
601       // prediction is outside the frame's boundary both horizontally and
602       // vertically, so we get a constant prediction block.
603       //
604       // In region 3, the prediction block is outside the frame's boundary
605       // vertically. Unfortunately because we apply the horizontal filters
606       // first, by the time we apply the vertical filters, they no longer see
607       // simple inputs. So the only simplification is that all the rows are
608       // the same, but we still need to apply all the horizontal and vertical
609       // filters.
610 
611       // Check for two simple special cases, where the horizontal filter can
612       // be significantly simplified.
613       //
614       // In general, for each row, the horizontal filter is calculated as
615       // follows:
616       //   for (int x = -4; x < 4; ++x) {
617       //     const int offset = ...;
618       //     int sum = first_pass_offset;
619       //     for (int k = 0; k < 8; ++k) {
620       //       const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
621       //       sum += kWarpedFilters[offset][k] * src_row[column];
622       //     }
623       //     ...
624       //   }
625       // The column index before clipping, ix4 + x + k - 3, varies in the range
626       // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1
627       // or ix4 + 7 <= 0, then all the column indexes are clipped to the same
628       // border index (source_width - 1 or 0, respectively). Then for each x,
629       // the inner for loop of the horizontal filter is reduced to multiplying
630       // the border pixel by the sum of the filter coefficients.
631       if (filter_params.ix4 - 7 >= source_width - 1 ||
632           filter_params.ix4 + 7 <= 0) {
633         // Regions 1 and 2.
634         // Points to the left or right border of the first row of |src|.
635         const uint16_t* first_row_border =
636             (filter_params.ix4 + 7 <= 0) ? src : src + source_width - 1;
637         // In general, for y in [-7, 8), the row number iy4 + y is clipped:
638         //   const int row = Clip3(iy4 + y, 0, source_height - 1);
639         // In two special cases, iy4 + y is clipped to either 0 or
640         // source_height - 1 for all y. In the rest of the cases, iy4 + y is
641         // bounded and we can avoid clipping iy4 + y by relying on a reference
642         // frame's boundary extension on the top and bottom.
643         if (filter_params.iy4 - 7 >= source_height - 1 ||
644             filter_params.iy4 + 7 <= 0) {
645           // Region 1.
646           // Every sample used to calculate the prediction block has the same
647           // value. So the whole prediction block has the same value.
648           const int row = (filter_params.iy4 + 7 <= 0) ? 0 : source_height - 1;
649           const uint16_t row_border_pixel = first_row_border[row * src_stride];
650 
651           DestType* dst_row = dst + start_x - block_start_x;
652           for (int y = 0; y < 8; ++y) {
653             if (is_compound) {
654               const int16x8_t sum =
655                   vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical -
656                                                    kRoundBitsVertical));
657               vst1q_s16(reinterpret_cast<int16_t*>(dst_row),
658                         vaddq_s16(sum, vdupq_n_s16(kCompoundOffset)));
659             } else {
660               vst1q_u16(reinterpret_cast<uint16_t*>(dst_row),
661                         vdupq_n_u16(row_border_pixel));
662             }
663             dst_row += dst_stride;
664           }
665           // End of region 1. Continue the |start_x| do-while loop.
666           start_x += 8;
667           continue;
668         }
669 
670         // Region 2.
671         // Horizontal filter.
672         // The input values in this region are generated by extending the border
673         // which makes them identical in the horizontal direction. This
674         // computation could be inlined in the vertical pass but most
675         // implementations will need a transpose of some sort.
676         // It is not necessary to use the offset values here because the
677         // horizontal pass is a simple shift and the vertical pass will always
678         // require using 32 bits.
679         for (int y = -7; y < 8; ++y) {
680           // We may over-read up to 13 pixels above the top source row, or up
681           // to 13 pixels below the bottom source row. This is proved in
682           // warp.cc.
683           const int row = filter_params.iy4 + y;
684           int sum = first_row_border[row * src_stride];
685           sum <<= (kFilterBits - kInterRoundBitsHorizontal);
686           intermediate_result_column[y + 7] = sum;
687         }
688         // Vertical filter.
689         DestType* dst_row = dst + start_x - block_start_x;
690         int sy4 = (filter_params.y4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
691                   MultiplyBy4(delta);
692         for (int y = 0; y < 8; ++y) {
693           int sy = sy4 - MultiplyBy4(gamma);
694 #if defined(__aarch64__)
695           const int16x8_t intermediate =
696               vld1q_s16(&intermediate_result_column[y]);
697           int16_t tmp[8];
698           for (int x = 0; x < 8; ++x) {
699             const int offset =
700                 RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
701                 kWarpedPixelPrecisionShifts;
702             const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]);
703             const int32x4_t product_low =
704                 vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate));
705             const int32x4_t product_high =
706                 vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate));
707             // vaddvq_s32 is only available on __aarch64__.
708             const int32_t sum =
709                 vaddvq_s32(product_low) + vaddvq_s32(product_high);
710             const int16_t sum_descale =
711                 RightShiftWithRounding(sum, kRoundBitsVertical);
712             if (is_compound) {
713               dst_row[x] = sum_descale + kCompoundOffset;
714             } else {
715               tmp[x] = sum_descale;
716             }
717             sy += gamma;
718           }
719           if (!is_compound) {
720             const uint16x8_t v_max_bitdepth =
721                 vdupq_n_u16((1 << kBitdepth10) - 1);
722             const int16x8_t sum = vld1q_s16(tmp);
723             const uint16x8_t d0 =
724                 vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(sum, vdupq_n_s16(0))),
725                           v_max_bitdepth);
726             vst1q_u16(reinterpret_cast<uint16_t*>(dst_row), d0);
727           }
728 #else   // !defined(__aarch64__)
729           int16x8_t filter[8];
730           for (int x = 0; x < 8; ++x) {
731             const int offset =
732                 RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
733                 kWarpedPixelPrecisionShifts;
734             filter[x] = vld1q_s16(kWarpedFilters[offset]);
735             sy += gamma;
736           }
737           Transpose8x8(filter);
738           int32x4_t sum_low = vdupq_n_s32(0);
739           int32x4_t sum_high = sum_low;
740           for (int k = 0; k < 8; ++k) {
741             const int16_t intermediate = intermediate_result_column[y + k];
742             sum_low =
743                 vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate);
744             sum_high =
745                 vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate);
746           }
747           if (is_compound) {
748             const int16x8_t sum =
749                 vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
750                              vrshrn_n_s32(sum_high, kRoundBitsVertical));
751             vst1q_s16(reinterpret_cast<int16_t*>(dst_row),
752                       vaddq_s16(sum, vdupq_n_s16(kCompoundOffset)));
753           } else {
754             const uint16x4_t v_max_bitdepth =
755                 vdup_n_u16((1 << kBitdepth10) - 1);
756             const uint16x4_t d0 = vmin_u16(
757                 vqrshrun_n_s32(sum_low, kRoundBitsVertical), v_max_bitdepth);
758             const uint16x4_t d1 = vmin_u16(
759                 vqrshrun_n_s32(sum_high, kRoundBitsVertical), v_max_bitdepth);
760             vst1_u16(reinterpret_cast<uint16_t*>(dst_row), d0);
761             vst1_u16(reinterpret_cast<uint16_t*>(dst_row + 4), d1);
762           }
763 #endif  // defined(__aarch64__)
764           dst_row += dst_stride;
765           sy4 += delta;
766         }
767         // End of region 2. Continue the |start_x| do-while loop.
768         start_x += 8;
769         continue;
770       }
771 
772       // Regions 3 and 4.
773       // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
774 
775       // In general, for y in [-7, 8), the row number iy4 + y is clipped:
776       //   const int row = Clip3(iy4 + y, 0, source_height - 1);
777       // In two special cases, iy4 + y is clipped to either 0 or
778       // source_height - 1 for all y. In the rest of the cases, iy4 + y is
779       // bounded and we can avoid clipping iy4 + y by relying on a reference
780       // frame's boundary extension on the top and bottom.
781       if (filter_params.iy4 - 7 >= source_height - 1 ||
782           filter_params.iy4 + 7 <= 0) {
783         // Region 3.
784         // Horizontal filter.
785         const int row = (filter_params.iy4 + 7 <= 0) ? 0 : source_height - 1;
786         const uint16_t* const src_row = src + row * src_stride;
787         // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
788         // read but is ignored.
789         //
790         // NOTE: This may read up to 13 pixels before src_row[0] or up to 14
791         // pixels after src_row[source_width - 1]. We assume the source frame
792         // has left and right borders of at least 13 pixels that extend the
793         // frame boundary pixels. We also assume there is at least one extra
794         // padding pixel after the right border of the last source row.
795         const uint16x8x2_t src_row_v =
796             LoadSrcRow(&src_row[filter_params.ix4 - 7]);
797         int sx4 = (filter_params.x4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
798                   beta * 7;
799         for (int y = -7; y < 8; ++y) {
800           HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]);
801           sx4 += beta;
802         }
803       } else {
804         // Region 4.
805         // Horizontal filter.
806         int sx4 = (filter_params.x4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
807                   beta * 7;
808         for (int y = -7; y < 8; ++y) {
809           // We may over-read up to 13 pixels above the top source row, or up
810           // to 13 pixels below the bottom source row. This is proved in
811           // warp.cc.
812           const int row = filter_params.iy4 + y;
813           const uint16_t* const src_row = src + row * src_stride;
814           // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
815           // read but is ignored.
816           //
817           // NOTE: This may read up to pixels bytes before src_row[0] or up to
818           // 14 pixels after src_row[source_width - 1]. We assume the source
819           // frame has left and right borders of at least 13 pixels that extend
820           // the frame boundary pixels. We also assume there is at least one
821           // extra padding pixel after the right border of the last source row.
822           const uint16x8x2_t src_row_v =
823               LoadSrcRow(&src_row[filter_params.ix4 - 7]);
824           HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]);
825           sx4 += beta;
826         }
827       }
828 
829       // Regions 3 and 4.
830       // Vertical filter.
831       DestType* dst_row = dst + start_x - block_start_x;
832       int sy4 = (filter_params.y4 & ((1 << kWarpedModelPrecisionBits) - 1)) -
833                 MultiplyBy4(delta);
834       for (int y = 0; y < 8; ++y) {
835         int sy = sy4 - MultiplyBy4(gamma);
836         int16x8_t filter[8];
837         for (auto& f : filter) {
838           const int offset =
839               RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
840               kWarpedPixelPrecisionShifts;
841           f = vld1q_s16(kWarpedFilters[offset]);
842           sy += gamma;
843         }
844         Transpose8x8(filter);
845         int32x4_t sum_low = vdupq_n_s32(0);
846         int32x4_t sum_high = sum_low;
847         for (int k = 0; k < 8; ++k) {
848           const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]);
849           sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]),
850                               vget_low_s16(intermediate));
851           sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]),
852                                vget_high_s16(intermediate));
853         }
854         if (is_compound) {
855           const int16x8_t sum =
856               vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
857                            vrshrn_n_s32(sum_high, kRoundBitsVertical));
858           vst1q_s16(reinterpret_cast<int16_t*>(dst_row),
859                     vaddq_s16(sum, vdupq_n_s16(kCompoundOffset)));
860         } else {
861           const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
862           const uint16x4_t d0 = vmin_u16(
863               vqrshrun_n_s32(sum_low, kRoundBitsVertical), v_max_bitdepth);
864           const uint16x4_t d1 = vmin_u16(
865               vqrshrun_n_s32(sum_high, kRoundBitsVertical), v_max_bitdepth);
866           vst1_u16(reinterpret_cast<uint16_t*>(dst_row), d0);
867           vst1_u16(reinterpret_cast<uint16_t*>(dst_row + 4), d1);
868         }
869         dst_row += dst_stride;
870         sy4 += delta;
871       }
872       start_x += 8;
873     } while (start_x < block_start_x + block_width);
874     dst += 8 * dst_stride;
875     start_y += 8;
876   } while (start_y < block_start_y + block_height);
877 }
878 
Init10bpp()879 void Init10bpp() {
880   Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
881   assert(dsp != nullptr);
882   dsp->warp = Warp_NEON</*is_compound=*/false>;
883   dsp->warp_compound = Warp_NEON</*is_compound=*/true>;
884 }
885 
886 }  // namespace
887 }  // namespace high_bitdepth
888 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
889 
WarpInit_NEON()890 void WarpInit_NEON() {
891   low_bitdepth::Init8bpp();
892 #if LIBGAV1_MAX_BITDEPTH >= 10
893   high_bitdepth::Init10bpp();
894 #endif
895 }
896 
897 }  // namespace dsp
898 }  // namespace libgav1
899 #else   // !LIBGAV1_ENABLE_NEON
900 namespace libgav1 {
901 namespace dsp {
902 
WarpInit_NEON()903 void WarpInit_NEON() {}
904 
905 }  // namespace dsp
906 }  // namespace libgav1
907 #endif  // LIBGAV1_ENABLE_NEON
908