1#version 450 core 2#define PRECISION ${PRECISION} 3#define FORMAT ${FORMAT} 4 5layout(std430) buffer; 6 7/* Qualifiers: layout - storage - precision - memory */ 8 9/* 10 * Output Image 11 */ 12layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; 13 14/* 15 * Input Buffer 16 */ 17layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 18 19/* 20 * Params Buffer 21 * input_shader_extents is the dimensions of the Vulkan 3D texture XYZ 22 * with a zero pad at W. 23 * input_tensor_dims is the dimensions of the NCHW PyTorch Tensor. 24 * input_dim_stride is the stride to include elements along the softmax 25 * dimension calculation. early_exit is the global workgroup position-based 26 * condition for unnecessary invocations to exit. 27 */ 28layout(set = 0, binding = 2) uniform PRECISION restrict Block { 29 ivec4 input_shader_extents; 30 ivec4 input_tensor_dims; 31 ivec4 input_dim_stride; 32 ivec4 early_exit; 33} 34uBlock; 35 36/* 37 * Local Work Group Size 38 */ 39layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 40 41void main() { 42 const ivec3 pos = ivec3(gl_GlobalInvocationID); 43 // how "wide" a batch is in terms of z. Only have one invocation per batch, 44 // as one batch width has elements from every channel in-memory. 45 if (!all(lessThan(pos, uBlock.early_exit.xyz))) { 46 return; 47 } 48 const int b_stride = int(ceil(uBlock.input_tensor_dims.y / 4.0)); 49 const ivec3 src_pos = ivec3(pos.x, pos.y, pos.z * b_stride); 50 // tail case, padded zeros in memory if tensor's channel dim % 4 != 0 51 uint tail_case_size = uBlock.input_tensor_dims.y % 4; 52 if (tail_case_size == 0) { 53 tail_case_size = 4; 54 } 55 // Calculate the denominator for the whole dimension. 56 // For numerical stability to avoid floating point overflow, 57 // we leverage the translation invariance of the softmax function, 58 // subtracting every element along channel by the maximum element along 59 // channel. find the maximum element 60 float max_element = texelFetch(uInput, src_pos, 0)[0]; 61 for (int c = 0; c < b_stride - 1; c++) { 62 const vec4 c_texel = 63 texelFetch(uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + c), 0); 64 for (int t = 0; t < 4; t++) { 65 if (c_texel[t] > max_element) { 66 max_element = c_texel[t]; 67 } 68 } 69 } 70 vec4 c_texel = texelFetch( 71 uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + b_stride - 1), 0); 72 for (int t = 0; t < tail_case_size; t++) { 73 if (c_texel[t] > max_element) { 74 max_element = c_texel[t]; 75 } 76 } 77 // Calculate the denominator. 78 float denominator = 0; 79 for (int c = 0; c < b_stride - 1; c++) { 80 const vec4 c_texel = 81 texelFetch(uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + c), 0); 82 for (int t = 0; t < 4; t++) { 83 denominator += exp(c_texel[t] - max_element); 84 } 85 } 86 c_texel = texelFetch( 87 uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + b_stride - 1), 0); 88 for (int t = 0; t < tail_case_size; t++) { 89 denominator += exp(c_texel[t] - max_element); 90 } 91 // Calculate every final channel element. 92 for (int c = 0; c < b_stride; c++) { 93 const ivec3 dst_pos = ivec3(src_pos.x, src_pos.y, src_pos.z + c); 94 const vec4 numerator = exp(texelFetch(uInput, dst_pos, 0) - max_element); 95 imageStore(uOutput, dst_pos, numerator / denominator); 96 } 97} 98