• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5/*
6 * TILE_SIZE = (1, 1, 1)
7 * WEIGHT_STORAGE = TEXTURE_2D
8 * BIAS_STORAGE = TEXTURE_2D
9 */
10
11layout(std430) buffer;
12
13/* Qualifiers: layout - storage - precision - memory */
14
15/*
16 * Output Image
17 */
18layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
19
20/*
21 * Input Textures
22 */
23layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
24layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
25layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
26
27/*
28 * Params Buffer
29 */
30layout(set = 0, binding = 4) uniform PRECISION restrict Block {
31  // extents of the output texture
32  ivec4 out_extents;
33  // extents of the input texture
34  ivec4 in_extents;
35  // size of the overlay region of the kernel
36  ivec4 overlay_region;
37  // width and height of the kernel
38  ivec2 kernel_size;
39  // convolution parameters
40  ivec2 stride;
41  ivec2 padding;
42  ivec2 dilate;
43  vec2 clamp_thresh;
44}
45uBlock;
46
47layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48
49void main() {
50  const ivec3 pos = ivec3(gl_GlobalInvocationID);
51
52  // Return if this global position is outside output texture bounds
53  if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
54    return;
55  }
56
57  const vec2 ksize = vec2(uBlock.kernel_size);
58  const vec2 stride = vec2(uBlock.stride);
59  const vec2 padding = vec2(uBlock.padding);
60
61  ivec2 ipos = pos.xy + uBlock.padding;
62  vec2 ipos_f = vec2(ipos);
63
64  const ivec2 start = max(ivec2(0), ivec2(ceil((ipos_f - ksize + 1) / stride)));
65  const ivec2 end =
66      min(uBlock.in_extents.xy, ivec2(floor(ipos_f / stride)) + 1);
67  ivec2 kstart = start;
68
69  vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0);
70
71  const int ic4 = uBlock.overlay_region.z;
72
73  int ky_start = uBlock.overlay_region.y - 1 -
74      (ipos.y - uBlock.stride.y * start.y) + pos.z * uBlock.kernel_size.y;
75  int kx_start =
76      (uBlock.overlay_region.x - 1 - (ipos.x - uBlock.stride.x * start.x)) *
77      ic4;
78  int kx_stride = ic4 * (uBlock.stride.x - 1);
79
80  for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += uBlock.stride.y) {
81    int kx = kx_start;
82    for (int x = start.x, kx = kx_start; x < end.x; ++x, kx += kx_stride) {
83      for (int z4 = 0; z4 < ic4 / 4; ++z4, kx += 4) {
84        const vec4 In = texelFetch(uInput, ivec3(x, y, z4), 0);
85        const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
86
87        sum = fma(In.xxxx, texelFetch(uKernel, ivec2(kxs.x, ky), 0), sum);
88        sum = fma(In.yyyy, texelFetch(uKernel, ivec2(kxs.y, ky), 0), sum);
89        sum = fma(In.zzzz, texelFetch(uKernel, ivec2(kxs.z, ky), 0), sum);
90        sum = fma(In.wwww, texelFetch(uKernel, ivec2(kxs.w, ky), 0), sum);
91      }
92    }
93  }
94
95  imageStore(
96      uOutput, pos, clamp(sum, uBlock.clamp_thresh.x, uBlock.clamp_thresh.y));
97}
98