1#version 450 core 2#extension GL_KHR_memory_scope_semantics : enable 3#extension GL_KHR_cooperative_matrix : enable 4#extension GL_EXT_shader_explicit_arithmetic_types : enable 5#extension GL_NV_cooperative_matrix2 : enable 6#extension GL_EXT_buffer_reference : enable 7 8layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in; 9 10buffer BufType { 11 float16_t x[]; 12} Buf; 13 14layout(buffer_reference, std430, buffer_reference_align = 2) buffer fp16Buf { 15 float16_t f; 16}; 17 18 19float16_t decode0(const in fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; } 20float16_t decode1(const fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; } 21float16_t decode2(in fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; } 22float16_t decode3(fp16Buf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; } 23float16_t decode4(const in fp16Buf b, const uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; } 24float16_t decode5(const in fp16Buf b, in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; } 25float16_t decode6(const in fp16Buf b, uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) { return b.f; } 26float16_t decode7(const in fp16Buf b, const in uint32_t blockCoords[2], const uint32_t coordInBlock[2]) { return b.f; } 27float16_t decode8(const in fp16Buf b, const in uint32_t blockCoords[2], in uint32_t coordInBlock[2]) { return b.f; } 28float16_t decode9(const in fp16Buf b, const in uint32_t blockCoords[2], uint32_t coordInBlock[2]) { return b.f; } 29float16_t decode10(const in uint32_t b, const in uint16_t blockCoords[2], const in uint16_t coordInBlock[2]) { return float16_t(0); } 30float16_t decode11(const in fp16Buf b, const in uint32_t blockCoords, const in uint32_t coordInBlock) { return float16_t(0); } 31 32struct S { 33 f16vec2 x; 34}; 35 36float16_t combineSum(const in float16_t a, const in float16_t b) { return a + b; } 37float16_t combineSum2(float16_t a, float16_t b) { return a + b; } 38 39layout(constant_id = 0) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; 40 41float16_t relu(const in uint32_t row, const in uint32_t col, const in float16_t x) { return max(x, float16_t(0)); } 42float16_t add(const in uint32_t row, const in uint32_t col, const in float16_t x, const in float16_t y) { return x+y; } 43float32_t perelemf32(const in uint32_t row, const in uint32_t col, const in float16_t x) { return float32_t(x); } 44 45void main() 46{ 47 coopmat<float16_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator> A; 48 49 tensorLayoutNV<2> t = createTensorLayoutNV(2); 50 51 coopMatLoadTensorNV(A, Buf.x, 0, t, decode0); 52 coopMatLoadTensorNV(A, Buf.x, 0, t, decode1); 53 coopMatLoadTensorNV(A, Buf.x, 0, t, decode2); 54 coopMatLoadTensorNV(A, Buf.x, 0, t, decode3); 55 coopMatLoadTensorNV(A, Buf.x, 0, t, decode4); 56 coopMatLoadTensorNV(A, Buf.x, 0, t, decode5); 57 coopMatLoadTensorNV(A, Buf.x, 0, t, decode6); 58 coopMatLoadTensorNV(A, Buf.x, 0, t, decode7); 59 coopMatLoadTensorNV(A, Buf.x, 0, t, decode8); 60 coopMatLoadTensorNV(A, Buf.x, 0, t, decode9); 61 coopMatLoadTensorNV(A, Buf.x, 0, t, decode10); 62 coopMatLoadTensorNV(A, Buf.x, 0, t, decode11); 63 64 coopmat<float32_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator> Af32; 65 66 coopMatReduceNV(A, A, gl_CooperativeMatrixReduceRowNV, combineSum); 67 coopMatReduceNV(A, A, gl_CooperativeMatrixReduceRowNV, combineSum2); 68 coopMatReduceNV(Af32, Af32, gl_CooperativeMatrixReduceRowNV, combineSum); 69 70 coopMatPerElementNV(A, A, relu); 71 coopMatPerElementNV(A, A, add, float16_t(1.0)); 72 coopMatPerElementNV(A, A, add, coopmat<float16_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator>(1.0)); 73 coopMatPerElementNV(A, A, add, float32_t(1.0)); 74 coopMatPerElementNV(A, A, add, coopmat<float32_t, gl_ScopeWorkgroup, 64, 32, gl_MatrixUseAccumulator>(1.0)); 75 coopMatPerElementNV(A, A, add); 76 coopMatPerElementNV(A, A, add, float16_t(1.0), float16_t(1.0)); 77 coopMatPerElementNV(Af32, A, perelemf32); 78 coopMatPerElementNV(Af32, A, relu); 79} 80