1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5layout(std430) buffer; 6 7/* 8 * Output Image 9 */ 10layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; 11 12/* 13 * Input Textures 14 */ 15layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 16layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma; 17layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta; 18layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean; 19layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar; 20 21/* 22 * Params Buffer 23 */ 24layout(set = 0, binding = 6) uniform PRECISION restrict Block { 25 // xyz contains extents of the output texture, w contains the number of 26 // channels divided by 4, rounded up. 27 ivec4 out_extents; 28 float eps; 29} 30uBlock; 31 32/* 33 * Local Work Group 34 */ 35layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 36 37/* 38 * Computes a Batch normalization. Each shader invocation calculates the output 39 * at a single output location. 40 */ 41void main() { 42 const ivec3 pos = ivec3(gl_GlobalInvocationID); 43 44 // Return if this global position is outside output texture bounds 45 if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { 46 return; 47 } 48 49 const ivec3 ch_pos = ivec3(0, 0, pos.z % uBlock.out_extents.w); 50 51 const vec4 in_tex = texelFetch(uInput, pos, 0); 52 const vec4 gamma_tex = texelFetch(uGamma, ch_pos, 0); 53 const vec4 beta_tex = texelFetch(uBeta, ch_pos, 0); 54 const vec4 mean_tex = texelFetch(uMean, ch_pos, 0); 55 const vec4 var_tex = texelFetch(uVar, ch_pos, 0); 56 57 const vec4 out_tex = 58 (in_tex - mean_tex) / sqrt(var_tex + uBlock.eps) * gamma_tex + beta_tex; 59 60 imageStore(uOutput, pos, out_tex); 61} 62