• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#version 450 core
2#extension GL_KHR_memory_scope_semantics : enable
3#extension GL_KHR_cooperative_matrix : enable
4#extension GL_EXT_shader_explicit_arithmetic_types : enable
5#extension GL_NV_cooperative_matrix2 : enable
6#extension GL_EXT_buffer_reference : enable
7
8layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
9
10buffer BufType {
11   float16_t x[];
12} Buf;
13
14layout(buffer_reference, std430, buffer_reference_align = 2) buffer fp16Buf {
15    float16_t f;
16};
17
18
19float16_t decode0(const in fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; }
20float16_t decode1(const fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; }
21float16_t decode2(in fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; }
22float16_t decode3(fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; }
23float16_t decode4(const in fp16Buf b, const uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; }
24float16_t decode5(const in fp16Buf b, in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; }
25float16_t decode6(const in fp16Buf b, uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; }
26float16_t decode7(const in fp16Buf b, const in uint32_t blockCoords[2], const uint32_t coordInBlock[2]) { return b.f; }
27float16_t decode8(const in fp16Buf b, const in uint32_t blockCoords[2], in uint32_t coordInBlock[2]) { return b.f; }
28float16_t decode9(const in fp16Buf b, const in uint32_t blockCoords[2], uint32_t coordInBlock[2]) { return b.f; }
29float16_t decode10(const in uint32_t b, const in uint16_t blockCoords[2], const in uint16_t coordInBlock[2]) { return float16_t(0); }
30float16_t decode11(const in fp16Buf b, const in uint32_t blockCoords, const in uint32_t coordInBlock) { return float16_t(0); }
31
32struct S {
33   f16vec2 x;
34};
35
36float16_t combineSum(const in float16_t a, const in float16_t b) { return a + b; }
37float16_t combineSum2(float16_t a, float16_t b) { return a + b; }
38
39layout(constant_id = 0) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
40
41float16_t relu(const in uint32_t row, const in uint32_t col, const in float16_t x) { return max(x, float16_t(0)); }
42float16_t add(const in uint32_t row, const in uint32_t col, const in float16_t x, const in float16_t y) { return x+y; }
43float32_t perelemf32(const in uint32_t row, const in uint32_t col, const in float16_t x) { return float32_t(x); }
44
45void main()
46{
47    coopmat<float16_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator> A;
48
49    tensorLayoutNV<2> t = createTensorLayoutNV(2);
50
51    coopMatLoadTensorNV(A, Buf.x, 0, t, decode0);
52    coopMatLoadTensorNV(A, Buf.x, 0, t, decode1);
53    coopMatLoadTensorNV(A, Buf.x, 0, t, decode2);
54    coopMatLoadTensorNV(A, Buf.x, 0, t, decode3);
55    coopMatLoadTensorNV(A, Buf.x, 0, t, decode4);
56    coopMatLoadTensorNV(A, Buf.x, 0, t, decode5);
57    coopMatLoadTensorNV(A, Buf.x, 0, t, decode6);
58    coopMatLoadTensorNV(A, Buf.x, 0, t, decode7);
59    coopMatLoadTensorNV(A, Buf.x, 0, t, decode8);
60    coopMatLoadTensorNV(A, Buf.x, 0, t, decode9);
61    coopMatLoadTensorNV(A, Buf.x, 0, t, decode10);
62    coopMatLoadTensorNV(A, Buf.x, 0, t, decode11);
63
64    coopmat<float32_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator> Af32;
65
66    coopMatReduceNV(A, A, gl_CooperativeMatrixReduceRowNV, combineSum);
67    coopMatReduceNV(A, A, gl_CooperativeMatrixReduceRowNV, combineSum2);
68    coopMatReduceNV(Af32, Af32, gl_CooperativeMatrixReduceRowNV, combineSum);
69
70    coopMatPerElementNV(A, A, relu);
71    coopMatPerElementNV(A, A, add, float16_t(1.0));
72    coopMatPerElementNV(A, A, add, coopmat<float16_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator>(1.0));
73    coopMatPerElementNV(A, A, add, float32_t(1.0));
74    coopMatPerElementNV(A, A, add, coopmat<float32_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator>(1.0));
75    coopMatPerElementNV(A, A, add);
76    coopMatPerElementNV(A, A, add, float16_t(1.0), float16_t(1.0));
77    coopMatPerElementNV(Af32, A, perelemf32);
78    coopMatPerElementNV(Af32, A, relu);
79}
80