• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/*
8 * Output Image
9 */
10layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
11
12/*
13 * Input Textures
14 */
15layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
16layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma;
17layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta;
18layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean;
19layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar;
20
21/*
22 * Params Buffer
23 */
24layout(set = 0, binding = 6) uniform PRECISION restrict Block {
25  // xyz contains extents of the output texture, w contains the number of
26  // channels divided by 4, rounded up.
27  ivec4 out_extents;
28  float eps;
29}
30uBlock;
31
32/*
33 * Local Work Group
34 */
35layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
36
37/*
38 * Computes a Batch normalization. Each shader invocation calculates the output
39 * at a single output location.
40 */
41void main() {
42  const ivec3 pos = ivec3(gl_GlobalInvocationID);
43
44  // Return if this global position is outside output texture bounds
45  if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
46    return;
47  }
48
49  const ivec3 ch_pos = ivec3(0, 0, pos.z % uBlock.out_extents.w);
50
51  const vec4 in_tex = texelFetch(uInput, pos, 0);
52  const vec4 gamma_tex = texelFetch(uGamma, ch_pos, 0);
53  const vec4 beta_tex = texelFetch(uBeta, ch_pos, 0);
54  const vec4 mean_tex = texelFetch(uMean, ch_pos, 0);
55  const vec4 var_tex = texelFetch(uVar, ch_pos, 0);
56
57  const vec4 out_tex =
58      (in_tex - mean_tex) / sqrt(var_tex + uBlock.eps) * gamma_tex + beta_tex;
59
60  imageStore(uOutput, pos, out_tex);
61}
62