1 // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "SpirvShader.hpp"
16 #include "SpirvShaderDebug.hpp"
17
18 #include "ShaderCore.hpp"
19
20 #include <spirv/unified1/spirv.hpp>
21
22 #include <limits>
23
24 namespace sw {
25
EmitVectorTimesScalar(InsnIterator insn,EmitState * state) const26 SpirvShader::EmitResult SpirvShader::EmitVectorTimesScalar(InsnIterator insn, EmitState *state) const
27 {
28 auto &type = getType(insn.resultTypeId());
29 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
30 auto lhs = Operand(this, state, insn.word(3));
31 auto rhs = Operand(this, state, insn.word(4));
32
33 for(auto i = 0u; i < type.componentCount; i++)
34 {
35 dst.move(i, lhs.Float(i) * rhs.Float(0));
36 }
37
38 return EmitResult::Continue;
39 }
40
EmitMatrixTimesVector(InsnIterator insn,EmitState * state) const41 SpirvShader::EmitResult SpirvShader::EmitMatrixTimesVector(InsnIterator insn, EmitState *state) const
42 {
43 auto &type = getType(insn.resultTypeId());
44 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
45 auto lhs = Operand(this, state, insn.word(3));
46 auto rhs = Operand(this, state, insn.word(4));
47
48 for(auto i = 0u; i < type.componentCount; i++)
49 {
50 SIMD::Float v = lhs.Float(i) * rhs.Float(0);
51 for(auto j = 1u; j < rhs.componentCount; j++)
52 {
53 v += lhs.Float(i + type.componentCount * j) * rhs.Float(j);
54 }
55 dst.move(i, v);
56 }
57
58 return EmitResult::Continue;
59 }
60
EmitVectorTimesMatrix(InsnIterator insn,EmitState * state) const61 SpirvShader::EmitResult SpirvShader::EmitVectorTimesMatrix(InsnIterator insn, EmitState *state) const
62 {
63 auto &type = getType(insn.resultTypeId());
64 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
65 auto lhs = Operand(this, state, insn.word(3));
66 auto rhs = Operand(this, state, insn.word(4));
67
68 for(auto i = 0u; i < type.componentCount; i++)
69 {
70 SIMD::Float v = lhs.Float(0) * rhs.Float(i * lhs.componentCount);
71 for(auto j = 1u; j < lhs.componentCount; j++)
72 {
73 v += lhs.Float(j) * rhs.Float(i * lhs.componentCount + j);
74 }
75 dst.move(i, v);
76 }
77
78 return EmitResult::Continue;
79 }
80
EmitMatrixTimesMatrix(InsnIterator insn,EmitState * state) const81 SpirvShader::EmitResult SpirvShader::EmitMatrixTimesMatrix(InsnIterator insn, EmitState *state) const
82 {
83 auto &type = getType(insn.resultTypeId());
84 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
85 auto lhs = Operand(this, state, insn.word(3));
86 auto rhs = Operand(this, state, insn.word(4));
87
88 auto numColumns = type.definition.word(3);
89 auto numRows = getType(type.definition.word(2)).definition.word(3);
90 auto numAdds = getObjectType(insn.word(3)).definition.word(3);
91
92 for(auto row = 0u; row < numRows; row++)
93 {
94 for(auto col = 0u; col < numColumns; col++)
95 {
96 SIMD::Float v = SIMD::Float(0);
97 for(auto i = 0u; i < numAdds; i++)
98 {
99 v += lhs.Float(i * numRows + row) * rhs.Float(col * numAdds + i);
100 }
101 dst.move(numRows * col + row, v);
102 }
103 }
104
105 return EmitResult::Continue;
106 }
107
EmitOuterProduct(InsnIterator insn,EmitState * state) const108 SpirvShader::EmitResult SpirvShader::EmitOuterProduct(InsnIterator insn, EmitState *state) const
109 {
110 auto &type = getType(insn.resultTypeId());
111 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
112 auto lhs = Operand(this, state, insn.word(3));
113 auto rhs = Operand(this, state, insn.word(4));
114
115 auto numRows = lhs.componentCount;
116 auto numCols = rhs.componentCount;
117
118 for(auto col = 0u; col < numCols; col++)
119 {
120 for(auto row = 0u; row < numRows; row++)
121 {
122 dst.move(col * numRows + row, lhs.Float(row) * rhs.Float(col));
123 }
124 }
125
126 return EmitResult::Continue;
127 }
128
EmitTranspose(InsnIterator insn,EmitState * state) const129 SpirvShader::EmitResult SpirvShader::EmitTranspose(InsnIterator insn, EmitState *state) const
130 {
131 auto &type = getType(insn.resultTypeId());
132 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
133 auto mat = Operand(this, state, insn.word(3));
134
135 auto numCols = type.definition.word(3);
136 auto numRows = getType(type.definition.word(2)).componentCount;
137
138 for(auto col = 0u; col < numCols; col++)
139 {
140 for(auto row = 0u; row < numRows; row++)
141 {
142 dst.move(col * numRows + row, mat.Float(row * numCols + col));
143 }
144 }
145
146 return EmitResult::Continue;
147 }
148
EmitUnaryOp(InsnIterator insn,EmitState * state) const149 SpirvShader::EmitResult SpirvShader::EmitUnaryOp(InsnIterator insn, EmitState *state) const
150 {
151 auto &type = getType(insn.resultTypeId());
152 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
153 auto src = Operand(this, state, insn.word(3));
154
155 for(auto i = 0u; i < type.componentCount; i++)
156 {
157 switch(insn.opcode())
158 {
159 case spv::OpNot:
160 case spv::OpLogicalNot: // logical not == bitwise not due to all-bits boolean representation
161 dst.move(i, ~src.UInt(i));
162 break;
163 case spv::OpBitFieldInsert:
164 {
165 auto insert = Operand(this, state, insn.word(4)).UInt(i);
166 auto offset = Operand(this, state, insn.word(5)).UInt(0);
167 auto count = Operand(this, state, insn.word(6)).UInt(0);
168 auto one = SIMD::UInt(1);
169 auto v = src.UInt(i);
170 auto mask = Bitmask32(offset + count) ^ Bitmask32(offset);
171 dst.move(i, (v & ~mask) | ((insert << offset) & mask));
172 }
173 break;
174 case spv::OpBitFieldSExtract:
175 case spv::OpBitFieldUExtract:
176 {
177 auto offset = Operand(this, state, insn.word(4)).UInt(0);
178 auto count = Operand(this, state, insn.word(5)).UInt(0);
179 auto one = SIMD::UInt(1);
180 auto v = src.UInt(i);
181 SIMD::UInt out = (v >> offset) & Bitmask32(count);
182 if(insn.opcode() == spv::OpBitFieldSExtract)
183 {
184 auto sign = out & NthBit32(count - one);
185 auto sext = ~(sign - one);
186 out |= sext;
187 }
188 dst.move(i, out);
189 }
190 break;
191 case spv::OpBitReverse:
192 {
193 // TODO: Add an intrinsic to reactor. Even if there isn't a
194 // single vector instruction, there may be target-dependent
195 // ways to make this faster.
196 // https://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel
197 SIMD::UInt v = src.UInt(i);
198 v = ((v >> 1) & SIMD::UInt(0x55555555)) | ((v & SIMD::UInt(0x55555555)) << 1);
199 v = ((v >> 2) & SIMD::UInt(0x33333333)) | ((v & SIMD::UInt(0x33333333)) << 2);
200 v = ((v >> 4) & SIMD::UInt(0x0F0F0F0F)) | ((v & SIMD::UInt(0x0F0F0F0F)) << 4);
201 v = ((v >> 8) & SIMD::UInt(0x00FF00FF)) | ((v & SIMD::UInt(0x00FF00FF)) << 8);
202 v = (v >> 16) | (v << 16);
203 dst.move(i, v);
204 }
205 break;
206 case spv::OpBitCount:
207 dst.move(i, CountBits(src.UInt(i)));
208 break;
209 case spv::OpSNegate:
210 dst.move(i, -src.Int(i));
211 break;
212 case spv::OpFNegate:
213 dst.move(i, -src.Float(i));
214 break;
215 case spv::OpConvertFToU:
216 dst.move(i, SIMD::UInt(src.Float(i)));
217 break;
218 case spv::OpConvertFToS:
219 dst.move(i, SIMD::Int(src.Float(i)));
220 break;
221 case spv::OpConvertSToF:
222 dst.move(i, SIMD::Float(src.Int(i)));
223 break;
224 case spv::OpConvertUToF:
225 dst.move(i, SIMD::Float(src.UInt(i)));
226 break;
227 case spv::OpBitcast:
228 dst.move(i, src.Float(i));
229 break;
230 case spv::OpIsInf:
231 dst.move(i, IsInf(src.Float(i)));
232 break;
233 case spv::OpIsNan:
234 dst.move(i, IsNan(src.Float(i)));
235 break;
236 case spv::OpDPdx:
237 case spv::OpDPdxCoarse:
238 // Derivative instructions: FS invocations are laid out like so:
239 // 0 1
240 // 2 3
241 static_assert(SIMD::Width == 4, "All cross-lane instructions will need care when using a different width");
242 dst.move(i, SIMD::Float(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)));
243 break;
244 case spv::OpDPdy:
245 case spv::OpDPdyCoarse:
246 dst.move(i, SIMD::Float(Extract(src.Float(i), 2) - Extract(src.Float(i), 0)));
247 break;
248 case spv::OpFwidth:
249 case spv::OpFwidthCoarse:
250 dst.move(i, SIMD::Float(Abs(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)) + Abs(Extract(src.Float(i), 2) - Extract(src.Float(i), 0))));
251 break;
252 case spv::OpDPdxFine:
253 {
254 auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
255 auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
256 SIMD::Float v = SIMD::Float(firstRow);
257 v = Insert(v, secondRow, 2);
258 v = Insert(v, secondRow, 3);
259 dst.move(i, v);
260 }
261 break;
262 case spv::OpDPdyFine:
263 {
264 auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
265 auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
266 SIMD::Float v = SIMD::Float(firstColumn);
267 v = Insert(v, secondColumn, 1);
268 v = Insert(v, secondColumn, 3);
269 dst.move(i, v);
270 }
271 break;
272 case spv::OpFwidthFine:
273 {
274 auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
275 auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
276 SIMD::Float dpdx = SIMD::Float(firstRow);
277 dpdx = Insert(dpdx, secondRow, 2);
278 dpdx = Insert(dpdx, secondRow, 3);
279 auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
280 auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
281 SIMD::Float dpdy = SIMD::Float(firstColumn);
282 dpdy = Insert(dpdy, secondColumn, 1);
283 dpdy = Insert(dpdy, secondColumn, 3);
284 dst.move(i, Abs(dpdx) + Abs(dpdy));
285 }
286 break;
287 case spv::OpQuantizeToF16:
288 {
289 // Note: keep in sync with the specialization constant version in EvalSpecConstantUnaryOp
290 auto abs = Abs(src.Float(i));
291 auto sign = src.Int(i) & SIMD::Int(0x80000000);
292 auto isZero = CmpLT(abs, SIMD::Float(0.000061035f));
293 auto isInf = CmpGT(abs, SIMD::Float(65504.0f));
294 auto isNaN = IsNan(abs);
295 auto isInfOrNan = isInf | isNaN;
296 SIMD::Int v = src.Int(i) & SIMD::Int(0xFFFFE000);
297 v &= ~isZero | SIMD::Int(0x80000000);
298 v = sign | (isInfOrNan & SIMD::Int(0x7F800000)) | (~isInfOrNan & v);
299 v |= isNaN & SIMD::Int(0x400000);
300 dst.move(i, v);
301 }
302 break;
303 default:
304 UNREACHABLE("%s", OpcodeName(insn.opcode()));
305 }
306 }
307
308 return EmitResult::Continue;
309 }
310
EmitBinaryOp(InsnIterator insn,EmitState * state) const311 SpirvShader::EmitResult SpirvShader::EmitBinaryOp(InsnIterator insn, EmitState *state) const
312 {
313 auto &type = getType(insn.resultTypeId());
314 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
315 auto &lhsType = getObjectType(insn.word(3));
316 auto lhs = Operand(this, state, insn.word(3));
317 auto rhs = Operand(this, state, insn.word(4));
318
319 for(auto i = 0u; i < lhsType.componentCount; i++)
320 {
321 switch(insn.opcode())
322 {
323 case spv::OpIAdd:
324 dst.move(i, lhs.Int(i) + rhs.Int(i));
325 break;
326 case spv::OpISub:
327 dst.move(i, lhs.Int(i) - rhs.Int(i));
328 break;
329 case spv::OpIMul:
330 dst.move(i, lhs.Int(i) * rhs.Int(i));
331 break;
332 case spv::OpSDiv:
333 {
334 SIMD::Int a = lhs.Int(i);
335 SIMD::Int b = rhs.Int(i);
336 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
337 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
338 dst.move(i, a / b);
339 }
340 break;
341 case spv::OpUDiv:
342 {
343 auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
344 dst.move(i, lhs.UInt(i) / (rhs.UInt(i) | zeroMask));
345 }
346 break;
347 case spv::OpSRem:
348 {
349 SIMD::Int a = lhs.Int(i);
350 SIMD::Int b = rhs.Int(i);
351 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
352 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
353 dst.move(i, a % b);
354 }
355 break;
356 case spv::OpSMod:
357 {
358 SIMD::Int a = lhs.Int(i);
359 SIMD::Int b = rhs.Int(i);
360 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
361 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
362 auto mod = a % b;
363 // If a and b have opposite signs, the remainder operation takes
364 // the sign from a but OpSMod is supposed to take the sign of b.
365 // Adding b will ensure that the result has the correct sign and
366 // that it is still congruent to a modulo b.
367 //
368 // See also http://mathforum.org/library/drmath/view/52343.html
369 auto signDiff = CmpNEQ(CmpGE(a, SIMD::Int(0)), CmpGE(b, SIMD::Int(0)));
370 auto fixedMod = mod + (b & CmpNEQ(mod, SIMD::Int(0)) & signDiff);
371 dst.move(i, As<SIMD::Float>(fixedMod));
372 }
373 break;
374 case spv::OpUMod:
375 {
376 auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
377 dst.move(i, lhs.UInt(i) % (rhs.UInt(i) | zeroMask));
378 }
379 break;
380 case spv::OpIEqual:
381 case spv::OpLogicalEqual:
382 dst.move(i, CmpEQ(lhs.Int(i), rhs.Int(i)));
383 break;
384 case spv::OpINotEqual:
385 case spv::OpLogicalNotEqual:
386 dst.move(i, CmpNEQ(lhs.Int(i), rhs.Int(i)));
387 break;
388 case spv::OpUGreaterThan:
389 dst.move(i, CmpGT(lhs.UInt(i), rhs.UInt(i)));
390 break;
391 case spv::OpSGreaterThan:
392 dst.move(i, CmpGT(lhs.Int(i), rhs.Int(i)));
393 break;
394 case spv::OpUGreaterThanEqual:
395 dst.move(i, CmpGE(lhs.UInt(i), rhs.UInt(i)));
396 break;
397 case spv::OpSGreaterThanEqual:
398 dst.move(i, CmpGE(lhs.Int(i), rhs.Int(i)));
399 break;
400 case spv::OpULessThan:
401 dst.move(i, CmpLT(lhs.UInt(i), rhs.UInt(i)));
402 break;
403 case spv::OpSLessThan:
404 dst.move(i, CmpLT(lhs.Int(i), rhs.Int(i)));
405 break;
406 case spv::OpULessThanEqual:
407 dst.move(i, CmpLE(lhs.UInt(i), rhs.UInt(i)));
408 break;
409 case spv::OpSLessThanEqual:
410 dst.move(i, CmpLE(lhs.Int(i), rhs.Int(i)));
411 break;
412 case spv::OpFAdd:
413 dst.move(i, lhs.Float(i) + rhs.Float(i));
414 break;
415 case spv::OpFSub:
416 dst.move(i, lhs.Float(i) - rhs.Float(i));
417 break;
418 case spv::OpFMul:
419 dst.move(i, lhs.Float(i) * rhs.Float(i));
420 break;
421 case spv::OpFDiv:
422 // TODO(b/169760262): Optimize using reciprocal instructions (2.5 ULP).
423 // TODO(b/222218659): Optimize for RelaxedPrecision (2.5 ULP).
424 dst.move(i, lhs.Float(i) / rhs.Float(i));
425 break;
426 case spv::OpFMod:
427 // TODO(b/126873455): Inaccurate for values greater than 2^24.
428 // TODO(b/169760262): Optimize using reciprocal instructions.
429 // TODO(b/222218659): Optimize for RelaxedPrecision.
430 dst.move(i, lhs.Float(i) - rhs.Float(i) * Floor(lhs.Float(i) / rhs.Float(i)));
431 break;
432 case spv::OpFRem:
433 // TODO(b/169760262): Optimize using reciprocal instructions.
434 // TODO(b/222218659): Optimize for RelaxedPrecision.
435 dst.move(i, lhs.Float(i) % rhs.Float(i));
436 break;
437 case spv::OpFOrdEqual:
438 dst.move(i, CmpEQ(lhs.Float(i), rhs.Float(i)));
439 break;
440 case spv::OpFUnordEqual:
441 dst.move(i, CmpUEQ(lhs.Float(i), rhs.Float(i)));
442 break;
443 case spv::OpFOrdNotEqual:
444 dst.move(i, CmpNEQ(lhs.Float(i), rhs.Float(i)));
445 break;
446 case spv::OpFUnordNotEqual:
447 dst.move(i, CmpUNEQ(lhs.Float(i), rhs.Float(i)));
448 break;
449 case spv::OpFOrdLessThan:
450 dst.move(i, CmpLT(lhs.Float(i), rhs.Float(i)));
451 break;
452 case spv::OpFUnordLessThan:
453 dst.move(i, CmpULT(lhs.Float(i), rhs.Float(i)));
454 break;
455 case spv::OpFOrdGreaterThan:
456 dst.move(i, CmpGT(lhs.Float(i), rhs.Float(i)));
457 break;
458 case spv::OpFUnordGreaterThan:
459 dst.move(i, CmpUGT(lhs.Float(i), rhs.Float(i)));
460 break;
461 case spv::OpFOrdLessThanEqual:
462 dst.move(i, CmpLE(lhs.Float(i), rhs.Float(i)));
463 break;
464 case spv::OpFUnordLessThanEqual:
465 dst.move(i, CmpULE(lhs.Float(i), rhs.Float(i)));
466 break;
467 case spv::OpFOrdGreaterThanEqual:
468 dst.move(i, CmpGE(lhs.Float(i), rhs.Float(i)));
469 break;
470 case spv::OpFUnordGreaterThanEqual:
471 dst.move(i, CmpUGE(lhs.Float(i), rhs.Float(i)));
472 break;
473 case spv::OpShiftRightLogical:
474 dst.move(i, lhs.UInt(i) >> rhs.UInt(i));
475 break;
476 case spv::OpShiftRightArithmetic:
477 dst.move(i, lhs.Int(i) >> rhs.Int(i));
478 break;
479 case spv::OpShiftLeftLogical:
480 dst.move(i, lhs.UInt(i) << rhs.UInt(i));
481 break;
482 case spv::OpBitwiseOr:
483 case spv::OpLogicalOr:
484 dst.move(i, lhs.UInt(i) | rhs.UInt(i));
485 break;
486 case spv::OpBitwiseXor:
487 dst.move(i, lhs.UInt(i) ^ rhs.UInt(i));
488 break;
489 case spv::OpBitwiseAnd:
490 case spv::OpLogicalAnd:
491 dst.move(i, lhs.UInt(i) & rhs.UInt(i));
492 break;
493 case spv::OpSMulExtended:
494 // Extended ops: result is a structure containing two members of the same type as lhs & rhs.
495 // In our flat view then, component i is the i'th component of the first member;
496 // component i + N is the i'th component of the second member.
497 dst.move(i, lhs.Int(i) * rhs.Int(i));
498 dst.move(i + lhsType.componentCount, MulHigh(lhs.Int(i), rhs.Int(i)));
499 break;
500 case spv::OpUMulExtended:
501 dst.move(i, lhs.UInt(i) * rhs.UInt(i));
502 dst.move(i + lhsType.componentCount, MulHigh(lhs.UInt(i), rhs.UInt(i)));
503 break;
504 case spv::OpIAddCarry:
505 dst.move(i, lhs.UInt(i) + rhs.UInt(i));
506 dst.move(i + lhsType.componentCount, CmpLT(dst.UInt(i), lhs.UInt(i)) >> 31);
507 break;
508 case spv::OpISubBorrow:
509 dst.move(i, lhs.UInt(i) - rhs.UInt(i));
510 dst.move(i + lhsType.componentCount, CmpLT(lhs.UInt(i), rhs.UInt(i)) >> 31);
511 break;
512 default:
513 UNREACHABLE("%s", OpcodeName(insn.opcode()));
514 }
515 }
516
517 SPIRV_SHADER_DBG("{0}: {1}", insn.word(2), dst);
518 SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
519 SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
520
521 return EmitResult::Continue;
522 }
523
EmitDot(InsnIterator insn,EmitState * state) const524 SpirvShader::EmitResult SpirvShader::EmitDot(InsnIterator insn, EmitState *state) const
525 {
526 auto &type = getType(insn.resultTypeId());
527 ASSERT(type.componentCount == 1);
528 auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
529 auto &lhsType = getObjectType(insn.word(3));
530 auto lhs = Operand(this, state, insn.word(3));
531 auto rhs = Operand(this, state, insn.word(4));
532
533 auto opcode = insn.opcode();
534 switch(opcode)
535 {
536 case spv::OpDot:
537 dst.move(0, FDot(lhsType.componentCount, lhs, rhs));
538 break;
539 case spv::OpSDot:
540 dst.move(0, SDot(lhsType.componentCount, lhs, rhs, nullptr));
541 break;
542 case spv::OpUDot:
543 dst.move(0, UDot(lhsType.componentCount, lhs, rhs, nullptr));
544 break;
545 case spv::OpSUDot:
546 dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, nullptr));
547 break;
548 case spv::OpSDotAccSat:
549 {
550 auto accum = Operand(this, state, insn.word(5));
551 dst.move(0, SDot(lhsType.componentCount, lhs, rhs, &accum));
552 }
553 break;
554 case spv::OpUDotAccSat:
555 {
556 auto accum = Operand(this, state, insn.word(5));
557 dst.move(0, UDot(lhsType.componentCount, lhs, rhs, &accum));
558 }
559 break;
560 case spv::OpSUDotAccSat:
561 {
562 auto accum = Operand(this, state, insn.word(5));
563 dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, &accum));
564 }
565 break;
566 default:
567 UNREACHABLE("%s", OpcodeName(opcode));
568 break;
569 }
570
571 SPIRV_SHADER_DBG("{0}: {1}", insn.resultId(), dst);
572 SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
573 SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
574
575 return EmitResult::Continue;
576 }
577
FDot(unsigned numComponents,Operand const & x,Operand const & y)578 SIMD::Float SpirvShader::FDot(unsigned numComponents, Operand const &x, Operand const &y)
579 {
580 SIMD::Float d = x.Float(0) * y.Float(0);
581
582 for(auto i = 1u; i < numComponents; i++)
583 {
584 d += x.Float(i) * y.Float(i);
585 }
586
587 return d;
588 }
589
SDot(unsigned numComponents,Operand const & x,Operand const & y,Operand const * accum)590 SIMD::Int SpirvShader::SDot(unsigned numComponents, Operand const &x, Operand const &y, Operand const *accum)
591 {
592 SIMD::Int d(0);
593
594 if(numComponents == 1) // 4x8bit packed
595 {
596 numComponents = 4;
597 for(auto i = 0u; i < numComponents; i++)
598 {
599 Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
600 Int4 ys(As<SByte4>(Extract(y.Int(0), i)));
601
602 Int4 xy = xs * ys;
603 rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
604
605 d = Insert(d, sum, i);
606 }
607 }
608 else
609 {
610 d = x.Int(0) * y.Int(0);
611
612 for(auto i = 1u; i < numComponents; i++)
613 {
614 d += x.Int(i) * y.Int(i);
615 }
616 }
617
618 if(accum)
619 {
620 d = AddSat(d, accum->Int(0));
621 }
622
623 return d;
624 }
625
UDot(unsigned numComponents,Operand const & x,Operand const & y,Operand const * accum)626 SIMD::UInt SpirvShader::UDot(unsigned numComponents, Operand const &x, Operand const &y, Operand const *accum)
627 {
628 SIMD::UInt d(0);
629
630 if(numComponents == 1) // 4x8bit packed
631 {
632 numComponents = 4;
633 for(auto i = 0u; i < numComponents; i++)
634 {
635 Int4 xs(As<Byte4>(Extract(x.Int(0), i)));
636 Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
637
638 UInt4 xy = xs * ys;
639 rr::UInt sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
640
641 d = Insert(d, sum, i);
642 }
643 }
644 else
645 {
646 d = x.UInt(0) * y.UInt(0);
647
648 for(auto i = 1u; i < numComponents; i++)
649 {
650 d += x.UInt(i) * y.UInt(i);
651 }
652 }
653
654 if(accum)
655 {
656 d = AddSat(d, accum->UInt(0));
657 }
658
659 return d;
660 }
661
SUDot(unsigned numComponents,Operand const & x,Operand const & y,Operand const * accum)662 SIMD::Int SpirvShader::SUDot(unsigned numComponents, Operand const &x, Operand const &y, Operand const *accum)
663 {
664 SIMD::Int d(0);
665
666 if(numComponents == 1) // 4x8bit packed
667 {
668 numComponents = 4;
669 for(auto i = 0u; i < numComponents; i++)
670 {
671 Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
672 Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
673
674 Int4 xy = xs * ys;
675 rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
676
677 d = Insert(d, sum, i);
678 }
679 }
680 else
681 {
682 d = x.Int(0) * As<SIMD::Int>(y.UInt(0));
683
684 for(auto i = 1u; i < numComponents; i++)
685 {
686 d += x.Int(i) * As<SIMD::Int>(y.UInt(i));
687 }
688 }
689
690 if(accum)
691 {
692 d = AddSat(d, accum->Int(0));
693 }
694
695 return d;
696 }
697
AddSat(RValue<SIMD::Int> a,RValue<SIMD::Int> b)698 SIMD::Int SpirvShader::AddSat(RValue<SIMD::Int> a, RValue<SIMD::Int> b)
699 {
700 SIMD::Int sum = a + b;
701 SIMD::Int sSign = sum >> 31;
702 SIMD::Int aSign = a >> 31;
703 SIMD::Int bSign = b >> 31;
704
705 // Overflow happened if both numbers added have the same sign and the sum has a different sign
706 SIMD::Int oob = ~(aSign ^ bSign) & (aSign ^ sSign);
707 SIMD::Int overflow = oob & sSign;
708 SIMD::Int underflow = oob & aSign;
709
710 return (overflow & std::numeric_limits<int32_t>::max()) |
711 (underflow & std::numeric_limits<int32_t>::min()) |
712 (~oob & sum);
713 }
714
AddSat(RValue<SIMD::UInt> a,RValue<SIMD::UInt> b)715 SIMD::UInt SpirvShader::AddSat(RValue<SIMD::UInt> a, RValue<SIMD::UInt> b)
716 {
717 SIMD::UInt sum = a + b;
718
719 // Overflow happened if the sum of unsigned integers is smaller than either of the 2 numbers being added
720 // Note: CmpLT()'s return value is automatically set to UINT_MAX when true
721 return CmpLT(sum, a) | sum;
722 }
723
Frexp(RValue<SIMD::Float> val) const724 std::pair<SIMD::Float, SIMD::Int> SpirvShader::Frexp(RValue<SIMD::Float> val) const
725 {
726 // Assumes IEEE 754
727 auto v = As<SIMD::UInt>(val);
728 auto isNotZero = CmpNEQ(v & SIMD::UInt(0x7FFFFFFF), SIMD::UInt(0));
729 auto zeroSign = v & SIMD::UInt(0x80000000) & ~isNotZero;
730 auto significand = As<SIMD::Float>((((v & SIMD::UInt(0x807FFFFF)) | SIMD::UInt(0x3F000000)) & isNotZero) | zeroSign);
731 auto exponent = Exponent(val) & SIMD::Int(isNotZero);
732 return std::make_pair(significand, exponent);
733 }
734
735 } // namespace sw