• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#include <metal_stdlib>
2#include <simd/simd.h>
3using namespace metal;
4struct Inputs {
5    uint3 sk_GlobalInvocationID;
6};
7struct inputs {
8    float in_data[1];
9};
10struct outputs {
11    float out_data[1];
12};
13struct Globals {
14    const device inputs* _anonInterface0;
15    device outputs* _anonInterface1;
16};
17struct Threadgroups {
18    array<float, 1024> shared_data;
19};
20void store_vIf(threadgroup Threadgroups& _threadgroups, uint i, float value) {
21    _threadgroups.shared_data[i] = value;
22}
23kernel void computeMain(uint3 sk_GlobalInvocationID [[thread_position_in_grid]], const device inputs& _anonInterface0 [[buffer(0)]], device outputs& _anonInterface1 [[buffer(1)]]) {
24    Globals _globals{&_anonInterface0, &_anonInterface1};
25    (void)_globals;
26    threadgroup Threadgroups _threadgroups{{}};
27    (void)_threadgroups;
28    Inputs _in = { sk_GlobalInvocationID };
29    uint id = _in.sk_GlobalInvocationID.x;
30    uint rd_id;
31    uint wr_id;
32    uint mask;
33    _threadgroups.shared_data[id * 2u] = _globals._anonInterface0->in_data[id * 2u];
34    _threadgroups.shared_data[id * 2u + 1u] = _globals._anonInterface0->in_data[id * 2u + 1u];
35    threadgroup_barrier(mem_flags::mem_threadgroup);
36    const uint steps = 10u;
37    for (uint step = 0u;step < steps; step++) {
38        mask = (1u << step) - 1u;
39        rd_id = ((id >> step) << step + 1u) + mask;
40        wr_id = (rd_id + 1u) + (id & mask);
41        store_vIf(_threadgroups, wr_id, _threadgroups.shared_data[wr_id] + _threadgroups.shared_data[rd_id]);
42        threadgroup_barrier(mem_flags::mem_threadgroup);
43    }
44    _globals._anonInterface1->out_data[id * 2u] = _threadgroups.shared_data[id * 2u];
45    _globals._anonInterface1->out_data[id * 2u + 1u] = _threadgroups.shared_data[id * 2u + 1u];
46    return;
47}
48