• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5/*
6 * TILE_SIZE = (1, 1, 1)
7 * WEIGHT_STORAGE = TEXTURE_3D
8 * BIAS_STORAGE = TEXTURE_3D
9 */
10
11layout(std430) buffer;
12
13/* Qualifiers: layout - storage - precision - memory */
14
15/*
16 * Output Image
17 */
18layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput;
19
20/*
21 * Input Textures
22 */
23layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
24layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
25layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
26
27/*
28 * Params Buffer
29 */
30layout(set = 0, binding = 4) uniform PRECISION restrict Block {
31  // quantization scales, xyzw corresponds to output, input, kernel, bias
32  vec4 scales;
33  // quantization zero points, xyzw corresponds to output, input, kernel, bias
34  ivec4 zero_points;
35  // extents of the output texture
36  ivec4 out_extents;
37  // extents of the input texture
38  ivec4 in_extents;
39  // size of the overlay region of the kernel
40  ivec4 overlay_region;
41  // width and height of the kernel
42  ivec2 kernel_size;
43  // convolution parameters
44  ivec2 stride;
45  ivec2 padding;
46  ivec2 dilate;
47  vec2 clamp_thresh;
48}
49uBlock;
50
51/*
52 * Local Work Group
53 */
54layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
55
56/*
57 * Dequantizes a float texel based on a scale and zero point.
58 */
59vec4 dequantize(vec4 tex, float scale, int zero_point) {
60  return scale * (tex - zero_point);
61}
62
63/*
64 * Quantizes a float texel based on a scale and zero point.
65 */
66uvec4 quantize(vec4 tex, float scale, int zero_point) {
67  return uvec4(roundEven(tex / scale) + zero_point);
68}
69
70/*
71 * Computes a 2D quantized convolution. Each shader invocation calculates the
72 * output at a single output location. Currently this is implemented in a naive
73 * way, where inputs are dequantized upon reading in, and requantized upon
74 * writing out.
75 */
76void main() {
77  const ivec3 pos = ivec3(gl_GlobalInvocationID);
78
79  // Return if this global position is outside output texture bounds
80  if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
81    return;
82  }
83
84  // Compute the index of the top-left element of the overlay region. Note that
85  // negative indices can be produced indicating that the top-left element is in
86  // a region added by padding.
87  const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;
88
89  // Compute the start and end of the input indices to load. Padding is assumed
90  // to be constant 0 padding, so any reads from the padding region is skipped.
91  const ivec2 start = max(ivec2(0), ipos);
92  const ivec2 end = min(ipos + uBlock.overlay_region.xy, uBlock.in_extents.xy);
93  // Compute the start of the kernel based on how far we are skipping ahead when
94  // reading the input. Note that these are "canonical" indices.
95  ivec2 kstart = (start - ipos) / uBlock.dilate;
96  // During prepacking, the weight tensor was rearranged in order to optimize
97  // for data access linearity in this shader. Therefore we need to adjust the
98  // canonical coordinates to the corresponding index in the rearranged weight
99  // tensor. the x coordinate is multipled by 4 since each group of 4 channels
100  // is folded into the X axis. The y coordinate is offset based on the z
101  // coordinate because the 2D planes were stacked atop each other vertically.
102  kstart.x *= 4;
103  kstart.y += pos.z * uBlock.kernel_size.y;
104
105  vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0);
106
107  // Perform the convolution by iterating over the overlay region
108  const int dil_y = uBlock.dilate.y;
109  const int dil_x = uBlock.dilate.x;
110  const int ic4 = uBlock.overlay_region.z / 4;
111  for (int z4 = 0; z4 < ic4; ++z4, kstart.x += uBlock.kernel_size.x * 4) {
112    for (int y = start.y, ky = kstart.y; y < end.y; y += dil_y, ++ky) {
113      for (int x = start.x, kx = kstart.x; x < end.x; x += dil_x, kx += 4) {
114        // Read in and dequantize the input texel
115        const vec4 in_tex = dequantize(
116            texelFetch(uInput, ivec3(x, y, z4), 0),
117            uBlock.scales.y,
118            uBlock.zero_points.y);
119
120        // To explain the calculation below, the contents of in_tex and the
121        // group of 4 texels loaded from uKernel are shown:
122        //
123        //   in_tex               uKernel
124        //    -x->                   ---x--->
125        //   +---+              +----+----+----+----+
126        // ^ | w |           ^  | D0 | D1 | D2 | D3 |
127        // | +---+           |  +----+----+----+----+
128        // | | z |           |  | C0 | C1 | C2 | C3 |
129        // z +---+           z  +----+----+----+----+
130        // | | y |           |  | B0 | B2 | B2 | B3 |
131        // | +---+           |  +----+----+----+----+
132        //   | x |              | A0 | A1 | A2 | A3 |
133        //   +---+              +----+----+----+----+
134        //
135        // In the uKernel graphic, cells sharing the same letter are from
136        // the same batch/output channel index, and the number denotes a unique
137        // channel index. To calculate the output texel, the following
138        // calculation is performed:
139        //
140        //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
141        //  | x | | D0 |   | y | | D1 |   | z | | D2 |   | w | | D3 |
142        //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
143        //  | x | | C0 |   | y | | C1 |   | z | | C2 |   | w | | C3 |
144        //  +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
145        //  | x | | B0 |   | y | | B1 |   | z | | B2 |   | w | | B3 |
146        //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
147        //  | x | | A0 |   | y | | A1 |   | z | | A2 |   | w | | A3 |
148        //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
149        //
150        //  which is what is expressed in the following calculations.
151
152        const vec4 ktex_0 = texelFetch(uKernel, ivec2(kx + 0, ky), 0);
153        sum = fma(in_tex.xxxx, ktex_0, sum);
154
155        const vec4 ktex_1 = texelFetch(uKernel, ivec2(kx + 1, ky), 0);
156        sum = fma(in_tex.yyyy, ktex_1, sum);
157
158        const vec4 ktex_2 = texelFetch(uKernel, ivec2(kx + 2, ky), 0);
159        sum = fma(in_tex.zzzz, ktex_2, sum);
160
161        const vec4 ktex_3 = texelFetch(uKernel, ivec2(kx + 3, ky), 0);
162        sum = fma(in_tex.wwww, ktex_3, sum);
163      }
164    }
165  }
166
167  uvec4 out_tex = quantize(
168      clamp(sum, uBlock.clamp_thresh.x, uBlock.clamp_thresh.y),
169      uBlock.scales.x,
170      uBlock.zero_points.x);
171
172  imageStore(uOutput, pos, out_tex);
173}
174