1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27
28 /*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39 if (val == NULL)
40 return NULL;
41
42 if (glsl_type_is_matrix(val->type))
43 return val;
44
45 struct vtn_ssa_value *dest = vtn_zalloc(b, struct vtn_ssa_value);
46 dest->type = glsl_get_bare_type(val->type);
47 dest->elems = vtn_alloc_array(b, struct vtn_ssa_value *, 1);
48 dest->elems[0] = val;
49
50 return dest;
51 }
52
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56 if (glsl_type_is_matrix(val->type))
57 return val;
58
59 return val->elems[0];
60 }
61
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66
67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72 unsigned src0_rows = glsl_get_vector_elements(src0->type);
73 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76 const struct glsl_type *dest_type;
77 if (src1_columns > 1) {
78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79 src0_rows, src1_columns);
80 } else {
81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 }
83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85 dest = wrap_matrix(b, dest);
86
87 bool transpose_result = false;
88 if (src0_transpose && src1_transpose) {
89 /* transpose(A) * transpose(B) = transpose(B * A) */
90 src1 = src0_transpose;
91 src0 = src1_transpose;
92 src0_transpose = NULL;
93 src1_transpose = NULL;
94 transpose_result = true;
95 }
96
97 for (unsigned i = 0; i < src1_columns; i++) {
98 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
99 dest->elems[i]->def =
100 nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
101 nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
102 for (int j = src0_columns - 2; j >= 0; j--) {
103 dest->elems[i]->def =
104 nir_ffma(&b->nb, src0->elems[j]->def,
105 nir_channel(&b->nb, src1->elems[i]->def, j),
106 dest->elems[i]->def);
107 }
108 }
109
110 dest = unwrap_matrix(dest);
111
112 if (transpose_result)
113 dest = vtn_ssa_transpose(b, dest);
114
115 return dest;
116 }
117
118 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_def * scalar)119 mat_times_scalar(struct vtn_builder *b,
120 struct vtn_ssa_value *mat,
121 nir_def *scalar)
122 {
123 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
124 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
125 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
126 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
127 else
128 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
129 }
130
131 return dest;
132 }
133
134 nir_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)135 vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
136 {
137 if (def->bit_size == 16)
138 return def;
139
140 switch (base_type) {
141 case GLSL_TYPE_FLOAT:
142 return nir_f2fmp(&b->nb, def);
143 case GLSL_TYPE_INT:
144 case GLSL_TYPE_UINT:
145 return nir_i2imp(&b->nb, def);
146 /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
147 * OpLogical* operations (which is forbidden by spec).
148 */
149 case GLSL_TYPE_BOOL:
150 return def;
151 default:
152 unreachable("bad relaxed precision input type");
153 }
154 }
155
156 struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)157 vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
158 {
159 if (!src)
160 return src;
161
162 struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
163
164 if (src->transposed) {
165 srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
166 } else {
167 enum glsl_base_type base_type = glsl_get_base_type(src->type);
168
169 if (glsl_type_is_vector_or_scalar(src->type)) {
170 srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
171 } else {
172 assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
173 for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
174 srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
175 }
176 }
177
178 return srcmp;
179 }
180
181 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)182 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
183 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
184 {
185 switch (opcode) {
186 case SpvOpFNegate: {
187 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
188 unsigned cols = glsl_get_matrix_columns(src0->type);
189 for (unsigned i = 0; i < cols; i++)
190 dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
191 return dest;
192 }
193
194 case SpvOpFAdd: {
195 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
196 unsigned cols = glsl_get_matrix_columns(src0->type);
197 for (unsigned i = 0; i < cols; i++)
198 dest->elems[i]->def =
199 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
200 return dest;
201 }
202
203 case SpvOpFSub: {
204 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
205 unsigned cols = glsl_get_matrix_columns(src0->type);
206 for (unsigned i = 0; i < cols; i++)
207 dest->elems[i]->def =
208 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
209 return dest;
210 }
211
212 case SpvOpTranspose:
213 return vtn_ssa_transpose(b, src0);
214
215 case SpvOpMatrixTimesScalar:
216 if (src0->transposed) {
217 return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
218 src1->def));
219 } else {
220 return mat_times_scalar(b, src0, src1->def);
221 }
222 break;
223
224 case SpvOpVectorTimesMatrix:
225 case SpvOpMatrixTimesVector:
226 case SpvOpMatrixTimesMatrix:
227 if (opcode == SpvOpVectorTimesMatrix) {
228 return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
229 } else {
230 return matrix_multiply(b, src0, src1);
231 }
232 break;
233
234 default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
235 }
236 }
237
238 static nir_alu_type
convert_op_src_type(SpvOp opcode)239 convert_op_src_type(SpvOp opcode)
240 {
241 switch (opcode) {
242 case SpvOpFConvert:
243 case SpvOpConvertFToS:
244 case SpvOpConvertFToU:
245 return nir_type_float;
246 case SpvOpSConvert:
247 case SpvOpConvertSToF:
248 case SpvOpSatConvertSToU:
249 return nir_type_int;
250 case SpvOpUConvert:
251 case SpvOpConvertUToF:
252 case SpvOpSatConvertUToS:
253 return nir_type_uint;
254 default:
255 unreachable("Unhandled conversion op");
256 }
257 }
258
259 static nir_alu_type
convert_op_dst_type(SpvOp opcode)260 convert_op_dst_type(SpvOp opcode)
261 {
262 switch (opcode) {
263 case SpvOpFConvert:
264 case SpvOpConvertSToF:
265 case SpvOpConvertUToF:
266 return nir_type_float;
267 case SpvOpSConvert:
268 case SpvOpConvertFToS:
269 case SpvOpSatConvertUToS:
270 return nir_type_int;
271 case SpvOpUConvert:
272 case SpvOpConvertFToU:
273 case SpvOpSatConvertSToU:
274 return nir_type_uint;
275 default:
276 unreachable("Unhandled conversion op");
277 }
278 }
279
280 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)281 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
282 SpvOp opcode, bool *swap, bool *exact,
283 unsigned src_bit_size, unsigned dst_bit_size)
284 {
285 /* Indicates that the first two arguments should be swapped. This is
286 * used for implementing greater-than and less-than-or-equal.
287 */
288 *swap = false;
289
290 *exact = false;
291
292 switch (opcode) {
293 case SpvOpSNegate: return nir_op_ineg;
294 case SpvOpFNegate: return nir_op_fneg;
295 case SpvOpNot: return nir_op_inot;
296 case SpvOpIAdd: return nir_op_iadd;
297 case SpvOpFAdd: return nir_op_fadd;
298 case SpvOpISub: return nir_op_isub;
299 case SpvOpFSub: return nir_op_fsub;
300 case SpvOpIMul: return nir_op_imul;
301 case SpvOpFMul: return nir_op_fmul;
302 case SpvOpUDiv: return nir_op_udiv;
303 case SpvOpSDiv: return nir_op_idiv;
304 case SpvOpFDiv: return nir_op_fdiv;
305 case SpvOpUMod: return nir_op_umod;
306 case SpvOpSMod: return nir_op_imod;
307 case SpvOpFMod: return nir_op_fmod;
308 case SpvOpSRem: return nir_op_irem;
309 case SpvOpFRem: return nir_op_frem;
310
311 case SpvOpShiftRightLogical: return nir_op_ushr;
312 case SpvOpShiftRightArithmetic: return nir_op_ishr;
313 case SpvOpShiftLeftLogical: return nir_op_ishl;
314 case SpvOpLogicalOr: return nir_op_ior;
315 case SpvOpLogicalEqual: return nir_op_ieq;
316 case SpvOpLogicalNotEqual: return nir_op_ine;
317 case SpvOpLogicalAnd: return nir_op_iand;
318 case SpvOpLogicalNot: return nir_op_inot;
319 case SpvOpBitwiseOr: return nir_op_ior;
320 case SpvOpBitwiseXor: return nir_op_ixor;
321 case SpvOpBitwiseAnd: return nir_op_iand;
322 case SpvOpSelect: return nir_op_bcsel;
323 case SpvOpIEqual: return nir_op_ieq;
324
325 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
326 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
327 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
328 case SpvOpBitReverse: return nir_op_bitfield_reverse;
329
330 case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
331 /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
332 case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
333 case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
334 case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
335 case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
336 case SpvOpIAverageINTEL: return nir_op_ihadd;
337 case SpvOpUAverageINTEL: return nir_op_uhadd;
338 case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
339 case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
340 case SpvOpISubSatINTEL: return nir_op_isub_sat;
341 case SpvOpUSubSatINTEL: return nir_op_usub_sat;
342 case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
343 case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
344
345 /* The ordered / unordered operators need special implementation besides
346 * the logical operator to use since they also need to check if operands are
347 * ordered.
348 */
349 case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
350 case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
351 case SpvOpINotEqual: return nir_op_ine;
352 case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
353 case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
354 case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
355 case SpvOpULessThan: return nir_op_ult;
356 case SpvOpSLessThan: return nir_op_ilt;
357 case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
358 case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
359 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
360 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
361 case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
362 case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
363 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
364 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
365 case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
366 case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
367 case SpvOpUGreaterThanEqual: return nir_op_uge;
368 case SpvOpSGreaterThanEqual: return nir_op_ige;
369 case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
370 case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
371
372 /* Conversions: */
373 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
374 case SpvOpUConvert:
375 case SpvOpConvertFToU:
376 case SpvOpConvertFToS:
377 case SpvOpConvertSToF:
378 case SpvOpConvertUToF:
379 case SpvOpSConvert:
380 case SpvOpFConvert: {
381 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
382 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
383 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
384 }
385
386 case SpvOpPtrCastToGeneric: return nir_op_mov;
387 case SpvOpGenericCastToPtr: return nir_op_mov;
388
389 case SpvOpIsNormal: return nir_op_fisnormal;
390 case SpvOpIsFinite: return nir_op_fisfinite;
391
392 default:
393 vtn_fail("No NIR equivalent: %u", opcode);
394 }
395 }
396
397 static void
handle_fp_fast_math(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)398 handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val,
399 UNUSED int member, const struct vtn_decoration *dec,
400 UNUSED void *_void)
401 {
402 vtn_assert(dec->scope == VTN_DEC_DECORATION);
403 if (dec->decoration != SpvDecorationFPFastMathMode)
404 return;
405
406 SpvFPFastMathModeMask can_fast_math =
407 SpvFPFastMathModeAllowRecipMask |
408 SpvFPFastMathModeAllowContractMask |
409 SpvFPFastMathModeAllowReassocMask |
410 SpvFPFastMathModeAllowTransformMask;
411
412 if ((dec->operands[0] & can_fast_math) != can_fast_math)
413 b->nb.exact = true;
414
415 /* Decoration overrides defaults */
416 b->nb.fp_fast_math = 0;
417 if (!(dec->operands[0] & SpvFPFastMathModeNSZMask))
418 b->nb.fp_fast_math |=
419 FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP16 |
420 FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP32 |
421 FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP64;
422 if (!(dec->operands[0] & SpvFPFastMathModeNotNaNMask))
423 b->nb.fp_fast_math |=
424 FLOAT_CONTROLS_NAN_PRESERVE_FP16 |
425 FLOAT_CONTROLS_NAN_PRESERVE_FP32 |
426 FLOAT_CONTROLS_NAN_PRESERVE_FP64;
427 if (!(dec->operands[0] & SpvFPFastMathModeNotInfMask))
428 b->nb.fp_fast_math |=
429 FLOAT_CONTROLS_INF_PRESERVE_FP16 |
430 FLOAT_CONTROLS_INF_PRESERVE_FP32 |
431 FLOAT_CONTROLS_INF_PRESERVE_FP64;
432 }
433
434 void
vtn_handle_fp_fast_math(struct vtn_builder * b,struct vtn_value * val)435 vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
436 {
437 /* Take the NaN/Inf/SZ preserve bits from the execution mode and set them
438 * on the builder, so the generated instructions can take it from it.
439 * We only care about some of them, check nir_alu_instr for details.
440 * We also copy all bit widths, because we can't easily get the correct one
441 * here.
442 */
443 #define FLOAT_CONTROLS2_BITS (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 | \
444 FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 | \
445 FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64)
446 static_assert(FLOAT_CONTROLS2_BITS == BITSET_MASK(9),
447 "enum float_controls and fp_fast_math out of sync!");
448 b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode &
449 FLOAT_CONTROLS2_BITS;
450 vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL);
451 #undef FLOAT_CONTROLS2_BITS
452 }
453
454 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)455 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
456 UNUSED int member, const struct vtn_decoration *dec,
457 UNUSED void *_void)
458 {
459 vtn_assert(dec->scope == VTN_DEC_DECORATION);
460 if (dec->decoration != SpvDecorationNoContraction)
461 return;
462
463 b->nb.exact = true;
464 }
465
466 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)467 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
468 {
469 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
470 }
471
472 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)473 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
474 {
475 switch (mode) {
476 case SpvFPRoundingModeRTE:
477 return nir_rounding_mode_rtne;
478 case SpvFPRoundingModeRTZ:
479 return nir_rounding_mode_rtz;
480 case SpvFPRoundingModeRTP:
481 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
482 "FPRoundingModeRTP is only supported in kernels");
483 return nir_rounding_mode_ru;
484 case SpvFPRoundingModeRTN:
485 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
486 "FPRoundingModeRTN is only supported in kernels");
487 return nir_rounding_mode_rd;
488 default:
489 vtn_fail("Unsupported rounding mode: %s",
490 spirv_fproundingmode_to_string(mode));
491 break;
492 }
493 }
494
495 struct conversion_opts {
496 nir_rounding_mode rounding_mode;
497 bool saturate;
498 };
499
500 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)501 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
502 UNUSED int member,
503 const struct vtn_decoration *dec, void *_opts)
504 {
505 struct conversion_opts *opts = _opts;
506
507 switch (dec->decoration) {
508 case SpvDecorationFPRoundingMode:
509 opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
510 break;
511
512 case SpvDecorationSaturatedConversion:
513 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
514 "Saturated conversions are only allowed in kernels");
515 opts->saturate = true;
516 break;
517
518 default:
519 break;
520 }
521 }
522
523 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)524 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
525 UNUSED int member,
526 const struct vtn_decoration *dec, void *_alu)
527 {
528 nir_alu_instr *alu = _alu;
529 switch (dec->decoration) {
530 case SpvDecorationNoSignedWrap:
531 alu->no_signed_wrap = true;
532 break;
533 case SpvDecorationNoUnsignedWrap:
534 alu->no_unsigned_wrap = true;
535 break;
536 default:
537 /* Do nothing. */
538 break;
539 }
540 }
541
542 static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)543 vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
544 struct vtn_value *val, int member,
545 const struct vtn_decoration *dec, void *void_ctx)
546 {
547 bool *relaxed_precision = void_ctx;
548 switch (dec->decoration) {
549 case SpvDecorationRelaxedPrecision:
550 *relaxed_precision = true;
551 break;
552
553 default:
554 break;
555 }
556 }
557
558 bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)559 vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
560 {
561 bool result = false;
562 vtn_foreach_decoration(b, val,
563 vtn_value_is_relaxed_precision_cb, &result);
564 return result;
565 }
566
567 static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)568 vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
569 {
570 if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
571 return false;
572
573 switch (opcode) {
574 case SpvOpDPdx:
575 case SpvOpDPdy:
576 case SpvOpDPdxFine:
577 case SpvOpDPdyFine:
578 case SpvOpDPdxCoarse:
579 case SpvOpDPdyCoarse:
580 case SpvOpFwidth:
581 case SpvOpFwidthFine:
582 case SpvOpFwidthCoarse:
583 return b->options->mediump_16bit_derivatives;
584 default:
585 return true;
586 }
587 }
588
589 static nir_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)590 vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
591 {
592 if (def->bit_size != 16)
593 return def;
594
595 switch (base_type) {
596 case GLSL_TYPE_FLOAT:
597 return nir_f2f32(&b->nb, def);
598 case GLSL_TYPE_INT:
599 return nir_i2i32(&b->nb, def);
600 case GLSL_TYPE_UINT:
601 return nir_u2u32(&b->nb, def);
602 default:
603 unreachable("bad relaxed precision output type");
604 }
605 }
606
607 void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)608 vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
609 {
610 enum glsl_base_type base_type = glsl_get_base_type(value->type);
611
612 if (glsl_type_is_vector_or_scalar(value->type)) {
613 value->def = vtn_mediump_upconvert(b, base_type, value->def);
614 } else {
615 for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
616 value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
617 }
618 }
619
620 static nir_def *
vtn_handle_deriv(struct vtn_builder * b,SpvOp opcode,nir_def * src)621 vtn_handle_deriv(struct vtn_builder *b, SpvOp opcode, nir_def *src)
622 {
623 /* SPV_NV_compute_shader_derivatives:
624 * In the GLCompute Execution Model:
625 * Selection of the four invocations is determined by the DerivativeGroup*NV
626 * execution mode that was specified for the entry point.
627 * If neither derivative group mode was specified, the derivatives return zero.
628 */
629 if (b->nb.shader->info.stage == MESA_SHADER_COMPUTE &&
630 b->nb.shader->info.derivative_group == DERIVATIVE_GROUP_NONE) {
631 return nir_imm_zero(&b->nb, src->num_components, src->bit_size);
632 }
633
634 switch (opcode) {
635 case SpvOpDPdx:
636 return nir_ddx(&b->nb, src);
637 case SpvOpDPdxFine:
638 return nir_ddx_fine(&b->nb, src);
639 case SpvOpDPdxCoarse:
640 return nir_ddx_coarse(&b->nb, src);
641 case SpvOpDPdy:
642 return nir_ddy(&b->nb, src);
643 case SpvOpDPdyFine:
644 return nir_ddy_fine(&b->nb, src);
645 case SpvOpDPdyCoarse:
646 return nir_ddy_coarse(&b->nb, src);
647 case SpvOpFwidth:
648 return nir_fadd(&b->nb,
649 nir_fabs(&b->nb, nir_ddx(&b->nb, src)),
650 nir_fabs(&b->nb, nir_ddy(&b->nb, src)));
651 case SpvOpFwidthFine:
652 return nir_fadd(&b->nb,
653 nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src)),
654 nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src)));
655 case SpvOpFwidthCoarse:
656 return nir_fadd(&b->nb,
657 nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src)),
658 nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src)));
659 default: unreachable("Not a derivative opcode");
660 }
661 }
662
663 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)664 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
665 const uint32_t *w, unsigned count)
666 {
667 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
668 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
669
670 if (glsl_type_is_cmat(dest_type)) {
671 vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
672 return;
673 }
674
675 vtn_handle_no_contraction(b, dest_val);
676 vtn_handle_fp_fast_math(b, dest_val);
677 bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
678
679 /* Collect the various SSA sources */
680 const unsigned num_inputs = count - 3;
681 struct vtn_ssa_value *vtn_src[4] = { NULL, };
682 for (unsigned i = 0; i < num_inputs; i++) {
683 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
684 if (mediump_16bit)
685 vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
686 }
687
688 if (glsl_type_is_matrix(vtn_src[0]->type) ||
689 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
690 struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
691
692 if (mediump_16bit)
693 vtn_mediump_upconvert_value(b, dest);
694
695 vtn_push_ssa_value(b, w[2], dest);
696 b->nb.exact = b->exact;
697 return;
698 }
699
700 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
701 nir_def *src[4] = { NULL, };
702 for (unsigned i = 0; i < num_inputs; i++) {
703 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
704 src[i] = vtn_src[i]->def;
705 }
706
707 switch (opcode) {
708 case SpvOpAny:
709 dest->def = nir_bany(&b->nb, src[0]);
710 break;
711
712 case SpvOpAll:
713 dest->def = nir_ball(&b->nb, src[0]);
714 break;
715
716 case SpvOpOuterProduct: {
717 for (unsigned i = 0; i < src[1]->num_components; i++) {
718 dest->elems[i]->def =
719 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
720 }
721 break;
722 }
723
724 case SpvOpDot:
725 dest->def = nir_fdot(&b->nb, src[0], src[1]);
726 break;
727
728 case SpvOpIAddCarry:
729 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
730 dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
731 dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
732 break;
733
734 case SpvOpISubBorrow:
735 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
736 dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
737 dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
738 break;
739
740 case SpvOpUMulExtended: {
741 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
742 if (src[0]->bit_size == 32) {
743 nir_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
744 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
745 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
746 } else {
747 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
748 dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
749 }
750 break;
751 }
752
753 case SpvOpSMulExtended: {
754 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
755 if (src[0]->bit_size == 32) {
756 nir_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
757 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
758 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
759 } else {
760 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
761 dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
762 }
763 break;
764 }
765
766 case SpvOpDPdx:
767 case SpvOpDPdxFine:
768 case SpvOpDPdxCoarse:
769 case SpvOpDPdy:
770 case SpvOpDPdyFine:
771 case SpvOpDPdyCoarse:
772 case SpvOpFwidth:
773 case SpvOpFwidthFine:
774 case SpvOpFwidthCoarse:
775 dest->def = vtn_handle_deriv(b, opcode, src[0]);
776 break;
777
778 case SpvOpVectorTimesScalar:
779 /* The builder will take care of splatting for us. */
780 dest->def = nir_fmul(&b->nb, src[0], src[1]);
781 break;
782
783 case SpvOpIsNan: {
784 const bool save_exact = b->nb.exact;
785
786 b->nb.exact = true;
787 dest->def = nir_fneu(&b->nb, src[0], src[0]);
788 b->nb.exact = save_exact;
789 break;
790 }
791
792 case SpvOpOrdered: {
793 const bool save_exact = b->nb.exact;
794
795 b->nb.exact = true;
796 dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
797 nir_feq(&b->nb, src[1], src[1]));
798 b->nb.exact = save_exact;
799 break;
800 }
801
802 case SpvOpUnordered: {
803 const bool save_exact = b->nb.exact;
804
805 b->nb.exact = true;
806 dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
807 nir_fneu(&b->nb, src[1], src[1]));
808 b->nb.exact = save_exact;
809 break;
810 }
811
812 case SpvOpIsInf: {
813 nir_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
814 dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
815 break;
816 }
817
818 case SpvOpFUnordEqual: {
819 const bool save_exact = b->nb.exact;
820
821 b->nb.exact = true;
822
823 /* This could also be implemented as !(a < b || b < a). If one or both
824 * of the source are numbers, later optimization passes can easily
825 * eliminate the isnan() checks. This may trim the sequence down to a
826 * single (a == b) operation. Otherwise, the optimizer can transform
827 * whatever is left to !(a < b || b < a). Since some applications will
828 * open-code this sequence, these optimizations are needed anyway.
829 */
830 dest->def =
831 nir_ior(&b->nb,
832 nir_feq(&b->nb, src[0], src[1]),
833 nir_ior(&b->nb,
834 nir_fneu(&b->nb, src[0], src[0]),
835 nir_fneu(&b->nb, src[1], src[1])));
836
837 b->nb.exact = save_exact;
838 break;
839 }
840
841 case SpvOpFUnordLessThan:
842 case SpvOpFUnordGreaterThan:
843 case SpvOpFUnordLessThanEqual:
844 case SpvOpFUnordGreaterThanEqual: {
845 bool swap;
846 bool unused_exact;
847 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
848 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
849 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
850 &unused_exact,
851 src_bit_size, dst_bit_size);
852
853 if (swap) {
854 nir_def *tmp = src[0];
855 src[0] = src[1];
856 src[1] = tmp;
857 }
858
859 const bool save_exact = b->nb.exact;
860
861 b->nb.exact = true;
862
863 /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
864 switch (op) {
865 case nir_op_fge: op = nir_op_flt; break;
866 case nir_op_flt: op = nir_op_fge; break;
867 default: unreachable("Impossible opcode.");
868 }
869
870 dest->def =
871 nir_inot(&b->nb,
872 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
873
874 b->nb.exact = save_exact;
875 break;
876 }
877
878 case SpvOpLessOrGreater:
879 case SpvOpFOrdNotEqual: {
880 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
881 * from the ALU will probably already be false if the operands are not
882 * ordered so we don’t need to handle it specially.
883 */
884 const bool save_exact = b->nb.exact;
885
886 b->nb.exact = true;
887
888 /* This could also be implemented as (a < b || b < a). If one or both
889 * of the source are numbers, later optimization passes can easily
890 * eliminate the isnan() checks. This may trim the sequence down to a
891 * single (a != b) operation. Otherwise, the optimizer can transform
892 * whatever is left to (a < b || b < a). Since some applications will
893 * open-code this sequence, these optimizations are needed anyway.
894 */
895 dest->def =
896 nir_iand(&b->nb,
897 nir_fneu(&b->nb, src[0], src[1]),
898 nir_iand(&b->nb,
899 nir_feq(&b->nb, src[0], src[0]),
900 nir_feq(&b->nb, src[1], src[1])));
901
902 b->nb.exact = save_exact;
903 break;
904 }
905
906 case SpvOpUConvert:
907 case SpvOpConvertFToU:
908 case SpvOpConvertFToS:
909 case SpvOpConvertSToF:
910 case SpvOpConvertUToF:
911 case SpvOpSConvert:
912 case SpvOpFConvert:
913 case SpvOpSatConvertSToU:
914 case SpvOpSatConvertUToS: {
915 unsigned src_bit_size = src[0]->bit_size;
916 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
917 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
918 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
919
920 struct conversion_opts opts = {
921 .rounding_mode = nir_rounding_mode_undef,
922 .saturate = false,
923 };
924 vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
925
926 if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
927 opts.saturate = true;
928
929 if (b->shader->info.stage == MESA_SHADER_KERNEL) {
930 if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
931 dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
932 nir_rounding_mode_undef);
933 } else {
934 dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
935 src_type, dst_type,
936 opts.rounding_mode, opts.saturate);
937 }
938 } else {
939 vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
940 dst_type != nir_type_float16,
941 "Rounding modes are only allowed on conversions to "
942 "16-bit float types");
943 dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
944 opts.rounding_mode);
945 }
946 break;
947 }
948
949 case SpvOpBitFieldInsert:
950 case SpvOpBitFieldSExtract:
951 case SpvOpBitFieldUExtract:
952 case SpvOpShiftLeftLogical:
953 case SpvOpShiftRightArithmetic:
954 case SpvOpShiftRightLogical: {
955 bool swap;
956 bool exact;
957 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
958 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
959 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
960 src0_bit_size, dst_bit_size);
961
962 assert(!exact);
963
964 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
965 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
966 op == nir_op_ibitfield_extract);
967
968 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
969 unsigned src_bit_size =
970 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
971 if (src_bit_size == 0)
972 continue;
973 if (src_bit_size != src[i]->bit_size) {
974 assert(src_bit_size == 32);
975 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
976 * supported by the NIR instructions. See discussion here:
977 *
978 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
979 */
980 src[i] = nir_u2u32(&b->nb, src[i]);
981 }
982 }
983 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
984 break;
985 }
986
987 case SpvOpSignBitSet:
988 dest->def = nir_i2b(&b->nb,
989 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
990 break;
991
992 case SpvOpUCountTrailingZerosINTEL:
993 dest->def = nir_umin(&b->nb,
994 nir_find_lsb(&b->nb, src[0]),
995 nir_imm_int(&b->nb, 32u));
996 break;
997
998 case SpvOpBitCount: {
999 /* bit_count always returns int32, but the SPIR-V opcode just says the return
1000 * value needs to be big enough to store the number of bits.
1001 */
1002 dest->def = nir_u2uN(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
1003 break;
1004 }
1005
1006 case SpvOpSDotKHR:
1007 case SpvOpUDotKHR:
1008 case SpvOpSUDotKHR:
1009 case SpvOpSDotAccSatKHR:
1010 case SpvOpUDotAccSatKHR:
1011 case SpvOpSUDotAccSatKHR:
1012 unreachable("Should have called vtn_handle_integer_dot instead.");
1013
1014 default: {
1015 bool swap;
1016 bool exact;
1017 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
1018 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
1019 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
1020 &exact,
1021 src_bit_size, dst_bit_size);
1022
1023 if (swap) {
1024 nir_def *tmp = src[0];
1025 src[0] = src[1];
1026 src[1] = tmp;
1027 }
1028
1029 switch (op) {
1030 case nir_op_ishl:
1031 case nir_op_ishr:
1032 case nir_op_ushr:
1033 if (src[1]->bit_size != 32)
1034 src[1] = nir_u2u32(&b->nb, src[1]);
1035 break;
1036 default:
1037 break;
1038 }
1039
1040 const bool save_exact = b->nb.exact;
1041
1042 if (exact)
1043 b->nb.exact = true;
1044
1045 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
1046
1047 b->nb.exact = save_exact;
1048 break;
1049 } /* default */
1050 }
1051
1052 switch (opcode) {
1053 case SpvOpIAdd:
1054 case SpvOpIMul:
1055 case SpvOpISub:
1056 case SpvOpShiftLeftLogical:
1057 case SpvOpSNegate: {
1058 nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
1059 vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
1060 break;
1061 }
1062 default:
1063 /* Do nothing. */
1064 break;
1065 }
1066
1067 if (mediump_16bit)
1068 vtn_mediump_upconvert_value(b, dest);
1069 vtn_push_ssa_value(b, w[2], dest);
1070
1071 b->nb.exact = b->exact;
1072 }
1073
1074 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1075 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1076 const uint32_t *w, unsigned count)
1077 {
1078 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1079 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1080 const unsigned dest_size = glsl_get_bit_size(dest_type);
1081
1082 vtn_handle_no_contraction(b, dest_val);
1083
1084 /* Collect the various SSA sources.
1085 *
1086 * Due to the optional "Packed Vector Format" field, determine number of
1087 * inputs from the opcode. This differs from vtn_handle_alu.
1088 */
1089 const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1090 opcode == SpvOpUDotAccSatKHR ||
1091 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1092
1093 vtn_assert(count >= num_inputs + 3);
1094
1095 struct vtn_ssa_value *vtn_src[3] = { NULL, };
1096 nir_def *src[3] = { NULL, };
1097
1098 for (unsigned i = 0; i < num_inputs; i++) {
1099 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1100 src[i] = vtn_src[i]->def;
1101
1102 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1103 }
1104
1105 /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1106 * the SPV_KHR_integer_dot_product spec says:
1107 *
1108 * _Vector 1_ and _Vector 2_ must have the same type.
1109 *
1110 * The practical requirement is the same bit-size and the same number of
1111 * components.
1112 */
1113 vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1114 glsl_get_bit_size(vtn_src[1]->type) ||
1115 glsl_get_vector_elements(vtn_src[0]->type) !=
1116 glsl_get_vector_elements(vtn_src[1]->type),
1117 "Vector 1 and vector 2 source of opcode %s must have the same "
1118 "type",
1119 spirv_op_to_string(opcode));
1120
1121 if (num_inputs == 3) {
1122 /* The SPV_KHR_integer_dot_product spec says:
1123 *
1124 * The type of Accumulator must be the same as Result Type.
1125 *
1126 * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1127 * types (far below) assumes these types have the same size.
1128 */
1129 vtn_fail_if(dest_type != vtn_src[2]->type,
1130 "Accumulator type must be the same as Result Type for "
1131 "opcode %s",
1132 spirv_op_to_string(opcode));
1133 }
1134
1135 unsigned packed_bit_size = 8;
1136 if (glsl_type_is_vector(vtn_src[0]->type)) {
1137 /* FINISHME: Is this actually as good or better for platforms that don't
1138 * have the special instructions (i.e., one or both of has_dot_4x8 or
1139 * has_sudot_4x8 is false)?
1140 */
1141 if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1142 glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1143 glsl_get_bit_size(dest_type) <= 32) {
1144 src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1145 src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1146 } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1147 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1148 glsl_get_bit_size(dest_type) <= 32 &&
1149 opcode != SpvOpSUDotKHR &&
1150 opcode != SpvOpSUDotAccSatKHR) {
1151 src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1152 src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1153 packed_bit_size = 16;
1154 }
1155 } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1156 glsl_type_is_32bit(vtn_src[0]->type)) {
1157 /* The SPV_KHR_integer_dot_product spec says:
1158 *
1159 * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1160 * Vector Format_ must be specified to select how the integers are to
1161 * be interpreted as vectors.
1162 *
1163 * The "Packed Vector Format" value follows the last input.
1164 */
1165 vtn_assert(count == (num_inputs + 4));
1166 const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1167 vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1168 "Unsupported vector packing format %d for opcode %s",
1169 pack_format, spirv_op_to_string(opcode));
1170 } else {
1171 vtn_fail_with_opcode("Invalid source types.", opcode);
1172 }
1173
1174 nir_def *dest = NULL;
1175
1176 if (src[0]->num_components > 1) {
1177 nir_def *(*src0_conversion)(nir_builder *, nir_def *, unsigned);
1178 nir_def *(*src1_conversion)(nir_builder *, nir_def *, unsigned);
1179
1180 switch (opcode) {
1181 case SpvOpSDotKHR:
1182 case SpvOpSDotAccSatKHR:
1183 src0_conversion = nir_i2iN;
1184 src1_conversion = nir_i2iN;
1185 break;
1186
1187 case SpvOpUDotKHR:
1188 case SpvOpUDotAccSatKHR:
1189 src0_conversion = nir_u2uN;
1190 src1_conversion = nir_u2uN;
1191 break;
1192
1193 case SpvOpSUDotKHR:
1194 case SpvOpSUDotAccSatKHR:
1195 src0_conversion = nir_i2iN;
1196 src1_conversion = nir_u2uN;
1197 break;
1198
1199 default:
1200 unreachable("Invalid opcode.");
1201 }
1202
1203 /* The SPV_KHR_integer_dot_product spec says:
1204 *
1205 * All components of the input vectors are sign-extended to the bit
1206 * width of the result's type. The sign-extended input vectors are
1207 * then multiplied component-wise and all components of the vector
1208 * resulting from the component-wise multiplication are added
1209 * together. The resulting value will equal the low-order N bits of
1210 * the correct result R, where N is the result width and R is
1211 * computed with enough precision to avoid overflow and underflow.
1212 */
1213 const unsigned vector_components =
1214 glsl_get_vector_elements(vtn_src[0]->type);
1215
1216 for (unsigned i = 0; i < vector_components; i++) {
1217 nir_def *const src0 =
1218 src0_conversion(&b->nb, nir_channel(&b->nb, src[0], i), dest_size);
1219
1220 nir_def *const src1 =
1221 src1_conversion(&b->nb, nir_channel(&b->nb, src[1], i), dest_size);
1222
1223 nir_def *const mul_result = nir_imul(&b->nb, src0, src1);
1224
1225 dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1226 }
1227
1228 if (num_inputs == 3) {
1229 /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1230 *
1231 * Signed integer dot product of _Vector 1_ and _Vector 2_ and
1232 * signed saturating addition of the result with _Accumulator_.
1233 *
1234 * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1235 *
1236 * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1237 * unsigned saturating addition of the result with _Accumulator_.
1238 *
1239 * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1240 *
1241 * Mixed-signedness integer dot product of _Vector 1_ and _Vector
1242 * 2_ and signed saturating addition of the result with
1243 * _Accumulator_.
1244 */
1245 dest = (opcode == SpvOpUDotAccSatKHR)
1246 ? nir_uadd_sat(&b->nb, dest, src[2])
1247 : nir_iadd_sat(&b->nb, dest, src[2]);
1248 }
1249 } else {
1250 assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1251 assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1252
1253 nir_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1254 bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1255 opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1256
1257 if (packed_bit_size == 16) {
1258 switch (opcode) {
1259 case SpvOpSDotKHR:
1260 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1261 break;
1262 case SpvOpUDotKHR:
1263 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1264 break;
1265 case SpvOpSDotAccSatKHR:
1266 if (dest_size == 32)
1267 dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1268 else
1269 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1270 break;
1271 case SpvOpUDotAccSatKHR:
1272 if (dest_size == 32)
1273 dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1274 else
1275 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1276 break;
1277 default:
1278 unreachable("Invalid opcode.");
1279 }
1280 } else {
1281 switch (opcode) {
1282 case SpvOpSDotKHR:
1283 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1284 break;
1285 case SpvOpUDotKHR:
1286 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1287 break;
1288 case SpvOpSUDotKHR:
1289 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1290 break;
1291 case SpvOpSDotAccSatKHR:
1292 if (dest_size == 32)
1293 dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1294 else
1295 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1296 break;
1297 case SpvOpUDotAccSatKHR:
1298 if (dest_size == 32)
1299 dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1300 else
1301 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1302 break;
1303 case SpvOpSUDotAccSatKHR:
1304 if (dest_size == 32)
1305 dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1306 else
1307 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1308 break;
1309 default:
1310 unreachable("Invalid opcode.");
1311 }
1312 }
1313
1314 if (dest_size != 32) {
1315 /* When the accumulator is 32-bits, a NIR dot-product with saturate
1316 * is generated above. In all other cases a regular dot-product is
1317 * generated above, and separate addition with saturate is generated
1318 * here.
1319 *
1320 * The SPV_KHR_integer_dot_product spec says:
1321 *
1322 * If any of the multiplications or additions, with the exception
1323 * of the final accumulation, overflow or underflow, the result of
1324 * the instruction is undefined.
1325 *
1326 * Therefore it is safe to cast the dot-product result down to the
1327 * size of the accumulator before doing the addition. Since the
1328 * result of the dot-product cannot overflow 32-bits, this is also
1329 * safe to cast up.
1330 */
1331 if (num_inputs == 3) {
1332 dest = is_signed
1333 ? nir_iadd_sat(&b->nb, nir_i2iN(&b->nb, dest, dest_size), src[2])
1334 : nir_uadd_sat(&b->nb, nir_u2uN(&b->nb, dest, dest_size), src[2]);
1335 } else {
1336 dest = is_signed
1337 ? nir_i2iN(&b->nb, dest, dest_size)
1338 : nir_u2uN(&b->nb, dest, dest_size);
1339 }
1340 }
1341 }
1342
1343 vtn_push_nir_ssa(b, w[2], dest);
1344
1345 b->nb.exact = b->exact;
1346 }
1347
1348 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1349 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1350 {
1351 vtn_assert(count == 4);
1352 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1353 *
1354 * "If Result Type has the same number of components as Operand, they
1355 * must also have the same component width, and results are computed per
1356 * component.
1357 *
1358 * If Result Type has a different number of components than Operand, the
1359 * total number of bits in Result Type must equal the total number of
1360 * bits in Operand. Let L be the type, either Result Type or Operand’s
1361 * type, that has the larger number of components. Let S be the other
1362 * type, with the smaller number of components. The number of components
1363 * in L must be an integer multiple of the number of components in S.
1364 * The first component (that is, the only or lowest-numbered component)
1365 * of S maps to the first components of L, and so on, up to the last
1366 * component of S mapping to the last components of L. Within this
1367 * mapping, any single component of S (mapping to multiple components of
1368 * L) maps its lower-ordered bits to the lower-numbered components of L."
1369 */
1370
1371 struct vtn_type *type = vtn_get_type(b, w[1]);
1372 if (type->base_type == vtn_base_type_cooperative_matrix) {
1373 vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
1374 return;
1375 }
1376
1377 struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
1378
1379 vtn_fail_if(src->num_components * src->bit_size !=
1380 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1381 "Source (%%%u) and destination (%%%u) of OpBitcast must have the same "
1382 "total number of bits", w[3], w[2]);
1383 nir_def *val =
1384 nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1385 vtn_push_nir_ssa(b, w[2], val);
1386 }
1387