1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5layout(std430) buffer; 6 7/* 8 * Output Image 9 */ 10layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput; 11 12/* 13 * Input Textures 14 */ 15layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 16 17/* 18 * Params Buffer 19 */ 20layout(set = 0, binding = 2) uniform PRECISION restrict Block { 21 // output texture size (x=width,y=height,z=depth,w=unused) 22 ivec4 size; 23 // input texture size (x=width,y=height,z=depth,w=unused) 24 ivec4 isize; 25 // output tensor size 26 uvec4 tensor_size; 27 // input tensor size 28 uvec4 itensor_size; 29 // input arguments (dim, start, end, step) 30 uvec4 args; 31 // x = output channels aligned to 4, y = input channels aligned to 4 32 uvec2 c_info; 33} 34uBlock; 35 36/* 37 * Local Work Group 38 */ 39layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 40 41void main() { 42 const ivec3 posOut = ivec3(gl_GlobalInvocationID); 43 const uint max_dst_index = uBlock.tensor_size[0] * uBlock.c_info.x; 44 const uint dim = uBlock.args[0]; 45 const uint start = uBlock.args[1]; 46 const uint step = uBlock.args[3]; 47 48 if (all(lessThan(posOut, uBlock.size.xyz))) { 49 vec4 outval = vec4(0.0); 50 for (uint j = 0; j < 4; ++j) { 51 uint dst_index = posOut.z * 4 + j; 52 if (dst_index >= max_dst_index) { 53 imageStore(uOutput, posOut, outval); 54 // out of range 55 break; 56 } 57 58 // dst dims 59 uint b1 = int(dst_index / uBlock.c_info.x); 60 uint c1 = dst_index % uBlock.c_info.x; 61 uint h1 = posOut.y; 62 uint w1 = posOut.x; 63 64 // src dims 65 uint b = b1; 66 uint c = c1; 67 uint h = h1; 68 uint w = w1; 69 70 if (dim == 0) { // batch 71 b = start + step * b1; 72 } else if (dim == 1) { // feature(channel) 73 c = start + step * c1; 74 } 75 76 uint src_index = b * uBlock.c_info.y + c; 77 ivec3 posIn; 78 posIn.x = int(w); 79 posIn.y = int(h); 80 posIn.z = int(src_index / 4); 81 uint i = (src_index % 4); 82 83 vec4 inval = texelFetch(uInput, posIn, 0); 84 outval[j] = inval[i]; 85 86 if (j == 3) { 87 imageStore(uOutput, posOut, outval); 88 } 89 } 90 } 91} 92