1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5// To convince the SPIR-V compiler to unroll the loops optimally, need this 6// macro 7#define FOUR 4 8layout(std430) buffer; 9 10/* Qualifiers: layout - storage - precision - memory */ 11 12layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; 13layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; 14layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; 15layout(set = 0, binding = 3) uniform PRECISION restrict Block { 16 ivec4 shader_extents; 17} 18uBlock; 19 20layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 21 22void main() { 23 const ivec3 pos = ivec3(gl_GlobalInvocationID); 24 if (all(lessThan(pos, uBlock.shader_extents.xyz))) { 25 // we avoid mat4 and vec4 usage here as they compile to much less efficient 26 // SPIR-V 27 float results[FOUR][FOUR]; 28 for (int i = 0; i < FOUR; i++) { 29 for (int j = 0; j < FOUR; j++) { 30 results[i][j] = 0; 31 } 32 } 33 34 for (int j = 0; j < uBlock.shader_extents.w; j++) { 35 // we may potentially read out of bounds, but (0, 0, 0, 0) will be sampled 36 // safely read and cache 4x4 tile of uM1 (4 adjacent rows) 37 vec4 uM1_partial_rows[FOUR]; 38 vec4 uM2_partial_cols[FOUR]; 39 40 for (int k = 0; k < FOUR; k++) { 41 const int pos_y_offset = (FOUR * pos.y) + k; 42 const ivec3 pos_rd = ivec3(j, pos_y_offset, pos.z); 43 uM1_partial_rows[k] = texelFetch(uM1, pos_rd, 0); 44 } 45 // read and cache 4x4 tile of uM2 (4 adjacent columns) 46 for (int k = 0; k < FOUR; k++) { 47 const int pos_x_offset = (FOUR * pos.x) + k; 48 const ivec3 pos_rd = ivec3(pos_x_offset, j, pos.z); 49 uM2_partial_cols[k] = texelFetch(uM2, pos_rd, 0); 50 } 51 // perform partial dot products and add partial result to results 52 for (int idx_r = 0; idx_r < FOUR; idx_r++) { 53 for (int idx_c = 0; idx_c < FOUR; idx_c++) { 54 results[idx_r][idx_c] += 55 dot(uM1_partial_rows[idx_r], uM2_partial_cols[idx_c]); 56 } 57 } 58 } 59 // results is in transposed order w.r.t. the desired output 60 for (int idx_c = 0; idx_c < FOUR; idx_c++) { 61 for (int idx_r = 0; idx_r < FOUR; idx_r++) { 62 const ivec3 out_pos = 63 ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z); 64 imageStore( 65 uOutput, out_pos, vec4(results[idx_c][idx_r], 0.0, 0.0, 0.0)); 66 } 67 } 68 } 69} 70