• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "SpirvShader.hpp"
16 #include "SpirvShaderDebug.hpp"
17 
18 #include "ShaderCore.hpp"
19 
20 #include <spirv/unified1/spirv.hpp>
21 
22 #include <limits>
23 
24 namespace sw {
25 
EmitVectorTimesScalar(InsnIterator insn,EmitState * state) const26 SpirvShader::EmitResult SpirvShader::EmitVectorTimesScalar(InsnIterator insn, EmitState *state) const
27 {
28 	auto &type = getType(insn.resultTypeId());
29 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
30 	auto lhs = Operand(this, state, insn.word(3));
31 	auto rhs = Operand(this, state, insn.word(4));
32 
33 	for(auto i = 0u; i < type.componentCount; i++)
34 	{
35 		dst.move(i, lhs.Float(i) * rhs.Float(0));
36 	}
37 
38 	return EmitResult::Continue;
39 }
40 
EmitMatrixTimesVector(InsnIterator insn,EmitState * state) const41 SpirvShader::EmitResult SpirvShader::EmitMatrixTimesVector(InsnIterator insn, EmitState *state) const
42 {
43 	auto &type = getType(insn.resultTypeId());
44 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
45 	auto lhs = Operand(this, state, insn.word(3));
46 	auto rhs = Operand(this, state, insn.word(4));
47 
48 	for(auto i = 0u; i < type.componentCount; i++)
49 	{
50 		SIMD::Float v = lhs.Float(i) * rhs.Float(0);
51 		for(auto j = 1u; j < rhs.componentCount; j++)
52 		{
53 			v += lhs.Float(i + type.componentCount * j) * rhs.Float(j);
54 		}
55 		dst.move(i, v);
56 	}
57 
58 	return EmitResult::Continue;
59 }
60 
EmitVectorTimesMatrix(InsnIterator insn,EmitState * state) const61 SpirvShader::EmitResult SpirvShader::EmitVectorTimesMatrix(InsnIterator insn, EmitState *state) const
62 {
63 	auto &type = getType(insn.resultTypeId());
64 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
65 	auto lhs = Operand(this, state, insn.word(3));
66 	auto rhs = Operand(this, state, insn.word(4));
67 
68 	for(auto i = 0u; i < type.componentCount; i++)
69 	{
70 		SIMD::Float v = lhs.Float(0) * rhs.Float(i * lhs.componentCount);
71 		for(auto j = 1u; j < lhs.componentCount; j++)
72 		{
73 			v += lhs.Float(j) * rhs.Float(i * lhs.componentCount + j);
74 		}
75 		dst.move(i, v);
76 	}
77 
78 	return EmitResult::Continue;
79 }
80 
EmitMatrixTimesMatrix(InsnIterator insn,EmitState * state) const81 SpirvShader::EmitResult SpirvShader::EmitMatrixTimesMatrix(InsnIterator insn, EmitState *state) const
82 {
83 	auto &type = getType(insn.resultTypeId());
84 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
85 	auto lhs = Operand(this, state, insn.word(3));
86 	auto rhs = Operand(this, state, insn.word(4));
87 
88 	auto numColumns = type.definition.word(3);
89 	auto numRows = getType(type.definition.word(2)).definition.word(3);
90 	auto numAdds = getObjectType(insn.word(3)).definition.word(3);
91 
92 	for(auto row = 0u; row < numRows; row++)
93 	{
94 		for(auto col = 0u; col < numColumns; col++)
95 		{
96 			SIMD::Float v = SIMD::Float(0);
97 			for(auto i = 0u; i < numAdds; i++)
98 			{
99 				v += lhs.Float(i * numRows + row) * rhs.Float(col * numAdds + i);
100 			}
101 			dst.move(numRows * col + row, v);
102 		}
103 	}
104 
105 	return EmitResult::Continue;
106 }
107 
EmitOuterProduct(InsnIterator insn,EmitState * state) const108 SpirvShader::EmitResult SpirvShader::EmitOuterProduct(InsnIterator insn, EmitState *state) const
109 {
110 	auto &type = getType(insn.resultTypeId());
111 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
112 	auto lhs = Operand(this, state, insn.word(3));
113 	auto rhs = Operand(this, state, insn.word(4));
114 
115 	auto numRows = lhs.componentCount;
116 	auto numCols = rhs.componentCount;
117 
118 	for(auto col = 0u; col < numCols; col++)
119 	{
120 		for(auto row = 0u; row < numRows; row++)
121 		{
122 			dst.move(col * numRows + row, lhs.Float(row) * rhs.Float(col));
123 		}
124 	}
125 
126 	return EmitResult::Continue;
127 }
128 
EmitTranspose(InsnIterator insn,EmitState * state) const129 SpirvShader::EmitResult SpirvShader::EmitTranspose(InsnIterator insn, EmitState *state) const
130 {
131 	auto &type = getType(insn.resultTypeId());
132 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
133 	auto mat = Operand(this, state, insn.word(3));
134 
135 	auto numCols = type.definition.word(3);
136 	auto numRows = getType(type.definition.word(2)).componentCount;
137 
138 	for(auto col = 0u; col < numCols; col++)
139 	{
140 		for(auto row = 0u; row < numRows; row++)
141 		{
142 			dst.move(col * numRows + row, mat.Float(row * numCols + col));
143 		}
144 	}
145 
146 	return EmitResult::Continue;
147 }
148 
EmitUnaryOp(InsnIterator insn,EmitState * state) const149 SpirvShader::EmitResult SpirvShader::EmitUnaryOp(InsnIterator insn, EmitState *state) const
150 {
151 	auto &type = getType(insn.resultTypeId());
152 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
153 	auto src = Operand(this, state, insn.word(3));
154 
155 	for(auto i = 0u; i < type.componentCount; i++)
156 	{
157 		switch(insn.opcode())
158 		{
159 		case spv::OpNot:
160 		case spv::OpLogicalNot:  // logical not == bitwise not due to all-bits boolean representation
161 			dst.move(i, ~src.UInt(i));
162 			break;
163 		case spv::OpBitFieldInsert:
164 			{
165 				auto insert = Operand(this, state, insn.word(4)).UInt(i);
166 				auto offset = Operand(this, state, insn.word(5)).UInt(0);
167 				auto count = Operand(this, state, insn.word(6)).UInt(0);
168 				auto one = SIMD::UInt(1);
169 				auto v = src.UInt(i);
170 				auto mask = Bitmask32(offset + count) ^ Bitmask32(offset);
171 				dst.move(i, (v & ~mask) | ((insert << offset) & mask));
172 			}
173 			break;
174 		case spv::OpBitFieldSExtract:
175 		case spv::OpBitFieldUExtract:
176 			{
177 				auto offset = Operand(this, state, insn.word(4)).UInt(0);
178 				auto count = Operand(this, state, insn.word(5)).UInt(0);
179 				auto one = SIMD::UInt(1);
180 				auto v = src.UInt(i);
181 				SIMD::UInt out = (v >> offset) & Bitmask32(count);
182 				if(insn.opcode() == spv::OpBitFieldSExtract)
183 				{
184 					auto sign = out & NthBit32(count - one);
185 					auto sext = ~(sign - one);
186 					out |= sext;
187 				}
188 				dst.move(i, out);
189 			}
190 			break;
191 		case spv::OpBitReverse:
192 			{
193 				// TODO: Add an intrinsic to reactor. Even if there isn't a
194 				// single vector instruction, there may be target-dependent
195 				// ways to make this faster.
196 				// https://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel
197 				SIMD::UInt v = src.UInt(i);
198 				v = ((v >> 1) & SIMD::UInt(0x55555555)) | ((v & SIMD::UInt(0x55555555)) << 1);
199 				v = ((v >> 2) & SIMD::UInt(0x33333333)) | ((v & SIMD::UInt(0x33333333)) << 2);
200 				v = ((v >> 4) & SIMD::UInt(0x0F0F0F0F)) | ((v & SIMD::UInt(0x0F0F0F0F)) << 4);
201 				v = ((v >> 8) & SIMD::UInt(0x00FF00FF)) | ((v & SIMD::UInt(0x00FF00FF)) << 8);
202 				v = (v >> 16) | (v << 16);
203 				dst.move(i, v);
204 			}
205 			break;
206 		case spv::OpBitCount:
207 			dst.move(i, CountBits(src.UInt(i)));
208 			break;
209 		case spv::OpSNegate:
210 			dst.move(i, -src.Int(i));
211 			break;
212 		case spv::OpFNegate:
213 			dst.move(i, -src.Float(i));
214 			break;
215 		case spv::OpConvertFToU:
216 			dst.move(i, SIMD::UInt(src.Float(i)));
217 			break;
218 		case spv::OpConvertFToS:
219 			dst.move(i, SIMD::Int(src.Float(i)));
220 			break;
221 		case spv::OpConvertSToF:
222 			dst.move(i, SIMD::Float(src.Int(i)));
223 			break;
224 		case spv::OpConvertUToF:
225 			dst.move(i, SIMD::Float(src.UInt(i)));
226 			break;
227 		case spv::OpBitcast:
228 			dst.move(i, src.Float(i));
229 			break;
230 		case spv::OpIsInf:
231 			dst.move(i, IsInf(src.Float(i)));
232 			break;
233 		case spv::OpIsNan:
234 			dst.move(i, IsNan(src.Float(i)));
235 			break;
236 		case spv::OpDPdx:
237 		case spv::OpDPdxCoarse:
238 			// Derivative instructions: FS invocations are laid out like so:
239 			//    0 1
240 			//    2 3
241 			static_assert(SIMD::Width == 4, "All cross-lane instructions will need care when using a different width");
242 			dst.move(i, SIMD::Float(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)));
243 			break;
244 		case spv::OpDPdy:
245 		case spv::OpDPdyCoarse:
246 			dst.move(i, SIMD::Float(Extract(src.Float(i), 2) - Extract(src.Float(i), 0)));
247 			break;
248 		case spv::OpFwidth:
249 		case spv::OpFwidthCoarse:
250 			dst.move(i, SIMD::Float(Abs(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)) + Abs(Extract(src.Float(i), 2) - Extract(src.Float(i), 0))));
251 			break;
252 		case spv::OpDPdxFine:
253 			{
254 				auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
255 				auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
256 				SIMD::Float v = SIMD::Float(firstRow);
257 				v = Insert(v, secondRow, 2);
258 				v = Insert(v, secondRow, 3);
259 				dst.move(i, v);
260 			}
261 			break;
262 		case spv::OpDPdyFine:
263 			{
264 				auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
265 				auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
266 				SIMD::Float v = SIMD::Float(firstColumn);
267 				v = Insert(v, secondColumn, 1);
268 				v = Insert(v, secondColumn, 3);
269 				dst.move(i, v);
270 			}
271 			break;
272 		case spv::OpFwidthFine:
273 			{
274 				auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
275 				auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
276 				SIMD::Float dpdx = SIMD::Float(firstRow);
277 				dpdx = Insert(dpdx, secondRow, 2);
278 				dpdx = Insert(dpdx, secondRow, 3);
279 				auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
280 				auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
281 				SIMD::Float dpdy = SIMD::Float(firstColumn);
282 				dpdy = Insert(dpdy, secondColumn, 1);
283 				dpdy = Insert(dpdy, secondColumn, 3);
284 				dst.move(i, Abs(dpdx) + Abs(dpdy));
285 			}
286 			break;
287 		case spv::OpQuantizeToF16:
288 			{
289 				// Note: keep in sync with the specialization constant version in EvalSpecConstantUnaryOp
290 				auto abs = Abs(src.Float(i));
291 				auto sign = src.Int(i) & SIMD::Int(0x80000000);
292 				auto isZero = CmpLT(abs, SIMD::Float(0.000061035f));
293 				auto isInf = CmpGT(abs, SIMD::Float(65504.0f));
294 				auto isNaN = IsNan(abs);
295 				auto isInfOrNan = isInf | isNaN;
296 				SIMD::Int v = src.Int(i) & SIMD::Int(0xFFFFE000);
297 				v &= ~isZero | SIMD::Int(0x80000000);
298 				v = sign | (isInfOrNan & SIMD::Int(0x7F800000)) | (~isInfOrNan & v);
299 				v |= isNaN & SIMD::Int(0x400000);
300 				dst.move(i, v);
301 			}
302 			break;
303 		default:
304 			UNREACHABLE("%s", OpcodeName(insn.opcode()));
305 		}
306 	}
307 
308 	return EmitResult::Continue;
309 }
310 
EmitBinaryOp(InsnIterator insn,EmitState * state) const311 SpirvShader::EmitResult SpirvShader::EmitBinaryOp(InsnIterator insn, EmitState *state) const
312 {
313 	auto &type = getType(insn.resultTypeId());
314 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
315 	auto &lhsType = getObjectType(insn.word(3));
316 	auto lhs = Operand(this, state, insn.word(3));
317 	auto rhs = Operand(this, state, insn.word(4));
318 
319 	for(auto i = 0u; i < lhsType.componentCount; i++)
320 	{
321 		switch(insn.opcode())
322 		{
323 		case spv::OpIAdd:
324 			dst.move(i, lhs.Int(i) + rhs.Int(i));
325 			break;
326 		case spv::OpISub:
327 			dst.move(i, lhs.Int(i) - rhs.Int(i));
328 			break;
329 		case spv::OpIMul:
330 			dst.move(i, lhs.Int(i) * rhs.Int(i));
331 			break;
332 		case spv::OpSDiv:
333 			{
334 				SIMD::Int a = lhs.Int(i);
335 				SIMD::Int b = rhs.Int(i);
336 				b = b | CmpEQ(b, SIMD::Int(0));                                       // prevent divide-by-zero
337 				a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1)));  // prevent integer overflow
338 				dst.move(i, a / b);
339 			}
340 			break;
341 		case spv::OpUDiv:
342 			{
343 				auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
344 				dst.move(i, lhs.UInt(i) / (rhs.UInt(i) | zeroMask));
345 			}
346 			break;
347 		case spv::OpSRem:
348 			{
349 				SIMD::Int a = lhs.Int(i);
350 				SIMD::Int b = rhs.Int(i);
351 				b = b | CmpEQ(b, SIMD::Int(0));                                       // prevent divide-by-zero
352 				a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1)));  // prevent integer overflow
353 				dst.move(i, a % b);
354 			}
355 			break;
356 		case spv::OpSMod:
357 			{
358 				SIMD::Int a = lhs.Int(i);
359 				SIMD::Int b = rhs.Int(i);
360 				b = b | CmpEQ(b, SIMD::Int(0));                                       // prevent divide-by-zero
361 				a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1)));  // prevent integer overflow
362 				auto mod = a % b;
363 				// If a and b have opposite signs, the remainder operation takes
364 				// the sign from a but OpSMod is supposed to take the sign of b.
365 				// Adding b will ensure that the result has the correct sign and
366 				// that it is still congruent to a modulo b.
367 				//
368 				// See also http://mathforum.org/library/drmath/view/52343.html
369 				auto signDiff = CmpNEQ(CmpGE(a, SIMD::Int(0)), CmpGE(b, SIMD::Int(0)));
370 				auto fixedMod = mod + (b & CmpNEQ(mod, SIMD::Int(0)) & signDiff);
371 				dst.move(i, As<SIMD::Float>(fixedMod));
372 			}
373 			break;
374 		case spv::OpUMod:
375 			{
376 				auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
377 				dst.move(i, lhs.UInt(i) % (rhs.UInt(i) | zeroMask));
378 			}
379 			break;
380 		case spv::OpIEqual:
381 		case spv::OpLogicalEqual:
382 			dst.move(i, CmpEQ(lhs.Int(i), rhs.Int(i)));
383 			break;
384 		case spv::OpINotEqual:
385 		case spv::OpLogicalNotEqual:
386 			dst.move(i, CmpNEQ(lhs.Int(i), rhs.Int(i)));
387 			break;
388 		case spv::OpUGreaterThan:
389 			dst.move(i, CmpGT(lhs.UInt(i), rhs.UInt(i)));
390 			break;
391 		case spv::OpSGreaterThan:
392 			dst.move(i, CmpGT(lhs.Int(i), rhs.Int(i)));
393 			break;
394 		case spv::OpUGreaterThanEqual:
395 			dst.move(i, CmpGE(lhs.UInt(i), rhs.UInt(i)));
396 			break;
397 		case spv::OpSGreaterThanEqual:
398 			dst.move(i, CmpGE(lhs.Int(i), rhs.Int(i)));
399 			break;
400 		case spv::OpULessThan:
401 			dst.move(i, CmpLT(lhs.UInt(i), rhs.UInt(i)));
402 			break;
403 		case spv::OpSLessThan:
404 			dst.move(i, CmpLT(lhs.Int(i), rhs.Int(i)));
405 			break;
406 		case spv::OpULessThanEqual:
407 			dst.move(i, CmpLE(lhs.UInt(i), rhs.UInt(i)));
408 			break;
409 		case spv::OpSLessThanEqual:
410 			dst.move(i, CmpLE(lhs.Int(i), rhs.Int(i)));
411 			break;
412 		case spv::OpFAdd:
413 			dst.move(i, lhs.Float(i) + rhs.Float(i));
414 			break;
415 		case spv::OpFSub:
416 			dst.move(i, lhs.Float(i) - rhs.Float(i));
417 			break;
418 		case spv::OpFMul:
419 			dst.move(i, lhs.Float(i) * rhs.Float(i));
420 			break;
421 		case spv::OpFDiv:
422 			// TODO(b/169760262): Optimize using reciprocal instructions (2.5 ULP).
423 			// TODO(b/222218659): Optimize for RelaxedPrecision (2.5 ULP).
424 			dst.move(i, lhs.Float(i) / rhs.Float(i));
425 			break;
426 		case spv::OpFMod:
427 			// TODO(b/126873455): Inaccurate for values greater than 2^24.
428 			// TODO(b/169760262): Optimize using reciprocal instructions.
429 			// TODO(b/222218659): Optimize for RelaxedPrecision.
430 			dst.move(i, lhs.Float(i) - rhs.Float(i) * Floor(lhs.Float(i) / rhs.Float(i)));
431 			break;
432 		case spv::OpFRem:
433 			// TODO(b/169760262): Optimize using reciprocal instructions.
434 			// TODO(b/222218659): Optimize for RelaxedPrecision.
435 			dst.move(i, lhs.Float(i) % rhs.Float(i));
436 			break;
437 		case spv::OpFOrdEqual:
438 			dst.move(i, CmpEQ(lhs.Float(i), rhs.Float(i)));
439 			break;
440 		case spv::OpFUnordEqual:
441 			dst.move(i, CmpUEQ(lhs.Float(i), rhs.Float(i)));
442 			break;
443 		case spv::OpFOrdNotEqual:
444 			dst.move(i, CmpNEQ(lhs.Float(i), rhs.Float(i)));
445 			break;
446 		case spv::OpFUnordNotEqual:
447 			dst.move(i, CmpUNEQ(lhs.Float(i), rhs.Float(i)));
448 			break;
449 		case spv::OpFOrdLessThan:
450 			dst.move(i, CmpLT(lhs.Float(i), rhs.Float(i)));
451 			break;
452 		case spv::OpFUnordLessThan:
453 			dst.move(i, CmpULT(lhs.Float(i), rhs.Float(i)));
454 			break;
455 		case spv::OpFOrdGreaterThan:
456 			dst.move(i, CmpGT(lhs.Float(i), rhs.Float(i)));
457 			break;
458 		case spv::OpFUnordGreaterThan:
459 			dst.move(i, CmpUGT(lhs.Float(i), rhs.Float(i)));
460 			break;
461 		case spv::OpFOrdLessThanEqual:
462 			dst.move(i, CmpLE(lhs.Float(i), rhs.Float(i)));
463 			break;
464 		case spv::OpFUnordLessThanEqual:
465 			dst.move(i, CmpULE(lhs.Float(i), rhs.Float(i)));
466 			break;
467 		case spv::OpFOrdGreaterThanEqual:
468 			dst.move(i, CmpGE(lhs.Float(i), rhs.Float(i)));
469 			break;
470 		case spv::OpFUnordGreaterThanEqual:
471 			dst.move(i, CmpUGE(lhs.Float(i), rhs.Float(i)));
472 			break;
473 		case spv::OpShiftRightLogical:
474 			dst.move(i, lhs.UInt(i) >> rhs.UInt(i));
475 			break;
476 		case spv::OpShiftRightArithmetic:
477 			dst.move(i, lhs.Int(i) >> rhs.Int(i));
478 			break;
479 		case spv::OpShiftLeftLogical:
480 			dst.move(i, lhs.UInt(i) << rhs.UInt(i));
481 			break;
482 		case spv::OpBitwiseOr:
483 		case spv::OpLogicalOr:
484 			dst.move(i, lhs.UInt(i) | rhs.UInt(i));
485 			break;
486 		case spv::OpBitwiseXor:
487 			dst.move(i, lhs.UInt(i) ^ rhs.UInt(i));
488 			break;
489 		case spv::OpBitwiseAnd:
490 		case spv::OpLogicalAnd:
491 			dst.move(i, lhs.UInt(i) & rhs.UInt(i));
492 			break;
493 		case spv::OpSMulExtended:
494 			// Extended ops: result is a structure containing two members of the same type as lhs & rhs.
495 			// In our flat view then, component i is the i'th component of the first member;
496 			// component i + N is the i'th component of the second member.
497 			dst.move(i, lhs.Int(i) * rhs.Int(i));
498 			dst.move(i + lhsType.componentCount, MulHigh(lhs.Int(i), rhs.Int(i)));
499 			break;
500 		case spv::OpUMulExtended:
501 			dst.move(i, lhs.UInt(i) * rhs.UInt(i));
502 			dst.move(i + lhsType.componentCount, MulHigh(lhs.UInt(i), rhs.UInt(i)));
503 			break;
504 		case spv::OpIAddCarry:
505 			dst.move(i, lhs.UInt(i) + rhs.UInt(i));
506 			dst.move(i + lhsType.componentCount, CmpLT(dst.UInt(i), lhs.UInt(i)) >> 31);
507 			break;
508 		case spv::OpISubBorrow:
509 			dst.move(i, lhs.UInt(i) - rhs.UInt(i));
510 			dst.move(i + lhsType.componentCount, CmpLT(lhs.UInt(i), rhs.UInt(i)) >> 31);
511 			break;
512 		default:
513 			UNREACHABLE("%s", OpcodeName(insn.opcode()));
514 		}
515 	}
516 
517 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(2), dst);
518 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
519 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
520 
521 	return EmitResult::Continue;
522 }
523 
EmitDot(InsnIterator insn,EmitState * state) const524 SpirvShader::EmitResult SpirvShader::EmitDot(InsnIterator insn, EmitState *state) const
525 {
526 	auto &type = getType(insn.resultTypeId());
527 	ASSERT(type.componentCount == 1);
528 	auto &dst = state->createIntermediate(insn.resultId(), type.componentCount);
529 	auto &lhsType = getObjectType(insn.word(3));
530 	auto lhs = Operand(this, state, insn.word(3));
531 	auto rhs = Operand(this, state, insn.word(4));
532 
533 	auto opcode = insn.opcode();
534 	switch(opcode)
535 	{
536 	case spv::OpDot:
537 		dst.move(0, FDot(lhsType.componentCount, lhs, rhs));
538 		break;
539 	case spv::OpSDot:
540 		dst.move(0, SDot(lhsType.componentCount, lhs, rhs, nullptr));
541 		break;
542 	case spv::OpUDot:
543 		dst.move(0, UDot(lhsType.componentCount, lhs, rhs, nullptr));
544 		break;
545 	case spv::OpSUDot:
546 		dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, nullptr));
547 		break;
548 	case spv::OpSDotAccSat:
549 		{
550 			auto accum = Operand(this, state, insn.word(5));
551 			dst.move(0, SDot(lhsType.componentCount, lhs, rhs, &accum));
552 		}
553 		break;
554 	case spv::OpUDotAccSat:
555 		{
556 			auto accum = Operand(this, state, insn.word(5));
557 			dst.move(0, UDot(lhsType.componentCount, lhs, rhs, &accum));
558 		}
559 		break;
560 	case spv::OpSUDotAccSat:
561 		{
562 			auto accum = Operand(this, state, insn.word(5));
563 			dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, &accum));
564 		}
565 		break;
566 	default:
567 		UNREACHABLE("%s", OpcodeName(opcode));
568 		break;
569 	}
570 
571 	SPIRV_SHADER_DBG("{0}: {1}", insn.resultId(), dst);
572 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
573 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
574 
575 	return EmitResult::Continue;
576 }
577 
FDot(unsigned numComponents,Operand const & x,Operand const & y)578 SIMD::Float SpirvShader::FDot(unsigned numComponents, Operand const &x, Operand const &y)
579 {
580 	SIMD::Float d = x.Float(0) * y.Float(0);
581 
582 	for(auto i = 1u; i < numComponents; i++)
583 	{
584 		d += x.Float(i) * y.Float(i);
585 	}
586 
587 	return d;
588 }
589 
SDot(unsigned numComponents,Operand const & x,Operand const & y,Operand const * accum)590 SIMD::Int SpirvShader::SDot(unsigned numComponents, Operand const &x, Operand const &y, Operand const *accum)
591 {
592 	SIMD::Int d(0);
593 
594 	if(numComponents == 1)  // 4x8bit packed
595 	{
596 		numComponents = 4;
597 		for(auto i = 0u; i < numComponents; i++)
598 		{
599 			Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
600 			Int4 ys(As<SByte4>(Extract(y.Int(0), i)));
601 
602 			Int4 xy = xs * ys;
603 			rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
604 
605 			d = Insert(d, sum, i);
606 		}
607 	}
608 	else
609 	{
610 		d = x.Int(0) * y.Int(0);
611 
612 		for(auto i = 1u; i < numComponents; i++)
613 		{
614 			d += x.Int(i) * y.Int(i);
615 		}
616 	}
617 
618 	if(accum)
619 	{
620 		d = AddSat(d, accum->Int(0));
621 	}
622 
623 	return d;
624 }
625 
UDot(unsigned numComponents,Operand const & x,Operand const & y,Operand const * accum)626 SIMD::UInt SpirvShader::UDot(unsigned numComponents, Operand const &x, Operand const &y, Operand const *accum)
627 {
628 	SIMD::UInt d(0);
629 
630 	if(numComponents == 1)  // 4x8bit packed
631 	{
632 		numComponents = 4;
633 		for(auto i = 0u; i < numComponents; i++)
634 		{
635 			Int4 xs(As<Byte4>(Extract(x.Int(0), i)));
636 			Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
637 
638 			UInt4 xy = xs * ys;
639 			rr::UInt sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
640 
641 			d = Insert(d, sum, i);
642 		}
643 	}
644 	else
645 	{
646 		d = x.UInt(0) * y.UInt(0);
647 
648 		for(auto i = 1u; i < numComponents; i++)
649 		{
650 			d += x.UInt(i) * y.UInt(i);
651 		}
652 	}
653 
654 	if(accum)
655 	{
656 		d = AddSat(d, accum->UInt(0));
657 	}
658 
659 	return d;
660 }
661 
SUDot(unsigned numComponents,Operand const & x,Operand const & y,Operand const * accum)662 SIMD::Int SpirvShader::SUDot(unsigned numComponents, Operand const &x, Operand const &y, Operand const *accum)
663 {
664 	SIMD::Int d(0);
665 
666 	if(numComponents == 1)  // 4x8bit packed
667 	{
668 		numComponents = 4;
669 		for(auto i = 0u; i < numComponents; i++)
670 		{
671 			Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
672 			Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
673 
674 			Int4 xy = xs * ys;
675 			rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
676 
677 			d = Insert(d, sum, i);
678 		}
679 	}
680 	else
681 	{
682 		d = x.Int(0) * As<SIMD::Int>(y.UInt(0));
683 
684 		for(auto i = 1u; i < numComponents; i++)
685 		{
686 			d += x.Int(i) * As<SIMD::Int>(y.UInt(i));
687 		}
688 	}
689 
690 	if(accum)
691 	{
692 		d = AddSat(d, accum->Int(0));
693 	}
694 
695 	return d;
696 }
697 
AddSat(RValue<SIMD::Int> a,RValue<SIMD::Int> b)698 SIMD::Int SpirvShader::AddSat(RValue<SIMD::Int> a, RValue<SIMD::Int> b)
699 {
700 	SIMD::Int sum = a + b;
701 	SIMD::Int sSign = sum >> 31;
702 	SIMD::Int aSign = a >> 31;
703 	SIMD::Int bSign = b >> 31;
704 
705 	// Overflow happened if both numbers added have the same sign and the sum has a different sign
706 	SIMD::Int oob = ~(aSign ^ bSign) & (aSign ^ sSign);
707 	SIMD::Int overflow = oob & sSign;
708 	SIMD::Int underflow = oob & aSign;
709 
710 	return (overflow & std::numeric_limits<int32_t>::max()) |
711 	       (underflow & std::numeric_limits<int32_t>::min()) |
712 	       (~oob & sum);
713 }
714 
AddSat(RValue<SIMD::UInt> a,RValue<SIMD::UInt> b)715 SIMD::UInt SpirvShader::AddSat(RValue<SIMD::UInt> a, RValue<SIMD::UInt> b)
716 {
717 	SIMD::UInt sum = a + b;
718 
719 	// Overflow happened if the sum of unsigned integers is smaller than either of the 2 numbers being added
720 	// Note: CmpLT()'s return value is automatically set to UINT_MAX when true
721 	return CmpLT(sum, a) | sum;
722 }
723 
Frexp(RValue<SIMD::Float> val) const724 std::pair<SIMD::Float, SIMD::Int> SpirvShader::Frexp(RValue<SIMD::Float> val) const
725 {
726 	// Assumes IEEE 754
727 	auto v = As<SIMD::UInt>(val);
728 	auto isNotZero = CmpNEQ(v & SIMD::UInt(0x7FFFFFFF), SIMD::UInt(0));
729 	auto zeroSign = v & SIMD::UInt(0x80000000) & ~isNotZero;
730 	auto significand = As<SIMD::Float>((((v & SIMD::UInt(0x807FFFFF)) | SIMD::UInt(0x3F000000)) & isNotZero) | zeroSign);
731 	auto exponent = Exponent(val) & SIMD::Int(isNotZero);
732 	return std::make_pair(significand, exponent);
733 }
734 
735 }  // namespace sw