• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/* Qualifiers: layout - storage - precision - memory */
8
9layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D   uOutput;
10layout(set = 0, binding = 1)         uniform PRECISION                    sampler3D uInput;
11layout(set = 0, binding = 2)         uniform PRECISION restrict           Block {
12  ivec3 size;  // output size
13  int ch;   // channel size of the output
14} uBlock;
15
16layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
17
18void main() {
19  const ivec3 pos = ivec3(gl_GlobalInvocationID);
20
21  if (all(lessThan(pos, uBlock.size.xyz))) {
22    const int z0a = 2 * ((4 * pos.z) / uBlock.ch) * uBlock.ch + ((4 * pos.z) % uBlock.ch);
23    const int z1a = 2 * ((4 * pos.z + 1) / uBlock.ch) * uBlock.ch + ((4 * pos.z + 1) % uBlock.ch);
24    const int z2a = 2 * ((4 * pos.z + 2) / uBlock.ch) * uBlock.ch + ((4 * pos.z + 2) % uBlock.ch);
25    const int z3a = 2 * ((4 * pos.z + 3) / uBlock.ch) * uBlock.ch + ((4 * pos.z + 3) % uBlock.ch);
26
27    const int z0b = z0a + uBlock.ch;
28    const int z1b = z1a + uBlock.ch;
29    const int z2b = z2a + uBlock.ch;
30    const int z3b = z3a + uBlock.ch;
31
32    const float v0a = texelFetch(uInput, ivec3(pos.x, pos.y, z0a / 4), 0)[z0a % 4];
33    const float v0b = texelFetch(uInput, ivec3(pos.x, pos.y, z0b / 4), 0)[z0b % 4];
34    const float v1a = texelFetch(uInput, ivec3(pos.x, pos.y, z1a / 4), 0)[z1a % 4];
35    const float v1b = texelFetch(uInput, ivec3(pos.x, pos.y, z1b / 4), 0)[z1b % 4];
36    const float v2a = texelFetch(uInput, ivec3(pos.x, pos.y, z2a / 4), 0)[z2a % 4];
37    const float v2b = texelFetch(uInput, ivec3(pos.x, pos.y, z2b / 4), 0)[z2b % 4];
38    const float v3a = texelFetch(uInput, ivec3(pos.x, pos.y, z3a / 4), 0)[z3a % 4];
39    const float v3b = texelFetch(uInput, ivec3(pos.x, pos.y, z3b / 4), 0)[z3b % 4];
40
41    imageStore(
42        uOutput,
43        pos,
44        vec4(
45            v0a * (1 / (1 + exp(-1 * v0b))),
46            v1a * (1 / (1 + exp(-1 * v1b))),
47            v2a * (1 / (1 + exp(-1 * v2b))),
48            v3a * (1 / (1 + exp(-1 * v3b)))
49        )
50    );
51  }
52}
53