• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#define PRECISION ${PRECISION}
3#define FORMAT ${FORMAT}
4
5// To convince the SPIR-V compiler to unroll the loops optimally, need this
6// macro
7#define FOUR 4
8layout(std430) buffer;
9
10/* Qualifiers: layout - storage - precision - memory */
11
12layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
13layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1;
14layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2;
15layout(set = 0, binding = 3) uniform PRECISION restrict Block {
16  ivec4 shader_extents;
17}
18uBlock;
19
20layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
21
22void main() {
23  const ivec3 pos = ivec3(gl_GlobalInvocationID);
24  if (all(lessThan(pos, uBlock.shader_extents.xyz))) {
25    // we avoid mat4 and vec4 usage here as they compile to much less efficient
26    // SPIR-V
27    float results[FOUR][FOUR];
28    for (int i = 0; i < FOUR; i++) {
29      for (int j = 0; j < FOUR; j++) {
30        results[i][j] = 0;
31      }
32    }
33
34    for (int j = 0; j < uBlock.shader_extents.w; j++) {
35      // we may potentially read out of bounds, but (0, 0, 0, 0) will be sampled
36      // safely read and cache 4x4 tile of uM1 (4 adjacent rows)
37      vec4 uM1_partial_rows[FOUR];
38      vec4 uM2_partial_cols[FOUR];
39
40      for (int k = 0; k < FOUR; k++) {
41        const int pos_y_offset = (FOUR * pos.y) + k;
42        const ivec3 pos_rd = ivec3(j, pos_y_offset, pos.z);
43        uM1_partial_rows[k] = texelFetch(uM1, pos_rd, 0);
44      }
45      // read and cache 4x4 tile of uM2 (4 adjacent columns)
46      for (int k = 0; k < FOUR; k++) {
47        const int pos_x_offset = (FOUR * pos.x) + k;
48        const ivec3 pos_rd = ivec3(pos_x_offset, j, pos.z);
49        uM2_partial_cols[k] = texelFetch(uM2, pos_rd, 0);
50      }
51      // perform partial dot products and add partial result to results
52      for (int idx_r = 0; idx_r < FOUR; idx_r++) {
53        for (int idx_c = 0; idx_c < FOUR; idx_c++) {
54          results[idx_r][idx_c] +=
55              dot(uM1_partial_rows[idx_r], uM2_partial_cols[idx_c]);
56        }
57      }
58    }
59    // results is in transposed order w.r.t. the desired output
60    for (int idx_c = 0; idx_c < FOUR; idx_c++) {
61      for (int idx_r = 0; idx_r < FOUR; idx_r++) {
62        const ivec3 out_pos =
63            ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z);
64        imageStore(
65            uOutput, out_pos, vec4(results[idx_c][idx_r], 0.0, 0.0, 0.0));
66      }
67    }
68  }
69}
70