• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * TILE_SIZE = (${OUTPUT_TILE_SIZE[0]}, ${OUTPUT_TILE_SIZE[1]}, 1)
3 * WEIGHT_STORAGE = TEXTURE_2D
4 * WEIGHT_STORAGE_LAYOUT = OC4,IC4,4ic,4oc
5 * BIAS_STORAGE = TEXTURE_2D
6 * REGISTER_FOR = ${REGISTER_FOR}
7 */
8#version 450 core
9#define PRECISION ${PRECISION}
10#define FORMAT ${FORMAT}
11
12layout(std430) buffer;
13
14// clang-format off
15layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
16// clang-format on
17layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
18layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
19layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
20layout(set = 0, binding = 4) uniform PRECISION restrict Block {
21  // extents of the output texture
22  ivec4 out_extents;
23  // extents of the input texture
24  ivec4 in_extents;
25  // size of the overlay region of the kernel
26  ivec4 overlay_region;
27  // width and height of the kernel
28  ivec2 kernel_size;
29  // convolution parameters
30  ivec2 stride;
31  ivec2 padding;
32  ivec2 dilate;
33  vec2 clamp_thresh;
34}
35uBlock;
36
37layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
38
39/*
40 * Computes a 2D pointwise convolution of a 2x2 output tile. Calculating an
41 * output tile for pointwise convolution is more efficient because the kernel
42 * size is only 1x1, making it much easier to re-use loaded texels from uKernel.
43 */
44void main() {
45  const ivec3 gpos = ivec3(gl_GlobalInvocationID);
46
47  // Output position for OUTPUT_TILE_SIZE[0], OUTPUT_TILE_SIZE[1] = 2, 2
48  // +--------+--------+
49  // | pos[0] | pos[1] |
50  // +--------+--------+
51  // | pos[2] | pos[3] |
52  // +--------+--------+
53  ivec3 pos[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}];
54  for (int y = 0, i = 0; y < ${OUTPUT_TILE_SIZE[1]}; ++y) {
55    for (int x = 0; x < ${OUTPUT_TILE_SIZE[0]}; ++x) {
56      pos[i] = ivec3(
57          gpos.x * ${OUTPUT_TILE_SIZE[0]} + x, gpos.y * ${OUTPUT_TILE_SIZE[1]} + y, gpos.z);
58      i++;
59    }
60  }
61
62  // If the top left position is out of bounds, then this invocation will have
63  // no work to do.
64  if (any(greaterThanEqual(pos[0], uBlock.out_extents.xyz))) {
65    return;
66  }
67
68  // Compute the index of the input texture that needs to be loaded for each
69  // output position. Note that negative indices can be produced indicating that
70  // the top-left element is in a region added by padding.
71  ivec2 ipos[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}];
72  for (int i = 0; i < ${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}; ++i) {
73    ipos[i] = pos[i].xy * uBlock.stride - uBlock.padding;
74  }
75
76  vec4 sum[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}];
77  sum[0] = texelFetch(uBias, ivec2(gpos.z, 0), 0);
78  for (int i = 1; i < ${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}; ++i) {
79    sum[i] = sum[0];
80  }
81
82  // Since the kernel is 1x1, we only have to loop over the depth dimension.
83  const int ic_aligned = uBlock.overlay_region.z;
84  for (int z = 0, z4 = 0; z < ic_aligned; z += 4, ++z4) {
85    // During prepacking, the weight tensor has been permuted so that the
86    // channel (IC) dim is along the x axis, and the batch (OC) dim is along
87    // the z axis.
88    vec4 in_tex[${OUTPUT_TILE_SIZE[0]} * ${OUTPUT_TILE_SIZE[1]}];
89    const vec4 ktex_0 = texelFetch(uKernel, ivec2(z + 0, gpos.z), 0);
90    const vec4 ktex_1 = texelFetch(uKernel, ivec2(z + 1, gpos.z), 0);
91    const vec4 ktex_2 = texelFetch(uKernel, ivec2(z + 2, gpos.z), 0);
92    const vec4 ktex_3 = texelFetch(uKernel, ivec2(z + 3, gpos.z), 0);
93
94    for (int i = 0; i < ${OUTPUT_TILE_SIZE[1]} * ${OUTPUT_TILE_SIZE[0]}; ++i) {
95      in_tex[i] = texelFetch(uInput, ivec3(ipos[i], z4), 0);
96    }
97
98    for (int i = 0; i < ${OUTPUT_TILE_SIZE[1]} * ${OUTPUT_TILE_SIZE[0]}; ++i) {
99      // For 2x2 tile size algorithm works as follows.
100      // To explain the calculations below, the contents one in_tex and the
101      // group of 4 texels loaded from uKernel are shown:
102      //
103      //   in_tex               uKernel
104      //    -x->                   ---x--->
105      //   +---+              +----+----+----+----+
106      // ^ | w |           ^  | D0 | D1 | D2 | D3 |
107      // | +---+           |  +----+----+----+----+
108      // | | z |           |  | C0 | C1 | C2 | C3 |
109      // z +---+           z  +----+----+----+----+
110      // | | y |           |  | B0 | B2 | B2 | B3 |
111      // | +---+           |  +----+----+----+----+
112      //   | x |              | A0 | A1 | A2 | A3 |
113      //   +---+              +----+----+----+----+
114      //
115      // In the uKernel graphic, cells sharing the same letter are from
116      // the same batch/output channel index, and the number denotes a unique
117      // channel index. To calculate the output texel, the following
118      // calculation is performed:
119      //
120      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
121      //  | x | | D0 |   | y | | D1 |   | z | | D2 |   | w | | D3 |
122      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
123      //  | x | | C0 |   | y | | C1 |   | z | | C2 |   | w | | C3 |
124      //  +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
125      //  | x | | B0 |   | y | | B1 |   | z | | B2 |   | w | | B3 |
126      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
127      //  | x | | A0 |   | y | | A1 |   | z | | A2 |   | w | | A3 |
128      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
129      //
130      //  which is what is expressed in the following calculations. This is done
131      //  for each output position.
132      sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]);
133      sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]);
134      sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]);
135      sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]);
136    }
137  }
138
139  for (int i = 0; i < ${OUTPUT_TILE_SIZE[1]} * ${OUTPUT_TILE_SIZE[0]}; ++i) {
140    if (all(lessThan(pos[i], uBlock.out_extents.xyz))) {
141      imageStore(
142          uOutput,
143          pos[i],
144          clamp(sum[i], uBlock.clamp_thresh.x, uBlock.clamp_thresh.y));
145    }
146  }
147}
148