1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5/* 6 * TILE_SIZE = (1, 1, 1) 7 * WEIGHT_STORAGE = TEXTURE_3D 8 * BIAS_STORAGE = TEXTURE_3D 9 */ 10 11layout(std430) buffer; 12 13/* Qualifiers: layout - storage - precision - memory */ 14 15/* 16 * Output Image 17 */ 18layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; 19 20/* 21 * Input Textures 22 */ 23layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; 24layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel; 25layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias; 26 27/* 28 * Params Buffer 29 */ 30layout(set = 0, binding = 4) uniform PRECISION restrict Block { 31 // quantization scales, xyzw corresponds to output, input, kernel, bias 32 vec4 scales; 33 // quantization zero points, xyzw corresponds to output, input, kernel, bias 34 ivec4 zero_points; 35 // extents of the output texture 36 ivec4 out_extents; 37 // extents of the input texture 38 ivec4 in_extents; 39 // size of the overlay region of the kernel 40 ivec4 overlay_region; 41 // width and height of the kernel 42 ivec2 kernel_size; 43 // convolution parameters 44 ivec2 stride; 45 ivec2 padding; 46 ivec2 dilate; 47 vec2 clamp_thresh; 48} 49uBlock; 50 51/* 52 * Local Work Group 53 */ 54layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 55 56/* 57 * Dequantizes a float texel based on a scale and zero point. 58 */ 59vec4 dequantize(vec4 tex, float scale, int zero_point) { 60 return scale * (tex - zero_point); 61} 62 63/* 64 * Quantizes a float texel based on a scale and zero point. 65 */ 66uvec4 quantize(vec4 tex, float scale, int zero_point) { 67 return uvec4(roundEven(tex / scale) + zero_point); 68} 69 70/* 71 * Computes a 2D quantized convolution. Each shader invocation calculates the 72 * output at a single output location. Currently this is implemented in a naive 73 * way, where inputs are dequantized upon reading in, and requantized upon 74 * writing out. 75 */ 76void main() { 77 const ivec3 pos = ivec3(gl_GlobalInvocationID); 78 79 // Return if this global position is outside output texture bounds 80 if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { 81 return; 82 } 83 84 // Compute the index of the top-left element of the overlay region. Note that 85 // negative indices can be produced indicating that the top-left element is in 86 // a region added by padding. 87 const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; 88 89 // Compute the start and end of the input indices to load. Padding is assumed 90 // to be constant 0 padding, so any reads from the padding region is skipped. 91 const ivec2 start = max(ivec2(0), ipos); 92 const ivec2 end = min(ipos + uBlock.overlay_region.xy, uBlock.in_extents.xy); 93 // Compute the start of the kernel based on how far we are skipping ahead when 94 // reading the input. Note that these are "canonical" indices. 95 ivec2 kstart = (start - ipos) / uBlock.dilate; 96 // During prepacking, the weight tensor was rearranged in order to optimize 97 // for data access linearity in this shader. Therefore we need to adjust the 98 // canonical coordinates to the corresponding index in the rearranged weight 99 // tensor. the x coordinate is multipled by 4 since each group of 4 channels 100 // is folded into the X axis. The y coordinate is offset based on the z 101 // coordinate because the 2D planes were stacked atop each other vertically. 102 kstart.x *= 4; 103 kstart.y += pos.z * uBlock.kernel_size.y; 104 105 vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0); 106 107 // Perform the convolution by iterating over the overlay region 108 const int dil_y = uBlock.dilate.y; 109 const int dil_x = uBlock.dilate.x; 110 const int ic4 = uBlock.overlay_region.z / 4; 111 for (int z4 = 0; z4 < ic4; ++z4, kstart.x += uBlock.kernel_size.x * 4) { 112 for (int y = start.y, ky = kstart.y; y < end.y; y += dil_y, ++ky) { 113 for (int x = start.x, kx = kstart.x; x < end.x; x += dil_x, kx += 4) { 114 // Read in and dequantize the input texel 115 const vec4 in_tex = dequantize( 116 texelFetch(uInput, ivec3(x, y, z4), 0), 117 uBlock.scales.y, 118 uBlock.zero_points.y); 119 120 // To explain the calculation below, the contents of in_tex and the 121 // group of 4 texels loaded from uKernel are shown: 122 // 123 // in_tex uKernel 124 // -x-> ---x---> 125 // +---+ +----+----+----+----+ 126 // ^ | w | ^ | D0 | D1 | D2 | D3 | 127 // | +---+ | +----+----+----+----+ 128 // | | z | | | C0 | C1 | C2 | C3 | 129 // z +---+ z +----+----+----+----+ 130 // | | y | | | B0 | B2 | B2 | B3 | 131 // | +---+ | +----+----+----+----+ 132 // | x | | A0 | A1 | A2 | A3 | 133 // +---+ +----+----+----+----+ 134 // 135 // In the uKernel graphic, cells sharing the same letter are from 136 // the same batch/output channel index, and the number denotes a unique 137 // channel index. To calculate the output texel, the following 138 // calculation is performed: 139 // 140 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 141 // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | 142 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 143 // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | 144 // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ 145 // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | 146 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 147 // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | 148 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 149 // 150 // which is what is expressed in the following calculations. 151 152 const vec4 ktex_0 = texelFetch(uKernel, ivec2(kx + 0, ky), 0); 153 sum = fma(in_tex.xxxx, ktex_0, sum); 154 155 const vec4 ktex_1 = texelFetch(uKernel, ivec2(kx + 1, ky), 0); 156 sum = fma(in_tex.yyyy, ktex_1, sum); 157 158 const vec4 ktex_2 = texelFetch(uKernel, ivec2(kx + 2, ky), 0); 159 sum = fma(in_tex.zzzz, ktex_2, sum); 160 161 const vec4 ktex_3 = texelFetch(uKernel, ivec2(kx + 3, ky), 0); 162 sum = fma(in_tex.wwww, ktex_3, sum); 163 } 164 } 165 } 166 167 uvec4 out_tex = quantize( 168 clamp(sum, uBlock.clamp_thresh.x, uBlock.clamp_thresh.y), 169 uBlock.scales.x, 170 uBlock.zero_points.x); 171 172 imageStore(uOutput, pos, out_tex); 173} 174