1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5layout(std430) buffer; 6 7/* Qualifiers: layout - storage - precision - memory */ 8layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; 9layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 10layout(set = 0, binding = 2) uniform PRECISION restrict Block { 11 // dim_info.x: dim to compute mean 12 // dim_info.y: size of dim (in the input) 13 uvec2 dim_info; 14 int channel; 15} 16uBlock; 17 18/* 19 * Local Work Group Size 20 */ 21layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 22 23/* 24 * Returns a new tensor with values averaged along dimension dim 25 * Dimension dim is squeezed 26 * For each pos: 27 * - Iterate over the out_texel and the averaged dimension 28 * - For H,W; rearrange pos.x, pos.y 29 * - For C,H,W; 30 * When CHW are averaged, batch moves into channel 31 * The src N is determined by pos.z * 4 + out_index 32 */ 33 34void main() { 35 const ivec3 pos = ivec3(gl_GlobalInvocationID); 36 37 int flattened_channels = int(ceil(uBlock.channel / 4.0)); 38 vec4 out_texel = vec4(0, 0, 0, 0); 39 40 // Batch 41 if (uBlock.dim_info.x == 0) { 42 for (int batch = 0; batch < uBlock.dim_info.y; batch++) { 43 // src_n = batch 44 // src_c = pos.z 45 int src_z = batch * flattened_channels + pos.z; 46 vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); 47 out_texel += v; 48 } 49 imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); 50 } 51 52 // Channel 53 else if (uBlock.dim_info.x == 1) { 54 for (int out_index = 0; out_index < 4; out_index++) { 55 for (int channel = 0; channel < uBlock.dim_info.y; channel++) { 56 // src_n = pos.z * 4 + out_index 57 // src_c = channel 58 int src_z = 59 (pos.z * 4 + out_index) * flattened_channels + int(channel / 4); 60 vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); 61 out_texel[out_index] += v[channel % 4]; 62 } 63 } 64 imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); 65 } 66 67 // Height, Width 68 else { 69 for (int out_index = 0; out_index < 4; out_index++) { 70 // src_n = pos.z * 4 + out_index 71 // src_c = pos.y 72 int src_z = (pos.z * 4 + out_index) * flattened_channels + pos.y / 4; 73 for (int hw = 0; hw < uBlock.dim_info.y; hw++) { 74 vec4 v = (uBlock.dim_info.x == 2) 75 ? texelFetch(uInput, ivec3(pos.x, hw, src_z), 0) // Height 76 : texelFetch(uInput, ivec3(hw, pos.x, src_z), 0); // Width 77 out_texel[out_index] += v[pos.y % 4]; 78 } 79 } 80 imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); 81 } 82} 83