1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include "test/av1_txfm_test.h"
13
14 #include <stdio.h>
15
16 #include <memory>
17 #include <new>
18
19 namespace libaom_test {
20
get_txfm1d_size(TX_SIZE tx_size)21 int get_txfm1d_size(TX_SIZE tx_size) { return tx_size_wide[tx_size]; }
22
get_txfm1d_type(TX_TYPE txfm2d_type,TYPE_TXFM * type0,TYPE_TXFM * type1)23 void get_txfm1d_type(TX_TYPE txfm2d_type, TYPE_TXFM *type0, TYPE_TXFM *type1) {
24 switch (txfm2d_type) {
25 case DCT_DCT:
26 *type0 = TYPE_DCT;
27 *type1 = TYPE_DCT;
28 break;
29 case ADST_DCT:
30 *type0 = TYPE_ADST;
31 *type1 = TYPE_DCT;
32 break;
33 case DCT_ADST:
34 *type0 = TYPE_DCT;
35 *type1 = TYPE_ADST;
36 break;
37 case ADST_ADST:
38 *type0 = TYPE_ADST;
39 *type1 = TYPE_ADST;
40 break;
41 case FLIPADST_DCT:
42 *type0 = TYPE_ADST;
43 *type1 = TYPE_DCT;
44 break;
45 case DCT_FLIPADST:
46 *type0 = TYPE_DCT;
47 *type1 = TYPE_ADST;
48 break;
49 case FLIPADST_FLIPADST:
50 *type0 = TYPE_ADST;
51 *type1 = TYPE_ADST;
52 break;
53 case ADST_FLIPADST:
54 *type0 = TYPE_ADST;
55 *type1 = TYPE_ADST;
56 break;
57 case FLIPADST_ADST:
58 *type0 = TYPE_ADST;
59 *type1 = TYPE_ADST;
60 break;
61 case IDTX:
62 *type0 = TYPE_IDTX;
63 *type1 = TYPE_IDTX;
64 break;
65 case H_DCT:
66 *type0 = TYPE_IDTX;
67 *type1 = TYPE_DCT;
68 break;
69 case V_DCT:
70 *type0 = TYPE_DCT;
71 *type1 = TYPE_IDTX;
72 break;
73 case H_ADST:
74 *type0 = TYPE_IDTX;
75 *type1 = TYPE_ADST;
76 break;
77 case V_ADST:
78 *type0 = TYPE_ADST;
79 *type1 = TYPE_IDTX;
80 break;
81 case H_FLIPADST:
82 *type0 = TYPE_IDTX;
83 *type1 = TYPE_ADST;
84 break;
85 case V_FLIPADST:
86 *type0 = TYPE_ADST;
87 *type1 = TYPE_IDTX;
88 break;
89 default:
90 *type0 = TYPE_DCT;
91 *type1 = TYPE_DCT;
92 assert(0);
93 break;
94 }
95 }
96
97 double Sqrt2 = pow(2, 0.5);
98 double invSqrt2 = 1 / pow(2, 0.5);
99
dct_matrix(double n,double k,int size)100 double dct_matrix(double n, double k, int size) {
101 return cos(PI * (2 * n + 1) * k / (2 * size));
102 }
103
reference_dct_1d(const double * in,double * out,int size)104 void reference_dct_1d(const double *in, double *out, int size) {
105 for (int k = 0; k < size; ++k) {
106 out[k] = 0;
107 for (int n = 0; n < size; ++n) {
108 out[k] += in[n] * dct_matrix(n, k, size);
109 }
110 if (k == 0) out[k] = out[k] * invSqrt2;
111 }
112 }
113
reference_idct_1d(const double * in,double * out,int size)114 void reference_idct_1d(const double *in, double *out, int size) {
115 for (int k = 0; k < size; ++k) {
116 out[k] = 0;
117 for (int n = 0; n < size; ++n) {
118 if (n == 0)
119 out[k] += invSqrt2 * in[n] * dct_matrix(k, n, size);
120 else
121 out[k] += in[n] * dct_matrix(k, n, size);
122 }
123 }
124 }
125
126 // TODO(any): Copied from the old 'fadst4' (same as the new 'av1_fadst4'
127 // function). Should be replaced by a proper reference function that takes
128 // 'double' input & output.
fadst4_new(const tran_low_t * input,tran_low_t * output)129 static void fadst4_new(const tran_low_t *input, tran_low_t *output) {
130 tran_high_t x0, x1, x2, x3;
131 tran_high_t s0, s1, s2, s3, s4, s5, s6, s7;
132
133 x0 = input[0];
134 x1 = input[1];
135 x2 = input[2];
136 x3 = input[3];
137
138 if (!(x0 | x1 | x2 | x3)) {
139 output[0] = output[1] = output[2] = output[3] = 0;
140 return;
141 }
142
143 s0 = sinpi_1_9 * x0;
144 s1 = sinpi_4_9 * x0;
145 s2 = sinpi_2_9 * x1;
146 s3 = sinpi_1_9 * x1;
147 s4 = sinpi_3_9 * x2;
148 s5 = sinpi_4_9 * x3;
149 s6 = sinpi_2_9 * x3;
150 s7 = x0 + x1 - x3;
151
152 x0 = s0 + s2 + s5;
153 x1 = sinpi_3_9 * s7;
154 x2 = s1 - s3 + s6;
155 x3 = s4;
156
157 s0 = x0 + x3;
158 s1 = x1;
159 s2 = x2 - x3;
160 s3 = x2 - x0 + x3;
161
162 // 1-D transform scaling factor is sqrt(2).
163 output[0] = (tran_low_t)fdct_round_shift(s0);
164 output[1] = (tran_low_t)fdct_round_shift(s1);
165 output[2] = (tran_low_t)fdct_round_shift(s2);
166 output[3] = (tran_low_t)fdct_round_shift(s3);
167 }
168
reference_adst_1d(const double * in,double * out,int size)169 void reference_adst_1d(const double *in, double *out, int size) {
170 if (size == 4) { // Special case.
171 tran_low_t int_input[4];
172 for (int i = 0; i < 4; ++i) {
173 int_input[i] = static_cast<tran_low_t>(round(in[i]));
174 }
175 tran_low_t int_output[4];
176 fadst4_new(int_input, int_output);
177 for (int i = 0; i < 4; ++i) {
178 out[i] = int_output[i];
179 }
180 return;
181 }
182
183 for (int k = 0; k < size; ++k) {
184 out[k] = 0;
185 for (int n = 0; n < size; ++n) {
186 out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size));
187 }
188 }
189 }
190
reference_idtx_1d(const double * in,double * out,int size)191 void reference_idtx_1d(const double *in, double *out, int size) {
192 double scale = 0;
193 if (size == 4)
194 scale = Sqrt2;
195 else if (size == 8)
196 scale = 2;
197 else if (size == 16)
198 scale = 2 * Sqrt2;
199 else if (size == 32)
200 scale = 4;
201 else if (size == 64)
202 scale = 4 * Sqrt2;
203 for (int k = 0; k < size; ++k) {
204 out[k] = in[k] * scale;
205 }
206 }
207
reference_hybrid_1d(double * in,double * out,int size,int type)208 void reference_hybrid_1d(double *in, double *out, int size, int type) {
209 if (type == TYPE_DCT)
210 reference_dct_1d(in, out, size);
211 else if (type == TYPE_ADST)
212 reference_adst_1d(in, out, size);
213 else
214 reference_idtx_1d(in, out, size);
215 }
216
get_amplification_factor(TX_TYPE tx_type,TX_SIZE tx_size)217 double get_amplification_factor(TX_TYPE tx_type, TX_SIZE tx_size) {
218 TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg;
219 av1_get_fwd_txfm_cfg(tx_type, tx_size, &fwd_txfm_flip_cfg);
220 const int tx_width = tx_size_wide[fwd_txfm_flip_cfg.tx_size];
221 const int tx_height = tx_size_high[fwd_txfm_flip_cfg.tx_size];
222 const int8_t *shift = fwd_txfm_flip_cfg.shift;
223 const int amplify_bit = shift[0] + shift[1] + shift[2];
224 double amplify_factor =
225 amplify_bit >= 0 ? (1 << amplify_bit) : (1.0 / (1 << -amplify_bit));
226
227 // For rectangular transforms, we need to multiply by an extra factor.
228 const int rect_type = get_rect_tx_log_ratio(tx_width, tx_height);
229 if (abs(rect_type) == 1) {
230 amplify_factor *= pow(2, 0.5);
231 }
232 return amplify_factor;
233 }
234
reference_hybrid_2d(double * in,double * out,TX_TYPE tx_type,TX_SIZE tx_size)235 void reference_hybrid_2d(double *in, double *out, TX_TYPE tx_type,
236 TX_SIZE tx_size) {
237 // Get transform type and size of each dimension.
238 TYPE_TXFM type0;
239 TYPE_TXFM type1;
240 get_txfm1d_type(tx_type, &type0, &type1);
241 const int tx_width = tx_size_wide[tx_size];
242 const int tx_height = tx_size_high[tx_size];
243
244 std::unique_ptr<double[]> temp_in(
245 new (std::nothrow) double[AOMMAX(tx_width, tx_height)]);
246 std::unique_ptr<double[]> temp_out(
247 new (std::nothrow) double[AOMMAX(tx_width, tx_height)]);
248 std::unique_ptr<double[]> out_interm(
249 new (std::nothrow) double[tx_width * tx_height]);
250 ASSERT_NE(temp_in, nullptr);
251 ASSERT_NE(temp_out, nullptr);
252 ASSERT_NE(out_interm, nullptr);
253 const int stride = tx_width;
254
255 // Transform columns.
256 for (int c = 0; c < tx_width; ++c) {
257 for (int r = 0; r < tx_height; ++r) {
258 temp_in[r] = in[r * stride + c];
259 }
260 reference_hybrid_1d(temp_in.get(), temp_out.get(), tx_height, type0);
261 for (int r = 0; r < tx_height; ++r) {
262 out_interm[r * stride + c] = temp_out[r];
263 }
264 }
265
266 // Transform rows.
267 for (int r = 0; r < tx_height; ++r) {
268 reference_hybrid_1d(out_interm.get() + r * stride, out + r * stride,
269 tx_width, type1);
270 }
271
272 // These transforms use an approximate 2D DCT transform, by only keeping the
273 // top-left quarter of the coefficients, and repacking them in the first
274 // quarter indices.
275 // TODO(urvang): Refactor this code.
276 if (tx_width == 64 && tx_height == 64) { // tx_size == TX_64X64
277 // Zero out top-right 32x32 area.
278 for (int row = 0; row < 32; ++row) {
279 memset(out + row * 64 + 32, 0, 32 * sizeof(*out));
280 }
281 // Zero out the bottom 64x32 area.
282 memset(out + 32 * 64, 0, 32 * 64 * sizeof(*out));
283 // Re-pack non-zero coeffs in the first 32x32 indices.
284 for (int row = 1; row < 32; ++row) {
285 memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out));
286 }
287 } else if (tx_width == 32 && tx_height == 64) { // tx_size == TX_32X64
288 // Zero out the bottom 32x32 area.
289 memset(out + 32 * 32, 0, 32 * 32 * sizeof(*out));
290 // Note: no repacking needed here.
291 } else if (tx_width == 64 && tx_height == 32) { // tx_size == TX_64X32
292 // Zero out right 32x32 area.
293 for (int row = 0; row < 32; ++row) {
294 memset(out + row * 64 + 32, 0, 32 * sizeof(*out));
295 }
296 // Re-pack non-zero coeffs in the first 32x32 indices.
297 for (int row = 1; row < 32; ++row) {
298 memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out));
299 }
300 } else if (tx_width == 16 && tx_height == 64) { // tx_size == TX_16X64
301 // Zero out the bottom 16x32 area.
302 memset(out + 16 * 32, 0, 16 * 32 * sizeof(*out));
303 // Note: no repacking needed here.
304 } else if (tx_width == 64 && tx_height == 16) { // tx_size == TX_64X16
305 // Zero out right 32x16 area.
306 for (int row = 0; row < 16; ++row) {
307 memset(out + row * 64 + 32, 0, 32 * sizeof(*out));
308 }
309 // Re-pack non-zero coeffs in the first 32x16 indices.
310 for (int row = 1; row < 16; ++row) {
311 memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out));
312 }
313 }
314
315 // Apply appropriate scale.
316 const double amplify_factor = get_amplification_factor(tx_type, tx_size);
317 for (int c = 0; c < tx_width; ++c) {
318 for (int r = 0; r < tx_height; ++r) {
319 out[r * stride + c] *= amplify_factor;
320 }
321 }
322 }
323
324 template <typename Type>
fliplr(Type * dest,int width,int height,int stride)325 void fliplr(Type *dest, int width, int height, int stride) {
326 for (int r = 0; r < height; ++r) {
327 for (int c = 0; c < width / 2; ++c) {
328 const Type tmp = dest[r * stride + c];
329 dest[r * stride + c] = dest[r * stride + width - 1 - c];
330 dest[r * stride + width - 1 - c] = tmp;
331 }
332 }
333 }
334
335 template <typename Type>
flipud(Type * dest,int width,int height,int stride)336 void flipud(Type *dest, int width, int height, int stride) {
337 for (int c = 0; c < width; ++c) {
338 for (int r = 0; r < height / 2; ++r) {
339 const Type tmp = dest[r * stride + c];
340 dest[r * stride + c] = dest[(height - 1 - r) * stride + c];
341 dest[(height - 1 - r) * stride + c] = tmp;
342 }
343 }
344 }
345
346 template <typename Type>
fliplrud(Type * dest,int width,int height,int stride)347 void fliplrud(Type *dest, int width, int height, int stride) {
348 for (int r = 0; r < height / 2; ++r) {
349 for (int c = 0; c < width; ++c) {
350 const Type tmp = dest[r * stride + c];
351 dest[r * stride + c] = dest[(height - 1 - r) * stride + width - 1 - c];
352 dest[(height - 1 - r) * stride + width - 1 - c] = tmp;
353 }
354 }
355 }
356
357 template void fliplr<double>(double *dest, int width, int height, int stride);
358 template void flipud<double>(double *dest, int width, int height, int stride);
359 template void fliplrud<double>(double *dest, int width, int height, int stride);
360
361 int bd_arr[BD_NUM] = { 8, 10, 12 };
362
363 int8_t low_range_arr[BD_NUM] = { 18, 32, 32 };
364 int8_t high_range_arr[BD_NUM] = { 32, 32, 32 };
365
txfm_stage_range_check(const int8_t * stage_range,int stage_num,int8_t cos_bit,int low_range,int high_range)366 void txfm_stage_range_check(const int8_t *stage_range, int stage_num,
367 int8_t cos_bit, int low_range, int high_range) {
368 for (int i = 0; i < stage_num; ++i) {
369 EXPECT_LE(stage_range[i], low_range);
370 ASSERT_LE(stage_range[i] + cos_bit, high_range) << "stage = " << i;
371 }
372 for (int i = 0; i < stage_num - 1; ++i) {
373 // make sure there is no overflow while doing half_btf()
374 ASSERT_LE(stage_range[i + 1] + cos_bit, high_range) << "stage = " << i;
375 }
376 }
377 } // namespace libaom_test
378