1#include <metal_stdlib> 2#include <simd/simd.h> 3using namespace metal; 4struct Inputs { 5 uint3 sk_GlobalInvocationID; 6}; 7struct inputs { 8 float in_data[1]; 9}; 10struct outputs { 11 float out_data[1]; 12}; 13struct Globals { 14 const device inputs* _anonInterface0; 15 device outputs* _anonInterface1; 16}; 17struct Threadgroups { 18 array<float, 1024> shared_data; 19}; 20void store_vIf(threadgroup Threadgroups& _threadgroups, uint i, float value) { 21 _threadgroups.shared_data[i] = value; 22} 23kernel void computeMain(uint3 sk_GlobalInvocationID [[thread_position_in_grid]], const device inputs& _anonInterface0 [[buffer(0)]], device outputs& _anonInterface1 [[buffer(1)]]) { 24 Globals _globals{&_anonInterface0, &_anonInterface1}; 25 (void)_globals; 26 threadgroup Threadgroups _threadgroups{{}}; 27 (void)_threadgroups; 28 Inputs _in = { sk_GlobalInvocationID }; 29 uint id = _in.sk_GlobalInvocationID.x; 30 uint rd_id; 31 uint wr_id; 32 uint mask; 33 _threadgroups.shared_data[id * 2u] = _globals._anonInterface0->in_data[id * 2u]; 34 _threadgroups.shared_data[id * 2u + 1u] = _globals._anonInterface0->in_data[id * 2u + 1u]; 35 threadgroup_barrier(mem_flags::mem_threadgroup); 36 const uint steps = 10u; 37 for (uint step = 0u;step < steps; step++) { 38 mask = (1u << step) - 1u; 39 rd_id = ((id >> step) << step + 1u) + mask; 40 wr_id = (rd_id + 1u) + (id & mask); 41 store_vIf(_threadgroups, wr_id, _threadgroups.shared_data[wr_id] + _threadgroups.shared_data[rd_id]); 42 threadgroup_barrier(mem_flags::mem_threadgroup); 43 } 44 _globals._anonInterface1->out_data[id * 2u] = _threadgroups.shared_data[id * 2u]; 45 _globals._anonInterface1->out_data[id * 2u + 1u] = _threadgroups.shared_data[id * 2u + 1u]; 46 return; 47} 48