• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/*
8 * Output Image
9 */
10layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput;
11
12/*
13 * Input Textures
14 */
15layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
16
17/*
18 * Params Buffer
19 */
20layout(set = 0, binding = 2) uniform PRECISION restrict Block {
21  // output texture size (x=width,y=height,z=depth,w=unused)
22  ivec4 out_extents;
23  // mask texture size (x=width,y=height,z=depth,w=unused)
24  ivec4 mask_extents;
25  // output extent sizes (x=batch,y=channel,z=height,w=width)
26  uvec4 out_size_info;
27  // mask extent sizes (x=batch,y=channel,z=height,w=width)
28  uvec4 mask_size_info;
29  // x: size of output channel dim up-aligned to 4
30  // y: size of mask channel dim up-aligned to 4
31  uvec2 aligned_channel_info;
32  // value to replace
33  float value;
34}
35uBlock;
36
37/*
38 * Local Work Group
39 */
40layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
41
42void main() {
43  const ivec3 pos_mask = ivec3(gl_GlobalInvocationID);
44
45  if (any(greaterThanEqual(pos_mask, uBlock.out_extents.xyz))) {
46    return;
47  }
48
49  ivec4 inval = texelFetch(uInput, pos_mask, 0);
50
51  bool mask_has_true = false;
52  for (uint i = 0; i < 4; ++i) {
53    if ((pos_mask.z * 4 + i) % uBlock.aligned_channel_info.y >=
54        uBlock.mask_size_info.y) {
55      break;
56    }
57    if (inval[i] == 1) {
58      mask_has_true = true;
59    }
60  }
61
62  // we traverse the elements of mask. If an element is True, we find the
63  // corresponding positions in the output according to broadcasting and fill
64  // the elements of output with value. Due to the padding at channel dimension,
65  // we have different ways to fill the value depending on whether the channel
66  // dimension is broadcasted or not
67  if (mask_has_true) {
68    bool mask_channel_is_broadcast =
69        uBlock.mask_size_info.y < uBlock.out_size_info.y;
70    uint tex_cnt_in_output_batch = uBlock.aligned_channel_info.x / 4;
71
72    for (uint batch = 0;
73         batch < uBlock.out_size_info.x / uBlock.mask_size_info.x;
74         ++batch) {
75      for (uint height = 0;
76           height < uBlock.out_size_info.z / uBlock.mask_size_info.z;
77           ++height) {
78        for (uint width = 0;
79             width < uBlock.out_size_info.w / uBlock.mask_size_info.w;
80             ++width) {
81          if (mask_channel_is_broadcast) {
82            for (int tex_idx = 0; tex_idx < tex_cnt_in_output_batch;
83                 ++tex_idx) {
84              ivec3 write_pos = ivec3(
85                  pos_mask.x + width,
86                  pos_mask.y + height,
87                  tex_cnt_in_output_batch * (batch + pos_mask.z) + tex_idx);
88              vec4 out_tex = imageLoad(uOutput, write_pos);
89              for (int i = 0; i < 4; ++i) {
90                if (tex_idx * 4 + i >= uBlock.out_size_info.y) {
91                  break;
92                }
93                out_tex[i] = uBlock.value;
94              }
95              imageStore(uOutput, write_pos, out_tex);
96            }
97          } else {
98            ivec3 write_pos = ivec3(
99                pos_mask.x + width,
100                pos_mask.y + height,
101                pos_mask.z + tex_cnt_in_output_batch * batch);
102            vec4 out_tex = imageLoad(uOutput, write_pos);
103            out_tex = vec4(equal(inval, ivec4(1))) * uBlock.value + vec4(notEqual(inval, ivec4(1))) * out_tex;
104            imageStore(uOutput, write_pos, out_tex);
105          }
106        }
107      }
108    }
109  }
110}
111