1/* 2 * TILE_SIZE = (${OUTPUT_TILE_SIZE[0]}, ${OUTPUT_TILE_SIZE[1]}, 1) 3 * WEIGHT_STORAGE = TEXTURE_2D 4 * WEIGHT_STORAGE_LAYOUT = OC4,IC4,4ic,4oc 5 * BIAS_STORAGE = TEXTURE_2D 6 * REGISTER_FOR = ${REGISTER_FOR} 7 */ 8#version 450 core 9#define PRECISION ${PRECISION} 10#define FORMAT ${FORMAT} 11 12layout(std430) buffer; 13 14// clang-format off 15layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; 16// clang-format on 17layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 18layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel; 19layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias; 20layout(set = 0, binding = 4) uniform PRECISION restrict Block { 21 // extents of the output texture 22 ivec4 out_extents; 23 // extents of the input texture 24 ivec4 in_extents; 25 // size of the overlay region of the kernel 26 ivec4 overlay_region; 27 // width and height of the kernel 28 ivec2 kernel_size; 29 // convolution parameters 30 ivec2 stride; 31 ivec2 padding; 32 ivec2 dilate; 33 vec2 clamp_thresh; 34} 35uBlock; 36 37layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 38 39/* 40 * Computes a 2D pointwise convolution of a 2x2 output tile. Calculating an 41 * output tile for pointwise convolution is more efficient because the kernel 42 * size is only 1x1, making it much easier to re-use loaded texels from uKernel. 43 */ 44void main() { 45 const ivec3 gpos = ivec3(gl_GlobalInvocationID); 46 47 // Output position for OUTPUT_TILE_SIZE[0], OUTPUT_TILE_SIZE[1] = 2, 2 48 // +--------+--------+ 49 // | pos[0] | pos[1] | 50 // +--------+--------+ 51 // | pos[2] | pos[3] | 52 // +--------+--------+ 53 ivec3 pos[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}]; 54 for (int y = 0, i = 0; y < ${OUTPUT_TILE_SIZE[1]}; ++y) { 55 for (int x = 0; x < ${OUTPUT_TILE_SIZE[0]}; ++x) { 56 pos[i] = ivec3( 57 gpos.x * ${OUTPUT_TILE_SIZE[0]} + x, gpos.y * ${OUTPUT_TILE_SIZE[1]} + y, gpos.z); 58 i++; 59 } 60 } 61 62 // If the top left position is out of bounds, then this invocation will have 63 // no work to do. 64 if (any(greaterThanEqual(pos[0], uBlock.out_extents.xyz))) { 65 return; 66 } 67 68 // Compute the index of the input texture that needs to be loaded for each 69 // output position. Note that negative indices can be produced indicating that 70 // the top-left element is in a region added by padding. 71 ivec2 ipos[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}]; 72 for (int i = 0; i < ${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}; ++i) { 73 ipos[i] = pos[i].xy * uBlock.stride - uBlock.padding; 74 } 75 76 vec4 sum[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}]; 77 sum[0] = texelFetch(uBias, ivec2(gpos.z, 0), 0); 78 for (int i = 1; i < ${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}; ++i) { 79 sum[i] = sum[0]; 80 } 81 82 // Since the kernel is 1x1, we only have to loop over the depth dimension. 83 const int ic_aligned = uBlock.overlay_region.z; 84 for (int z = 0, z4 = 0; z < ic_aligned; z += 4, ++z4) { 85 // During prepacking, the weight tensor has been permuted so that the 86 // channel (IC) dim is along the x axis, and the batch (OC) dim is along 87 // the z axis. 88 vec4 in_tex[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}]; 89 const vec4 ktex_0 = texelFetch(uKernel, ivec2(z + 0, gpos.z), 0); 90 const vec4 ktex_1 = texelFetch(uKernel, ivec2(z + 1, gpos.z), 0); 91 const vec4 ktex_2 = texelFetch(uKernel, ivec2(z + 2, gpos.z), 0); 92 const vec4 ktex_3 = texelFetch(uKernel, ivec2(z + 3, gpos.z), 0); 93 94 for (int i = 0; i < ${OUTPUT_TILE_SIZE[1]} * ${OUTPUT_TILE_SIZE[0]}; ++i) { 95 in_tex[i] = texelFetch(uInput, ivec3(ipos[i], z4), 0); 96 } 97 98 for (int i = 0; i < ${OUTPUT_TILE_SIZE[1]} * ${OUTPUT_TILE_SIZE[0]}; ++i) { 99 // For 2x2 tile size algorithm works as follows. 100 // To explain the calculations below, the contents one in_tex and the 101 // group of 4 texels loaded from uKernel are shown: 102 // 103 // in_tex uKernel 104 // -x-> ---x---> 105 // +---+ +----+----+----+----+ 106 // ^ | w | ^ | D0 | D1 | D2 | D3 | 107 // | +---+ | +----+----+----+----+ 108 // | | z | | | C0 | C1 | C2 | C3 | 109 // z +---+ z +----+----+----+----+ 110 // | | y | | | B0 | B2 | B2 | B3 | 111 // | +---+ | +----+----+----+----+ 112 // | x | | A0 | A1 | A2 | A3 | 113 // +---+ +----+----+----+----+ 114 // 115 // In the uKernel graphic, cells sharing the same letter are from 116 // the same batch/output channel index, and the number denotes a unique 117 // channel index. To calculate the output texel, the following 118 // calculation is performed: 119 // 120 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 121 // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | 122 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 123 // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | 124 // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ 125 // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | 126 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 127 // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | 128 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 129 // 130 // which is what is expressed in the following calculations. This is done 131 // for each output position. 132 sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]); 133 sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]); 134 sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]); 135 sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]); 136 } 137 } 138 139 for (int i = 0; i < ${OUTPUT_TILE_SIZE[1]} * ${OUTPUT_TILE_SIZE[0]}; ++i) { 140 if (all(lessThan(pos[i], uBlock.out_extents.xyz))) { 141 imageStore( 142 uOutput, 143 pos[i], 144 clamp(sum[i], uBlock.clamp_thresh.x, uBlock.clamp_thresh.y)); 145 } 146 } 147} 148