• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#pragma clang diagnostic ignored "-Wmissing-prototypes"
2
3#include <metal_stdlib>
4#include <simd/simd.h>
5
6using namespace metal;
7
8struct main0_out
9{
10    float FragColor [[color(0)]];
11};
12
13template<typename T>
14inline T spvSubgroupBroadcast(T value, ushort lane)
15{
16    return simd_broadcast(value, lane);
17}
18
19template<>
20inline bool spvSubgroupBroadcast(bool value, ushort lane)
21{
22    return !!simd_broadcast((ushort)value, lane);
23}
24
25template<uint N>
26inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)
27{
28    return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);
29}
30
31template<typename T>
32inline T spvSubgroupBroadcastFirst(T value)
33{
34    return simd_broadcast_first(value);
35}
36
37template<>
38inline bool spvSubgroupBroadcastFirst(bool value)
39{
40    return !!simd_broadcast_first((ushort)value);
41}
42
43template<uint N>
44inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)
45{
46    return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);
47}
48
49inline uint4 spvSubgroupBallot(bool value)
50{
51    simd_vote vote = simd_ballot(value);
52    // simd_ballot() returns a 64-bit integer-like object, but
53    // SPIR-V callers expect a uint4. We must convert.
54    // FIXME: This won't include higher bits if Apple ever supports
55    // 128 lanes in an SIMD-group.
56    return uint4(as_type<uint2>((simd_vote::vote_t)vote), 0, 0);
57}
58
59inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
60{
61    return !!extract_bits(ballot[bit / 32], bit % 32, 1);
62}
63
64inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)
65{
66    uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
67    ballot &= mask;
68    return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
69}
70
71inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)
72{
73    uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
74    ballot &= mask;
75    return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
76}
77
78inline uint spvPopCount4(uint4 ballot)
79{
80    return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
81}
82
83inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)
84{
85    uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
86    return spvPopCount4(ballot & mask);
87}
88
89inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
90{
91    uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
92    return spvPopCount4(ballot & mask);
93}
94
95inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
96{
97    uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
98    return spvPopCount4(ballot & mask);
99}
100
101template<typename T>
102inline bool spvSubgroupAllEqual(T value)
103{
104    return simd_all(all(value == simd_broadcast_first(value)));
105}
106
107template<>
108inline bool spvSubgroupAllEqual(bool value)
109{
110    return simd_all(value) || !simd_any(value);
111}
112
113template<uint N>
114inline bool spvSubgroupAllEqual(vec<bool, N> value)
115{
116    return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));
117}
118
119template<typename T>
120inline T spvSubgroupShuffle(T value, ushort lane)
121{
122    return simd_shuffle(value, lane);
123}
124
125template<>
126inline bool spvSubgroupShuffle(bool value, ushort lane)
127{
128    return !!simd_shuffle((ushort)value, lane);
129}
130
131template<uint N>
132inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)
133{
134    return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);
135}
136
137template<typename T>
138inline T spvSubgroupShuffleXor(T value, ushort mask)
139{
140    return simd_shuffle_xor(value, mask);
141}
142
143template<>
144inline bool spvSubgroupShuffleXor(bool value, ushort mask)
145{
146    return !!simd_shuffle_xor((ushort)value, mask);
147}
148
149template<uint N>
150inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)
151{
152    return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);
153}
154
155template<typename T>
156inline T spvSubgroupShuffleUp(T value, ushort delta)
157{
158    return simd_shuffle_up(value, delta);
159}
160
161template<>
162inline bool spvSubgroupShuffleUp(bool value, ushort delta)
163{
164    return !!simd_shuffle_up((ushort)value, delta);
165}
166
167template<uint N>
168inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)
169{
170    return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);
171}
172
173template<typename T>
174inline T spvSubgroupShuffleDown(T value, ushort delta)
175{
176    return simd_shuffle_down(value, delta);
177}
178
179template<>
180inline bool spvSubgroupShuffleDown(bool value, ushort delta)
181{
182    return !!simd_shuffle_down((ushort)value, delta);
183}
184
185template<uint N>
186inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)
187{
188    return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);
189}
190
191template<typename T>
192inline T spvQuadBroadcast(T value, uint lane)
193{
194    return quad_broadcast(value, lane);
195}
196
197template<>
198inline bool spvQuadBroadcast(bool value, uint lane)
199{
200    return !!quad_broadcast((ushort)value, lane);
201}
202
203template<uint N>
204inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)
205{
206    return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);
207}
208
209template<typename T>
210inline T spvQuadSwap(T value, uint dir)
211{
212    return quad_shuffle_xor(value, dir + 1);
213}
214
215template<>
216inline bool spvQuadSwap(bool value, uint dir)
217{
218    return !!quad_shuffle_xor((ushort)value, dir + 1);
219}
220
221template<uint N>
222inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)
223{
224    return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);
225}
226
227fragment main0_out main0(uint gl_SubgroupSize [[threads_per_simdgroup]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
228{
229    main0_out out = {};
230    uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID >= 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
231    uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
232    uint4 gl_SubgroupGtMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
233    uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
234    uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
235    out.FragColor = float(gl_SubgroupSize);
236    out.FragColor = float(gl_SubgroupInvocationID);
237    bool _24 = simd_is_first();
238    bool elected = _24;
239    out.FragColor = float4(gl_SubgroupEqMask).x;
240    out.FragColor = float4(gl_SubgroupGeMask).x;
241    out.FragColor = float4(gl_SubgroupGtMask).x;
242    out.FragColor = float4(gl_SubgroupLeMask).x;
243    out.FragColor = float4(gl_SubgroupLtMask).x;
244    float4 broadcasted = spvSubgroupBroadcast(float4(10.0), 8u);
245    bool2 broadcasted_bool = spvSubgroupBroadcast(bool2(true), 8u);
246    float3 first = spvSubgroupBroadcastFirst(float3(20.0));
247    bool4 first_bool = spvSubgroupBroadcastFirst(bool4(false));
248    uint4 ballot_value = spvSubgroupBallot(true);
249    bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
250    bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
251    uint bit_count = spvSubgroupBallotBitCount(ballot_value, gl_SubgroupSize);
252    uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
253    uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
254    uint lsb = spvSubgroupBallotFindLSB(ballot_value, gl_SubgroupSize);
255    uint msb = spvSubgroupBallotFindMSB(ballot_value, gl_SubgroupSize);
256    uint shuffled = spvSubgroupShuffle(10u, 8u);
257    bool shuffled_bool = spvSubgroupShuffle(true, 9u);
258    uint shuffled_xor = spvSubgroupShuffleXor(30u, 8u);
259    bool shuffled_xor_bool = spvSubgroupShuffleXor(false, 9u);
260    uint shuffled_up = spvSubgroupShuffleUp(20u, 4u);
261    bool shuffled_up_bool = spvSubgroupShuffleUp(true, 4u);
262    uint shuffled_down = spvSubgroupShuffleDown(20u, 4u);
263    bool shuffled_down_bool = spvSubgroupShuffleDown(false, 4u);
264    bool has_all = simd_all(true);
265    bool has_any = simd_any(true);
266    bool has_equal = spvSubgroupAllEqual(0);
267    has_equal = spvSubgroupAllEqual(true);
268    has_equal = spvSubgroupAllEqual(float3(0.0, 1.0, 2.0));
269    has_equal = spvSubgroupAllEqual(bool4(true, true, false, true));
270    float4 added = simd_sum(float4(20.0));
271    int4 iadded = simd_sum(int4(20));
272    float4 multiplied = simd_product(float4(20.0));
273    int4 imultiplied = simd_product(int4(20));
274    float4 lo = simd_min(float4(20.0));
275    float4 hi = simd_max(float4(20.0));
276    int4 slo = simd_min(int4(20));
277    int4 shi = simd_max(int4(20));
278    uint4 ulo = simd_min(uint4(20u));
279    uint4 uhi = simd_max(uint4(20u));
280    uint4 anded = simd_and(ballot_value);
281    uint4 ored = simd_or(ballot_value);
282    uint4 xored = simd_xor(ballot_value);
283    added = simd_prefix_inclusive_sum(added);
284    iadded = simd_prefix_inclusive_sum(iadded);
285    multiplied = simd_prefix_inclusive_product(multiplied);
286    imultiplied = simd_prefix_inclusive_product(imultiplied);
287    added = simd_prefix_exclusive_sum(multiplied);
288    multiplied = simd_prefix_exclusive_product(multiplied);
289    iadded = simd_prefix_exclusive_sum(imultiplied);
290    imultiplied = simd_prefix_exclusive_product(imultiplied);
291    added = quad_sum(added);
292    multiplied = quad_product(multiplied);
293    iadded = quad_sum(iadded);
294    imultiplied = quad_product(imultiplied);
295    lo = quad_min(lo);
296    hi = quad_max(hi);
297    ulo = quad_min(ulo);
298    uhi = quad_max(uhi);
299    slo = quad_min(slo);
300    shi = quad_max(shi);
301    anded = quad_and(anded);
302    ored = quad_or(ored);
303    xored = quad_xor(xored);
304    float4 swap_horiz = spvQuadSwap(float4(20.0), 0u);
305    bool4 swap_horiz_bool = spvQuadSwap(bool4(true), 0u);
306    float4 swap_vertical = spvQuadSwap(float4(20.0), 1u);
307    bool4 swap_vertical_bool = spvQuadSwap(bool4(true), 1u);
308    float4 swap_diagonal = spvQuadSwap(float4(20.0), 2u);
309    bool4 swap_diagonal_bool = spvQuadSwap(bool4(true), 2u);
310    float4 quad_broadcast0 = spvQuadBroadcast(float4(20.0), 3u);
311    bool4 quad_broadcast_bool = spvQuadBroadcast(bool4(true), 3u);
312    return out;
313}
314
315