1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5layout(std430) buffer; 6 7/* 8 * Output Image 9 */ 10layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput; 11 12/* 13 * Input Textures 14 */ 15layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 16 17/* 18 * Params Buffer 19 */ 20layout(set = 0, binding = 2) uniform PRECISION restrict Block { 21 // output texture size (x=width,y=height,z=depth,w=unused) 22 ivec4 size; 23 // input texture size (x=width,y=height,z=depth,w=unused) 24 ivec4 isize; 25 // output tensor size 26 uvec4 out_tensor_size; 27 // input tensor size 28 uvec4 in_tensor_size; 29 // output dims 30 uvec4 out_ndims; 31 // x = output channels aligned to 4, y = input channels aligned to 4 32 uvec2 ch_info; 33} 34uBlock; 35 36/* 37 * Local Work Group 38 */ 39layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 40 41void main() { 42 const ivec3 posOut = ivec3(gl_GlobalInvocationID); 43 44 if (all(lessThan(posOut, uBlock.size.xyz))) { 45 const uint max_dst_index = uBlock.out_tensor_size[0] * uBlock.ch_info.x; 46 vec4 outval = vec4(0.0); 47 48 for (uint j = 0; j < 4; ++j) { 49 uint dst_index = posOut.z * 4 + j; 50 if (dst_index >= max_dst_index) { 51 imageStore(uOutput, posOut, outval); 52 // out of range 53 break; 54 } 55 56 uint b1 = int(dst_index / uBlock.ch_info.x); 57 uint c1 = dst_index % uBlock.ch_info.x; 58 uint h1 = posOut.y; 59 uint w1 = posOut.x; 60 61 uint b, c, h, w; 62 switch (uBlock.out_ndims[0]) { 63 case 0: 64 b = b1; 65 break; 66 case 1: 67 c = b1; 68 break; 69 case 2: 70 h = b1; 71 break; 72 case 3: 73 w = b1; 74 break; 75 } 76 77 switch (uBlock.out_ndims[1]) { 78 case 0: 79 b = c1; 80 break; 81 case 1: 82 c = c1; 83 break; 84 case 2: 85 h = c1; 86 break; 87 case 3: 88 w = c1; 89 break; 90 } 91 92 switch (uBlock.out_ndims[2]) { 93 case 0: 94 b = h1; 95 break; 96 case 1: 97 c = h1; 98 break; 99 case 2: 100 h = h1; 101 break; 102 case 3: 103 w = h1; 104 break; 105 } 106 107 switch (uBlock.out_ndims[3]) { 108 case 0: 109 b = w1; 110 break; 111 case 1: 112 c = w1; 113 break; 114 case 2: 115 h = w1; 116 break; 117 case 3: 118 w = w1; 119 break; 120 } 121 122 uint src_index = b * uBlock.ch_info.y + c; 123 ivec3 posIn; 124 posIn.x = int(w); 125 posIn.y = int(h); 126 posIn.z = int(src_index / 4); 127 uint i = (src_index % 4); 128 129 vec4 inval = texelFetch(uInput, posIn, 0); 130 outval[j] = inval[i]; 131 132 if (j == 3) { 133 imageStore(uOutput, posOut, outval); 134 } 135 } 136 } 137} 138