1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5layout(std430) buffer; 6 7/* 8 * Output Image 9 */ 10layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput; 11 12/* 13 * Input Textures 14 */ 15layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; 16 17/* 18 * Params Buffer 19 */ 20layout(set = 0, binding = 2) uniform PRECISION restrict Block { 21 // output texture size (x=width,y=height,z=depth,w=unused) 22 ivec4 out_extents; 23 // mask texture size (x=width,y=height,z=depth,w=unused) 24 ivec4 mask_extents; 25 // output extent sizes (x=batch,y=channel,z=height,w=width) 26 uvec4 out_size_info; 27 // mask extent sizes (x=batch,y=channel,z=height,w=width) 28 uvec4 mask_size_info; 29 // x: size of output channel dim up-aligned to 4 30 // y: size of mask channel dim up-aligned to 4 31 uvec2 aligned_channel_info; 32 // value to replace 33 float value; 34} 35uBlock; 36 37/* 38 * Local Work Group 39 */ 40layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 41 42void main() { 43 const ivec3 pos_mask = ivec3(gl_GlobalInvocationID); 44 45 if (any(greaterThanEqual(pos_mask, uBlock.out_extents.xyz))) { 46 return; 47 } 48 49 ivec4 inval = texelFetch(uInput, pos_mask, 0); 50 51 bool mask_has_true = false; 52 for (uint i = 0; i < 4; ++i) { 53 if ((pos_mask.z * 4 + i) % uBlock.aligned_channel_info.y >= 54 uBlock.mask_size_info.y) { 55 break; 56 } 57 if (inval[i] == 1) { 58 mask_has_true = true; 59 } 60 } 61 62 // we traverse the elements of mask. If an element is True, we find the 63 // corresponding positions in the output according to broadcasting and fill 64 // the elements of output with value. Due to the padding at channel dimension, 65 // we have different ways to fill the value depending on whether the channel 66 // dimension is broadcasted or not 67 if (mask_has_true) { 68 bool mask_channel_is_broadcast = 69 uBlock.mask_size_info.y < uBlock.out_size_info.y; 70 uint tex_cnt_in_output_batch = uBlock.aligned_channel_info.x / 4; 71 72 for (uint batch = 0; 73 batch < uBlock.out_size_info.x / uBlock.mask_size_info.x; 74 ++batch) { 75 for (uint height = 0; 76 height < uBlock.out_size_info.z / uBlock.mask_size_info.z; 77 ++height) { 78 for (uint width = 0; 79 width < uBlock.out_size_info.w / uBlock.mask_size_info.w; 80 ++width) { 81 if (mask_channel_is_broadcast) { 82 for (int tex_idx = 0; tex_idx < tex_cnt_in_output_batch; 83 ++tex_idx) { 84 ivec3 write_pos = ivec3( 85 pos_mask.x + width, 86 pos_mask.y + height, 87 tex_cnt_in_output_batch * (batch + pos_mask.z) + tex_idx); 88 vec4 out_tex = imageLoad(uOutput, write_pos); 89 for (int i = 0; i < 4; ++i) { 90 if (tex_idx * 4 + i >= uBlock.out_size_info.y) { 91 break; 92 } 93 out_tex[i] = uBlock.value; 94 } 95 imageStore(uOutput, write_pos, out_tex); 96 } 97 } else { 98 ivec3 write_pos = ivec3( 99 pos_mask.x + width, 100 pos_mask.y + height, 101 pos_mask.z + tex_cnt_in_output_batch * batch); 102 vec4 out_tex = imageLoad(uOutput, write_pos); 103 out_tex = vec4(equal(inval, ivec4(1))) * uBlock.value + vec4(notEqual(inval, ivec4(1))) * out_tex; 104 imageStore(uOutput, write_pos, out_tex); 105 } 106 } 107 } 108 } 109 } 110} 111