1 /*
2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include "tools/txfm_analyzer/txfm_graph.h"
13
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <math.h>
17
18 typedef struct Node Node;
19
get_fun_name(char * str_fun_name,int str_buf_size,const TYPE_TXFM type,const int txfm_size)20 void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type,
21 const int txfm_size) {
22 if (type == TYPE_DCT)
23 snprintf(str_fun_name, str_buf_size, "fdct%d_new", txfm_size);
24 else if (type == TYPE_ADST)
25 snprintf(str_fun_name, str_buf_size, "fadst%d_new", txfm_size);
26 else if (type == TYPE_IDCT)
27 snprintf(str_fun_name, str_buf_size, "idct%d_new", txfm_size);
28 else if (type == TYPE_IADST)
29 snprintf(str_fun_name, str_buf_size, "iadst%d_new", txfm_size);
30 }
31
get_txfm_type_name(char * str_fun_name,int str_buf_size,const TYPE_TXFM type,const int txfm_size)32 void get_txfm_type_name(char *str_fun_name, int str_buf_size,
33 const TYPE_TXFM type, const int txfm_size) {
34 if (type == TYPE_DCT)
35 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
36 else if (type == TYPE_ADST)
37 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
38 else if (type == TYPE_IDCT)
39 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
40 else if (type == TYPE_IADST)
41 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
42 }
43
get_hybrid_2d_type_name(char * buf,int buf_size,const TYPE_TXFM type0,const TYPE_TXFM type1,const int txfm_size0,const int txfm_size1)44 void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0,
45 const TYPE_TXFM type1, const int txfm_size0,
46 const int txfm_size1) {
47 if (type0 == TYPE_DCT && type1 == TYPE_DCT)
48 snprintf(buf, buf_size, "_dct_dct_%dx%d", txfm_size1, txfm_size0);
49 else if (type0 == TYPE_DCT && type1 == TYPE_ADST)
50 snprintf(buf, buf_size, "_dct_adst_%dx%d", txfm_size1, txfm_size0);
51 else if (type0 == TYPE_ADST && type1 == TYPE_ADST)
52 snprintf(buf, buf_size, "_adst_adst_%dx%d", txfm_size1, txfm_size0);
53 else if (type0 == TYPE_ADST && type1 == TYPE_DCT)
54 snprintf(buf, buf_size, "_adst_dct_%dx%d", txfm_size1, txfm_size0);
55 }
56
get_inv_type(TYPE_TXFM type)57 TYPE_TXFM get_inv_type(TYPE_TXFM type) {
58 if (type == TYPE_DCT)
59 return TYPE_IDCT;
60 else if (type == TYPE_ADST)
61 return TYPE_IADST;
62 else if (type == TYPE_IDCT)
63 return TYPE_DCT;
64 else if (type == TYPE_IADST)
65 return TYPE_ADST;
66 else
67 return TYPE_LAST;
68 }
69
reference_dct_1d(double * in,double * out,int size)70 void reference_dct_1d(double *in, double *out, int size) {
71 const double kInvSqrt2 = 0.707106781186547524400844362104;
72 for (int k = 0; k < size; k++) {
73 out[k] = 0; // initialize out[k]
74 for (int n = 0; n < size; n++) {
75 out[k] += in[n] * cos(PI * (2 * n + 1) * k / (2 * size));
76 }
77 if (k == 0) out[k] = out[k] * kInvSqrt2;
78 }
79 }
80
reference_dct_2d(double * in,double * out,int size)81 void reference_dct_2d(double *in, double *out, int size) {
82 double *tempOut = new double[size * size];
83 // dct each row: in -> out
84 for (int r = 0; r < size; r++) {
85 reference_dct_1d(in + r * size, out + r * size, size);
86 }
87
88 for (int r = 0; r < size; r++) {
89 // out ->tempOut
90 for (int c = 0; c < size; c++) {
91 tempOut[r * size + c] = out[c * size + r];
92 }
93 }
94 for (int r = 0; r < size; r++) {
95 reference_dct_1d(tempOut + r * size, out + r * size, size);
96 }
97 delete[] tempOut;
98 }
99
reference_adst_1d(double * in,double * out,int size)100 void reference_adst_1d(double *in, double *out, int size) {
101 for (int k = 0; k < size; k++) {
102 out[k] = 0; // initialize out[k]
103 for (int n = 0; n < size; n++) {
104 out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size));
105 }
106 }
107 }
108
reference_hybrid_2d(double * in,double * out,int size,int type0,int type1)109 void reference_hybrid_2d(double *in, double *out, int size, int type0,
110 int type1) {
111 double *tempOut = new double[size * size];
112 // dct each row: in -> out
113 for (int r = 0; r < size; r++) {
114 if (type0 == TYPE_DCT)
115 reference_dct_1d(in + r * size, out + r * size, size);
116 else
117 reference_adst_1d(in + r * size, out + r * size, size);
118 }
119
120 for (int r = 0; r < size; r++) {
121 // out ->tempOut
122 for (int c = 0; c < size; c++) {
123 tempOut[r * size + c] = out[c * size + r];
124 }
125 }
126 for (int r = 0; r < size; r++) {
127 if (type1 == TYPE_DCT)
128 reference_dct_1d(tempOut + r * size, out + r * size, size);
129 else
130 reference_adst_1d(tempOut + r * size, out + r * size, size);
131 }
132 delete[] tempOut;
133 }
134
reference_hybrid_2d_new(double * in,double * out,int size0,int size1,int type0,int type1)135 void reference_hybrid_2d_new(double *in, double *out, int size0, int size1,
136 int type0, int type1) {
137 double *tempOut = new double[size0 * size1];
138 // dct each row: in -> out
139 for (int r = 0; r < size1; r++) {
140 if (type0 == TYPE_DCT)
141 reference_dct_1d(in + r * size0, out + r * size0, size0);
142 else
143 reference_adst_1d(in + r * size0, out + r * size0, size0);
144 }
145
146 for (int r = 0; r < size1; r++) {
147 // out ->tempOut
148 for (int c = 0; c < size0; c++) {
149 tempOut[c * size1 + r] = out[r * size0 + c];
150 }
151 }
152 for (int r = 0; r < size0; r++) {
153 if (type1 == TYPE_DCT)
154 reference_dct_1d(tempOut + r * size1, out + r * size1, size1);
155 else
156 reference_adst_1d(tempOut + r * size1, out + r * size1, size1);
157 }
158 delete[] tempOut;
159 }
160
get_max_bit(unsigned int x)161 unsigned int get_max_bit(unsigned int x) {
162 int max_bit = -1;
163 while (x) {
164 x = x >> 1;
165 max_bit++;
166 }
167 return max_bit;
168 }
169
bitwise_reverse(unsigned int x,int max_bit)170 unsigned int bitwise_reverse(unsigned int x, int max_bit) {
171 x = ((x >> 16) & 0x0000ffff) | ((x & 0x0000ffff) << 16);
172 x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8);
173 x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4);
174 x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2);
175 x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1);
176 x = x >> (31 - max_bit);
177 return x;
178 }
179
get_idx(int ri,int ci,int cSize)180 int get_idx(int ri, int ci, int cSize) { return ri * cSize + ci; }
181
add_node(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int in,double w)182 void add_node(Node *node, int stage_num, int node_num, int stage_idx,
183 int node_idx, int in, double w) {
184 int outIdx = get_idx(stage_idx, node_idx, node_num);
185 int inIdx = get_idx(stage_idx - 1, in, node_num);
186 int idx = node[outIdx].inNodeNum;
187 if (idx < 2) {
188 node[outIdx].inNode[idx] = &node[inIdx];
189 node[outIdx].inNodeIdx[idx] = in;
190 node[outIdx].inWeight[idx] = w;
191 idx++;
192 node[outIdx].inNodeNum = idx;
193 } else {
194 printf("Error: inNode is full");
195 }
196 }
197
connect_node(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int in0,double w0,int in1,double w1)198 void connect_node(Node *node, int stage_num, int node_num, int stage_idx,
199 int node_idx, int in0, double w0, int in1, double w1) {
200 int outIdx = get_idx(stage_idx, node_idx, node_num);
201 int inIdx0 = get_idx(stage_idx - 1, in0, node_num);
202 int inIdx1 = get_idx(stage_idx - 1, in1, node_num);
203
204 int idx = 0;
205 // if(w0 != 0) {
206 node[outIdx].inNode[idx] = &node[inIdx0];
207 node[outIdx].inNodeIdx[idx] = in0;
208 node[outIdx].inWeight[idx] = w0;
209 idx++;
210 //}
211
212 // if(w1 != 0) {
213 node[outIdx].inNode[idx] = &node[inIdx1];
214 node[outIdx].inNodeIdx[idx] = in1;
215 node[outIdx].inWeight[idx] = w1;
216 idx++;
217 //}
218
219 node[outIdx].inNodeNum = idx;
220 }
221
propagate(Node * node,int stage_num,int node_num,int stage_idx)222 void propagate(Node *node, int stage_num, int node_num, int stage_idx) {
223 for (int ni = 0; ni < node_num; ni++) {
224 int outIdx = get_idx(stage_idx, ni, node_num);
225 node[outIdx].value = 0;
226 for (int k = 0; k < node[outIdx].inNodeNum; k++) {
227 node[outIdx].value +=
228 node[outIdx].inNode[k]->value * node[outIdx].inWeight[k];
229 }
230 }
231 }
232
round_shift(int64_t value,int bit)233 int64_t round_shift(int64_t value, int bit) {
234 if (bit > 0) {
235 if (value < 0) {
236 return -round_shift(-value, bit);
237 } else {
238 return (value + (1 << (bit - 1))) >> bit;
239 }
240 } else {
241 return value << (-bit);
242 }
243 }
244
round_shift_array(int32_t * arr,int size,int bit)245 void round_shift_array(int32_t *arr, int size, int bit) {
246 if (bit == 0) {
247 return;
248 } else {
249 for (int i = 0; i < size; i++) {
250 arr[i] = round_shift(arr[i], bit);
251 }
252 }
253 }
254
graph_reset_visited(Node * node,int stage_num,int node_num)255 void graph_reset_visited(Node *node, int stage_num, int node_num) {
256 for (int si = 0; si < stage_num; si++) {
257 for (int ni = 0; ni < node_num; ni++) {
258 int idx = get_idx(si, ni, node_num);
259 node[idx].visited = 0;
260 }
261 }
262 }
263
estimate_value(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int estimate_bit)264 void estimate_value(Node *node, int stage_num, int node_num, int stage_idx,
265 int node_idx, int estimate_bit) {
266 if (stage_idx > 0) {
267 int outIdx = get_idx(stage_idx, node_idx, node_num);
268 int64_t out = 0;
269 node[outIdx].value = 0;
270 for (int k = 0; k < node[outIdx].inNodeNum; k++) {
271 int64_t w = round(node[outIdx].inWeight[k] * (1 << estimate_bit));
272 int64_t v = round(node[outIdx].inNode[k]->value);
273 out += v * w;
274 }
275 node[outIdx].value = round_shift(out, estimate_bit);
276 }
277 }
278
amplify_value(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int amplify_bit)279 void amplify_value(Node *node, int stage_num, int node_num, int stage_idx,
280 int node_idx, int amplify_bit) {
281 int outIdx = get_idx(stage_idx, node_idx, node_num);
282 node[outIdx].value = round_shift(round(node[outIdx].value), -amplify_bit);
283 }
284
propagate_estimate_amlify(Node * node,int stage_num,int node_num,int stage_idx,int amplify_bit,int estimate_bit)285 void propagate_estimate_amlify(Node *node, int stage_num, int node_num,
286 int stage_idx, int amplify_bit,
287 int estimate_bit) {
288 for (int ni = 0; ni < node_num; ni++) {
289 estimate_value(node, stage_num, node_num, stage_idx, ni, estimate_bit);
290 amplify_value(node, stage_num, node_num, stage_idx, ni, amplify_bit);
291 }
292 }
293
init_graph(Node * node,int stage_num,int node_num)294 void init_graph(Node *node, int stage_num, int node_num) {
295 for (int si = 0; si < stage_num; si++) {
296 for (int ni = 0; ni < node_num; ni++) {
297 int outIdx = get_idx(si, ni, node_num);
298 node[outIdx].stageIdx = si;
299 node[outIdx].nodeIdx = ni;
300 node[outIdx].value = 0;
301 node[outIdx].inNodeNum = 0;
302 if (si >= 1) {
303 connect_node(node, stage_num, node_num, si, ni, ni, 1, ni, 0);
304 }
305 }
306 }
307 }
308
gen_B_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N,int star)309 void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
310 int node_idx, int N, int star) {
311 for (int i = 0; i < N / 2; i++) {
312 int out = node_idx + i;
313 int in1 = node_idx + N - 1 - i;
314 if (star == 1) {
315 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
316 1);
317 } else {
318 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
319 1);
320 }
321 }
322 for (int i = N / 2; i < N; i++) {
323 int out = node_idx + i;
324 int in1 = node_idx + N - 1 - i;
325 if (star == 1) {
326 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
327 1);
328 } else {
329 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
330 1);
331 }
332 }
333 }
334
gen_P_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)335 void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx,
336 int node_idx, int N) {
337 int max_bit = get_max_bit(N - 1);
338 for (int i = 0; i < N; i++) {
339 int out = node_idx + bitwise_reverse(i, max_bit);
340 int in = node_idx + i;
341 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
342 }
343 }
344
gen_type1_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)345 void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx,
346 int node_idx, int N) {
347 int max_bit = get_max_bit(N);
348 for (int ni = 0; ni < N / 2; ni++) {
349 int ai = bitwise_reverse(N + ni, max_bit);
350 int out = node_idx + ni;
351 int in1 = node_idx + N - ni - 1;
352 connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
353 sin(PI * ai / (2 * 2 * N)), in1, cos(PI * ai / (2 * 2 * N)));
354 }
355 for (int ni = N / 2; ni < N; ni++) {
356 int ai = bitwise_reverse(N + ni, max_bit);
357 int out = node_idx + ni;
358 int in1 = node_idx + N - ni - 1;
359 connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
360 cos(PI * ai / (2 * 2 * N)), in1, -sin(PI * ai / (2 * 2 * N)));
361 }
362 }
363
gen_type2_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)364 void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx,
365 int node_idx, int N) {
366 for (int ni = 0; ni < N / 4; ni++) {
367 int out = node_idx + ni;
368 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
369 }
370
371 for (int ni = N / 4; ni < N / 2; ni++) {
372 int out = node_idx + ni;
373 int in1 = node_idx + N - ni - 1;
374 connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
375 -cos(PI / 4), in1, cos(-PI / 4));
376 }
377
378 for (int ni = N / 2; ni < N * 3 / 4; ni++) {
379 int out = node_idx + ni;
380 int in1 = node_idx + N - ni - 1;
381 connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
382 cos(-PI / 4), in1, cos(PI / 4));
383 }
384
385 for (int ni = N * 3 / 4; ni < N; ni++) {
386 int out = node_idx + ni;
387 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
388 }
389 }
390
gen_type3_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int idx,int N)391 void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx,
392 int node_idx, int idx, int N) {
393 // TODO(angiebird): Simplify and clarify this function
394
395 int i = 2 * N / (1 << (idx / 2));
396 int max_bit =
397 get_max_bit(i / 2) - 1; // the max_bit counts on i/2 instead of N here
398 int N_over_i = 2 << (idx / 2);
399
400 for (int nj = 0; nj < N / 2; nj += N_over_i) {
401 int j = nj / (N_over_i);
402 int kj = bitwise_reverse(i / 4 + j, max_bit);
403 // printf("kj = %d\n", kj);
404
405 // I_N/2i --- 0
406 int offset = nj;
407 for (int ni = 0; ni < N_over_i / 4; ni++) {
408 int out = node_idx + offset + ni;
409 int in = out;
410 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
411 }
412
413 // -C_Kj/i --- S_Kj/i
414 offset += N_over_i / 4;
415 for (int ni = 0; ni < N_over_i / 4; ni++) {
416 int out = node_idx + offset + ni;
417 int in0 = out;
418 double w0 = -cos(kj * PI / i);
419 int in1 = N - (offset + ni) - 1 + node_idx;
420 double w1 = sin(kj * PI / i);
421 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
422 w1);
423 }
424
425 // S_kj/i --- -C_Kj/i
426 offset += N_over_i / 4;
427 for (int ni = 0; ni < N_over_i / 4; ni++) {
428 int out = node_idx + offset + ni;
429 int in0 = out;
430 double w0 = -sin(kj * PI / i);
431 int in1 = N - (offset + ni) - 1 + node_idx;
432 double w1 = -cos(kj * PI / i);
433 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
434 w1);
435 }
436
437 // I_N/2i --- 0
438 offset += N_over_i / 4;
439 for (int ni = 0; ni < N_over_i / 4; ni++) {
440 int out = node_idx + offset + ni;
441 int in = out;
442 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
443 }
444 }
445
446 for (int nj = N / 2; nj < N; nj += N_over_i) {
447 int j = nj / N_over_i;
448 int kj = bitwise_reverse(i / 4 + j, max_bit);
449
450 // I_N/2i --- 0
451 int offset = nj;
452 for (int ni = 0; ni < N_over_i / 4; ni++) {
453 int out = node_idx + offset + ni;
454 int in = out;
455 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
456 }
457
458 // C_kj/i --- -S_Kj/i
459 offset += N_over_i / 4;
460 for (int ni = 0; ni < N_over_i / 4; ni++) {
461 int out = node_idx + offset + ni;
462 int in0 = out;
463 double w0 = cos(kj * PI / i);
464 int in1 = N - (offset + ni) - 1 + node_idx;
465 double w1 = -sin(kj * PI / i);
466 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
467 w1);
468 }
469
470 // S_kj/i --- C_Kj/i
471 offset += N_over_i / 4;
472 for (int ni = 0; ni < N_over_i / 4; ni++) {
473 int out = node_idx + offset + ni;
474 int in0 = out;
475 double w0 = sin(kj * PI / i);
476 int in1 = N - (offset + ni) - 1 + node_idx;
477 double w1 = cos(kj * PI / i);
478 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
479 w1);
480 }
481
482 // I_N/2i --- 0
483 offset += N_over_i / 4;
484 for (int ni = 0; ni < N_over_i / 4; ni++) {
485 int out = node_idx + offset + ni;
486 int in = out;
487 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
488 }
489 }
490 }
491
gen_type4_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int idx,int N)492 void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx,
493 int node_idx, int idx, int N) {
494 int B_size = 1 << ((idx + 1) / 2);
495 for (int ni = 0; ni < N; ni += B_size) {
496 gen_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, B_size,
497 (ni / B_size) % 2);
498 }
499 }
500
gen_R_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)501 void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx,
502 int node_idx, int N) {
503 int max_idx = 2 * (get_max_bit(N) + 1) - 3;
504 for (int idx = 0; idx < max_idx; idx++) {
505 int s = stage_idx + max_idx - idx - 1;
506 if (idx == 0) {
507 // type 1
508 gen_type1_graph(node, stage_num, node_num, s, node_idx, N);
509 } else if (idx == max_idx - 1) {
510 // type 2
511 gen_type2_graph(node, stage_num, node_num, s, node_idx, N);
512 } else if ((idx + 1) % 2 == 0) {
513 // type 4
514 gen_type4_graph(node, stage_num, node_num, s, node_idx, idx, N);
515 } else if ((idx + 1) % 2 == 1) {
516 // type 3
517 gen_type3_graph(node, stage_num, node_num, s, node_idx, idx, N);
518 } else {
519 printf("check gen_R_graph()\n");
520 }
521 }
522 }
523
gen_DCT_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)524 void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx,
525 int node_idx, int N) {
526 if (N > 2) {
527 gen_B_graph(node, stage_num, node_num, stage_idx, node_idx, N, 0);
528 gen_DCT_graph(node, stage_num, node_num, stage_idx + 1, node_idx, N / 2);
529 gen_R_graph(node, stage_num, node_num, stage_idx + 1, node_idx + N / 2,
530 N / 2);
531 } else {
532 // generate dct_2
533 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
534 cos(PI / 4), node_idx + 1, cos(PI / 4));
535 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
536 node_idx + 1, -cos(PI / 4), node_idx, cos(PI / 4));
537 }
538 }
539
get_dct_stage_num(int size)540 int get_dct_stage_num(int size) { return 2 * get_max_bit(size); }
541
gen_DCT_graph_1d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num)542 void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
543 int node_idx, int dct_node_num) {
544 gen_DCT_graph(node, stage_num, node_num, stage_idx, node_idx, dct_node_num);
545 int dct_stage_num = get_dct_stage_num(dct_node_num);
546 gen_P_graph(node, stage_num, node_num, stage_idx + dct_stage_num - 2,
547 node_idx, dct_node_num);
548 }
549
gen_adst_B_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx)550 void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
551 int node_idx, int adst_idx) {
552 int size = 1 << (adst_idx + 1);
553 for (int ni = 0; ni < size / 2; ni++) {
554 int nOut = node_idx + ni;
555 int nIn = nOut + size / 2;
556 connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, 1, nIn,
557 1);
558 // printf("nOut: %d nIn: %d\n", nOut, nIn);
559 }
560 for (int ni = size / 2; ni < size; ni++) {
561 int nOut = node_idx + ni;
562 int nIn = nOut - size / 2;
563 connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, -1, nIn,
564 1);
565 // printf("ndctOut: %d nIn: %d\n", nOut, nIn);
566 }
567 }
568
gen_adst_U_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx,int adst_node_num)569 void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx,
570 int node_idx, int adst_idx, int adst_node_num) {
571 int size = 1 << (adst_idx + 1);
572 for (int ni = 0; ni < adst_node_num; ni += size) {
573 gen_adst_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
574 adst_idx);
575 }
576 }
577
gen_adst_T_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,double freq)578 void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx,
579 int node_idx, double freq) {
580 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
581 cos(freq * PI), node_idx + 1, sin(freq * PI));
582 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
583 node_idx + 1, -cos(freq * PI), node_idx, sin(freq * PI));
584 }
585
gen_adst_E_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx)586 void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx,
587 int node_idx, int adst_idx) {
588 int size = 1 << (adst_idx);
589 for (int i = 0; i < size / 2; i++) {
590 int ni = i * 2;
591 double fi = (1 + 4 * i) * 1.0 / (1 << (adst_idx + 1));
592 gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
593 }
594 }
595
gen_adst_V_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx,int adst_node_num)596 void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx,
597 int node_idx, int adst_idx, int adst_node_num) {
598 int size = 1 << (adst_idx);
599 for (int i = 0; i < adst_node_num / size; i++) {
600 if (i % 2 == 1) {
601 int ni = i * size;
602 gen_adst_E_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
603 adst_idx);
604 }
605 }
606 }
gen_adst_VJ_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)607 void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx,
608 int node_idx, int adst_node_num) {
609 for (int i = 0; i < adst_node_num / 2; i++) {
610 int ni = i * 2;
611 double fi = (1 + 4 * i) * 1.0 / (4 * adst_node_num);
612 gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
613 }
614 }
gen_adst_Q_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)615 void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx,
616 int node_idx, int adst_node_num) {
617 // reverse order when idx is 1, 3, 5, 7 ...
618 // example of adst_node_num = 8:
619 // 0 1 2 3 4 5 6 7
620 // --> 0 7 2 5 4 3 6 1
621 for (int ni = 0; ni < adst_node_num; ni++) {
622 if (ni % 2 == 0) {
623 int out = node_idx + ni;
624 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out,
625 0);
626 } else {
627 int out = node_idx + ni;
628 int in = node_idx + adst_node_num - ni;
629 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
630 }
631 }
632 }
gen_adst_Ibar_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)633 void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx,
634 int node_idx, int adst_node_num) {
635 // reverse order
636 // 0 1 2 3 --> 3 2 1 0
637 for (int ni = 0; ni < adst_node_num; ni++) {
638 int out = node_idx + ni;
639 int in = node_idx + adst_node_num - ni - 1;
640 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
641 }
642 }
643
get_Q_out2in(int adst_node_num,int out)644 int get_Q_out2in(int adst_node_num, int out) {
645 int in;
646 if (out % 2 == 0) {
647 in = out;
648 } else {
649 in = adst_node_num - out;
650 }
651 return in;
652 }
653
get_Ibar_out2in(int adst_node_num,int out)654 int get_Ibar_out2in(int adst_node_num, int out) {
655 return adst_node_num - out - 1;
656 }
657
gen_adst_IbarQ_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)658 void gen_adst_IbarQ_graph(Node *node, int stage_num, int node_num,
659 int stage_idx, int node_idx, int adst_node_num) {
660 // in -> Ibar -> Q -> out
661 for (int ni = 0; ni < adst_node_num; ni++) {
662 int out = node_idx + ni;
663 int in = node_idx +
664 get_Ibar_out2in(adst_node_num, get_Q_out2in(adst_node_num, ni));
665 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
666 }
667 }
668
gen_adst_D_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)669 void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx,
670 int node_idx, int adst_node_num) {
671 // reverse order
672 for (int ni = 0; ni < adst_node_num; ni++) {
673 int out = node_idx + ni;
674 int in = out;
675 if (ni % 2 == 0) {
676 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
677 } else {
678 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, -1, in,
679 0);
680 }
681 }
682 }
683
get_hadamard_idx(int x,int adst_node_num)684 int get_hadamard_idx(int x, int adst_node_num) {
685 int max_bit = get_max_bit(adst_node_num - 1);
686 x = bitwise_reverse(x, max_bit);
687
688 // gray code
689 int c = x & 1;
690 int p = x & 1;
691 int y = c;
692
693 for (int i = 1; i <= max_bit; i++) {
694 p = c;
695 c = (x >> i) & 1;
696 y += (c ^ p) << i;
697 }
698 return y;
699 }
700
gen_adst_Ht_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)701 void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx,
702 int node_idx, int adst_node_num) {
703 for (int ni = 0; ni < adst_node_num; ni++) {
704 int out = node_idx + ni;
705 int in = node_idx + get_hadamard_idx(ni, adst_node_num);
706 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
707 }
708 }
709
gen_adst_HtD_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)710 void gen_adst_HtD_graph(Node *node, int stage_num, int node_num, int stage_idx,
711 int node_idx, int adst_node_num) {
712 for (int ni = 0; ni < adst_node_num; ni++) {
713 int out = node_idx + ni;
714 int in = node_idx + get_hadamard_idx(ni, adst_node_num);
715 double inW;
716 if (ni % 2 == 0)
717 inW = 1;
718 else
719 inW = -1;
720 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, inW, in, 0);
721 }
722 }
723
get_adst_stage_num(int adst_node_num)724 int get_adst_stage_num(int adst_node_num) {
725 return 2 * get_max_bit(adst_node_num) + 2;
726 }
727
gen_iadst_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)728 int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx,
729 int node_idx, int adst_node_num) {
730 int max_bit = get_max_bit(adst_node_num);
731 int si = 0;
732 gen_adst_IbarQ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
733 adst_node_num);
734 si++;
735 gen_adst_VJ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
736 adst_node_num);
737 si++;
738 for (int adst_idx = max_bit - 1; adst_idx >= 1; adst_idx--) {
739 gen_adst_U_graph(node, stage_num, node_num, stage_idx + si, node_idx,
740 adst_idx, adst_node_num);
741 si++;
742 gen_adst_V_graph(node, stage_num, node_num, stage_idx + si, node_idx,
743 adst_idx, adst_node_num);
744 si++;
745 }
746 gen_adst_HtD_graph(node, stage_num, node_num, stage_idx + si, node_idx,
747 adst_node_num);
748 si++;
749 return si + 1;
750 }
751
gen_adst_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)752 int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx,
753 int node_idx, int adst_node_num) {
754 int hybrid_stage_num = get_hybrid_stage_num(TYPE_ADST, adst_node_num);
755 // generate a adst tempNode
756 Node *tempNode = new Node[hybrid_stage_num * adst_node_num];
757 init_graph(tempNode, hybrid_stage_num, adst_node_num);
758 int si = gen_iadst_graph(tempNode, hybrid_stage_num, adst_node_num, 0, 0,
759 adst_node_num);
760
761 // tempNode's inverse graph to node[stage_idx][node_idx]
762 gen_inv_graph(tempNode, hybrid_stage_num, adst_node_num, node, stage_num,
763 node_num, stage_idx, node_idx);
764 delete[] tempNode;
765 return si;
766 }
767
connect_layer_2d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num)768 void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx,
769 int node_idx, int dct_node_num) {
770 for (int first = 0; first < dct_node_num; first++) {
771 for (int second = 0; second < dct_node_num; second++) {
772 // int sIn = stage_idx;
773 int sOut = stage_idx + 1;
774 int nIn = node_idx + first * dct_node_num + second;
775 int nOut = node_idx + second * dct_node_num + first;
776
777 // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut);
778
779 connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
780 }
781 }
782 }
783
connect_layer_2d_new(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num0,int dct_node_num1)784 void connect_layer_2d_new(Node *node, int stage_num, int node_num,
785 int stage_idx, int node_idx, int dct_node_num0,
786 int dct_node_num1) {
787 for (int i = 0; i < dct_node_num1; i++) {
788 for (int j = 0; j < dct_node_num0; j++) {
789 // int sIn = stage_idx;
790 int sOut = stage_idx + 1;
791 int nIn = node_idx + i * dct_node_num0 + j;
792 int nOut = node_idx + j * dct_node_num1 + i;
793
794 // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut);
795
796 connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
797 }
798 }
799 }
800
gen_DCT_graph_2d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num)801 void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
802 int node_idx, int dct_node_num) {
803 int dct_stage_num = get_dct_stage_num(dct_node_num);
804 // put 2 layers of dct_node_num DCTs on the graph
805 for (int ni = 0; ni < dct_node_num; ni++) {
806 gen_DCT_graph_1d(node, stage_num, node_num, stage_idx,
807 node_idx + ni * dct_node_num, dct_node_num);
808 gen_DCT_graph_1d(node, stage_num, node_num, stage_idx + dct_stage_num,
809 node_idx + ni * dct_node_num, dct_node_num);
810 }
811 // connect first layer and second layer
812 connect_layer_2d(node, stage_num, node_num, stage_idx + dct_stage_num - 1,
813 node_idx, dct_node_num);
814 }
815
get_hybrid_stage_num(int type,int hybrid_node_num)816 int get_hybrid_stage_num(int type, int hybrid_node_num) {
817 if (type == TYPE_DCT || type == TYPE_IDCT) {
818 return get_dct_stage_num(hybrid_node_num);
819 } else if (type == TYPE_ADST || type == TYPE_IADST) {
820 return get_adst_stage_num(hybrid_node_num);
821 }
822 return 0;
823 }
824
get_hybrid_2d_stage_num(int type0,int type1,int hybrid_node_num)825 int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num) {
826 int stage_num = 0;
827 stage_num += get_hybrid_stage_num(type0, hybrid_node_num);
828 stage_num += get_hybrid_stage_num(type1, hybrid_node_num);
829 return stage_num;
830 }
831
get_hybrid_2d_stage_num_new(int type0,int type1,int hybrid_node_num0,int hybrid_node_num1)832 int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0,
833 int hybrid_node_num1) {
834 int stage_num = 0;
835 stage_num += get_hybrid_stage_num(type0, hybrid_node_num0);
836 stage_num += get_hybrid_stage_num(type1, hybrid_node_num1);
837 return stage_num;
838 }
839
get_hybrid_amplify_factor(int type,int hybrid_node_num)840 int get_hybrid_amplify_factor(int type, int hybrid_node_num) {
841 return get_max_bit(hybrid_node_num) - 1;
842 }
843
gen_hybrid_graph_1d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int hybrid_node_num,int type)844 void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
845 int node_idx, int hybrid_node_num, int type) {
846 if (type == TYPE_DCT) {
847 gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, node_idx,
848 hybrid_node_num);
849 } else if (type == TYPE_ADST) {
850 gen_adst_graph(node, stage_num, node_num, stage_idx, node_idx,
851 hybrid_node_num);
852 } else if (type == TYPE_IDCT) {
853 int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
854 // generate a dct tempNode
855 Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
856 init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
857 gen_DCT_graph_1d(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
858 hybrid_node_num);
859
860 // tempNode's inverse graph to node[stage_idx][node_idx]
861 gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
862 node_num, stage_idx, node_idx);
863 delete[] tempNode;
864 } else if (type == TYPE_IADST) {
865 int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
866 // generate a adst tempNode
867 Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
868 init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
869 gen_adst_graph(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
870 hybrid_node_num);
871
872 // tempNode's inverse graph to node[stage_idx][node_idx]
873 gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
874 node_num, stage_idx, node_idx);
875 delete[] tempNode;
876 }
877 }
878
gen_hybrid_graph_2d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int hybrid_node_num,int type0,int type1)879 void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
880 int node_idx, int hybrid_node_num, int type0,
881 int type1) {
882 int hybrid_stage_num = get_hybrid_stage_num(type0, hybrid_node_num);
883
884 for (int ni = 0; ni < hybrid_node_num; ni++) {
885 gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
886 node_idx + ni * hybrid_node_num, hybrid_node_num,
887 type0);
888 gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx + hybrid_stage_num,
889 node_idx + ni * hybrid_node_num, hybrid_node_num,
890 type1);
891 }
892
893 // connect first layer and second layer
894 connect_layer_2d(node, stage_num, node_num, stage_idx + hybrid_stage_num - 1,
895 node_idx, hybrid_node_num);
896 }
897
gen_hybrid_graph_2d_new(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int hybrid_node_num0,int hybrid_node_num1,int type0,int type1)898 void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num,
899 int stage_idx, int node_idx, int hybrid_node_num0,
900 int hybrid_node_num1, int type0, int type1) {
901 int hybrid_stage_num0 = get_hybrid_stage_num(type0, hybrid_node_num0);
902
903 for (int ni = 0; ni < hybrid_node_num1; ni++) {
904 gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
905 node_idx + ni * hybrid_node_num0, hybrid_node_num0,
906 type0);
907 }
908 for (int ni = 0; ni < hybrid_node_num0; ni++) {
909 gen_hybrid_graph_1d(
910 node, stage_num, node_num, stage_idx + hybrid_stage_num0,
911 node_idx + ni * hybrid_node_num1, hybrid_node_num1, type1);
912 }
913
914 // connect first layer and second layer
915 connect_layer_2d_new(node, stage_num, node_num,
916 stage_idx + hybrid_stage_num0 - 1, node_idx,
917 hybrid_node_num0, hybrid_node_num1);
918 }
919
gen_inv_graph(Node * node,int stage_num,int node_num,Node * invNode,int inv_stage_num,int inv_node_num,int inv_stage_idx,int inv_node_idx)920 void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode,
921 int inv_stage_num, int inv_node_num, int inv_stage_idx,
922 int inv_node_idx) {
923 // clean up inNodeNum in invNode because of add_node
924 for (int si = 1 + inv_stage_idx; si < inv_stage_idx + stage_num; si++) {
925 for (int ni = inv_node_idx; ni < inv_node_idx + node_num; ni++) {
926 int idx = get_idx(si, ni, inv_node_num);
927 invNode[idx].inNodeNum = 0;
928 }
929 }
930 // generate inverse graph of node on invNode
931 for (int si = 1; si < stage_num; si++) {
932 for (int ni = 0; ni < node_num; ni++) {
933 int invSi = stage_num - si;
934 int idx = get_idx(si, ni, node_num);
935 for (int k = 0; k < node[idx].inNodeNum; k++) {
936 int invNi = node[idx].inNodeIdx[k];
937 add_node(invNode, inv_stage_num, inv_node_num, invSi + inv_stage_idx,
938 invNi + inv_node_idx, ni + inv_node_idx,
939 node[idx].inWeight[k]);
940 }
941 }
942 }
943 }
944