• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5layout(std430) buffer;
6
7/* Qualifiers: layout - storage - precision - memory */
8
9/*
10 * Output Image
11 */
12layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
13
14/*
15 * Input Buffer
16 */
17layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
18
19/*
20 * Params Buffer
21 * input_shader_extents is the dimensions of the Vulkan 3D texture XYZ
22 * with a zero pad at W.
23 * input_tensor_dims is the dimensions of the NCHW PyTorch Tensor.
24 * input_dim_stride is the stride to include elements along the softmax
25 * dimension calculation. early_exit is the global workgroup position-based
26 * condition for unnecessary invocations to exit.
27 */
28layout(set = 0, binding = 2) uniform PRECISION restrict Block {
29  ivec4 input_shader_extents;
30  ivec4 input_tensor_dims;
31  ivec4 input_dim_stride;
32  ivec4 early_exit;
33}
34uBlock;
35
36/*
37 * Local Work Group Size
38 */
39layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40
41void main() {
42  const ivec3 pos = ivec3(gl_GlobalInvocationID);
43  // how "wide" a batch is in terms of z. Only have one invocation per batch,
44  // as one batch width has elements from every channel in-memory.
45  if (!all(lessThan(pos, uBlock.early_exit.xyz))) {
46    return;
47  }
48  const int b_stride = int(ceil(uBlock.input_tensor_dims.y / 4.0));
49  const ivec3 src_pos = ivec3(pos.x, pos.y, pos.z * b_stride);
50  // tail case, padded zeros in memory if tensor's channel dim % 4 != 0
51  uint tail_case_size = uBlock.input_tensor_dims.y % 4;
52  if (tail_case_size == 0) {
53    tail_case_size = 4;
54  }
55  // Calculate the denominator for the whole dimension.
56  // For numerical stability to avoid floating point overflow,
57  // we leverage the translation invariance of the softmax function,
58  // subtracting every element along channel by the maximum element along
59  // channel. find the maximum element
60  float max_element = texelFetch(uInput, src_pos, 0)[0];
61  for (int c = 0; c < b_stride - 1; c++) {
62    const vec4 c_texel =
63        texelFetch(uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + c), 0);
64    for (int t = 0; t < 4; t++) {
65      if (c_texel[t] > max_element) {
66        max_element = c_texel[t];
67      }
68    }
69  }
70  vec4 c_texel = texelFetch(
71      uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + b_stride - 1), 0);
72  for (int t = 0; t < tail_case_size; t++) {
73    if (c_texel[t] > max_element) {
74      max_element = c_texel[t];
75    }
76  }
77  // Calculate the denominator.
78  float denominator = 0;
79  for (int c = 0; c < b_stride - 1; c++) {
80    const vec4 c_texel =
81        texelFetch(uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + c), 0);
82    for (int t = 0; t < 4; t++) {
83      denominator += exp(c_texel[t] - max_element);
84    }
85  }
86  c_texel = texelFetch(
87      uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + b_stride - 1), 0);
88  for (int t = 0; t < tail_case_size; t++) {
89    denominator += exp(c_texel[t] - max_element);
90  }
91  // Calculate every final channel element.
92  for (int c = 0; c < b_stride; c++) {
93    const ivec3 dst_pos = ivec3(src_pos.x, src_pos.y, src_pos.z + c);
94    const vec4 numerator = exp(texelFetch(uInput, dst_pos, 0) - max_element);
95    imageStore(uOutput, dst_pos, numerator / denominator);
96  }
97}
98