• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/* Qualifiers: layout - storage - precision - memory */
8layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
9layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
10layout(set = 0, binding = 2) uniform PRECISION restrict Block {
11  // dim_info.x: dim to compute mean
12  // dim_info.y: size of dim (in the input)
13  uvec2 dim_info;
14  int channel;
15}
16uBlock;
17
18/*
19 * Local Work Group Size
20 */
21layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
22
23/*
24 * Returns a new tensor with values averaged along dimension dim
25 * Dimension dim is squeezed
26 * For each pos:
27 *  - Iterate over the out_texel and the averaged dimension
28 *  - For H,W; rearrange pos.x, pos.y
29 *  - For C,H,W;
30 *      When CHW are averaged, batch moves into channel
31 *      The src N is determined by pos.z * 4 + out_index
32 */
33
34void main() {
35  const ivec3 pos = ivec3(gl_GlobalInvocationID);
36
37  int flattened_channels = int(ceil(uBlock.channel / 4.0));
38  vec4 out_texel = vec4(0, 0, 0, 0);
39
40  // Batch
41  if (uBlock.dim_info.x == 0) {
42    for (int batch = 0; batch < uBlock.dim_info.y; batch++) {
43      // src_n = batch
44      // src_c = pos.z
45      int src_z = batch * flattened_channels + pos.z;
46      vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
47      out_texel += v;
48    }
49    imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
50  }
51
52  // Channel
53  else if (uBlock.dim_info.x == 1) {
54    for (int out_index = 0; out_index < 4; out_index++) {
55      for (int channel = 0; channel < uBlock.dim_info.y; channel++) {
56        // src_n = pos.z * 4 + out_index
57        // src_c = channel
58        int src_z =
59            (pos.z * 4 + out_index) * flattened_channels + int(channel / 4);
60        vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
61        out_texel[out_index] += v[channel % 4];
62      }
63    }
64    imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
65  }
66
67  // Height, Width
68  else {
69    for (int out_index = 0; out_index < 4; out_index++) {
70      // src_n = pos.z * 4 + out_index
71      // src_c = pos.y
72      int src_z = (pos.z * 4 + out_index) * flattened_channels + pos.y / 4;
73      for (int hw = 0; hw < uBlock.dim_info.y; hw++) {
74        vec4 v = (uBlock.dim_info.x == 2)
75            ? texelFetch(uInput, ivec3(pos.x, hw, src_z), 0) // Height
76            : texelFetch(uInput, ivec3(hw, pos.x, src_z), 0); // Width
77        out_texel[out_index] += v[pos.y % 4];
78      }
79    }
80    imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
81  }
82}
83