1/* 2 * TILE_SIZE = (1, 1, 1) 3 * WEIGHT_STORAGE = TEXTURE_2D 4 * BIAS_STORAGE = TEXTURE_2D 5 * Note that for DW kernel IC = 1 so the weight layout is really OC4, H, W, 4oc 6 */ 7#version 450 core 8#define PRECISION ${PRECISION} 9#define FORMAT ${FORMAT} 10 11layout(std430) buffer; 12 13// clang-format off 14layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOut; 15// clang-format on 16layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 17layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel; 18layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias; 19layout(set = 0, binding = 4) uniform PRECISION restrict Block { 20 // extents of the output texture 21 ivec4 out_extents; 22 // extents of the input texture 23 ivec4 in_extents; 24 // size of the overlay region of the kernel 25 ivec4 overlay_region; 26 // width and height of the kernel 27 ivec2 kernel_size; 28 // convolution parameters 29 ivec2 stride; 30 ivec2 padding; 31 ivec2 dilate; 32 vec2 clamp_thresh; 33} 34uBlock; 35 36layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 37 38/* 39 * Computes depthwise convolution. Each shader invocation calculates the output 40 * of a single output location. 41 */ 42void main() { 43 const ivec3 pos = ivec3(gl_GlobalInvocationID); 44 45 // Return if this global position is outside output texture bounds 46 if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { 47 return; 48 } 49 50 // Compute the index of the top-left element of the overlay region. Note that 51 // negative indices can be produced indicating that the top-left element is in 52 // a region added by padding. 53 const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; 54 55 // Compute the start and end of the input indices to load. Padding is assumed 56 // to be constant 0 padding, so any reads from the padding region is skipped. 57 const ivec2 start = ipos; 58 const ivec2 end = ipos + uBlock.overlay_region.xy; 59 60 vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0); 61 const int dil_y = uBlock.dilate.y; 62 const int dil_x = uBlock.dilate.x; 63 int k_ind = 0; 64 for (int y = start.y; y < end.y; y += dil_y) { 65 for (int x = start.x; x < end.x; x += dil_x) { 66 // The weight kernel was rearranged so that every NxN filter was flattened 67 // so that it fits on one row. Each filter was then stacked on top of each 68 // other vertically. 69 const vec4 k_tex = texelFetch(uKernel, ivec2(k_ind, pos.z), 0); 70 const vec4 i_tex = texelFetch(uInput, ivec3(x, y, pos.z), 0); 71 sum = fma(i_tex, k_tex, sum); 72 k_ind++; 73 } 74 } 75 76 imageStore( 77 uOut, pos, clamp(sum, uBlock.clamp_thresh.x, uBlock.clamp_thresh.y)); 78} 79