• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13layout(std430) buffer;
14
15$if MEMTYPE == "ubo":
16    ${layout_declare_ubo(0, "vec4", "A")}
17$elif MEMTYPE == "buffer":
18    ${layout_declare_buffer(0, "r", "A", DTYPE, "PRECISION", False)}
19$else:
20    ${layout_declare_buffer(0, "r", "_", DTYPE, "PRECISION", False)}
21
22${layout_declare_buffer(1, "w", "B", DTYPE, "PRECISION", False)}
23
24layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
25
26layout(constant_id = 3) const int niter = 1;
27layout(constant_id = 4) const int nvec = 1;
28layout(constant_id = 5) const int local_group_size = 1;
29// The address mask works as a modulo because x % 2^n == x & (2^n - 1).
30// This will help us limit address accessing to a specific set of unique
31// addresses depending on the access size we want to measure.
32layout(constant_id = 6) const int addr_mask = 1;
33layout(constant_id = 7) const int workgroup_width = 1;
34
35$if MEMTYPE == "shared":
36    shared vec4 A[nvec];
37
38void main() {
39
40    $if MEMTYPE == "shared":
41        A[gl_LocalInvocationID[0]][0] = gl_LocalInvocationID[0];
42        memoryBarrierShared();
43
44    vec4 sum = vec4(0);
45    uint offset = (gl_WorkGroupID[0] * workgroup_width  + gl_LocalInvocationID[0]) & addr_mask;
46
47    int i = 0;
48    for (; i < niter; ++i){
49      $for j in range(int(NUNROLL)):
50          sum *= A[offset];
51
52          // On each unroll, a new unique address will be accessed through the offset,
53          // limited by the address mask to a specific set of unique addresses
54          offset = (offset + local_group_size) & addr_mask;
55    }
56
57    // This is to ensure no compiler optimizations occur
58    vec4 zero = vec4(i>>31);
59
60    B[gl_LocalInvocationID[0]] = sum + zero;
61}
62