1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27
28 /*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39 if (val == NULL)
40 return NULL;
41
42 if (glsl_type_is_matrix(val->type))
43 return val;
44
45 struct vtn_ssa_value *dest = vtn_zalloc(b, struct vtn_ssa_value);
46 dest->type = glsl_get_bare_type(val->type);
47 dest->elems = vtn_alloc_array(b, struct vtn_ssa_value *, 1);
48 dest->elems[0] = val;
49
50 return dest;
51 }
52
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56 if (glsl_type_is_matrix(val->type))
57 return val;
58
59 return val->elems[0];
60 }
61
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66
67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72 unsigned src0_rows = glsl_get_vector_elements(src0->type);
73 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76 const struct glsl_type *dest_type;
77 if (src1_columns > 1) {
78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79 src0_rows, src1_columns);
80 } else {
81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 }
83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85 dest = wrap_matrix(b, dest);
86
87 bool transpose_result = false;
88 if (src0_transpose && src1_transpose) {
89 /* transpose(A) * transpose(B) = transpose(B * A) */
90 src1 = src0_transpose;
91 src0 = src1_transpose;
92 src0_transpose = NULL;
93 src1_transpose = NULL;
94 transpose_result = true;
95 }
96
97 for (unsigned i = 0; i < src1_columns; i++) {
98 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
99 dest->elems[i]->def =
100 nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
101 nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
102 for (int j = src0_columns - 2; j >= 0; j--) {
103 dest->elems[i]->def =
104 nir_ffma(&b->nb, src0->elems[j]->def,
105 nir_channel(&b->nb, src1->elems[i]->def, j),
106 dest->elems[i]->def);
107 }
108 }
109
110 dest = unwrap_matrix(dest);
111
112 if (transpose_result)
113 dest = vtn_ssa_transpose(b, dest);
114
115 return dest;
116 }
117
118 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_def * scalar)119 mat_times_scalar(struct vtn_builder *b,
120 struct vtn_ssa_value *mat,
121 nir_def *scalar)
122 {
123 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
124 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
125 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
126 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
127 else
128 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
129 }
130
131 return dest;
132 }
133
134 nir_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)135 vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
136 {
137 if (def->bit_size == 16)
138 return def;
139
140 switch (base_type) {
141 case GLSL_TYPE_FLOAT:
142 return nir_f2fmp(&b->nb, def);
143 case GLSL_TYPE_INT:
144 case GLSL_TYPE_UINT:
145 return nir_i2imp(&b->nb, def);
146 /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
147 * OpLogical* operations (which is forbidden by spec).
148 */
149 case GLSL_TYPE_BOOL:
150 return def;
151 default:
152 unreachable("bad relaxed precision input type");
153 }
154 }
155
156 struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)157 vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
158 {
159 if (!src)
160 return src;
161
162 struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
163
164 if (src->transposed) {
165 srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
166 } else {
167 enum glsl_base_type base_type = glsl_get_base_type(src->type);
168
169 if (glsl_type_is_vector_or_scalar(src->type)) {
170 srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
171 } else {
172 assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
173 for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
174 srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
175 }
176 }
177
178 return srcmp;
179 }
180
181 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)182 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
183 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
184 {
185 switch (opcode) {
186 case SpvOpFNegate: {
187 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
188 unsigned cols = glsl_get_matrix_columns(src0->type);
189 for (unsigned i = 0; i < cols; i++)
190 dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
191 return dest;
192 }
193
194 case SpvOpFAdd: {
195 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
196 unsigned cols = glsl_get_matrix_columns(src0->type);
197 for (unsigned i = 0; i < cols; i++)
198 dest->elems[i]->def =
199 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
200 return dest;
201 }
202
203 case SpvOpFSub: {
204 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
205 unsigned cols = glsl_get_matrix_columns(src0->type);
206 for (unsigned i = 0; i < cols; i++)
207 dest->elems[i]->def =
208 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
209 return dest;
210 }
211
212 case SpvOpTranspose:
213 return vtn_ssa_transpose(b, src0);
214
215 case SpvOpMatrixTimesScalar:
216 if (src0->transposed) {
217 return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
218 src1->def));
219 } else {
220 return mat_times_scalar(b, src0, src1->def);
221 }
222 break;
223
224 case SpvOpVectorTimesMatrix:
225 case SpvOpMatrixTimesVector:
226 case SpvOpMatrixTimesMatrix:
227 if (opcode == SpvOpVectorTimesMatrix) {
228 return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
229 } else {
230 return matrix_multiply(b, src0, src1);
231 }
232 break;
233
234 default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
235 }
236 }
237
238 static nir_alu_type
convert_op_src_type(SpvOp opcode)239 convert_op_src_type(SpvOp opcode)
240 {
241 switch (opcode) {
242 case SpvOpFConvert:
243 case SpvOpConvertFToS:
244 case SpvOpConvertFToU:
245 return nir_type_float;
246 case SpvOpSConvert:
247 case SpvOpConvertSToF:
248 case SpvOpSatConvertSToU:
249 return nir_type_int;
250 case SpvOpUConvert:
251 case SpvOpConvertUToF:
252 case SpvOpSatConvertUToS:
253 return nir_type_uint;
254 default:
255 unreachable("Unhandled conversion op");
256 }
257 }
258
259 static nir_alu_type
convert_op_dst_type(SpvOp opcode)260 convert_op_dst_type(SpvOp opcode)
261 {
262 switch (opcode) {
263 case SpvOpFConvert:
264 case SpvOpConvertSToF:
265 case SpvOpConvertUToF:
266 return nir_type_float;
267 case SpvOpSConvert:
268 case SpvOpConvertFToS:
269 case SpvOpSatConvertUToS:
270 return nir_type_int;
271 case SpvOpUConvert:
272 case SpvOpConvertFToU:
273 case SpvOpSatConvertSToU:
274 return nir_type_uint;
275 default:
276 unreachable("Unhandled conversion op");
277 }
278 }
279
280 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)281 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
282 SpvOp opcode, bool *swap, bool *exact,
283 unsigned src_bit_size, unsigned dst_bit_size)
284 {
285 /* Indicates that the first two arguments should be swapped. This is
286 * used for implementing greater-than and less-than-or-equal.
287 */
288 *swap = false;
289
290 *exact = false;
291
292 switch (opcode) {
293 case SpvOpSNegate: return nir_op_ineg;
294 case SpvOpFNegate: return nir_op_fneg;
295 case SpvOpNot: return nir_op_inot;
296 case SpvOpIAdd: return nir_op_iadd;
297 case SpvOpFAdd: return nir_op_fadd;
298 case SpvOpISub: return nir_op_isub;
299 case SpvOpFSub: return nir_op_fsub;
300 case SpvOpIMul: return nir_op_imul;
301 case SpvOpFMul: return nir_op_fmul;
302 case SpvOpUDiv: return nir_op_udiv;
303 case SpvOpSDiv: return nir_op_idiv;
304 case SpvOpFDiv: return nir_op_fdiv;
305 case SpvOpUMod: return nir_op_umod;
306 case SpvOpSMod: return nir_op_imod;
307 case SpvOpFMod: return nir_op_fmod;
308 case SpvOpSRem: return nir_op_irem;
309 case SpvOpFRem: return nir_op_frem;
310
311 case SpvOpShiftRightLogical: return nir_op_ushr;
312 case SpvOpShiftRightArithmetic: return nir_op_ishr;
313 case SpvOpShiftLeftLogical: return nir_op_ishl;
314 case SpvOpLogicalOr: return nir_op_ior;
315 case SpvOpLogicalEqual: return nir_op_ieq;
316 case SpvOpLogicalNotEqual: return nir_op_ine;
317 case SpvOpLogicalAnd: return nir_op_iand;
318 case SpvOpLogicalNot: return nir_op_inot;
319 case SpvOpBitwiseOr: return nir_op_ior;
320 case SpvOpBitwiseXor: return nir_op_ixor;
321 case SpvOpBitwiseAnd: return nir_op_iand;
322 case SpvOpSelect: return nir_op_bcsel;
323 case SpvOpIEqual: return nir_op_ieq;
324
325 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
326 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
327 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
328 case SpvOpBitReverse: return nir_op_bitfield_reverse;
329
330 case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
331 /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
332 case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
333 case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
334 case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
335 case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
336 case SpvOpIAverageINTEL: return nir_op_ihadd;
337 case SpvOpUAverageINTEL: return nir_op_uhadd;
338 case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
339 case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
340 case SpvOpISubSatINTEL: return nir_op_isub_sat;
341 case SpvOpUSubSatINTEL: return nir_op_usub_sat;
342 case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
343 case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
344
345 /* The ordered / unordered operators need special implementation besides
346 * the logical operator to use since they also need to check if operands are
347 * ordered.
348 */
349 case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
350 case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
351 case SpvOpINotEqual: return nir_op_ine;
352 case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
353 case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
354 case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
355 case SpvOpULessThan: return nir_op_ult;
356 case SpvOpSLessThan: return nir_op_ilt;
357 case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
358 case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
359 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
360 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
361 case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
362 case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
363 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
364 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
365 case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
366 case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
367 case SpvOpUGreaterThanEqual: return nir_op_uge;
368 case SpvOpSGreaterThanEqual: return nir_op_ige;
369 case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
370 case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
371
372 /* Conversions: */
373 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
374 case SpvOpUConvert:
375 case SpvOpConvertFToU:
376 case SpvOpConvertFToS:
377 case SpvOpConvertSToF:
378 case SpvOpConvertUToF:
379 case SpvOpSConvert:
380 case SpvOpFConvert: {
381 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
382 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
383 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
384 }
385
386 case SpvOpPtrCastToGeneric: return nir_op_mov;
387 case SpvOpGenericCastToPtr: return nir_op_mov;
388
389 /* Derivatives: */
390 case SpvOpDPdx: return nir_op_fddx;
391 case SpvOpDPdy: return nir_op_fddy;
392 case SpvOpDPdxFine: return nir_op_fddx_fine;
393 case SpvOpDPdyFine: return nir_op_fddy_fine;
394 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
395 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
396
397 case SpvOpIsNormal: return nir_op_fisnormal;
398 case SpvOpIsFinite: return nir_op_fisfinite;
399
400 default:
401 vtn_fail("No NIR equivalent: %u", opcode);
402 }
403 }
404
405 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)406 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
407 UNUSED int member, const struct vtn_decoration *dec,
408 UNUSED void *_void)
409 {
410 vtn_assert(dec->scope == VTN_DEC_DECORATION);
411 if (dec->decoration != SpvDecorationNoContraction)
412 return;
413
414 b->nb.exact = true;
415 }
416
417 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)418 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
419 {
420 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
421 }
422
423 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)424 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
425 {
426 switch (mode) {
427 case SpvFPRoundingModeRTE:
428 return nir_rounding_mode_rtne;
429 case SpvFPRoundingModeRTZ:
430 return nir_rounding_mode_rtz;
431 case SpvFPRoundingModeRTP:
432 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
433 "FPRoundingModeRTP is only supported in kernels");
434 return nir_rounding_mode_ru;
435 case SpvFPRoundingModeRTN:
436 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
437 "FPRoundingModeRTN is only supported in kernels");
438 return nir_rounding_mode_rd;
439 default:
440 vtn_fail("Unsupported rounding mode: %s",
441 spirv_fproundingmode_to_string(mode));
442 break;
443 }
444 }
445
446 struct conversion_opts {
447 nir_rounding_mode rounding_mode;
448 bool saturate;
449 };
450
451 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)452 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
453 UNUSED int member,
454 const struct vtn_decoration *dec, void *_opts)
455 {
456 struct conversion_opts *opts = _opts;
457
458 switch (dec->decoration) {
459 case SpvDecorationFPRoundingMode:
460 opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
461 break;
462
463 case SpvDecorationSaturatedConversion:
464 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
465 "Saturated conversions are only allowed in kernels");
466 opts->saturate = true;
467 break;
468
469 default:
470 break;
471 }
472 }
473
474 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)475 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
476 UNUSED int member,
477 const struct vtn_decoration *dec, void *_alu)
478 {
479 nir_alu_instr *alu = _alu;
480 switch (dec->decoration) {
481 case SpvDecorationNoSignedWrap:
482 alu->no_signed_wrap = true;
483 break;
484 case SpvDecorationNoUnsignedWrap:
485 alu->no_unsigned_wrap = true;
486 break;
487 default:
488 /* Do nothing. */
489 break;
490 }
491 }
492
493 static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)494 vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
495 struct vtn_value *val, int member,
496 const struct vtn_decoration *dec, void *void_ctx)
497 {
498 bool *relaxed_precision = void_ctx;
499 switch (dec->decoration) {
500 case SpvDecorationRelaxedPrecision:
501 *relaxed_precision = true;
502 break;
503
504 default:
505 break;
506 }
507 }
508
509 bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)510 vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
511 {
512 bool result = false;
513 vtn_foreach_decoration(b, val,
514 vtn_value_is_relaxed_precision_cb, &result);
515 return result;
516 }
517
518 static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)519 vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
520 {
521 if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
522 return false;
523
524 switch (opcode) {
525 case SpvOpDPdx:
526 case SpvOpDPdy:
527 case SpvOpDPdxFine:
528 case SpvOpDPdyFine:
529 case SpvOpDPdxCoarse:
530 case SpvOpDPdyCoarse:
531 case SpvOpFwidth:
532 case SpvOpFwidthFine:
533 case SpvOpFwidthCoarse:
534 return b->options->mediump_16bit_derivatives;
535 default:
536 return true;
537 }
538 }
539
540 static nir_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)541 vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
542 {
543 if (def->bit_size != 16)
544 return def;
545
546 switch (base_type) {
547 case GLSL_TYPE_FLOAT:
548 return nir_f2f32(&b->nb, def);
549 case GLSL_TYPE_INT:
550 return nir_i2i32(&b->nb, def);
551 case GLSL_TYPE_UINT:
552 return nir_u2u32(&b->nb, def);
553 default:
554 unreachable("bad relaxed precision output type");
555 }
556 }
557
558 void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)559 vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
560 {
561 enum glsl_base_type base_type = glsl_get_base_type(value->type);
562
563 if (glsl_type_is_vector_or_scalar(value->type)) {
564 value->def = vtn_mediump_upconvert(b, base_type, value->def);
565 } else {
566 for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
567 value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
568 }
569 }
570
571 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)572 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
573 const uint32_t *w, unsigned count)
574 {
575 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
576 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
577
578 if (glsl_type_is_cmat(dest_type)) {
579 vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
580 return;
581 }
582
583 vtn_handle_no_contraction(b, dest_val);
584 bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
585
586 /* Collect the various SSA sources */
587 const unsigned num_inputs = count - 3;
588 struct vtn_ssa_value *vtn_src[4] = { NULL, };
589 for (unsigned i = 0; i < num_inputs; i++) {
590 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
591 if (mediump_16bit)
592 vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
593 }
594
595 if (glsl_type_is_matrix(vtn_src[0]->type) ||
596 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
597 struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
598
599 if (mediump_16bit)
600 vtn_mediump_upconvert_value(b, dest);
601
602 vtn_push_ssa_value(b, w[2], dest);
603 b->nb.exact = b->exact;
604 return;
605 }
606
607 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
608 nir_def *src[4] = { NULL, };
609 for (unsigned i = 0; i < num_inputs; i++) {
610 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
611 src[i] = vtn_src[i]->def;
612 }
613
614 switch (opcode) {
615 case SpvOpAny:
616 dest->def = nir_bany(&b->nb, src[0]);
617 break;
618
619 case SpvOpAll:
620 dest->def = nir_ball(&b->nb, src[0]);
621 break;
622
623 case SpvOpOuterProduct: {
624 for (unsigned i = 0; i < src[1]->num_components; i++) {
625 dest->elems[i]->def =
626 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
627 }
628 break;
629 }
630
631 case SpvOpDot:
632 dest->def = nir_fdot(&b->nb, src[0], src[1]);
633 break;
634
635 case SpvOpIAddCarry:
636 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
637 dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
638 dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
639 break;
640
641 case SpvOpISubBorrow:
642 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
643 dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
644 dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
645 break;
646
647 case SpvOpUMulExtended: {
648 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
649 if (src[0]->bit_size == 32) {
650 nir_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
651 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
652 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
653 } else {
654 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
655 dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
656 }
657 break;
658 }
659
660 case SpvOpSMulExtended: {
661 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
662 if (src[0]->bit_size == 32) {
663 nir_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
664 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
665 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
666 } else {
667 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
668 dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
669 }
670 break;
671 }
672
673 case SpvOpFwidth:
674 dest->def = nir_fadd(&b->nb,
675 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
676 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
677 break;
678 case SpvOpFwidthFine:
679 dest->def = nir_fadd(&b->nb,
680 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
681 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
682 break;
683 case SpvOpFwidthCoarse:
684 dest->def = nir_fadd(&b->nb,
685 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
686 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
687 break;
688
689 case SpvOpVectorTimesScalar:
690 /* The builder will take care of splatting for us. */
691 dest->def = nir_fmul(&b->nb, src[0], src[1]);
692 break;
693
694 case SpvOpIsNan: {
695 const bool save_exact = b->nb.exact;
696
697 b->nb.exact = true;
698 dest->def = nir_fneu(&b->nb, src[0], src[0]);
699 b->nb.exact = save_exact;
700 break;
701 }
702
703 case SpvOpOrdered: {
704 const bool save_exact = b->nb.exact;
705
706 b->nb.exact = true;
707 dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
708 nir_feq(&b->nb, src[1], src[1]));
709 b->nb.exact = save_exact;
710 break;
711 }
712
713 case SpvOpUnordered: {
714 const bool save_exact = b->nb.exact;
715
716 b->nb.exact = true;
717 dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
718 nir_fneu(&b->nb, src[1], src[1]));
719 b->nb.exact = save_exact;
720 break;
721 }
722
723 case SpvOpIsInf: {
724 nir_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
725 dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
726 break;
727 }
728
729 case SpvOpFUnordEqual: {
730 const bool save_exact = b->nb.exact;
731
732 b->nb.exact = true;
733
734 /* This could also be implemented as !(a < b || b < a). If one or both
735 * of the source are numbers, later optimization passes can easily
736 * eliminate the isnan() checks. This may trim the sequence down to a
737 * single (a == b) operation. Otherwise, the optimizer can transform
738 * whatever is left to !(a < b || b < a). Since some applications will
739 * open-code this sequence, these optimizations are needed anyway.
740 */
741 dest->def =
742 nir_ior(&b->nb,
743 nir_feq(&b->nb, src[0], src[1]),
744 nir_ior(&b->nb,
745 nir_fneu(&b->nb, src[0], src[0]),
746 nir_fneu(&b->nb, src[1], src[1])));
747
748 b->nb.exact = save_exact;
749 break;
750 }
751
752 case SpvOpFUnordLessThan:
753 case SpvOpFUnordGreaterThan:
754 case SpvOpFUnordLessThanEqual:
755 case SpvOpFUnordGreaterThanEqual: {
756 bool swap;
757 bool unused_exact;
758 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
759 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
760 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
761 &unused_exact,
762 src_bit_size, dst_bit_size);
763
764 if (swap) {
765 nir_def *tmp = src[0];
766 src[0] = src[1];
767 src[1] = tmp;
768 }
769
770 const bool save_exact = b->nb.exact;
771
772 b->nb.exact = true;
773
774 /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
775 switch (op) {
776 case nir_op_fge: op = nir_op_flt; break;
777 case nir_op_flt: op = nir_op_fge; break;
778 default: unreachable("Impossible opcode.");
779 }
780
781 dest->def =
782 nir_inot(&b->nb,
783 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
784
785 b->nb.exact = save_exact;
786 break;
787 }
788
789 case SpvOpLessOrGreater:
790 case SpvOpFOrdNotEqual: {
791 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
792 * from the ALU will probably already be false if the operands are not
793 * ordered so we don’t need to handle it specially.
794 */
795 const bool save_exact = b->nb.exact;
796
797 b->nb.exact = true;
798
799 /* This could also be implemented as (a < b || b < a). If one or both
800 * of the source are numbers, later optimization passes can easily
801 * eliminate the isnan() checks. This may trim the sequence down to a
802 * single (a != b) operation. Otherwise, the optimizer can transform
803 * whatever is left to (a < b || b < a). Since some applications will
804 * open-code this sequence, these optimizations are needed anyway.
805 */
806 dest->def =
807 nir_iand(&b->nb,
808 nir_fneu(&b->nb, src[0], src[1]),
809 nir_iand(&b->nb,
810 nir_feq(&b->nb, src[0], src[0]),
811 nir_feq(&b->nb, src[1], src[1])));
812
813 b->nb.exact = save_exact;
814 break;
815 }
816
817 case SpvOpUConvert:
818 case SpvOpConvertFToU:
819 case SpvOpConvertFToS:
820 case SpvOpConvertSToF:
821 case SpvOpConvertUToF:
822 case SpvOpSConvert:
823 case SpvOpFConvert:
824 case SpvOpSatConvertSToU:
825 case SpvOpSatConvertUToS: {
826 unsigned src_bit_size = src[0]->bit_size;
827 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
828 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
829 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
830
831 struct conversion_opts opts = {
832 .rounding_mode = nir_rounding_mode_undef,
833 .saturate = false,
834 };
835 vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
836
837 if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
838 opts.saturate = true;
839
840 if (b->shader->info.stage == MESA_SHADER_KERNEL) {
841 if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
842 dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
843 nir_rounding_mode_undef);
844 } else {
845 dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
846 src_type, dst_type,
847 opts.rounding_mode, opts.saturate);
848 }
849 } else {
850 vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
851 dst_type != nir_type_float16,
852 "Rounding modes are only allowed on conversions to "
853 "16-bit float types");
854 dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
855 opts.rounding_mode);
856 }
857 break;
858 }
859
860 case SpvOpBitFieldInsert:
861 case SpvOpBitFieldSExtract:
862 case SpvOpBitFieldUExtract:
863 case SpvOpShiftLeftLogical:
864 case SpvOpShiftRightArithmetic:
865 case SpvOpShiftRightLogical: {
866 bool swap;
867 bool exact;
868 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
869 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
870 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
871 src0_bit_size, dst_bit_size);
872
873 assert(!exact);
874
875 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
876 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
877 op == nir_op_ibitfield_extract);
878
879 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
880 unsigned src_bit_size =
881 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
882 if (src_bit_size == 0)
883 continue;
884 if (src_bit_size != src[i]->bit_size) {
885 assert(src_bit_size == 32);
886 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
887 * supported by the NIR instructions. See discussion here:
888 *
889 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
890 */
891 src[i] = nir_u2u32(&b->nb, src[i]);
892 }
893 }
894 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
895 break;
896 }
897
898 case SpvOpSignBitSet:
899 dest->def = nir_i2b(&b->nb,
900 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
901 break;
902
903 case SpvOpUCountTrailingZerosINTEL:
904 dest->def = nir_umin(&b->nb,
905 nir_find_lsb(&b->nb, src[0]),
906 nir_imm_int(&b->nb, 32u));
907 break;
908
909 case SpvOpBitCount: {
910 /* bit_count always returns int32, but the SPIR-V opcode just says the return
911 * value needs to be big enough to store the number of bits.
912 */
913 dest->def = nir_u2uN(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
914 break;
915 }
916
917 case SpvOpSDotKHR:
918 case SpvOpUDotKHR:
919 case SpvOpSUDotKHR:
920 case SpvOpSDotAccSatKHR:
921 case SpvOpUDotAccSatKHR:
922 case SpvOpSUDotAccSatKHR:
923 unreachable("Should have called vtn_handle_integer_dot instead.");
924
925 default: {
926 bool swap;
927 bool exact;
928 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
929 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
930 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
931 &exact,
932 src_bit_size, dst_bit_size);
933
934 if (swap) {
935 nir_def *tmp = src[0];
936 src[0] = src[1];
937 src[1] = tmp;
938 }
939
940 switch (op) {
941 case nir_op_ishl:
942 case nir_op_ishr:
943 case nir_op_ushr:
944 if (src[1]->bit_size != 32)
945 src[1] = nir_u2u32(&b->nb, src[1]);
946 break;
947 default:
948 break;
949 }
950
951 const bool save_exact = b->nb.exact;
952
953 if (exact)
954 b->nb.exact = true;
955
956 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
957
958 b->nb.exact = save_exact;
959 break;
960 } /* default */
961 }
962
963 switch (opcode) {
964 case SpvOpIAdd:
965 case SpvOpIMul:
966 case SpvOpISub:
967 case SpvOpShiftLeftLogical:
968 case SpvOpSNegate: {
969 nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
970 vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
971 break;
972 }
973 default:
974 /* Do nothing. */
975 break;
976 }
977
978 if (mediump_16bit)
979 vtn_mediump_upconvert_value(b, dest);
980 vtn_push_ssa_value(b, w[2], dest);
981
982 b->nb.exact = b->exact;
983 }
984
985 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)986 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
987 const uint32_t *w, unsigned count)
988 {
989 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
990 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
991 const unsigned dest_size = glsl_get_bit_size(dest_type);
992
993 vtn_handle_no_contraction(b, dest_val);
994
995 /* Collect the various SSA sources.
996 *
997 * Due to the optional "Packed Vector Format" field, determine number of
998 * inputs from the opcode. This differs from vtn_handle_alu.
999 */
1000 const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1001 opcode == SpvOpUDotAccSatKHR ||
1002 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1003
1004 vtn_assert(count >= num_inputs + 3);
1005
1006 struct vtn_ssa_value *vtn_src[3] = { NULL, };
1007 nir_def *src[3] = { NULL, };
1008
1009 for (unsigned i = 0; i < num_inputs; i++) {
1010 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1011 src[i] = vtn_src[i]->def;
1012
1013 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1014 }
1015
1016 /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1017 * the SPV_KHR_integer_dot_product spec says:
1018 *
1019 * _Vector 1_ and _Vector 2_ must have the same type.
1020 *
1021 * The practical requirement is the same bit-size and the same number of
1022 * components.
1023 */
1024 vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1025 glsl_get_bit_size(vtn_src[1]->type) ||
1026 glsl_get_vector_elements(vtn_src[0]->type) !=
1027 glsl_get_vector_elements(vtn_src[1]->type),
1028 "Vector 1 and vector 2 source of opcode %s must have the same "
1029 "type",
1030 spirv_op_to_string(opcode));
1031
1032 if (num_inputs == 3) {
1033 /* The SPV_KHR_integer_dot_product spec says:
1034 *
1035 * The type of Accumulator must be the same as Result Type.
1036 *
1037 * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1038 * types (far below) assumes these types have the same size.
1039 */
1040 vtn_fail_if(dest_type != vtn_src[2]->type,
1041 "Accumulator type must be the same as Result Type for "
1042 "opcode %s",
1043 spirv_op_to_string(opcode));
1044 }
1045
1046 unsigned packed_bit_size = 8;
1047 if (glsl_type_is_vector(vtn_src[0]->type)) {
1048 /* FINISHME: Is this actually as good or better for platforms that don't
1049 * have the special instructions (i.e., one or both of has_dot_4x8 or
1050 * has_sudot_4x8 is false)?
1051 */
1052 if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1053 glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1054 glsl_get_bit_size(dest_type) <= 32) {
1055 src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1056 src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1057 } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1058 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1059 glsl_get_bit_size(dest_type) <= 32 &&
1060 opcode != SpvOpSUDotKHR &&
1061 opcode != SpvOpSUDotAccSatKHR) {
1062 src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1063 src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1064 packed_bit_size = 16;
1065 }
1066 } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1067 glsl_type_is_32bit(vtn_src[0]->type)) {
1068 /* The SPV_KHR_integer_dot_product spec says:
1069 *
1070 * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1071 * Vector Format_ must be specified to select how the integers are to
1072 * be interpreted as vectors.
1073 *
1074 * The "Packed Vector Format" value follows the last input.
1075 */
1076 vtn_assert(count == (num_inputs + 4));
1077 const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1078 vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1079 "Unsupported vector packing format %d for opcode %s",
1080 pack_format, spirv_op_to_string(opcode));
1081 } else {
1082 vtn_fail_with_opcode("Invalid source types.", opcode);
1083 }
1084
1085 nir_def *dest = NULL;
1086
1087 if (src[0]->num_components > 1) {
1088 nir_def *(*src0_conversion)(nir_builder *, nir_def *, unsigned);
1089 nir_def *(*src1_conversion)(nir_builder *, nir_def *, unsigned);
1090
1091 switch (opcode) {
1092 case SpvOpSDotKHR:
1093 case SpvOpSDotAccSatKHR:
1094 src0_conversion = nir_i2iN;
1095 src1_conversion = nir_i2iN;
1096 break;
1097
1098 case SpvOpUDotKHR:
1099 case SpvOpUDotAccSatKHR:
1100 src0_conversion = nir_u2uN;
1101 src1_conversion = nir_u2uN;
1102 break;
1103
1104 case SpvOpSUDotKHR:
1105 case SpvOpSUDotAccSatKHR:
1106 src0_conversion = nir_i2iN;
1107 src1_conversion = nir_u2uN;
1108 break;
1109
1110 default:
1111 unreachable("Invalid opcode.");
1112 }
1113
1114 /* The SPV_KHR_integer_dot_product spec says:
1115 *
1116 * All components of the input vectors are sign-extended to the bit
1117 * width of the result's type. The sign-extended input vectors are
1118 * then multiplied component-wise and all components of the vector
1119 * resulting from the component-wise multiplication are added
1120 * together. The resulting value will equal the low-order N bits of
1121 * the correct result R, where N is the result width and R is
1122 * computed with enough precision to avoid overflow and underflow.
1123 */
1124 const unsigned vector_components =
1125 glsl_get_vector_elements(vtn_src[0]->type);
1126
1127 for (unsigned i = 0; i < vector_components; i++) {
1128 nir_def *const src0 =
1129 src0_conversion(&b->nb, nir_channel(&b->nb, src[0], i), dest_size);
1130
1131 nir_def *const src1 =
1132 src1_conversion(&b->nb, nir_channel(&b->nb, src[1], i), dest_size);
1133
1134 nir_def *const mul_result = nir_imul(&b->nb, src0, src1);
1135
1136 dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1137 }
1138
1139 if (num_inputs == 3) {
1140 /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1141 *
1142 * Signed integer dot product of _Vector 1_ and _Vector 2_ and
1143 * signed saturating addition of the result with _Accumulator_.
1144 *
1145 * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1146 *
1147 * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1148 * unsigned saturating addition of the result with _Accumulator_.
1149 *
1150 * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1151 *
1152 * Mixed-signedness integer dot product of _Vector 1_ and _Vector
1153 * 2_ and signed saturating addition of the result with
1154 * _Accumulator_.
1155 */
1156 dest = (opcode == SpvOpUDotAccSatKHR)
1157 ? nir_uadd_sat(&b->nb, dest, src[2])
1158 : nir_iadd_sat(&b->nb, dest, src[2]);
1159 }
1160 } else {
1161 assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1162 assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1163
1164 nir_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1165 bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1166 opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1167
1168 if (packed_bit_size == 16) {
1169 switch (opcode) {
1170 case SpvOpSDotKHR:
1171 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1172 break;
1173 case SpvOpUDotKHR:
1174 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1175 break;
1176 case SpvOpSDotAccSatKHR:
1177 if (dest_size == 32)
1178 dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1179 else
1180 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1181 break;
1182 case SpvOpUDotAccSatKHR:
1183 if (dest_size == 32)
1184 dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1185 else
1186 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1187 break;
1188 default:
1189 unreachable("Invalid opcode.");
1190 }
1191 } else {
1192 switch (opcode) {
1193 case SpvOpSDotKHR:
1194 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1195 break;
1196 case SpvOpUDotKHR:
1197 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1198 break;
1199 case SpvOpSUDotKHR:
1200 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1201 break;
1202 case SpvOpSDotAccSatKHR:
1203 if (dest_size == 32)
1204 dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1205 else
1206 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1207 break;
1208 case SpvOpUDotAccSatKHR:
1209 if (dest_size == 32)
1210 dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1211 else
1212 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1213 break;
1214 case SpvOpSUDotAccSatKHR:
1215 if (dest_size == 32)
1216 dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1217 else
1218 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1219 break;
1220 default:
1221 unreachable("Invalid opcode.");
1222 }
1223 }
1224
1225 if (dest_size != 32) {
1226 /* When the accumulator is 32-bits, a NIR dot-product with saturate
1227 * is generated above. In all other cases a regular dot-product is
1228 * generated above, and separate addition with saturate is generated
1229 * here.
1230 *
1231 * The SPV_KHR_integer_dot_product spec says:
1232 *
1233 * If any of the multiplications or additions, with the exception
1234 * of the final accumulation, overflow or underflow, the result of
1235 * the instruction is undefined.
1236 *
1237 * Therefore it is safe to cast the dot-product result down to the
1238 * size of the accumulator before doing the addition. Since the
1239 * result of the dot-product cannot overflow 32-bits, this is also
1240 * safe to cast up.
1241 */
1242 if (num_inputs == 3) {
1243 dest = is_signed
1244 ? nir_iadd_sat(&b->nb, nir_i2iN(&b->nb, dest, dest_size), src[2])
1245 : nir_uadd_sat(&b->nb, nir_u2uN(&b->nb, dest, dest_size), src[2]);
1246 } else {
1247 dest = is_signed
1248 ? nir_i2iN(&b->nb, dest, dest_size)
1249 : nir_u2uN(&b->nb, dest, dest_size);
1250 }
1251 }
1252 }
1253
1254 vtn_push_nir_ssa(b, w[2], dest);
1255
1256 b->nb.exact = b->exact;
1257 }
1258
1259 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1260 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1261 {
1262 vtn_assert(count == 4);
1263 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1264 *
1265 * "If Result Type has the same number of components as Operand, they
1266 * must also have the same component width, and results are computed per
1267 * component.
1268 *
1269 * If Result Type has a different number of components than Operand, the
1270 * total number of bits in Result Type must equal the total number of
1271 * bits in Operand. Let L be the type, either Result Type or Operand’s
1272 * type, that has the larger number of components. Let S be the other
1273 * type, with the smaller number of components. The number of components
1274 * in L must be an integer multiple of the number of components in S.
1275 * The first component (that is, the only or lowest-numbered component)
1276 * of S maps to the first components of L, and so on, up to the last
1277 * component of S mapping to the last components of L. Within this
1278 * mapping, any single component of S (mapping to multiple components of
1279 * L) maps its lower-ordered bits to the lower-numbered components of L."
1280 */
1281
1282 struct vtn_type *type = vtn_get_type(b, w[1]);
1283 if (type->base_type == vtn_base_type_cooperative_matrix) {
1284 vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
1285 return;
1286 }
1287
1288 struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
1289
1290 vtn_fail_if(src->num_components * src->bit_size !=
1291 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1292 "Source (%%%u) and destination (%%%u) of OpBitcast must have the same "
1293 "total number of bits", w[3], w[2]);
1294 nir_def *val =
1295 nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1296 vtn_push_nir_ssa(b, w[2], val);
1297 }
1298