1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27
28 /*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39 if (val == NULL)
40 return NULL;
41
42 if (glsl_type_is_matrix(val->type))
43 return val;
44
45 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46 dest->type = glsl_get_bare_type(val->type);
47 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48 dest->elems[0] = val;
49
50 return dest;
51 }
52
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56 if (glsl_type_is_matrix(val->type))
57 return val;
58
59 return val->elems[0];
60 }
61
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66
67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72 unsigned src0_rows = glsl_get_vector_elements(src0->type);
73 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76 const struct glsl_type *dest_type;
77 if (src1_columns > 1) {
78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79 src0_rows, src1_columns);
80 } else {
81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 }
83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85 dest = wrap_matrix(b, dest);
86
87 bool transpose_result = false;
88 if (src0_transpose && src1_transpose) {
89 /* transpose(A) * transpose(B) = transpose(B * A) */
90 src1 = src0_transpose;
91 src0 = src1_transpose;
92 src0_transpose = NULL;
93 src1_transpose = NULL;
94 transpose_result = true;
95 }
96
97 if (src0_transpose && !src1_transpose &&
98 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99 /* We already have the rows of src0 and the columns of src1 available,
100 * so we can just take the dot product of each row with each column to
101 * get the result.
102 */
103
104 for (unsigned i = 0; i < src1_columns; i++) {
105 nir_ssa_def *vec_src[4];
106 for (unsigned j = 0; j < src0_rows; j++) {
107 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108 src1->elems[i]->def);
109 }
110 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111 }
112 } else {
113 /* We don't handle the case where src1 is transposed but not src0, since
114 * the general case only uses individual components of src1 so the
115 * optimizer should chew through the transpose we emitted for src1.
116 */
117
118 for (unsigned i = 0; i < src1_columns; i++) {
119 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120 dest->elems[i]->def =
121 nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122 nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123 for (int j = src0_columns - 2; j >= 0; j--) {
124 dest->elems[i]->def =
125 nir_ffma(&b->nb, src0->elems[j]->def,
126 nir_channel(&b->nb, src1->elems[i]->def, j),
127 dest->elems[i]->def);
128 }
129 }
130 }
131
132 dest = unwrap_matrix(dest);
133
134 if (transpose_result)
135 dest = vtn_ssa_transpose(b, dest);
136
137 return dest;
138 }
139
140 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_ssa_def * scalar)141 mat_times_scalar(struct vtn_builder *b,
142 struct vtn_ssa_value *mat,
143 nir_ssa_def *scalar)
144 {
145 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 else
150 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151 }
152
153 return dest;
154 }
155
156 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)157 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159 {
160 switch (opcode) {
161 case SpvOpFNegate: {
162 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
163 unsigned cols = glsl_get_matrix_columns(src0->type);
164 for (unsigned i = 0; i < cols; i++)
165 dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166 return dest;
167 }
168
169 case SpvOpFAdd: {
170 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
171 unsigned cols = glsl_get_matrix_columns(src0->type);
172 for (unsigned i = 0; i < cols; i++)
173 dest->elems[i]->def =
174 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175 return dest;
176 }
177
178 case SpvOpFSub: {
179 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
180 unsigned cols = glsl_get_matrix_columns(src0->type);
181 for (unsigned i = 0; i < cols; i++)
182 dest->elems[i]->def =
183 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184 return dest;
185 }
186
187 case SpvOpTranspose:
188 return vtn_ssa_transpose(b, src0);
189
190 case SpvOpMatrixTimesScalar:
191 if (src0->transposed) {
192 return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193 src1->def));
194 } else {
195 return mat_times_scalar(b, src0, src1->def);
196 }
197 break;
198
199 case SpvOpVectorTimesMatrix:
200 case SpvOpMatrixTimesVector:
201 case SpvOpMatrixTimesMatrix:
202 if (opcode == SpvOpVectorTimesMatrix) {
203 return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204 } else {
205 return matrix_multiply(b, src0, src1);
206 }
207 break;
208
209 default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
210 }
211 }
212
213 static nir_alu_type
convert_op_src_type(SpvOp opcode)214 convert_op_src_type(SpvOp opcode)
215 {
216 switch (opcode) {
217 case SpvOpFConvert:
218 case SpvOpConvertFToS:
219 case SpvOpConvertFToU:
220 return nir_type_float;
221 case SpvOpSConvert:
222 case SpvOpConvertSToF:
223 case SpvOpSatConvertSToU:
224 return nir_type_int;
225 case SpvOpUConvert:
226 case SpvOpConvertUToF:
227 case SpvOpSatConvertUToS:
228 return nir_type_uint;
229 default:
230 unreachable("Unhandled conversion op");
231 }
232 }
233
234 static nir_alu_type
convert_op_dst_type(SpvOp opcode)235 convert_op_dst_type(SpvOp opcode)
236 {
237 switch (opcode) {
238 case SpvOpFConvert:
239 case SpvOpConvertSToF:
240 case SpvOpConvertUToF:
241 return nir_type_float;
242 case SpvOpSConvert:
243 case SpvOpConvertFToS:
244 case SpvOpSatConvertUToS:
245 return nir_type_int;
246 case SpvOpUConvert:
247 case SpvOpConvertFToU:
248 case SpvOpSatConvertSToU:
249 return nir_type_uint;
250 default:
251 unreachable("Unhandled conversion op");
252 }
253 }
254
255 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)256 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
257 SpvOp opcode, bool *swap, bool *exact,
258 unsigned src_bit_size, unsigned dst_bit_size)
259 {
260 /* Indicates that the first two arguments should be swapped. This is
261 * used for implementing greater-than and less-than-or-equal.
262 */
263 *swap = false;
264
265 *exact = false;
266
267 switch (opcode) {
268 case SpvOpSNegate: return nir_op_ineg;
269 case SpvOpFNegate: return nir_op_fneg;
270 case SpvOpNot: return nir_op_inot;
271 case SpvOpIAdd: return nir_op_iadd;
272 case SpvOpFAdd: return nir_op_fadd;
273 case SpvOpISub: return nir_op_isub;
274 case SpvOpFSub: return nir_op_fsub;
275 case SpvOpIMul: return nir_op_imul;
276 case SpvOpFMul: return nir_op_fmul;
277 case SpvOpUDiv: return nir_op_udiv;
278 case SpvOpSDiv: return nir_op_idiv;
279 case SpvOpFDiv: return nir_op_fdiv;
280 case SpvOpUMod: return nir_op_umod;
281 case SpvOpSMod: return nir_op_imod;
282 case SpvOpFMod: return nir_op_fmod;
283 case SpvOpSRem: return nir_op_irem;
284 case SpvOpFRem: return nir_op_frem;
285
286 case SpvOpShiftRightLogical: return nir_op_ushr;
287 case SpvOpShiftRightArithmetic: return nir_op_ishr;
288 case SpvOpShiftLeftLogical: return nir_op_ishl;
289 case SpvOpLogicalOr: return nir_op_ior;
290 case SpvOpLogicalEqual: return nir_op_ieq;
291 case SpvOpLogicalNotEqual: return nir_op_ine;
292 case SpvOpLogicalAnd: return nir_op_iand;
293 case SpvOpLogicalNot: return nir_op_inot;
294 case SpvOpBitwiseOr: return nir_op_ior;
295 case SpvOpBitwiseXor: return nir_op_ixor;
296 case SpvOpBitwiseAnd: return nir_op_iand;
297 case SpvOpSelect: return nir_op_bcsel;
298 case SpvOpIEqual: return nir_op_ieq;
299
300 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
301 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
302 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
303 case SpvOpBitReverse: return nir_op_bitfield_reverse;
304
305 case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
306 /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
307 case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
308 case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
309 case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
310 case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
311 case SpvOpIAverageINTEL: return nir_op_ihadd;
312 case SpvOpUAverageINTEL: return nir_op_uhadd;
313 case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
314 case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
315 case SpvOpISubSatINTEL: return nir_op_isub_sat;
316 case SpvOpUSubSatINTEL: return nir_op_usub_sat;
317 case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
318 case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
319
320 /* The ordered / unordered operators need special implementation besides
321 * the logical operator to use since they also need to check if operands are
322 * ordered.
323 */
324 case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
325 case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
326 case SpvOpINotEqual: return nir_op_ine;
327 case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
328 case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
329 case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
330 case SpvOpULessThan: return nir_op_ult;
331 case SpvOpSLessThan: return nir_op_ilt;
332 case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
333 case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
334 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
335 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
336 case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
337 case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
338 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
339 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
340 case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
341 case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
342 case SpvOpUGreaterThanEqual: return nir_op_uge;
343 case SpvOpSGreaterThanEqual: return nir_op_ige;
344 case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
345 case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
346
347 /* Conversions: */
348 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
349 case SpvOpUConvert:
350 case SpvOpConvertFToU:
351 case SpvOpConvertFToS:
352 case SpvOpConvertSToF:
353 case SpvOpConvertUToF:
354 case SpvOpSConvert:
355 case SpvOpFConvert: {
356 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
357 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
358 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
359 }
360
361 case SpvOpPtrCastToGeneric: return nir_op_mov;
362 case SpvOpGenericCastToPtr: return nir_op_mov;
363
364 /* Derivatives: */
365 case SpvOpDPdx: return nir_op_fddx;
366 case SpvOpDPdy: return nir_op_fddy;
367 case SpvOpDPdxFine: return nir_op_fddx_fine;
368 case SpvOpDPdyFine: return nir_op_fddy_fine;
369 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
370 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
371
372 case SpvOpIsNormal: return nir_op_fisnormal;
373 case SpvOpIsFinite: return nir_op_fisfinite;
374
375 default:
376 vtn_fail("No NIR equivalent: %u", opcode);
377 }
378 }
379
380 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)381 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
382 UNUSED int member, const struct vtn_decoration *dec,
383 UNUSED void *_void)
384 {
385 vtn_assert(dec->scope == VTN_DEC_DECORATION);
386 if (dec->decoration != SpvDecorationNoContraction)
387 return;
388
389 b->nb.exact = true;
390 }
391
392 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)393 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
394 {
395 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
396 }
397
398 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)399 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
400 {
401 switch (mode) {
402 case SpvFPRoundingModeRTE:
403 return nir_rounding_mode_rtne;
404 case SpvFPRoundingModeRTZ:
405 return nir_rounding_mode_rtz;
406 case SpvFPRoundingModeRTP:
407 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
408 "FPRoundingModeRTP is only supported in kernels");
409 return nir_rounding_mode_ru;
410 case SpvFPRoundingModeRTN:
411 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
412 "FPRoundingModeRTN is only supported in kernels");
413 return nir_rounding_mode_rd;
414 default:
415 vtn_fail("Unsupported rounding mode: %s",
416 spirv_fproundingmode_to_string(mode));
417 break;
418 }
419 }
420
421 struct conversion_opts {
422 nir_rounding_mode rounding_mode;
423 bool saturate;
424 };
425
426 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)427 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
428 UNUSED int member,
429 const struct vtn_decoration *dec, void *_opts)
430 {
431 struct conversion_opts *opts = _opts;
432
433 switch (dec->decoration) {
434 case SpvDecorationFPRoundingMode:
435 opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
436 break;
437
438 case SpvDecorationSaturatedConversion:
439 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
440 "Saturated conversions are only allowed in kernels");
441 opts->saturate = true;
442 break;
443
444 default:
445 break;
446 }
447 }
448
449 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)450 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
451 UNUSED int member,
452 const struct vtn_decoration *dec, void *_alu)
453 {
454 nir_alu_instr *alu = _alu;
455 switch (dec->decoration) {
456 case SpvDecorationNoSignedWrap:
457 alu->no_signed_wrap = true;
458 break;
459 case SpvDecorationNoUnsignedWrap:
460 alu->no_unsigned_wrap = true;
461 break;
462 default:
463 /* Do nothing. */
464 break;
465 }
466 }
467
468 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)469 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
470 const uint32_t *w, unsigned count)
471 {
472 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
473 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
474
475 vtn_handle_no_contraction(b, dest_val);
476
477 /* Collect the various SSA sources */
478 const unsigned num_inputs = count - 3;
479 struct vtn_ssa_value *vtn_src[4] = { NULL, };
480 for (unsigned i = 0; i < num_inputs; i++)
481 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
482
483 if (glsl_type_is_matrix(vtn_src[0]->type) ||
484 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
485 vtn_push_ssa_value(b, w[2],
486 vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
487 b->nb.exact = b->exact;
488 return;
489 }
490
491 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
492 nir_ssa_def *src[4] = { NULL, };
493 for (unsigned i = 0; i < num_inputs; i++) {
494 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
495 src[i] = vtn_src[i]->def;
496 }
497
498 switch (opcode) {
499 case SpvOpAny:
500 dest->def = nir_bany(&b->nb, src[0]);
501 break;
502
503 case SpvOpAll:
504 dest->def = nir_ball(&b->nb, src[0]);
505 break;
506
507 case SpvOpOuterProduct: {
508 for (unsigned i = 0; i < src[1]->num_components; i++) {
509 dest->elems[i]->def =
510 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
511 }
512 break;
513 }
514
515 case SpvOpDot:
516 dest->def = nir_fdot(&b->nb, src[0], src[1]);
517 break;
518
519 case SpvOpIAddCarry:
520 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
521 dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
522 dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
523 break;
524
525 case SpvOpISubBorrow:
526 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
527 dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
528 dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
529 break;
530
531 case SpvOpUMulExtended: {
532 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
533 nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
534 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
535 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
536 break;
537 }
538
539 case SpvOpSMulExtended: {
540 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
541 nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
542 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
543 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
544 break;
545 }
546
547 case SpvOpFwidth:
548 dest->def = nir_fadd(&b->nb,
549 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
550 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
551 break;
552 case SpvOpFwidthFine:
553 dest->def = nir_fadd(&b->nb,
554 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
555 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
556 break;
557 case SpvOpFwidthCoarse:
558 dest->def = nir_fadd(&b->nb,
559 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
560 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
561 break;
562
563 case SpvOpVectorTimesScalar:
564 /* The builder will take care of splatting for us. */
565 dest->def = nir_fmul(&b->nb, src[0], src[1]);
566 break;
567
568 case SpvOpIsNan: {
569 const bool save_exact = b->nb.exact;
570
571 b->nb.exact = true;
572 dest->def = nir_fneu(&b->nb, src[0], src[0]);
573 b->nb.exact = save_exact;
574 break;
575 }
576
577 case SpvOpOrdered: {
578 const bool save_exact = b->nb.exact;
579
580 b->nb.exact = true;
581 dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
582 nir_feq(&b->nb, src[1], src[1]));
583 b->nb.exact = save_exact;
584 break;
585 }
586
587 case SpvOpUnordered: {
588 const bool save_exact = b->nb.exact;
589
590 b->nb.exact = true;
591 dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
592 nir_fneu(&b->nb, src[1], src[1]));
593 b->nb.exact = save_exact;
594 break;
595 }
596
597 case SpvOpIsInf: {
598 nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
599 dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
600 break;
601 }
602
603 case SpvOpFUnordEqual: {
604 const bool save_exact = b->nb.exact;
605
606 b->nb.exact = true;
607
608 /* This could also be implemented as !(a < b || b < a). If one or both
609 * of the source are numbers, later optimization passes can easily
610 * eliminate the isnan() checks. This may trim the sequence down to a
611 * single (a == b) operation. Otherwise, the optimizer can transform
612 * whatever is left to !(a < b || b < a). Since some applications will
613 * open-code this sequence, these optimizations are needed anyway.
614 */
615 dest->def =
616 nir_ior(&b->nb,
617 nir_feq(&b->nb, src[0], src[1]),
618 nir_ior(&b->nb,
619 nir_fneu(&b->nb, src[0], src[0]),
620 nir_fneu(&b->nb, src[1], src[1])));
621
622 b->nb.exact = save_exact;
623 break;
624 }
625
626 case SpvOpFUnordLessThan:
627 case SpvOpFUnordGreaterThan:
628 case SpvOpFUnordLessThanEqual:
629 case SpvOpFUnordGreaterThanEqual: {
630 bool swap;
631 bool unused_exact;
632 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
633 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
634 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
635 &unused_exact,
636 src_bit_size, dst_bit_size);
637
638 if (swap) {
639 nir_ssa_def *tmp = src[0];
640 src[0] = src[1];
641 src[1] = tmp;
642 }
643
644 const bool save_exact = b->nb.exact;
645
646 b->nb.exact = true;
647
648 /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
649 switch (op) {
650 case nir_op_fge: op = nir_op_flt; break;
651 case nir_op_flt: op = nir_op_fge; break;
652 default: unreachable("Impossible opcode.");
653 }
654
655 dest->def =
656 nir_inot(&b->nb,
657 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
658
659 b->nb.exact = save_exact;
660 break;
661 }
662
663 case SpvOpLessOrGreater:
664 case SpvOpFOrdNotEqual: {
665 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
666 * from the ALU will probably already be false if the operands are not
667 * ordered so we don’t need to handle it specially.
668 */
669 const bool save_exact = b->nb.exact;
670
671 b->nb.exact = true;
672
673 /* This could also be implemented as (a < b || b < a). If one or both
674 * of the source are numbers, later optimization passes can easily
675 * eliminate the isnan() checks. This may trim the sequence down to a
676 * single (a != b) operation. Otherwise, the optimizer can transform
677 * whatever is left to (a < b || b < a). Since some applications will
678 * open-code this sequence, these optimizations are needed anyway.
679 */
680 dest->def =
681 nir_iand(&b->nb,
682 nir_fneu(&b->nb, src[0], src[1]),
683 nir_iand(&b->nb,
684 nir_feq(&b->nb, src[0], src[0]),
685 nir_feq(&b->nb, src[1], src[1])));
686
687 b->nb.exact = save_exact;
688 break;
689 }
690
691 case SpvOpUConvert:
692 case SpvOpConvertFToU:
693 case SpvOpConvertFToS:
694 case SpvOpConvertSToF:
695 case SpvOpConvertUToF:
696 case SpvOpSConvert:
697 case SpvOpFConvert:
698 case SpvOpSatConvertSToU:
699 case SpvOpSatConvertUToS: {
700 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
701 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
702 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
703 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
704
705 struct conversion_opts opts = {
706 .rounding_mode = nir_rounding_mode_undef,
707 .saturate = false,
708 };
709 vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
710
711 if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
712 opts.saturate = true;
713
714 if (b->shader->info.stage == MESA_SHADER_KERNEL) {
715 if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
716 nir_op op = nir_type_conversion_op(src_type, dst_type,
717 nir_rounding_mode_undef);
718 dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
719 } else {
720 dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
721 src_type, dst_type,
722 opts.rounding_mode, opts.saturate);
723 }
724 } else {
725 vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
726 dst_type != nir_type_float16,
727 "Rounding modes are only allowed on conversions to "
728 "16-bit float types");
729 nir_op op = nir_type_conversion_op(src_type, dst_type,
730 opts.rounding_mode);
731 dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
732 }
733 break;
734 }
735
736 case SpvOpBitFieldInsert:
737 case SpvOpBitFieldSExtract:
738 case SpvOpBitFieldUExtract:
739 case SpvOpShiftLeftLogical:
740 case SpvOpShiftRightArithmetic:
741 case SpvOpShiftRightLogical: {
742 bool swap;
743 bool exact;
744 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
745 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
746 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
747 src0_bit_size, dst_bit_size);
748
749 assert(!exact);
750
751 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
752 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
753 op == nir_op_ibitfield_extract);
754
755 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
756 unsigned src_bit_size =
757 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
758 if (src_bit_size == 0)
759 continue;
760 if (src_bit_size != src[i]->bit_size) {
761 assert(src_bit_size == 32);
762 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
763 * supported by the NIR instructions. See discussion here:
764 *
765 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
766 */
767 src[i] = nir_u2u32(&b->nb, src[i]);
768 }
769 }
770 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
771 break;
772 }
773
774 case SpvOpSignBitSet:
775 dest->def = nir_i2b(&b->nb,
776 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
777 break;
778
779 case SpvOpUCountTrailingZerosINTEL:
780 dest->def = nir_umin(&b->nb,
781 nir_find_lsb(&b->nb, src[0]),
782 nir_imm_int(&b->nb, 32u));
783 break;
784
785 case SpvOpBitCount: {
786 /* bit_count always returns int32, but the SPIR-V opcode just says the return
787 * value needs to be big enough to store the number of bits.
788 */
789 dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
790 break;
791 }
792
793 case SpvOpSDotKHR:
794 case SpvOpUDotKHR:
795 case SpvOpSUDotKHR:
796 case SpvOpSDotAccSatKHR:
797 case SpvOpUDotAccSatKHR:
798 case SpvOpSUDotAccSatKHR:
799 unreachable("Should have called vtn_handle_integer_dot instead.");
800
801 default: {
802 bool swap;
803 bool exact;
804 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
805 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
806 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
807 &exact,
808 src_bit_size, dst_bit_size);
809
810 if (swap) {
811 nir_ssa_def *tmp = src[0];
812 src[0] = src[1];
813 src[1] = tmp;
814 }
815
816 switch (op) {
817 case nir_op_ishl:
818 case nir_op_ishr:
819 case nir_op_ushr:
820 if (src[1]->bit_size != 32)
821 src[1] = nir_u2u32(&b->nb, src[1]);
822 break;
823 default:
824 break;
825 }
826
827 const bool save_exact = b->nb.exact;
828
829 if (exact)
830 b->nb.exact = true;
831
832 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
833
834 b->nb.exact = save_exact;
835 break;
836 } /* default */
837 }
838
839 switch (opcode) {
840 case SpvOpIAdd:
841 case SpvOpIMul:
842 case SpvOpISub:
843 case SpvOpShiftLeftLogical:
844 case SpvOpSNegate: {
845 nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
846 vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
847 break;
848 }
849 default:
850 /* Do nothing. */
851 break;
852 }
853
854 vtn_push_ssa_value(b, w[2], dest);
855
856 b->nb.exact = b->exact;
857 }
858
859 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)860 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
861 const uint32_t *w, unsigned count)
862 {
863 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
864 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
865 const unsigned dest_size = glsl_get_bit_size(dest_type);
866
867 vtn_handle_no_contraction(b, dest_val);
868
869 /* Collect the various SSA sources.
870 *
871 * Due to the optional "Packed Vector Format" field, determine number of
872 * inputs from the opcode. This differs from vtn_handle_alu.
873 */
874 const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
875 opcode == SpvOpUDotAccSatKHR ||
876 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
877
878 vtn_assert(count >= num_inputs + 3);
879
880 struct vtn_ssa_value *vtn_src[3] = { NULL, };
881 nir_ssa_def *src[3] = { NULL, };
882
883 for (unsigned i = 0; i < num_inputs; i++) {
884 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
885 src[i] = vtn_src[i]->def;
886
887 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
888 }
889
890 /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
891 * the SPV_KHR_integer_dot_product spec says:
892 *
893 * _Vector 1_ and _Vector 2_ must have the same type.
894 *
895 * The practical requirement is the same bit-size and the same number of
896 * components.
897 */
898 vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
899 glsl_get_bit_size(vtn_src[1]->type) ||
900 glsl_get_vector_elements(vtn_src[0]->type) !=
901 glsl_get_vector_elements(vtn_src[1]->type),
902 "Vector 1 and vector 2 source of opcode %s must have the same "
903 "type",
904 spirv_op_to_string(opcode));
905
906 if (num_inputs == 3) {
907 /* The SPV_KHR_integer_dot_product spec says:
908 *
909 * The type of Accumulator must be the same as Result Type.
910 *
911 * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
912 * types (far below) assumes these types have the same size.
913 */
914 vtn_fail_if(dest_type != vtn_src[2]->type,
915 "Accumulator type must be the same as Result Type for "
916 "opcode %s",
917 spirv_op_to_string(opcode));
918 }
919
920 unsigned packed_bit_size = 8;
921 if (glsl_type_is_vector(vtn_src[0]->type)) {
922 /* FINISHME: Is this actually as good or better for platforms that don't
923 * have the special instructions (i.e., one or both of has_dot_4x8 or
924 * has_sudot_4x8 is false)?
925 */
926 if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
927 glsl_get_bit_size(vtn_src[0]->type) == 8 &&
928 glsl_get_bit_size(dest_type) <= 32) {
929 src[0] = nir_pack_32_4x8(&b->nb, src[0]);
930 src[1] = nir_pack_32_4x8(&b->nb, src[1]);
931 } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
932 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
933 glsl_get_bit_size(dest_type) <= 32 &&
934 opcode != SpvOpSUDotKHR &&
935 opcode != SpvOpSUDotAccSatKHR) {
936 src[0] = nir_pack_32_2x16(&b->nb, src[0]);
937 src[1] = nir_pack_32_2x16(&b->nb, src[1]);
938 packed_bit_size = 16;
939 }
940 } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
941 glsl_type_is_32bit(vtn_src[0]->type)) {
942 /* The SPV_KHR_integer_dot_product spec says:
943 *
944 * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
945 * Vector Format_ must be specified to select how the integers are to
946 * be interpreted as vectors.
947 *
948 * The "Packed Vector Format" value follows the last input.
949 */
950 vtn_assert(count == (num_inputs + 4));
951 const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
952 vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
953 "Unsupported vector packing format %d for opcode %s",
954 pack_format, spirv_op_to_string(opcode));
955 } else {
956 vtn_fail_with_opcode("Invalid source types.", opcode);
957 }
958
959 nir_ssa_def *dest = NULL;
960
961 if (src[0]->num_components > 1) {
962 const nir_op s_conversion_op =
963 nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
964 nir_rounding_mode_undef);
965
966 const nir_op u_conversion_op =
967 nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
968 nir_rounding_mode_undef);
969
970 nir_op src0_conversion_op;
971 nir_op src1_conversion_op;
972
973 switch (opcode) {
974 case SpvOpSDotKHR:
975 case SpvOpSDotAccSatKHR:
976 src0_conversion_op = s_conversion_op;
977 src1_conversion_op = s_conversion_op;
978 break;
979
980 case SpvOpUDotKHR:
981 case SpvOpUDotAccSatKHR:
982 src0_conversion_op = u_conversion_op;
983 src1_conversion_op = u_conversion_op;
984 break;
985
986 case SpvOpSUDotKHR:
987 case SpvOpSUDotAccSatKHR:
988 src0_conversion_op = s_conversion_op;
989 src1_conversion_op = u_conversion_op;
990 break;
991
992 default:
993 unreachable("Invalid opcode.");
994 }
995
996 /* The SPV_KHR_integer_dot_product spec says:
997 *
998 * All components of the input vectors are sign-extended to the bit
999 * width of the result's type. The sign-extended input vectors are
1000 * then multiplied component-wise and all components of the vector
1001 * resulting from the component-wise multiplication are added
1002 * together. The resulting value will equal the low-order N bits of
1003 * the correct result R, where N is the result width and R is
1004 * computed with enough precision to avoid overflow and underflow.
1005 */
1006 const unsigned vector_components =
1007 glsl_get_vector_elements(vtn_src[0]->type);
1008
1009 for (unsigned i = 0; i < vector_components; i++) {
1010 nir_ssa_def *const src0 =
1011 nir_build_alu(&b->nb, src0_conversion_op,
1012 nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
1013
1014 nir_ssa_def *const src1 =
1015 nir_build_alu(&b->nb, src1_conversion_op,
1016 nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
1017
1018 nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
1019
1020 dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1021 }
1022
1023 if (num_inputs == 3) {
1024 /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1025 *
1026 * Signed integer dot product of _Vector 1_ and _Vector 2_ and
1027 * signed saturating addition of the result with _Accumulator_.
1028 *
1029 * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1030 *
1031 * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1032 * unsigned saturating addition of the result with _Accumulator_.
1033 *
1034 * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1035 *
1036 * Mixed-signedness integer dot product of _Vector 1_ and _Vector
1037 * 2_ and signed saturating addition of the result with
1038 * _Accumulator_.
1039 */
1040 dest = (opcode == SpvOpUDotAccSatKHR)
1041 ? nir_uadd_sat(&b->nb, dest, src[2])
1042 : nir_iadd_sat(&b->nb, dest, src[2]);
1043 }
1044 } else {
1045 assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1046 assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1047
1048 nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1049 bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1050 opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1051
1052 if (packed_bit_size == 16) {
1053 switch (opcode) {
1054 case SpvOpSDotKHR:
1055 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1056 break;
1057 case SpvOpUDotKHR:
1058 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1059 break;
1060 case SpvOpSDotAccSatKHR:
1061 if (dest_size == 32)
1062 dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1063 else
1064 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1065 break;
1066 case SpvOpUDotAccSatKHR:
1067 if (dest_size == 32)
1068 dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1069 else
1070 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1071 break;
1072 default:
1073 unreachable("Invalid opcode.");
1074 }
1075 } else {
1076 switch (opcode) {
1077 case SpvOpSDotKHR:
1078 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1079 break;
1080 case SpvOpUDotKHR:
1081 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1082 break;
1083 case SpvOpSUDotKHR:
1084 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1085 break;
1086 case SpvOpSDotAccSatKHR:
1087 if (dest_size == 32)
1088 dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1089 else
1090 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1091 break;
1092 case SpvOpUDotAccSatKHR:
1093 if (dest_size == 32)
1094 dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1095 else
1096 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1097 break;
1098 case SpvOpSUDotAccSatKHR:
1099 if (dest_size == 32)
1100 dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1101 else
1102 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1103 break;
1104 default:
1105 unreachable("Invalid opcode.");
1106 }
1107 }
1108
1109 if (dest_size != 32) {
1110 /* When the accumulator is 32-bits, a NIR dot-product with saturate
1111 * is generated above. In all other cases a regular dot-product is
1112 * generated above, and separate addition with saturate is generated
1113 * here.
1114 *
1115 * The SPV_KHR_integer_dot_product spec says:
1116 *
1117 * If any of the multiplications or additions, with the exception
1118 * of the final accumulation, overflow or underflow, the result of
1119 * the instruction is undefined.
1120 *
1121 * Therefore it is safe to cast the dot-product result down to the
1122 * size of the accumulator before doing the addition. Since the
1123 * result of the dot-product cannot overflow 32-bits, this is also
1124 * safe to cast up.
1125 */
1126 if (num_inputs == 3) {
1127 dest = is_signed
1128 ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
1129 : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
1130 } else {
1131 dest = is_signed
1132 ? nir_i2i(&b->nb, dest, dest_size)
1133 : nir_u2u(&b->nb, dest, dest_size);
1134 }
1135 }
1136 }
1137
1138 vtn_push_nir_ssa(b, w[2], dest);
1139
1140 b->nb.exact = b->exact;
1141 }
1142
1143 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1144 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1145 {
1146 vtn_assert(count == 4);
1147 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1148 *
1149 * "If Result Type has the same number of components as Operand, they
1150 * must also have the same component width, and results are computed per
1151 * component.
1152 *
1153 * If Result Type has a different number of components than Operand, the
1154 * total number of bits in Result Type must equal the total number of
1155 * bits in Operand. Let L be the type, either Result Type or Operand’s
1156 * type, that has the larger number of components. Let S be the other
1157 * type, with the smaller number of components. The number of components
1158 * in L must be an integer multiple of the number of components in S.
1159 * The first component (that is, the only or lowest-numbered component)
1160 * of S maps to the first components of L, and so on, up to the last
1161 * component of S mapping to the last components of L. Within this
1162 * mapping, any single component of S (mapping to multiple components of
1163 * L) maps its lower-ordered bits to the lower-numbered components of L."
1164 */
1165
1166 struct vtn_type *type = vtn_get_type(b, w[1]);
1167 struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
1168
1169 vtn_fail_if(src->num_components * src->bit_size !=
1170 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1171 "Source and destination of OpBitcast must have the same "
1172 "total number of bits");
1173 nir_ssa_def *val =
1174 nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1175 vtn_push_nir_ssa(b, w[2], val);
1176 }
1177