1 // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "SpirvShader.hpp"
16
17 #include <spirv/unified1/spirv.hpp>
18
19 namespace sw {
20
21 struct SpirvShader::Impl::Group
22 {
23 // Template function to perform a binary operation.
24 // |TYPE| should be the type of the binary operation (as a SIMD::<ScalarType>).
25 // |I| should be a type suitable to initialize the identity value.
26 // |APPLY| should be a callable object that takes two RValue<TYPE> parameters
27 // and returns a new RValue<TYPE> corresponding to the operation's result.
28 template<typename TYPE, typename I, typename APPLY>
BinaryOperationsw::SpirvShader::Impl::Group29 static void BinaryOperation(
30 const SpirvShader *shader,
31 const SpirvShader::InsnIterator &insn,
32 const SpirvShader::EmitState *state,
33 Intermediate &dst,
34 const I identityValue,
35 APPLY &&apply)
36 {
37 SpirvShader::Operand value(shader, state, insn.word(5));
38 auto &type = shader->getType(SpirvShader::Type::ID(insn.word(1)));
39 for(auto i = 0u; i < type.componentCount; i++)
40 {
41 auto mask = As<SIMD::UInt>(state->activeLaneMask()); // Considers helper invocations active. See b/151137030
42 auto identity = TYPE(identityValue);
43 SIMD::UInt v_uint = (value.UInt(i) & mask) | (As<SIMD::UInt>(identity) & ~mask);
44 TYPE v = As<TYPE>(v_uint);
45 switch(spv::GroupOperation(insn.word(4)))
46 {
47 case spv::GroupOperationReduce:
48 {
49 // NOTE: floating-point add and multiply are not really commutative so
50 // ensure that all values in the final lanes are identical
51 TYPE v2 = apply(v.xxzz, v.yyww); // [xy] [xy] [zw] [zw]
52 TYPE v3 = apply(v2.xxxx, v2.zzzz); // [xyzw] [xyzw] [xyzw] [xyzw]
53 dst.move(i, v3);
54 break;
55 }
56 case spv::GroupOperationInclusiveScan:
57 {
58 TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
59 TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
60 dst.move(i, v3);
61 break;
62 }
63 case spv::GroupOperationExclusiveScan:
64 {
65 TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
66 TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
67 auto v4 = Shuffle(v3, identity, 0x4012 /* [id, v3.x, v3.y, v3.z] */); // [i] [x] [xy] [xyz]
68 dst.move(i, v4);
69 break;
70 }
71 default:
72 UNSUPPORTED("EmitGroupNonUniform op: %s Group operation: %d",
73 SpirvShader::OpcodeName(type.opcode()), insn.word(4));
74 }
75 }
76 }
77 };
78
EmitGroupNonUniform(InsnIterator insn,EmitState * state) const79 SpirvShader::EmitResult SpirvShader::EmitGroupNonUniform(InsnIterator insn, EmitState *state) const
80 {
81 static_assert(SIMD::Width == 4, "EmitGroupNonUniform makes many assumptions that the SIMD vector width is 4");
82
83 auto &type = getType(Type::ID(insn.word(1)));
84 Object::ID resultId = insn.word(2);
85 auto scope = spv::Scope(GetConstScalarInt(insn.word(3)));
86 ASSERT_MSG(scope == spv::ScopeSubgroup, "Scope for Non Uniform Group Operations must be Subgroup for Vulkan 1.1");
87
88 auto &dst = state->createIntermediate(resultId, type.componentCount);
89
90 switch(insn.opcode())
91 {
92 case spv::OpGroupNonUniformElect:
93 {
94 // Result is true only in the active invocation with the lowest id
95 // in the group, otherwise result is false.
96 SIMD::Int active = state->activeLaneMask(); // Considers helper invocations active. See b/151137030
97 // TODO: Would be nice if we could write this as:
98 // elect = active & ~(active.Oxyz | active.OOxy | active.OOOx)
99 auto v0111 = SIMD::Int(0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF);
100 auto elect = active & ~(v0111 & (active.xxyz | active.xxxy | active.xxxx));
101 dst.move(0, elect);
102 break;
103 }
104
105 case spv::OpGroupNonUniformAll:
106 {
107 Operand predicate(this, state, insn.word(4));
108 dst.move(0, AndAll(predicate.UInt(0) | ~As<SIMD::UInt>(state->activeLaneMask()))); // Considers helper invocations active. See b/151137030
109 break;
110 }
111
112 case spv::OpGroupNonUniformAny:
113 {
114 Operand predicate(this, state, insn.word(4));
115 dst.move(0, OrAll(predicate.UInt(0) & As<SIMD::UInt>(state->activeLaneMask()))); // Considers helper invocations active. See b/151137030
116 break;
117 }
118
119 case spv::OpGroupNonUniformAllEqual:
120 {
121 Operand value(this, state, insn.word(4));
122 auto res = SIMD::UInt(0xffffffff);
123 SIMD::UInt active = As<SIMD::UInt>(state->activeLaneMask()); // Considers helper invocations active. See b/151137030
124 SIMD::UInt inactive = ~active;
125 for(auto i = 0u; i < type.componentCount; i++)
126 {
127 SIMD::UInt v = value.UInt(i) & active;
128 SIMD::UInt filled = v;
129 for(int j = 0; j < SIMD::Width - 1; j++)
130 {
131 filled |= filled.yzwx & inactive; // Populate inactive 'holes' with a live value
132 }
133 res &= AndAll(CmpEQ(filled.xyzw, filled.yzwx));
134 }
135 dst.move(0, res);
136 break;
137 }
138
139 case spv::OpGroupNonUniformBroadcast:
140 {
141 auto valueId = Object::ID(insn.word(4));
142 auto idId = Object::ID(insn.word(5));
143 Operand value(this, state, valueId);
144
145 // Decide between the fast path for constants and the slow path for
146 // intermediates.
147 if(getObject(idId).kind == SpirvShader::Object::Kind::Constant)
148 {
149 auto id = SIMD::Int(GetConstScalarInt(insn.word(5)));
150 auto mask = CmpEQ(id, SIMD::Int(0, 1, 2, 3));
151 for(auto i = 0u; i < type.componentCount; i++)
152 {
153 dst.move(i, OrAll(value.Int(i) & mask));
154 }
155 }
156 else
157 {
158 Operand id(this, state, idId);
159
160 SIMD::UInt active = As<SIMD::UInt>(state->activeLaneMask()); // Considers helper invocations active. See b/151137030
161 SIMD::UInt inactive = ~active;
162 SIMD::UInt filled = id.UInt(0) & active;
163
164 for(int j = 0; j < SIMD::Width - 1; j++)
165 {
166 filled |= filled.yzwx & inactive; // Populate inactive 'holes' with a live value
167 }
168
169 auto mask = CmpEQ(filled, SIMD::UInt(0, 1, 2, 3));
170
171 for(uint32_t i = 0u; i < type.componentCount; i++)
172 {
173 dst.move(i, OrAll(value.UInt(i) & mask));
174 }
175 }
176 break;
177 }
178
179 case spv::OpGroupNonUniformBroadcastFirst:
180 {
181 auto valueId = Object::ID(insn.word(4));
182 Operand value(this, state, valueId);
183 // Result is true only in the active invocation with the lowest id
184 // in the group, otherwise result is false.
185 SIMD::Int active = state->activeLaneMask(); // Considers helper invocations active. See b/151137030
186 // TODO: Would be nice if we could write this as:
187 // elect = active & ~(active.Oxyz | active.OOxy | active.OOOx)
188 auto v0111 = SIMD::Int(0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF);
189 auto elect = active & ~(v0111 & (active.xxyz | active.xxxy | active.xxxx));
190 for(auto i = 0u; i < type.componentCount; i++)
191 {
192 dst.move(i, OrAll(value.Int(i) & elect));
193 }
194 break;
195 }
196
197 case spv::OpGroupNonUniformBallot:
198 {
199 ASSERT(type.componentCount == 4);
200 Operand predicate(this, state, insn.word(4));
201 dst.move(0, SIMD::Int(SignMask(state->activeLaneMask() & predicate.Int(0)))); // Considers helper invocations active. See b/151137030
202 dst.move(1, SIMD::Int(0));
203 dst.move(2, SIMD::Int(0));
204 dst.move(3, SIMD::Int(0));
205 break;
206 }
207
208 case spv::OpGroupNonUniformInverseBallot:
209 {
210 auto valueId = Object::ID(insn.word(4));
211 ASSERT(type.componentCount == 1);
212 ASSERT(getType(getObject(valueId)).componentCount == 4);
213 Operand value(this, state, valueId);
214 auto bit = (value.Int(0) >> SIMD::Int(0, 1, 2, 3)) & SIMD::Int(1);
215 dst.move(0, -bit);
216 break;
217 }
218
219 case spv::OpGroupNonUniformBallotBitExtract:
220 {
221 auto valueId = Object::ID(insn.word(4));
222 auto indexId = Object::ID(insn.word(5));
223 ASSERT(type.componentCount == 1);
224 ASSERT(getType(getObject(valueId)).componentCount == 4);
225 ASSERT(getType(getObject(indexId)).componentCount == 1);
226 Operand value(this, state, valueId);
227 Operand index(this, state, indexId);
228 auto vecIdx = index.Int(0) / SIMD::Int(32);
229 auto bitIdx = index.Int(0) & SIMD::Int(31);
230 auto bits = (value.Int(0) & CmpEQ(vecIdx, SIMD::Int(0))) |
231 (value.Int(1) & CmpEQ(vecIdx, SIMD::Int(1))) |
232 (value.Int(2) & CmpEQ(vecIdx, SIMD::Int(2))) |
233 (value.Int(3) & CmpEQ(vecIdx, SIMD::Int(3)));
234 dst.move(0, -((bits >> bitIdx) & SIMD::Int(1)));
235 break;
236 }
237
238 case spv::OpGroupNonUniformBallotBitCount:
239 {
240 auto operation = spv::GroupOperation(insn.word(4));
241 auto valueId = Object::ID(insn.word(5));
242 ASSERT(type.componentCount == 1);
243 ASSERT(getType(getObject(valueId)).componentCount == 4);
244 Operand value(this, state, valueId);
245 switch(operation)
246 {
247 case spv::GroupOperationReduce:
248 dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(15)));
249 break;
250 case spv::GroupOperationInclusiveScan:
251 dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(1, 3, 7, 15)));
252 break;
253 case spv::GroupOperationExclusiveScan:
254 dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(0, 1, 3, 7)));
255 break;
256 default:
257 UNSUPPORTED("GroupOperation %d", int(operation));
258 }
259 break;
260 }
261
262 case spv::OpGroupNonUniformBallotFindLSB:
263 {
264 auto valueId = Object::ID(insn.word(4));
265 ASSERT(type.componentCount == 1);
266 ASSERT(getType(getObject(valueId)).componentCount == 4);
267 Operand value(this, state, valueId);
268 dst.move(0, Cttz(value.UInt(0) & SIMD::UInt(15), true));
269 break;
270 }
271
272 case spv::OpGroupNonUniformBallotFindMSB:
273 {
274 auto valueId = Object::ID(insn.word(4));
275 ASSERT(type.componentCount == 1);
276 ASSERT(getType(getObject(valueId)).componentCount == 4);
277 Operand value(this, state, valueId);
278 dst.move(0, SIMD::UInt(31) - Ctlz(value.UInt(0) & SIMD::UInt(15), false));
279 break;
280 }
281
282 case spv::OpGroupNonUniformShuffle:
283 {
284 Operand value(this, state, insn.word(4));
285 Operand id(this, state, insn.word(5));
286 auto x = CmpEQ(SIMD::Int(0), id.Int(0));
287 auto y = CmpEQ(SIMD::Int(1), id.Int(0));
288 auto z = CmpEQ(SIMD::Int(2), id.Int(0));
289 auto w = CmpEQ(SIMD::Int(3), id.Int(0));
290 for(auto i = 0u; i < type.componentCount; i++)
291 {
292 SIMD::Int v = value.Int(i);
293 dst.move(i, (x & v.xxxx) | (y & v.yyyy) | (z & v.zzzz) | (w & v.wwww));
294 }
295 break;
296 }
297
298 case spv::OpGroupNonUniformShuffleXor:
299 {
300 Operand value(this, state, insn.word(4));
301 Operand mask(this, state, insn.word(5));
302 auto x = CmpEQ(SIMD::Int(0), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
303 auto y = CmpEQ(SIMD::Int(1), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
304 auto z = CmpEQ(SIMD::Int(2), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
305 auto w = CmpEQ(SIMD::Int(3), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
306 for(auto i = 0u; i < type.componentCount; i++)
307 {
308 SIMD::Int v = value.Int(i);
309 dst.move(i, (x & v.xxxx) | (y & v.yyyy) | (z & v.zzzz) | (w & v.wwww));
310 }
311 break;
312 }
313
314 case spv::OpGroupNonUniformShuffleUp:
315 {
316 Operand value(this, state, insn.word(4));
317 Operand delta(this, state, insn.word(5));
318 auto d0 = CmpEQ(SIMD::Int(0), delta.Int(0));
319 auto d1 = CmpEQ(SIMD::Int(1), delta.Int(0));
320 auto d2 = CmpEQ(SIMD::Int(2), delta.Int(0));
321 auto d3 = CmpEQ(SIMD::Int(3), delta.Int(0));
322 for(auto i = 0u; i < type.componentCount; i++)
323 {
324 SIMD::Int v = value.Int(i);
325 dst.move(i, (d0 & v.xyzw) | (d1 & v.xxyz) | (d2 & v.xxxy) | (d3 & v.xxxx));
326 }
327 break;
328 }
329
330 case spv::OpGroupNonUniformShuffleDown:
331 {
332 Operand value(this, state, insn.word(4));
333 Operand delta(this, state, insn.word(5));
334 auto d0 = CmpEQ(SIMD::Int(0), delta.Int(0));
335 auto d1 = CmpEQ(SIMD::Int(1), delta.Int(0));
336 auto d2 = CmpEQ(SIMD::Int(2), delta.Int(0));
337 auto d3 = CmpEQ(SIMD::Int(3), delta.Int(0));
338 for(auto i = 0u; i < type.componentCount; i++)
339 {
340 SIMD::Int v = value.Int(i);
341 dst.move(i, (d0 & v.xyzw) | (d1 & v.yzww) | (d2 & v.zwww) | (d3 & v.wwww));
342 }
343 break;
344 }
345
346 case spv::OpGroupNonUniformIAdd:
347 Impl::Group::BinaryOperation<SIMD::Int>(
348 this, insn, state, dst, 0,
349 [](auto a, auto b) { return a + b; });
350 break;
351
352 case spv::OpGroupNonUniformFAdd:
353 Impl::Group::BinaryOperation<SIMD::Float>(
354 this, insn, state, dst, 0.0f,
355 [](auto a, auto b) { return a + b; });
356 break;
357
358 case spv::OpGroupNonUniformIMul:
359 Impl::Group::BinaryOperation<SIMD::Int>(
360 this, insn, state, dst, 1,
361 [](auto a, auto b) { return a * b; });
362 break;
363
364 case spv::OpGroupNonUniformFMul:
365 Impl::Group::BinaryOperation<SIMD::Float>(
366 this, insn, state, dst, 1.0f,
367 [](auto a, auto b) { return a * b; });
368 break;
369
370 case spv::OpGroupNonUniformBitwiseAnd:
371 Impl::Group::BinaryOperation<SIMD::UInt>(
372 this, insn, state, dst, ~0u,
373 [](auto a, auto b) { return a & b; });
374 break;
375
376 case spv::OpGroupNonUniformBitwiseOr:
377 Impl::Group::BinaryOperation<SIMD::UInt>(
378 this, insn, state, dst, 0,
379 [](auto a, auto b) { return a | b; });
380 break;
381
382 case spv::OpGroupNonUniformBitwiseXor:
383 Impl::Group::BinaryOperation<SIMD::UInt>(
384 this, insn, state, dst, 0,
385 [](auto a, auto b) { return a ^ b; });
386 break;
387
388 case spv::OpGroupNonUniformSMin:
389 Impl::Group::BinaryOperation<SIMD::Int>(
390 this, insn, state, dst, INT32_MAX,
391 [](auto a, auto b) { return Min(a, b); });
392 break;
393
394 case spv::OpGroupNonUniformUMin:
395 Impl::Group::BinaryOperation<SIMD::UInt>(
396 this, insn, state, dst, ~0u,
397 [](auto a, auto b) { return Min(a, b); });
398 break;
399
400 case spv::OpGroupNonUniformFMin:
401 Impl::Group::BinaryOperation<SIMD::Float>(
402 this, insn, state, dst, SIMD::Float::infinity(),
403 [](auto a, auto b) { return NMin(a, b); });
404 break;
405
406 case spv::OpGroupNonUniformSMax:
407 Impl::Group::BinaryOperation<SIMD::Int>(
408 this, insn, state, dst, INT32_MIN,
409 [](auto a, auto b) { return Max(a, b); });
410 break;
411
412 case spv::OpGroupNonUniformUMax:
413 Impl::Group::BinaryOperation<SIMD::UInt>(
414 this, insn, state, dst, 0,
415 [](auto a, auto b) { return Max(a, b); });
416 break;
417
418 case spv::OpGroupNonUniformFMax:
419 Impl::Group::BinaryOperation<SIMD::Float>(
420 this, insn, state, dst, -SIMD::Float::infinity(),
421 [](auto a, auto b) { return NMax(a, b); });
422 break;
423
424 case spv::OpGroupNonUniformLogicalAnd:
425 Impl::Group::BinaryOperation<SIMD::UInt>(
426 this, insn, state, dst, ~0u,
427 [](auto a, auto b) {
428 SIMD::UInt zero = SIMD::UInt(0);
429 return CmpNEQ(a, zero) & CmpNEQ(b, zero);
430 });
431 break;
432
433 case spv::OpGroupNonUniformLogicalOr:
434 Impl::Group::BinaryOperation<SIMD::UInt>(
435 this, insn, state, dst, 0,
436 [](auto a, auto b) {
437 SIMD::UInt zero = SIMD::UInt(0);
438 return CmpNEQ(a, zero) | CmpNEQ(b, zero);
439 });
440 break;
441
442 case spv::OpGroupNonUniformLogicalXor:
443 Impl::Group::BinaryOperation<SIMD::UInt>(
444 this, insn, state, dst, 0,
445 [](auto a, auto b) {
446 SIMD::UInt zero = SIMD::UInt(0);
447 return CmpNEQ(a, zero) ^ CmpNEQ(b, zero);
448 });
449 break;
450
451 default:
452 UNSUPPORTED("EmitGroupNonUniform op: %s", OpcodeName(type.opcode()));
453 }
454 return EmitResult::Continue;
455 }
456
457 } // namespace sw
458