• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Implementation of the parallel prefix sum algorithm
2
3const int SIZE = 512;
4
5layout(set=0, binding=0) readonly buffer inputs {
6    float[] in_data;
7};
8layout(set=0, binding=1) writeonly buffer outputs {
9    float[] out_data;
10};
11
12workgroup float[SIZE * 2] shared_data;
13
14// Test that workgroup-shared variables are passed to user-defined functions
15// correctly.
16noinline void store(uint i, float value) {
17    shared_data[i] = value;
18}
19
20void main() {
21    uint id = sk_GlobalInvocationID.x;
22    uint rd_id;
23    uint wr_id;
24    uint mask;
25
26    // Each thread is responsible for two elements of the output array
27    shared_data[id * 2] = in_data[id * 2];
28    shared_data[id * 2 + 1] = in_data[id * 2 + 1];
29
30    workgroupBarrier();
31
32    const uint steps = uint(log2(float(SIZE))) + 1;
33    for (uint step = 0; step < steps; step++) {
34        // Calculate the read and write index in the shared array
35        mask = (1 << step) - 1;
36        rd_id = ((id >> step) << (step + 1)) + mask;
37        wr_id = rd_id + 1 + (id & mask);
38
39        // Accumulate the read data into our element
40        store(wr_id, shared_data[wr_id] + shared_data[rd_id]);
41
42        workgroupBarrier();
43    }
44
45    // Write the final result out
46    out_data[id * 2] = shared_data[id * 2];
47    out_data[id * 2 + 1] = shared_data[id * 2 + 1];
48}
49