• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/* Qualifiers: layout - storage - precision - memory */
8layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
9layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
10layout(set = 0, binding = 2) uniform PRECISION restrict Block {
11  // x=width, y=height, z=channel, w=batch
12  uvec4 extents;
13  // x=width, y=height, z=channel, w=batch
14  // 1=flip, 0=noflip
15  ivec4 dims;
16}
17uBlock;
18
19/*
20 * Local Work Group Size
21 */
22layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
23
24/*
25 * Returns a new tensor with values flipped along dimension dim
26 */
27
28void main() {
29  const ivec3 pos = ivec3(gl_GlobalInvocationID);
30  int flattened_channels = int(ceil(uBlock.extents.z / 4.0));
31  vec4 out_texel = vec4(0, 0, 0, 0);
32  uint src_x = pos.x;
33  uint src_y = pos.y;
34  uint src_z = pos.z;
35
36  // Width
37  if (uBlock.dims.x == 1) {
38    src_x = uBlock.extents.x - 1 - pos.x;
39  }
40
41  // Height
42  if (uBlock.dims.y == 1) {
43    src_y = uBlock.extents.y - 1 - pos.y;
44  }
45
46  // Batch
47  if (uBlock.dims.w == 1) {
48    uint n = pos.z / flattened_channels;
49    uint src_n = uBlock.extents.w - 1 - n;
50    uint c_div4 = pos.z - n * flattened_channels;
51    src_z = src_n * flattened_channels + c_div4;
52  }
53
54  uint prev_src_z = src_z; // save this
55  for (int p = 0; p < 4; p++) {
56    uint src_p = p;
57
58    // Channel
59    if (uBlock.dims.z == 1) {
60      // n * [C/4]
61      uint nc = (pos.z / flattened_channels) * flattened_channels;
62      // i / 4
63      uint c_div4 = pos.z - nc;
64      uint c = c_div4 * 4 + p;
65      uint src_c = uBlock.extents.z - 1 - c;
66
67      src_z = (uBlock.dims.w == 1)
68          ? prev_src_z - c_div4 + src_c / 4 // Batch and Channel
69          : nc + src_c / 4; // Channel only
70      src_p = src_c % 4;
71    }
72
73    vec4 v = texelFetch(uInput, ivec3(src_x, src_y, src_z), 0);
74    out_texel[p] = v[src_p];
75    imageStore(uOutput, pos, out_texel);
76  }
77}
78