• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Ad-hoc test that the wave op masks work as expected.
2 #include <glm/glm.hpp>
3 #include <assert.h>
4 
5 using namespace glm;
6 
7 static uvec4 gl_SubgroupEqMask;
8 static uvec4 gl_SubgroupGeMask;
9 static uvec4 gl_SubgroupGtMask;
10 static uvec4 gl_SubgroupLeMask;
11 static uvec4 gl_SubgroupLtMask;
12 using uint4 = uvec4;
13 
test_main(unsigned wave_index)14 static void test_main(unsigned wave_index)
15 {
16 	const auto WaveGetLaneIndex = [&]() { return wave_index; };
17 
18 	gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));
19 	if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;
20 	if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;
21 	if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;
22 	if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;
23 	gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);
24 	if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;
25 	if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;
26 	if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;
27 	if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;
28 	if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;
29 	if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;
30 	uint gt_lane_index = WaveGetLaneIndex() + 1;
31 	gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);
32 	if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;
33 	if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;
34 	if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;
35 	if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;
36 	if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;
37 	if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;
38 	if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;
39 	uint le_lane_index = WaveGetLaneIndex() + 1;
40 	gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;
41 	if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;
42 	if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;
43 	if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;
44 	if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;
45 	if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;
46 	if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;
47 	if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;
48 	gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;
49 	if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;
50 	if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;
51 	if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;
52 	if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;
53 	if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;
54 	if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;
55 }
56 
main()57 int main()
58 {
59 	for (unsigned subgroup_id = 0; subgroup_id < 128; subgroup_id++)
60 	{
61 		test_main(subgroup_id);
62 
63 		for (unsigned bit = 0; bit < 128; bit++)
64 		{
65 			assert(bool(gl_SubgroupEqMask[bit / 32] & (1u << (bit & 31))) == (bit == subgroup_id));
66 			assert(bool(gl_SubgroupGtMask[bit / 32] & (1u << (bit & 31))) == (bit > subgroup_id));
67 			assert(bool(gl_SubgroupGeMask[bit / 32] & (1u << (bit & 31))) == (bit >= subgroup_id));
68 			assert(bool(gl_SubgroupLtMask[bit / 32] & (1u << (bit & 31))) == (bit < subgroup_id));
69 			assert(bool(gl_SubgroupLeMask[bit / 32] & (1u << (bit & 31))) == (bit <= subgroup_id));
70 		}
71 	}
72 }
73 
74