• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * TILE_SIZE = (1, 1, 1)
3 * WEIGHT_STORAGE = TEXTURE_2D
4 * BIAS_STORAGE = TEXTURE_2D
5 * Note that for DW kernel IC = 1 so the weight layout is really OC4, H, W, 4oc
6 */
7#version 450 core
8#define PRECISION ${PRECISION}
9#define FORMAT ${FORMAT}
10
11layout(std430) buffer;
12
13// clang-format off
14layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOut;
15// clang-format on
16layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
17layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
18layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
19layout(set = 0, binding = 4) uniform PRECISION restrict Block {
20  // extents of the output texture
21  ivec4 out_extents;
22  // extents of the input texture
23  ivec4 in_extents;
24  // size of the overlay region of the kernel
25  ivec4 overlay_region;
26  // width and height of the kernel
27  ivec2 kernel_size;
28  // convolution parameters
29  ivec2 stride;
30  ivec2 padding;
31  ivec2 dilate;
32  vec2 clamp_thresh;
33}
34uBlock;
35
36layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37
38/*
39 * Computes depthwise convolution. Each shader invocation calculates the output
40 * of a single output location.
41 */
42void main() {
43  const ivec3 pos = ivec3(gl_GlobalInvocationID);
44
45  // Return if this global position is outside output texture bounds
46  if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
47    return;
48  }
49
50  // Compute the index of the top-left element of the overlay region. Note that
51  // negative indices can be produced indicating that the top-left element is in
52  // a region added by padding.
53  const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;
54
55  // Compute the start and end of the input indices to load. Padding is assumed
56  // to be constant 0 padding, so any reads from the padding region is skipped.
57  const ivec2 start = ipos;
58  const ivec2 end = ipos + uBlock.overlay_region.xy;
59
60  vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0);
61  const int dil_y = uBlock.dilate.y;
62  const int dil_x = uBlock.dilate.x;
63  int k_ind = 0;
64  for (int y = start.y; y < end.y; y += dil_y) {
65    for (int x = start.x; x < end.x; x += dil_x) {
66      // The weight kernel was rearranged so that every NxN filter was flattened
67      // so that it fits on one row. Each filter was then stacked on top of each
68      // other vertically.
69      const vec4 k_tex = texelFetch(uKernel, ivec2(k_ind, pos.z), 0);
70      const vec4 i_tex = texelFetch(uInput, ivec3(x, y, pos.z), 0);
71      sum = fma(i_tex, k_tex, sum);
72      k_ind++;
73    }
74  }
75
76  imageStore(
77      uOut, pos, clamp(sum, uBlock.clamp_thresh.x, uBlock.clamp_thresh.y));
78}
79