• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * function: kernel_3d_denoise_slm
3 *     3D Noise Reduction
4 * gain:        The parameter determines the filtering strength for the reference block
5 * threshold:   Noise variances of observed image
6 * restoredPrev: The previous restored image, image2d_t as read only
7 * output:      restored image, image2d_t as write only
8 * input:       observed image, image2d_t as read only
9 * inputPrev1:  reference image, image2d_t as read only
10 * inputPrev2:  reference image, image2d_t as read only
11 */
12
13#ifndef REFERENCE_FRAME_COUNT
14#define REFERENCE_FRAME_COUNT 2
15#endif
16
17#ifndef ENABLE_IIR_FILERING
18#define ENABLE_IIR_FILERING 1
19#endif
20
21#define WORK_GROUP_WIDTH    8
22#define WORK_GROUP_HEIGHT   1
23
24#define WORK_BLOCK_WIDTH    8
25#define WORK_BLOCK_HEIGHT   8
26
27#define REF_BLOCK_X_OFFSET  1
28#define REF_BLOCK_Y_OFFSET  4
29
30#define REF_BLOCK_WIDTH  (WORK_BLOCK_WIDTH + 2 * REF_BLOCK_X_OFFSET)
31#define REF_BLOCK_HEIGHT (WORK_BLOCK_HEIGHT + 2 * REF_BLOCK_Y_OFFSET)
32
33
34inline void weighted_average (__read_only image2d_t input,
35                              __local float4* ref_cache,
36                              bool load_observe,
37                              __local float4* observe_cache,
38                              float4* restore,
39                              float2* sum_weight,
40                              float gain,
41                              float threshold)
42{
43    sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;
44
45    const int local_id_x = get_local_id(0);
46    const int local_id_y = get_local_id(1);
47    const int group_id_x = get_group_id(0);
48    const int group_id_y = get_group_id(1);
49
50    int i = local_id_x + local_id_y * WORK_BLOCK_WIDTH;
51    int start_x = mad24(group_id_x, WORK_BLOCK_WIDTH, -REF_BLOCK_X_OFFSET);
52    int start_y = mad24(group_id_y, WORK_BLOCK_HEIGHT, -REF_BLOCK_Y_OFFSET);
53    for (int j = i; j < REF_BLOCK_WIDTH * REF_BLOCK_HEIGHT; j += (WORK_GROUP_WIDTH * WORK_GROUP_HEIGHT)) {
54        int corrd_x = start_x + (j % REF_BLOCK_WIDTH);
55        int corrd_y = start_y + (j / REF_BLOCK_WIDTH);
56        ref_cache[j] = read_imagef(input, sampler, (int2)(corrd_x, corrd_y));
57    }
58    barrier(CLK_LOCAL_MEM_FENCE);
59
60    if (load_observe) {
61        for (int i = 0; i < WORK_BLOCK_HEIGHT; i++) {
62            observe_cache[i * WORK_BLOCK_WIDTH + local_id_x] =
63                ref_cache[(i + REF_BLOCK_Y_OFFSET) * REF_BLOCK_WIDTH
64                          + local_id_x + REF_BLOCK_X_OFFSET];
65        }
66    }
67
68    float4 dist = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
69    float4 gradient = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
70    float weight = 0.0f;
71
72#pragma unroll
73    for (int i = 0; i < 3; i++) {
74#pragma unroll
75        for (int j = 0; j < 3; j++) {
76            dist = (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)] -
77                    observe_cache[local_id_x]) *
78                   (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)] -
79                    observe_cache[local_id_x]);
80            dist = mad((ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)] -
81                        observe_cache[WORK_BLOCK_WIDTH + local_id_x]),
82                       (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)] -
83                        observe_cache[WORK_BLOCK_WIDTH + local_id_x]),
84                       dist);
85            dist = mad((ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)] -
86                        observe_cache[2 * WORK_BLOCK_WIDTH + local_id_x]),
87                       (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)] -
88                        observe_cache[2 * WORK_BLOCK_WIDTH + local_id_x]),
89                       dist);
90            dist = mad((ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)] -
91                        observe_cache[3 * WORK_BLOCK_WIDTH + local_id_x]),
92                       (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)] -
93                        observe_cache[3 * WORK_BLOCK_WIDTH + local_id_x]),
94                       dist);
95
96            gradient = (float4)(ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2,
97                                ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2,
98                                ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2,
99                                ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2);
100            gradient = (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)]) +
101                       (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)]) +
102                       (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)]) +
103                       (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)]);
104            gradient.s0 = (gradient.s0 + gradient.s1 + gradient.s2 + gradient.s3) / 15.0f;
105            gain = (gradient.s0 < threshold) ? gain : 2.0f * gain;
106
107            weight = native_exp(-gain * (dist.s0 + dist.s1 + dist.s2 + dist.s3));
108            weight = (weight < 0) ? 0 : weight;
109            (*sum_weight).s0 = (*sum_weight).s0 + weight;
110
111            restore[0] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)], restore[0]);
112            restore[1] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)], restore[1]);
113            restore[2] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)], restore[2]);
114            restore[3] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)], restore[3]);
115        }
116    }
117
118#pragma unroll
119    for (int i = 1; i < 4; i++) {
120#pragma unroll
121        for (int j = 0; j < 3; j++) {
122            dist = (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)] -
123                    observe_cache[4 * WORK_BLOCK_WIDTH + local_id_x]) *
124                   (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)] -
125                    observe_cache[4 * WORK_BLOCK_WIDTH + local_id_x]);
126            dist = mad((ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)] -
127                        observe_cache[5 * WORK_BLOCK_WIDTH + local_id_x]),
128                       (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)] -
129                        observe_cache[5 * WORK_BLOCK_WIDTH + local_id_x]),
130                       dist);
131            dist = mad((ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)] -
132                        observe_cache[6 * WORK_BLOCK_WIDTH + local_id_x]),
133                       (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)] -
134                        observe_cache[6 * WORK_BLOCK_WIDTH + local_id_x]),
135                       dist);
136            dist = mad((ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)] -
137                        observe_cache[7 * WORK_BLOCK_WIDTH + local_id_x]),
138                       (ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)] -
139                        observe_cache[7 * WORK_BLOCK_WIDTH + local_id_x]),
140                       dist);
141
142            gradient = (float4)(ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2,
143                                ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2,
144                                ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2,
145                                ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)].s2);
146            gradient = (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)]) +
147                       (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)]) +
148                       (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)]) +
149                       (gradient - ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)]);
150            gradient.s0 = (gradient.s0 + gradient.s1 + gradient.s2 + gradient.s3) / 15.0f;
151            gain = (gradient.s0 < threshold) ? gain : 2.0f * gain;
152
153            weight = native_exp(-gain * (dist.s0 + dist.s1 + dist.s2 + dist.s3));
154            weight = (weight < 0) ? 0 : weight;
155            (*sum_weight).s1 = (*sum_weight).s1 + weight;
156
157            restore[4] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, local_id_x + j)], restore[4]);
158            restore[5] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, REF_BLOCK_WIDTH + local_id_x + j)], restore[5]);
159            restore[6] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 2 * REF_BLOCK_WIDTH + local_id_x + j)], restore[6]);
160            restore[7] = mad(weight, ref_cache[mad24(i, 4 * REF_BLOCK_WIDTH, 3 * REF_BLOCK_WIDTH + local_id_x + j)], restore[7]);
161        }
162    }
163}
164
165__kernel void kernel_3d_denoise_slm( float gain,
166                                     float threshold,
167                                     __read_only image2d_t restoredPrev,
168                                     __write_only image2d_t output,
169                                     __read_only image2d_t input,
170                                     __read_only image2d_t inputPrev1,
171                                     __read_only image2d_t inputPrev2)
172{
173    float4 restore[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
174    float2 sum_weight = {0.0f, 0.0f};
175
176    __local float4 ref_cache[REF_BLOCK_HEIGHT * REF_BLOCK_WIDTH];
177    __local float4 observe_cache[WORK_BLOCK_HEIGHT * WORK_BLOCK_WIDTH];
178
179    weighted_average (input, ref_cache, true, observe_cache, restore, &sum_weight, gain, threshold);
180
181#if 1
182
183#if ENABLE_IIR_FILERING
184    weighted_average (restoredPrev, ref_cache, false, observe_cache, restore, &sum_weight, gain, threshold);
185#else
186#if REFERENCE_FRAME_COUNT > 1
187    weighted_average (inputPrev1, ref_cache, false, observe_cache, restore, &sum_weight, gain, threshold);
188#endif
189
190#if REFERENCE_FRAME_COUNT > 2
191    weighted_average (inputPrev2, ref_cache, false, observe_cache, restore, &sum_weight, gain, threshold);
192#endif
193#endif
194
195#endif
196
197    restore[0] = restore[0] / sum_weight.s0;
198    restore[1] = restore[1] / sum_weight.s0;
199    restore[2] = restore[2] / sum_weight.s0;
200    restore[3] = restore[3] / sum_weight.s0;
201
202    restore[4] = restore[4] / sum_weight.s1;
203    restore[5] = restore[5] / sum_weight.s1;
204    restore[6] = restore[6] / sum_weight.s1;
205    restore[7] = restore[7] / sum_weight.s1;
206
207    const int global_id_x = get_global_id (0);
208    const int global_id_y = get_global_id (1);
209
210    write_imagef(output, (int2)(global_id_x, 8 * global_id_y), restore[0]);
211    write_imagef(output, (int2)(global_id_x, mad24(8, global_id_y, 1)), restore[1]);
212    write_imagef(output, (int2)(global_id_x, mad24(8, global_id_y, 2)), restore[2]);
213    write_imagef(output, (int2)(global_id_x, mad24(8, global_id_y, 3)), restore[3]);
214    write_imagef(output, (int2)(global_id_x, mad24(8, global_id_y, 4)), restore[4]);
215    write_imagef(output, (int2)(global_id_x, mad24(8, global_id_y, 5)), restore[5]);
216    write_imagef(output, (int2)(global_id_x, mad24(8, global_id_y, 6)), restore[6]);
217    write_imagef(output, (int2)(global_id_x, mad24(8, global_id_y, 7)), restore[7]);
218}
219
220