• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/*
8 * Output Image
9 */
10layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput;
11
12/*
13 * Input Textures
14 */
15layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
16
17/*
18 * Params Buffer
19 */
20layout(set = 0, binding = 2) uniform PRECISION restrict Block {
21  // output texture size (x=width,y=height,z=depth,w=unused)
22  ivec4 size;
23  // input texture size (x=width,y=height,z=depth,w=unused)
24  ivec4 isize;
25  // output tensor size
26  uvec4 tensor_size;
27  // input tensor size
28  uvec4 itensor_size;
29  // input arguments (dim, start, end, step)
30  uvec4 args;
31  // x = output channels aligned to 4, y = input channels aligned to 4
32  uvec2 c_info;
33}
34uBlock;
35
36/*
37 * Local Work Group
38 */
39layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40
41void main() {
42  const ivec3 posOut = ivec3(gl_GlobalInvocationID);
43  const uint max_dst_index = uBlock.tensor_size[0] * uBlock.c_info.x;
44  const uint dim = uBlock.args[0];
45  const uint start = uBlock.args[1];
46  const uint step = uBlock.args[3];
47
48  if (all(lessThan(posOut, uBlock.size.xyz))) {
49    vec4 outval = vec4(0.0);
50    for (uint j = 0; j < 4; ++j) {
51      uint dst_index = posOut.z * 4 + j;
52      if (dst_index >= max_dst_index) {
53        imageStore(uOutput, posOut, outval);
54        // out of range
55        break;
56      }
57
58      // dst dims
59      uint b1 = int(dst_index / uBlock.c_info.x);
60      uint c1 = dst_index % uBlock.c_info.x;
61      uint h1 = posOut.y;
62      uint w1 = posOut.x;
63
64      // src dims
65      uint b = b1;
66      uint c = c1;
67      uint h = h1;
68      uint w = w1;
69
70      if (dim == 0) { // batch
71        b = start + step * b1;
72      } else if (dim == 1) { // feature(channel)
73        c = start + step * c1;
74      }
75
76      uint src_index = b * uBlock.c_info.y + c;
77      ivec3 posIn;
78      posIn.x = int(w);
79      posIn.y = int(h);
80      posIn.z = int(src_index / 4);
81      uint i = (src_index % 4);
82
83      vec4 inval = texelFetch(uInput, posIn, 0);
84      outval[j] = inval[i];
85
86      if (j == 3) {
87        imageStore(uOutput, posOut, outval);
88      }
89    }
90  }
91}
92