• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/*
8 * Output Image
9 */
10layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput;
11
12/*
13 * Input Textures
14 */
15layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
16
17/*
18 * Params Buffer
19 */
20layout(set = 0, binding = 2) uniform PRECISION restrict Block {
21  // output texture size (x=width,y=height,z=depth,w=unused)
22  ivec4 size;
23  // input texture size (x=width,y=height,z=depth,w=unused)
24  ivec4 isize;
25  // output tensor size
26  uvec4 out_tensor_size;
27  // input tensor size
28  uvec4 in_tensor_size;
29  // output dims
30  uvec4 out_ndims;
31  // x = output channels aligned to 4, y = input channels aligned to 4
32  uvec2 ch_info;
33}
34uBlock;
35
36/*
37 * Local Work Group
38 */
39layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40
41void main() {
42  const ivec3 posOut = ivec3(gl_GlobalInvocationID);
43
44  if (all(lessThan(posOut, uBlock.size.xyz))) {
45    const uint max_dst_index = uBlock.out_tensor_size[0] * uBlock.ch_info.x;
46    vec4 outval = vec4(0.0);
47
48    for (uint j = 0; j < 4; ++j) {
49      uint dst_index = posOut.z * 4 + j;
50      if (dst_index >= max_dst_index) {
51        imageStore(uOutput, posOut, outval);
52        // out of range
53        break;
54      }
55
56      uint b1 = int(dst_index / uBlock.ch_info.x);
57      uint c1 = dst_index % uBlock.ch_info.x;
58      uint h1 = posOut.y;
59      uint w1 = posOut.x;
60
61      uint b, c, h, w;
62      switch (uBlock.out_ndims[0]) {
63        case 0:
64          b = b1;
65          break;
66        case 1:
67          c = b1;
68          break;
69        case 2:
70          h = b1;
71          break;
72        case 3:
73          w = b1;
74          break;
75      }
76
77      switch (uBlock.out_ndims[1]) {
78        case 0:
79          b = c1;
80          break;
81        case 1:
82          c = c1;
83          break;
84        case 2:
85          h = c1;
86          break;
87        case 3:
88          w = c1;
89          break;
90      }
91
92      switch (uBlock.out_ndims[2]) {
93        case 0:
94          b = h1;
95          break;
96        case 1:
97          c = h1;
98          break;
99        case 2:
100          h = h1;
101          break;
102        case 3:
103          w = h1;
104          break;
105      }
106
107      switch (uBlock.out_ndims[3]) {
108        case 0:
109          b = w1;
110          break;
111        case 1:
112          c = w1;
113          break;
114        case 2:
115          h = w1;
116          break;
117        case 3:
118          w = w1;
119          break;
120      }
121
122      uint src_index = b * uBlock.ch_info.y + c;
123      ivec3 posIn;
124      posIn.x = int(w);
125      posIn.y = int(h);
126      posIn.z = int(src_index / 4);
127      uint i = (src_index % 4);
128
129      vec4 inval = texelFetch(uInput, posIn, 0);
130      outval[j] = inval[i];
131
132      if (j == 3) {
133        imageStore(uOutput, posOut, outval);
134      }
135    }
136  }
137}
138