• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/*
8 * Output Image
9 */
10layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
11
12/*
13 * Input Sampler
14 */
15layout(set = 0, binding = 1) uniform PRECISION sampler3D uImage;
16
17/*
18 * Params Buffer
19 */
20layout(set = 0, binding = 2) uniform PRECISION restrict Block {
21  // info.x: dimension to insert at
22  // info.y: channels (for 3d->4d unsqueeze)
23  ivec2 info;
24}
25uBlock;
26
27/*
28 * Local Work Group Size
29 */
30layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
31
32/*
33 * Returns a new tensor with dimension of size one inserted at the specified
34 * position (dim)
35 */
36void main() {
37  const ivec3 pos = ivec3(gl_GlobalInvocationID);
38  const int dim = uBlock.info.x;
39  const int channels = uBlock.info.y;
40  vec4 out_texel = vec4(0, 0, 0, 0);
41  if (dim == 1) {
42    int src_x = pos.x;
43    int src_y = pos.y;
44    int src_z = pos.z / 4;
45    int p = pos.z % 4;
46    const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0);
47    out_texel[0] = v[p];
48    imageStore(uOutput, pos, out_texel);
49  } else if (dim == 2) {
50    int src_x = pos.x;
51    int src_z = pos.z / (channels * 4);
52    for (int i = 0; i < 4; i++) {
53      int src_y = i + (pos.z % channels) * 4;
54      int p = (pos.z / channels) % 4;
55      const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0);
56      out_texel[i] = v[p];
57    }
58    imageStore(uOutput, pos, out_texel);
59  } else if (dim == 3) {
60    int src_x = pos.y;
61    int src_z = pos.z / (channels * 4);
62    for (int i = 0; i < 4; i++) {
63      int src_y = i + (pos.z % channels) * 4;
64      int p = (pos.z / channels) % 4;
65      const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0);
66      out_texel[i] = v[p];
67    }
68    imageStore(uOutput, pos, out_texel);
69  }
70}
71