1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5layout(std430) buffer; 6 7/* 8 * Output Image 9 */ 10layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; 11 12/* 13 * Input Sampler 14 */ 15layout(set = 0, binding = 1) uniform PRECISION sampler3D uImage; 16 17/* 18 * Params Buffer 19 */ 20layout(set = 0, binding = 2) uniform PRECISION restrict Block { 21 // info.x: dimension to insert at 22 // info.y: channels (for 3d->4d unsqueeze) 23 ivec2 info; 24} 25uBlock; 26 27/* 28 * Local Work Group Size 29 */ 30layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 31 32/* 33 * Returns a new tensor with dimension of size one inserted at the specified 34 * position (dim) 35 */ 36void main() { 37 const ivec3 pos = ivec3(gl_GlobalInvocationID); 38 const int dim = uBlock.info.x; 39 const int channels = uBlock.info.y; 40 vec4 out_texel = vec4(0, 0, 0, 0); 41 if (dim == 1) { 42 int src_x = pos.x; 43 int src_y = pos.y; 44 int src_z = pos.z / 4; 45 int p = pos.z % 4; 46 const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0); 47 out_texel[0] = v[p]; 48 imageStore(uOutput, pos, out_texel); 49 } else if (dim == 2) { 50 int src_x = pos.x; 51 int src_z = pos.z / (channels * 4); 52 for (int i = 0; i < 4; i++) { 53 int src_y = i + (pos.z % channels) * 4; 54 int p = (pos.z / channels) % 4; 55 const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0); 56 out_texel[i] = v[p]; 57 } 58 imageStore(uOutput, pos, out_texel); 59 } else if (dim == 3) { 60 int src_x = pos.y; 61 int src_z = pos.z / (channels * 4); 62 for (int i = 0; i < 4; i++) { 63 int src_y = i + (pos.z % channels) * 4; 64 int p = (pos.z / channels) % 4; 65 const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0); 66 out_texel[i] = v[p]; 67 } 68 imageStore(uOutput, pos, out_texel); 69 } 70} 71